Swapped back to qwen3-tts

This commit is contained in:
2026-05-05 16:42:49 +10:00
parent e90d2b1ec2
commit 109084e8e4
3 changed files with 100 additions and 78 deletions

View File

@@ -1,37 +1,33 @@
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04
# Prevent interactive prompts
ENV DEBIAN_FRONTEND=noninteractive
ENV NVIDIA_VISIBLE_DEVICES=all
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
# 1. Install Python 3.12 and SoX dependencies
RUN apt-get update && apt-get install -y software-properties-common && \
add-apt-repository ppa:deadsnakes/ppa -y && \
apt-get update && apt-get install -y \
python3.12 \
python3.12-dev \
curl \
git \
libsndfile1 \
ffmpeg \
sox \
libsox-dev && \
# No deadsnakes PPA needed. Native Python 3.10 works perfectly.
RUN apt-get update && apt-get install -y \
python3 python3-dev python3-pip curl git libsndfile1 ffmpeg sox libsox-dev ninja-build && \
rm -rf /var/lib/apt/lists/*
# 2. Use the official bootstrap to install a clean Pip for 3.12
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.12
WORKDIR /app
# 3. Explicitly install BOTH torch and torchaudio from the cu124 index
RUN python3.12 -m pip install --no-cache-dir torch==2.6.0 torchaudio --index-url https://download.pytorch.org/whl/cu124
RUN python3.12 -m pip install --no-cache-dir fastapi uvicorn numpy soundfile
# Install Torch and core packages
RUN python3 -m pip install --no-cache-dir torch==2.6.0 torchaudio --index-url https://download.pytorch.org/whl/cu124
# 4. Install the local Qwen3-TTS requirements
RUN python3.12 -m pip install --no-cache-dir faster-qwen3-tts
# 1. Install foundational build tools and numpy first
RUN python3 -m pip install --no-cache-dir numpy setuptools wheel ninja packaging psutil
COPY tts-server.py .
# 2. Install the rest of the stack that relies on numpy being present
RUN python3 -m pip install --no-cache-dir fastapi uvicorn soundfile librosa transformers==4.57.3 accelerate sox onnxruntime
# Force ABI compatibility for the C++ compiler
ENV _GLIBCXX_USE_CXX11_ABI=0
ENV MAX_JOBS=8
# Install flash-attn
RUN python3 -m pip install --no-cache-dir flash-attn --no-build-isolation
COPY swarm-control/indra-tts-server/tts-server.py .
EXPOSE 8002
CMD ["python3.12", "tts-server.py"]
CMD ["python3", "tts-server.py"]

View File

@@ -1,74 +1,98 @@
import os
import torch
import numpy as np
import io
import wave
import soundfile as sf
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from pydantic import BaseModel
from faster_qwen3_tts import FasterQwen3TTS
from qwen_tts import Qwen3TTSModel
app = FastAPI(title="Indra tts")
# --- 1. HARDWARE OPTIMIZATIONS ---
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')
if not torch.cuda.is_available():
raise RuntimeError("Mouth cannot find CUDA. Check nvidia-container-toolkit.")
print(f"Loading model on: {torch.cuda.get_device_name(0)}")
# NEW: Confine PyTorch CPU threads. On massive servers, PyTorch tries to use
# all available cores for background tasks, causing severe context-switching lag.
torch.set_num_threads(4)
# Load the Base model for high-fidelity mimicry
model = FasterQwen3TTS.from_pretrained(
app = FastAPI(title="Indra Mouth - Qwen3-TTS Official (Optimized)")
prompt_cache = {}
print(f"Loading Official Qwen3-TTS on: {torch.cuda.get_device_name(0)}")
model = Qwen3TTSModel.from_pretrained(
"Qwen/Qwen3-TTS-12Hz-1.7B-Base",
device="cuda:0", # Targets GPU 7
dtype=torch.bfloat16
device_map="cuda:0",
dtype=torch.bfloat16,
attn_implementation="sdpa"
)
model.eval()
# --- 2. THE SURGICAL COMPILE ---
# We bypass the wrapper class and strictly compile the heavy LLM engine
# --- 2. THE SURGICAL COMPILE ---
import torch._dynamo
import torch._inductor.config
torch._dynamo.config.suppress_errors = True
# Expand the cache limit so it doesn't thrash when sequence lengths vary
torch._dynamo.config.cache_size_limit = 128
# Force inductor to use the fastest possible memory layouts for L40S
torch._inductor.config.coordinate_descent_tuning = True
print("Compiling Autoregressive Engine (Dynamic Shapes)...")
model.talker = torch.compile(
model.talker,
mode="max-autotune", # The most aggressive optimization level available
dynamic=True # CRITICAL: Tells the compiler the audio length will grow
)
class TTSRequest(BaseModel):
model: str = "tts-1" # ignored by backend, here to satisfy modelix router
input: str
voice: str = "oni"
response_format: str = "wav"
seed: int = 42
@app.post("/v1/audio/speech")
async def generate_speech(request: TTSRequest):
try:
voice_file = f"{request.voice}.wav"
base_path = "/mnt/nvme3n1/swarm/voice-samples"
ref_path = os.path.join(base_path, voice_file)
txt_path = os.path.splitext(ref_path)[0] + ".txt"
voice_key = request.voice
if voice_key not in prompt_cache:
base_path = "/mnt/nvme3n1/swarm/voice-samples"
ref_path = os.path.join(base_path, f"{voice_key}.wav")
txt_path = os.path.join(base_path, f"{voice_key}.txt")
if not os.path.exists(ref_path):
raise FileNotFoundError(f"Voice sample {ref_path} not found.")
ref_text = None
if os.path.exists(txt_path):
with open(txt_path, "r") as f:
ref_text = f.read().strip()
ref_text = None
if os.path.exists(txt_path):
with open(txt_path, "r") as f:
ref_text = f.read().strip()
prompt_cache[voice_key] = model.create_voice_clone_prompt(
ref_audio=ref_path,
ref_text=ref_text,
x_vector_only_mode=(ref_text is None)
)
print(f"--- Cached pristine voice prompt: {voice_key} ---")
# Fix the seed for the persona identity
torch.manual_seed(request.seed)
full_audio = []
# Non-streaming call is fine here since it takes <1s on your L40S
audio_data, sample_rate = model.generate_voice_clone(
text=request.input,
language="English",
ref_audio=ref_path,
ref_text=ref_text,
xvec_only=(ref_text is None)
)
audio_data = np.array(audio_data)
audio_data = audio_data.flatten()
# Convert Float32 to Int16 for standard WAV compatibility
audio_int16 = (audio_data * 32767).astype(np.int16)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
wavs, sr = model.generate_voice_clone(
text=[request.input],
language=["English"],
voice_clone_prompt=prompt_cache[voice_key]
)
wav_io = io.BytesIO()
with wave.open(wav_io, 'wb') as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
wav_file.writeframes(audio_int16.tobytes())
sf.write(wav_io, wavs[0], sr, format='WAV', subtype='FLOAT')
wav_io.seek(0)
return Response(content=wav_io.getvalue(), media_type="audio/wav")
except Exception as e: