import { getMipmapsShaders } from "../chunks/mipmap_shader";

// Based on https://github.com/toji/web-texture-tool/blob/main/src/webgpu-mipmap-generator.js
export class MipmapPipeline {
  #device;
  #bindGroupLayout;
  #pipelineLayout;

  #pipelines = {};

  constructor(device) {
    this.#device = device;
  }

  #createPipeline(format) {
    let pipeline = this.#pipelines[format];
    if (!pipeline) {
      const mipmapShaderModule = this.#device.createShaderModule({
        label: 'Mipmap Generator',
        code: getMipmapsShaders()
      });

      this.#bindGroupLayout = this.#device.createBindGroupLayout({
        label: 'Mipmap Generator',
        entries: [{
          binding: 0,
          visibility: GPUShaderStage.FRAGMENT,
          sampler: {}
        }, {
          binding: 1,
          visibility: GPUShaderStage.FRAGMENT,
          texture: {}
        }]
      });

      this.#pipelineLayout = this.#device.createPipelineLayout({
        label: 'Mipmap Generator',
        bindGroupLayouts: [this.#bindGroupLayout]
      });

      pipeline = this.#device.createRenderPipeline({
        layout: this.#pipelineLayout,
        vertex: {
          module: mipmapShaderModule,
          entryPoint: 'vertexMain'
        },
        fragment: {
          module: mipmapShaderModule,
          entryPoint: 'fragmentMain',
          targets: [{
            format: format
          }]
        },
        primitive: {
          topology: 'triangle-strip',
          stripIndexFormat: 'uint32'
        }
      });

      this.#pipelines[format] = pipeline;
    }

    return pipeline;
  }

  generateMipmaps(texture, descriptor) {
    let srcView = texture.__gpuTexture.createView({
      baseMipLevel: 0,
      mipLevelCount: 1
    });

    const commandEncoder = this.#device.createCommandEncoder({});

    // Loop through each mip level and renders the previous level's contents into it
    for (let i = 1; i < descriptor.mipLevelCount; ++i) {
      const dstView = texture.__gpuTexture.createView({
        baseMipLevel: i,
        mipLevelCount: 1
      });

      const passEncoder = commandEncoder.beginRenderPass({
        colorAttachments: [{
          view: dstView, // Render pass uses the next mip level as it's render attachment
          loadValue: [0, 0, 0, 0],
          loadOp: 'clear',
          storeOp: 'store'
        }]
      });

      const pipeline = this.#createPipeline(descriptor.format);

      // Need a separate bind group for each level to ensure we're only sampling from the previous level
      const bindGroup = this.#device.createBindGroup({
        layout: pipeline.getBindGroupLayout(0),
        entries: [{
          binding: 0,
          resource: texture.__gpuSampler
        }, {
          binding: 1,
          resource: srcView
        }]
      });

      // Render
      passEncoder.setPipeline(pipeline);
      passEncoder.setBindGroup(0, bindGroup);
      passEncoder.draw(3, 1, 0, 0);
      passEncoder.end();

      // The source texture view for the next iteration of the loop is the destination view for this one.
      srcView = dstView;
    }

    this.#device.queue.submit([commandEncoder.finish()]);
  }
}