lxysl commited on
Commit
fbe8f3a
·
1 Parent(s): 8301b1c

achieve normal interaction

Browse files
Files changed (3) hide show
  1. app.py +215 -94
  2. requirements.txt +2 -2
  3. vita/model/vita_arch.py +8 -0
app.py CHANGED
@@ -8,33 +8,42 @@ import re
8
  import torchaudio
9
  import io
10
  import cv2
 
11
  import math
12
- import spaces
13
  from numba import jit
 
14
  from huggingface_hub import snapshot_download
15
-
16
- from vita.constants import DEFAULT_AUDIO_TOKEN, DEFAULT_IMAGE_TOKEN, MAX_IMAGE_LENGTH, MIN_IMAGE_LENGTH, IMAGE_TOKEN_INDEX, AUDIO_TOKEN_INDEX
 
 
 
 
 
 
 
17
  from vita.conversation import conv_templates, SeparatorStyle
18
- from vita.util.mm_utils import tokenizer_image_token, tokenizer_image_audio_token
 
 
 
 
 
 
 
19
  from PIL import Image
20
  from decord import VideoReader, cpu
21
- from vita.model.builder import load_pretrained_model
22
  from vita.model.vita_tts.decoder.llm2tts import llm2TTS
23
  from vita.model.language_model.vita_qwen2 import VITAQwen2Config, VITAQwen2ForCausalLM
24
-
 
25
  decoder_topk = 2
26
  codec_chunk_size = 40
27
  codec_padding_size = 10
28
 
29
- PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘'‛""„‟…‧﹏."
30
 
31
- MODEL_NAME = "VITA-MLLM/VITA-1.5"
32
- model_path = snapshot_download(MODEL_NAME, local_dir="VITA_ckpt")
33
- tokenizer, model, feature_extractor, context_len = load_pretrained_model(
34
- model_path, model_base=None, model_name="VITA-1.5", model_type="qwen2p5_instruct"
35
- )
36
- llm_embedding = model.get_input_embeddings().cuda()
37
- tts = llm2TTS(os.path.join(model_path, 'vita_tts_ckpt/'))
38
 
39
  @jit
40
  def float_to_int16(audio: np.ndarray) -> np.ndarray:
@@ -42,7 +51,6 @@ def float_to_int16(audio: np.ndarray) -> np.ndarray:
42
  am = 32767 * 32768 // am
43
  return np.multiply(audio, am).astype(np.int16)
44
 
45
-
46
  def remove_special_characters(input_str):
47
  # Remove special tokens
48
  special_tokens = ['☞', '☟', '☜', '<unk>', '<|im_end|>']
@@ -50,7 +58,6 @@ def remove_special_characters(input_str):
50
  input_str = input_str.replace(token, '')
51
  return input_str
52
 
53
-
54
  def replace_equation(sentence):
55
  special_notations = {
56
  "sin": " sine ",
@@ -139,7 +146,7 @@ def is_wav(file_path):
139
  return ext.lower() in wav_extensions
140
 
141
  def load_model_embemding(model_path):
142
- config_path = os.path.join(model_path, 'origin_config.json')
143
  config = VITAQwen2Config.from_pretrained(config_path)
144
  model = VITAQwen2ForCausalLM.from_pretrained(model_path, config=config, low_cpu_mem_usage=True)
145
  embedding = model.get_input_embeddings()
@@ -170,14 +177,26 @@ def convert_webm_to_mp4(input_file, output_file):
170
  raise
171
 
172
 
173
- def _get_rawvideo_dec(video_path, max_frames=MAX_IMAGE_LENGTH, min_frames=MIN_IMAGE_LENGTH, video_framerate=1, s=None, e=None):
174
- if s is None or e is None:
 
 
 
 
 
 
 
 
 
 
 
 
175
  start_time, end_time = None, None
176
  else:
177
  start_time = int(s)
178
  end_time = int(e)
179
- start_time = max(start_time, 0)
180
- end_time = max(end_time, 0)
181
  if start_time > end_time:
182
  start_time, end_time = end_time, start_time
183
  elif start_time == end_time:
@@ -192,21 +211,58 @@ def _get_rawvideo_dec(video_path, max_frames=MAX_IMAGE_LENGTH, min_frames=MIN_IM
192
  f_start = 0 if start_time is None else int(start_time * fps)
193
  f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1))
194
  num_frames = f_end - f_start + 1
195
-
196
  if num_frames > 0:
 
197
  sample_fps = int(video_framerate)
198
  t_stride = int(round(float(fps) / sample_fps))
199
- all_pos = list(range(f_start, f_end + 1, t_stride))
200
 
 
201
  if len(all_pos) > max_frames:
202
- sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int)]
 
 
203
  elif len(all_pos) < min_frames:
204
- sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=min_frames, dtype=int)]
 
 
205
  else:
206
  sample_pos = all_pos
207
 
208
- patch_images = [Image.fromarray(f).convert("RGB") for f in vreader.get_batch(sample_pos).asnumpy()]
209
- return patch_images, len(patch_images)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  else:
211
  print(f"video path: {video_path} error.")
212
 
@@ -241,6 +297,27 @@ def _parse_text(text):
241
 
242
  return "".join(lines)
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
  @spaces.GPU
246
  def predict(_chatbot, task_history):
@@ -258,13 +335,25 @@ def predict(_chatbot, task_history):
258
  for i, (q, a) in enumerate(task_history):
259
  if isinstance(q, (tuple, list)):
260
  if is_image(q[0]):
261
- images = [Image.open(q[0]).convert("RGB")]
262
- all_visual_tensor.extend(images)
 
 
 
 
 
263
  input_mode = 'image'
264
- qs += DEFAULT_IMAGE_TOKEN * len(images) + '\n'
265
  elif is_video(q[0]):
266
- video_frames, slice_len = _get_rawvideo_dec(q[0])
267
- all_visual_tensor.extend(video_frames)
 
 
 
 
 
 
 
268
  input_mode = 'video'
269
  qs += DEFAULT_IMAGE_TOKEN * slice_len + '\n'
270
  elif is_wav(q[0]):
@@ -282,66 +371,85 @@ def predict(_chatbot, task_history):
282
  conv.append_message(conv.roles[0], new_q)
283
  conv.append_message(conv.roles[1], a)
284
 
 
 
 
 
285
  prompt = conv.get_prompt(input_mode)
286
 
287
- if all_audio_path != []:
288
- input_ids = tokenizer_image_audio_token(
289
- prompt, tokenizer,
290
- image_token_index=IMAGE_TOKEN_INDEX,
291
- audio_token_index=AUDIO_TOKEN_INDEX
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  )
293
- audio_list = []
294
- for single_audio_path in all_audio_path:
295
- try:
296
- audio, original_sr = torchaudio.load(single_audio_path)
297
- target_sr = 16000
298
- if original_sr != target_sr:
299
- resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=target_sr)
300
- audio = resampler(audio)
301
- audio_features = feature_extractor(audio, sampling_rate=target_sr, return_tensors="pt")["input_features"]
302
- audio_list.append(audio_features.squeeze(0))
303
- except Exception as e:
304
- print(f"Error processing {single_audio_path}: {e}")
305
  else:
306
- input_ids = tokenizer_image_token(
307
- prompt, tokenizer,
308
- image_token_index=IMAGE_TOKEN_INDEX
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  )
310
-
311
- if all_visual_tensor == [] and all_audio_path == []:
312
- datapromt = {
313
- "prompt_token_ids": input_ids,
314
- }
315
- elif all_visual_tensor != [] and all_audio_path == []:
316
- datapromt = {
317
- "prompt_token_ids": input_ids,
318
- "multi_modal_data": {
319
- "image": all_visual_tensor
320
- },
321
- }
322
- elif all_visual_tensor == [] and all_audio_path != []:
323
- datapromt = {
324
- "prompt_token_ids": input_ids,
325
- "multi_modal_data": {
326
- "audio": audio_list
327
- },
328
- }
329
  else:
330
- datapromt = {
331
- "prompt_token_ids": input_ids,
332
- "multi_modal_data": {
333
- "image": all_visual_tensor,
334
- "audio": audio_list
335
- },
336
- }
337
-
338
- print(datapromt)
339
-
 
 
340
  with torch.inference_mode():
341
  output_ids = model.generate(
342
  input_ids,
343
- images=all_visual_tensor if all_visual_tensor else None,
344
- audios=audio_list if audio_list else None,
345
  do_sample=False,
346
  temperature=0.01,
347
  top_p=None,
@@ -350,18 +458,30 @@ def predict(_chatbot, task_history):
350
  return_dict_in_generate=True,
351
  max_new_tokens=1024,
352
  use_cache=True,
 
 
353
  )
 
354
 
 
 
355
  outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=False)[0]
 
 
 
 
356
  outputs = outputs.strip()
 
 
 
357
 
358
  task_history[-1] = (chat_query, outputs)
359
  remove_special_characters_output = remove_special_characters(outputs)
360
  _chatbot[-1] = (chat_query, _parse_text(remove_special_characters_output))
361
- print("query", chat_query)
362
- print("task_history", task_history)
363
  print(_chatbot)
364
- print("answer: ", outputs)
365
  yield _chatbot
366
 
367
 
@@ -393,6 +513,7 @@ def add_video(history, task_history, file):
393
  new_file_name = file.replace(".webm",".mp4")
394
  if file.endswith(".webm"):
395
  convert_webm_to_mp4(file, new_file_name)
 
396
  task_history = task_history + [((new_file_name,), None)]
397
  return history, task_history
398
 
@@ -406,10 +527,14 @@ def reset_state(task_history):
406
 
407
  @spaces.GPU
408
  def stream_audio_output(history, task_history):
409
- text = task_history[-1][-1]
 
 
410
  if not text:
411
  # import pdb;pdb.set_trace()
412
- yield None,None
 
 
413
  llm_resounse = replace_equation(remove_special_characters(text))
414
  #print('tts_text', llm_resounse)
415
  for idx, text in enumerate(split_into_sentences(llm_resounse)):
@@ -459,24 +584,20 @@ with gr.Blocks(title="VideoMLLM") as demo:
459
  ),
460
  )
461
 
462
-
463
  add_text_button.click(add_text, [chatbot, task_history, query], [chatbot, task_history], show_progress=True).then(
464
  reset_user_input, [], [query]
465
  ).then(
466
- predict, [chatbot, task_history], [chatbot], show_progress=True
467
  ).then(
468
  stream_audio_output,[chatbot, task_history], [audio_output],
469
  )
470
 
471
-
472
  video_input.stop_recording(add_video, [chatbot, task_history, video_input], [chatbot, task_history], show_progress=True)
473
  empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
474
  addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
475
 
476
-
477
-
478
  add_audio_button.click(add_audio, [chatbot, task_history,record_btn], [chatbot, task_history], show_progress=True).then(
479
- predict, [chatbot, task_history], [chatbot], show_progress=True
480
  ).then(
481
  stream_audio_output,[chatbot, task_history], [audio_output],
482
  )
 
8
  import torchaudio
9
  import io
10
  import cv2
11
+ import time
12
  import math
 
13
  from numba import jit
14
+ import spaces
15
  from huggingface_hub import snapshot_download
16
+ from vita.constants import (
17
+ DEFAULT_AUDIO_TOKEN,
18
+ DEFAULT_IMAGE_TOKEN,
19
+ DEFAULT_VIDEO_TOKEN,
20
+ IGNORE_INDEX,
21
+ IMAGE_TOKEN_INDEX,
22
+ MAX_IMAGE_LENGTH,
23
+ MIN_IMAGE_LENGTH,
24
+ )
25
  from vita.conversation import conv_templates, SeparatorStyle
26
+ from vita.model.builder import load_pretrained_model
27
+ from vita.util.mm_utils import (
28
+ KeywordsStoppingCriteria,
29
+ get_model_name_from_path,
30
+ tokenizer_image_token,
31
+ tokenizer_image_audio_token,
32
+ )
33
+ from vita.util.utils import disable_torch_init
34
  from PIL import Image
35
  from decord import VideoReader, cpu
 
36
  from vita.model.vita_tts.decoder.llm2tts import llm2TTS
37
  from vita.model.language_model.vita_qwen2 import VITAQwen2Config, VITAQwen2ForCausalLM
38
+ from vita.util.data_utils_video_audio_neg_patch import dynamic_preprocess
39
+ from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoFeatureExtractor
40
  decoder_topk = 2
41
  codec_chunk_size = 40
42
  codec_padding_size = 10
43
 
 
44
 
45
+ PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
46
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
47
 
48
  @jit
49
  def float_to_int16(audio: np.ndarray) -> np.ndarray:
 
51
  am = 32767 * 32768 // am
52
  return np.multiply(audio, am).astype(np.int16)
53
 
 
54
  def remove_special_characters(input_str):
55
  # Remove special tokens
56
  special_tokens = ['☞', '☟', '☜', '<unk>', '<|im_end|>']
 
58
  input_str = input_str.replace(token, '')
59
  return input_str
60
 
 
61
  def replace_equation(sentence):
62
  special_notations = {
63
  "sin": " sine ",
 
146
  return ext.lower() in wav_extensions
147
 
148
  def load_model_embemding(model_path):
149
+ config_path = os.path.join(model_path, 'config.json')
150
  config = VITAQwen2Config.from_pretrained(config_path)
151
  model = VITAQwen2ForCausalLM.from_pretrained(model_path, config=config, low_cpu_mem_usage=True)
152
  embedding = model.get_input_embeddings()
 
177
  raise
178
 
179
 
180
+ def _get_rawvideo_dec(
181
+ video_path,
182
+ image_processor=None,
183
+ max_frames=MAX_IMAGE_LENGTH,
184
+ min_frames=MIN_IMAGE_LENGTH,
185
+ image_resolution=384,
186
+ video_framerate=1,
187
+ s=None,
188
+ e=None,
189
+ image_aspect_ratio="pad",
190
+ ):
191
+ # speed up video decode via decord.
192
+
193
+ if s is None:
194
  start_time, end_time = None, None
195
  else:
196
  start_time = int(s)
197
  end_time = int(e)
198
+ start_time = start_time if start_time >= 0.0 else 0.0
199
+ end_time = end_time if end_time >= 0.0 else 0.0
200
  if start_time > end_time:
201
  start_time, end_time = end_time, start_time
202
  elif start_time == end_time:
 
211
  f_start = 0 if start_time is None else int(start_time * fps)
212
  f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1))
213
  num_frames = f_end - f_start + 1
 
214
  if num_frames > 0:
215
+ # T x 3 x H x W
216
  sample_fps = int(video_framerate)
217
  t_stride = int(round(float(fps) / sample_fps))
 
218
 
219
+ all_pos = list(range(f_start, f_end + 1, t_stride))
220
  if len(all_pos) > max_frames:
221
+ sample_pos = [
222
+ all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int)
223
+ ]
224
  elif len(all_pos) < min_frames:
225
+ sample_pos = [
226
+ all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=min_frames, dtype=int)
227
+ ]
228
  else:
229
  sample_pos = all_pos
230
 
231
+ patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()]
232
+
233
+ if image_aspect_ratio == "pad":
234
+
235
+ def expand2square(pil_img, background_color):
236
+ width, height = pil_img.size
237
+ if width == height:
238
+ return pil_img
239
+ elif width > height:
240
+ result = Image.new(pil_img.mode, (width, width), background_color)
241
+ result.paste(pil_img, (0, (width - height) // 2))
242
+ return result
243
+ else:
244
+ result = Image.new(pil_img.mode, (height, height), background_color)
245
+ result.paste(pil_img, ((height - width) // 2, 0))
246
+ return result
247
+
248
+ patch_images = [
249
+ expand2square(i, tuple(int(x * 255) for x in image_processor.image_mean))
250
+ for i in patch_images
251
+ ]
252
+ patch_images = [
253
+ image_processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
254
+ for i in patch_images
255
+ ]
256
+ else:
257
+ patch_images = [
258
+ image_processor.preprocess(i, return_tensors="pt")["pixel_values"][0]
259
+ for i in patch_images
260
+ ]
261
+
262
+ patch_images = torch.stack(patch_images)
263
+ slice_len = patch_images.shape[0]
264
+
265
+ return patch_images, slice_len
266
  else:
267
  print(f"video path: {video_path} error.")
268
 
 
297
 
298
  return "".join(lines)
299
 
300
+ MODEL_NAME = "VITA-MLLM/VITA-1.5"
301
+ model_path = snapshot_download(MODEL_NAME, local_dir="VITA_ckpt")
302
+ model_type = "qwen2p5_instruct"
303
+ tokenizer, model, feature_extractor, context_len = load_pretrained_model(
304
+ model_path, model_base=None, model_name="VITA-1.5", model_type="qwen2p5_instruct"
305
+ )
306
+ model.resize_token_embeddings(len(tokenizer))
307
+
308
+ vision_tower = model.get_vision_tower()
309
+ if not vision_tower.is_loaded:
310
+ vision_tower.load_model()
311
+ image_processor = vision_tower.image_processor
312
+
313
+ audio_encoder = model.get_audio_encoder()
314
+ audio_encoder.to(dtype=torch.float16)
315
+ audio_processor = audio_encoder.audio_processor
316
+
317
+ model.eval()
318
+
319
+ tts = llm2TTS(os.path.join(model_path, 'vita_tts_ckpt/'))
320
+ llm_embedding = load_model_embemding(model_path).to(device)
321
 
322
  @spaces.GPU
323
  def predict(_chatbot, task_history):
 
335
  for i, (q, a) in enumerate(task_history):
336
  if isinstance(q, (tuple, list)):
337
  if is_image(q[0]):
338
+ image = Image.open(q[0]).convert("RGB")
339
+ image, p_num = dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=True)
340
+ assert len(p_num) == 1
341
+ image_tensor = model.process_images(image, model.config).to(
342
+ dtype=model.dtype, device="cuda"
343
+ )
344
+ all_visual_tensor.append(image_tensor)
345
  input_mode = 'image'
346
+ qs += DEFAULT_IMAGE_TOKEN * p_num[0] + '\n'
347
  elif is_video(q[0]):
348
+ video_frames, slice_len = _get_rawvideo_dec(
349
+ q[0],
350
+ image_processor,
351
+ max_frames=MAX_IMAGE_LENGTH,
352
+ video_framerate=1,
353
+ image_aspect_ratio=getattr(model.config, "image_aspect_ratio", None),
354
+ )
355
+ image_tensor = video_frames.half().cuda()
356
+ all_visual_tensor.append(image_tensor)
357
  input_mode = 'video'
358
  qs += DEFAULT_IMAGE_TOKEN * slice_len + '\n'
359
  elif is_wav(q[0]):
 
371
  conv.append_message(conv.roles[0], new_q)
372
  conv.append_message(conv.roles[1], a)
373
 
374
+ if qs:
375
+ conv.append_message(conv.roles[0], qs)
376
+ conv.append_message(conv.roles[1], None)
377
+
378
  prompt = conv.get_prompt(input_mode)
379
 
380
+ if all_audio_path:
381
+ # 处理多个音频并合并
382
+ all_audio_features = []
383
+ all_audio_lengths = []
384
+ all_audio_for_llm_lens = []
385
+
386
+ for audio_path in all_audio_path:
387
+ audio, audio_for_llm_lens = audio_processor.process(os.path.join(audio_path))
388
+ all_audio_features.append(audio)
389
+ all_audio_lengths.append(audio.shape[0])
390
+ all_audio_for_llm_lens.append(audio_for_llm_lens)
391
+
392
+ # 合并音频特征
393
+ combined_audio = torch.cat(all_audio_features, dim=0)
394
+ combined_audio = torch.unsqueeze(combined_audio, dim=0)
395
+
396
+ # 合并长度信息
397
+ combined_length = torch.tensor(sum(all_audio_lengths))
398
+ combined_length = torch.unsqueeze(combined_length, dim=0)
399
+
400
+ # 合并LLM长度
401
+ combined_for_llm_lens = torch.tensor(sum(all_audio_for_llm_lens))
402
+ combined_for_llm_lens = torch.unsqueeze(combined_for_llm_lens, dim=0)
403
+
404
+ audios = dict()
405
+ audios["audios"] = combined_audio.half().cuda()
406
+ audios["lengths"] = combined_length.half().cuda()
407
+ audios["lengths_for_llm"] = combined_for_llm_lens.cuda()
408
+
409
+ input_ids = (
410
+ tokenizer_image_audio_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
411
+ .unsqueeze(0)
412
+ .cuda()
413
  )
 
 
 
 
 
 
 
 
 
 
 
 
414
  else:
415
+ # 空音频处理
416
+ audio = torch.zeros(400, 80)
417
+ audio_length = audio.shape[0]
418
+ audio_for_llm_lens = 60
419
+ audio = torch.unsqueeze(audio, dim=0)
420
+ audio_length = torch.unsqueeze(torch.tensor(audio_length), dim=0)
421
+ audio_for_llm_lens = torch.unsqueeze(torch.tensor(audio_for_llm_lens), dim=0)
422
+ audios = dict()
423
+ audios["audios"] = audio.half().cuda()
424
+ audios["lengths"] = audio_length.half().cuda()
425
+ audios["lengths_for_llm"] = audio_for_llm_lens.cuda()
426
+
427
+ input_ids = (
428
+ tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
429
+ .unsqueeze(0)
430
+ .cuda()
431
  )
432
+
433
+ if len(all_visual_tensor) > 0:
434
+ all_visual_tensor = torch.cat(all_visual_tensor, dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  else:
436
+ all_visual_tensor = torch.zeros((1, 3, 448, 448)).to(dtype=model.dtype, device="cuda")
437
+ if type(all_visual_tensor) is list:
438
+ print("all_visual_tensor is a list: ", len(all_visual_tensor))
439
+ if type(all_visual_tensor) is torch.Tensor:
440
+ print("all_visual_tensor is a tensor: ", all_visual_tensor.shape)
441
+ # 停止条件设置
442
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
443
+ keywords = [stop_str]
444
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
445
+
446
+ # 生成文本
447
+ start_time = time.time()
448
  with torch.inference_mode():
449
  output_ids = model.generate(
450
  input_ids,
451
+ images=all_visual_tensor,
452
+ audios=audios,
453
  do_sample=False,
454
  temperature=0.01,
455
  top_p=None,
 
458
  return_dict_in_generate=True,
459
  max_new_tokens=1024,
460
  use_cache=True,
461
+ stopping_criteria=[stopping_criteria],
462
+ shared_v_pid_stride=None,
463
  )
464
+ infer_time = time.time() - start_time
465
 
466
+ output_ids = output_ids.sequences
467
+ input_token_len = input_ids.shape[1]
468
  outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=False)[0]
469
+
470
+ outputs = outputs.strip()
471
+ if outputs.endswith(stop_str):
472
+ outputs = outputs[: -len(stop_str)]
473
  outputs = outputs.strip()
474
+
475
+ print(f"Generated output: {outputs}")
476
+ print(f"Time consumed: {infer_time}")
477
 
478
  task_history[-1] = (chat_query, outputs)
479
  remove_special_characters_output = remove_special_characters(outputs)
480
  _chatbot[-1] = (chat_query, _parse_text(remove_special_characters_output))
481
+ print("query",chat_query)
482
+ print("task_history",task_history)
483
  print(_chatbot)
484
+ print("answer: ",outputs)
485
  yield _chatbot
486
 
487
 
 
513
  new_file_name = file.replace(".webm",".mp4")
514
  if file.endswith(".webm"):
515
  convert_webm_to_mp4(file, new_file_name)
516
+ history = history + [((new_file_name,), None)]
517
  task_history = task_history + [((new_file_name,), None)]
518
  return history, task_history
519
 
 
527
 
528
  @spaces.GPU
529
  def stream_audio_output(history, task_history):
530
+ print("stream_audio_output", history, task_history)
531
+ text = history[-1][-1]
532
+ print("text", text)
533
  if not text:
534
  # import pdb;pdb.set_trace()
535
+ yield None, None
536
+ return
537
+
538
  llm_resounse = replace_equation(remove_special_characters(text))
539
  #print('tts_text', llm_resounse)
540
  for idx, text in enumerate(split_into_sentences(llm_resounse)):
 
584
  ),
585
  )
586
 
 
587
  add_text_button.click(add_text, [chatbot, task_history, query], [chatbot, task_history], show_progress=True).then(
588
  reset_user_input, [], [query]
589
  ).then(
590
+ predict, [chatbot, task_history], [chatbot], show_progress=True
591
  ).then(
592
  stream_audio_output,[chatbot, task_history], [audio_output],
593
  )
594
 
 
595
  video_input.stop_recording(add_video, [chatbot, task_history, video_input], [chatbot, task_history], show_progress=True)
596
  empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
597
  addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
598
 
 
 
599
  add_audio_button.click(add_audio, [chatbot, task_history,record_btn], [chatbot, task_history], show_progress=True).then(
600
+ predict, [chatbot, task_history], [chatbot], show_progress=True
601
  ).then(
602
  stream_audio_output,[chatbot, task_history], [audio_output],
603
  )
requirements.txt CHANGED
@@ -114,14 +114,14 @@ starlette==0.41.3
114
  sympy==1.13.1
115
  threadpoolctl==3.5.0
116
  timm==1.0.15
117
- tokenizers==0.21.0
118
  tomlkit==0.13.2
119
  torch==2.4.0
120
  torchaudio==2.4.0
121
  torchvision==0.19.0
122
  tqdm==4.67.1
123
  traitlets==5.14.3
124
- transformers==4.49.0
125
  triton==3.0.0
126
  typer==0.15.1
127
  typing_extensions==4.12.2
 
114
  sympy==1.13.1
115
  threadpoolctl==3.5.0
116
  timm==1.0.15
117
+ tokenizers==0.20.3
118
  tomlkit==0.13.2
119
  torch==2.4.0
120
  torchaudio==2.4.0
121
  torchvision==0.19.0
122
  tqdm==4.67.1
123
  traitlets==5.14.3
124
+ transformers==4.46.3
125
  triton==3.0.0
126
  typer==0.15.1
127
  typing_extensions==4.12.2
vita/model/vita_arch.py CHANGED
@@ -388,6 +388,14 @@ class VITAMetaForCausalLM(ABC):
388
  v_start_end = []
389
  cur_image_idx = 0
390
  cur_audio_idx = 0
 
 
 
 
 
 
 
 
391
  assert (
392
  sum([(cur == IMAGE_TOKEN_INDEX).sum() for cur in input_ids])
393
  + sum([(IMAGE_TOKEN_INDEX not in cur) for cur in input_ids])
 
388
  v_start_end = []
389
  cur_image_idx = 0
390
  cur_audio_idx = 0
391
+ print("sum1",sum([(cur == IMAGE_TOKEN_INDEX).sum() for cur in input_ids]))
392
+ print("sum2",sum([(IMAGE_TOKEN_INDEX not in cur) for cur in input_ids]))
393
+ print("len",len(image_features))
394
+ if type(image_features) is list:
395
+ print("image_features is a list: ", len(image_features))
396
+ print("image_features[0] is a tensor: ", image_features[0].shape)
397
+ if type(image_features) is torch.Tensor:
398
+ print("image_features is a tensor: ", image_features.shape)
399
  assert (
400
  sum([(cur == IMAGE_TOKEN_INDEX).sum() for cur in input_ids])
401
  + sum([(IMAGE_TOKEN_INDEX not in cur) for cur in input_ids])