import { MipmapPipeline } from "./MipmapPipeline.js";
import { initTexture, createSamplerDesc } from "../Texture.js";
import { getBufferLayout, getPipelineHash } from '../Pipelines.js';
import { UniformGroup } from '../utils/UniformGroup.ts';
import { $wgsl } from "../../wgsl-preprocessor/wgsl-preprocessor.js";
import { FrameBindGroup } from "./FrameBindGroup.js";
import { getObjectUniformsDeclaration } from "./uniforms/ObjectUniforms.js";
import { sideToCullMode } from "../compat.js";
import { DepthFormat } from "../CommonRenderTargets.js";
import { BufferGeometryUtils } from "../../scene/BufferGeometry.js";


/*
How to extend a Three.ShaderMaterial for WebGPU

MainPass will catch the material and use it with CustomShaderPipeline.
The pipeline will have frameBindGroup and objectUniforms always.
Custom materials will use bind group 2.
All uniforms of simple types are stored in a single uniform buffer with binding 0.
The uniform buffer in the shader must list the uniforms in the same order as in material.uniforms.
Entrypoints for shaders are fixed as vsmain and psmain.

Before material is added to the scene, we need to add some information to it.
Here is how to do it, with types:

Use samplerBinding in textureBindings if you want to create a sampler from the texture uniform.
Use samplers if you want to define your own sampler.
Defaults are inserted in UniformGroup.ts.

material.webGPU = {
    labelPrefix: string,
    shader: string,
    uniformBufferVisibility?: GPUShaderStageFlags
    uniformBufferBindingLayout?: GPUBufferBindingLayout
    textureBindings?: [
        {
            uniformKey: string,
            binding: number,
            textureVisibility?: GPUShaderStageFlags,
            sampleType?: GPUTextureSampleType,
            samplerBinding?: number,
            sampleVisibility?: GPUShaderStageFlags
        }
    ],
    samplers?: [
        {
            binding: number,
            visibility?: GPUShaderStageFlags,
            samplerDesc?: GPUSamplerDescriptor

        }
    ]
}
*/


const CustomMaterialBindGroup = 2;


function getCustomShader(material) {
    return $wgsl(material.webGPU.shader, {
        frameBindGroup: FrameBindGroup.getDeclaration(0),
        objectUniforms: getObjectUniformsDeclaration(1),
    });
}


function getWebGPUUniformType(threeMaterialUniformType) {
    switch (threeMaterialUniformType) {
        case 'i': return 'i32';
        case 'f': return 'f32';
        case 'v2': return 'vec2f';
        case 'v3': return 'vec3f';
        case 'v4': return 'vec4f';
        case 'c': return 'c';
        case 'm3': return 'mat3x3f';
        case 'm4': return 'mat4x4f';
    }
    return null;
}


/**
 * Extracts the uniform buffer layout and values from a THREE.ShaderMaterial
 * in the format expected by UniformGroup.
 * @param {THREE.ShaderMaterial} material - The material to extract the uniform buffer info from.
 */
function extractUniformBufferInfo(material) {
    const uniformBufferLayout = {};
    const uniformBufferValues = {};

    for (let key in material.uniforms) {
        const threeMaterialUniformType = material.uniforms[key].type;
        const webGPUUniformType = getWebGPUUniformType(threeMaterialUniformType);

        if (webGPUUniformType) {
            uniformBufferLayout[key] = webGPUUniformType;
            uniformBufferValues[key] = material.uniforms[key].value;
        }
    }

    return { uniformBufferLayout, uniformBufferValues };
}


export class CustomShaderPipeline {

    #renderer;
    #device;

    #pipelines = new Map();
    #mipmapPipelines = new Map();

    #activePipeline;
    #currentMaterial;
    #activeBindGroupLayout;
    #activeTargetsList;
    #vb;

    constructor(renderer) {
        this.#renderer = renderer;
        this.#device = renderer.getDevice();
        this.#vb = this.#renderer.getVB();
    }

    #createPipeline(key, geometry, material) {
        // TODO:
        //  - Currently, we just add everything available to the vertex layout.
        //    So, we might bind more than needed.
        //  - VertexLayout is only considering basic types (like position/normal/uv), but should be more flexible.
        const attributes = geometry.attributes;
        const includeVC = attributes.color;
        const includeNormals = attributes.normal;
        const hasUV = attributes.uv;

        const labelPrefix = material.webGPU.labelPrefix || `Material-${key}`;

        let pipeline;

        const shader = this.#device.createShaderModule({
            label: `${labelPrefix} Shader`,
            code: getCustomShader(material)
        });

        // Unlike UberShader, which uses a single bind group layout for all materials, the layout may
        // vary for this case. Therefore, we replace the material uniform bind group layout at slot 2 with the
        // individual bindGroupLayout of the current material.
        const bindGroupLayouts = this.#activeBindGroupLayout.slice();
        bindGroupLayouts[CustomMaterialBindGroup] = material.__gpuCustomMaterialUniforms.layout;

        pipeline = this.#device.createRenderPipeline({
            label: `${labelPrefix} Pipeline`,
            layout: this.#device.createPipelineLayout({
                label: `${labelPrefix} Pipeline Layout`,
                bindGroupLayouts
            }),
            vertex: {
                module: shader,
                entryPoint: 'vsmain',
                buffers: getBufferLayout(geometry, includeNormals, hasUV, includeVC),
            },
            fragment: {
                module: shader,
                entryPoint: 'psmain',
                targets: this.#activeTargetsList,
            },
            primitive: {
                topology: 'triangle-list',
                cullMode: sideToCullMode(material.side),
            },

            depthStencil: {
                depthWriteEnabled: material.depthWrite,
                depthCompare: material.depthTest ? (material.depthFunc || 'less-equal') : "always",
                format: DepthFormat,
                depthBias: 0,
                depthBiasSlopeScale: 0
            },
        });

        this.#pipelines.set(key, pipeline);

        return pipeline;
    }


    reset(layouts, targets) {
        this.#activePipeline = null;
        this.#currentMaterial = null;
        this.#activeBindGroupLayout = layouts;
        this.#activeTargetsList = targets;
    }


    #getTextureFromTextureUniform(texture) {
        const key = texture.id;
        if (!this.#mipmapPipelines.has(key)) {
            this.#mipmapPipelines.set(key, new MipmapPipeline(this.#device));
        }
        initTexture(this.#device, this.#mipmapPipelines.get(key), texture);
        return texture.__gpuTexture;
    }

    #getUniformGroupFromCustomMaterial(material) {
        const customMaterialUniforms = new UniformGroup(material.webGPU.labelPrefix);
        customMaterialUniforms.setDefaultTexture(this.#renderer.getPlaceholderTexture().texture)

        const { uniformBufferLayout, uniformBufferValues } = extractUniformBufferInfo(material);
        customMaterialUniforms.addUniformBuffer(0, uniformBufferLayout,
            material.webGPU.uniformBufferVisibility, material.webGPU.uniformBufferBindingLayout);
        customMaterialUniforms.setUniformBufferValues(uniformBufferValues);

        for (let textureBinding of material.webGPU.textureBindings) {
            if (textureBinding.uniformKey && textureBinding.binding) {
                const textureUniform = material.uniforms[textureBinding.uniformKey]
                if (textureUniform.type !== 't' || textureUniform.value === null) {
                    continue;
                }
                const texture = this.#getTextureFromTextureUniform(textureUniform.value);
                customMaterialUniforms.addTexture(textureBinding.binding, texture,
                    textureBinding.textureVisibility, textureBinding.sampleType);

                if (textureBinding.samplerBinding) {
                    const samplerDesc = createSamplerDesc(textureUniform.value)
                    customMaterialUniforms.addSampler(textureBinding.samplerBinding, textureBinding.samplerVisibility, samplerDesc);
                }
            }
        }

        if (material.webGPU.samplers) {
            for (sampler of material.webGPU.samplers) {
                customMaterialUniforms.addSampler(sampler.binding, sampler.visibility, sampler.descriptor);
            }
        }

        return customMaterialUniforms;
    }

    #updateCustomMaterialUniforms(material) {
        material.__gpuCustomMaterialUniforms.update(this.#device)

    }

    #initMaterialBindings(material) {
        // TODO: Handle modifications, disposal etc.
        const gpuUniforms = material.__gpuCustomMaterialUniforms;

        if (!gpuUniforms) {
            const customMaterialUniforms = this.#getUniformGroupFromCustomMaterial(material);
            material.__gpuCustomMaterialUniforms = customMaterialUniforms;
            material.__gpuCustomMaterialUniforms.update(this.#device)
        } else if (material.uniformsNeedUpdate) {
            // Uniform values changed
            this.#updateCustomMaterialUniforms(material);
            material.uniformsNeedUpdate = false;
        }
    }

    #activateMaterialBindings(passEncoder, material) {
        this.#initMaterialBindings(material);

        if (this.#currentMaterial !== material) {
            passEncoder.setBindGroup(CustomMaterialBindGroup, material.__gpuCustomMaterialUniforms.bindGroup);
            this.#currentMaterial = material;
        }
    }


    drawOne(passEncoder, objectIndex, geometry, material, getPipelineHashCustom = getPipelineHash) {
        if (!geometry.vb || geometry.vbNeedsUpdate) {
            BufferGeometryUtils.interleaveGeometry(geometry, true);
        }

        this.#activateMaterialBindings(passEncoder, material);

        // ShaderMaterials require custom shaders, so we just brutally create/store a pipeline per material for now.
        // TODO: The default pipeline hash function might not be sufficient and we might need different pipelines depending on
        // other attributes and settings. This should be further generalized.
        const key = getPipelineHashCustom(geometry, material);
        let pipeline = this.#pipelines.get(key);
        if (!pipeline || material.needsUpdate) {
            pipeline = this.#createPipeline(key, geometry, material);
            material.needsUpdate = false;
        }

        if (pipeline !== this.#activePipeline) {
            passEncoder.setPipeline(pipeline);
            this.#activePipeline = pipeline;
        }

        this.#vb.draw(passEncoder, geometry, objectIndex);
    }

}
