SemThink / app.py
Gijs Wijngaard
Revert
ce70672
import spaces
import os
import re
import gradio as gr
import torch
import librosa
import numpy as np
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration, TextIteratorStreamer
import torchaudio
from threading import Thread
# Model paths and configuration
model_path_1 = "./model"
model_path_2 = "./model2"
base_model_id = "Qwen/Qwen2-Audio-7B-Instruct"
# Dictionary to store loaded models and processors
loaded_models = {}
# Load the model and processor
def load_model(model_path):
# Check if model is already loaded
if model_path in loaded_models:
return loaded_models[model_path]
# Load the processor from the base model
processor = AutoProcessor.from_pretrained(
base_model_id,
trust_remote_code=True,
)
# Load the model
model = Qwen2AudioForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
)
model.eval()
# Store in cache
loaded_models[model_path] = (model, processor)
return model, processor
# Initialize first model and processor
model, processor = load_model(model_path_1)
def process_output(output):
if "<think>" in output:
rest = output.split("<think>")[1]
output = "<think>\n" + rest
elif "<semantic_elements>" in output:
rest = output.split("<semantic_elements>")[1]
output = "<semantic_elements>\n" + rest
elif "<answer>" in output:
rest = output.split("<answer>")[1]
output = "<answer>\n" + rest
elif "</think>" in output:
rest = output.split("</think>")[0]
output = rest + "\n</think>\n\n"
elif "</semantic_elements>" in output:
rest = output.split("</semantic_elements>")[0]
output = rest + "\n</semantic_elements>\n\n"
elif "</answer>" in output:
rest = output.split("</answer>")[0]
output = rest + "\n</answer>\n"
output = output.replace("\\n", "\n")
output = output.replace("\\", "\n")
output = output.replace("\n-", "-")
return output
# Keep only the process_audio_streaming function that's actually used in the Gradio interface
@spaces.GPU
def process_audio_streaming(audio_file, model_choice):
# Load the selected model
model_path = model_path_1 if model_choice == "Think" else model_path_2
model, processor = load_model(model_path)
# Load and process the audio with torchaudio
waveform, sr = torchaudio.load(audio_file)
# Resample to 16kHz if needed
if sr != 16000:
waveform = torchaudio.functional.resample(waveform, sr, 16000)
sr = 16000
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Get the audio data as numpy array
y = waveform.squeeze().numpy()
# Set sampling rate for the processor
sampling_rate = 16000
# Create conversation format
conversation = [
{"role": "user", "content": [
{"type": "audio", "audio": y},
{"type": "text", "text": "Describe the audio in detail."}
]}
]
# Format the chat
chat_text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
# Process the inputs
inputs = processor(
text=chat_text,
audios=[y],
return_tensors="pt",
sampling_rate=sampling_rate,
).to(model.device)
# Create a streamer instance
streamer = TextIteratorStreamer(
processor.tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
# Initialize an empty string to store the generated text
accumulated_output = ""
# Generate the output with streaming
with torch.no_grad():
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=768,
do_sample=False,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Yield the final outputs
for output in streamer:
output = process_output(output)
accumulated_output += output # Append new output to the accumulated string
yield accumulated_output # Yield the accumulated output
# Create Gradio interface for audio processing
audio_demo = gr.Interface(
fn=process_audio_streaming,
inputs=[
gr.Audio(type="filepath", label="Upload Audio"),
gr.Radio(["Think", "Think + Semantics"], label="Select Model", value="Think + Semantics")
],
outputs=gr.Textbox(label="Generated Output", lines=30),
title="SemThink",
description="Upload an audio file and the model will provide detailed analysis and description. Choose between different model versions.",
examples=[["examples/1.wav", "Think + Semantics"]], # Updated default model in examples
cache_examples=False,
live=True # Enable live updates
)
# Launch the apps
if __name__ == "__main__":
audio_demo.launch()