
#version 330

//TRACE for cross-dichroic prism

#define PI      3.1415926536
#define PI_HALF 1.5707963268

uniform sampler2D PosData;
uniform sampler2D RngData;
uniform sampler2D RgbData;

uniform vec2 PrismRotation;
uniform vec2 PrismScale;

in vec2 vTexCoord;

layout(location=0) out vec4 posOut;
layout(location=1) out vec4 rngOut;
layout(location=2) out vec4 rgbOut;


float rand(inout vec4 state)
{
    const vec4 q = vec4(   1225.0,    1585.0,    2457.0,    2098.0);
    const vec4 r = vec4(   1112.0,     367.0,      92.0,     265.0);
    const vec4 a = vec4(   3423.0,    2646.0,    1707.0,    1999.0);
    const vec4 m = vec4(4194287.0, 4194277.0, 4194191.0, 4194167.0);

    vec4 beta = floor(state/q);
    vec4 p = a*(state - beta*q) - beta*r;
    beta = (1.0 - sign(p))*0.5*m;
    state = p + beta;
    return fract(dot(state/m, vec4(1.0, -1.0, 1.0, -1.0)));
}

struct Ray {
    vec2 pos;
    vec2 dir;
    vec2 invDir;
    vec2 dirSign;
};
struct Intersection {
    float tMin;
    float tMax;
    vec2  n;
    float mat;
};

float cauchy_ior( float ior, float abbe, float lambda ) {
    float B = (ior - 1.0) / abbe * 0.52345;
    float A = ior - B / 0.34522792;
    float C = lambda / 1000.;
    return A + B / (C * C);
}

float sellmeierIor(vec3 b, vec3 c, float lambda)
{
    float lSq = (lambda*1e-3)*(lambda*1e-3);
    return 1.0 + dot((b*lSq)/(lSq - c), vec3(1.0));
}

float tanh(float x)
{
    if (abs(x) > 10.0) /* Prevent nasty overflow problems */
        return sign(x);
    float e = exp(-2.0*x);
    return (1.0 - e)/(1.0 + e);
}

float atanh(float x)
{
    return 0.5*log((1.0 + x)/(1.0 - x));
}

float dielectricReflectance(float eta, float cosThetaI, out float cosThetaT)
{
    float sinThetaTSq = eta*eta*(1.0 - cosThetaI*cosThetaI);
    if (sinThetaTSq > 1.0) {
        cosThetaT = 0.0;
        return 1.0;
    }
    cosThetaT = sqrt(1.0 - sinThetaTSq);
    float Rs = (eta*cosThetaI - cosThetaT)/(eta*cosThetaI + cosThetaT);
    float Rp = (eta*cosThetaT - cosThetaI)/(eta*cosThetaT + cosThetaI);
    return (Rs*Rs + Rp*Rp)*0.5;
}

vec2 sampleDiffuse(inout vec4 state, vec2 wi)
{
    float x = rand(state)*2.0 - 1.0;
    float y = sqrt(1.0 - x*x);
    return vec2(x, y*sign(wi.y));
}
vec2 sampleMirror(vec2 wi) {
    return vec2(-wi.x, wi.y);
}

// wiは始点向き.
vec2 sampleReflectance(inout vec4 state, vec2 wi, float reflectance)
{
    float r = rand( state );
    if( r < reflectance )
        return sampleMirror( wi );
    else
        return -wi;
}

vec2 sampleDielectric(inout vec4 state, vec2 wi, float ior)
{
    float cosThetaT;
    float eta = wi.y < 0.0 ? ior : 1.0/ior;
    float Fr = dielectricReflectance(eta, abs(wi.y), cosThetaT);
    if (rand(state) < Fr)
        return vec2(-wi.x, wi.y);
    else
        return vec2(-wi.x*eta, -cosThetaT*sign(wi.y));
}


float sampleVisibleNormal(float sigma, float xi, float theta0, float theta1)
{
    float sigmaSq = sigma*sigma;
    float invSigmaSq = 1.0/sigmaSq;
    
    float cdf0 = tanh(theta0*0.5*invSigmaSq);
    float cdf1 = tanh(theta1*0.5*invSigmaSq);

    return 2.0*sigmaSq*atanh(cdf0 + (cdf1 - cdf0)*xi);
}
vec2 sampleRoughMirror(inout vec4 state, vec2 wi, inout vec3 throughput, float sigma)
{
    float theta = asin(clamp(wi.x, -1.0, 1.0));
    float theta0 = max(theta - PI_HALF, -PI_HALF);
    float theta1 = min(theta + PI_HALF,  PI_HALF);

    float thetaM = sampleVisibleNormal(sigma, rand(state), theta0, theta1);
    vec2 m = vec2(sin(thetaM), cos(thetaM));
    vec2 wo = m*(dot(wi, m)*2.0) - wi;
    if (wo.y < 0.0)
        throughput = vec3(0.0);
    return wo;
}
vec2 sampleRoughDielectric(inout vec4 state, vec2 wi, float sigma, float ior)
{
    float theta = asin(min(abs(wi.x), 1.0));
    float theta0 = max(theta - PI_HALF, -PI_HALF);
    float theta1 = min(theta + PI_HALF,  PI_HALF);

    float thetaM = sampleVisibleNormal(sigma, rand(state), theta0, theta1);
    vec2 m = vec2(sin(thetaM), cos(thetaM));

    float wiDotM = dot(wi, m);
    
    float cosThetaT;
    float etaM = wiDotM < 0.0 ? ior : 1.0/ior;
    float F = dielectricReflectance(etaM, abs(wiDotM), cosThetaT);
    if (wiDotM < 0.0)
        cosThetaT = -cosThetaT;

    if (rand(state) < F)
        return 2.0*wiDotM*m - wi;
    else
        return (etaM*wiDotM - cosThetaT)*m - etaM*wi;
}

void bboxIntersect(Ray ray, vec2 center, vec2 radius, float matId, inout Intersection isect)
{
    vec2 pos = ray.pos - center;
    float tx1 = (-radius.x - pos.x)*ray.invDir.x;
    float tx2 = ( radius.x - pos.x)*ray.invDir.x;
    float ty1 = (-radius.y - pos.y)*ray.invDir.y;
    float ty2 = ( radius.y - pos.y)*ray.invDir.y;
    
    float minX = min(tx1, tx2), maxX = max(tx1, tx2);
    float minY = min(ty1, ty2), maxY = max(ty1, ty2);
 
    float tmin = max(isect.tMin, max(minX, minY));
    float tmax = min(isect.tMax, min(maxX, maxY));
 
    if (tmax >= tmin) {
        isect.tMax = (tmin == isect.tMin) ? tmax : tmin;
        isect.n = isect.tMax == tx1 ? vec2(-1.0, 0.0) : isect.tMax == tx2 ? vec2(1.0, 0.0) :
                  isect.tMax == ty1 ? vec2( 0.0, 1.0) :                     vec2(0.0, 1.0);
        isect.mat = matId;
    }
}
void sphereIntersect(Ray ray, vec2 center, float radius, float matId, inout Intersection isect)
{
    vec2 p = ray.pos - center;
    float B = dot(p, ray.dir);
    float C = dot(p, p) - radius*radius;
    float detSq = B*B - C;
    if (detSq >= 0.0) {
        float det = sqrt(detSq);
        float t = -B - det;
        if (t <= isect.tMin || t >= isect.tMax)
            t = -B + det;
        if (t > isect.tMin && t < isect.tMax) {
            isect.tMax = t;
            isect.n = normalize(p + ray.dir*t);
            isect.mat = matId;
        }
    }
}
void lineIntersect(Ray ray, vec2 a, vec2 b, float matId, inout Intersection isect)
{
    vec2 sT = b - a;
    vec2 sN = vec2(-sT.y, sT.x);
    float t = dot(sN, a - ray.pos)/dot(sN, ray.dir);
    float u = dot(sT, ray.pos + ray.dir*t - a);
    if (t < isect.tMin || t >= isect.tMax || u < 0.0 || u > dot(sT, sT))
        return;
    
    isect.tMax = t;
    isect.n = normalize(sN);
    isect.mat = matId;
}
void prismIntersect(Ray ray, vec2 center, float radius, float matId, inout Intersection isect)
{
    lineIntersect(ray, center + vec2(   0.0,  1.0)*radius, center + vec2( 0.866, -0.5)*radius, matId, isect);
    lineIntersect(ray, center + vec2( 0.866, -0.5)*radius, center + vec2(-0.866, -0.5)*radius, matId, isect);
    lineIntersect(ray, center + vec2(-0.866, -0.5)*radius, center + vec2(   0.0,  1.0)*radius, matId, isect);
}
void crossPrismIntersect(Ray ray, vec2 center, float radius, float matId, inout Intersection isect)
{
    lineIntersect(ray, center + vec2( 0., 1.)*radius, center + vec2( 1., 0.)*radius, matId, isect);
    lineIntersect(ray, center + vec2( 1., 0.)*radius, center + vec2( 0., 0.)*radius, matId, isect);
    lineIntersect(ray, center + vec2( 0., 0.)*radius, center + vec2( 0., 1.)*radius, matId, isect);
}
void crossDichroicPrismIntersect(Ray ray, vec2 center, float radius, float matId, inout Intersection isect)
{
    vec2 cross = vec2( -PrismRotation.y, PrismRotation.x );
    vec2 vert1 = vec2( -PrismScale.x,  PrismScale.y );
    vec2 vert2 = vec2(  PrismScale.x,  PrismScale.y );
    vec2 vert3 = vec2(  PrismScale.x, -PrismScale.y );
    vec2 vert4 = vec2( -PrismScale.x, -PrismScale.y );
    
    vert1 = vec2( dot( PrismRotation, vert1 ), dot( cross, vert1 ) );
    vert2 = vec2( dot( PrismRotation, vert2 ), dot( cross, vert2 ) );
    vert3 = vec2( dot( PrismRotation, vert3 ), dot( cross, vert3 ) );
    vert4 = vec2( dot( PrismRotation, vert4 ), dot( cross, vert4 ) );
    
    lineIntersect(ray, center + vert1, center + vert2, matId, isect);
    lineIntersect(ray, center + vert2, center + vert3, matId, isect);
    lineIntersect(ray, center + vert3, center + vert4, matId, isect);
    lineIntersect(ray, center + vert4, center + vert1, matId, isect);
    lineIntersect(ray, center + vert1, center + vert3, matId+1, isect);
    lineIntersect(ray, center + vert2, center + vert4, matId+2, isect);
}

struct Segment {
    float tNear, tFar;
    vec2  nNear, nFar;
};

Segment segmentIntersection(Segment a, Segment b) {
    return Segment(
        max(a.tNear, b.tNear),
        min(a.tFar,  b.tFar),
        (a.tNear > b.tNear) ? a.nNear : b.nNear,
        (a.tFar  < b.tFar)  ? a.nFar  : b.nFar
    );
}
Segment segmentSubtraction(Segment a, Segment b, float tMin) {
    if (a.tNear >= a.tFar || b.tNear >= b.tFar || a.tFar <= b.tNear || a.tNear >= b.tFar)
        return a;
    
    Segment s1 = Segment(a.tNear, b.tNear, a.nNear, -b.nNear);
    Segment s2 = Segment(b.tFar,  a.tFar, -b.nFar,   a.nFar);
    bool valid1 = s1.tNear <= s1.tFar;
    bool valid2 = s2.tNear <= s2.tFar;
    
    if (valid1 && valid2) {
        if (s1.tFar >= tMin) return s1; else return s2;
    } else {
        if (valid1) return s1; else return s2;
    }
}
void segmentCollapse(Segment segment, float matId, inout Intersection isect) {
    segment.tNear = max(segment.tNear, isect.tMin);
    segment.tFar  = min(segment.tFar,  isect.tMax);
    
    if (segment.tNear <= segment.tFar) {
        if (segment.tNear > isect.tMin) {
            isect.tMax = segment.tNear;
            isect.n = segment.nNear;
            isect.mat = matId;
        } else if (segment.tFar < isect.tMax) {
            isect.tMax = segment.tFar;
            isect.n = segment.nFar;
            isect.mat = matId;
        }
    }
}

Segment horzSpanIntersect(Ray ray, float y, float radius) {
    float dc = (y - ray.pos.y)*ray.invDir.y;
    float dt = ray.dirSign.y*radius*ray.invDir.y;
    return Segment(dc - dt, dc + dt, vec2(0.0, -ray.dirSign.y), vec2(0.0, ray.dirSign.y));
}
Segment vertSpanIntersect(Ray ray, float x, float radius) {
    float dc = (x - ray.pos.x)*ray.invDir.x;
    float dt = ray.dirSign.x*radius*ray.invDir.x;
    return Segment(dc - dt, dc + dt, vec2(-ray.dirSign.x, 0.0), vec2(ray.dirSign.x, 0.0));
}
Segment boxSegmentIntersect(Ray ray, vec2 center, vec2 radius) {
    return segmentIntersection(
        horzSpanIntersect(ray, center.y, radius.y),
        vertSpanIntersect(ray, center.x, radius.x)
    );
}
Segment sphereSegmentIntersect(Ray ray, vec2 center, float radius) {
    Segment result;
    
    vec2 p = ray.pos - center;
    float B = dot(p, ray.dir);
    float C = dot(p, p) - radius*radius;
    float detSq = B*B - C;
    if (detSq >= 0.0) {
        float det = sqrt(detSq);
        result.tNear = -B - det;
        result.tFar  = -B + det;
        result.nNear = (p + ray.dir*result.tNear)*(1.0/radius);
        result.nFar  = (p + ray.dir*result.tFar)*(1.0/radius);
    } else {
        result.tNear =  1e30;
        result.tFar  = -1e30;
    }
    
    return result;
}

void biconvexLensIntersect(Ray ray, vec2 center, float h, float d, float r1, float r2, float matId, inout Intersection isect) {
    segmentCollapse(segmentIntersection(segmentIntersection(
        horzSpanIntersect(ray, center.y, h),
        sphereSegmentIntersect(ray, center + vec2(r1 - d, 0.0), r1)),
        sphereSegmentIntersect(ray, center - vec2(r2 - d, 0.0), r2)
    ), matId, isect);
}
void biconcaveLensIntersect(Ray ray, vec2 center, float h, float d, float r1, float r2, float matId, inout Intersection isect) {
    segmentCollapse(segmentSubtraction(segmentSubtraction(segmentIntersection(
        horzSpanIntersect(ray, center.y, h),
        vertSpanIntersect(ray, center.x + 0.5*(r2 - r1), 0.5*(abs(r1) + abs(r2)) + d)),
        sphereSegmentIntersect(ray, center + vec2(r2 + d, 0.0), r2), isect.tMin),
        sphereSegmentIntersect(ray, center - vec2(r1 + d, 0.0), r1), isect.tMin
    ), matId, isect);
}
void meniscusLensIntersect(Ray ray, vec2 center, float h, float d, float r1, float r2, float matId, inout Intersection isect) {
    segmentCollapse(segmentSubtraction(segmentIntersection(segmentIntersection(
        horzSpanIntersect(ray, center.y, h),
        vertSpanIntersect(ray, center.x + 0.5*r2, 0.5*abs(r2) + d)),
        sphereSegmentIntersect(ray, center + vec2(r1 - sign(r1)*d, 0.0), abs(r1))),
        sphereSegmentIntersect(ray, center + vec2(r2 + sign(r2)*d, 0.0), abs(r2)), isect.tMin
    ), matId, isect);
}
void planoConvexLensIntersect(Ray ray, vec2 center, float h, float d, float r, float matId, inout Intersection isect) {
    segmentCollapse(segmentIntersection(
        boxSegmentIntersect(ray, center, vec2(d, h)),
        sphereSegmentIntersect(ray, center + vec2(r - d, 0.0), abs(r))
    ), matId, isect);
}
void planoConcaveLensIntersect(Ray ray, vec2 center, float h, float d, float r, float matId, inout Intersection isect) {
    segmentCollapse(segmentSubtraction(segmentIntersection(
        horzSpanIntersect(ray, center.y, h),
        vertSpanIntersect(ray, center.x - 0.5*r, 0.5*abs(r) + d)),
        sphereSegmentIntersect(ray, center - vec2(r + d, 0.0), abs(r)), isect.tMin
    ), matId, isect);
}


// scene decl.

void intersect(Ray ray, inout Intersection isect) 
{
    bboxIntersect(ray, vec2(0.0), vec2(1.78, 1.0), 0.0, isect);
    crossDichroicPrismIntersect( ray, vec2(0.,0.), 0.5, 1.0, isect);
}

float blueDichroicReflectance( float lambda )
{
    return (lambda < 480.) ? 1.:0.;
}
float redDichroicReflectance( float lambda )
{
    return (lambda > 580.) ? 1.:0.;
}

vec2 sample(inout vec4 state, Intersection isect, float lambda, vec2 wiLocal, inout vec3 throughput) 
{
    float ior = cauchy_ior(1.33, 10, lambda);
    if (isect.mat == 1.0) {
        return sampleDielectric(state, wiLocal, ior);
    } else if( isect.mat == 2.0 ){
        float reflectance = blueDichroicReflectance( lambda );
        return sampleReflectance( state, wiLocal, reflectance );
    }else if( isect.mat == 3.0 ){
        float reflectance = redDichroicReflectance( lambda );
        return sampleReflectance( state, wiLocal, reflectance );
    }else{
        throughput *= vec3(0.5);
        return sampleDiffuse(state, wiLocal);
    }
}

Ray unpackRay(vec4 posDir)
{
    vec2 pos = posDir.xy;
    vec2 dir = posDir.zw;
    dir.x = abs(dir.x) < 1e-5 ? 1e-5 : dir.x; /* The nuclear option to fix NaN issues on some platforms */
    dir.y = abs(dir.y) < 1e-5 ? 1e-5 : dir.y;
    return Ray(pos, normalize(dir), 1.0/dir, sign(dir));
}

void main()
{
    vec4 posDir    = texture2D(PosData, vTexCoord);
    vec4 state     = texture2D(RngData, vTexCoord);
    vec4 rgbLambda = texture2D(RgbData, vTexCoord);

    posOut = posDir;
    rngOut = state;
    rgbOut = rgbLambda;

    Ray ray = unpackRay(posDir);
    Intersection isect;
    isect.tMin = 1e-4;
    isect.tMax = 1e30;
    isect.n    = vec2(0);
    isect.mat  = 0.;
    intersect(ray, isect);

    vec2 t = vec2(-isect.n.y, isect.n.x);
    vec2 wiLocal = -vec2(dot(t, ray.dir), dot(isect.n, ray.dir));
    vec2 woLocal = sample(state, isect, rgbLambda.w, wiLocal, rgbLambda.rgb);

    if (isect.tMax == 1e30) {
        rgbLambda.rgb = vec3(0.0);
    } else {
        posDir.xy = ray.pos + ray.dir*isect.tMax;
        posDir.zw = woLocal.y*isect.n + woLocal.x*t;
    }

    posOut = posDir;
    rngOut = state;
    rgbOut = rgbLambda;
}
