Hyggge commited on
Commit
1bcb6c4
·
1 Parent(s): c99a13d

feat: support batch infer and optimize processor

Browse files
modeling_valley.py CHANGED
@@ -17,7 +17,7 @@ import numpy as np
17
  from torch import nn
18
  from torch.nn import CrossEntropyLoss
19
  from abc import ABC, abstractmethod
20
- from typing import List, Optional, Tuple, Union
21
  from transformers.modeling_outputs import CausalLMOutputWithPast
22
  from transformers import AutoConfig, AutoModelForCausalLM, Qwen2Config, Qwen2ForCausalLM, Qwen2Model
23
 
@@ -39,7 +39,7 @@ class ValleyMetaModel:
39
  else:
40
  self.vision_tower = build_vision_tower(config, delay_load=False)
41
  # Build Projector
42
- if hasattr(config, "mm_projector_type"):
43
  self.mm_projector = build_vision_projector(config)
44
 
45
  def get_vision_tower(self):
@@ -114,6 +114,15 @@ class ValleyMetaForCausalLM(ABC):
114
 
115
  return image_features
116
 
 
 
 
 
 
 
 
 
 
117
 
118
  def prepare_inputs_labels_for_multimodal(
119
  self, input_ids, position_ids, attention_mask, past_key_values, labels, images,
@@ -128,7 +137,6 @@ class ValleyMetaForCausalLM(ABC):
128
  dtype=attention_mask.dtype,
129
  device=attention_mask.device
130
  )), dim=1)
131
- position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
132
  return input_ids, position_ids, attention_mask, past_key_values, None, labels
133
 
134
  # Step1: Get image embedings
@@ -355,8 +363,7 @@ class ValleyMetaForCausalLM(ABC):
355
 
356
  for i, (cur_new_embed, cur_new_labels, cur_attention_mask) in enumerate(zip(new_input_embeds, new_labels, new_attention_mask)):
357
  cur_len = cur_new_embed.shape[0]
358
- # Right padding when inferencing
359
- if not self.training and not getattr(self, "right_padding", None):
360
  new_input_embeds_padded.append(torch.cat((
361
  torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
362
  cur_new_embed
@@ -366,7 +373,6 @@ class ValleyMetaForCausalLM(ABC):
366
  new_attention_mask_padded[i, -cur_len:] = cur_attention_mask
367
  position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
368
 
369
- # Left padding while training
370
  else:
371
  new_input_embeds_padded.append(torch.cat((
372
  cur_new_embed,
@@ -404,6 +410,33 @@ class ValleyQwen2ForCausalLM(Qwen2ForCausalLM, ValleyMetaForCausalLM):
404
  def get_model(self):
405
  return self.model
406
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  def forward(
408
  self,
409
  input_ids: torch.LongTensor = None,
@@ -481,7 +514,7 @@ class ValleyQwen2ForCausalLM(Qwen2ForCausalLM, ValleyMetaForCausalLM):
481
  output = (logits,) + outputs[1:]
482
  return (loss,) + output if loss is not None else output
483
 
484
- return CausalLMOutputWithPast(
485
  loss=loss,
486
  logits=logits,
487
  past_key_values=outputs.past_key_values,
@@ -489,6 +522,9 @@ class ValleyQwen2ForCausalLM(Qwen2ForCausalLM, ValleyMetaForCausalLM):
489
  attentions=outputs.attentions,
490
  )
491
 
 
 
 
492
  def prepare_inputs_for_generation(
493
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
494
  ):
 
17
  from torch import nn
18
  from torch.nn import CrossEntropyLoss
19
  from abc import ABC, abstractmethod
20
+ from typing import List, Optional, Tuple, Union, Dict, Any
21
  from transformers.modeling_outputs import CausalLMOutputWithPast
22
  from transformers import AutoConfig, AutoModelForCausalLM, Qwen2Config, Qwen2ForCausalLM, Qwen2Model
23
 
 
39
  else:
40
  self.vision_tower = build_vision_tower(config, delay_load=False)
41
  # Build Projector
42
+ if hasattr(config, "mm_projector_type") and not getattr(config, "only_navit", False):
43
  self.mm_projector = build_vision_projector(config)
44
 
45
  def get_vision_tower(self):
 
114
 
115
  return image_features
116
 
117
+ def get_padding_method(self):
118
+ right_padding = getattr(self, 'right_padding', None)
119
+ # if right_padding flag is setted, ignore training flag.
120
+ if right_padding is not None:
121
+ method = 'right' if right_padding else 'left'
122
+ # in the other way, use training flag to determine the padding method.
123
+ method = 'right' if self.training else 'left'
124
+
125
+ return method
126
 
127
  def prepare_inputs_labels_for_multimodal(
128
  self, input_ids, position_ids, attention_mask, past_key_values, labels, images,
 
137
  dtype=attention_mask.dtype,
138
  device=attention_mask.device
139
  )), dim=1)
 
140
  return input_ids, position_ids, attention_mask, past_key_values, None, labels
141
 
142
  # Step1: Get image embedings
 
363
 
364
  for i, (cur_new_embed, cur_new_labels, cur_attention_mask) in enumerate(zip(new_input_embeds, new_labels, new_attention_mask)):
365
  cur_len = cur_new_embed.shape[0]
366
+ if self.get_padding_method() == 'left':
 
367
  new_input_embeds_padded.append(torch.cat((
368
  torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
369
  cur_new_embed
 
373
  new_attention_mask_padded[i, -cur_len:] = cur_attention_mask
374
  position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
375
 
 
376
  else:
377
  new_input_embeds_padded.append(torch.cat((
378
  cur_new_embed,
 
410
  def get_model(self):
411
  return self.model
412
 
413
+ def _update_model_kwargs_for_generation(
414
+ self,
415
+ outputs: CausalLMOutputWithPast,
416
+ model_kwargs: Dict[str, Any],
417
+ is_encoder_decoder: bool = False,
418
+ num_new_tokens: int = 1,
419
+ ) -> Dict[str, Any]:
420
+ new_model_kwargs = super()._update_model_kwargs_for_generation(
421
+ outputs,
422
+ model_kwargs,
423
+ is_encoder_decoder,
424
+ num_new_tokens
425
+ )
426
+ """
427
+ Set model_kwargs["attention_mask"] to the expanded `attention_mask` in
428
+ the `prepare_inputs_labels_for_multimodal` function to ensure the
429
+ correctness of the generate behavior when `use_cache` is enabled.
430
+ """
431
+ if not is_encoder_decoder:
432
+ if "attention_mask" in new_model_kwargs:
433
+ attention_mask = outputs.attention_mask
434
+ new_model_kwargs["attention_mask"] = torch.cat(
435
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
436
+ )
437
+ return new_model_kwargs
438
+
439
+
440
  def forward(
441
  self,
442
  input_ids: torch.LongTensor = None,
 
514
  output = (logits,) + outputs[1:]
515
  return (loss,) + output if loss is not None else output
516
 
517
+ res = CausalLMOutputWithPast(
518
  loss=loss,
519
  logits=logits,
520
  past_key_values=outputs.past_key_values,
 
522
  attentions=outputs.attentions,
523
  )
524
 
525
+ res.attention_mask = attention_mask
526
+ return res
527
+
528
  def prepare_inputs_for_generation(
529
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
530
  ):
preprocessor_config.json CHANGED
@@ -2,25 +2,5 @@
2
  "processor_class": "ValleyProcessor",
3
  "auto_map": {
4
  "AutoProcessor": "processing_valley.ValleyProcessor"
5
- },
6
- "min_pixels": 1,
7
- "qwen2vl_processor_config": {
8
- "min_pixels": 3136,
9
- "max_pixels": 12845056,
10
- "patch_size": 14,
11
- "temporal_patch_size": 2,
12
- "merge_size": 2,
13
- "image_mean": [
14
- 0.48145466,
15
- 0.4578275,
16
- 0.40821073
17
- ],
18
- "image_std": [
19
- 0.26862954,
20
- 0.26130258,
21
- 0.27577711
22
- ],
23
- "image_processor_type": "Qwen2VLImageProcessor",
24
- "processor_class": "Qwen2VLProcessor"
25
  }
26
  }
 
2
  "processor_class": "ValleyProcessor",
3
  "auto_map": {
4
  "AutoProcessor": "processing_valley.ValleyProcessor"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  }
6
  }
processing_valley.py CHANGED
@@ -88,10 +88,15 @@ class ValleyProcessor(ProcessorMixin):
88
  self.siglip_image_processor = SiglipImageProcessor.from_dict(siglip_processor_config)
89
  self.qwen2vl_image_processor = Qwen2VLImageProcessor.from_dict(
90
  qwen2vl_processor_config,
91
- max_pixels=kwargs.get("max_pixels", 1280*28*28),
92
- min_pixels=kwargs.get("min_pixels", 4*28*28)
93
  )
94
-
 
 
 
 
 
 
 
95
  self.anyres = kwargs.get("anyres", True)
96
  self.grid_pinpoints = kwargs.get("grid_pinpoints", "(1x1),...,(3x3)")
97
  self.only_crop_single_image = kwargs.get("only_crop_single_image", True)
@@ -259,7 +264,7 @@ class ValleyProcessor(ProcessorMixin):
259
  return input_ids
260
 
261
 
262
- def __call__(self, messages, inference=True) -> BatchFeature:
263
  # Deal with images
264
  if "images" not in messages or not messages["images"] or not messages["images"][0]:
265
  images = [self.black_img]
 
88
  self.siglip_image_processor = SiglipImageProcessor.from_dict(siglip_processor_config)
89
  self.qwen2vl_image_processor = Qwen2VLImageProcessor.from_dict(
90
  qwen2vl_processor_config,
 
 
91
  )
92
+
93
+ max_pixels = kwargs.get("max_pixels", None)
94
+ min_pixels = kwargs.get("min_pixels", None)
95
+ if max_pixels:
96
+ self.qwen2vl_image_processor.max_pixels = max_pixels
97
+ if min_pixels:
98
+ self.qwen2vl_image_processor.min_pixels = min_pixels
99
+
100
  self.anyres = kwargs.get("anyres", True)
101
  self.grid_pinpoints = kwargs.get("grid_pinpoints", "(1x1),...,(3x3)")
102
  self.only_crop_single_image = kwargs.get("only_crop_single_image", True)
 
264
  return input_ids
265
 
266
 
267
+ def __call__(self, messages, inference=True, **kwargs) -> BatchFeature:
268
  # Deal with images
269
  if "images" not in messages or not messages["images"] or not messages["images"][0]:
270
  images = [self.black_img]