Files
2026-05-05 16:42:49 +10:00

105 lines
3.4 KiB
Python

import os
import torch
import io
import soundfile as sf
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from pydantic import BaseModel
from qwen_tts import Qwen3TTSModel
# --- 1. HARDWARE OPTIMIZATIONS ---
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')
# NEW: Confine PyTorch CPU threads. On massive servers, PyTorch tries to use
# all available cores for background tasks, causing severe context-switching lag.
torch.set_num_threads(4)
app = FastAPI(title="Indra Mouth - Qwen3-TTS Official (Optimized)")
prompt_cache = {}
print(f"Loading Official Qwen3-TTS on: {torch.cuda.get_device_name(0)}")
model = Qwen3TTSModel.from_pretrained(
"Qwen/Qwen3-TTS-12Hz-1.7B-Base",
device_map="cuda:0",
dtype=torch.bfloat16,
attn_implementation="sdpa"
)
model.eval()
# --- 2. THE SURGICAL COMPILE ---
# We bypass the wrapper class and strictly compile the heavy LLM engine
# --- 2. THE SURGICAL COMPILE ---
import torch._dynamo
import torch._inductor.config
torch._dynamo.config.suppress_errors = True
# Expand the cache limit so it doesn't thrash when sequence lengths vary
torch._dynamo.config.cache_size_limit = 128
# Force inductor to use the fastest possible memory layouts for L40S
torch._inductor.config.coordinate_descent_tuning = True
print("Compiling Autoregressive Engine (Dynamic Shapes)...")
model.talker = torch.compile(
model.talker,
mode="max-autotune", # The most aggressive optimization level available
dynamic=True # CRITICAL: Tells the compiler the audio length will grow
)
class TTSRequest(BaseModel):
input: str
voice: str = "oni"
seed: int = 42
@app.post("/v1/audio/speech")
async def generate_speech(request: TTSRequest):
try:
voice_key = request.voice
if voice_key not in prompt_cache:
base_path = "/mnt/nvme3n1/swarm/voice-samples"
ref_path = os.path.join(base_path, f"{voice_key}.wav")
txt_path = os.path.join(base_path, f"{voice_key}.txt")
if not os.path.exists(ref_path):
raise FileNotFoundError(f"Voice sample {ref_path} not found.")
ref_text = None
if os.path.exists(txt_path):
with open(txt_path, "r") as f:
ref_text = f.read().strip()
prompt_cache[voice_key] = model.create_voice_clone_prompt(
ref_audio=ref_path,
ref_text=ref_text,
x_vector_only_mode=(ref_text is None)
)
print(f"--- Cached pristine voice prompt: {voice_key} ---")
torch.manual_seed(request.seed)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
wavs, sr = model.generate_voice_clone(
text=[request.input],
language=["English"],
voice_clone_prompt=prompt_cache[voice_key]
)
wav_io = io.BytesIO()
sf.write(wav_io, wavs[0], sr, format='WAV', subtype='FLOAT')
wav_io.seek(0)
return Response(content=wav_io.getvalue(), media_type="audio/wav")
except Exception as e:
print(f"Indra Mouth Error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8002)