//--------------------------------------------------------------------------------------
// Constant Buffer Variables
//--------------------------------------------------------------------------------------
Texture2D tex0 : register(t0);
SamplerState sampLinearClamp : register(s0);
SamplerState sampNearestClamp : register(s1);
SamplerState sampLinearWrap : register(s2);
SamplerState sampNearestWrap : register(s3);

cbuffer GlobalRenderData : register(b3)
{
    matrix g_matWorld;
    matrix g_matView;
    matrix g_matProjection;
    float4 g_vecEyePoint;
};

cbuffer GlobalRenderData : register(b2)
{
    float g_fTime;
    float g_fFloat1;
    float g_fFloat2; // Light strength
    float g_fFloat3;
    float g_fFloat4;
    float g_fFloat5;
    float g_fFloat6;
    float g_fSeed;
};


//--------------------------------------------------------------------------------------
struct VS_INPUT
{
    float3 vecPosition : POSITION;
    float2 vecTexCoord : TEXCOORD0;
};

struct PS_INPUT
{
    float4 vecPosition : SV_POSITION;
    float2 vecTexCoord : TEXCOORD0;
};


//--------------------------------------------------------------------------------------
#define RM_STEP_MULTIPLIER  0.75
#define RM_TOLERANCE        0.01
#define RM_MAX_STEPS        100

#define NORMAL_EPSILON      0.05

#define AO_ITERATIONS       10
#define AO_STEP             0.3


float sdSphere(float3 pos, float size)
{
    return length(pos) - size;
}


float sdBox(float3 pos, float3 size)
{
    float3 d = abs(pos) - size;
    return min(max(d.x, max(d.y, d.z)), 0.0) + length(max(d, 0.0));
}


float sdCappedCylinder(float3 p, float2 h)
{
    float2 d = abs(float2(length(p.xy), p.z)) - h;
    return min(max(d.x, d.y), 0.0) + length(max(d, 0.0));
}


float3 repeat(float3 pos, float3 range)
{
    return (abs(pos) % range) - 0.5 * range;
}


float field(float3 pos)
{
    // The main field - a bunch of repeated cubes that intersect in interesting ways
    float s1 = sdBox(repeat(pos, float3(4.0, 4.0, 4.0)), float3(1.7, 1.7, 1.7));
    float s2 = sdBox(repeat(pos, float3(3.0, 3.0, 3.0)), float3(1.3, 1.3, 1.4));
    float s3 = sdBox(repeat(pos, float3(1.0, 2.1, 1.2)), float3(0.3, 0.2, 0.3));
    float s4 = sdBox(repeat(pos, float3(6.0, 6.3, 6.0)), float3(2.5, 2.2, 2.2));
    float s1234 = min(max(-s1, s2), max(s3, s4));

    // Carve a "large sphere-shaped hole with a cylinder stuck on" in the middle of everything
    float ss1 = sdSphere(pos - float3(0, 0, 900), 30.0);
    float ss2 = sdCappedCylinder(pos - float3(0, 0, -1000+900), float2(10, 1000));
    float ss12 = min(ss1, ss2);

//    return -min(max(-ss12, s1234), sdSphere(pos - float3(0, 0, 900), 5.0));
    return -max(-ss12, s1234);
}


float3 getNormal(float3 pos)
{
    return normalize(float3(field(pos + float3(NORMAL_EPSILON, 0, 0)) - field(pos - float3(NORMAL_EPSILON, 0, 0)),
                            field(pos + float3(0, NORMAL_EPSILON, 0)) - field(pos - float3(0, NORMAL_EPSILON, 0)),
                            field(pos + float3(0, 0, NORMAL_EPSILON)) - field(pos - float3(0, 0, NORMAL_EPSILON))));
}


float getAmbientOcclusion(float3 pos, float3 norm)
{
    float o = 1.0;
    for (float i = float(AO_ITERATIONS); i > 0.0; i--)
    {
        o -= (i * AO_STEP - abs(field(pos + norm * i * AO_STEP))) / pow(2.0, i);
    }
    return o;
}


float3 rayMarch(float3 pos, float3 rayDir)
{
    float3 currentPos = pos;
    float dist;
    for (int stepCounter = 0; stepCounter < RM_MAX_STEPS; stepCounter++)
    {
        dist = abs(field(currentPos));
        currentPos += dist * RM_STEP_MULTIPLIER * rayDir;
        if (dist < RM_TOLERANCE)
        {
            return float3(float(stepCounter) / float(RM_MAX_STEPS), length(pos - currentPos), dist);
        }
    }
    return float3(1, 0, 0);
}



//--------------------------------------------------------------------------------------
// Vertex Shader
//--------------------------------------------------------------------------------------
PS_INPUT VS( VS_INPUT input )
{
    PS_INPUT output;

    float2 vecTexCoord = input.vecTexCoord.xy * float2(2, -2) + float2(-1, 1);
    output.vecPosition = float4(vecTexCoord * float2(1, 1280.0/720.0), 0, 1);
    output.vecTexCoord = vecTexCoord;

    return output;
}


//--------------------------------------------------------------------------------------
// Pixel Shader
//--------------------------------------------------------------------------------------
float4 PS(PS_INPUT input) : SV_Target
{
    const matrix matWorldView = mul(g_matWorld, g_matView); // to the vs..
    const float3 origin = -g_vecEyePoint.xyz;
    const float3 target = float3(-input.vecTexCoord.xy, 0);
    const float3 eye = float3(0, 0, -1.00); // This controls the field-of-view
    const float3 ray = mul(g_matView, float3(normalize(target - eye)));

    // Ray march to find the position and normal that this ray hits
    const float3 res = rayMarch(origin, ray);
    const float3 pos = origin + ray * res.y;
    const float3 normal = getNormal(pos);

    // Start with some ambient light colour
    float3 colour = float3(0.1, 0.1, 0.1);
    if (res.x < 1.0)
    {
        // Add white-ish light shining down and forwards
        {
            float3 lightNormal = normalize(float3(0.0, -1.0, -0.5));
            float v = max(0.0, dot(normal, -lightNormal));
            colour += v * float3(0.8, 0.8, 0.9) * g_fFloat2;
        }

        // Add dark red-ish light shining straight up
        {
            float3 lightNormal = normalize(float3(0.0, 1.0, 0.0));
            float v = max(0.0, dot(normal, -lightNormal));
            colour += v * float3(0.5, 0.4, 0.4) * g_fFloat2;
        }

        // Add red glow from the middle
        /*
        {
            float3 toLight= float3(0, 0, 900) - pos;
            float distToLight = length(toLight);
            float3 lightNormal = normalize(toLight);
            float v = max(0.0, dot(normal, -lightNormal));
            colour += v * float3(1, 0, 0) * clamp(distToLight/90, 0, 1);
        }
        */

        // Add cyan glow based on how long the ray march took, aka: make the edges glow
        //colour += pow(res.x, 5.0) * float3(0.2, 1.2, 1.6);

        // Multiply by ambient occlusion to darken the shadowed areas
        float ambientOcclusion = getAmbientOcclusion(pos, normal);
        colour *= ambientOcclusion;
    }

    // Set the final colour to the fragment
    float distFade = 1 - clamp(res.y / 300, 0, 1);
    return float4(colour.rgb * g_fFloat1 * distFade, 1.0);
}
