Files
Swarm/swarm-control/indra-tts-server/tts-server.py

81 lines
2.4 KiB
Python

import os
import torch
import numpy as np
import io
import wave
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from pydantic import BaseModel
from faster_qwen3_tts import FasterQwen3TTS
app = FastAPI(title="Indra tts")
if not torch.cuda.is_available():
raise RuntimeError("Mouth cannot find CUDA. Check nvidia-container-toolkit.")
print(f"Loading model on: {torch.cuda.get_device_name(0)}")
# Load the Base model for high-fidelity mimicry
model = FasterQwen3TTS.from_pretrained(
"Qwen/Qwen3-TTS-12Hz-1.7B-Base",
device="cuda:0", # Targets GPU 7
dtype=torch.bfloat16
)
class TTSRequest(BaseModel):
model: str = "tts-1" # ignored by backend, here to satisfy modelix router
input: str
voice: str = "oni"
response_format: str = "wav"
seed: int = 42
@app.post("/v1/audio/speech")
async def generate_speech(request: TTSRequest):
try:
voice_file = f"{request.voice}.wav"
base_path = "/mnt/nvme3n1/swarm/voice-samples"
ref_path = os.path.join(base_path, voice_file)
txt_path = os.path.splitext(ref_path)[0] + ".txt"
ref_text = None
if os.path.exists(txt_path):
with open(txt_path, "r") as f:
ref_text = f.read().strip()
# Fix the seed for the persona identity
torch.manual_seed(request.seed)
full_audio = []
# Non-streaming call is fine here since it takes <1s on your L40S
audio_data, sample_rate = model.generate_voice_clone(
text=request.input,
language="English",
ref_audio=ref_path,
ref_text=ref_text,
xvec_only=(ref_text is None)
)
audio_data = np.array(audio_data)
audio_data = audio_data.flatten()
# Convert Float32 to Int16 for standard WAV compatibility
audio_int16 = (audio_data * 32767).astype(np.int16)
wav_io = io.BytesIO()
with wave.open(wav_io, 'wb') as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
wav_file.writeframes(audio_int16.tobytes())
wav_io.seek(0)
return Response(content=wav_io.getvalue(), media_type="audio/wav")
except Exception as e:
print(f"Indra Mouth Error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8002)