ylacombe commited on
Commit
cf600c8
·
1 Parent(s): e45093f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -16
app.py CHANGED
@@ -47,9 +47,21 @@ text_client = InferenceClient(
47
  image_client = Client("https://openskyml-fast-sdxl-stable-diffusion-xl.hf.space/--replicas/ffe2bn2dk/")
48
  image_negative_prompt = "ultrarealistic, soft lighting, 8k, ugly, text, blurry"
49
  image_positive_prompt = ""
50
- image_seed = 9
51
 
52
  processor = AutoProcessor.from_pretrained("suno/bark")
 
 
 
 
 
 
 
 
 
 
 
 
53
  model = BarkModel.from_pretrained("suno/bark", torch_dtype=torch.float16).to(device)
54
  sampling_rate = model.generation_config.sample_rate
55
  silence = np.zeros(int(0.25 * sampling_rate)) # quarter second of silence
@@ -125,7 +137,7 @@ def generate_story(
125
  return output
126
 
127
 
128
- def generate_audio_and_image(story_prompt, voice_preset=voice_preset):
129
 
130
 
131
  story = generate_story(story_prompt)
@@ -153,7 +165,7 @@ def generate_audio_and_image(story_prompt, voice_preset=voice_preset):
153
  inputs = model_input[BATCH_SIZE*i:min(BATCH_SIZE*(i+1), len(model_input))]
154
 
155
  if len(inputs) != 0:
156
- inputs = processor(inputs, voice_preset=voice_preset)
157
 
158
  speech_output, output_lengths = model.generate(**inputs.to(device), return_output_lengths=True, min_eos_p=0.2)
159
 
@@ -163,9 +175,12 @@ def generate_audio_and_image(story_prompt, voice_preset=voice_preset):
163
  pieces += [*speech_output, silence.copy()]
164
 
165
  print("Calling image")
166
-
167
- # TODO: if error catch it
168
- img = job_img.result()
 
 
 
169
 
170
  return story, (sampling_rate, np.concatenate(pieces)), img
171
 
@@ -175,16 +190,16 @@ def generate_audio_and_image(story_prompt, voice_preset=voice_preset):
175
  # Gradio blocks demo
176
  with gr.Blocks() as demo_blocks:
177
  gr.Markdown("""<h1 align="center">🐶Children story</h1>""")
178
- gr.HTML("""<h3 style="text-align:center;">📢Audio Streaming powered by Gradio (v3.40.0 onwards)🦾! </h3>""")
179
  with gr.Group():
180
  with gr.Row():
181
  inp_text = gr.Textbox(label="Story prompt", info="Enter text here")
182
- #dd = gr.Dropdown(
183
- # speaker_embeddings,
184
- # value=None,
185
- # label="Available voice presets",
186
- # info="Defaults to no speaker embeddings!"
187
- # )
188
 
189
 
190
  with gr.Row():
@@ -197,8 +212,24 @@ with gr.Blocks() as demo_blocks:
197
  out_audio = gr.Audio(
198
  streaming=False, autoplay=True) # needed to stream output audio
199
  out_text = gr.Text()
200
- btn.click(generate_audio_and_image, [inp_text], [out_text, out_audio, image_output] ) #[out_audio]) #, out_count])
201
 
202
-
203
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  demo_blocks.queue().launch(debug=True)
 
47
  image_client = Client("https://openskyml-fast-sdxl-stable-diffusion-xl.hf.space/--replicas/ffe2bn2dk/")
48
  image_negative_prompt = "ultrarealistic, soft lighting, 8k, ugly, text, blurry"
49
  image_positive_prompt = ""
50
+ image_seed = 6
51
 
52
  processor = AutoProcessor.from_pretrained("suno/bark")
53
+
54
+ def format_speaker_key(key):
55
+ key = key.replace("v2/", "").split("_")
56
+
57
+ return f"Speaker {key[2]} ({key[0]})"
58
+
59
+
60
+ voice_presets = [key for key in processor.speaker_embeddings.keys() if "v2/en" in key]
61
+ voice_presets_dict = {
62
+ format_speaker_key(key): key for key in voice_presets
63
+ }
64
+
65
  model = BarkModel.from_pretrained("suno/bark", torch_dtype=torch.float16).to(device)
66
  sampling_rate = model.generation_config.sample_rate
67
  silence = np.zeros(int(0.25 * sampling_rate)) # quarter second of silence
 
137
  return output
138
 
139
 
140
+ def generate_audio_and_image(story_prompt, voice_preset="Speaker 6 (en)"):
141
 
142
 
143
  story = generate_story(story_prompt)
 
165
  inputs = model_input[BATCH_SIZE*i:min(BATCH_SIZE*(i+1), len(model_input))]
166
 
167
  if len(inputs) != 0:
168
+ inputs = processor(inputs, voice_preset=voice_presets_dict[voice_preset])
169
 
170
  speech_output, output_lengths = model.generate(**inputs.to(device), return_output_lengths=True, min_eos_p=0.2)
171
 
 
175
  pieces += [*speech_output, silence.copy()]
176
 
177
  print("Calling image")
178
+ try:
179
+ img = job_img.result()
180
+ except Exception as e:
181
+ print("Unhandled Exception: ", str(e))
182
+ gr.Warning("Unfortunately there was an issue when generating the image with SDXL.")
183
+ img = None
184
 
185
  return story, (sampling_rate, np.concatenate(pieces)), img
186
 
 
190
  # Gradio blocks demo
191
  with gr.Blocks() as demo_blocks:
192
  gr.Markdown("""<h1 align="center">🐶Children story</h1>""")
193
+ gr.HTML("""<h3 style="text-align:center;">Let Mistral tell you a story</h3>""")
194
  with gr.Group():
195
  with gr.Row():
196
  inp_text = gr.Textbox(label="Story prompt", info="Enter text here")
197
+ with gr.Accordion("Advanced settings", open=False):
198
+ voice_preset = gr.Dropdown(
199
+ voice_presets_dict,
200
+ value="Speaker 6 (en)",
201
+ label="Available speakers",
202
+ )
203
 
204
 
205
  with gr.Row():
 
212
  out_audio = gr.Audio(
213
  streaming=False, autoplay=True) # needed to stream output audio
214
  out_text = gr.Text()
215
+ btn.click(generate_audio_and_image, [inp_text, voice_preset], [out_text, out_audio, image_output] ) #[out_audio]) #, out_count])
216
 
217
+ with gr.Row():
218
+ gr.Examples(
219
+ [
220
+ "A panda going on an adventure with a caterpillar. This is a story teaching a wonderful life lesson.",
221
+ "A princess breaks free from a dragon's grip. This evocates women empowerement and freedom."
222
+ "Tell me about the wonders of the world.",
223
+ ],
224
+ [inp_text],
225
+ [out_text, out_audio, image_output],
226
+ generate_audio_and_image,
227
+ cache_examples=True,
228
+ )
229
+
230
+ gr.Markdown(
231
+ """
232
+ This Space uses **[Bark](https://huggingface.co/docs/transformers/main/en/model_doc/bark)**, [Mistral-7b-instruct](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) and [Fast SD-XL](https://huggingface.co/spaces/openskyml/fast-sdxl-stable-diffusion-xl)!
233
+ """
234
+ )
235
  demo_blocks.queue().launch(debug=True)