SemThink / app.py
Gijs Wijngaard
Add semantics
b95f6f3
raw
history blame
5.09 kB
# import spaces
import os
import re
import gradio as gr
import torch
import librosa
import numpy as np
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration, TextIteratorStreamer
import torchaudio
from threading import Thread
# Model paths and configuration
model_path_1 = "./model"
model_path_2 = "./model2"
base_model_id = "Qwen/Qwen2-Audio-7B-Instruct"
# Dictionary to store loaded models and processors
loaded_models = {}
# Load the model and processor
def load_model(model_path):
# Check if model is already loaded
if model_path in loaded_models:
return loaded_models[model_path]
# Load the processor from the base model
processor = AutoProcessor.from_pretrained(
base_model_id,
trust_remote_code=True,
)
# Load the model
model = Qwen2AudioForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
)
model.eval()
# Store in cache
loaded_models[model_path] = (model, processor)
return model, processor
# Initialize first model and processor
model, processor = load_model(model_path_1)
def process_output(output):
if "<think>" in output:
rest = output.split("<think>")[1]
output = "<think>\n" + rest
elif "<semantic_elements>" in output:
rest = output.split("<semantic_elements>")[1]
output = "<semantic_elements>\n" + rest
elif "<answer>" in output:
rest = output.split("<answer>")[1]
output = "<answer>\n" + rest
elif "</think>" in output:
rest = output.split("</think>")[0]
output = rest + "\n</think>\n\n"
elif "</semantic_elements>" in output:
rest = output.split("</semantic_elements>")[0]
output = rest + "\n</semantic_elements>\n\n"
elif "</answer>" in output:
rest = output.split("</answer>")[0]
output = rest + "\n</answer>\n"
output = output.replace("\\n", "\n")
output = output.replace("\\", "\n")
output = output.replace("\n-", "-")
return output
# Keep only the process_audio_streaming function that's actually used in the Gradio interface
# @spaces.GPU
def process_audio_streaming(audio_file, model_choice):
# Load the selected model
model_path = model_path_1 if model_choice == "Think" else model_path_2
model, processor = load_model(model_path)
# Load and process the audio with torchaudio
waveform, sr = torchaudio.load(audio_file)
# Resample to 16kHz if needed
if sr != 16000:
waveform = torchaudio.functional.resample(waveform, sr, 16000)
sr = 16000
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Get the audio data as numpy array
y = waveform.squeeze().numpy()
# Set sampling rate for the processor
sampling_rate = 16000
# Create conversation format
conversation = [
{"role": "user", "content": [
{"type": "audio", "audio": y},
{"type": "text", "text": "Describe the audio in detail."}
]}
]
# Format the chat
chat_text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
# Process the inputs
inputs = processor(
text=chat_text,
audios=[y],
return_tensors="pt",
sampling_rate=sampling_rate,
).to(model.device)
# Create a streamer instance
streamer = TextIteratorStreamer(
processor.tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
# Initialize an empty string to store the generated text
accumulated_output = ""
# Generate the output with streaming
with torch.no_grad():
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=768,
do_sample=False,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Yield the final outputs
for output in streamer:
output = process_output(output)
accumulated_output += output # Append new output to the accumulated string
yield accumulated_output # Yield the accumulated output
# Create Gradio interface for audio processing
audio_demo = gr.Interface(
fn=process_audio_streaming,
inputs=[
gr.Audio(type="filepath", label="Upload Audio"),
gr.Radio(["Think", "Think + Semantics"], label="Select Model", value="Think + Semantics")
],
outputs=gr.Textbox(label="Generated Output", lines=30),
title="SemThink",
description="Upload an audio file and the model will provide detailed analysis and description. Choose between different model versions.",
examples=[["examples/1.wav", "Think + Semantics"]], # Updated default model in examples
cache_examples=False,
live=True # Enable live updates
)
# Launch the apps
if __name__ == "__main__":
audio_demo.launch()