fred-dev commited on
Commit
110dfa4
·
verified ·
1 Parent(s): 64f1d20

fixed interface

Browse files
Files changed (1) hide show
  1. stable_audio_tools/interface/gradio.py +34 -102
stable_audio_tools/interface/gradio.py CHANGED
@@ -56,8 +56,6 @@ def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pr
56
  return model, model_config
57
 
58
  def generate_cond(
59
- prompt,
60
- negative_prompt=None,
61
  seconds_start=0,
62
  seconds_total=30,
63
  latitude = 0.0,
@@ -94,7 +92,6 @@ def generate_cond(
94
  torch.cuda.empty_cache()
95
  gc.collect()
96
 
97
- print(f"Prompt: {prompt}")
98
 
99
  global preview_images
100
  preview_images = []
@@ -102,12 +99,7 @@ def generate_cond(
102
  preview_every = None
103
 
104
  # Return fake stereo audio
105
- conditioning = [{"prompt": prompt, "latitude": -latitude, "longitude": longitude, "temperature": temperature, "humidity": humidity, "wind_speed": wind_speed, "pressure": pressure, "minutes_of_day": minutes_of_day,"day_of_year": day_of_year, "seconds_start":seconds_start, "seconds_total": seconds_total }] * batch_size
106
-
107
- if negative_prompt:
108
- negative_conditioning = [{"prompt": negative_prompt, "latitude": -latitude, "longitude": longitude, "temperature": temperature, "humidity": humidity, "wind_speed": wind_speed, "pressure": pressure, "minutes_of_day": minutes_of_day,"day_of_year": day_of_year, "seconds_start":seconds_start, "seconds_total": seconds_total}] * batch_size
109
- else:
110
- negative_conditioning = None
111
 
112
  #Get the device from the model
113
  device = next(model.parameters()).device
@@ -175,7 +167,6 @@ def generate_cond(
175
  audio = generate_diffusion_cond(
176
  model,
177
  conditioning=conditioning,
178
- negative_conditioning=negative_conditioning,
179
  steps=steps,
180
  cfg_scale=cfg_scale,
181
  batch_size=batch_size,
@@ -399,11 +390,9 @@ def create_conditioning_slider(min_val, max_val, label):
399
  print(f"Creating slider for {label} with min_val={min_val}, max_val={max_val}, step={step}, default_val={default_val}")
400
  return gr.Slider(minimum=min_val, maximum=max_val, step=step, value=default_val, label=label)
401
 
402
- def create_sampling_ui(model_config, inpainting=False):
403
  with gr.Row():
404
- with gr.Column(scale=6):
405
- prompt = gr.Textbox(show_label=False, placeholder="Prompt")
406
- negative_prompt = gr.Textbox(show_label=False, placeholder="Negative prompt")
407
  generate_button = gr.Button("Generate", variant='primary', scale=1)
408
 
409
  model_conditioning_config = model_config["model"].get("conditioning", None)
@@ -428,13 +417,13 @@ def create_sampling_ui(model_config, inpainting=False):
428
 
429
  with gr.Row():
430
  # Steps slider
431
- steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps")
432
 
433
  # Preview Every slider
434
  preview_every_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Preview Every")
435
 
436
  # CFG scale
437
- cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=0.2, label="CFG scale")
438
 
439
  with gr.Accordion("Climate and location", open=True):
440
  latitude_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "latitude"), None)
@@ -505,94 +494,37 @@ def create_sampling_ui(model_config, inpainting=False):
505
  sigma_max_slider = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=80, label="Sigma max")
506
  cfg_rescale_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=0.2, label="CFG rescale amount")
507
 
508
- if inpainting:
509
- # Inpainting Tab
510
- with gr.Accordion("Inpainting", open=False):
511
- sigma_max_slider.maximum=1000
512
-
513
- init_audio_checkbox = gr.Checkbox(label="Do inpainting")
514
- init_audio_input = gr.Audio(label="Init audio")
515
- init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.1, value=80, label="Init audio noise level", visible=False) # hide this
516
-
517
- mask_cropfrom_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Crop From %")
518
- mask_pastefrom_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Paste From %")
519
- mask_pasteto_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=100, label="Paste To %")
520
-
521
- mask_maskstart_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=50, label="Mask Start %")
522
- mask_maskend_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=100, label="Mask End %")
523
- mask_softnessL_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Softmask Left Crossfade Length %")
524
- mask_softnessR_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Softmask Right Crossfade Length %")
525
- mask_marination_slider = gr.Slider(minimum=0.0, maximum=1, step=0.0001, value=0, label="Marination level", visible=False) # still working on the usefulness of this
526
-
527
- inputs = [prompt,
528
- negative_prompt,
529
- seconds_start_slider,
530
- seconds_total_slider,
531
- latitude_slider,
532
- longitude_slider,
533
- temperature_slider,
534
- humidity_slider,
535
- wind_speed_slider,
536
- pressure_slider,
537
- minutes_of_day_slider,
538
- day_of_year_slider,
539
- cfg_scale_slider,
540
- steps_slider,
541
- preview_every_slider,
542
- seed_textbox,
543
- sampler_type_dropdown,
544
- sigma_min_slider,
545
- sigma_max_slider,
546
- cfg_rescale_slider,
547
- init_audio_checkbox,
548
- init_audio_input,
549
- init_noise_level_slider,
550
- mask_cropfrom_slider,
551
- mask_pastefrom_slider,
552
- mask_pasteto_slider,
553
- mask_maskstart_slider,
554
- mask_maskend_slider,
555
- mask_softnessL_slider,
556
- mask_softnessR_slider,
557
- mask_marination_slider
558
- ]
559
- else:
560
- # Default generation tab
561
- with gr.Accordion("Init audio", open=False):
562
- init_audio_checkbox = gr.Checkbox(label="Use init audio")
563
- init_audio_input = gr.Audio(label="Init audio")
564
- init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.01, value=0.1, label="Init noise level")
565
-
566
- inputs = [prompt,
567
- negative_prompt,
568
- seconds_start_slider,
569
- seconds_total_slider,
570
- latitude_slider,
571
- longitude_slider,
572
- temperature_slider,
573
- humidity_slider,
574
- wind_speed_slider,
575
- pressure_slider,
576
- minutes_of_day_slider,
577
- day_of_year_slider,
578
- cfg_scale_slider,
579
- steps_slider,
580
- preview_every_slider,
581
- seed_textbox,
582
- sampler_type_dropdown,
583
- sigma_min_slider,
584
- sigma_max_slider,
585
- cfg_rescale_slider,
586
- init_audio_checkbox,
587
- init_audio_input,
588
- init_noise_level_slider
589
- ]
590
 
591
  with gr.Column():
592
  audio_output = gr.Audio(label="Output audio", interactive=False)
593
  audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False)
594
- send_to_init_button = gr.Button("Send to init audio", scale=1)
595
- send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input])
596
 
597
  generate_button.click(fn=generate_cond,
598
  inputs=inputs,
@@ -607,8 +539,8 @@ def create_txt2audio_ui(model_config):
607
  with gr.Blocks() as ui:
608
  with gr.Tab("Generation"):
609
  create_sampling_ui(model_config)
610
- with gr.Tab("Inpainting"):
611
- create_sampling_ui(model_config, inpainting=True)
612
  return ui
613
 
614
  def create_diffusion_uncond_ui(model_config):
 
56
  return model, model_config
57
 
58
  def generate_cond(
 
 
59
  seconds_start=0,
60
  seconds_total=30,
61
  latitude = 0.0,
 
92
  torch.cuda.empty_cache()
93
  gc.collect()
94
 
 
95
 
96
  global preview_images
97
  preview_images = []
 
99
  preview_every = None
100
 
101
  # Return fake stereo audio
102
+ conditioning = [{"latitude": -latitude, "longitude": longitude, "temperature": temperature, "humidity": humidity, "wind_speed": wind_speed, "pressure": pressure, "minutes_of_day": minutes_of_day,"day_of_year": day_of_year, "seconds_start":seconds_start, "seconds_total": seconds_total }] * batch_size
 
 
 
 
 
103
 
104
  #Get the device from the model
105
  device = next(model.parameters()).device
 
167
  audio = generate_diffusion_cond(
168
  model,
169
  conditioning=conditioning,
 
170
  steps=steps,
171
  cfg_scale=cfg_scale,
172
  batch_size=batch_size,
 
390
  print(f"Creating slider for {label} with min_val={min_val}, max_val={max_val}, step={step}, default_val={default_val}")
391
  return gr.Slider(minimum=min_val, maximum=max_val, step=step, value=default_val, label=label)
392
 
393
+ def create_sampling_ui(model_config):
394
  with gr.Row():
395
+
 
 
396
  generate_button = gr.Button("Generate", variant='primary', scale=1)
397
 
398
  model_conditioning_config = model_config["model"].get("conditioning", None)
 
417
 
418
  with gr.Row():
419
  # Steps slider
420
+ steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=200, label="Steps")
421
 
422
  # Preview Every slider
423
  preview_every_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Preview Every")
424
 
425
  # CFG scale
426
+ cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=4.0, label="CFG scale")
427
 
428
  with gr.Accordion("Climate and location", open=True):
429
  latitude_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "latitude"), None)
 
494
  sigma_max_slider = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=80, label="Sigma max")
495
  cfg_rescale_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=0.2, label="CFG rescale amount")
496
 
497
+
498
+ # Default generation tab
499
+ with gr.Accordion("Init audio", open=False):
500
+ init_audio_input = gr.Audio(label="Init audio")
501
+ init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.01, value=0.1, label="Init noise level")
502
+
503
+ inputs = [
504
+ seconds_start_slider,
505
+ seconds_total_slider,
506
+ latitude_slider,
507
+ longitude_slider,
508
+ temperature_slider,
509
+ humidity_slider,
510
+ wind_speed_slider,
511
+ pressure_slider,
512
+ minutes_of_day_slider,
513
+ day_of_year_slider,
514
+ cfg_scale_slider,
515
+ steps_slider,
516
+ preview_every_slider,
517
+ seed_textbox,
518
+ sampler_type_dropdown,
519
+ sigma_min_slider,
520
+ sigma_max_slider,
521
+ cfg_rescale_slider,
522
+ init_noise_level_slider
523
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
 
525
  with gr.Column():
526
  audio_output = gr.Audio(label="Output audio", interactive=False)
527
  audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False)
 
 
528
 
529
  generate_button.click(fn=generate_cond,
530
  inputs=inputs,
 
539
  with gr.Blocks() as ui:
540
  with gr.Tab("Generation"):
541
  create_sampling_ui(model_config)
542
+ # with gr.Tab("Inpainting"):
543
+ # create_sampling_ui(model_config, inpainting=True)
544
  return ui
545
 
546
  def create_diffusion_uncond_ui(model_config):