#include "VISIRenderer.hpp"
#include "athena/FileReader.hpp"
#include "zeus/CAABox.hpp"
#include "VISIBuilder.hpp"
#include "zeus/CFrustum.hpp"
#include "logvisor/logvisor.hpp"

static logvisor::Module Log("visigen");

static const char* VS =
    "#version 330\n"
    "layout(location=0) in vec4 posIn;\n"
    "layout(location=1) in vec4 colorIn;\n"
    "\n"
    "uniform UniformBlock\n"
    "{\n"
    "    mat4 xf;\n"
    "};\n"
    "\n"
    "struct VertToFrag\n"
    "{\n"
    "    vec4 color;\n"
    "};\n"
    "\n"
    "out VertToFrag vtf;\n"
    "void main()\n"
    "{\n"
    "    vtf.color = colorIn;\n"
    "    gl_Position = xf * vec4(posIn.xyz, 1.0);\n"
    "}\n";

static const char* FS =
    "#version 330\n"
    "struct VertToFrag\n"
    "{\n"
    "    vec4 color;\n"
    "};\n"
    "\n"
    "in VertToFrag vtf;\n"
    "layout(location=0) out vec4 colorOut;\n"
    "void main()\n"
    "{\n"
    "    colorOut = vtf.color;\n"
    "}\n";

static const uint32_t AABBIdxs[20] = {0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 1, 7, 3, 5, 5, 0, 0, 2, 6, 4};

bool VISIRenderer::SetupShaders() {
  m_vtxShader = glCreateShader(GL_VERTEX_SHADER);
  m_fragShader = glCreateShader(GL_FRAGMENT_SHADER);
  m_program = glCreateProgram();

  glShaderSource(m_vtxShader, 1, &VS, nullptr);
  glCompileShader(m_vtxShader);
  GLint status;
  glGetShaderiv(m_vtxShader, GL_COMPILE_STATUS, &status);
  if (status != GL_TRUE) {
    GLint logLen;
    glGetShaderiv(m_vtxShader, GL_INFO_LOG_LENGTH, &logLen);
    char* log = (char*)malloc(logLen);
    glGetShaderInfoLog(m_vtxShader, logLen, nullptr, log);
    Log.report(logvisor::Error, fmt("unable to compile vert source\n{}\n{}\n"), log, VS);
    free(log);
    return false;
  }

  glShaderSource(m_fragShader, 1, &FS, nullptr);
  glCompileShader(m_fragShader);
  glGetShaderiv(m_fragShader, GL_COMPILE_STATUS, &status);
  if (status != GL_TRUE) {
    GLint logLen;
    glGetShaderiv(m_fragShader, GL_INFO_LOG_LENGTH, &logLen);
    char* log = (char*)malloc(logLen);
    glGetShaderInfoLog(m_fragShader, logLen, nullptr, log);
    Log.report(logvisor::Error, fmt("unable to compile frag source\n{}\n{}\n"), log, FS);
    free(log);
    return false;
  }

  glAttachShader(m_program, m_vtxShader);
  glAttachShader(m_program, m_fragShader);

  glLinkProgram(m_program);
  glGetProgramiv(m_program, GL_LINK_STATUS, &status);
  if (status != GL_TRUE) {
    GLint logLen;
    glGetProgramiv(m_program, GL_INFO_LOG_LENGTH, &logLen);
    char* log = (char*)malloc(logLen);
    glGetProgramInfoLog(m_program, logLen, nullptr, log);
    Log.report(logvisor::Error, fmt("unable to link shader program\n{}\n"), log);
    free(log);
    return false;
  }

  glUseProgram(m_program);
  m_uniLoc = glGetUniformBlockIndex(m_program, "UniformBlock");

  glGenBuffers(1, &m_uniformBufferGL);
  glBindBuffer(GL_UNIFORM_BUFFER, m_uniformBufferGL);
  glBufferData(GL_UNIFORM_BUFFER, sizeof(UniformBuffer), nullptr, GL_DYNAMIC_DRAW);

  glGenBuffers(1, &m_aabbIBO);
  glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, m_aabbIBO);
  glBufferData(GL_ELEMENT_ARRAY_BUFFER, 20 * 4, AABBIdxs, GL_STATIC_DRAW);

  glEnable(GL_PRIMITIVE_RESTART);
  glPrimitiveRestartIndex(0xffffffff);

  return true;
}

std::vector<VISIRenderer::Model::Vert> VISIRenderer::AABBToVerts(const zeus::CAABox& aabb, const zeus::CColor& color) {
  std::vector<Model::Vert> verts;
  verts.resize(8);

  for (int i = 0; i < 8; ++i)
    verts[i].color = color;

  verts[0].pos = aabb.min;
  verts[1].pos = {aabb.max.x(), aabb.min.y(), aabb.min.z()};
  verts[2].pos = {aabb.min.x(), aabb.min.y(), aabb.max.z()};
  verts[3].pos = {aabb.max.x(), aabb.min.y(), aabb.max.z()};
  verts[4].pos = {aabb.min.x(), aabb.max.y(), aabb.max.z()};
  verts[5].pos = aabb.max;
  verts[6].pos = {aabb.min.x(), aabb.max.y(), aabb.min.z()};
  verts[7].pos = {aabb.max.x(), aabb.max.y(), aabb.min.z()};

  return verts;
}

static zeus::CColor ColorForIndex(int i) {
  i += 1;
  return zeus::CColor((i & 0xff) / 255.f, ((i >> 8) & 0xff) / 255.f, ((i >> 16) & 0xff) / 255.f, 1.f);
}

bool VISIRenderer::SetupVertexBuffersAndFormats() {
  for (Model& model : m_models) {
    glGenVertexArrays(1, &model.vao);
    glGenBuffers(1, &model.vbo);
    glGenBuffers(1, &model.ibo);

    glBindVertexArray(model.vao);

    glBindBuffer(GL_ARRAY_BUFFER, model.vbo);
    glBufferData(GL_ARRAY_BUFFER, model.verts.size() * sizeof(Model::Vert), model.verts.data(), GL_STATIC_DRAW);
    glEnableVertexAttribArray(0);
    glEnableVertexAttribArray(1);
    glVertexAttribPointer(0, 4, GL_FLOAT, GL_FALSE, sizeof(Model::Vert), 0);
    glVertexAttribPointer(1, 4, GL_FLOAT, GL_FALSE, sizeof(Model::Vert), (void*)16);

    glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, model.ibo);
    glBufferData(GL_ELEMENT_ARRAY_BUFFER, model.idxs.size() * 4, model.idxs.data(), GL_STATIC_DRAW);
  }

  int idx = m_models.size();
  for (Entity& ent : m_entities) {
    glGenVertexArrays(1, &ent.vao);
    glGenBuffers(1, &ent.vbo);

    glBindVertexArray(ent.vao);

    auto verts = AABBToVerts(ent.aabb, ColorForIndex(idx++));
    glBindBuffer(GL_ARRAY_BUFFER, ent.vbo);
    glBufferData(GL_ARRAY_BUFFER, verts.size() * sizeof(Model::Vert), verts.data(), GL_STATIC_DRAW);
    glEnableVertexAttribArray(0);
    glEnableVertexAttribArray(1);
    glVertexAttribPointer(0, 4, GL_FLOAT, GL_FALSE, sizeof(Model::Vert), 0);
    glVertexAttribPointer(1, 4, GL_FLOAT, GL_FALSE, sizeof(Model::Vert), (void*)16);

    glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, m_aabbIBO);
  }

  for (Light& light : m_lights) {
    glGenVertexArrays(1, &light.vao);
    glGenBuffers(1, &light.vbo);

    glBindVertexArray(light.vao);

    Model::Vert vert;
    vert.pos = light.point;
    vert.color = ColorForIndex(idx++);
    glBindBuffer(GL_ARRAY_BUFFER, light.vbo);
    glBufferData(GL_ARRAY_BUFFER, sizeof(Model::Vert), &vert, GL_STATIC_DRAW);
    glEnableVertexAttribArray(0);
    glEnableVertexAttribArray(1);
    glVertexAttribPointer(0, 4, GL_FLOAT, GL_FALSE, sizeof(Model::Vert), 0);
    glVertexAttribPointer(1, 4, GL_FLOAT, GL_FALSE, sizeof(Model::Vert), (void*)16);
  }

  m_queryCount = m_models.size() + m_entities.size() + m_lights.size();
  m_queries.reset(new GLuint[m_queryCount]);
  m_queryBools.reset(new bool[m_queryCount]);
  glGenQueries(GLsizei(m_queryCount), m_queries.get());

  return true;
}

static zeus::CMatrix4f g_Proj;

static void CalculateProjMatrix() {
  float znear = 0.2f;
  float zfar = 1000.f;
  float tfov = std::tan(zeus::degToRad(90.f * 0.5f));
  float top = znear * tfov;
  float bottom = -top;
  float right = znear * tfov;
  float left = -right;

  float rml = right - left;
  float rpl = right + left;
  float tmb = top - bottom;
  float tpb = top + bottom;
  float fpn = zfar + znear;
  float fmn = zfar - znear;

  g_Proj = zeus::CMatrix4f(2.f * znear / rml, 0.f, rpl / rml, 0.f, 0.f, 2.f * znear / tmb, tpb / tmb, 0.f, 0.f, 0.f,
                           -fpn / fmn, -2.f * zfar * znear / fmn, 0.f, 0.f, -1.f, 0.f);
}

static const zeus::CMatrix4f LookMATs[] = {
    {// Forward
     1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, -1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.f},
    {// Backward
     -1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.f},
    {// Up
     1.f, 0.f, 0.f, 0.f, 0.f, -1.f, 0.f, 0.f, 0.f, 0.f, -1.f, 0.f, 0.f, 0.f, 0.f, 1.f},
    {// Down
     1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f},
    {// Left
     0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.f},
    {// Right
     0.f, -1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, -1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.f},
};

void VISIRenderer::RenderPVSOpaque(RGBA8* bufOut, const zeus::CVector3f& pos, bool& needTransparent) {
  glViewport(0, 0, 768, 512);
  glEnable(GL_CULL_FACE);
  glDepthMask(GL_TRUE);
  glDepthFunc(GL_LEQUAL);
  glEnable(GL_DEPTH_TEST);
  glClearColor(0.f, 0.f, 0.f, 1.f);
  glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);

  for (int j = 0; j < 6; ++j) {
    GLint x = (j % 3) * 256;
    GLint y = (j / 3) * 256;
    glViewport(x, y, 256, 256);

    zeus::CMatrix4f mv = LookMATs[j] * zeus::CTransform::Translate(-pos).toMatrix4f();
    m_uniformBuffer.m_xf = g_Proj * mv;
    glBindBuffer(GL_UNIFORM_BUFFER, m_uniformBufferGL);
    glBufferData(GL_UNIFORM_BUFFER, sizeof(UniformBuffer), &m_uniformBuffer, GL_DYNAMIC_DRAW);

    glUniformBlockBinding(m_program, m_uniLoc, 0);
    glBindBufferRange(GL_UNIFORM_BUFFER, 0, m_uniformBufferGL, 0, sizeof(UniformBuffer));

    zeus::CFrustum frustum;
    frustum.updatePlanes(mv, g_Proj);

    // Draw frontfaces
    glCullFace(GL_BACK);
    glColorMask(GL_TRUE, GL_TRUE, GL_TRUE, GL_TRUE);

    for (const Model& model : m_models) {
      if (!frustum.aabbFrustumTest(model.aabb))
        continue;
      glBindVertexArray(model.vao);
      for (const Model::Surface& surf : model.surfaces) {
        // Non-transparents first
        if (!surf.transparent)
          glDrawElements(model.topology, surf.count, GL_UNSIGNED_INT,
                         reinterpret_cast<void*>(uintptr_t(surf.first * 4)));
      }
    }
  }

  // m_swapFunc();
  glFinish();
  glReadPixels(0, 0, 768, 512, GL_RGBA, GL_UNSIGNED_BYTE, (GLvoid*)bufOut);
}

void VISIRenderer::RenderPVSTransparent(const std::function<void(int)>& passFunc, const zeus::CVector3f& pos) {
  glDepthMask(GL_FALSE);

  for (int j = 0; j < 6; ++j) {
    GLint x = (j % 3) * 256;
    GLint y = (j / 3) * 256;
    glViewport(x, y, 256, 256);

    zeus::CMatrix4f mv = LookMATs[j] * zeus::CTransform::Translate(-pos).toMatrix4f();
    m_uniformBuffer.m_xf = g_Proj * mv;
    glBindBuffer(GL_UNIFORM_BUFFER, m_uniformBufferGL);
    glBufferData(GL_UNIFORM_BUFFER, sizeof(UniformBuffer), &m_uniformBuffer, GL_DYNAMIC_DRAW);

    glUniformBlockBinding(m_program, m_uniLoc, 0);
    glBindBufferRange(GL_UNIFORM_BUFFER, 0, m_uniformBufferGL, 0, sizeof(UniformBuffer));

    zeus::CFrustum frustum;
    frustum.updatePlanes(mv, g_Proj);

    memset(m_queryBools.get(), 0, m_queryCount);

    int idx = 0;
    for (const Model& model : m_models) {
      if (!frustum.aabbFrustumTest(model.aabb)) {
        ++idx;
        continue;
      }
      glBindVertexArray(model.vao);
      glBeginQuery(GL_ANY_SAMPLES_PASSED_CONSERVATIVE, m_queries[idx]);
      m_queryBools[idx] = true;
      for (const Model::Surface& surf : model.surfaces) {
        // transparents
        if (surf.transparent)
          glDrawElements(model.topology, surf.count, GL_UNSIGNED_INT,
                         reinterpret_cast<void*>(uintptr_t(surf.first * 4)));
      }
      glEndQuery(GL_ANY_SAMPLES_PASSED_CONSERVATIVE);
      ++idx;
    }

    for (int i = 0; i < idx; ++i) {
      if (m_queryBools[i]) {
        GLint res;
        glGetQueryObjectiv(m_queries[i], GL_QUERY_RESULT, &res);
        if (res)
          passFunc(i);
      }
    }
  }
}

void VISIRenderer::RenderPVSEntitiesAndLights(const std::function<void(int)>& passFunc,
                                              const std::function<void(int, EPVSVisSetState)>& lightPassFunc,
                                              const zeus::CVector3f& pos) {
  glDepthMask(GL_FALSE);

  for (int j = 0; j < 6; ++j) {
    GLint x = (j % 3) * 256;
    GLint y = (j / 3) * 256;
    glViewport(x, y, 256, 256);

    zeus::CMatrix4f mv = LookMATs[j] * zeus::CTransform::Translate(-pos).toMatrix4f();
    m_uniformBuffer.m_xf = g_Proj * mv;
    glBindBuffer(GL_UNIFORM_BUFFER, m_uniformBufferGL);
    glBufferData(GL_UNIFORM_BUFFER, sizeof(UniformBuffer), &m_uniformBuffer, GL_DYNAMIC_DRAW);

    glUniformBlockBinding(m_program, m_uniLoc, 0);
    glBindBufferRange(GL_UNIFORM_BUFFER, 0, m_uniformBufferGL, 0, sizeof(UniformBuffer));

    zeus::CFrustum frustum;
    frustum.updatePlanes(mv, g_Proj);

    memset(m_queryBools.get(), 0, m_queryCount);

    int idx = m_models.size();
    for (const Entity& ent : m_entities) {
      if (!frustum.aabbFrustumTest(ent.aabb)) {
        ++idx;
        continue;
      }
      glBindVertexArray(ent.vao);
      m_queryBools[idx] = true;
      glBeginQuery(GL_ANY_SAMPLES_PASSED_CONSERVATIVE, m_queries[idx]);
      glDrawElements(GL_TRIANGLE_STRIP, 20, GL_UNSIGNED_INT, 0);
      glEndQuery(GL_ANY_SAMPLES_PASSED_CONSERVATIVE);
      ++idx;
    }

    for (const Light& light : m_lights) {
      if (!frustum.pointFrustumTest(light.point)) {
        ++idx;
        continue;
      }
      glBindVertexArray(light.vao);
      m_queryBools[idx] = true;
      glBeginQuery(GL_ANY_SAMPLES_PASSED_CONSERVATIVE, m_queries[idx]);
      glDrawArrays(GL_POINTS, 0, 1);
      glEndQuery(GL_ANY_SAMPLES_PASSED_CONSERVATIVE);
      ++idx;
    }

    idx = m_models.size();
    for (const Entity& ent : m_entities) {
      (void)ent;
      if (m_queryBools[idx]) {
        GLint res;
        glGetQueryObjectiv(m_queries[idx], GL_QUERY_RESULT, &res);
        if (res)
          passFunc(idx);
      }
      ++idx;
    }

    int lightIdx = 0;
    for (const Light& light : m_lights) {
      if (m_queryBools[idx]) {
        GLint res;
        glGetQueryObjectiv(m_queries[idx], GL_QUERY_RESULT, &res);
        EPVSVisSetState state =
            m_totalAABB.pointInside(light.point) ? EPVSVisSetState::EndOfTree : EPVSVisSetState::OutOfBounds;
        if (res && state == EPVSVisSetState::EndOfTree)
          state = EPVSVisSetState::NodeFound;
        lightPassFunc(lightIdx, state);
      }
      ++lightIdx;
      ++idx;
    }
  }
}

void VISIRenderer::Run(FPercent updatePercent) {
  m_updatePercent = updatePercent;
  CalculateProjMatrix();

  if (glewInit() != GLEW_OK) {
    Log.report(logvisor::Error, fmt("unable to init glew"));
    m_return = 1;
    return;
  }

  if (!GLEW_ARB_occlusion_query2) {
    Log.report(logvisor::Error, fmt("GL_ARB_occlusion_query2 extension not present"));
    m_return = 1;
    return;
  }

  if (!SetupShaders()) {
    m_return = 1;
    return;
  }

  if (m_argc < 3) {
    Log.report(logvisor::Error, fmt("Missing input/output file args"));
    m_return = 1;
    return;
  }

  ProcessType parentPid = 0;
  if (m_argc > 4)
#ifdef _WIN32
    parentPid = ProcessType(wcstoull(m_argv[4], nullptr, 10));
#else
    parentPid = ProcessType(strtoull(m_argv[4], nullptr, 10));
#endif

  uint32_t layer2LightCount = 0;
  {
    athena::io::FileReader r(m_argv[1]);
    if (r.hasError())
      return;

    uint32_t modelCount = r.readUint32Big();
    m_models.resize(modelCount);
    for (uint32_t i = 0; i < modelCount; ++i) {
      zeus::CColor color = ColorForIndex(i);
      Model& model = m_models[i];
      uint32_t topology = r.readUint32Big();
      model.topology = topology ? GL_TRIANGLE_STRIP : GL_TRIANGLES;

      uint32_t vertCount = r.readUint32Big();
      model.verts.reserve(vertCount);
      for (uint32_t j = 0; j < vertCount; ++j) {
        model.verts.emplace_back();
        Model::Vert& vert = model.verts.back();
        vert.pos = r.readVec3fBig();
        vert.color = color;
        m_totalAABB.accumulateBounds(vert.pos);
        model.aabb.accumulateBounds(vert.pos);
      }

      uint32_t surfCount = r.readUint32Big();
      model.surfaces.resize(surfCount);
      uint32_t curIdx = 0;
      for (uint32_t j = 0; j < surfCount; ++j) {
        Model::Surface& surf = model.surfaces[j];
        surf.first = curIdx;
        surf.count = r.readUint32Big();
        curIdx += surf.count;

        for (uint32_t k = 0; k < surf.count; ++k) {
          uint32_t idx = r.readUint32Big();
          model.idxs.push_back(idx);
        }

        surf.transparent = r.readBool();
      }
    }

    uint32_t entityCount = r.readUint32Big();
    m_entities.resize(entityCount);
    for (uint32_t i = 0; i < entityCount; ++i) {
      Entity& ent = m_entities[i];
      ent.entityId = r.readUint32Big();
      ent.aabb.min = r.readVec3fBig();
      ent.aabb.max = r.readVec3fBig();
    }

    uint32_t lightCount = r.readUint32Big();
    layer2LightCount = r.readUint32Big();
    m_lights.resize(lightCount);
    for (uint32_t i = 0; i < lightCount; ++i) {
      Light& light = m_lights[i];
      light.point = r.readVec3fBig();
    }
  }

  if (!SetupVertexBuffersAndFormats()) {
    m_return = 1;
    return;
  }

  VISIBuilder builder(*this);
  std::vector<uint8_t> dataOut =
      builder.build(m_totalAABB, m_models.size(), m_entities, m_lights, layer2LightCount, m_updatePercent, parentPid);
  if (dataOut.empty()) {
    m_return = 1;
    return;
  }

  athena::io::FileWriter w(m_argv[2]);
  w.writeUBytes(dataOut.data(), dataOut.size());
}

void VISIRenderer::Terminate() { m_terminate = true; }