aimeri commited on
Commit
e4a9a7a
·
1 Parent(s): 1638860

Enhance process_input and create_demo functions in app.py to improve multimodal input handling, including better formatting for user messages and integration of TextStreamer for text response generation.

Browse files
Files changed (1) hide show
  1. app.py +44 -14
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import Qwen2_5OmniModel, Qwen2_5OmniProcessor
4
  from qwen_omni_utils import process_mm_info
5
  import soundfile as sf
6
  import tempfile
@@ -51,7 +51,16 @@ def process_input(image, audio, video, text, chat_history, voice_type, enable_au
51
  if isinstance(item, list) and len(item) == 2:
52
  user_msg, bot_msg = item
53
  if bot_msg is not None: # Only add complete message pairs
54
- conversation.append({"role": "user", "content": user_input_to_content(user_msg)})
 
 
 
 
 
 
 
 
 
55
  conversation.append({"role": "assistant", "content": bot_msg})
56
  else:
57
  # Initialize chat history if it's not a list
@@ -78,14 +87,19 @@ def process_input(image, audio, video, text, chat_history, voice_type, enable_au
78
  )
79
  inputs = inputs.to(model.device).to(model.dtype)
80
 
81
- # Generate response
82
  if enable_audio_output:
83
  voice_type_value = VOICE_OPTIONS.get(voice_type, "Chelsie")
84
  text_ids, audio = model.generate(
85
  **inputs,
86
  use_audio_in_video=False, # Set to False to avoid audio processing issues
87
  return_audio=True,
88
- spk=voice_type_value
 
 
 
 
 
89
  )
90
 
91
  # Save audio to temporary file
@@ -100,7 +114,12 @@ def process_input(image, audio, video, text, chat_history, voice_type, enable_au
100
  text_ids = model.generate(
101
  **inputs,
102
  use_audio_in_video=False, # Set to False to avoid audio processing issues
103
- return_audio=False
 
 
 
 
 
104
  )
105
  audio_path = None
106
 
@@ -111,17 +130,20 @@ def process_input(image, audio, video, text, chat_history, voice_type, enable_au
111
  clean_up_tokenization_spaces=False
112
  )[0]
113
 
114
- # Clean up text response
115
  text_response = text_response.strip()
 
 
 
116
 
117
  # Format user message for chat history display
118
  user_message_for_display = str(text) if text is not None else ""
119
  if image is not None:
120
- user_message_for_display = (user_message_for_display or "Image uploaded") + " [Image]"
121
  if audio is not None:
122
- user_message_for_display = (user_message_for_display or "Audio uploaded") + " [Audio]"
123
  if video is not None:
124
- user_message_for_display = (user_message_for_display or "Video uploaded") + " [Video]"
125
 
126
  # If empty, provide a default message
127
  if not user_message_for_display.strip():
@@ -168,7 +190,12 @@ def create_demo():
168
  # Chat interface
169
  with gr.Row():
170
  with gr.Column(scale=3):
171
- chatbot = gr.Chatbot(height=600)
 
 
 
 
 
172
  with gr.Accordion("Advanced Options", open=False):
173
  voice_type = gr.Dropdown(
174
  choices=list(VOICE_OPTIONS.keys()),
@@ -185,9 +212,11 @@ def create_demo():
185
  with gr.TabItem("Text Input"):
186
  text_input = gr.Textbox(
187
  placeholder="Type your message here...",
188
- label="Text Input"
 
 
189
  )
190
- text_submit = gr.Button("Send Text")
191
 
192
  with gr.TabItem("Multimodal Input"):
193
  with gr.Row():
@@ -205,9 +234,10 @@ def create_demo():
205
  )
206
  additional_text = gr.Textbox(
207
  placeholder="Additional text message...",
208
- label="Additional Text"
 
209
  )
210
- multimodal_submit = gr.Button("Send Multimodal Input")
211
 
212
  clear_button = gr.Button("Clear Chat")
213
 
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import Qwen2_5OmniModel, Qwen2_5OmniProcessor, TextStreamer
4
  from qwen_omni_utils import process_mm_info
5
  import soundfile as sf
6
  import tempfile
 
51
  if isinstance(item, list) and len(item) == 2:
52
  user_msg, bot_msg = item
53
  if bot_msg is not None: # Only add complete message pairs
54
+ # Convert display format back to processable format
55
+ processed_msg = user_msg
56
+ if "[Image]" in user_msg:
57
+ processed_msg = {"type": "text", "text": user_msg.replace("[Image]", "").strip()}
58
+ if "[Audio]" in user_msg:
59
+ processed_msg = {"type": "text", "text": user_msg.replace("[Audio]", "").strip()}
60
+ if "[Video]" in user_msg:
61
+ processed_msg = {"type": "text", "text": user_msg.replace("[Video]", "").strip()}
62
+
63
+ conversation.append({"role": "user", "content": processed_msg})
64
  conversation.append({"role": "assistant", "content": bot_msg})
65
  else:
66
  # Initialize chat history if it's not a list
 
87
  )
88
  inputs = inputs.to(model.device).to(model.dtype)
89
 
90
+ # Generate response with streaming
91
  if enable_audio_output:
92
  voice_type_value = VOICE_OPTIONS.get(voice_type, "Chelsie")
93
  text_ids, audio = model.generate(
94
  **inputs,
95
  use_audio_in_video=False, # Set to False to avoid audio processing issues
96
  return_audio=True,
97
+ spk=voice_type_value,
98
+ max_new_tokens=512,
99
+ do_sample=True,
100
+ temperature=0.7,
101
+ top_p=0.9,
102
+ streamer=TextStreamer(processor, skip_prompt=True)
103
  )
104
 
105
  # Save audio to temporary file
 
114
  text_ids = model.generate(
115
  **inputs,
116
  use_audio_in_video=False, # Set to False to avoid audio processing issues
117
+ return_audio=False,
118
+ max_new_tokens=512,
119
+ do_sample=True,
120
+ temperature=0.7,
121
+ top_p=0.9,
122
+ streamer=TextStreamer(processor, skip_prompt=True)
123
  )
124
  audio_path = None
125
 
 
130
  clean_up_tokenization_spaces=False
131
  )[0]
132
 
133
+ # Clean up text response by removing system/user messages
134
  text_response = text_response.strip()
135
+ text_response = text_response.split("assistant")[-1].strip()
136
+ if text_response.startswith(":"):
137
+ text_response = text_response[1:].strip()
138
 
139
  # Format user message for chat history display
140
  user_message_for_display = str(text) if text is not None else ""
141
  if image is not None:
142
+ user_message_for_display = (user_message_for_display + " " if user_message_for_display.strip() else "") + "[Image]"
143
  if audio is not None:
144
+ user_message_for_display = (user_message_for_display + " " if user_message_for_display.strip() else "") + "[Audio]"
145
  if video is not None:
146
+ user_message_for_display = (user_message_for_display + " " if user_message_for_display.strip() else "") + "[Video]"
147
 
148
  # If empty, provide a default message
149
  if not user_message_for_display.strip():
 
190
  # Chat interface
191
  with gr.Row():
192
  with gr.Column(scale=3):
193
+ chatbot = gr.Chatbot(
194
+ height=600,
195
+ show_label=False,
196
+ avatar_images=["👤", "🤖"],
197
+ bubble_full_width=False,
198
+ )
199
  with gr.Accordion("Advanced Options", open=False):
200
  voice_type = gr.Dropdown(
201
  choices=list(VOICE_OPTIONS.keys()),
 
212
  with gr.TabItem("Text Input"):
213
  text_input = gr.Textbox(
214
  placeholder="Type your message here...",
215
+ label="Text Input",
216
+ autofocus=True,
217
+ container=False,
218
  )
219
+ text_submit = gr.Button("Send Text", variant="primary")
220
 
221
  with gr.TabItem("Multimodal Input"):
222
  with gr.Row():
 
234
  )
235
  additional_text = gr.Textbox(
236
  placeholder="Additional text message...",
237
+ label="Additional Text",
238
+ container=False,
239
  )
240
+ multimodal_submit = gr.Button("Send Multimodal Input", variant="primary")
241
 
242
  clear_button = gr.Button("Clear Chat")
243