File size: 3,534 Bytes
39d1328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04a31f5
 
39d1328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import re
import gradio as gr
import torch
from transformers import AutoProcessor
from qwen import Qwen2AudioForConditionalGeneration
from peft import PeftModel, PeftConfig

# Model path and configuration
model_path = "./model"
base_model_id = "Qwen/Qwen2-Audio-7B-Instruct"

# Load the model and processor
def load_model():
    # Load the processor from the base model
    processor = AutoProcessor.from_pretrained(
        base_model_id,
        trust_remote_code=True,
    )
    
    # Load the base model
    base_model = Qwen2AudioForConditionalGeneration.from_pretrained(
        base_model_id,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        device_map="auto",
    )
    
    # Load the LoRA adapter
    model = PeftModel.from_pretrained(base_model, model_path)
    
    model.eval()
    
    return model, processor

# Initialize model and processor
model, processor = load_model()

# Function to extract components from model output
def extract_components(text):
    thinking = ""
    semantic = ""
    answer = ""
    
    # Extract thinking
    think_match = re.search(r"<think>(.*?)</think>", text, re.DOTALL)
    if think_match:
        thinking = think_match.group(1).strip()
    
    # Extract semantic elements
    semantic_match = re.search(r"<semantic_elements>(.*?)</semantic_elements>", text, re.DOTALL)
    if semantic_match:
        semantic = semantic_match.group(1).strip()
    
    # Extract answer
    answer_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
    if answer_match:
        answer = answer_match.group(1).strip()
    
    return thinking, semantic, answer

# Function to process audio and return components
def process_audio(audio_file):
    # Load and process the audio
    sampling_rate = processor.feature_extractor.sampling_rate
    
    # Create conversation format
    conversation = [
        {"role": "user", "content": [
            {"type": "audio", "audio": audio_file},
            {"type": "text", "text": "Describe the audio in detail."}
        ]}
    ]
    
    # Format the chat
    chat_text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
    
    # Process the inputs
    inputs = processor(
        text=chat_text,
        audios=[audio_file],
        return_tensors="pt",
        sampling_rate=sampling_rate,
    ).to(model.device)
    
    # Generate the output
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=768,
            do_sample=False,
        )
    
    # Decode the output
    generated_text = processor.tokenizer.decode(outputs[0], skip_special_tokens=False)
    assistant_text = generated_text.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0].strip()
    
    # Extract components
    thinking, semantic, answer = extract_components(assistant_text)
    
    return thinking, semantic, answer

# Create Gradio interface
demo = gr.Interface(
    fn=process_audio,
    inputs=gr.Audio(type="filepath", label="Upload Audio"),
    outputs=[
        gr.Textbox(label="Thinking Process", lines=10),
        gr.Textbox(label="Semantic Elements", lines=5),
        gr.Textbox(label="Answer", lines=5)
    ],
    title="Qwen2Audio Audio Description Demo",
    description="Upload an audio file and the model will provide detailed analysis and description.",
    examples=[],  # Add example files here if available
    cache_examples=False,
)

# Launch the app
if __name__ == "__main__":
    demo.launch()