import json, array, ctypes
import pylase as ol
from utils import *

raster_fx = ctypes.cdll.LoadLibrary("./raster_fx.so")

class Vector(object):
    def __init__(self, v):
        self.v = tuple(v)

    @property
    def x(self):
        return self.v[0]

    @property
    def y(self):
        return self.v[1]

    @property
    def z(self):
        return self.v[2]

    def emul(self, other):
        x0, y0, z0 = self.v
        x1, y1, z1 = other.v
        return Vector((x0 * x1, y0 * y1, z0 * z1))

    def cross(self, other):
        x0, y0, z0 = self.v
        x1, y1, z1 = other.v
        return x0 * x1 + y0 * y1 + z0 * z1

    def __mul__(self, other):
        x0, y0, z0 = self.v
        if isinstance(other, Vector):
            x1, y1, z1 = other.v
            return Vector((
                y0 * z1 - z0 * y1,
                z0 * x1 - x0 * z1,
                x0 * y1 - y0 * x1
            ))
        else:
            return Vector((other * x0, other * y0, other * z0))

    def __sub__(self, other):
        x0, y0, z0 = self.v
        x1, y1, z1 = other.v
        return Vector((x0-x1, y0-y1, z0-z1))

    def __add__(self, other):
        x0, y0, z0 = self.v
        if isinstance(other, Vector):
            x1, y1, z1 = other.v
            return Vector((x0+x1, y0+y1, z0+z1))
        else:
            return Vector((x0+other, y0+other, z0+other))

    def __neg__(self):
        x0, y0, z0 = self.v
        return Vector((-x0, -y0, -z0))

    def __repr__(self):
        return "Vector(%f, %f, %f)" % self.v

class ZBuffer(object):
    FAR = 1.0
    def __init__(self, size):
        self.size = size
        self.thresh = 0.000025
        self.buf = array.array('f', [self.FAR]*size * size)

    def clear(self):
        # meh
        self.buf = array.array('f', [self.FAR]*self.size*self.size)

    def ztest(self, vertex, color):
        x, y, z = vertex
        if x < -1 or x >= 1 or y < -1 or y >= 1:
            return vertex, 0
        dx = int((x+1) * 0.5 * self.size)
        dy = int((y+1) * 0.5 * self.size)
        bv = self.buf[dx + dy * self.size]
        # hax
        t = self.thresh * ((z - 1) * 250 + 1)
        if (1.0/z - 1.0/self.buf[dx + dy * self.size]) > t:
            return vertex, 0
        else:
            return vertex, color

    def enable(self):
        ol.setPixel3Shader(self.ztest)

    def disable(self):
        ol.setPixel3Shader(None)

    def render_triangle(self, a, b, c):
        addr, count = self.buf.buffer_info()

        raster_fx.render_triangle(
            ctypes.cast(addr, ctypes.POINTER(ctypes.c_double)),
            self.size,
            self.size,
            int((a.x + 1.0) * self.size * 0.5),
            int((a.y + 1.0) * self.size * 0.5), ctypes.c_float(1.0/a.z),
            int((b.x + 1.0) * self.size * 0.5),
            int((b.y + 1.0) * self.size * 0.5), ctypes.c_float(1.0/b.z),
            int((c.x + 1.0) * self.size * 0.5),
            int((c.y + 1.0) * self.size * 0.5), ctypes.c_float(1.0/c.z))

class Mesh(object):
    def __init__(self, data):
        self.vertices = map(Vector, data["vertices"])
        self.edges = data["edges"]
        self.faces = data["polys"]
        self.edgemap = {}
        for face in self.faces:
            assert len(face) == 3
            a, b, c = face
            for edge in ((a,b), (b,c), (c,a),
                         (b,a), (c,b), (a,c)):
                if edge not in self.edgemap:
                    self.edgemap[edge] = []
                self.edgemap[edge].append(face)

    def apply_xform(self):
        self.xformed_vertices = []
        for v in self.vertices:
            x, y, z = v.v
            _, _, zz, _ = ol.transformVertex4(x, y, z, 1.0)
            if zz < 0.1:
                self.xformed_vertices.append(None)
            else:
                self.xformed_vertices.append(Vector(ol.transformVertex3(*v.v)))

    def is_border(self, v0, v1):
        s = (v0, v1)
        if s not in self.edgemap:
            return False
        faces = self.edgemap[s]
        if len(faces) != 2:
            return False
        dots = []
        for face in faces:
            a, b, c = [self.xformed_vertices[i] for i in face]
            if a is None or b is None or c is None:
                return False
            dot = (b - a) * (c - a)
            dots.append(dot.z > 0)
        return dots[0] != dots[1]

    def draw_edges(self, color):
        for (v0, v1), hard in self.edges:
            if not hard and not self.is_border(v0, v1):
                continue
            v0 = self.xformed_vertices[v0]
            v1 = self.xformed_vertices[v1]
            if v0 is None or v1 is None:
                continue
            ol.begin(ol.LINESTRIP)
            ol.vertex2Z((v0.x, v0.y, 1.0/v0.z), color)
            ol.vertex2Z((v1.x, v1.y, 1.0/v1.z), color)
            ol.end()

    def render_z(self, zbuffer):
        for face in self.faces:
            assert len(face) == 3
            a, b, c = [self.xformed_vertices[i] for i in face]
            if a is None or b is None or c is None:
                continue
            dot = (b - a) * (c - a)
            # back-face cull
            if dot.z < 0:
                continue
            zbuffer.render_triangle(a, b, c)

class MeshCollection(object):
    def __init__(self, filename):
        j = json.load(open(filename))

        self.meshes = {}
        for k,v in j.iteritems():
            self.meshes[k] = Mesh(v)

    def __getitem__(self, i):
        return self.meshes[i]
    def __getattr__(self, i):
        return self.meshes[i]

def load_mesh(f):
    return MeshCollection(f)
