import os import io import uuid import time import json import logging import tempfile import threading from flask import Flask, request, jsonify, send_file from transformers import pipeline from gtts import gTTS from pydub import AudioSegment # ================= CONFIG ================= TEMP_AUDIO_DIR = "/tmp/audio" os.makedirs(TEMP_AUDIO_DIR, exist_ok=True) STT_MODEL = "openai/whisper-tiny" LLM_MODEL = "google/flan-t5-base" MAX_AUDIO_SECONDS = 10 MAX_TEXT_LEN = 200 CLEANUP_INTERVAL = 300 # seconds FILE_EXPIRE_TIME = 600 # seconds # ================= LOG ================= logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s" ) logger = logging.getLogger(__name__) # ================= APP ================= app = Flask(__name__) app.config["TEMP_AUDIO_DIR"] = TEMP_AUDIO_DIR # ================= LOAD MODELS ================= logger.info("Loading STT model...") stt_pipeline = pipeline( "automatic-speech-recognition", model=STT_MODEL, device="cpu" ) logger.info("Loading LLM model...") llm_pipeline = pipeline( "text2text-generation", model=LLM_MODEL, device="cpu" ) logger.info("Models loaded successfully") # ================= UTILS ================= def generate_tts_audio(text: str) -> bytes: """ Generate WAV 16kHz mono audio from text """ try: text = text.replace("\n", " ").strip() if not text: text = "I understand." text = text[:MAX_TEXT_LEN] logger.info(f"TTS: {text}") with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as wav_file: mp3_path = wav_file.name.replace(".wav", ".mp3") tts = gTTS(text=text, lang="en") tts.save(mp3_path) audio = AudioSegment.from_file(mp3_path) audio = audio.set_frame_rate(16000).set_channels(1) audio.export(wav_file.name, format="wav") with open(wav_file.name, "rb") as f: wav_data = f.read() os.remove(mp3_path) os.remove(wav_file.name) return wav_data except Exception as e: logger.error(f"TTS error: {e}", exc_info=True) return b"" def cleanup_temp_files(): while True: try: now = time.time() for filename in os.listdir(TEMP_AUDIO_DIR): path = os.path.join(TEMP_AUDIO_DIR, filename) if os.path.isfile(path): if now - os.path.getmtime(path) > FILE_EXPIRE_TIME: os.remove(path) except Exception as e: logger.warning(f"Cleanup error: {e}") time.sleep(CLEANUP_INTERVAL) # ================= ROUTES ================= @app.route("/health", methods=["GET"]) def health(): return jsonify({ "status": "ok", "stt": STT_MODEL, "llm": LLM_MODEL }) @app.route("/process_audio", methods=["POST"]) def process_audio(): try: if "audio" not in request.files: return jsonify({"error": "No audio file"}), 400 audio_file = request.files["audio"] raw_audio = audio_file.read() if len(raw_audio) < 1000: return jsonify({"error": "Audio too short"}), 400 # ================= STT ================= logger.info("Running STT...") stt_result = stt_pipeline( raw_audio, sampling_rate=16000 ) user_text = stt_result.get("text", "").strip() logger.info(f"User said: {user_text}") if not user_text: user_text = "Hello" # ================= LLM ================= logger.info("Running LLM...") llm_result = llm_pipeline( user_text, max_new_tokens=64, do_sample=False ) answer = llm_result[0]["generated_text"] logger.info(f"Answer: {answer}") # ================= TTS ================= audio_response = generate_tts_audio(answer) if not audio_response: return jsonify({"error": "TTS failed"}), 500 file_id = str(uuid.uuid4()) filepath = os.path.join(TEMP_AUDIO_DIR, f"{file_id}.wav") with open(filepath, "wb") as f: f.write(audio_response) return send_file( filepath, mimetype="audio/wav", as_attachment=False, download_name="response.wav" ) except Exception as e: logger.error(f"Processing error: {e}", exc_info=True) return jsonify({"error": "Internal error"}), 500 # ================= STARTUP ================= if __name__ == "__main__": threading.Thread(target=cleanup_temp_files, daemon=True).start() app.run( host="0.0.0.0", port=7860, threaded=True )