Spaces:
Paused
Paused
Added presets
Browse files
stable_audio_tools/interface/gradio.py
CHANGED
@@ -18,10 +18,60 @@ from ..models.utils import load_ckpt_state_dict
|
|
18 |
from ..inference.utils import prepare_audio
|
19 |
from ..training.utils import copy_state_dict
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
model = None
|
22 |
sample_rate = 44100
|
23 |
sample_size = 524288
|
24 |
|
|
|
25 |
def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False):
|
26 |
global model, sample_rate, sample_size
|
27 |
|
@@ -426,6 +476,8 @@ def create_sampling_ui(model_config):
|
|
426 |
cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=4.0, label="CFG scale")
|
427 |
|
428 |
with gr.Accordion("Climate and location", open=True):
|
|
|
|
|
429 |
latitude_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "latitude"), None)
|
430 |
if latitude_config:
|
431 |
latitude_slider = create_conditioning_slider(
|
@@ -542,6 +594,20 @@ def create_sampling_ui(model_config):
|
|
542 |
],
|
543 |
api_name="generate")
|
544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
545 |
|
546 |
def create_txt2audio_ui(model_config):
|
547 |
with gr.Blocks() as ui:
|
|
|
18 |
from ..inference.utils import prepare_audio
|
19 |
from ..training.utils import copy_state_dict
|
20 |
|
21 |
+
# Define preset values
|
22 |
+
presets = {
|
23 |
+
"Pied Currawong": {
|
24 |
+
"latitude": -33.6467,
|
25 |
+
"longitude": 150.3246,
|
26 |
+
"temperature": 12.43,
|
27 |
+
"humidity": 86,
|
28 |
+
"wind_speed": 0.66,
|
29 |
+
"pressure": 1013,
|
30 |
+
"minutes_of_day": 369,
|
31 |
+
"day_of_year": 297,
|
32 |
+
},
|
33 |
+
"Yellow-tailed Black Cockatoo": {
|
34 |
+
"latitude": -32.8334,
|
35 |
+
"longitude": 150.2001,
|
36 |
+
"temperature": 23.23,
|
37 |
+
"humidity": 45,
|
38 |
+
"wind_speed": 1.37,
|
39 |
+
"pressure": 1009,
|
40 |
+
"minutes_of_day": 986,
|
41 |
+
"day_of_year": 78,
|
42 |
+
},
|
43 |
+
"Australian Magpie": {
|
44 |
+
"latitude": -38.522,
|
45 |
+
"longitude": 145.3365,
|
46 |
+
"temperature": 18.75,
|
47 |
+
"humidity": 67,
|
48 |
+
"wind_speed": 1.5,
|
49 |
+
"pressure": 1023,
|
50 |
+
"minutes_of_day": 940,
|
51 |
+
"day_of_year": 307,
|
52 |
+
},
|
53 |
+
"Laughing Kookaburra": {
|
54 |
+
"latitude": -27.2685099,
|
55 |
+
"longitude": 152.8587437,
|
56 |
+
"temperature": 9.02,
|
57 |
+
"humidity": 94,
|
58 |
+
"wind_speed": 1.5,
|
59 |
+
"pressure": 1025,
|
60 |
+
"minutes_of_day": 320,
|
61 |
+
"day_of_year": 236,
|
62 |
+
}
|
63 |
+
}
|
64 |
+
|
65 |
+
def update_sliders(preset_name):
|
66 |
+
preset = presets[preset_name]
|
67 |
+
return (preset["latitude"], preset["longitude"], preset["temperature"], preset["humidity"], preset["wind_speed"], preset["pressure"], preset["minutes_of_day"], preset["day_of_year"])
|
68 |
+
|
69 |
+
|
70 |
model = None
|
71 |
sample_rate = 44100
|
72 |
sample_size = 524288
|
73 |
|
74 |
+
|
75 |
def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False):
|
76 |
global model, sample_rate, sample_size
|
77 |
|
|
|
476 |
cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=4.0, label="CFG scale")
|
477 |
|
478 |
with gr.Accordion("Climate and location", open=True):
|
479 |
+
preset_dropdown = gr.Dropdown(choices=list(presets.keys()), label="Select Preset")
|
480 |
+
|
481 |
latitude_config = next((item for item in model_conditioning_config["configs"] if item["id"] == "latitude"), None)
|
482 |
if latitude_config:
|
483 |
latitude_slider = create_conditioning_slider(
|
|
|
594 |
],
|
595 |
api_name="generate")
|
596 |
|
597 |
+
preset_dropdown.change(
|
598 |
+
fn=update_sliders,
|
599 |
+
inputs=[preset_dropdown],
|
600 |
+
outputs=[
|
601 |
+
latitude_slider,
|
602 |
+
longitude_slider,
|
603 |
+
temperature_slider,
|
604 |
+
humidity_slider,
|
605 |
+
wind_speed_slider,
|
606 |
+
pressure_slider,
|
607 |
+
minutes_of_day_slider,
|
608 |
+
day_of_year_slider
|
609 |
+
]
|
610 |
+
)
|
611 |
|
612 |
def create_txt2audio_ui(model_config):
|
613 |
with gr.Blocks() as ui:
|