#include <iostream>
#include <fstream>
#include <sstream>
#include <gl/glew.h>
#include "CShaderGenerator.h"

const std::string gkCoordSrc[] = {
    "RawPosition.xyz",
    "RawNormal.xyz",
    "0.0, 0.0, 0.0",
    "0.0, 0.0, 0.0",
    "RawTex0.xy, 1.0",
    "RawTex1.xy, 1.0",
    "RawTex2.xy, 1.0",
    "RawTex3.xy, 1.0",
    "RawTex4.xy, 1.0",
    "RawTex5.xy, 1.0",
    "RawTex6.xy, 1.0",
    "RawTex7.xy, 1.0"
};

const std::string gkKonstColor[] = {
    "1.0, 1.0, 1.0",
    "0.875, 0.875, 0.875",
    "0.75, 0.75, 0.75",
    "0.625, 0.625, 0.625",
    "0.5, 0.5, 0.5",
    "0.375, 0.375, 0.375",
    "0.25, 0.25, 0.25",
    "0.125, 0.125, 0.125",
    "",
    "",
    "",
    "",
    "KonstColors[0].rgb",
    "KonstColors[1].rgb",
    "KonstColors[2].rgb",
    "KonstColors[3].rgb",
    "KonstColors[0].rrr",
    "KonstColors[1].rrr",
    "KonstColors[2].rrr",
    "KonstColors[3].rrr",
    "KonstColors[0].ggg",
    "KonstColors[1].ggg",
    "KonstColors[2].ggg",
    "KonstColors[3].ggg",
    "KonstColors[0].bbb",
    "KonstColors[1].bbb",
    "KonstColors[2].bbb",
    "KonstColors[3].bbb",
    "KonstColors[0].aaa",
    "KonstColors[1].aaa",
    "KonstColors[2].aaa",
    "KonstColors[3].aaa"
};

const std::string gkKonstAlpha[] = {
    "1.0",
    "0.875",
    "0.75",
    "0.625",
    "0.5",
    "0.375",
    "0.25",
    "0.125",
    "",
    "",
    "",
    "",
    "",
    "",
    "",
    "",
    "KonstColors[0].r",
    "KonstColors[1].r",
    "KonstColors[2].r",
    "KonstColors[3].r",
    "KonstColors[0].g",
    "KonstColors[1].g",
    "KonstColors[2].g",
    "KonstColors[3].g",
    "KonstColors[0].b",
    "KonstColors[1].b",
    "KonstColors[2].b",
    "KonstColors[3].b",
    "KonstColors[0].a",
    "KonstColors[1].a",
    "KonstColors[2].a",
    "KonstColors[3].a"
};

const std::string gkTevColor[] = {
    "Prev.rgb",
    "Prev.aaa",
    "C0.rgb",
    "C0.aaa",
    "C1.rgb",
    "C1.aaa",
    "C2.rgb",
    "C2.aaa",
    "Tex.rgb",
    "Tex.aaa",
    "Ras.rgb",
    "Ras.aaa",
    "1.0, 1.0, 1.0",
    "0.5, 0.5, 0.5",
    "Konst.rgb",
    "0, 0, 0"
};

const std::string gkTevAlpha[] = {
    "Prev.a",
    "C0.a",
    "C1.a",
    "C2.a",
    "Tex.a",
    "Ras.a",
    "Konst.a",
    "0"
};

const std::string gkTevRigid[] = {
    "Prev",
    "C0",
    "C1",
    "C2"
};

CShaderGenerator::CShaderGenerator()
{
}

CShaderGenerator::~CShaderGenerator()
{
}

bool CShaderGenerator::CreateVertexShader(const CMaterial& Mat)
{
    std::stringstream ShaderCode;

    ShaderCode << "#version 330 core\n"
               << "\n";

    // Input
    ShaderCode << "// Input\n";
    EVertexDescription VtxDesc = Mat.GetVtxDesc();
    if (VtxDesc & ePosition) ShaderCode << "layout(location = 0) in vec3 RawPosition;\n";
    if (VtxDesc & eNormal)   ShaderCode << "layout(location = 1) in vec3 RawNormal;\n";
    if (VtxDesc & eColor0)   ShaderCode << "layout(location = 2) in vec4 RawColor0;\n";
    if (VtxDesc & eColor1)   ShaderCode << "layout(location = 3) in vec4 RawColor1;\n";
    if (VtxDesc & eTex0)     ShaderCode << "layout(location = 4) in vec2 RawTex0;\n";
    if (VtxDesc & eTex1)     ShaderCode << "layout(location = 5) in vec2 RawTex1;\n";
    if (VtxDesc & eTex2)     ShaderCode << "layout(location = 6) in vec2 RawTex2;\n";
    if (VtxDesc & eTex3)     ShaderCode << "layout(location = 7) in vec2 RawTex3;\n";
    if (VtxDesc & eTex4)     ShaderCode << "layout(location = 8) in vec2 RawTex4;\n";
    if (VtxDesc & eTex5)     ShaderCode << "layout(location = 9) in vec2 RawTex5;\n";
    if (VtxDesc & eTex6)     ShaderCode << "layout(location = 10) in vec2 RawTex6;\n";
    ShaderCode << "\n";

    // Output
    ShaderCode << "// Output\n";
    if (VtxDesc & eNormal)  ShaderCode << "out vec3 Normal;\n";
    if (VtxDesc & eColor0) ShaderCode << "out vec4 Color0;\n";
    if (VtxDesc & eColor1) ShaderCode << "out vec4 Color1;\n";

    for (u32 iPass = 0; iPass < Mat.mPasses.size(); iPass++)
        if (Mat.mPasses[iPass].TexCoordSource != 0xFF)
            ShaderCode << "out vec3 Tex" << iPass << ";\n";

    ShaderCode << "out vec4 COLOR0A0;\n"
               << "out vec4 COLOR1A1;\n";
    ShaderCode << "\n";

    // Uniforms
    ShaderCode << "// Uniforms\n"
               << "layout(std140) uniform MVPBlock\n"
               << "{\n"
               << "    mat4 ModelMtx;\n"
               << "    mat4 ViewMtx;\n"
               << "    mat4 ProjMtx;\n"
               << "};\n"
               << "\n"
               << "layout(std140) uniform VertexBlock\n"
               << "{\n"
               << "    mat4 TexMtx[10];\n"
               << "    mat4 PostMtx[20];\n"
               << "    vec4 COLOR0_Amb;\n"
               << "    vec4 COLOR0_Mat;\n"
               << "    vec4 COLOR1_Amb;\n"
               << "    vec4 COLOR1_Mat;\n"
               << "};\n"
               << "\n"
               << "struct GXLight\n"
               << "{\n"
               << "    vec4 Position;\n"
               << "    vec4 Direction;\n"
               << "    vec4 Color;\n"
               << "    vec4 DistAtten;\n"
               << "    vec4 AngleAtten;\n"
               << "};\n"
               << "layout(std140) uniform LightBlock {\n"
               << "    GXLight Lights[8];\n"
               << "};\n"
               << "uniform int NumLights;\n"
               << "\n";

    // Main
   ShaderCode  << "// Main\n"
               << "void main()\n"
               << "{\n"
               << "    mat4 MVP = ModelMtx * ViewMtx * ProjMtx;\n"
               << "    mat4 MV = ModelMtx * ViewMtx;\n";

    if (VtxDesc & ePosition) ShaderCode << "    gl_Position = vec4(RawPosition, 1) * MVP;\n";
    if (VtxDesc & eNormal)   ShaderCode << "    Normal = normalize(RawNormal.xyz * inverse(transpose(mat3(MV))));\n";
    if (VtxDesc & eColor0)   ShaderCode << "    Color1 = RawColor0;\n";
    if (VtxDesc & eColor1)   ShaderCode << "    Color2 = RawColor1;\n";

    // Per-vertex lighting
    ShaderCode << "\n"
               << "    // Dynamic Lighting\n";

    // The 0x1 bit on the flag determines whether lighting is enabled for COLOR0
    if (Mat.mChanCtrlFlags & 0x1)
    {
        u8 DiffuseFunction = (Mat.mChanCtrlFlags >> 11) & 0x3;

        if (Mat.mChanCount > 0)
        {
            ShaderCode << "    vec4 Illum = vec4(0.0);\n"
                       << "    vec3 PositionMV = vec3(vec4(RawPosition, 1.0) * MV);\n"
                       << "    \n"
                       << "    for (int iLight = 0; iLight < NumLights; iLight++)\n"
                       << "    {\n"
                       << "        vec3 LightPosMV = vec3(Lights[iLight].Position * ViewMtx);\n"
                       << "        vec3 LightDirMV = normalize(Lights[iLight].Direction.xyz * inverse(transpose(mat3(ViewMtx))));\n"
                       << "        vec3 LightDist = LightPosMV.xyz - PositionMV.xyz;\n"
                       << "        float DistSquared = dot(LightDist, LightDist);\n"
                       << "        float Dist = sqrt(DistSquared);\n"
                       << "        LightDist /= Dist;\n"
                       << "        vec3 AngleAtten = Lights[iLight].AngleAtten.xyz;\n"
                       << "        AngleAtten = vec3(AngleAtten.x, AngleAtten.y, AngleAtten.z);\n"
                       << "        float Atten = max(0, dot(LightDist, LightDirMV.xyz));\n"
                       << "        Atten = max(0, dot(AngleAtten, vec3(1.0, Atten, Atten * Atten))) / dot(Lights[iLight].DistAtten.xyz, vec3(1.0, Dist, DistSquared));\n";

                 if (DiffuseFunction == 2) ShaderCode << "        float DiffuseAtten = max(0, dot(Normal, LightDist));\n";
            else if (DiffuseFunction == 1) ShaderCode << "        float DiffuseAtten = dot(Normal, LightDist);\n";
            else                           ShaderCode << "        float DiffuseAtten = 1.0;\n";

            ShaderCode << "        Illum += (Atten * DiffuseAtten * Lights[iLight].Color);\n"
                       << "    }\n"
                       << "    COLOR0A0 = COLOR0_Mat * (Illum + COLOR0_Amb);\n"
                       << "    COLOR1A1 = COLOR1_Mat * (Illum + COLOR1_Amb);\n"
                       << "    \n";
        }
        else
        {
            ShaderCode << "    COLOR0A0 = COLOR0_Mat;\n"
                       << "    COLOR1A1 = COLOR1_Mat;\n";
        }
    }

    else
    {
        ShaderCode << "    COLOR0A0 = COLOR0_Mat;\n"
                   << "    COLOR1A1 = COLOR1_Mat;\n"
                   << "\n";
    }

    // Texture coordinate generation
    ShaderCode << "    \n"
               << "    // TexGen\n";

    for (u32 iCoord = 0; iCoord < Mat.mPasses.size(); iCoord++)
    {
        if (Mat.mPasses[iCoord].TexCoordSource == 0xFF) continue;

        s32 AnimType = Mat.mPasses[iCoord].AnimMode;

        // Texture Matrix
        if (AnimType == -1) // No animation
            ShaderCode << "    Tex" << iCoord << " = vec3(" << gkCoordSrc[Mat.mPasses[iCoord].TexCoordSource] << ");\n";

        else // Animation used - texture matrix at least, possibly normalization/post-transform
        {
            // Texture Matrix
            ShaderCode << "    Tex" << iCoord << " = vec3(vec4(" << gkCoordSrc[Mat.mPasses[iCoord].TexCoordSource] << ", 1.0) * TexMtx[" << iCoord << "]).xyz;\n";

            if ((AnimType < 2) || (AnimType > 5))
            {
                // Normalization + Post-Transform
                ShaderCode << "    Tex" << iCoord << " = normalize(Tex" << iCoord << ");\n";
                ShaderCode << "    Tex" << iCoord << " = vec3(vec4(Tex" << iCoord << ", 1.0) * PostMtx[" << iCoord << "]).xyz;\n";
            }
        }

        ShaderCode << "\n";
    }
    ShaderCode << "}\n\n";


    // Done!
    return mShader->CompileVertexSource(ShaderCode.str().c_str());
}

bool CShaderGenerator::CreatePixelShader(const CMaterial& Mat)
{
    std::stringstream ShaderCode;
    ShaderCode << "#version 330 core\n"
               << "\n"
               << "#extension GL_ARB_shading_language_420pack : enable\n" // Needed to set texture binding layouts
               << "\n";

    EVertexDescription VtxDesc = Mat.GetVtxDesc();
    if (VtxDesc & ePosition) ShaderCode << "in vec3 Position;\n";
    if (VtxDesc & eNormal)   ShaderCode << "in vec3 Normal;\n";
    if (VtxDesc & eColor0)   ShaderCode << "in vec4 Color0;\n";
    if (VtxDesc & eColor1)   ShaderCode << "in vec4 Color1;\n";

    for (u32 iPass = 0; iPass < Mat.mPasses.size(); iPass++)
        if (Mat.mPasses[iPass].TexCoordSource != 0xFF)
            ShaderCode << "in vec3 Tex" << iPass << ";\n";

    ShaderCode << "in vec4 COLOR0A0;\n"
               << "in vec4 COLOR1A1;\n"
               << "\n"
               << "out vec4 PixelColor;\n"
               << "\n"
               << "layout(std140) uniform PixelBlock {\n"
               << "    vec4 KonstColors[4];\n"
               << "    vec4 TevColor;\n"
               << "    vec4 TintColor;\n"
               << "};\n\n";

    for (u32 iTex = 0; iTex < Mat.mPasses.size(); iTex++)
        if (Mat.mPasses[iTex].pTexture != nullptr)
            ShaderCode << "layout(binding = " << iTex << ") uniform sampler2D Texture" << iTex << ";\n";

    ShaderCode <<"\n";

    ShaderCode << "void main()\n"
               << "{\n"
               << "    vec4 TevInA = vec4(0, 0, 0, 0), TevInB = vec4(0, 0, 0, 0), TevInC = vec4(0, 0, 0, 0), TevInD = vec4(0, 0, 0, 0);\n"
               << "    vec4 Prev = vec4(0, 0, 0, 0), C0 = TevColor, C1 = C0, C2 = C0;\n"
               << "    vec4 Ras = vec4(0, 0, 0, 1), Tex = vec4(0, 0, 0, 0);\n"
               << "    vec4 Konst = vec4(1, 1, 1, 1);\n";

    ShaderCode << "    vec2 TevCoord = vec2(0, 0);\n"
               << "    \n";

    for (u32 iPass = 0; iPass < Mat.mPasses.size(); iPass++)
    {
        ShaderCode << "    // TEV Stage " << iPass << "\n";
        const CMaterial::SPass *pPass = &Mat.mPasses[iPass];

        if (pPass->Hidden)
        {
            ShaderCode << "    // Pass is hidden\n\n";
            continue;
        }

        if (pPass->TexCoordSource != 0xFF)
            ShaderCode << "    TevCoord = (Tex" << iPass << ".z == 0.0 ? Tex" << iPass << ".xy : Tex" << iPass << ".xy / Tex" << iPass << ".z);\n";

        if (pPass->pTexture != nullptr)
            ShaderCode << "    Tex = texture(Texture" << iPass << ", TevCoord);\n";

        ShaderCode << "    Konst = vec4(" << gkKonstColor[pPass->KonstColorSel] << ", " << gkKonstAlpha[pPass->KonstAlphaSel] << ");\n";

        if (pPass->RasSel != 0xFF)
        {
                 if (pPass->RasSel == 0x0) ShaderCode << "    Ras = vec4(COLOR0A0.xyz, 1.0);\n";
            else if (pPass->RasSel == 0x1) ShaderCode << "    Ras = vec4(COLOR1A1.xyz, 1.0);\n";
            else if (pPass->RasSel == 0x2) ShaderCode << "    Ras = vec4(0.0, 0.0, 0.0, COLOR0A0.w);\n";
            else if (pPass->RasSel == 0x3) ShaderCode << "    Ras = vec4(0.0, 0.0, 0.0, COLOR1A1.w);\n";
            else if (pPass->RasSel == 0x4) ShaderCode << "    Ras = COLOR0A0;\n";
            else if (pPass->RasSel == 0x5) ShaderCode << "    Ras = COLOR1A1;\n";
            else if (pPass->RasSel == 0x6) ShaderCode << "    Ras = vec4(0.0, 0.0, 0.0, 0.0);\n";
        }

        for (u8 iInput = 0; iInput < 4; iInput++)
        {
            u8 TevCharacter = iInput + 0x41; // the current stage number represented as an ASCII letter; eg 0 is 'A'

            ShaderCode << "    TevIn" << TevCharacter << " = vec4("
                       << gkTevColor[Mat.GetTevColorIn(iPass, iInput) & 0xF]
                       << ", "
                       << gkTevAlpha[Mat.GetTevAlphaIn(iPass, iInput) & 0x7]
                       << ");\n";
        }

        // Applying TRAN and BLOL (opacity and bloom maps) in Corruption require accessing specific color channels
        // This feels hacky and might not be the best way to implement this
        if (pPass->Type == "TRAN")
        {
            ShaderCode << "    // TRAN Combine\n"
                       << "    Prev.a = 1.0 - Tex.r;\n\n";
        }
        /*else if (pPass->Type == "BLOL")
        {
            ShaderCode << "    // BLOL Combine\n"
                       << "    C0.rgb += vec3(Tex.g, Tex.g, Tex.g);\n\n";
        }*/

        else
        {
            ShaderCode << "    // RGB Combine\n"
                       << "    "
                       << gkTevRigid[pPass->ColorOutputRegister]
                       << ".rgb = ";

            ShaderCode << "clamp(vec3(TevInD.rgb + ((1.0 - TevInC.rgb) * TevInA.rgb + TevInC.rgb * TevInB.rgb)), vec3(0, 0, 0), vec3(1.0, 1.0, 1.0));\n";

            ShaderCode << "    // Alpha Combine\n"
                       << "    "
                       << gkTevRigid[pPass->AlphaOutputRegister]
                       << ".a = ";

            ShaderCode << "clamp(TevInD.a + ((1.0 - TevInC.a) * TevInA.a + TevInC.a * TevInB.a), 0.0, 1.0);\n\n";
        }
    }

    if (Mat.GetOptions() & ePunchthrough) {
        ShaderCode << "    if (Prev.a <= 0.25) discard;\n"
                   << "    else Prev.a = 1.0;\n";
    }

    ShaderCode << "    PixelColor = Prev.rgba * TintColor;\n"
               << "}\n\n";

    // Done!
    return mShader->CompilePixelSource(ShaderCode.str().c_str());
}

CShader* CShaderGenerator::GenerateShader(const CMaterial& Mat)
{
    CShaderGenerator Generator;
    Generator.mShader = new CShader();

    bool success = Generator.CreateVertexShader(Mat);
    if (success) success = Generator.CreatePixelShader(Mat);

    Generator.mShader->LinkShaders();
    return Generator.mShader;
}