none commited on
Commit
1b58092
·
1 Parent(s): 8551eaf
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +65 -18
  2. datasets/Grid/README.md +0 -1
  3. datasets/V2C/README.md +0 -1
  4. datasets/V2C/V2C_Setting2.txt +0 -0
  5. datasets/V2C/V2C_Setting3.txt +0 -0
  6. src/third_party/InternVL/internvl_chat/README.md +635 -0
  7. src/third_party/InternVL/internvl_chat/evaluate.sh +726 -0
  8. src/third_party/InternVL/internvl_chat/internvl/conversation.py +402 -0
  9. src/third_party/InternVL/internvl_chat/internvl/dist_utils.py +104 -0
  10. src/third_party/InternVL/internvl_chat/internvl/model/__init__.py +51 -0
  11. src/third_party/InternVL/internvl_chat/internvl/model/internlm2/configuration_internlm2.py +150 -0
  12. src/third_party/InternVL/internvl_chat/internvl/model/internlm2/modeling_internlm2.py +1429 -0
  13. src/third_party/InternVL/internvl_chat/internvl/model/internlm2/tokenization_internlm2.py +235 -0
  14. src/third_party/InternVL/internvl_chat/internvl/model/internlm2/tokenization_internlm2_fast.py +211 -0
  15. src/third_party/InternVL/internvl_chat/internvl/model/internvl_chat/__init__.py +13 -0
  16. src/third_party/InternVL/internvl_chat/internvl/model/internvl_chat/configuration_intern_vit.py +120 -0
  17. src/third_party/InternVL/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py +109 -0
  18. src/third_party/InternVL/internvl_chat/internvl/model/internvl_chat/modeling_intern_vit.py +450 -0
  19. src/third_party/InternVL/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py +477 -0
  20. src/third_party/InternVL/internvl_chat/internvl/model/phi3/configuration_phi3.py +211 -0
  21. src/third_party/InternVL/internvl_chat/internvl/model/phi3/modeling_phi3.py +1610 -0
  22. src/third_party/InternVL/internvl_chat/internvl/patch/__init__.py +34 -0
  23. src/third_party/InternVL/internvl_chat/internvl/patch/internlm2_packed_training_patch.py +74 -0
  24. src/third_party/InternVL/internvl_chat/internvl/patch/internvit_liger_monkey_patch.py +13 -0
  25. src/third_party/InternVL/internvl_chat/internvl/patch/llama2_flash_attn_monkey_patch.py +238 -0
  26. src/third_party/InternVL/internvl_chat/internvl/patch/llama_flash_attn_monkey_patch.py +222 -0
  27. src/third_party/InternVL/internvl_chat/internvl/patch/llama_packed_training_patch.py +106 -0
  28. src/third_party/InternVL/internvl_chat/internvl/patch/llama_rmsnorm_monkey_patch.py +23 -0
  29. src/third_party/InternVL/internvl_chat/internvl/patch/pad_data_collator.py +155 -0
  30. src/third_party/InternVL/internvl_chat/internvl/patch/phi3_packed_training_patch.py +105 -0
  31. src/third_party/InternVL/internvl_chat/internvl/patch/qwen2_packed_training_patch.py +106 -0
  32. src/third_party/InternVL/internvl_chat/internvl/patch/train_dataloader_patch.py +53 -0
  33. src/third_party/InternVL/internvl_chat/internvl/patch/train_sampler_patch.py +125 -0
  34. src/third_party/InternVL/internvl_chat/internvl/train/__init__.py +5 -0
  35. src/third_party/InternVL/internvl_chat/internvl/train/constants.py +21 -0
  36. src/third_party/InternVL/internvl_chat/internvl/train/dataset.py +866 -0
  37. src/third_party/InternVL/internvl_chat/internvl/train/dataset_packed.py +634 -0
  38. src/third_party/InternVL/internvl_chat/internvl/train/internvl_chat_dpo.py +1056 -0
  39. src/third_party/InternVL/internvl_chat/internvl/train/internvl_chat_finetune.py +1072 -0
  40. src/third_party/InternVL/internvl_chat/internvl/train/internvl_chat_pretrain.py +1116 -0
  41. src/third_party/InternVL/internvl_chat/internvl/train/trainer_dpo.py +302 -0
  42. src/third_party/InternVL/internvl_chat/pyproject.toml +33 -0
  43. src/third_party/InternVL/internvl_chat/tools/convert_to_int8.py +16 -0
  44. src/third_party/InternVL/internvl_chat/tools/extract_mlp.py +19 -0
  45. src/third_party/InternVL/internvl_chat/tools/extract_video_frames.py +120 -0
  46. src/third_party/InternVL/internvl_chat/tools/extract_vit.py +16 -0
  47. src/third_party/InternVL/internvl_chat/tools/images_stitching.py +79 -0
  48. src/third_party/InternVL/internvl_chat/tools/json2jsonl.py +20 -0
  49. src/third_party/InternVL/internvl_chat/tools/jsonl2jsonl.py +22 -0
  50. src/third_party/InternVL/internvl_chat/tools/merge_lora.py +31 -0
app.py CHANGED
@@ -8,10 +8,12 @@ import soundfile
8
  import torch
9
  import torch.nn.functional as F
10
  import torchaudio
 
11
  from moviepy import VideoFileClip
12
  from pydub import AudioSegment
13
- from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
14
 
 
15
  from src.moviedubber.infer.utils_infer import (
16
  cfg_strength,
17
  chunk_text,
@@ -23,8 +25,11 @@ from src.moviedubber.infer_with_mmlm_result import concat_movie_with_audio, get_
23
  from src.moviedubber.model.utils import convert_char_to_pinyin
24
 
25
 
 
26
  sys.path.append("src/third_party/BigVGAN")
27
 
 
 
28
 
29
  def load_asr_model(model_id="openai/whisper-large-v3-turbo"):
30
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
@@ -45,7 +50,22 @@ def load_asr_model(model_id="openai/whisper-large-v3-turbo"):
45
  return pipe
46
 
47
 
48
- device = "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  ema_model, vocoder, ort_session = load_models(device=device)
51
  asr_pipe = load_asr_model()
@@ -54,6 +74,29 @@ videofeature_extractor = VideoFeatureExtractor(device=device)
54
 
55
 
56
  def deepdubber(video_path: str, subtitle_text: str, audio_path: str = None) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  print(f"Starting deepdubber with video_path: {video_path} and subtitle_text: {subtitle_text}")
58
  gen_clip = videofeature_extractor.extract_features(video_path)
59
  gen_text = subtitle_text
@@ -143,29 +186,29 @@ def deepdubber(video_path: str, subtitle_text: str, audio_path: str = None) -> s
143
  os.remove(temp_wav_path)
144
 
145
  print(f"Deepdubber completed successfully, output path: {concated_video}")
146
- return concated_video
147
 
148
 
149
  def process_video_dubbing(video_path: str, subtitle_text: str, audio_path: str = None) -> str:
150
- try:
151
- print(f"Processing video: {video_path}")
152
- if not os.path.exists(video_path):
153
- raise ValueError("Video file does not exist")
154
 
155
- if not subtitle_text.strip():
156
- raise ValueError("Subtitle text cannot be empty")
157
 
158
- if audio_path is None:
159
- audio_path = "datasets/CoTMovieDubbing/GT.wav"
160
 
161
- output_path = deepdubber(video_path, subtitle_text, audio_path)
162
 
163
- return output_path
164
 
165
- except Exception as e:
166
- print(f"Error in process_video_dubbing: {e}")
167
 
168
- return None
169
 
170
 
171
  def create_ui():
@@ -181,10 +224,14 @@ def create_ui():
181
 
182
  process_btn = gr.Button("Start Dubbing")
183
 
184
- output_video = gr.Video(label="Dubbed Video")
 
 
185
 
186
  process_btn.click(
187
- fn=process_video_dubbing, inputs=[video_input, subtitle_input, audio_input], outputs=output_video
 
 
188
  )
189
 
190
  return app
 
8
  import torch
9
  import torch.nn.functional as F
10
  import torchaudio
11
+ from huggingface_hub import hf_hub_download
12
  from moviepy import VideoFileClip
13
  from pydub import AudioSegment
14
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, AutoTokenizer, pipeline
15
 
16
+ from src.internvl.eval import load_video
17
  from src.moviedubber.infer.utils_infer import (
18
  cfg_strength,
19
  chunk_text,
 
25
  from src.moviedubber.model.utils import convert_char_to_pinyin
26
 
27
 
28
+ sys.path.insert(0, "src/third_party")
29
  sys.path.append("src/third_party/BigVGAN")
30
 
31
+ from InternVL.internvl_chat.internvl.model.internvl_chat.modeling_internvl_chat import InternVLChatModel # type: ignore
32
+
33
 
34
  def load_asr_model(model_id="openai/whisper-large-v3-turbo"):
35
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
 
50
  return pipe
51
 
52
 
53
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+
55
+ mmlm_path = hf_hub_download(repo_id="woak-oa/DeepDubber-V1", filename="mmlm")
56
+
57
+ mmlm = InternVLChatModel.from_pretrained(
58
+ mmlm_path,
59
+ torch_dtype=torch.bfloat16,
60
+ low_cpu_mem_usage=True,
61
+ use_flash_attn=False,
62
+ )
63
+ mmlm = mmlm.eval().to(device)
64
+
65
+ tokenizer = AutoTokenizer.from_pretrained(mmlm_path, trust_remote_code=True, use_fast=False)
66
+
67
+ generation_config = dict(max_new_tokens=1024, do_sample=False)
68
+
69
 
70
  ema_model, vocoder, ort_session = load_models(device=device)
71
  asr_pipe = load_asr_model()
 
74
 
75
 
76
  def deepdubber(video_path: str, subtitle_text: str, audio_path: str = None) -> str:
77
+ pixel_values, num_patches_list = load_video(video_path, num_segments=8, max_num=1)
78
+ pixel_values = pixel_values.to(torch.bfloat16).to(device)
79
+ video_prefix = "".join([f"Frame{i + 1}: <image>\n" for i in range(len(num_patches_list))])
80
+ question = (
81
+ video_prefix
82
+ + "What is the voice-over category for this video? Options: A. dialogue, B. monologue, C. narration."
83
+ )
84
+ response = mmlm.chat(
85
+ tokenizer,
86
+ pixel_values,
87
+ question,
88
+ generation_config,
89
+ num_patches_list=num_patches_list,
90
+ history=None,
91
+ return_history=False,
92
+ )
93
+
94
+ try:
95
+ response = response.split("<CONCLUSION>")[1].split("</CONCLUSION>")[0].strip()
96
+ except Exception as e:
97
+ print(f"Error: {e}, response: {response}")
98
+ response = response.strip()[0]
99
+
100
  print(f"Starting deepdubber with video_path: {video_path} and subtitle_text: {subtitle_text}")
101
  gen_clip = videofeature_extractor.extract_features(video_path)
102
  gen_text = subtitle_text
 
186
  os.remove(temp_wav_path)
187
 
188
  print(f"Deepdubber completed successfully, output path: {concated_video}")
189
+ return response, concated_video
190
 
191
 
192
  def process_video_dubbing(video_path: str, subtitle_text: str, audio_path: str = None) -> str:
193
+ # try:
194
+ print(f"Processing video: {video_path}")
195
+ if not os.path.exists(video_path):
196
+ raise ValueError("Video file does not exist")
197
 
198
+ if not subtitle_text.strip():
199
+ raise ValueError("Subtitle text cannot be empty")
200
 
201
+ if audio_path is None:
202
+ audio_path = "datasets/CoTMovieDubbing/GT.wav"
203
 
204
+ res, output_path = deepdubber(video_path, subtitle_text, audio_path)
205
 
206
+ return res, output_path
207
 
208
+ # except Exception as e:
209
+ # print(f"Error in process_video_dubbing: {e}")
210
 
211
+ # return None, None
212
 
213
 
214
  def create_ui():
 
224
 
225
  process_btn = gr.Button("Start Dubbing")
226
 
227
+ with gr.Row():
228
+ output_response = gr.Textbox(label="Response", placeholder="Response from MMLM", lines=5)
229
+ output_video = gr.Video(label="Dubbed Video")
230
 
231
  process_btn.click(
232
+ fn=process_video_dubbing,
233
+ inputs=[video_input, subtitle_input, audio_input],
234
+ outputs=[output_response, output_video],
235
  )
236
 
237
  return app
datasets/Grid/README.md DELETED
@@ -1 +0,0 @@
1
- Refer to: [Grid](https://paperswithcode.com/dataset/grid)
 
 
datasets/V2C/README.md DELETED
@@ -1 +0,0 @@
1
- Refer to: [V2C](https://github.com/chenqi008/V2C)
 
 
datasets/V2C/V2C_Setting2.txt DELETED
The diff for this file is too large to render. See raw diff
 
datasets/V2C/V2C_Setting3.txt DELETED
The diff for this file is too large to render. See raw diff
 
src/third_party/InternVL/internvl_chat/README.md ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # InternVL-Chat
2
+
3
+ This folder contains the implementation of the InternVL-Chat.
4
+
5
+ ## 📖 Documents
6
+
7
+ ### 🌟 **Get Started**
8
+
9
+ - **Installation**: 🌱 [Installation Guide](https://internvl.readthedocs.io/en/latest/get_started/installation.html) | 📄 [requirements.txt](./requirements.txt)
10
+ - **Chat Data Format**: 📝 [Meta File](https://internvl.readthedocs.io/en/latest/get_started/chat_data_format.html#meta-file) | ✏️ [Text](https://internvl.readthedocs.io/en/latest/get_started/chat_data_format.html#pure-text-data) | 🖼️ [Single-Image](https://internvl.readthedocs.io/en/latest/get_started/chat_data_format.html#single-image-data) | 🖼️🖼️ [Multi-Image](https://internvl.readthedocs.io/en/latest/get_started/chat_data_format.html#multi-image-data) | 🎥 [Video](https://internvl.readthedocs.io/en/latest/get_started/chat_data_format.html#video-data)
11
+ - **Local Chat Demo**: 🤖 [Streamlit Demo](https://internvl.readthedocs.io/en/latest/get_started/local_chat_demo.html#streamlit-demo)
12
+ - **InternVL-Chat API**: 🌐 [InternVL2-Pro](https://internvl.readthedocs.io/en/latest/get_started/internvl_chat_api.html#official-api-of-internvl2-pro)
13
+ - **Tutorials**: 🚀 [Enhancing InternVL2 on COCO Caption Using LoRA Fine-Tuning](https://internvl.readthedocs.io/en/latest/tutorials/coco_caption_finetune.html)
14
+
15
+ ### 🏆 **InternVL Family**
16
+
17
+ - **InternVL 2.5**: 📖 [Introduction](https://internvl.readthedocs.io/en/latest/internvl2.5/introduction.html) | ⚡ [Quick Start](https://internvl.readthedocs.io/en/latest/internvl2.5/quick_start.html) | ✨ [Finetune](https://internvl.readthedocs.io/en/latest/internvl2.5/finetune.html) | 📊 [Evaluation](https://internvl.readthedocs.io/en/latest/internvl2.5/evaluation.html) | 📦 [Deployment](https://internvl.readthedocs.io/en/latest/internvl2.5/deployment.html) | 🎯 [Preference Optimization](https://internvl.readthedocs.io/en/latest/internvl2.5/preference_optimization.html)
18
+ - **InternVL 2.0**: 📖 [Introduction](https://internvl.readthedocs.io/en/latest/internvl2.0/introduction.html) | ⚡ [Quick Start](https://internvl.readthedocs.io/en/latest/internvl2.0/quick_start.html) | ✨ [Finetune](https://internvl.readthedocs.io/en/latest/internvl2.0/finetune.html) | 📊 [Evaluation](https://internvl.readthedocs.io/en/latest/internvl2.0/evaluation.html) | 📦 [Deployment](https://internvl.readthedocs.io/en/latest/internvl2.0/deployment.html) | 🎯 [Preference Optimization](https://internvl.readthedocs.io/en/latest/internvl2.0/preference_optimization.html)
19
+ - **InternVL 1.5**: 📖 [Introduction](https://internvl.readthedocs.io/en/latest/internvl1.5/introduction.html) | ⚡ [Quick Start](https://internvl.readthedocs.io/en/latest/internvl1.5/quick_start.html) | ✨ [Finetune](https://internvl.readthedocs.io/en/latest/internvl1.5/finetune.html) | 📊 [Evaluation](https://internvl.readthedocs.io/en/latest/internvl1.5/evaluation.html) | 📦 [Deployment](https://internvl.readthedocs.io/en/latest/internvl1.5/deployment.html)
20
+ - **InternVL 1.2**: 📖 [Introduction](https://internvl.readthedocs.io/en/latest/internvl1.2/introduction.html) | ⚡ [Quick Start](https://internvl.readthedocs.io/en/latest/internvl1.2/quick_start.html) | ✨ [Finetune](https://internvl.readthedocs.io/en/latest/internvl1.2/finetune.html) | 📊 [Evaluation](https://internvl.readthedocs.io/en/latest/internvl1.2/evaluation.html)
21
+ - **InternVL 1.1**: 📖 [Introduction](https://internvl.readthedocs.io/en/latest/internvl1.1/introduction.html) | ⚡ [Quick Start](https://internvl.readthedocs.io/en/latest/internvl1.1/quick_start.html) | 📊 [Evaluation](https://internvl.readthedocs.io/en/latest/internvl1.1/evaluation.html)
22
+
23
+ # Introduction
24
+
25
+ We are excited to introduce **InternVL 2.5**, an advanced multimodal large language model (MLLM) series that builds upon InternVL 2.0, maintaining its core model architecture while introducing significant enhancements in training and testing strategies as well as data quality.
26
+
27
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64119264f0f81eb569e0d569/5HDAGOQOZvS1EtI107Ac-.png)
28
+
29
+ ## InternVL 2.5 Family
30
+
31
+ In the following table, we provide an overview of the InternVL 2.5 series.
32
+
33
+ | Model Name | Vision Part | Language Part | HF Link |
34
+ | :-------------: | :-------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------: | :---------------------------------------------------------: |
35
+ | InternVL2_5-1B | [InternViT-300M-448px-V2_5](https://huggingface.co/OpenGVLab/InternViT-300M-448px-V2_5) | [Qwen2.5-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct) | [🤗 link](https://huggingface.co/OpenGVLab/InternVL2_5-1B) |
36
+ | InternVL2_5-2B | [InternViT-300M-448px-V2_5](https://huggingface.co/OpenGVLab/InternViT-300M-448px-V2_5) | [internlm2_5-1_8b-chat](https://huggingface.co/internlm/internlm2_5-1_8b-chat) | [🤗 link](https://huggingface.co/OpenGVLab/InternVL2_5-2B) |
37
+ | InternVL2_5-4B | [InternViT-300M-448px-V2_5](https://huggingface.co/OpenGVLab/InternViT-300M-448px-V2_5) | [Qwen2.5-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-3B-Instruct) | [🤗 link](https://huggingface.co/OpenGVLab/InternVL2_5-4B) |
38
+ | InternVL2_5-8B | [InternViT-300M-448px-V2_5](https://huggingface.co/OpenGVLab/InternViT-300M-448px-V2_5) | [internlm2_5-7b-chat](https://huggingface.co/internlm/internlm2_5-7b-chat) | [🤗 link](https://huggingface.co/OpenGVLab/InternVL2_5-8B) |
39
+ | InternVL2_5-26B | [InternViT-6B-448px-V2_5](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V2_5) | [internlm2_5-20b-chat](https://huggingface.co/internlm/internlm2_5-20b-chat) | [🤗 link](https://huggingface.co/OpenGVLab/InternVL2_5-26B) |
40
+ | InternVL2_5-38B | [InternViT-6B-448px-V2_5](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V2_5) | [Qwen2.5-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-32B-Instruct) | [🤗 link](https://huggingface.co/OpenGVLab/InternVL2_5-38B) |
41
+ | InternVL2_5-78B | [InternViT-6B-448px-V2_5](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V2_5) | [Qwen2.5-72B-Instruct](https://huggingface.co/Qwen/Qwen2.5-72B-Instruct) | [🤗 link](https://huggingface.co/OpenGVLab/InternVL2_5-78B) |
42
+
43
+ ## Model Architecture
44
+
45
+ As shown in the following figure, InternVL 2.5 retains the same model architecture as its predecessors, InternVL 1.5 and 2.0, following the "ViT-MLP-LLM" paradigm. In this new version, we integrate a newly incrementally pre-trained InternViT with various pre-trained LLMs, including InternLM 2.5 and Qwen 2.5, using a randomly initialized MLP projector.
46
+
47
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64119264f0f81eb569e0d569/BiiyXN6NOk0p-3rl3ueyL.png)
48
+
49
+ As in the previous version, we applied a pixel unshuffle operation, reducing the number of visual tokens to one-quarter of the original. Besides, we adopted a similar dynamic resolution strategy as InternVL 1.5, dividing images into tiles of 448×448 pixels. The key difference, starting from InternVL 2.0, is that we additionally introduced support for multi-image and video data.
50
+
51
+ ## Training Strategy
52
+
53
+ ### Dynamic High-Resolution for Multimodal Data
54
+
55
+ In InternVL 2.0 and 2.5, we extend the dynamic high-resolution training approach, enhancing its capabilities to handle multi-image and video datasets.
56
+
57
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64119264f0f81eb569e0d569/xoMY6rwRrNxbAGYPNyU8g.png)
58
+
59
+ - For single-image datasets, the total number of tiles `n_max` are allocated to a single image for maximum resolution. Visual tokens are enclosed in `<img>` and `</img>` tags.
60
+
61
+ - For multi-image datasets, the total number of tiles `n_max` are distributed across all images in a sample. Each image is labeled with auxiliary tags like `Image-1` and enclosed in `<img>` and `</img>` tags.
62
+
63
+ - For videos, each frame is resized to 448×448. Frames are labeled with tags like `Frame-1` and enclosed in `<img>` and `</img>` tags, similar to images.
64
+
65
+ ### Single Model Training Pipeline
66
+
67
+ The training pipeline for a single model in InternVL 2.5 is structured across three stages, designed to enhance the model's visual perception and multimodal capabilities.
68
+
69
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64119264f0f81eb569e0d569/5NduZeCPLgPJTFr0RGTq3.png)
70
+
71
+ - **Stage 1: MLP Warmup.** In this stage, only the MLP projector is trained while the vision encoder and language model are frozen. A dynamic high-resolution training strategy is applied for better performance, despite increased cost. This phase ensures robust cross-modal alignment and prepares the model for stable multimodal training.
72
+
73
+ - **Stage 1.5: ViT Incremental Learning (Optional).** This stage allows incremental training of the vision encoder and MLP projector using the same data as Stage 1. It enhances the encoder’s ability to handle rare domains like multilingual OCR and mathematical charts. Once trained, the encoder can be reused across LLMs without retraining, making this stage optional unless new domains are introduced.
74
+
75
+ - **Stage 2: Full Model Instruction Tuning.** The entire model is trained on high-quality multimodal instruction datasets. Strict data quality controls are enforced to prevent degradation of the LLM, as noisy data can cause issues like repetitive or incorrect outputs. After this stage, the training process is complete.
76
+
77
+ ### Progressive Scaling Strategy
78
+
79
+ We introduce a progressive scaling strategy to align the vision encoder with LLMs efficiently. This approach trains with smaller LLMs first (e.g., 20B) to optimize foundational visual capabilities and cross-modal alignment before transferring the vision encoder to larger LLMs (e.g., 72B) without retraining. This reuse skips intermediate stages for larger models.
80
+
81
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64006c09330a45b03605bba3/UoNUyS7ctN5pBxNv9KnzH.png)
82
+
83
+ Compared to Qwen2-VL's 1.4 trillion tokens, InternVL2.5-78B uses only 120 billion tokens—less than one-tenth. This strategy minimizes redundancy, maximizes pre-trained component reuse, and enables efficient training for complex vision-language tasks.
84
+
85
+ ### Training Enhancements
86
+
87
+ To improve real-world adaptability and performance, we introduce two key techniques:
88
+
89
+ - **Random JPEG Compression**: Random JPEG compression with quality levels between 75 and 100 is applied as a data augmentation technique. This simulates image degradation from internet sources, enhancing the model's robustness to noisy images.
90
+
91
+ - **Loss Reweighting**: To balance the NTP loss across responses of different lengths, we use a reweighting strategy called **square averaging**. This method balances contributions from responses of varying lengths, mitigating biases toward longer or shorter responses.
92
+
93
+ ## Data Organization
94
+
95
+ ### Dataset Configuration
96
+
97
+ In InternVL 2.0 and 2.5, the organization of the training data is controlled by several key parameters to optimize the balance and distribution of datasets during training.
98
+
99
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64119264f0f81eb569e0d569/2LJe24b1ua3gjI9gDitVl.png)
100
+
101
+ - **Data Augmentation:** JPEG compression is applied conditionally: enabled for image datasets to enhance robustness and disabled for video datasets to maintain consistent frame quality.
102
+
103
+ - **Maximum Tile Number:** The parameter `n_max` controls the maximum tiles per dataset. For example, higher values (24–36) are used for multi-image or high-resolution data, lower values (6–12) for standard images, and 1 for videos.
104
+
105
+ - **Repeat Factor:** The repeat factor `r` adjusts dataset sampling frequency. Values below 1 reduce a dataset's weight, while values above 1 increase it. This ensures balanced training across tasks and prevents overfitting or underfitting.
106
+
107
+ ### Data Filtering Pipeline
108
+
109
+ During development, we found that LLMs are highly sensitive to data noise, with even small anomalies—like outliers or repetitive data—causing abnormal behavior during inference. Repetitive generation, especially in long-form or CoT reasoning tasks, proved particularly harmful.
110
+
111
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64119264f0f81eb569e0d569/aka8ZRiKF3ajdyZBnNFZI.png)
112
+
113
+ To address this challenge and support future research, we designed an efficient data filtering pipeline to remove low-quality samples.
114
+
115
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64119264f0f81eb569e0d569/70l1UxnX-Arn0NoOGwpth.png)
116
+
117
+ The pipeline includes two modules, for **pure-text data**, three key strategies are used:
118
+
119
+ 1. **LLM-Based Quality Scoring**: Each sample is scored (0–10) using a pre-trained LLM with domain-specific prompts. Samples scoring below a threshold (e.g., 7) are removed to ensure high-quality data.
120
+ 2. **Repetition Detection**: Repetitive samples are flagged using LLM-based prompts and manually reviewed. Samples scoring below a stricter threshold (e.g., 3) are excluded to avoid repetitive patterns.
121
+ 3. **Heuristic Rule-Based Filtering**: Anomalies like abnormal sentence lengths or duplicate lines are detected using rules. Flagged samples undergo manual verification to ensure accuracy before removal.
122
+
123
+ For **multimodal data**, two strategies are used:
124
+
125
+ 1. **Repetition Detection**: Repetitive samples in non-academic datasets are flagged and manually reviewed to prevent pattern loops. High-quality datasets are exempt from this process.
126
+ 2. **Heuristic Rule-Based Filtering**: Similar rules are applied to detect visual anomalies, with flagged data verified manually to maintain integrity.
127
+
128
+ ### Training Data
129
+
130
+ As shown in the following figure, from InternVL 1.5 to 2.0 and then to 2.5, the fine-tuning data mixture has undergone iterative improvements in scale, quality, and diversity. For more information about the training data, please refer to our technical report.
131
+
132
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64119264f0f81eb569e0d569/GaTY9Lde02YzclASMthDa.png)
133
+
134
+ ## Evaluation on Multimodal Capability
135
+
136
+ ### Multimodal Reasoning and Mathematics
137
+
138
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64119264f0f81eb569e0d569/ihFWMRHbF0lpFTkLqnnj1.png)
139
+
140
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64119264f0f81eb569e0d569/Nrzq0kjlitjp_jrJCqtwX.png)
141
+
142
+ ### OCR, Chart, and Document Understanding
143
+
144
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64119264f0f81eb569e0d569/3yCMoLjlbsqY7ZJViGzih.png)
145
+
146
+ ### Multi-Image & Real-World Comprehension
147
+
148
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64119264f0f81eb569e0d569/DSnalmEyhDVQ9GE0GPCla.png)
149
+
150
+ ### Comprehensive Multimodal & Hallucination Evaluation
151
+
152
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64119264f0f81eb569e0d569/Z7Raj3TGDiV1H81pDHtoG.png)
153
+
154
+ ### Visual Grounding
155
+
156
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64119264f0f81eb569e0d569/lPcIrng8MPSg_PM1hpDPt.png)
157
+
158
+ ### Multimodal Multilingual Understanding
159
+
160
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64119264f0f81eb569e0d569/BPpbAOX36RV8RTnm3j-gs.png)
161
+
162
+ ### Video Understanding
163
+
164
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64006c09330a45b03605bba3/tcwH-i1qc8H16En-7AZ5M.png)
165
+
166
+ ## Evaluation on Language Capability
167
+
168
+ Training InternVL 2.0 models led to a decline in pure language capabilities. InternVL 2.5 addresses this by collecting more high-quality open-source data and filtering out low-quality data, achieving better preservation of pure language performance.
169
+
170
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64119264f0f81eb569e0d569/mxuSKvSY-kfI8zePpXj6y.png)
171
+
172
+ ## Quick Start
173
+
174
+ We provide an example code to run `InternVL2_5-8B` using `transformers`.
175
+
176
+ > Please use transformers>=4.37.2 to ensure the model works normally.
177
+
178
+ ### Model Loading
179
+
180
+ #### 16-bit (bf16 / fp16)
181
+
182
+ ```python
183
+ import torch
184
+ from transformers import AutoTokenizer, AutoModel
185
+ path = "OpenGVLab/InternVL2_5-8B"
186
+ model = AutoModel.from_pretrained(
187
+ path,
188
+ torch_dtype=torch.bfloat16,
189
+ low_cpu_mem_usage=True,
190
+ use_flash_attn=True,
191
+ trust_remote_code=True).eval().cuda()
192
+ ```
193
+
194
+ #### BNB 8-bit Quantization
195
+
196
+ ```python
197
+ import torch
198
+ from transformers import AutoTokenizer, AutoModel
199
+ path = "OpenGVLab/InternVL2_5-8B"
200
+ model = AutoModel.from_pretrained(
201
+ path,
202
+ torch_dtype=torch.bfloat16,
203
+ load_in_8bit=True,
204
+ low_cpu_mem_usage=True,
205
+ use_flash_attn=True,
206
+ trust_remote_code=True).eval()
207
+ ```
208
+
209
+ #### Multiple GPUs
210
+
211
+ The reason for writing the code this way is to avoid errors that occur during multi-GPU inference due to tensors not being on the same device. By ensuring that the first and last layers of the large language model (LLM) are on the same device, we prevent such errors.
212
+
213
+ ```python
214
+ import math
215
+ import torch
216
+ from transformers import AutoTokenizer, AutoModel
217
+
218
+ def split_model(model_name):
219
+ device_map = {}
220
+ world_size = torch.cuda.device_count()
221
+ num_layers = {
222
+ 'InternVL2_5-1B': 24, 'InternVL2_5-2B': 24, 'InternVL2_5-4B': 36, 'InternVL2_5-8B': 32,
223
+ 'InternVL2_5-26B': 48, 'InternVL2_5-38B': 64, 'InternVL2_5-78B': 80}[model_name]
224
+ # Since the first GPU will be used for ViT, treat it as half a GPU.
225
+ num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
226
+ num_layers_per_gpu = [num_layers_per_gpu] * world_size
227
+ num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
228
+ layer_cnt = 0
229
+ for i, num_layer in enumerate(num_layers_per_gpu):
230
+ for j in range(num_layer):
231
+ device_map[f'language_model.model.layers.{layer_cnt}'] = i
232
+ layer_cnt += 1
233
+ device_map['vision_model'] = 0
234
+ device_map['mlp1'] = 0
235
+ device_map['language_model.model.tok_embeddings'] = 0
236
+ device_map['language_model.model.embed_tokens'] = 0
237
+ device_map['language_model.output'] = 0
238
+ device_map['language_model.model.norm'] = 0
239
+ device_map['language_model.lm_head'] = 0
240
+ device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
241
+
242
+ return device_map
243
+
244
+ path = "OpenGVLab/InternVL2_5-8B"
245
+ device_map = split_model('InternVL2_5-8B')
246
+ model = AutoModel.from_pretrained(
247
+ path,
248
+ torch_dtype=torch.bfloat16,
249
+ low_cpu_mem_usage=True,
250
+ use_flash_attn=True,
251
+ trust_remote_code=True,
252
+ device_map=device_map).eval()
253
+ ```
254
+
255
+ ### Inference with Transformers
256
+
257
+ ```python
258
+ import numpy as np
259
+ import torch
260
+ import torchvision.transforms as T
261
+ from decord import VideoReader, cpu
262
+ from PIL import Image
263
+ from torchvision.transforms.functional import InterpolationMode
264
+ from transformers import AutoModel, AutoTokenizer
265
+
266
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
267
+ IMAGENET_STD = (0.229, 0.224, 0.225)
268
+
269
+ def build_transform(input_size):
270
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
271
+ transform = T.Compose([
272
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
273
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
274
+ T.ToTensor(),
275
+ T.Normalize(mean=MEAN, std=STD)
276
+ ])
277
+ return transform
278
+
279
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
280
+ best_ratio_diff = float('inf')
281
+ best_ratio = (1, 1)
282
+ area = width * height
283
+ for ratio in target_ratios:
284
+ target_aspect_ratio = ratio[0] / ratio[1]
285
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
286
+ if ratio_diff < best_ratio_diff:
287
+ best_ratio_diff = ratio_diff
288
+ best_ratio = ratio
289
+ elif ratio_diff == best_ratio_diff:
290
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
291
+ best_ratio = ratio
292
+ return best_ratio
293
+
294
+ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
295
+ orig_width, orig_height = image.size
296
+ aspect_ratio = orig_width / orig_height
297
+
298
+ # calculate the existing image aspect ratio
299
+ target_ratios = set(
300
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
301
+ i * j <= max_num and i * j >= min_num)
302
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
303
+
304
+ # find the closest aspect ratio to the target
305
+ target_aspect_ratio = find_closest_aspect_ratio(
306
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
307
+
308
+ # calculate the target width and height
309
+ target_width = image_size * target_aspect_ratio[0]
310
+ target_height = image_size * target_aspect_ratio[1]
311
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
312
+
313
+ # resize the image
314
+ resized_img = image.resize((target_width, target_height))
315
+ processed_images = []
316
+ for i in range(blocks):
317
+ box = (
318
+ (i % (target_width // image_size)) * image_size,
319
+ (i // (target_width // image_size)) * image_size,
320
+ ((i % (target_width // image_size)) + 1) * image_size,
321
+ ((i // (target_width // image_size)) + 1) * image_size
322
+ )
323
+ # split the image
324
+ split_img = resized_img.crop(box)
325
+ processed_images.append(split_img)
326
+ assert len(processed_images) == blocks
327
+ if use_thumbnail and len(processed_images) != 1:
328
+ thumbnail_img = image.resize((image_size, image_size))
329
+ processed_images.append(thumbnail_img)
330
+ return processed_images
331
+
332
+ def load_image(image_file, input_size=448, max_num=12):
333
+ image = Image.open(image_file).convert('RGB')
334
+ transform = build_transform(input_size=input_size)
335
+ images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
336
+ pixel_values = [transform(image) for image in images]
337
+ pixel_values = torch.stack(pixel_values)
338
+ return pixel_values
339
+
340
+ # If you want to load a model using multiple GPUs, please refer to the `Multiple GPUs` section.
341
+ path = 'OpenGVLab/InternVL2_5-8B'
342
+ model = AutoModel.from_pretrained(
343
+ path,
344
+ torch_dtype=torch.bfloat16,
345
+ low_cpu_mem_usage=True,
346
+ use_flash_attn=True,
347
+ trust_remote_code=True).eval().cuda()
348
+ tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
349
+
350
+ # set the max number of tiles in `max_num`
351
+ pixel_values = load_image('./examples/image1.jpg', max_num=12).to(torch.bfloat16).cuda()
352
+ generation_config = dict(max_new_tokens=1024, do_sample=True)
353
+
354
+ # pure-text conversation (纯文本对话)
355
+ question = 'Hello, who are you?'
356
+ response, history = model.chat(tokenizer, None, question, generation_config, history=None, return_history=True)
357
+ print(f'User: {question}\nAssistant: {response}')
358
+
359
+ question = 'Can you tell me a story?'
360
+ response, history = model.chat(tokenizer, None, question, generation_config, history=history, return_history=True)
361
+ print(f'User: {question}\nAssistant: {response}')
362
+
363
+ # single-image single-round conversation (单图单轮对话)
364
+ question = '<image>\nPlease describe the image shortly.'
365
+ response = model.chat(tokenizer, pixel_values, question, generation_config)
366
+ print(f'User: {question}\nAssistant: {response}')
367
+
368
+ # single-image multi-round conversation (单图多轮对话)
369
+ question = '<image>\nPlease describe the image in detail.'
370
+ response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=None, return_history=True)
371
+ print(f'User: {question}\nAssistant: {response}')
372
+
373
+ question = 'Please write a poem according to the image.'
374
+ response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=history, return_history=True)
375
+ print(f'User: {question}\nAssistant: {response}')
376
+
377
+ # multi-image multi-round conversation, combined images (多图多轮对话,拼接图像)
378
+ pixel_values1 = load_image('./examples/image1.jpg', max_num=12).to(torch.bfloat16).cuda()
379
+ pixel_values2 = load_image('./examples/image2.jpg', max_num=12).to(torch.bfloat16).cuda()
380
+ pixel_values = torch.cat((pixel_values1, pixel_values2), dim=0)
381
+
382
+ question = '<image>\nDescribe the two images in detail.'
383
+ response, history = model.chat(tokenizer, pixel_values, question, generation_config,
384
+ history=None, return_history=True)
385
+ print(f'User: {question}\nAssistant: {response}')
386
+
387
+ question = 'What are the similarities and differences between these two images.'
388
+ response, history = model.chat(tokenizer, pixel_values, question, generation_config,
389
+ history=history, return_history=True)
390
+ print(f'User: {question}\nAssistant: {response}')
391
+
392
+ # multi-image multi-round conversation, separate images (多图多轮对话,独立图像)
393
+ pixel_values1 = load_image('./examples/image1.jpg', max_num=12).to(torch.bfloat16).cuda()
394
+ pixel_values2 = load_image('./examples/image2.jpg', max_num=12).to(torch.bfloat16).cuda()
395
+ pixel_values = torch.cat((pixel_values1, pixel_values2), dim=0)
396
+ num_patches_list = [pixel_values1.size(0), pixel_values2.size(0)]
397
+
398
+ question = 'Image-1: <image>\nImage-2: <image>\nDescribe the two images in detail.'
399
+ response, history = model.chat(tokenizer, pixel_values, question, generation_config,
400
+ num_patches_list=num_patches_list,
401
+ history=None, return_history=True)
402
+ print(f'User: {question}\nAssistant: {response}')
403
+
404
+ question = 'What are the similarities and differences between these two images.'
405
+ response, history = model.chat(tokenizer, pixel_values, question, generation_config,
406
+ num_patches_list=num_patches_list,
407
+ history=history, return_history=True)
408
+ print(f'User: {question}\nAssistant: {response}')
409
+
410
+ # batch inference, single image per sample (单图批处理)
411
+ pixel_values1 = load_image('./examples/image1.jpg', max_num=12).to(torch.bfloat16).cuda()
412
+ pixel_values2 = load_image('./examples/image2.jpg', max_num=12).to(torch.bfloat16).cuda()
413
+ num_patches_list = [pixel_values1.size(0), pixel_values2.size(0)]
414
+ pixel_values = torch.cat((pixel_values1, pixel_values2), dim=0)
415
+
416
+ questions = ['<image>\nDescribe the image in detail.'] * len(num_patches_list)
417
+ responses = model.batch_chat(tokenizer, pixel_values,
418
+ num_patches_list=num_patches_list,
419
+ questions=questions,
420
+ generation_config=generation_config)
421
+ for question, response in zip(questions, responses):
422
+ print(f'User: {question}\nAssistant: {response}')
423
+
424
+ # video multi-round conversation (视频多轮对话)
425
+ def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
426
+ if bound:
427
+ start, end = bound[0], bound[1]
428
+ else:
429
+ start, end = -100000, 100000
430
+ start_idx = max(first_idx, round(start * fps))
431
+ end_idx = min(round(end * fps), max_frame)
432
+ seg_size = float(end_idx - start_idx) / num_segments
433
+ frame_indices = np.array([
434
+ int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
435
+ for idx in range(num_segments)
436
+ ])
437
+ return frame_indices
438
+
439
+ def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
440
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
441
+ max_frame = len(vr) - 1
442
+ fps = float(vr.get_avg_fps())
443
+
444
+ pixel_values_list, num_patches_list = [], []
445
+ transform = build_transform(input_size=input_size)
446
+ frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
447
+ for frame_index in frame_indices:
448
+ img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB')
449
+ img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
450
+ pixel_values = [transform(tile) for tile in img]
451
+ pixel_values = torch.stack(pixel_values)
452
+ num_patches_list.append(pixel_values.shape[0])
453
+ pixel_values_list.append(pixel_values)
454
+ pixel_values = torch.cat(pixel_values_list)
455
+ return pixel_values, num_patches_list
456
+
457
+ video_path = './examples/red-panda.mp4'
458
+ pixel_values, num_patches_list = load_video(video_path, num_segments=8, max_num=1)
459
+ pixel_values = pixel_values.to(torch.bfloat16).cuda()
460
+ video_prefix = ''.join([f'Frame{i+1}: <image>\n' for i in range(len(num_patches_list))])
461
+ question = video_prefix + 'What is the red panda doing?'
462
+ # Frame1: <image>\nFrame2: <image>\n...\nFrame8: <image>\n{question}
463
+ response, history = model.chat(tokenizer, pixel_values, question, generation_config,
464
+ num_patches_list=num_patches_list, history=None, return_history=True)
465
+ print(f'User: {question}\nAssistant: {response}')
466
+
467
+ question = 'Describe this video in detail.'
468
+ response, history = model.chat(tokenizer, pixel_values, question, generation_config,
469
+ num_patches_list=num_patches_list, history=history, return_history=True)
470
+ print(f'User: {question}\nAssistant: {response}')
471
+ ```
472
+
473
+ #### Streaming Output
474
+
475
+ Besides this method, you can also use the following code to get streamed output.
476
+
477
+ ```python
478
+ from transformers import TextIteratorStreamer
479
+ from threading import Thread
480
+
481
+ # Initialize the streamer
482
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10)
483
+ # Define the generation configuration
484
+ generation_config = dict(max_new_tokens=1024, do_sample=False, streamer=streamer)
485
+ # Start the model chat in a separate thread
486
+ thread = Thread(target=model.chat, kwargs=dict(
487
+ tokenizer=tokenizer, pixel_values=pixel_values, question=question,
488
+ history=None, return_history=False, generation_config=generation_config,
489
+ ))
490
+ thread.start()
491
+
492
+ # Initialize an empty string to store the generated text
493
+ generated_text = ''
494
+ # Loop through the streamer to get the new text as it is generated
495
+ for new_text in streamer:
496
+ if new_text == model.conv_template.sep:
497
+ break
498
+ generated_text += new_text
499
+ print(new_text, end='', flush=True) # Print each new chunk of generated text on the same line
500
+ ```
501
+
502
+ ## Finetune
503
+
504
+ Many repositories now support fine-tuning of the InternVL series models, including [InternVL](https://github.com/OpenGVLab/InternVL), [SWIFT](https://github.com/modelscope/ms-swift), [XTurner](https://github.com/InternLM/xtuner), and others. Please refer to their documentation for more details on fine-tuning.
505
+
506
+ ## Deployment
507
+
508
+ ### LMDeploy
509
+
510
+ LMDeploy is a toolkit for compressing, deploying, and serving LLMs & VLMs.
511
+
512
+ ```sh
513
+ pip install lmdeploy>=0.6.4 --no-deps
514
+ ```
515
+
516
+ LMDeploy abstracts the complex inference process of multi-modal Vision-Language Models (VLM) into an easy-to-use pipeline, similar to the Large Language Model (LLM) inference pipeline.
517
+
518
+ #### A 'Hello, world' Example
519
+
520
+ ```python
521
+ from lmdeploy import pipeline, TurbomindEngineConfig
522
+ from lmdeploy.vl import load_image
523
+
524
+ model = 'OpenGVLab/InternVL2_5-8B'
525
+ image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
526
+ pipe = pipeline(model, backend_config=TurbomindEngineConfig(session_len=8192))
527
+ response = pipe(('describe this image', image))
528
+ print(response.text)
529
+ ```
530
+
531
+ If `ImportError` occurs while executing this case, please install the required dependency packages as prompted.
532
+
533
+ #### Multi-images Inference
534
+
535
+ When dealing with multiple images, you can put them all in one list. Keep in mind that multiple images will lead to a higher number of input tokens, and as a result, the size of the context window typically needs to be increased.
536
+
537
+ ```python
538
+ from lmdeploy import pipeline, TurbomindEngineConfig
539
+ from lmdeploy.vl import load_image
540
+ from lmdeploy.vl.constants import IMAGE_TOKEN
541
+
542
+ model = 'OpenGVLab/InternVL2_5-8B'
543
+ pipe = pipeline(model, backend_config=TurbomindEngineConfig(session_len=8192))
544
+
545
+ image_urls=[
546
+ 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg',
547
+ 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/det.jpg'
548
+ ]
549
+
550
+ images = [load_image(img_url) for img_url in image_urls]
551
+ # Numbering images improves multi-image conversations
552
+ response = pipe((f'Image-1: {IMAGE_TOKEN}\nImage-2: {IMAGE_TOKEN}\ndescribe these two images', images))
553
+ print(response.text)
554
+ ```
555
+
556
+ #### Batch Prompts Inference
557
+
558
+ Conducting inference with batch prompts is quite straightforward; just place them within a list structure:
559
+
560
+ ```python
561
+ from lmdeploy import pipeline, TurbomindEngineConfig
562
+ from lmdeploy.vl import load_image
563
+
564
+ model = 'OpenGVLab/InternVL2_5-8B'
565
+ pipe = pipeline(model, backend_config=TurbomindEngineConfig(session_len=8192))
566
+
567
+ image_urls=[
568
+ "https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg",
569
+ "https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/det.jpg"
570
+ ]
571
+ prompts = [('describe this image', load_image(img_url)) for img_url in image_urls]
572
+ response = pipe(prompts)
573
+ print(response)
574
+ ```
575
+
576
+ #### Multi-turn Conversation
577
+
578
+ There are two ways to do the multi-turn conversations with the pipeline. One is to construct messages according to the format of OpenAI and use above introduced method, the other is to use the `pipeline.chat` interface.
579
+
580
+ ```python
581
+ from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig
582
+ from lmdeploy.vl import load_image
583
+
584
+ model = 'OpenGVLab/InternVL2_5-8B'
585
+ pipe = pipeline(model, backend_config=TurbomindEngineConfig(session_len=8192))
586
+
587
+ image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg')
588
+ gen_config = GenerationConfig(top_k=40, top_p=0.8, temperature=0.8)
589
+ sess = pipe.chat(('describe this image', image), gen_config=gen_config)
590
+ print(sess.response.text)
591
+ sess = pipe.chat('What is the woman doing?', session=sess, gen_config=gen_config)
592
+ print(sess.response.text)
593
+ ```
594
+
595
+ #### Service
596
+
597
+ LMDeploy's `api_server` enables models to be easily packed into services with a single command. The provided RESTful APIs are compatible with OpenAI's interfaces. Below are an example of service startup:
598
+
599
+ ```shell
600
+ lmdeploy serve api_server OpenGVLab/InternVL2_5-8B --server-port 23333
601
+ ```
602
+
603
+ To use the OpenAI-style interface, you need to install OpenAI:
604
+
605
+ ```shell
606
+ pip install openai
607
+ ```
608
+
609
+ Then, use the code below to make the API call:
610
+
611
+ ```python
612
+ from openai import OpenAI
613
+
614
+ client = OpenAI(api_key='YOUR_API_KEY', base_url='http://0.0.0.0:23333/v1')
615
+ model_name = client.models.list().data[0].id
616
+ response = client.chat.completions.create(
617
+ model=model_name,
618
+ messages=[{
619
+ 'role':
620
+ 'user',
621
+ 'content': [{
622
+ 'type': 'text',
623
+ 'text': 'describe this image',
624
+ }, {
625
+ 'type': 'image_url',
626
+ 'image_url': {
627
+ 'url':
628
+ 'https://modelscope.oss-cn-beijing.aliyuncs.com/resource/tiger.jpeg',
629
+ },
630
+ }],
631
+ }],
632
+ temperature=0.8,
633
+ top_p=0.8)
634
+ print(response)
635
+ ```
src/third_party/InternVL/internvl_chat/evaluate.sh ADDED
@@ -0,0 +1,726 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -x
2
+
3
+ CHECKPOINT=${1}
4
+ DATASET=${2}
5
+ CHECKPOINT="$(pwd)/${CHECKPOINT}"
6
+ export PYTHONPATH="$(pwd):${PYTHONPATH}"
7
+ echo "CHECKPOINT: ${CHECKPOINT}"
8
+
9
+ MASTER_PORT=${MASTER_PORT:-63669}
10
+ PORT=${PORT:-63665}
11
+ GPUS=${GPUS:-8}
12
+ GPUS_PER_NODE=${GPUS_PER_NODE:-8}
13
+ NODES=$((GPUS / GPUS_PER_NODE))
14
+ export MASTER_PORT=${MASTER_PORT}
15
+ export PORT=${PORT}
16
+
17
+ # Save original arguments
18
+ ARGS=("$@")
19
+
20
+ # Parse options
21
+ while [[ $# -gt 0 ]]; do
22
+ case "$1" in
23
+ --auto)
24
+ GPUS=1
25
+ shift
26
+ ;;
27
+ *)
28
+ shift
29
+ ;;
30
+ esac
31
+ done
32
+ echo "GPUS: ${GPUS}"
33
+
34
+ if [ ${DATASET} == "mme" ]; then
35
+ cd eval/mme/
36
+ DIRNAME=`basename ${CHECKPOINT}`
37
+ python eval.py --checkpoint ${CHECKPOINT} "${ARGS[@]:2}"
38
+ python calculation.py --results_dir ${DIRNAME}
39
+ cd ../../
40
+ fi
41
+
42
+ if [ ${DATASET} == "caption" ]; then
43
+ torchrun \
44
+ --nnodes=1 \
45
+ --node_rank=0 \
46
+ --master_addr=127.0.0.1 \
47
+ --nproc_per_node=${GPUS} \
48
+ --master_port=${MASTER_PORT} \
49
+ eval/caption/evaluate_caption.py --checkpoint ${CHECKPOINT} "${ARGS[@]:2}"
50
+ fi
51
+
52
+ if [ ${DATASET} == "caption-coco" ]; then
53
+ torchrun \
54
+ --nnodes=1 \
55
+ --node_rank=0 \
56
+ --master_addr=127.0.0.1 \
57
+ --nproc_per_node=${GPUS} \
58
+ --master_port=${MASTER_PORT} \
59
+ eval/caption/evaluate_caption.py --checkpoint ${CHECKPOINT} --datasets coco "${ARGS[@]:2}"
60
+ fi
61
+
62
+ if [ ${DATASET} == "caption-flickr30k" ]; then
63
+ torchrun \
64
+ --nnodes=1 \
65
+ --node_rank=0 \
66
+ --master_addr=127.0.0.1 \
67
+ --nproc_per_node=${GPUS} \
68
+ --master_port=${MASTER_PORT} \
69
+ eval/caption/evaluate_caption.py --checkpoint ${CHECKPOINT} --datasets flickr30k "${ARGS[@]:2}"
70
+ fi
71
+
72
+ if [ ${DATASET} == "caption-nocaps" ]; then
73
+ torchrun \
74
+ --nnodes=1 \
75
+ --node_rank=0 \
76
+ --master_addr=127.0.0.1 \
77
+ --nproc_per_node=${GPUS} \
78
+ --master_port=${MASTER_PORT} \
79
+ eval/caption/evaluate_caption.py --checkpoint ${CHECKPOINT} --datasets nocaps "${ARGS[@]:2}"
80
+ fi
81
+
82
+ if [ ${DATASET} == "vqa" ]; then
83
+ torchrun \
84
+ --nnodes=1 \
85
+ --node_rank=0 \
86
+ --master_addr=127.0.0.1 \
87
+ --nproc_per_node=${GPUS} \
88
+ --master_port=${MASTER_PORT} \
89
+ eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} "${ARGS[@]:2}"
90
+ fi
91
+
92
+ if [ ${DATASET} == "vqa-okvqa-val" ]; then
93
+ torchrun \
94
+ --nnodes=1 \
95
+ --node_rank=0 \
96
+ --master_addr=127.0.0.1 \
97
+ --nproc_per_node=${GPUS} \
98
+ --master_port=${MASTER_PORT} \
99
+ eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --datasets okvqa_val "${ARGS[@]:2}"
100
+ fi
101
+
102
+ if [ ${DATASET} == "vqa-textvqa-val" ]; then
103
+ torchrun \
104
+ --nnodes=1 \
105
+ --node_rank=0 \
106
+ --master_addr=127.0.0.1 \
107
+ --nproc_per_node=${GPUS} \
108
+ --master_port=${MASTER_PORT} \
109
+ eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --datasets textvqa_val "${ARGS[@]:2}"
110
+ fi
111
+
112
+ if [ ${DATASET} == "vqa-textvqa-val-ocr" ]; then
113
+ torchrun \
114
+ --nnodes=1 \
115
+ --node_rank=0 \
116
+ --master_addr=127.0.0.1 \
117
+ --nproc_per_node=${GPUS} \
118
+ --master_port=${MASTER_PORT} \
119
+ eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --datasets textvqa_val_ocr "${ARGS[@]:2}"
120
+ fi
121
+
122
+ if [ ${DATASET} == "vqa-vizwiz-val" ]; then
123
+ torchrun \
124
+ --nnodes=1 \
125
+ --node_rank=0 \
126
+ --master_addr=127.0.0.1 \
127
+ --nproc_per_node=${GPUS} \
128
+ --master_port=${MASTER_PORT} \
129
+ eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --datasets vizwiz_val "${ARGS[@]:2}"
130
+ fi
131
+
132
+ if [ ${DATASET} == "vqa-vizwiz-test" ]; then
133
+ torchrun \
134
+ --nnodes=1 \
135
+ --node_rank=0 \
136
+ --master_addr=127.0.0.1 \
137
+ --nproc_per_node=${GPUS} \
138
+ --master_port=${MASTER_PORT} \
139
+ eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --datasets vizwiz_test "${ARGS[@]:2}"
140
+ fi
141
+
142
+ if [ ${DATASET} == "vqa-vqav2-testdev" ]; then
143
+ torchrun \
144
+ --nnodes=1 \
145
+ --node_rank=0 \
146
+ --master_addr=127.0.0.1 \
147
+ --nproc_per_node=${GPUS} \
148
+ --master_port=${MASTER_PORT} \
149
+ eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --datasets vqav2_testdev "${ARGS[@]:2}"
150
+ fi
151
+
152
+ if [ ${DATASET} == "vqa-ai2d-test" ]; then
153
+ torchrun \
154
+ --nnodes=1 \
155
+ --node_rank=0 \
156
+ --master_addr=127.0.0.1 \
157
+ --nproc_per_node=${GPUS} \
158
+ --master_port=${MASTER_PORT} \
159
+ eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --datasets ai2diagram_test "${ARGS[@]:2}"
160
+ fi
161
+
162
+ if [ ${DATASET} == "vqa-vqav2-val" ]; then
163
+ torchrun \
164
+ --nnodes=1 \
165
+ --node_rank=0 \
166
+ --master_addr=127.0.0.1 \
167
+ --nproc_per_node=${GPUS} \
168
+ --master_port=${MASTER_PORT} \
169
+ eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --datasets vqav2_val "${ARGS[@]:2}"
170
+ fi
171
+
172
+ if [ ${DATASET} == "vqa-gqa-testdev" ]; then
173
+ torchrun \
174
+ --nnodes=1 \
175
+ --node_rank=0 \
176
+ --master_addr=127.0.0.1 \
177
+ --nproc_per_node=${GPUS} \
178
+ --master_port=${MASTER_PORT} \
179
+ eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --datasets gqa_testdev_llava "${ARGS[@]:2}"
180
+ fi
181
+
182
+ if [ ${DATASET} == "vqa-docvqa-val" ]; then
183
+ torchrun \
184
+ --nnodes=1 \
185
+ --node_rank=0 \
186
+ --master_addr=127.0.0.1 \
187
+ --nproc_per_node=${GPUS} \
188
+ --master_port=${MASTER_PORT} \
189
+ eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --datasets docvqa_val "${ARGS[@]:2}"
190
+ fi
191
+
192
+ if [ ${DATASET} == "vqa-docvqa-test" ]; then
193
+ torchrun \
194
+ --nnodes=1 \
195
+ --node_rank=0 \
196
+ --master_addr=127.0.0.1 \
197
+ --nproc_per_node=${GPUS} \
198
+ --master_port=${MASTER_PORT} \
199
+ eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --datasets docvqa_test "${ARGS[@]:2}"
200
+ fi
201
+
202
+ if [ ${DATASET} == "vqa-mpdocvqa-val" ]; then
203
+ torchrun \
204
+ --nnodes=1 \
205
+ --node_rank=0 \
206
+ --master_addr=127.0.0.1 \
207
+ --nproc_per_node=${GPUS} \
208
+ --master_port=${MASTER_PORT} \
209
+ eval/mpdocvqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --datasets mpdocvqa_val "${ARGS[@]:2}"
210
+ fi
211
+
212
+ if [ ${DATASET} == "vqa-mpdocvqa-test" ]; then
213
+ torchrun \
214
+ --nnodes=1 \
215
+ --node_rank=0 \
216
+ --master_addr=127.0.0.1 \
217
+ --nproc_per_node=${GPUS} \
218
+ --master_port=${MASTER_PORT} \
219
+ eval/mpdocvqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --datasets mpdocvqa_test "${ARGS[@]:2}"
220
+ fi
221
+
222
+ if [ ${DATASET} == "vqa-chartqa-test" ]; then
223
+ torchrun \
224
+ --nnodes=1 \
225
+ --node_rank=0 \
226
+ --master_addr=127.0.0.1 \
227
+ --nproc_per_node=${GPUS} \
228
+ --master_port=${MASTER_PORT} \
229
+ eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --datasets chartqa_test_human,chartqa_test_augmented "${ARGS[@]:2}"
230
+ fi
231
+
232
+ if [ ${DATASET} == "vqa-infovqa-val" ]; then
233
+ torchrun \
234
+ --nnodes=1 \
235
+ --node_rank=0 \
236
+ --master_addr=127.0.0.1 \
237
+ --nproc_per_node=${GPUS} \
238
+ --master_port=${MASTER_PORT} \
239
+ eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --datasets infographicsvqa_val "${ARGS[@]:2}"
240
+ fi
241
+
242
+ if [ ${DATASET} == "vqa-infovqa-test" ]; then
243
+ torchrun \
244
+ --nnodes=1 \
245
+ --node_rank=0 \
246
+ --master_addr=127.0.0.1 \
247
+ --nproc_per_node=${GPUS} \
248
+ --master_port=${MASTER_PORT} \
249
+ eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --datasets infographicsvqa_test "${ARGS[@]:2}"
250
+ fi
251
+
252
+ if [ ${DATASET} == "vqa-chartqa-test-human" ]; then
253
+ torchrun \
254
+ --nnodes=1 \
255
+ --node_rank=0 \
256
+ --master_addr=127.0.0.1 \
257
+ --nproc_per_node=${GPUS} \
258
+ --master_port=${MASTER_PORT} \
259
+ eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --datasets chartqa_test_human "${ARGS[@]:2}"
260
+ fi
261
+
262
+ if [ ${DATASET} == "vqa-chartqa-test-augmented" ]; then
263
+ torchrun \
264
+ --nnodes=1 \
265
+ --node_rank=0 \
266
+ --master_addr=127.0.0.1 \
267
+ --nproc_per_node=${GPUS} \
268
+ --master_port=${MASTER_PORT} \
269
+ eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --datasets chartqa_test_augmented "${ARGS[@]:2}"
270
+ fi
271
+
272
+ if [ ${DATASET} == "vqa-ocrvqa-val" ]; then
273
+ torchrun \
274
+ --nnodes=1 \
275
+ --node_rank=0 \
276
+ --master_addr=127.0.0.1 \
277
+ --nproc_per_node=${GPUS} \
278
+ --master_port=${MASTER_PORT} \
279
+ eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --datasets ocrvqa_val "${ARGS[@]:2}"
280
+ fi
281
+
282
+ if [ ${DATASET} == "vqa-ocrvqa-test" ]; then
283
+ torchrun \
284
+ --nnodes=1 \
285
+ --node_rank=0 \
286
+ --master_addr=127.0.0.1 \
287
+ --nproc_per_node=${GPUS} \
288
+ --master_port=${MASTER_PORT} \
289
+ eval/vqa/evaluate_vqa.py --checkpoint ${CHECKPOINT} --datasets ocrvqa_test "${ARGS[@]:2}"
290
+ fi
291
+
292
+ if [ ${DATASET} == "refcoco" ]; then
293
+ torchrun \
294
+ --nnodes=1 \
295
+ --node_rank=0 \
296
+ --master_addr=127.0.0.1 \
297
+ --nproc_per_node=${GPUS} \
298
+ --master_port=${MASTER_PORT} \
299
+ eval/refcoco/evaluate_grounding.py --checkpoint ${CHECKPOINT} "${ARGS[@]:2}"
300
+ fi
301
+
302
+ if [ ${DATASET} == "refcoco-val" ]; then
303
+ torchrun \
304
+ --nnodes=1 \
305
+ --node_rank=0 \
306
+ --master_addr=127.0.0.1 \
307
+ --nproc_per_node=${GPUS} \
308
+ --master_port=${MASTER_PORT} \
309
+ eval/refcoco/evaluate_grounding.py --checkpoint ${CHECKPOINT} --datasets refcoco_val "${ARGS[@]:2}"
310
+ fi
311
+
312
+ if [ ${DATASET} == "refcoco-testA" ]; then
313
+ torchrun \
314
+ --nnodes=1 \
315
+ --node_rank=0 \
316
+ --master_addr=127.0.0.1 \
317
+ --nproc_per_node=${GPUS} \
318
+ --master_port=${MASTER_PORT} \
319
+ eval/refcoco/evaluate_grounding.py --checkpoint ${CHECKPOINT} --datasets refcoco_testA "${ARGS[@]:2}"
320
+ fi
321
+
322
+ if [ ${DATASET} == "refcoco-testB" ]; then
323
+ torchrun \
324
+ --nnodes=1 \
325
+ --node_rank=0 \
326
+ --master_addr=127.0.0.1 \
327
+ --nproc_per_node=${GPUS} \
328
+ --master_port=${MASTER_PORT} \
329
+ eval/refcoco/evaluate_grounding.py --checkpoint ${CHECKPOINT} --datasets refcoco_testB "${ARGS[@]:2}"
330
+ fi
331
+
332
+ if [ ${DATASET} == "refcoco+-val" ]; then
333
+ torchrun \
334
+ --nnodes=1 \
335
+ --node_rank=0 \
336
+ --master_addr=127.0.0.1 \
337
+ --nproc_per_node=${GPUS} \
338
+ --master_port=${MASTER_PORT} \
339
+ eval/refcoco/evaluate_grounding.py --checkpoint ${CHECKPOINT} --datasets refcoco+_val "${ARGS[@]:2}"
340
+ fi
341
+
342
+ if [ ${DATASET} == "refcoco+-testA" ]; then
343
+ torchrun \
344
+ --nnodes=1 \
345
+ --node_rank=0 \
346
+ --master_addr=127.0.0.1 \
347
+ --nproc_per_node=${GPUS} \
348
+ --master_port=${MASTER_PORT} \
349
+ eval/refcoco/evaluate_grounding.py --checkpoint ${CHECKPOINT} --datasets refcoco+_testA "${ARGS[@]:2}"
350
+ fi
351
+
352
+ if [ ${DATASET} == "refcoco+-testB" ]; then
353
+ torchrun \
354
+ --nnodes=1 \
355
+ --node_rank=0 \
356
+ --master_addr=127.0.0.1 \
357
+ --nproc_per_node=${GPUS} \
358
+ --master_port=${MASTER_PORT} \
359
+ eval/refcoco/evaluate_grounding.py --checkpoint ${CHECKPOINT} --datasets refcoco+_testB "${ARGS[@]:2}"
360
+ fi
361
+
362
+ if [ ${DATASET} == "refcocog-val" ]; then
363
+ torchrun \
364
+ --nnodes=1 \
365
+ --node_rank=0 \
366
+ --master_addr=127.0.0.1 \
367
+ --nproc_per_node=${GPUS} \
368
+ --master_port=${MASTER_PORT} \
369
+ eval/refcoco/evaluate_grounding.py --checkpoint ${CHECKPOINT} --datasets refcocog_val "${ARGS[@]:2}"
370
+ fi
371
+
372
+ if [ ${DATASET} == "refcocog-test" ]; then
373
+ torchrun \
374
+ --nnodes=1 \
375
+ --node_rank=0 \
376
+ --master_addr=127.0.0.1 \
377
+ --nproc_per_node=${GPUS} \
378
+ --master_port=${MASTER_PORT} \
379
+ eval/refcoco/evaluate_grounding.py --checkpoint ${CHECKPOINT} --datasets refcocog_test "${ARGS[@]:2}"
380
+ fi
381
+
382
+ if [ ${DATASET} == "llava-bench" ]; then
383
+ rm -rf results/llava_bench_results_review.jsonl
384
+ python eval/llava_bench/evaluate_llava_bench.py --checkpoint ${CHECKPOINT} "${ARGS[@]:2}"
385
+ python -u eval/llava_bench/eval_gpt_review_bench.py \
386
+ --question data/llava-bench-in-the-wild/questions.jsonl \
387
+ --context data/llava-bench-in-the-wild/context.jsonl \
388
+ --rule eval/llava_bench/rule.json \
389
+ --answer-list \
390
+ data/llava-bench-in-the-wild/answers_gpt4.jsonl \
391
+ results/llava_bench_results.jsonl \
392
+ --output \
393
+ results/llava_bench_results_review.jsonl
394
+ python -u eval/llava_bench/summarize_gpt_review.py -f results/llava_bench_results_review.jsonl
395
+ fi
396
+
397
+ if [ ${DATASET} == "pope" ]; then
398
+ torchrun \
399
+ --nnodes=1 \
400
+ --node_rank=0 \
401
+ --master_addr=127.0.0.1 \
402
+ --nproc_per_node=${GPUS} \
403
+ --master_port=${MASTER_PORT} \
404
+ eval/pope/evaluate_pope.py --checkpoint ${CHECKPOINT} --datasets pope "${ARGS[@]:2}"
405
+ fi
406
+
407
+ if [ ${DATASET} == "tiny_lvlm" ]; then
408
+ torchrun \
409
+ --nnodes=1 \
410
+ --node_rank=0 \
411
+ --master_addr=127.0.0.1 \
412
+ --nproc_per_node=${GPUS} \
413
+ --master_port=${MASTER_PORT} \
414
+ eval/tiny_lvlm/evaluate_lvlm.py --checkpoint ${CHECKPOINT} --datasets updated_datasets "${ARGS[@]:2}"
415
+ fi
416
+
417
+ if [ ${DATASET} == "mmvet" ]; then
418
+ python eval/mmvet/evaluate_mmvet.py --checkpoint ${CHECKPOINT} --datasets mmvet "${ARGS[@]:2}"
419
+ fi
420
+
421
+ if [ ${DATASET} == "mmvetv2" ]; then
422
+ torchrun \
423
+ --nnodes=1 \
424
+ --node_rank=0 \
425
+ --master_addr=127.0.0.1 \
426
+ --nproc_per_node=${GPUS} \
427
+ --master_port=${MASTER_PORT} \
428
+ eval/mmvetv2/evaluate_mmvet_v2.py --checkpoint ${CHECKPOINT} --datasets mmvet-v2 "${ARGS[@]:2}"
429
+ fi
430
+
431
+ if [ ${DATASET} == "mmbench-dev-en" ]; then
432
+ torchrun \
433
+ --nnodes=1 \
434
+ --node_rank=0 \
435
+ --master_addr=127.0.0.1 \
436
+ --nproc_per_node=${GPUS} \
437
+ --master_port=${MASTER_PORT} \
438
+ eval/mmbench/evaluate_mmbench.py --checkpoint ${CHECKPOINT} --datasets mmbench_dev_20230712 "${ARGS[@]:2}"
439
+ fi
440
+
441
+ if [ ${DATASET} == "mmbench-dev-cn" ]; then
442
+ torchrun \
443
+ --nnodes=1 \
444
+ --node_rank=0 \
445
+ --master_addr=127.0.0.1 \
446
+ --nproc_per_node=${GPUS} \
447
+ --master_port=${MASTER_PORT} \
448
+ eval/mmbench/evaluate_mmbench.py --checkpoint ${CHECKPOINT} --datasets mmbench_dev_cn_20231003 "${ARGS[@]:2}"
449
+ fi
450
+
451
+ if [ ${DATASET} == "mmbench-test-en" ]; then
452
+ torchrun \
453
+ --nnodes=1 \
454
+ --node_rank=0 \
455
+ --master_addr=127.0.0.1 \
456
+ --nproc_per_node=${GPUS} \
457
+ --master_port=${MASTER_PORT} \
458
+ eval/mmbench/evaluate_mmbench.py --checkpoint ${CHECKPOINT} --datasets mmbench_test_en_20231003 "${ARGS[@]:2}"
459
+ fi
460
+
461
+ if [ ${DATASET} == "mmbench-test-cn" ]; then
462
+ torchrun \
463
+ --nnodes=1 \
464
+ --node_rank=0 \
465
+ --master_addr=127.0.0.1 \
466
+ --nproc_per_node=${GPUS} \
467
+ --master_port=${MASTER_PORT} \
468
+ eval/mmbench/evaluate_mmbench.py --checkpoint ${CHECKPOINT} --datasets mmbench_test_cn_20231003 "${ARGS[@]:2}"
469
+ fi
470
+
471
+ if [ ${DATASET} == "ccbench-dev" ]; then
472
+ torchrun \
473
+ --nnodes=1 \
474
+ --node_rank=0 \
475
+ --master_addr=127.0.0.1 \
476
+ --nproc_per_node=${GPUS} \
477
+ --master_port=${MASTER_PORT} \
478
+ eval/mmbench/evaluate_mmbench.py --checkpoint ${CHECKPOINT} --datasets ccbench_dev_cn "${ARGS[@]:2}"
479
+ fi
480
+
481
+ if [ ${DATASET} == "scienceqa" ]; then
482
+ torchrun \
483
+ --nnodes=1 \
484
+ --node_rank=0 \
485
+ --master_addr=127.0.0.1 \
486
+ --nproc_per_node=${GPUS} \
487
+ --master_port=${MASTER_PORT} \
488
+ eval/scienceqa/evaluate_scienceqa.py --checkpoint ${CHECKPOINT} --datasets sqa_test "${ARGS[@]:2}"
489
+ fi
490
+
491
+ if [ ${DATASET} == "mantis" ]; then
492
+ torchrun \
493
+ --nnodes=1 \
494
+ --node_rank=0 \
495
+ --master_addr=127.0.0.1 \
496
+ --nproc_per_node=${GPUS} \
497
+ --master_port=${MASTER_PORT} \
498
+ eval/mantis_eval/evaluate_mantis.py --checkpoint ${CHECKPOINT} --datasets Mantis-Eval "${ARGS[@]:2}"
499
+ fi
500
+
501
+ if [ ${DATASET} == "mirb" ]; then
502
+ torchrun \
503
+ --nnodes=1 \
504
+ --node_rank=0 \
505
+ --master_addr=127.0.0.1 \
506
+ --nproc_per_node=${GPUS} \
507
+ --master_port=${MASTER_PORT} \
508
+ eval/mirb/evaluate_mirb.py --checkpoint ${CHECKPOINT} --datasets MIRB "${ARGS[@]:2}"
509
+ fi
510
+
511
+ if [ ${DATASET} == "m3cot" ]; then
512
+ torchrun \
513
+ --nnodes=1 \
514
+ --node_rank=0 \
515
+ --master_addr=127.0.0.1 \
516
+ --nproc_per_node=${GPUS} \
517
+ --master_port=${MASTER_PORT} \
518
+ eval/scienceqa/evaluate_scienceqa.py --checkpoint ${CHECKPOINT} --datasets m3cot_test "${ARGS[@]:2}"
519
+ fi
520
+
521
+ if [ ${DATASET} == "mmmu-dev" ]; then
522
+ torchrun \
523
+ --nnodes=1 \
524
+ --node_rank=0 \
525
+ --master_addr=127.0.0.1 \
526
+ --nproc_per_node=${GPUS} \
527
+ --master_port=${MASTER_PORT} \
528
+ eval/mmmu/evaluate_mmmu.py --checkpoint ${CHECKPOINT} --datasets MMMU_dev "${ARGS[@]:2}"
529
+ fi
530
+
531
+ if [ ${DATASET} == "mmmu-val" ]; then
532
+ torchrun \
533
+ --nnodes=1 \
534
+ --node_rank=0 \
535
+ --master_addr=127.0.0.1 \
536
+ --nproc_per_node=${GPUS} \
537
+ --master_port=${MASTER_PORT} \
538
+ eval/mmmu/evaluate_mmmu.py --checkpoint ${CHECKPOINT} --datasets MMMU_validation "${ARGS[@]:2}"
539
+ fi
540
+
541
+ if [ ${DATASET} == "mmmu-test" ]; then
542
+ torchrun \
543
+ --nnodes=1 \
544
+ --node_rank=0 \
545
+ --master_addr=127.0.0.1 \
546
+ --nproc_per_node=${GPUS} \
547
+ --master_port=${MASTER_PORT} \
548
+ eval/mmmu/evaluate_mmmu.py --checkpoint ${CHECKPOINT} --datasets MMMU_test "${ARGS[@]:2}"
549
+ fi
550
+
551
+ if [ ${DATASET} == "mmmu-dev-cot" ]; then
552
+ torchrun \
553
+ --nnodes=1 \
554
+ --node_rank=0 \
555
+ --master_addr=127.0.0.1 \
556
+ --nproc_per_node=${GPUS} \
557
+ --master_port=${MASTER_PORT} \
558
+ eval/mmmu/evaluate_mmmu_cot.py --checkpoint ${CHECKPOINT} --datasets MMMU_dev "${ARGS[@]:2}"
559
+ fi
560
+
561
+ if [ ${DATASET} == "mmmu-val-cot" ]; then
562
+ torchrun \
563
+ --nnodes=1 \
564
+ --node_rank=0 \
565
+ --master_addr=127.0.0.1 \
566
+ --nproc_per_node=${GPUS} \
567
+ --master_port=${MASTER_PORT} \
568
+ eval/mmmu/evaluate_mmmu_cot.py --checkpoint ${CHECKPOINT} --datasets MMMU_validation "${ARGS[@]:2}"
569
+ fi
570
+
571
+ if [ ${DATASET} == "mmmu-test-cot" ]; then
572
+ torchrun \
573
+ --nnodes=1 \
574
+ --node_rank=0 \
575
+ --master_addr=127.0.0.1 \
576
+ --nproc_per_node=${GPUS} \
577
+ --master_port=${MASTER_PORT} \
578
+ eval/mmmu/evaluate_mmmu_cot.py --checkpoint ${CHECKPOINT} --datasets MMMU_test "${ARGS[@]:2}"
579
+ fi
580
+
581
+ if [ ${DATASET} == "mmvp" ]; then
582
+ torchrun \
583
+ --nnodes=1 \
584
+ --node_rank=0 \
585
+ --master_addr=127.0.0.1 \
586
+ --nproc_per_node=${GPUS} \
587
+ --master_port=${MASTER_PORT} \
588
+ eval/mmvp/evaluate_mmvp.py --checkpoint ${CHECKPOINT} --datasets MMVP "${ARGS[@]:2}"
589
+ fi
590
+
591
+ if [ ${DATASET} == "mathvista-testmini" ]; then
592
+ torchrun \
593
+ --nnodes=1 \
594
+ --node_rank=0 \
595
+ --master_addr=127.0.0.1 \
596
+ --nproc_per_node=${GPUS} \
597
+ --master_port=${MASTER_PORT} \
598
+ eval/mathvista/evaluate_mathvista.py --checkpoint ${CHECKPOINT} --datasets MathVista_testmini "${ARGS[@]:2}"
599
+ fi
600
+
601
+ if [ ${DATASET} == "mathvista-test" ]; then
602
+ torchrun \
603
+ --nnodes=1 \
604
+ --node_rank=0 \
605
+ --master_addr=127.0.0.1 \
606
+ --nproc_per_node=${GPUS} \
607
+ --master_port=${MASTER_PORT} \
608
+ eval/mathvista/evaluate_mathvista.py --checkpoint ${CHECKPOINT} --datasets MathVista_test "${ARGS[@]:2}"
609
+ fi
610
+
611
+ if [ ${DATASET} == "seed" ]; then
612
+ torchrun \
613
+ --nnodes=1 \
614
+ --node_rank=0 \
615
+ --master_addr=127.0.0.1 \
616
+ --nproc_per_node=${GPUS} \
617
+ --master_port=${MASTER_PORT} \
618
+ eval/seed/evaluate_seed.py --checkpoint ${CHECKPOINT} --datasets SEEDv1 "${ARGS[@]:2}"
619
+ fi
620
+
621
+ if [ ${DATASET} == "mvbench" ]; then
622
+ torchrun \
623
+ --nnodes=1 \
624
+ --node_rank=0 \
625
+ --master_addr=127.0.0.1 \
626
+ --nproc_per_node=${GPUS} \
627
+ --master_port=${MASTER_PORT} \
628
+ eval/mvbench/evaluate_mvbench.py --checkpoint ${CHECKPOINT} --num_segments 16 "${ARGS[@]:2}"
629
+ fi
630
+
631
+ if [ ${DATASET} == "mmiu" ]; then
632
+ torchrun \
633
+ --nnodes=1 \
634
+ --node_rank=0 \
635
+ --master_addr=127.0.0.1 \
636
+ --nproc_per_node=${GPUS} \
637
+ --master_port=${MASTER_PORT} \
638
+ eval/mmiu/evaluate_mmiu.py --checkpoint ${CHECKPOINT} "${ARGS[@]:2}"
639
+ fi
640
+
641
+ if [ ${DATASET} == "mmhal" ]; then
642
+ torchrun \
643
+ --nnodes=1 \
644
+ --node_rank=0 \
645
+ --master_addr=127.0.0.1 \
646
+ --nproc_per_node=${GPUS} \
647
+ --master_port=${MASTER_PORT} \
648
+ eval/mmhal/evaluate_mmhal.py --checkpoint ${CHECKPOINT} "${ARGS[@]:2}"
649
+ fi
650
+
651
+ if [ ${DATASET} == "mmmu-pro" ]; then
652
+ python -u eval/mmmu_pro/evaluate_mmmu_pro.py --model ${CHECKPOINT} --mode direct --setting "standard (10 options)" "${ARGS[@]:2}"
653
+ python -u eval/mmmu_pro/evaluate_mmmu_pro.py --model ${CHECKPOINT} --mode cot --setting "standard (10 options)" "${ARGS[@]:2}"
654
+ python -u eval/mmmu_pro/evaluate_mmmu_pro.py --model ${CHECKPOINT} --mode direct --setting vision "${ARGS[@]:2}"
655
+ python -u eval/mmmu_pro/evaluate_mmmu_pro.py --model ${CHECKPOINT} --mode cot --setting vision "${ARGS[@]:2}"
656
+ fi
657
+
658
+ if [ ${DATASET} == "mmmu-pro-std10" ]; then
659
+ python -u eval/mmmu_pro/evaluate_mmmu_pro.py --model ${CHECKPOINT} --mode direct --setting "standard (10 options)" "${ARGS[@]:2}"
660
+ python -u eval/mmmu_pro/evaluate_mmmu_pro.py --model ${CHECKPOINT} --mode cot --setting "standard (10 options)" "${ARGS[@]:2}"
661
+ fi
662
+
663
+ if [ ${DATASET} == "mmmu-pro-vision" ]; then
664
+ python -u eval/mmmu_pro/evaluate_mmmu_pro.py --model ${CHECKPOINT} --mode direct --setting vision "${ARGS[@]:2}"
665
+ python -u eval/mmmu_pro/evaluate_mmmu_pro.py --model ${CHECKPOINT} --mode cot --setting vision "${ARGS[@]:2}"
666
+ fi
667
+
668
+ if [ ${DATASET} == "drivelm" ]; then
669
+ torchrun \
670
+ --nnodes=1 \
671
+ --node_rank=0 \
672
+ --master_addr=127.0.0.1 \
673
+ --nproc_per_node=${GPUS} \
674
+ --master_port=${MASTER_PORT} \
675
+ eval/domain_specific/drivelm/evaluate.py --checkpoint ${CHECKPOINT} --datasets DriveLM_val --dynamic --max-num 12
676
+ fi
677
+
678
+ if [ ${DATASET} == "mme—realworld" ]; then
679
+ torchrun \
680
+ --nnodes=1 \
681
+ --node_rank=0 \
682
+ --master_addr=127.0.0.1 \
683
+ --nproc_per_node=${GPUS} \
684
+ --master_port=${MASTER_PORT} \
685
+ eval/domain_specific/mme_rw/evaluate.py --checkpoint ${CHECKPOINT} --datasets MME_RealWorld "${ARGS[@]:2}"
686
+ fi
687
+
688
+ if [ ${DATASET} == "dior-rsvg" ]; then
689
+ torchrun \
690
+ --nnodes=1 \
691
+ --node_rank=0 \
692
+ --master_addr=127.0.0.1 \
693
+ --nproc_per_node=${GPUS} \
694
+ --master_port=${MASTER_PORT} \
695
+ eval/domain_specific/rs_det/evaluate.py --checkpoint ${CHECKPOINT} --datasets DIOR_RSVG "${ARGS[@]:2}"
696
+ fi
697
+
698
+ if [ ${DATASET} == "rsvqa-lr" ]; then
699
+ torchrun \
700
+ --nnodes=1 \
701
+ --node_rank=0 \
702
+ --master_addr=127.0.0.1 \
703
+ --nproc_per_node=${GPUS} \
704
+ --master_port=${MASTER_PORT} \
705
+ eval/domain_specific/rs_vqa/evaluate.py --checkpoint ${CHECKPOINT} --datasets RSVQA_H_TEST2 "${ARGS[@]:2}"
706
+ fi
707
+
708
+ if [ ${DATASET} == "rsvqa-hr-test1" ]; then
709
+ torchrun \
710
+ --nnodes=1 \
711
+ --node_rank=0 \
712
+ --master_addr=127.0.0.1 \
713
+ --nproc_per_node=${GPUS} \
714
+ --master_port=${MASTER_PORT} \
715
+ eval/domain_specific/rs_vqa/evaluate.py --checkpoint ${CHECKPOINT} --datasets RSVQA_H_TEST1 "${ARGS[@]:2}"
716
+ fi
717
+
718
+ if [ ${DATASET} == "rsvqa-hr-test2" ]; then
719
+ torchrun \
720
+ --nnodes=1 \
721
+ --node_rank=0 \
722
+ --master_addr=127.0.0.1 \
723
+ --nproc_per_node=${GPUS} \
724
+ --master_port=${MASTER_PORT} \
725
+ eval/domain_specific/rs_vqa/evaluate.py --checkpoint ${CHECKPOINT} --datasets RSVQA_L "${ARGS[@]:2}"
726
+ fi
src/third_party/InternVL/internvl_chat/internvl/conversation.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Conversation prompt templates.
3
+
4
+ We kindly request that you import fastchat instead of copying this file if you wish to use it.
5
+ If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates.
6
+ """
7
+
8
+ import dataclasses
9
+ from enum import IntEnum, auto
10
+ from typing import Any, Dict, List, Tuple, Union
11
+
12
+
13
+ class SeparatorStyle(IntEnum):
14
+ """Separator styles."""
15
+
16
+ ADD_COLON_SINGLE = auto()
17
+ ADD_COLON_TWO = auto()
18
+ ADD_COLON_SPACE_SINGLE = auto()
19
+ NO_COLON_SINGLE = auto()
20
+ NO_COLON_TWO = auto()
21
+ ADD_NEW_LINE_SINGLE = auto()
22
+ LLAMA2 = auto()
23
+ CHATGLM = auto()
24
+ CHATML = auto()
25
+ CHATINTERN = auto()
26
+ DOLLY = auto()
27
+ RWKV = auto()
28
+ PHOENIX = auto()
29
+ ROBIN = auto()
30
+ FALCON_CHAT = auto()
31
+ CHATGLM3 = auto()
32
+ INTERNVL_ZH = auto()
33
+ MPT = auto()
34
+
35
+
36
+ @dataclasses.dataclass
37
+ class Conversation:
38
+ """A class that manages prompt templates and keeps all conversation history."""
39
+
40
+ # The name of this template
41
+ name: str
42
+ # The template of the system prompt
43
+ system_template: str = '{system_message}'
44
+ # The system message
45
+ system_message: str = ''
46
+ # The names of two roles
47
+ roles: Tuple[str] = ('USER', 'ASSISTANT')
48
+ # All messages. Each item is (role, message).
49
+ messages: List[List[str]] = ()
50
+ # The number of few shot examples
51
+ offset: int = 0
52
+ # The separator style and configurations
53
+ sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
54
+ sep: str = '\n'
55
+ sep2: str = None
56
+ # Stop criteria (the default one is EOS token)
57
+ stop_str: Union[str, List[str]] = None
58
+ # Stops generation if meeting any token in this list
59
+ stop_token_ids: List[int] = None
60
+
61
+ def get_prompt(self) -> str:
62
+ """Get the prompt for generation."""
63
+ system_prompt = self.system_template.format(system_message=self.system_message)
64
+ if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
65
+ ret = system_prompt + self.sep
66
+ for role, message in self.messages:
67
+ if message:
68
+ ret += role + ': ' + message + self.sep
69
+ else:
70
+ ret += role + ':'
71
+ return ret
72
+ elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
73
+ seps = [self.sep, self.sep2]
74
+ ret = system_prompt + seps[0]
75
+ for i, (role, message) in enumerate(self.messages):
76
+ if message:
77
+ ret += role + ': ' + message + seps[i % 2]
78
+ else:
79
+ ret += role + ':'
80
+ return ret
81
+ elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
82
+ ret = system_prompt + self.sep
83
+ for role, message in self.messages:
84
+ if message:
85
+ ret += role + ': ' + message + self.sep
86
+ else:
87
+ ret += role + ': ' # must be end with a space
88
+ return ret
89
+ elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
90
+ ret = '' if system_prompt == '' else system_prompt + self.sep
91
+ for role, message in self.messages:
92
+ if message:
93
+ ret += role + '\n' + message + self.sep
94
+ else:
95
+ ret += role + '\n'
96
+ return ret
97
+ elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
98
+ ret = system_prompt
99
+ for role, message in self.messages:
100
+ if message:
101
+ ret += role + message + self.sep
102
+ else:
103
+ ret += role
104
+ return ret
105
+ elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
106
+ seps = [self.sep, self.sep2]
107
+ ret = system_prompt
108
+ for i, (role, message) in enumerate(self.messages):
109
+ if message:
110
+ ret += role + message + seps[i % 2]
111
+ else:
112
+ ret += role
113
+ return ret
114
+ elif self.sep_style == SeparatorStyle.RWKV:
115
+ ret = system_prompt
116
+ for i, (role, message) in enumerate(self.messages):
117
+ if message:
118
+ ret += (
119
+ role
120
+ + ': '
121
+ + message.replace('\r\n', '\n').replace('\n\n', '\n')
122
+ )
123
+ ret += '\n\n'
124
+ else:
125
+ ret += role + ':'
126
+ return ret
127
+ elif self.sep_style == SeparatorStyle.LLAMA2:
128
+ seps = [self.sep, self.sep2]
129
+ if self.system_message:
130
+ ret = system_prompt
131
+ else:
132
+ ret = '[INST] '
133
+ for i, (role, message) in enumerate(self.messages):
134
+ tag = self.roles[i % 2]
135
+ if message:
136
+ if i == 0:
137
+ ret += message + ' '
138
+ else:
139
+ ret += tag + ' ' + message + seps[i % 2]
140
+ else:
141
+ ret += tag
142
+ return ret
143
+ elif self.sep_style == SeparatorStyle.CHATGLM:
144
+ # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
145
+ # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
146
+ round_add_n = 1 if self.name == 'chatglm2' else 0
147
+ if system_prompt:
148
+ ret = system_prompt + self.sep
149
+ else:
150
+ ret = ''
151
+
152
+ for i, (role, message) in enumerate(self.messages):
153
+ if i % 2 == 0:
154
+ ret += f'[Round {i//2 + round_add_n}]{self.sep}'
155
+
156
+ if message:
157
+ ret += f'{role}:{message}{self.sep}'
158
+ else:
159
+ ret += f'{role}:'
160
+ return ret
161
+ elif self.sep_style == SeparatorStyle.CHATML:
162
+ ret = '' if system_prompt == '' else system_prompt + self.sep + '\n'
163
+ for role, message in self.messages:
164
+ if message:
165
+ ret += role + '\n' + message + self.sep + '\n'
166
+ else:
167
+ ret += role + '\n'
168
+ return ret
169
+ elif self.sep_style == SeparatorStyle.CHATGLM3:
170
+ ret = ''
171
+ if self.system_message:
172
+ ret += system_prompt
173
+ for role, message in self.messages:
174
+ if message:
175
+ ret += role + '\n' + ' ' + message
176
+ else:
177
+ ret += role
178
+ return ret
179
+ elif self.sep_style == SeparatorStyle.CHATINTERN:
180
+ # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
181
+ seps = [self.sep, self.sep2]
182
+ ret = system_prompt
183
+ for i, (role, message) in enumerate(self.messages):
184
+ # if i % 2 == 0:
185
+ # ret += "<s>"
186
+ if message:
187
+ ret += role + ':' + message + seps[i % 2] + '\n'
188
+ else:
189
+ ret += role + ':'
190
+ return ret
191
+ elif self.sep_style == SeparatorStyle.DOLLY:
192
+ seps = [self.sep, self.sep2]
193
+ ret = system_prompt
194
+ for i, (role, message) in enumerate(self.messages):
195
+ if message:
196
+ ret += role + ':\n' + message + seps[i % 2]
197
+ if i % 2 == 1:
198
+ ret += '\n\n'
199
+ else:
200
+ ret += role + ':\n'
201
+ return ret
202
+ elif self.sep_style == SeparatorStyle.PHOENIX:
203
+ ret = system_prompt
204
+ for role, message in self.messages:
205
+ if message:
206
+ ret += role + ': ' + '<s>' + message + '</s>'
207
+ else:
208
+ ret += role + ': ' + '<s>'
209
+ return ret
210
+ elif self.sep_style == SeparatorStyle.ROBIN:
211
+ ret = system_prompt + self.sep
212
+ for role, message in self.messages:
213
+ if message:
214
+ ret += role + ':\n' + message + self.sep
215
+ else:
216
+ ret += role + ':\n'
217
+ return ret
218
+ elif self.sep_style == SeparatorStyle.FALCON_CHAT:
219
+ ret = ''
220
+ if self.system_message:
221
+ ret += system_prompt + self.sep
222
+ for role, message in self.messages:
223
+ if message:
224
+ ret += role + ': ' + message + self.sep
225
+ else:
226
+ ret += role + ':'
227
+
228
+ return ret
229
+ elif self.sep_style == SeparatorStyle.INTERNVL_ZH:
230
+ seps = [self.sep2, self.sep]
231
+ ret = self.system_message + seps[0]
232
+ for i, (role, message) in enumerate(self.messages):
233
+ if message:
234
+ ret += role + ': ' + message + seps[i % 2]
235
+ else:
236
+ ret += role + ':'
237
+ return ret
238
+ elif self.sep_style == SeparatorStyle.MPT:
239
+ ret = system_prompt + self.sep
240
+ for role, message in self.messages:
241
+ if message:
242
+ if type(message) is tuple:
243
+ message, _, _ = message
244
+ ret += role + message + self.sep
245
+ else:
246
+ ret += role
247
+ return ret
248
+ else:
249
+ raise ValueError(f'Invalid style: {self.sep_style}')
250
+
251
+ def set_system_message(self, system_message: str):
252
+ """Set the system message."""
253
+ self.system_message = system_message
254
+
255
+ def append_message(self, role: str, message: str):
256
+ """Append a new message."""
257
+ self.messages.append([role, message])
258
+
259
+ def update_last_message(self, message: str):
260
+ """Update the last output.
261
+
262
+ The last message is typically set to be None when constructing the prompt,
263
+ so we need to update it in-place after getting the response from a model.
264
+ """
265
+ self.messages[-1][1] = message
266
+
267
+ def to_gradio_chatbot(self):
268
+ """Convert the conversation to gradio chatbot format."""
269
+ ret = []
270
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
271
+ if i % 2 == 0:
272
+ ret.append([msg, None])
273
+ else:
274
+ ret[-1][-1] = msg
275
+ return ret
276
+
277
+ def to_openai_api_messages(self):
278
+ """Convert the conversation to OpenAI chat completion format."""
279
+ ret = [{'role': 'system', 'content': self.system_message}]
280
+
281
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
282
+ if i % 2 == 0:
283
+ ret.append({'role': 'user', 'content': msg})
284
+ else:
285
+ if msg is not None:
286
+ ret.append({'role': 'assistant', 'content': msg})
287
+ return ret
288
+
289
+ def copy(self):
290
+ return Conversation(
291
+ name=self.name,
292
+ system_template=self.system_template,
293
+ system_message=self.system_message,
294
+ roles=self.roles,
295
+ messages=[[x, y] for x, y in self.messages],
296
+ offset=self.offset,
297
+ sep_style=self.sep_style,
298
+ sep=self.sep,
299
+ sep2=self.sep2,
300
+ stop_str=self.stop_str,
301
+ stop_token_ids=self.stop_token_ids,
302
+ )
303
+
304
+ def dict(self):
305
+ return {
306
+ 'template_name': self.name,
307
+ 'system_message': self.system_message,
308
+ 'roles': self.roles,
309
+ 'messages': self.messages,
310
+ 'offset': self.offset,
311
+ }
312
+
313
+
314
+ # A global registry for all conversation templates
315
+ conv_templates: Dict[str, Conversation] = {}
316
+
317
+
318
+ def register_conv_template(template: Conversation, override: bool = False):
319
+ """Register a new conversation template."""
320
+ if not override:
321
+ assert (
322
+ template.name not in conv_templates
323
+ ), f'{template.name} has been registered.'
324
+
325
+ conv_templates[template.name] = template
326
+
327
+
328
+ def get_conv_template(name: str) -> Conversation:
329
+ """Get a conversation template."""
330
+ return conv_templates[name].copy()
331
+
332
+
333
+ # InternVL-Chat-V1-1 template
334
+ register_conv_template(
335
+ Conversation(
336
+ name='internvl_zh',
337
+ system_template='',
338
+ roles=('<human>', '<bot>'),
339
+ sep_style=SeparatorStyle.INTERNVL_ZH,
340
+ sep='</s>',
341
+ sep2=' ',
342
+ )
343
+ )
344
+
345
+
346
+ # Both Hermes-2 and internlm2-chat are chatml-format conversation templates. The difference
347
+ # is that during training, the preprocessing function for the Hermes-2 template doesn't add
348
+ # <s> at the beginning of the tokenized sequence, while the internlm2-chat template does.
349
+ # Therefore, they are completely equivalent during inference.
350
+ register_conv_template(
351
+ Conversation(
352
+ name='Hermes-2',
353
+ system_template='<|im_start|>system\n{system_message}',
354
+ # note: The new system prompt was not used here to avoid changes in benchmark performance.
355
+ # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
356
+ system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
357
+ roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
358
+ sep_style=SeparatorStyle.MPT,
359
+ sep='<|im_end|>',
360
+ stop_str='<|endoftext|>',
361
+ )
362
+ )
363
+
364
+
365
+ register_conv_template(
366
+ Conversation(
367
+ name='internlm2-chat',
368
+ system_template='<|im_start|>system\n{system_message}',
369
+ # note: The new system prompt was not used here to avoid changes in benchmark performance.
370
+ # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
371
+ system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
372
+ roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
373
+ sep_style=SeparatorStyle.MPT,
374
+ sep='<|im_end|>',
375
+ )
376
+ )
377
+
378
+
379
+ register_conv_template(
380
+ Conversation(
381
+ name='phi3-chat',
382
+ system_template='<|system|>\n{system_message}',
383
+ # note: The new system prompt was not used here to avoid changes in benchmark performance.
384
+ # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
385
+ system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
386
+ roles=('<|user|>\n', '<|assistant|>\n'),
387
+ sep_style=SeparatorStyle.MPT,
388
+ sep='<|end|>',
389
+ )
390
+ )
391
+
392
+
393
+ register_conv_template(
394
+ Conversation(
395
+ name='internvl2_5',
396
+ system_template='<|im_start|>system\n{system_message}',
397
+ system_message='你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
398
+ roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
399
+ sep_style=SeparatorStyle.MPT,
400
+ sep='<|im_end|>\n',
401
+ )
402
+ )
src/third_party/InternVL/internvl_chat/internvl/dist_utils.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import socket
3
+ import subprocess
4
+ from datetime import timedelta
5
+
6
+ import deepspeed
7
+ import torch
8
+ import torch.multiprocessing as mp
9
+ from torch import distributed as dist
10
+
11
+ timeout = timedelta(minutes=60)
12
+
13
+
14
+ def _find_free_port():
15
+ # Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
16
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
17
+ # Binding to port 0 will cause the OS to find an available port for us
18
+ sock.bind(('', 0))
19
+ port = sock.getsockname()[1]
20
+ sock.close()
21
+ # NOTE: there is still a chance the port could be taken by other processes.
22
+ return port
23
+
24
+
25
+ def _is_free_port(port):
26
+ ips = socket.gethostbyname_ex(socket.gethostname())[-1]
27
+ ips.append('localhost')
28
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
29
+ return all(s.connect_ex((ip, port)) != 0 for ip in ips)
30
+
31
+
32
+ def init_dist(launcher, backend='nccl', **kwargs):
33
+ if mp.get_start_method(allow_none=True) is None:
34
+ mp.set_start_method('spawn')
35
+ if launcher == 'pytorch':
36
+ _init_dist_pytorch(backend, **kwargs)
37
+ elif launcher == 'mpi':
38
+ _init_dist_mpi(backend, **kwargs)
39
+ elif launcher == 'slurm':
40
+ _init_dist_slurm(backend, **kwargs)
41
+ else:
42
+ raise ValueError(f'Invalid launcher type: {launcher}')
43
+
44
+
45
+ def _init_dist_pytorch(backend, **kwargs):
46
+ # TODO: use local_rank instead of rank % num_gpus
47
+ rank = int(os.environ['RANK'])
48
+ num_gpus = torch.cuda.device_count()
49
+ torch.cuda.set_device(rank % num_gpus)
50
+ # dist.init_process_group(backend=backend, **kwargs)
51
+ deepspeed.init_distributed(dist_backend=backend)
52
+
53
+
54
+ def _init_dist_mpi(backend, **kwargs):
55
+ local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
56
+ torch.cuda.set_device(local_rank)
57
+ if 'MASTER_PORT' not in os.environ:
58
+ # 29500 is torch.distributed default port
59
+ os.environ['MASTER_PORT'] = '29500'
60
+ if 'MASTER_ADDR' not in os.environ:
61
+ raise KeyError('The environment variable MASTER_ADDR is not set')
62
+ os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
63
+ os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
64
+ dist.init_process_group(backend=backend, **kwargs)
65
+
66
+
67
+ def _init_dist_slurm(backend, port=None):
68
+ """Initialize slurm distributed training environment.
69
+
70
+ If argument ``port`` is not specified, then the master port will be system
71
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
72
+ environment variable, then a default port ``29500`` will be used.
73
+
74
+ Args:
75
+ backend (str): Backend of torch.distributed.
76
+ port (int, optional): Master port. Defaults to None.
77
+ """
78
+ proc_id = int(os.environ['SLURM_PROCID'])
79
+ ntasks = int(os.environ['SLURM_NTASKS'])
80
+ node_list = os.environ['SLURM_NODELIST']
81
+ num_gpus = torch.cuda.device_count()
82
+ torch.cuda.set_device(proc_id % num_gpus)
83
+ addr = subprocess.getoutput(
84
+ f'scontrol show hostname {node_list} | head -n1')
85
+ # specify master port
86
+ if port is not None:
87
+ os.environ['MASTER_PORT'] = str(port)
88
+ elif 'MASTER_PORT' in os.environ:
89
+ pass # use MASTER_PORT in the environment variable
90
+ else:
91
+ # if torch.distributed default port(29500) is available
92
+ # then use it, else find a free port
93
+ if _is_free_port(29500):
94
+ os.environ['MASTER_PORT'] = '29500'
95
+ else:
96
+ os.environ['MASTER_PORT'] = str(_find_free_port())
97
+ # use MASTER_ADDR in the environment variable if it already exists
98
+ if 'MASTER_ADDR' not in os.environ:
99
+ os.environ['MASTER_ADDR'] = addr
100
+ os.environ['WORLD_SIZE'] = str(ntasks)
101
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
102
+ os.environ['RANK'] = str(proc_id)
103
+ # dist.init_process_group(backend=backend, timeout=timeout)
104
+ deepspeed.init_distributed(dist_backend=backend)
src/third_party/InternVL/internvl_chat/internvl/model/__init__.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import math
8
+
9
+ import torch
10
+ from internvl.model.internvl_chat import InternVLChatConfig, InternVLChatModel
11
+ from transformers import AutoTokenizer
12
+
13
+
14
+ def split_model(num_layers, vit_alpha=0.5):
15
+ device_map = {}
16
+ world_size = torch.cuda.device_count()
17
+ # Since the first GPU will be used for ViT, treat it as half a GPU.
18
+ num_layers_per_gpu = math.ceil(num_layers / (world_size - vit_alpha))
19
+ num_layers_per_gpu = [num_layers_per_gpu] * world_size
20
+ num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * (1 - vit_alpha))
21
+ layer_cnt = 0
22
+ for i, num_layer in enumerate(num_layers_per_gpu):
23
+ for j in range(num_layer):
24
+ device_map[f'language_model.model.layers.{layer_cnt}'] = i
25
+ layer_cnt += 1
26
+ device_map['vision_model'] = 0
27
+ device_map['mlp1'] = 0
28
+ device_map['language_model.model.tok_embeddings'] = 0
29
+ device_map['language_model.model.embed_tokens'] = 0
30
+ device_map['language_model.output'] = 0
31
+ device_map['language_model.model.norm'] = 0
32
+ device_map['language_model.lm_head'] = 0
33
+ device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
34
+ device_map['language_model.model.rotary_emb'] = 0
35
+
36
+ return device_map
37
+
38
+
39
+ def load_model_and_tokenizer(args):
40
+ if args.auto:
41
+ config = InternVLChatConfig.from_pretrained(args.checkpoint)
42
+ num_hidden_layers = config.llm_config.num_hidden_layers
43
+ device_map = split_model(num_hidden_layers)
44
+ kwargs = {'device_map': device_map} if args.auto else {}
45
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False)
46
+ model = InternVLChatModel.from_pretrained(
47
+ args.checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16,
48
+ load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit, **kwargs).eval()
49
+ if not args.load_in_8bit and not args.load_in_4bit and not args.auto:
50
+ model = model.cuda()
51
+ return model, tokenizer
src/third_party/InternVL/internvl_chat/internvl/model/internlm2/configuration_internlm2.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on transformers/src/transformers/models/llama/configuration_llama.py
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ InternLM2 model configuration"""
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
24
+
25
+
26
+ # Modified from transformers.model.llama.configuration_llama.LlamaConfig
27
+ class InternLM2Config(PretrainedConfig):
28
+ r"""
29
+ This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
30
+ an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
31
+ configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
32
+
33
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
+ documentation from [`PretrainedConfig`] for more information.
35
+
36
+
37
+ Args:
38
+ vocab_size (`int`, *optional*, defaults to 32000):
39
+ Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
40
+ `inputs_ids` passed when calling [`InternLM2Model`]
41
+ hidden_size (`int`, *optional*, defaults to 4096):
42
+ Dimension of the hidden representations.
43
+ intermediate_size (`int`, *optional*, defaults to 11008):
44
+ Dimension of the MLP representations.
45
+ num_hidden_layers (`int`, *optional*, defaults to 32):
46
+ Number of hidden layers in the Transformer encoder.
47
+ num_attention_heads (`int`, *optional*, defaults to 32):
48
+ Number of attention heads for each attention layer in the Transformer encoder.
49
+ num_key_value_heads (`int`, *optional*):
50
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
51
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
52
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
53
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
54
+ by meanpooling all the original heads within that group. For more details checkout [this
55
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
56
+ `num_attention_heads`.
57
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
58
+ The non-linear activation function (function or string) in the decoder.
59
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
60
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
61
+ just in case (e.g., 512 or 1024 or 2048).
62
+ initializer_range (`float`, *optional*, defaults to 0.02):
63
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
64
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
65
+ The epsilon used by the rms normalization layers.
66
+ use_cache (`bool`, *optional*, defaults to `True`):
67
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
68
+ relevant if `config.is_decoder=True`.
69
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
70
+ Whether to tie weight embeddings
71
+ Example:
72
+
73
+ """
74
+ model_type = 'internlm2'
75
+ _auto_class = 'AutoConfig'
76
+
77
+ def __init__( # pylint: disable=W0102
78
+ self,
79
+ vocab_size=103168,
80
+ hidden_size=4096,
81
+ intermediate_size=11008,
82
+ num_hidden_layers=32,
83
+ num_attention_heads=32,
84
+ num_key_value_heads=None,
85
+ hidden_act='silu',
86
+ max_position_embeddings=2048,
87
+ initializer_range=0.02,
88
+ rms_norm_eps=1e-6,
89
+ use_cache=True,
90
+ pad_token_id=0,
91
+ bos_token_id=1,
92
+ eos_token_id=2,
93
+ tie_word_embeddings=False,
94
+ bias=True,
95
+ rope_theta=10000,
96
+ rope_scaling=None,
97
+ attn_implementation='eager',
98
+ **kwargs,
99
+ ):
100
+ self.vocab_size = vocab_size
101
+ self.max_position_embeddings = max_position_embeddings
102
+ self.hidden_size = hidden_size
103
+ self.intermediate_size = intermediate_size
104
+ self.num_hidden_layers = num_hidden_layers
105
+ self.num_attention_heads = num_attention_heads
106
+ self.bias = bias
107
+
108
+ if num_key_value_heads is None:
109
+ num_key_value_heads = num_attention_heads
110
+ self.num_key_value_heads = num_key_value_heads
111
+
112
+ self.hidden_act = hidden_act
113
+ self.initializer_range = initializer_range
114
+ self.rms_norm_eps = rms_norm_eps
115
+ self.use_cache = use_cache
116
+ self.rope_theta = rope_theta
117
+ self.rope_scaling = rope_scaling
118
+ self._rope_scaling_validation()
119
+
120
+ self.attn_implementation = attn_implementation
121
+ if self.attn_implementation is None:
122
+ self.attn_implementation = 'eager'
123
+ super().__init__(
124
+ pad_token_id=pad_token_id,
125
+ bos_token_id=bos_token_id,
126
+ eos_token_id=eos_token_id,
127
+ tie_word_embeddings=tie_word_embeddings,
128
+ **kwargs,
129
+ )
130
+
131
+ def _rope_scaling_validation(self):
132
+ """
133
+ Validate the `rope_scaling` configuration.
134
+ """
135
+ if self.rope_scaling is None:
136
+ return
137
+
138
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
139
+ raise ValueError(
140
+ '`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, '
141
+ f'got {self.rope_scaling}'
142
+ )
143
+ rope_scaling_type = self.rope_scaling.get('type', None)
144
+ rope_scaling_factor = self.rope_scaling.get('factor', None)
145
+ if rope_scaling_type is None or rope_scaling_type not in ['linear', 'dynamic']:
146
+ raise ValueError(
147
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
148
+ )
149
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0:
150
+ raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}")
src/third_party/InternVL/internvl_chat/internvl/model/internlm2/modeling_internlm2.py ADDED
@@ -0,0 +1,1429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on transformers/src/transformers/models/llama/modeling_llama.py
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ PyTorch InternLM2 model."""
17
+ import math
18
+ import queue
19
+ import threading
20
+ import warnings
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import torch.utils.checkpoint
26
+ from einops import rearrange
27
+ from torch import nn
28
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
+ from transformers.activations import ACT2FN
30
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
31
+ CausalLMOutputWithPast,
32
+ SequenceClassifierOutputWithPast)
33
+ from transformers.modeling_utils import PreTrainedModel
34
+ from transformers.utils import (add_start_docstrings,
35
+ add_start_docstrings_to_model_forward, logging,
36
+ replace_return_docstrings)
37
+
38
+ try:
39
+ from transformers.generation.streamers import BaseStreamer
40
+ except: # noqa # pylint: disable=bare-except
41
+ BaseStreamer = None
42
+
43
+ from .configuration_internlm2 import InternLM2Config
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+ _CONFIG_FOR_DOC = 'InternLM2Config'
48
+
49
+ flash_attn_func, flash_attn_varlen_func = None, None
50
+ pad_input, index_first_axis, unpad_input = None, None, None
51
+ try:
52
+ from flash_attn import flash_attn_func as _flash_attn_func
53
+ from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
54
+ from flash_attn.bert_padding import index_first_axis as _index_first_axis
55
+ from flash_attn.bert_padding import pad_input as _pad_input
56
+ from flash_attn.bert_padding import unpad_input as _unpad_input
57
+
58
+ flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
59
+ pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
60
+ has_flash_attn = True
61
+ except:
62
+ has_flash_attn = False
63
+
64
+
65
+ def _import_flash_attn():
66
+ global flash_attn_func, flash_attn_varlen_func
67
+ global pad_input, index_first_axis, unpad_input
68
+ try:
69
+ from flash_attn import flash_attn_func as _flash_attn_func
70
+ from flash_attn import \
71
+ flash_attn_varlen_func as _flash_attn_varlen_func
72
+ from flash_attn.bert_padding import \
73
+ index_first_axis as _index_first_axis
74
+ from flash_attn.bert_padding import pad_input as _pad_input
75
+ from flash_attn.bert_padding import unpad_input as _unpad_input
76
+ flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
77
+ pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
78
+ except ImportError:
79
+ raise ImportError('flash_attn is not installed.')
80
+
81
+
82
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
83
+ def _get_unpad_data(attention_mask):
84
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
85
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
86
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
87
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
88
+ return (
89
+ indices,
90
+ cu_seqlens,
91
+ max_seqlen_in_batch,
92
+ )
93
+
94
+
95
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
96
+ def _make_causal_mask(
97
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
98
+ ):
99
+ """
100
+ Make causal mask used for bi-directional self-attention.
101
+ """
102
+ bsz, tgt_len = input_ids_shape
103
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
104
+ mask_cond = torch.arange(mask.size(-1), device=device)
105
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
106
+ mask = mask.to(dtype)
107
+
108
+ if past_key_values_length > 0:
109
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
110
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
111
+
112
+
113
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
114
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
115
+ """
116
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
117
+ """
118
+ bsz, src_len = mask.size()
119
+ tgt_len = tgt_len if tgt_len is not None else src_len
120
+
121
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
122
+
123
+ inverted_mask = 1.0 - expanded_mask
124
+
125
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
126
+
127
+
128
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->InternLM2
129
+ class InternLM2RMSNorm(nn.Module):
130
+ def __init__(self, hidden_size, eps=1e-6):
131
+ """
132
+ InternLM2RMSNorm is equivalent to T5LayerNorm
133
+ """
134
+ super().__init__()
135
+ self.weight = nn.Parameter(torch.ones(hidden_size))
136
+ self.variance_epsilon = eps
137
+
138
+ def forward(self, hidden_states):
139
+ input_dtype = hidden_states.dtype
140
+ hidden_states = hidden_states.to(torch.float32)
141
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
142
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
143
+ return self.weight * hidden_states.to(input_dtype)
144
+
145
+
146
+ try:
147
+ from functools import partial
148
+
149
+ from apex.normalization import FusedRMSNorm
150
+ InternLM2RMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa
151
+ print('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternLM2RMSNorm')
152
+ except ImportError:
153
+ # using the normal LlamaRMSNorm
154
+ pass
155
+ except Exception:
156
+ print('discovered apex but it failed to load, falling back to InternLM2RMSNorm')
157
+ pass
158
+
159
+
160
+ # Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2
161
+ class InternLM2RotaryEmbedding(nn.Module):
162
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
163
+ super().__init__()
164
+
165
+ self.dim = dim
166
+ self.max_position_embeddings = max_position_embeddings
167
+ self.base = base
168
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
169
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
170
+
171
+ # Build here to make `torch.jit.trace` work.
172
+ self._set_cos_sin_cache(
173
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
174
+ )
175
+
176
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
177
+ self.max_seq_len_cached = seq_len
178
+ t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype)
179
+
180
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
181
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
182
+ emb = torch.cat((freqs, freqs), dim=-1)
183
+ self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False)
184
+ self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False)
185
+
186
+ def forward(self, x, seq_len=None):
187
+ # x: [bs, num_attention_heads, seq_len, head_size]
188
+ if seq_len > self.max_seq_len_cached:
189
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32)
190
+
191
+ return (
192
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
193
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
194
+ )
195
+
196
+
197
+ # Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->InternLM2
198
+ class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
199
+ """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
200
+
201
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
202
+ self.scaling_factor = scaling_factor
203
+ super().__init__(dim, max_position_embeddings, base, device)
204
+
205
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
206
+ self.max_seq_len_cached = seq_len
207
+ t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype)
208
+ t = t / self.scaling_factor
209
+
210
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
211
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
212
+ emb = torch.cat((freqs, freqs), dim=-1)
213
+ self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False)
214
+ self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False)
215
+
216
+
217
+ # Copied from transformers.model.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->InternLM2
218
+ class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
219
+ """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
220
+ Credits to the Reddit users /u/bloc97 and /u/emozilla.
221
+ """
222
+
223
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
224
+ self.scaling_factor = scaling_factor
225
+ super().__init__(dim, max_position_embeddings, base, device)
226
+
227
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
228
+ self.max_seq_len_cached = seq_len
229
+
230
+ if seq_len > self.max_position_embeddings:
231
+ base = self.base * (
232
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
233
+ ) ** (self.dim / (self.dim - 2))
234
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
235
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
236
+
237
+ t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype)
238
+
239
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
240
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
241
+ emb = torch.cat((freqs, freqs), dim=-1)
242
+ self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False)
243
+ self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False)
244
+
245
+
246
+ # Copied from transformers.model.llama.modeling_llama.rotate_half
247
+ def rotate_half(x):
248
+ """Rotates half the hidden dims of the input."""
249
+ x1 = x[..., : x.shape[-1] // 2]
250
+ x2 = x[..., x.shape[-1] // 2:]
251
+ return torch.cat((-x2, x1), dim=-1)
252
+
253
+
254
+ # Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
255
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
256
+ """Applies Rotary Position Embedding to the query and key tensors."""
257
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
258
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
259
+ q_embed = (q * cos) + (rotate_half(q) * sin)
260
+ k_embed = (k * cos) + (rotate_half(k) * sin)
261
+ return q_embed, k_embed
262
+
263
+
264
+ class InternLM2MLP(nn.Module):
265
+ def __init__(self, config):
266
+ super().__init__()
267
+ self.config = config
268
+ self.hidden_size = config.hidden_size
269
+ self.intermediate_size = config.intermediate_size
270
+ self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
271
+ self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
272
+ self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
273
+ self.act_fn = ACT2FN[config.hidden_act]
274
+
275
+ def forward(self, x):
276
+ down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x))
277
+
278
+ return down_proj
279
+
280
+
281
+ # Copied from transformers.model.llama.modeling_llama.repeat_kv
282
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
283
+ """
284
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
285
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
286
+ """
287
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
288
+ if n_rep == 1:
289
+ return hidden_states
290
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
291
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
292
+
293
+
294
+ # Modified from transformers.model.llama.modeling_llama.LlamaAttention
295
+ class InternLM2Attention(nn.Module):
296
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
297
+
298
+ def __init__(self, config: InternLM2Config):
299
+ super().__init__()
300
+ self.config = config
301
+ self.hidden_size = config.hidden_size
302
+ self.num_heads = config.num_attention_heads
303
+ self.head_dim = self.hidden_size // self.num_heads
304
+ self.num_key_value_heads = config.num_key_value_heads
305
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
306
+ self.max_position_embeddings = config.max_position_embeddings
307
+ self.is_causal = True
308
+
309
+ if (self.head_dim * self.num_heads) != self.hidden_size:
310
+ raise ValueError(
311
+ f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
312
+ f' and `num_heads`: {self.num_heads}).'
313
+ )
314
+
315
+ self.wqkv = nn.Linear(
316
+ self.hidden_size,
317
+ (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
318
+ bias=config.bias,
319
+ )
320
+
321
+ self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
322
+ self._init_rope()
323
+
324
+ def _init_rope(self):
325
+ if self.config.rope_scaling is None:
326
+ self.rotary_emb = InternLM2RotaryEmbedding(
327
+ self.head_dim,
328
+ max_position_embeddings=self.max_position_embeddings,
329
+ base=self.config.rope_theta,
330
+ )
331
+ else:
332
+ scaling_type = self.config.rope_scaling['type']
333
+ scaling_factor = self.config.rope_scaling['factor']
334
+ if scaling_type == 'dynamic':
335
+ self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
336
+ self.head_dim,
337
+ max_position_embeddings=self.max_position_embeddings,
338
+ base=self.config.rope_theta,
339
+ scaling_factor=scaling_factor,
340
+ )
341
+ elif scaling_type == 'linear':
342
+ self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
343
+ self.head_dim,
344
+ max_position_embeddings=self.max_position_embeddings,
345
+ base=self.config.rope_theta,
346
+ scaling_factor=scaling_factor,
347
+ )
348
+ else:
349
+ raise ValueError("Currently we only support rotary embedding's type being 'dynamic' or 'linear'.")
350
+ return self.rotary_emb
351
+
352
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
353
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
354
+
355
+ def forward(
356
+ self,
357
+ hidden_states: torch.Tensor,
358
+ attention_mask: Optional[torch.Tensor] = None,
359
+ position_ids: Optional[torch.LongTensor] = None,
360
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
361
+ output_attentions: bool = False,
362
+ use_cache: bool = False,
363
+ **kwargs,
364
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
365
+ if 'padding_mask' in kwargs:
366
+ warnings.warn(
367
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. '
368
+ 'Please make sure use `attention_mask` instead.`'
369
+ )
370
+
371
+ bsz, q_len, _ = hidden_states.size()
372
+
373
+ qkv_states = self.wqkv(hidden_states)
374
+
375
+ qkv_states = rearrange(
376
+ qkv_states,
377
+ 'b q (h gs d) -> b q h gs d',
378
+ gs=2 + self.num_key_value_groups,
379
+ d=self.head_dim,
380
+ )
381
+
382
+ query_states = qkv_states[..., : self.num_key_value_groups, :]
383
+ query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
384
+ key_states = qkv_states[..., -2, :]
385
+ value_states = qkv_states[..., -1, :]
386
+
387
+ query_states = query_states.transpose(1, 2)
388
+ key_states = key_states.transpose(1, 2)
389
+ value_states = value_states.transpose(1, 2)
390
+
391
+ kv_seq_len = key_states.shape[-2]
392
+ if past_key_value is not None:
393
+ kv_seq_len += past_key_value[0].shape[-2]
394
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
395
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
396
+
397
+ if past_key_value is not None:
398
+ # reuse k, v, self_attention
399
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
400
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
401
+
402
+ past_key_value = (key_states, value_states) if use_cache else None
403
+
404
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
405
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
406
+
407
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
408
+
409
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
410
+ raise ValueError(
411
+ f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is'
412
+ f' {attn_weights.size()}'
413
+ )
414
+
415
+ if attention_mask is not None:
416
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
417
+ raise ValueError(
418
+ f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
419
+ )
420
+ attn_weights = attn_weights + attention_mask
421
+
422
+ # upcast attention to fp32
423
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
424
+ attn_output = torch.matmul(attn_weights, value_states)
425
+
426
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
427
+ raise ValueError(
428
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
429
+ f' {attn_output.size()}'
430
+ )
431
+
432
+ attn_output = attn_output.transpose(1, 2).contiguous()
433
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
434
+
435
+ attn_output = self.wo(attn_output)
436
+
437
+ if not output_attentions:
438
+ attn_weights = None
439
+
440
+ return attn_output, attn_weights, past_key_value
441
+
442
+
443
+ # Modified from transformers.model.llama.modeling_llama.InternLM2FlashAttention2
444
+ class InternLM2FlashAttention2(InternLM2Attention):
445
+ """
446
+ InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
447
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
448
+ flash attention and deal with padding tokens in case the input contains any of them.
449
+ """
450
+
451
+ def forward(
452
+ self,
453
+ hidden_states: torch.Tensor,
454
+ attention_mask: Optional[torch.LongTensor] = None,
455
+ position_ids: Optional[torch.LongTensor] = None,
456
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
457
+ output_attentions: bool = False,
458
+ use_cache: bool = False,
459
+ **kwargs,
460
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
461
+ # InternLM2FlashAttention2 attention does not support output_attentions
462
+ if 'padding_mask' in kwargs:
463
+ warnings.warn(
464
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. '
465
+ 'Please make sure use `attention_mask` instead.`'
466
+ )
467
+
468
+ # overwrite attention_mask with padding_mask
469
+ attention_mask = kwargs.pop('padding_mask')
470
+
471
+ output_attentions = False
472
+
473
+ bsz, q_len, _ = hidden_states.size()
474
+
475
+ qkv_states = self.wqkv(hidden_states)
476
+
477
+ qkv_states = rearrange(
478
+ qkv_states,
479
+ 'b q (h gs d) -> b q h gs d',
480
+ gs=2 + self.num_key_value_groups,
481
+ d=self.head_dim,
482
+ )
483
+
484
+ query_states = qkv_states[..., : self.num_key_value_groups, :]
485
+ query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
486
+ key_states = qkv_states[..., -2, :]
487
+ value_states = qkv_states[..., -1, :]
488
+
489
+ query_states = query_states.transpose(1, 2)
490
+ key_states = key_states.transpose(1, 2)
491
+ value_states = value_states.transpose(1, 2)
492
+
493
+ kv_seq_len = key_states.shape[-2]
494
+ if past_key_value is not None:
495
+ kv_seq_len += past_key_value[0].shape[-2]
496
+
497
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
498
+
499
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
500
+
501
+ if past_key_value is not None:
502
+ # reuse k, v, self_attention
503
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
504
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
505
+
506
+ past_key_value = (key_states, value_states) if use_cache else None
507
+
508
+ query_states = query_states.transpose(1, 2)
509
+ key_states = key_states.transpose(1, 2)
510
+ value_states = value_states.transpose(1, 2)
511
+
512
+ attn_output = self._flash_attention_forward(
513
+ query_states, key_states, value_states, attention_mask, q_len
514
+ )
515
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
516
+ attn_output = self.wo(attn_output)
517
+
518
+ if not output_attentions:
519
+ attn_weights = None
520
+
521
+ return attn_output, attn_weights, past_key_value
522
+
523
+ def _flash_attention_forward(
524
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
525
+ ):
526
+ """
527
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
528
+ first unpad the input, then computes the attention scores and pad the final attention scores.
529
+
530
+ Args:
531
+ query_states (`torch.Tensor`):
532
+ Input query states to be passed to Flash Attention API
533
+ key_states (`torch.Tensor`):
534
+ Input key states to be passed to Flash Attention API
535
+ value_states (`torch.Tensor`):
536
+ Input value states to be passed to Flash Attention API
537
+ attention_mask (`torch.Tensor`):
538
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
539
+ position of padding tokens and 1 for the position of non-padding tokens.
540
+ dropout (`int`, *optional*):
541
+ Attention dropout
542
+ softmax_scale (`float`, *optional*):
543
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
544
+ """
545
+ # Contains at least one padding token in the sequence
546
+ causal = self.is_causal and query_length != 1
547
+ if attention_mask is not None:
548
+ batch_size = query_states.shape[0]
549
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(
550
+ query_states, key_states, value_states, attention_mask, query_length
551
+ )
552
+
553
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
554
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
555
+
556
+ attn_output_unpad = flash_attn_varlen_func(
557
+ query_states,
558
+ key_states,
559
+ value_states,
560
+ cu_seqlens_q=cu_seqlens_q,
561
+ cu_seqlens_k=cu_seqlens_k,
562
+ max_seqlen_q=max_seqlen_in_batch_q,
563
+ max_seqlen_k=max_seqlen_in_batch_k,
564
+ dropout_p=dropout,
565
+ softmax_scale=softmax_scale,
566
+ causal=causal,
567
+ )
568
+
569
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
570
+ else:
571
+ attn_output = flash_attn_func(
572
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
573
+ )
574
+
575
+ return attn_output
576
+
577
+ def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
578
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
579
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
580
+
581
+ key_layer = index_first_axis(
582
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
583
+ )
584
+ value_layer = index_first_axis(
585
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
586
+ )
587
+
588
+ if query_length == kv_seq_len:
589
+ query_layer = index_first_axis(
590
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
591
+ )
592
+ cu_seqlens_q = cu_seqlens_k
593
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
594
+ indices_q = indices_k
595
+ elif query_length == 1:
596
+ max_seqlen_in_batch_q = 1
597
+ cu_seqlens_q = torch.arange(
598
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
599
+ ) # There is a memcpy here, that is very bad.
600
+ indices_q = cu_seqlens_q[:-1]
601
+ query_layer = query_layer.squeeze(1)
602
+ else:
603
+ # The -q_len: slice assumes left padding.
604
+ attention_mask = attention_mask[:, -query_length:]
605
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
606
+
607
+ return (
608
+ query_layer,
609
+ key_layer,
610
+ value_layer,
611
+ indices_q.to(torch.int64),
612
+ (cu_seqlens_q, cu_seqlens_k),
613
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
614
+ )
615
+
616
+
617
+ INTERNLM2_ATTENTION_CLASSES = {
618
+ 'eager': InternLM2Attention,
619
+ 'flash_attention_2': InternLM2FlashAttention2,
620
+ }
621
+
622
+
623
+ # Modified from transformers.model.llama.modeling_llama.LlamaDecoderLayer
624
+ class InternLM2DecoderLayer(nn.Module):
625
+ def __init__(self, config: InternLM2Config):
626
+ super().__init__()
627
+ self.hidden_size = config.hidden_size
628
+
629
+ self.attention = INTERNLM2_ATTENTION_CLASSES[config.attn_implementation](config=config)
630
+
631
+ self.feed_forward = InternLM2MLP(config)
632
+ self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
633
+ self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
634
+
635
+ def forward(
636
+ self,
637
+ hidden_states: torch.Tensor,
638
+ attention_mask: Optional[torch.Tensor] = None,
639
+ position_ids: Optional[torch.LongTensor] = None,
640
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
641
+ output_attentions: Optional[bool] = False,
642
+ use_cache: Optional[bool] = False,
643
+ **kwargs,
644
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
645
+ """
646
+ Args:
647
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
648
+ attention_mask (`torch.FloatTensor`, *optional*):
649
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
650
+ query_sequence_length, key_sequence_length)` if default attention is used.
651
+ output_attentions (`bool`, *optional*):
652
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
653
+ returned tensors for more detail.
654
+ use_cache (`bool`, *optional*):
655
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
656
+ (see `past_key_values`).
657
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
658
+ """
659
+ if 'padding_mask' in kwargs:
660
+ warnings.warn(
661
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. '
662
+ 'Please make sure use `attention_mask` instead.`'
663
+ )
664
+
665
+ residual = hidden_states
666
+
667
+ hidden_states = self.attention_norm(hidden_states)
668
+
669
+ # Self Attention
670
+ hidden_states, self_attn_weights, present_key_value = self.attention(
671
+ hidden_states=hidden_states,
672
+ attention_mask=attention_mask,
673
+ position_ids=position_ids,
674
+ past_key_value=past_key_value,
675
+ output_attentions=output_attentions,
676
+ use_cache=use_cache,
677
+ **kwargs,
678
+ )
679
+ hidden_states = residual + hidden_states
680
+
681
+ # Fully Connected
682
+ residual = hidden_states
683
+ hidden_states = self.ffn_norm(hidden_states)
684
+ hidden_states = self.feed_forward(hidden_states)
685
+ hidden_states = residual + hidden_states
686
+
687
+ outputs = (hidden_states,)
688
+
689
+ if output_attentions:
690
+ outputs += (self_attn_weights,)
691
+
692
+ if use_cache:
693
+ outputs += (present_key_value,)
694
+
695
+ return outputs
696
+
697
+
698
+ InternLM2_START_DOCSTRING = r"""
699
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
700
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
701
+ etc.)
702
+
703
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
704
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
705
+ and behavior.
706
+
707
+ Parameters:
708
+ config ([`InternLM2Config`]):
709
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
710
+ load the weights associated with the model, only the configuration. Check out the
711
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
712
+ """
713
+
714
+
715
+ # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2
716
+ @add_start_docstrings(
717
+ 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.',
718
+ InternLM2_START_DOCSTRING,
719
+ )
720
+ class InternLM2PreTrainedModel(PreTrainedModel):
721
+ config_class = InternLM2Config
722
+ base_model_prefix = 'model'
723
+ supports_gradient_checkpointing = True
724
+ _no_split_modules = ['InternLM2DecoderLayer']
725
+ _skip_keys_device_placement = 'past_key_values'
726
+ _supports_flash_attn_2 = True
727
+
728
+ def _init_weights(self, module):
729
+ std = self.config.initializer_range
730
+ if isinstance(module, nn.Linear):
731
+ module.weight.data.normal_(mean=0.0, std=std)
732
+ if module.bias is not None:
733
+ module.bias.data.zero_()
734
+ elif isinstance(module, nn.Embedding):
735
+ module.weight.data.normal_(mean=0.0, std=std)
736
+ if module.padding_idx is not None:
737
+ module.weight.data[module.padding_idx].zero_()
738
+
739
+
740
+ InternLM2_INPUTS_DOCSTRING = r"""
741
+ Args:
742
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
743
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
744
+ it.
745
+
746
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
747
+ [`PreTrainedTokenizer.__call__`] for details.
748
+
749
+ [What are input IDs?](../glossary#input-ids)
750
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
751
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
752
+
753
+ - 1 for tokens that are **not masked**,
754
+ - 0 for tokens that are **masked**.
755
+
756
+ [What are attention masks?](../glossary#attention-mask)
757
+
758
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
759
+ [`PreTrainedTokenizer.__call__`] for details.
760
+
761
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
762
+ `past_key_values`).
763
+
764
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
765
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
766
+ information on the default strategy.
767
+
768
+ - 1 indicates the head is **not masked**,
769
+ - 0 indicates the head is **masked**.
770
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
771
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
772
+ config.n_positions - 1]`.
773
+
774
+ [What are position IDs?](../glossary#position-ids)
775
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or
776
+ when `config.use_cache=True`):
777
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
778
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
779
+ `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`.
780
+
781
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
782
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
783
+
784
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
785
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
786
+ of shape `(batch_size, sequence_length)`.
787
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
788
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
789
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
790
+ model's internal embedding lookup matrix.
791
+ use_cache (`bool`, *optional*):
792
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
793
+ `past_key_values`).
794
+ output_attentions (`bool`, *optional*):
795
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
796
+ tensors for more detail.
797
+ output_hidden_states (`bool`, *optional*):
798
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
799
+ more detail.
800
+ return_dict (`bool`, *optional*):
801
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
802
+ """
803
+
804
+
805
+ # Modified from transformers.model.llama.modeling_llama.LlamaModel
806
+ @add_start_docstrings(
807
+ 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.',
808
+ InternLM2_START_DOCSTRING,
809
+ )
810
+ class InternLM2Model(InternLM2PreTrainedModel):
811
+ """
812
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`]
813
+
814
+ Args:
815
+ config: InternLM2Config
816
+ """
817
+
818
+ _auto_class = 'AutoModel'
819
+
820
+ def __init__(self, config: InternLM2Config):
821
+ super().__init__(config)
822
+ self.padding_idx = config.pad_token_id
823
+ self.vocab_size = config.vocab_size
824
+ self.config = config
825
+ if not has_flash_attn:
826
+ self.config.attn_implementation = 'eager'
827
+ print('Warning: Flash attention is not available, using eager attention instead.')
828
+
829
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
830
+
831
+ self.layers = nn.ModuleList([InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
832
+ self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
833
+
834
+ self.gradient_checkpointing = False
835
+ # Initialize weights and apply final processing
836
+ self.post_init()
837
+
838
+ def get_input_embeddings(self):
839
+ return self.tok_embeddings
840
+
841
+ def set_input_embeddings(self, value):
842
+ self.tok_embeddings = value
843
+
844
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
845
+ # create causal mask
846
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
847
+ combined_attention_mask = None
848
+ if input_shape[-1] > 1:
849
+ combined_attention_mask = _make_causal_mask(
850
+ input_shape,
851
+ inputs_embeds.dtype,
852
+ device=inputs_embeds.device,
853
+ past_key_values_length=past_key_values_length,
854
+ )
855
+
856
+ if attention_mask is not None:
857
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
858
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
859
+ inputs_embeds.device
860
+ )
861
+ combined_attention_mask = (
862
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
863
+ )
864
+
865
+ return combined_attention_mask
866
+
867
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
868
+ def forward(
869
+ self,
870
+ input_ids: torch.LongTensor = None,
871
+ attention_mask: Optional[torch.Tensor] = None,
872
+ position_ids: Optional[torch.LongTensor] = None,
873
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
874
+ inputs_embeds: Optional[torch.FloatTensor] = None,
875
+ use_cache: Optional[bool] = None,
876
+ output_attentions: Optional[bool] = None,
877
+ output_hidden_states: Optional[bool] = None,
878
+ return_dict: Optional[bool] = None,
879
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
880
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
881
+ output_hidden_states = (
882
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
883
+ )
884
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
885
+
886
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
887
+
888
+ if self.config.attn_implementation == 'flash_attention_2':
889
+ _import_flash_attn()
890
+
891
+ # retrieve input_ids and inputs_embeds
892
+ if input_ids is not None and inputs_embeds is not None:
893
+ raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time')
894
+ elif input_ids is not None:
895
+ batch_size, seq_length = input_ids.shape[:2]
896
+ elif inputs_embeds is not None:
897
+ batch_size, seq_length = inputs_embeds.shape[:2]
898
+ else:
899
+ raise ValueError('You have to specify either input_ids or inputs_embeds')
900
+
901
+ seq_length_with_past = seq_length
902
+ past_key_values_length = 0
903
+ if past_key_values is not None:
904
+ past_key_values_length = past_key_values[0][0].shape[2]
905
+ seq_length_with_past = seq_length_with_past + past_key_values_length
906
+
907
+ if position_ids is None:
908
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
909
+ position_ids = torch.arange(
910
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
911
+ )
912
+ position_ids = position_ids.unsqueeze(0)
913
+
914
+ if inputs_embeds is None:
915
+ inputs_embeds = self.tok_embeddings(input_ids)
916
+
917
+ if self.config.attn_implementation == 'flash_attention_2':
918
+ # 2d mask is passed through the layers
919
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
920
+ else:
921
+ if attention_mask is None:
922
+ attention_mask = torch.ones(
923
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
924
+ )
925
+ attention_mask = self._prepare_decoder_attention_mask(
926
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
927
+ )
928
+
929
+ # embed positions
930
+ hidden_states = inputs_embeds
931
+
932
+ if self.gradient_checkpointing and self.training:
933
+ if use_cache:
934
+ logger.warning_once(
935
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
936
+ )
937
+ use_cache = False
938
+
939
+ # decoder layers
940
+ all_hidden_states = () if output_hidden_states else None
941
+ all_self_attns = () if output_attentions else None
942
+ next_decoder_cache = () if use_cache else None
943
+
944
+ for idx, decoder_layer in enumerate(self.layers):
945
+ if output_hidden_states:
946
+ all_hidden_states += (hidden_states,)
947
+
948
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
949
+
950
+ if self.gradient_checkpointing and self.training:
951
+
952
+ def create_custom_forward(module):
953
+ def custom_forward(*inputs):
954
+ # None for past_key_value
955
+ return module(*inputs, output_attentions, None)
956
+
957
+ return custom_forward
958
+
959
+ layer_outputs = torch.utils.checkpoint.checkpoint(
960
+ create_custom_forward(decoder_layer),
961
+ hidden_states,
962
+ attention_mask,
963
+ position_ids,
964
+ None,
965
+ )
966
+ else:
967
+ layer_outputs = decoder_layer(
968
+ hidden_states,
969
+ attention_mask=attention_mask,
970
+ position_ids=position_ids,
971
+ past_key_value=past_key_value,
972
+ output_attentions=output_attentions,
973
+ use_cache=use_cache,
974
+ )
975
+
976
+ hidden_states = layer_outputs[0]
977
+
978
+ if use_cache:
979
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
980
+
981
+ if output_attentions:
982
+ all_self_attns += (layer_outputs[1],)
983
+
984
+ hidden_states = self.norm(hidden_states)
985
+
986
+ # add hidden states from the last decoder layer
987
+ if output_hidden_states:
988
+ all_hidden_states += (hidden_states,)
989
+
990
+ next_cache = next_decoder_cache if use_cache else None
991
+ if not return_dict:
992
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
993
+ return BaseModelOutputWithPast(
994
+ last_hidden_state=hidden_states,
995
+ past_key_values=next_cache,
996
+ hidden_states=all_hidden_states,
997
+ attentions=all_self_attns,
998
+ )
999
+
1000
+
1001
+ # Modified from transformers.model.llama.modeling_llama.LlamaForCausalLM
1002
+ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1003
+ _auto_class = 'AutoModelForCausalLM'
1004
+
1005
+ _tied_weights_keys = ['output.weight']
1006
+
1007
+ def __init__(self, config):
1008
+ super().__init__(config)
1009
+ self.model = InternLM2Model(config)
1010
+ self.vocab_size = config.vocab_size
1011
+ self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1012
+
1013
+ # Initialize weights and apply final processing
1014
+ self.post_init()
1015
+
1016
+ def get_input_embeddings(self):
1017
+ return self.model.tok_embeddings
1018
+
1019
+ def set_input_embeddings(self, value):
1020
+ self.model.tok_embeddings = value
1021
+
1022
+ def get_output_embeddings(self):
1023
+ return self.output
1024
+
1025
+ def set_output_embeddings(self, new_embeddings):
1026
+ self.output = new_embeddings
1027
+
1028
+ def set_decoder(self, decoder):
1029
+ self.model = decoder
1030
+
1031
+ def get_decoder(self):
1032
+ return self.model
1033
+
1034
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1035
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1036
+ def forward(
1037
+ self,
1038
+ input_ids: torch.LongTensor = None,
1039
+ attention_mask: Optional[torch.Tensor] = None,
1040
+ position_ids: Optional[torch.LongTensor] = None,
1041
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1042
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1043
+ labels: Optional[torch.LongTensor] = None,
1044
+ use_cache: Optional[bool] = None,
1045
+ output_attentions: Optional[bool] = None,
1046
+ output_hidden_states: Optional[bool] = None,
1047
+ return_dict: Optional[bool] = None,
1048
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1049
+ r"""
1050
+ Args:
1051
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1052
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1053
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1054
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1055
+
1056
+ Returns:
1057
+
1058
+ Example:
1059
+
1060
+ ```python
1061
+ >>> from transformers import AutoTokenizer, InternLM2ForCausalLM
1062
+
1063
+ >>> model = InternLM2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1064
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1065
+
1066
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1067
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1068
+
1069
+ >>> # Generate
1070
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1071
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1072
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1073
+ ```"""
1074
+
1075
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1076
+ output_hidden_states = (
1077
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1078
+ )
1079
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1080
+
1081
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1082
+ outputs = self.model(
1083
+ input_ids=input_ids,
1084
+ attention_mask=attention_mask,
1085
+ position_ids=position_ids,
1086
+ past_key_values=past_key_values,
1087
+ inputs_embeds=inputs_embeds,
1088
+ use_cache=use_cache,
1089
+ output_attentions=output_attentions,
1090
+ output_hidden_states=output_hidden_states,
1091
+ return_dict=return_dict,
1092
+ )
1093
+
1094
+ hidden_states = outputs[0]
1095
+ logits = self.output(hidden_states)
1096
+ logits = logits.float()
1097
+
1098
+ loss = None
1099
+ if labels is not None:
1100
+ # Shift so that tokens < n predict n
1101
+ shift_logits = logits[..., :-1, :].contiguous()
1102
+ shift_labels = labels[..., 1:].contiguous()
1103
+ # Flatten the tokens
1104
+ loss_fct = CrossEntropyLoss()
1105
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1106
+ shift_labels = shift_labels.view(-1)
1107
+ # Enable model parallelism
1108
+ shift_labels = shift_labels.to(shift_logits.device)
1109
+ loss = loss_fct(shift_logits, shift_labels)
1110
+
1111
+ if not return_dict:
1112
+ output = (logits,) + outputs[1:]
1113
+ return (loss,) + output if loss is not None else output
1114
+
1115
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1116
+ output = CausalLMOutputWithPast(
1117
+ loss=loss,
1118
+ logits=logits,
1119
+ past_key_values=outputs.past_key_values,
1120
+ hidden_states=outputs.hidden_states,
1121
+ attentions=outputs.attentions,
1122
+ )
1123
+ output['logits'] = output['logits'].to(device)
1124
+ return output
1125
+
1126
+ def prepare_inputs_for_generation(
1127
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1128
+ ):
1129
+ if past_key_values is not None:
1130
+ past_length = past_key_values[0][0].shape[2]
1131
+
1132
+ # Some generation methods already pass only the last input ID
1133
+ if input_ids.shape[1] > past_length:
1134
+ remove_prefix_length = past_length
1135
+ else:
1136
+ # Default to old behavior: keep only final ID
1137
+ remove_prefix_length = input_ids.shape[1] - 1
1138
+
1139
+ input_ids = input_ids[:, remove_prefix_length:]
1140
+
1141
+ position_ids = kwargs.get('position_ids', None)
1142
+ if attention_mask is not None and position_ids is None:
1143
+ # create position_ids on the fly for batch generation
1144
+ position_ids = attention_mask.long().cumsum(-1) - 1
1145
+ position_ids.masked_fill_(attention_mask == 0, 1)
1146
+ if past_key_values:
1147
+ position_ids = position_ids[:, -input_ids.shape[1]:]
1148
+
1149
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1150
+ if inputs_embeds is not None and past_key_values is None:
1151
+ model_inputs = {'inputs_embeds': inputs_embeds}
1152
+ else:
1153
+ model_inputs = {'input_ids': input_ids}
1154
+
1155
+ model_inputs.update(
1156
+ {
1157
+ 'position_ids': position_ids,
1158
+ 'past_key_values': past_key_values,
1159
+ 'use_cache': kwargs.get('use_cache'),
1160
+ 'attention_mask': attention_mask,
1161
+ }
1162
+ )
1163
+ return model_inputs
1164
+
1165
+ @staticmethod
1166
+ def _reorder_cache(past_key_values, beam_idx):
1167
+ reordered_past = ()
1168
+ for layer_past in past_key_values:
1169
+ reordered_past += (
1170
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1171
+ )
1172
+ return reordered_past
1173
+
1174
+ def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction=''):
1175
+ if tokenizer.add_bos_token:
1176
+ prompt = ''
1177
+ else:
1178
+ prompt = tokenizer.bos_token
1179
+ if meta_instruction:
1180
+ prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n"""
1181
+ for record in history:
1182
+ prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n"""
1183
+ prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"""
1184
+ return tokenizer([prompt], return_tensors='pt')
1185
+
1186
+ @torch.no_grad()
1187
+ def chat(
1188
+ self,
1189
+ tokenizer,
1190
+ query: str,
1191
+ history: List[Tuple[str, str]] = [],
1192
+ streamer: Optional[BaseStreamer] = None,
1193
+ max_new_tokens: int = 1024,
1194
+ do_sample: bool = True,
1195
+ temperature: float = 0.8,
1196
+ top_p: float = 0.8,
1197
+ meta_instruction: str = 'You are an AI assistant whose name is InternLM (书生·浦语).\n'
1198
+ '- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n'
1199
+ '- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.',
1200
+ **kwargs,
1201
+ ):
1202
+ inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
1203
+ inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
1204
+ # also add end-of-assistant token in eos token id to avoid unnecessary generation
1205
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(['<|im_end|>'])[0]]
1206
+ outputs = self.generate(
1207
+ **inputs,
1208
+ streamer=streamer,
1209
+ max_new_tokens=max_new_tokens,
1210
+ do_sample=do_sample,
1211
+ temperature=temperature,
1212
+ top_p=top_p,
1213
+ eos_token_id=eos_token_id,
1214
+ **kwargs,
1215
+ )
1216
+ outputs = outputs[0].cpu().tolist()[len(inputs['input_ids'][0]):]
1217
+ response = tokenizer.decode(outputs, skip_special_tokens=True)
1218
+ response = response.split('<|im_end|>')[0]
1219
+ history = history + [(query, response)]
1220
+ return response, history
1221
+
1222
+ @torch.no_grad()
1223
+ def stream_chat(
1224
+ self,
1225
+ tokenizer,
1226
+ query: str,
1227
+ history: List[Tuple[str, str]] = [],
1228
+ max_new_tokens: int = 1024,
1229
+ do_sample: bool = True,
1230
+ temperature: float = 0.8,
1231
+ top_p: float = 0.8,
1232
+ **kwargs,
1233
+ ):
1234
+ """
1235
+ Return a generator in format: (response, history)
1236
+ Eg.
1237
+ ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
1238
+ ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
1239
+ """
1240
+ if BaseStreamer is None:
1241
+ raise ModuleNotFoundError(
1242
+ 'The version of `transformers` is too low. Please make sure '
1243
+ 'that you have installed `transformers>=4.28.0`.'
1244
+ )
1245
+
1246
+ response_queue = queue.Queue(maxsize=20)
1247
+
1248
+ class ChatStreamer(BaseStreamer):
1249
+ def __init__(self, tokenizer) -> None:
1250
+ super().__init__()
1251
+ self.tokenizer = tokenizer
1252
+ self.queue = response_queue
1253
+ self.query = query
1254
+ self.history = history
1255
+ self.response = ''
1256
+ self.cache = []
1257
+ self.received_inputs = False
1258
+ self.queue.put((self.response, history + [(self.query, self.response)]))
1259
+
1260
+ def put(self, value):
1261
+ if len(value.shape) > 1 and value.shape[0] > 1:
1262
+ raise ValueError('ChatStreamer only supports batch size 1')
1263
+ elif len(value.shape) > 1:
1264
+ value = value[0]
1265
+
1266
+ if not self.received_inputs:
1267
+ # The first received value is input_ids, ignore here
1268
+ self.received_inputs = True
1269
+ return
1270
+
1271
+ self.cache.extend(value.tolist())
1272
+ token = self.tokenizer.decode(self.cache, skip_special_tokens=True)
1273
+ if token.strip() != '<|im_end|>':
1274
+ self.response = self.response + token
1275
+ history = self.history + [(self.query, self.response)]
1276
+ self.queue.put((self.response, history))
1277
+ self.cache = []
1278
+ else:
1279
+ self.end()
1280
+
1281
+ def end(self):
1282
+ self.queue.put(None)
1283
+
1284
+ def stream_producer():
1285
+ return self.chat(
1286
+ tokenizer=tokenizer,
1287
+ query=query,
1288
+ streamer=ChatStreamer(tokenizer=tokenizer),
1289
+ history=history,
1290
+ max_new_tokens=max_new_tokens,
1291
+ do_sample=do_sample,
1292
+ temperature=temperature,
1293
+ top_p=top_p,
1294
+ **kwargs,
1295
+ )
1296
+
1297
+ def consumer():
1298
+ producer = threading.Thread(target=stream_producer)
1299
+ producer.start()
1300
+ while True:
1301
+ res = response_queue.get()
1302
+ if res is None:
1303
+ return
1304
+ yield res
1305
+
1306
+ return consumer()
1307
+
1308
+
1309
+ # Copied from transformers.model.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
1310
+ @add_start_docstrings(
1311
+ """
1312
+ The InternLM2 Model transformer with a sequence classification head on top (linear layer).
1313
+
1314
+ [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification,
1315
+ as other causal models (e.g. GPT-2) do.
1316
+
1317
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1318
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1319
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1320
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1321
+ each row of the batch).
1322
+ """,
1323
+ InternLM2_START_DOCSTRING,
1324
+ )
1325
+ class InternLM2ForSequenceClassification(InternLM2PreTrainedModel):
1326
+ def __init__(self, config):
1327
+ super().__init__(config)
1328
+ self.num_labels = config.num_labels
1329
+ self.model = InternLM2Model(config)
1330
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1331
+
1332
+ # Initialize weights and apply final processing
1333
+ self.post_init()
1334
+
1335
+ def get_input_embeddings(self):
1336
+ return self.model.tok_embeddings
1337
+
1338
+ def set_input_embeddings(self, value):
1339
+ self.model.tok_embeddings = value
1340
+
1341
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1342
+ def forward(
1343
+ self,
1344
+ input_ids: torch.LongTensor = None,
1345
+ attention_mask: Optional[torch.Tensor] = None,
1346
+ position_ids: Optional[torch.LongTensor] = None,
1347
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1348
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1349
+ labels: Optional[torch.LongTensor] = None,
1350
+ use_cache: Optional[bool] = None,
1351
+ output_attentions: Optional[bool] = None,
1352
+ output_hidden_states: Optional[bool] = None,
1353
+ return_dict: Optional[bool] = None,
1354
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1355
+ r"""
1356
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1357
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1358
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1359
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1360
+ """
1361
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1362
+
1363
+ transformer_outputs = self.model(
1364
+ input_ids,
1365
+ attention_mask=attention_mask,
1366
+ position_ids=position_ids,
1367
+ past_key_values=past_key_values,
1368
+ inputs_embeds=inputs_embeds,
1369
+ use_cache=use_cache,
1370
+ output_attentions=output_attentions,
1371
+ output_hidden_states=output_hidden_states,
1372
+ return_dict=return_dict,
1373
+ )
1374
+ hidden_states = transformer_outputs[0]
1375
+ logits = self.score(hidden_states)
1376
+
1377
+ if input_ids is not None:
1378
+ batch_size = input_ids.shape[0]
1379
+ else:
1380
+ batch_size = inputs_embeds.shape[0]
1381
+
1382
+ if self.config.pad_token_id is None and batch_size != 1:
1383
+ raise ValueError('Cannot handle batch sizes > 1 if no padding token is defined.')
1384
+ if self.config.pad_token_id is None:
1385
+ sequence_lengths = -1
1386
+ else:
1387
+ if input_ids is not None:
1388
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
1389
+ logits.device
1390
+ )
1391
+ else:
1392
+ sequence_lengths = -1
1393
+
1394
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1395
+
1396
+ loss = None
1397
+ if labels is not None:
1398
+ labels = labels.to(logits.device)
1399
+ if self.config.problem_type is None:
1400
+ if self.num_labels == 1:
1401
+ self.config.problem_type = 'regression'
1402
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1403
+ self.config.problem_type = 'single_label_classification'
1404
+ else:
1405
+ self.config.problem_type = 'multi_label_classification'
1406
+
1407
+ if self.config.problem_type == 'regression':
1408
+ loss_fct = MSELoss()
1409
+ if self.num_labels == 1:
1410
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1411
+ else:
1412
+ loss = loss_fct(pooled_logits, labels)
1413
+ elif self.config.problem_type == 'single_label_classification':
1414
+ loss_fct = CrossEntropyLoss()
1415
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1416
+ elif self.config.problem_type == 'multi_label_classification':
1417
+ loss_fct = BCEWithLogitsLoss()
1418
+ loss = loss_fct(pooled_logits, labels)
1419
+ if not return_dict:
1420
+ output = (pooled_logits,) + transformer_outputs[1:]
1421
+ return ((loss,) + output) if loss is not None else output
1422
+
1423
+ return SequenceClassifierOutputWithPast(
1424
+ loss=loss,
1425
+ logits=pooled_logits,
1426
+ past_key_values=transformer_outputs.past_key_values,
1427
+ hidden_states=transformer_outputs.hidden_states,
1428
+ attentions=transformer_outputs.attentions,
1429
+ )
src/third_party/InternVL/internvl_chat/internvl/model/internlm2/tokenization_internlm2.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on transformers/src/transformers/models/llama/tokenization_llama.py
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """Tokenization classes for InternLM."""
18
+ import os
19
+ from shutil import copyfile
20
+ from typing import Any, Dict, List, Optional, Tuple
21
+
22
+ import sentencepiece as spm
23
+ from transformers.tokenization_utils import PreTrainedTokenizer
24
+ from transformers.utils import logging
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+ VOCAB_FILES_NAMES = {'vocab_file': './tokenizer.model'}
29
+
30
+ PRETRAINED_VOCAB_FILES_MAP = {}
31
+
32
+
33
+ # Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer
34
+ class InternLM2Tokenizer(PreTrainedTokenizer):
35
+ """
36
+ Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding.
37
+
38
+ Args:
39
+ vocab_file (`str`):
40
+ Path to the vocabulary file.
41
+ """
42
+
43
+ vocab_files_names = VOCAB_FILES_NAMES
44
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
45
+ model_input_names = ['input_ids', 'attention_mask']
46
+ _auto_class = 'AutoTokenizer'
47
+
48
+ def __init__(
49
+ self,
50
+ vocab_file,
51
+ unk_token='<unk>',
52
+ bos_token='<s>',
53
+ eos_token='</s>',
54
+ pad_token='</s>',
55
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
56
+ add_bos_token=True,
57
+ add_eos_token=False,
58
+ decode_with_prefix_space=False,
59
+ clean_up_tokenization_spaces=False,
60
+ **kwargs,
61
+ ):
62
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
63
+ self.vocab_file = vocab_file
64
+ self.add_bos_token = add_bos_token
65
+ self.add_eos_token = add_eos_token
66
+ self.decode_with_prefix_space = decode_with_prefix_space
67
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
68
+ self.sp_model.Load(vocab_file)
69
+ self._no_prefix_space_tokens = None
70
+ super().__init__(
71
+ bos_token=bos_token,
72
+ eos_token=eos_token,
73
+ unk_token=unk_token,
74
+ pad_token=pad_token,
75
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
76
+ **kwargs,
77
+ )
78
+
79
+ @property
80
+ def no_prefix_space_tokens(self):
81
+ if self._no_prefix_space_tokens is None:
82
+ vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
83
+ self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith('▁')}
84
+ return self._no_prefix_space_tokens
85
+
86
+ @property
87
+ def vocab_size(self):
88
+ """Returns vocab size"""
89
+ return self.sp_model.get_piece_size()
90
+
91
+ @property
92
+ def bos_token_id(self) -> Optional[int]:
93
+ return self.sp_model.bos_id()
94
+
95
+ @property
96
+ def eos_token_id(self) -> Optional[int]:
97
+ return self.sp_model.eos_id()
98
+
99
+ def get_vocab(self):
100
+ """Returns vocab as a dict"""
101
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
102
+ vocab.update(self.added_tokens_encoder)
103
+ return vocab
104
+
105
+ def _tokenize(self, text):
106
+ """Returns a tokenized string."""
107
+ return self.sp_model.encode(text, out_type=str)
108
+
109
+ def _convert_token_to_id(self, token):
110
+ """Converts a token (str) in an id using the vocab."""
111
+ return self.sp_model.piece_to_id(token)
112
+
113
+ def _convert_id_to_token(self, index):
114
+ """Converts an index (integer) in a token (str) using the vocab."""
115
+ token = self.sp_model.IdToPiece(index)
116
+ return token
117
+
118
+ def _maybe_add_prefix_space(self, tokens, decoded):
119
+ if tokens and tokens[0] not in self.no_prefix_space_tokens:
120
+ return ' ' + decoded
121
+ else:
122
+ return decoded
123
+
124
+ def convert_tokens_to_string(self, tokens):
125
+ """Converts a sequence of tokens (string) in a single string."""
126
+ current_sub_tokens = []
127
+ out_string = ''
128
+ prev_is_special = False
129
+ for token in tokens:
130
+ # make sure that special tokens are not decoded using sentencepiece model
131
+ if token in self.all_special_tokens:
132
+ if not prev_is_special:
133
+ out_string += ' '
134
+ out_string += self.sp_model.decode(current_sub_tokens) + token
135
+ prev_is_special = True
136
+ current_sub_tokens = []
137
+ else:
138
+ current_sub_tokens.append(token)
139
+ prev_is_special = False
140
+ out_string += self.sp_model.decode(current_sub_tokens)
141
+ out_string = self.clean_up_tokenization(out_string)
142
+ out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
143
+ return out_string[1:]
144
+
145
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
146
+ """
147
+ Save the vocabulary and special tokens file to a directory.
148
+
149
+ Args:
150
+ save_directory (`str`):
151
+ The directory in which to save the vocabulary.
152
+
153
+ Returns:
154
+ `Tuple(str)`: Paths to the files saved.
155
+ """
156
+ if not os.path.isdir(save_directory):
157
+ logger.error(f'Vocabulary path ({save_directory}) should be a directory')
158
+ return
159
+ out_vocab_file = os.path.join(
160
+ save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file']
161
+ )
162
+
163
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
164
+ copyfile(self.vocab_file, out_vocab_file)
165
+ elif not os.path.isfile(self.vocab_file):
166
+ with open(out_vocab_file, 'wb') as fi:
167
+ content_spiece_model = self.sp_model.serialized_model_proto()
168
+ fi.write(content_spiece_model)
169
+
170
+ return (out_vocab_file,)
171
+
172
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
173
+ if self.add_bos_token:
174
+ bos_token_ids = [self.bos_token_id]
175
+ else:
176
+ bos_token_ids = []
177
+
178
+ output = bos_token_ids + token_ids_0
179
+
180
+ if token_ids_1 is not None:
181
+ output = output + token_ids_1
182
+
183
+ if self.add_eos_token:
184
+ output = output + [self.eos_token_id]
185
+
186
+ return output
187
+
188
+ def get_special_tokens_mask(
189
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
190
+ ) -> List[int]:
191
+ """
192
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
193
+ special tokens using the tokenizer `prepare_for_model` method.
194
+
195
+ Args:
196
+ token_ids_0 (`List[int]`):
197
+ List of IDs.
198
+ token_ids_1 (`List[int]`, *optional*):
199
+ Optional second list of IDs for sequence pairs.
200
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
201
+ Whether or not the token list is already formatted with special tokens for the model.
202
+
203
+ Returns:
204
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
205
+ """
206
+ if already_has_special_tokens:
207
+ return super().get_special_tokens_mask(
208
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
209
+ )
210
+
211
+ if token_ids_1 is None:
212
+ return [1] + ([0] * len(token_ids_0)) + [1]
213
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
214
+
215
+ def create_token_type_ids_from_sequences(
216
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
217
+ ) -> List[int]:
218
+ """
219
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
220
+ use of token type ids, therefore a list of zeros is returned.
221
+
222
+ Args:
223
+ token_ids_0 (`List[int]`):
224
+ List of IDs.
225
+ token_ids_1 (`List[int]`, *optional*):
226
+ Optional second list of IDs for sequence pairs.
227
+
228
+ Returns:
229
+ `List[int]`: List of zeros.
230
+ """
231
+ eos = [self.eos_token_id]
232
+
233
+ if token_ids_1 is None:
234
+ return len(token_ids_0 + eos) * [0]
235
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
src/third_party/InternVL/internvl_chat/internvl/model/internlm2/tokenization_internlm2_fast.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on transformers/src/transformers/models/llama/tokenization_llama_fast.py
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """Tokenization Fast class for InternLM."""
18
+ import os
19
+ from shutil import copyfile
20
+ from typing import Any, Dict, Optional, Tuple
21
+
22
+ from tokenizers import Tokenizer, decoders, normalizers, processors
23
+ from tokenizers.models import BPE
24
+ from transformers.convert_slow_tokenizer import (SLOW_TO_FAST_CONVERTERS,
25
+ SentencePieceExtractor,
26
+ SpmConverter)
27
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
28
+ from transformers.utils import logging
29
+
30
+ from .tokenization_internlm2 import InternLM2Tokenizer
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ VOCAB_FILES_NAMES = {'vocab_file': './tokenizer.model'}
35
+
36
+
37
+ # Modified from transformers.convert_slow_tokenizer.LlamaConverter
38
+ class InternLM2Converter(SpmConverter):
39
+ handle_byte_fallback = True
40
+
41
+ def vocab(self, proto):
42
+ vocab = [
43
+ ('<unk>', 0.0),
44
+ ('<s>', 0.0),
45
+ ('</s>', 0.0),
46
+ ]
47
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
48
+ return vocab
49
+
50
+ def unk_id(self, proto):
51
+ unk_id = 0
52
+ return unk_id
53
+
54
+ def decoder(self, replacement, add_prefix_space):
55
+ return decoders.Sequence(
56
+ [
57
+ decoders.Replace('▁', ' '),
58
+ decoders.ByteFallback(),
59
+ decoders.Fuse(),
60
+ decoders.Strip(content=' ', left=1),
61
+ ]
62
+ )
63
+
64
+ def tokenizer(self, proto):
65
+ model_type = proto.trainer_spec.model_type
66
+ vocab_scores = self.vocab(proto)
67
+ # special tokens
68
+ added_tokens = self.original_tokenizer.added_tokens_decoder
69
+ for i in range(len(vocab_scores)):
70
+ piece, score = vocab_scores[i]
71
+ if i in added_tokens:
72
+ vocab_scores[i] = (added_tokens[i].content, score)
73
+ if model_type == 1:
74
+ raise RuntimeError('InternLM2 is supposed to be a BPE model!')
75
+
76
+ elif model_type == 2:
77
+ _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
78
+ bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
79
+ tokenizer = Tokenizer(
80
+ BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
81
+ )
82
+ tokenizer.add_special_tokens(
83
+ [ added_token for index, added_token in added_tokens.items()]
84
+ )
85
+ else:
86
+ raise Exception(
87
+ "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
88
+ )
89
+
90
+ return tokenizer
91
+
92
+ def normalizer(self, proto):
93
+ normalizers_list = []
94
+ if proto.normalizer_spec.add_dummy_prefix:
95
+ normalizers_list.append(normalizers.Prepend(prepend='▁'))
96
+ normalizers_list.append(normalizers.Replace(pattern=' ', content='▁'))
97
+ return normalizers.Sequence(normalizers_list)
98
+
99
+ def pre_tokenizer(self, replacement, add_prefix_space):
100
+ return None
101
+
102
+
103
+ SLOW_TO_FAST_CONVERTERS['InternLM2Tokenizer'] = InternLM2Converter
104
+
105
+
106
+ # Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast
107
+ class InternLM2TokenizerFast(PreTrainedTokenizerFast):
108
+ vocab_files_names = VOCAB_FILES_NAMES
109
+ slow_tokenizer_class = InternLM2Tokenizer
110
+ padding_side = 'left'
111
+ model_input_names = ['input_ids', 'attention_mask']
112
+ _auto_class = 'AutoTokenizer'
113
+
114
+ def __init__(
115
+ self,
116
+ vocab_file,
117
+ unk_token='<unk>',
118
+ bos_token='<s>',
119
+ eos_token='</s>',
120
+ pad_token='</s>',
121
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
122
+ add_bos_token=True,
123
+ add_eos_token=False,
124
+ decode_with_prefix_space=False,
125
+ clean_up_tokenization_spaces=False,
126
+ **kwargs,
127
+ ):
128
+ super().__init__(
129
+ vocab_file=vocab_file,
130
+ unk_token=unk_token,
131
+ bos_token=bos_token,
132
+ eos_token=eos_token,
133
+ pad_token=pad_token,
134
+ sp_model_kwargs=sp_model_kwargs,
135
+ add_bos_token=add_bos_token,
136
+ add_eos_token=add_eos_token,
137
+ decode_with_prefix_space=decode_with_prefix_space,
138
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
139
+ **kwargs,
140
+ )
141
+ self._add_bos_token = add_bos_token
142
+ self._add_eos_token = add_eos_token
143
+ self.update_post_processor()
144
+ self.vocab_file = vocab_file
145
+
146
+ @property
147
+ def can_save_slow_tokenizer(self) -> bool:
148
+ return os.path.isfile(self.vocab_file) if self.vocab_file else False
149
+
150
+ def update_post_processor(self):
151
+ """
152
+ Updates the underlying post processor with the current `bos_token` and `eos_token`.
153
+ """
154
+ bos = self.bos_token
155
+ bos_token_id = self.bos_token_id
156
+ if bos is None and self.add_bos_token:
157
+ raise ValueError('add_bos_token = True but bos_token = None')
158
+
159
+ eos = self.eos_token
160
+ eos_token_id = self.eos_token_id
161
+ if eos is None and self.add_eos_token:
162
+ raise ValueError('add_eos_token = True but eos_token = None')
163
+
164
+ single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
165
+ pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
166
+
167
+ special_tokens = []
168
+ if self.add_bos_token:
169
+ special_tokens.append((bos, bos_token_id))
170
+ if self.add_eos_token:
171
+ special_tokens.append((eos, eos_token_id))
172
+ self._tokenizer.post_processor = processors.TemplateProcessing(
173
+ single=single, pair=pair, special_tokens=special_tokens
174
+ )
175
+
176
+ @property
177
+ def add_eos_token(self):
178
+ return self._add_eos_token
179
+
180
+ @property
181
+ def add_bos_token(self):
182
+ return self._add_bos_token
183
+
184
+ @add_eos_token.setter
185
+ def add_eos_token(self, value):
186
+ self._add_eos_token = value
187
+ self.update_post_processor()
188
+
189
+ @add_bos_token.setter
190
+ def add_bos_token(self, value):
191
+ self._add_bos_token = value
192
+ self.update_post_processor()
193
+
194
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
195
+ if not self.can_save_slow_tokenizer:
196
+ raise ValueError(
197
+ 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow '
198
+ 'tokenizer.'
199
+ )
200
+
201
+ if not os.path.isdir(save_directory):
202
+ logger.error(f'Vocabulary path ({save_directory}) should be a directory')
203
+ return
204
+ out_vocab_file = os.path.join(
205
+ save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file']
206
+ )
207
+
208
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
209
+ copyfile(self.vocab_file, out_vocab_file)
210
+
211
+ return (out_vocab_file,)
src/third_party/InternVL/internvl_chat/internvl/model/internvl_chat/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ from .configuration_intern_vit import InternVisionConfig
8
+ from .configuration_internvl_chat import InternVLChatConfig
9
+ from .modeling_intern_vit import InternVisionModel
10
+ from .modeling_internvl_chat import InternVLChatModel
11
+
12
+ __all__ = ['InternVisionConfig', 'InternVisionModel',
13
+ 'InternVLChatConfig', 'InternVLChatModel']
src/third_party/InternVL/internvl_chat/internvl/model/internvl_chat/configuration_intern_vit.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import os
8
+ from typing import Union
9
+
10
+ from transformers.configuration_utils import PretrainedConfig
11
+ from transformers.utils import logging
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+
16
+ class InternVisionConfig(PretrainedConfig):
17
+ r"""
18
+ This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
19
+ instantiate a vision encoder according to the specified arguments, defining the model architecture.
20
+
21
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
22
+ documentation from [`PretrainedConfig`] for more information.
23
+
24
+ Args:
25
+ num_channels (`int`, *optional*, defaults to 3):
26
+ Number of color channels in the input images (e.g., 3 for RGB).
27
+ patch_size (`int`, *optional*, defaults to 14):
28
+ The size (resolution) of each patch.
29
+ image_size (`int`, *optional*, defaults to 224):
30
+ The size (resolution) of each image.
31
+ qkv_bias (`bool`, *optional*, defaults to `False`):
32
+ Whether to add a bias to the queries and values in the self-attention layers.
33
+ hidden_size (`int`, *optional*, defaults to 3200):
34
+ Dimensionality of the encoder layers and the pooler layer.
35
+ num_attention_heads (`int`, *optional*, defaults to 25):
36
+ Number of attention heads for each attention layer in the Transformer encoder.
37
+ intermediate_size (`int`, *optional*, defaults to 12800):
38
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
39
+ qk_normalization (`bool`, *optional*, defaults to `True`):
40
+ Whether to normalize the queries and keys in the self-attention layers.
41
+ num_hidden_layers (`int`, *optional*, defaults to 48):
42
+ Number of hidden layers in the Transformer encoder.
43
+ use_flash_attn (`bool`, *optional*, defaults to `True`):
44
+ Whether to use flash attention mechanism.
45
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
46
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
47
+ `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
48
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
49
+ The epsilon used by the layer normalization layers.
50
+ dropout (`float`, *optional*, defaults to 0.0):
51
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
52
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
53
+ Dropout rate for stochastic depth.
54
+ attention_dropout (`float`, *optional*, defaults to 0.0):
55
+ The dropout ratio for the attention probabilities.
56
+ initializer_range (`float`, *optional*, defaults to 0.02):
57
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
58
+ initializer_factor (`float`, *optional*, defaults to 0.1):
59
+ A factor for layer scale.
60
+ """
61
+
62
+ model_type = 'intern_vit_6b'
63
+
64
+ def __init__(
65
+ self,
66
+ num_channels=3,
67
+ patch_size=14,
68
+ image_size=224,
69
+ qkv_bias=False,
70
+ hidden_size=3200,
71
+ num_attention_heads=25,
72
+ intermediate_size=12800,
73
+ qk_normalization=True,
74
+ num_hidden_layers=48,
75
+ use_flash_attn=True,
76
+ hidden_act='gelu',
77
+ norm_type='rms_norm',
78
+ layer_norm_eps=1e-6,
79
+ dropout=0.0,
80
+ drop_path_rate=0.0,
81
+ attention_dropout=0.0,
82
+ initializer_range=0.02,
83
+ initializer_factor=0.1,
84
+ **kwargs,
85
+ ):
86
+ super().__init__(**kwargs)
87
+
88
+ self.hidden_size = hidden_size
89
+ self.intermediate_size = intermediate_size
90
+ self.dropout = dropout
91
+ self.drop_path_rate = drop_path_rate
92
+ self.num_hidden_layers = num_hidden_layers
93
+ self.num_attention_heads = num_attention_heads
94
+ self.num_channels = num_channels
95
+ self.patch_size = patch_size
96
+ self.image_size = image_size
97
+ self.initializer_range = initializer_range
98
+ self.initializer_factor = initializer_factor
99
+ self.attention_dropout = attention_dropout
100
+ self.layer_norm_eps = layer_norm_eps
101
+ self.hidden_act = hidden_act
102
+ self.norm_type = norm_type
103
+ self.qkv_bias = qkv_bias
104
+ self.qk_normalization = qk_normalization
105
+ self.use_flash_attn = use_flash_attn
106
+
107
+ @classmethod
108
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
109
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
110
+
111
+ if 'vision_config' in config_dict:
112
+ config_dict = config_dict['vision_config']
113
+
114
+ if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
115
+ logger.warning(
116
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
117
+ f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
118
+ )
119
+
120
+ return cls.from_dict(config_dict, **kwargs)
src/third_party/InternVL/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import copy
8
+
9
+ from internvl.model.internlm2.configuration_internlm2 import InternLM2Config
10
+ from internvl.model.phi3.configuration_phi3 import Phi3Config
11
+ from transformers import AutoConfig, LlamaConfig, Qwen2Config
12
+ from transformers.configuration_utils import PretrainedConfig
13
+ from transformers.utils import logging
14
+
15
+ from .configuration_intern_vit import InternVisionConfig
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+
20
+ class InternVLChatConfig(PretrainedConfig):
21
+ model_type = 'internvl_chat'
22
+ is_composition = True
23
+
24
+ def __init__(
25
+ self,
26
+ vision_config=None,
27
+ llm_config=None,
28
+ use_backbone_lora=0,
29
+ use_llm_lora=0,
30
+ pad2square=False,
31
+ select_layer=-1,
32
+ force_image_size=None,
33
+ downsample_ratio=0.5,
34
+ template=None,
35
+ dynamic_image_size=False,
36
+ use_thumbnail=False,
37
+ ps_version='v1',
38
+ min_dynamic_patch=1,
39
+ max_dynamic_patch=6,
40
+ **kwargs):
41
+ super().__init__(**kwargs)
42
+
43
+ if vision_config is None:
44
+ vision_config = {'architectures': ['InternVisionModel']}
45
+ logger.info('vision_config is None. Initializing the InternVisionConfig with default values.')
46
+
47
+ if llm_config is None:
48
+ # TODO: There might still be a bug in transformers version 4.44 and above.
49
+ llm_config = {'architectures': ['']}
50
+ logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).')
51
+
52
+ self.vision_config = InternVisionConfig(**vision_config)
53
+ if llm_config['architectures'][0] == 'LlamaForCausalLM':
54
+ self.llm_config = LlamaConfig(**llm_config)
55
+ elif llm_config['architectures'][0] == 'InternLM2ForCausalLM':
56
+ self.llm_config = InternLM2Config(**llm_config)
57
+ elif llm_config['architectures'][0] == 'Phi3ForCausalLM':
58
+ self.llm_config = Phi3Config(**llm_config)
59
+ elif llm_config['architectures'][0] == 'Qwen2ForCausalLM':
60
+ self.llm_config = Qwen2Config(**llm_config)
61
+ else:
62
+ raise ValueError('Unsupported architecture: {}'.format(llm_config['architectures'][0]))
63
+ self.use_backbone_lora = use_backbone_lora
64
+ self.use_llm_lora = use_llm_lora
65
+ self.pad2square = pad2square
66
+ self.select_layer = select_layer
67
+ self.force_image_size = force_image_size
68
+ self.downsample_ratio = downsample_ratio
69
+ self.template = template
70
+ self.dynamic_image_size = dynamic_image_size
71
+ self.use_thumbnail = use_thumbnail
72
+ self.ps_version = ps_version # pixel shuffle version
73
+ self.min_dynamic_patch = min_dynamic_patch
74
+ self.max_dynamic_patch = max_dynamic_patch
75
+
76
+ self.hidden_size = self.llm_config.hidden_size
77
+ # By default, we use tie_word_embeddings=False for models of all sizes.
78
+ self.tie_word_embeddings = False
79
+ self.llm_config.tie_word_embeddings = self.tie_word_embeddings
80
+
81
+ logger.info(f'vision_select_layer: {self.select_layer}')
82
+ logger.info(f'ps_version: {self.ps_version}')
83
+ logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}')
84
+ logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}')
85
+
86
+ def to_dict(self):
87
+ """
88
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
89
+
90
+ Returns:
91
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
92
+ """
93
+ output = copy.deepcopy(self.__dict__)
94
+ output['vision_config'] = self.vision_config.to_dict()
95
+ output['llm_config'] = self.llm_config.to_dict()
96
+ output['model_type'] = self.__class__.model_type
97
+ output['use_backbone_lora'] = self.use_backbone_lora
98
+ output['use_llm_lora'] = self.use_llm_lora
99
+ output['select_layer'] = self.select_layer
100
+ output['force_image_size'] = self.force_image_size
101
+ output['downsample_ratio'] = self.downsample_ratio
102
+ output['template'] = self.template
103
+ output['dynamic_image_size'] = self.dynamic_image_size
104
+ output['use_thumbnail'] = self.use_thumbnail
105
+ output['ps_version'] = self.ps_version
106
+ output['min_dynamic_patch'] = self.min_dynamic_patch
107
+ output['max_dynamic_patch'] = self.max_dynamic_patch
108
+
109
+ return output
src/third_party/InternVL/internvl_chat/internvl/model/internvl_chat/modeling_intern_vit.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from einops import rearrange
13
+ from timm.models.layers import DropPath
14
+ from torch import nn
15
+ from transformers.activations import ACT2FN
16
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.utils import logging
19
+
20
+ from .configuration_intern_vit import InternVisionConfig
21
+
22
+
23
+ try:
24
+ from flash_attn.bert_padding import pad_input, unpad_input
25
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
26
+
27
+ has_flash_attn = True
28
+ except:
29
+ print("FlashAttention2 is not installed.")
30
+ has_flash_attn = False
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class FlashAttention(nn.Module):
36
+ """Implement the scaled dot product attention with softmax.
37
+ Arguments
38
+ ---------
39
+ softmax_scale: The temperature to use for the softmax attention.
40
+ (default: 1/sqrt(d_keys) where d_keys is computed at
41
+ runtime)
42
+ attention_dropout: The dropout rate to apply to the attention
43
+ (default: 0.0)
44
+ """
45
+
46
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
47
+ super().__init__()
48
+ self.softmax_scale = softmax_scale
49
+ self.dropout_p = attention_dropout
50
+
51
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None, max_s=None, need_weights=False):
52
+ """Implements the multihead softmax attention.
53
+ Arguments
54
+ ---------
55
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
56
+ if unpadded: (nnz, 3, h, d)
57
+ key_padding_mask: a bool tensor of shape (B, S)
58
+ """
59
+ assert not need_weights
60
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
61
+ assert qkv.is_cuda
62
+
63
+ if cu_seqlens is None:
64
+ batch_size = qkv.shape[0]
65
+ seqlen = qkv.shape[1]
66
+ if key_padding_mask is None:
67
+ qkv = rearrange(qkv, "b s ... -> (b s) ...")
68
+ max_s = seqlen
69
+ cu_seqlens = torch.arange(
70
+ 0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device
71
+ )
72
+ output = flash_attn_varlen_qkvpacked_func(
73
+ qkv,
74
+ cu_seqlens,
75
+ max_s,
76
+ self.dropout_p if self.training else 0.0,
77
+ softmax_scale=self.softmax_scale,
78
+ causal=causal,
79
+ )
80
+ output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
81
+ else:
82
+ nheads = qkv.shape[-2]
83
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
84
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
85
+ x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
86
+ output_unpad = flash_attn_varlen_qkvpacked_func(
87
+ x_unpad,
88
+ cu_seqlens,
89
+ max_s,
90
+ self.dropout_p if self.training else 0.0,
91
+ softmax_scale=self.softmax_scale,
92
+ causal=causal,
93
+ )
94
+ output = rearrange(
95
+ pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen),
96
+ "b s (h d) -> b s h d",
97
+ h=nheads,
98
+ )
99
+ else:
100
+ assert max_s is not None
101
+ output = flash_attn_varlen_qkvpacked_func(
102
+ qkv,
103
+ cu_seqlens,
104
+ max_s,
105
+ self.dropout_p if self.training else 0.0,
106
+ softmax_scale=self.softmax_scale,
107
+ causal=causal,
108
+ )
109
+
110
+ return output, None
111
+
112
+
113
+ class InternRMSNorm(nn.Module):
114
+ def __init__(self, hidden_size, eps=1e-6):
115
+ super().__init__()
116
+ self.weight = nn.Parameter(torch.ones(hidden_size))
117
+ self.variance_epsilon = eps
118
+
119
+ def forward(self, hidden_states):
120
+ input_dtype = hidden_states.dtype
121
+ hidden_states = hidden_states.to(torch.float32)
122
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
123
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
124
+ return self.weight * hidden_states.to(input_dtype)
125
+
126
+
127
+ try:
128
+ from apex.normalization import FusedRMSNorm
129
+
130
+ InternRMSNorm = FusedRMSNorm # noqa
131
+
132
+ logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm")
133
+ except ImportError:
134
+ # using the normal InternRMSNorm
135
+ pass
136
+ except Exception:
137
+ logger.warning("discovered apex but it failed to load, falling back to InternRMSNorm")
138
+ pass
139
+
140
+
141
+ NORM2FN = {
142
+ "rms_norm": InternRMSNorm,
143
+ "layer_norm": nn.LayerNorm,
144
+ }
145
+
146
+
147
+ class InternVisionEmbeddings(nn.Module):
148
+ def __init__(self, config: InternVisionConfig):
149
+ super().__init__()
150
+ self.config = config
151
+ self.embed_dim = config.hidden_size
152
+ self.image_size = config.image_size
153
+ self.patch_size = config.patch_size
154
+
155
+ self.class_embedding = nn.Parameter(
156
+ torch.randn(1, 1, self.embed_dim),
157
+ )
158
+
159
+ self.patch_embedding = nn.Conv2d(
160
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
161
+ )
162
+
163
+ self.num_patches = (self.image_size // self.patch_size) ** 2
164
+ self.num_positions = self.num_patches + 1
165
+
166
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
167
+
168
+ def _get_pos_embed(self, pos_embed, H, W):
169
+ target_dtype = pos_embed.dtype
170
+ pos_embed = (
171
+ pos_embed.float()
172
+ .reshape(1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1)
173
+ .permute(0, 3, 1, 2)
174
+ )
175
+ pos_embed = (
176
+ F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False)
177
+ .reshape(1, -1, H * W)
178
+ .permute(0, 2, 1)
179
+ .to(target_dtype)
180
+ )
181
+ return pos_embed
182
+
183
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
184
+ target_dtype = self.patch_embedding.weight.dtype
185
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
186
+ batch_size, _, height, width = patch_embeds.shape
187
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
188
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
189
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
190
+ position_embedding = torch.cat(
191
+ [self.position_embedding[:, :1, :], self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)],
192
+ dim=1,
193
+ )
194
+ embeddings = embeddings + position_embedding.to(target_dtype)
195
+ return embeddings
196
+
197
+
198
+ class InternAttention(nn.Module):
199
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
200
+
201
+ def __init__(self, config: InternVisionConfig):
202
+ super().__init__()
203
+ self.config = config
204
+ self.embed_dim = config.hidden_size
205
+ self.num_heads = config.num_attention_heads
206
+ self.use_flash_attn = config.use_flash_attn and has_flash_attn
207
+ if config.use_flash_attn and not has_flash_attn:
208
+ print("Warning: Flash Attention is not available, use_flash_attn is set to False.")
209
+ self.head_dim = self.embed_dim // self.num_heads
210
+ if self.head_dim * self.num_heads != self.embed_dim:
211
+ raise ValueError(
212
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
213
+ f" {self.num_heads})."
214
+ )
215
+
216
+ self.scale = self.head_dim**-0.5
217
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
218
+ self.attn_drop = nn.Dropout(config.attention_dropout)
219
+ self.proj_drop = nn.Dropout(config.dropout)
220
+
221
+ self.qk_normalization = config.qk_normalization
222
+
223
+ if self.qk_normalization:
224
+ self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
225
+ self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
226
+
227
+ if self.use_flash_attn:
228
+ self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
229
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim)
230
+
231
+ def _naive_attn(self, x):
232
+ B, N, C = x.shape
233
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
234
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
235
+
236
+ if self.qk_normalization:
237
+ B_, H_, N_, D_ = q.shape
238
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
239
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
240
+
241
+ attn = (q * self.scale) @ k.transpose(-2, -1)
242
+ attn = attn.softmax(dim=-1)
243
+ attn = self.attn_drop(attn)
244
+
245
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
246
+ x = self.proj(x)
247
+ x = self.proj_drop(x)
248
+ return x
249
+
250
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
251
+ qkv = self.qkv(x)
252
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
253
+
254
+ if self.qk_normalization:
255
+ q, k, v = qkv.unbind(2)
256
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
257
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
258
+ qkv = torch.stack([q, k, v], dim=2)
259
+
260
+ context, _ = self.inner_attn(qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False)
261
+ outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
262
+ outs = self.proj_drop(outs)
263
+ return outs
264
+
265
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
266
+ x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
267
+ return x
268
+
269
+
270
+ class InternMLP(nn.Module):
271
+ def __init__(self, config: InternVisionConfig):
272
+ super().__init__()
273
+ self.config = config
274
+ self.act = ACT2FN[config.hidden_act]
275
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
276
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
277
+
278
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
279
+ hidden_states = self.fc1(hidden_states)
280
+ hidden_states = self.act(hidden_states)
281
+ hidden_states = self.fc2(hidden_states)
282
+ return hidden_states
283
+
284
+
285
+ class InternVisionEncoderLayer(nn.Module):
286
+ def __init__(self, config: InternVisionConfig, drop_path_rate: float):
287
+ super().__init__()
288
+ self.embed_dim = config.hidden_size
289
+ self.intermediate_size = config.intermediate_size
290
+ self.norm_type = config.norm_type
291
+
292
+ self.attn = InternAttention(config)
293
+ self.mlp = InternMLP(config)
294
+ self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
295
+ self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
296
+
297
+ self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
298
+ self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
299
+ self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
300
+ self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
301
+
302
+ def forward(
303
+ self,
304
+ hidden_states: torch.Tensor,
305
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
306
+ """
307
+ Args:
308
+ hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
309
+ """
310
+ hidden_states = hidden_states + self.drop_path1(
311
+ self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1
312
+ )
313
+
314
+ hidden_states = hidden_states + self.drop_path2(
315
+ self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2
316
+ )
317
+
318
+ return hidden_states
319
+
320
+
321
+ class InternVisionEncoder(nn.Module):
322
+ """
323
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
324
+ [`InternEncoderLayer`].
325
+
326
+ Args:
327
+ config (`InternConfig`):
328
+ The corresponding vision configuration for the `InternEncoder`.
329
+ """
330
+
331
+ def __init__(self, config: InternVisionConfig):
332
+ super().__init__()
333
+ self.config = config
334
+ # stochastic depth decay rule
335
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
336
+ self.layers = nn.ModuleList(
337
+ [InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)]
338
+ )
339
+ self.gradient_checkpointing = True
340
+
341
+ def forward(
342
+ self,
343
+ inputs_embeds,
344
+ output_hidden_states: Optional[bool] = None,
345
+ return_dict: Optional[bool] = None,
346
+ ) -> Union[Tuple, BaseModelOutput]:
347
+ r"""
348
+ Args:
349
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
350
+ Embedded representation of the inputs. Should be float, not int tokens.
351
+ output_hidden_states (`bool`, *optional*):
352
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
353
+ for more detail.
354
+ return_dict (`bool`, *optional*):
355
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
356
+ """
357
+ output_hidden_states = (
358
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
359
+ )
360
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
361
+
362
+ encoder_states = () if output_hidden_states else None
363
+ hidden_states = inputs_embeds
364
+
365
+ for idx, encoder_layer in enumerate(self.layers):
366
+ if output_hidden_states:
367
+ encoder_states = encoder_states + (hidden_states,)
368
+ if self.gradient_checkpointing and self.training:
369
+ layer_outputs = torch.utils.checkpoint.checkpoint(encoder_layer, hidden_states)
370
+ else:
371
+ layer_outputs = encoder_layer(
372
+ hidden_states,
373
+ )
374
+ hidden_states = layer_outputs
375
+
376
+ if output_hidden_states:
377
+ encoder_states = encoder_states + (hidden_states,)
378
+
379
+ if not return_dict:
380
+ return tuple(v for v in [hidden_states, encoder_states] if v is not None)
381
+ return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states)
382
+
383
+
384
+ class InternVisionModel(PreTrainedModel):
385
+ main_input_name = "pixel_values"
386
+ _supports_flash_attn_2 = True
387
+ config_class = InternVisionConfig
388
+ _no_split_modules = ["InternVisionEncoderLayer"]
389
+
390
+ def __init__(self, config: InternVisionConfig):
391
+ super().__init__(config)
392
+ self.config = config
393
+
394
+ self.embeddings = InternVisionEmbeddings(config)
395
+ self.encoder = InternVisionEncoder(config)
396
+
397
+ def resize_pos_embeddings(self, old_size, new_size, patch_size):
398
+ pos_emb = self.embeddings.position_embedding
399
+ _, num_positions, embed_dim = pos_emb.shape
400
+ cls_emb = pos_emb[:, :1, :]
401
+ pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
402
+ pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode="bicubic", align_corners=False)
403
+ pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
404
+ pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
405
+ self.embeddings.position_embedding = nn.Parameter(pos_emb)
406
+ self.embeddings.image_size = new_size
407
+ logger.info("Resized position embeddings from {} to {}".format(old_size, new_size))
408
+
409
+ def get_input_embeddings(self):
410
+ return self.embeddings
411
+
412
+ def forward(
413
+ self,
414
+ pixel_values: Optional[torch.FloatTensor] = None,
415
+ output_hidden_states: Optional[bool] = None,
416
+ return_dict: Optional[bool] = None,
417
+ pixel_embeds: Optional[torch.FloatTensor] = None,
418
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
419
+ output_hidden_states = (
420
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
421
+ )
422
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
423
+
424
+ if pixel_values is None and pixel_embeds is None:
425
+ raise ValueError("You have to specify pixel_values or pixel_embeds")
426
+
427
+ if pixel_embeds is not None:
428
+ hidden_states = pixel_embeds
429
+ else:
430
+ if len(pixel_values.shape) == 4:
431
+ hidden_states = self.embeddings(pixel_values)
432
+ else:
433
+ raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
434
+ encoder_outputs = self.encoder(
435
+ inputs_embeds=hidden_states,
436
+ output_hidden_states=output_hidden_states,
437
+ return_dict=return_dict,
438
+ )
439
+ last_hidden_state = encoder_outputs.last_hidden_state
440
+ pooled_output = last_hidden_state[:, 0, :]
441
+
442
+ if not return_dict:
443
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
444
+
445
+ return BaseModelOutputWithPooling(
446
+ last_hidden_state=last_hidden_state,
447
+ pooler_output=pooled_output,
448
+ hidden_states=encoder_outputs.hidden_states,
449
+ attentions=encoder_outputs.attentions,
450
+ )
src/third_party/InternVL/internvl_chat/internvl/model/internvl_chat/modeling_internvl_chat.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import warnings
8
+ from typing import List, Optional, Tuple, Union
9
+
10
+ import torch.distributed as dist
11
+ import torch.utils.checkpoint
12
+ import transformers
13
+ from peft import LoraConfig, get_peft_model
14
+ from torch import nn
15
+ from torch.nn import CrossEntropyLoss
16
+ from transformers import GenerationConfig, LlamaForCausalLM, Qwen2ForCausalLM
17
+ from transformers.modeling_outputs import CausalLMOutputWithPast
18
+ from transformers.modeling_utils import PreTrainedModel
19
+ from transformers.utils import logging
20
+
21
+ from internvl.conversation import get_conv_template
22
+ from internvl.model.internlm2.modeling_internlm2 import InternLM2ForCausalLM
23
+ from internvl.model.phi3.modeling_phi3 import Phi3ForCausalLM
24
+
25
+ from .configuration_internvl_chat import InternVLChatConfig
26
+ from .modeling_intern_vit import InternVisionModel, has_flash_attn
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ def version_cmp(v1, v2, op="eq"):
33
+ import operator
34
+
35
+ from packaging import version
36
+
37
+ op_func = getattr(operator, op)
38
+ return op_func(version.parse(v1), version.parse(v2))
39
+
40
+
41
+ class InternVLChatModel(PreTrainedModel):
42
+ config_class = InternVLChatConfig
43
+ main_input_name = "pixel_values"
44
+ base_model_prefix = "language_model"
45
+ _no_split_modules = [
46
+ "InternVisionModel",
47
+ "LlamaDecoderLayer",
48
+ "InternLM2DecoderLayer",
49
+ "Phi3DecoderLayer",
50
+ "Qwen2DecoderLayer",
51
+ ]
52
+ _supports_flash_attn_2 = True
53
+ supports_gradient_checkpointing = True
54
+
55
+ def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True):
56
+ super().__init__(config)
57
+
58
+ assert version_cmp(transformers.__version__, "4.37.0", "ge")
59
+ image_size = config.force_image_size or config.vision_config.image_size
60
+ patch_size = config.vision_config.patch_size
61
+ self.patch_size = patch_size
62
+ self.select_layer = config.select_layer
63
+ self.template = config.template
64
+ self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio**2))
65
+ self.downsample_ratio = config.downsample_ratio
66
+ self.ps_version = config.ps_version
67
+ self.llm_arch_name = config.llm_config.architectures[0]
68
+ # Enable Flash Attention if supported, otherwise fall back to eager attention.
69
+ use_flash_attn = use_flash_attn if has_flash_attn else False
70
+ config.vision_config.use_flash_attn = True if use_flash_attn else False
71
+ config.llm_config.attn_implementation = "flash_attention_2" if use_flash_attn else "eager"
72
+
73
+ logger.info(f"num_image_token: {self.num_image_token}")
74
+ logger.info(f"ps_version: {self.ps_version}")
75
+ if vision_model is not None:
76
+ self.vision_model = vision_model
77
+ else:
78
+ self.vision_model = InternVisionModel(config.vision_config)
79
+ if language_model is not None:
80
+ self.language_model = language_model
81
+ else:
82
+ if config.llm_config.architectures[0] == "LlamaForCausalLM":
83
+ self.language_model = LlamaForCausalLM(config.llm_config)
84
+ elif config.llm_config.architectures[0] == "InternLM2ForCausalLM":
85
+ self.language_model = InternLM2ForCausalLM(config.llm_config)
86
+ elif config.llm_config.architectures[0] == "Phi3ForCausalLM":
87
+ self.language_model = Phi3ForCausalLM(config.llm_config)
88
+ elif config.llm_config.architectures[0] == "Qwen2ForCausalLM":
89
+ self.language_model = Qwen2ForCausalLM(config.llm_config)
90
+ else:
91
+ raise NotImplementedError(f"{config.llm_config.architectures[0]} is not implemented.")
92
+
93
+ vit_hidden_size = config.vision_config.hidden_size
94
+ llm_hidden_size = config.llm_config.hidden_size
95
+
96
+ self.mlp1 = nn.Sequential(
97
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
98
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
99
+ nn.GELU(),
100
+ nn.Linear(llm_hidden_size, llm_hidden_size),
101
+ )
102
+
103
+ self.img_context_token_id = None
104
+ self.conv_template = get_conv_template(self.template)
105
+ if hasattr(config, "system_message"):
106
+ self.system_message = config.system_message
107
+ else:
108
+ self.system_message = self.conv_template.system_message
109
+ self.num_samples = 0
110
+
111
+ if config.use_backbone_lora:
112
+ self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
113
+
114
+ if config.use_llm_lora:
115
+ self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
116
+
117
+ def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
118
+ lora_config = LoraConfig(
119
+ r=r,
120
+ target_modules=["attn.qkv", "attn.proj", "mlp.fc1", "mlp.fc2"],
121
+ lora_alpha=lora_alpha,
122
+ lora_dropout=lora_dropout,
123
+ )
124
+ self.vision_model = get_peft_model(self.vision_model, lora_config)
125
+ self.vision_model.print_trainable_parameters()
126
+
127
+ def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
128
+ # Determine the target modules based on the architecture of the language model
129
+ if self.llm_arch_name == "InternLM2ForCausalLM":
130
+ target_modules = ["attention.wqkv", "attention.wo", "feed_forward.w1", "feed_forward.w2", "feed_forward.w3"]
131
+ elif self.llm_arch_name == "Phi3ForCausalLM":
132
+ target_modules = ["mlp.down_proj", "mlp.gate_up_proj", "self_attn.o_proj", "self_attn.qkv_proj"]
133
+ elif self.llm_arch_name in ["Qwen2ForCausalLM", "LlamaForCausalLM"]:
134
+ target_modules = [
135
+ "self_attn.q_proj",
136
+ "self_attn.k_proj",
137
+ "self_attn.v_proj",
138
+ "self_attn.o_proj",
139
+ "mlp.gate_proj",
140
+ "mlp.down_proj",
141
+ "mlp.up_proj",
142
+ ]
143
+ else:
144
+ raise NotImplemented
145
+ lora_config = LoraConfig(
146
+ r=r, target_modules=target_modules, lora_alpha=lora_alpha, lora_dropout=lora_dropout, task_type="CAUSAL_LM"
147
+ )
148
+ self.language_model = get_peft_model(self.language_model, lora_config)
149
+ self.language_model.enable_input_require_grads()
150
+ self.language_model.print_trainable_parameters()
151
+
152
+ def forward(
153
+ self,
154
+ pixel_values: torch.FloatTensor,
155
+ input_ids: torch.LongTensor = None,
156
+ attention_mask: Optional[torch.Tensor] = None,
157
+ position_ids: Optional[torch.LongTensor] = None,
158
+ image_flags: Optional[torch.LongTensor] = None,
159
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
160
+ labels: Optional[torch.LongTensor] = None,
161
+ use_cache: Optional[bool] = None,
162
+ output_attentions: Optional[bool] = None,
163
+ output_hidden_states: Optional[bool] = None,
164
+ return_dict: Optional[bool] = None,
165
+ statistics: Optional[torch.LongTensor] = None,
166
+ loss_weight: Optional[List] = None,
167
+ loss_reduction_all_gather: Optional[bool] = False,
168
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
169
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
170
+
171
+ image_flags = image_flags.squeeze(-1)
172
+ input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
173
+
174
+ vit_embeds = self.extract_feature(pixel_values)
175
+ vit_embeds = vit_embeds[image_flags == 1]
176
+ vit_batch_size = pixel_values.shape[0]
177
+
178
+ B, N, C = input_embeds.shape
179
+ input_embeds = input_embeds.reshape(B * N, C)
180
+
181
+ if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
182
+ print(
183
+ f"dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}"
184
+ )
185
+ if statistics is not None:
186
+ num_samples, num_padding_tokens, num_padding_images = statistics.tolist()
187
+ self.num_samples += num_samples
188
+ print(f"total_samples={self.num_samples}, {num_samples=}, {num_padding_tokens=}, {num_padding_images=}")
189
+
190
+ input_ids = input_ids.reshape(B * N)
191
+ selected = input_ids == self.img_context_token_id
192
+ try:
193
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
194
+ ignore_flag = False
195
+ except Exception as e:
196
+ vit_embeds = vit_embeds.reshape(-1, C)
197
+ print(
198
+ f"warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, "
199
+ f"vit_embeds.shape={vit_embeds.shape}"
200
+ )
201
+ n_token = selected.sum()
202
+ input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
203
+ ignore_flag = True
204
+
205
+ input_embeds = input_embeds.reshape(B, N, C)
206
+
207
+ outputs = self.language_model(
208
+ inputs_embeds=input_embeds,
209
+ attention_mask=attention_mask,
210
+ position_ids=position_ids,
211
+ past_key_values=past_key_values,
212
+ use_cache=use_cache,
213
+ output_attentions=output_attentions,
214
+ output_hidden_states=output_hidden_states,
215
+ return_dict=return_dict,
216
+ )
217
+ logits = outputs.logits
218
+
219
+ loss = None
220
+ if labels is not None and loss_weight is not None:
221
+ loss_weight = torch.tensor(loss_weight, dtype=torch.float32, device=labels.device)
222
+ # Shift so that tokens < n predict n
223
+ shift_logits = logits[..., :-1, :].contiguous()
224
+ shift_labels = labels[..., 1:].contiguous()
225
+ shift_weights = loss_weight[..., 1:].contiguous()
226
+ # Flatten the tokens
227
+ loss_fct = CrossEntropyLoss(reduction="none")
228
+ shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
229
+ shift_labels = shift_labels.view(-1)
230
+ shift_weights = shift_weights.view(-1)
231
+ # Enable model parallelism
232
+ shift_labels = shift_labels.to(shift_logits.device)
233
+ shift_weights = shift_weights.to(shift_logits.device)
234
+ loss = loss_fct(shift_logits, shift_labels)
235
+
236
+ shift_weights_sum = shift_weights.sum()
237
+ if loss_reduction_all_gather:
238
+ dist.all_reduce(shift_weights_sum, op=dist.ReduceOp.AVG)
239
+
240
+ loss = loss * shift_weights
241
+ loss = loss.sum() / shift_weights_sum
242
+ if ignore_flag:
243
+ loss = loss * 0.0
244
+ elif labels is not None:
245
+ # Shift so that tokens < n predict n
246
+ shift_logits = logits[..., :-1, :].contiguous()
247
+ shift_labels = labels[..., 1:].contiguous()
248
+ # Flatten the tokens
249
+ loss_fct = CrossEntropyLoss()
250
+ shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
251
+ shift_labels = shift_labels.view(-1)
252
+ # Enable model parallelism
253
+ shift_labels = shift_labels.to(shift_logits.device)
254
+ loss = loss_fct(shift_logits, shift_labels)
255
+ if ignore_flag:
256
+ loss = loss * 0.0
257
+
258
+ if not return_dict:
259
+ output = (logits,) + outputs[1:]
260
+ return (loss,) + output if loss is not None else output
261
+
262
+ return CausalLMOutputWithPast(
263
+ loss=loss,
264
+ logits=logits,
265
+ past_key_values=outputs.past_key_values,
266
+ hidden_states=outputs.hidden_states,
267
+ attentions=outputs.attentions,
268
+ )
269
+
270
+ def pixel_shuffle(self, x, scale_factor=0.5):
271
+ n, w, h, c = x.size()
272
+ # N, W, H, C --> N, W, H * scale, C // scale
273
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
274
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
275
+ x = x.permute(0, 2, 1, 3).contiguous()
276
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
277
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor)))
278
+ if self.ps_version == "v1":
279
+ warnings.warn(
280
+ "In ps_version 'v1', the height and width have not been swapped back, "
281
+ "which results in a transposed image."
282
+ )
283
+ else:
284
+ x = x.permute(0, 2, 1, 3).contiguous()
285
+ return x
286
+
287
+ def extract_feature(self, pixel_values):
288
+ if self.select_layer == -1:
289
+ vit_embeds = self.vision_model(
290
+ pixel_values=pixel_values, output_hidden_states=False, return_dict=True
291
+ ).last_hidden_state
292
+ else:
293
+ vit_embeds = self.vision_model(
294
+ pixel_values=pixel_values, output_hidden_states=True, return_dict=True
295
+ ).hidden_states[self.select_layer]
296
+ vit_embeds = vit_embeds[:, 1:, :]
297
+
298
+ h = w = int(vit_embeds.shape[1] ** 0.5)
299
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
300
+ vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
301
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
302
+ vit_embeds = self.mlp1(vit_embeds)
303
+ return vit_embeds
304
+
305
+ def batch_chat(
306
+ self,
307
+ tokenizer,
308
+ pixel_values,
309
+ questions,
310
+ generation_config,
311
+ num_patches_list=None,
312
+ history=None,
313
+ return_history=False,
314
+ IMG_START_TOKEN="<img>",
315
+ IMG_END_TOKEN="</img>",
316
+ IMG_CONTEXT_TOKEN="<IMG_CONTEXT>",
317
+ verbose=False,
318
+ image_counts=None,
319
+ ):
320
+ if history is not None or return_history:
321
+ print("Now multi-turn chat is not supported in batch_chat.")
322
+ raise NotImplementedError
323
+
324
+ if image_counts is not None:
325
+ num_patches_list = image_counts
326
+ print("Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.")
327
+
328
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
329
+ self.img_context_token_id = img_context_token_id
330
+
331
+ if verbose and pixel_values is not None:
332
+ image_bs = pixel_values.shape[0]
333
+ print(f"dynamic ViT batch size: {image_bs}")
334
+
335
+ queries = []
336
+ for idx, num_patches in enumerate(num_patches_list):
337
+ question = questions[idx]
338
+ if pixel_values is not None and "<image>" not in question:
339
+ question = "<image>\n" + question
340
+ template = get_conv_template(self.template)
341
+ template.system_message = self.system_message
342
+ template.append_message(template.roles[0], question)
343
+ template.append_message(template.roles[1], None)
344
+ query = template.get_prompt()
345
+
346
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
347
+ query = query.replace("<image>", image_tokens, 1)
348
+ queries.append(query)
349
+
350
+ tokenizer.padding_side = "left"
351
+ model_inputs = tokenizer(queries, return_tensors="pt", padding=True)
352
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
353
+ input_ids = model_inputs["input_ids"].to(device)
354
+ attention_mask = model_inputs["attention_mask"].to(device)
355
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
356
+ generation_config["eos_token_id"] = eos_token_id
357
+ generation_output = self.generate(
358
+ pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, **generation_config
359
+ )
360
+ responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
361
+ responses = [response.split(template.sep.strip())[0].strip() for response in responses]
362
+ return responses
363
+
364
+ def chat(
365
+ self,
366
+ tokenizer,
367
+ pixel_values,
368
+ question,
369
+ generation_config,
370
+ history=None,
371
+ return_history=False,
372
+ num_patches_list=None,
373
+ IMG_START_TOKEN="<img>",
374
+ IMG_END_TOKEN="</img>",
375
+ IMG_CONTEXT_TOKEN="<IMG_CONTEXT>",
376
+ verbose=False,
377
+ ):
378
+ if history is None and pixel_values is not None and "<image>" not in question:
379
+ question = "<image>\n" + question
380
+
381
+ if num_patches_list is None:
382
+ num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
383
+ assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
384
+
385
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
386
+ self.img_context_token_id = img_context_token_id
387
+
388
+ template = get_conv_template(self.template)
389
+ template.system_message = self.system_message
390
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
391
+
392
+ history = [] if history is None else history
393
+ for old_question, old_answer in history:
394
+ template.append_message(template.roles[0], old_question)
395
+ template.append_message(template.roles[1], old_answer)
396
+ template.append_message(template.roles[0], question)
397
+ template.append_message(template.roles[1], None)
398
+ query = template.get_prompt()
399
+
400
+ if verbose and pixel_values is not None:
401
+ image_bs = pixel_values.shape[0]
402
+ print(f"dynamic ViT batch size: {image_bs}")
403
+
404
+ for num_patches in num_patches_list:
405
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
406
+ query = query.replace("<image>", image_tokens, 1)
407
+
408
+ model_inputs = tokenizer(query, return_tensors="pt")
409
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
410
+ input_ids = model_inputs["input_ids"].to(device)
411
+ attention_mask = model_inputs["attention_mask"].to(device)
412
+ generation_config["eos_token_id"] = eos_token_id
413
+ generation_output = self.generate(
414
+ pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, **generation_config
415
+ )
416
+ response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
417
+ response = response.split(template.sep.strip())[0].strip()
418
+ history.append((question, response))
419
+ if return_history:
420
+ return response, history
421
+ else:
422
+ query_to_print = query.replace(IMG_CONTEXT_TOKEN, "")
423
+ query_to_print = query_to_print.replace(f"{IMG_START_TOKEN}{IMG_END_TOKEN}", "<image>")
424
+ if verbose:
425
+ print(query_to_print, response)
426
+ return response
427
+
428
+ @torch.no_grad()
429
+ def generate(
430
+ self,
431
+ pixel_values: Optional[torch.FloatTensor] = None,
432
+ input_ids: Optional[torch.FloatTensor] = None,
433
+ attention_mask: Optional[torch.LongTensor] = None,
434
+ visual_features: Optional[torch.FloatTensor] = None,
435
+ generation_config: Optional[GenerationConfig] = None,
436
+ output_hidden_states: Optional[bool] = None,
437
+ **generate_kwargs,
438
+ ) -> torch.LongTensor:
439
+ assert self.img_context_token_id is not None
440
+ if pixel_values is not None:
441
+ if visual_features is not None:
442
+ vit_embeds = visual_features
443
+ else:
444
+ vit_embeds = self.extract_feature(pixel_values)
445
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
446
+ B, N, C = input_embeds.shape
447
+ input_embeds = input_embeds.reshape(B * N, C)
448
+
449
+ input_ids = input_ids.reshape(B * N)
450
+ selected = input_ids == self.img_context_token_id
451
+ assert selected.sum() != 0
452
+ input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
453
+
454
+ input_embeds = input_embeds.reshape(B, N, C)
455
+ else:
456
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
457
+
458
+ outputs = self.language_model.generate(
459
+ inputs_embeds=input_embeds,
460
+ attention_mask=attention_mask,
461
+ generation_config=generation_config,
462
+ output_hidden_states=output_hidden_states,
463
+ use_cache=True,
464
+ **generate_kwargs,
465
+ )
466
+
467
+ return outputs
468
+
469
+ @property
470
+ def lm_head(self):
471
+ return self.language_model.get_output_embeddings()
472
+
473
+ def get_input_embeddings(self):
474
+ return self.language_model.get_input_embeddings()
475
+
476
+ def get_output_embeddings(self):
477
+ return self.language_model.get_output_embeddings()
src/third_party/InternVL/internvl_chat/internvl/model/phi3/configuration_phi3.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License atd
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """ Phi-3 model configuration"""
16
+
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24
+ 'microsoft/Phi-3-mini-4k-instruct': 'https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/config.json',
25
+ 'microsoft/Phi-3-mini-128k-instruct': 'https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/resolve/main/config.json',
26
+ }
27
+
28
+
29
+ class Phi3Config(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3
32
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
33
+ defaults will yield a similar configuration to that of the
34
+ [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct).
35
+
36
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
37
+ documentation from [`PretrainedConfig`] for more information.
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 32064):
41
+ Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`Phi3Model`].
43
+ hidden_size (`int`, *optional*, defaults to 3072):
44
+ Dimension of the hidden representations.
45
+ intermediate_size (`int`, *optional*, defaults to 8192):
46
+ Dimension of the MLP representations.
47
+ num_hidden_layers (`int`, *optional*, defaults to 32):
48
+ Number of hidden layers in the Transformer decoder.
49
+ num_attention_heads (`int`, *optional*, defaults to 32):
50
+ Number of attention heads for each attention layer in the Transformer decoder.
51
+ num_key_value_heads (`int`, *optional*):
52
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
53
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
54
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
55
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
56
+ by meanpooling all the original heads within that group. For more details checkout [this
57
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
58
+ `num_attention_heads`.
59
+ resid_pdrop (`float`, *optional*, defaults to 0.0):
60
+ Dropout probability for mlp outputs.
61
+ embd_pdrop (`int`, *optional*, defaults to 0.0):
62
+ The dropout ratio for the embeddings.
63
+ attention_dropout (`float`, *optional*, defaults to 0.0):
64
+ The dropout ratio after computing the attention scores.
65
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
66
+ The non-linear activation function (function or string) in the decoder.
67
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
68
+ The maximum sequence length that this model might ever be used with.
69
+ original_max_position_embeddings (`int`, *optional*, defaults to 4096):
70
+ The maximum sequence length that this model was trained with. This is used to determine the size of the
71
+ original RoPE embeddings when using long scaling.
72
+ initializer_range (`float`, *optional*, defaults to 0.02):
73
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
74
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
75
+ The epsilon value used for the RMSNorm.
76
+ use_cache (`bool`, *optional*, defaults to `True`):
77
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
78
+ relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
79
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
80
+ Whether to tie weight embeddings
81
+ rope_theta (`float`, *optional*, defaults to 10000.0):
82
+ The base period of the RoPE embeddings.
83
+ rope_scaling (`dict`, *optional*):
84
+ The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
85
+ contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and
86
+ the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
87
+ divided by the number of attention heads divided by 2.
88
+ bos_token_id (`int`, *optional*, defaults to 1):
89
+ The id of the "beginning-of-sequence" token.
90
+ eos_token_id (`int`, *optional*, defaults to 32000):
91
+ The id of the "end-of-sequence" token.
92
+ pad_token_id (`int`, *optional*, defaults to 32000):
93
+ The id of the padding token.
94
+ sliding_window (`int`, *optional*):
95
+ Sliding window attention window size. If `None`, no sliding window is applied.
96
+
97
+ Example:
98
+
99
+ ```python
100
+ >>> from transformers import Phi3Model, Phi3Config
101
+
102
+ >>> # Initializing a Phi-3 style configuration
103
+ >>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
104
+
105
+ >>> # Initializing a model from the configuration
106
+ >>> model = Phi3Model(configuration)
107
+
108
+ >>> # Accessing the model configuration
109
+ >>> configuration = model.config
110
+ ```"""
111
+
112
+ model_type = 'phi3'
113
+ keys_to_ignore_at_inference = ['past_key_values']
114
+
115
+ def __init__(
116
+ self,
117
+ vocab_size=32064,
118
+ hidden_size=3072,
119
+ intermediate_size=8192,
120
+ num_hidden_layers=32,
121
+ num_attention_heads=32,
122
+ num_key_value_heads=None,
123
+ resid_pdrop=0.0,
124
+ embd_pdrop=0.0,
125
+ attention_dropout=0.0,
126
+ hidden_act='silu',
127
+ max_position_embeddings=4096,
128
+ original_max_position_embeddings=4096,
129
+ initializer_range=0.02,
130
+ rms_norm_eps=1e-5,
131
+ use_cache=True,
132
+ tie_word_embeddings=False,
133
+ rope_theta=10000.0,
134
+ rope_scaling=None,
135
+ bos_token_id=1,
136
+ eos_token_id=32000,
137
+ pad_token_id=32000,
138
+ sliding_window=None,
139
+ **kwargs,
140
+ ):
141
+ self.vocab_size = vocab_size
142
+ self.hidden_size = hidden_size
143
+ self.intermediate_size = intermediate_size
144
+ self.num_hidden_layers = num_hidden_layers
145
+ self.num_attention_heads = num_attention_heads
146
+
147
+ if num_key_value_heads is None:
148
+ num_key_value_heads = num_attention_heads
149
+
150
+ self.num_key_value_heads = num_key_value_heads
151
+ self.resid_pdrop = resid_pdrop
152
+ self.embd_pdrop = embd_pdrop
153
+ self.attention_dropout = attention_dropout
154
+ self.hidden_act = hidden_act
155
+ self.max_position_embeddings = max_position_embeddings
156
+ self.original_max_position_embeddings = original_max_position_embeddings
157
+ self.initializer_range = initializer_range
158
+ self.rms_norm_eps = rms_norm_eps
159
+ self.use_cache = use_cache
160
+ self.rope_theta = rope_theta
161
+ self.rope_scaling = rope_scaling
162
+ self._rope_scaling_validation()
163
+ self.sliding_window = sliding_window
164
+
165
+ super().__init__(
166
+ bos_token_id=bos_token_id,
167
+ eos_token_id=eos_token_id,
168
+ pad_token_id=pad_token_id,
169
+ tie_word_embeddings=tie_word_embeddings,
170
+ **kwargs,
171
+ )
172
+
173
+ def _rope_scaling_validation(self):
174
+ """
175
+ Validate the `rope_scaling` configuration.
176
+ """
177
+ if self.rope_scaling is None:
178
+ return
179
+
180
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3:
181
+ raise ValueError(
182
+ '`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, '
183
+ f'got {self.rope_scaling}'
184
+ )
185
+ rope_scaling_type = self.rope_scaling.get('type', None)
186
+ rope_scaling_short_factor = self.rope_scaling.get('short_factor', None)
187
+ rope_scaling_long_factor = self.rope_scaling.get('long_factor', None)
188
+ if rope_scaling_type is None or rope_scaling_type not in ['su', 'yarn']:
189
+ raise ValueError(f"`rope_scaling`'s type field must be one of ['su', 'yarn'], got {rope_scaling_type}")
190
+ if not (
191
+ isinstance(rope_scaling_short_factor, list)
192
+ and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
193
+ ):
194
+ raise ValueError(
195
+ f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
196
+ )
197
+ if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2:
198
+ raise ValueError(
199
+ f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
200
+ )
201
+ if not (
202
+ isinstance(rope_scaling_long_factor, list)
203
+ and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
204
+ ):
205
+ raise ValueError(
206
+ f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
207
+ )
208
+ if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2:
209
+ raise ValueError(
210
+ f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
211
+ )
src/third_party/InternVL/internvl_chat/internvl/model/phi3/modeling_phi3.py ADDED
@@ -0,0 +1,1610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """ PyTorch Phi-3 model."""
16
+
17
+ import inspect
18
+ import math
19
+ import warnings
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+ from transformers.activations import ACT2FN
28
+ from transformers.cache_utils import Cache, DynamicCache
29
+ from transformers.modeling_attn_mask_utils import \
30
+ _prepare_4d_causal_attention_mask
31
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
32
+ CausalLMOutputWithPast,
33
+ SequenceClassifierOutputWithPast,
34
+ TokenClassifierOutput)
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import (add_code_sample_docstrings,
37
+ add_start_docstrings,
38
+ add_start_docstrings_to_model_forward,
39
+ is_flash_attn_2_available,
40
+ is_flash_attn_greater_or_equal_2_10, logging,
41
+ replace_return_docstrings)
42
+
43
+ from .configuration_phi3 import Phi3Config
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+ # Transformers scans dependencies in the modeling file, causing issues on conditional loading. The regex only ignores try/catch blocks, but not if statements
48
+ # if is_flash_attn_2_available():
49
+ _flash_supports_window_size = False
50
+ try:
51
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
52
+ from flash_attn.bert_padding import (index_first_axis, pad_input, # noqa
53
+ unpad_input)
54
+
55
+ _flash_supports_window_size = 'window_size' in list(inspect.signature(flash_attn_func).parameters)
56
+ has_flash_attn = True
57
+ except ImportError as error:
58
+ logger.warning(
59
+ f'`flash-attention` package not found, consider installing for better performance: {error}.'
60
+ )
61
+ if not _flash_supports_window_size:
62
+ logger.warning(
63
+ "Current `flash-attenton` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`."
64
+ )
65
+ has_flash_attn = False
66
+
67
+ _CHECKPOINT_FOR_DOC = 'microsoft/Phi-3-mini-4k-instruct'
68
+ _CONFIG_FOR_DOC = 'Phi3Config'
69
+
70
+ PHI3_PRETRAINED_MODEL_ARCHIVE_LIST = [
71
+ 'microsoft/Phi-3-mini-4k-instruct',
72
+ 'microsoft/Phi-3-mini-128k-instruct',
73
+ # See all Phi-3 models at https://huggingface.co/models?filter=Phi-3
74
+ ]
75
+
76
+
77
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3
78
+ class Phi3RMSNorm(nn.Module):
79
+ def __init__(self, hidden_size, eps=1e-6):
80
+ """
81
+ Phi3RMSNorm is equivalent to T5LayerNorm
82
+ """
83
+ super().__init__()
84
+ self.weight = nn.Parameter(torch.ones(hidden_size))
85
+ self.variance_epsilon = eps
86
+
87
+ def forward(self, hidden_states):
88
+ input_dtype = hidden_states.dtype
89
+ hidden_states = hidden_states.to(torch.float32)
90
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
91
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
92
+ return self.weight * hidden_states.to(input_dtype)
93
+
94
+
95
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
96
+ def _get_unpad_data(attention_mask):
97
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
98
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
99
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
100
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
101
+ return (
102
+ indices,
103
+ cu_seqlens,
104
+ max_seqlen_in_batch,
105
+ )
106
+
107
+
108
+ # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3
109
+ class Phi3RotaryEmbedding(nn.Module):
110
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
111
+ super().__init__()
112
+
113
+ self.dim = dim
114
+ self.max_position_embeddings = max_position_embeddings
115
+ self.base = base
116
+ self.register_buffer('inv_freq', None, persistent=False)
117
+
118
+ @torch.no_grad()
119
+ def forward(self, x, position_ids, seq_len=None):
120
+ # x: [bs, num_attention_heads, seq_len, head_size]
121
+ if self.inv_freq is None:
122
+ self.inv_freq = 1.0 / (
123
+ self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
124
+ )
125
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
126
+ position_ids_expanded = position_ids[:, None, :].float()
127
+ # Force float32 since bfloat16 loses precision on long contexts
128
+ # See https://github.com/huggingface/transformers/pull/29285
129
+ device_type = x.device.type
130
+ device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu'
131
+ with torch.autocast(device_type=device_type, enabled=False):
132
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
133
+ emb = torch.cat((freqs, freqs), dim=-1)
134
+ cos = emb.cos()
135
+ sin = emb.sin()
136
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
137
+
138
+
139
+ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
140
+ def __init__(self, dim, config, device=None):
141
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
142
+
143
+ self.short_factor = config.rope_scaling['short_factor']
144
+ self.long_factor = config.rope_scaling['long_factor']
145
+ self.original_max_position_embeddings = config.original_max_position_embeddings
146
+
147
+ @torch.no_grad()
148
+ def forward(self, x, position_ids, seq_len=None):
149
+ seq_len = torch.max(position_ids) + 1
150
+ if seq_len > self.original_max_position_embeddings:
151
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
152
+ else:
153
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
154
+
155
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
156
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
157
+
158
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
159
+ position_ids_expanded = position_ids[:, None, :].float()
160
+
161
+ # Force float32 since bfloat16 loses precision on long contexts
162
+ # See https://github.com/huggingface/transformers/pull/29285
163
+ device_type = x.device.type
164
+ device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu'
165
+ with torch.autocast(device_type=device_type, enabled=False):
166
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
167
+ emb = torch.cat((freqs, freqs), dim=-1)
168
+
169
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
170
+ if scale <= 1.0:
171
+ scaling_factor = 1.0
172
+ else:
173
+ scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
174
+
175
+ cos = emb.cos() * scaling_factor
176
+ sin = emb.sin() * scaling_factor
177
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
178
+
179
+
180
+ class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
181
+ def __init__(self, dim, config, device=None):
182
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
183
+
184
+ self.short_factor = config.rope_scaling['short_factor']
185
+ self.long_factor = config.rope_scaling['long_factor']
186
+ self.original_max_position_embeddings = config.original_max_position_embeddings
187
+
188
+ @torch.no_grad()
189
+ def forward(self, x, position_ids, seq_len=None):
190
+ seq_len = torch.max(position_ids) + 1
191
+ if seq_len > self.original_max_position_embeddings:
192
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
193
+ else:
194
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
195
+
196
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
197
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
198
+
199
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
200
+ position_ids_expanded = position_ids[:, None, :].float()
201
+
202
+ # Force float32 since bfloat16 loses precision on long contexts
203
+ # See https://github.com/huggingface/transformers/pull/29285
204
+ device_type = x.device.type
205
+ device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu'
206
+ with torch.autocast(device_type=device_type, enabled=False):
207
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
208
+ emb = torch.cat((freqs, freqs), dim=-1)
209
+
210
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
211
+ if scale <= 1.0:
212
+ scaling_factor = 1.0
213
+ else:
214
+ scaling_factor = 0.1 * math.log(scale) + 1.0
215
+
216
+ cos = emb.cos() * scaling_factor
217
+ sin = emb.sin() * scaling_factor
218
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
219
+
220
+
221
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
222
+ def rotate_half(x):
223
+ """Rotates half the hidden dims of the input."""
224
+ x1 = x[..., : x.shape[-1] // 2]
225
+ x2 = x[..., x.shape[-1] // 2 :]
226
+ return torch.cat((-x2, x1), dim=-1)
227
+
228
+
229
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
230
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
231
+ """Applies Rotary Position Embedding to the query and key tensors.
232
+
233
+ Args:
234
+ q (`torch.Tensor`): The query tensor.
235
+ k (`torch.Tensor`): The key tensor.
236
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
237
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
238
+ position_ids (`torch.Tensor`, *optional*):
239
+ Deprecated and unused.
240
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
241
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
242
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
243
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
244
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
245
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
246
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
247
+ Returns:
248
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
249
+ """
250
+ cos = cos.unsqueeze(unsqueeze_dim)
251
+ sin = sin.unsqueeze(unsqueeze_dim)
252
+ q_embed = (q * cos) + (rotate_half(q) * sin)
253
+ k_embed = (k * cos) + (rotate_half(k) * sin)
254
+ return q_embed, k_embed
255
+
256
+
257
+ class Phi3MLP(nn.Module):
258
+ def __init__(self, config):
259
+ super().__init__()
260
+
261
+ self.config = config
262
+ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
263
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
264
+
265
+ self.activation_fn = ACT2FN[config.hidden_act]
266
+
267
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
268
+ up_states = self.gate_up_proj(hidden_states)
269
+
270
+ gate, up_states = up_states.chunk(2, dim=-1)
271
+ up_states = up_states * self.activation_fn(gate)
272
+
273
+ return self.down_proj(up_states)
274
+
275
+
276
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
277
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
278
+ """
279
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
280
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
281
+ """
282
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
283
+ if n_rep == 1:
284
+ return hidden_states
285
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
286
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
287
+
288
+
289
+ class Phi3Attention(nn.Module):
290
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
291
+
292
+ def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None):
293
+ super().__init__()
294
+ self.config = config
295
+ self.layer_idx = layer_idx
296
+ if layer_idx is None:
297
+ logger.warning_once(
298
+ f'Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will '
299
+ 'lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` '
300
+ 'when creating this class.'
301
+ )
302
+
303
+ self.attention_dropout = config.attention_dropout
304
+ self.hidden_size = config.hidden_size
305
+ self.num_heads = config.num_attention_heads
306
+ self.head_dim = self.hidden_size // self.num_heads
307
+ self.num_key_value_heads = config.num_key_value_heads
308
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
309
+ self.max_position_embeddings = config.max_position_embeddings
310
+ self.original_max_position_embeddings = config.original_max_position_embeddings
311
+ self.rope_theta = config.rope_theta
312
+ self.rope_scaling = config.rope_scaling
313
+ self.is_causal = True
314
+
315
+ if (self.head_dim * self.num_heads) != self.hidden_size:
316
+ raise ValueError(
317
+ f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
318
+ f' and `num_heads`: {self.num_heads}).'
319
+ )
320
+
321
+ op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim)
322
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
323
+ self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False)
324
+ self._init_rope()
325
+
326
+ def _init_rope(self):
327
+ if self.rope_scaling is None:
328
+ self.rotary_emb = Phi3RotaryEmbedding(
329
+ self.head_dim,
330
+ max_position_embeddings=self.max_position_embeddings,
331
+ base=self.rope_theta,
332
+ )
333
+ else:
334
+ scaling_type = self.config.rope_scaling['type']
335
+ if scaling_type == 'su':
336
+ self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config)
337
+ elif scaling_type == 'yarn':
338
+ self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config)
339
+ else:
340
+ raise ValueError(f'Unknown RoPE scaling type {scaling_type}')
341
+
342
+ def forward(
343
+ self,
344
+ hidden_states: torch.Tensor,
345
+ attention_mask: Optional[torch.Tensor] = None,
346
+ position_ids: Optional[torch.LongTensor] = None,
347
+ past_key_value: Optional[Cache] = None,
348
+ output_attentions: bool = False,
349
+ use_cache: bool = False,
350
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
351
+ logger.warning_once('You are not running the flash-attention implementation, expect numerical differences.')
352
+
353
+ bsz, q_len, _ = hidden_states.size()
354
+
355
+ qkv = self.qkv_proj(hidden_states)
356
+ query_pos = self.num_heads * self.head_dim
357
+ query_states = qkv[..., :query_pos]
358
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
359
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
360
+
361
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
362
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
363
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
364
+
365
+ kv_seq_len = key_states.shape[-2]
366
+ if past_key_value is not None:
367
+ if self.layer_idx is None:
368
+ raise ValueError(
369
+ f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} '
370
+ 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class '
371
+ 'with a layer index.'
372
+ )
373
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
374
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
375
+
376
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
377
+
378
+ if past_key_value is not None:
379
+ cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
380
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
381
+
382
+ # repeat k/v heads if n_kv_heads < n_heads
383
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
384
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
385
+
386
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
387
+
388
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
389
+ raise ValueError(
390
+ f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is'
391
+ f' {attn_weights.size()}'
392
+ )
393
+
394
+ if attention_mask is not None:
395
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
396
+ raise ValueError(
397
+ f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
398
+ )
399
+ attn_weights = attn_weights + attention_mask
400
+
401
+ # upcast attention to fp32
402
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
403
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
404
+
405
+ attn_output = torch.matmul(attn_weights, value_states)
406
+
407
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
408
+ raise ValueError(
409
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
410
+ f' {attn_output.size()}'
411
+ )
412
+
413
+ attn_output = attn_output.transpose(1, 2).contiguous()
414
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
415
+
416
+ attn_output = self.o_proj(attn_output)
417
+
418
+ if not output_attentions:
419
+ attn_weights = None
420
+
421
+ return attn_output, attn_weights, past_key_value
422
+
423
+
424
+ class Phi3FlashAttention2(Phi3Attention):
425
+ """
426
+ Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays
427
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
428
+ flash attention and deal with padding tokens in case the input contains any of them.
429
+ """
430
+
431
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
432
+ def __init__(self, *args, **kwargs):
433
+ super().__init__(*args, **kwargs)
434
+
435
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
436
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
437
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
438
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
439
+
440
+ def forward(
441
+ self,
442
+ hidden_states: torch.Tensor,
443
+ attention_mask: Optional[torch.LongTensor] = None,
444
+ position_ids: Optional[torch.LongTensor] = None,
445
+ past_key_value: Optional[Cache] = None,
446
+ output_attentions: bool = False,
447
+ use_cache: bool = False,
448
+ **kwargs,
449
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
450
+ # Phi3FlashAttention2 attention does not support output_attentions
451
+
452
+ if not _flash_supports_window_size:
453
+ logger.warning_once(
454
+ "The current flash attention version does not support sliding window attention. Please use `attn_implementation='eager'` or upgrade flash-attn library."
455
+ )
456
+ raise ValueError('The current flash attention version does not support sliding window attention.')
457
+
458
+ output_attentions = False
459
+
460
+ if 'padding_mask' in kwargs:
461
+ warnings.warn(
462
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`'
463
+ )
464
+
465
+ # overwrite attention_mask with padding_mask
466
+ attention_mask = kwargs.pop('padding_mask')
467
+
468
+ bsz, q_len, _ = hidden_states.size()
469
+
470
+ qkv = self.qkv_proj(hidden_states)
471
+ query_pos = self.num_heads * self.head_dim
472
+ query_states = qkv[..., :query_pos]
473
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
474
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
475
+
476
+ # Flash attention requires the input to have the shape
477
+ # batch_size x seq_length x head_dim x hidden_dim
478
+ # therefore we just need to keep the original shape
479
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
480
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
481
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
482
+
483
+ kv_seq_len = key_states.shape[-2]
484
+ if past_key_value is not None:
485
+ if self.layer_idx is None:
486
+ raise ValueError(
487
+ f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} '
488
+ 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class '
489
+ 'with a layer index.'
490
+ )
491
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
492
+
493
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
494
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
495
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len)
496
+
497
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
498
+
499
+ use_sliding_windows = (
500
+ _flash_supports_window_size
501
+ and getattr(self.config, 'sliding_window', None) is not None
502
+ and kv_seq_len > self.config.sliding_window
503
+ )
504
+
505
+ if past_key_value is not None:
506
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
507
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
508
+ if (
509
+ getattr(self.config, 'sliding_window', None) is not None
510
+ and kv_seq_len > self.config.sliding_window
511
+ and cache_has_contents
512
+ ):
513
+ slicing_tokens = 1 - self.config.sliding_window
514
+
515
+ past_key = past_key_value[self.layer_idx][0]
516
+ past_value = past_key_value[self.layer_idx][1]
517
+
518
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
519
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
520
+
521
+ if past_key.shape[-2] != self.config.sliding_window - 1:
522
+ raise ValueError(
523
+ f'past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got'
524
+ f' {past_key.shape}'
525
+ )
526
+
527
+ if attention_mask is not None:
528
+ attention_mask = attention_mask[:, slicing_tokens:]
529
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
530
+
531
+ cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
532
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
533
+
534
+ # repeat k/v heads if n_kv_heads < n_heads
535
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
536
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
537
+
538
+ attn_dropout = self.attention_dropout if self.training else 0.0
539
+
540
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
541
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
542
+ # cast them back in the correct dtype just to be sure everything works as expected.
543
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
544
+ # in fp32.
545
+
546
+ if query_states.dtype == torch.float32:
547
+ if torch.is_autocast_enabled():
548
+ target_dtype = torch.get_autocast_gpu_dtype()
549
+ # Handle the case where the model is quantized
550
+ elif hasattr(self.config, '_pre_quantization_dtype'):
551
+ target_dtype = self.config._pre_quantization_dtype
552
+ else:
553
+ target_dtype = self.qkv_proj.weight.dtype
554
+
555
+ logger.warning_once(
556
+ f'The input hidden states seems to be silently casted in float32, this might be related to'
557
+ f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in'
558
+ f' {target_dtype}.'
559
+ )
560
+
561
+ query_states = query_states.to(target_dtype)
562
+ key_states = key_states.to(target_dtype)
563
+ value_states = value_states.to(target_dtype)
564
+
565
+ # Reashape to the expected shape for Flash Attention
566
+ query_states = query_states.transpose(1, 2)
567
+ key_states = key_states.transpose(1, 2)
568
+ value_states = value_states.transpose(1, 2)
569
+
570
+ attn_output = self._flash_attention_forward(
571
+ query_states,
572
+ key_states,
573
+ value_states,
574
+ attention_mask,
575
+ q_len,
576
+ dropout=attn_dropout,
577
+ use_sliding_windows=use_sliding_windows,
578
+ )
579
+
580
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
581
+ attn_output = self.o_proj(attn_output)
582
+
583
+ if not output_attentions:
584
+ attn_weights = None
585
+
586
+ return attn_output, attn_weights, past_key_value
587
+
588
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward
589
+ def _flash_attention_forward(
590
+ self,
591
+ query_states,
592
+ key_states,
593
+ value_states,
594
+ attention_mask,
595
+ query_length,
596
+ dropout=0.0,
597
+ softmax_scale=None,
598
+ use_sliding_windows=False,
599
+ ):
600
+ """
601
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
602
+ first unpad the input, then computes the attention scores and pad the final attention scores.
603
+
604
+ Args:
605
+ query_states (`torch.Tensor`):
606
+ Input query states to be passed to Flash Attention API
607
+ key_states (`torch.Tensor`):
608
+ Input key states to be passed to Flash Attention API
609
+ value_states (`torch.Tensor`):
610
+ Input value states to be passed to Flash Attention API
611
+ attention_mask (`torch.Tensor`):
612
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
613
+ position of padding tokens and 1 for the position of non-padding tokens.
614
+ dropout (`float`):
615
+ Attention dropout
616
+ softmax_scale (`float`, *optional*):
617
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
618
+ use_sliding_windows (`bool`, *optional*):
619
+ Whether to activate sliding window attention.
620
+ """
621
+ if not self._flash_attn_uses_top_left_mask:
622
+ causal = self.is_causal
623
+ else:
624
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
625
+ causal = self.is_causal and query_length != 1
626
+
627
+ # Contains at least one padding token in the sequence
628
+ if attention_mask is not None:
629
+ batch_size = query_states.shape[0]
630
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
631
+ query_states, key_states, value_states, attention_mask, query_length
632
+ )
633
+
634
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
635
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
636
+
637
+ if not use_sliding_windows:
638
+ attn_output_unpad = flash_attn_varlen_func(
639
+ query_states,
640
+ key_states,
641
+ value_states,
642
+ cu_seqlens_q=cu_seqlens_q,
643
+ cu_seqlens_k=cu_seqlens_k,
644
+ max_seqlen_q=max_seqlen_in_batch_q,
645
+ max_seqlen_k=max_seqlen_in_batch_k,
646
+ dropout_p=dropout,
647
+ softmax_scale=softmax_scale,
648
+ causal=causal,
649
+ )
650
+ else:
651
+ attn_output_unpad = flash_attn_varlen_func(
652
+ query_states,
653
+ key_states,
654
+ value_states,
655
+ cu_seqlens_q=cu_seqlens_q,
656
+ cu_seqlens_k=cu_seqlens_k,
657
+ max_seqlen_q=max_seqlen_in_batch_q,
658
+ max_seqlen_k=max_seqlen_in_batch_k,
659
+ dropout_p=dropout,
660
+ softmax_scale=softmax_scale,
661
+ causal=causal,
662
+ window_size=(self.config.sliding_window, self.config.sliding_window),
663
+ )
664
+
665
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
666
+ else:
667
+ if not use_sliding_windows:
668
+ attn_output = flash_attn_func(
669
+ query_states,
670
+ key_states,
671
+ value_states,
672
+ dropout,
673
+ softmax_scale=softmax_scale,
674
+ causal=causal,
675
+ )
676
+ else:
677
+ attn_output = flash_attn_func(
678
+ query_states,
679
+ key_states,
680
+ value_states,
681
+ dropout,
682
+ softmax_scale=softmax_scale,
683
+ causal=causal,
684
+ window_size=(self.config.sliding_window, self.config.sliding_window),
685
+ )
686
+
687
+ return attn_output
688
+
689
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
690
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
691
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
692
+
693
+ # On the first iteration we need to properly re-create the padding mask
694
+ # by slicing it on the proper place
695
+ if kv_seq_len != attention_mask.shape[-1]:
696
+ attention_mask_num_tokens = attention_mask.shape[-1]
697
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
698
+
699
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
700
+
701
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
702
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
703
+
704
+ if query_length == kv_seq_len:
705
+ query_layer = index_first_axis(
706
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
707
+ )
708
+ cu_seqlens_q = cu_seqlens_k
709
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
710
+ indices_q = indices_k
711
+ elif query_length == 1:
712
+ max_seqlen_in_batch_q = 1
713
+ cu_seqlens_q = torch.arange(
714
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
715
+ ) # There is a memcpy here, that is very bad.
716
+ indices_q = cu_seqlens_q[:-1]
717
+ query_layer = query_layer.squeeze(1)
718
+ else:
719
+ # The -q_len: slice assumes left padding.
720
+ attention_mask = attention_mask[:, -query_length:]
721
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
722
+
723
+ return (
724
+ query_layer,
725
+ key_layer,
726
+ value_layer,
727
+ indices_q,
728
+ (cu_seqlens_q, cu_seqlens_k),
729
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
730
+ )
731
+
732
+
733
+ # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3
734
+ # TODO @Arthur no longer copied from LLama after static cache
735
+ class Phi3SdpaAttention(Phi3Attention):
736
+ """
737
+ Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
738
+ `Phi3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
739
+ SDPA API.
740
+ """
741
+
742
+ # Adapted from Phi3Attention.forward
743
+ def forward(
744
+ self,
745
+ hidden_states: torch.Tensor,
746
+ attention_mask: Optional[torch.Tensor] = None,
747
+ position_ids: Optional[torch.LongTensor] = None,
748
+ past_key_value: Optional[Cache] = None,
749
+ output_attentions: bool = False,
750
+ use_cache: bool = False,
751
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
752
+ if output_attentions:
753
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
754
+ logger.warning_once(
755
+ 'Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, '
756
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
757
+ )
758
+ return super().forward(
759
+ hidden_states=hidden_states,
760
+ attention_mask=attention_mask,
761
+ position_ids=position_ids,
762
+ past_key_value=past_key_value,
763
+ output_attentions=output_attentions,
764
+ use_cache=use_cache,
765
+ )
766
+
767
+ bsz, q_len, _ = hidden_states.size()
768
+
769
+ qkv = self.qkv_proj(hidden_states)
770
+ query_pos = self.num_heads * self.head_dim
771
+ query_states = qkv[..., :query_pos]
772
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
773
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
774
+
775
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
776
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
777
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
778
+
779
+ kv_seq_len = key_states.shape[-2]
780
+ if past_key_value is not None:
781
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
782
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
783
+
784
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
785
+
786
+ if past_key_value is not None:
787
+ cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
788
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
789
+
790
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
791
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
792
+
793
+ if attention_mask is not None:
794
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
795
+ raise ValueError(
796
+ f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
797
+ )
798
+
799
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
800
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
801
+ if query_states.device.type == 'cuda' and attention_mask is not None:
802
+ query_states = query_states.contiguous()
803
+ key_states = key_states.contiguous()
804
+ value_states = value_states.contiguous()
805
+
806
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
807
+ query_states,
808
+ key_states,
809
+ value_states,
810
+ attn_mask=attention_mask,
811
+ dropout_p=self.attention_dropout if self.training else 0.0,
812
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
813
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
814
+ )
815
+
816
+ attn_output = attn_output.transpose(1, 2).contiguous()
817
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
818
+
819
+ attn_output = self.o_proj(attn_output)
820
+
821
+ return attn_output, None, past_key_value
822
+
823
+
824
+ PHI3_ATTENTION_CLASSES = {
825
+ 'eager': Phi3Attention,
826
+ 'flash_attention_2': Phi3FlashAttention2,
827
+ 'sdpa': Phi3SdpaAttention,
828
+ }
829
+
830
+
831
+ class Phi3DecoderLayer(nn.Module):
832
+ def __init__(self, config: Phi3Config, layer_idx: int):
833
+ super().__init__()
834
+
835
+ self.config = config
836
+ self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
837
+
838
+ self.mlp = Phi3MLP(config)
839
+ self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
840
+
841
+ self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
842
+ self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
843
+ self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
844
+
845
+ def forward(
846
+ self,
847
+ hidden_states: torch.Tensor,
848
+ attention_mask: Optional[torch.Tensor] = None,
849
+ position_ids: Optional[torch.LongTensor] = None,
850
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
851
+ output_attentions: Optional[bool] = False,
852
+ use_cache: Optional[bool] = False,
853
+ **kwargs,
854
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
855
+ if 'padding_mask' in kwargs:
856
+ warnings.warn(
857
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`'
858
+ )
859
+ """
860
+ Args:
861
+ hidden_states (`torch.FloatTensor`):
862
+ input to the layer of shape `(batch, seq_len, embed_dim)`
863
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
864
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
865
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
866
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
867
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
868
+ output_attentions (`bool`, *optional*):
869
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
870
+ returned tensors for more detail.
871
+ use_cache (`bool`, *optional*):
872
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
873
+ (see `past_key_values`).
874
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
875
+ """
876
+
877
+ residual = hidden_states
878
+
879
+ hidden_states = self.input_layernorm(hidden_states)
880
+
881
+ # Self Attention
882
+ attn_outputs, self_attn_weights, present_key_value = self.self_attn(
883
+ hidden_states=hidden_states,
884
+ attention_mask=attention_mask,
885
+ position_ids=position_ids,
886
+ past_key_value=past_key_value,
887
+ output_attentions=output_attentions,
888
+ use_cache=use_cache,
889
+ )
890
+
891
+ hidden_states = residual + self.resid_attn_dropout(attn_outputs)
892
+
893
+ residual = hidden_states
894
+ hidden_states = self.post_attention_layernorm(hidden_states)
895
+ hidden_states = self.mlp(hidden_states)
896
+ hidden_states = residual + self.resid_mlp_dropout(hidden_states)
897
+
898
+ outputs = (hidden_states,)
899
+
900
+ if output_attentions:
901
+ outputs += (self_attn_weights,)
902
+
903
+ if use_cache:
904
+ outputs += (present_key_value,)
905
+
906
+ return outputs
907
+
908
+
909
+ PHI3_START_DOCSTRING = r"""
910
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
911
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
912
+ etc.)
913
+
914
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
915
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
916
+ and behavior.
917
+
918
+ Parameters:
919
+ config ([`Phi3Config`]):
920
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
921
+ load the weights associated with the model, only the configuration. Check out the
922
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
923
+ """
924
+
925
+
926
+ @add_start_docstrings(
927
+ 'The bare Phi-3 model outputting raw hidden-states without any specific head on top.',
928
+ PHI3_START_DOCSTRING,
929
+ )
930
+ class Phi3PreTrainedModel(PreTrainedModel):
931
+ config_class = Phi3Config
932
+ base_model_prefix = 'model'
933
+ supports_gradient_checkpointing = True
934
+ _no_split_modules = ['Phi3DecoderLayer']
935
+ _skip_keys_device_placement = 'past_key_values'
936
+ _supports_flash_attn_2 = True
937
+ _supports_sdpa = False
938
+ _supports_cache_class = True
939
+
940
+ _version = '0.0.5'
941
+
942
+ def __init__(self, config: Phi3Config):
943
+ if not has_flash_attn:
944
+ config._attn_implementation = 'eager'
945
+ print('Warning: Flash attention is not available, using eager attention instead.')
946
+ super().__init__(config)
947
+
948
+ def _init_weights(self, module):
949
+ std = self.config.initializer_range
950
+ if isinstance(module, nn.Linear):
951
+ module.weight.data.normal_(mean=0.0, std=std)
952
+ if module.bias is not None:
953
+ module.bias.data.zero_()
954
+ elif isinstance(module, nn.Embedding):
955
+ module.weight.data.normal_(mean=0.0, std=std)
956
+ if module.padding_idx is not None:
957
+ module.weight.data[module.padding_idx].zero_()
958
+
959
+
960
+ PHI3_INPUTS_DOCSTRING = r"""
961
+ Args:
962
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
963
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
964
+ it.
965
+
966
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
967
+ [`PreTrainedTokenizer.__call__`] for details.
968
+
969
+ [What are input IDs?](../glossary#input-ids)
970
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
971
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
972
+
973
+ - 1 for tokens that are **not masked**,
974
+ - 0 for tokens that are **masked**.
975
+
976
+ [What are attention masks?](../glossary#attention-mask)
977
+
978
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
979
+ [`PreTrainedTokenizer.__call__`] for details.
980
+
981
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
982
+ `past_key_values`).
983
+
984
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
985
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
986
+ information on the default strategy.
987
+
988
+ - 1 indicates the head is **not masked**,
989
+ - 0 indicates the head is **masked**.
990
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
991
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
992
+ config.n_positions - 1]`.
993
+
994
+ [What are position IDs?](../glossary#position-ids)
995
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
996
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
997
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
998
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
999
+
1000
+ Two formats are allowed:
1001
+ - a [`~cache_utils.Cache`] instance;
1002
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1003
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1004
+ cache format.
1005
+
1006
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1007
+ legacy cache format will be returned.
1008
+
1009
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1010
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1011
+ of shape `(batch_size, sequence_length)`.
1012
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1013
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1014
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1015
+ model's internal embedding lookup matrix.
1016
+ use_cache (`bool`, *optional*):
1017
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1018
+ `past_key_values`).
1019
+ output_attentions (`bool`, *optional*):
1020
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1021
+ tensors for more detail.
1022
+ output_hidden_states (`bool`, *optional*):
1023
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1024
+ more detail.
1025
+ return_dict (`bool`, *optional*):
1026
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1027
+ """
1028
+
1029
+
1030
+ @add_start_docstrings(
1031
+ 'The bare Phi-3 model outputting raw hidden-states without any specific head on top.',
1032
+ PHI3_START_DOCSTRING,
1033
+ )
1034
+ class Phi3Model(Phi3PreTrainedModel):
1035
+ """
1036
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
1037
+
1038
+ Args:
1039
+ config: Phi3Config
1040
+ """
1041
+
1042
+ def __init__(self, config: Phi3Config):
1043
+ super().__init__(config)
1044
+ self.padding_idx = config.pad_token_id
1045
+ self.vocab_size = config.vocab_size
1046
+
1047
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1048
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
1049
+ self.layers = nn.ModuleList(
1050
+ [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1051
+ )
1052
+ self._attn_implementation = config._attn_implementation
1053
+
1054
+ self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1055
+
1056
+ self.gradient_checkpointing = False
1057
+ # Initialize weights and apply final processing
1058
+ self.post_init()
1059
+
1060
+ def get_input_embeddings(self):
1061
+ return self.embed_tokens
1062
+
1063
+ def set_input_embeddings(self, value):
1064
+ self.embed_tokens = value
1065
+
1066
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1067
+ def forward(
1068
+ self,
1069
+ input_ids: torch.LongTensor = None,
1070
+ attention_mask: Optional[torch.Tensor] = None,
1071
+ position_ids: Optional[torch.LongTensor] = None,
1072
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1073
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1074
+ use_cache: Optional[bool] = None,
1075
+ output_attentions: Optional[bool] = None,
1076
+ output_hidden_states: Optional[bool] = None,
1077
+ return_dict: Optional[bool] = None,
1078
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1079
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1080
+ output_hidden_states = (
1081
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1082
+ )
1083
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1084
+
1085
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1086
+
1087
+ # retrieve input_ids and inputs_embeds
1088
+ if input_ids is not None and inputs_embeds is not None:
1089
+ raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time')
1090
+ elif input_ids is not None:
1091
+ batch_size, seq_length = input_ids.shape[:2]
1092
+ elif inputs_embeds is not None:
1093
+ batch_size, seq_length = inputs_embeds.shape[:2]
1094
+ else:
1095
+ raise ValueError('You have to specify either input_ids or inputs_embeds')
1096
+
1097
+ past_key_values_length = 0
1098
+
1099
+ if self.gradient_checkpointing and self.training:
1100
+ if use_cache:
1101
+ logger.warning_once(
1102
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
1103
+ )
1104
+ use_cache = False
1105
+
1106
+ if use_cache:
1107
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1108
+ if use_legacy_cache:
1109
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1110
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1111
+
1112
+ if position_ids is None:
1113
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1114
+ position_ids = torch.arange(
1115
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1116
+ )
1117
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1118
+ else:
1119
+ position_ids = position_ids.view(-1, seq_length).long()
1120
+
1121
+ if inputs_embeds is None:
1122
+ inputs_embeds = self.embed_tokens(input_ids)
1123
+
1124
+ if attention_mask is not None and self._attn_implementation == 'flash_attention_2' and use_cache:
1125
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1126
+ if is_padding_right:
1127
+ raise ValueError(
1128
+ "You are attempting to perform batched generation with padding_side='right'"
1129
+ ' this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to '
1130
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1131
+ )
1132
+
1133
+ if self._attn_implementation == 'flash_attention_2':
1134
+ # 2d mask is passed through the layers
1135
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1136
+ else:
1137
+ # 4d mask is passed through the layers
1138
+ attention_mask = _prepare_4d_causal_attention_mask(
1139
+ attention_mask,
1140
+ (batch_size, seq_length),
1141
+ inputs_embeds,
1142
+ past_key_values_length,
1143
+ sliding_window=self.config.sliding_window,
1144
+ )
1145
+
1146
+ hidden_states = inputs_embeds
1147
+
1148
+ # decoder layers
1149
+ all_hidden_states = () if output_hidden_states else None
1150
+ all_self_attns = () if output_attentions else None
1151
+ next_decoder_cache = None
1152
+
1153
+ for decoder_layer in self.layers:
1154
+ if output_hidden_states:
1155
+ all_hidden_states += (hidden_states,)
1156
+
1157
+ if self.gradient_checkpointing and self.training:
1158
+ layer_outputs = self._gradient_checkpointing_func(
1159
+ decoder_layer.__call__,
1160
+ hidden_states,
1161
+ attention_mask,
1162
+ position_ids,
1163
+ past_key_values,
1164
+ output_attentions,
1165
+ use_cache,
1166
+ )
1167
+ else:
1168
+ layer_outputs = decoder_layer(
1169
+ hidden_states,
1170
+ attention_mask=attention_mask,
1171
+ position_ids=position_ids,
1172
+ past_key_value=past_key_values,
1173
+ output_attentions=output_attentions,
1174
+ use_cache=use_cache,
1175
+ )
1176
+
1177
+ hidden_states = layer_outputs[0]
1178
+
1179
+ if use_cache:
1180
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1181
+
1182
+ if output_attentions:
1183
+ all_self_attns += (layer_outputs[1],)
1184
+
1185
+ hidden_states = self.norm(hidden_states)
1186
+
1187
+ # add hidden states from the last decoder layer
1188
+ if output_hidden_states:
1189
+ all_hidden_states += (hidden_states,)
1190
+
1191
+ next_cache = None
1192
+ if use_cache:
1193
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1194
+ if not return_dict:
1195
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1196
+ return BaseModelOutputWithPast(
1197
+ last_hidden_state=hidden_states,
1198
+ past_key_values=next_cache,
1199
+ hidden_states=all_hidden_states,
1200
+ attentions=all_self_attns,
1201
+ )
1202
+
1203
+
1204
+ class Phi3ForCausalLM(Phi3PreTrainedModel):
1205
+ _tied_weights_keys = ['lm_head.weight']
1206
+
1207
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3
1208
+ def __init__(self, config):
1209
+ super().__init__(config)
1210
+ self.model = Phi3Model(config)
1211
+ self.vocab_size = config.vocab_size
1212
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1213
+
1214
+ # Initialize weights and apply final processing
1215
+ self.post_init()
1216
+
1217
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
1218
+ def get_input_embeddings(self):
1219
+ return self.model.embed_tokens
1220
+
1221
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
1222
+ def set_input_embeddings(self, value):
1223
+ self.model.embed_tokens = value
1224
+
1225
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
1226
+ def get_output_embeddings(self):
1227
+ return self.lm_head
1228
+
1229
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
1230
+ def set_output_embeddings(self, new_embeddings):
1231
+ self.lm_head = new_embeddings
1232
+
1233
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
1234
+ def set_decoder(self, decoder):
1235
+ self.model = decoder
1236
+
1237
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
1238
+ def get_decoder(self):
1239
+ return self.model
1240
+
1241
+ # Ignore copy
1242
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1243
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1244
+ def forward(
1245
+ self,
1246
+ input_ids: torch.LongTensor = None,
1247
+ attention_mask: Optional[torch.Tensor] = None,
1248
+ position_ids: Optional[torch.LongTensor] = None,
1249
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1250
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1251
+ labels: Optional[torch.LongTensor] = None,
1252
+ use_cache: Optional[bool] = None,
1253
+ output_attentions: Optional[bool] = None,
1254
+ output_hidden_states: Optional[bool] = None,
1255
+ return_dict: Optional[bool] = None,
1256
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1257
+ r"""
1258
+ Args:
1259
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1260
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1261
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1262
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1263
+
1264
+ Returns:
1265
+
1266
+ Example:
1267
+
1268
+ ```python
1269
+ >>> from transformers import AutoTokenizer, Phi3ForCausalLM
1270
+
1271
+ >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
1272
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
1273
+
1274
+ >>> prompt = "This is an example script ."
1275
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1276
+
1277
+ >>> # Generate
1278
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1279
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1280
+ 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
1281
+ ```"""
1282
+
1283
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1284
+ output_hidden_states = (
1285
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1286
+ )
1287
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1288
+
1289
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1290
+ outputs = self.model(
1291
+ input_ids=input_ids,
1292
+ attention_mask=attention_mask,
1293
+ position_ids=position_ids,
1294
+ past_key_values=past_key_values,
1295
+ inputs_embeds=inputs_embeds,
1296
+ use_cache=use_cache,
1297
+ output_attentions=output_attentions,
1298
+ output_hidden_states=output_hidden_states,
1299
+ return_dict=return_dict,
1300
+ )
1301
+
1302
+ hidden_states = outputs[0]
1303
+ logits = self.lm_head(hidden_states)
1304
+ logits = logits.float()
1305
+
1306
+ loss = None
1307
+ if labels is not None:
1308
+ # Shift so that tokens < n predict n
1309
+ shift_logits = logits[..., :-1, :].contiguous()
1310
+ shift_labels = labels[..., 1:].contiguous()
1311
+ # Flatten the tokens
1312
+ loss_fct = CrossEntropyLoss()
1313
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1314
+ shift_labels = shift_labels.view(-1)
1315
+ # Enable model parallelism
1316
+ shift_labels = shift_labels.to(shift_logits.device)
1317
+ loss = loss_fct(shift_logits, shift_labels)
1318
+
1319
+ if not return_dict:
1320
+ output = (logits,) + outputs[1:]
1321
+ return (loss,) + output if loss is not None else output
1322
+
1323
+ return CausalLMOutputWithPast(
1324
+ loss=loss,
1325
+ logits=logits,
1326
+ past_key_values=outputs.past_key_values,
1327
+ hidden_states=outputs.hidden_states,
1328
+ attentions=outputs.attentions,
1329
+ )
1330
+
1331
+ # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation
1332
+ def prepare_inputs_for_generation(
1333
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1334
+ ):
1335
+ if past_key_values is not None:
1336
+ if isinstance(past_key_values, Cache):
1337
+ cache_length = past_key_values.get_seq_length()
1338
+ past_length = past_key_values.seen_tokens
1339
+ max_cache_length = past_key_values.get_max_length()
1340
+ else:
1341
+ cache_length = past_length = past_key_values[0][0].shape[2]
1342
+ max_cache_length = None
1343
+
1344
+ # Keep only the unprocessed tokens:
1345
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1346
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1347
+ # input)
1348
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1349
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1350
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1351
+ # input_ids based on the past_length.
1352
+ elif past_length < input_ids.shape[1]:
1353
+ input_ids = input_ids[:, past_length:]
1354
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1355
+
1356
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1357
+ if (
1358
+ max_cache_length is not None
1359
+ and attention_mask is not None
1360
+ and cache_length + input_ids.shape[1] > max_cache_length
1361
+ ):
1362
+ attention_mask = attention_mask[:, -max_cache_length:]
1363
+
1364
+ position_ids = kwargs.get('position_ids', None)
1365
+ if attention_mask is not None and position_ids is None:
1366
+ # create position_ids on the fly for batch generation
1367
+ position_ids = attention_mask.long().cumsum(-1) - 1
1368
+ position_ids.masked_fill_(attention_mask == 0, 1)
1369
+ if past_key_values:
1370
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1371
+
1372
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1373
+ if (inputs_embeds is not None and past_key_values is None) or (inputs_embeds is not None and len(past_key_values) == 0):
1374
+ model_inputs = {'inputs_embeds': inputs_embeds}
1375
+ else:
1376
+ model_inputs = {'input_ids': input_ids}
1377
+
1378
+ model_inputs.update(
1379
+ {
1380
+ 'position_ids': position_ids,
1381
+ 'past_key_values': past_key_values,
1382
+ 'use_cache': kwargs.get('use_cache'),
1383
+ 'attention_mask': attention_mask,
1384
+ }
1385
+ )
1386
+ return model_inputs
1387
+
1388
+ @staticmethod
1389
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
1390
+ def _reorder_cache(past_key_values, beam_idx):
1391
+ reordered_past = ()
1392
+ for layer_past in past_key_values:
1393
+ reordered_past += (
1394
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1395
+ )
1396
+ return reordered_past
1397
+
1398
+
1399
+ @add_start_docstrings(
1400
+ """
1401
+ The [`Phi3Model`] with a sequence classification head on top (linear layer).
1402
+
1403
+ [`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1404
+ (e.g. GPT-2) do.
1405
+
1406
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1407
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1408
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1409
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1410
+ each row of the batch).
1411
+ """,
1412
+ PHI3_START_DOCSTRING,
1413
+ )
1414
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs
1415
+ class Phi3ForSequenceClassification(Phi3PreTrainedModel):
1416
+ def __init__(self, config):
1417
+ super().__init__(config)
1418
+ self.num_labels = config.num_labels
1419
+ self.model = Phi3Model(config)
1420
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1421
+
1422
+ # Initialize weights and apply final processing
1423
+ self.post_init()
1424
+
1425
+ def get_input_embeddings(self):
1426
+ return self.model.embed_tokens
1427
+
1428
+ def set_input_embeddings(self, value):
1429
+ self.model.embed_tokens = value
1430
+
1431
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1432
+ def forward(
1433
+ self,
1434
+ input_ids: torch.LongTensor = None,
1435
+ attention_mask: Optional[torch.Tensor] = None,
1436
+ position_ids: Optional[torch.LongTensor] = None,
1437
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1438
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1439
+ labels: Optional[torch.LongTensor] = None,
1440
+ use_cache: Optional[bool] = None,
1441
+ output_attentions: Optional[bool] = None,
1442
+ output_hidden_states: Optional[bool] = None,
1443
+ return_dict: Optional[bool] = None,
1444
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1445
+ r"""
1446
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1447
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1448
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1449
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1450
+ """
1451
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1452
+
1453
+ model_outputs = self.model(
1454
+ input_ids,
1455
+ attention_mask=attention_mask,
1456
+ position_ids=position_ids,
1457
+ past_key_values=past_key_values,
1458
+ inputs_embeds=inputs_embeds,
1459
+ use_cache=use_cache,
1460
+ output_attentions=output_attentions,
1461
+ output_hidden_states=output_hidden_states,
1462
+ return_dict=return_dict,
1463
+ )
1464
+ hidden_states = model_outputs[0]
1465
+ logits = self.score(hidden_states)
1466
+
1467
+ if input_ids is not None:
1468
+ batch_size = input_ids.shape[0]
1469
+ else:
1470
+ batch_size = inputs_embeds.shape[0]
1471
+
1472
+ if self.config.pad_token_id is None and batch_size != 1:
1473
+ raise ValueError('Cannot handle batch sizes > 1 if no padding token is defined.')
1474
+ if self.config.pad_token_id is None:
1475
+ sequence_lengths = -1
1476
+ else:
1477
+ if input_ids is not None:
1478
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1479
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1480
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1481
+ sequence_lengths = sequence_lengths.to(logits.device)
1482
+ else:
1483
+ sequence_lengths = -1
1484
+
1485
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1486
+
1487
+ loss = None
1488
+ if labels is not None:
1489
+ labels = labels.to(logits.device)
1490
+ if self.config.problem_type is None:
1491
+ if self.num_labels == 1:
1492
+ self.config.problem_type = 'regression'
1493
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1494
+ self.config.problem_type = 'single_label_classification'
1495
+ else:
1496
+ self.config.problem_type = 'multi_label_classification'
1497
+
1498
+ if self.config.problem_type == 'regression':
1499
+ loss_fct = MSELoss()
1500
+ if self.num_labels == 1:
1501
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1502
+ else:
1503
+ loss = loss_fct(pooled_logits, labels)
1504
+ elif self.config.problem_type == 'single_label_classification':
1505
+ loss_fct = CrossEntropyLoss()
1506
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1507
+ elif self.config.problem_type == 'multi_label_classification':
1508
+ loss_fct = BCEWithLogitsLoss()
1509
+ loss = loss_fct(pooled_logits, labels)
1510
+ if not return_dict:
1511
+ output = (pooled_logits,) + model_outputs[1:]
1512
+ return ((loss,) + output) if loss is not None else output
1513
+
1514
+ return SequenceClassifierOutputWithPast(
1515
+ loss=loss,
1516
+ logits=pooled_logits,
1517
+ past_key_values=model_outputs.past_key_values,
1518
+ hidden_states=model_outputs.hidden_states,
1519
+ attentions=model_outputs.attentions,
1520
+ )
1521
+
1522
+
1523
+ @add_start_docstrings(
1524
+ """
1525
+ [`Phi3Model`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1526
+ Named-Entity-Recognition (NER) tasks.
1527
+ """,
1528
+ PHI3_START_DOCSTRING,
1529
+ )
1530
+ # Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs
1531
+ class Phi3ForTokenClassification(Phi3PreTrainedModel):
1532
+ def __init__(self, config: Phi3Config):
1533
+ super().__init__(config)
1534
+ self.num_labels = config.num_labels
1535
+
1536
+ self.model = Phi3Model(config)
1537
+ if hasattr(config, 'classifier_dropout') and config.classifier_dropout is not None:
1538
+ classifier_dropout = config.classifier_dropout
1539
+ elif hasattr(config, 'hidden_dropout') and config.hidden_dropout is not None:
1540
+ classifier_dropout = config.hidden_dropout
1541
+ else:
1542
+ classifier_dropout = 0.1
1543
+ self.dropout = nn.Dropout(classifier_dropout)
1544
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1545
+
1546
+ # Initialize weights and apply final processing
1547
+ self.post_init()
1548
+
1549
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1550
+ @add_code_sample_docstrings(
1551
+ checkpoint=_CHECKPOINT_FOR_DOC,
1552
+ output_type=TokenClassifierOutput,
1553
+ config_class=_CONFIG_FOR_DOC,
1554
+ )
1555
+ def forward(
1556
+ self,
1557
+ input_ids: Optional[torch.LongTensor] = None,
1558
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1559
+ attention_mask: Optional[torch.Tensor] = None,
1560
+ inputs_embeds: Optional[torch.Tensor] = None,
1561
+ labels: Optional[torch.Tensor] = None,
1562
+ use_cache: Optional[bool] = None,
1563
+ output_attentions: Optional[bool] = None,
1564
+ output_hidden_states: Optional[bool] = None,
1565
+ return_dict: Optional[bool] = None,
1566
+ **deprecated_arguments,
1567
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1568
+ r"""
1569
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1570
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1571
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1572
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1573
+ """
1574
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1575
+
1576
+ model_outputs = self.model(
1577
+ input_ids,
1578
+ past_key_values=past_key_values,
1579
+ attention_mask=attention_mask,
1580
+ inputs_embeds=inputs_embeds,
1581
+ use_cache=use_cache,
1582
+ output_attentions=output_attentions,
1583
+ output_hidden_states=output_hidden_states,
1584
+ return_dict=return_dict,
1585
+ )
1586
+
1587
+ hidden_states = model_outputs[0]
1588
+ hidden_states = self.dropout(hidden_states)
1589
+ logits = self.classifier(hidden_states)
1590
+
1591
+ loss = None
1592
+ if labels is not None:
1593
+ # move labels to correct device to enable model parallelism
1594
+ labels = labels.to(logits.device)
1595
+ batch_size, seq_length = labels.shape
1596
+ loss_fct = CrossEntropyLoss()
1597
+ loss = loss_fct(
1598
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1599
+ )
1600
+
1601
+ if not return_dict:
1602
+ output = (logits,) + model_outputs[2:]
1603
+ return ((loss,) + output) if loss is not None else output
1604
+
1605
+ return TokenClassifierOutput(
1606
+ loss=loss,
1607
+ logits=logits,
1608
+ hidden_states=model_outputs.hidden_states,
1609
+ attentions=model_outputs.attentions,
1610
+ )
src/third_party/InternVL/internvl_chat/internvl/patch/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ from .internlm2_packed_training_patch import replace_internlm2_attention_class
8
+ from .internvit_liger_monkey_patch import apply_liger_kernel_to_internvit
9
+ from .llama2_flash_attn_monkey_patch import replace_llama2_attn_with_flash_attn
10
+ from .llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
11
+ from .llama_packed_training_patch import replace_llama_attention_class
12
+ from .llama_rmsnorm_monkey_patch import \
13
+ replace_llama_rmsnorm_with_fused_rmsnorm
14
+ from .pad_data_collator import (concat_pad_data_collator,
15
+ dpo_concat_pad_data_collator,
16
+ pad_data_collator)
17
+ from .phi3_packed_training_patch import replace_phi3_attention_class
18
+ from .qwen2_packed_training_patch import replace_qwen2_attention_class
19
+ from .train_dataloader_patch import replace_train_dataloader
20
+ from .train_sampler_patch import replace_train_sampler
21
+
22
+ __all__ = ['replace_llama_attn_with_flash_attn',
23
+ 'replace_llama_rmsnorm_with_fused_rmsnorm',
24
+ 'replace_llama2_attn_with_flash_attn',
25
+ 'replace_train_sampler',
26
+ 'replace_train_dataloader',
27
+ 'replace_internlm2_attention_class',
28
+ 'replace_qwen2_attention_class',
29
+ 'replace_phi3_attention_class',
30
+ 'replace_llama_attention_class',
31
+ 'pad_data_collator',
32
+ 'dpo_concat_pad_data_collator',
33
+ 'concat_pad_data_collator',
34
+ 'apply_liger_kernel_to_internvit']
src/third_party/InternVL/internvl_chat/internvl/patch/internlm2_packed_training_patch.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import torch
8
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
9
+ from internvl.model.internlm2.modeling_internlm2 import (
10
+ INTERNLM2_ATTENTION_CLASSES, InternLM2FlashAttention2,
11
+ apply_rotary_pos_emb)
12
+
13
+
14
+ # Modified from internvl.model.internlm2.modeling_internlm2.InternLM2FlashAttention2
15
+ class InternLM2FlashAttention2ForPackedTraining(InternLM2FlashAttention2):
16
+
17
+ def _flash_attention_forward(
18
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
19
+ ):
20
+ """
21
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
22
+ first unpad the input, then computes the attention scores and pad the final attention scores.
23
+
24
+ Args:
25
+ query_states (`torch.Tensor`):
26
+ Input query states to be passed to Flash Attention API
27
+ key_states (`torch.Tensor`):
28
+ Input key states to be passed to Flash Attention API
29
+ value_states (`torch.Tensor`):
30
+ Input value states to be passed to Flash Attention API
31
+ attention_mask (`torch.Tensor`):
32
+ rename from cu_seqlens to keep compatability - (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
33
+ of the sequences in the batch.
34
+ dropout (`int`, *optional*):
35
+ Attention dropout
36
+ softmax_scale (`float`, *optional*):
37
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
38
+ """
39
+ assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1
40
+ query_states = query_states.squeeze(0)
41
+ key_states = key_states.squeeze(0)
42
+ value_states = value_states.squeeze(0)
43
+ cu_seqlens = attention_mask.squeeze(0)
44
+
45
+ with torch.no_grad():
46
+ max_seqlen = max([
47
+ cu_seqlens[idx+1] - cu_seqlens[idx]
48
+ for idx in range(cu_seqlens.size(0) - 1)
49
+ ]).item()
50
+
51
+ # Contains at least one padding token in the sequence
52
+ causal = self.is_causal and query_length != 1
53
+ attn_output = flash_attn_varlen_func(
54
+ q=query_states,
55
+ k=key_states,
56
+ v=value_states,
57
+ cu_seqlens_q=cu_seqlens,
58
+ cu_seqlens_k=cu_seqlens,
59
+ max_seqlen_q=max_seqlen,
60
+ max_seqlen_k=max_seqlen,
61
+ dropout_p=dropout,
62
+ softmax_scale=softmax_scale,
63
+ causal=causal,
64
+ )
65
+
66
+ query_states = query_states.unsqueeze(0)
67
+ key_states = key_states.unsqueeze(0)
68
+ value_states = value_states.unsqueeze(0)
69
+ return attn_output
70
+
71
+
72
+ def replace_internlm2_attention_class():
73
+ INTERNLM2_ATTENTION_CLASSES['flash_attention_2'] = InternLM2FlashAttention2ForPackedTraining
74
+ print('Replace INTERNLM2_ATTENTION_CLASSES to support packed training!!')
src/third_party/InternVL/internvl_chat/internvl/patch/internvit_liger_monkey_patch.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ def apply_liger_kernel_to_internvit() -> None:
8
+ from internvl.model.internvl_chat import modeling_intern_vit
9
+ from liger_kernel.transformers.layer_norm import LigerLayerNorm
10
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm
11
+ modeling_intern_vit.NORM2FN['rms_norm'] = LigerRMSNorm
12
+ modeling_intern_vit.NORM2FN['layer_norm'] = LigerLayerNorm
13
+ print('Liger kernel applied to InternViT')
src/third_party/InternVL/internvl_chat/internvl/patch/llama2_flash_attn_monkey_patch.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is copied from: https://github.com/lm-sys/FastChat
3
+ """
4
+
5
+ import warnings
6
+ from typing import Optional, Tuple
7
+
8
+ import torch
9
+ from flash_attn import __version__ as flash_attn_version
10
+ from flash_attn.bert_padding import pad_input, unpad_input
11
+ from flash_attn.flash_attn_interface import (flash_attn_func,
12
+ flash_attn_varlen_kvpacked_func)
13
+ from transformers.models.llama.modeling_llama import (LlamaAttention,
14
+ LlamaModel, rotate_half)
15
+
16
+
17
+ def apply_rotary_pos_emb(q, k, cos_sin, position_ids):
18
+ gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1]
19
+ gather_indices = gather_indices.repeat(
20
+ 1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3]
21
+ )
22
+ bsz = gather_indices.shape[0]
23
+ cos, sin = (
24
+ torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices)
25
+ for x in cos_sin
26
+ )
27
+ q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k))
28
+ return q, k
29
+
30
+
31
+ def forward(
32
+ self,
33
+ hidden_states: torch.Tensor,
34
+ attention_mask: Optional[torch.Tensor] = None,
35
+ position_ids: Optional[torch.Tensor] = None,
36
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
37
+ output_attentions: bool = False,
38
+ use_cache: bool = False,
39
+ padding_mask: Optional[torch.Tensor] = None,
40
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
41
+ if output_attentions:
42
+ warnings.warn(
43
+ 'Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.'
44
+ )
45
+
46
+ bsz, q_len, _ = hidden_states.size()
47
+ kv_heads = getattr(self, 'num_key_value_heads', self.num_heads)
48
+
49
+ q, k, v = (
50
+ op(hidden_states).view(bsz, q_len, nh, self.head_dim)
51
+ for op, nh in (
52
+ (self.q_proj, self.num_heads),
53
+ (self.k_proj, kv_heads),
54
+ (self.v_proj, kv_heads),
55
+ )
56
+ )
57
+ # shape: (b, s, num_heads, head_dim)
58
+
59
+ kv_seq_len = k.shape[1]
60
+ past_kv_len = 0
61
+ if past_key_value is not None:
62
+ past_kv_len = past_key_value[0].shape[2]
63
+ kv_seq_len += past_kv_len
64
+
65
+ cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
66
+ q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids)
67
+
68
+ if past_key_value is not None:
69
+ assert (
70
+ flash_attn_version >= '2.1.0'
71
+ ), 'past_key_value support requires flash-attn >= 2.1.0'
72
+ # reuse k, v
73
+ k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
74
+ v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)
75
+
76
+ past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None
77
+
78
+ if attention_mask is None:
79
+ output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
80
+ bsz, q_len, -1
81
+ )
82
+ else:
83
+ q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
84
+ # We can skip concat and call unpad twice but seems better to call unpad only once.
85
+ kv, _, cu_k_lens, max_k = unpad_input(
86
+ torch.stack((k, v), dim=2), attention_mask
87
+ )
88
+ output_unpad = flash_attn_varlen_kvpacked_func(
89
+ q,
90
+ kv,
91
+ cu_q_lens,
92
+ cu_k_lens,
93
+ max_s,
94
+ max_k,
95
+ 0.0,
96
+ softmax_scale=None,
97
+ causal=True,
98
+ )
99
+ output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
100
+ output = pad_input(output_unpad, indices, bsz, q_len)
101
+
102
+ return self.o_proj(output), None, past_key_value
103
+
104
+
105
+ # Disable the transformation of the attention mask in LlamaModel as flash attention
106
+ # takes a boolean key_padding_mask. Fills in the past kv length for use in forward.
107
+ def _prepare_decoder_attention_mask(
108
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
109
+ ):
110
+ # [bsz, seq_len]
111
+ if past_key_values_length > 0 and attention_mask is not None:
112
+ attention_mask = torch.cat(
113
+ (
114
+ torch.full(
115
+ (input_shape[0], past_key_values_length),
116
+ True,
117
+ dtype=attention_mask.dtype,
118
+ device=attention_mask.device,
119
+ ),
120
+ attention_mask,
121
+ ),
122
+ dim=-1,
123
+ )
124
+
125
+ if attention_mask is not None and torch.all(attention_mask):
126
+ return None # This uses the faster call when training with full samples
127
+
128
+ return attention_mask
129
+
130
+
131
+ def replace_llama2_attn_with_flash_attn():
132
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
133
+ if cuda_major < 8:
134
+ warnings.warn(
135
+ 'Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward.'
136
+ 'ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593'
137
+ )
138
+
139
+ LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
140
+ LlamaAttention.forward = forward
141
+
142
+
143
+ def test():
144
+ from fastchat.train.llama_flash_attn_monkey_patch import \
145
+ forward as fastchat_forward
146
+ from transformers.models.llama.configuration_llama import LlamaConfig
147
+
148
+ config = LlamaConfig(
149
+ hidden_size=1024,
150
+ intermediate_size=128,
151
+ num_hidden_layers=1,
152
+ num_attention_heads=8,
153
+ max_position_embeddings=16,
154
+ )
155
+ device = torch.device('cuda')
156
+ model = LlamaModel(config)
157
+ attn = LlamaAttention(config).to(device).half()
158
+ bsz, hs, seqlen = 2, config.hidden_size, config.max_position_embeddings
159
+ position_ids = torch.arange(seqlen, dtype=torch.long, device=device).view(
160
+ -1, seqlen
161
+ )
162
+
163
+ mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
164
+ for i in range(4):
165
+ hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
166
+ if i:
167
+ mask[0, -i:] = False
168
+ mask[1, :i] = False
169
+
170
+ lmask = model._prepare_decoder_attention_mask(mask, hidden.shape[:2], hidden, 0)
171
+ ref, _, _ = attn.forward(
172
+ hidden, attention_mask=lmask, position_ids=position_ids
173
+ )
174
+
175
+ fast, _, _ = fastchat_forward(
176
+ attn, hidden, attention_mask=mask, position_ids=position_ids
177
+ )
178
+
179
+ lmask = _prepare_decoder_attention_mask(
180
+ model, mask, hidden.shape[:2], hidden, 0
181
+ )
182
+ test, _, _ = forward(
183
+ attn, hidden, attention_mask=lmask, position_ids=position_ids
184
+ )
185
+
186
+ print(f'Mean(abs(ref)) = {torch.mean(torch.abs(ref))}')
187
+ print(f'Mean(abs(ref - fast)) = {torch.mean(torch.abs(ref - fast))}')
188
+ print(f'Mean(abs(ref - test)) = {torch.mean(torch.abs(ref - test))}')
189
+ print(f'Mean(abs(fast - test)) = {torch.mean(torch.abs(fast - test))}')
190
+ print(f'allclose(fast, test) = {torch.allclose(fast, test)}')
191
+
192
+ with torch.no_grad():
193
+ # Also check that past_kv is handled properly
194
+ hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
195
+ part_len = seqlen // 4
196
+ assert part_len * 4 == seqlen
197
+ mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
198
+ mask[0, -2:] = False
199
+ lmask = _prepare_decoder_attention_mask(
200
+ model, mask, hidden.shape[:2], hidden, 0
201
+ )
202
+ oneshot, _, _ = forward(
203
+ attn, hidden, attention_mask=lmask, position_ids=position_ids
204
+ )
205
+ parts = []
206
+ past_kv, past_kv_len = None, 0
207
+ for i in range(4):
208
+ start = part_len * i
209
+ end = start + part_len
210
+ hidden_part = hidden[:, start:end, ...]
211
+ lmask = _prepare_decoder_attention_mask(
212
+ model,
213
+ mask[:, start:end],
214
+ hidden_part.shape[:2],
215
+ hidden_part,
216
+ past_kv_len,
217
+ )
218
+ part, _, past_kv = forward(
219
+ attn,
220
+ hidden_part.clone(),
221
+ attention_mask=lmask,
222
+ position_ids=position_ids[:, start:end],
223
+ past_key_value=past_kv,
224
+ use_cache=True,
225
+ )
226
+ parts.append(part)
227
+ past_kv_len = past_kv[0].shape[2]
228
+
229
+ print(
230
+ f'allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}'
231
+ )
232
+ print(
233
+ f'allclose(oneshot, parts) = {torch.allclose(oneshot, torch.cat(parts, dim=1))}'
234
+ )
235
+
236
+
237
+ if __name__ == '__main__':
238
+ test()
src/third_party/InternVL/internvl_chat/internvl/patch/llama_flash_attn_monkey_patch.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import math
8
+ from typing import Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import transformers
13
+ from torch import nn
14
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
15
+
16
+
17
+ def forward(
18
+ self,
19
+ hidden_states: torch.Tensor,
20
+ attention_mask: Optional[torch.Tensor] = None,
21
+ position_ids: Optional[torch.Tensor] = None,
22
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
23
+ output_attentions: bool = False,
24
+ use_cache: bool = False,
25
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
26
+ """Input shape: Batch x Time x Channel
27
+
28
+ attention_mask: [bsz, q_len]
29
+ """
30
+ from einops import rearrange
31
+ try: # v1
32
+ from flash_attn.flash_attn_interface import \
33
+ flash_attn_unpadded_qkvpacked_func
34
+ except: # v2
35
+ from flash_attn.flash_attn_interface import \
36
+ flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
37
+ from flash_attn.bert_padding import pad_input, unpad_input
38
+
39
+ bsz, q_len, _ = hidden_states.size()
40
+
41
+ query_states = (
42
+ self.q_proj(hidden_states)
43
+ .view(bsz, q_len, self.num_heads, self.head_dim)
44
+ .transpose(1, 2)
45
+ )
46
+ key_states = (
47
+ self.k_proj(hidden_states)
48
+ .view(bsz, q_len, self.num_heads, self.head_dim)
49
+ .transpose(1, 2)
50
+ )
51
+ value_states = (
52
+ self.v_proj(hidden_states)
53
+ .view(bsz, q_len, self.num_heads, self.head_dim)
54
+ .transpose(1, 2)
55
+ )
56
+ # [bsz, q_len, nh, hd]
57
+ # [bsz, nh, q_len, hd]
58
+
59
+ kv_seq_len = key_states.shape[-2]
60
+ assert past_key_value is None, 'past_key_value is not supported'
61
+
62
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
63
+ query_states, key_states = apply_rotary_pos_emb(
64
+ query_states, key_states, cos, sin, position_ids
65
+ )
66
+ # [bsz, nh, t, hd]
67
+ assert not output_attentions, 'output_attentions is not supported'
68
+ assert not use_cache, 'use_cache is not supported'
69
+
70
+ # Flash attention codes from
71
+ # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
72
+
73
+ # transform the data into the format required by flash attention
74
+ qkv = torch.stack(
75
+ [query_states, key_states, value_states], dim=2
76
+ ) # [bsz, nh, 3, q_len, hd]
77
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
78
+ # We have disabled _prepare_decoder_attention_mask in LlamaModel
79
+ # the attention_mask should be the same as the key_padding_mask
80
+ key_padding_mask = attention_mask
81
+
82
+ if key_padding_mask is None:
83
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
84
+ max_s = q_len
85
+ cu_q_lens = torch.arange(
86
+ 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
87
+ )
88
+ output = flash_attn_unpadded_qkvpacked_func(
89
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
90
+ )
91
+ output = rearrange(output, '(b s) ... -> b s ...', b=bsz)
92
+ else:
93
+ nheads = qkv.shape[-2]
94
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
95
+ x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
96
+ x_unpad = rearrange(
97
+ x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads
98
+ )
99
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
100
+ x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
101
+ )
102
+ output = rearrange(
103
+ pad_input(
104
+ rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices, bsz, q_len
105
+ ),
106
+ 'b s (h d) -> b s h d',
107
+ h=nheads,
108
+ )
109
+ return self.o_proj(rearrange(output, 'b s h d -> b s (h d)')), None, None
110
+
111
+
112
+ # Disable the transformation of the attention mask in LlamaModel as the flash attention
113
+ # requires the attention mask to be the same as the key_padding_mask
114
+ def _prepare_decoder_attention_mask(
115
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
116
+ ):
117
+ # [bsz, seq_len]
118
+ return attention_mask
119
+
120
+
121
+ def forward_2(
122
+ self,
123
+ hidden_states: torch.Tensor,
124
+ attention_mask: Optional[torch.Tensor] = None,
125
+ position_ids: Optional[torch.LongTensor] = None,
126
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
127
+ output_attentions: bool = False,
128
+ use_cache: bool = False,
129
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
130
+ bsz, q_len, _ = hidden_states.size()
131
+
132
+ query_states = (
133
+ self.q_proj(hidden_states)
134
+ .view(bsz, q_len, self.num_heads, self.head_dim)
135
+ .transpose(1, 2)
136
+ )
137
+ key_states = (
138
+ self.k_proj(hidden_states)
139
+ .view(bsz, q_len, self.num_heads, self.head_dim)
140
+ .transpose(1, 2)
141
+ )
142
+ value_states = (
143
+ self.v_proj(hidden_states)
144
+ .view(bsz, q_len, self.num_heads, self.head_dim)
145
+ .transpose(1, 2)
146
+ )
147
+
148
+ kv_seq_len = key_states.shape[-2]
149
+ if past_key_value is not None:
150
+ kv_seq_len += past_key_value[0].shape[-2]
151
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
152
+ query_states, key_states = apply_rotary_pos_emb(
153
+ query_states, key_states, cos, sin, position_ids
154
+ )
155
+
156
+ assert not output_attentions, 'output_attentions is not supported'
157
+ assert not use_cache, 'use_cache is not supported'
158
+ assert past_key_value is None, 'past_key_value is not supported'
159
+
160
+ if past_key_value is not None:
161
+ # reuse k, v, self_attention
162
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
163
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
164
+
165
+ past_key_value = (key_states, value_states) if use_cache else None
166
+ if self.training:
167
+ attn_output = F.scaled_dot_product_attention(
168
+ query_states, key_states, value_states, dropout_p=0.0, is_causal=True
169
+ )
170
+ attn_weights = None
171
+ else:
172
+ attn_weights = torch.matmul(
173
+ query_states, key_states.transpose(2, 3)
174
+ ) / math.sqrt(self.head_dim)
175
+
176
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
177
+ raise ValueError(
178
+ f'Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is'
179
+ f' {attn_weights.size()}'
180
+ )
181
+
182
+ if attention_mask is not None:
183
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
184
+ raise ValueError(
185
+ f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
186
+ )
187
+ attn_weights = attn_weights + attention_mask
188
+ attn_weights = torch.max(
189
+ attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
190
+ )
191
+
192
+ # upcast attention to fp32
193
+ attn_weights = nn.functional.softmax(
194
+ attn_weights, dim=-1, dtype=torch.float32
195
+ ).to(query_states.dtype)
196
+ attn_output = torch.matmul(attn_weights, value_states)
197
+
198
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
199
+ raise ValueError(
200
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
201
+ f' {attn_output.size()}'
202
+ )
203
+
204
+ attn_output = attn_output.transpose(1, 2)
205
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
206
+
207
+ attn_output = self.o_proj(attn_output)
208
+
209
+ if not output_attentions:
210
+ attn_weights = None
211
+
212
+ return attn_output, attn_weights, past_key_value
213
+
214
+
215
+ def replace_llama_attn_with_flash_attn():
216
+ if hasattr(F, 'scaled_dot_product_attention'):
217
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_2
218
+ else:
219
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
220
+ _prepare_decoder_attention_mask
221
+ )
222
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
src/third_party/InternVL/internvl_chat/internvl/patch/llama_packed_training_patch.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import torch
8
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
9
+ from transformers.models.llama.modeling_llama import (LLAMA_ATTENTION_CLASSES,
10
+ LlamaFlashAttention2)
11
+
12
+
13
+ # Modified from transformers.models.llama.modeling_llama.LlamaFlashAttention2
14
+ class LlamaFlashAttention2ForPackedTraining(LlamaFlashAttention2):
15
+
16
+ def _flash_attention_forward(
17
+ self,
18
+ query_states,
19
+ key_states,
20
+ value_states,
21
+ attention_mask,
22
+ query_length,
23
+ dropout=0.0,
24
+ softmax_scale=None,
25
+ use_sliding_windows=False,
26
+ ):
27
+ """
28
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
29
+ first unpad the input, then computes the attention scores and pad the final attention scores.
30
+
31
+ Args:
32
+ query_states (`torch.Tensor`):
33
+ Input query states to be passed to Flash Attention API
34
+ key_states (`torch.Tensor`):
35
+ Input key states to be passed to Flash Attention API
36
+ value_states (`torch.Tensor`):
37
+ Input value states to be passed to Flash Attention API
38
+ attention_mask (`torch.Tensor`):
39
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
40
+ position of padding tokens and 1 for the position of non-padding tokens.
41
+ dropout (`int`, *optional*):
42
+ Attention dropout
43
+ softmax_scale (`float`, *optional*):
44
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
45
+ use_sliding_windows (`bool`, *optional*):
46
+ Whether to activate sliding window attention.
47
+ """
48
+ assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1
49
+ query_states = query_states.squeeze(0)
50
+ key_states = key_states.squeeze(0)
51
+ value_states = value_states.squeeze(0)
52
+ cu_seqlens = attention_mask.squeeze(0)
53
+
54
+ with torch.no_grad():
55
+ max_seqlen = max([
56
+ cu_seqlens[idx+1] - cu_seqlens[idx]
57
+ for idx in range(cu_seqlens.size(0) - 1)
58
+ ]).item()
59
+
60
+ if not self._flash_attn_uses_top_left_mask:
61
+ causal = self.is_causal
62
+ else:
63
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
64
+ causal = self.is_causal and query_length != 1
65
+
66
+ # Decide whether to use SWA or not by layer index.
67
+ if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
68
+ use_sliding_windows = False
69
+
70
+ if not use_sliding_windows:
71
+ attn_output = flash_attn_varlen_func(
72
+ q=query_states,
73
+ k=key_states,
74
+ v=value_states,
75
+ cu_seqlens_q=cu_seqlens,
76
+ cu_seqlens_k=cu_seqlens,
77
+ max_seqlen_q=max_seqlen,
78
+ max_seqlen_k=max_seqlen,
79
+ dropout_p=dropout,
80
+ softmax_scale=softmax_scale,
81
+ causal=causal,
82
+ )
83
+ else:
84
+ attn_output = flash_attn_varlen_func(
85
+ q=query_states,
86
+ k=key_states,
87
+ v=value_states,
88
+ cu_seqlens_q=cu_seqlens,
89
+ cu_seqlens_k=cu_seqlens,
90
+ max_seqlen_q=max_seqlen,
91
+ max_seqlen_k=max_seqlen,
92
+ dropout_p=dropout,
93
+ softmax_scale=softmax_scale,
94
+ causal=causal,
95
+ window_size=(self.config.sliding_window, self.config.sliding_window),
96
+ )
97
+
98
+ query_states = query_states.unsqueeze(0)
99
+ key_states = key_states.unsqueeze(0)
100
+ value_states = value_states.unsqueeze(0)
101
+ return attn_output
102
+
103
+
104
+ def replace_llama_attention_class():
105
+ LLAMA_ATTENTION_CLASSES['flash_attention_2'] = LlamaFlashAttention2ForPackedTraining
106
+ print('Replace LLAMA_ATTENTION_CLASSES to support packed training!!')
src/third_party/InternVL/internvl_chat/internvl/patch/llama_rmsnorm_monkey_patch.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import transformers
8
+
9
+
10
+ def replace_llama_rmsnorm_with_fused_rmsnorm():
11
+ try:
12
+ from functools import partial
13
+
14
+ from apex.normalization import FusedRMSNorm
15
+ LlamaRMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa
16
+ transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
17
+ print('Discovered apex.normalization.FusedRMSNorm - will use it instead of LlamaRMSNorm')
18
+ except ImportError:
19
+ # using the normal LlamaRMSNorm
20
+ pass
21
+ except Exception:
22
+ print('discovered apex but it failed to load, falling back to LlamaRMSNorm')
23
+ pass
src/third_party/InternVL/internvl_chat/internvl/patch/pad_data_collator.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ IGNORE_INDEX = -100
11
+
12
+
13
+ def pad_data_collator(features, pad_id=0):
14
+
15
+ first = features[0]
16
+ batch = {}
17
+
18
+ batch_lens = [feat['input_ids'].shape for feat in features]
19
+ max_item_length = max(batch_lens)[0]
20
+ for idx in range(len(features)):
21
+ feat = features[idx]
22
+ temp_input_ids = torch.LongTensor([pad_id] * max_item_length)
23
+ temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids']
24
+ feat['input_ids'] = temp_input_ids
25
+ temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length)
26
+ temp_labels[:feat['labels'].shape[0]] = feat['labels']
27
+ feat['labels'] = temp_labels
28
+ feat['attention_mask'] = feat['input_ids'].ne(pad_id)
29
+
30
+ # Special handling for labels.
31
+ # Ensure that tensor is created with the correct type
32
+ # (it should be automatically the case, but let's make sure of it.)
33
+ if 'label' in first and first['label'] is not None:
34
+ label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label']
35
+ dtype = torch.long if isinstance(label, int) else torch.float
36
+ batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype)
37
+ elif 'label_ids' in first and first['label_ids'] is not None:
38
+ if isinstance(first['label_ids'], torch.Tensor):
39
+ batch['labels'] = torch.stack([f['label_ids'] for f in features])
40
+ else:
41
+ dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float
42
+ batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype)
43
+
44
+ # Handling of all other possible keys.
45
+ # Again, we will use the first element to figure out which key/values are not None for this model.
46
+ for k, v in first.items():
47
+ if k not in ('label', 'label_ids') and v is not None and not isinstance(v, str):
48
+ if isinstance(v, torch.Tensor):
49
+ batch[k] = torch.stack([f[k] for f in features])
50
+ elif isinstance(v, np.ndarray):
51
+ batch[k] = torch.tensor(np.stack([f[k] for f in features]))
52
+ else:
53
+ batch[k] = torch.tensor([f[k] for f in features])
54
+ return batch
55
+
56
+
57
+ def concat_pad_data_collator(features, max_item_length=None, pad_id=0):
58
+
59
+ first = features[0]
60
+ batch = {}
61
+
62
+ batch_lens = [feat['input_ids'].shape for feat in features]
63
+ max_item_length = max_item_length or max(batch_lens)[0]
64
+ for idx in range(len(features)):
65
+ feat = features[idx]
66
+ temp_input_ids = torch.LongTensor([pad_id] * max_item_length)
67
+ temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids']
68
+ feat['input_ids'] = temp_input_ids
69
+ temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length)
70
+ temp_labels[:feat['labels'].shape[0]] = feat['labels']
71
+ feat['labels'] = temp_labels
72
+ feat['attention_mask'] = feat['input_ids'].ne(pad_id)
73
+
74
+ if 'position_ids' in feat:
75
+ temp_position_ids = torch.LongTensor([pad_id] * max_item_length)
76
+ temp_position_ids[:feat['position_ids'].shape[0]] = feat['position_ids']
77
+ feat['position_ids'] = temp_position_ids
78
+
79
+ if 'loss_weight' in feat:
80
+ temp_loss_weight = torch.FloatTensor([pad_id] * max_item_length)
81
+ temp_loss_weight[:feat['loss_weight'].shape[0]] = feat['loss_weight']
82
+ feat['loss_weight'] = temp_loss_weight
83
+
84
+ # Special handling for labels.
85
+ # Ensure that tensor is created with the correct type
86
+ # (it should be automatically the case, but let's make sure of it.)
87
+ if 'label' in first and first['label'] is not None:
88
+ label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label']
89
+ dtype = torch.long if isinstance(label, int) else torch.float
90
+ batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype)
91
+ elif 'label_ids' in first and first['label_ids'] is not None:
92
+ if isinstance(first['label_ids'], torch.Tensor):
93
+ batch['labels'] = torch.stack([f['label_ids'] for f in features])
94
+ else:
95
+ dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float
96
+ batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype)
97
+
98
+ # Handling of all other possible keys.
99
+ # Again, we will use the first element to figure out which key/values are not None for this model.
100
+ for k, v in first.items():
101
+ if k not in ('label', 'label_ids', 'pixel_values', 'image_flags') and \
102
+ v is not None and not isinstance(v, str):
103
+ if isinstance(v, torch.Tensor):
104
+ batch[k] = torch.stack([f[k] for f in features])
105
+ elif isinstance(v, np.ndarray):
106
+ batch[k] = torch.tensor(np.stack([f[k] for f in features]))
107
+ else:
108
+ batch[k] = torch.tensor([f[k] for f in features])
109
+ if k in ('pixel_values', 'image_flags'):
110
+ if isinstance(v, torch.Tensor):
111
+ batch[k] = torch.concat([f[k] for f in features])
112
+ elif isinstance(v, np.ndarray):
113
+ batch[k] = torch.concat(np.stack([f[k] for f in features]))
114
+ else:
115
+ batch[k] = torch.concat([f[k] for f in features])
116
+ return batch
117
+
118
+
119
+ def dpo_concat_pad_data_collator(features, pad_id=0):
120
+
121
+ first = features[0]
122
+ batch = {}
123
+
124
+ for prefix in ['chosen_', 'rejected_']:
125
+ batch_lens = [feat[f'{prefix}input_ids'].shape[0] for feat in features]
126
+ max_item_length = max(batch_lens)
127
+ for idx in range(len(features)):
128
+ feat = features[idx]
129
+ temp_input_ids = torch.LongTensor([pad_id] * max_item_length)
130
+ temp_input_ids[:feat[f'{prefix}input_ids'].shape[0]] = feat[f'{prefix}input_ids']
131
+ feat[f'{prefix}input_ids'] = temp_input_ids
132
+ temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length)
133
+ temp_labels[:feat[f'{prefix}labels'].shape[0]] = feat[f'{prefix}labels']
134
+ feat[f'{prefix}labels'] = temp_labels
135
+ feat[f'{prefix}attention_mask'] = feat[f'{prefix}input_ids'].ne(pad_id)
136
+
137
+ # Handling of all other possible keys.
138
+ # Again, we will use the first element to figure out which key/values are not None for this model.
139
+ for k, v in first.items():
140
+ if k not in ('pixel_values', 'image_flags') and \
141
+ v is not None and not isinstance(v, str):
142
+ if isinstance(v, torch.Tensor):
143
+ batch[k] = torch.stack([f[k] for f in features])
144
+ elif isinstance(v, np.ndarray):
145
+ batch[k] = torch.tensor(np.stack([f[k] for f in features]))
146
+ else:
147
+ batch[k] = torch.tensor([f[k] for f in features])
148
+ if k in ('pixel_values', 'image_flags'):
149
+ if isinstance(v, torch.Tensor):
150
+ batch[k] = torch.concat([f[k] for f in features])
151
+ elif isinstance(v, np.ndarray):
152
+ batch[k] = torch.concat(np.stack([f[k] for f in features]))
153
+ else:
154
+ batch[k] = torch.concat([f[k] for f in features])
155
+ return batch
src/third_party/InternVL/internvl_chat/internvl/patch/phi3_packed_training_patch.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import torch
8
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
9
+ from internvl.model.phi3.modeling_phi3 import (PHI3_ATTENTION_CLASSES,
10
+ Phi3FlashAttention2)
11
+
12
+
13
+ class Phi3FlashAttention2ForPackedTraining(Phi3FlashAttention2):
14
+
15
+ def _flash_attention_forward(
16
+ self,
17
+ query_states,
18
+ key_states,
19
+ value_states,
20
+ attention_mask,
21
+ query_length,
22
+ dropout=0.0,
23
+ softmax_scale=None,
24
+ use_sliding_windows=False,
25
+ ):
26
+ """
27
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
28
+ first unpad the input, then computes the attention scores and pad the final attention scores.
29
+
30
+ Args:
31
+ query_states (`torch.Tensor`):
32
+ Input query states to be passed to Flash Attention API
33
+ key_states (`torch.Tensor`):
34
+ Input key states to be passed to Flash Attention API
35
+ value_states (`torch.Tensor`):
36
+ Input value states to be passed to Flash Attention API
37
+ attention_mask (`torch.Tensor`):
38
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
39
+ position of padding tokens and 1 for the position of non-padding tokens.
40
+ dropout (`float`):
41
+ Attention dropout
42
+ softmax_scale (`float`, *optional*):
43
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
44
+ use_sliding_windows (`bool`, *optional*):
45
+ Whether to activate sliding window attention.
46
+ """
47
+ assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1
48
+ query_states = query_states.squeeze(0)
49
+ key_states = key_states.squeeze(0)
50
+ value_states = value_states.squeeze(0)
51
+ cu_seqlens = attention_mask.squeeze(0)
52
+
53
+ with torch.no_grad():
54
+ max_seqlen = max([
55
+ cu_seqlens[idx+1] - cu_seqlens[idx]
56
+ for idx in range(cu_seqlens.size(0) - 1)
57
+ ]).item()
58
+
59
+ if not self._flash_attn_uses_top_left_mask:
60
+ causal = self.is_causal
61
+ else:
62
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
63
+ causal = self.is_causal and query_length != 1
64
+
65
+ # Decide whether to use SWA or not by layer index.
66
+ if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
67
+ use_sliding_windows = False
68
+
69
+ if not use_sliding_windows:
70
+ attn_output = flash_attn_varlen_func(
71
+ q=query_states,
72
+ k=key_states,
73
+ v=value_states,
74
+ cu_seqlens_q=cu_seqlens,
75
+ cu_seqlens_k=cu_seqlens,
76
+ max_seqlen_q=max_seqlen,
77
+ max_seqlen_k=max_seqlen,
78
+ dropout_p=dropout,
79
+ softmax_scale=softmax_scale,
80
+ causal=causal,
81
+ )
82
+ else:
83
+ attn_output = flash_attn_varlen_func(
84
+ q=query_states,
85
+ k=key_states,
86
+ v=value_states,
87
+ cu_seqlens_q=cu_seqlens,
88
+ cu_seqlens_k=cu_seqlens,
89
+ max_seqlen_q=max_seqlen,
90
+ max_seqlen_k=max_seqlen,
91
+ dropout_p=dropout,
92
+ softmax_scale=softmax_scale,
93
+ causal=causal,
94
+ window_size=(self.config.sliding_window, self.config.sliding_window),
95
+ )
96
+
97
+ query_states = query_states.unsqueeze(0)
98
+ key_states = key_states.unsqueeze(0)
99
+ value_states = value_states.unsqueeze(0)
100
+ return attn_output
101
+
102
+
103
+ def replace_phi3_attention_class():
104
+ PHI3_ATTENTION_CLASSES['flash_attention_2'] = Phi3FlashAttention2ForPackedTraining
105
+ print('Replace PHI3_ATTENTION_CLASSES to support packed training!!')
src/third_party/InternVL/internvl_chat/internvl/patch/qwen2_packed_training_patch.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import torch
8
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
9
+ from transformers.models.qwen2.modeling_qwen2 import (QWEN2_ATTENTION_CLASSES,
10
+ Qwen2FlashAttention2)
11
+
12
+
13
+ # Modified from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2
14
+ class Qwen2FlashAttention2ForPackedTraining(Qwen2FlashAttention2):
15
+
16
+ def _flash_attention_forward(
17
+ self,
18
+ query_states,
19
+ key_states,
20
+ value_states,
21
+ attention_mask,
22
+ query_length,
23
+ dropout=0.0,
24
+ softmax_scale=None,
25
+ use_sliding_windows=False,
26
+ ):
27
+ """
28
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
29
+ first unpad the input, then computes the attention scores and pad the final attention scores.
30
+
31
+ Args:
32
+ query_states (`torch.Tensor`):
33
+ Input query states to be passed to Flash Attention API
34
+ key_states (`torch.Tensor`):
35
+ Input key states to be passed to Flash Attention API
36
+ value_states (`torch.Tensor`):
37
+ Input value states to be passed to Flash Attention API
38
+ attention_mask (`torch.Tensor`):
39
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
40
+ position of padding tokens and 1 for the position of non-padding tokens.
41
+ dropout (`int`, *optional*):
42
+ Attention dropout
43
+ softmax_scale (`float`, *optional*):
44
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
45
+ use_sliding_windows (`bool`, *optional*):
46
+ Whether to activate sliding window attention.
47
+ """
48
+ assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1
49
+ query_states = query_states.squeeze(0)
50
+ key_states = key_states.squeeze(0)
51
+ value_states = value_states.squeeze(0)
52
+ cu_seqlens = attention_mask.squeeze(0)
53
+
54
+ with torch.no_grad():
55
+ max_seqlen = max([
56
+ cu_seqlens[idx+1] - cu_seqlens[idx]
57
+ for idx in range(cu_seqlens.size(0) - 1)
58
+ ]).item()
59
+
60
+ if not self._flash_attn_uses_top_left_mask:
61
+ causal = self.is_causal
62
+ else:
63
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
64
+ causal = self.is_causal and query_length != 1
65
+
66
+ # Decide whether to use SWA or not by layer index.
67
+ if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
68
+ use_sliding_windows = False
69
+
70
+ if not use_sliding_windows:
71
+ attn_output = flash_attn_varlen_func(
72
+ q=query_states,
73
+ k=key_states,
74
+ v=value_states,
75
+ cu_seqlens_q=cu_seqlens,
76
+ cu_seqlens_k=cu_seqlens,
77
+ max_seqlen_q=max_seqlen,
78
+ max_seqlen_k=max_seqlen,
79
+ dropout_p=dropout,
80
+ softmax_scale=softmax_scale,
81
+ causal=causal,
82
+ )
83
+ else:
84
+ attn_output = flash_attn_varlen_func(
85
+ q=query_states,
86
+ k=key_states,
87
+ v=value_states,
88
+ cu_seqlens_q=cu_seqlens,
89
+ cu_seqlens_k=cu_seqlens,
90
+ max_seqlen_q=max_seqlen,
91
+ max_seqlen_k=max_seqlen,
92
+ dropout_p=dropout,
93
+ softmax_scale=softmax_scale,
94
+ causal=causal,
95
+ window_size=(self.config.sliding_window, self.config.sliding_window),
96
+ )
97
+
98
+ query_states = query_states.unsqueeze(0)
99
+ key_states = key_states.unsqueeze(0)
100
+ value_states = value_states.unsqueeze(0)
101
+ return attn_output
102
+
103
+
104
+ def replace_qwen2_attention_class():
105
+ QWEN2_ATTENTION_CLASSES['flash_attention_2'] = Qwen2FlashAttention2ForPackedTraining
106
+ print('Replace QWEN2_ATTENTION_CLASSES to support packed training!!')
src/third_party/InternVL/internvl_chat/internvl/patch/train_dataloader_patch.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import datasets
8
+ import torch
9
+ import transformers
10
+ from torch.utils.data import DataLoader
11
+ from transformers.trainer import is_datasets_available, seed_worker
12
+
13
+
14
+ def get_train_dataloader(self) -> DataLoader:
15
+ """
16
+ Returns the training [`~torch.utils.data.DataLoader`].
17
+
18
+ Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
19
+ training if necessary) otherwise.
20
+
21
+ Subclass and override this method if you want to inject some custom behavior.
22
+ """
23
+ if self.train_dataset is None:
24
+ raise ValueError('Trainer: training requires a train_dataset.')
25
+
26
+ train_dataset = self.train_dataset
27
+ data_collator = self.data_collator
28
+ if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
29
+ train_dataset = self._remove_unused_columns(train_dataset, description='training')
30
+ else:
31
+ data_collator = self._get_collator_with_removed_columns(data_collator, description='training')
32
+
33
+ dataloader_params = {
34
+ 'batch_size': self._train_batch_size,
35
+ 'collate_fn': data_collator,
36
+ 'num_workers': self.args.dataloader_num_workers,
37
+ 'pin_memory': self.args.dataloader_pin_memory,
38
+ 'persistent_workers': self.args.dataloader_persistent_workers,
39
+ }
40
+
41
+ if not isinstance(train_dataset, torch.utils.data.IterableDataset):
42
+ dataloader_params['sampler'] = self._get_train_sampler()
43
+ dataloader_params['drop_last'] = self.args.dataloader_drop_last
44
+ dataloader_params['worker_init_fn'] = seed_worker
45
+
46
+ if self.args.use_packed_ds:
47
+ return DataLoader(train_dataset, **dataloader_params)
48
+ return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
49
+
50
+
51
+ def replace_train_dataloader():
52
+ transformers.Trainer.get_train_dataloader = get_train_dataloader
53
+ # print('Replace train dataloader!!')
src/third_party/InternVL/internvl_chat/internvl/patch/train_sampler_patch.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ from typing import List, Optional
8
+
9
+ import torch
10
+ import transformers
11
+ from torch.utils.data import Dataset, Sampler
12
+ from transformers.tokenization_utils_base import BatchEncoding
13
+ from transformers.trainer import (LengthGroupedSampler, RandomSampler,
14
+ has_length)
15
+ from transformers.trainer_pt_utils import logger
16
+
17
+
18
+ # copy from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L38
19
+ def split_to_even_chunks(indices, lengths, num_chunks):
20
+ """
21
+ Split a list of indices into `chunks` chunks of roughly equal lengths.
22
+ """
23
+
24
+ if len(indices) % num_chunks != 0:
25
+ return [indices[i::num_chunks] for i in range(num_chunks)]
26
+
27
+ num_indices_per_chunk = len(indices) // num_chunks
28
+
29
+ chunks = [[] for _ in range(num_chunks)]
30
+ chunks_lengths = [0 for _ in range(num_chunks)]
31
+ for index in indices:
32
+ shortest_chunk = chunks_lengths.index(min(chunks_lengths))
33
+ chunks[shortest_chunk].append(index)
34
+ chunks_lengths[shortest_chunk] += lengths[index]
35
+ if len(chunks[shortest_chunk]) == num_indices_per_chunk:
36
+ chunks_lengths[shortest_chunk] = float('inf')
37
+
38
+ return chunks
39
+
40
+
41
+ # copy from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L88
42
+ def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
43
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
44
+ indices = torch.randperm(len(lengths), generator=generator)
45
+ megabatch_size = world_size * batch_size
46
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
47
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
48
+ megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
49
+
50
+ return [i for megabatch in megabatches for batch in megabatch for i in batch]
51
+
52
+
53
+ # modified from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L99
54
+ class LengthGroupedSampler(Sampler):
55
+ r"""
56
+ Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
57
+ keeping a bit of randomness.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ batch_size: int,
63
+ world_size: int,
64
+ dataset: Optional[Dataset] = None,
65
+ lengths: Optional[List[int]] = None,
66
+ model_input_name: Optional[str] = None,
67
+ generator=None,
68
+ ):
69
+ if dataset is None and lengths is None:
70
+ raise ValueError('One of dataset and lengths must be provided.')
71
+
72
+ self.batch_size = batch_size
73
+ if lengths is None:
74
+ model_input_name = model_input_name if model_input_name is not None else 'input_ids'
75
+ if (
76
+ not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
77
+ or model_input_name not in dataset[0]
78
+ ):
79
+ raise ValueError(
80
+ 'Can only automatically infer lengths for datasets whose items are dictionaries with an '
81
+ f"'{model_input_name}' key."
82
+ )
83
+ lengths = [len(feature[model_input_name]) for feature in dataset]
84
+ elif isinstance(lengths, torch.Tensor):
85
+ logger.info(
86
+ 'If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]...'
87
+ )
88
+ lengths = lengths.tolist()
89
+ self.world_size = world_size
90
+ self.lengths = lengths
91
+ self.generator = generator
92
+
93
+ def __len__(self):
94
+ return len(self.lengths)
95
+
96
+ def __iter__(self):
97
+ indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
98
+ return iter(indices)
99
+
100
+
101
+ # patch trainer
102
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
103
+ if self.train_dataset is None or not has_length(self.train_dataset):
104
+ return None
105
+ # Build the sampler.
106
+ if self.args.group_by_length:
107
+ lengths = []
108
+ for dataset in self.train_dataset.datasets:
109
+ lengths = lengths + dataset.length
110
+ model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
111
+ return LengthGroupedSampler(
112
+ self.args.train_batch_size,
113
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps,
114
+ # self.args.train_batch_size * self.args.gradient_accumulation_steps,
115
+ dataset=self.train_dataset,
116
+ lengths=lengths,
117
+ model_input_name=model_input_name,
118
+ )
119
+ else:
120
+ return RandomSampler(self.train_dataset)
121
+
122
+
123
+ def replace_train_sampler():
124
+ transformers.Trainer._get_train_sampler = _get_train_sampler
125
+ # print('Replace train sampler!!')
src/third_party/InternVL/internvl_chat/internvl/train/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
src/third_party/InternVL/internvl_chat/internvl/train/constants.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
8
+ IMG_START_TOKEN = '<img>'
9
+ IMG_END_TOKEN = '</img>'
10
+ QUAD_START_TOKEN = '<quad>'
11
+ QUAD_END_TOKEN = '</quad>'
12
+ REF_START_TOKEN = '<ref>'
13
+ REF_END_TOKEN = '</ref>'
14
+ BOX_START_TOKEN = '<box>'
15
+ BOX_END_TOKEN = '</box>'
16
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
17
+ IMAGENET_STD = (0.229, 0.224, 0.225)
18
+ CLIP_MEAN = (0.4814546, 0.4578275, 0.40821073)
19
+ CLIP_STD = (0.2686295, 0.2613025, 0.2757711)
20
+ SIGLIP_MEAN = (0.5, 0.5, 0.5)
21
+ SIGLIP_STD = (0.5, 0.5, 0.5)
src/third_party/InternVL/internvl_chat/internvl/train/dataset.py ADDED
@@ -0,0 +1,866 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import io
8
+
9
+ from transformers.trainer_pt_utils import LabelSmoother
10
+
11
+ IGNORE_TOKEN_ID = LabelSmoother.ignore_index
12
+ import os
13
+ import random
14
+ import re
15
+ from collections import Counter
16
+ from typing import Dict
17
+
18
+ import cv2
19
+ import imageio
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import torchvision.transforms as T
24
+ import transformers
25
+ from decord import VideoReader
26
+ from internvl.conversation import get_conv_template
27
+ from PIL import Image
28
+ from torch.utils.data import ConcatDataset, WeightedRandomSampler
29
+ from torchvision.transforms.functional import InterpolationMode
30
+
31
+ from .constants import (CLIP_MEAN, CLIP_STD, IMAGENET_MEAN, IMAGENET_STD,
32
+ IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN,
33
+ SIGLIP_MEAN, SIGLIP_STD)
34
+
35
+ try:
36
+ from petrel_client.client import Client
37
+ from petrel_client.common.config import Config
38
+ except ImportError as E:
39
+ print('petrel_client is not installed. If you read data locally instead of from ceph, ignore it.')
40
+ import sys
41
+
42
+
43
+ def calculate_ngram_repetition(text, n):
44
+ words = text.split()
45
+ ngrams = [tuple(words[i:i+n]) for i in range(len(words)-n+1)]
46
+ ngram_counts = Counter(ngrams)
47
+ total_ngrams = len(ngrams)
48
+ repeated_ngrams = sum(1 for count in ngram_counts.values() if count > 1)
49
+ return repeated_ngrams / total_ngrams if total_ngrams > 0 else 0
50
+
51
+
52
+ def check_conversations_repetition(conversations, repeat_threshold=0.4, ngram=10):
53
+ for conversation in conversations:
54
+ if conversation['from'] == 'gpt':
55
+ model_answer = conversation['value']
56
+ repeat_ratio = calculate_ngram_repetition(model_answer, ngram)
57
+ if repeat_ratio > repeat_threshold:
58
+ raise Exception
59
+
60
+
61
+ def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
62
+ if sample in ['rand', 'middle']: # uniform sampling
63
+ acc_samples = min(num_frames, vlen)
64
+ # split the video into `acc_samples` intervals, and sample from each interval.
65
+ intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
66
+ ranges = []
67
+ for idx, interv in enumerate(intervals[:-1]):
68
+ ranges.append((interv, intervals[idx + 1] - 1))
69
+ if sample == 'rand':
70
+ try:
71
+ frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
72
+ except:
73
+ frame_indices = np.random.permutation(vlen)[:acc_samples]
74
+ frame_indices.sort()
75
+ frame_indices = list(frame_indices)
76
+ elif fix_start is not None:
77
+ frame_indices = [x[0] + fix_start for x in ranges]
78
+ elif sample == 'middle':
79
+ frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
80
+ else:
81
+ raise NotImplementedError
82
+
83
+ if len(frame_indices) < num_frames: # padded with last frame
84
+ padded_frame_indices = [frame_indices[-1]] * num_frames
85
+ padded_frame_indices[:len(frame_indices)] = frame_indices
86
+ frame_indices = padded_frame_indices
87
+ elif 'fps' in sample: # fps0.5, sequentially sample frames at 0.5 fps
88
+ output_fps = float(sample[3:])
89
+ duration = float(vlen) / input_fps
90
+ delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
91
+ frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
92
+ frame_indices = np.around(frame_seconds * input_fps).astype(int)
93
+ frame_indices = [e for e in frame_indices if e < vlen]
94
+ if max_num_frames > 0 and len(frame_indices) > max_num_frames:
95
+ frame_indices = frame_indices[:max_num_frames]
96
+ # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
97
+ else:
98
+ raise ValueError
99
+ return frame_indices
100
+
101
+
102
+ def read_frames_gif(
103
+ video_path, num_frames, sample='rand', fix_start=None,
104
+ client=None, min_num_frames=4
105
+ ):
106
+ if 's3://' in video_path:
107
+ video_bytes = client.get(video_path)
108
+ gif = imageio.get_reader(io.BytesIO(video_bytes))
109
+ else:
110
+ gif = imageio.get_reader(video_path)
111
+ vlen = len(gif)
112
+
113
+ t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
114
+ frame_indices = get_frame_indices(
115
+ t_num_frames, vlen, sample=sample, fix_start=fix_start
116
+ )
117
+ frames = []
118
+ for index, frame in enumerate(gif):
119
+ if index in frame_indices:
120
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB).astype(np.uint8)
121
+ frame = Image.fromarray(frame)
122
+ frames.append(frame)
123
+ return frames
124
+
125
+
126
+ def read_frames_decord(
127
+ video_path, num_frames, sample='rand', fix_start=None,
128
+ client=None, clip=None, min_num_frames=4
129
+ ):
130
+ if 's3://' in video_path:
131
+ video_bytes = client.get(video_path)
132
+ video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1)
133
+ else:
134
+ video_reader = VideoReader(video_path, num_threads=1)
135
+ vlen = len(video_reader)
136
+ fps = video_reader.get_avg_fps()
137
+ duration = vlen / float(fps)
138
+ if clip:
139
+ start, end = clip
140
+ duration = end - start
141
+ vlen = int(duration * fps)
142
+ start_index = int(start * fps)
143
+
144
+ # t_num_frames = min(max(int(duration * sample_fps), min_num_frames), num_frames)
145
+ t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
146
+
147
+ frame_indices = get_frame_indices(
148
+ t_num_frames, vlen, sample=sample, fix_start=fix_start,
149
+ input_fps=fps
150
+ )
151
+ if clip:
152
+ frame_indices = [f + start_index for f in frame_indices]
153
+ frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C), np.uint8
154
+ frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
155
+ return frames
156
+
157
+
158
+ def extract_frame_number(filename):
159
+ # Extract the numeric part from the filename using regular expressions
160
+ match = re.search(r'_(\d+).jpg$', filename)
161
+ return int(match.group(1)) if match else -1
162
+
163
+
164
+ def sort_frames(frame_paths):
165
+ # Extract filenames from each path and sort by their numeric part
166
+ return sorted(frame_paths, key=lambda x: extract_frame_number(os.path.basename(x)))
167
+
168
+
169
+ def read_frames_folder(
170
+ video_path, num_frames, sample='rand', fix_start=None,
171
+ client=None, clip=None, min_num_frames=4
172
+ ):
173
+ if 's3://' in video_path:
174
+ image_list = sort_frames(client.list(video_path))
175
+ frames = []
176
+ for image in image_list:
177
+ fp = os.path.join(video_path, image)
178
+ frame = Image.open(io.BytesIO(client.get(fp)))
179
+ frames.append(frame)
180
+ else:
181
+ image_list = sort_frames(list(os.listdir(video_path)))
182
+ frames = []
183
+ for image in image_list:
184
+ fp = os.path.join(video_path, image)
185
+ frame = Image.open(fp).convert('RGB')
186
+ frames.append(frame)
187
+ vlen = len(frames)
188
+
189
+ t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
190
+
191
+ if vlen > t_num_frames:
192
+ frame_indices = get_frame_indices(
193
+ t_num_frames, vlen, sample=sample, fix_start=fix_start
194
+ )
195
+ frames = [frames[i] for i in frame_indices]
196
+ return frames
197
+
198
+
199
+ class WeightedConcatDataset(ConcatDataset):
200
+ def __init__(self, datasets, weights):
201
+ super().__init__(datasets)
202
+ self.weights = torch.DoubleTensor(weights)
203
+ self.total_size = sum(len(d) for d in datasets)
204
+ self.sampler = WeightedRandomSampler(weights=self.weights, num_samples=self.total_size, replacement=True)
205
+
206
+ def __iter__(self):
207
+ return iter(self.sampler)
208
+
209
+ def __len__(self):
210
+ return self.total_size
211
+
212
+
213
+ def pil_loader(img_str):
214
+ buff = io.BytesIO(img_str)
215
+ img = Image.open(buff)
216
+ return img.convert('RGB')
217
+
218
+
219
+ class TCSLoader(object):
220
+
221
+ def __init__(self, conf_path, sc_config_key='sensecore'):
222
+ print(f'[TCSLoader] config_path: {conf_path}')
223
+ print('--> before Client(conf_path)')
224
+ self.client = Client(conf_path)
225
+ self.sc_config_key = sc_config_key
226
+ print('--> after Client(conf_path)')
227
+
228
+ def __call__(self, fn, image_type='image', max_num_frames=-1, min_num_frames=8, sample='rand', clip=None):
229
+ if image_type == 'image':
230
+ img_value_str = self.client.get(fn)
231
+ img = pil_loader(img_value_str)
232
+ return img
233
+
234
+ elif image_type == 'video':
235
+ if fn.endswith('/'):
236
+ frames = read_frames_folder(fn, num_frames=max_num_frames, min_num_frames=min_num_frames,
237
+ client=self.client, sample=sample)
238
+ elif fn.endswith('.gif'):
239
+ frames = read_frames_gif(fn, num_frames=max_num_frames, min_num_frames=min_num_frames,
240
+ client=self.client, sample=sample)
241
+ else:
242
+ frames = read_frames_decord(fn, num_frames=max_num_frames, min_num_frames=min_num_frames,
243
+ client=self.client, sample=sample, clip=clip)
244
+ return frames
245
+
246
+
247
+ def expand2square(pil_img, background_color):
248
+ width, height = pil_img.size
249
+ if width == height:
250
+ return pil_img
251
+ elif width > height:
252
+ result = Image.new(pil_img.mode, (width, width), background_color)
253
+ result.paste(pil_img, (0, (width - height) // 2))
254
+ return result
255
+ else:
256
+ result = Image.new(pil_img.mode, (height, height), background_color)
257
+ result.paste(pil_img, ((height - width) // 2, 0))
258
+ return result
259
+
260
+
261
+ def simulate_jpeg_degradation(quality):
262
+ def jpeg_degrade(img):
263
+ with io.BytesIO() as output:
264
+ img.convert('RGB').save(output, format='JPEG', quality=quality)
265
+ output.seek(0) # Move the reading cursor to the start of the stream
266
+ img_jpeg = Image.open(output).copy() # Use .copy() to make sure the image is loaded in memory
267
+ return img_jpeg
268
+ return jpeg_degrade
269
+
270
+
271
+ # Define the JPEG compression quality range, pre-create all JPEG compression functions
272
+ qualities = list(range(75, 101))
273
+ jpeg_degrade_functions = {quality: simulate_jpeg_degradation(quality) for quality in qualities}
274
+
275
+
276
+ def build_transform(is_train, input_size, pad2square=False, normalize_type='imagenet'):
277
+ if normalize_type == 'imagenet':
278
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
279
+ elif normalize_type == 'clip':
280
+ MEAN, STD = CLIP_MEAN, CLIP_STD
281
+ elif normalize_type == 'siglip':
282
+ MEAN, STD = SIGLIP_MEAN, SIGLIP_STD
283
+ else:
284
+ raise NotImplementedError
285
+ if is_train: # use data augumentation
286
+ transform = T.Compose([
287
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
288
+ T.RandomChoice([T.Lambda(jpeg_degrade_functions[quality]) for quality in qualities]),
289
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
290
+ T.ToTensor(),
291
+ T.Normalize(mean=MEAN, std=STD)
292
+ ])
293
+ else:
294
+ if pad2square is False: # now we use this transform function by default
295
+ transform = T.Compose([
296
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
297
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
298
+ T.ToTensor(),
299
+ T.Normalize(mean=MEAN, std=STD)
300
+ ])
301
+ else:
302
+ transform = T.Compose([
303
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
304
+ T.Lambda(lambda img: expand2square(img, tuple(int(x * 255) for x in MEAN))),
305
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
306
+ T.ToTensor(),
307
+ T.Normalize(mean=MEAN, std=STD)
308
+ ])
309
+
310
+ return transform
311
+
312
+
313
+ def preprocess(
314
+ template_name,
315
+ sources,
316
+ tokenizer: transformers.PreTrainedTokenizer,
317
+ num_image_token_list: list,
318
+ text_only: bool = False,
319
+ group_by_length: bool = False,
320
+ use_packed_ds: bool = False,
321
+ ds_name: str = None,
322
+ num_image: int = 1
323
+ ) -> Dict:
324
+ conv = get_conv_template(template_name)
325
+ roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
326
+
327
+ # Apply prompt templates
328
+ conversations = []
329
+ for i, source in enumerate(sources):
330
+ if roles[source[0]['from']] != conv.roles[0]:
331
+ # Skip the first one if it is not from human
332
+ source = source[1:]
333
+
334
+ conv.messages = []
335
+ for j, sentence in enumerate(source):
336
+ role = roles[sentence['from']]
337
+ assert role == conv.roles[j % 2], f'{i}'
338
+ conv.append_message(role, sentence['value'])
339
+ conversations.append(conv.get_prompt())
340
+
341
+ if not text_only:
342
+ new_conversations = []
343
+ for conversation in conversations:
344
+ for i in range(num_image):
345
+ image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
346
+ conversation = conversation.replace('<image>', image_tokens, 1)
347
+ new_conversations.append(conversation)
348
+ conversations = new_conversations
349
+
350
+ # Tokenize conversations
351
+ input_ids = tokenizer(
352
+ conversations,
353
+ return_tensors='pt',
354
+ padding=False if group_by_length or use_packed_ds else 'max_length',
355
+ max_length=tokenizer.model_max_length,
356
+ truncation=True,
357
+ ).input_ids
358
+ targets = input_ids.clone()
359
+
360
+ # assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO
361
+
362
+ # Mask targets. Only compute loss on the assistant outputs.
363
+ sep = conv.sep + conv.roles[1] + ': '
364
+ for conversation, target in zip(conversations, targets):
365
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
366
+
367
+ turns = conversation.split(conv.sep2)
368
+ cur_len = 1
369
+ target[:cur_len] = IGNORE_TOKEN_ID
370
+ for i, turn in enumerate(turns):
371
+ if turn == '':
372
+ break
373
+ turn_len = len(tokenizer(turn).input_ids)
374
+
375
+ parts = turn.split(sep)
376
+ if len(parts) != 2:
377
+ break
378
+ parts[0] += sep
379
+ # "-2" is hardcoded for the Llama tokenizer to make the offset correct.
380
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
381
+
382
+ if i != 0 and not tokenizer.legacy:
383
+ # The legacy and non-legacy modes handle special tokens differently
384
+ instruction_len -= 1
385
+
386
+ # Ignore the user instructions
387
+ target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
388
+ cur_len += turn_len
389
+
390
+ if i != 0 and not tokenizer.legacy:
391
+ # The legacy and non-legacy modes handle special tokens differently
392
+ cur_len -= 1
393
+
394
+ target[cur_len:] = IGNORE_TOKEN_ID
395
+
396
+ if False: # Inspect and check the correctness of masking
397
+ z = target.clone()
398
+ z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
399
+ logger.info(tokenizer.decode(z))
400
+ exit()
401
+
402
+ if cur_len < tokenizer.model_max_length:
403
+ if cur_len != total_len:
404
+ target[:] = IGNORE_TOKEN_ID
405
+ print(
406
+ f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
407
+ f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
408
+ )
409
+ sys.stdout.flush()
410
+
411
+ return dict(
412
+ input_ids=input_ids,
413
+ labels=targets,
414
+ attention_mask=input_ids.ne(tokenizer.pad_token_id),
415
+ )
416
+
417
+
418
+ def preprocess_mpt(
419
+ template_name,
420
+ sources,
421
+ tokenizer: transformers.PreTrainedTokenizer,
422
+ num_image_token_list: list,
423
+ text_only: bool = False,
424
+ group_by_length: bool = False,
425
+ use_packed_ds: bool = False,
426
+ ds_name: str = None,
427
+ num_image: int = 1
428
+ ) -> Dict:
429
+ conv = get_conv_template(template_name)
430
+ roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
431
+
432
+ # Apply prompt templates
433
+ conversations = []
434
+ for i, source in enumerate(sources):
435
+ if roles[source[0]['from']] != conv.roles[0]:
436
+ # Skip the first one if it is not from human
437
+ source = source[1:]
438
+
439
+ conv.messages = []
440
+ for j, sentence in enumerate(source):
441
+ role = roles[sentence['from']]
442
+ assert role == conv.roles[j % 2], f'{i}'
443
+ conv.append_message(role, sentence['value'])
444
+ conversations.append(conv.get_prompt())
445
+
446
+ if not text_only:
447
+ new_conversations = []
448
+ for conversation in conversations:
449
+ for i in range(num_image):
450
+ image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
451
+ conversation = conversation.replace('<image>', image_tokens, 1)
452
+ new_conversations.append(conversation)
453
+ conversations = new_conversations
454
+
455
+ # Tokenize conversations
456
+ input_ids = tokenizer(
457
+ conversations,
458
+ return_tensors='pt',
459
+ padding=False if group_by_length or use_packed_ds else 'max_length',
460
+ max_length=tokenizer.model_max_length,
461
+ truncation=True,
462
+ ).input_ids
463
+ targets = input_ids.clone()
464
+
465
+ # Mask targets. Only compute loss on the assistant outputs.
466
+ sep = conv.sep + conv.roles[1] # <|im_end|><|im_start|>assistant\n
467
+ for conversation, target in zip(conversations, targets):
468
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
469
+
470
+ turns = conversation.split(conv.sep)
471
+ re_turns = [conv.sep.join(turns[:3])] # system + user + gpt
472
+ for conv_idx in range(3, len(turns), 2):
473
+ re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt
474
+ cur_len = 0
475
+ target[:cur_len] = IGNORE_TOKEN_ID
476
+ for i, turn in enumerate(re_turns):
477
+ if turn == '':
478
+ break
479
+ turn_len = len(tokenizer(turn).input_ids) + 1
480
+
481
+ parts = turn.split(sep)
482
+ if len(parts) != 2:
483
+ break
484
+ parts[0] += sep
485
+ instruction_len = len(tokenizer(parts[0]).input_ids)
486
+
487
+ # Ignore the user instructions
488
+ target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
489
+ # print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0]))
490
+ # print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0]))
491
+ # print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len])
492
+ cur_len += turn_len
493
+
494
+ target[cur_len:] = IGNORE_TOKEN_ID
495
+
496
+ if cur_len < tokenizer.model_max_length:
497
+ if cur_len != total_len:
498
+ target[:] = IGNORE_TOKEN_ID
499
+ print(
500
+ f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
501
+ f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
502
+ )
503
+ sys.stdout.flush()
504
+
505
+ return dict(
506
+ input_ids=input_ids,
507
+ labels=targets,
508
+ attention_mask=input_ids.ne(tokenizer.pad_token_id),
509
+ )
510
+
511
+
512
+ def preprocess_phi3(
513
+ template_name,
514
+ sources,
515
+ tokenizer: transformers.PreTrainedTokenizer,
516
+ num_image_token_list: list,
517
+ text_only: bool = False,
518
+ group_by_length: bool = False,
519
+ use_packed_ds: bool = False,
520
+ ds_name: str = None,
521
+ num_image: int = 1
522
+ ) -> Dict:
523
+ conv = get_conv_template(template_name)
524
+ roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
525
+
526
+ # Apply prompt templates
527
+ conversations = []
528
+ for i, source in enumerate(sources):
529
+ if roles[source[0]['from']] != conv.roles[0]:
530
+ # Skip the first one if it is not from human
531
+ source = source[1:]
532
+
533
+ conv.messages = []
534
+ for j, sentence in enumerate(source):
535
+ role = roles[sentence['from']]
536
+ assert role == conv.roles[j % 2], f'{i}'
537
+ conv.append_message(role, sentence['value'])
538
+ conversations.append(conv.get_prompt())
539
+
540
+ if not text_only:
541
+ new_conversations = []
542
+ for conversation in conversations:
543
+ for i in range(num_image):
544
+ image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
545
+ conversation = conversation.replace('<image>', image_tokens, 1)
546
+ new_conversations.append(conversation)
547
+ conversations = new_conversations
548
+
549
+ # Tokenize conversations
550
+ tokenizer.padding_side = 'right'
551
+ input_ids = tokenizer(
552
+ conversations,
553
+ return_tensors='pt',
554
+ padding=False if group_by_length or use_packed_ds else 'max_length',
555
+ max_length=tokenizer.model_max_length,
556
+ truncation=True,
557
+ ).input_ids
558
+ targets = input_ids.clone()
559
+
560
+ # Mask targets. Only compute loss on the assistant outputs.
561
+ sep = conv.sep + conv.roles[1] # <|end|>\n<|assistant|>
562
+ for conversation, target in zip(conversations, targets):
563
+ total_len = int(target.ne(int(tokenizer.pad_token_id)).sum())
564
+
565
+ turns = conversation.split(conv.sep)
566
+ re_turns = [conv.sep.join(turns[:3])] # system + user + gpt
567
+ for conv_idx in range(3, len(turns), 2):
568
+ re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt
569
+ cur_len = 1
570
+ target[:cur_len] = IGNORE_TOKEN_ID
571
+ endoftext_id = tokenizer.convert_tokens_to_ids('<|endoftext|>')
572
+ target[target == endoftext_id] = IGNORE_TOKEN_ID
573
+
574
+ for i, turn in enumerate(re_turns):
575
+ if turn == '':
576
+ break
577
+ if i == 0:
578
+ turn_len = len(tokenizer(turn).input_ids)
579
+ else:
580
+ turn_len = len(tokenizer(turn).input_ids) - 1
581
+ parts = turn.split(sep)
582
+ if len(parts) != 2:
583
+ break
584
+ parts[0] += sep
585
+
586
+ if i == 0:
587
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 1
588
+ else:
589
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
590
+
591
+ # Ignore the user instructions
592
+ target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
593
+ # print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0]))
594
+ # print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0]))
595
+ # print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len])
596
+ cur_len += turn_len
597
+
598
+ target[cur_len:] = IGNORE_TOKEN_ID
599
+
600
+ if False: # Inspect and check the correctness of masking
601
+ z = target.clone()
602
+ z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
603
+ print(repr(tokenizer.decode(z)))
604
+
605
+ if cur_len < tokenizer.model_max_length:
606
+ if cur_len != total_len:
607
+ target[:] = IGNORE_TOKEN_ID
608
+ print(
609
+ f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
610
+ f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
611
+ )
612
+ sys.stdout.flush()
613
+
614
+ return dict(
615
+ input_ids=input_ids,
616
+ labels=targets,
617
+ attention_mask=input_ids.ne(tokenizer.pad_token_id),
618
+ )
619
+
620
+
621
+ def preprocess_internlm(
622
+ template_name,
623
+ sources,
624
+ tokenizer: transformers.PreTrainedTokenizer,
625
+ num_image_token_list: list,
626
+ text_only: bool = False,
627
+ group_by_length: bool = False,
628
+ use_packed_ds: bool = False,
629
+ ds_name: str = None,
630
+ num_image: int = 1
631
+ ) -> Dict:
632
+ conv = get_conv_template(template_name)
633
+ roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
634
+
635
+ # Apply prompt templates
636
+ conversations = []
637
+ for i, source in enumerate(sources):
638
+ if roles[source[0]['from']] != conv.roles[0]:
639
+ # Skip the first one if it is not from human
640
+ source = source[1:]
641
+
642
+ conv.messages = []
643
+ for j, sentence in enumerate(source):
644
+ role = roles[sentence['from']]
645
+ assert role == conv.roles[j % 2], f'{i}'
646
+ sentence['value'] = sentence['value'].strip()
647
+ conv.append_message(role, sentence['value'])
648
+ conversations.append(conv.get_prompt())
649
+
650
+ if not text_only:
651
+ new_conversations = []
652
+ for conversation in conversations:
653
+ for i in range(num_image):
654
+ image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
655
+ conversation = conversation.replace('<image>', image_tokens, 1)
656
+ new_conversations.append(conversation)
657
+ conversations = new_conversations
658
+
659
+ # Tokenize conversations
660
+ input_ids = tokenizer(
661
+ conversations,
662
+ return_tensors='pt',
663
+ padding=False if group_by_length or use_packed_ds else 'max_length',
664
+ max_length=tokenizer.model_max_length,
665
+ truncation=True,
666
+ ).input_ids
667
+ targets = input_ids.clone()
668
+
669
+ for conversation, target in zip(conversations, targets):
670
+ total_len = int(target.ne(tokenizer.pad_token_id).sum()) # 浦语里面 pad_token_id = eos_token_id
671
+ cur_len = 1
672
+ target[:cur_len] = IGNORE_TOKEN_ID # <s>
673
+ parts = conversation.split(conv.roles[1]) # [UNUSED_TOKEN_146]assistant\n
674
+ info = parts[0] + conv.roles[1]
675
+ temp_len = len(tokenizer(info).input_ids) - 1 # 去除tokenizer的<s>
676
+ target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID
677
+ cur_len = cur_len + temp_len
678
+
679
+ for index in range(1, len(parts) - 1):
680
+ info = parts[index]
681
+ part1, part2 = info.split(conv.roles[0])
682
+ temp_len = len(tokenizer(part1).input_ids) - 1
683
+ cur_len = cur_len + temp_len
684
+ part = conv.roles[0] + part2 + conv.roles[1]
685
+ temp_len = len(tokenizer(part).input_ids) - 1
686
+ target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID
687
+ cur_len = cur_len + temp_len
688
+ last_info = parts[-1]
689
+ temp_len = len(tokenizer(last_info).input_ids) - 1
690
+ cur_len = cur_len + temp_len
691
+
692
+ target[cur_len:] = IGNORE_TOKEN_ID
693
+ if False: # Inspect and check the correctness of masking
694
+ z = target.clone()
695
+ z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
696
+ print(repr(tokenizer.decode(z)))
697
+
698
+ if cur_len < tokenizer.model_max_length:
699
+ if cur_len != total_len:
700
+ target[:] = IGNORE_TOKEN_ID
701
+ print(f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}. This dataset is {ds_name}.')
702
+ sys.stdout.flush()
703
+
704
+ return dict(
705
+ input_ids=input_ids,
706
+ labels=targets,
707
+ attention_mask=input_ids.ne(tokenizer.pad_token_id),
708
+ )
709
+
710
+
711
+ def preprocess_internvl2_5(
712
+ template_name,
713
+ sources,
714
+ tokenizer: transformers.PreTrainedTokenizer,
715
+ num_image_token_list: list,
716
+ text_only: bool = False,
717
+ group_by_length: bool = False,
718
+ use_packed_ds: bool = False,
719
+ ds_name: str = None,
720
+ num_image: int = 1
721
+ ) -> Dict:
722
+ assert len(sources) == 1, 'process only the first conversations'
723
+ conversations = sources[0]
724
+
725
+ if conversations[0]['from'] == 'system':
726
+ system_prompt = conversations[0]['value']
727
+ conversations = conversations[1:] # remove system prompt
728
+ else:
729
+ conv = get_conv_template(template_name)
730
+ system_prompt = conv.system_message
731
+ # system_prompt = None
732
+
733
+ if not text_only:
734
+ new_conversations = []
735
+ current_image_idx = 0
736
+ for conversation in conversations:
737
+ if conversation['from'] == 'human':
738
+ image_cnt = conversation['value'].count('<image>')
739
+ for i in range(image_cnt):
740
+ if current_image_idx == num_image:
741
+ break
742
+ image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[current_image_idx]}{IMG_END_TOKEN}'
743
+ conversation['value'] = conversation['value'].replace('<image>', image_tokens, 1)
744
+ current_image_idx += 1
745
+ new_conversations.append(conversation)
746
+ conversations = new_conversations
747
+ assert current_image_idx == num_image, f'{current_image_idx} != {num_image}'
748
+
749
+ batches, roles = [], []
750
+ if system_prompt is not None:
751
+ batches.append(f'<|im_start|>system\n{system_prompt}<|im_end|>\n')
752
+ roles.append('system')
753
+ for conversation in conversations:
754
+ if conversation['from'] == 'human':
755
+ batches.append(f'<|im_start|>user\n{conversation["value"]}<|im_end|>\n')
756
+ roles.append('human')
757
+ elif conversation['from'] == 'gpt':
758
+ batches.append(f'<|im_start|>assistant\n{conversation["value"]}<|im_end|>\n')
759
+ roles.append('gpt')
760
+ else:
761
+ raise NotImplementedError
762
+
763
+ add_bos_token = getattr(tokenizer, 'add_bos_token', False)
764
+ if add_bos_token: # for InternLM series
765
+ batches[0] = tokenizer.bos_token + batches[0]
766
+
767
+ # Tokenize conversations
768
+ input_ids = tokenizer(
769
+ batches,
770
+ return_tensors='np',
771
+ padding=False,
772
+ max_length=tokenizer.model_max_length,
773
+ truncation=False,
774
+ ).input_ids
775
+
776
+ if add_bos_token: # for InternLM series
777
+ input_ids = [item[1:] for item in input_ids]
778
+
779
+ final_input_ids, final_targets = [], []
780
+ ignore_ids = tokenizer('<|im_start|>assistant\n', return_tensors='np').input_ids[0]
781
+ ignore_len = ignore_ids.shape[0] - 1 if add_bos_token else ignore_ids.shape[0]
782
+ for role, input_id in zip(roles, input_ids):
783
+ final_input_ids.append(input_id)
784
+ if role == 'system' or role == 'human':
785
+ final_targets.append(np.full(input_id.shape, IGNORE_TOKEN_ID)) # ignore
786
+ elif role == 'gpt':
787
+ target = input_id.copy()
788
+ target[:ignore_len] = IGNORE_TOKEN_ID # ignore loss for `<|im_start|>assistant\n`
789
+ target[-1:] = IGNORE_TOKEN_ID # ignore loss for `\n`
790
+ final_targets.append(target)
791
+ else:
792
+ raise NotImplementedError
793
+ input_ids = torch.tensor(np.concatenate(final_input_ids))[:tokenizer.model_max_length]
794
+ targets = torch.tensor(np.concatenate(final_targets))[:tokenizer.model_max_length]
795
+
796
+ padding = False if group_by_length or use_packed_ds else True
797
+ if padding:
798
+ current_length = input_ids.size(0)
799
+ padding_length = tokenizer.model_max_length - current_length
800
+ input_ids = F.pad(input_ids, (0, padding_length), value=tokenizer.pad_token_id)
801
+ targets = F.pad(targets, (0, padding_length), value=IGNORE_TOKEN_ID)
802
+
803
+ input_ids = input_ids.unsqueeze(0)
804
+ targets = targets.unsqueeze(0)
805
+
806
+ return dict(
807
+ input_ids=input_ids,
808
+ labels=targets,
809
+ attention_mask=input_ids.ne(tokenizer.pad_token_id),
810
+ )
811
+
812
+
813
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
814
+ best_ratio_diff = float('inf')
815
+ best_ratio = (1, 1)
816
+ area = width * height
817
+ for ratio in target_ratios:
818
+ target_aspect_ratio = ratio[0] / ratio[1]
819
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
820
+ if ratio_diff < best_ratio_diff:
821
+ best_ratio_diff = ratio_diff
822
+ best_ratio = ratio
823
+ elif ratio_diff == best_ratio_diff:
824
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
825
+ best_ratio = ratio
826
+ # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
827
+ return best_ratio
828
+
829
+
830
+ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
831
+ orig_width, orig_height = image.size
832
+ aspect_ratio = orig_width / orig_height
833
+
834
+ # calculate the existing image aspect ratio
835
+ target_ratios = set(
836
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
837
+ i * j <= max_num and i * j >= min_num)
838
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
839
+
840
+ # find the closest aspect ratio to the target
841
+ target_aspect_ratio = find_closest_aspect_ratio(
842
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
843
+
844
+ # calculate the target width and height
845
+ target_width = image_size * target_aspect_ratio[0]
846
+ target_height = image_size * target_aspect_ratio[1]
847
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
848
+
849
+ # resize the image
850
+ resized_img = image.resize((target_width, target_height))
851
+ processed_images = []
852
+ for i in range(blocks):
853
+ box = (
854
+ (i % (target_width // image_size)) * image_size,
855
+ (i // (target_width // image_size)) * image_size,
856
+ ((i % (target_width // image_size)) + 1) * image_size,
857
+ ((i // (target_width // image_size)) + 1) * image_size
858
+ )
859
+ # split the image
860
+ split_img = resized_img.crop(box)
861
+ processed_images.append(split_img)
862
+ assert len(processed_images) == blocks
863
+ if use_thumbnail and len(processed_images) != 1:
864
+ thumbnail_img = image.resize((image_size, image_size))
865
+ processed_images.append(thumbnail_img)
866
+ return processed_images
src/third_party/InternVL/internvl_chat/internvl/train/dataset_packed.py ADDED
@@ -0,0 +1,634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import bisect
8
+ import copy
9
+ import logging
10
+ from collections import defaultdict
11
+ from typing import List, Union
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.distributed as dist
16
+ from torch.utils.data import IterableDataset, get_worker_info
17
+ from transformers.trainer_pt_utils import LabelSmoother
18
+
19
+ from .constants import IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN
20
+
21
+ IGNORE_TOKEN_ID = LabelSmoother.ignore_index
22
+ logger = logging.getLogger(__name__)
23
+ logger.setLevel(logging.INFO)
24
+
25
+
26
+ def is_dist_avail_and_initialized():
27
+ if not dist.is_available():
28
+ return False
29
+ if not dist.is_initialized():
30
+ return False
31
+ return True
32
+
33
+
34
+ def get_world_size():
35
+ if not is_dist_avail_and_initialized():
36
+ return 1
37
+ return dist.get_world_size()
38
+
39
+
40
+ def get_rank():
41
+ if not is_dist_avail_and_initialized():
42
+ return 0
43
+ return dist.get_rank()
44
+
45
+
46
+ class PackedDataset(IterableDataset):
47
+ def __init__(
48
+ self,
49
+ tokenizer,
50
+ data_rank,
51
+ data_world_size,
52
+ datasets: List,
53
+ dataset_weight: List[int] = None,
54
+ num_images_expected: int = 6,
55
+ max_packed_tokens: int = 32768,
56
+ max_buffer_size: int = 100,
57
+ log_freq: int = 1000000,
58
+ strict_mode: bool = False,
59
+ debug_mode: bool = False,
60
+ replacement: bool = True,
61
+ allow_overflow: bool = True,
62
+ allow_empty_data: bool = False,
63
+ allow_deduplicated_ds_name: bool = False,
64
+ ):
65
+ super().__init__()
66
+ self.tokenizer = tokenizer
67
+ self.data_rank = data_rank
68
+ self.data_world_size = data_world_size
69
+ self.datasets = datasets
70
+ self.num_images_expected = num_images_expected
71
+ self.max_buffer_size = max_buffer_size
72
+ self.log_freq = log_freq
73
+ self.strict_mode = strict_mode
74
+ self.debug_mode = debug_mode
75
+ self.replacement = replacement
76
+ self.allow_overflow = allow_overflow
77
+ self.allow_empty_data = allow_empty_data
78
+
79
+ self.max_packed_tokens = max_packed_tokens
80
+
81
+ self.img_start_token_id = self.tokenizer.convert_tokens_to_ids(IMG_START_TOKEN)
82
+ self.img_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
83
+ self.img_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN)
84
+
85
+ assert self.img_start_token_id != self.tokenizer.unk_token_id
86
+ assert self.img_token_id != self.tokenizer.unk_token_id
87
+ assert self.img_end_token_id != self.tokenizer.unk_token_id
88
+
89
+ if dataset_weight is None:
90
+ dataset_weight = [1] * len(datasets)
91
+ self.dataset_type = [d.dataset_type for d in self.datasets]
92
+
93
+ self.datasets_orig = datasets
94
+ self.dataset_weight_orig = [w / sum(dataset_weight) for w in dataset_weight]
95
+
96
+ self.datasets = [ds for ds in self.datasets_orig]
97
+ self.dataset_weight = [w for w in self.dataset_weight_orig]
98
+
99
+ # lazy init
100
+ self.worker_id = None
101
+ self.worker_state_key = None
102
+ self.dataset_iter_list = None
103
+ self._state_dict = {
104
+ 'sample_info': {d.ds_name:0 for d in self.datasets},
105
+ }
106
+
107
+ self.worker_custom_infos = None
108
+
109
+ ds_name_list = [d.ds_name for d in self.datasets]
110
+ if not allow_deduplicated_ds_name:
111
+ assert len(ds_name_list) == len(set(ds_name_list)), f'deduplicated ds_name: {ds_name_list}'
112
+
113
+ for ds in self.datasets:
114
+ if ds.max_num_images > self.num_images_expected:
115
+ logger.warning(f'{ds.max_num_images=} of {ds.ds_name} is larger than {self.num_images_expected=}')
116
+ ds.max_num_images = num_images_expected
117
+
118
+ if ds.max_tokens > self.max_packed_tokens:
119
+ logger.warning(f'{ds.max_tokens=} of {ds.ds_name} is larger than {self.max_packed_tokens=}')
120
+ ds.max_tokens = self.max_packed_tokens
121
+
122
+ self._state_dict[ds.ds_name] = {}
123
+
124
+ if get_rank() == 0:
125
+ logger.info(
126
+ f'Loaded dataset to pack: {ds_name_list}, '
127
+ f'{self.num_images_expected=}, {self.max_packed_tokens=}, '
128
+ f'{self.replacement=}, {self.allow_overflow=}',
129
+ )
130
+
131
+ temp = []
132
+ for ds, ds_w in zip(self.datasets, self.dataset_weight):
133
+ temp.append(f'{ds.ds_name:<25}: {ds_w*100:.2f}%')
134
+ temp = '\n'.join(temp)
135
+ logger.info(
136
+ f'Sampling prob for each dataset:\n{temp}'
137
+ )
138
+
139
+ if self.allow_empty_data:
140
+ logger.warning('allow_empty_data is enabled, note that empty data may be generated!')
141
+
142
+ def load_state_dict(self, state_dict, custom_infos=None):
143
+
144
+ self.worker_custom_infos = custom_infos
145
+
146
+ self._state_dict.update(state_dict)
147
+ for ds in self.datasets:
148
+ if ds.ds_name in self._state_dict:
149
+ ds.load_state_dict(self._state_dict[ds.ds_name])
150
+ logger.info(f'{ds.ds_name=} is resumed.')
151
+ else:
152
+ logger.warning(f'{ds.ds_name=} is not resumed.')
153
+
154
+ def _should_log(self):
155
+ worker_id = 0 if get_worker_info() is None else get_worker_info().id
156
+ num_workers = 1 if get_worker_info() is None else get_worker_info().num_workers
157
+
158
+ worker_id = num_workers * get_rank() + worker_id
159
+ num_workers = num_workers * get_world_size()
160
+
161
+ return worker_id == 0
162
+
163
+ def next_data(self, current_dataset_idx):
164
+ while True:
165
+ try:
166
+ current_sample = next(self.dataset_iter_list[current_dataset_idx])
167
+ break # Exit loop if successful
168
+ except StopIteration:
169
+ if self.replacement:
170
+ # logger.info(f'[Worker id {self.worker_id}] Dataset {self.datasets[current_dataset_idx].ds_name} is exhausted, restart it.')
171
+ try:
172
+ self.dataset_iter_list[current_dataset_idx] = iter(self.datasets[current_dataset_idx])
173
+ current_sample = next(self.dataset_iter_list[current_dataset_idx])
174
+ break
175
+ except:
176
+ # logger.error(f'{self.worker_id=} Fail to get any data from {self.datasets[current_dataset_idx].ds_name}! length={len(self.datasets)}')
177
+ self.datasets.pop(current_dataset_idx)
178
+ self.dataset_iter_list.pop(current_dataset_idx)
179
+ self.dataset_weight.pop(current_dataset_idx)
180
+
181
+ if len(self.datasets) == 0:
182
+ raise StopIteration
183
+ current_dataset_idx = np.random.choice(len(self.datasets))
184
+ else:
185
+ # logger.error(f'{self.worker_id=} Fail to get any data from {self.datasets[current_dataset_idx].ds_name}! length={len(self.datasets)}')
186
+ self.datasets.pop(current_dataset_idx)
187
+ self.dataset_iter_list.pop(current_dataset_idx)
188
+ self.dataset_weight.pop(current_dataset_idx)
189
+
190
+ if len(self.datasets) == 0:
191
+ raise StopIteration
192
+ current_dataset_idx = np.random.choice(len(self.datasets))
193
+ except:
194
+ logger.error('Unexpected error!')
195
+ if len(self.datasets) == 0:
196
+ raise StopIteration
197
+ current_dataset_idx = np.random.choice(len(self.datasets))
198
+
199
+ current_ds_name = self.datasets[current_dataset_idx].ds_name
200
+ current_sample['type_ids'] = torch.zeros_like(current_sample['input_ids']) + current_dataset_idx
201
+
202
+ if self.worker_state_key not in self._state_dict[current_ds_name]:
203
+ self._state_dict[current_ds_name][self.worker_state_key] = {}
204
+
205
+ meta_info = current_sample.pop('meta_info', {})
206
+ self._state_dict[current_ds_name][self.worker_state_key].update(**meta_info)
207
+ self._state_dict['sample_info'][self.datasets[current_dataset_idx].ds_name] += 1
208
+ return current_sample
209
+
210
+ def find_buffer(self, buffer_list, new_sample):
211
+ # NOTE: use `bisect` to search might be faster
212
+
213
+ find = False
214
+ find_idx = -1
215
+ num_images_current = new_sample['pixel_values'].size(0)
216
+ for buffer_idx, buffer in enumerate(buffer_list):
217
+ num_images_buffer = buffer['pixel_values'].size(0)
218
+ if num_images_buffer + num_images_current <= self.num_images_expected:
219
+ num_merged_tokens = new_sample['input_ids'].size(0) + buffer['input_ids'].size(0)
220
+
221
+ if num_merged_tokens <= self.max_packed_tokens:
222
+ find = True
223
+ find_idx = buffer_idx
224
+ break
225
+
226
+ if self.allow_overflow and len(buffer_list) >= self.max_buffer_size // 2:
227
+ find = True
228
+ find_idx = buffer_idx
229
+
230
+ if find:
231
+ return buffer_list.pop(find_idx)
232
+ return None
233
+
234
+ def update_buffer(self, buffer, new_sample):
235
+ if buffer is None:
236
+ new_sample['data_index'] = torch.zeros_like(new_sample['input_ids'])
237
+ return new_sample
238
+
239
+ new_sample['data_index'] = torch.ones_like(new_sample['input_ids']) + buffer['data_index'][-1].item()
240
+
241
+ assert buffer.keys() == new_sample.keys()
242
+ for k in buffer:
243
+ buffer[k] = torch.cat([buffer[k], new_sample[k]])
244
+ return buffer
245
+
246
+ @staticmethod
247
+ def check_valid(sample_to_check, min_active_tokens_ratio=1/256):
248
+ num_ignore_tokens = (sample_to_check['labels'] == IGNORE_TOKEN_ID).sum()
249
+ num_tokens = sample_to_check['labels'].numel()
250
+ return (1 - num_ignore_tokens / num_tokens) > min_active_tokens_ratio
251
+
252
+ @staticmethod
253
+ def split_buffer(buffer, max_tokens, img_start_token_id, img_token_id, img_end_token_id):
254
+ if buffer['input_ids'].size(0) <= max_tokens:
255
+ return [buffer]
256
+
257
+ def _image_is_splitted(input_ids, cut_idx):
258
+ is_image_start = input_ids[cut_idx].item() == img_start_token_id
259
+ is_image_token = input_ids[cut_idx].item() == img_token_id
260
+ is_image_end = input_ids[cut_idx].item() == img_end_token_id
261
+ return is_image_start or is_image_token or is_image_end
262
+
263
+ def _split(sample_to_split, left_idx, right_idx, left_img_idx, right_img_idx):
264
+ assert (right_idx is None) == (right_img_idx is None)
265
+
266
+ left_sample = {}
267
+ right_sample = {} if right_idx is not None else None
268
+ for k in sample_to_split:
269
+ if k in ['input_ids', 'labels', 'attention_mask', 'position_ids', 'data_index', 'type_ids']:
270
+ left_sample[k] = sample_to_split[k][:left_idx]
271
+ if right_sample is not None:
272
+ right_sample[k] = sample_to_split[k][right_idx:]
273
+ elif k in ['pixel_values', 'image_flags']:
274
+ left_sample[k] = sample_to_split[k][:left_img_idx]
275
+ if right_sample is not None:
276
+ right_sample[k] = sample_to_split[k][right_img_idx:]
277
+ else:
278
+ raise NotImplementedError(f'find unsupported keys: {k} from {sample_to_split.keys()}')
279
+ return left_sample, right_sample
280
+
281
+ splitted_buffer = []
282
+ while buffer['input_ids'].size(0) > max_tokens:
283
+ img_start_idx_list = (buffer['input_ids'] == img_start_token_id).nonzero().squeeze(1).tolist()
284
+ img_end_idx_list = (buffer['input_ids'] == img_end_token_id).nonzero().squeeze(1).tolist()
285
+ assert len(img_start_idx_list) == len(img_end_idx_list)
286
+
287
+ if _image_is_splitted(buffer['input_ids'], max_tokens):
288
+ cut_idx = bisect.bisect_left(img_start_idx_list, max_tokens)
289
+ if buffer['input_ids'][max_tokens] == img_start_token_id:
290
+ assert max_tokens == img_start_idx_list[cut_idx]
291
+ cut_left_idx = img_start_idx_list[cut_idx]
292
+ cut_left_img_idx = cut_idx
293
+ else:
294
+ cut_left_idx = img_start_idx_list[cut_idx - 1]
295
+ cut_left_img_idx = cut_idx - 1
296
+ cut_right_idx = cut_left_idx
297
+ cut_right_img_idx = cut_left_img_idx
298
+ else:
299
+ cut_img_idx = bisect.bisect(img_start_idx_list, max_tokens)
300
+ if cut_img_idx < len(img_start_idx_list):
301
+ cut_right_idx = img_start_idx_list[cut_img_idx]
302
+ cut_right_img_idx = cut_img_idx
303
+ else:
304
+ cut_right_idx = None
305
+ cut_right_img_idx = None
306
+
307
+ cut_left_idx = max_tokens
308
+ cut_left_img_idx = cut_right_img_idx if cut_right_img_idx is not None else buffer['pixel_values'].size(0)
309
+
310
+ left, right = _split(
311
+ sample_to_split=buffer,
312
+ left_idx=cut_left_idx,
313
+ left_img_idx=cut_left_img_idx,
314
+ right_idx=cut_right_idx,
315
+ right_img_idx=cut_right_img_idx,
316
+ )
317
+
318
+ assert (left['input_ids'] == img_end_token_id).sum() == (left['input_ids'] == img_start_token_id).sum() == left['pixel_values'].size(0)
319
+ if right is not None:
320
+ assert (right['input_ids'] == img_end_token_id).sum() == (right['input_ids'] == img_start_token_id).sum() == right['pixel_values'].size(0)
321
+
322
+ if left['pixel_values'].size(0) >= 1 and PackedDataset.check_valid(left):
323
+ splitted_buffer.append(left)
324
+
325
+ if right is None or right['pixel_values'].size(0) == 0:
326
+ break
327
+
328
+ buffer = right
329
+ if buffer['input_ids'].size(0) <= max_tokens and PackedDataset.check_valid(buffer):
330
+ splitted_buffer.append(buffer)
331
+ break
332
+
333
+ logger.debug(
334
+ f'split a sample into {len(splitted_buffer)} samples, '
335
+ f'current max_tokens={max_tokens}'
336
+ )
337
+ return splitted_buffer
338
+
339
+ def update_buffer_list(self, buffer_list, buffer_max_len_list, buffer):
340
+ # NOTE: in-place operation
341
+
342
+ splitted_buffer = PackedDataset.split_buffer(
343
+ buffer=buffer,
344
+ max_tokens=self.max_packed_tokens,
345
+ img_start_token_id=self.img_start_token_id,
346
+ img_token_id=self.img_token_id,
347
+ img_end_token_id=self.img_end_token_id,
348
+ )
349
+
350
+ for each_buffer in splitted_buffer:
351
+ if each_buffer['pixel_values'].size(0) > self.num_images_expected:
352
+ logger.error(
353
+ f"Find a sample with {each_buffer['pixel_values'].size(0)} images, "
354
+ f'which exceeds {self.num_images_expected}'
355
+ )
356
+ continue
357
+
358
+ if each_buffer['input_ids'].size(0) >= self.max_packed_tokens:
359
+ assert each_buffer['input_ids'].size(0) == self.max_packed_tokens
360
+ buffer_max_len_list.append(each_buffer)
361
+ continue
362
+
363
+ find_idx = len(buffer_list)
364
+ num_images_new_sample = each_buffer['pixel_values'].size(0)
365
+ for buffer_idx in range(len(buffer_list)):
366
+ if buffer_list[buffer_idx]['pixel_values'].size(0) < num_images_new_sample:
367
+ find_idx = buffer_idx
368
+ break
369
+ buffer_list.insert(find_idx, each_buffer)
370
+
371
+ for i in range(1, len(buffer_list)):
372
+ assert buffer_list[i-1]['pixel_values'].size(0) >= buffer_list[i]['pixel_values'].size(0)
373
+
374
+ return buffer_list, buffer_max_len_list
375
+
376
+ def pad_buffer(self, buffer):
377
+ if buffer['pixel_values'].size(0) == self.num_images_expected:
378
+ return buffer
379
+
380
+ num_pad_images = self.num_images_expected - buffer['pixel_values'].size(0)
381
+ pad_images = torch.stack([
382
+ torch.zeros_like(buffer['pixel_values'][0])
383
+ for _ in range(num_pad_images)
384
+ ])
385
+ pad_image_flags = torch.tensor([0] * num_pad_images, dtype=torch.long)
386
+
387
+ buffer['pixel_values'] = torch.cat([buffer['pixel_values'], pad_images])
388
+ buffer['image_flags'] = torch.cat([buffer['image_flags'], pad_image_flags])
389
+
390
+ return buffer
391
+
392
+ def postprocess_buffer(self, buffer, custom_infos=None):
393
+ buffer['worker_state_key'] = self.worker_state_key
394
+ buffer['worker_state_dict'] = self._state_dict
395
+ if custom_infos is not None:
396
+ buffer['custom_infos'] = {self.worker_state_key: copy.deepcopy(custom_infos)}
397
+ return buffer
398
+
399
+ def print_log(self, iter_idx, buffer_list):
400
+ if iter_idx % self.log_freq != 0:
401
+ return
402
+
403
+ if self._should_log():
404
+ logger.info(
405
+ f"{iter_idx=}, {len(buffer_list)=}, {self._state_dict['sample_info']}"
406
+ )
407
+
408
+ def __iter__(self):
409
+ iter_idx = 0
410
+ buffer_list = []
411
+ buffer_max_len_list = []
412
+
413
+ if self._should_log():
414
+ logger.info(f'Begin to iter, {len(buffer_list)=}')
415
+
416
+ worker_id = 0 if get_worker_info() is None else get_worker_info().id
417
+ num_workers = 1 if get_worker_info() is None else get_worker_info().num_workers
418
+
419
+ worker_id = num_workers * self.data_rank + worker_id
420
+ num_workers = num_workers * self.data_world_size
421
+
422
+ rng = np.random.default_rng(seed=worker_id)
423
+
424
+ # reset states of each dataset
425
+ self.worker_id = worker_id
426
+ self.worker_state_key = f'work_state_{self.worker_id}'
427
+ self.datasets = [d for d in self.datasets_orig]
428
+ self.dataset_weight = [w for w in self.dataset_weight_orig]
429
+ self.dataset_iter_list = [iter(d) for d in self.datasets]
430
+
431
+ for ds in self.datasets:
432
+ # if not isinstance(ds, (ImageTextPairDataset, InterleavedDataset)):
433
+ ds.worker_id = worker_id
434
+ ds.worker_state_key = f'work_state_{self.worker_id}'
435
+ ds.num_workers = num_workers
436
+ if self._should_log() and worker_id == 0:
437
+ logger.info(f'set worker_id and num_workers of {ds.__class__.__name__} {ds.ds_name}')
438
+
439
+ if self.worker_custom_infos is not None and self.worker_state_key in self.worker_custom_infos:
440
+ custom_infos = self.worker_custom_infos[self.worker_state_key]
441
+ # buffer list
442
+ if 'buffer_list' in custom_infos and isinstance(custom_infos['buffer_list'], list):
443
+ buffer_list = custom_infos['buffer_list']
444
+ if self._should_log() and worker_id == 0:
445
+ logger.info(f'[{self.worker_state_key}] load buffer list --> {len(buffer_list)=}')
446
+ # other infos
447
+
448
+ # reset
449
+ self.worker_custom_infos = None
450
+
451
+ logger.debug(
452
+ f'{self.__class__.__name__} Rank {self.data_rank} '
453
+ f'Worker {worker_id} begin to load data'
454
+ )
455
+
456
+ while True:
457
+ self.dataset_weight = [w / sum(self.dataset_weight) for w in self.dataset_weight]
458
+ current_dataset_idx = rng.choice(len(self.dataset_iter_list), p=self.dataset_weight)
459
+
460
+ try:
461
+ current_sample = self.next_data(current_dataset_idx)
462
+ except:
463
+ logger.info(f'All datasets are exhausted, begin to empty the buffer_list ({len(buffer_list)=})')
464
+ while len(buffer_list) > 0:
465
+ if self.strict_mode:
466
+ yield self.postprocess_buffer(self.pad_buffer(buffer_list.pop(0)))
467
+ else:
468
+ yield self.postprocess_buffer(buffer_list.pop(0))
469
+ logger.info(f'buffer_list is empty! ({len(buffer_list)=})')
470
+ return
471
+
472
+ buffer = self.find_buffer(buffer_list, current_sample)
473
+ buffer = self.update_buffer(buffer, current_sample)
474
+ buffer_list, buffer_max_len_list = self.update_buffer_list(buffer_list, buffer_max_len_list, buffer)
475
+
476
+ while len(buffer_max_len_list) > 0:
477
+ if buffer_max_len_list[0]['pixel_values'].size(0) != self.max_packed_tokens:
478
+ logger.debug(
479
+ f'num tokens of a buffer exceed {self.max_packed_tokens=}, '
480
+ f"yield a sample with {buffer_max_len_list[0]['pixel_values'].size(0)} images"
481
+ )
482
+ if self.strict_mode and buffer_max_len_list[0]['pixel_values'].size(0) != self.num_images_expected:
483
+ # buffer_max_len_list.pop(0)
484
+ yield self.postprocess_buffer(self.pad_buffer(buffer_max_len_list.pop(0)), {'buffer_list': buffer_list})
485
+ else:
486
+ yield self.postprocess_buffer(buffer_max_len_list.pop(0), {'buffer_list': buffer_list})
487
+
488
+ while len(buffer_list) > 0 and buffer_list[0]['pixel_values'].size(0) > self.num_images_expected:
489
+ logger.error(
490
+ f"num images of a buffer ({buffer_list[0]['pixel_values'].size(0)}) "
491
+ f'is larger than num_images_expected({self.num_images_expected})'
492
+ )
493
+ buffer_list.pop(0)
494
+
495
+ while len(buffer_list) > 0 and buffer_list[0]['pixel_values'].size(0) == self.num_images_expected:
496
+ if self.debug_mode:
497
+ debug_data = self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list})
498
+ while True:
499
+ yield debug_data.copy()
500
+
501
+ yield self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list})
502
+
503
+ while len(buffer_list) > self.max_buffer_size:
504
+ logger.debug(
505
+ f'Failed to pack data to exactly {self.num_images_expected} images, '
506
+ f"yield a data sample with {buffer_list[0]['pixel_values'].size(0)} images."
507
+ )
508
+ if self.strict_mode:
509
+ yield self.postprocess_buffer(self.pad_buffer(buffer_list.pop(0)), {'buffer_list': buffer_list})
510
+ else:
511
+ yield self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list})
512
+
513
+ self.print_log(iter_idx=iter_idx, buffer_list=buffer_list)
514
+ iter_idx += 1
515
+
516
+ @staticmethod
517
+ def get_cu_seqlens_and_indexes(
518
+ data_index: torch.LongTensor, # (seq_len,)
519
+ input_ids: torch.LongTensor, # (seq_len,)
520
+ labels: torch.LongTensor, # (seq_len,)
521
+ len2weight: callable,
522
+ ):
523
+ indexes = []
524
+ cu_seqlens = [0]
525
+ loss_weight = []
526
+
527
+ start = data_index.min()
528
+ end = data_index.max() + 1
529
+ for i in range(start, end):
530
+ num_tokens = (data_index == i).sum().item()
531
+ indexes.extend(list(range(num_tokens)))
532
+ cu_seqlens.append(cu_seqlens[-1] + num_tokens)
533
+ assert num_tokens > 0
534
+
535
+ curr_data_index = data_index[cu_seqlens[-2]:cu_seqlens[-2]+num_tokens]
536
+ assert (curr_data_index == i).all(), data_index
537
+
538
+ curr_labels = labels[cu_seqlens[-2]:cu_seqlens[-2]+num_tokens]
539
+ num_effective_tokens = (curr_labels != IGNORE_TOKEN_ID).sum().item()
540
+ loss_weight.extend([len2weight(num_effective_tokens)] * num_tokens)
541
+
542
+ assert len(indexes) == data_index.size(0), f'{len(indexes)=}, {data_index.size(0)=}'
543
+
544
+ loss_weight = torch.tensor(loss_weight, dtype=torch.float32)
545
+ return cu_seqlens, indexes, loss_weight
546
+
547
+
548
+ WARNING_CNT = defaultdict(int)
549
+
550
+
551
+ def packed_collate_fn(
552
+ features,
553
+ data_collator,
554
+ len2weight: callable,
555
+ max_item_length: int,
556
+ micro_num: int = 1,
557
+ loss_reduction_all_gather: bool = False,
558
+ pad_id: int = 0,
559
+ ):
560
+ if not isinstance(features, list):
561
+ features = [features]
562
+
563
+ if len(features) > micro_num:
564
+ raise NotImplementedError(f'{len(features)=} > {micro_num=}')
565
+
566
+ if len(features) < micro_num and WARNING_CNT['micro_num_warning'] < 5:
567
+ logger.warning(
568
+ f'{len(features)=} > {micro_num=}, '
569
+ f'the features will be padded to satisfy micro_num requirement'
570
+ )
571
+ WARNING_CNT['micro_num_warning'] += 1
572
+
573
+ # ensure that the len(features) is equal to the required micro_num
574
+ num_features = len(features)
575
+ while len(features) < micro_num:
576
+ features.append(copy.deepcopy(features[0]))
577
+ features[-1]['labels'] = torch.full_like(features[-1]['labels'], IGNORE_TOKEN_ID)
578
+
579
+ indexes = []
580
+ cu_seqlens = []
581
+ cu_num_images_list = [0]
582
+
583
+ worker_state_key_list = []
584
+ worker_state_dict_list = []
585
+ worker_state_custom_infos_list = []
586
+
587
+ batch_lens = [feat['input_ids'].shape for feat in features]
588
+ max_item_length = max_item_length or max(batch_lens)[0]
589
+
590
+ num_samples = 0
591
+ num_padding_tokens = 0
592
+ for feat_idx, feat in enumerate(features):
593
+ data_index = feat.pop('data_index')
594
+ curr_cu_seqlens, curr_indexes, curr_loss_weight = PackedDataset.get_cu_seqlens_and_indexes(
595
+ data_index=data_index,
596
+ input_ids=feat['input_ids'],
597
+ labels=feat['labels'],
598
+ len2weight=len2weight,
599
+ )
600
+
601
+ feat['loss_weight'] = curr_loss_weight
602
+
603
+ if feat_idx < num_features:
604
+ num_samples += len(curr_cu_seqlens) - 1
605
+
606
+ if curr_cu_seqlens[-1] < max_item_length:
607
+ curr_cu_seqlens.append(max_item_length)
608
+ curr_indexes.extend(list(range(max_item_length - curr_cu_seqlens[-2])))
609
+
610
+ indexes.append(torch.tensor(curr_indexes, dtype=torch.long))
611
+ cu_seqlens.append(torch.tensor(curr_cu_seqlens, dtype=torch.int32))
612
+
613
+ worker_state_key_list.append(feat.pop('worker_state_key'))
614
+ worker_state_dict_list.append(feat.pop('worker_state_dict'))
615
+ worker_state_custom_infos_list.append(feat.pop('custom_infos', None))
616
+
617
+ num_padding_tokens += (max_item_length - feat['input_ids'].size(0))
618
+ cu_num_images_list.append(cu_num_images_list[-1] + feat['pixel_values'].size(0))
619
+
620
+ batch = data_collator(features=features, max_item_length=max_item_length, pad_id=pad_id)
621
+ # convert it to list in case it is converted into bf16
622
+ batch['loss_weight'] = torch.where(batch['labels'] == IGNORE_TOKEN_ID, 0, batch['loss_weight']).tolist()
623
+ batch['attention_mask'] = torch.stack(cu_seqlens)
624
+ batch['loss_reduction_all_gather'] = loss_reduction_all_gather
625
+ batch['statistics'] = torch.tensor(
626
+ [
627
+ num_samples,
628
+ num_padding_tokens,
629
+ batch['image_flags'].numel() - batch['image_flags'].sum().item(),
630
+ ],
631
+ dtype=torch.long,
632
+ )
633
+ batch.pop('type_ids')
634
+ return batch
src/third_party/InternVL/internvl_chat/internvl/train/internvl_chat_dpo.py ADDED
@@ -0,0 +1,1056 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import warnings
8
+
9
+ warnings.filterwarnings('ignore', category=FutureWarning)
10
+ warnings.filterwarnings('ignore', category=DeprecationWarning)
11
+
12
+ import logging
13
+ import math
14
+ import os
15
+ import random
16
+ import shutil
17
+ import sys
18
+ import traceback
19
+ from copy import deepcopy
20
+ from dataclasses import dataclass, field
21
+ from typing import Dict, Literal, Optional
22
+
23
+ import numpy as np
24
+
25
+ try:
26
+ import orjson as json
27
+ except:
28
+ import json
29
+
30
+ import torch
31
+ import torch.distributed as dist
32
+ import transformers
33
+ from internvl.dist_utils import init_dist
34
+ from internvl.model.internlm2.modeling_internlm2 import InternLM2ForCausalLM
35
+ from internvl.model.internvl_chat import (InternVisionConfig,
36
+ InternVisionModel,
37
+ InternVLChatConfig,
38
+ InternVLChatModel)
39
+ from internvl.patch import (concat_pad_data_collator,
40
+ dpo_concat_pad_data_collator,
41
+ replace_llama_rmsnorm_with_fused_rmsnorm,
42
+ replace_train_sampler)
43
+ from internvl.train.constants import (BOX_END_TOKEN, BOX_START_TOKEN,
44
+ IMG_CONTEXT_TOKEN, IMG_END_TOKEN,
45
+ IMG_START_TOKEN, QUAD_END_TOKEN,
46
+ QUAD_START_TOKEN, REF_END_TOKEN,
47
+ REF_START_TOKEN)
48
+ from internvl.train.dataset import (ConcatDataset, TCSLoader,
49
+ WeightedConcatDataset, build_transform,
50
+ dynamic_preprocess, preprocess,
51
+ preprocess_internlm,
52
+ preprocess_internvl2_5, preprocess_mpt,
53
+ preprocess_phi3)
54
+ from internvl.train.trainer_dpo import MultimodalDPOTrainer
55
+ from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError
56
+ from torch.utils.data import Dataset
57
+ from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
58
+ HfArgumentParser, Trainer, TrainingArguments,
59
+ set_seed)
60
+ from transformers.trainer_utils import get_last_checkpoint
61
+ from transformers.utils.logging import (enable_default_handler,
62
+ enable_explicit_format, set_verbosity)
63
+ from trl import DPOConfig as DPOConfigTRL
64
+
65
+ # Try to import petrel_client for image loading, fallback to PIL if unavailable
66
+ try:
67
+ from petrel_client.client import Client
68
+ from petrel_client.common.config import Config
69
+ has_tcs_loader = True
70
+ except ImportError as E:
71
+ print('petrel_client is not installed. Using PIL to load images.')
72
+ has_tcs_loader = False
73
+
74
+ # Set constants for image processing and logging
75
+ IGNORE_INDEX = -100
76
+ Image.MAX_IMAGE_PIXELS = None
77
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
78
+ MaximumDecompressedSize = 1024
79
+ MegaByte = 2 ** 20
80
+ PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
81
+
82
+ warnings.filterwarnings('ignore')
83
+ logger = logging.getLogger(__name__)
84
+
85
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
86
+
87
+
88
+ @dataclass
89
+ class ModelArguments:
90
+ """
91
+ Arguments for specifying model, tokenizer, and configurations.
92
+ """
93
+ model_name_or_path: Optional[str] = field(
94
+ default=None,
95
+ metadata={'help': 'Path to a pretrained model (local or from huggingface.co/models).'}
96
+ )
97
+ vision_path: Optional[str] = field(
98
+ default=None,
99
+ metadata={'help': 'Path to a pretrained model (local or from huggingface.co/models).'}
100
+ )
101
+ llm_path: Optional[str] = field(
102
+ default=None,
103
+ metadata={'help': 'Path to a pretrained model (local or from huggingface.co/models).'}
104
+ )
105
+ mlp_path: Optional[str] = field(
106
+ default=None,
107
+ metadata={'help': 'Path to a pretrained model (local or from huggingface.co/models).'}
108
+ )
109
+ freeze_llm: bool = field(
110
+ default=False,
111
+ metadata={'help': 'Set to True to freeze the LLM. Default is False.'},
112
+ )
113
+ freeze_backbone: bool = field(
114
+ default=False,
115
+ metadata={'help': 'Set to True to freeze the ViT. Default is False.'},
116
+ )
117
+ freeze_mlp: bool = field(
118
+ default=False,
119
+ metadata={'help': 'Set to True to freeze the MLP. Default is False.'},
120
+ )
121
+ unfreeze_vit_layers: int = field(
122
+ default=0,
123
+ metadata={'help': 'Specify the number of ViT layers to unfreeze. Default is 0.'},
124
+ )
125
+ vision_select_layer: int = field(
126
+ default=-1,
127
+ metadata={'help': 'Specify the layer of ViT feature map to use. Default is -1 for the last layer.'},
128
+ )
129
+ use_backbone_lora: int = field(
130
+ default=0,
131
+ metadata={'help': 'Set the LoRA adapter rank for the ViT. Default is 0.'}
132
+ )
133
+ use_llm_lora: int = field(
134
+ default=0,
135
+ metadata={'help': 'Set the LoRA adapter rank for the LLM. Default is 0.'}
136
+ )
137
+ unfreeze_lm_head: bool = field(
138
+ default=False,
139
+ metadata={'help': 'Set to True to unfreeze the head of LLM. Default is False.'},
140
+ )
141
+ grad_checkpoint: bool = field(
142
+ default=True,
143
+ metadata={'help': 'Set to True to use gradient checkpointing. Default is True.'},
144
+ )
145
+ drop_path_rate: float = field(
146
+ default=0.0,
147
+ metadata={'help': 'Set the drop path rate for the ViT. Default is 0.'},
148
+ )
149
+ ps_version: Literal['v1', 'v2'] = field(
150
+ default='v2',
151
+ metadata={'help': 'Specify the version of pixel shuffle implementation. Default is v2.'}
152
+ )
153
+ use_fast_tokenizer: bool = field(
154
+ default=False,
155
+ metadata={'help': 'Set to True to use the fast mode of the tokenizer.'}
156
+ )
157
+ use_liger: bool = field(
158
+ default=False,
159
+ metadata={'help': 'Set to True to use the liger kernel.'}
160
+ )
161
+
162
+
163
+ @dataclass
164
+ class DataTrainingArguments:
165
+ """
166
+ Arguments for specifying data input for training and evaluation.
167
+ """
168
+ max_seq_length: int = field(
169
+ default=8192,
170
+ metadata={
171
+ 'help': (
172
+ 'The maximum total input sequence length after tokenization. Sequences longer '
173
+ 'than this will be truncated, sequences shorter will be padded.'
174
+ )
175
+ },
176
+ )
177
+ force_image_size: int = field(
178
+ default=448,
179
+ metadata={'help': 'Set the desired size for the image. Default is 448.'},
180
+ )
181
+ down_sample_ratio: float = field(
182
+ default=0.5,
183
+ metadata={'help': 'Set the desired down-sampling ratio for the image. Default is 0.5.'},
184
+ )
185
+ pad2square: bool = field(
186
+ default=False,
187
+ metadata={'help': 'Pad the image to a square shape if set to True. Default is False.'},
188
+ )
189
+ conv_style: str = field(
190
+ default='internlm2-chat', metadata={'help': 'Prompt style for a conversation.'}
191
+ )
192
+ meta_path: str = field(
193
+ default=None,
194
+ metadata={'help': 'The path of the meta file of datasets.'},
195
+ )
196
+ use_data_resampling: bool = field(
197
+ default=False,
198
+ metadata={'help': 'Set to True to use data resampling. Default is False.'},
199
+ )
200
+ dynamic_image_size: bool = field(
201
+ default=False,
202
+ metadata={'help': 'Set to True to use dynamic high resolution strategy. Default is False.'},
203
+ )
204
+ use_thumbnail: bool = field(
205
+ default=False,
206
+ metadata={'help': 'Set to True to add a thumbnail image. Default is False.'},
207
+ )
208
+ min_dynamic_patch: int = field(
209
+ default=1,
210
+ metadata={'help': 'The minimum number of dynamic patches. Default is 1.'},
211
+ )
212
+ max_dynamic_patch: int = field(
213
+ default=12,
214
+ metadata={'help': 'The maximum number of dynamic patches. Default is 12.'},
215
+ )
216
+ min_num_frame: int = field(
217
+ default=8,
218
+ metadata={'help': 'The minimum number of frames for video data. Default is 8.'},
219
+ )
220
+ max_num_frame: int = field(
221
+ default=32,
222
+ metadata={'help': 'The maximum number of frames for video data. Default is 32.'},
223
+ )
224
+ normalize_type: Literal['imagenet', 'clip', 'siglip'] = field(
225
+ default='imagenet',
226
+ metadata={'help': 'The normalization type for the image. Default is imagenet.'},
227
+ )
228
+ sigmoid_loss_weight: float = field(
229
+ default=1.0,
230
+ metadata={'help': 'Loss weight for DPO loss. Default is 1.0'},
231
+ )
232
+ bco_pair_loss_weight: float = field(
233
+ default=1.0,
234
+ metadata={'help': 'Loss weight for BCO loss. Default is 1.0'},
235
+ )
236
+
237
+
238
+ class DPOConfig(DPOConfigTRL):
239
+ loss_type: Literal[
240
+ 'sigmoid', 'hinge', 'ipo', 'bco_pair', 'sppo_hard', 'nca_pair', 'robust', 'aot', 'aot_pair', 'exo_pair',
241
+ 'sigmoid,bco_pair',
242
+ ] = 'sigmoid'
243
+
244
+
245
+ class LazySupervisedDataset(Dataset):
246
+ """Dataset for supervised fine-tuning."""
247
+
248
+ def __init__(
249
+ self,
250
+ template_name,
251
+ meta,
252
+ tokenizer,
253
+ tcs_loader,
254
+ ds_name,
255
+ num_image_token,
256
+ image_size=448,
257
+ is_train=True,
258
+ pad2square=False,
259
+ group_by_length=False,
260
+ dynamic_image_size=False,
261
+ use_thumbnail=False,
262
+ min_dynamic_patch=1,
263
+ max_dynamic_patch=12,
264
+ min_num_frame=8, # for video data
265
+ max_num_frame=32, # for video data
266
+ sampling_method='rand', # for video data
267
+ repeat_time=1,
268
+ normalize_type='imagenet',
269
+ random_seed=0,
270
+ ):
271
+ super(LazySupervisedDataset, self).__init__()
272
+ self.ds_name = ds_name
273
+ self.tokenizer = tokenizer
274
+ self.template_name = template_name
275
+ self.num_image_token = num_image_token
276
+ logger.info(f'[Dataset] num_image_token: {num_image_token}')
277
+ logger.info(f'[Dataset] dynamic_image_size: {dynamic_image_size}')
278
+ logger.info(f'[Dataset] use_thumbnail: {use_thumbnail}')
279
+ logger.info(f'[Dataset] min_dynamic_patch: {min_dynamic_patch}, max_dynamic_patch: {max_dynamic_patch}')
280
+
281
+ self.image_size = image_size
282
+ self.is_train = is_train
283
+ self.pad2square = pad2square
284
+ self.max_num_frame = max_num_frame
285
+ self.min_num_frame = min_num_frame
286
+ self.sampling_method = sampling_method
287
+
288
+ logger.info('Formatting inputs...Skip in lazy mode')
289
+ assert meta['annotation'].endswith('jsonl'), f'annotation must be jsonl, but got {meta["annotation"]}'
290
+
291
+ with open(meta['annotation'], 'r') as f:
292
+ self.raw_data = f.readlines()
293
+ if repeat_time < 1:
294
+ # If repeat_time is less than 1, select a portion of the data
295
+ self.raw_data = random.sample(self.raw_data, k=int(len(self.raw_data) * repeat_time))
296
+ if repeat_time > 1:
297
+ repeat_time = int(repeat_time)
298
+ assert isinstance(repeat_time, int)
299
+ # Repeat the list if repeat_time is greater than 1
300
+ self.raw_data = self.raw_data * repeat_time
301
+
302
+ self.rng = np.random.default_rng(seed=random_seed)
303
+ self.rng.shuffle(self.raw_data)
304
+
305
+ self.root = meta['root']
306
+ self.cached_data_dict = {}
307
+ self.tcs_loader = tcs_loader
308
+ self.group_by_length = group_by_length
309
+ self.dynamic_image_size = dynamic_image_size
310
+ self.use_thumbnail = use_thumbnail
311
+ self.min_dynamic_patch = min_dynamic_patch
312
+ self.max_dynamic_patch = max_dynamic_patch
313
+ self.normalize_type = normalize_type
314
+
315
+ # If the precomputed length does not exist, roughly estimate the length of
316
+ # each sample to improve the efficiency of group_by_length.
317
+ if self.group_by_length:
318
+ self.conv2length = {} # Using a dictionary to speed up token length calculation
319
+ self.length = []
320
+ for data_item in self.raw_data:
321
+ data_item = json.loads(data_item)
322
+ if 'length' in data_item:
323
+ token_length = data_item['length'] # Use precomputed length if available
324
+ else:
325
+ # Compute token length using the tokenizer
326
+ conversations = '\n'.join([temp['value'] for temp in data_item['conversations']])
327
+ str_length = len(conversations)
328
+ if str_length not in self.conv2length:
329
+ token_length = tokenizer(
330
+ conversations, return_tensors='pt', padding=False, truncation=False,
331
+ ).input_ids.size(1)
332
+ self.conv2length[str_length] = token_length + num_image_token * (
333
+ max_dynamic_patch + use_thumbnail)
334
+ else:
335
+ token_length = self.conv2length[str_length]
336
+ self.length.append(token_length)
337
+
338
+ def __len__(self):
339
+ return len(self.raw_data)
340
+
341
+ def get_preprocess_function(self):
342
+ # Select the appropriate preprocessing function based on the template name
343
+ if self.template_name == 'Hermes-2':
344
+ preprocess_function = preprocess_mpt
345
+ elif self.template_name == 'internlm2-chat':
346
+ preprocess_function = preprocess_internlm
347
+ elif self.template_name == 'phi3-chat':
348
+ preprocess_function = preprocess_phi3
349
+ elif self.template_name == 'internvl2_5':
350
+ preprocess_function = preprocess_internvl2_5
351
+ else:
352
+ preprocess_function = preprocess
353
+ return preprocess_function
354
+
355
+ def load_image(self, image_path):
356
+ # Load the image using tcs_loader if available, otherwise use PIL
357
+ if self.tcs_loader is not None and 's3://' in image_path:
358
+ return self.tcs_loader(image_path)
359
+ return Image.open(image_path).convert('RGB')
360
+
361
+ def get_image_path(self, image_path):
362
+ if image_path.startswith('s3://'): # for ceph
363
+ image_path = self.root + image_path
364
+ else: # for local image
365
+ image_path = os.path.join(self.root, image_path)
366
+ return image_path
367
+
368
+ def get_transform(self):
369
+ # Build transformation function
370
+ transform = build_transform(is_train=self.is_train, input_size=self.image_size,
371
+ pad2square=self.pad2square, normalize_type=self.normalize_type)
372
+ return transform
373
+
374
+ @staticmethod
375
+ def get_longest_common_prefix_index(tensor1, tensor2):
376
+ min_len = min(len(tensor1), len(tensor2))
377
+
378
+ for i in range(min_len):
379
+ if tensor1[i] != tensor2[i]:
380
+ return i
381
+
382
+ return min_len
383
+
384
+ def multi_modal_get_item(self, data_item):
385
+ # Build transformation function
386
+ transform = self.get_transform()
387
+
388
+ # Ensure the first conversation contains an image placeholder
389
+ if '<image>' not in data_item['question']:
390
+ data_item['question'] = '<image>\n' + data_item['question']
391
+
392
+ # Merge the image path
393
+ image_path = self.get_image_path(data_item['image'])
394
+
395
+ # Load the image using tcs_loader if available, otherwise use PIL
396
+ image = self.load_image(image_path)
397
+
398
+ if self.dynamic_image_size: # If dynamic image size is enabled, preprocess the image dynamically
399
+ images = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=self.max_dynamic_patch,
400
+ image_size=self.image_size, use_thumbnail=self.use_thumbnail)
401
+ else: # Otherwise, use the original image as a single patch
402
+ images = [image]
403
+
404
+ # Apply the transformation to each image and stack the results into a tensor
405
+ pixel_values = [transform(image) for image in images]
406
+ pixel_values = torch.stack(pixel_values)
407
+
408
+ # Ensure that there is only one patch if dynamic image size is not enabled
409
+ num_patches = pixel_values.size(0)
410
+ if not self.dynamic_image_size:
411
+ assert num_patches == 1, f'The number of patches should be 1, but got {num_patches}.'
412
+
413
+ # Select the appropriate preprocessing function based on the template name
414
+ preprocess_function = self.get_preprocess_function()
415
+
416
+ # Preprocess the conversations and generate the return dictionary
417
+ chosen_conversations = [
418
+ {'from': 'human', 'value': data_item['question']},
419
+ {'from': 'gpt', 'value': data_item['chosen']},
420
+ ]
421
+ chosen_ret = preprocess_function(
422
+ self.template_name,
423
+ [deepcopy(chosen_conversations)],
424
+ self.tokenizer,
425
+ [self.num_image_token * num_patches],
426
+ group_by_length=True,
427
+ ds_name=self.ds_name,
428
+ )
429
+
430
+ rejected_conversations = [
431
+ {'from': 'human', 'value': data_item['question']},
432
+ {'from': 'gpt', 'value': data_item['rejected']},
433
+ ]
434
+ rejected_ret = preprocess_function(
435
+ self.template_name,
436
+ [deepcopy(rejected_conversations)],
437
+ self.tokenizer,
438
+ [self.num_image_token * num_patches],
439
+ group_by_length=True,
440
+ ds_name=self.ds_name,
441
+ )
442
+
443
+ # Create the final return dictionary
444
+ ret = dict(
445
+ chosen_input_ids=chosen_ret['input_ids'][0],
446
+ chosen_labels=chosen_ret['labels'][0],
447
+ chosen_attention_mask=chosen_ret['attention_mask'][0],
448
+ rejected_input_ids=rejected_ret['input_ids'][0],
449
+ rejected_labels=rejected_ret['labels'][0],
450
+ rejected_attention_mask=rejected_ret['attention_mask'][0],
451
+ pixel_values=pixel_values,
452
+ image_flags=torch.tensor([1] * num_patches, dtype=torch.long),
453
+ )
454
+ return ret
455
+
456
+ def multi_modal_multi_image_get_item(self, data_item):
457
+ # Build transformation function
458
+ transform = self.get_transform()
459
+
460
+ images, num_tiles = [], []
461
+ num_image = len(data_item['image'])
462
+ for image_path in data_item['image']:
463
+ # Merge the image path
464
+ image_path = self.get_image_path(image_path)
465
+ # Load the image using tcs_loader if available, otherwise use PIL
466
+ image = self.load_image(image_path)
467
+ if self.dynamic_image_size: # If dynamic image size is enabled, preprocess the image dynamically
468
+ image = dynamic_preprocess(image, min_num=self.min_dynamic_patch,
469
+ max_num=max(1, self.max_dynamic_patch // num_image),
470
+ image_size=self.image_size, use_thumbnail=self.use_thumbnail)
471
+ images += image
472
+ num_tiles.append(len(image))
473
+ else: # Otherwise, use the original image as a single patch
474
+ images.append(image)
475
+ num_tiles.append(1)
476
+ pixel_values = [transform(image) for image in images]
477
+ pixel_values = torch.stack(pixel_values)
478
+ num_patches = pixel_values.size(0)
479
+
480
+ # Select the appropriate preprocessing function based on the template name
481
+ preprocess_function = self.get_preprocess_function()
482
+
483
+ # Preprocess the conversations and generate the return dictionary
484
+ num_image_tokens = [self.num_image_token * num_tile for num_tile in num_tiles]
485
+
486
+ chosen_conversations = [
487
+ {'from': 'human', 'value': data_item['question']},
488
+ {'from': 'gpt', 'value': data_item['chosen']},
489
+ ]
490
+ chosen_ret = preprocess_function(
491
+ self.template_name,
492
+ [deepcopy(chosen_conversations)],
493
+ self.tokenizer,
494
+ num_image_tokens,
495
+ group_by_length=self.group_by_length,
496
+ ds_name=self.ds_name,
497
+ num_image=num_image,
498
+ )
499
+
500
+ rejected_conversations = [
501
+ {'from': 'human', 'value': data_item['question']},
502
+ {'from': 'gpt', 'value': data_item['rejected']},
503
+ ]
504
+ rejected_ret = preprocess_function(
505
+ self.template_name,
506
+ [deepcopy(rejected_conversations)],
507
+ self.tokenizer,
508
+ num_image_tokens,
509
+ group_by_length=self.group_by_length,
510
+ ds_name=self.ds_name,
511
+ num_image=num_image,
512
+ )
513
+
514
+ # Create the final return dictionary
515
+ ret = dict(
516
+ chosen_input_ids=chosen_ret['input_ids'][0],
517
+ chosen_labels=chosen_ret['labels'][0],
518
+ chosen_attention_mask=chosen_ret['attention_mask'][0],
519
+ rejected_input_ids=rejected_ret['input_ids'][0],
520
+ rejected_labels=rejected_ret['labels'][0],
521
+ rejected_attention_mask=rejected_ret['attention_mask'][0],
522
+ pixel_values=pixel_values,
523
+ image_flags=torch.tensor([1] * num_patches, dtype=torch.long),
524
+ )
525
+ return ret
526
+
527
+ def video_get_item(self, data_item):
528
+ # Build transformation function
529
+ transform = self.get_transform()
530
+
531
+ # Ensure the first conversation contains a video placeholder
532
+ if '<video>' not in data_item['question']:
533
+ data_item['question'] = '<video>\n' + data_item['question']
534
+
535
+ # Get the video file path
536
+ video_file = data_item['video']
537
+ video_path = os.path.join(self.root, video_file)
538
+
539
+ # Load the video frames using tcs_loader
540
+ # TODO: Load videos without using tcsloader.
541
+ image_list = self.tcs_loader(
542
+ video_path,
543
+ image_type='video',
544
+ max_num_frames=self.max_num_frame,
545
+ min_num_frames=self.min_num_frame,
546
+ sample=self.sampling_method,
547
+ clip=data_item.get('clip', None))
548
+
549
+ # Generate special tokens for each video frame
550
+ special_tokens = '\n'.join(['Frame{}: <image>'.format(i + 1) for i in range(len(image_list))])
551
+ data_item['question'] = data_item['question'].replace('<video>\n', special_tokens)
552
+
553
+ # Transform each frame image and stack them into a tensor
554
+ pixel_values = [transform(image) for image in image_list]
555
+ pixel_values = torch.stack(pixel_values)
556
+ num_patches = pixel_values.size(0)
557
+
558
+ # Select the appropriate preprocessing function based on the template name
559
+ preprocess_function = self.get_preprocess_function()
560
+
561
+ # Preprocess the conversations and generate the return dictionary
562
+ num_image_tokens = [self.num_image_token] * num_patches
563
+
564
+ chosen_conversations = [
565
+ {'from': 'human', 'value': data_item['question']},
566
+ {'from': 'gpt', 'value': data_item['chosen']},
567
+ ]
568
+ chosen_ret = preprocess_function(
569
+ self.template_name,
570
+ [deepcopy(chosen_conversations)],
571
+ self.tokenizer,
572
+ num_image_tokens,
573
+ group_by_length=True,
574
+ use_packed_ds=self.use_packed_ds,
575
+ ds_name=self.ds_name,
576
+ num_image=num_patches,
577
+ )
578
+
579
+ rejected_conversations = [
580
+ {'from': 'human', 'value': data_item['question']},
581
+ {'from': 'gpt', 'value': data_item['rejected']},
582
+ ]
583
+ rejected_ret = preprocess_function(
584
+ self.template_name,
585
+ [deepcopy(rejected_conversations)],
586
+ self.tokenizer,
587
+ num_image_tokens,
588
+ group_by_length=True,
589
+ use_packed_ds=self.use_packed_ds,
590
+ ds_name=self.ds_name,
591
+ num_image=num_patches,
592
+ )
593
+
594
+ ret = dict(
595
+ chosen_input_ids=chosen_ret['input_ids'][0],
596
+ chosen_labels=chosen_ret['labels'][0],
597
+ chosen_attention_mask=chosen_ret['attention_mask'][0],
598
+ rejected_input_ids=rejected_ret['input_ids'][0],
599
+ rejected_labels=rejected_ret['labels'][0],
600
+ rejected_attention_mask=rejected_ret['attention_mask'][0],
601
+ pixel_values=pixel_values,
602
+ image_flags=torch.tensor([1] * num_patches, dtype=torch.long),
603
+ )
604
+ return ret
605
+
606
+ def pure_text_get_item(self, data_item):
607
+ # Build transformation function
608
+ transform = self.get_transform()
609
+
610
+ # Create a blank white image
611
+ image = Image.new('RGB', (224, 224), (255, 255, 255))
612
+
613
+ # Dynamically preprocess the image to generate patches
614
+ images = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=1,
615
+ image_size=self.image_size, use_thumbnail=self.use_thumbnail)
616
+
617
+ # Apply the transformation to each image patch and stack them into a tensor
618
+ pixel_values = [transform(image) for image in images]
619
+ pixel_values = torch.stack(pixel_values)
620
+ num_patches = pixel_values.size(0)
621
+
622
+ # Ensure there is only one patch
623
+ assert num_patches == 1, f'The number of patches should be 1, but got {num_patches}.'
624
+
625
+ # Select the appropriate preprocessing function based on the template name
626
+ preprocess_function = self.get_preprocess_function()
627
+
628
+ # Preprocess the conversations and generate the return dictionary
629
+ chosen_conversations = [
630
+ {'from': 'human', 'value': data_item['question']},
631
+ {'from': 'gpt', 'value': data_item['chosen']},
632
+ ]
633
+ chosen_ret = preprocess_function(
634
+ self.template_name,
635
+ [deepcopy(chosen_conversations)],
636
+ self.tokenizer,
637
+ [self.num_image_token * num_patches],
638
+ text_only=True,
639
+ group_by_length=True,
640
+ ds_name=self.ds_name,
641
+ )
642
+
643
+ rejected_conversations = [
644
+ {'from': 'human', 'value': data_item['question']},
645
+ {'from': 'gpt', 'value': data_item['rejected']},
646
+ ]
647
+ rejected_ret = preprocess_function(
648
+ self.template_name,
649
+ [deepcopy(rejected_conversations)],
650
+ self.tokenizer,
651
+ [self.num_image_token * num_patches],
652
+ text_only=True,
653
+ group_by_length=True,
654
+ ds_name=self.ds_name,
655
+ )
656
+
657
+ # Create the final return dictionary
658
+ ret = dict(
659
+ chosen_input_ids=chosen_ret['input_ids'][0],
660
+ chosen_labels=chosen_ret['labels'][0],
661
+ chosen_attention_mask=chosen_ret['attention_mask'][0],
662
+ rejected_input_ids=rejected_ret['input_ids'][0],
663
+ rejected_labels=rejected_ret['labels'][0],
664
+ rejected_attention_mask=rejected_ret['attention_mask'][0],
665
+ pixel_values=pixel_values,
666
+ image_flags=torch.tensor([0] * num_patches, dtype=torch.long),
667
+ )
668
+ return ret
669
+
670
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
671
+ i = i % len(self.raw_data)
672
+
673
+ try_cnt, max_try = 0, 10
674
+ while True:
675
+ if try_cnt > max_try:
676
+ raise StopIteration
677
+ try:
678
+ data_item = json.loads(self.raw_data[i])
679
+ if 'image' in data_item and len(data_item['image']) != 0:
680
+ if type(data_item['image']) == list:
681
+ ret = self.multi_modal_multi_image_get_item(data_item)
682
+ else:
683
+ ret = self.multi_modal_get_item(data_item)
684
+ elif 'video' in data_item and data_item['video'] is not None and data_item['video'] != '':
685
+ ret = self.video_get_item(data_item)
686
+ else:
687
+ ret = self.pure_text_get_item(data_item)
688
+ break
689
+ except Exception as e:
690
+ try_cnt += 1
691
+ print(e, self.ds_name, flush=True)
692
+ if not isinstance(e, (UnidentifiedImageError, FileNotFoundError)):
693
+ traceback.print_exc()
694
+ data_item = json.loads(self.raw_data[i])
695
+ if 'image' in data_item:
696
+ if type(data_item['image']) == list:
697
+ images = [self.root + item for item in data_item['image']]
698
+ print(f'Failed to load image: {images}, the dataset is: {self.ds_name}')
699
+ else:
700
+ if data_item['image'].startswith('s3://'):
701
+ data_path = self.root + data_item['image']
702
+ else:
703
+ data_path = os.path.join(self.root, data_item['image'])
704
+ print(f'Failed to load image: {data_path}, the dataset is: {self.ds_name}')
705
+ elif 'video' in data_item:
706
+ data_path = os.path.join(self.root, data_item['video'])
707
+ print(f'Failed to load video: {data_path}, the dataset is: {self.ds_name}')
708
+ i = random.randint(0, len(self.raw_data) - 1)
709
+ return ret
710
+
711
+
712
+ def build_datasets(
713
+ data_args,
714
+ tokenizer,
715
+ tcs_loader,
716
+ model,
717
+ group_by_length=False,
718
+ dynamic_image_size=False,
719
+ use_thumbnail=False,
720
+ min_dynamic_patch=1,
721
+ max_dynamic_patch=12,
722
+ min_num_frame=8,
723
+ max_num_frame=32,
724
+ normalize_type='imagenet',
725
+ ):
726
+ datasets = []
727
+ lengths = []
728
+ ds_collections = json.loads(open(data_args.meta_path).read())
729
+ for ds_idx, ds_name in enumerate(ds_collections.keys()):
730
+ repeat_time = ds_collections[ds_name]['repeat_time']
731
+ if 'max_dynamic_patch' in ds_collections[ds_name]:
732
+ max_num = ds_collections[ds_name]['max_dynamic_patch']
733
+ logger.info(f'max_dynamic_patch is set to {max_num} according to the meta file')
734
+ else:
735
+ max_num = max_dynamic_patch
736
+ dataset = LazySupervisedDataset(
737
+ data_args.conv_style, ds_collections[ds_name],
738
+ tokenizer,
739
+ tcs_loader,
740
+ ds_name=ds_name,
741
+ num_image_token=model.num_image_token,
742
+ image_size=data_args.force_image_size,
743
+ is_train=ds_collections[ds_name].get('data_augment', False),
744
+ pad2square=data_args.pad2square,
745
+ group_by_length=group_by_length,
746
+ dynamic_image_size=dynamic_image_size,
747
+ use_thumbnail=use_thumbnail,
748
+ min_dynamic_patch=min_dynamic_patch,
749
+ max_dynamic_patch=max_num,
750
+ min_num_frame=min_num_frame,
751
+ max_num_frame=max_num_frame,
752
+ repeat_time=repeat_time,
753
+ normalize_type=normalize_type,
754
+ random_seed=ds_idx,
755
+ )
756
+ logger.info(f'Add dataset: {ds_name} with length: {len(dataset)}')
757
+ datasets.append(dataset)
758
+ if data_args.use_data_resampling:
759
+ lengths.append(math.sqrt(len(dataset)))
760
+ else:
761
+ lengths.append(len(dataset))
762
+
763
+ if data_args.use_data_resampling:
764
+ total_length = sum(lengths)
765
+ weights = [l / total_length for l in lengths]
766
+ train_dataset = WeightedConcatDataset(datasets, weights)
767
+ else:
768
+ train_dataset = ConcatDataset(datasets)
769
+ return train_dataset
770
+
771
+
772
+ def main():
773
+ # Apply necessary patches for the transformers library
774
+ replace_llama_rmsnorm_with_fused_rmsnorm()
775
+ replace_train_sampler()
776
+
777
+ # Parse input arguments
778
+ # See all possible arguments in src/transformers/training_args.py
779
+ # If use DeepSpeed zero3, init_dist must before HfArgumentParser
780
+ launcher = os.environ.get('LAUNCHER', 'slurm')
781
+ init_dist(launcher=launcher, backend='nccl')
782
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, DPOConfig))
783
+ if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
784
+ # If we pass only one argument to the script, and it's the path to a json file,
785
+ # let's parse it to get our arguments.
786
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
787
+ else:
788
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
789
+
790
+ training_args.remove_unused_columns = False
791
+ training_args.gradient_checkpointing = model_args.grad_checkpoint
792
+ training_args.sigmoid_loss_weight = data_args.sigmoid_loss_weight
793
+ training_args.bco_pair_loss_weight = data_args.bco_pair_loss_weight
794
+
795
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
796
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
797
+ # send_example_telemetry('InternV-Chat', model_args, data_args)
798
+
799
+ # Setup logging
800
+ logging.basicConfig(
801
+ format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
802
+ datefmt='%m/%d/%Y %H:%M:%S',
803
+ handlers=[logging.StreamHandler(sys.stdout)],
804
+ )
805
+
806
+ if training_args.should_log:
807
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
808
+ transformers.utils.logging.set_verbosity_info()
809
+
810
+ log_level = training_args.get_process_log_level()
811
+ logger.setLevel(log_level)
812
+ set_verbosity(log_level)
813
+ enable_default_handler()
814
+ enable_explicit_format()
815
+
816
+ # Log on each process the small summary:
817
+ logger.warning(
818
+ f'Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}'
819
+ + f'distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}'
820
+ )
821
+ logger.info(f'Training/evaluation parameters {training_args}')
822
+
823
+ # Detecting last checkpoint and eventually continue from last checkpoint.
824
+ last_checkpoint = None
825
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
826
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
827
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
828
+ raise ValueError(
829
+ f'Output directory ({training_args.output_dir}) already exists and is not empty. '
830
+ 'Use --overwrite_output_dir to overcome.'
831
+ )
832
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
833
+ logger.info(
834
+ f'Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change '
835
+ 'the `--output_dir` or add `--overwrite_output_dir` to train from scratch.'
836
+ )
837
+ # Set seed before initializing model.
838
+ set_seed(training_args.seed)
839
+
840
+ # Load pretrained model, tokenizer, and image processor
841
+ tokenizer_path = model_args.model_name_or_path or model_args.llm_path
842
+ logger.info(f'Loading Tokenizer: {tokenizer_path}')
843
+ tokenizer = AutoTokenizer.from_pretrained(
844
+ tokenizer_path, add_eos_token=False, trust_remote_code=True, use_fast=model_args.use_fast_tokenizer)
845
+ tokenizer.tokenizer_path = tokenizer_path
846
+ tokenizer.model_max_length = data_args.max_seq_length
847
+ token_list = [IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN,
848
+ QUAD_START_TOKEN, QUAD_END_TOKEN, REF_START_TOKEN,
849
+ REF_END_TOKEN, BOX_START_TOKEN, BOX_END_TOKEN]
850
+ num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=True)
851
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
852
+ tcs_loader = TCSLoader('~/petreloss.conf') if has_tcs_loader else None
853
+
854
+ if model_args.use_liger:
855
+ from internvl.patch import apply_liger_kernel_to_internvit
856
+ from liger_kernel.transformers import (apply_liger_kernel_to_llama,
857
+ apply_liger_kernel_to_qwen2)
858
+ apply_liger_kernel_to_llama()
859
+ apply_liger_kernel_to_qwen2()
860
+ # apply_liger_kernel_to_internvit()
861
+
862
+ if model_args.model_name_or_path is not None:
863
+ logger.info('Loading InternVLChatModel...')
864
+ config = InternVLChatConfig.from_pretrained(model_args.model_name_or_path)
865
+ config.vision_config.drop_path_rate = model_args.drop_path_rate
866
+ if config.llm_config.model_type == 'internlm2':
867
+ config.llm_config.attn_implementation = 'flash_attention_2' # for InternLM
868
+ logger.info('Using flash_attention_2 for InternLM')
869
+ else:
870
+ config.llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
871
+ logger.info('Using flash_attention_2 for LLaMA')
872
+ config.template = data_args.conv_style
873
+ config.select_layer = model_args.vision_select_layer
874
+ config.dynamic_image_size = data_args.dynamic_image_size
875
+ config.use_thumbnail = data_args.use_thumbnail
876
+ config.ps_version = model_args.ps_version
877
+ config.min_dynamic_patch = data_args.min_dynamic_patch
878
+ config.max_dynamic_patch = data_args.max_dynamic_patch
879
+ model = InternVLChatModel.from_pretrained(
880
+ model_args.model_name_or_path, torch_dtype=torch.bfloat16, config=config)
881
+ ref_model = InternVLChatModel.from_pretrained(
882
+ model_args.model_name_or_path, torch_dtype=torch.bfloat16, config=config)
883
+ else:
884
+ logger.info('Loading ViT-6B...')
885
+ vision_config = InternVisionConfig.from_pretrained(model_args.vision_path)
886
+ vision_config.drop_path_rate = model_args.drop_path_rate
887
+ vision_model = InternVisionModel.from_pretrained(
888
+ model_args.vision_path, torch_dtype=torch.bfloat16, config=vision_config)
889
+ logger.info('Loading LLaMA...')
890
+ llm_config = AutoConfig.from_pretrained(model_args.llm_path, trust_remote_code=True)
891
+ if llm_config.model_type == 'internlm2':
892
+ model_type = InternLM2ForCausalLM
893
+ llm_config.attn_implementation = 'flash_attention_2' # for InternLM
894
+ logger.info('Using flash_attention_2 for InternLM')
895
+ else:
896
+ model_type = AutoModelForCausalLM
897
+ llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
898
+ logger.info('Using flash_attention_2 for LLaMA')
899
+ llm = model_type.from_pretrained(
900
+ model_args.llm_path, torch_dtype=torch.bfloat16,
901
+ config=llm_config, trust_remote_code=True)
902
+ logger.info('Building InternVLChatConfig...')
903
+ internvl_chat_config = InternVLChatConfig(
904
+ vision_config.to_dict(), llm_config.to_dict(), downsample_ratio=data_args.down_sample_ratio,
905
+ pad2square=data_args.pad2square, template=data_args.conv_style,
906
+ select_layer=model_args.vision_select_layer, dynamic_image_size=data_args.dynamic_image_size,
907
+ use_thumbnail=data_args.use_thumbnail, ps_version=model_args.ps_version,
908
+ min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch)
909
+ internvl_chat_config.force_image_size = data_args.force_image_size
910
+ logger.info('Building InternVLChatModel...')
911
+ model = InternVLChatModel(internvl_chat_config, vision_model, llm)
912
+ model.img_context_token_id = img_context_token_id
913
+ ref_model.img_context_token_id = img_context_token_id
914
+
915
+ assert model.config.downsample_ratio == data_args.down_sample_ratio
916
+ assert ref_model.config.downsample_ratio == data_args.down_sample_ratio
917
+
918
+ if model_args.mlp_path is not None:
919
+ logger.info('Loading pretrained MLP projector...')
920
+ state_dict = torch.load(model_args.mlp_path, map_location='cpu')
921
+ message = model.mlp1.load_state_dict(state_dict)
922
+ logger.info(message)
923
+ logger.info('Finished')
924
+
925
+ patch_size = model.config.vision_config.patch_size
926
+ logger.info(f'model.config.force_image_size: {model.config.force_image_size}')
927
+ logger.info(f'data_args.force_image_size: {data_args.force_image_size}')
928
+ logger.info(f'model.config.vision_config.image_size: {model.config.vision_config.image_size}')
929
+ if model.config.vision_config.image_size != data_args.force_image_size:
930
+ logger.info(f'Resizing position embedding from '
931
+ f'{model.config.vision_config.image_size} '
932
+ f'to {data_args.force_image_size}...')
933
+ model.vision_model.resize_pos_embeddings(old_size=model.config.vision_config.image_size,
934
+ new_size=data_args.force_image_size,
935
+ patch_size=patch_size)
936
+ model.config.vision_config.image_size = data_args.force_image_size
937
+ model.config.force_image_size = data_args.force_image_size
938
+ model.num_image_token = int((data_args.force_image_size // patch_size) ** 2 * (data_args.down_sample_ratio ** 2))
939
+
940
+ ref_model.config.force_image_size = model.config.force_image_size
941
+ ref_model.num_image_token = model.num_image_token
942
+
943
+ if num_new_tokens > 0:
944
+ model.language_model.resize_token_embeddings(len(tokenizer))
945
+ output_embeddings = model.language_model.get_output_embeddings().weight.data
946
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
947
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
948
+
949
+ model.config.llm_config.vocab_size = len(tokenizer)
950
+ model.language_model.config.vocab_size = len(tokenizer)
951
+
952
+ model.language_model.config.use_cache = False
953
+ model.vision_model.gradient_checkpointing = True
954
+ model.vision_model.encoder.gradient_checkpointing = True
955
+ if model_args.grad_checkpoint:
956
+ model.language_model._set_gradient_checkpointing()
957
+
958
+ train_dataset = build_datasets(
959
+ data_args, tokenizer, tcs_loader, model, group_by_length=training_args.group_by_length,
960
+ dynamic_image_size=data_args.dynamic_image_size, use_thumbnail=data_args.use_thumbnail,
961
+ min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch,
962
+ normalize_type=data_args.normalize_type, min_num_frame=data_args.min_num_frame,
963
+ max_num_frame=data_args.max_num_frame)
964
+
965
+ def _freeze_params(module):
966
+ for param in module.parameters():
967
+ param.requires_grad = False
968
+
969
+ ref_model.eval()
970
+ # _freeze_params(ref_model)
971
+
972
+ if model_args.freeze_backbone:
973
+ # model.vision_model = model.vision_model.eval()
974
+ _freeze_params(model.vision_model)
975
+
976
+ if model_args.freeze_llm:
977
+ model.language_model = model.language_model.eval()
978
+ _freeze_params(model.language_model)
979
+
980
+ if model_args.unfreeze_lm_head:
981
+ model.language_model.lm_head.requires_grad = True
982
+
983
+ if model_args.use_backbone_lora:
984
+ model.wrap_backbone_lora(r=model_args.use_backbone_lora, lora_alpha=2 * model_args.use_backbone_lora)
985
+ model.config.use_backbone_lora = model_args.use_backbone_lora
986
+
987
+ if model_args.use_llm_lora:
988
+ model.wrap_llm_lora(r=model_args.use_llm_lora, lora_alpha=2 * model_args.use_llm_lora)
989
+ model.config.use_llm_lora = model_args.use_llm_lora
990
+
991
+ if model_args.freeze_mlp:
992
+ _freeze_params(model.mlp1)
993
+
994
+ if model_args.unfreeze_vit_layers != 0:
995
+ layers = model.vision_model.encoder.layers[model_args.unfreeze_vit_layers:]
996
+ for k, v in layers.named_parameters():
997
+ logger.info(f'Unfreezing ViT layer: {k}')
998
+ v.requires_grad = True
999
+
1000
+ # print trainable parameters
1001
+ if dist.get_rank() == 0:
1002
+ for name, param in model.named_parameters():
1003
+ if param.requires_grad:
1004
+ logger.info(name)
1005
+
1006
+ # set seed for torch dataloaders
1007
+ set_seed(training_args.seed)
1008
+
1009
+ trainer = MultimodalDPOTrainer(
1010
+ model=model,
1011
+ ref_model=ref_model,
1012
+ args=training_args,
1013
+ train_dataset=train_dataset if training_args.do_train else None,
1014
+ eval_dataset=None,
1015
+ tokenizer=tokenizer,
1016
+ data_collator=dpo_concat_pad_data_collator,
1017
+ )
1018
+
1019
+ # Training
1020
+ if training_args.do_train:
1021
+ checkpoint = None
1022
+ if training_args.resume_from_checkpoint is not None:
1023
+ checkpoint = training_args.resume_from_checkpoint
1024
+ elif last_checkpoint is not None:
1025
+ checkpoint = last_checkpoint
1026
+ print(f'[Memory Usage before training] {torch.cuda.memory_allocated()/1024/1024/1024:.2f}GB')
1027
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
1028
+ trainer.save_model() # Saves the tokenizer too for easy upload
1029
+
1030
+ metrics = train_result.metrics
1031
+ try:
1032
+ metrics['train_samples'] = len(train_dataset)
1033
+ except:
1034
+ metrics['train_samples'] = -1
1035
+
1036
+ trainer.log_metrics('train', metrics)
1037
+ trainer.save_metrics('train', metrics)
1038
+ trainer.save_state()
1039
+
1040
+ model_dir = model_args.model_name_or_path
1041
+ output_dir = training_args.output_dir
1042
+ for filename in [
1043
+ 'conversation.py',
1044
+ 'modeling_internvl_chat.py',
1045
+ 'modeling_intern_vit.py',
1046
+ 'modeling_internlm2.py',
1047
+ 'configuration_internvl_chat.py',
1048
+ 'configuration_intern_vit.py',
1049
+ 'configuration_internlm2.py',
1050
+ ]:
1051
+ if os.path.exists(os.path.join(model_dir, filename)):
1052
+ shutil.copy(os.path.join(model_dir, filename), output_dir)
1053
+
1054
+
1055
+ if __name__ == '__main__':
1056
+ main()
src/third_party/InternVL/internvl_chat/internvl/train/internvl_chat_finetune.py ADDED
@@ -0,0 +1,1072 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import logging
8
+ import math
9
+ import os
10
+ import random
11
+ import sys
12
+ import traceback
13
+ import warnings
14
+ from copy import deepcopy
15
+ from dataclasses import dataclass, field
16
+ from functools import partial
17
+ from typing import Dict, Literal, Optional
18
+
19
+ import numpy as np
20
+
21
+ try:
22
+ import orjson as json
23
+ except:
24
+ import json
25
+
26
+ import torch
27
+ import torch.distributed as dist
28
+ import transformers
29
+ from internvl.dist_utils import init_dist
30
+ from internvl.model.internlm2.modeling_internlm2 import InternLM2ForCausalLM
31
+ from internvl.model.internvl_chat import (InternVisionConfig,
32
+ InternVisionModel,
33
+ InternVLChatConfig,
34
+ InternVLChatModel)
35
+ from internvl.patch import (concat_pad_data_collator,
36
+ replace_internlm2_attention_class,
37
+ replace_llama_attention_class,
38
+ replace_llama_rmsnorm_with_fused_rmsnorm,
39
+ replace_phi3_attention_class,
40
+ replace_qwen2_attention_class,
41
+ replace_train_dataloader, replace_train_sampler)
42
+ from internvl.train.constants import (BOX_END_TOKEN, BOX_START_TOKEN,
43
+ IMG_CONTEXT_TOKEN, IMG_END_TOKEN,
44
+ IMG_START_TOKEN, QUAD_END_TOKEN,
45
+ QUAD_START_TOKEN, REF_END_TOKEN,
46
+ REF_START_TOKEN)
47
+ from internvl.train.dataset import (ConcatDataset, TCSLoader,
48
+ WeightedConcatDataset, build_transform,
49
+ check_conversations_repetition,
50
+ dynamic_preprocess, preprocess,
51
+ preprocess_internlm,
52
+ preprocess_internvl2_5, preprocess_mpt,
53
+ preprocess_phi3)
54
+ from internvl.train.dataset_packed import PackedDataset, packed_collate_fn
55
+ from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError
56
+ from torch.utils.data import Dataset
57
+ from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
58
+ HfArgumentParser, Trainer, TrainingArguments,
59
+ set_seed)
60
+ from transformers.trainer_utils import get_last_checkpoint
61
+ from transformers.utils.logging import (enable_default_handler,
62
+ enable_explicit_format, set_verbosity)
63
+
64
+ # Try to import petrel_client for image loading, fallback to PIL if unavailable
65
+ try:
66
+ from petrel_client.client import Client
67
+ from petrel_client.common.config import Config
68
+ has_tcs_loader = True
69
+ except ImportError as E:
70
+ print('petrel_client is not installed. Using PIL to load images.')
71
+ has_tcs_loader = False
72
+
73
+ # Set constants for image processing and logging
74
+ IGNORE_INDEX = -100
75
+ Image.MAX_IMAGE_PIXELS = None
76
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
77
+ MaximumDecompressedSize = 1024
78
+ MegaByte = 2 ** 20
79
+ PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
80
+
81
+ warnings.filterwarnings('ignore')
82
+ logger = logging.getLogger(__name__)
83
+
84
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
85
+
86
+
87
+ @dataclass
88
+ class ModelArguments:
89
+ """
90
+ Arguments for specifying model, tokenizer, and configurations.
91
+ """
92
+ model_name_or_path: Optional[str] = field(
93
+ default=None,
94
+ metadata={'help': 'Path to a pretrained model (local or from huggingface.co/models).'}
95
+ )
96
+ vision_path: Optional[str] = field(
97
+ default=None,
98
+ metadata={'help': 'Path to a pretrained model (local or from huggingface.co/models).'}
99
+ )
100
+ llm_path: Optional[str] = field(
101
+ default=None,
102
+ metadata={'help': 'Path to a pretrained model (local or from huggingface.co/models).'}
103
+ )
104
+ mlp_path: Optional[str] = field(
105
+ default=None,
106
+ metadata={'help': 'Path to a pretrained model (local or from huggingface.co/models).'}
107
+ )
108
+ freeze_llm: bool = field(
109
+ default=False,
110
+ metadata={'help': 'Set to True to freeze the LLM. Default is False.'},
111
+ )
112
+ freeze_backbone: bool = field(
113
+ default=False,
114
+ metadata={'help': 'Set to True to freeze the ViT. Default is False.'},
115
+ )
116
+ freeze_mlp: bool = field(
117
+ default=False,
118
+ metadata={'help': 'Set to True to freeze the MLP. Default is False.'},
119
+ )
120
+ unfreeze_vit_layers: int = field(
121
+ default=0,
122
+ metadata={'help': 'Specify the number of ViT layers to unfreeze. Default is 0.'},
123
+ )
124
+ vision_select_layer: int = field(
125
+ default=-1,
126
+ metadata={'help': 'Specify the layer of ViT feature map to use. Default is -1 for the last layer.'},
127
+ )
128
+ use_backbone_lora: int = field(
129
+ default=0,
130
+ metadata={'help': 'Set the LoRA adapter rank for the ViT. Default is 0.'}
131
+ )
132
+ use_llm_lora: int = field(
133
+ default=0,
134
+ metadata={'help': 'Set the LoRA adapter rank for the LLM. Default is 0.'}
135
+ )
136
+ unfreeze_lm_head: bool = field(
137
+ default=False,
138
+ metadata={'help': 'Set to True to unfreeze the head of LLM. Default is False.'},
139
+ )
140
+ grad_checkpoint: bool = field(
141
+ default=True,
142
+ metadata={'help': 'Set to True to use gradient checkpointing. Default is True.'},
143
+ )
144
+ drop_path_rate: float = field(
145
+ default=0.0,
146
+ metadata={'help': 'Set the drop path rate for the ViT. Default is 0.'},
147
+ )
148
+ ps_version: Literal['v1', 'v2'] = field(
149
+ default='v2',
150
+ metadata={'help': 'Specify the version of pixel shuffle implementation. Default is v2.'}
151
+ )
152
+ use_fast_tokenizer: bool = field(
153
+ default=False,
154
+ metadata={'help': 'Set to True to use the fast mode of the tokenizer.'}
155
+ )
156
+ use_liger: bool = field(
157
+ default=False,
158
+ metadata={'help': 'Set to True to use the liger kernel.'}
159
+ )
160
+
161
+
162
+ @dataclass
163
+ class DataTrainingArguments:
164
+ """
165
+ Arguments for specifying data input for training and evaluation.
166
+ """
167
+ max_seq_length: int = field(
168
+ default=8192,
169
+ metadata={
170
+ 'help': (
171
+ 'The maximum total input sequence length after tokenization. Sequences longer '
172
+ 'than this will be truncated, sequences shorter will be padded.'
173
+ )
174
+ },
175
+ )
176
+ force_image_size: int = field(
177
+ default=448,
178
+ metadata={'help': 'Set the desired size for the image. Default is 448.'},
179
+ )
180
+ down_sample_ratio: float = field(
181
+ default=0.5,
182
+ metadata={'help': 'Set the desired down-sampling ratio for the image. Default is 0.5.'},
183
+ )
184
+ pad2square: bool = field(
185
+ default=False,
186
+ metadata={'help': 'Pad the image to a square shape if set to True. Default is False.'},
187
+ )
188
+ conv_style: str = field(
189
+ default='internlm2-chat', metadata={'help': 'Prompt style for a conversation.'}
190
+ )
191
+ meta_path: str = field(
192
+ default=None,
193
+ metadata={'help': 'The path of the meta file of datasets.'},
194
+ )
195
+ use_data_resampling: bool = field(
196
+ default=False,
197
+ metadata={'help': 'Set to True to use data resampling. Default is False.'},
198
+ )
199
+ dynamic_image_size: bool = field(
200
+ default=False,
201
+ metadata={'help': 'Set to True to use dynamic high resolution strategy. Default is False.'},
202
+ )
203
+ use_thumbnail: bool = field(
204
+ default=False,
205
+ metadata={'help': 'Set to True to add a thumbnail image. Default is False.'},
206
+ )
207
+ min_dynamic_patch: int = field(
208
+ default=1,
209
+ metadata={'help': 'The minimum number of dynamic patches. Default is 1.'},
210
+ )
211
+ max_dynamic_patch: int = field(
212
+ default=12,
213
+ metadata={'help': 'The maximum number of dynamic patches. Default is 12.'},
214
+ )
215
+ min_num_frame: int = field(
216
+ default=8,
217
+ metadata={'help': 'The minimum number of frames for video data. Default is 8.'},
218
+ )
219
+ max_num_frame: int = field(
220
+ default=32,
221
+ metadata={'help': 'The maximum number of frames for video data. Default is 32.'},
222
+ )
223
+ normalize_type: Literal['imagenet', 'clip', 'siglip'] = field(
224
+ default='imagenet',
225
+ metadata={'help': 'The normalization type for the image. Default is imagenet.'},
226
+ )
227
+ use_packed_ds: bool = field(
228
+ default=False,
229
+ metadata={'help': 'Whether to use packed dataset for efficient training. Default is False.'},
230
+ )
231
+ num_images_expected: int = field(
232
+ default=40,
233
+ metadata={'help': 'The maximum number of images per packed sample. Default is 40.'},
234
+ )
235
+ max_packed_tokens: int = field(
236
+ default=8192,
237
+ metadata={'help': 'The required token length of per packed sample. Default is 8192.'},
238
+ )
239
+ max_buffer_size: int = field(
240
+ default=20,
241
+ metadata={'help': 'The buffer size of the packed dataset. Default is 20.'},
242
+ )
243
+ log_freq: int = field(
244
+ default=1000,
245
+ metadata={'help': 'The log frequency of the packed dataset. Default is 1000.'},
246
+ )
247
+ strict_mode: bool = field(
248
+ default=True,
249
+ metadata={'help': 'Whether to pad the number of images to satisfy num_images_expected. Default is True.'},
250
+ )
251
+ replacement: bool = field(
252
+ default=False,
253
+ metadata={'help': 'Whether to restart the dataset after it is exhausted. Default is False.'},
254
+ )
255
+ allow_overflow: bool = field(
256
+ default=False,
257
+ metadata={'help': 'Whether to drop the sample over the specified max_packed_tokens. Default is False.'},
258
+ )
259
+ loss_reduction: str = field(
260
+ default='token',
261
+ metadata={'help': 'Loss reduction method. Default is token.'},
262
+ )
263
+ loss_reduction_all_gather: bool = field(
264
+ default=False,
265
+ metadata={'help': 'Whether to gather all during loss reduction. Default is False.'},
266
+ )
267
+
268
+
269
+ class LazySupervisedDataset(Dataset):
270
+ """Dataset for supervised fine-tuning."""
271
+
272
+ def __init__(
273
+ self,
274
+ template_name,
275
+ meta,
276
+ tokenizer,
277
+ tcs_loader,
278
+ ds_name,
279
+ num_image_token,
280
+ image_size=448,
281
+ is_train=True,
282
+ pad2square=False,
283
+ group_by_length=False,
284
+ dynamic_image_size=False,
285
+ use_thumbnail=False,
286
+ min_dynamic_patch=1,
287
+ max_dynamic_patch=12,
288
+ min_num_frame=8, # for video data
289
+ max_num_frame=32, # for video data
290
+ sampling_method='rand', # for video data
291
+ repeat_time=1,
292
+ normalize_type='imagenet',
293
+ # hyperparameters for packed training
294
+ use_packed_ds=False,
295
+ data_rank=0,
296
+ data_world_size=1,
297
+ distributed_mode=False,
298
+ force_shuffle=False,
299
+ random_seed=0,
300
+ ):
301
+ super(LazySupervisedDataset, self).__init__()
302
+ self.ds_name = ds_name
303
+ self.tokenizer = tokenizer
304
+ self.template_name = template_name
305
+ self.num_image_token = num_image_token
306
+ logger.info(f'[Dataset] num_image_token: {num_image_token}')
307
+ logger.info(f'[Dataset] dynamic_image_size: {dynamic_image_size}')
308
+ logger.info(f'[Dataset] use_thumbnail: {use_thumbnail}')
309
+ logger.info(f'[Dataset] min_dynamic_patch: {min_dynamic_patch}, max_dynamic_patch: {max_dynamic_patch}')
310
+
311
+ self.image_size = image_size
312
+ self.is_train = is_train
313
+ self.pad2square = pad2square
314
+ self.max_num_frame = max_num_frame
315
+ self.min_num_frame = min_num_frame
316
+ self.sampling_method = sampling_method
317
+
318
+ # hyperparameters for distributed training
319
+ self.use_packed_ds = use_packed_ds
320
+ self.data_rank = data_rank
321
+ self.data_world_size = data_world_size
322
+ self.worker_id = None
323
+ self.worker_state_key = None
324
+ self.worker_distributed = False
325
+ self.distributed_mode = distributed_mode
326
+ # hyperparameters for packed dataset
327
+ self.dataset_type = 'pair'
328
+ self.max_num_images = 1
329
+ self.max_tokens = tokenizer.model_max_length
330
+ self.force_shuffle = force_shuffle
331
+ # TODO: quick resume
332
+ self._state_dict = {}
333
+
334
+ logger.info('Formatting inputs...Skip in lazy mode')
335
+ assert meta['annotation'].endswith('jsonl'), f'annotation must be jsonl, but got {meta["annotation"]}'
336
+
337
+ with open(meta['annotation'], 'r') as f:
338
+ self.raw_data = f.readlines()
339
+ if repeat_time < 1:
340
+ # If repeat_time is less than 1, select a portion of the data
341
+ self.raw_data = self.raw_data[:int(len(self.raw_data) * repeat_time)]
342
+ if repeat_time > 1:
343
+ assert isinstance(repeat_time, int)
344
+ # Repeat the list if repeat_time is greater than 1
345
+ self.raw_data = self.raw_data * repeat_time
346
+
347
+ self.rng = np.random.default_rng(seed=random_seed)
348
+ if self.force_shuffle:
349
+ self.rng.shuffle(self.raw_data)
350
+
351
+ self.root = meta['root']
352
+ self.cached_data_dict = {}
353
+ self.tcs_loader = tcs_loader
354
+ self.group_by_length = group_by_length
355
+ self.dynamic_image_size = dynamic_image_size
356
+ self.use_thumbnail = use_thumbnail
357
+ self.min_dynamic_patch = min_dynamic_patch
358
+ self.max_dynamic_patch = max_dynamic_patch
359
+ self.normalize_type = normalize_type
360
+
361
+ # If the precomputed length does not exist, roughly estimate the length of
362
+ # each sample to improve the efficiency of group_by_length.
363
+ if self.group_by_length:
364
+ self.conv2length = {} # Using a dictionary to speed up token length calculation
365
+ self.length = []
366
+ for data_item in self.raw_data:
367
+ data_item = json.loads(data_item)
368
+ if 'length' in data_item:
369
+ token_length = data_item['length'] # Use precomputed length if available
370
+ else:
371
+ # Compute token length using the tokenizer
372
+ conversations = '\n'.join([temp['value'] for temp in data_item['conversations']])
373
+ str_length = len(conversations)
374
+ if str_length not in self.conv2length:
375
+ token_length = tokenizer(
376
+ conversations, return_tensors='pt', padding=False, truncation=False,
377
+ ).input_ids.size(1)
378
+ self.conv2length[str_length] = token_length + num_image_token * (
379
+ max_dynamic_patch + use_thumbnail)
380
+ else:
381
+ token_length = self.conv2length[str_length]
382
+ self.length.append(token_length)
383
+
384
+ def __len__(self):
385
+ return len(self.raw_data)
386
+
387
+ def get_preprocess_function(self):
388
+ # Select the appropriate preprocessing function based on the template name
389
+ if self.template_name == 'Hermes-2':
390
+ preprocess_function = preprocess_mpt
391
+ elif self.template_name == 'internlm2-chat':
392
+ preprocess_function = preprocess_internlm
393
+ elif self.template_name == 'phi3-chat':
394
+ preprocess_function = preprocess_phi3
395
+ elif self.template_name == 'internvl2_5':
396
+ preprocess_function = preprocess_internvl2_5
397
+ else:
398
+ preprocess_function = preprocess
399
+ return preprocess_function
400
+
401
+ def load_image(self, image_path):
402
+ # Load the image using tcs_loader if available, otherwise use PIL
403
+ if self.tcs_loader is not None and 's3://' in image_path:
404
+ return self.tcs_loader(image_path)
405
+ return Image.open(image_path).convert('RGB')
406
+
407
+ def get_image_path(self, image_path):
408
+ if image_path.startswith('s3://'): # for ceph
409
+ image_path = self.root + image_path
410
+ else: # for local image
411
+ image_path = os.path.join(self.root, image_path)
412
+ return image_path
413
+
414
+ def get_transform(self):
415
+ # Build transformation function
416
+ transform = build_transform(is_train=self.is_train, input_size=self.image_size,
417
+ pad2square=self.pad2square, normalize_type=self.normalize_type)
418
+ return transform
419
+
420
+ def multi_modal_get_item(self, data_item):
421
+ # Build transformation function
422
+ transform = self.get_transform()
423
+
424
+ # Ensure the first conversation contains an image placeholder
425
+ if '<image>' not in data_item['conversations'][0]['value']:
426
+ data_item['conversations'][0]['value'] = '<image>\n' + data_item['conversations'][0]['value']
427
+
428
+ # Merge the image path
429
+ image_path = self.get_image_path(data_item['image'])
430
+
431
+ # Load the image using tcs_loader if available, otherwise use PIL
432
+ image = self.load_image(image_path)
433
+
434
+ if self.dynamic_image_size: # If dynamic image size is enabled, preprocess the image dynamically
435
+ images = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=self.max_dynamic_patch,
436
+ image_size=self.image_size, use_thumbnail=self.use_thumbnail)
437
+ else: # Otherwise, use the original image as a single patch
438
+ images = [image]
439
+
440
+ # Apply the transformation to each image and stack the results into a tensor
441
+ pixel_values = [transform(image) for image in images]
442
+ pixel_values = torch.stack(pixel_values)
443
+
444
+ # Ensure that there is only one patch if dynamic image size is not enabled
445
+ num_patches = pixel_values.size(0)
446
+ if not self.dynamic_image_size:
447
+ assert num_patches == 1, f'The number of patches should be 1, but got {num_patches}.'
448
+
449
+ # Select the appropriate preprocessing function based on the template name
450
+ preprocess_function = self.get_preprocess_function()
451
+
452
+ # Preprocess the conversations and generate the return dictionary
453
+ ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
454
+ self.tokenizer, [self.num_image_token * num_patches],
455
+ group_by_length=self.group_by_length,
456
+ use_packed_ds=self.use_packed_ds, ds_name=self.ds_name)
457
+
458
+ # Calculate position_ids for packed dataset
459
+ position_ids = ret['attention_mask'].long().cumsum(-1) - 1
460
+ position_ids.masked_fill_(ret['attention_mask'] == 0, 1)
461
+ image_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN)
462
+ assert (ret['input_ids'][0] == image_end_token_id).sum() == 1, f'image tokens are truncated, this dataset is {self.ds_name}'
463
+
464
+ # Create the final return dictionary
465
+ ret = dict(
466
+ input_ids=ret['input_ids'][0],
467
+ labels=ret['labels'][0],
468
+ attention_mask=ret['attention_mask'][0],
469
+ position_ids=position_ids[0],
470
+ pixel_values=pixel_values,
471
+ image_flags=torch.tensor([1] * num_patches, dtype=torch.long)
472
+ )
473
+ return ret
474
+
475
+ def multi_modal_multi_image_get_item(self, data_item):
476
+ # Build transformation function
477
+ transform = self.get_transform()
478
+
479
+ images, num_tiles = [], []
480
+ num_image = len(data_item['image'])
481
+ for image_path in data_item['image']:
482
+ # Merge the image path
483
+ image_path = self.get_image_path(image_path)
484
+ # Load the image using tcs_loader if available, otherwise use PIL
485
+ image = self.load_image(image_path)
486
+ if self.dynamic_image_size: # If dynamic image size is enabled, preprocess the image dynamically
487
+ image = dynamic_preprocess(image, min_num=self.min_dynamic_patch,
488
+ max_num=max(1, self.max_dynamic_patch // num_image),
489
+ image_size=self.image_size, use_thumbnail=self.use_thumbnail)
490
+ images += image
491
+ num_tiles.append(len(image))
492
+ else: # Otherwise, use the original image as a single patch
493
+ images.append(image)
494
+ num_tiles.append(1)
495
+ pixel_values = [transform(image) for image in images]
496
+ pixel_values = torch.stack(pixel_values)
497
+ num_patches = pixel_values.size(0)
498
+
499
+ # Select the appropriate preprocessing function based on the template name
500
+ preprocess_function = self.get_preprocess_function()
501
+
502
+ # Preprocess the conversations and generate the return dictionary
503
+ num_image_tokens = [self.num_image_token * num_tile for num_tile in num_tiles]
504
+ ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
505
+ self.tokenizer, num_image_tokens, group_by_length=self.group_by_length,
506
+ use_packed_ds=self.use_packed_ds, ds_name=self.ds_name, num_image=num_image)
507
+
508
+ # Calculate position_ids for packed dataset
509
+ position_ids = ret['attention_mask'].long().cumsum(-1) - 1
510
+ position_ids.masked_fill_(ret['attention_mask'] == 0, 1)
511
+ image_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN)
512
+ assert (ret['input_ids'][0] == image_end_token_id).sum() == num_image, f'image tokens are truncated, this dataset is {self.ds_name}'
513
+
514
+ # Create the final return dictionary
515
+ ret = dict(
516
+ input_ids=ret['input_ids'][0],
517
+ labels=ret['labels'][0],
518
+ attention_mask=ret['attention_mask'][0],
519
+ position_ids=position_ids[0],
520
+ pixel_values=pixel_values,
521
+ image_flags=torch.tensor([1] * num_patches, dtype=torch.long)
522
+ )
523
+ return ret
524
+
525
+ def video_get_item(self, data_item):
526
+ # Build transformation function
527
+ transform = self.get_transform()
528
+
529
+ # Ensure the first conversation contains a video placeholder
530
+ if '<video>' not in data_item['conversations'][0]['value']:
531
+ data_item['conversations'][0]['value'] = '<video>\n' + data_item['conversations'][0]['value']
532
+
533
+ # Get the video file path
534
+ video_file = data_item['video']
535
+ video_path = os.path.join(self.root, video_file)
536
+
537
+ # Load the video frames using tcs_loader
538
+ # TODO: Load videos without using tcsloader.
539
+ image_list = self.tcs_loader(
540
+ video_path,
541
+ image_type='video',
542
+ max_num_frames=self.max_num_frame,
543
+ min_num_frames=self.min_num_frame,
544
+ sample=self.sampling_method,
545
+ clip=data_item.get('clip', None))
546
+
547
+ # Generate special tokens for each video frame
548
+ special_tokens = '\n'.join(['Frame-{}: <image>'.format(i + 1) for i in range(len(image_list))])
549
+ data_item['conversations'][0]['value'] = data_item['conversations'][0]['value'].replace(
550
+ '<video>\n', special_tokens + '\n')
551
+
552
+ # Transform each frame image and stack them into a tensor
553
+ pixel_values = [transform(image) for image in image_list]
554
+ pixel_values = torch.stack(pixel_values)
555
+ num_patches = pixel_values.size(0)
556
+
557
+ # Select the appropriate preprocessing function based on the template name
558
+ preprocess_function = self.get_preprocess_function()
559
+
560
+ # Preprocess the conversations and generate the return dictionary
561
+ num_image_tokens = [self.num_image_token] * num_patches
562
+ ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
563
+ self.tokenizer, num_image_tokens, group_by_length=self.group_by_length,
564
+ use_packed_ds=self.use_packed_ds, ds_name=self.ds_name, num_image=num_patches)
565
+
566
+ # Calculate position_ids for packed dataset
567
+ position_ids = ret['attention_mask'].long().cumsum(-1) - 1
568
+ position_ids.masked_fill_(ret['attention_mask'] == 0, 1)
569
+
570
+ # Create the final return dictionary
571
+ ret = dict(
572
+ input_ids=ret['input_ids'][0],
573
+ labels=ret['labels'][0],
574
+ attention_mask=ret['attention_mask'][0],
575
+ position_ids=position_ids[0],
576
+ pixel_values=pixel_values,
577
+ image_flags=torch.tensor([1] * num_patches, dtype=torch.long)
578
+ )
579
+ return ret
580
+
581
+ def pure_text_get_item(self, data_item):
582
+ # Build transformation function
583
+ transform = self.get_transform()
584
+
585
+ # Create a blank white image
586
+ image = Image.new('RGB', (224, 224), (255, 255, 255))
587
+
588
+ # Dynamically preprocess the image to generate patches
589
+ images = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=1,
590
+ image_size=self.image_size, use_thumbnail=self.use_thumbnail)
591
+
592
+ # Apply the transformation to each image patch and stack them into a tensor
593
+ pixel_values = [transform(image) for image in images]
594
+ pixel_values = torch.stack(pixel_values)
595
+ num_patches = pixel_values.size(0)
596
+
597
+ # Ensure there is only one patch
598
+ assert num_patches == 1, f'The number of patches should be 1, but got {num_patches}.'
599
+
600
+ # Select the appropriate preprocessing function based on the template name
601
+ preprocess_function = self.get_preprocess_function()
602
+
603
+ # Preprocess the conversations and generate the return dictionary
604
+ ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
605
+ self.tokenizer, [self.num_image_token * num_patches], text_only=True,
606
+ group_by_length=self.group_by_length, use_packed_ds=self.use_packed_ds,
607
+ ds_name=self.ds_name)
608
+
609
+ # Calculate position_ids for packed dataset
610
+ position_ids = ret['attention_mask'].long().cumsum(-1) - 1
611
+ position_ids.masked_fill_(ret['attention_mask'] == 0, 1)
612
+
613
+ # Create the final return dictionary
614
+ ret = dict(
615
+ input_ids=ret['input_ids'][0],
616
+ labels=ret['labels'][0],
617
+ attention_mask=ret['attention_mask'][0],
618
+ position_ids=position_ids[0],
619
+ pixel_values=pixel_values,
620
+ image_flags=torch.tensor([0] * num_patches, dtype=torch.long)
621
+ )
622
+ return ret
623
+
624
+ def _enable_worker_distributed(self):
625
+ if (
626
+ self.distributed_mode
627
+ and not self.worker_distributed
628
+ and self.worker_id is not None
629
+ ):
630
+ self.worker_distributed = True
631
+ self.raw_data = self.raw_data[self.worker_id::self.num_workers]
632
+ logger.info(f'worker_distributed is enabled, {self.num_workers=}, {len(self.raw_data)=}')
633
+
634
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
635
+ if i >= len(self.raw_data):
636
+ if self.use_packed_ds:
637
+ raise NotImplementedError
638
+ else:
639
+ i = i % len(self.raw_data)
640
+
641
+ try_cnt, max_try = 0, 10
642
+ while True:
643
+ if try_cnt > max_try:
644
+ raise StopIteration
645
+ try:
646
+ data_item = json.loads(self.raw_data[i])
647
+ # conversations = data_item['conversations']
648
+ # check_conversations_repetition(conversations, repeat_threshold=0.4, ngram=10)
649
+ if 'image' in data_item and len(data_item['image']) != 0:
650
+ if type(data_item['image']) == list:
651
+ ret = self.multi_modal_multi_image_get_item(data_item)
652
+ else:
653
+ ret = self.multi_modal_get_item(data_item)
654
+ elif 'video' in data_item and data_item['video'] is not None and data_item['video'] != '':
655
+ ret = self.video_get_item(data_item)
656
+ else:
657
+ ret = self.pure_text_get_item(data_item)
658
+ break
659
+ except Exception as e:
660
+ try_cnt += 1
661
+ print(e, self.ds_name, flush=True)
662
+ if not isinstance(e, (UnidentifiedImageError, FileNotFoundError)):
663
+ traceback.print_exc()
664
+ data_item = json.loads(self.raw_data[i])
665
+ if 'image' in data_item:
666
+ if type(data_item['image']) == list:
667
+ images = [self.root + item for item in data_item['image']]
668
+ print(f'Failed to load image: {images}, the dataset is: {self.ds_name}')
669
+ else:
670
+ if data_item['image'].startswith('s3://'):
671
+ data_path = self.root + data_item['image']
672
+ else:
673
+ data_path = os.path.join(self.root, data_item['image'])
674
+ print(f'Failed to load image: {data_path}, the dataset is: {self.ds_name}')
675
+ elif 'video' in data_item:
676
+ data_path = os.path.join(self.root, data_item['video'])
677
+ print(f'Failed to load video: {data_path}, the dataset is: {self.ds_name}')
678
+ i = random.randint(0, len(self.raw_data) - 1)
679
+ return ret
680
+
681
+ def __iter__(self):
682
+ self._enable_worker_distributed()
683
+ start_idx = 0
684
+
685
+ assert self.worker_state_key is not None
686
+ if self.worker_state_key in self._state_dict and len(self._state_dict[self.worker_state_key]) > 0:
687
+ start_idx = self._state_dict[self.worker_state_key]['current_idx']
688
+
689
+ self._state_dict.pop(self.worker_state_key)
690
+
691
+ if self.worker_id == 0:
692
+ logger.info(
693
+ f'[{self.ds_name}] [Worker id {self.worker_id}] '
694
+ f'begin to iter with {start_idx=}'
695
+ )
696
+
697
+ for i in range(start_idx, len(self)):
698
+ yield self[i]
699
+
700
+
701
+ def build_datasets(
702
+ data_args,
703
+ tokenizer,
704
+ tcs_loader,
705
+ model,
706
+ group_by_length=False,
707
+ dynamic_image_size=False,
708
+ use_thumbnail=False,
709
+ min_dynamic_patch=1,
710
+ max_dynamic_patch=12,
711
+ min_num_frame=8,
712
+ max_num_frame=32,
713
+ normalize_type='imagenet',
714
+ ):
715
+ datasets = []
716
+ lengths = []
717
+ data_rank = dist.get_rank()
718
+ data_world_size = dist.get_world_size()
719
+ ds_collections = json.loads(open(data_args.meta_path).read())
720
+ for ds_idx, ds_name in enumerate(ds_collections.keys()):
721
+ repeat_time = ds_collections[ds_name]['repeat_time']
722
+ if 'max_dynamic_patch' in ds_collections[ds_name]:
723
+ max_num = ds_collections[ds_name]['max_dynamic_patch']
724
+ logger.info(f'max_dynamic_patch is set to {max_num} according to the meta file')
725
+ else:
726
+ max_num = max_dynamic_patch
727
+ dataset = LazySupervisedDataset(
728
+ data_args.conv_style, ds_collections[ds_name],
729
+ tokenizer,
730
+ tcs_loader,
731
+ ds_name=ds_name,
732
+ num_image_token=model.num_image_token,
733
+ image_size=data_args.force_image_size,
734
+ is_train=ds_collections[ds_name]['data_augment'],
735
+ pad2square=data_args.pad2square,
736
+ group_by_length=group_by_length and not data_args.use_packed_ds,
737
+ dynamic_image_size=dynamic_image_size,
738
+ use_thumbnail=use_thumbnail,
739
+ min_dynamic_patch=min_dynamic_patch,
740
+ max_dynamic_patch=max_num,
741
+ min_num_frame=min_num_frame,
742
+ max_num_frame=max_num_frame,
743
+ repeat_time=repeat_time,
744
+ normalize_type=normalize_type,
745
+ # hyperparameters for packed training
746
+ use_packed_ds=data_args.use_packed_ds,
747
+ data_rank=data_rank,
748
+ data_world_size=data_world_size,
749
+ distributed_mode=data_args.use_packed_ds,
750
+ force_shuffle=data_args.use_packed_ds,
751
+ random_seed=ds_idx,
752
+ )
753
+ logger.info(f'Add dataset: {ds_name} with length: {len(dataset)}')
754
+ datasets.append(dataset)
755
+ if data_args.use_data_resampling:
756
+ lengths.append(math.sqrt(len(dataset)))
757
+ else:
758
+ lengths.append(len(dataset))
759
+
760
+ if data_args.use_packed_ds:
761
+ total_length = sum(lengths)
762
+ train_dataset = PackedDataset(
763
+ tokenizer=tokenizer,
764
+ data_rank=data_rank,
765
+ data_world_size=data_world_size,
766
+ datasets=datasets,
767
+ dataset_weight=[l / total_length for l in lengths],
768
+ num_images_expected=data_args.num_images_expected,
769
+ max_packed_tokens=data_args.max_packed_tokens,
770
+ max_buffer_size=data_args.max_buffer_size,
771
+ log_freq=data_args.log_freq,
772
+ strict_mode=data_args.strict_mode,
773
+ replacement=data_args.replacement,
774
+ allow_overflow=data_args.allow_overflow,
775
+ allow_deduplicated_ds_name=False,
776
+ )
777
+ elif data_args.use_data_resampling:
778
+ total_length = sum(lengths)
779
+ weights = [l / total_length for l in lengths]
780
+ train_dataset = WeightedConcatDataset(datasets, weights)
781
+ else:
782
+ train_dataset = ConcatDataset(datasets)
783
+ return train_dataset
784
+
785
+
786
+ def len2weight(x, loss_reduction):
787
+ if x == 0:
788
+ return x
789
+ if loss_reduction == 'token':
790
+ return 1
791
+ if loss_reduction == 'sample':
792
+ return 1 / x
793
+ if loss_reduction == 'square':
794
+ return 1 / (x ** 0.5)
795
+ raise NotImplementedError(loss_reduction)
796
+
797
+
798
+ def main():
799
+ # Apply necessary patches for the transformers library
800
+ replace_llama_rmsnorm_with_fused_rmsnorm()
801
+ replace_train_sampler()
802
+ replace_train_dataloader()
803
+
804
+ # Parse input arguments
805
+ # See all possible arguments in src/transformers/training_args.py
806
+ # If use DeepSpeed zero3, init_dist must before HfArgumentParser
807
+ launcher = os.environ.get('LAUNCHER', 'slurm')
808
+ init_dist(launcher=launcher, backend='nccl')
809
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
810
+ if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
811
+ # If we pass only one argument to the script, and it's the path to a json file,
812
+ # let's parse it to get our arguments.
813
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
814
+ else:
815
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
816
+
817
+ training_args.use_packed_ds = data_args.use_packed_ds
818
+
819
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
820
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
821
+ # send_example_telemetry('InternV-Chat', model_args, data_args)
822
+
823
+ # Setup logging
824
+ logging.basicConfig(
825
+ format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
826
+ datefmt='%m/%d/%Y %H:%M:%S',
827
+ handlers=[logging.StreamHandler(sys.stdout)],
828
+ )
829
+
830
+ if training_args.should_log:
831
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
832
+ transformers.utils.logging.set_verbosity_info()
833
+
834
+ log_level = training_args.get_process_log_level()
835
+ logger.setLevel(log_level)
836
+ set_verbosity(log_level)
837
+ enable_default_handler()
838
+ enable_explicit_format()
839
+
840
+ # Log on each process the small summary:
841
+ logger.warning(
842
+ f'Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}'
843
+ + f'distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}'
844
+ )
845
+ logger.info(f'Training/evaluation parameters {training_args}')
846
+
847
+ # Detecting last checkpoint and eventually continue from last checkpoint.
848
+ last_checkpoint = None
849
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
850
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
851
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
852
+ raise ValueError(
853
+ f'Output directory ({training_args.output_dir}) already exists and is not empty. '
854
+ 'Use --overwrite_output_dir to overcome.'
855
+ )
856
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
857
+ logger.info(
858
+ f'Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change '
859
+ 'the `--output_dir` or add `--overwrite_output_dir` to train from scratch.'
860
+ )
861
+ # Set seed before initializing model.
862
+ set_seed(training_args.seed)
863
+
864
+ # Load pretrained model, tokenizer, and image processor
865
+ tokenizer_path = model_args.model_name_or_path or model_args.llm_path
866
+ logger.info(f'Loading Tokenizer: {tokenizer_path}')
867
+ tokenizer = AutoTokenizer.from_pretrained(
868
+ tokenizer_path, add_eos_token=False, trust_remote_code=True, use_fast=model_args.use_fast_tokenizer)
869
+ tokenizer.tokenizer_path = tokenizer_path
870
+ tokenizer.model_max_length = data_args.max_seq_length
871
+ token_list = [IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN,
872
+ QUAD_START_TOKEN, QUAD_END_TOKEN, REF_START_TOKEN,
873
+ REF_END_TOKEN, BOX_START_TOKEN, BOX_END_TOKEN]
874
+ num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=True)
875
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
876
+ tcs_loader = TCSLoader('~/petreloss.conf') if has_tcs_loader else None
877
+
878
+ if data_args.use_packed_ds:
879
+ replace_internlm2_attention_class()
880
+ replace_qwen2_attention_class()
881
+ replace_phi3_attention_class()
882
+ replace_llama_attention_class()
883
+
884
+ if model_args.use_liger:
885
+ from internvl.patch import apply_liger_kernel_to_internvit
886
+ from liger_kernel.transformers import (apply_liger_kernel_to_llama,
887
+ apply_liger_kernel_to_qwen2)
888
+ apply_liger_kernel_to_llama()
889
+ apply_liger_kernel_to_qwen2()
890
+ # apply_liger_kernel_to_internvit()
891
+
892
+ if model_args.model_name_or_path is not None:
893
+ logger.info('Loading InternVLChatModel...')
894
+ config = InternVLChatConfig.from_pretrained(model_args.model_name_or_path)
895
+ config.vision_config.drop_path_rate = model_args.drop_path_rate
896
+ if config.llm_config.model_type == 'internlm2':
897
+ config.llm_config.attn_implementation = 'flash_attention_2' # for InternLM
898
+ logger.info('Using flash_attention_2 for InternLM')
899
+ else:
900
+ config.llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
901
+ logger.info('Using flash_attention_2 for LLaMA')
902
+ config.template = data_args.conv_style
903
+ config.select_layer = model_args.vision_select_layer
904
+ config.dynamic_image_size = data_args.dynamic_image_size
905
+ config.use_thumbnail = data_args.use_thumbnail
906
+ config.ps_version = model_args.ps_version
907
+ config.min_dynamic_patch = data_args.min_dynamic_patch
908
+ config.max_dynamic_patch = data_args.max_dynamic_patch
909
+ model = InternVLChatModel.from_pretrained(
910
+ model_args.model_name_or_path, torch_dtype=torch.bfloat16, config=config)
911
+ else:
912
+ logger.info('Loading ViT-6B...')
913
+ vision_config = InternVisionConfig.from_pretrained(model_args.vision_path)
914
+ vision_config.drop_path_rate = model_args.drop_path_rate
915
+ vision_model = InternVisionModel.from_pretrained(
916
+ model_args.vision_path, torch_dtype=torch.bfloat16, config=vision_config)
917
+ logger.info('Loading LLaMA...')
918
+ llm_config = AutoConfig.from_pretrained(model_args.llm_path, trust_remote_code=True)
919
+ if llm_config.model_type == 'internlm2':
920
+ model_type = InternLM2ForCausalLM
921
+ llm_config.attn_implementation = 'flash_attention_2' # for InternLM
922
+ logger.info('Using flash_attention_2 for InternLM')
923
+ else:
924
+ model_type = AutoModelForCausalLM
925
+ llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
926
+ logger.info('Using flash_attention_2 for LLaMA')
927
+ llm = model_type.from_pretrained(
928
+ model_args.llm_path, torch_dtype=torch.bfloat16,
929
+ config=llm_config, trust_remote_code=True)
930
+ logger.info('Building InternVLChatConfig...')
931
+ internvl_chat_config = InternVLChatConfig(
932
+ vision_config.to_dict(), llm_config.to_dict(), downsample_ratio=data_args.down_sample_ratio,
933
+ pad2square=data_args.pad2square, template=data_args.conv_style,
934
+ select_layer=model_args.vision_select_layer, dynamic_image_size=data_args.dynamic_image_size,
935
+ use_thumbnail=data_args.use_thumbnail, ps_version=model_args.ps_version,
936
+ min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch)
937
+ internvl_chat_config.force_image_size = data_args.force_image_size
938
+ logger.info('Building InternVLChatModel...')
939
+ model = InternVLChatModel(internvl_chat_config, vision_model, llm)
940
+ model.img_context_token_id = img_context_token_id
941
+
942
+ assert model.config.downsample_ratio == data_args.down_sample_ratio
943
+
944
+ if model_args.mlp_path is not None:
945
+ logger.info('Loading pretrained MLP projector...')
946
+ state_dict = torch.load(model_args.mlp_path, map_location='cpu')
947
+ message = model.mlp1.load_state_dict(state_dict)
948
+ logger.info(message)
949
+ logger.info('Finished')
950
+
951
+ patch_size = model.config.vision_config.patch_size
952
+ logger.info(f'model.config.force_image_size: {model.config.force_image_size}')
953
+ logger.info(f'data_args.force_image_size: {data_args.force_image_size}')
954
+ logger.info(f'model.config.vision_config.image_size: {model.config.vision_config.image_size}')
955
+ if model.config.vision_config.image_size != data_args.force_image_size:
956
+ logger.info(f'Resizing position embedding from '
957
+ f'{model.config.vision_config.image_size} '
958
+ f'to {data_args.force_image_size}...')
959
+ model.vision_model.resize_pos_embeddings(old_size=model.config.vision_config.image_size,
960
+ new_size=data_args.force_image_size,
961
+ patch_size=patch_size)
962
+ model.config.vision_config.image_size = data_args.force_image_size
963
+ model.config.force_image_size = data_args.force_image_size
964
+ model.num_image_token = int((data_args.force_image_size // patch_size) ** 2 * (data_args.down_sample_ratio ** 2))
965
+
966
+ if num_new_tokens > 0:
967
+ model.language_model.resize_token_embeddings(len(tokenizer))
968
+ output_embeddings = model.language_model.get_output_embeddings().weight.data
969
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
970
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
971
+
972
+ model.config.llm_config.vocab_size = len(tokenizer)
973
+ model.language_model.config.vocab_size = len(tokenizer)
974
+
975
+ model.language_model.config.use_cache = False
976
+ model.vision_model.gradient_checkpointing = True
977
+ model.vision_model.encoder.gradient_checkpointing = True
978
+ if model_args.grad_checkpoint:
979
+ model.language_model._set_gradient_checkpointing()
980
+
981
+ train_dataset = build_datasets(
982
+ data_args, tokenizer, tcs_loader, model, group_by_length=training_args.group_by_length,
983
+ dynamic_image_size=data_args.dynamic_image_size, use_thumbnail=data_args.use_thumbnail,
984
+ min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch,
985
+ normalize_type=data_args.normalize_type, min_num_frame=data_args.min_num_frame,
986
+ max_num_frame=data_args.max_num_frame)
987
+
988
+ def _freeze_params(module):
989
+ for param in module.parameters():
990
+ param.requires_grad = False
991
+
992
+ if model_args.freeze_backbone:
993
+ # model.vision_model = model.vision_model.eval()
994
+ _freeze_params(model.vision_model)
995
+
996
+ if model_args.freeze_llm:
997
+ model.language_model = model.language_model.eval()
998
+ _freeze_params(model.language_model)
999
+
1000
+ if model_args.unfreeze_lm_head:
1001
+ model.language_model.lm_head.requires_grad = True
1002
+
1003
+ if model_args.use_backbone_lora:
1004
+ model.wrap_backbone_lora(r=model_args.use_backbone_lora, lora_alpha=2 * model_args.use_backbone_lora)
1005
+ model.config.use_backbone_lora = model_args.use_backbone_lora
1006
+
1007
+ if model_args.use_llm_lora:
1008
+ model.wrap_llm_lora(r=model_args.use_llm_lora, lora_alpha=2 * model_args.use_llm_lora)
1009
+ model.config.use_llm_lora = model_args.use_llm_lora
1010
+
1011
+ if model_args.freeze_mlp:
1012
+ _freeze_params(model.mlp1)
1013
+
1014
+ if model_args.unfreeze_vit_layers != 0:
1015
+ layers = model.vision_model.encoder.layers[model_args.unfreeze_vit_layers:]
1016
+ for k, v in layers.named_parameters():
1017
+ logger.info(f'Unfreezing ViT layer: {k}')
1018
+ v.requires_grad = True
1019
+
1020
+ # print trainable parameters
1021
+ if dist.get_rank() == 0:
1022
+ for name, param in model.named_parameters():
1023
+ if param.requires_grad:
1024
+ logger.info(name)
1025
+
1026
+ # set seed for torch dataloaders
1027
+ set_seed(training_args.seed)
1028
+
1029
+ if data_args.use_packed_ds:
1030
+ collator = partial(
1031
+ packed_collate_fn,
1032
+ data_collator=concat_pad_data_collator,
1033
+ max_item_length=data_args.max_packed_tokens if data_args.strict_mode else 0,
1034
+ micro_num=training_args.train_batch_size,
1035
+ len2weight=partial(len2weight, loss_reduction=data_args.loss_reduction),
1036
+ loss_reduction_all_gather=data_args.loss_reduction_all_gather,
1037
+ )
1038
+ else:
1039
+ collator = concat_pad_data_collator
1040
+
1041
+ trainer = Trainer(
1042
+ model=model,
1043
+ args=training_args,
1044
+ train_dataset=train_dataset if training_args.do_train else None,
1045
+ eval_dataset=None,
1046
+ tokenizer=tokenizer,
1047
+ data_collator=collator,
1048
+ )
1049
+
1050
+ # Training
1051
+ if training_args.do_train:
1052
+ checkpoint = None
1053
+ if training_args.resume_from_checkpoint is not None:
1054
+ checkpoint = training_args.resume_from_checkpoint
1055
+ elif last_checkpoint is not None:
1056
+ checkpoint = last_checkpoint
1057
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
1058
+ trainer.save_model() # Saves the tokenizer too for easy upload
1059
+
1060
+ metrics = train_result.metrics
1061
+ try:
1062
+ metrics['train_samples'] = len(train_dataset)
1063
+ except:
1064
+ metrics['train_samples'] = -1
1065
+
1066
+ trainer.log_metrics('train', metrics)
1067
+ trainer.save_metrics('train', metrics)
1068
+ trainer.save_state()
1069
+
1070
+
1071
+ if __name__ == '__main__':
1072
+ main()
src/third_party/InternVL/internvl_chat/internvl/train/internvl_chat_pretrain.py ADDED
@@ -0,0 +1,1116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import logging
8
+ import math
9
+ import os
10
+ import random
11
+ import sys
12
+ import traceback
13
+ import warnings
14
+ from copy import deepcopy
15
+ from dataclasses import dataclass, field
16
+ from functools import partial
17
+ from typing import Dict, Literal, Optional
18
+
19
+ import numpy as np
20
+
21
+ try:
22
+ import orjson as json
23
+ except:
24
+ import json
25
+
26
+ import torch
27
+ import torch.distributed as dist
28
+ import transformers
29
+ from internvl.dist_utils import init_dist
30
+ from internvl.model.internlm2.modeling_internlm2 import InternLM2ForCausalLM
31
+ from internvl.model.internvl_chat import (InternVisionConfig,
32
+ InternVisionModel,
33
+ InternVLChatConfig,
34
+ InternVLChatModel)
35
+ from internvl.patch import (concat_pad_data_collator,
36
+ replace_internlm2_attention_class,
37
+ replace_llama_attention_class,
38
+ replace_llama_rmsnorm_with_fused_rmsnorm,
39
+ replace_phi3_attention_class,
40
+ replace_qwen2_attention_class,
41
+ replace_train_dataloader, replace_train_sampler)
42
+ from internvl.train.constants import (BOX_END_TOKEN, BOX_START_TOKEN,
43
+ IMG_CONTEXT_TOKEN, IMG_END_TOKEN,
44
+ IMG_START_TOKEN, QUAD_END_TOKEN,
45
+ QUAD_START_TOKEN, REF_END_TOKEN,
46
+ REF_START_TOKEN)
47
+ from internvl.train.dataset import (ConcatDataset, TCSLoader,
48
+ WeightedConcatDataset, build_transform,
49
+ check_conversations_repetition,
50
+ dynamic_preprocess, preprocess,
51
+ preprocess_internlm,
52
+ preprocess_internvl2_5, preprocess_mpt,
53
+ preprocess_phi3)
54
+ from internvl.train.dataset_packed import PackedDataset, packed_collate_fn
55
+ from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError
56
+ from torch.utils.data import Dataset
57
+ from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
58
+ HfArgumentParser, Trainer, TrainingArguments,
59
+ set_seed)
60
+ from transformers.trainer_utils import get_last_checkpoint
61
+ from transformers.utils.logging import (enable_default_handler,
62
+ enable_explicit_format, set_verbosity)
63
+
64
+ # Try to import petrel_client for image loading, fallback to PIL if unavailable
65
+ try:
66
+ from petrel_client.client import Client
67
+ from petrel_client.common.config import Config
68
+ has_tcs_loader = True
69
+ except ImportError as E:
70
+ print('petrel_client is not installed. Using PIL to load images.')
71
+ has_tcs_loader = False
72
+
73
+ # Set constants for image processing and logging
74
+ IGNORE_INDEX = -100
75
+ Image.MAX_IMAGE_PIXELS = None
76
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
77
+ MaximumDecompressedSize = 1024
78
+ MegaByte = 2 ** 20
79
+ PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
80
+
81
+ warnings.filterwarnings('ignore')
82
+ logger = logging.getLogger(__name__)
83
+
84
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
85
+
86
+
87
+ @dataclass
88
+ class ModelArguments:
89
+ """
90
+ Arguments for specifying model, tokenizer, and configurations.
91
+ """
92
+ model_name_or_path: Optional[str] = field(
93
+ default=None,
94
+ metadata={'help': 'Path to a pretrained model (local or from huggingface.co/models).'}
95
+ )
96
+ vision_path: Optional[str] = field(
97
+ default=None,
98
+ metadata={'help': 'Path to a pretrained model (local or from huggingface.co/models).'}
99
+ )
100
+ llm_path: Optional[str] = field(
101
+ default=None,
102
+ metadata={'help': 'Path to a pretrained model (local or from huggingface.co/models).'}
103
+ )
104
+ mlp_path: Optional[str] = field(
105
+ default=None,
106
+ metadata={'help': 'Path to a pretrained model (local or from huggingface.co/models).'}
107
+ )
108
+ freeze_llm: bool = field(
109
+ default=False,
110
+ metadata={'help': 'Set to True to freeze the LLM. Default is False.'},
111
+ )
112
+ freeze_backbone: bool = field(
113
+ default=False,
114
+ metadata={'help': 'Set to True to freeze the ViT. Default is False.'},
115
+ )
116
+ freeze_mlp: bool = field(
117
+ default=False,
118
+ metadata={'help': 'Set to True to freeze the MLP. Default is False.'},
119
+ )
120
+ unfreeze_vit_layers: int = field(
121
+ default=0,
122
+ metadata={'help': 'Specify the number of ViT layers to unfreeze. Default is 0.'},
123
+ )
124
+ vision_select_layer: int = field(
125
+ default=-1,
126
+ metadata={'help': 'Specify the layer of ViT feature map to use. Default is -1 for the last layer.'},
127
+ )
128
+ use_backbone_lora: int = field(
129
+ default=0,
130
+ metadata={'help': 'Set the LoRA adapter rank for the ViT. Default is 0.'}
131
+ )
132
+ use_llm_lora: int = field(
133
+ default=0,
134
+ metadata={'help': 'Set the LoRA adapter rank for the LLM. Default is 0.'}
135
+ )
136
+ unfreeze_lm_head: bool = field(
137
+ default=False,
138
+ metadata={'help': 'Set to True to unfreeze the head of LLM. Default is False.'},
139
+ )
140
+ grad_checkpoint: bool = field(
141
+ default=True,
142
+ metadata={'help': 'Set to True to use gradient checkpointing. Default is True.'},
143
+ )
144
+ drop_path_rate: float = field(
145
+ default=0.0,
146
+ metadata={'help': 'Set the drop path rate for the ViT. Default is 0.'},
147
+ )
148
+ ps_version: Literal['v1', 'v2'] = field(
149
+ default='v2',
150
+ metadata={'help': 'Specify the version of pixel shuffle implementation. Default is v2.'}
151
+ )
152
+ use_fast_tokenizer: bool = field(
153
+ default=False,
154
+ metadata={'help': 'Set to True to use the fast mode of the tokenizer.'}
155
+ )
156
+ use_liger: bool = field(
157
+ default=False,
158
+ metadata={'help': 'Set to True to use the liger kernel.'}
159
+ )
160
+
161
+
162
+ @dataclass
163
+ class DataTrainingArguments:
164
+ """
165
+ Arguments for specifying data input for training and evaluation.
166
+ """
167
+ max_seq_length: int = field(
168
+ default=8192,
169
+ metadata={
170
+ 'help': (
171
+ 'The maximum total input sequence length after tokenization. Sequences longer '
172
+ 'than this will be truncated, sequences shorter will be padded.'
173
+ )
174
+ },
175
+ )
176
+ force_image_size: int = field(
177
+ default=448,
178
+ metadata={'help': 'Set the desired size for the image. Default is 448.'},
179
+ )
180
+ down_sample_ratio: float = field(
181
+ default=0.5,
182
+ metadata={'help': 'Set the desired down-sampling ratio for the image. Default is 0.5.'},
183
+ )
184
+ pad2square: bool = field(
185
+ default=False,
186
+ metadata={'help': 'Pad the image to a square shape if set to True. Default is False.'},
187
+ )
188
+ conv_style: str = field(
189
+ default='internlm2-chat', metadata={'help': 'Prompt style for a conversation.'}
190
+ )
191
+ meta_path: str = field(
192
+ default=None,
193
+ metadata={'help': 'The path of the meta file of datasets.'},
194
+ )
195
+ use_data_resampling: bool = field(
196
+ default=False,
197
+ metadata={'help': 'Set to True to use data resampling. Default is False.'},
198
+ )
199
+ dynamic_image_size: bool = field(
200
+ default=False,
201
+ metadata={'help': 'Set to True to use dynamic high resolution strategy. Default is False.'},
202
+ )
203
+ use_thumbnail: bool = field(
204
+ default=False,
205
+ metadata={'help': 'Set to True to add a thumbnail image. Default is False.'},
206
+ )
207
+ min_dynamic_patch: int = field(
208
+ default=1,
209
+ metadata={'help': 'The minimum number of dynamic patches. Default is 1.'},
210
+ )
211
+ max_dynamic_patch: int = field(
212
+ default=12,
213
+ metadata={'help': 'The maximum number of dynamic patches. Default is 12.'},
214
+ )
215
+ min_num_frame: int = field(
216
+ default=8,
217
+ metadata={'help': 'The minimum number of frames for video data. Default is 8.'},
218
+ )
219
+ max_num_frame: int = field(
220
+ default=32,
221
+ metadata={'help': 'The maximum number of frames for video data. Default is 32.'},
222
+ )
223
+ normalize_type: Literal['imagenet', 'clip', 'siglip'] = field(
224
+ default='imagenet',
225
+ metadata={'help': 'The normalization type for the image. Default is imagenet.'},
226
+ )
227
+ use_packed_ds: bool = field(
228
+ default=False,
229
+ metadata={'help': 'Whether to use packed dataset for efficient training. Default is False.'},
230
+ )
231
+ num_images_expected: int = field(
232
+ default=40,
233
+ metadata={'help': 'The maximum number of images per packed sample. Default is 40.'},
234
+ )
235
+ max_packed_tokens: int = field(
236
+ default=8192,
237
+ metadata={'help': 'The required token length of per packed sample. Default is 8192.'},
238
+ )
239
+ max_buffer_size: int = field(
240
+ default=20,
241
+ metadata={'help': 'The buffer size of the packed dataset. Default is 20.'},
242
+ )
243
+ log_freq: int = field(
244
+ default=1000,
245
+ metadata={'help': 'The log frequency of the packed dataset. Default is 1000.'},
246
+ )
247
+ strict_mode: bool = field(
248
+ default=True,
249
+ metadata={'help': 'Whether to pad the number of images to satisfy num_images_expected. Default is True.'},
250
+ )
251
+ replacement: bool = field(
252
+ default=False,
253
+ metadata={'help': 'Whether to restart the dataset after it is exhausted. Default is False.'},
254
+ )
255
+ allow_overflow: bool = field(
256
+ default=False,
257
+ metadata={'help': 'Whether to drop the sample over the specified max_packed_tokens. Default is False.'},
258
+ )
259
+ loss_reduction: str = field(
260
+ default='token',
261
+ metadata={'help': 'Loss reduction method. Default is token.'},
262
+ )
263
+ loss_reduction_all_gather: bool = field(
264
+ default=False,
265
+ metadata={'help': 'Whether to gather all during loss reduction. Default is False.'},
266
+ )
267
+
268
+
269
+ class LazySupervisedDataset(Dataset):
270
+ """Dataset for supervised fine-tuning."""
271
+
272
+ def __init__(
273
+ self,
274
+ template_name,
275
+ meta,
276
+ tokenizer,
277
+ tcs_loader,
278
+ ds_name,
279
+ num_image_token,
280
+ image_size=448,
281
+ is_train=True,
282
+ pad2square=False,
283
+ group_by_length=False,
284
+ dynamic_image_size=False,
285
+ use_thumbnail=False,
286
+ min_dynamic_patch=1,
287
+ max_dynamic_patch=12,
288
+ min_num_frame=8, # for video data
289
+ max_num_frame=32, # for video data
290
+ sampling_method='rand', # for video data
291
+ repeat_time=1,
292
+ normalize_type='imagenet',
293
+ # hyperparameters for packed training
294
+ use_packed_ds=False,
295
+ data_rank=0,
296
+ data_world_size=1,
297
+ distributed_mode=False,
298
+ force_shuffle=False,
299
+ random_seed=0,
300
+ ):
301
+ super(LazySupervisedDataset, self).__init__()
302
+ self.ds_name = ds_name
303
+ self.tokenizer = tokenizer
304
+ self.template_name = template_name
305
+ self.num_image_token = num_image_token
306
+ logger.info(f'[Dataset] num_image_token: {num_image_token}')
307
+ logger.info(f'[Dataset] dynamic_image_size: {dynamic_image_size}')
308
+ logger.info(f'[Dataset] use_thumbnail: {use_thumbnail}')
309
+ logger.info(f'[Dataset] min_dynamic_patch: {min_dynamic_patch}, max_dynamic_patch: {max_dynamic_patch}')
310
+
311
+ self.image_size = image_size
312
+ self.is_train = is_train
313
+ self.pad2square = pad2square
314
+ self.max_num_frame = max_num_frame
315
+ self.min_num_frame = min_num_frame
316
+ self.sampling_method = sampling_method
317
+
318
+ # hyperparameters for distributed training
319
+ self.use_packed_ds = use_packed_ds
320
+ self.data_rank = data_rank
321
+ self.data_world_size = data_world_size
322
+ self.worker_id = None
323
+ self.worker_state_key = None
324
+ self.worker_distributed = False
325
+ self.distributed_mode = distributed_mode
326
+ # hyperparameters for packed dataset
327
+ self.dataset_type = 'pair'
328
+ self.max_num_images = 1
329
+ self.max_tokens = tokenizer.model_max_length
330
+ self.force_shuffle = force_shuffle
331
+ # TODO: quick resume
332
+ self._state_dict = {}
333
+
334
+ logger.info('Formatting inputs...Skip in lazy mode')
335
+ assert meta['annotation'].endswith('jsonl'), f'annotation must be jsonl, but got {meta["annotation"]}'
336
+
337
+ total_ranks = torch.distributed.get_world_size()
338
+ self.total_ranks = total_ranks
339
+ current_rank = torch.distributed.get_rank()
340
+
341
+ """
342
+ This section of the code is used to read hundreds of millions of data entries.
343
+ By using caching and splitting the data according to rank, it ensures fast reading
344
+ speed and prevents out-of-memory.
345
+ """
346
+ # Create a cache directory path
347
+ basename = os.path.basename(meta['annotation']).replace('.jsonl', '')
348
+ data_dir = os.path.join(os.path.dirname(meta['annotation']), f'{basename}_temp')
349
+ os.makedirs(data_dir, exist_ok=True) # Create the cache directory if it does not exist
350
+ # Create a temporary path for the current rank
351
+ temp_path = os.path.join(data_dir, f'{basename}_{current_rank}_of_{total_ranks}.jsonl')
352
+
353
+ # Check if the temporary file for the current rank already exists
354
+ if os.path.exists(temp_path):
355
+ # If it exists, read the raw data from the file
356
+ with open(temp_path, 'r') as f:
357
+ self.raw_data = f.readlines()
358
+ else:
359
+ # If it does not exist, read the raw data from the original annotation file
360
+ with open(meta['annotation'], 'r') as f:
361
+ self.raw_data = f.readlines()
362
+
363
+ # Adjust the raw data based on the repeat_time parameter
364
+ if repeat_time < 1:
365
+ self.raw_data = self.raw_data[:int(len(self.raw_data) * repeat_time)]
366
+ else:
367
+ self.raw_data = self.raw_data * int(repeat_time)
368
+
369
+ # Calculate the total number of lines and distribute lines to each rank
370
+ total_lines = len(self.raw_data)
371
+ logger.info(f'total_ranks: {total_ranks}, current_rank: {current_rank}, total_lines: {total_lines}')
372
+ lines_per_rank = total_lines // total_ranks # Number of lines each rank should process
373
+ lines_per_rank = max(1, lines_per_rank)
374
+
375
+ # Calculate the start and end line numbers for the current rank
376
+ start_line = lines_per_rank * current_rank # Starting line for the current rank
377
+ end_line = start_line + lines_per_rank # Ending line for the current rank
378
+
379
+ # Assign the appropriate lines to the current rank
380
+ self.raw_data = self.raw_data[start_line:end_line]
381
+
382
+ # Write the raw data for the current rank to the temporary file
383
+ with open(temp_path, 'w') as f:
384
+ f.writelines(self.raw_data)
385
+
386
+ self.rng = np.random.default_rng(seed=random_seed)
387
+ if self.force_shuffle:
388
+ self.rng.shuffle(self.raw_data)
389
+
390
+ self.root = meta['root']
391
+ self.cached_data_dict = {}
392
+ self.tcs_loader = tcs_loader
393
+ self.group_by_length = group_by_length
394
+ self.dynamic_image_size = dynamic_image_size
395
+ self.use_thumbnail = use_thumbnail
396
+ self.min_dynamic_patch = min_dynamic_patch
397
+ self.max_dynamic_patch = max_dynamic_patch
398
+ self.normalize_type = normalize_type
399
+
400
+ assert not group_by_length
401
+ # If the precomputed length does not exist, roughly estimate the length of
402
+ # each sample to improve the efficiency of group_by_length.
403
+ if self.group_by_length:
404
+ self.conv2length = {} # Using a dictionary to speed up token length calculation
405
+ self.length = []
406
+ for data_item in self.raw_data:
407
+ data_item = json.loads(data_item)
408
+ if 'length' in data_item:
409
+ token_length = data_item['length'] # Use precomputed length if available
410
+ else:
411
+ # Compute token length using the tokenizer
412
+ conversations = '\n'.join([temp['value'] for temp in data_item['conversations']])
413
+ str_length = len(conversations)
414
+ if str_length not in self.conv2length:
415
+ token_length = tokenizer(
416
+ conversations, return_tensors='pt', padding=False, truncation=False,
417
+ ).input_ids.size(1)
418
+ self.conv2length[str_length] = token_length + num_image_token * (
419
+ max_dynamic_patch + use_thumbnail)
420
+ else:
421
+ token_length = self.conv2length[str_length]
422
+ self.length.append(token_length)
423
+
424
+ def __len__(self):
425
+ if not self.use_packed_ds:
426
+ return len(self.raw_data) * self.total_ranks
427
+ else:
428
+ return len(self.raw_data)
429
+
430
+ def get_preprocess_function(self):
431
+ # Select the appropriate preprocessing function based on the template name
432
+ if self.template_name == 'Hermes-2':
433
+ preprocess_function = preprocess_mpt
434
+ elif self.template_name == 'internlm2-chat':
435
+ preprocess_function = preprocess_internlm
436
+ elif self.template_name == 'phi3-chat':
437
+ preprocess_function = preprocess_phi3
438
+ elif self.template_name == 'internvl2_5':
439
+ preprocess_function = preprocess_internvl2_5
440
+ else:
441
+ preprocess_function = preprocess
442
+ return preprocess_function
443
+
444
+ def load_image(self, image_path):
445
+ # Load the image using tcs_loader if available, otherwise use PIL
446
+ if self.tcs_loader is not None and 's3://' in image_path:
447
+ return self.tcs_loader(image_path)
448
+ return Image.open(image_path).convert('RGB')
449
+
450
+ def get_image_path(self, image_path):
451
+ if image_path.startswith('s3://'): # for ceph
452
+ image_path = self.root + image_path
453
+ else: # for local image
454
+ image_path = os.path.join(self.root, image_path)
455
+ return image_path
456
+
457
+ def get_transform(self):
458
+ # Build transformation function
459
+ transform = build_transform(is_train=self.is_train, input_size=self.image_size,
460
+ pad2square=self.pad2square, normalize_type=self.normalize_type)
461
+ return transform
462
+
463
+ def multi_modal_get_item(self, data_item):
464
+ # Build transformation function
465
+ transform = self.get_transform()
466
+
467
+ # Ensure the first conversation contains an image placeholder
468
+ if '<image>' not in data_item['conversations'][0]['value']:
469
+ data_item['conversations'][0]['value'] = '<image>\n' + data_item['conversations'][0]['value']
470
+
471
+ # Merge the image path
472
+ image_path = self.get_image_path(data_item['image'])
473
+
474
+ # Load the image using tcs_loader if available, otherwise use PIL
475
+ image = self.load_image(image_path)
476
+
477
+ if self.dynamic_image_size: # If dynamic image size is enabled, preprocess the image dynamically
478
+ images = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=self.max_dynamic_patch,
479
+ image_size=self.image_size, use_thumbnail=self.use_thumbnail)
480
+ else: # Otherwise, use the original image as a single patch
481
+ images = [image]
482
+
483
+ # Apply the transformation to each image and stack the results into a tensor
484
+ pixel_values = [transform(image) for image in images]
485
+ pixel_values = torch.stack(pixel_values)
486
+
487
+ # Ensure that there is only one patch if dynamic image size is not enabled
488
+ num_patches = pixel_values.size(0)
489
+ if not self.dynamic_image_size:
490
+ assert num_patches == 1, f'The number of patches should be 1, but got {num_patches}.'
491
+
492
+ # Select the appropriate preprocessing function based on the template name
493
+ preprocess_function = self.get_preprocess_function()
494
+
495
+ # Preprocess the conversations and generate the return dictionary
496
+ ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
497
+ self.tokenizer, [self.num_image_token * num_patches],
498
+ group_by_length=self.group_by_length,
499
+ use_packed_ds=self.use_packed_ds, ds_name=self.ds_name)
500
+
501
+ # Calculate position_ids for packed dataset
502
+ position_ids = ret['attention_mask'].long().cumsum(-1) - 1
503
+ position_ids.masked_fill_(ret['attention_mask'] == 0, 1)
504
+ image_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN)
505
+ assert (ret['input_ids'][0] == image_end_token_id).sum() == 1, f'image tokens are truncated, this dataset is {self.ds_name}'
506
+
507
+ # Create the final return dictionary
508
+ ret = dict(
509
+ input_ids=ret['input_ids'][0],
510
+ labels=ret['labels'][0],
511
+ attention_mask=ret['attention_mask'][0],
512
+ position_ids=position_ids[0],
513
+ pixel_values=pixel_values,
514
+ image_flags=torch.tensor([1] * num_patches, dtype=torch.long)
515
+ )
516
+ return ret
517
+
518
+ def multi_modal_multi_image_get_item(self, data_item):
519
+ # Build transformation function
520
+ transform = self.get_transform()
521
+
522
+ images, num_tiles = [], []
523
+ num_image = len(data_item['image'])
524
+ for image_path in data_item['image']:
525
+ # Merge the image path
526
+ image_path = self.get_image_path(image_path)
527
+ # Load the image using tcs_loader if available, otherwise use PIL
528
+ image = self.load_image(image_path)
529
+ if self.dynamic_image_size: # If dynamic image size is enabled, preprocess the image dynamically
530
+ image = dynamic_preprocess(image, min_num=self.min_dynamic_patch,
531
+ max_num=max(1, self.max_dynamic_patch // num_image),
532
+ image_size=self.image_size, use_thumbnail=self.use_thumbnail)
533
+ images += image
534
+ num_tiles.append(len(image))
535
+ else: # Otherwise, use the original image as a single patch
536
+ images.append(image)
537
+ num_tiles.append(1)
538
+ pixel_values = [transform(image) for image in images]
539
+ pixel_values = torch.stack(pixel_values)
540
+ num_patches = pixel_values.size(0)
541
+
542
+ # Select the appropriate preprocessing function based on the template name
543
+ preprocess_function = self.get_preprocess_function()
544
+
545
+ # Preprocess the conversations and generate the return dictionary
546
+ num_image_tokens = [self.num_image_token * num_tile for num_tile in num_tiles]
547
+ ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
548
+ self.tokenizer, num_image_tokens, group_by_length=self.group_by_length,
549
+ use_packed_ds=self.use_packed_ds, ds_name=self.ds_name, num_image=num_image)
550
+
551
+ # Calculate position_ids for packed dataset
552
+ position_ids = ret['attention_mask'].long().cumsum(-1) - 1
553
+ position_ids.masked_fill_(ret['attention_mask'] == 0, 1)
554
+ image_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN)
555
+ assert (ret['input_ids'][0] == image_end_token_id).sum() == num_image, f'image tokens are truncated, this dataset is {self.ds_name}'
556
+
557
+ # Create the final return dictionary
558
+ ret = dict(
559
+ input_ids=ret['input_ids'][0],
560
+ labels=ret['labels'][0],
561
+ attention_mask=ret['attention_mask'][0],
562
+ position_ids=position_ids[0],
563
+ pixel_values=pixel_values,
564
+ image_flags=torch.tensor([1] * num_patches, dtype=torch.long)
565
+ )
566
+ return ret
567
+
568
+ def video_get_item(self, data_item):
569
+ # Build transformation function
570
+ transform = self.get_transform()
571
+
572
+ # Ensure the first conversation contains a video placeholder
573
+ if '<video>' not in data_item['conversations'][0]['value']:
574
+ data_item['conversations'][0]['value'] = '<video>\n' + data_item['conversations'][0]['value']
575
+
576
+ # Get the video file path
577
+ video_file = data_item['video']
578
+ video_path = os.path.join(self.root, video_file)
579
+
580
+ # Load the video frames using tcs_loader
581
+ # TODO: Load videos without using tcsloader.
582
+ image_list = self.tcs_loader(
583
+ video_path,
584
+ image_type='video',
585
+ max_num_frames=self.max_num_frame,
586
+ min_num_frames=self.min_num_frame,
587
+ sample=self.sampling_method,
588
+ clip=data_item.get('clip', None))
589
+
590
+ # Generate special tokens for each video frame
591
+ special_tokens = '\n'.join(['Frame-{}: <image>'.format(i + 1) for i in range(len(image_list))])
592
+ data_item['conversations'][0]['value'] = data_item['conversations'][0]['value'].replace(
593
+ '<video>\n', special_tokens + '\n')
594
+
595
+ # Transform each frame image and stack them into a tensor
596
+ pixel_values = [transform(image) for image in image_list]
597
+ pixel_values = torch.stack(pixel_values)
598
+ num_patches = pixel_values.size(0)
599
+
600
+ # Select the appropriate preprocessing function based on the template name
601
+ preprocess_function = self.get_preprocess_function()
602
+
603
+ # Preprocess the conversations and generate the return dictionary
604
+ num_image_tokens = [self.num_image_token] * num_patches
605
+ ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
606
+ self.tokenizer, num_image_tokens, group_by_length=self.group_by_length,
607
+ use_packed_ds=self.use_packed_ds, ds_name=self.ds_name, num_image=num_patches)
608
+
609
+ # Calculate position_ids for packed dataset
610
+ position_ids = ret['attention_mask'].long().cumsum(-1) - 1
611
+ position_ids.masked_fill_(ret['attention_mask'] == 0, 1)
612
+
613
+ # Create the final return dictionary
614
+ ret = dict(
615
+ input_ids=ret['input_ids'][0],
616
+ labels=ret['labels'][0],
617
+ attention_mask=ret['attention_mask'][0],
618
+ position_ids=position_ids[0],
619
+ pixel_values=pixel_values,
620
+ image_flags=torch.tensor([1] * num_patches, dtype=torch.long)
621
+ )
622
+ return ret
623
+
624
+ def pure_text_get_item(self, data_item):
625
+ # Build transformation function
626
+ transform = self.get_transform()
627
+
628
+ # Create a blank white image
629
+ image = Image.new('RGB', (224, 224), (255, 255, 255))
630
+
631
+ # Dynamically preprocess the image to generate patches
632
+ images = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=1,
633
+ image_size=self.image_size, use_thumbnail=self.use_thumbnail)
634
+
635
+ # Apply the transformation to each image patch and stack them into a tensor
636
+ pixel_values = [transform(image) for image in images]
637
+ pixel_values = torch.stack(pixel_values)
638
+ num_patches = pixel_values.size(0)
639
+
640
+ # Ensure there is only one patch
641
+ assert num_patches == 1, f'The number of patches should be 1, but got {num_patches}.'
642
+
643
+ # Select the appropriate preprocessing function based on the template name
644
+ preprocess_function = self.get_preprocess_function()
645
+
646
+ # Preprocess the conversations and generate the return dictionary
647
+ ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
648
+ self.tokenizer, [self.num_image_token * num_patches], text_only=True,
649
+ group_by_length=self.group_by_length, use_packed_ds=self.use_packed_ds,
650
+ ds_name=self.ds_name)
651
+
652
+ # Calculate position_ids for packed dataset
653
+ position_ids = ret['attention_mask'].long().cumsum(-1) - 1
654
+ position_ids.masked_fill_(ret['attention_mask'] == 0, 1)
655
+
656
+ # Create the final return dictionary
657
+ ret = dict(
658
+ input_ids=ret['input_ids'][0],
659
+ labels=ret['labels'][0],
660
+ attention_mask=ret['attention_mask'][0],
661
+ position_ids=position_ids[0],
662
+ pixel_values=pixel_values,
663
+ image_flags=torch.tensor([0] * num_patches, dtype=torch.long)
664
+ )
665
+ return ret
666
+
667
+ def _enable_worker_distributed(self):
668
+ if (
669
+ self.distributed_mode
670
+ and not self.worker_distributed
671
+ and self.worker_id is not None
672
+ ):
673
+ self.worker_distributed = True
674
+ num_worker_per_rank = self.num_workers // self.total_ranks
675
+ self.raw_data = self.raw_data[self.worker_id % num_worker_per_rank::num_worker_per_rank]
676
+ logger.info(f'worker_distributed is enabled, {self.num_workers=}, {len(self.raw_data)=}')
677
+
678
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
679
+ if i >= len(self.raw_data):
680
+ if self.use_packed_ds:
681
+ raise NotImplementedError
682
+ else:
683
+ i = i % len(self.raw_data)
684
+
685
+ try_cnt, max_try = 0, 10
686
+ while True:
687
+ if try_cnt > max_try:
688
+ raise StopIteration
689
+ try:
690
+ data_item = json.loads(self.raw_data[i])
691
+ # conversations = data_item['conversations']
692
+ # check_conversations_repetition(conversations, repeat_threshold=0.4, ngram=10)
693
+ if 'image' in data_item and len(data_item['image']) != 0:
694
+ if type(data_item['image']) == list:
695
+ ret = self.multi_modal_multi_image_get_item(data_item)
696
+ else:
697
+ ret = self.multi_modal_get_item(data_item)
698
+ elif 'video' in data_item and data_item['video'] is not None and data_item['video'] != '':
699
+ ret = self.video_get_item(data_item)
700
+ else:
701
+ ret = self.pure_text_get_item(data_item)
702
+ break
703
+ except Exception as e:
704
+ try_cnt += 1
705
+ print(e, self.ds_name, flush=True)
706
+ if not isinstance(e, (UnidentifiedImageError, FileNotFoundError)):
707
+ traceback.print_exc()
708
+ data_item = json.loads(self.raw_data[i])
709
+ if 'image' in data_item:
710
+ if type(data_item['image']) == list:
711
+ images = [self.root + item for item in data_item['image']]
712
+ print(f'Failed to load image: {images}, the dataset is: {self.ds_name}')
713
+ else:
714
+ if data_item['image'].startswith('s3://'):
715
+ data_path = self.root + data_item['image']
716
+ else:
717
+ data_path = os.path.join(self.root, data_item['image'])
718
+ print(f'Failed to load image: {data_path}, the dataset is: {self.ds_name}')
719
+ elif 'video' in data_item:
720
+ data_path = os.path.join(self.root, data_item['video'])
721
+ print(f'Failed to load video: {data_path}, the dataset is: {self.ds_name}')
722
+ i = random.randint(0, len(self.raw_data) - 1)
723
+ return ret
724
+
725
+ def __iter__(self):
726
+ self._enable_worker_distributed()
727
+ start_idx = 0
728
+
729
+ assert self.worker_state_key is not None
730
+ if self.worker_state_key in self._state_dict and len(self._state_dict[self.worker_state_key]) > 0:
731
+ start_idx = self._state_dict[self.worker_state_key]['current_idx']
732
+
733
+ self._state_dict.pop(self.worker_state_key)
734
+
735
+ if self.worker_id == 0:
736
+ logger.info(
737
+ f'[{self.ds_name}] [Worker id {self.worker_id}] '
738
+ f'begin to iter with {start_idx=}'
739
+ )
740
+
741
+ for i in range(start_idx, len(self)):
742
+ yield self[i]
743
+
744
+
745
+ def build_datasets(
746
+ data_args,
747
+ tokenizer,
748
+ tcs_loader,
749
+ model,
750
+ group_by_length=False,
751
+ dynamic_image_size=False,
752
+ use_thumbnail=False,
753
+ min_dynamic_patch=1,
754
+ max_dynamic_patch=12,
755
+ min_num_frame=8,
756
+ max_num_frame=32,
757
+ normalize_type='imagenet',
758
+ ):
759
+ datasets = []
760
+ lengths = []
761
+ data_rank = dist.get_rank()
762
+ data_world_size = dist.get_world_size()
763
+ ds_collections = json.loads(open(data_args.meta_path).read())
764
+ for ds_idx, ds_name in enumerate(ds_collections.keys()):
765
+ repeat_time = ds_collections[ds_name]['repeat_time']
766
+ if 'max_dynamic_patch' in ds_collections[ds_name]:
767
+ max_num = ds_collections[ds_name]['max_dynamic_patch']
768
+ logger.info(f'max_dynamic_patch is set to {max_num} according to the meta file')
769
+ else:
770
+ max_num = max_dynamic_patch
771
+ dataset = LazySupervisedDataset(
772
+ data_args.conv_style, ds_collections[ds_name],
773
+ tokenizer,
774
+ tcs_loader,
775
+ ds_name=ds_name,
776
+ num_image_token=model.num_image_token,
777
+ image_size=data_args.force_image_size,
778
+ is_train=ds_collections[ds_name]['data_augment'],
779
+ pad2square=data_args.pad2square,
780
+ group_by_length=group_by_length and not data_args.use_packed_ds,
781
+ dynamic_image_size=dynamic_image_size,
782
+ use_thumbnail=use_thumbnail,
783
+ min_dynamic_patch=min_dynamic_patch,
784
+ max_dynamic_patch=max_num,
785
+ min_num_frame=min_num_frame,
786
+ max_num_frame=max_num_frame,
787
+ repeat_time=repeat_time,
788
+ normalize_type=normalize_type,
789
+ # hyperparameters for packed training
790
+ use_packed_ds=data_args.use_packed_ds,
791
+ data_rank=data_rank,
792
+ data_world_size=data_world_size,
793
+ distributed_mode=data_args.use_packed_ds,
794
+ force_shuffle=data_args.use_packed_ds,
795
+ random_seed=ds_idx,
796
+ )
797
+ logger.info(f'Add dataset: {ds_name} with length: {len(dataset)}')
798
+ datasets.append(dataset)
799
+ if data_args.use_data_resampling:
800
+ lengths.append(math.sqrt(len(dataset)))
801
+ else:
802
+ lengths.append(len(dataset))
803
+
804
+ if data_args.use_packed_ds:
805
+ total_length = sum(lengths)
806
+ train_dataset = PackedDataset(
807
+ tokenizer=tokenizer,
808
+ data_rank=data_rank,
809
+ data_world_size=data_world_size,
810
+ datasets=datasets,
811
+ dataset_weight=[l / total_length for l in lengths],
812
+ num_images_expected=data_args.num_images_expected,
813
+ max_packed_tokens=data_args.max_packed_tokens,
814
+ max_buffer_size=data_args.max_buffer_size,
815
+ log_freq=data_args.log_freq,
816
+ strict_mode=data_args.strict_mode,
817
+ replacement=data_args.replacement,
818
+ allow_overflow=data_args.allow_overflow,
819
+ allow_deduplicated_ds_name=False,
820
+ )
821
+ elif data_args.use_data_resampling:
822
+ total_length = sum(lengths)
823
+ weights = [l / total_length for l in lengths]
824
+ train_dataset = WeightedConcatDataset(datasets, weights)
825
+ else:
826
+ train_dataset = ConcatDataset(datasets)
827
+ return train_dataset
828
+
829
+
830
+ def len2weight(x, loss_reduction):
831
+ if x == 0:
832
+ return x
833
+ if loss_reduction == 'token':
834
+ return 1
835
+ if loss_reduction == 'sample':
836
+ return 1 / x
837
+ if loss_reduction == 'square':
838
+ return 1 / (x ** 0.5)
839
+ raise NotImplementedError(loss_reduction)
840
+
841
+
842
+ def main():
843
+ # Apply necessary patches for the transformers library
844
+ replace_llama_rmsnorm_with_fused_rmsnorm()
845
+ replace_train_sampler()
846
+ replace_train_dataloader()
847
+
848
+ # Parse input arguments
849
+ # See all possible arguments in src/transformers/training_args.py
850
+ # If use DeepSpeed zero3, init_dist must before HfArgumentParser
851
+ launcher = os.environ.get('LAUNCHER', 'slurm')
852
+ init_dist(launcher=launcher, backend='nccl')
853
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
854
+ if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
855
+ # If we pass only one argument to the script, and it's the path to a json file,
856
+ # let's parse it to get our arguments.
857
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
858
+ else:
859
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
860
+
861
+ training_args.use_packed_ds = data_args.use_packed_ds
862
+
863
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
864
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
865
+ # send_example_telemetry('InternV-Chat', model_args, data_args)
866
+
867
+ # Setup logging
868
+ logging.basicConfig(
869
+ format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
870
+ datefmt='%m/%d/%Y %H:%M:%S',
871
+ handlers=[logging.StreamHandler(sys.stdout)],
872
+ )
873
+
874
+ if training_args.should_log:
875
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
876
+ transformers.utils.logging.set_verbosity_info()
877
+
878
+ log_level = training_args.get_process_log_level()
879
+ logger.setLevel(log_level)
880
+ set_verbosity(log_level)
881
+ enable_default_handler()
882
+ enable_explicit_format()
883
+
884
+ # Log on each process the small summary:
885
+ logger.warning(
886
+ f'Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}'
887
+ + f'distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}'
888
+ )
889
+ logger.info(f'Training/evaluation parameters {training_args}')
890
+
891
+ # Detecting last checkpoint and eventually continue from last checkpoint.
892
+ last_checkpoint = None
893
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
894
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
895
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
896
+ raise ValueError(
897
+ f'Output directory ({training_args.output_dir}) already exists and is not empty. '
898
+ 'Use --overwrite_output_dir to overcome.'
899
+ )
900
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
901
+ logger.info(
902
+ f'Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change '
903
+ 'the `--output_dir` or add `--overwrite_output_dir` to train from scratch.'
904
+ )
905
+ # Set seed before initializing model.
906
+ set_seed(training_args.seed)
907
+
908
+ # Load pretrained model, tokenizer, and image processor
909
+ tokenizer_path = model_args.model_name_or_path or model_args.llm_path
910
+ logger.info(f'Loading Tokenizer: {tokenizer_path}')
911
+ tokenizer = AutoTokenizer.from_pretrained(
912
+ tokenizer_path, add_eos_token=False, trust_remote_code=True, use_fast=model_args.use_fast_tokenizer)
913
+ tokenizer.tokenizer_path = tokenizer_path
914
+ tokenizer.model_max_length = data_args.max_seq_length
915
+ token_list = [IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN,
916
+ QUAD_START_TOKEN, QUAD_END_TOKEN, REF_START_TOKEN,
917
+ REF_END_TOKEN, BOX_START_TOKEN, BOX_END_TOKEN]
918
+ num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=True)
919
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
920
+ tcs_loader = TCSLoader('~/petreloss.conf') if has_tcs_loader else None
921
+
922
+ if data_args.use_packed_ds:
923
+ replace_internlm2_attention_class()
924
+ replace_qwen2_attention_class()
925
+ replace_phi3_attention_class()
926
+ replace_llama_attention_class()
927
+
928
+ if model_args.use_liger:
929
+ from internvl.patch import apply_liger_kernel_to_internvit
930
+ from liger_kernel.transformers import (apply_liger_kernel_to_llama,
931
+ apply_liger_kernel_to_qwen2)
932
+ apply_liger_kernel_to_llama()
933
+ apply_liger_kernel_to_qwen2()
934
+ # apply_liger_kernel_to_internvit()
935
+
936
+ if model_args.model_name_or_path is not None:
937
+ logger.info('Loading InternVLChatModel...')
938
+ config = InternVLChatConfig.from_pretrained(model_args.model_name_or_path)
939
+ config.vision_config.drop_path_rate = model_args.drop_path_rate
940
+ if config.llm_config.model_type == 'internlm2':
941
+ config.llm_config.attn_implementation = 'flash_attention_2' # for InternLM
942
+ logger.info('Using flash_attention_2 for InternLM')
943
+ else:
944
+ config.llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
945
+ logger.info('Using flash_attention_2 for LLaMA')
946
+ config.template = data_args.conv_style
947
+ config.select_layer = model_args.vision_select_layer
948
+ config.dynamic_image_size = data_args.dynamic_image_size
949
+ config.use_thumbnail = data_args.use_thumbnail
950
+ config.ps_version = model_args.ps_version
951
+ config.min_dynamic_patch = data_args.min_dynamic_patch
952
+ config.max_dynamic_patch = data_args.max_dynamic_patch
953
+ model = InternVLChatModel.from_pretrained(
954
+ model_args.model_name_or_path, torch_dtype=torch.bfloat16, config=config)
955
+ else:
956
+ logger.info('Loading ViT-6B...')
957
+ vision_config = InternVisionConfig.from_pretrained(model_args.vision_path)
958
+ vision_config.drop_path_rate = model_args.drop_path_rate
959
+ vision_model = InternVisionModel.from_pretrained(
960
+ model_args.vision_path, torch_dtype=torch.bfloat16, config=vision_config)
961
+ logger.info('Loading LLaMA...')
962
+ llm_config = AutoConfig.from_pretrained(model_args.llm_path, trust_remote_code=True)
963
+ if llm_config.model_type == 'internlm2':
964
+ model_type = InternLM2ForCausalLM
965
+ llm_config.attn_implementation = 'flash_attention_2' # for InternLM
966
+ logger.info('Using flash_attention_2 for InternLM')
967
+ else:
968
+ model_type = AutoModelForCausalLM
969
+ llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
970
+ logger.info('Using flash_attention_2 for LLaMA')
971
+ llm = model_type.from_pretrained(
972
+ model_args.llm_path, torch_dtype=torch.bfloat16,
973
+ config=llm_config, trust_remote_code=True)
974
+ logger.info('Building InternVLChatConfig...')
975
+ internvl_chat_config = InternVLChatConfig(
976
+ vision_config.to_dict(), llm_config.to_dict(), downsample_ratio=data_args.down_sample_ratio,
977
+ pad2square=data_args.pad2square, template=data_args.conv_style,
978
+ select_layer=model_args.vision_select_layer, dynamic_image_size=data_args.dynamic_image_size,
979
+ use_thumbnail=data_args.use_thumbnail, ps_version=model_args.ps_version,
980
+ min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch)
981
+ internvl_chat_config.force_image_size = data_args.force_image_size
982
+ logger.info('Building InternVLChatModel...')
983
+ model = InternVLChatModel(internvl_chat_config, vision_model, llm)
984
+ model.img_context_token_id = img_context_token_id
985
+
986
+ assert model.config.downsample_ratio == data_args.down_sample_ratio
987
+
988
+ if model_args.mlp_path is not None:
989
+ logger.info('Loading pretrained MLP projector...')
990
+ state_dict = torch.load(model_args.mlp_path, map_location='cpu')
991
+ message = model.mlp1.load_state_dict(state_dict)
992
+ logger.info(message)
993
+ logger.info('Finished')
994
+
995
+ patch_size = model.config.vision_config.patch_size
996
+ logger.info(f'model.config.force_image_size: {model.config.force_image_size}')
997
+ logger.info(f'data_args.force_image_size: {data_args.force_image_size}')
998
+ logger.info(f'model.config.vision_config.image_size: {model.config.vision_config.image_size}')
999
+ if model.config.vision_config.image_size != data_args.force_image_size:
1000
+ logger.info(f'Resizing position embedding from '
1001
+ f'{model.config.vision_config.image_size} '
1002
+ f'to {data_args.force_image_size}...')
1003
+ model.vision_model.resize_pos_embeddings(old_size=model.config.vision_config.image_size,
1004
+ new_size=data_args.force_image_size,
1005
+ patch_size=patch_size)
1006
+ model.config.vision_config.image_size = data_args.force_image_size
1007
+ model.config.force_image_size = data_args.force_image_size
1008
+ model.num_image_token = int((data_args.force_image_size // patch_size) ** 2 * (data_args.down_sample_ratio ** 2))
1009
+
1010
+ if num_new_tokens > 0:
1011
+ model.language_model.resize_token_embeddings(len(tokenizer))
1012
+ output_embeddings = model.language_model.get_output_embeddings().weight.data
1013
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
1014
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
1015
+
1016
+ model.config.llm_config.vocab_size = len(tokenizer)
1017
+ model.language_model.config.vocab_size = len(tokenizer)
1018
+
1019
+ model.language_model.config.use_cache = False
1020
+ model.vision_model.gradient_checkpointing = True
1021
+ model.vision_model.encoder.gradient_checkpointing = True
1022
+ if model_args.grad_checkpoint:
1023
+ model.language_model._set_gradient_checkpointing()
1024
+
1025
+ train_dataset = build_datasets(
1026
+ data_args, tokenizer, tcs_loader, model, group_by_length=training_args.group_by_length,
1027
+ dynamic_image_size=data_args.dynamic_image_size, use_thumbnail=data_args.use_thumbnail,
1028
+ min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch,
1029
+ normalize_type=data_args.normalize_type, min_num_frame=data_args.min_num_frame,
1030
+ max_num_frame=data_args.max_num_frame)
1031
+
1032
+ def _freeze_params(module):
1033
+ for param in module.parameters():
1034
+ param.requires_grad = False
1035
+
1036
+ if model_args.freeze_backbone:
1037
+ # model.vision_model = model.vision_model.eval()
1038
+ _freeze_params(model.vision_model)
1039
+
1040
+ if model_args.freeze_llm:
1041
+ model.language_model = model.language_model.eval()
1042
+ _freeze_params(model.language_model)
1043
+
1044
+ if model_args.unfreeze_lm_head:
1045
+ model.language_model.lm_head.requires_grad = True
1046
+
1047
+ if model_args.use_backbone_lora:
1048
+ model.wrap_backbone_lora(r=model_args.use_backbone_lora, lora_alpha=2 * model_args.use_backbone_lora)
1049
+ model.config.use_backbone_lora = model_args.use_backbone_lora
1050
+
1051
+ if model_args.use_llm_lora:
1052
+ model.wrap_llm_lora(r=model_args.use_llm_lora, lora_alpha=2 * model_args.use_llm_lora)
1053
+ model.config.use_llm_lora = model_args.use_llm_lora
1054
+
1055
+ if model_args.freeze_mlp:
1056
+ _freeze_params(model.mlp1)
1057
+
1058
+ if model_args.unfreeze_vit_layers != 0:
1059
+ layers = model.vision_model.encoder.layers[model_args.unfreeze_vit_layers:]
1060
+ for k, v in layers.named_parameters():
1061
+ logger.info(f'Unfreezing ViT layer: {k}')
1062
+ v.requires_grad = True
1063
+
1064
+ # print trainable parameters
1065
+ if dist.get_rank() == 0:
1066
+ for name, param in model.named_parameters():
1067
+ if param.requires_grad:
1068
+ logger.info(name)
1069
+
1070
+ # set seed for torch dataloaders
1071
+ set_seed(training_args.seed)
1072
+
1073
+ if data_args.use_packed_ds:
1074
+ collator = partial(
1075
+ packed_collate_fn,
1076
+ data_collator=concat_pad_data_collator,
1077
+ max_item_length=data_args.max_packed_tokens if data_args.strict_mode else 0,
1078
+ micro_num=training_args.train_batch_size,
1079
+ len2weight=partial(len2weight, loss_reduction=data_args.loss_reduction),
1080
+ loss_reduction_all_gather=data_args.loss_reduction_all_gather,
1081
+ )
1082
+ else:
1083
+ collator = concat_pad_data_collator
1084
+
1085
+ trainer = Trainer(
1086
+ model=model,
1087
+ args=training_args,
1088
+ train_dataset=train_dataset if training_args.do_train else None,
1089
+ eval_dataset=None,
1090
+ tokenizer=tokenizer,
1091
+ data_collator=collator,
1092
+ )
1093
+
1094
+ # Training
1095
+ if training_args.do_train:
1096
+ checkpoint = None
1097
+ if training_args.resume_from_checkpoint is not None:
1098
+ checkpoint = training_args.resume_from_checkpoint
1099
+ elif last_checkpoint is not None:
1100
+ checkpoint = last_checkpoint
1101
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
1102
+ trainer.save_model() # Saves the tokenizer too for easy upload
1103
+
1104
+ metrics = train_result.metrics
1105
+ try:
1106
+ metrics['train_samples'] = len(train_dataset)
1107
+ except:
1108
+ metrics['train_samples'] = -1
1109
+
1110
+ trainer.log_metrics('train', metrics)
1111
+ trainer.save_metrics('train', metrics)
1112
+ trainer.save_state()
1113
+
1114
+
1115
+ if __name__ == '__main__':
1116
+ main()
src/third_party/InternVL/internvl_chat/internvl/train/trainer_dpo.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ from copy import deepcopy
8
+ from typing import Dict, List, Literal, Optional, Tuple, Union
9
+
10
+ import deepspeed
11
+ import torch
12
+ from torch import nn
13
+ from torch.utils.data import ConcatDataset
14
+ from trl import DPOTrainer
15
+ from trl.trainer.utils import RunningMoments, pad_to_length
16
+
17
+
18
+ def _map(self, *args, **kwargs):
19
+ return self
20
+
21
+
22
+ ConcatDataset.map = _map
23
+
24
+
25
+ class MultimodalDPOTrainer(DPOTrainer):
26
+ def __init__(self, *args, **kwargs):
27
+ super().__init__(*args, **kwargs)
28
+
29
+ if self.loss_type != 'bco_pair' and 'bco_pair' in self.loss_type:
30
+ self.running = RunningMoments(self.accelerator)
31
+
32
+ @staticmethod
33
+ def concatenated_inputs(
34
+ batch: Dict[str, Union[List, torch.LongTensor]],
35
+ is_encoder_decoder: bool = False,
36
+ is_vision_model: bool = False,
37
+ label_pad_token_id: int = -100,
38
+ padding_value: int = 0,
39
+ device: Optional[torch.device] = None,
40
+ ) -> Dict[str, torch.LongTensor]:
41
+ """Concatenate the chosen and rejected inputs into a single tensor.
42
+
43
+ Args:
44
+ batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
45
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
46
+ label_pad_token_id: The label pad token id.
47
+ padding_value: The padding value to use for the concatenated inputs_ids.
48
+ device: The device for the concatenated inputs.
49
+
50
+ Returns:
51
+ A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
52
+ """
53
+ concatenated_batch = {}
54
+
55
+ if is_encoder_decoder:
56
+ max_length = max(batch['chosen_labels'].shape[1], batch['rejected_labels'].shape[1])
57
+ else:
58
+ max_length = max(batch['chosen_input_ids'].shape[1], batch['rejected_input_ids'].shape[1])
59
+
60
+ for k in batch:
61
+ if k.startswith('chosen') and isinstance(batch[k], torch.Tensor):
62
+ if 'labels' in k or is_encoder_decoder:
63
+ pad_value = label_pad_token_id
64
+ elif k.endswith('_input_ids'):
65
+ pad_value = padding_value
66
+ elif k.endswith('_attention_mask'):
67
+ pad_value = 0
68
+ concatenated_key = k.replace('chosen', 'concatenated')
69
+ concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
70
+ for k in batch:
71
+ if k.startswith('rejected') and isinstance(batch[k], torch.Tensor):
72
+ if 'labels' in k or is_encoder_decoder:
73
+ pad_value = label_pad_token_id
74
+ elif k.endswith('_input_ids'):
75
+ pad_value = padding_value
76
+ elif k.endswith('_attention_mask'):
77
+ pad_value = 0
78
+ concatenated_key = k.replace('rejected', 'concatenated')
79
+ concatenated_batch[concatenated_key] = torch.cat(
80
+ (
81
+ concatenated_batch[concatenated_key],
82
+ pad_to_length(batch[k], max_length, pad_value=pad_value),
83
+ ),
84
+ dim=0,
85
+ ).to(device=device)
86
+
87
+ if is_encoder_decoder:
88
+ concatenated_batch['concatenated_input_ids'] = batch['prompt_input_ids'].repeat(2, 1).to(device=device)
89
+ concatenated_batch['concatenated_attention_mask'] = (
90
+ batch['prompt_attention_mask'].repeat(2, 1).to(device=device)
91
+ )
92
+
93
+ if 'pixel_values' in batch:
94
+ concatenated_batch['pixel_values'] = batch['pixel_values'].repeat(2, 1, 1, 1)
95
+ concatenated_batch['image_flags'] = batch['image_flags'].repeat(2)
96
+
97
+ return concatenated_batch
98
+
99
+ def concatenated_forward(
100
+ self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
101
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
102
+ """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
103
+
104
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
105
+ """
106
+ concatenated_batch = self.concatenated_inputs(
107
+ batch,
108
+ is_encoder_decoder=self.is_encoder_decoder,
109
+ is_vision_model=self.is_vision_model,
110
+ label_pad_token_id=self.label_pad_token_id,
111
+ padding_value=self.padding_value,
112
+ device=self.accelerator.device,
113
+ )
114
+ len_chosen = batch['chosen_labels'].shape[0]
115
+
116
+ model_kwargs = {}
117
+
118
+ if self.is_encoder_decoder:
119
+ model_kwargs['labels'] = concatenated_batch['concatenated_labels']
120
+ model_kwargs['decoder_input_ids'] = concatenated_batch.pop('concatenated_decoder_input_ids', None)
121
+
122
+ if self.is_vision_model:
123
+ model_kwargs['pixel_values'] = concatenated_batch['pixel_values']
124
+ model_kwargs['pixel_attention_mask'] = concatenated_batch['pixel_attention_mask']
125
+
126
+ if self.aux_loss_enabled:
127
+ model_kwargs['output_router_logits'] = True
128
+
129
+ outputs = model(
130
+ input_ids=concatenated_batch['concatenated_input_ids'],
131
+ attention_mask=concatenated_batch['concatenated_attention_mask'],
132
+ pixel_values=concatenated_batch['pixel_values'],
133
+ image_flags=concatenated_batch['image_flags'],
134
+ use_cache=False,
135
+ **model_kwargs,
136
+ )
137
+ all_logits = outputs.logits
138
+
139
+ all_logps, size_completion = self.get_batch_logps(
140
+ all_logits,
141
+ concatenated_batch['concatenated_labels'],
142
+ # average_log_prob=self.loss_type == "ipo",
143
+ is_encoder_decoder=self.is_encoder_decoder,
144
+ label_pad_token_id=self.label_pad_token_id,
145
+ )
146
+
147
+ def cross_entropy_loss(logits, labels):
148
+ if not self.is_encoder_decoder:
149
+ # Shift so that tokens < n predict n
150
+ logits = logits[..., :-1, :].contiguous()
151
+ labels = labels[..., 1:].contiguous()
152
+ # Flatten the tokens
153
+ loss_fct = nn.CrossEntropyLoss()
154
+ logits = logits.view(-1, logits.shape[-1])
155
+ labels = labels.view(-1)
156
+ # Enable model parallelism
157
+ labels = labels.to(logits.device)
158
+ loss = loss_fct(logits, labels)
159
+ return loss
160
+
161
+ labels = concatenated_batch['concatenated_labels'].clone()
162
+ nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
163
+
164
+ if self.loss_type == 'ipo':
165
+ all_logps = all_logps / size_completion
166
+
167
+ chosen_logps = all_logps[:len_chosen]
168
+ rejected_logps = all_logps[len_chosen:]
169
+
170
+ chosen_logits = all_logits[:len_chosen]
171
+ rejected_logits = all_logits[len_chosen:]
172
+
173
+ if self.aux_loss_enabled:
174
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
175
+
176
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
177
+
178
+ def _prepare_deepspeed_orig(self, model):
179
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
180
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
181
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
182
+
183
+ # If ZeRO-3 is used, we shard both the active and reference model.
184
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
185
+ if config_kwargs['zero_optimization']['stage'] != 3:
186
+ config_kwargs['zero_optimization']['stage'] = 0
187
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
188
+ model.eval()
189
+ return model
190
+
191
+ def _prepare_deepspeed(self, model):
192
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
193
+ config_kwargs = deepspeed_plugin.deepspeed_config
194
+ if config_kwargs['zero_optimization']['stage'] == 3:
195
+ print('Enable DPOTrainer._prepare_deepspeed')
196
+ return self._prepare_deepspeed_orig(model)
197
+
198
+ print('Disable DPOTrainer._prepare_deepspeed')
199
+ for param in model.parameters():
200
+ param.requires_grad = False
201
+
202
+ model.eval()
203
+ model = model.to(self.accelerator.device)
204
+ return model
205
+
206
+ def get_batch_loss_metrics(
207
+ self,
208
+ model,
209
+ batch: Dict[str, Union[List, torch.LongTensor]],
210
+ train_eval: Literal['train', 'eval'] = 'train',
211
+ ):
212
+ """Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
213
+ metrics = {}
214
+
215
+ forward_output = self.concatenated_forward(model, batch)
216
+ (
217
+ policy_chosen_logps,
218
+ policy_rejected_logps,
219
+ policy_chosen_logits,
220
+ policy_rejected_logits,
221
+ policy_nll_loss,
222
+ ) = forward_output[:5]
223
+ if self.aux_loss_enabled:
224
+ aux_loss = forward_output[5]
225
+
226
+ # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
227
+ if (
228
+ 'reference_chosen_logps' in batch
229
+ and 'reference_rejected_logps' in batch
230
+ and self.args.rpo_alpha is not None
231
+ ):
232
+ reference_chosen_logps = batch['reference_chosen_logps']
233
+ reference_rejected_logps = batch['reference_rejected_logps']
234
+ else:
235
+ with torch.no_grad():
236
+ if self.ref_model is None:
237
+ with self.null_ref_context():
238
+ (
239
+ reference_chosen_logps,
240
+ reference_rejected_logps,
241
+ _,
242
+ _,
243
+ _,
244
+ ) = self.concatenated_forward(self.model, batch)
245
+ else:
246
+ (
247
+ reference_chosen_logps,
248
+ reference_rejected_logps,
249
+ _,
250
+ _,
251
+ _,
252
+ ) = self.concatenated_forward(self.ref_model, batch)
253
+
254
+ if ',' in self.loss_type:
255
+ loss_type = self.loss_type
256
+ loss_type_list = loss_type.split(',')
257
+
258
+ losses, chosen_rewards, rejected_rewards = 0, 0, 0
259
+ for curr_type in loss_type_list:
260
+ self.loss_type = curr_type
261
+ curr_losses, curr_chosen_rewards, curr_rejected_rewards = self.dpo_loss(
262
+ policy_chosen_logps,
263
+ policy_rejected_logps,
264
+ reference_chosen_logps,
265
+ reference_rejected_logps,
266
+ )
267
+ curr_weight = getattr(self.args, f'{curr_type}_loss_weight')
268
+ losses = losses + curr_losses * curr_weight
269
+ chosen_rewards = chosen_rewards + curr_chosen_rewards * curr_weight
270
+ rejected_rewards = rejected_rewards + curr_rejected_rewards * curr_weight
271
+
272
+ self.loss_type = loss_type
273
+ else:
274
+ losses, chosen_rewards, rejected_rewards = self.dpo_loss(
275
+ policy_chosen_logps,
276
+ policy_rejected_logps,
277
+ reference_chosen_logps,
278
+ reference_rejected_logps,
279
+ )
280
+
281
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
282
+
283
+ if self.args.rpo_alpha is not None:
284
+ # losses = losses * self.args.rpo_alpha + policy_nll_loss
285
+ losses = losses + policy_nll_loss * self.args.rpo_alpha
286
+
287
+ prefix = 'eval_' if train_eval == 'eval' else ''
288
+ metrics[f'{prefix}rewards/chosen'] = chosen_rewards.mean().cpu()
289
+ metrics[f'{prefix}rewards/rejected'] = rejected_rewards.mean().cpu()
290
+ metrics[f'{prefix}rewards/accuracies'] = reward_accuracies.mean().cpu()
291
+ metrics[f'{prefix}rewards/margins'] = (chosen_rewards - rejected_rewards).mean().cpu()
292
+ metrics[f'{prefix}logps/rejected'] = policy_rejected_logps.detach().mean().cpu()
293
+ metrics[f'{prefix}logps/chosen'] = policy_chosen_logps.detach().mean().cpu()
294
+ metrics[f'{prefix}logits/rejected'] = policy_rejected_logits.detach().mean().cpu()
295
+ metrics[f'{prefix}logits/chosen'] = policy_chosen_logits.detach().mean().cpu()
296
+ if self.args.rpo_alpha is not None:
297
+ metrics[f'{prefix}nll_loss'] = policy_nll_loss.detach().mean().cpu()
298
+
299
+ if self.aux_loss_enabled:
300
+ return losses.mean() + getattr(model.config, 'router_aux_loss_coef', 0.0) * aux_loss, metrics
301
+
302
+ return losses.mean(), metrics
src/third_party/InternVL/internvl_chat/pyproject.toml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "internvl_chat"
7
+ version = "2.0.0"
8
+ description = "Scaling up Vision Foundation Models and Aligning for Generic Visual-Linguistic Tasks."
9
+ readme = "README.md"
10
+ requires-python = ">=3.8"
11
+ classifiers = [
12
+ "Programming Language :: Python :: 3",
13
+ "License :: OSI Approved :: Apache Software License",
14
+ ]
15
+ dependencies = [
16
+ "torch>=2", "torchvision>=0.15",
17
+ "transformers==4.37.2", "tokenizers==0.15.1", "sentencepiece==0.1.99", "shortuuid",
18
+ "accelerate", "peft>=0.4.0", "bitsandbytes==0.41.0",
19
+ "pydantic", "markdown2[all]", "numpy", "scikit-learn>=1.2.2",
20
+ "gradio==3.35.2", "gradio_client==0.2.9",
21
+ "requests", "httpx==0.24.0", "uvicorn", "fastapi",
22
+ "deepspeed==0.13.5", "einops", "einops-exts", "timm==0.9.12",
23
+ ]
24
+
25
+ [project.urls]
26
+ "Homepage" = "https://github.com/OpenGVLab/InternVL"
27
+ "Bug Tracker" = "https://github.com/OpenGVLab/InternVL/issues"
28
+
29
+ [tool.setuptools.packages.find]
30
+ exclude = ["data*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "shell*"]
31
+
32
+ [tool.wheel]
33
+ exclude = ["data*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "shell*"]
src/third_party/InternVL/internvl_chat/tools/convert_to_int8.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModel, AutoTokenizer
3
+
4
+ path = 'OpenGVLab/InternVL-Chat-V1-5'
5
+ model = AutoModel.from_pretrained(
6
+ path,
7
+ torch_dtype=torch.bfloat16,
8
+ low_cpu_mem_usage=True,
9
+ trust_remote_code=True,
10
+ load_in_8bit=True).eval()
11
+
12
+ tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
13
+
14
+ model.save_pretrained('release/InternVL-Chat-V1-5-Int8')
15
+ tokenizer.save_pretrained('release/InternVL-Chat-V1-5-Int8')
16
+ print('finished')
src/third_party/InternVL/internvl_chat/tools/extract_mlp.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os.path
3
+
4
+ import torch
5
+ from internvl.model.internvl_chat import InternVLChatModel
6
+
7
+ argparse = argparse.ArgumentParser()
8
+ argparse.add_argument('model_path', type=str, default='')
9
+ argparse.add_argument('output_path', type=str, default='')
10
+
11
+ args = argparse.parse_args()
12
+
13
+ model = InternVLChatModel.from_pretrained(args.model_path, torch_dtype=torch.bfloat16)
14
+ model = model.mlp1.to(torch.bfloat16)
15
+
16
+ ckpt = model.state_dict()
17
+ output_path = os.path.join(args.output_path, 'mlp_projector.pth')
18
+ torch.save(ckpt, output_path)
19
+ print('finished')
src/third_party/InternVL/internvl_chat/tools/extract_video_frames.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent.futures
2
+ import json
3
+ import os
4
+
5
+ import av
6
+ import numpy as np
7
+ import torch
8
+ from decord import VideoReader, cpu
9
+ from PIL import Image
10
+ from tqdm.auto import tqdm
11
+
12
+ num_segments = 1
13
+
14
+ # root directory of evaluation dimension 10
15
+ dimension10_dir = './videos/20bn-something-something-v2'
16
+ # root directory of evaluation dimension 11
17
+ dimension11_dir = './videos/EPIC-KITCHENS'
18
+ # root directory of evaluation dimension 12
19
+ dimension12_dir = './videos/BreakfastII_15fps_qvga_sync'
20
+
21
+
22
+ def transform_video(buffer):
23
+ try:
24
+ buffer = buffer.numpy()
25
+ except AttributeError:
26
+ try:
27
+ buffer = buffer.asnumpy()
28
+ except AttributeError:
29
+ print('Both buffer.numpy() and buffer.asnumpy() failed.')
30
+ buffer = None
31
+ images_group = list()
32
+ for fid in range(len(buffer)):
33
+ images_group.append(Image.fromarray(buffer[fid]))
34
+ return images_group
35
+
36
+
37
+ def get_index(num_frames, num_segments):
38
+ if num_segments > num_frames:
39
+ offsets = np.array([
40
+ idx for idx in range(num_frames)
41
+ ])
42
+ else:
43
+ # uniform sampling
44
+ seg_size = float(num_frames - 1) / num_segments
45
+ start = int(seg_size / 2)
46
+ offsets = np.array([
47
+ start + int(np.round(seg_size * idx)) for idx in range(num_segments)
48
+ ])
49
+ return offsets
50
+
51
+
52
+ def fetch_images(qa_item):
53
+ use_pyav = False
54
+ segment = None
55
+ if qa_item['question_type_id'] == 10:
56
+ data_path = os.path.join(dimension10_dir, qa_item['data_id'])
57
+ start = 0.0
58
+ end = 0.0
59
+ elif qa_item['question_type_id'] == 11:
60
+ data_path = os.path.join(dimension11_dir, qa_item['data_id'].split('/')[-1])
61
+ segment = qa_item['segment']
62
+ start, end = segment[0], segment[1]
63
+ elif qa_item['question_type_id'] == 12:
64
+ data_path = os.path.join(dimension12_dir, qa_item['data_id'])
65
+ segment = qa_item['segment']
66
+ start, end = segment[0], segment[1]
67
+ use_pyav = True
68
+
69
+ if use_pyav:
70
+ # using pyav for decoding videos in evaluation dimension 12
71
+ reader = av.open(data_path)
72
+ frames = [torch.from_numpy(f.to_rgb().to_ndarray()) for f in reader.decode(video=0)]
73
+ video_len = len(frames)
74
+ start_frame, end_frame = start, end
75
+ end_frame = min(end_frame, video_len)
76
+ offset = get_index(end_frame - start_frame, num_segments)
77
+ frame_indices = offset + start_frame
78
+ buffer = torch.stack([frames[idx] for idx in frame_indices])
79
+ else:
80
+ # using decord for decoding videos in evaluation dimension 10-11
81
+ vr = VideoReader(data_path, num_threads=1, ctx=cpu(0))
82
+ video_len = len(vr)
83
+ fps = vr.get_avg_fps()
84
+ if segment is not None:
85
+ # obtain start and end frame for the video segment in evaluation dimension 11
86
+ start_frame = int(min(max(start * fps, 0), video_len - 1))
87
+ end_frame = int(min(max(end * fps, 0), video_len - 1))
88
+ tot_frames = int(end_frame - start_frame)
89
+ offset = get_index(tot_frames, num_segments)
90
+ frame_indices = offset + start_frame
91
+ else:
92
+ # sample frames of the video in evaluation dimension 10
93
+ frame_indices = get_index(video_len - 1, num_segments)
94
+ vr.seek(0)
95
+ buffer = vr.get_batch(frame_indices)
96
+ return transform_video(buffer)
97
+
98
+
99
+ def fetch_images_parallel(qa_item):
100
+ return qa_item, fetch_images(qa_item)
101
+
102
+
103
+ if __name__ == '__main__':
104
+ data = json.load(open('SEED-Bench.json'))
105
+ video_img_dir = 'SEED-Bench-video-image'
106
+ ques_type_id_to_name = {id:n for n,id in data['question_type'].items()}
107
+
108
+ video_data = [x for x in data['questions'] if x['data_type'] == 'video']
109
+
110
+ with open(output, 'w') as f, concurrent.futures.ThreadPoolExecutor() as executor:
111
+ future_to_images = {executor.submit(fetch_images_parallel, qa_item): qa_item for qa_item in video_data}
112
+ for future in tqdm(concurrent.futures.as_completed(future_to_images), total=len(future_to_images)):
113
+ qa_item = future_to_images[future]
114
+ try:
115
+ qa_item, images = future.result()
116
+ except Exception as exc:
117
+ print(f'{qa_item} generated an exception: {exc}')
118
+ else:
119
+ img_file = f"{qa_item['question_type_id']}_{qa_item['question_id']}.png"
120
+ images[0].save(os.path.join(video_img_dir, img_file))
src/third_party/InternVL/internvl_chat/tools/extract_vit.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from internvl.model.internvl_chat import InternVLChatModel
5
+
6
+ argparse = argparse.ArgumentParser()
7
+ argparse.add_argument('model_path', type=str, default='')
8
+ argparse.add_argument('output_path', type=str, default='')
9
+
10
+ args = argparse.parse_args()
11
+
12
+ model = InternVLChatModel.from_pretrained(args.model_path, torch_dtype=torch.bfloat16)
13
+ model = model.vision_model.to(torch.bfloat16)
14
+
15
+ model.save_pretrained(args.output_path)
16
+ print('finished')
src/third_party/InternVL/internvl_chat/tools/images_stitching.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ from PIL import Image, ImageDraw, ImageFont
6
+ from tqdm import tqdm
7
+
8
+ FOOT = ImageFont.truetype('/usr/share/fonts/dejavu/DejaVuSans-Bold.ttf', 50)
9
+
10
+
11
+ def custom_image(img_paths, save_path, image_size=448):
12
+ captions = ['CAM_FRONT_LEFT', 'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_BACK_LEFT', 'CAM_BACK', 'CAM_BACK_RIGHT']
13
+
14
+ width = image_size * 2
15
+ height = image_size
16
+ # count = 0
17
+ all_images = {}
18
+ for image_id, image_files in tqdm(img_paths.items()):
19
+ all_images[image_id] = dict()
20
+ all_images[image_id]['images_path'] = image_files
21
+ all_images[image_id]['images_size'] = {k: (0, 0) for k in image_files.keys()}
22
+ imgs = {}
23
+ for caption, image_file in image_files.items():
24
+ image_path = os.path.join(args.data_root, image_file.replace('../nuscenes/samples/', '/nuscenes/samples/'))
25
+ img = Image.open(image_path).convert('RGB')
26
+ old_wide, old_height = img.size
27
+ all_images[image_id]['images_size'][caption] = (old_wide, old_height)
28
+ img = img.resize((width, height))
29
+
30
+ draw = ImageDraw.Draw(img)
31
+ text = caption
32
+ draw.text((0, 0), text, fill=(255, 0, 255), font=FOOT)
33
+ imgs[caption] = img
34
+
35
+ result_width = width * 3
36
+ result_height = height * 2
37
+ result_img = Image.new('RGB', (result_width, result_height))
38
+
39
+ imgs = [imgs[caption] for caption in captions]
40
+ for i in range(len(imgs)):
41
+ row = i // 3
42
+ col = i % 3
43
+
44
+ left = col * width
45
+ top = row * height
46
+ right = left + width
47
+ bottom = top + height
48
+ result_img.paste(imgs[i], (left, top))
49
+
50
+ result_path = os.path.join(save_path, image_id + '.jpg')
51
+ result_img.save(result_path)
52
+
53
+
54
+ def get_images(ann_file):
55
+ with open(ann_file, 'r') as f: # , \
56
+ train_file = json.load(f)
57
+
58
+ images = {}
59
+ for scene_id in train_file.keys():
60
+ scene_data = train_file[scene_id]['key_frames']
61
+ for frame_id in scene_data.keys():
62
+ image_id = scene_id + '_' + frame_id
63
+ if image_id not in images:
64
+ images[image_id] = scene_data[frame_id]['image_paths']
65
+ else:
66
+ print(image_id)
67
+
68
+ return images
69
+
70
+
71
+ if __name__ == '__main__':
72
+ parser = argparse.ArgumentParser()
73
+ parser.add_argument('--data-root', type=str, default='InternVL-Domain-Adaptation-Data/images/drivelm')
74
+ parser.add_argument('--ann-file', type=str, default='path/to/v1_1_val_nus_q_only.json')
75
+ args = parser.parse_args()
76
+ images = get_images(args.ann_file)
77
+ save_path = os.path.join(args.data_root, 'stitch')
78
+ os.makedirs(save_path, exist_ok=True)
79
+ custom_image(img_paths=images, save_path=save_path)
src/third_party/InternVL/internvl_chat/tools/json2jsonl.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+
4
+ argparse = argparse.ArgumentParser()
5
+ argparse.add_argument('path', type=str)
6
+
7
+ args = argparse.parse_args()
8
+
9
+ assert args.path.endswith('.json')
10
+
11
+ data = json.load(open(args.path))
12
+ writer = open(args.path.replace('.json', '.jsonl'), 'w')
13
+ for idx, item in enumerate(data):
14
+ conversations = item['conversations']
15
+ if conversations[0]['from'] == 'system':
16
+ item['conversations'] = item['conversations'][1:]
17
+ item['id'] = idx
18
+ writer.write(json.dumps(item, ensure_ascii=False) + '\n')
19
+
20
+ writer.close()
src/third_party/InternVL/internvl_chat/tools/jsonl2jsonl.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ argparse = argparse.ArgumentParser()
6
+ argparse.add_argument('path', type=str)
7
+
8
+ args = argparse.parse_args()
9
+
10
+ assert args.path.endswith('.jsonl')
11
+
12
+ f = open(args.path)
13
+ data = [json.loads(line) for line in f.readlines()]
14
+ writer = open(args.path.replace('.jsonl', '_new.jsonl'), 'w')
15
+ for idx, item in enumerate(data):
16
+ item['id'] = idx
17
+ conversations = item['conversations']
18
+ if conversations[0]['from'] == 'system':
19
+ item['conversations'] = item['conversations'][1:]
20
+ writer.write(json.dumps(item, ensure_ascii=False) + '\n')
21
+
22
+ writer.close()
src/third_party/InternVL/internvl_chat/tools/merge_lora.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from internvl.model.internvl_chat import InternVLChatModel
5
+ from transformers import AutoTokenizer
6
+
7
+ argparse = argparse.ArgumentParser()
8
+ argparse.add_argument('input_path', type=str, help='Path to the input model')
9
+ argparse.add_argument('output_path', type=str, help='Path to the output model')
10
+ args = argparse.parse_args()
11
+
12
+ print('Loading model...')
13
+ model = InternVLChatModel.from_pretrained(
14
+ args.input_path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).eval()
15
+ print('Loading tokenizer...')
16
+ tokenizer = AutoTokenizer.from_pretrained(args.input_path, trust_remote_code=True)
17
+
18
+ if model.config.use_backbone_lora:
19
+ model.vision_model.merge_and_unload()
20
+ model.vision_model = model.vision_model.model
21
+ model.config.use_backbone_lora = 0
22
+ if model.config.use_llm_lora:
23
+ model.language_model.merge_and_unload()
24
+ model.language_model = model.language_model.model
25
+ model.config.use_llm_lora = 0
26
+
27
+ print('Saving model...')
28
+ model.save_pretrained(args.output_path)
29
+ print('Saving tokenizer...')
30
+ tokenizer.save_pretrained(args.output_path)
31
+ print('Done!')