Bomme commited on
Commit
af70a95
·
1 Parent(s): 993b9ab

bring back the config file

Browse files
Files changed (2) hide show
  1. app.py +15 -18
  2. configs/inference.yml +61 -0
app.py CHANGED
@@ -8,26 +8,14 @@ import gradio as gr
8
  import spaces
9
  import torch
10
 
 
11
  from NatureLM.models.NatureLM import NatureLM
12
  from NatureLM.utils import generate_sample_batches, prepare_sample_waveforms
13
 
14
- CONFIG = None
15
  MODEL: NatureLM = None
16
 
17
 
18
- class DummyConfig:
19
- def __init__(self):
20
- self.generate = {
21
- "max_new_tokens": 300,
22
- "num_beams": 2,
23
- "do_sample": False,
24
- "min_length": 1,
25
- "temperature": 0.1,
26
- "repetition_penalty": 1.0,
27
- "length_penalty": 1.0,
28
- }
29
-
30
-
31
  @spaces.GPU
32
  def prompt_lm(audios: list[str], messages: list[dict[str, str]]):
33
  cuda_enabled = torch.cuda.is_available()
@@ -277,8 +265,12 @@ def _long_recording_tab():
277
  )
278
 
279
 
280
- def main(assets_dir: Path, device: str = "cuda"):
281
- cfg = DummyConfig()
 
 
 
 
282
  model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
283
  model.to(device)
284
  model.eval()
@@ -335,7 +327,12 @@ if __name__ == "__main__":
335
  default=Path(__file__).parent / "assets",
336
  help="Directory containing the assets (favicon, examples, etc.)",
337
  )
338
-
 
 
 
 
 
339
  args = parser.parse_args()
340
 
341
- main(args.assets_dir)
 
8
  import spaces
9
  import torch
10
 
11
+ from NatureLM.config import Config
12
  from NatureLM.models.NatureLM import NatureLM
13
  from NatureLM.utils import generate_sample_batches, prepare_sample_waveforms
14
 
15
+ CONFIG: Config = None
16
  MODEL: NatureLM = None
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  @spaces.GPU
20
  def prompt_lm(audios: list[str], messages: list[dict[str, str]]):
21
  cuda_enabled = torch.cuda.is_available()
 
265
  )
266
 
267
 
268
+ def main(
269
+ assets_dir: Path,
270
+ cfg_path: str | Path,
271
+ device: str = "cuda",
272
+ ):
273
+ cfg = Config.from_sources(yaml_file=cfg_path)
274
  model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
275
  model.to(device)
276
  model.eval()
 
327
  default=Path(__file__).parent / "assets",
328
  help="Directory containing the assets (favicon, examples, etc.)",
329
  )
330
+ parser.add_argument(
331
+ "--cfg-path",
332
+ type=str,
333
+ default=Path(__file__).parent / "configs/inference.yml",
334
+ help="Path to the config file",
335
+ )
336
  args = parser.parse_args()
337
 
338
+ main(args.assets_dir, args.cfg_path)
configs/inference.yml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ llama_path: "meta-llama/Meta-Llama-3.1-8B-Instruct"
3
+
4
+ freeze_beats: True
5
+
6
+ use_audio_Qformer: True
7
+ max_pooling: False
8
+ downsample_factor: 8
9
+ freeze_audio_QFormer: False
10
+ window_level_Qformer: True
11
+ num_audio_query_token: 1
12
+ second_per_window: 0.333333
13
+ second_stride: 0.333333
14
+
15
+ audio_llama_proj_model: ""
16
+ freeze_audio_llama_proj: False
17
+
18
+ lora: True
19
+ lora_rank: 32
20
+ lora_alpha: 32
21
+ lora_dropout: 0.1
22
+
23
+ prompt_template: "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
24
+ max_txt_len: 160
25
+ end_sym: <|end_of_text|>
26
+
27
+ beats_cfg:
28
+ input_patch_size: 16
29
+ embed_dim: 512
30
+ conv_bias: False
31
+ encoder_layers: 12
32
+ encoder_embed_dim: 768
33
+ encoder_ffn_embed_dim: 3072
34
+ encoder_attention_heads: 12
35
+ activation_fn: "gelu"
36
+ layer_wise_gradient_decay_ratio: 0.6
37
+ layer_norm_first: False
38
+ deep_norm: True
39
+ dropout: 0.0
40
+ attention_dropout: 0.0
41
+ activation_dropout: 0.0
42
+ encoder_layerdrop: 0.05
43
+ dropout_input: 0.0
44
+ conv_pos: 128
45
+ conv_pos_groups: 16
46
+ relative_position_embedding: True
47
+ num_buckets: 320
48
+ max_distance: 800
49
+ gru_rel_pos: True
50
+ finetuned_model: True
51
+ predictor_dropout: 0.0
52
+ predictor_class: 527
53
+
54
+ generate:
55
+ max_new_tokens: 300
56
+ num_beams: 2
57
+ do_sample: False
58
+ min_length: 1
59
+ temperature: 0.1
60
+ repetition_penalty: 1.0
61
+ length_penalty: 1.0