Gijs Wijngaard
commited on
Commit
·
b770eaa
1
Parent(s):
5ee12ec
Finished
Browse files- app.py +96 -40
- examples/1.wav +0 -0
app.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1 |
-
import spaces
|
2 |
import os
|
3 |
import re
|
4 |
import gradio as gr
|
5 |
import torch
|
6 |
import librosa
|
7 |
import numpy as np
|
8 |
-
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
|
|
|
|
|
9 |
|
10 |
# Model path and configuration
|
11 |
model_path = "./model"
|
@@ -21,7 +23,7 @@ def load_model():
|
|
21 |
|
22 |
# Load the base model
|
23 |
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
24 |
-
|
25 |
torch_dtype=torch.bfloat16,
|
26 |
trust_remote_code=True,
|
27 |
device_map="auto",
|
@@ -57,21 +59,78 @@ def extract_components(text):
|
|
57 |
|
58 |
return thinking, semantic, answer
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
-
|
63 |
-
def
|
64 |
-
# Load and process the audio with
|
65 |
-
|
66 |
|
67 |
# Resample to 16kHz if needed
|
68 |
if sr != 16000:
|
69 |
-
|
70 |
sr = 16000
|
71 |
|
72 |
# Convert to mono if stereo
|
73 |
-
if
|
74 |
-
|
|
|
|
|
|
|
75 |
|
76 |
# Set sampling rate for the processor
|
77 |
sampling_rate = 16000
|
@@ -95,45 +154,42 @@ def process_audio(audio_file):
|
|
95 |
sampling_rate=sampling_rate,
|
96 |
).to(model.device)
|
97 |
|
98 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
with torch.no_grad():
|
100 |
-
|
101 |
**inputs,
|
|
|
102 |
max_new_tokens=768,
|
103 |
do_sample=False,
|
104 |
)
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
assistant_text = assistant_text.replace("<think>", "\n<think>")
|
114 |
-
|
115 |
-
if "<semantic_elements>" in assistant_text:
|
116 |
-
assistant_text = assistant_text.replace("<semantic_elements>", "\n<semantic_elements>")
|
117 |
-
|
118 |
-
if "<answer>" in assistant_text:
|
119 |
-
assistant_text = assistant_text.replace("<answer>", "\n<answer>")
|
120 |
-
|
121 |
-
|
122 |
-
# Combine all components into a single output
|
123 |
-
|
124 |
-
return assistant_text
|
125 |
|
126 |
-
# Create Gradio interface
|
127 |
-
|
128 |
-
fn=
|
129 |
inputs=gr.Audio(type="filepath", label="Upload Audio"),
|
130 |
-
outputs=gr.Textbox(label="
|
131 |
-
title="
|
132 |
description="Upload an audio file and the model will provide detailed analysis and description.",
|
133 |
-
examples=[], # Add example files here if available
|
134 |
cache_examples=False,
|
|
|
135 |
)
|
136 |
|
137 |
-
# Launch the
|
138 |
if __name__ == "__main__":
|
139 |
-
|
|
|
1 |
+
# import spaces
|
2 |
import os
|
3 |
import re
|
4 |
import gradio as gr
|
5 |
import torch
|
6 |
import librosa
|
7 |
import numpy as np
|
8 |
+
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration, TextIteratorStreamer
|
9 |
+
import torchaudio
|
10 |
+
from threading import Thread
|
11 |
|
12 |
# Model path and configuration
|
13 |
model_path = "./model"
|
|
|
23 |
|
24 |
# Load the base model
|
25 |
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
26 |
+
model_path,
|
27 |
torch_dtype=torch.bfloat16,
|
28 |
trust_remote_code=True,
|
29 |
device_map="auto",
|
|
|
59 |
|
60 |
return thinking, semantic, answer
|
61 |
|
62 |
+
# Function to handle chat messages
|
63 |
+
def chat(message, history):
|
64 |
+
chat = []
|
65 |
+
for item in history:
|
66 |
+
chat.append({"role": "user", "content": item[0]})
|
67 |
+
if item[1] is not None:
|
68 |
+
chat.append({"role": "assistant", "content": item[1]})
|
69 |
+
chat.append({"role": "user", "content": message})
|
70 |
+
messages = processor.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
|
71 |
+
# Tokenize the messages string
|
72 |
+
model_inputs = processor([messages], return_tensors="pt").to(model.device)
|
73 |
+
streamer = TextIteratorStreamer(
|
74 |
+
processor.tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
|
75 |
+
generate_kwargs = dict(
|
76 |
+
model_inputs,
|
77 |
+
streamer=streamer,
|
78 |
+
max_new_tokens=1024,
|
79 |
+
do_sample=True,
|
80 |
+
top_p=0.95,
|
81 |
+
top_k=1000,
|
82 |
+
temperature=0.75,
|
83 |
+
num_beams=1,
|
84 |
+
)
|
85 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
86 |
+
t.start()
|
87 |
+
|
88 |
+
# Initialize an empty string to store the generated text
|
89 |
+
partial_text = ""
|
90 |
+
for new_text in streamer:
|
91 |
+
# print(new_text)
|
92 |
+
partial_text += new_text
|
93 |
+
# Yield an empty string to cleanup the message textbox and the updated conversation history
|
94 |
+
yield partial_text
|
95 |
+
|
96 |
|
97 |
+
def process_output(output):
|
98 |
+
if "<think>" in output:
|
99 |
+
rest = output.split("<think>")[1]
|
100 |
+
output = "<think>\n" + rest
|
101 |
+
elif "<semantic_elements>" in output:
|
102 |
+
rest = output.split("<semantic_elements>")[1]
|
103 |
+
output = "<semantic_elements>\n" + rest
|
104 |
+
elif "<answer>" in output:
|
105 |
+
rest = output.split("<answer>")[1]
|
106 |
+
output = "<answer>\n" + rest
|
107 |
+
elif "</think>" in output:
|
108 |
+
rest = output.split("</think>")[0]
|
109 |
+
output = rest + "\n</think>\n"
|
110 |
+
elif "</semantic_elements>" in output:
|
111 |
+
rest = output.split("</semantic_elements>")[0]
|
112 |
+
output = rest + "\n</semantic_elements>\n"
|
113 |
+
elif "</answer>" in output:
|
114 |
+
rest = output.split("</answer>")[0]
|
115 |
+
output = rest + "\n</answer>\n"
|
116 |
+
return output
|
117 |
|
118 |
+
# Keep only the process_audio_streaming function that's actually used in the Gradio interface
|
119 |
+
def process_audio_streaming(audio_file):
|
120 |
+
# Load and process the audio with torchaudio
|
121 |
+
waveform, sr = torchaudio.load(audio_file)
|
122 |
|
123 |
# Resample to 16kHz if needed
|
124 |
if sr != 16000:
|
125 |
+
waveform = torchaudio.functional.resample(waveform, sr, 16000)
|
126 |
sr = 16000
|
127 |
|
128 |
# Convert to mono if stereo
|
129 |
+
if waveform.shape[0] > 1:
|
130 |
+
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
131 |
+
|
132 |
+
# Get the audio data as numpy array
|
133 |
+
y = waveform.squeeze().numpy()
|
134 |
|
135 |
# Set sampling rate for the processor
|
136 |
sampling_rate = 16000
|
|
|
154 |
sampling_rate=sampling_rate,
|
155 |
).to(model.device)
|
156 |
|
157 |
+
# Create a streamer instance
|
158 |
+
streamer = TextIteratorStreamer(
|
159 |
+
processor.tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
|
160 |
+
|
161 |
+
# Initialize an empty string to store the generated text
|
162 |
+
accumulated_output = ""
|
163 |
+
|
164 |
+
# Generate the output with streaming
|
165 |
with torch.no_grad():
|
166 |
+
generate_kwargs = dict(
|
167 |
**inputs,
|
168 |
+
streamer=streamer,
|
169 |
max_new_tokens=768,
|
170 |
do_sample=False,
|
171 |
)
|
172 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
173 |
+
t.start()
|
174 |
+
|
175 |
+
# Yield the final outputs
|
176 |
+
for output in streamer:
|
177 |
+
output = process_output(output)
|
178 |
+
accumulated_output += output # Append new output to the accumulated string
|
179 |
+
yield accumulated_output # Yield the accumulated output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
+
# Create Gradio interface for audio processing
|
182 |
+
audio_demo = gr.Interface(
|
183 |
+
fn=process_audio_streaming,
|
184 |
inputs=gr.Audio(type="filepath", label="Upload Audio"),
|
185 |
+
outputs=gr.Textbox(label="Generated Output", lines=24),
|
186 |
+
title="SemThink",
|
187 |
description="Upload an audio file and the model will provide detailed analysis and description.",
|
188 |
+
examples=["examples/1.wav"], # Add example files here if available
|
189 |
cache_examples=False,
|
190 |
+
live=True # Enable live updates
|
191 |
)
|
192 |
|
193 |
+
# Launch the apps
|
194 |
if __name__ == "__main__":
|
195 |
+
audio_demo.launch()
|
examples/1.wav
ADDED
Binary file (163 kB). View file
|
|