import { Vec2 } from 'curtainsjs';
import {
  vertexShader,
  UNIVERSAL_UNIFORMS,
  BLEND,
  FLOATING_POINT,
  universalUniformParams,
  computeFragColor,
  GAUSSIAN_WEIGHTS_24
} from '../ShaderHelpers.js';

const fragmentShader = `#version 300 es
precision mediump float;

in vec2 vTextureCoord;
uniform sampler2D uTexture;
uniform sampler2D uBgTexture;
uniform int uPass;
uniform float uSmoothness;
uniform float uLevels;
uniform float uThreshold;
${UNIVERSAL_UNIFORMS}

const float PI = 3.1415926;
${GAUSSIAN_WEIGHTS_24}

out vec4 fragColor;

float luma(vec3 color) {
  return dot(color, vec3(0.299, 0.587, 0.114));
}

float random(vec2 p) {
  p = fract(p * vec2(123.45, 678.90));
  p += dot(p, p + 45.32);
  return fract(p.x * p.y);
}

vec3 Tonemap_tanh(vec3 x) {
  x = clamp(x, -40.0, 40.0);
  return (exp(x) - exp(-x)) / (exp(x) + exp(-x));
}

// Helper function for HSL to RGB conversion
float hueToRgb(float p, float q, float t) {
    if (t < 0.0) t += 1.0;
    if (t > 1.0) t -= 1.0;
    if (t < 1.0 / 6.0) return p + (q - p) * 6.0 * t;
    if (t < 1.0 / 2.0) return q;
    if (t < 2.0 / 3.0) return p + (q - p) * (2.0 / 3.0 - t) * 6.0;
    return p;
}

// Convert from HSL to RGB
vec3 hslToRgb(vec3 hsl) {
    float h = hsl.x;
    float s = hsl.y;
    float l = hsl.z;
    
    vec3 rgb = vec3(l);  // achromatic color (grey)
    if (s != 0.0) {
        float q = l < 0.5 ? l * (1.0 + s) : l + s - l * s;
        float p = 2.0 * l - q;
        rgb.r = hueToRgb(p, q, h + 1.0 / 3.0);
        rgb.g = hueToRgb(p, q, h);
        rgb.b = hueToRgb(p, q, h - 1.0 / 3.0);
    }
    return rgb;
}

// Convert from RGB to HSL
vec3 rgbToHsl(vec3 rgb) {
    float max = max(max(rgb.r, rgb.g), rgb.b);
    float min = min(min(rgb.r, rgb.g), rgb.b);
    float h, s, l = (max + min) / 2.0;

    if (max == min) {
        h = s = 0.0; // achromatic
    } else {
        float d = max - min;
        s = l > 0.5 ? d / (2.0 - max - min) : d / (max + min);
        if (max == rgb.r) {
            h = (rgb.g - rgb.b) / d + (rgb.g < rgb.b ? 6.0 : 0.0);
        } else if (max == rgb.g) {
            h = (rgb.b - rgb.r) / d + 2.0;
        } else if (max == rgb.b) {
            h = (rgb.r - rgb.g) / d + 4.0;
        }
        h /= 6.0;
    }

    return vec3(h, s, l);
}

vec4 blur(vec2 uv, vec2 dir) {
  vec4 color = vec4(0.0);
  float total_weight = 0.0;
  
  vec4 center = texture(uTexture, uv);
  float center_weight = getGaussianWeight(0);
  color += center * center_weight;
  total_weight += center_weight;
  
  float edge_threshold = mix(0.1, 0.2, uSmoothness);
  
  // Calculate center luma
  float center_luma = luma(center.rgb);
  
  for (int i = 1; i <= 9; i++) {
    float weight = getGaussianWeight(i);
    float offset = mix(0.001, 0.02, uSmoothness) * float(i)/9.;
    
    vec4 sample1 = texture(uTexture, uv + offset * dir);
    vec4 sample2 = texture(uTexture, uv - offset * dir);
    
    // Calculate luma differences instead of color vector differences
    float diff1 = abs(center_luma - luma(sample1.rgb));
    float diff2 = abs(center_luma - luma(sample2.rgb));
    
    // Apply weight reduction based on luma difference (edge detection)
    float edge_factor1 = smoothstep(edge_threshold, 0.0, diff1);
    float edge_factor2 = smoothstep(edge_threshold, 0.0, diff2);
    
    // Apply modified weights
    float modified_weight1 = weight * edge_factor1;
    float modified_weight2 = weight * edge_factor2;
    
    color += sample1 * modified_weight1 + sample2 * modified_weight2;
    total_weight += modified_weight1 + modified_weight2;
  }

  // Ensure we don't divide by zero
  total_weight = max(total_weight, 0.001);
  return color / total_weight;
}

vec4 posterize(vec4 color) {

  // Convert to HSL
  vec3 hsl = rgbToHsl(color.rgb);
  
  // Posterize each component
  float levels = max(2.0, uLevels);
  
  // Posterize hue - keep more levels for hue to maintain color variety
  hsl.x = floor(hsl.x * levels * 2.0) / (levels * 2.0);
  
  // Posterize saturation
  hsl.y = floor(hsl.y * levels) / (levels - 1.0);
  
  // Posterize lightness
  hsl.z = floor(hsl.z * levels) / (levels - 1.0);
  
  // Convert back to RGB
  vec3 rgb = hslToRgb(hsl);
  
  return vec4(rgb, color.a);
}

vec4 finalize(vec2 uv, vec2 dir) {
  // Get blurred color
  vec4 color = blur(uv, dir);
  
  // Apply posterization effect directly
  vec4 result = posterize(color);
  float dither = (random(gl_FragCoord.xy) - 0.5) / 255.0;
  result.rgb += dither;
  return result;
}

vec4 getColor(vec2 uv) {
float aspect = uResolution.x/uResolution.y;
  switch(uPass) {
    case 0: return blur(uv, vec2(1.0 / aspect, 0.0)); break;
    case 1: return finalize(uv, vec2(0.0, 1.0)); break;
    default: return vec4(0.0); break;
  }
}

void main() {
    vec2 uv = vTextureCoord;

    vec4 color = getColor(uv);

    ${computeFragColor('color')}
}`;

export const params = {
  fragmentShader: fragmentShader,
  vertexShader,
  crossorigin: 'anonymous',
  depth: false,
  texturesOptions: {
    floatingPoint: FLOATING_POINT,
    premultiplyAlpha: true,
  },

  uniforms: {
    pass: {
      name: 'uPass',
      type: '1i',
      value: 0,
    },
    smoothness: {
      name: 'uSmoothness',
      type: '1f',
      value: 0.5,
    },
    threshold: {
      name: 'uThreshold',
      type: '1f',
      value: 0.1,
    },
    levels: {
      name: 'uLevels',
      type: '1f',
      value: 5.0,
    },
    ...universalUniformParams,
  },
};

export const POSTERIZE = {
  id: 'posterize',
  label: 'Posterize',
  params: params,
  passes: [
    {
      prop: 'pass',
      value: 0,
      includeBg: false,
    },
    {
      prop: 'pass',
      value: 1,
      includeBg: true,
    },
  ],
  properties: {
    smoothness: {
      label: 'Smoothness',
      value: 0.5,
      min: 0,
      max: 1,
      step: 0.01,
      output: 'percent',
    },
    levels: {
      label: 'Levels',
      value: 5.0,
      min: 2.0,
      max: 20.0,
      step: 1.0,
      output: 'int',
    },
  },
};
