Spaces:
Build error
Build error
import io | |
import os | |
import ffmpeg | |
import numpy as np | |
import gradio as gr | |
import soundfile as sf | |
import modelscope_studio.components.base as ms | |
import modelscope_studio.components.antd as antd | |
import gradio.processing_utils as processing_utils | |
from transformers import Qwen2_5OmniModel, Qwen2_5OmniProcessor | |
from gradio_client import utils as client_utils | |
from qwen_omni_utils import process_mm_info | |
from argparse import ArgumentParser | |
def _load_model_processor(args): | |
if args.cpu_only: | |
device_map = 'cpu' | |
else: | |
device_map = 'auto' | |
# Check if flash-attn2 flag is enabled and load model accordingly | |
if args.flash_attn2: | |
model = Qwen2_5OmniModel.from_pretrained(args.checkpoint_path, | |
torch_dtype='auto', | |
attn_implementation='flash_attention_2', | |
device_map=device_map) | |
else: | |
model = Qwen2_5OmniModel.from_pretrained(args.checkpoint_path, device_map=device_map) | |
processor = Qwen2_5OmniProcessor.from_pretrained(args.checkpoint_path) | |
return model, processor | |
def _launch_demo(args, model, processor): | |
# Voice settings | |
VOICE_LIST = ['Chelsie', 'Ethan'] | |
DEFAULT_VOICE = 'Chelsie' | |
default_system_prompt = 'You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.' | |
language = args.ui_language | |
def get_text(text: str, cn_text: str): | |
if language == 'en': | |
return text | |
if language == 'zh': | |
return cn_text | |
return text | |
def convert_webm_to_mp4(input_file, output_file): | |
try: | |
( | |
ffmpeg | |
.input(input_file) | |
.output(output_file, acodec='aac', ar='16000', audio_bitrate='192k') | |
.run(quiet=True, overwrite_output=True) | |
) | |
print(f"Conversion successful: {output_file}") | |
except ffmpeg.Error as e: | |
print("An error occurred during conversion.") | |
print(e.stderr.decode('utf-8')) | |
def format_history(history: list, system_prompt: str): | |
messages = [] | |
messages.append({"role": "system", "content": system_prompt}) | |
for item in history: | |
if isinstance(item["content"], str): | |
messages.append({"role": item['role'], "content": item['content']}) | |
elif item["role"] == "user" and (isinstance(item["content"], list) or | |
isinstance(item["content"], tuple)): | |
file_path = item["content"][0] | |
mime_type = client_utils.get_mimetype(file_path) | |
if mime_type.startswith("image"): | |
messages.append({ | |
"role": | |
item['role'], | |
"content": [{ | |
"type": "image", | |
"image": file_path | |
}] | |
}) | |
elif mime_type.startswith("video"): | |
messages.append({ | |
"role": | |
item['role'], | |
"content": [{ | |
"type": "video", | |
"video": file_path | |
}] | |
}) | |
elif mime_type.startswith("audio"): | |
messages.append({ | |
"role": | |
item['role'], | |
"content": [{ | |
"type": "audio", | |
"audio": file_path, | |
}] | |
}) | |
return messages | |
def predict(messages, voice=DEFAULT_VOICE): | |
print('predict history: ', messages) | |
text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) | |
audios, images, videos = process_mm_info(messages, True) | |
inputs = processor(text=text, audios=audios, images=images, videos=videos, return_tensors="pt", padding=True) | |
inputs = inputs.to(model.device).to(model.dtype) | |
text_ids, audio = model.generate(**inputs, spk=voice, use_audio_in_video=True) | |
response = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
response = response[0].split("\n")[-1] | |
yield {"type": "text", "data": response} | |
audio = np.array(audio * 32767).astype(np.int16) | |
wav_io = io.BytesIO() | |
sf.write(wav_io, audio, samplerate=24000, format="WAV") | |
wav_io.seek(0) | |
wav_bytes = wav_io.getvalue() | |
audio_path = processing_utils.save_bytes_to_cache( | |
wav_bytes, "audio.wav", cache_dir=demo.GRADIO_CACHE) | |
yield {"type": "audio", "data": audio_path} | |
def media_predict(audio, video, history, system_prompt, voice_choice): | |
# First yield | |
yield ( | |
None, # microphone | |
None, # webcam | |
history, # media_chatbot | |
gr.update(visible=False), # submit_btn | |
gr.update(visible=True), # stop_btn | |
) | |
if video is not None: | |
convert_webm_to_mp4(video, video.replace('.webm', '.mp4')) | |
video = video.replace(".webm", ".mp4") | |
files = [audio, video] | |
for f in files: | |
if f: | |
history.append({"role": "user", "content": (f, )}) | |
formatted_history = format_history(history=history, | |
system_prompt=system_prompt,) | |
history.append({"role": "assistant", "content": ""}) | |
for chunk in predict(formatted_history, voice_choice): | |
if chunk["type"] == "text": | |
history[-1]["content"] = chunk["data"] | |
yield ( | |
None, # microphone | |
None, # webcam | |
history, # media_chatbot | |
gr.update(visible=False), # submit_btn | |
gr.update(visible=True), # stop_btn | |
) | |
if chunk["type"] == "audio": | |
history.append({ | |
"role": "assistant", | |
"content": gr.Audio(chunk["data"]) | |
}) | |
# Final yield | |
yield ( | |
None, # microphone | |
None, # webcam | |
history, # media_chatbot | |
gr.update(visible=True), # submit_btn | |
gr.update(visible=False), # stop_btn | |
) | |
def chat_predict(text, audio, image, video, history, system_prompt, voice_choice): | |
# Process text input | |
if text: | |
history.append({"role": "user", "content": text}) | |
# Process audio input | |
if audio: | |
history.append({"role": "user", "content": (audio, )}) | |
# Process image input | |
if image: | |
history.append({"role": "user", "content": (image, )}) | |
# Process video input | |
if video: | |
history.append({"role": "user", "content": (video, )}) | |
formatted_history = format_history(history=history, | |
system_prompt=system_prompt) | |
yield None, None, None, None, history | |
history.append({"role": "assistant", "content": ""}) | |
for chunk in predict(formatted_history, voice_choice): | |
if chunk["type"] == "text": | |
history[-1]["content"] = chunk["data"] | |
yield gr.skip(), gr.skip(), gr.skip(), gr.skip( | |
), history | |
if chunk["type"] == "audio": | |
history.append({ | |
"role": "assistant", | |
"content": gr.Audio(chunk["data"]) | |
}) | |
yield gr.skip(), gr.skip(), gr.skip(), gr.skip(), history | |
with gr.Blocks() as demo, ms.Application(), antd.ConfigProvider(): | |
with gr.Sidebar(open=False): | |
system_prompt_textbox = gr.Textbox(label="System Prompt", | |
value=default_system_prompt) | |
with antd.Flex(gap="small", justify="center", align="center"): | |
with antd.Flex(vertical=True, gap="small", align="center"): | |
antd.Typography.Title("Qwen2.5-Omni Demo", | |
level=1, | |
elem_style=dict(margin=0, fontSize=28)) | |
with antd.Flex(vertical=True, gap="small"): | |
antd.Typography.Text(get_text("🎯 Instructions for use:", | |
"🎯 使用说明:"), | |
strong=True) | |
antd.Typography.Text( | |
get_text( | |
"1️⃣ Click the Audio Record button or the Camera Record button.", | |
"1️⃣ 点击音频录制按钮,或摄像头-录制按钮")) | |
antd.Typography.Text( | |
get_text("2️⃣ Input audio or video.", "2️⃣ 输入音频或者视频")) | |
antd.Typography.Text( | |
get_text( | |
"3️⃣ Click the submit button and wait for the model's response.", | |
"3️⃣ 点击提交并等待模型的回答")) | |
voice_choice = gr.Dropdown(label="Voice Choice", | |
choices=VOICE_LIST, | |
value=DEFAULT_VOICE) | |
with gr.Tabs(): | |
with gr.Tab("Online"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
microphone = gr.Audio(sources=['microphone'], | |
type="filepath") | |
webcam = gr.Video(sources=['webcam'], | |
height=400, | |
include_audio=True) | |
submit_btn = gr.Button(get_text("Submit", "提交"), | |
variant="primary") | |
stop_btn = gr.Button(get_text("Stop", "停止"), visible=False) | |
clear_btn = gr.Button(get_text("Clear History", "清除历史")) | |
with gr.Column(scale=2): | |
media_chatbot = gr.Chatbot(height=650, type="messages") | |
def clear_history(): | |
return [], gr.update(value=None), gr.update(value=None) | |
submit_event = submit_btn.click(fn=media_predict, | |
inputs=[ | |
microphone, webcam, | |
media_chatbot, | |
system_prompt_textbox, | |
voice_choice | |
], | |
outputs=[ | |
microphone, webcam, | |
media_chatbot, submit_btn, | |
stop_btn | |
]) | |
stop_btn.click( | |
fn=lambda: | |
(gr.update(visible=True), gr.update(visible=False)), | |
inputs=None, | |
outputs=[submit_btn, stop_btn], | |
cancels=[submit_event], | |
queue=False) | |
clear_btn.click(fn=clear_history, | |
inputs=None, | |
outputs=[media_chatbot, microphone, webcam]) | |
with gr.Tab("Offline"): | |
chatbot = gr.Chatbot(type="messages", height=650) | |
# Media upload section in one row | |
with gr.Row(equal_height=True): | |
audio_input = gr.Audio(sources=["upload"], | |
type="filepath", | |
label="Upload Audio", | |
elem_classes="media-upload", | |
scale=1) | |
image_input = gr.Image(sources=["upload"], | |
type="filepath", | |
label="Upload Image", | |
elem_classes="media-upload", | |
scale=1) | |
video_input = gr.Video(sources=["upload"], | |
label="Upload Video", | |
elem_classes="media-upload", | |
scale=1) | |
# Text input section | |
text_input = gr.Textbox(show_label=False, | |
placeholder="Enter text here...") | |
# Control buttons | |
with gr.Row(): | |
submit_btn = gr.Button(get_text("Submit", "提交"), | |
variant="primary", | |
size="lg") | |
stop_btn = gr.Button(get_text("Stop", "停止"), | |
visible=False, | |
size="lg") | |
clear_btn = gr.Button(get_text("Clear History", "清除历史"), | |
size="lg") | |
def clear_chat_history(): | |
return [], gr.update(value=None), gr.update( | |
value=None), gr.update(value=None), gr.update(value=None) | |
submit_event = gr.on( | |
triggers=[submit_btn.click, text_input.submit], | |
fn=chat_predict, | |
inputs=[ | |
text_input, audio_input, image_input, video_input, chatbot, | |
system_prompt_textbox, voice_choice | |
], | |
outputs=[ | |
text_input, audio_input, image_input, video_input, chatbot | |
]) | |
stop_btn.click(fn=lambda: | |
(gr.update(visible=True), gr.update(visible=False)), | |
inputs=None, | |
outputs=[submit_btn, stop_btn], | |
cancels=[submit_event], | |
queue=False) | |
clear_btn.click(fn=clear_chat_history, | |
inputs=None, | |
outputs=[ | |
chatbot, text_input, audio_input, image_input, | |
video_input | |
]) | |
# Add some custom CSS to improve the layout | |
gr.HTML(""" | |
<style> | |
.media-upload { | |
margin: 10px; | |
min-height: 160px; | |
} | |
.media-upload > .wrap { | |
border: 2px dashed #ccc; | |
border-radius: 8px; | |
padding: 10px; | |
height: 100%; | |
} | |
.media-upload:hover > .wrap { | |
border-color: #666; | |
} | |
/* Make upload areas equal width */ | |
.media-upload { | |
flex: 1; | |
min-width: 0; | |
} | |
</style> | |
""") | |
demo.queue(default_concurrency_limit=100, max_size=100).launch(max_threads=100, | |
ssr_mode=False, | |
share=args.share, | |
inbrowser=args.inbrowser, | |
server_port=args.server_port, | |
server_name=args.server_name,) | |
DEFAULT_CKPT_PATH = "Qwen/Qwen2.5-Omni-7B" | |
def _get_args(): | |
parser = ArgumentParser() | |
parser.add_argument('-c', | |
'--checkpoint-path', | |
type=str, | |
default=DEFAULT_CKPT_PATH, | |
help='Checkpoint name or path, default to %(default)r') | |
parser.add_argument('--cpu-only', action='store_true', help='Run demo with CPU only') | |
parser.add_argument('--flash-attn2', | |
action='store_true', | |
default=False, | |
help='Enable flash_attention_2 when loading the model.') | |
parser.add_argument('--share', | |
action='store_true', | |
default=False, | |
help='Create a publicly shareable link for the interface.') | |
parser.add_argument('--inbrowser', | |
action='store_true', | |
default=False, | |
help='Automatically launch the interface in a new tab on the default browser.') | |
parser.add_argument('--server-port', type=int, default=7860, help='Demo server port.') | |
parser.add_argument('--server-name', type=str, default='127.0.0.1', help='Demo server name.') | |
parser.add_argument('--ui-language', type=str, choices=['en', 'zh'], default='en', help='Display language for the UI.') | |
args = parser.parse_args() | |
return args | |
if __name__ == "__main__": | |
args = _get_args() | |
args.share = True | |
model, processor = _load_model_processor(args) | |
_launch_demo(args, model, processor) |