chenjoya commited on
Commit
b9706c3
·
1 Parent(s): 64d3228
Files changed (1) hide show
  1. demo/infer.py +54 -48
demo/infer.py CHANGED
@@ -156,57 +156,63 @@ class LiveCCDemoInfer:
156
  state['past_ids'] = outputs.sequences[:, :-1]
157
  yield (start_timestamp, stop_timestamp), self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True), state
158
 
 
159
  def video_qa(
160
- model,
161
- processor,
162
- video_path: str,
163
- query: str,
164
- answer_prefix: str = '',
165
- video_start: float = None,
166
- video_end: float = None,
167
- strict_fps: bool = False,
168
- strict_abcd_ids: list[int] = None,
169
- do_sample: bool = False,
170
- max_new_tokens: int = 128
171
- ):
172
- if strict_fps:
173
- video_inputs, _ = _read_video_decord_plus({'video': video_path, 'video_start': video_start, 'video_end': video_end}, strict_fps=True, drop_last=False)
174
- video_inputs = _spatial_resize_video(video_inputs)
175
- conversation = [
176
- {
177
- "role": "user",
178
- "content": [
179
- {"type": "video", "video": video_inputs},
180
- {"type": "text", "text": query},
181
- ],
182
- }
183
- ]
184
- image_inputs = None
 
 
185
  else:
186
- conversation = [
187
- {
188
- "role": "user",
189
- "content": [
190
- {"type": "video", "video": video_path, "video_start": video_start, "video_end": video_end},
191
- {"type": "text", "text": query},
192
- ],
193
- }
194
- ]
195
- image_inputs, video_inputs = process_vision_info(conversation)
196
- text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) + answer_prefix
197
- inputs = processor(
198
- text=[text],
199
  images=image_inputs,
200
  videos=video_inputs,
201
  return_tensors="pt",
 
202
  )
203
- print(text)
204
- inputs = inputs.to("cuda")
205
- if not strict_abcd_ids:
206
- generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=do_sample)
207
- output_text = processor.decode(generated_ids[0, inputs.input_ids.size(1):], clean_up_tokenization_spaces=False)
208
- else:
209
- outputs = model.generate(**inputs, do_sample=do_sample, top_p=None, temperature=None, top_k=None, max_new_tokens=1, return_dict_in_generate=True, output_scores=True, repetition_penalty=1)
210
- print(outputs.scores[0][0, strict_abcd_ids])
211
- output_text = ['A', 'B', 'C', 'D'][outputs.scores[0][0, strict_abcd_ids].argmax()]
212
- return output_text
 
 
 
156
  state['past_ids'] = outputs.sequences[:, :-1]
157
  yield (start_timestamp, stop_timestamp), self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True), state
158
 
159
+ @torch.inference_mode()
160
  def video_qa(
161
+ self,
162
+ query: str,
163
+ state: dict,
164
+ default_query: str = 'Please describe the video.',
165
+ do_sample: bool = False,
166
+ repetition_penalty: float = 1.05,
167
+ **kwargs,
168
+ ):
169
+ """
170
+ state: dict, (maybe) with keys:
171
+ video_path: str, video path
172
+ video_timestamp: float, current video timestamp
173
+ last_timestamp: float, last processed video timestamp
174
+ last_video_pts_index: int, last processed video frame index
175
+ video_pts: np.ndarray, video pts
176
+ last_history: list, last processed history
177
+ """
178
+ video_path = state.get('video_path', None)
179
+ if video_path:
180
+ message = {
181
+ "role": "user",
182
+ "content": [
183
+ {"type": "video", "video": video_path},
184
+ {"type": "text", "text": query if query else default_query},
185
+ ],
186
+ }
187
+
188
  else:
189
+ message = {
190
+ "role": "user",
191
+ "content": [
192
+ {"type": "text", "text": query if query else default_query},
193
+ ],
194
+ }
195
+ image_inputs, video_inputs = process_vision_info([message])
196
+ texts = self.processor.apply_chat_template([message], tokenize=False, add_generation_prompt=True, return_tensors='pt')
197
+ past_ids = state.get('past_ids', None)
198
+ if past_ids is not None:
199
+ texts = '<|im_end|>\n' + texts[self.system_prompt_offset:]
200
+ inputs = self.processor(
201
+ text=texts,
202
  images=image_inputs,
203
  videos=video_inputs,
204
  return_tensors="pt",
205
+ return_attention_mask=False
206
  )
207
+ inputs.to('cuda')
208
+ if past_ids is not None:
209
+ inputs['input_ids'] = torch.cat([past_ids, inputs.input_ids], dim=1)
210
+ outputs = self.model.generate(
211
+ **inputs, past_key_values=state.get('past_key_values', None),
212
+ return_dict_in_generate=True, do_sample=do_sample,
213
+ repetition_penalty=repetition_penalty,
214
+ max_new_tokens=512,
215
+ )
216
+ state['past_key_values'] = outputs.past_key_values
217
+ state['past_ids'] = outputs.sequences[:, :-1]
218
+ return self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True), state