diff --git a/frontend/server/utils/whisper.ts b/frontend/server/utils/whisper.ts index ffbb903..2a3dd61 100644 --- a/frontend/server/utils/whisper.ts +++ b/frontend/server/utils/whisper.ts @@ -3,11 +3,9 @@ type WhisperTranscribeInput = { sampleRate: number; language?: string; }; -let whisperPipelinePromise: Promise | null = null; -let transformersPromise: Promise | null = null; -function getWhisperModelId() { - return (process.env.CF_WHISPER_MODEL ?? "Xenova/whisper-small").trim() || "Xenova/whisper-small"; +function getWhisperUrl() { + return (process.env.WHISPER_URL ?? "http://whisper:9000").replace(/\/+$/, ""); } function getWhisperLanguage() { @@ -15,39 +13,57 @@ function getWhisperLanguage() { return value || "ru"; } -async function getWhisperPipeline() { - if (!transformersPromise) { - transformersPromise = import("@xenova/transformers"); +function pcmToWav(samples: Float32Array, sampleRate: number): Buffer { + const numChannels = 1; + const bitsPerSample = 16; + const byteRate = sampleRate * numChannels * (bitsPerSample / 8); + const blockAlign = numChannels * (bitsPerSample / 8); + const dataSize = samples.length * (bitsPerSample / 8); + const headerSize = 44; + const buffer = Buffer.alloc(headerSize + dataSize); + + // RIFF header + buffer.write("RIFF", 0); + buffer.writeUInt32LE(36 + dataSize, 4); + buffer.write("WAVE", 8); + + // fmt chunk + buffer.write("fmt ", 12); + buffer.writeUInt32LE(16, 16); + buffer.writeUInt16LE(1, 20); // PCM + buffer.writeUInt16LE(numChannels, 22); + buffer.writeUInt32LE(sampleRate, 24); + buffer.writeUInt32LE(byteRate, 28); + buffer.writeUInt16LE(blockAlign, 32); + buffer.writeUInt16LE(bitsPerSample, 34); + + // data chunk + buffer.write("data", 36); + buffer.writeUInt32LE(dataSize, 40); + + for (let i = 0; i < samples.length; i += 1) { + const s = Math.max(-1, Math.min(1, samples[i] ?? 0)); + const val = s < 0 ? s * 0x8000 : s * 0x7fff; + buffer.writeInt16LE(Math.round(val), headerSize + i * 2); } - const { env, pipeline } = await transformersPromise; - - if (!whisperPipelinePromise) { - env.allowRemoteModels = true; - env.allowLocalModels = true; - env.cacheDir = "/app/.data/transformers"; - - const modelId = getWhisperModelId(); - whisperPipelinePromise = pipeline("automatic-speech-recognition", modelId); - } - - return whisperPipelinePromise; + return buffer; } export async function transcribeWithWhisper(input: WhisperTranscribeInput) { - const transcriber = (await getWhisperPipeline()) as any; - const result = await transcriber( - input.samples, - { - sampling_rate: input.sampleRate, - language: (input.language ?? getWhisperLanguage()) || "ru", - task: "transcribe", - chunk_length_s: 20, - stride_length_s: 5, - return_timestamps: false, - }, - ); + const wav = pcmToWav(input.samples, input.sampleRate); + const language = (input.language ?? getWhisperLanguage()) || "ru"; + const url = `${getWhisperUrl()}/asr?task=transcribe&language=${language}&output=json`; - const text = String((result as any)?.text ?? "").trim(); - return text; + const formData = new FormData(); + formData.append("audio_file", new Blob([wav], { type: "audio/wav" }), "audio.wav"); + + const response = await fetch(url, { method: "POST", body: formData }); + if (!response.ok) { + const detail = await response.text().catch(() => ""); + throw new Error(`Whisper service error ${response.status}: ${detail}`); + } + + const result = (await response.json()) as { text?: string }; + return String(result?.text ?? "").trim(); }