Gijs Wijngaard commited on
Commit
b770eaa
·
1 Parent(s): 5ee12ec
Files changed (2) hide show
  1. app.py +96 -40
  2. examples/1.wav +0 -0
app.py CHANGED
@@ -1,11 +1,13 @@
1
- import spaces
2
  import os
3
  import re
4
  import gradio as gr
5
  import torch
6
  import librosa
7
  import numpy as np
8
- from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
 
 
9
 
10
  # Model path and configuration
11
  model_path = "./model"
@@ -21,7 +23,7 @@ def load_model():
21
 
22
  # Load the base model
23
  model = Qwen2AudioForConditionalGeneration.from_pretrained(
24
- base_model_id,
25
  torch_dtype=torch.bfloat16,
26
  trust_remote_code=True,
27
  device_map="auto",
@@ -57,21 +59,78 @@ def extract_components(text):
57
 
58
  return thinking, semantic, answer
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- @spaces.GPU
63
- def process_audio(audio_file):
64
- # Load and process the audio with librosa
65
- y, sr = librosa.load(audio_file, sr=None) # Load audio file
66
 
67
  # Resample to 16kHz if needed
68
  if sr != 16000:
69
- y = librosa.resample(y, orig_sr=sr, target_sr=16000)
70
  sr = 16000
71
 
72
  # Convert to mono if stereo
73
- if len(y.shape) > 1 and y.shape[1] > 1:
74
- y = librosa.to_mono(y)
 
 
 
75
 
76
  # Set sampling rate for the processor
77
  sampling_rate = 16000
@@ -95,45 +154,42 @@ def process_audio(audio_file):
95
  sampling_rate=sampling_rate,
96
  ).to(model.device)
97
 
98
- # Generate the output
 
 
 
 
 
 
 
99
  with torch.no_grad():
100
- outputs = model.generate(
101
  **inputs,
 
102
  max_new_tokens=768,
103
  do_sample=False,
104
  )
105
-
106
- # Decode the output
107
- generated_text = processor.tokenizer.decode(outputs[0], skip_special_tokens=False)
108
- assistant_text = generated_text.split("\nassistant\n")[1]
109
-
110
- # Extract sections from the response
111
- # Add newlines before XML tags if they exist
112
- if "<think>" in assistant_text:
113
- assistant_text = assistant_text.replace("<think>", "\n<think>")
114
-
115
- if "<semantic_elements>" in assistant_text:
116
- assistant_text = assistant_text.replace("<semantic_elements>", "\n<semantic_elements>")
117
-
118
- if "<answer>" in assistant_text:
119
- assistant_text = assistant_text.replace("<answer>", "\n<answer>")
120
-
121
-
122
- # Combine all components into a single output
123
-
124
- return assistant_text
125
 
126
- # Create Gradio interface
127
- demo = gr.Interface(
128
- fn=process_audio,
129
  inputs=gr.Audio(type="filepath", label="Upload Audio"),
130
- outputs=gr.Textbox(label="Analysis Result", lines=20),
131
- title="Qwen2Audio Audio Description Demo",
132
  description="Upload an audio file and the model will provide detailed analysis and description.",
133
- examples=[], # Add example files here if available
134
  cache_examples=False,
 
135
  )
136
 
137
- # Launch the app
138
  if __name__ == "__main__":
139
- demo.launch()
 
1
+ # import spaces
2
  import os
3
  import re
4
  import gradio as gr
5
  import torch
6
  import librosa
7
  import numpy as np
8
+ from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration, TextIteratorStreamer
9
+ import torchaudio
10
+ from threading import Thread
11
 
12
  # Model path and configuration
13
  model_path = "./model"
 
23
 
24
  # Load the base model
25
  model = Qwen2AudioForConditionalGeneration.from_pretrained(
26
+ model_path,
27
  torch_dtype=torch.bfloat16,
28
  trust_remote_code=True,
29
  device_map="auto",
 
59
 
60
  return thinking, semantic, answer
61
 
62
+ # Function to handle chat messages
63
+ def chat(message, history):
64
+ chat = []
65
+ for item in history:
66
+ chat.append({"role": "user", "content": item[0]})
67
+ if item[1] is not None:
68
+ chat.append({"role": "assistant", "content": item[1]})
69
+ chat.append({"role": "user", "content": message})
70
+ messages = processor.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
71
+ # Tokenize the messages string
72
+ model_inputs = processor([messages], return_tensors="pt").to(model.device)
73
+ streamer = TextIteratorStreamer(
74
+ processor.tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
75
+ generate_kwargs = dict(
76
+ model_inputs,
77
+ streamer=streamer,
78
+ max_new_tokens=1024,
79
+ do_sample=True,
80
+ top_p=0.95,
81
+ top_k=1000,
82
+ temperature=0.75,
83
+ num_beams=1,
84
+ )
85
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
86
+ t.start()
87
+
88
+ # Initialize an empty string to store the generated text
89
+ partial_text = ""
90
+ for new_text in streamer:
91
+ # print(new_text)
92
+ partial_text += new_text
93
+ # Yield an empty string to cleanup the message textbox and the updated conversation history
94
+ yield partial_text
95
+
96
 
97
+ def process_output(output):
98
+ if "<think>" in output:
99
+ rest = output.split("<think>")[1]
100
+ output = "<think>\n" + rest
101
+ elif "<semantic_elements>" in output:
102
+ rest = output.split("<semantic_elements>")[1]
103
+ output = "<semantic_elements>\n" + rest
104
+ elif "<answer>" in output:
105
+ rest = output.split("<answer>")[1]
106
+ output = "<answer>\n" + rest
107
+ elif "</think>" in output:
108
+ rest = output.split("</think>")[0]
109
+ output = rest + "\n</think>\n"
110
+ elif "</semantic_elements>" in output:
111
+ rest = output.split("</semantic_elements>")[0]
112
+ output = rest + "\n</semantic_elements>\n"
113
+ elif "</answer>" in output:
114
+ rest = output.split("</answer>")[0]
115
+ output = rest + "\n</answer>\n"
116
+ return output
117
 
118
+ # Keep only the process_audio_streaming function that's actually used in the Gradio interface
119
+ def process_audio_streaming(audio_file):
120
+ # Load and process the audio with torchaudio
121
+ waveform, sr = torchaudio.load(audio_file)
122
 
123
  # Resample to 16kHz if needed
124
  if sr != 16000:
125
+ waveform = torchaudio.functional.resample(waveform, sr, 16000)
126
  sr = 16000
127
 
128
  # Convert to mono if stereo
129
+ if waveform.shape[0] > 1:
130
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
131
+
132
+ # Get the audio data as numpy array
133
+ y = waveform.squeeze().numpy()
134
 
135
  # Set sampling rate for the processor
136
  sampling_rate = 16000
 
154
  sampling_rate=sampling_rate,
155
  ).to(model.device)
156
 
157
+ # Create a streamer instance
158
+ streamer = TextIteratorStreamer(
159
+ processor.tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
160
+
161
+ # Initialize an empty string to store the generated text
162
+ accumulated_output = ""
163
+
164
+ # Generate the output with streaming
165
  with torch.no_grad():
166
+ generate_kwargs = dict(
167
  **inputs,
168
+ streamer=streamer,
169
  max_new_tokens=768,
170
  do_sample=False,
171
  )
172
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
173
+ t.start()
174
+
175
+ # Yield the final outputs
176
+ for output in streamer:
177
+ output = process_output(output)
178
+ accumulated_output += output # Append new output to the accumulated string
179
+ yield accumulated_output # Yield the accumulated output
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
+ # Create Gradio interface for audio processing
182
+ audio_demo = gr.Interface(
183
+ fn=process_audio_streaming,
184
  inputs=gr.Audio(type="filepath", label="Upload Audio"),
185
+ outputs=gr.Textbox(label="Generated Output", lines=24),
186
+ title="SemThink",
187
  description="Upload an audio file and the model will provide detailed analysis and description.",
188
+ examples=["examples/1.wav"], # Add example files here if available
189
  cache_examples=False,
190
+ live=True # Enable live updates
191
  )
192
 
193
+ # Launch the apps
194
  if __name__ == "__main__":
195
+ audio_demo.launch()
examples/1.wav ADDED
Binary file (163 kB). View file