|
import torch |
|
import torchaudio |
|
import numpy as np |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from livekit import rtc |
|
import asyncio |
|
import os |
|
|
|
hf_token = os.getenv("HF_AUTH_TOKEN") |
|
if not hf_token: |
|
raise ValueError("HF_AUTH_TOKEN not found in environment variables.") |
|
print(f"Token fetched successfully: {hf_token[:10]}...") |
|
|
|
class EndpointHandler: |
|
def __init__(self, path: str = ""): |
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
path = "https://github.com/atharva-create/Orpheus-TTS" |
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32) |
|
self.model.to(self.device) |
|
self.model.eval() |
|
|
|
def __call__(self, data: dict) -> list: |
|
|
|
text_input = data.get("inputs") or data.get("text") or "" |
|
if not isinstance(text_input, str) or text_input.strip() == "": |
|
raise ValueError("No text input provided for TTS") |
|
voice = data.get("voice", "tara") |
|
|
|
|
|
prompt = f"{voice}: {text_input}" |
|
|
|
|
|
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) |
|
generate_kwargs = { |
|
"max_new_tokens": 1024, |
|
"do_sample": True, |
|
"temperature": 0.8, |
|
"top_p": 0.95, |
|
"repetition_penalty": 1.1, |
|
"pad_token_id": self.tokenizer.eos_token_id, |
|
} |
|
output_ids = self.model.generate(input_ids, **generate_kwargs) |
|
|
|
generated_tokens = output_ids[0, input_ids.size(1):] |
|
output_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=False) |
|
|
|
|
|
|
|
audio_token_ids = [int(m) for m in output_text.split()] |
|
|
|
|
|
waveform = self.generate_waveform_from_tokens(audio_token_ids) |
|
|
|
|
|
torchaudio.save("output_audio.wav", waveform, 24000) |
|
|
|
|
|
lk_url = data.get("livekit_url") |
|
lk_token = data.get("livekit_token") |
|
room_name = data.get("livekit_room", "default-room") |
|
|
|
|
|
asyncio.run(self.stream_audio(lk_url, lk_token, room_name, waveform)) |
|
|
|
return [{"status": "success"}] |
|
|
|
def generate_waveform_from_tokens(self, audio_token_ids): |
|
""" |
|
Convert audio tokens into a waveform (this part is for demonstration). |
|
You should implement a proper method to decode tokens to actual audio. |
|
""" |
|
|
|
|
|
num_samples = len(audio_token_ids) * 100 |
|
waveform = torch.randn(1, num_samples) |
|
return waveform |
|
|
|
async def stream_audio(self, lk_url, lk_token, room_name, waveform): |
|
room = rtc.Room() |
|
try: |
|
await room.connect(lk_url, lk_token, options=rtc.RoomOptions(auto_subscribe=True)) |
|
except Exception as e: |
|
return f"Failed to connect to LiveKit: {e}" |
|
|
|
|
|
source = rtc.AudioSource(sample_rate=24000, num_channels=1) |
|
track = rtc.LocalAudioTrack.create_audio_track("tts-audio", source) |
|
await room.local_participant.publish_track(track, rtc.TrackPublishOptions(name="TTS Audio")) |
|
|
|
|
|
frame_duration = 0.05 |
|
frame_samples = int(24000 * frame_duration) |
|
total_samples = waveform.size(1) |
|
for start in range(0, total_samples, frame_samples): |
|
end = min(start + frame_samples, total_samples) |
|
chunk = waveform[:, start:end].numpy().astype(np.int16) |
|
|
|
|
|
audio_frame = rtc.AudioFrame.create(24000, 1, len(chunk)) |
|
np.copyto(audio_frame.data, chunk) |
|
await source.capture_frame(audio_frame) |
|
|
|
|
|
await asyncio.sleep(frame_duration) |
|
|
|
|
|
await room.disconnect() |
|
|