Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import os | |
import argparse | |
import numpy as np | |
import copy | |
import gradio as gr | |
import re | |
import torchaudio | |
import io | |
import cv2 | |
import time | |
import math | |
from numba import jit | |
import spaces | |
from huggingface_hub import snapshot_download | |
from vita.constants import ( | |
DEFAULT_AUDIO_TOKEN, | |
DEFAULT_IMAGE_TOKEN, | |
DEFAULT_VIDEO_TOKEN, | |
IGNORE_INDEX, | |
IMAGE_TOKEN_INDEX, | |
MAX_IMAGE_LENGTH, | |
MIN_IMAGE_LENGTH, | |
) | |
from vita.conversation import conv_templates, SeparatorStyle | |
from vita.model.builder import load_pretrained_model | |
from vita.util.mm_utils import ( | |
KeywordsStoppingCriteria, | |
get_model_name_from_path, | |
tokenizer_image_token, | |
tokenizer_image_audio_token, | |
) | |
from vita.util.utils import disable_torch_init | |
from PIL import Image | |
from decord import VideoReader, cpu | |
from vita.model.vita_tts.decoder.llm2tts import llm2TTS | |
from vita.model.language_model.vita_qwen2 import VITAQwen2Config, VITAQwen2ForCausalLM | |
from vita.util.data_utils_video_audio_neg_patch import dynamic_preprocess | |
from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoFeatureExtractor | |
decoder_topk = 2 | |
codec_chunk_size = 40 | |
codec_padding_size = 10 | |
PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏." | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def float_to_int16(audio: np.ndarray) -> np.ndarray: | |
am = int(math.ceil(float(np.abs(audio).max())) * 32768) | |
am = 32767 * 32768 // am | |
return np.multiply(audio, am).astype(np.int16) | |
def remove_special_characters(input_str): | |
# Remove special tokens | |
special_tokens = ['☞', '☟', '☜', '<unk>', '<|im_end|>'] | |
for token in special_tokens: | |
input_str = input_str.replace(token, '') | |
return input_str | |
def replace_equation(sentence): | |
special_notations = { | |
"sin": " sine ", | |
"cos": " cosine ", | |
"tan": " tangent ", | |
"cot": " cotangent ", | |
"sec": " secant ", | |
"csc": " cosecant ", | |
"log": " logarithm ", | |
"exp": "e^", | |
"sqrt": "根号 ", | |
"abs": "绝对值 ", | |
} | |
special_operators = { | |
"+": "加", | |
"-": "减", | |
"*": "乘", | |
"/": "除", | |
"=": "等于", | |
'!=': '不等于', | |
'>': '大于', | |
'<': '小于', | |
'>=': '大于等于', | |
'<=': '小于等于', | |
} | |
greek_letters = { | |
"α": "alpha ", | |
"β": "beta ", | |
"γ": "gamma ", | |
"δ": "delta ", | |
"ε": "epsilon ", | |
"ζ": "zeta ", | |
"η": "eta ", | |
"θ": "theta ", | |
"ι": "iota ", | |
"κ": "kappa ", | |
"λ": "lambda ", | |
"μ": "mu ", | |
"ν": "nu ", | |
"ξ": "xi ", | |
"ο": "omicron ", | |
"π": "派 ", | |
"ρ": "rho ", | |
"σ": "sigma ", | |
"τ": "tau ", | |
"υ": "upsilon ", | |
"φ": "phi ", | |
"χ": "chi ", | |
"ψ": "psi ", | |
"ω": "omega " | |
} | |
sentence = sentence.replace('**', ' ') | |
sentence = re.sub(r'(?<![\d)])-(\d+)', r'负\1', sentence) | |
for key in special_notations: | |
sentence = sentence.replace(key, special_notations[key]) | |
for key in special_operators: | |
sentence = sentence.replace(key, special_operators[key]) | |
for key in greek_letters: | |
sentence = sentence.replace(key, greek_letters[key]) | |
sentence = re.sub(r'\(?(\d+)\)?\((\d+)\)', r'\1乘\2', sentence) | |
sentence = re.sub(r'\(?(\w+)\)?\^\(?(\w+)\)?', r'\1的\2次方', sentence) | |
return sentence | |
def is_video(file_path): | |
video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm'} | |
_, ext = os.path.splitext(file_path) | |
return ext.lower() in video_extensions | |
def is_image(file_path): | |
image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff'} | |
_, ext = os.path.splitext(file_path) | |
return ext.lower() in image_extensions | |
def is_wav(file_path): | |
wav_extensions = {'.wav'} | |
_, ext = os.path.splitext(file_path) | |
return ext.lower() in wav_extensions | |
def load_model_embemding(model_path): | |
config_path = os.path.join(model_path, 'config.json') | |
config = VITAQwen2Config.from_pretrained(config_path) | |
model = VITAQwen2ForCausalLM.from_pretrained(model_path, config=config, low_cpu_mem_usage=True) | |
embedding = model.get_input_embeddings() | |
del model | |
return embedding | |
def split_into_sentences(text): | |
sentence_endings = re.compile(r'[,。?\n!?、,?.!]') | |
sentences = sentence_endings.split(text) | |
return [sentence.strip() for sentence in sentences if sentence.strip()] | |
def convert_webm_to_mp4(input_file, output_file): | |
try: | |
cap = cv2.VideoCapture(input_file) | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out = cv2.VideoWriter(output_file, fourcc, 20.0, (int(cap.get(3)), int(cap.get(4)))) | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
out.write(frame) | |
cap.release() | |
out.release() | |
except Exception as e: | |
print(f"Error: {e}") | |
raise | |
def _get_rawvideo_dec( | |
video_path, | |
image_processor=None, | |
max_frames=MAX_IMAGE_LENGTH, | |
min_frames=MIN_IMAGE_LENGTH, | |
image_resolution=384, | |
video_framerate=1, | |
s=None, | |
e=None, | |
image_aspect_ratio="pad", | |
): | |
# speed up video decode via decord. | |
if s is None: | |
start_time, end_time = None, None | |
else: | |
start_time = int(s) | |
end_time = int(e) | |
start_time = start_time if start_time >= 0.0 else 0.0 | |
end_time = end_time if end_time >= 0.0 else 0.0 | |
if start_time > end_time: | |
start_time, end_time = end_time, start_time | |
elif start_time == end_time: | |
end_time = start_time + 1 | |
if os.path.exists(video_path): | |
vreader = VideoReader(video_path, ctx=cpu(0)) | |
else: | |
raise FileNotFoundError | |
fps = vreader.get_avg_fps() | |
f_start = 0 if start_time is None else int(start_time * fps) | |
f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1)) | |
num_frames = f_end - f_start + 1 | |
if num_frames > 0: | |
# T x 3 x H x W | |
sample_fps = int(video_framerate) | |
t_stride = int(round(float(fps) / sample_fps)) | |
all_pos = list(range(f_start, f_end + 1, t_stride)) | |
if len(all_pos) > max_frames: | |
sample_pos = [ | |
all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int) | |
] | |
elif len(all_pos) < min_frames: | |
sample_pos = [ | |
all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=min_frames, dtype=int) | |
] | |
else: | |
sample_pos = all_pos | |
patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()] | |
if image_aspect_ratio == "pad": | |
def expand2square(pil_img, background_color): | |
width, height = pil_img.size | |
if width == height: | |
return pil_img | |
elif width > height: | |
result = Image.new(pil_img.mode, (width, width), background_color) | |
result.paste(pil_img, (0, (width - height) // 2)) | |
return result | |
else: | |
result = Image.new(pil_img.mode, (height, height), background_color) | |
result.paste(pil_img, ((height - width) // 2, 0)) | |
return result | |
patch_images = [ | |
expand2square(i, tuple(int(x * 255) for x in image_processor.image_mean)) | |
for i in patch_images | |
] | |
patch_images = [ | |
image_processor.preprocess(i, return_tensors="pt")["pixel_values"][0] | |
for i in patch_images | |
] | |
else: | |
patch_images = [ | |
image_processor.preprocess(i, return_tensors="pt")["pixel_values"][0] | |
for i in patch_images | |
] | |
patch_images = torch.stack(patch_images) | |
slice_len = patch_images.shape[0] | |
return patch_images, slice_len | |
else: | |
print(f"video path: {video_path} error.") | |
def _parse_text(text): | |
lines = text.split("\n") | |
lines = [line for line in lines if line != ""] | |
count = 0 | |
for i, line in enumerate(lines): | |
if "```" in line: | |
count += 1 | |
items = line.split("`") | |
if count % 2 == 1: | |
lines[i] = f'<pre><code class="language-{items[-1]}">' | |
else: | |
lines[i] = "<br></code></pre>" | |
else: | |
if i > 0 and count % 2 == 1: | |
line = line.replace("`", r"\`") | |
line = line.replace("<", "<") | |
line = line.replace(">", ">") | |
line = line.replace(" ", " ") | |
line = line.replace("*", "*") | |
line = line.replace("_", "_") | |
line = line.replace("-", "-") | |
line = line.replace(".", ".") | |
line = line.replace("!", "!") | |
line = line.replace("(", "(") | |
line = line.replace(")", ")") | |
line = line.replace("$", "$") | |
lines[i] = "<br>" + line | |
return "".join(lines) | |
MODEL_NAME = "VITA-MLLM/VITA-1.5" | |
model_path = snapshot_download(MODEL_NAME, local_dir="VITA_ckpt") | |
model_type = "qwen2p5_instruct" | |
tokenizer, model, feature_extractor, context_len = load_pretrained_model( | |
model_path, model_base=None, model_name="VITA-1.5", model_type="qwen2p5_instruct" | |
) | |
model.resize_token_embeddings(len(tokenizer)) | |
vision_tower = model.get_vision_tower() | |
if not vision_tower.is_loaded: | |
vision_tower.load_model() | |
image_processor = vision_tower.image_processor | |
audio_encoder = model.get_audio_encoder() | |
audio_encoder.to(dtype=torch.float16) | |
audio_processor = audio_encoder.audio_processor | |
model.eval() | |
tts = llm2TTS(os.path.join(model_path, 'vita_tts_ckpt/')) | |
llm_embedding = load_model_embemding(model_path).to(device) | |
def predict(_chatbot, task_history): | |
chat_query = task_history[-1][0] | |
print(task_history) | |
conv_mode = "qwen2p5_instruct" | |
conv = conv_templates[conv_mode].copy() | |
all_audio_path = [] | |
all_visual_tensor = [] | |
qs = '' | |
input_mode = 'lang' | |
for i, (q, a) in enumerate(task_history): | |
if isinstance(q, (tuple, list)): | |
if is_image(q[0]): | |
image = Image.open(q[0]).convert("RGB") | |
image, p_num = dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=True) | |
assert len(p_num) == 1 | |
image_tensor = model.process_images(image, model.config).to( | |
dtype=model.dtype, device="cuda" | |
) | |
all_visual_tensor.append(image_tensor) | |
input_mode = 'image' | |
qs += DEFAULT_IMAGE_TOKEN * p_num[0] + '\n' | |
elif is_video(q[0]): | |
video_frames, slice_len = _get_rawvideo_dec( | |
q[0], | |
image_processor, | |
max_frames=MAX_IMAGE_LENGTH, | |
video_framerate=1, | |
image_aspect_ratio=getattr(model.config, "image_aspect_ratio", None), | |
) | |
image_tensor = video_frames.half().cuda() | |
all_visual_tensor.append(image_tensor) | |
input_mode = 'video' | |
qs += DEFAULT_IMAGE_TOKEN * slice_len + '\n' | |
elif is_wav(q[0]): | |
if a is not None and a.startswith('☜'): | |
continue | |
else: | |
all_audio_path.append(q[0]) | |
new_q = qs + DEFAULT_AUDIO_TOKEN | |
qs = '' | |
conv.append_message(conv.roles[0], new_q) | |
conv.append_message(conv.roles[1], a) | |
else: | |
new_q = qs + q | |
qs = '' | |
conv.append_message(conv.roles[0], new_q) | |
conv.append_message(conv.roles[1], a) | |
if qs: | |
conv.append_message(conv.roles[0], qs) | |
conv.append_message(conv.roles[1], None) | |
prompt = conv.get_prompt(input_mode) | |
if all_audio_path: | |
# 处理多个音频并合并 | |
all_audio_features = [] | |
all_audio_lengths = [] | |
all_audio_for_llm_lens = [] | |
for audio_path in all_audio_path: | |
audio, audio_for_llm_lens = audio_processor.process(os.path.join(audio_path)) | |
all_audio_features.append(audio) | |
all_audio_lengths.append(audio.shape[0]) | |
all_audio_for_llm_lens.append(audio_for_llm_lens) | |
# 合并音频特征 | |
combined_audio = torch.cat(all_audio_features, dim=0) | |
combined_audio = torch.unsqueeze(combined_audio, dim=0) | |
# 合并长度信息 | |
combined_length = torch.tensor(sum(all_audio_lengths)) | |
combined_length = torch.unsqueeze(combined_length, dim=0) | |
# 合并LLM长度 | |
combined_for_llm_lens = torch.tensor(sum(all_audio_for_llm_lens)) | |
combined_for_llm_lens = torch.unsqueeze(combined_for_llm_lens, dim=0) | |
audios = dict() | |
audios["audios"] = combined_audio.half().cuda() | |
audios["lengths"] = combined_length.half().cuda() | |
audios["lengths_for_llm"] = combined_for_llm_lens.cuda() | |
input_ids = ( | |
tokenizer_image_audio_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") | |
.unsqueeze(0) | |
.cuda() | |
) | |
else: | |
# 空音频处理 | |
audio = torch.zeros(400, 80) | |
audio_length = audio.shape[0] | |
audio_for_llm_lens = 60 | |
audio = torch.unsqueeze(audio, dim=0) | |
audio_length = torch.unsqueeze(torch.tensor(audio_length), dim=0) | |
audio_for_llm_lens = torch.unsqueeze(torch.tensor(audio_for_llm_lens), dim=0) | |
audios = dict() | |
audios["audios"] = audio.half().cuda() | |
audios["lengths"] = audio_length.half().cuda() | |
audios["lengths_for_llm"] = audio_for_llm_lens.cuda() | |
input_ids = ( | |
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") | |
.unsqueeze(0) | |
.cuda() | |
) | |
if len(all_visual_tensor) > 0: | |
all_visual_tensor = torch.cat(all_visual_tensor, dim=0) | |
else: | |
all_visual_tensor = torch.zeros((1, 3, 448, 448)).to(dtype=model.dtype, device="cuda") | |
if type(all_visual_tensor) is list: | |
print("all_visual_tensor is a list: ", len(all_visual_tensor)) | |
if type(all_visual_tensor) is torch.Tensor: | |
print("all_visual_tensor is a tensor: ", all_visual_tensor.shape) | |
# 停止条件设置 | |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 | |
keywords = [stop_str] | |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) | |
# 生成文本 | |
start_time = time.time() | |
with torch.inference_mode(): | |
output_ids = model.generate( | |
input_ids, | |
images=all_visual_tensor, | |
audios=audios, | |
do_sample=False, | |
temperature=0.01, | |
top_p=None, | |
num_beams=1, | |
output_scores=True, | |
return_dict_in_generate=True, | |
max_new_tokens=1024, | |
use_cache=True, | |
stopping_criteria=[stopping_criteria], | |
shared_v_pid_stride=None, | |
) | |
infer_time = time.time() - start_time | |
output_ids = output_ids.sequences | |
input_token_len = input_ids.shape[1] | |
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=False)[0] | |
outputs = outputs.strip() | |
if outputs.endswith(stop_str): | |
outputs = outputs[: -len(stop_str)] | |
outputs = outputs.strip() | |
print(f"Generated output: {outputs}") | |
print(f"Time consumed: {infer_time}") | |
task_history[-1] = (chat_query, outputs) | |
remove_special_characters_output = remove_special_characters(outputs) | |
_chatbot[-1] = (chat_query, _parse_text(remove_special_characters_output)) | |
print("query",chat_query) | |
print("task_history",task_history) | |
print(_chatbot) | |
print("answer: ",outputs) | |
yield _chatbot | |
def add_text(history, task_history, text): | |
task_text = text | |
if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION: | |
task_text = text[:-1] | |
history = history + [(_parse_text(text), None)] | |
task_history = task_history + [(task_text, None)] | |
return history, task_history, "" | |
def add_file(history, task_history, file): | |
history = history + [((file.name,), None)] | |
task_history = task_history + [((file.name,), None)] | |
return history, task_history | |
def add_audio(history, task_history, file): | |
print(file) | |
if file is None: | |
return history, task_history | |
history = history + [((file,), None)] | |
task_history = task_history + [((file,), None)] | |
return history, task_history | |
def add_video(history, task_history, file): | |
print(file) | |
if file is None: | |
return history, task_history | |
new_file_name = file.replace(".webm",".mp4") | |
if file.endswith(".webm"): | |
convert_webm_to_mp4(file, new_file_name) | |
history = history + [((new_file_name,), None)] | |
task_history = task_history + [((new_file_name,), None)] | |
print("add_video", history, task_history) | |
return history, task_history | |
def reset_user_input(): | |
return gr.update(value="") | |
def reset_state(task_history): | |
task_history.clear() | |
return [] | |
def stream_audio_output(history, task_history): | |
print("stream_audio_output", history, task_history) | |
text = history[-1][-1] | |
text = text.replace("<br>", "") | |
print("text", text) | |
if not text: | |
# import pdb;pdb.set_trace() | |
yield None, None | |
return | |
llm_resounse = replace_equation(remove_special_characters(text)) | |
#print('tts_text', llm_resounse) | |
for idx, text in enumerate(split_into_sentences(llm_resounse)): | |
embeddings = llm_embedding(torch.tensor(tokenizer.encode(text)).cuda()) | |
for seg in tts.run(embeddings.reshape(-1, 896).unsqueeze(0), decoder_topk, | |
None, | |
codec_chunk_size, codec_padding_size): | |
if idx == 0: | |
try: | |
split_idx = torch.nonzero(seg.abs() > 0.03, as_tuple=True)[-1][0] | |
seg = seg[:, :, split_idx:] | |
except: | |
print('Do not need to split') | |
pass | |
if seg is not None and len(seg) > 0: | |
seg = seg.to(torch.float32).cpu().numpy() | |
yield 24000, float_to_int16(seg).T | |
with gr.Blocks(title="VideoMLLM") as demo: | |
gr.Markdown("""<center><font size=8>VITA</center>""") | |
chatbot = gr.Chatbot(label='VITA', elem_classes="control-height", height=500) | |
query = gr.Textbox(lines=2, label='Text Input') | |
task_history = gr.State([]) | |
with gr.Row(): | |
add_text_button = gr.Button("Submit Text (提交文本)") | |
add_audio_button = gr.Button("Submit Audio (提交音频)") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
addfile_btn = gr.UploadButton("📁 Upload (上传文件[视频,图片])", file_types=["video", "image"]) | |
video_input = gr.Video(sources=[ "webcam"], height=400, width=700, container=True, interactive=True, show_download_button=True, label="📹 Video Recording (视频录制)") | |
with gr.Column(scale=1): | |
empty_bin = gr.Button("🧹 Clear History (清除历史)") | |
record_btn = gr.Audio(sources=[ "microphone","upload"], type="filepath", label="🎤 Record or Upload Audio (录音或上传音频)", show_download_button=True, waveform_options=gr.WaveformOptions(sample_rate=16000)) | |
audio_output = gr.Audio( | |
label="Output Audio", | |
value=None, | |
format= "wav", | |
autoplay=True, | |
streaming=True, | |
interactive=False, | |
show_label=True, | |
waveform_options=gr.WaveformOptions( | |
sample_rate=24000, | |
), | |
) | |
add_text_button.click(add_text, [chatbot, task_history, query], [chatbot, task_history], show_progress=True).then( | |
reset_user_input, [], [query] | |
).then( | |
predict, [chatbot, task_history], [chatbot], show_progress=True | |
).then( | |
stream_audio_output,[chatbot, task_history], [audio_output], | |
) | |
video_input.stop_recording(add_video, [chatbot, task_history, video_input], [chatbot, task_history], show_progress=True) | |
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True) | |
addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True) | |
add_audio_button.click(add_audio, [chatbot, task_history,record_btn], [chatbot, task_history], show_progress=True).then( | |
predict, [chatbot, task_history], [chatbot], show_progress=True | |
).then( | |
stream_audio_output,[chatbot, task_history], [audio_output], | |
) | |
demo.launch() | |