chenjoya commited on
Commit
8f86fe7
·
verified ·
1 Parent(s): 5d6f2f7

Update demo/infer.py

Browse files
Files changed (1) hide show
  1. demo/infer.py +6 -6
demo/infer.py CHANGED
@@ -56,7 +56,7 @@ class LiveCCDemoInfer:
56
  self,
57
  message: str,
58
  state: dict,
59
- max_pixels: int = 256 * 28 * 28,
60
  default_query: str = 'Please describe the video.',
61
  do_sample: bool = False,
62
  repetition_penalty: float = 1.05,
@@ -122,20 +122,20 @@ class LiveCCDemoInfer:
122
  # 5. make conversation and send to model
123
  for clip, timestamps in zip(interleave_clips, interleave_timestamps):
124
  start_timestamp, stop_timestamp = timestamps[0].item(), timestamps[-1].item() + self.frame_time_interval
125
- message = {
126
  "role": "user",
127
  "content": [
128
  {"type": "text", "text": f'Time={start_timestamp:.1f}-{stop_timestamp:.1f}s'},
129
  {"type": "video", "video": clip}
130
  ]
131
- }
132
  if not message and not state.get('message', None):
133
  message = default_query
134
  logger.warning(f'No query provided, use default_query={default_query}')
135
  if message and state.get('message', None) != message:
136
- message['content'].append({"type": "text", "text": message})
137
  state['message'] = message
138
- texts = self.processor.apply_chat_template([message], tokenize=False, add_generation_prompt=True, return_tensors='pt')
139
  past_ids = state.get('past_ids', None)
140
  if past_ids is not None:
141
  texts = '<|im_end|>\n' + texts[self.system_prompt_offset:]
@@ -146,7 +146,6 @@ class LiveCCDemoInfer:
146
  return_tensors="pt",
147
  return_attention_mask=False
148
  )
149
- print(texts)
150
  inputs.to(self.model.device)
151
  if past_ids is not None:
152
  inputs['input_ids'] = torch.cat([past_ids, inputs.input_ids], dim=1)
@@ -159,6 +158,7 @@ class LiveCCDemoInfer:
159
  return_dict_in_generate=True, do_sample=do_sample,
160
  repetition_penalty=repetition_penalty,
161
  logits_processor=logits_processor,
 
162
  )
163
  state['past_key_values'] = outputs.past_key_values
164
  state['past_ids'] = outputs.sequences[:, :-1]
 
56
  self,
57
  message: str,
58
  state: dict,
59
+ max_pixels: int = 384 * 28 * 28,
60
  default_query: str = 'Please describe the video.',
61
  do_sample: bool = False,
62
  repetition_penalty: float = 1.05,
 
122
  # 5. make conversation and send to model
123
  for clip, timestamps in zip(interleave_clips, interleave_timestamps):
124
  start_timestamp, stop_timestamp = timestamps[0].item(), timestamps[-1].item() + self.frame_time_interval
125
+ conversation = [{
126
  "role": "user",
127
  "content": [
128
  {"type": "text", "text": f'Time={start_timestamp:.1f}-{stop_timestamp:.1f}s'},
129
  {"type": "video", "video": clip}
130
  ]
131
+ }]
132
  if not message and not state.get('message', None):
133
  message = default_query
134
  logger.warning(f'No query provided, use default_query={default_query}')
135
  if message and state.get('message', None) != message:
136
+ conversation[0]['content'].append({"type": "text", "text": message})
137
  state['message'] = message
138
+ texts = self.processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True, return_tensors='pt')
139
  past_ids = state.get('past_ids', None)
140
  if past_ids is not None:
141
  texts = '<|im_end|>\n' + texts[self.system_prompt_offset:]
 
146
  return_tensors="pt",
147
  return_attention_mask=False
148
  )
 
149
  inputs.to(self.model.device)
150
  if past_ids is not None:
151
  inputs['input_ids'] = torch.cat([past_ids, inputs.input_ids], dim=1)
 
158
  return_dict_in_generate=True, do_sample=do_sample,
159
  repetition_penalty=repetition_penalty,
160
  logits_processor=logits_processor,
161
+ max_new_tokens=16,
162
  )
163
  state['past_key_values'] = outputs.past_key_values
164
  state['past_ids'] = outputs.sequences[:, :-1]