File size: 8,663 Bytes
a72a7d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1855cc2
a72a7d4
 
 
1855cc2
a72a7d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1855cc2
a72a7d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1855cc2
a72a7d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
from transformers.trainer import Trainer, TRAINING_ARGS_NAME
import torch.distributed as dist
from typing import Optional
import os
import torch
from src.loss import SimpleContrastiveLoss, DistributedContrastiveLoss, HardNegativeContrastiveLoss, DistributedHardNegativeContrastiveLoss
from itertools import repeat
from grad_cache.grad_cache import GradCache


MAX_INPUT_ID = int(1e9)
LLAVA_IMAGE_TOKEN_ID = 32000

class MMEBTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super(MMEBTrainer, self).__init__(*args, **kwargs)
        self.is_ddp = dist.is_initialized()
        self._dist_loss_scale_factor = dist.get_world_size() if self.is_ddp else 1

    def compute_loss(self, model, inputs, *args, **kwargs):
        if self.args.hard_neg:
            qry_inputs, tgt_inputs, neg_inputs = inputs
            return model(qry=qry_inputs, tgt=tgt_inputs, neg=neg_inputs)

        qry_inputs, tgt_inputs = inputs
        return model(qry=qry_inputs, tgt=tgt_inputs)

    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        os.makedirs(output_dir, exist_ok=True)

        if state_dict is None:
            state_dict = self.model.state_dict()
        prefix = 'encoder.'
        assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys())
        state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
        self.model.encoder.save_pretrained(
            output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
        )

        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)

        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))


def split_dense_inputs(model_input: dict, chunk_size: int):
    assert len(model_input) == 1
    arg_key = list(model_input.keys())[0]
    arg_val = model_input[arg_key]

    keys = list(arg_val.keys())
    chunked_tensors = [arg_val[k].split(chunk_size, dim=0) for k in keys]
    chunked_arg_val = [dict(zip(kk, tt)) for kk, tt in zip(repeat(keys), zip(*chunked_tensors))]

    return [{arg_key: c} for c in chunked_arg_val]


def split_vlm_inputs(model_input: dict, chunk_size: int):
    assert len(model_input) == 1
    arg_key = list(model_input.keys())[0]
    arg_val = model_input[arg_key]
    keys = list(arg_val.keys())

    # for input_ids and attention_mask, split directly
    chunked_tensors = [arg_val[k].split(chunk_size, dim=0) for k in ["input_ids", "attention_mask"]]

    # for pixel_values and image_sizes or any other image-related fields, need to split based on the position of images
    
    image_mask = "image_mask" if "image_mask" in keys else None

    if image_mask in keys:
        row_contain_image = torch.nonzero(arg_val[image_mask], as_tuple=False).squeeze()  # indicates which row in input_ids contain images
        if image_mask == "image_mask": 
            keys.remove(image_mask)
        num_chunks = len(chunked_tensors[0])
        chunk_image_count = []
        for chunk_idx in range(num_chunks):
            chunk_image_count.append(torch.sum(
                (row_contain_image >= chunk_idx * chunk_size) & (row_contain_image < (chunk_idx + 1) * chunk_size)).item())
        
        if "pixel_values" in keys:
            pixel_values = arg_val["pixel_values"]
            chunked_tensors.append(torch.split(pixel_values, chunk_image_count))
        if "image_sizes" in keys:
            image_sizes = arg_val["image_sizes"]
            chunked_tensors.append(torch.split(image_sizes, chunk_image_count))
        if "image_grid_thw" in keys:
            image_grid_thw = arg_val["image_grid_thw"]
            chunked_tensors.append(torch.split(image_grid_thw, chunk_image_count))
        
        if "image_flags" in keys:
            image_flags = arg_val["image_flags"]
            chunked_tensors.append(torch.split(image_flags, chunk_size))
            keys.remove("image_flags") 
        
    
    chunked_arg_val = []
    for kk, tt in zip(repeat(keys), zip(*chunked_tensors)):
        chunk_dict = {}
        # 先添加基本字段
        if "pixel_values" in keys and tt[2].numel() == 0:  # this chunk doesn't contain image
            chunk_dict.update(dict(zip(kk[:2], tt[:2])))
        else:
            chunk_dict.update(dict(zip(kk, tt)))
            
        # 如果有image_flags,添加对应的chunk
        if "image_flags" in arg_val:
            chunk_idx = len(chunked_arg_val)
            chunk_dict["image_flags"] = chunked_tensors[-1][chunk_idx]
            
        chunked_arg_val.append(chunk_dict)

    return [{arg_key: c} for c in chunked_arg_val]


def get_dense_rep(x):
    """
    Get either qry_reps or tgt_reps.
    """
    if x["qry_reps"] is None:
        return x["tgt_reps"]
    else:
        return x["qry_reps"]


class GradCacheTrainer(Trainer):
    """
    Adapted from gradcache repo.
    """
    def __init__(self, *args, **kwargs):
        super(GradCacheTrainer, self).__init__(*args, **kwargs)
        self.is_ddp = dist.is_initialized()
        self._dist_loss_scale_factor = dist.get_world_size() if self.is_ddp else 1
        # loss_fn_cls = DistributedContrastiveLoss if self.is_ddp else SimpleContrastiveLoss
        # 使用新的损失函数
        loss_fn_cls = DistributedHardNegativeContrastiveLoss if self.is_ddp else HardNegativeContrastiveLoss
        loss_fn = loss_fn_cls(temperature=self.model.temperature)

        self.gc = GradCache(
            models=[self.model, self.model],
            chunk_sizes=[self.args.gc_q_chunk_size, self.args.gc_p_chunk_size],
            loss_fn=loss_fn,
            split_input_fn=split_vlm_inputs,
            get_rep_fn=get_dense_rep,
            fp16=self.args.fp16,
            scaler=self.scaler if self.args.fp16 else None
        )

    def training_step(self, model, inputs, *args, **kwargs) -> torch.Tensor:
        model.train()
        
        if self.args.hard_neg:
            queries, passages, negatives = inputs
            queries, passages, negatives = {'qry': queries}, {'tgt': passages}, {'neg': negatives}
            
            if self.args.local_rank == 0:
                print(f"qry.shape={queries['qry']['input_ids'].shape}")
                print(f"tgt.shape={passages['tgt']['input_ids'].shape}")
                print(f"neg.shape={negatives['neg']['input_ids'].shape}")
                if 'pixel_values' in queries['qry']:
                    print(f"qry_img.shape={queries['qry']['pixel_values'].shape}")
                if 'pixel_values' in passages['tgt']:
                    print(f"tgt_img.shape={passages['tgt']['pixel_values'].shape}")
                if 'pixel_values' in negatives['neg']:
                    print(f"neg_img.shape={negatives['neg']['pixel_values'].shape}")
            
            _distributed = self.args.local_rank > -1
            self.gc.models = [model, model, model]  
            loss = self.gc(queries, passages, negatives, no_sync_except_last=_distributed)
        else:
            queries, passages = inputs
            queries, passages = {'qry': queries}, {'tgt': passages}
            
            if self.args.local_rank == 0:
                print(f"qry.shape={queries['qry']['input_ids'].shape}")
                print(f"tgt.shape={passages['tgt']['input_ids'].shape}")
                if 'pixel_values' in queries['qry']:
                    print(f"qry_img.shape={queries['qry']['pixel_values'].shape}")
                if 'pixel_values' in passages['tgt']:
                    print(f"tgt_img.shape={passages['tgt']['pixel_values'].shape}")
            
            _distributed = self.args.local_rank > -1
            self.gc.models = [model, model]
            loss = self.gc(queries, passages, no_sync_except_last=_distributed)

        return loss / self._dist_loss_scale_factor

    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        print(f"Saving model to {output_dir}")
        os.makedirs(output_dir, exist_ok=True)

        if state_dict is None:
            state_dict = self.model.state_dict()
        prefix = 'encoder.'
        assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys())
        state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
        self.model.encoder.save_pretrained(
            output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
        )

        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)

        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
        self.model.encoder.config.to_json_file(os.path.join(output_dir, 'config.json'))