import { useCallback, useRef } from "react";

import * as tmImage from "@teachablemachine/image";
import { getModelJSONForModelArtifacts, getModelArtifactsInfoForJSON } from "@tensorflow/tfjs-core/dist/io/io_utils";
import type { WeightsManifestConfig } from "@tensorflow/tfjs-core/dist/io/types";

import Template, {
  ClassData,
  ModelImportHandler,
  ModelExportHandler,
  PredictionRequestHandler,
  TrainingRequestHandler
} from "./template";
import _ from "lodash";

const TRAIN_PARAMS = Object.freeze({ batchSize: 16, denseUnits: 100, epochs: 50, learningRate: 0.001 });
const EXPORT_NAME = "image-model-export-";

export default function ImageClassification() {
  const model = useRef<tmImage.TeachableMobileNet | tmImage.CustomMobileNet>();

  const trainModel: TrainingRequestHandler = useCallback(async (data, onProgress) => {
    const imgEl = document.createElement("img");
    const canvas = document.createElement("canvas");
    const ctx = canvas.getContext("2d")!;
    const keys = [...data.keys()];
    const m = await tmImage.createTeachable({}, { version: 2 });
    m.setLabels(keys);
    m.examples = _.map(Array(keys.length), () => []);
    canvas.width = canvas.height = 224;

    for (let classIdx = 0; classIdx < keys.length; classIdx++) {
      for (const img of data.get(keys[classIdx])!.imgs) {
        imgEl.src = img;
        await imgEl.decode();
        const xIsMin = imgEl.width < imgEl.height;
        const minVal = xIsMin ? imgEl.width : imgEl.height;
        // prettier-ignore
        ctx.drawImage(
          imgEl,
          xIsMin ? 0 : (imgEl.width - imgEl.height) / 2, xIsMin ? (imgEl.height - imgEl.width) / 2 : 0, minVal, minVal,
          0, 0, 224, 224
        );
        await m.addExample(classIdx, canvas);
      }
    }

    imgEl.remove();
    canvas.remove();
    await m.train(TRAIN_PARAMS, { onEpochEnd: (epoch: number) => onProgress(epoch / TRAIN_PARAMS.epochs, false) });
    onProgress(1, true);
    model.current = m;
  }, []);

  const predict: PredictionRequestHandler = useCallback(async videoEl => {
    if (!model.current) {
      throw new Error("Model is not trained yet");
    }

    const predictions = await model.current.predictTopK(
      videoEl,
      model.current instanceof tmImage.TeachableMobileNet
        ? model.current.getLabels().length
        : model.current.getClassLabels().length
    );
    return new Map(predictions.map(({ className, probability }) => [className, probability]));
  }, []);

  const saveModel: ModelExportHandler = useCallback(async (classes, zip) => {
    const SUFFIX = Date.now();

    if (model.current && model.current instanceof tmImage.TeachableMobileNet) {
      // Referenced from https://github.com/tensorflow/tfjs/blob/master/tfjs-core/src/io/browser_files.ts#L67
      await model.current.save({
        save: async artifacts => {
          if (!artifacts.weightData || !artifacts.weightSpecs) {
            throw new Error("No weight data");
          }

          if (artifacts.modelTopology instanceof ArrayBuffer) {
            throw new Error("Model topology is not JSON");
          }

          const weights = new Blob([artifacts.weightData], { type: "application/octet-stream" });
          const weightsMenifest: WeightsManifestConfig = [
            {
              paths: [`./${EXPORT_NAME}${SUFFIX}.weights.bin`],
              weights: artifacts.weightSpecs
            }
          ];
          const modelJson = getModelJSONForModelArtifacts(artifacts, weightsMenifest);
          const metadataJson = {
            modelName: "image-model",
            labels: (model.current as tmImage.TeachableMobileNet).getLabels(),
            namedLabels: [...classes.values()].map(({ name }) => name)
          };

          zip
            .file(`${EXPORT_NAME}${SUFFIX}-model.json`, JSON.stringify(modelJson))
            .file(`${EXPORT_NAME}${SUFFIX}-metadata.json`, JSON.stringify(metadataJson))
            .file(`${EXPORT_NAME}${SUFFIX}.weights.bin`, weights);

          return { modelArtifactsInfo: getModelArtifactsInfoForJSON(artifacts) };
        }
      });

      return zip;
    } else {
      throw new Error("Model is not trained yet");
    }
  }, []);

  const importModel: ModelImportHandler = useCallback(async zip => {
    let modelFile: File | undefined, weightsFile: File | undefined, metadataFile: File | undefined;
    let namedLabels: string[] | undefined;

    for (const [name, file] of Object.entries(zip.files)) {
      if (name.endsWith("weights.bin")) {
        weightsFile = new File([await file.async("blob")], name);
      } else if (name.endsWith("model.json")) {
        modelFile = new File([await file.async("blob")], name);
      } else if (name.endsWith("metadata.json")) {
        const metadataStr = await file.async("string");
        namedLabels = JSON.parse(metadataStr).namedLabels;
        metadataFile = new File([metadataStr], name);
      }
    }

    if (!modelFile || !weightsFile || !metadataFile || !namedLabels) {
      throw new Error("Missing data");
    }

    model.current = await tmImage.loadFromFiles(modelFile, weightsFile, metadataFile);
    return new Map(
      _.zip(
        model.current.getClassLabels(),
        namedLabels.map(l => ({ name: l, imgs: [] }))
      ) as [string, ClassData][]
    );
  }, []);

  return (
    <Template
      storageBucketPath="image-models"
      prefix="image"
      onTrainingRequested={trainModel}
      onPredictionRequested={predict}
      onModelExport={saveModel}
      onModelImport={importModel}
    />
  );
}
