Athspi commited on
Commit
7a742e9
·
verified ·
1 Parent(s): 5eaee65

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -65
app.py CHANGED
@@ -8,29 +8,28 @@ from dotenv import load_dotenv
8
  # Load environment variables
9
  load_dotenv()
10
 
 
 
 
11
  # Device and torch dtype selection
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
14
 
15
- # Define a no-op decorator for CPU if needed
16
  def gpu_decorator(func):
17
  return func
18
 
19
- # If you are on GPU and have the spaces module, you could replace gpu_decorator with spaces.GPU
20
- # For CPU usage we simply use a no-op
21
- # Example: from snac import spaces; gpu_decorator = spaces.GPU()
22
-
23
  # Import SNAC after setting device
24
  from snac import SNAC
25
 
26
  print("Loading SNAC model...")
27
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
28
  snac_model = snac_model.to(device)
29
- snac_model.eval() # set SNAC to eval mode
30
 
31
  model_name = "canopylabs/orpheus-3b-0.1-ft"
32
 
33
- # Download only model config and safetensors files
34
  snapshot_download(
35
  repo_id=model_name,
36
  allow_patterns=[
@@ -55,23 +54,30 @@ snapshot_download(
55
  print("Loading Orpheus model...")
56
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch_dtype)
57
  model.to(device)
58
- model.eval() # set Orpheus to eval mode
 
 
 
 
 
 
 
 
 
59
  tokenizer = AutoTokenizer.from_pretrained(model_name)
60
  print(f"Orpheus model loaded to {device}")
61
 
62
- # Process text prompt into tokens with start/end markers
63
  def process_prompt(prompt, voice, tokenizer, device):
64
  prompt = f"{voice}: {prompt}"
65
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
66
 
67
- start_token = torch.tensor([[128259]], dtype=torch.int64) # Start token
68
- end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End tokens
69
 
70
  modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
71
  attention_mask = torch.ones_like(modified_input_ids)
72
  return modified_input_ids.to(device), attention_mask.to(device)
73
 
74
- # Parse output tokens to extract audio codes
75
  def parse_output(generated_ids):
76
  token_to_find = 128257
77
  token_to_remove = 128258
@@ -96,9 +102,8 @@ def parse_output(generated_ids):
96
  trimmed_row = [t - 128266 for t in trimmed_row]
97
  code_lists.append(trimmed_row)
98
 
99
- return code_lists[0] # Return first sample
100
 
101
- # Redistribute codes for audio generation using SNAC
102
  def redistribute_codes(code_list, snac_model):
103
  snac_device = next(snac_model.parameters()).device
104
  layer_1, layer_2, layer_3 = [], [], []
@@ -116,21 +121,17 @@ def redistribute_codes(code_list, snac_model):
116
  torch.tensor(layer_2, device=snac_device).unsqueeze(0),
117
  torch.tensor(layer_3, device=snac_device).unsqueeze(0)
118
  ]
119
-
120
  audio_hat = snac_model.decode(codes)
121
  return audio_hat.detach().squeeze().cpu().numpy()
122
 
123
- # Main generation function with CPU optimizations
124
  @gpu_decorator
125
  def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
126
  if not text.strip():
127
  return None
128
-
129
  try:
130
- progress(0.1, "Processing text...")
131
  input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
132
-
133
- progress(0.3, "Generating speech tokens...")
134
  with torch.inference_mode():
135
  generated_ids = model.generate(
136
  input_ids=input_ids,
@@ -143,71 +144,73 @@ def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new
143
  num_return_sequences=1,
144
  eos_token_id=128258,
145
  )
146
-
147
- progress(0.6, "Processing speech tokens...")
148
  code_list = parse_output(generated_ids)
149
-
150
- progress(0.8, "Converting tokens to audio...")
151
  audio_samples = redistribute_codes(code_list, snac_model)
152
-
153
- return (24000, audio_samples) # Return sample rate and numpy array audio
154
  except Exception as e:
155
  print(f"Error generating speech: {e}")
156
  return None
157
 
158
- # Example inputs for the Gradio UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  examples = [
160
  ["Hey there my name is Tara, <chuckle> and I'm a speech generation model that can sound like a person.", "tara", 0.6, 0.95, 1.1, 1200],
161
  ["I've also been taught to understand and produce paralinguistic things like sighing, or chuckling, or yawning!", "dan", 0.7, 0.95, 1.1, 1200],
162
  ["I live in San Francisco, and have, uhm let's see, 3 billion 7 hundred ... well, let's just say a lot of parameters.", "emma", 0.6, 0.9, 1.2, 1200]
163
  ]
164
-
165
  VOICES = ["tara", "dan", "josh", "emma"]
166
 
167
- # Create Gradio interface
168
  with gr.Blocks(title="Orpheus Text-to-Speech") as demo:
169
  gr.Markdown("""
170
  # 🎵 Orpheus Text-to-Speech
171
- Enter your text below and hear it converted to natural-sounding speech.
172
 
173
- **Tips for better prompts:**
174
- - Include paralinguistic elements like `<chuckle>`, `<sigh>`, or `uhm` for more human-like speech.
175
- - Longer prompts often produce more natural results.
176
- - Adjust the temperature slider to control variation in speech patterns.
177
  """)
178
  with gr.Row():
179
  with gr.Column(scale=3):
180
- text_input = gr.Textbox(
181
- label="Text to speak",
182
- placeholder="Enter your text here...",
183
- lines=5
184
- )
185
- voice = gr.Dropdown(
186
- choices=VOICES,
187
- value="tara",
188
- label="Voice"
189
- )
190
  with gr.Accordion("Advanced Settings", open=False):
191
- temperature = gr.Slider(
192
- minimum=0.1, maximum=1.5, value=0.6, step=0.05,
193
- label="Temperature",
194
- info="Higher values (0.7-1.0) create more expressive but less stable speech"
195
- )
196
- top_p = gr.Slider(
197
- minimum=0.1, maximum=1.0, value=0.95, step=0.05,
198
- label="Top P",
199
- info="Nucleus sampling threshold"
200
- )
201
- repetition_penalty = gr.Slider(
202
- minimum=1.0, maximum=2.0, value=1.1, step=0.05,
203
- label="Repetition Penalty",
204
- info="Higher values discourage repetitive patterns"
205
- )
206
- max_new_tokens = gr.Slider(
207
- minimum=100, maximum=2000, value=1200, step=100,
208
- label="Max Length",
209
- info="Maximum length of generated audio (in tokens)"
210
- )
211
  with gr.Row():
212
  submit_btn = gr.Button("Generate Speech", variant="primary")
213
  clear_btn = gr.Button("Clear")
@@ -232,7 +235,11 @@ with gr.Blocks(title="Orpheus Text-to-Speech") as demo:
232
  inputs=[],
233
  outputs=[text_input, audio_output]
234
  )
 
 
 
 
 
235
 
236
- # Launch the Gradio app
237
  if __name__ == "__main__":
238
  demo.queue().launch(share=False, ssr_mode=False)
 
8
  # Load environment variables
9
  load_dotenv()
10
 
11
+ # Set number of threads (adjust based on your CPU cores)
12
+ torch.set_num_threads(4)
13
+
14
  # Device and torch dtype selection
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
17
 
18
+ # No-op decorator for CPU mode (if you had GPU-specific decorators)
19
  def gpu_decorator(func):
20
  return func
21
 
 
 
 
 
22
  # Import SNAC after setting device
23
  from snac import SNAC
24
 
25
  print("Loading SNAC model...")
26
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
27
  snac_model = snac_model.to(device)
28
+ snac_model.eval() # Set SNAC to eval mode
29
 
30
  model_name = "canopylabs/orpheus-3b-0.1-ft"
31
 
32
+ # Download only necessary files for the Orpheus model
33
  snapshot_download(
34
  repo_id=model_name,
35
  allow_patterns=[
 
54
  print("Loading Orpheus model...")
55
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch_dtype)
56
  model.to(device)
57
+ model.eval() # Set the model to evaluation mode
58
+
59
+ # Optionally compile the model for PyTorch 2.0+ on CPU (if available)
60
+ if hasattr(torch, "compile") and device == "cpu":
61
+ try:
62
+ model = torch.compile(model)
63
+ print("Model compiled with torch.compile")
64
+ except Exception as e:
65
+ print("torch.compile not supported:", e)
66
+
67
  tokenizer = AutoTokenizer.from_pretrained(model_name)
68
  print(f"Orpheus model loaded to {device}")
69
 
 
70
  def process_prompt(prompt, voice, tokenizer, device):
71
  prompt = f"{voice}: {prompt}"
72
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
73
 
74
+ start_token = torch.tensor([[128259]], dtype=torch.int64)
75
+ end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)
76
 
77
  modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
78
  attention_mask = torch.ones_like(modified_input_ids)
79
  return modified_input_ids.to(device), attention_mask.to(device)
80
 
 
81
  def parse_output(generated_ids):
82
  token_to_find = 128257
83
  token_to_remove = 128258
 
102
  trimmed_row = [t - 128266 for t in trimmed_row]
103
  code_lists.append(trimmed_row)
104
 
105
+ return code_lists[0]
106
 
 
107
  def redistribute_codes(code_list, snac_model):
108
  snac_device = next(snac_model.parameters()).device
109
  layer_1, layer_2, layer_3 = [], [], []
 
121
  torch.tensor(layer_2, device=snac_device).unsqueeze(0),
122
  torch.tensor(layer_3, device=snac_device).unsqueeze(0)
123
  ]
 
124
  audio_hat = snac_model.decode(codes)
125
  return audio_hat.detach().squeeze().cpu().numpy()
126
 
 
127
  @gpu_decorator
128
  def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
129
  if not text.strip():
130
  return None
 
131
  try:
132
+ progress(0.05, "Processing text...")
133
  input_ids, attention_mask = process_prompt(text, voice, tokenizer, device)
134
+ progress(0.2, "Generating tokens...")
 
135
  with torch.inference_mode():
136
  generated_ids = model.generate(
137
  input_ids=input_ids,
 
144
  num_return_sequences=1,
145
  eos_token_id=128258,
146
  )
147
+ progress(0.4, "Parsing tokens...")
 
148
  code_list = parse_output(generated_ids)
149
+ progress(0.7, "Generating audio...")
 
150
  audio_samples = redistribute_codes(code_list, snac_model)
151
+ progress(1.0, "Done")
152
+ return (24000, audio_samples)
153
  except Exception as e:
154
  print(f"Error generating speech: {e}")
155
  return None
156
 
157
+ def convert_model_to_onnx():
158
+ """
159
+ Converts the Orpheus model to ONNX format using a dummy prompt.
160
+ The exported file will be saved as 'orpheus_model.onnx' in the working directory.
161
+ """
162
+ dummy_prompt = "tara: Hello"
163
+ dummy_input = tokenizer(dummy_prompt, return_tensors="pt").input_ids.to(device)
164
+ file_path = "orpheus_model.onnx"
165
+ try:
166
+ # Export the model to ONNX format
167
+ torch.onnx.export(
168
+ model,
169
+ dummy_input,
170
+ file_path,
171
+ export_params=True,
172
+ opset_version=14,
173
+ input_names=["input_ids"],
174
+ output_names=["logits"],
175
+ dynamic_axes={
176
+ "input_ids": {0: "batch_size", 1: "sequence_length"},
177
+ "logits": {0: "batch_size", 1: "sequence_length"}
178
+ },
179
+ )
180
+ return f"Model converted to ONNX and saved as '{file_path}'."
181
+ except Exception as e:
182
+ return f"Error during ONNX conversion: {e}"
183
+
184
+ # UI examples and voice choices
185
  examples = [
186
  ["Hey there my name is Tara, <chuckle> and I'm a speech generation model that can sound like a person.", "tara", 0.6, 0.95, 1.1, 1200],
187
  ["I've also been taught to understand and produce paralinguistic things like sighing, or chuckling, or yawning!", "dan", 0.7, 0.95, 1.1, 1200],
188
  ["I live in San Francisco, and have, uhm let's see, 3 billion 7 hundred ... well, let's just say a lot of parameters.", "emma", 0.6, 0.9, 1.2, 1200]
189
  ]
 
190
  VOICES = ["tara", "dan", "josh", "emma"]
191
 
 
192
  with gr.Blocks(title="Orpheus Text-to-Speech") as demo:
193
  gr.Markdown("""
194
  # 🎵 Orpheus Text-to-Speech
195
+ Enter text to hear it converted to natural-sounding speech.
196
 
197
+ **Tips:**
198
+ - Use paralinguistic cues like `<chuckle>` or `<sigh>`.
199
+ - Longer text can produce more natural results.
 
200
  """)
201
  with gr.Row():
202
  with gr.Column(scale=3):
203
+ text_input = gr.Textbox(label="Text to speak", placeholder="Enter your text...", lines=5)
204
+ voice = gr.Dropdown(choices=VOICES, value="tara", label="Voice")
 
 
 
 
 
 
 
 
205
  with gr.Accordion("Advanced Settings", open=False):
206
+ temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.6, step=0.05, label="Temperature",
207
+ info="Higher values produce more varied speech")
208
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top P",
209
+ info="Nucleus sampling threshold")
210
+ repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition Penalty",
211
+ info="Discourage repetition")
212
+ max_new_tokens = gr.Slider(minimum=100, maximum=2000, value=1200, step=100, label="Max Length",
213
+ info="Maximum generated tokens")
 
 
 
 
 
 
 
 
 
 
 
 
214
  with gr.Row():
215
  submit_btn = gr.Button("Generate Speech", variant="primary")
216
  clear_btn = gr.Button("Clear")
 
235
  inputs=[],
236
  outputs=[text_input, audio_output]
237
  )
238
+
239
+ gr.Markdown("## ONNX Conversion")
240
+ onnx_btn = gr.Button("Convert Model to ONNX")
241
+ onnx_output = gr.Textbox(label="Conversion Output")
242
+ onnx_btn.click(fn=convert_model_to_onnx, inputs=[], outputs=onnx_output)
243
 
 
244
  if __name__ == "__main__":
245
  demo.queue().launch(share=False, ssr_mode=False)