import { useCallback, useEffect, useRef } from "react";
import * as faceapi from "@vladmandic/face-api/";
import Template, {
  ClassData,
  ModelImportHandler,
  ModelExportHandler,
  PredictionRequestHandler,
  TrainingRequestHandler
} from "./template";
import _ from "lodash";

const MODEL_PATH = "https://cdn.jsdelivr.net/npm/@vladmandic/face-api/model/";
const EXPORT_NAME = "face-model-export-";

export default function FaceDetection() {
  const canvasRef = useRef<HTMLCanvasElement>(null);
  const ctx = useRef<CanvasRenderingContext2D | null>(null);
  const modelLoadPromise = useRef(Promise.resolve());
  const faceMatcher = useRef<faceapi.FaceMatcher | null>(null);

  useEffect(() => {
    modelLoadPromise.current = Promise.all([
      faceapi.nets.ssdMobilenetv1.loadFromUri(MODEL_PATH),
      faceapi.nets.faceLandmark68Net.loadFromUri(MODEL_PATH),
      faceapi.nets.faceRecognitionNet.loadFromUri(MODEL_PATH)
    ]).then(() => console.log("Models loaded"));
  }, []);

  const clearCanvas = useCallback(() => {
    if (canvasRef.current && ctx.current) {
      ctx.current.clearRect(0, 0, canvasRef.current.width, canvasRef.current.height);
    }
  }, []);

  const genereateFaceMatcher: TrainingRequestHandler = useCallback(async (data, onProgress) => {
    await modelLoadPromise.current;
    const progresses = _.map(Array(data.size), () => ({ value: 0, done: false }));

    const updateProgress = (idx: number, value: number, done: boolean) => {
      progresses[idx] = { value, done };
      onProgress(
        _.meanBy(progresses, p => p.value),
        progresses.every(p => p.done)
      );
    };

    const faceResults = await Promise.allSettled(
      [...data.entries()].map(async ([key, { name, imgs }], idx) => {
        const imgEl = document.createElement("img");
        const descriptors = Array<Float32Array>();

        for (let i = 0; i < imgs.length; i++) {
          imgEl.src = imgs[i];
          await imgEl.decode();
          const detection = await faceapi.detectSingleFace(imgEl).withFaceLandmarks().withFaceDescriptor();

          if (detection) {
            descriptors.push(detection.descriptor);
          }

          updateProgress(idx, (i + 1) / imgs.length, i === imgs.length - 1);
        }

        if (!descriptors.length) {
          throw new Error(`No faces detected for class ${name}`);
        } else {
          return new faceapi.LabeledFaceDescriptors(key, descriptors);
        }
      })
    );

    const faces = faceResults
      .filter((r): r is PromiseFulfilledResult<faceapi.LabeledFaceDescriptors> => r.status === "fulfilled")
      .map(r => r.value);

    if (faces.length < 2) {
      throw new Error("Could not detect faces from dataset");
    }

    faceMatcher.current = new faceapi.FaceMatcher(faces, 1);
  }, []);

  const predict: PredictionRequestHandler = useCallback(
    async videoEl => {
      if (!faceMatcher.current) {
        throw new Error("Face matcher not trained yet");
      }

      const detection = await faceapi.detectSingleFace(videoEl).withFaceLandmarks().withFaceDescriptor();

      if (detection) {
        const match = faceMatcher.current.findBestMatch(detection.descriptor);

        if (canvasRef.current) {
          if (!ctx.current) {
            ctx.current = canvasRef.current.getContext("2d")!;
          }

          clearCanvas();
          faceapi.draw.drawFaceLandmarks(
            canvasRef.current,
            faceapi.resizeResults(detection, { width: canvasRef.current.width, height: canvasRef.current.height })
          );
        }

        return new Map([[match.label, 1 - match.distance]]);
      } else {
        return new Map();
      }
    },
    [clearCanvas]
  );

  const saveModel: ModelExportHandler = useCallback((classes, zip) => {
    const SUFFIX = Date.now();

    if (faceMatcher.current) {
      return zip.file(
        `${EXPORT_NAME}${SUFFIX}.json`,
        JSON.stringify({
          ...faceMatcher.current.toJSON(),
          namedLabels: [...classes.values()].map(({ name }) => name)
        })
      );
    } else {
      throw new Error("Model not trained yet");
    }
  }, []);

  const importModel: ModelImportHandler = useCallback(async zip => {
    const jsonFile = zip.filter((_, file) => file.name.endsWith(".json"))[0];

    if (!jsonFile) {
      throw new Error("Invalid model file. Please select a valid zip file.");
    }

    const json = JSON.parse(await jsonFile.async("string"));

    try {
      faceMatcher.current = faceapi.FaceMatcher.fromJSON(json);
      return new Map(
        _.zip(
          faceMatcher.current.labeledDescriptors.map(descriptor => descriptor.label),
          (json.namedLabels as string[]).map(l => ({ name: l, imgs: [] }))
        ) as [string, ClassData][]
      );
    } catch (e: any) {
      console.error(e);
      throw new Error(`Failed to import model: ${e.message}`);
    }
  }, []);

  return (
    <Template
      storageBucketPath="face-models"
      prefix="face"
      onTrainingRequested={genereateFaceMatcher}
      onPredictionRequested={predict}
      onPredictionRequestStreamStop={clearCanvas}
      onModelExport={saveModel}
      onModelImport={importModel}
      canvasRef={canvasRef}
    />
  );
}
