import { universalUniformParams, vertexShaderNoMatrix, UNIVERSAL_UNIFORMS } from '../ShaderHelpers.js';
import { Vec2 } from 'curtainsjs';

const fragmentShader = `#version 300 es
precision highp float;

in vec3 vVertexPosition;
in vec2 vTextureCoord;

uniform sampler2D uTexture;
uniform sampler2D uPingPongTexture;
uniform vec2 uPreviousMousePos;
uniform float uRadius;
uniform float uAmount;
uniform float uDissipate;
uniform float uTurbulence;
uniform float uBloom;
uniform float uDecay;
uniform float uTime;
${UNIVERSAL_UNIFORMS}

const float PI = 3.1415926;
const float TWOPI = 6.2831852;

out vec4 fragColor;

vec3 Tonemap_tanh(vec3 x) {
  x = clamp(x, -40.0, 40.0);
  return (exp(x) - exp(-x)) / (exp(x) + exp(-x));
}

vec3 hsv2rgb(vec3 c) {
    vec4 K = vec4(1.0, 2.0 / 3.0, 1.0 / 3.0, 3.0);
    vec3 p = abs(fract(c.xxx + K.xyz) * 6.0 - K.www);
    return c.z * mix(K.xxx, clamp(p - K.xxx, 0.0, 1.0), c.y);
}

vec3 rgb2hsv(vec3 c) {
    vec4 K = vec4(0.0, -1.0 / 3.0, 2.0 / 3.0, -1.0);
    vec4 p = mix(vec4(c.bg, K.wz), vec4(c.gb, K.xy), step(c.b, c.g));
    vec4 q = mix(vec4(p.xyw, c.r), vec4(c.r, p.yzx), step(p.x, c.r));

    float d = q.x - min(q.w, q.y);
    float e = 1.0e-10;
    return vec3(abs(q.z + (q.w - q.y) / (6.0 * d + e)), d / (q.x + e), q.x);
}

mat2 rot(float a) {
    return mat2(cos(a), -sin(a), sin(a), cos(a));
}

vec2 angleToDir(float angle) {
    float rad = angle * 2.0 * PI;
    return vec2(cos(rad), sin(rad));
}

vec2 liquify(vec2 st, vec2 dir) {
    float aspectRatio = uResolution.x / uResolution.y;
    st.x *= aspectRatio;
    float amplitude = 0.0025;
    float freq = 6.;
    for (float i = 1.0; i <= 5.0; i++) {
        st = st * rot(i / 5.0 * PI * 2.0);
        st += vec2(
            amplitude * cos(i * freq * st.y + uTime * 0.02 * dir.x),
            amplitude * sin(i * freq * st.x + uTime * 0.02 * dir.y)
        );
    }
    st.x /= aspectRatio;
    return st;
}

// Calculate trail contribution for a specific point
vec3 calculateTrailContribution(vec2 mousePos, vec2 prevMousePos, vec2 uv, vec2 correctedUv, float aspectRatio, float radius) {
    vec2 dir = (mousePos - prevMousePos) * vec2(aspectRatio, 1.0);
    float angle = atan(dir.y, dir.x);
    if (angle < 0.0) angle += TWOPI;
    
    // Create a line segment between prevMousePos and mousePos
    vec2 mouseVec = mousePos - prevMousePos;
    float mouseLen = length(mouseVec);
    vec2 mouseDir = mouseLen > 0.0 ? mouseVec / mouseLen : vec2(0.0);
    
    // Project point onto line to get closest point
    vec2 posToUv = correctedUv - prevMousePos * vec2(aspectRatio, 1.0);
    float projection = clamp(dot(posToUv, mouseDir * vec2(aspectRatio, 1.0)), 0.0, mouseLen * aspectRatio);
    vec2 closestPoint = prevMousePos * vec2(aspectRatio, 1.0) + mouseDir * vec2(aspectRatio, 1.0) * projection;
    
    // Calculate distance to line segment rather than just a point
    float distanceToLine = distance(correctedUv, closestPoint);
    float s = (1.0 + radius)/(distanceToLine + radius) * radius;
    
    vec3 color = vec3(angle / TWOPI, 1.0, 1.0);
    vec3 pointColor = hsv2rgb(color);
    pointColor = pow(pointColor, vec3(2.2));
    
    // Smoother falloff
    float intensity = pow(s, 10.0 * (1. - uBloom + 0.1));
    return pointColor * intensity;
}

void main() {
    float aspectRatio = uResolution.x / uResolution.y;
    vec2 uv = vTextureCoord;
    vec2 correctedUv = (uv) * vec2(aspectRatio, 1.0);

    vec3 lastFrameColor = texture(uPingPongTexture, uv).rgb;
    lastFrameColor = pow(lastFrameColor, vec3(2.2));
    
    vec3 hsv = rgb2hsv(lastFrameColor);
    float prevAngle = hsv.x;
    vec2 prevDir = angleToDir(prevAngle);
    float prevStrength = hsv.z;
    vec2 dir = (uMousePos - uPreviousMousePos) * vec2(aspectRatio, 1.0);
    float dist = length(dir);
    
    float blurAmount = 0.03 * prevStrength;
    uv = uv - prevDir * blurAmount;
    uv = mix(uv, liquify(uv - prevDir * 0.005, prevDir), (1. - prevStrength) * uTurbulence);
    lastFrameColor = texture(uPingPongTexture, uv).rgb;
    lastFrameColor = pow(lastFrameColor, vec3(2.2));
    
    int numPoints = int(max(12.0, dist * 24.0));
    float speedFactor = clamp(dist, 0.7, 1.3);
    float radius = mix(0.1, 0.7, uRadius * speedFactor);
    
    vec3 trailColor = vec3(0.0);
    int iter = min(numPoints, 24);
    
    for (int i = 0; i <= iter; i++) {
      float t = float(i) / float(numPoints);
      vec2 interpPos = mix(uPreviousMousePos, uMousePos, t);
      vec2 prevInterpPos = i > 0 ? mix(uPreviousMousePos, uMousePos, float(i-1) / float(numPoints)) : uPreviousMousePos;
      trailColor += calculateTrailContribution(interpPos, prevInterpPos, uv, correctedUv, aspectRatio, radius);
    }
    
    trailColor = trailColor / float(min(numPoints, 50) + 1);
    
    vec3 blurredLastFrame = vec3(0.0);
    float clampedDist = clamp(length(trailColor) * dist, 0.0, 1.0);
    float blurRadius = 0.005;
    
    blurredLastFrame += pow(texture(uPingPongTexture, uv + vec2(blurRadius, 0.0)).rgb, vec3(2.2)) * 0.2;
    blurredLastFrame += pow(texture(uPingPongTexture, uv + vec2(-blurRadius, 0.0)).rgb, vec3(2.2)) * 0.2;
    blurredLastFrame += pow(texture(uPingPongTexture, uv + vec2(0.0, blurRadius)).rgb, vec3(2.2)) * 0.2;
    blurredLastFrame += pow(texture(uPingPongTexture, uv + vec2(0.0, -blurRadius)).rgb, vec3(2.2)) * 0.2;
    blurredLastFrame += lastFrameColor * 0.2;
        
    vec3 draw = mix(blurredLastFrame, trailColor, clampedDist);
    
    draw *= pow(uDecay, 0.2);
    draw = pow(draw, vec3(1.0/2.2));
    
    fragColor = vec4(draw, 1.0);
}`;

export const MOUSE_DRAW_PINGPONG = {
  fragmentShader: fragmentShader,
  vertexShader: vertexShaderNoMatrix,
  crossorigin: 'Anonymous',
  texturesOptions: {
    floatingPoint: 'half-float',
    premultiplyAlpha: true,
  },
  uniforms: {
    pos: {
      name: 'uPos',
      type: '2f',
      value: new Vec2(0.5),
    },
    previousMousePos: {
      name: 'uPreviousMousePos',
      type: '2f',
      value: new Vec2(0.5),
    },
    radius: {
      name: 'uRadius',
      type: '1f',
      value: 0.5,
    },
    amount: {
      name: 'uAmount',
      type: '1f',
      value: 0.5,
    },
    turbulence: {
      name: 'uTurbulence',
      type: '1f',
      value: 0.25,
    },
    time: {
      name: 'uTime',
      type: '1f',
      value: 0,
    },
    bloom: {
      name: 'uBloom',
      type: '1f',
      value: 0.25,
    },
    decay: {
      name: 'uDecay',
      type: '1f',
      value: 0.5,
    },
    ...universalUniformParams,
  },
};
