import {wgsl} from "../../wgsl-preprocessor/wgsl-preprocessor";
import vertexTextureQuad from "../post/quad.vert.wgsl";
import { UniformGroup } from "./UniformGroup.ts";
import pack_depth from "../chunks/pack_depth.wgsl";

// Dummy fragment shader that just displays the UV coordinates
export const DummyFragmentShader = wgsl`
    @fragment
    fn main(
        @builtin(position) coord : vec4f,
        @location(0) @interpolate(linear) vUv: vec2f
    ) -> @location(0) vec4f {
        var outColor = vec4f(vUv, 0.0, 1.0);
        return outColor;
    }
`;

// Trivial fragment shader that just reads a source texture and writes back all channels.
export const CopyFragmentShader = wgsl`
    @group(0) @binding(1) var texture: texture_2d<f32>;
    @group(0) @binding(2) var texSampler: sampler;

    @fragment
    fn main(
      @builtin(position) coord: vec4f,
      @location(0) @interpolate(linear) vUv: vec2f
    ) -> @location(0) vec4f {
        var outColor = textureSample(texture, texSampler, vUv);
        return outColor;
    }
`;

// Converts RGB10-packed depth texture into a single-channel float32 texture
export const UnpackRGB10DepthShader = wgsl`
    ${pack_depth}

    @group(0) @binding(1) var texture: texture_2d<f32>;
    @group(0) @binding(2) var texSampler: sampler;

    @fragment
    fn main(
        @builtin(position) coord: vec4f,
        @location(0) @interpolate(linear) vUv: vec2f
    ) -> @location(0) vec4f {

        var packedDepth = textureSample(texture, texSampler, vUv);
        var depth       = unpackDepth10(packedDepth.rgb);
        return vec4f(depth, 0, 0, 1);
    }
`;

// Converts RGB10-packed normals texture into a single-channel float32 texture
export const UnpackRGB10NormalsShader = wgsl`
    // Note that pack_normals chunk is only needed for int-packed vertex-normals.
    // The rgb10-packed normals need unpackNormals10 instead.
    ${pack_depth}

    @group(0) @binding(1) var texture: texture_2d<f32>;
    @group(0) @binding(2) var texSampler: sampler;

    @fragment
    fn main(
        @builtin(position) coord: vec4f,
        @location(0) @interpolate(linear) vUv: vec2f
    ) -> @location(0) vec4f {

        var packedNormal = textureSample(texture, texSampler, vUv);
        var normal       = unpackNormal10(packedNormal.xyz);
          return vec4f(normal, 1.0f);
    }
`;

// We only use bindGroup 0 for all uniforms and textures
const DefaultBindGroup = 0;

// type BindGroupIndex = number;
// type BindingIndex   = number;

// Helper to render a fullscreen pass into a given target texture.
export class ShaderPassWebGPU {

    #device;         // GPUDevice;

    #pipeline;       // GPURenderPipeline;
    #passDescriptor; // GPURenderPassDescriptor;
    #vShader;        // GPUShaderModule;
    #fShader;        // GPUShaderModule;

    // target format for which pipeline was initialized
    #format;        // GPUTextureFormat;
    #targetView;    // GPUTextureView;
    #targetTexture; // GPUTexture

    #uniforms = []; // [BindGroupBuilder] - indexed by bindGroup

    // Dummy shader is used if this is empty
    #fragmentShaderCode;
    #fragmentShaderEntryPoint;
    #vertexShaderCode;
    #vertexShaderEntryPoint;

    // If run() is called with a srcTexture with input from a previous pass,
    // we need binding and bindGroup of the uniform where this texture is bound to.
    // This is mandatory, e.g., when using the pass as a userFinalPass.
    #srcTextureBinding   = 0; // BindingIndex
    #srcSamplerBinding   = 0; // BindingIndex
    #srcTextureBindGroup = 0; // BindGroupIndex used for sampler and texture

    // Used to allow tolerance for null-texture uniforms
    #defaultTexture; // GPUTexture;

    constructor(
        device,
        defaultTexture,
        bindGroup = DefaultBindGroup // by default, we create uniforms at bindGroup 0
    ) {
        this.#device = device;
        this.#defaultTexture = defaultTexture;

        // Set default shaders
        this.setFragmentShader();
        this.setVertexShader();

        // Init default uniforms
        const uniform = new UniformGroup();
        this.setUniforms(uniform, bindGroup);
    }

    // The default shader just colorizes the UV coordinates.
    setFragmentShader(fragmentShaderCode = DummyFragmentShader, entryPoint = "main") {
        this.#fragmentShaderCode       = fragmentShaderCode;
        this.#fragmentShaderEntryPoint = entryPoint;

        // discard prior shader & pipeline
        this.#fShader  = null;
        this.#pipeline = null;
    }

    // Default shader just renders a fullscreen triangle and y-flipped uv coords.
    //
    // Why yFlip?
    //    By default, we include an y-flip of uv coords in the vertex shader.
    //    This is necessary to preserve the image orientation when rendering from one texture into another one.
    //    Reason is that in WebGPU...
    //     - The y-axis in uv-space points down, but
    //     - In clip space, y-axis points up.
    setVertexShader(vertexShaderCode = vertexTextureQuad, entryPoint = "mainFlipY") {
        this.#vertexShaderCode       = vertexShaderCode;
        this.#vertexShaderEntryPoint = entryPoint;

        // discard prior shader & pipeline
        this.#vShader  = null;
        this.#pipeline = null;
    }

    /*
     * Only relevant if a srcTexture is passed in the run(...) method. This happens, e.g., when using this
     * pass as a post-processing step (userFinalPass).
     * In this case, this function specifies the uniform name where the srcTexture is bound to.
     *
     * @param {BindingIndex}   [textureBinding=1] - We leave 1 free for a UniformBuffer
     * @param {BindingIndex}   [samplerBinding=2]
     * @param {BindGroupIndex} [bindGroup=0]
     */
    setSourceTextureUniform(textureBinding = 1, samplerBinding = 2, bindGroup = 0) {

        this.#srcTextureBinding   = textureBinding;
        this.#srcSamplerBinding   = samplerBinding;
        this.#srcTextureBindGroup = bindGroup;

        // Add uniforms automatically if not existing already
        const uniforms = this.getUniforms(bindGroup);
        if (!uniforms.isAssigned(textureBinding)) {
            uniforms.addTexture(textureBinding);
        }

        // Same for sampler
        if (!uniforms.isAssigned(samplerBinding)) {
            uniforms.addSampler(samplerBinding);
        }
    }

    /*
     * @param bindGroup: BindGroupIndex
     * @param uniforms:  BindGroupBuilder
     */
    setUniforms(uniforms, bindGroup = DefaultBindGroup) {
        this.#uniforms[bindGroup] = uniforms;
    }

    getUniforms(bindGroup = DefaultBindGroup) {
        return this.#uniforms[bindGroup];
    }

    // convenience for the default case that UniformBuffer is on bindGroup 0
    get uniforms() {
        return this.getUniforms();
    }

    createPipeline() {

        // create pipeline layout
        let layout = "auto";

        // Collect bindGroup layouts from uniforms
        if (this.#uniforms.length > 0) {
            const bindGroupLayouts = [];
            for (let bg = 0; bg < this.#uniforms.length; bg++) {
                const uniforms = this.#uniforms[bg];
                bindGroupLayouts[bg] = uniforms.layout;
            }

            layout = this.#device.createPipelineLayout({ bindGroupLayouts });
        }

        this.#pipeline = this.#device.createRenderPipeline({
            layout,
            vertex: {
                module: this.#vShader,
                entryPoint: this.#vertexShaderEntryPoint
            },
            fragment: {
                module: this.#fShader,
                entryPoint: this.#fragmentShaderEntryPoint,
                targets: [
                    {
                        format: this.#format
                    }
                ]
            },
            primitive: {
                topology: 'triangle-list',
                cullMode: 'back',
            }
        });
    }

    init(dstTexture) {

        // Refresh pipeline with old texture if no new one is specified
        dstTexture = dstTexture || this.#targetTexture;

        // Without knowing the format, we cannot create the pipeline yet.
        if (!dstTexture) {
            return;
        }

        // Already initialized
        if (this.#pipeline && this.#targetTexture === dstTexture) {
            return;
        }

        // Make sure layout and bindGroup are ready for all uniforms.
        for (let bgIndex in this.#uniforms) {
            this.#uniforms[bgIndex].update(this.#device);
        }

        // texture & view
        this.#targetView    = dstTexture.createView();
        this.#targetTexture = dstTexture;

        // create shaders
        if (!this.#vShader) {
            this.#vShader = this.#device.createShaderModule({ code: this.#vertexShaderCode });
        }

        if (!this.#fShader) {
            this.#fShader = this.#device.createShaderModule({ code: this.#fragmentShaderCode });
        }

        // make sure pipeline is setup for the target format
        if (!this.#pipeline || this.#format !== dstTexture.format) {
            this.#format = dstTexture.format;
            this.createPipeline();
        }

        if (!this.#passDescriptor) {
            this.#passDescriptor = {
                colorAttachments: [
                    {
                        // view is acquired and set in render loop.
                        view: undefined,

                        // by default, we use full-screen pass without blending, so the clearValue isn't relevant
                        clearValue: { r: 0.0, g: 0.0, b: 0.0, a: 1.0 },

                        loadOp:  'clear',
                        storeOp: 'store',
                    },
                ],
            };
        }
    }

    // render into given target
    run(dstTexture, srcTexture) {

        if (!this.#device) {
            return;
        }

        // Set srcTexture uniform if changed.
        if (srcTexture) {
            const uniforms = this.getUniforms(this.#srcTextureBindGroup);
            uniforms.setTexture(srcTexture, this.#srcTextureBinding);
        }

        this.init(dstTexture);

        let commandEncoder = this.#device.createCommandEncoder();

        this.#passDescriptor.colorAttachments[0].view = this.#targetView;

        let pass = commandEncoder.beginRenderPass(this.#passDescriptor);

        // Activate uniforms (if any)
        this.#uniforms.forEach((uniforms, bindGroup) => {
            pass.setBindGroup(bindGroup, uniforms.bindGroup);
        });

        pass.setPipeline(this.#pipeline);
        pass.draw(3);
        pass.end();

        this.#device.queue.submit([commandEncoder.finish()]);
    }
}

// Shortcut for creating and running a shader pass to...
//  - Read srcTexture
//  - Apply a fragmentShader to it
//  - Write into dstTexture (width and height must match with srcTexture)
export function convertTexture(device, srcTexture, dstTexture, fragmentShaderCode) {
    let pass = new ShaderPassWebGPU(device);
    pass.setSourceTextureUniform();
    pass.setFragmentShader(fragmentShaderCode);
    pass.run(dstTexture, srcTexture);
}

// Run a dummy pass to copy one texture into another.
// Can help to read textures without 'COPY_SRC' flag.
export function copyGPUTexture(
    device,
    srcTexture, // GPUTexture
    usage = GPUTextureUsage.RENDER_ATTACHMENT | GPUTextureUsage.TEXTURE_BINDING | GPUTextureUsage.COPY_SRC
) {
    // Use shader pass to convert to readable format
    const newTex = device.createTexture({
        size:   [srcTexture.width, srcTexture.height],
        format: srcTexture.format,
        usage:  usage,
    });
    convertTexture(device, srcTexture, newTex, CopyFragmentShader);
    return newTex;
}
