Swapped back to qwen3-tts

This commit is contained in:
2026-05-05 16:42:49 +10:00
parent e90d2b1ec2
commit 109084e8e4
3 changed files with 100 additions and 78 deletions

View File

@@ -1,74 +1,98 @@
import os
import torch
import numpy as np
import io
import wave
import soundfile as sf
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from pydantic import BaseModel
from faster_qwen3_tts import FasterQwen3TTS
from qwen_tts import Qwen3TTSModel
app = FastAPI(title="Indra tts")
# --- 1. HARDWARE OPTIMIZATIONS ---
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')
if not torch.cuda.is_available():
raise RuntimeError("Mouth cannot find CUDA. Check nvidia-container-toolkit.")
print(f"Loading model on: {torch.cuda.get_device_name(0)}")
# NEW: Confine PyTorch CPU threads. On massive servers, PyTorch tries to use
# all available cores for background tasks, causing severe context-switching lag.
torch.set_num_threads(4)
# Load the Base model for high-fidelity mimicry
model = FasterQwen3TTS.from_pretrained(
app = FastAPI(title="Indra Mouth - Qwen3-TTS Official (Optimized)")
prompt_cache = {}
print(f"Loading Official Qwen3-TTS on: {torch.cuda.get_device_name(0)}")
model = Qwen3TTSModel.from_pretrained(
"Qwen/Qwen3-TTS-12Hz-1.7B-Base",
device="cuda:0", # Targets GPU 7
dtype=torch.bfloat16
device_map="cuda:0",
dtype=torch.bfloat16,
attn_implementation="sdpa"
)
model.eval()
# --- 2. THE SURGICAL COMPILE ---
# We bypass the wrapper class and strictly compile the heavy LLM engine
# --- 2. THE SURGICAL COMPILE ---
import torch._dynamo
import torch._inductor.config
torch._dynamo.config.suppress_errors = True
# Expand the cache limit so it doesn't thrash when sequence lengths vary
torch._dynamo.config.cache_size_limit = 128
# Force inductor to use the fastest possible memory layouts for L40S
torch._inductor.config.coordinate_descent_tuning = True
print("Compiling Autoregressive Engine (Dynamic Shapes)...")
model.talker = torch.compile(
model.talker,
mode="max-autotune", # The most aggressive optimization level available
dynamic=True # CRITICAL: Tells the compiler the audio length will grow
)
class TTSRequest(BaseModel):
model: str = "tts-1" # ignored by backend, here to satisfy modelix router
input: str
voice: str = "oni"
response_format: str = "wav"
seed: int = 42
@app.post("/v1/audio/speech")
async def generate_speech(request: TTSRequest):
try:
voice_file = f"{request.voice}.wav"
base_path = "/mnt/nvme3n1/swarm/voice-samples"
ref_path = os.path.join(base_path, voice_file)
txt_path = os.path.splitext(ref_path)[0] + ".txt"
voice_key = request.voice
if voice_key not in prompt_cache:
base_path = "/mnt/nvme3n1/swarm/voice-samples"
ref_path = os.path.join(base_path, f"{voice_key}.wav")
txt_path = os.path.join(base_path, f"{voice_key}.txt")
if not os.path.exists(ref_path):
raise FileNotFoundError(f"Voice sample {ref_path} not found.")
ref_text = None
if os.path.exists(txt_path):
with open(txt_path, "r") as f:
ref_text = f.read().strip()
ref_text = None
if os.path.exists(txt_path):
with open(txt_path, "r") as f:
ref_text = f.read().strip()
prompt_cache[voice_key] = model.create_voice_clone_prompt(
ref_audio=ref_path,
ref_text=ref_text,
x_vector_only_mode=(ref_text is None)
)
print(f"--- Cached pristine voice prompt: {voice_key} ---")
# Fix the seed for the persona identity
torch.manual_seed(request.seed)
full_audio = []
# Non-streaming call is fine here since it takes <1s on your L40S
audio_data, sample_rate = model.generate_voice_clone(
text=request.input,
language="English",
ref_audio=ref_path,
ref_text=ref_text,
xvec_only=(ref_text is None)
)
audio_data = np.array(audio_data)
audio_data = audio_data.flatten()
# Convert Float32 to Int16 for standard WAV compatibility
audio_int16 = (audio_data * 32767).astype(np.int16)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
wavs, sr = model.generate_voice_clone(
text=[request.input],
language=["English"],
voice_clone_prompt=prompt_cache[voice_key]
)
wav_io = io.BytesIO()
with wave.open(wav_io, 'wb') as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
wav_file.writeframes(audio_int16.tobytes())
sf.write(wav_io, wavs[0], sr, format='WAV', subtype='FLOAT')
wav_io.seek(0)
return Response(content=wav_io.getvalue(), media_type="audio/wav")
except Exception as e: