File size: 5,082 Bytes
b6d8181
39d1328
 
 
 
fbe7912
 
b770eaa
 
 
39d1328
b95f6f3
 
 
39d1328
 
b95f6f3
 
 
39d1328
b95f6f3
 
 
 
 
39d1328
 
 
 
 
 
b95f6f3
f754b4d
b770eaa
39d1328
 
 
 
 
04a31f5
 
b95f6f3
 
39d1328
b95f6f3
39d1328
b95f6f3
 
b770eaa
d1d89ce
b770eaa
 
 
 
 
 
 
 
 
 
 
 
b95f6f3
b770eaa
 
b95f6f3
b770eaa
 
 
b95f6f3
 
 
b770eaa
d1d89ce
b770eaa
b6d8181
b95f6f3
 
ce70672
b95f6f3
 
b770eaa
 
fbe7912
 
 
b770eaa
fbe7912
 
 
b770eaa
 
 
 
 
fbe7912
 
 
39d1328
 
 
 
fbe7912
39d1328
 
 
 
 
 
 
 
 
 
fbe7912
39d1328
 
 
 
b770eaa
 
 
 
 
 
 
 
39d1328
b770eaa
39d1328
b770eaa
39d1328
 
 
b770eaa
 
 
 
 
 
 
 
39d1328
b770eaa
 
 
b95f6f3
 
ce70672
b95f6f3
 
b770eaa
b95f6f3
 
39d1328
b770eaa
39d1328
 
b770eaa
39d1328
b770eaa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import spaces
import os
import re
import gradio as gr
import torch
import librosa
import numpy as np
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration, TextIteratorStreamer
import torchaudio
from threading import Thread

# Model paths and configuration
model_path_1 = "./model"
model_path_2 = "./model2"
base_model_id = "Qwen/Qwen2-Audio-7B-Instruct"

# Dictionary to store loaded models and processors
loaded_models = {}

# Load the model and processor
def load_model(model_path):
    # Check if model is already loaded
    if model_path in loaded_models:
        return loaded_models[model_path]
    
    # Load the processor from the base model
    processor = AutoProcessor.from_pretrained(
        base_model_id,
        trust_remote_code=True,
    )
    
    # Load the model
    model = Qwen2AudioForConditionalGeneration.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        device_map="auto",
    )
    
    model.eval()
    
    # Store in cache
    loaded_models[model_path] = (model, processor)
    
    return model, processor

# Initialize first model and processor
model, processor = load_model(model_path_1)


def process_output(output):
    if "<think>" in output:
        rest = output.split("<think>")[1]
        output = "<think>\n" + rest
    elif "<semantic_elements>" in output:
        rest = output.split("<semantic_elements>")[1]
        output = "<semantic_elements>\n" + rest
    elif "<answer>" in output:
        rest = output.split("<answer>")[1]
        output = "<answer>\n" + rest
    elif "</think>" in output:
        rest = output.split("</think>")[0]
        output = rest + "\n</think>\n\n"
    elif "</semantic_elements>" in output:
        rest = output.split("</semantic_elements>")[0]
        output = rest + "\n</semantic_elements>\n\n"
    elif "</answer>" in output:
        rest = output.split("</answer>")[0]
        output = rest + "\n</answer>\n"
    output = output.replace("\\n", "\n")
    output = output.replace("\\", "\n")
    output = output.replace("\n-", "-")
    return output

# Keep only the process_audio_streaming function that's actually used in the Gradio interface
@spaces.GPU
def process_audio_streaming(audio_file, model_choice):
    # Load the selected model
    model_path = model_path_1 if model_choice == "Think" else model_path_2
    model, processor = load_model(model_path)
    
    # Load and process the audio with torchaudio
    waveform, sr = torchaudio.load(audio_file)
    
    # Resample to 16kHz if needed
    if sr != 16000:
        waveform = torchaudio.functional.resample(waveform, sr, 16000)
        sr = 16000
    
    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    # Get the audio data as numpy array
    y = waveform.squeeze().numpy()
            
    # Set sampling rate for the processor
    sampling_rate = 16000
    
    # Create conversation format
    conversation = [
        {"role": "user", "content": [
            {"type": "audio", "audio": y},
            {"type": "text", "text": "Describe the audio in detail."}
        ]}
    ]
    
    # Format the chat
    chat_text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
    
    # Process the inputs
    inputs = processor(
        text=chat_text,
        audios=[y],
        return_tensors="pt",
        sampling_rate=sampling_rate,
    ).to(model.device)
    
    # Create a streamer instance
    streamer = TextIteratorStreamer(
        processor.tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
    
    # Initialize an empty string to store the generated text
    accumulated_output = ""

    # Generate the output with streaming
    with torch.no_grad():
        generate_kwargs = dict(
            **inputs,
            streamer=streamer,
            max_new_tokens=768,
            do_sample=False,
        )
        t = Thread(target=model.generate, kwargs=generate_kwargs)
        t.start()

        # Yield the final outputs
        for output in streamer:
            output = process_output(output)
            accumulated_output += output  # Append new output to the accumulated string
            yield accumulated_output  # Yield the accumulated output

# Create Gradio interface for audio processing
audio_demo = gr.Interface(
    fn=process_audio_streaming,
    inputs=[
        gr.Audio(type="filepath", label="Upload Audio"),
        gr.Radio(["Think", "Think + Semantics"], label="Select Model", value="Think + Semantics")
    ],
    outputs=gr.Textbox(label="Generated Output", lines=30),
    title="SemThink",
    description="Upload an audio file and the model will provide detailed analysis and description. Choose between different model versions.",
    examples=[["examples/1.wav", "Think + Semantics"]],  # Updated default model in examples
    cache_examples=False,
    live=True  # Enable live updates
)

# Launch the apps
if __name__ == "__main__":
    audio_demo.launch()