import { useRequiredContext } from "@redotech/react-util/context";
import { lazy } from "@redotech/util/cache";
import { createContext, useMemo } from "react";

const MODEL_ID = "Xenova/all-mpnet-base-v2";

const SearchModelContext = createContext<
  { getEmbedding: (text: string) => Promise<number[]> } | undefined
>(undefined);

export const SearchModelProvider: React.FC<{ children: React.ReactNode }> = ({
  children,
}) => {
  const extractorFn = useMemo(
    () =>
      lazy(async () => {
        const { pipeline } = await import("@xenova/transformers");
        return await pipeline("feature-extraction", MODEL_ID, {
          quantized: true,
        });
      }),
    [],
  );

  const getEmbedding = async (text: string): Promise<number[]> => {
    try {
      const extractor = await extractorFn();
      const output = await extractor(text, {
        pooling: "mean",
        normalize: true,
      });
      return Array.from(output.data);
    } catch (error) {
      console.error("Embedding generation error:", error);
      throw error;
    }
  };

  return (
    <SearchModelContext.Provider value={{ getEmbedding }}>
      {children}
    </SearchModelContext.Provider>
  );
};

export const useSearchModel = () => useRequiredContext(SearchModelContext);
