import torch import os import argparse import numpy as np import copy import gradio as gr import re import torchaudio import io import cv2 import time import math from numba import jit import spaces from huggingface_hub import snapshot_download from vita.constants import ( DEFAULT_AUDIO_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_VIDEO_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX, MAX_IMAGE_LENGTH, MIN_IMAGE_LENGTH, ) from vita.conversation import conv_templates, SeparatorStyle from vita.model.builder import load_pretrained_model from vita.util.mm_utils import ( KeywordsStoppingCriteria, get_model_name_from_path, tokenizer_image_token, tokenizer_image_audio_token, ) from vita.util.utils import disable_torch_init from PIL import Image from decord import VideoReader, cpu from vita.model.vita_tts.decoder.llm2tts import llm2TTS from vita.model.language_model.vita_qwen2 import VITAQwen2Config, VITAQwen2ForCausalLM from vita.util.data_utils_video_audio_neg_patch import dynamic_preprocess from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoFeatureExtractor decoder_topk = 2 codec_chunk_size = 40 codec_padding_size = 10 PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏." device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @jit def float_to_int16(audio: np.ndarray) -> np.ndarray: am = int(math.ceil(float(np.abs(audio).max())) * 32768) am = 32767 * 32768 // am return np.multiply(audio, am).astype(np.int16) def remove_special_characters(input_str): # Remove special tokens special_tokens = ['☞', '☟', '☜', '', '<|im_end|>'] for token in special_tokens: input_str = input_str.replace(token, '') return input_str def replace_equation(sentence): special_notations = { "sin": " sine ", "cos": " cosine ", "tan": " tangent ", "cot": " cotangent ", "sec": " secant ", "csc": " cosecant ", "log": " logarithm ", "exp": "e^", "sqrt": "根号 ", "abs": "绝对值 ", } special_operators = { "+": "加", "-": "减", "*": "乘", "/": "除", "=": "等于", '!=': '不等于', '>': '大于', '<': '小于', '>=': '大于等于', '<=': '小于等于', } greek_letters = { "α": "alpha ", "β": "beta ", "γ": "gamma ", "δ": "delta ", "ε": "epsilon ", "ζ": "zeta ", "η": "eta ", "θ": "theta ", "ι": "iota ", "κ": "kappa ", "λ": "lambda ", "μ": "mu ", "ν": "nu ", "ξ": "xi ", "ο": "omicron ", "π": "派 ", "ρ": "rho ", "σ": "sigma ", "τ": "tau ", "υ": "upsilon ", "φ": "phi ", "χ": "chi ", "ψ": "psi ", "ω": "omega " } sentence = sentence.replace('**', ' ') sentence = re.sub(r'(?= 0.0 else 0.0 end_time = end_time if end_time >= 0.0 else 0.0 if start_time > end_time: start_time, end_time = end_time, start_time elif start_time == end_time: end_time = start_time + 1 if os.path.exists(video_path): vreader = VideoReader(video_path, ctx=cpu(0)) else: raise FileNotFoundError fps = vreader.get_avg_fps() f_start = 0 if start_time is None else int(start_time * fps) f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1)) num_frames = f_end - f_start + 1 if num_frames > 0: # T x 3 x H x W sample_fps = int(video_framerate) t_stride = int(round(float(fps) / sample_fps)) all_pos = list(range(f_start, f_end + 1, t_stride)) if len(all_pos) > max_frames: sample_pos = [ all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int) ] elif len(all_pos) < min_frames: sample_pos = [ all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=min_frames, dtype=int) ] else: sample_pos = all_pos patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()] if image_aspect_ratio == "pad": def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result patch_images = [ expand2square(i, tuple(int(x * 255) for x in image_processor.image_mean)) for i in patch_images ] patch_images = [ image_processor.preprocess(i, return_tensors="pt")["pixel_values"][0] for i in patch_images ] else: patch_images = [ image_processor.preprocess(i, return_tensors="pt")["pixel_values"][0] for i in patch_images ] patch_images = torch.stack(patch_images) slice_len = patch_images.shape[0] return patch_images, slice_len else: print(f"video path: {video_path} error.") def _parse_text(text): lines = text.split("\n") lines = [line for line in lines if line != ""] count = 0 for i, line in enumerate(lines): if "```" in line: count += 1 items = line.split("`") if count % 2 == 1: lines[i] = f'
'
            else:
                lines[i] = "
" else: if i > 0 and count % 2 == 1: line = line.replace("`", r"\`") line = line.replace("<", "<") line = line.replace(">", ">") line = line.replace(" ", " ") line = line.replace("*", "*") line = line.replace("_", "_") line = line.replace("-", "-") line = line.replace(".", ".") line = line.replace("!", "!") line = line.replace("(", "(") line = line.replace(")", ")") line = line.replace("$", "$") lines[i] = "
" + line return "".join(lines) MODEL_NAME = "VITA-MLLM/VITA-1.5" model_path = snapshot_download(MODEL_NAME, local_dir="VITA_ckpt") model_type = "qwen2p5_instruct" tokenizer, model, feature_extractor, context_len = load_pretrained_model( model_path, model_base=None, model_name="VITA-1.5", model_type="qwen2p5_instruct" ) model.resize_token_embeddings(len(tokenizer)) vision_tower = model.get_vision_tower() if not vision_tower.is_loaded: vision_tower.load_model() image_processor = vision_tower.image_processor audio_encoder = model.get_audio_encoder() audio_encoder.to(dtype=torch.float16) audio_processor = audio_encoder.audio_processor model.eval() tts = llm2TTS(os.path.join(model_path, 'vita_tts_ckpt/')) llm_embedding = load_model_embemding(model_path).to(device) @spaces.GPU def predict(_chatbot, task_history): chat_query = task_history[-1][0] print(task_history) conv_mode = "qwen2p5_instruct" conv = conv_templates[conv_mode].copy() all_audio_path = [] all_visual_tensor = [] qs = '' input_mode = 'lang' for i, (q, a) in enumerate(task_history): if isinstance(q, (tuple, list)): if is_image(q[0]): image = Image.open(q[0]).convert("RGB") image, p_num = dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=True) assert len(p_num) == 1 image_tensor = model.process_images(image, model.config).to( dtype=model.dtype, device="cuda" ) all_visual_tensor.append(image_tensor) input_mode = 'image' qs += DEFAULT_IMAGE_TOKEN * p_num[0] + '\n' elif is_video(q[0]): video_frames, slice_len = _get_rawvideo_dec( q[0], image_processor, max_frames=MAX_IMAGE_LENGTH, video_framerate=1, image_aspect_ratio=getattr(model.config, "image_aspect_ratio", None), ) image_tensor = video_frames.half().cuda() all_visual_tensor.append(image_tensor) input_mode = 'video' qs += DEFAULT_IMAGE_TOKEN * slice_len + '\n' elif is_wav(q[0]): if a is not None and a.startswith('☜'): continue else: all_audio_path.append(q[0]) new_q = qs + DEFAULT_AUDIO_TOKEN qs = '' conv.append_message(conv.roles[0], new_q) conv.append_message(conv.roles[1], a) else: new_q = qs + q qs = '' conv.append_message(conv.roles[0], new_q) conv.append_message(conv.roles[1], a) if qs: conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt(input_mode) if all_audio_path: # 处理多个音频并合并 all_audio_features = [] all_audio_lengths = [] all_audio_for_llm_lens = [] for audio_path in all_audio_path: audio, audio_for_llm_lens = audio_processor.process(os.path.join(audio_path)) all_audio_features.append(audio) all_audio_lengths.append(audio.shape[0]) all_audio_for_llm_lens.append(audio_for_llm_lens) # 合并音频特征 combined_audio = torch.cat(all_audio_features, dim=0) combined_audio = torch.unsqueeze(combined_audio, dim=0) # 合并长度信息 combined_length = torch.tensor(sum(all_audio_lengths)) combined_length = torch.unsqueeze(combined_length, dim=0) # 合并LLM长度 combined_for_llm_lens = torch.tensor(sum(all_audio_for_llm_lens)) combined_for_llm_lens = torch.unsqueeze(combined_for_llm_lens, dim=0) audios = dict() audios["audios"] = combined_audio.half().cuda() audios["lengths"] = combined_length.half().cuda() audios["lengths_for_llm"] = combined_for_llm_lens.cuda() input_ids = ( tokenizer_image_audio_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") .unsqueeze(0) .cuda() ) else: # 空音频处理 audio = torch.zeros(400, 80) audio_length = audio.shape[0] audio_for_llm_lens = 60 audio = torch.unsqueeze(audio, dim=0) audio_length = torch.unsqueeze(torch.tensor(audio_length), dim=0) audio_for_llm_lens = torch.unsqueeze(torch.tensor(audio_for_llm_lens), dim=0) audios = dict() audios["audios"] = audio.half().cuda() audios["lengths"] = audio_length.half().cuda() audios["lengths_for_llm"] = audio_for_llm_lens.cuda() input_ids = ( tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") .unsqueeze(0) .cuda() ) if len(all_visual_tensor) > 0: all_visual_tensor = torch.cat(all_visual_tensor, dim=0) else: all_visual_tensor = torch.zeros((1, 3, 448, 448)).to(dtype=model.dtype, device="cuda") if type(all_visual_tensor) is list: print("all_visual_tensor is a list: ", len(all_visual_tensor)) if type(all_visual_tensor) is torch.Tensor: print("all_visual_tensor is a tensor: ", all_visual_tensor.shape) # 停止条件设置 stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) # 生成文本 start_time = time.time() with torch.inference_mode(): output_ids = model.generate( input_ids, images=all_visual_tensor, audios=audios, do_sample=False, temperature=0.01, top_p=None, num_beams=1, output_scores=True, return_dict_in_generate=True, max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria], shared_v_pid_stride=None, ) infer_time = time.time() - start_time output_ids = output_ids.sequences input_token_len = input_ids.shape[1] outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=False)[0] outputs = outputs.strip() if outputs.endswith(stop_str): outputs = outputs[: -len(stop_str)] outputs = outputs.strip() print(f"Generated output: {outputs}") print(f"Time consumed: {infer_time}") task_history[-1] = (chat_query, outputs) remove_special_characters_output = remove_special_characters(outputs) _chatbot[-1] = (chat_query, _parse_text(remove_special_characters_output)) print("query",chat_query) print("task_history",task_history) print(_chatbot) print("answer: ",outputs) yield _chatbot def add_text(history, task_history, text): task_text = text if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION: task_text = text[:-1] history = history + [(_parse_text(text), None)] task_history = task_history + [(task_text, None)] return history, task_history, "" def add_file(history, task_history, file): history = history + [((file.name,), None)] task_history = task_history + [((file.name,), None)] return history, task_history def add_audio(history, task_history, file): print(file) if file is None: return history, task_history history = history + [((file,), None)] task_history = task_history + [((file,), None)] return history, task_history def add_video(history, task_history, file): print(file) if file is None: return history, task_history new_file_name = file.replace(".webm",".mp4") if file.endswith(".webm"): convert_webm_to_mp4(file, new_file_name) history = history + [((new_file_name,), None)] task_history = task_history + [((new_file_name,), None)] print("add_video", history, task_history) return history, task_history def reset_user_input(): return gr.update(value="") def reset_state(task_history): task_history.clear() return [] @spaces.GPU def stream_audio_output(history, task_history): print("stream_audio_output", history, task_history) text = history[-1][-1] text = text.replace("
", "") print("text", text) if not text: # import pdb;pdb.set_trace() yield None, None return llm_resounse = replace_equation(remove_special_characters(text)) #print('tts_text', llm_resounse) for idx, text in enumerate(split_into_sentences(llm_resounse)): embeddings = llm_embedding(torch.tensor(tokenizer.encode(text)).cuda()) for seg in tts.run(embeddings.reshape(-1, 896).unsqueeze(0), decoder_topk, None, codec_chunk_size, codec_padding_size): if idx == 0: try: split_idx = torch.nonzero(seg.abs() > 0.03, as_tuple=True)[-1][0] seg = seg[:, :, split_idx:] except: print('Do not need to split') pass if seg is not None and len(seg) > 0: seg = seg.to(torch.float32).cpu().numpy() yield 24000, float_to_int16(seg).T with gr.Blocks(title="VideoMLLM") as demo: gr.Markdown("""
VITA
""") chatbot = gr.Chatbot(label='VITA', elem_classes="control-height", height=500) query = gr.Textbox(lines=2, label='Text Input') task_history = gr.State([]) with gr.Row(): add_text_button = gr.Button("Submit Text (提交文本)") add_audio_button = gr.Button("Submit Audio (提交音频)") with gr.Row(): with gr.Column(scale=2): addfile_btn = gr.UploadButton("📁 Upload (上传文件[视频,图片])", file_types=["video", "image"]) video_input = gr.Video(sources=[ "webcam"], height=400, width=700, container=True, interactive=True, show_download_button=True, label="📹 Video Recording (视频录制)") with gr.Column(scale=1): empty_bin = gr.Button("🧹 Clear History (清除历史)") record_btn = gr.Audio(sources=[ "microphone","upload"], type="filepath", label="🎤 Record or Upload Audio (录音或上传音频)", show_download_button=True, waveform_options=gr.WaveformOptions(sample_rate=16000)) audio_output = gr.Audio( label="Output Audio", value=None, format= "wav", autoplay=True, streaming=True, interactive=False, show_label=True, waveform_options=gr.WaveformOptions( sample_rate=24000, ), ) add_text_button.click(add_text, [chatbot, task_history, query], [chatbot, task_history], show_progress=True).then( reset_user_input, [], [query] ).then( predict, [chatbot, task_history], [chatbot], show_progress=True ).then( stream_audio_output,[chatbot, task_history], [audio_output], ) video_input.stop_recording(add_video, [chatbot, task_history, video_input], [chatbot, task_history], show_progress=True) empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True) addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True) add_audio_button.click(add_audio, [chatbot, task_history,record_btn], [chatbot, task_history], show_progress=True).then( predict, [chatbot, task_history], [chatbot], show_progress=True ).then( stream_audio_output,[chatbot, task_history], [audio_output], ) demo.launch()