aimeri's picture
Update process_input function in app.py to handle audio generation output more robustly, introducing a fallback mechanism for text generation in case of unexpected output formats. Improve error handling during audio and text generation processes. Additionally, update requirements.txt to include flash-attn for enhanced performance.
c98fc82
import gradio as gr
import torch
from transformers import Qwen2_5OmniModel, Qwen2_5OmniProcessor, TextStreamer
from qwen_omni_utils import process_mm_info
import soundfile as sf
import tempfile
import spaces
import gc
# Initialize the model and processor
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16
def get_model():
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
model = Qwen2_5OmniModel.from_pretrained(
"Qwen/Qwen2.5-Omni-7B",
torch_dtype=torch_dtype,
device_map="auto",
enable_audio_output=True,
low_cpu_mem_usage=True,
attn_implementation="flash_attention_2" if torch.cuda.is_available() else None
)
return model
model = get_model()
processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B")
# System prompt
SYSTEM_PROMPT = {
"role": "system",
"content": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."
}
# Voice options
VOICE_OPTIONS = {
"Chelsie (Female)": "Chelsie",
"Ethan (Male)": "Ethan"
}
@spaces.GPU
def process_input(image, audio, video, text, chat_history, voice_type, enable_audio_output):
try:
# Clear GPU memory before processing
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Initialize user_message_for_display at the start
user_message_for_display = str(text) if text is not None else ""
if image is not None:
user_message_for_display = (user_message_for_display + " " if user_message_for_display.strip() else "") + "[Image]"
if audio is not None:
user_message_for_display = (user_message_for_display + " " if user_message_for_display.strip() else "") + "[Audio]"
if video is not None:
user_message_for_display = (user_message_for_display + " " if user_message_for_display.strip() else "") + "[Video]"
# If empty, provide a default message
if not user_message_for_display.strip():
user_message_for_display = "Multimodal input"
# Combine multimodal inputs
user_input = {
"text": text,
"image": image if image is not None else None,
"audio": audio if audio is not None else None,
"video": video if video is not None else None
}
# Prepare conversation history for model processing
conversation = [SYSTEM_PROMPT]
# Add previous chat history
if isinstance(chat_history, list):
for message in chat_history:
if isinstance(message, dict) and "role" in message and "content" in message:
# Messages are already in the correct format
conversation.append(message)
elif isinstance(message, list) and len(message) == 2:
# Convert old format to new format
user_msg, bot_msg = message
if bot_msg is not None: # Only add complete message pairs
# Convert display format back to processable format
processed_msg = user_msg
if "[Image]" in user_msg:
processed_msg = {"type": "text", "text": user_msg.replace("[Image]", "").strip()}
if "[Audio]" in user_msg:
processed_msg = {"type": "text", "text": user_msg.replace("[Audio]", "").strip()}
if "[Video]" in user_msg:
processed_msg = {"type": "text", "text": user_msg.replace("[Video]", "").strip()}
conversation.append({"role": "user", "content": processed_msg})
conversation.append({"role": "assistant", "content": bot_msg})
# Add current user input
conversation.append({"role": "user", "content": user_input_to_content(user_input)})
# Prepare for inference
model_input = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
try:
audios, images, videos = process_mm_info(conversation, use_audio_in_video=False) # Default to no audio in video
except Exception as e:
print(f"Error processing multimedia: {str(e)}")
audios, images, videos = [], [], [] # Fallback to empty lists
inputs = processor(
text=model_input,
audios=audios,
images=images,
videos=videos,
return_tensors="pt",
padding=True
)
# Move inputs to device and convert dtype
inputs = {k: v.to(device=model.device, dtype=model.dtype) if isinstance(v, torch.Tensor) else v
for k, v in inputs.items()}
# Generate response with streaming
try:
text_ids = None
audio_path = None
generation_output = None
if enable_audio_output:
voice_type_value = VOICE_OPTIONS.get(voice_type, "Chelsie")
try:
generation_output = model.generate(
**inputs,
use_audio_in_video=False,
return_audio=True,
spk=voice_type_value,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
streamer=TextStreamer(processor, skip_prompt=True)
)
if generation_output is not None and isinstance(generation_output, tuple) and len(generation_output) == 2:
text_ids, audio = generation_output
if audio is not None:
# Save audio to temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
sf.write(
tmp_file.name,
audio.reshape(-1).detach().cpu().numpy(),
samplerate=24000,
)
audio_path = tmp_file.name
else:
print("Warning: Unexpected generation output format")
# Fall back to text-only generation
text_ids = model.generate(
**inputs,
use_audio_in_video=False,
return_audio=False,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
streamer=TextStreamer(processor, skip_prompt=True)
)
except Exception as e:
print(f"Error during audio generation: {str(e)}")
# Fall back to text-only generation
try:
text_ids = model.generate(
**inputs,
use_audio_in_video=False,
return_audio=False,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
streamer=TextStreamer(processor, skip_prompt=True)
)
except Exception as e:
print(f"Error during fallback text generation: {str(e)}")
text_ids = None
else:
try:
text_ids = model.generate(
**inputs,
use_audio_in_video=False,
return_audio=False,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
streamer=TextStreamer(processor, skip_prompt=True)
)
except Exception as e:
print(f"Error during text generation: {str(e)}")
text_ids = None
# Process the response
if text_ids is not None and len(text_ids) > 0:
try:
text_response = processor.batch_decode(
text_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
# Clean up text response
text_response = text_response.strip()
if "<|im_start|>assistant" in text_response:
text_response = text_response.split("<|im_start|>assistant")[-1]
text_response = text_response.replace("<|im_end|>", "").replace("<|im_start|>", "")
if text_response.startswith(":"):
text_response = text_response[1:].strip()
except Exception as e:
print(f"Error during text decoding: {str(e)}")
text_response = "I apologize, but I encountered an error processing the response."
else:
text_response = "I apologize, but I encountered an error generating a response."
# Update chat history with properly formatted entries
if not isinstance(chat_history, list):
chat_history = []
# Convert the current messages to the proper format
user_message = {"role": "user", "content": user_message_for_display}
assistant_message = {"role": "assistant", "content": text_response}
# Find the last incomplete message pair if it exists
if chat_history and isinstance(chat_history[-1], dict) and chat_history[-1]["role"] == "user":
chat_history.append(assistant_message)
else:
chat_history.extend([user_message, assistant_message])
# Clear GPU memory after processing
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Prepare output
if enable_audio_output and audio_path:
return chat_history, text_response, audio_path
else:
return chat_history, text_response, None
except Exception as e:
print(f"Error during generation: {str(e)}")
error_msg = "I apologize, but I encountered an error processing your request. Please try again."
chat_history.append(
{"role": "assistant", "content": error_msg}
)
return chat_history, error_msg, None
except Exception as e:
print(f"Error in process_input: {str(e)}")
if not isinstance(chat_history, list):
chat_history = []
error_msg = "I apologize, but I encountered an error processing your request. Please try again."
chat_history.extend([
{"role": "user", "content": user_message_for_display},
{"role": "assistant", "content": error_msg}
])
return chat_history, error_msg, None
def user_input_to_content(user_input):
if isinstance(user_input, str):
return user_input
elif isinstance(user_input, dict):
# Handle file uploads
content = []
if "text" in user_input and user_input["text"]:
content.append({"type": "text", "text": user_input["text"]})
if "image" in user_input and user_input["image"]:
content.append({"type": "image", "image": user_input["image"]})
if "audio" in user_input and user_input["audio"]:
content.append({"type": "audio", "audio": user_input["audio"]})
if "video" in user_input and user_input["video"]:
content.append({"type": "video", "video": user_input["video"]})
return content
return user_input
def create_demo():
with gr.Blocks(title="Qwen2.5-Omni Chat Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown("# Qwen2.5-Omni Multimodal Chat Demo")
gr.Markdown("Experience the omni-modal capabilities of Qwen2.5-Omni through text, images, audio, and video interactions.")
# Hidden placeholder components for text-only input
placeholder_image = gr.Image(type="filepath", visible=False)
placeholder_audio = gr.Audio(type="filepath", visible=False)
placeholder_video = gr.Video(visible=False)
# Chat interface
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(
height=600,
show_label=False,
avatar_images=["user.png", "assistant.png"],
type="messages"
)
with gr.Accordion("Advanced Options", open=False):
voice_type = gr.Dropdown(
choices=list(VOICE_OPTIONS.keys()),
value="Chelsie (Female)",
label="Voice Type"
)
enable_audio_output = gr.Checkbox(
value=True,
label="Enable Audio Output"
)
# Multimodal input components
with gr.Tabs():
with gr.TabItem("Text Input"):
text_input = gr.Textbox(
placeholder="Type your message here...",
label="Text Input",
autofocus=True,
container=False,
)
text_submit = gr.Button("Send Text", variant="primary")
with gr.TabItem("Multimodal Input"):
with gr.Row():
image_input = gr.Image(
type="filepath",
label="Upload Image"
)
audio_input = gr.Audio(
type="filepath",
label="Upload Audio"
)
with gr.Row():
video_input = gr.Video(
label="Upload Video"
)
additional_text = gr.Textbox(
placeholder="Additional text message...",
label="Additional Text",
container=False,
)
multimodal_submit = gr.Button("Send Multimodal Input", variant="primary")
clear_button = gr.Button("Clear Chat")
with gr.Column(scale=1):
gr.Markdown("## Model Capabilities")
gr.Markdown("""
**Qwen2.5-Omni can:**
- Process and understand text
- Analyze images and answer questions about them
- Transcribe and understand audio
- Analyze video content (with or without audio)
- Generate natural speech responses
""")
gr.Markdown("### Example Prompts")
gr.Examples(
examples=[
["Describe what you see in this image", "image"],
["What is being said in this audio clip?", "audio"],
["What's happening in this video?", "video"],
["Explain quantum computing in simple terms", "text"],
["Generate a short story about a robot learning to paint", "text"]
],
inputs=[text_input, gr.Textbox(visible=False)],
label="Text Examples"
)
audio_output = gr.Audio(
label="Model Speech Output",
visible=True,
autoplay=True
)
text_output = gr.Textbox(
label="Model Text Response",
interactive=False
)
# Text input handling
text_submit.click(
fn=lambda text: [{"role": "user", "content": text if text is not None else ""}],
inputs=text_input,
outputs=[chatbot],
queue=False
).then(
fn=process_input,
inputs=[placeholder_image, placeholder_audio, placeholder_video, text_input, chatbot, voice_type, enable_audio_output],
outputs=[chatbot, text_output, audio_output]
).then(
fn=lambda: "", # Clear input after submission
outputs=text_input
)
# Multimodal input handling
def prepare_multimodal_input(image, audio, video, text):
# Create a display message that indicates what was uploaded
display_message = str(text) if text is not None else ""
if image is not None:
display_message = (display_message + " " if display_message.strip() else "") + "[Image]"
if audio is not None:
display_message = (display_message + " " if display_message.strip() else "") + "[Audio]"
if video is not None:
display_message = (display_message + " " if display_message.strip() else "") + "[Video]"
if not display_message.strip():
display_message = "Multimodal content"
return [{"role": "user", "content": display_message}]
multimodal_submit.click(
fn=prepare_multimodal_input,
inputs=[image_input, audio_input, video_input, additional_text],
outputs=[chatbot],
queue=False
).then(
fn=process_input,
inputs=[image_input, audio_input, video_input, additional_text,
chatbot, voice_type, enable_audio_output],
outputs=[chatbot, text_output, audio_output]
).then(
fn=lambda: (None, None, None, ""), # Clear inputs after submission
outputs=[image_input, audio_input, video_input, additional_text]
)
# Clear chat
def clear_chat():
return [], None, None
clear_button.click(
fn=clear_chat,
outputs=[chatbot, text_output, audio_output]
)
# Update audio output visibility
def toggle_audio_output(enable_audio):
return gr.Audio(visible=enable_audio)
enable_audio_output.change(
fn=toggle_audio_output,
inputs=enable_audio_output,
outputs=audio_output
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch(server_name="0.0.0.0", server_port=7860)