105 lines
3.4 KiB
Python
105 lines
3.4 KiB
Python
import os
|
|
import torch
|
|
import io
|
|
import soundfile as sf
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.responses import Response
|
|
from pydantic import BaseModel
|
|
from qwen_tts import Qwen3TTSModel
|
|
|
|
# --- 1. HARDWARE OPTIMIZATIONS ---
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
torch.set_float32_matmul_precision('high')
|
|
|
|
# NEW: Confine PyTorch CPU threads. On massive servers, PyTorch tries to use
|
|
# all available cores for background tasks, causing severe context-switching lag.
|
|
torch.set_num_threads(4)
|
|
|
|
app = FastAPI(title="Indra Mouth - Qwen3-TTS Official (Optimized)")
|
|
prompt_cache = {}
|
|
|
|
print(f"Loading Official Qwen3-TTS on: {torch.cuda.get_device_name(0)}")
|
|
|
|
model = Qwen3TTSModel.from_pretrained(
|
|
"Qwen/Qwen3-TTS-12Hz-1.7B-Base",
|
|
device_map="cuda:0",
|
|
dtype=torch.bfloat16,
|
|
attn_implementation="sdpa"
|
|
)
|
|
model.eval()
|
|
|
|
# --- 2. THE SURGICAL COMPILE ---
|
|
# We bypass the wrapper class and strictly compile the heavy LLM engine
|
|
# --- 2. THE SURGICAL COMPILE ---
|
|
import torch._dynamo
|
|
import torch._inductor.config
|
|
|
|
torch._dynamo.config.suppress_errors = True
|
|
# Expand the cache limit so it doesn't thrash when sequence lengths vary
|
|
torch._dynamo.config.cache_size_limit = 128
|
|
# Force inductor to use the fastest possible memory layouts for L40S
|
|
torch._inductor.config.coordinate_descent_tuning = True
|
|
|
|
print("Compiling Autoregressive Engine (Dynamic Shapes)...")
|
|
model.talker = torch.compile(
|
|
model.talker,
|
|
mode="max-autotune", # The most aggressive optimization level available
|
|
dynamic=True # CRITICAL: Tells the compiler the audio length will grow
|
|
)
|
|
|
|
|
|
|
|
class TTSRequest(BaseModel):
|
|
input: str
|
|
voice: str = "oni"
|
|
seed: int = 42
|
|
|
|
@app.post("/v1/audio/speech")
|
|
async def generate_speech(request: TTSRequest):
|
|
try:
|
|
voice_key = request.voice
|
|
|
|
if voice_key not in prompt_cache:
|
|
base_path = "/mnt/nvme3n1/swarm/voice-samples"
|
|
ref_path = os.path.join(base_path, f"{voice_key}.wav")
|
|
txt_path = os.path.join(base_path, f"{voice_key}.txt")
|
|
|
|
if not os.path.exists(ref_path):
|
|
raise FileNotFoundError(f"Voice sample {ref_path} not found.")
|
|
|
|
ref_text = None
|
|
if os.path.exists(txt_path):
|
|
with open(txt_path, "r") as f:
|
|
ref_text = f.read().strip()
|
|
|
|
prompt_cache[voice_key] = model.create_voice_clone_prompt(
|
|
ref_audio=ref_path,
|
|
ref_text=ref_text,
|
|
x_vector_only_mode=(ref_text is None)
|
|
)
|
|
print(f"--- Cached pristine voice prompt: {voice_key} ---")
|
|
|
|
torch.manual_seed(request.seed)
|
|
|
|
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
|
wavs, sr = model.generate_voice_clone(
|
|
text=[request.input],
|
|
language=["English"],
|
|
voice_clone_prompt=prompt_cache[voice_key]
|
|
)
|
|
|
|
wav_io = io.BytesIO()
|
|
sf.write(wav_io, wavs[0], sr, format='WAV', subtype='FLOAT')
|
|
wav_io.seek(0)
|
|
|
|
return Response(content=wav_io.getvalue(), media_type="audio/wav")
|
|
|
|
except Exception as e:
|
|
print(f"Indra Mouth Error: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8002)
|