|
|
import React, { useState, useRef, useCallback } from "react"; |
|
|
import { |
|
|
AutoProcessor, |
|
|
AutoModelForImageTextToText, |
|
|
RawImage, |
|
|
TextStreamer, |
|
|
} from "@huggingface/transformers"; |
|
|
import type { |
|
|
Tensor, |
|
|
PixtralProcessor, |
|
|
Ministral3ForCausalLM, |
|
|
ProgressInfo, |
|
|
} from "@huggingface/transformers"; |
|
|
import { VLMContext } from "./VLMContext"; |
|
|
|
|
|
const MODEL_ID = "mistralai/Ministral-3-3B-Instruct-2512-ONNX"; |
|
|
const MAX_NEW_TOKENS = 512; |
|
|
|
|
|
export const VLMProvider: React.FC<React.PropsWithChildren> = ({ |
|
|
children, |
|
|
}) => { |
|
|
const [isLoaded, setIsLoaded] = useState(false); |
|
|
const [isLoading, setIsLoading] = useState(false); |
|
|
const [error, setError] = useState<string | null>(null); |
|
|
const [imageSize, setImageSize] = useState(480); |
|
|
|
|
|
const processorRef = useRef<PixtralProcessor | null>(null); |
|
|
const modelRef = useRef<Ministral3ForCausalLM | null>(null); |
|
|
const loadPromiseRef = useRef<Promise<void> | null>(null); |
|
|
const inferenceLock = useRef(false); |
|
|
const canvasRef = useRef<HTMLCanvasElement | null>(null); |
|
|
const imageSizeRef = useRef(480); |
|
|
|
|
|
const updateImageSize = useCallback((size: number) => { |
|
|
setImageSize(size); |
|
|
imageSizeRef.current = size; |
|
|
if (processorRef.current?.image_processor) { |
|
|
processorRef.current.image_processor.size = { longest_edge: size }; |
|
|
} |
|
|
}, []); |
|
|
|
|
|
const loadModel = useCallback( |
|
|
async (onProgress?: (msg: string, percentage: number) => void) => { |
|
|
if (isLoaded) { |
|
|
onProgress?.("Model already loaded!", 100); |
|
|
return; |
|
|
} |
|
|
|
|
|
if (loadPromiseRef.current) { |
|
|
return loadPromiseRef.current; |
|
|
} |
|
|
|
|
|
setIsLoading(true); |
|
|
setError(null); |
|
|
|
|
|
loadPromiseRef.current = (async () => { |
|
|
try { |
|
|
onProgress?.("Loading processor...", 0); |
|
|
processorRef.current = await AutoProcessor.from_pretrained(MODEL_ID); |
|
|
processorRef.current.image_processor!.size = { |
|
|
longest_edge: imageSizeRef.current, |
|
|
}; |
|
|
onProgress?.("Processor loaded. Loading model...", 0); |
|
|
|
|
|
const progressMap = new Map<string, number>(); |
|
|
const progressCallback = (info: ProgressInfo) => { |
|
|
if ( |
|
|
info.status === "progress" && |
|
|
info.file.endsWith(".onnx_data") |
|
|
) { |
|
|
progressMap.set(info.file, info.loaded / info.total); |
|
|
const total = Array.from(progressMap.values()).reduce( |
|
|
(a, b) => a + b, |
|
|
0, |
|
|
); |
|
|
const percentage = (total / 3) * 100; |
|
|
onProgress?.("Downloading model...", percentage); |
|
|
} |
|
|
}; |
|
|
|
|
|
modelRef.current = await AutoModelForImageTextToText.from_pretrained( |
|
|
MODEL_ID, |
|
|
{ |
|
|
dtype: { |
|
|
embed_tokens: "fp16", |
|
|
vision_encoder: "q4", |
|
|
decoder_model_merged: "q4f16", |
|
|
}, |
|
|
device: "webgpu", |
|
|
progress_callback: progressCallback, |
|
|
}, |
|
|
); |
|
|
onProgress?.("Model loaded successfully!", 100); |
|
|
setIsLoaded(true); |
|
|
} catch (e) { |
|
|
const errorMessage = e instanceof Error ? e.message : String(e); |
|
|
setError(errorMessage); |
|
|
console.error("Error loading model:", e); |
|
|
throw e; |
|
|
} finally { |
|
|
setIsLoading(false); |
|
|
loadPromiseRef.current = null; |
|
|
} |
|
|
})(); |
|
|
|
|
|
return loadPromiseRef.current; |
|
|
}, |
|
|
[isLoaded], |
|
|
); |
|
|
|
|
|
const runInference = useCallback( |
|
|
async ( |
|
|
video: HTMLVideoElement, |
|
|
instruction: string, |
|
|
onTextUpdate?: (text: string) => void, |
|
|
onStatsUpdate?: (stats: { tps?: number; ttft?: number }) => void, |
|
|
): Promise<string> => { |
|
|
if (inferenceLock.current) { |
|
|
return ""; |
|
|
} |
|
|
inferenceLock.current = true; |
|
|
|
|
|
if (!processorRef.current || !modelRef.current) { |
|
|
throw new Error("Model/processor not loaded"); |
|
|
} |
|
|
|
|
|
if (!canvasRef.current) { |
|
|
canvasRef.current = document.createElement("canvas"); |
|
|
} |
|
|
const canvas = canvasRef.current; |
|
|
|
|
|
canvas.width = video.videoWidth; |
|
|
canvas.height = video.videoHeight; |
|
|
|
|
|
const ctx = canvas.getContext("2d", { willReadFrequently: true }); |
|
|
if (!ctx) throw new Error("Could not get canvas context"); |
|
|
|
|
|
ctx.drawImage(video, 0, 0); |
|
|
|
|
|
const frame = ctx.getImageData(0, 0, canvas.width, canvas.height); |
|
|
const rawImg = new RawImage(frame.data, frame.width, frame.height, 4); |
|
|
const messages = [ |
|
|
{ |
|
|
role: "system", |
|
|
content: `You are a helpful visual AI assistant. Respond concisely and accurately to the user's query in one sentence.`, |
|
|
}, |
|
|
{ role: "user", content: `[IMG]${instruction}` }, |
|
|
]; |
|
|
const prompt = processorRef.current.apply_chat_template(messages); |
|
|
const inputs = await processorRef.current(rawImg, prompt, { |
|
|
add_special_tokens: false, |
|
|
}); |
|
|
|
|
|
let streamed = ""; |
|
|
const start = performance.now(); |
|
|
let decodeStart: number | undefined; |
|
|
let numTokens = 0; |
|
|
const streamer = new TextStreamer(processorRef.current.tokenizer!, { |
|
|
skip_prompt: true, |
|
|
skip_special_tokens: true, |
|
|
callback_function: (t: string) => { |
|
|
if (streamed.length === 0) { |
|
|
const latency = performance.now() - start; |
|
|
onStatsUpdate?.({ ttft: latency }); |
|
|
} |
|
|
streamed += t; |
|
|
onTextUpdate?.(streamed.trim()); |
|
|
}, |
|
|
token_callback_function: () => { |
|
|
decodeStart ??= performance.now(); |
|
|
numTokens++; |
|
|
const elapsed = (performance.now() - decodeStart) / 1000; |
|
|
if (elapsed > 0) { |
|
|
onStatsUpdate?.({ tps: numTokens / elapsed }); |
|
|
} |
|
|
}, |
|
|
}); |
|
|
|
|
|
const outputs = (await modelRef.current.generate({ |
|
|
...inputs, |
|
|
max_new_tokens: MAX_NEW_TOKENS, |
|
|
do_sample: false, |
|
|
streamer, |
|
|
repetition_penalty: 1.2, |
|
|
})) as Tensor; |
|
|
const generated = outputs.slice(null, [ |
|
|
inputs.input_ids.dims.at(-1), |
|
|
null, |
|
|
]); |
|
|
const decodeEnd = performance.now(); |
|
|
if (decodeStart) { |
|
|
const numTokens = generated.dims[1]; |
|
|
const tps = numTokens / ((decodeEnd - decodeStart) / 1000); |
|
|
onStatsUpdate?.({ tps }); |
|
|
} |
|
|
|
|
|
const decoded = processorRef.current.batch_decode(generated, { |
|
|
skip_special_tokens: true, |
|
|
}); |
|
|
inferenceLock.current = false; |
|
|
return decoded[0].trim(); |
|
|
}, |
|
|
[], |
|
|
); |
|
|
|
|
|
return ( |
|
|
<VLMContext.Provider |
|
|
value={{ |
|
|
isLoaded, |
|
|
isLoading, |
|
|
error, |
|
|
loadModel, |
|
|
runInference, |
|
|
imageSize, |
|
|
setImageSize: updateImageSize, |
|
|
}} |
|
|
> |
|
|
{children} |
|
|
</VLMContext.Provider> |
|
|
); |
|
|
}; |
|
|
|
|
|
export default VLMProvider; |