SemThink / app.py
Gijs Wijngaard
Add eval
04a31f5
raw
history blame
3.53 kB
import os
import re
import gradio as gr
import torch
from transformers import AutoProcessor
from qwen import Qwen2AudioForConditionalGeneration
from peft import PeftModel, PeftConfig
# Model path and configuration
model_path = "./model"
base_model_id = "Qwen/Qwen2-Audio-7B-Instruct"
# Load the model and processor
def load_model():
# Load the processor from the base model
processor = AutoProcessor.from_pretrained(
base_model_id,
trust_remote_code=True,
)
# Load the base model
base_model = Qwen2AudioForConditionalGeneration.from_pretrained(
base_model_id,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
)
# Load the LoRA adapter
model = PeftModel.from_pretrained(base_model, model_path)
model.eval()
return model, processor
# Initialize model and processor
model, processor = load_model()
# Function to extract components from model output
def extract_components(text):
thinking = ""
semantic = ""
answer = ""
# Extract thinking
think_match = re.search(r"<think>(.*?)</think>", text, re.DOTALL)
if think_match:
thinking = think_match.group(1).strip()
# Extract semantic elements
semantic_match = re.search(r"<semantic_elements>(.*?)</semantic_elements>", text, re.DOTALL)
if semantic_match:
semantic = semantic_match.group(1).strip()
# Extract answer
answer_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
if answer_match:
answer = answer_match.group(1).strip()
return thinking, semantic, answer
# Function to process audio and return components
def process_audio(audio_file):
# Load and process the audio
sampling_rate = processor.feature_extractor.sampling_rate
# Create conversation format
conversation = [
{"role": "user", "content": [
{"type": "audio", "audio": audio_file},
{"type": "text", "text": "Describe the audio in detail."}
]}
]
# Format the chat
chat_text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
# Process the inputs
inputs = processor(
text=chat_text,
audios=[audio_file],
return_tensors="pt",
sampling_rate=sampling_rate,
).to(model.device)
# Generate the output
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=768,
do_sample=False,
)
# Decode the output
generated_text = processor.tokenizer.decode(outputs[0], skip_special_tokens=False)
assistant_text = generated_text.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0].strip()
# Extract components
thinking, semantic, answer = extract_components(assistant_text)
return thinking, semantic, answer
# Create Gradio interface
demo = gr.Interface(
fn=process_audio,
inputs=gr.Audio(type="filepath", label="Upload Audio"),
outputs=[
gr.Textbox(label="Thinking Process", lines=10),
gr.Textbox(label="Semantic Elements", lines=5),
gr.Textbox(label="Answer", lines=5)
],
title="Qwen2Audio Audio Description Demo",
description="Upload an audio file and the model will provide detailed analysis and description.",
examples=[], # Add example files here if available
cache_examples=False,
)
# Launch the app
if __name__ == "__main__":
demo.launch()