import os import torch import io import soundfile as sf from fastapi import FastAPI, HTTPException from fastapi.responses import Response from pydantic import BaseModel from qwen_tts import Qwen3TTSModel # --- 1. HARDWARE OPTIMIZATIONS --- torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.set_float32_matmul_precision('high') # NEW: Confine PyTorch CPU threads. On massive servers, PyTorch tries to use # all available cores for background tasks, causing severe context-switching lag. torch.set_num_threads(4) app = FastAPI(title="Indra Mouth - Qwen3-TTS Official (Optimized)") prompt_cache = {} print(f"Loading Official Qwen3-TTS on: {torch.cuda.get_device_name(0)}") model = Qwen3TTSModel.from_pretrained( "Qwen/Qwen3-TTS-12Hz-1.7B-Base", device_map="cuda:0", dtype=torch.bfloat16, attn_implementation="sdpa" ) model.eval() # --- 2. THE SURGICAL COMPILE --- # We bypass the wrapper class and strictly compile the heavy LLM engine # --- 2. THE SURGICAL COMPILE --- import torch._dynamo import torch._inductor.config torch._dynamo.config.suppress_errors = True # Expand the cache limit so it doesn't thrash when sequence lengths vary torch._dynamo.config.cache_size_limit = 128 # Force inductor to use the fastest possible memory layouts for L40S torch._inductor.config.coordinate_descent_tuning = True print("Compiling Autoregressive Engine (Dynamic Shapes)...") model.talker = torch.compile( model.talker, mode="max-autotune", # The most aggressive optimization level available dynamic=True # CRITICAL: Tells the compiler the audio length will grow ) class TTSRequest(BaseModel): input: str voice: str = "oni" seed: int = 42 @app.post("/v1/audio/speech") async def generate_speech(request: TTSRequest): try: voice_key = request.voice if voice_key not in prompt_cache: base_path = "/mnt/nvme3n1/swarm/voice-samples" ref_path = os.path.join(base_path, f"{voice_key}.wav") txt_path = os.path.join(base_path, f"{voice_key}.txt") if not os.path.exists(ref_path): raise FileNotFoundError(f"Voice sample {ref_path} not found.") ref_text = None if os.path.exists(txt_path): with open(txt_path, "r") as f: ref_text = f.read().strip() prompt_cache[voice_key] = model.create_voice_clone_prompt( ref_audio=ref_path, ref_text=ref_text, x_vector_only_mode=(ref_text is None) ) print(f"--- Cached pristine voice prompt: {voice_key} ---") torch.manual_seed(request.seed) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): wavs, sr = model.generate_voice_clone( text=[request.input], language=["English"], voice_clone_prompt=prompt_cache[voice_key] ) wav_io = io.BytesIO() sf.write(wav_io, wavs[0], sr, format='WAV', subtype='FLOAT') wav_io.seek(0) return Response(content=wav_io.getvalue(), media_type="audio/wav") except Exception as e: print(f"Indra Mouth Error: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8002)