bluenevus commited on
Commit
c594756
·
verified ·
1 Parent(s): 1f13bd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -215
app.py CHANGED
@@ -1,50 +1,132 @@
1
- import spaces
 
 
 
 
 
2
  from snac import SNAC
3
  import torch
4
- import gradio as gr
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
- from huggingface_hub import snapshot_download
7
  import google.generativeai as genai
8
  import re
9
  import logging
10
  import numpy as np
11
  from pydub import AudioSegment
12
- import io
13
  from docx import Document
14
  import PyPDF2
15
 
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
 
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
 
 
21
  print("Loading SNAC model...")
22
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
23
  snac_model = snac_model.to(device)
24
 
25
  model_name = "canopylabs/orpheus-3b-0.1-ft"
26
-
27
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
28
  model.to(device)
29
  tokenizer = AutoTokenizer.from_pretrained(model_name)
30
  print(f"Orpheus model loaded to {device}")
31
 
32
- # Available voices
33
  VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"]
34
-
35
- # Available Emotive Tags
36
  EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
37
 
38
- @spaces.GPU()
39
- def generate_podcast_script(api_key, host1_name, host2_name, podcast_name, podcast_topic, prompt, uploaded_file, duration, num_hosts):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  try:
 
 
 
 
 
41
  genai.configure(api_key=api_key)
42
  model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
43
 
44
  combined_content = prompt or ""
45
 
46
- if uploaded_file is not None:
47
- file_bytes = io.BytesIO(uploaded_file)
 
 
48
 
49
  # Try to detect the file type based on content
50
  file_bytes.seek(0)
@@ -105,99 +187,26 @@ def generate_podcast_script(api_key, host1_name, host2_name, podcast_name, podca
105
  return re.sub(r'[^a-zA-Z0-9\s.,?!<>]', '', response.text)
106
  except Exception as e:
107
  logger.error(f"Error generating podcast script: {str(e)}")
108
- raise
109
-
110
- def process_prompt(prompt, voice, tokenizer, device):
111
- prompt = f"{voice}: {prompt}"
112
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids
113
-
114
- start_token = torch.tensor([[128259]], dtype=torch.int64)
115
- end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)
116
-
117
- modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
118
- attention_mask = torch.ones_like(modified_input_ids)
119
-
120
- return modified_input_ids.to(device), attention_mask.to(device)
121
-
122
- def parse_output(generated_ids):
123
- token_to_find = 128257
124
- token_to_remove = 128258
125
-
126
- token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
127
-
128
- if len(token_indices[1]) > 0:
129
- last_occurrence_idx = token_indices[1][-1].item()
130
- cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
131
- else:
132
- cropped_tensor = generated_ids
133
-
134
- processed_rows = []
135
- for row in cropped_tensor:
136
- masked_row = row[row != token_to_remove]
137
- processed_rows.append(masked_row)
138
-
139
- code_lists = []
140
- for row in processed_rows:
141
- row_length = row.size(0)
142
- new_length = (row_length // 7) * 7
143
- trimmed_row = row[:new_length]
144
- trimmed_row = [t - 128266 for t in trimmed_row]
145
- code_lists.append(trimmed_row)
146
-
147
- return code_lists[0]
148
-
149
- def redistribute_codes(code_list, snac_model):
150
- device = next(snac_model.parameters()).device # Get the device of SNAC model
151
-
152
- layer_1 = []
153
- layer_2 = []
154
- layer_3 = []
155
- for i in range((len(code_list)+1)//7):
156
- layer_1.append(code_list[7*i])
157
- layer_2.append(code_list[7*i+1]-4096)
158
- layer_3.append(code_list[7*i+2]-(2*4096))
159
- layer_3.append(code_list[7*i+3]-(3*4096))
160
- layer_2.append(code_list[7*i+4]-(4*4096))
161
- layer_3.append(code_list[7*i+5]-(5*4096))
162
- layer_3.append(code_list[7*i+6]-(6*4096))
163
-
164
- codes = [
165
- torch.tensor(layer_1, device=device).unsqueeze(0),
166
- torch.tensor(layer_2, device=device).unsqueeze(0),
167
- torch.tensor(layer_3, device=device).unsqueeze(0)
168
- ]
169
-
170
- audio_hat = snac_model.decode(codes)
171
- return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
172
-
173
- def detect_silence(audio, threshold=0.005, min_silence_duration=1.3):
174
- sample_rate = 24000 # Adjust if your sample rate is different
175
- is_silent = np.abs(audio) < threshold
176
- silent_regions = np.where(is_silent)[0]
177
-
178
- silence_starts = []
179
- silence_ends = []
180
-
181
- if len(silent_regions) > 0:
182
- silence_starts.append(silent_regions[0])
183
- for i in range(1, len(silent_regions)):
184
- if silent_regions[i] - silent_regions[i-1] > 1:
185
- silence_ends.append(silent_regions[i-1])
186
- silence_starts.append(silent_regions[i])
187
- silence_ends.append(silent_regions[-1])
188
-
189
- long_silences = [(start, end) for start, end in zip(silence_starts, silence_ends)
190
- if (end - start) / sample_rate >= min_silence_duration]
191
-
192
- return long_silences
193
 
194
- @spaces.GPU()
195
- def generate_speech(text, voice1, voice2, temperature, top_p, repetition_penalty, max_new_tokens, num_hosts, progress=gr.Progress()):
196
- if not text.strip():
197
- return None
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  try:
200
- progress(0.1, "Processing text...")
201
  paragraphs = text.split('\n\n') # Split by double newline
202
  audio_samples = []
203
 
@@ -209,7 +218,6 @@ def generate_speech(text, voice1, voice2, temperature, top_p, repetition_penalty
209
 
210
  input_ids, attention_mask = process_prompt(paragraph, voice, tokenizer, device)
211
 
212
- progress(0.3, f"Generating speech tokens for paragraph {i+1}...")
213
  with torch.no_grad():
214
  generated_ids = model.generate(
215
  input_ids,
@@ -223,130 +231,50 @@ def generate_speech(text, voice1, voice2, temperature, top_p, repetition_penalty
223
  eos_token_id=128258,
224
  )
225
 
226
- progress(0.6, f"Processing speech tokens for paragraph {i+1}...")
227
  code_list = parse_output(generated_ids)
228
-
229
- progress(0.8, f"Converting paragraph {i+1} to audio...")
230
  paragraph_audio = redistribute_codes(code_list, snac_model)
231
 
232
- # Add silence detection here
233
  silences = detect_silence(paragraph_audio)
234
  if silences:
235
- # Trim the audio at the last detected silence
236
  paragraph_audio = paragraph_audio[:silences[-1][1]]
237
 
238
  audio_samples.append(paragraph_audio)
239
 
240
  final_audio = np.concatenate(audio_samples)
241
-
242
- # Normalize the audio
243
  final_audio = np.int16(final_audio / np.max(np.abs(final_audio)) * 32767)
244
-
245
- return (24000, final_audio)
 
 
 
 
246
  except Exception as e:
247
- print(f"Error generating speech: {e}")
248
- return None
249
-
250
- with gr.Blocks(title="Orpheus Text-to-Speech") as demo:
251
- with gr.Row():
252
- def get_field_value(field, default=""):
253
- return field.value if field.value and not field.value.isspace() else default
254
- with gr.Column(scale=1):
255
- gemini_api_key = gr.Textbox(label="Gemini API Key", type="password")
256
- host1_name = gr.Textbox(label="Name of Podcast Host 1", placeholder="Enter name of first host")
257
- host2_name = gr.Textbox(label="Name of Podcast Host 2", placeholder="Enter name of second host")
258
- podcast_name = gr.Textbox(label="Name of Podcast", placeholder="Enter podcast name")
259
- podcast_topic = gr.Textbox(label="Podcast Topic", placeholder="Enter podcast topic")
260
- prompt = gr.Textbox(
261
- label="Prompt",
262
- placeholder="Enter your text here...",
263
- lines=5,
264
- max_lines=30,
265
- show_label=True,
266
- interactive=True,
267
- container=True
268
- )
269
-
270
- with gr.Column(scale=2):
271
- uploaded_file = gr.File(label="Upload File", type="binary")
272
- duration = gr.Slider(minimum=1, maximum=60, value=5, step=1, label="Duration (minutes)")
273
- num_hosts = gr.Radio(["1", "2"], label="Number of Hosts", value="1")
274
- script_output = gr.Textbox(label="Generated Script", lines=10)
275
- generate_script_btn = gr.Button("Generate Podcast Script") # Add this line
276
- generate_script_btn.click(
277
- fn=generate_podcast_script,
278
- inputs=[
279
- gemini_api_key,
280
- host1_name,
281
- host2_name,
282
- podcast_name,
283
- podcast_topic,
284
- prompt,
285
- uploaded_file,
286
- duration,
287
- num_hosts
288
- ],
289
- outputs=script_output
290
- )
291
 
292
- with gr.Column(scale=2):
293
- voice1 = gr.Dropdown(
294
- choices=VOICES,
295
- value="tara",
296
- label="Voice 1",
297
- info="Select the first voice for speech generation"
298
- )
299
- voice2 = gr.Dropdown(
300
- choices=VOICES,
301
- value="zac",
302
- label="Voice 2",
303
- info="Select the second voice for speech generation"
304
- )
305
 
306
- with gr.Accordion("Advanced Settings", open=False):
307
- temperature = gr.Slider(
308
- minimum=0.1, maximum=1.5, value=0.6, step=0.05,
309
- label="Temperature",
310
- info="Higher values (0.7-1.0) create more expressive but less stable speech"
311
- )
312
- top_p = gr.Slider(
313
- minimum=0.1, maximum=1.0, value=0.9, step=0.05,
314
- label="Top P",
315
- info="Higher values produce more diverse outputs"
316
- )
317
- repetition_penalty = gr.Slider(
318
- minimum=1.0, maximum=2.0, value=1.2, step=0.1,
319
- label="Repetition Penalty",
320
- info="Higher values discourage repetitive patterns"
321
- )
322
- max_new_tokens = gr.Slider(
323
- minimum=100, maximum=16384, value=4096, step=100,
324
- label="Max Length",
325
- info="Maximum length of generated audio (in tokens)"
326
- )
327
-
328
- audio_output = gr.Audio(label="Generated Audio", type="numpy")
329
- with gr.Row():
330
- submit_btn = gr.Button("Generate Audio", variant="primary")
331
- clear_btn = gr.Button("Clear")
332
-
333
- generate_script_btn.click(
334
- fn=generate_podcast_script,
335
- inputs=[gemini_api_key, prompt, uploaded_file, duration, num_hosts],
336
- outputs=script_output
337
- )
338
-
339
- submit_btn.click(
340
- fn=generate_speech,
341
- inputs=[script_output, voice1, voice2, temperature, top_p, repetition_penalty, max_new_tokens, num_hosts],
342
- outputs=audio_output
343
- )
344
-
345
- clear_btn.click(
346
- fn=lambda: (None, None, None),
347
- inputs=[],
348
- outputs=[prompt, script_output, audio_output]
349
- )
350
 
351
- if __name__ == "__main__":
352
- demo.queue().launch(share=False, ssr_mode=False)
 
 
 
 
1
+ import dash
2
+ from dash import dcc, html, Input, Output, State, callback
3
+ import dash_bootstrap_components as dbc
4
+ import base64
5
+ import io
6
+ import os
7
  from snac import SNAC
8
  import torch
 
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
10
  import google.generativeai as genai
11
  import re
12
  import logging
13
  import numpy as np
14
  from pydub import AudioSegment
 
15
  from docx import Document
16
  import PyPDF2
17
 
18
+ # Initialize logging
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
22
+ # Initialize device
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
 
25
+ # Load models
26
  print("Loading SNAC model...")
27
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
28
  snac_model = snac_model.to(device)
29
 
30
  model_name = "canopylabs/orpheus-3b-0.1-ft"
 
31
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
32
  model.to(device)
33
  tokenizer = AutoTokenizer.from_pretrained(model_name)
34
  print(f"Orpheus model loaded to {device}")
35
 
36
+ # Available voices and emotive tags
37
  VOICES = ["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"]
 
 
38
  EMOTIVE_TAGS = ["`<laugh>`", "`<chuckle>`", "`<sigh>`", "`<cough>`", "`<sniffle>`", "`<groan>`", "`<yawn>`", "`<gasp>`"]
39
 
40
+ # Initialize Dash app
41
+ app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
42
+
43
+ # Layout
44
+ app.layout = dbc.Container([
45
+ dbc.Row([
46
+ dbc.Col([
47
+ html.H1("Orpheus Text-to-Speech", className="mb-4"),
48
+ dbc.Input(id="host1-name", placeholder="Enter name of first host", className="mb-2"),
49
+ dbc.Input(id="host2-name", placeholder="Enter name of second host", className="mb-2"),
50
+ dbc.Input(id="podcast-name", placeholder="Enter podcast name", className="mb-2"),
51
+ dbc.Input(id="podcast-topic", placeholder="Enter podcast topic", className="mb-2"),
52
+ dbc.Textarea(id="prompt", placeholder="Enter your text here...", rows=5, className="mb-2"),
53
+ dcc.Upload(
54
+ id='upload-file',
55
+ children=html.Div(['Drag and Drop or ', html.A('Select a File')]),
56
+ style={
57
+ 'width': '100%',
58
+ 'height': '60px',
59
+ 'lineHeight': '60px',
60
+ 'borderWidth': '1px',
61
+ 'borderStyle': 'dashed',
62
+ 'borderRadius': '5px',
63
+ 'textAlign': 'center',
64
+ 'margin': '10px 0'
65
+ },
66
+ ),
67
+ dcc.Slider(id="duration", min=1, max=60, value=5, step=1, marks={1: '1', 30: '30', 60: '60'}, className="mb-2"),
68
+ dbc.RadioItems(
69
+ id="num-hosts",
70
+ options=[{"label": i, "value": i} for i in ["1", "2"]],
71
+ value="1",
72
+ inline=True,
73
+ className="mb-2"
74
+ ),
75
+ dbc.Button("Generate Podcast Script", id="generate-script-btn", color="primary", className="mb-2"),
76
+ ], width=6),
77
+ dbc.Col([
78
+ dbc.Textarea(id="script-output", placeholder="Generated script will appear here...", rows=10, className="mb-2"),
79
+ dcc.Dropdown(id="voice1", options=[{"label": v, "value": v} for v in VOICES], value="tara", className="mb-2"),
80
+ dcc.Dropdown(id="voice2", options=[{"label": v, "value": v} for v in VOICES], value="zac", className="mb-2"),
81
+ dbc.Button("Generate Audio", id="generate-audio-btn", color="success", className="mb-2"),
82
+ html.Div(id="audio-output"),
83
+ dbc.Button("Clear", id="clear-btn", color="secondary", className="mb-2"),
84
+ dbc.Collapse([
85
+ dcc.Slider(id="temperature", min=0.1, max=1.5, value=0.6, step=0.05, marks={0.1: '0.1', 0.8: '0.8', 1.5: '1.5'}, className="mb-2"),
86
+ dcc.Slider(id="top-p", min=0.1, max=1.0, value=0.9, step=0.05, marks={0.1: '0.1', 0.5: '0.5', 1.0: '1.0'}, className="mb-2"),
87
+ dcc.Slider(id="repetition-penalty", min=1.0, max=2.0, value=1.2, step=0.1, marks={1.0: '1.0', 1.5: '1.5', 2.0: '2.0'}, className="mb-2"),
88
+ dcc.Slider(id="max-new-tokens", min=100, max=16384, value=4096, step=100, marks={100: '100', 8192: '8192', 16384: '16384'}, className="mb-2"),
89
+ ], id="advanced-settings", is_open=False),
90
+ dbc.Button("Advanced Settings", id="advanced-settings-toggle", color="info", className="mb-2"),
91
+ ], width=6),
92
+ ]),
93
+ dcc.Store(id='generated-script'),
94
+ dcc.Store(id='generated-audio'),
95
+ ])
96
+
97
+ # Callbacks
98
+ @callback(
99
+ Output("script-output", "value"),
100
+ Input("generate-script-btn", "n_clicks"),
101
+ State("host1-name", "value"),
102
+ State("host2-name", "value"),
103
+ State("podcast-name", "value"),
104
+ State("podcast-topic", "value"),
105
+ State("prompt", "value"),
106
+ State("upload-file", "contents"),
107
+ State("duration", "value"),
108
+ State("num-hosts", "value"),
109
+ prevent_initial_call=True
110
+ )
111
+ def generate_podcast_script(n_clicks, host1_name, host2_name, podcast_name, podcast_topic, prompt, uploaded_file, duration, num_hosts):
112
+ if n_clicks is None:
113
+ return ""
114
+
115
  try:
116
+ # Get the Gemini API key from Hugging Face secrets
117
+ api_key = os.environ.get("GEMINI_API_KEY")
118
+ if not api_key:
119
+ raise ValueError("Gemini API key not found in environment variables")
120
+
121
  genai.configure(api_key=api_key)
122
  model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
123
 
124
  combined_content = prompt or ""
125
 
126
+ if uploaded_file:
127
+ content_type, content_string = uploaded_file.split(',')
128
+ decoded = base64.b64decode(content_string)
129
+ file_bytes = io.BytesIO(decoded)
130
 
131
  # Try to detect the file type based on content
132
  file_bytes.seek(0)
 
187
  return re.sub(r'[^a-zA-Z0-9\s.,?!<>]', '', response.text)
188
  except Exception as e:
189
  logger.error(f"Error generating podcast script: {str(e)}")
190
+ return f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
+ @callback(
193
+ Output("audio-output", "children"),
194
+ Input("generate-audio-btn", "n_clicks"),
195
+ State("script-output", "value"),
196
+ State("voice1", "value"),
197
+ State("voice2", "value"),
198
+ State("temperature", "value"),
199
+ State("top-p", "value"),
200
+ State("repetition-penalty", "value"),
201
+ State("max-new-tokens", "value"),
202
+ State("num-hosts", "value"),
203
+ prevent_initial_call=True
204
+ )
205
+ def generate_speech(n_clicks, text, voice1, voice2, temperature, top_p, repetition_penalty, max_new_tokens, num_hosts):
206
+ if n_clicks is None or not text.strip():
207
+ return html.Div("No audio generated yet.")
208
 
209
  try:
 
210
  paragraphs = text.split('\n\n') # Split by double newline
211
  audio_samples = []
212
 
 
218
 
219
  input_ids, attention_mask = process_prompt(paragraph, voice, tokenizer, device)
220
 
 
221
  with torch.no_grad():
222
  generated_ids = model.generate(
223
  input_ids,
 
231
  eos_token_id=128258,
232
  )
233
 
 
234
  code_list = parse_output(generated_ids)
 
 
235
  paragraph_audio = redistribute_codes(code_list, snac_model)
236
 
 
237
  silences = detect_silence(paragraph_audio)
238
  if silences:
 
239
  paragraph_audio = paragraph_audio[:silences[-1][1]]
240
 
241
  audio_samples.append(paragraph_audio)
242
 
243
  final_audio = np.concatenate(audio_samples)
 
 
244
  final_audio = np.int16(final_audio / np.max(np.abs(final_audio)) * 32767)
245
+
246
+ # Convert to base64 for audio playback
247
+ audio_base64 = base64.b64encode(final_audio.tobytes()).decode('utf-8')
248
+ src = f"data:audio/wav;base64,{audio_base64}"
249
+
250
+ return html.Audio(src=src, controls=True)
251
  except Exception as e:
252
+ logger.error(f"Error generating speech: {str(e)}")
253
+ return html.Div(f"Error generating audio: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
+ @callback(
256
+ Output("advanced-settings", "is_open"),
257
+ Input("advanced-settings-toggle", "n_clicks"),
258
+ State("advanced-settings", "is_open"),
259
+ )
260
+ def toggle_advanced_settings(n_clicks, is_open):
261
+ if n_clicks:
262
+ return not is_open
263
+ return is_open
 
 
 
 
264
 
265
+ @callback(
266
+ Output("prompt", "value"),
267
+ Output("script-output", "value"),
268
+ Output("audio-output", "children"),
269
+ Input("clear-btn", "n_clicks"),
270
+ )
271
+ def clear_outputs(n_clicks):
272
+ if n_clicks:
273
+ return "", "", html.Div("No audio generated yet.")
274
+ return dash.no_update, dash.no_update, dash.no_update
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
+ # Run the app
277
+ if __name__ == '__main__':
278
+ print("Starting the Dash application...")
279
+ app.run(debug=True, host='0.0.0.0', port=7860)
280
+ print("Dash application has finished running.")