File size: 3,357 Bytes
d1d89ce
39d1328
 
 
 
43282ee
39d1328
 
 
 
 
 
 
 
 
 
 
 
 
 
f754b4d
39d1328
 
 
 
 
 
04a31f5
 
39d1328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1d89ce
 
 
39d1328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import spaces
import os
import re
import gradio as gr
import torch
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration

# Model path and configuration
model_path = "./model"
base_model_id = "Qwen/Qwen2-Audio-7B-Instruct"

# Load the model and processor
def load_model():
    # Load the processor from the base model
    processor = AutoProcessor.from_pretrained(
        base_model_id,
        trust_remote_code=True,
    )
    
    # Load the base model
    model = Qwen2AudioForConditionalGeneration.from_pretrained(
        base_model_id,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        device_map="auto",
    )
    
    model.eval()
    
    return model, processor

# Initialize model and processor
model, processor = load_model()

# Function to extract components from model output
def extract_components(text):
    thinking = ""
    semantic = ""
    answer = ""
    
    # Extract thinking
    think_match = re.search(r"<think>(.*?)</think>", text, re.DOTALL)
    if think_match:
        thinking = think_match.group(1).strip()
    
    # Extract semantic elements
    semantic_match = re.search(r"<semantic_elements>(.*?)</semantic_elements>", text, re.DOTALL)
    if semantic_match:
        semantic = semantic_match.group(1).strip()
    
    # Extract answer
    answer_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
    if answer_match:
        answer = answer_match.group(1).strip()
    
    return thinking, semantic, answer



@spaces.GPU
def process_audio(audio_file):
    # Load and process the audio
    sampling_rate = processor.feature_extractor.sampling_rate
    
    # Create conversation format
    conversation = [
        {"role": "user", "content": [
            {"type": "audio", "audio": audio_file},
            {"type": "text", "text": "Describe the audio in detail."}
        ]}
    ]
    
    # Format the chat
    chat_text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
    
    # Process the inputs
    inputs = processor(
        text=chat_text,
        audios=[audio_file],
        return_tensors="pt",
        sampling_rate=sampling_rate,
    ).to(model.device)
    
    # Generate the output
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=768,
            do_sample=False,
        )
    
    # Decode the output
    generated_text = processor.tokenizer.decode(outputs[0], skip_special_tokens=False)
    assistant_text = generated_text.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0].strip()
    
    # Extract components
    thinking, semantic, answer = extract_components(assistant_text)
    
    return thinking, semantic, answer

# Create Gradio interface
demo = gr.Interface(
    fn=process_audio,
    inputs=gr.Audio(type="filepath", label="Upload Audio"),
    outputs=[
        gr.Textbox(label="Thinking Process", lines=10),
        gr.Textbox(label="Semantic Elements", lines=5),
        gr.Textbox(label="Answer", lines=5)
    ],
    title="Qwen2Audio Audio Description Demo",
    description="Upload an audio file and the model will provide detailed analysis and description.",
    examples=[],  # Add example files here if available
    cache_examples=False,
)

# Launch the app
if __name__ == "__main__":
    demo.launch()