Spaces:
Build error
Build error

Update process_input function in app.py to handle audio generation output more robustly, introducing a fallback mechanism for text generation in case of unexpected output formats. Improve error handling during audio and text generation processes. Additionally, update requirements.txt to include flash-attn for enhanced performance.
c98fc82
import gradio as gr | |
import torch | |
from transformers import Qwen2_5OmniModel, Qwen2_5OmniProcessor, TextStreamer | |
from qwen_omni_utils import process_mm_info | |
import soundfile as sf | |
import tempfile | |
import spaces | |
import gc | |
# Initialize the model and processor | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16 | |
def get_model(): | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
model = Qwen2_5OmniModel.from_pretrained( | |
"Qwen/Qwen2.5-Omni-7B", | |
torch_dtype=torch_dtype, | |
device_map="auto", | |
enable_audio_output=True, | |
low_cpu_mem_usage=True, | |
attn_implementation="flash_attention_2" if torch.cuda.is_available() else None | |
) | |
return model | |
model = get_model() | |
processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B") | |
# System prompt | |
SYSTEM_PROMPT = { | |
"role": "system", | |
"content": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech." | |
} | |
# Voice options | |
VOICE_OPTIONS = { | |
"Chelsie (Female)": "Chelsie", | |
"Ethan (Male)": "Ethan" | |
} | |
def process_input(image, audio, video, text, chat_history, voice_type, enable_audio_output): | |
try: | |
# Clear GPU memory before processing | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
# Initialize user_message_for_display at the start | |
user_message_for_display = str(text) if text is not None else "" | |
if image is not None: | |
user_message_for_display = (user_message_for_display + " " if user_message_for_display.strip() else "") + "[Image]" | |
if audio is not None: | |
user_message_for_display = (user_message_for_display + " " if user_message_for_display.strip() else "") + "[Audio]" | |
if video is not None: | |
user_message_for_display = (user_message_for_display + " " if user_message_for_display.strip() else "") + "[Video]" | |
# If empty, provide a default message | |
if not user_message_for_display.strip(): | |
user_message_for_display = "Multimodal input" | |
# Combine multimodal inputs | |
user_input = { | |
"text": text, | |
"image": image if image is not None else None, | |
"audio": audio if audio is not None else None, | |
"video": video if video is not None else None | |
} | |
# Prepare conversation history for model processing | |
conversation = [SYSTEM_PROMPT] | |
# Add previous chat history | |
if isinstance(chat_history, list): | |
for message in chat_history: | |
if isinstance(message, dict) and "role" in message and "content" in message: | |
# Messages are already in the correct format | |
conversation.append(message) | |
elif isinstance(message, list) and len(message) == 2: | |
# Convert old format to new format | |
user_msg, bot_msg = message | |
if bot_msg is not None: # Only add complete message pairs | |
# Convert display format back to processable format | |
processed_msg = user_msg | |
if "[Image]" in user_msg: | |
processed_msg = {"type": "text", "text": user_msg.replace("[Image]", "").strip()} | |
if "[Audio]" in user_msg: | |
processed_msg = {"type": "text", "text": user_msg.replace("[Audio]", "").strip()} | |
if "[Video]" in user_msg: | |
processed_msg = {"type": "text", "text": user_msg.replace("[Video]", "").strip()} | |
conversation.append({"role": "user", "content": processed_msg}) | |
conversation.append({"role": "assistant", "content": bot_msg}) | |
# Add current user input | |
conversation.append({"role": "user", "content": user_input_to_content(user_input)}) | |
# Prepare for inference | |
model_input = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) | |
try: | |
audios, images, videos = process_mm_info(conversation, use_audio_in_video=False) # Default to no audio in video | |
except Exception as e: | |
print(f"Error processing multimedia: {str(e)}") | |
audios, images, videos = [], [], [] # Fallback to empty lists | |
inputs = processor( | |
text=model_input, | |
audios=audios, | |
images=images, | |
videos=videos, | |
return_tensors="pt", | |
padding=True | |
) | |
# Move inputs to device and convert dtype | |
inputs = {k: v.to(device=model.device, dtype=model.dtype) if isinstance(v, torch.Tensor) else v | |
for k, v in inputs.items()} | |
# Generate response with streaming | |
try: | |
text_ids = None | |
audio_path = None | |
generation_output = None | |
if enable_audio_output: | |
voice_type_value = VOICE_OPTIONS.get(voice_type, "Chelsie") | |
try: | |
generation_output = model.generate( | |
**inputs, | |
use_audio_in_video=False, | |
return_audio=True, | |
spk=voice_type_value, | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
streamer=TextStreamer(processor, skip_prompt=True) | |
) | |
if generation_output is not None and isinstance(generation_output, tuple) and len(generation_output) == 2: | |
text_ids, audio = generation_output | |
if audio is not None: | |
# Save audio to temporary file | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: | |
sf.write( | |
tmp_file.name, | |
audio.reshape(-1).detach().cpu().numpy(), | |
samplerate=24000, | |
) | |
audio_path = tmp_file.name | |
else: | |
print("Warning: Unexpected generation output format") | |
# Fall back to text-only generation | |
text_ids = model.generate( | |
**inputs, | |
use_audio_in_video=False, | |
return_audio=False, | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
streamer=TextStreamer(processor, skip_prompt=True) | |
) | |
except Exception as e: | |
print(f"Error during audio generation: {str(e)}") | |
# Fall back to text-only generation | |
try: | |
text_ids = model.generate( | |
**inputs, | |
use_audio_in_video=False, | |
return_audio=False, | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
streamer=TextStreamer(processor, skip_prompt=True) | |
) | |
except Exception as e: | |
print(f"Error during fallback text generation: {str(e)}") | |
text_ids = None | |
else: | |
try: | |
text_ids = model.generate( | |
**inputs, | |
use_audio_in_video=False, | |
return_audio=False, | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
streamer=TextStreamer(processor, skip_prompt=True) | |
) | |
except Exception as e: | |
print(f"Error during text generation: {str(e)}") | |
text_ids = None | |
# Process the response | |
if text_ids is not None and len(text_ids) > 0: | |
try: | |
text_response = processor.batch_decode( | |
text_ids, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=False | |
)[0] | |
# Clean up text response | |
text_response = text_response.strip() | |
if "<|im_start|>assistant" in text_response: | |
text_response = text_response.split("<|im_start|>assistant")[-1] | |
text_response = text_response.replace("<|im_end|>", "").replace("<|im_start|>", "") | |
if text_response.startswith(":"): | |
text_response = text_response[1:].strip() | |
except Exception as e: | |
print(f"Error during text decoding: {str(e)}") | |
text_response = "I apologize, but I encountered an error processing the response." | |
else: | |
text_response = "I apologize, but I encountered an error generating a response." | |
# Update chat history with properly formatted entries | |
if not isinstance(chat_history, list): | |
chat_history = [] | |
# Convert the current messages to the proper format | |
user_message = {"role": "user", "content": user_message_for_display} | |
assistant_message = {"role": "assistant", "content": text_response} | |
# Find the last incomplete message pair if it exists | |
if chat_history and isinstance(chat_history[-1], dict) and chat_history[-1]["role"] == "user": | |
chat_history.append(assistant_message) | |
else: | |
chat_history.extend([user_message, assistant_message]) | |
# Clear GPU memory after processing | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
# Prepare output | |
if enable_audio_output and audio_path: | |
return chat_history, text_response, audio_path | |
else: | |
return chat_history, text_response, None | |
except Exception as e: | |
print(f"Error during generation: {str(e)}") | |
error_msg = "I apologize, but I encountered an error processing your request. Please try again." | |
chat_history.append( | |
{"role": "assistant", "content": error_msg} | |
) | |
return chat_history, error_msg, None | |
except Exception as e: | |
print(f"Error in process_input: {str(e)}") | |
if not isinstance(chat_history, list): | |
chat_history = [] | |
error_msg = "I apologize, but I encountered an error processing your request. Please try again." | |
chat_history.extend([ | |
{"role": "user", "content": user_message_for_display}, | |
{"role": "assistant", "content": error_msg} | |
]) | |
return chat_history, error_msg, None | |
def user_input_to_content(user_input): | |
if isinstance(user_input, str): | |
return user_input | |
elif isinstance(user_input, dict): | |
# Handle file uploads | |
content = [] | |
if "text" in user_input and user_input["text"]: | |
content.append({"type": "text", "text": user_input["text"]}) | |
if "image" in user_input and user_input["image"]: | |
content.append({"type": "image", "image": user_input["image"]}) | |
if "audio" in user_input and user_input["audio"]: | |
content.append({"type": "audio", "audio": user_input["audio"]}) | |
if "video" in user_input and user_input["video"]: | |
content.append({"type": "video", "video": user_input["video"]}) | |
return content | |
return user_input | |
def create_demo(): | |
with gr.Blocks(title="Qwen2.5-Omni Chat Demo", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# Qwen2.5-Omni Multimodal Chat Demo") | |
gr.Markdown("Experience the omni-modal capabilities of Qwen2.5-Omni through text, images, audio, and video interactions.") | |
# Hidden placeholder components for text-only input | |
placeholder_image = gr.Image(type="filepath", visible=False) | |
placeholder_audio = gr.Audio(type="filepath", visible=False) | |
placeholder_video = gr.Video(visible=False) | |
# Chat interface | |
with gr.Row(): | |
with gr.Column(scale=3): | |
chatbot = gr.Chatbot( | |
height=600, | |
show_label=False, | |
avatar_images=["user.png", "assistant.png"], | |
type="messages" | |
) | |
with gr.Accordion("Advanced Options", open=False): | |
voice_type = gr.Dropdown( | |
choices=list(VOICE_OPTIONS.keys()), | |
value="Chelsie (Female)", | |
label="Voice Type" | |
) | |
enable_audio_output = gr.Checkbox( | |
value=True, | |
label="Enable Audio Output" | |
) | |
# Multimodal input components | |
with gr.Tabs(): | |
with gr.TabItem("Text Input"): | |
text_input = gr.Textbox( | |
placeholder="Type your message here...", | |
label="Text Input", | |
autofocus=True, | |
container=False, | |
) | |
text_submit = gr.Button("Send Text", variant="primary") | |
with gr.TabItem("Multimodal Input"): | |
with gr.Row(): | |
image_input = gr.Image( | |
type="filepath", | |
label="Upload Image" | |
) | |
audio_input = gr.Audio( | |
type="filepath", | |
label="Upload Audio" | |
) | |
with gr.Row(): | |
video_input = gr.Video( | |
label="Upload Video" | |
) | |
additional_text = gr.Textbox( | |
placeholder="Additional text message...", | |
label="Additional Text", | |
container=False, | |
) | |
multimodal_submit = gr.Button("Send Multimodal Input", variant="primary") | |
clear_button = gr.Button("Clear Chat") | |
with gr.Column(scale=1): | |
gr.Markdown("## Model Capabilities") | |
gr.Markdown(""" | |
**Qwen2.5-Omni can:** | |
- Process and understand text | |
- Analyze images and answer questions about them | |
- Transcribe and understand audio | |
- Analyze video content (with or without audio) | |
- Generate natural speech responses | |
""") | |
gr.Markdown("### Example Prompts") | |
gr.Examples( | |
examples=[ | |
["Describe what you see in this image", "image"], | |
["What is being said in this audio clip?", "audio"], | |
["What's happening in this video?", "video"], | |
["Explain quantum computing in simple terms", "text"], | |
["Generate a short story about a robot learning to paint", "text"] | |
], | |
inputs=[text_input, gr.Textbox(visible=False)], | |
label="Text Examples" | |
) | |
audio_output = gr.Audio( | |
label="Model Speech Output", | |
visible=True, | |
autoplay=True | |
) | |
text_output = gr.Textbox( | |
label="Model Text Response", | |
interactive=False | |
) | |
# Text input handling | |
text_submit.click( | |
fn=lambda text: [{"role": "user", "content": text if text is not None else ""}], | |
inputs=text_input, | |
outputs=[chatbot], | |
queue=False | |
).then( | |
fn=process_input, | |
inputs=[placeholder_image, placeholder_audio, placeholder_video, text_input, chatbot, voice_type, enable_audio_output], | |
outputs=[chatbot, text_output, audio_output] | |
).then( | |
fn=lambda: "", # Clear input after submission | |
outputs=text_input | |
) | |
# Multimodal input handling | |
def prepare_multimodal_input(image, audio, video, text): | |
# Create a display message that indicates what was uploaded | |
display_message = str(text) if text is not None else "" | |
if image is not None: | |
display_message = (display_message + " " if display_message.strip() else "") + "[Image]" | |
if audio is not None: | |
display_message = (display_message + " " if display_message.strip() else "") + "[Audio]" | |
if video is not None: | |
display_message = (display_message + " " if display_message.strip() else "") + "[Video]" | |
if not display_message.strip(): | |
display_message = "Multimodal content" | |
return [{"role": "user", "content": display_message}] | |
multimodal_submit.click( | |
fn=prepare_multimodal_input, | |
inputs=[image_input, audio_input, video_input, additional_text], | |
outputs=[chatbot], | |
queue=False | |
).then( | |
fn=process_input, | |
inputs=[image_input, audio_input, video_input, additional_text, | |
chatbot, voice_type, enable_audio_output], | |
outputs=[chatbot, text_output, audio_output] | |
).then( | |
fn=lambda: (None, None, None, ""), # Clear inputs after submission | |
outputs=[image_input, audio_input, video_input, additional_text] | |
) | |
# Clear chat | |
def clear_chat(): | |
return [], None, None | |
clear_button.click( | |
fn=clear_chat, | |
outputs=[chatbot, text_output, audio_output] | |
) | |
# Update audio output visibility | |
def toggle_audio_output(enable_audio): | |
return gr.Audio(visible=enable_audio) | |
enable_audio_output.change( | |
fn=toggle_audio_output, | |
inputs=enable_audio_output, | |
outputs=audio_output | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.launch(server_name="0.0.0.0", server_port=7860) |