File size: 5,267 Bytes
9cc3956 9e89a91 3a8dbfa 9e89a91 6adecd0 6a25a35 a4328c6 b9d94db a4328c6 3a8dbfa 36d5922 9e89a91 3a8dbfa 82337a1 3a8dbfa 9e89a91 3a8dbfa 9e89a91 3a8dbfa 9e89a91 3a8dbfa 9e89a91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import torch
import torchaudio
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from livekit import rtc
import asyncio
import os
hf_token = os.getenv("HF_AUTH_TOKEN")
if not hf_token:
raise ValueError("HF_AUTH_TOKEN not found in environment variables.")
print(f"Token fetched successfully: {hf_token[:10]}...") # Print the first 10 chars
class EndpointHandler:
def __init__(self, path: str = ""):
# Load the Orpheus TTS model and tokenizer from the given path (Hub repository).
self.device = "cuda" if torch.cuda.is_available() else "cpu"
path = "https://github.com/atharva-create/Orpheus-TTS"
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
self.model.to(self.device)
self.model.eval()
def __call__(self, data: dict) -> list:
# Extract input text and optional voice and LiveKit parameters.
text_input = data.get("inputs") or data.get("text") or ""
if not isinstance(text_input, str) or text_input.strip() == "":
raise ValueError("No text input provided for TTS")
voice = data.get("voice", "tara") # default voice (e.g., "tara")
# Format prompt with voice name (Orpheus expects prompts like "voice: text").
prompt = f"{voice}: {text_input}"
# Encode prompt and generate output tokens with the TTS model.
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
generate_kwargs = {
"max_new_tokens": 1024, # allow sufficient tokens for audio output
"do_sample": True,
"temperature": 0.8,
"top_p": 0.95,
"repetition_penalty": 1.1, # >=1.1 for stable speech generation
"pad_token_id": self.tokenizer.eos_token_id,
}
output_ids = self.model.generate(input_ids, **generate_kwargs)
# The generated sequence includes the prompt; isolate newly generated tokens:
generated_tokens = output_ids[0, input_ids.size(1):]
output_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=False)
# Extract audio token IDs (assume tokens are in the output_text)
# This is a placeholder for token extraction, replace with actual logic.
audio_token_ids = [int(m) for m in output_text.split()]
# Example: convert the audio token IDs to waveform data
waveform = self.generate_waveform_from_tokens(audio_token_ids)
# Save or stream waveform
torchaudio.save("output_audio.wav", waveform, 24000) # Save as a 24 kHz audio file
# For real-time streaming, we will use LiveKit to stream the audio
lk_url = data.get("livekit_url")
lk_token = data.get("livekit_token")
room_name = data.get("livekit_room", "default-room")
# Streaming logic
asyncio.run(self.stream_audio(lk_url, lk_token, room_name, waveform))
return [{"status": "success"}]
def generate_waveform_from_tokens(self, audio_token_ids):
"""
Convert audio tokens into a waveform (this part is for demonstration).
You should implement a proper method to decode tokens to actual audio.
"""
# Here we're simulating the waveform by generating random data based on the tokens
# Replace this logic with actual audio generation
num_samples = len(audio_token_ids) * 100 # Estimate number of samples based on tokens
waveform = torch.randn(1, num_samples) # Simulate random audio waveform
return waveform
async def stream_audio(self, lk_url, lk_token, room_name, waveform):
room = rtc.Room()
try:
await room.connect(lk_url, lk_token, options=rtc.RoomOptions(auto_subscribe=True))
except Exception as e:
return f"Failed to connect to LiveKit: {e}"
# Create an audio track for streaming the TTS output
source = rtc.AudioSource(sample_rate=24000, num_channels=1)
track = rtc.LocalAudioTrack.create_audio_track("tts-audio", source)
await room.local_participant.publish_track(track, rtc.TrackPublishOptions(name="TTS Audio"))
# Stream the waveform data in chunks for real-time playback
frame_duration = 0.05 # 50 ms per frame
frame_samples = int(24000 * frame_duration) # 50 ms of audio at 24 kHz sample rate
total_samples = waveform.size(1)
for start in range(0, total_samples, frame_samples):
end = min(start + frame_samples, total_samples)
chunk = waveform[:, start:end].numpy().astype(np.int16) # Convert chunk to 16-bit PCM
# Create an AudioFrame and send to LiveKit
audio_frame = rtc.AudioFrame.create(24000, 1, len(chunk))
np.copyto(audio_frame.data, chunk)
await source.capture_frame(audio_frame)
# Sleep to maintain real-time pace (synchronize with frame duration)
await asyncio.sleep(frame_duration)
# Disconnect from the room after streaming is finished
await room.disconnect()
|