fred-dev commited on
Commit
6a00aab
·
verified ·
1 Parent(s): e3dd36d

Added presets

Browse files
stable_audio_tools/interface/gradio.py CHANGED
@@ -18,10 +18,60 @@ from ..models.utils import load_ckpt_state_dict
18
  from ..inference.utils import prepare_audio
19
  from ..training.utils import copy_state_dict
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  model = None
22
  sample_rate = 44100
23
  sample_size = 524288
24
 
 
25
  def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False):
26
  global model, sample_rate, sample_size
27
 
@@ -426,6 +476,8 @@ def create_sampling_ui(model_config):
426
  cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=4.0, label="CFG scale")
427
 
428
  with gr.Accordion("Climate and location", open=True):
 
 
429
  latitude_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "latitude"), None)
430
  if latitude_config:
431
  latitude_slider = create_conditioning_slider(
@@ -542,6 +594,20 @@ def create_sampling_ui(model_config):
542
  ],
543
  api_name="generate")
544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
 
546
  def create_txt2audio_ui(model_config):
547
  with gr.Blocks() as ui:
 
18
  from ..inference.utils import prepare_audio
19
  from ..training.utils import copy_state_dict
20
 
21
+ # Define preset values
22
+ presets = {
23
+ "Pied Currawong": {
24
+ "latitude": -33.6467,
25
+ "longitude": 150.3246,
26
+ "temperature": 12.43,
27
+ "humidity": 86,
28
+ "wind_speed": 0.66,
29
+ "pressure": 1013,
30
+ "minutes_of_day": 369,
31
+ "day_of_year": 297,
32
+ },
33
+ "Yellow-tailed Black Cockatoo": {
34
+ "latitude": -32.8334,
35
+ "longitude": 150.2001,
36
+ "temperature": 23.23,
37
+ "humidity": 45,
38
+ "wind_speed": 1.37,
39
+ "pressure": 1009,
40
+ "minutes_of_day": 986,
41
+ "day_of_year": 78,
42
+ },
43
+ "Australian Magpie": {
44
+ "latitude": -38.522,
45
+ "longitude": 145.3365,
46
+ "temperature": 18.75,
47
+ "humidity": 67,
48
+ "wind_speed": 1.5,
49
+ "pressure": 1023,
50
+ "minutes_of_day": 940,
51
+ "day_of_year": 307,
52
+ },
53
+ "Laughing Kookaburra": {
54
+ "latitude": -27.2685099,
55
+ "longitude": 152.8587437,
56
+ "temperature": 9.02,
57
+ "humidity": 94,
58
+ "wind_speed": 1.5,
59
+ "pressure": 1025,
60
+ "minutes_of_day": 320,
61
+ "day_of_year": 236,
62
+ }
63
+ }
64
+
65
+ def update_sliders(preset_name):
66
+ preset = presets[preset_name]
67
+ return (preset["latitude"], preset["longitude"], preset["temperature"], preset["humidity"], preset["wind_speed"], preset["pressure"], preset["minutes_of_day"], preset["day_of_year"])
68
+
69
+
70
  model = None
71
  sample_rate = 44100
72
  sample_size = 524288
73
 
74
+
75
  def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False):
76
  global model, sample_rate, sample_size
77
 
 
476
  cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=4.0, label="CFG scale")
477
 
478
  with gr.Accordion("Climate and location", open=True):
479
+ preset_dropdown = gr.Dropdown(choices=list(presets.keys()), label="Select Preset")
480
+
481
  latitude_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "latitude"), None)
482
  if latitude_config:
483
  latitude_slider = create_conditioning_slider(
 
594
  ],
595
  api_name="generate")
596
 
597
+ preset_dropdown.change(
598
+ fn=update_sliders,
599
+ inputs=[preset_dropdown],
600
+ outputs=[
601
+ latitude_slider,
602
+ longitude_slider,
603
+ temperature_slider,
604
+ humidity_slider,
605
+ wind_speed_slider,
606
+ pressure_slider,
607
+ minutes_of_day_slider,
608
+ day_of_year_slider
609
+ ]
610
+ )
611
 
612
  def create_txt2audio_ui(model_config):
613
  with gr.Blocks() as ui: