# -------------------------------------------------------- # InternVL # Copyright (c) 2024 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- import bisect import copy import logging from collections import defaultdict from typing import List, Union import numpy as np import torch import torch.distributed as dist from torch.utils.data import IterableDataset, get_worker_info from transformers.trainer_pt_utils import LabelSmoother from .constants import IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN IGNORE_TOKEN_ID = LabelSmoother.ignore_index logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_world_size(): if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size() def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() class PackedDataset(IterableDataset): def __init__( self, tokenizer, data_rank, data_world_size, datasets: List, dataset_weight: List[int] = None, num_images_expected: int = 6, max_packed_tokens: int = 32768, max_buffer_size: int = 100, log_freq: int = 1000000, strict_mode: bool = False, debug_mode: bool = False, replacement: bool = True, allow_overflow: bool = True, allow_empty_data: bool = False, allow_deduplicated_ds_name: bool = False, ): super().__init__() self.tokenizer = tokenizer self.data_rank = data_rank self.data_world_size = data_world_size self.datasets = datasets self.num_images_expected = num_images_expected self.max_buffer_size = max_buffer_size self.log_freq = log_freq self.strict_mode = strict_mode self.debug_mode = debug_mode self.replacement = replacement self.allow_overflow = allow_overflow self.allow_empty_data = allow_empty_data self.max_packed_tokens = max_packed_tokens self.img_start_token_id = self.tokenizer.convert_tokens_to_ids(IMG_START_TOKEN) self.img_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) self.img_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN) assert self.img_start_token_id != self.tokenizer.unk_token_id assert self.img_token_id != self.tokenizer.unk_token_id assert self.img_end_token_id != self.tokenizer.unk_token_id if dataset_weight is None: dataset_weight = [1] * len(datasets) self.dataset_type = [d.dataset_type for d in self.datasets] self.datasets_orig = datasets self.dataset_weight_orig = [w / sum(dataset_weight) for w in dataset_weight] self.datasets = [ds for ds in self.datasets_orig] self.dataset_weight = [w for w in self.dataset_weight_orig] # lazy init self.worker_id = None self.worker_state_key = None self.dataset_iter_list = None self._state_dict = { 'sample_info': {d.ds_name:0 for d in self.datasets}, } self.worker_custom_infos = None ds_name_list = [d.ds_name for d in self.datasets] if not allow_deduplicated_ds_name: assert len(ds_name_list) == len(set(ds_name_list)), f'deduplicated ds_name: {ds_name_list}' for ds in self.datasets: if ds.max_num_images > self.num_images_expected: logger.warning(f'{ds.max_num_images=} of {ds.ds_name} is larger than {self.num_images_expected=}') ds.max_num_images = num_images_expected if ds.max_tokens > self.max_packed_tokens: logger.warning(f'{ds.max_tokens=} of {ds.ds_name} is larger than {self.max_packed_tokens=}') ds.max_tokens = self.max_packed_tokens self._state_dict[ds.ds_name] = {} if get_rank() == 0: logger.info( f'Loaded dataset to pack: {ds_name_list}, ' f'{self.num_images_expected=}, {self.max_packed_tokens=}, ' f'{self.replacement=}, {self.allow_overflow=}', ) temp = [] for ds, ds_w in zip(self.datasets, self.dataset_weight): temp.append(f'{ds.ds_name:<25}: {ds_w*100:.2f}%') temp = '\n'.join(temp) logger.info( f'Sampling prob for each dataset:\n{temp}' ) if self.allow_empty_data: logger.warning('allow_empty_data is enabled, note that empty data may be generated!') def load_state_dict(self, state_dict, custom_infos=None): self.worker_custom_infos = custom_infos self._state_dict.update(state_dict) for ds in self.datasets: if ds.ds_name in self._state_dict: ds.load_state_dict(self._state_dict[ds.ds_name]) logger.info(f'{ds.ds_name=} is resumed.') else: logger.warning(f'{ds.ds_name=} is not resumed.') def _should_log(self): worker_id = 0 if get_worker_info() is None else get_worker_info().id num_workers = 1 if get_worker_info() is None else get_worker_info().num_workers worker_id = num_workers * get_rank() + worker_id num_workers = num_workers * get_world_size() return worker_id == 0 def next_data(self, current_dataset_idx): while True: try: current_sample = next(self.dataset_iter_list[current_dataset_idx]) break # Exit loop if successful except StopIteration: if self.replacement: # logger.info(f'[Worker id {self.worker_id}] Dataset {self.datasets[current_dataset_idx].ds_name} is exhausted, restart it.') try: self.dataset_iter_list[current_dataset_idx] = iter(self.datasets[current_dataset_idx]) current_sample = next(self.dataset_iter_list[current_dataset_idx]) break except: # logger.error(f'{self.worker_id=} Fail to get any data from {self.datasets[current_dataset_idx].ds_name}! length={len(self.datasets)}') self.datasets.pop(current_dataset_idx) self.dataset_iter_list.pop(current_dataset_idx) self.dataset_weight.pop(current_dataset_idx) if len(self.datasets) == 0: raise StopIteration current_dataset_idx = np.random.choice(len(self.datasets)) else: # logger.error(f'{self.worker_id=} Fail to get any data from {self.datasets[current_dataset_idx].ds_name}! length={len(self.datasets)}') self.datasets.pop(current_dataset_idx) self.dataset_iter_list.pop(current_dataset_idx) self.dataset_weight.pop(current_dataset_idx) if len(self.datasets) == 0: raise StopIteration current_dataset_idx = np.random.choice(len(self.datasets)) except: logger.error('Unexpected error!') if len(self.datasets) == 0: raise StopIteration current_dataset_idx = np.random.choice(len(self.datasets)) current_ds_name = self.datasets[current_dataset_idx].ds_name current_sample['type_ids'] = torch.zeros_like(current_sample['input_ids']) + current_dataset_idx if self.worker_state_key not in self._state_dict[current_ds_name]: self._state_dict[current_ds_name][self.worker_state_key] = {} meta_info = current_sample.pop('meta_info', {}) self._state_dict[current_ds_name][self.worker_state_key].update(**meta_info) self._state_dict['sample_info'][self.datasets[current_dataset_idx].ds_name] += 1 return current_sample def find_buffer(self, buffer_list, new_sample): # NOTE: use `bisect` to search might be faster find = False find_idx = -1 num_images_current = new_sample['pixel_values'].size(0) for buffer_idx, buffer in enumerate(buffer_list): num_images_buffer = buffer['pixel_values'].size(0) if num_images_buffer + num_images_current <= self.num_images_expected: num_merged_tokens = new_sample['input_ids'].size(0) + buffer['input_ids'].size(0) if num_merged_tokens <= self.max_packed_tokens: find = True find_idx = buffer_idx break if self.allow_overflow and len(buffer_list) >= self.max_buffer_size // 2: find = True find_idx = buffer_idx if find: return buffer_list.pop(find_idx) return None def update_buffer(self, buffer, new_sample): if buffer is None: new_sample['data_index'] = torch.zeros_like(new_sample['input_ids']) return new_sample new_sample['data_index'] = torch.ones_like(new_sample['input_ids']) + buffer['data_index'][-1].item() assert buffer.keys() == new_sample.keys() for k in buffer: buffer[k] = torch.cat([buffer[k], new_sample[k]]) return buffer @staticmethod def check_valid(sample_to_check, min_active_tokens_ratio=1/256): num_ignore_tokens = (sample_to_check['labels'] == IGNORE_TOKEN_ID).sum() num_tokens = sample_to_check['labels'].numel() return (1 - num_ignore_tokens / num_tokens) > min_active_tokens_ratio @staticmethod def split_buffer(buffer, max_tokens, img_start_token_id, img_token_id, img_end_token_id): if buffer['input_ids'].size(0) <= max_tokens: return [buffer] def _image_is_splitted(input_ids, cut_idx): is_image_start = input_ids[cut_idx].item() == img_start_token_id is_image_token = input_ids[cut_idx].item() == img_token_id is_image_end = input_ids[cut_idx].item() == img_end_token_id return is_image_start or is_image_token or is_image_end def _split(sample_to_split, left_idx, right_idx, left_img_idx, right_img_idx): assert (right_idx is None) == (right_img_idx is None) left_sample = {} right_sample = {} if right_idx is not None else None for k in sample_to_split: if k in ['input_ids', 'labels', 'attention_mask', 'position_ids', 'data_index', 'type_ids']: left_sample[k] = sample_to_split[k][:left_idx] if right_sample is not None: right_sample[k] = sample_to_split[k][right_idx:] elif k in ['pixel_values', 'image_flags']: left_sample[k] = sample_to_split[k][:left_img_idx] if right_sample is not None: right_sample[k] = sample_to_split[k][right_img_idx:] else: raise NotImplementedError(f'find unsupported keys: {k} from {sample_to_split.keys()}') return left_sample, right_sample splitted_buffer = [] while buffer['input_ids'].size(0) > max_tokens: img_start_idx_list = (buffer['input_ids'] == img_start_token_id).nonzero().squeeze(1).tolist() img_end_idx_list = (buffer['input_ids'] == img_end_token_id).nonzero().squeeze(1).tolist() assert len(img_start_idx_list) == len(img_end_idx_list) if _image_is_splitted(buffer['input_ids'], max_tokens): cut_idx = bisect.bisect_left(img_start_idx_list, max_tokens) if buffer['input_ids'][max_tokens] == img_start_token_id: assert max_tokens == img_start_idx_list[cut_idx] cut_left_idx = img_start_idx_list[cut_idx] cut_left_img_idx = cut_idx else: cut_left_idx = img_start_idx_list[cut_idx - 1] cut_left_img_idx = cut_idx - 1 cut_right_idx = cut_left_idx cut_right_img_idx = cut_left_img_idx else: cut_img_idx = bisect.bisect(img_start_idx_list, max_tokens) if cut_img_idx < len(img_start_idx_list): cut_right_idx = img_start_idx_list[cut_img_idx] cut_right_img_idx = cut_img_idx else: cut_right_idx = None cut_right_img_idx = None cut_left_idx = max_tokens cut_left_img_idx = cut_right_img_idx if cut_right_img_idx is not None else buffer['pixel_values'].size(0) left, right = _split( sample_to_split=buffer, left_idx=cut_left_idx, left_img_idx=cut_left_img_idx, right_idx=cut_right_idx, right_img_idx=cut_right_img_idx, ) assert (left['input_ids'] == img_end_token_id).sum() == (left['input_ids'] == img_start_token_id).sum() == left['pixel_values'].size(0) if right is not None: assert (right['input_ids'] == img_end_token_id).sum() == (right['input_ids'] == img_start_token_id).sum() == right['pixel_values'].size(0) if left['pixel_values'].size(0) >= 1 and PackedDataset.check_valid(left): splitted_buffer.append(left) if right is None or right['pixel_values'].size(0) == 0: break buffer = right if buffer['input_ids'].size(0) <= max_tokens and PackedDataset.check_valid(buffer): splitted_buffer.append(buffer) break logger.debug( f'split a sample into {len(splitted_buffer)} samples, ' f'current max_tokens={max_tokens}' ) return splitted_buffer def update_buffer_list(self, buffer_list, buffer_max_len_list, buffer): # NOTE: in-place operation splitted_buffer = PackedDataset.split_buffer( buffer=buffer, max_tokens=self.max_packed_tokens, img_start_token_id=self.img_start_token_id, img_token_id=self.img_token_id, img_end_token_id=self.img_end_token_id, ) for each_buffer in splitted_buffer: if each_buffer['pixel_values'].size(0) > self.num_images_expected: logger.error( f"Find a sample with {each_buffer['pixel_values'].size(0)} images, " f'which exceeds {self.num_images_expected}' ) continue if each_buffer['input_ids'].size(0) >= self.max_packed_tokens: assert each_buffer['input_ids'].size(0) == self.max_packed_tokens buffer_max_len_list.append(each_buffer) continue find_idx = len(buffer_list) num_images_new_sample = each_buffer['pixel_values'].size(0) for buffer_idx in range(len(buffer_list)): if buffer_list[buffer_idx]['pixel_values'].size(0) < num_images_new_sample: find_idx = buffer_idx break buffer_list.insert(find_idx, each_buffer) for i in range(1, len(buffer_list)): assert buffer_list[i-1]['pixel_values'].size(0) >= buffer_list[i]['pixel_values'].size(0) return buffer_list, buffer_max_len_list def pad_buffer(self, buffer): if buffer['pixel_values'].size(0) == self.num_images_expected: return buffer num_pad_images = self.num_images_expected - buffer['pixel_values'].size(0) pad_images = torch.stack([ torch.zeros_like(buffer['pixel_values'][0]) for _ in range(num_pad_images) ]) pad_image_flags = torch.tensor([0] * num_pad_images, dtype=torch.long) buffer['pixel_values'] = torch.cat([buffer['pixel_values'], pad_images]) buffer['image_flags'] = torch.cat([buffer['image_flags'], pad_image_flags]) return buffer def postprocess_buffer(self, buffer, custom_infos=None): buffer['worker_state_key'] = self.worker_state_key buffer['worker_state_dict'] = self._state_dict if custom_infos is not None: buffer['custom_infos'] = {self.worker_state_key: copy.deepcopy(custom_infos)} return buffer def print_log(self, iter_idx, buffer_list): if iter_idx % self.log_freq != 0: return if self._should_log(): logger.info( f"{iter_idx=}, {len(buffer_list)=}, {self._state_dict['sample_info']}" ) def __iter__(self): iter_idx = 0 buffer_list = [] buffer_max_len_list = [] if self._should_log(): logger.info(f'Begin to iter, {len(buffer_list)=}') worker_id = 0 if get_worker_info() is None else get_worker_info().id num_workers = 1 if get_worker_info() is None else get_worker_info().num_workers worker_id = num_workers * self.data_rank + worker_id num_workers = num_workers * self.data_world_size rng = np.random.default_rng(seed=worker_id) # reset states of each dataset self.worker_id = worker_id self.worker_state_key = f'work_state_{self.worker_id}' self.datasets = [d for d in self.datasets_orig] self.dataset_weight = [w for w in self.dataset_weight_orig] self.dataset_iter_list = [iter(d) for d in self.datasets] for ds in self.datasets: # if not isinstance(ds, (ImageTextPairDataset, InterleavedDataset)): ds.worker_id = worker_id ds.worker_state_key = f'work_state_{self.worker_id}' ds.num_workers = num_workers if self._should_log() and worker_id == 0: logger.info(f'set worker_id and num_workers of {ds.__class__.__name__} {ds.ds_name}') if self.worker_custom_infos is not None and self.worker_state_key in self.worker_custom_infos: custom_infos = self.worker_custom_infos[self.worker_state_key] # buffer list if 'buffer_list' in custom_infos and isinstance(custom_infos['buffer_list'], list): buffer_list = custom_infos['buffer_list'] if self._should_log() and worker_id == 0: logger.info(f'[{self.worker_state_key}] load buffer list --> {len(buffer_list)=}') # other infos # reset self.worker_custom_infos = None logger.debug( f'{self.__class__.__name__} Rank {self.data_rank} ' f'Worker {worker_id} begin to load data' ) while True: self.dataset_weight = [w / sum(self.dataset_weight) for w in self.dataset_weight] current_dataset_idx = rng.choice(len(self.dataset_iter_list), p=self.dataset_weight) try: current_sample = self.next_data(current_dataset_idx) except: logger.info(f'All datasets are exhausted, begin to empty the buffer_list ({len(buffer_list)=})') while len(buffer_list) > 0: if self.strict_mode: yield self.postprocess_buffer(self.pad_buffer(buffer_list.pop(0))) else: yield self.postprocess_buffer(buffer_list.pop(0)) logger.info(f'buffer_list is empty! ({len(buffer_list)=})') return buffer = self.find_buffer(buffer_list, current_sample) buffer = self.update_buffer(buffer, current_sample) buffer_list, buffer_max_len_list = self.update_buffer_list(buffer_list, buffer_max_len_list, buffer) while len(buffer_max_len_list) > 0: if buffer_max_len_list[0]['pixel_values'].size(0) != self.max_packed_tokens: logger.debug( f'num tokens of a buffer exceed {self.max_packed_tokens=}, ' f"yield a sample with {buffer_max_len_list[0]['pixel_values'].size(0)} images" ) if self.strict_mode and buffer_max_len_list[0]['pixel_values'].size(0) != self.num_images_expected: # buffer_max_len_list.pop(0) yield self.postprocess_buffer(self.pad_buffer(buffer_max_len_list.pop(0)), {'buffer_list': buffer_list}) else: yield self.postprocess_buffer(buffer_max_len_list.pop(0), {'buffer_list': buffer_list}) while len(buffer_list) > 0 and buffer_list[0]['pixel_values'].size(0) > self.num_images_expected: logger.error( f"num images of a buffer ({buffer_list[0]['pixel_values'].size(0)}) " f'is larger than num_images_expected({self.num_images_expected})' ) buffer_list.pop(0) while len(buffer_list) > 0 and buffer_list[0]['pixel_values'].size(0) == self.num_images_expected: if self.debug_mode: debug_data = self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list}) while True: yield debug_data.copy() yield self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list}) while len(buffer_list) > self.max_buffer_size: logger.debug( f'Failed to pack data to exactly {self.num_images_expected} images, ' f"yield a data sample with {buffer_list[0]['pixel_values'].size(0)} images." ) if self.strict_mode: yield self.postprocess_buffer(self.pad_buffer(buffer_list.pop(0)), {'buffer_list': buffer_list}) else: yield self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list}) self.print_log(iter_idx=iter_idx, buffer_list=buffer_list) iter_idx += 1 @staticmethod def get_cu_seqlens_and_indexes( data_index: torch.LongTensor, # (seq_len,) input_ids: torch.LongTensor, # (seq_len,) labels: torch.LongTensor, # (seq_len,) len2weight: callable, ): indexes = [] cu_seqlens = [0] loss_weight = [] start = data_index.min() end = data_index.max() + 1 for i in range(start, end): num_tokens = (data_index == i).sum().item() indexes.extend(list(range(num_tokens))) cu_seqlens.append(cu_seqlens[-1] + num_tokens) assert num_tokens > 0 curr_data_index = data_index[cu_seqlens[-2]:cu_seqlens[-2]+num_tokens] assert (curr_data_index == i).all(), data_index curr_labels = labels[cu_seqlens[-2]:cu_seqlens[-2]+num_tokens] num_effective_tokens = (curr_labels != IGNORE_TOKEN_ID).sum().item() loss_weight.extend([len2weight(num_effective_tokens)] * num_tokens) assert len(indexes) == data_index.size(0), f'{len(indexes)=}, {data_index.size(0)=}' loss_weight = torch.tensor(loss_weight, dtype=torch.float32) return cu_seqlens, indexes, loss_weight WARNING_CNT = defaultdict(int) def packed_collate_fn( features, data_collator, len2weight: callable, max_item_length: int, micro_num: int = 1, loss_reduction_all_gather: bool = False, pad_id: int = 0, ): if not isinstance(features, list): features = [features] if len(features) > micro_num: raise NotImplementedError(f'{len(features)=} > {micro_num=}') if len(features) < micro_num and WARNING_CNT['micro_num_warning'] < 5: logger.warning( f'{len(features)=} > {micro_num=}, ' f'the features will be padded to satisfy micro_num requirement' ) WARNING_CNT['micro_num_warning'] += 1 # ensure that the len(features) is equal to the required micro_num num_features = len(features) while len(features) < micro_num: features.append(copy.deepcopy(features[0])) features[-1]['labels'] = torch.full_like(features[-1]['labels'], IGNORE_TOKEN_ID) indexes = [] cu_seqlens = [] cu_num_images_list = [0] worker_state_key_list = [] worker_state_dict_list = [] worker_state_custom_infos_list = [] batch_lens = [feat['input_ids'].shape for feat in features] max_item_length = max_item_length or max(batch_lens)[0] num_samples = 0 num_padding_tokens = 0 for feat_idx, feat in enumerate(features): data_index = feat.pop('data_index') curr_cu_seqlens, curr_indexes, curr_loss_weight = PackedDataset.get_cu_seqlens_and_indexes( data_index=data_index, input_ids=feat['input_ids'], labels=feat['labels'], len2weight=len2weight, ) feat['loss_weight'] = curr_loss_weight if feat_idx < num_features: num_samples += len(curr_cu_seqlens) - 1 if curr_cu_seqlens[-1] < max_item_length: curr_cu_seqlens.append(max_item_length) curr_indexes.extend(list(range(max_item_length - curr_cu_seqlens[-2]))) indexes.append(torch.tensor(curr_indexes, dtype=torch.long)) cu_seqlens.append(torch.tensor(curr_cu_seqlens, dtype=torch.int32)) worker_state_key_list.append(feat.pop('worker_state_key')) worker_state_dict_list.append(feat.pop('worker_state_dict')) worker_state_custom_infos_list.append(feat.pop('custom_infos', None)) num_padding_tokens += (max_item_length - feat['input_ids'].size(0)) cu_num_images_list.append(cu_num_images_list[-1] + feat['pixel_values'].size(0)) batch = data_collator(features=features, max_item_length=max_item_length, pad_id=pad_id) # convert it to list in case it is converted into bf16 batch['loss_weight'] = torch.where(batch['labels'] == IGNORE_TOKEN_ID, 0, batch['loss_weight']).tolist() batch['attention_mask'] = torch.stack(cu_seqlens) batch['loss_reduction_all_gather'] = loss_reduction_all_gather batch['statistics'] = torch.tensor( [ num_samples, num_padding_tokens, batch['image_flags'].numel() - batch['image_flags'].sum().item(), ], dtype=torch.long, ) batch.pop('type_ids') return batch