import { VertexEnumerator } from "./VertexEnumerator";
import { LmvBox3 as Box3 } from "./LmvBox3";
import { LmvMatrix4 } from "./LmvMatrix4";
import { BVHBuilder } from "./BVHBuilder";
import { LmvVector3 } from "./LmvVector3";

const _tmpBox = new Box3();
const _tmpMtx = new LmvMatrix4();
const _tmpVec = new LmvVector3();

// Compute the offset origin, edges, and normal.
const diff = new LmvVector3();
const edge1 = new LmvVector3();
const edge2 = new LmvVector3();
const normal = new LmvVector3();

const _invDir = new LmvVector3();

/** Implements FragInfo interface for BVH builder. */
class ConsolidatedFragInfo {

  constructor(boxes) {
    this.boxes = boxes;
    this.boxStride = 6;
    this.wantSort = false;
  }

  getCount() {
    return this.boxes.length / 6;
  }

  isTransparent() {
    return false;
  }

  getPolygonCount() {
    return 1;
  }

}


export class ConsolidatedBVH {

  #primBoxes;
  #primVerts;
  #primNorms;
  #primIds;
  #primInds;
  #bvhNodes;
  #bvhNodesF;
  #bvhChildren;
  #nodeStack = new Array(32);

  #model;

  /** @param {DtModel} model */
  constructor(model) {
    this.#model = model;
    this.#primBoxes = [];
    this.#primVerts = [];
    this.#primNorms = [];
    this.#primInds = [];
    this.#primIds = [];
  }

  build() {

    this.#collectRoomPolygons();

    let builder = new BVHBuilder(null, null, new ConsolidatedFragInfo(this.#primBoxes));

    builder.build({ useSlimNodes: true });

    this.#bvhNodes = builder.nodes;
    this.#bvhNodesF = this.#bvhNodes.nodesF;

    this.#bvhChildren = builder.primitives;

    this.#primBoxes = null;

    let vertsBuf = new Float32Array(this.#primVerts.length);
    vertsBuf.set(this.#primVerts);
    this.#primVerts = vertsBuf;

    let normsBuf = new Float32Array(this.#primNorms.length);
    normsBuf.set(this.#primNorms);
    this.#primNorms = normsBuf;

    let indsBuf = new Int32Array(this.#primInds.length);
    indsBuf.set(this.#primInds);
    this.#primInds = indsBuf;

    let primIds = new Int32Array(this.#primIds.length);
    primIds.set(this.#primIds);
    this.#primIds = primIds;
  }

  rayIntersect(ray, options) {
    // TODO prevent ray-casting against fragments hidden by cut-planes.
    const it = this.#model.getInstanceTree();
    const isHiddenOrOff = (dbId) => it.isNodeHidden(dbId) || it.isNodeOff(dbId);
    const nodeStack = this.#nodeStack;

    nodeStack[0] = 0; //No transparent geometries means we don't have to put the transparent objects root node (index = 1) in here
    let stackTop = 1;

    let distance = options?.maxDistance ?? Infinity; //distance to nearest intersection
    let dbId = -1;

    _invDir.set(1.0 / ray.direction.x, 1.0 / ray.direction.y, 1.0 / ray.direction.z);

    while (stackTop) {
      const nodeIdx = nodeStack[--stackTop];

      const xBoxPt = this.#rayIntersectBox(ray, _invDir, nodeIdx, _tmpVec);

      //Ray doesn't intersect the current node's bbox, or any intersection
      //is farther than an already found intersection
      if (xBoxPt === null || xBoxPt.x > distance) continue;

      const primCount = this.#bvhNodes.getPrimCount(nodeIdx);

      if (primCount === 0) {
        //Inner node -- push children on stack
        const child = this.#bvhNodes.getLeftChild(nodeIdx);
        if (child !== -1) {
          nodeStack[stackTop++] = child;
          nodeStack[stackTop++] = child + 1;
        }

        continue;
      }

      //Leaf node -- intersect the triangles with the ray
      const primStart = this.#bvhNodes.getPrimStart(nodeIdx);
      for (let i = primStart, iEnd = primStart + primCount; i < iEnd; i++) {

        const primId = this.#bvhChildren[i];
        const xPt = this.#rayIntersectTriangle(ray, primId, false, _tmpVec);

        if (xPt && xPt.x < distance && !isHiddenOrOff(this.#primIds[2 * primId])) {
          distance = xPt.x;
          dbId = this.#primIds[2 * primId];
        }
      }
    }

    return [dbId, distance];
  }

  #addGeomToBvh(model, fragId, dbId) {

    const fl = model.getFragmentList();
    const geom = fl.getGeometry(fragId);

    if (!geom) {
      console.warn("room without geom");
      return;
    }

    if (geom.isLines) {
      return;
    }

    fl.getWorldMatrix(fragId, _tmpMtx);

    const baseIndex = this.#primVerts.length / 4;

    VertexEnumerator.enumMeshVertices(geom, (p, n, uv, idx) => {

      this.#primVerts.push(p.x, p.y, p.z, uv?.x || 0.0);

      if (n) {
        this.#primNorms.push(n.x, n.y, n.z, uv?.y || 0.0);
      } else {
        this.#primNorms.push(0, 0, 0, 0);
      }

    }, _tmpMtx);

    VertexEnumerator.enumMeshIndices(geom, (a, b, c) => {

      const i0 = a + baseIndex;
      const i1 = b + baseIndex;
      const i2 = c + baseIndex;

      this.#primInds.push(i0, i1, i2);

      _tmpBox.makeEmpty();

      _tmpVec.set(this.#primVerts[4 * i0], this.#primVerts[4 * i0 + 1], this.#primVerts[4 * i0 + 2]);
      _tmpBox.expandByPoint(_tmpVec);

      _tmpVec.set(this.#primVerts[4 * i1], this.#primVerts[4 * i1 + 1], this.#primVerts[4 * i1 + 2]);
      _tmpBox.expandByPoint(_tmpVec);

      _tmpVec.set(this.#primVerts[4 * i2], this.#primVerts[4 * i2 + 1], this.#primVerts[4 * i2 + 2]);
      _tmpBox.expandByPoint(_tmpVec);

      this.#primBoxes.push(_tmpBox.min.x, _tmpBox.min.y, _tmpBox.min.z, _tmpBox.max.x, _tmpBox.max.y, _tmpBox.max.z);

      //Remembering the model ID per triangle is currently redundant, because ConsolidatedBVH is used per model.
      //However, the code is capable of combining triangles from multiple models, which I expect we will do soon.
      this.#primIds.push(dbId, model.id);

    });

    //console.log("room polys", this.#primVerts.length / 3);
  }

  /** This collector is specific to rooms, a similar logic can be used for polygons of any subset of the model. */
  #collectRoomPolygons() {

    const roomMap = this.#model.getRooms();
    const it = this.#model.getInstanceTree();
    const seen = new Set();

    for (const dbIdStr in roomMap) {
      const roomEntry = roomMap[dbIdStr];

      // Each room occurs twice in the roomMap, under its dbId and long row ID
      if (seen.has(roomEntry.dbId)) {
        continue;
      }
      seen.add(roomEntry.dbId);

      //Skip duplicate room geoms -- rooms from Revit come with two fragments each, with duplicate but reverse facing geom
      //we just take the first fragment here, assuming no split meshes
      let firstId;
      it.enumNodeFragments(roomEntry.dbId, (fragId) => {
        firstId = fragId;
        return true; // Early exit
      }, false);

      this.#addGeomToBvh(this.#model, firstId, roomEntry.dbId);
    }

  }

  // @see THREE.Ray<intersectBox>
  #rayIntersectBox(ray, invDir, nodeIdx, res) {

    let tmin, tmax, tymin, tymax, tzmin, tzmax;

    const boxPtr = this.#bvhNodesF;
    const boff = nodeIdx << 3;

    const invdirx = invDir.x;
    const invdiry = invDir.y;
    const invdirz = invDir.z;

    const origin = ray.origin;

    if (invdirx >= 0) {

      tmin = (boxPtr[boff] - origin.x) * invdirx;
      tmax = (boxPtr[boff + 3] - origin.x) * invdirx;

    } else {

      tmin = (boxPtr[boff + 3] - origin.x) * invdirx;
      tmax = (boxPtr[boff] - origin.x) * invdirx;
    }

    if (invdiry >= 0) {

      tymin = (boxPtr[boff + 1] - origin.y) * invdiry;
      tymax = (boxPtr[boff + 4] - origin.y) * invdiry;

    } else {

      tymin = (boxPtr[boff + 4] - origin.y) * invdiry;
      tymax = (boxPtr[boff + 1] - origin.y) * invdiry;
    }

    if (tmin > tymax || tymin > tmax) return null;

    // These lines also handle the case where tmin or tmax is NaN
    // (result of 0 * Infinity). x !== x returns true if x is NaN

    if (tymin > tmin || tmin !== tmin) tmin = tymin;

    if (tymax < tmax || tmax !== tmax) tmax = tymax;

    if (invdirz >= 0) {

      tzmin = (boxPtr[boff + 2] - origin.z) * invdirz;
      tzmax = (boxPtr[boff + 5] - origin.z) * invdirz;

    } else {

      tzmin = (boxPtr[boff + 5] - origin.z) * invdirz;
      tzmax = (boxPtr[boff + 2] - origin.z) * invdirz;
    }

    if (tmin > tzmax || tzmin > tmax) return null;

    if (tzmin > tmin || tmin !== tmin) tmin = tzmin;

    if (tzmax < tmax || tmax !== tmax) tmax = tzmax;

    //return point closest to the ray (positive side)

    if (tmax < 0) return null;

    //return this.at( tmin >= 0 ? tmin : tmax, optionalTarget );
    res.x = tmin >= 0 ? tmin : tmax;
    return res;
  }

  // @see THREE.Ray<intersectTriangle>
  #rayIntersectTriangle(ray, iPrim, backfaceCulling, result) {

    const i0 = this.#primInds[3 * iPrim];
    const i1 = this.#primInds[3 * iPrim + 1];
    const i2 = this.#primInds[3 * iPrim + 2];

    //vB - vA
    edge1.set(
      this.#primVerts[4 * i1] - this.#primVerts[4 * i0],
      this.#primVerts[4 * i1 + 1] - this.#primVerts[4 * i0 + 1],
      this.#primVerts[4 * i1 + 2] - this.#primVerts[4 * i0 + 2]
    );

    //vC - vA
    edge2.set(
      this.#primVerts[4 * i2] - this.#primVerts[4 * i0],
      this.#primVerts[4 * i2 + 1] - this.#primVerts[4 * i0 + 1],
      this.#primVerts[4 * i2 + 2] - this.#primVerts[4 * i0 + 2]
    );

    normal.crossVectors(edge1, edge2);

    // Solve Q + t*D = b1*E1 + b2*E2 (Q = kDiff, D = ray direction,
    // E1 = kEdge1, E2 = kEdge2, N = Cross(E1,E2)) by
    //   |Dot(D,N)|*b1 = sign(Dot(D,N))*Dot(D,Cross(Q,E2))
    //   |Dot(D,N)|*b2 = sign(Dot(D,N))*Dot(D,Cross(E1,Q))
    //   |Dot(D,N)|*t = -sign(Dot(D,N))*Dot(Q,N)
    let DdN = ray.direction.dot(normal);
    let sign;

    if (DdN > 0) {

      if (backfaceCulling) return null;
      sign = 1;

    } else if (DdN < 0) {

      sign = -1;
      DdN = -DdN;

    } else {
      return null;
    }

    //ray.origin - vA
    diff.set(
      ray.origin.x - this.#primVerts[4 * i0],
      ray.origin.y - this.#primVerts[4 * i0 + 1],
      ray.origin.z - this.#primVerts[4 * i0 + 2]
    );

    const DdQxE2 = sign * ray.direction.dot(edge2.crossVectors(diff, edge2));

    // b1 < 0, no intersection
    if (DdQxE2 < 0) {
      return null;
    }

    const DdE1xQ = sign * ray.direction.dot(edge1.cross(diff));

    // b2 < 0, no intersection
    if (DdE1xQ < 0) {
      return null;
    }

    // b1+b2 > 1, no intersection
    if (DdQxE2 + DdE1xQ > DdN) {
      return null;
    }

    // Line intersects triangle, check if ray does.
    const QdN = -sign * diff.dot(normal);

    // t < 0, no intersection
    if (QdN < 0) {
      return null;
    }

    // Ray intersects triangle.
    result.x = QdN / DdN; //just return the distance to the hit
    return result;
  }
}