File size: 5,267 Bytes
9cc3956
9e89a91
 
3a8dbfa
9e89a91
 
6adecd0
6a25a35
a4328c6
b9d94db
 
 
a4328c6
3a8dbfa
 
 
 
36d5922
 
9e89a91
3a8dbfa
 
82337a1
3a8dbfa
 
 
 
 
 
9e89a91
3a8dbfa
 
 
 
 
 
 
 
 
 
9e89a91
3a8dbfa
 
 
 
 
 
 
9e89a91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a8dbfa
9e89a91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import torchaudio
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from livekit import rtc
import asyncio
import os

hf_token = os.getenv("HF_AUTH_TOKEN")
if not hf_token:
    raise ValueError("HF_AUTH_TOKEN not found in environment variables.")
print(f"Token fetched successfully: {hf_token[:10]}...")  # Print the first 10 chars

class EndpointHandler:
    def __init__(self, path: str = ""):
        # Load the Orpheus TTS model and tokenizer from the given path (Hub repository).
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        path = "https://github.com/atharva-create/Orpheus-TTS"
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
        self.model.to(self.device)
        self.model.eval()

    def __call__(self, data: dict) -> list:
        # Extract input text and optional voice and LiveKit parameters.
        text_input = data.get("inputs") or data.get("text") or ""
        if not isinstance(text_input, str) or text_input.strip() == "":
            raise ValueError("No text input provided for TTS")
        voice = data.get("voice", "tara")  # default voice (e.g., "tara")
        
        # Format prompt with voice name (Orpheus expects prompts like "voice: text").
        prompt = f"{voice}: {text_input}"
        
        # Encode prompt and generate output tokens with the TTS model.
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
        generate_kwargs = {
            "max_new_tokens": 1024,               # allow sufficient tokens for audio output
            "do_sample": True,
            "temperature": 0.8,
            "top_p": 0.95,
            "repetition_penalty": 1.1,           # >=1.1 for stable speech generation
            "pad_token_id": self.tokenizer.eos_token_id,
        }
        output_ids = self.model.generate(input_ids, **generate_kwargs)
        # The generated sequence includes the prompt; isolate newly generated tokens:
        generated_tokens = output_ids[0, input_ids.size(1):]
        output_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=False)
        
        # Extract audio token IDs (assume tokens are in the output_text)
        # This is a placeholder for token extraction, replace with actual logic.
        audio_token_ids = [int(m) for m in output_text.split()]

        # Example: convert the audio token IDs to waveform data
        waveform = self.generate_waveform_from_tokens(audio_token_ids)

        # Save or stream waveform
        torchaudio.save("output_audio.wav", waveform, 24000)  # Save as a 24 kHz audio file

        # For real-time streaming, we will use LiveKit to stream the audio
        lk_url = data.get("livekit_url")
        lk_token = data.get("livekit_token")
        room_name = data.get("livekit_room", "default-room")

        # Streaming logic
        asyncio.run(self.stream_audio(lk_url, lk_token, room_name, waveform))

        return [{"status": "success"}]

    def generate_waveform_from_tokens(self, audio_token_ids):
        """
        Convert audio tokens into a waveform (this part is for demonstration).
        You should implement a proper method to decode tokens to actual audio.
        """
        # Here we're simulating the waveform by generating random data based on the tokens
        # Replace this logic with actual audio generation
        num_samples = len(audio_token_ids) * 100  # Estimate number of samples based on tokens
        waveform = torch.randn(1, num_samples)  # Simulate random audio waveform
        return waveform

    async def stream_audio(self, lk_url, lk_token, room_name, waveform):
        room = rtc.Room()
        try:
            await room.connect(lk_url, lk_token, options=rtc.RoomOptions(auto_subscribe=True))
        except Exception as e:
            return f"Failed to connect to LiveKit: {e}"

        # Create an audio track for streaming the TTS output
        source = rtc.AudioSource(sample_rate=24000, num_channels=1)
        track = rtc.LocalAudioTrack.create_audio_track("tts-audio", source)
        await room.local_participant.publish_track(track, rtc.TrackPublishOptions(name="TTS Audio"))

        # Stream the waveform data in chunks for real-time playback
        frame_duration = 0.05  # 50 ms per frame
        frame_samples = int(24000 * frame_duration)  # 50 ms of audio at 24 kHz sample rate
        total_samples = waveform.size(1)
        for start in range(0, total_samples, frame_samples):
            end = min(start + frame_samples, total_samples)
            chunk = waveform[:, start:end].numpy().astype(np.int16)  # Convert chunk to 16-bit PCM

            # Create an AudioFrame and send to LiveKit
            audio_frame = rtc.AudioFrame.create(24000, 1, len(chunk))
            np.copyto(audio_frame.data, chunk)
            await source.capture_frame(audio_frame)

            # Sleep to maintain real-time pace (synchronize with frame duration)
            await asyncio.sleep(frame_duration)

        # Disconnect from the room after streaming is finished
        await room.disconnect()