diff --git a/.gitignore b/.gitignore index d5e4f15f8c02ac458b0019ffbb99d969fdc719db..4dd603546f0981218010459c7b5ae63b865d352e 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,4 @@ -**/__pycache__ \ No newline at end of file +**/__pycache__ +runs/ +.vscode/ + diff --git a/README.md b/README.md index a0d8744e3f90e263b8b55461fc8678ac3ceb9cc4..d14e548be9b335ad261877d599e4179db058d1b8 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,83 @@ For more details, please refer to the [paper](https://arxiv.org/abs/2308.00692). ``` pip install -r requirements.txt ``` + +## Training +### Training Data Preparation +The training data consists of 4 types of data: + +1. Semantic segmentation datasets: [ADE20K](http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip), [COCO-Stuff](https://github.com/nightrome/cocostuff#downloads), [Mapillary](https://www.mapillary.com/dataset/vistas), [PACO-LVIS](https://github.com/facebookresearch/paco/tree/main#dataset-setup), [PASCAL-Part](http://roozbehm.info/pascal-parts/pascal-parts.html) + +2. Referring segmentation datasets: refCOCO, refCOCO+, refCOCOg [\[Download\]](https://github.com/lichengunc/refer#download) + +3. Visual Question Answering dataset: [LLaVA-Instruct-150k](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_instruct_150k.json) + +4. Reasoning segmentation dataset: [ReasonSeg](https://github.com/dvlab-research/LISA#dataset) + +Download them from the above links, and organize them as follows. + +``` +├── dataset +│   ├── ade20k +│   │   ├── annotations +│   │   └── images +│   ├── coco +│   │   └── train2017 +│   ├── cocostuff +│   │   ├── annotations +│   │   └── train2017 +│   ├── llava_dataset +│   │   └── llava_instruct_150k.json +│   ├── mapillary +│   │   ├── config_v2.0.json +│   │   ├── testing +│   │   ├── training +│   │   └── validation +│   ├── reason_seg +│   │   └── ReasonSeg +│   │   ├── train +│   │   ├── val +│   │   └── explanatory +│   ├── refer_seg +│   │   ├── images +│   │   | ├── saiapr_tc-12 +│   │   | └── mscoco +│   │   | └── images +│   │   | └── train2014 +│   │   ├── refclef +│   │   ├── refcoco +│   │   ├── refcoco+ +│   │   └── refcocog +│   └── vlpart +│   ├── paco +│ │ └── annotations +│   └── pascal_part +│   ├── train.json +│ └── VOCdevkit +``` + +### Pre-trained weights + +#### LLaVA +To train LISA-7B or 13B, you need to follow the [instruction](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md) to merge the LLaVA delta weights. Typically, we use the final weights `LLaVA-Lightning-7B-v1-1` and `LLaVA-13B-v1-1` merged from `liuhaotian/LLaVA-Lightning-7B-delta-v1-1` and `liuhaotian/LLaVA-13b-delta-v1-1`, respectively. For Llama2, we can directly use the LLaVA full weights `liuhaotian/llava-llama-2-13b-chat-lightning-preview`. + +#### SAM ViT-H weights +Download SAM ViT-H pre-trained weights from the [link](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth). + +### Training +``` +deepspeed --master_port=24999 train_ds.py --version="PATH_TO_LLaVA_Wegihts" --dataset_dir='./dataset' --vision_pretrained="PATH_TO_SAM_Weights" --exp_name="lisa-7b" +``` +When training is finished, to get the full model weight: +``` +cd ./runs/lisa-7b/ckpt_model && python zero_to_fp32.py . ../pytorch_model.bin +``` + +### Validation +``` +deepspeed --master_port=24999 train_ds.py --version="PATH_TO_LLaVA_Wegihts" --dataset_dir='./dataset' --vision_pretrained="PATH_TO_SAM_Weights" --exp_name="lisa-7b" --weight='PATH_TO_pytorch_model.bin' --eval_only +``` + ## Inference To chat with [LISA-13B-llama2-v0](https://huggingface.co/xinlai/LISA-13B-llama2-v0) or [LISA-13B-llama2-v0-explainatory](https://huggingface.co/xinlai/LISA-13B-llama2-v0-explainatory): (Note that LISA-13B-llama2-v0 currently does not support explanatory answers.) @@ -93,9 +170,9 @@ Important keys contained in JSON files: The elements of the "shapes" exhibit two categories, namely **"target"** and **"ignore"**. The former category is indispensable for evaluation, while the latter category denotes the ambiguous region and hence disregarded during the evaluation process. -We provide a **script** that demonstrates how to process the annotations: +We provide a **script** that demonstrates how to process the annotations: ``` -python3 utils/data_proc_demo.py +python3 utils/data_processing.py ``` Besides, we leveraged GPT-3.5 for rephrasing instructions, so images in the training set may have **more than one instructions (but fewer than six)** in the "text" field. During training, users may randomly select one as the text query to obtain a better model. diff --git a/chat.py b/chat.py index 44353cfa4af892702a162a04695cd8deba37f4de..34715d7913d4daefcd9a58740de47d8efe35e085 100755 --- a/chat.py +++ b/chat.py @@ -1,38 +1,48 @@ -import sys +import argparse import os +import sys + import cv2 -import argparse -import torch -import transformers import numpy as np +import torch import torch.nn.functional as F - +import transformers from transformers import AutoTokenizer, CLIPImageProcessor from model.LISA import LISA -from utils.conversation import get_default_conv_template from model.segment_anything.utils.transforms import ResizeLongestSide +from utils.conversation import get_default_conv_template + def parse_args(args): - parser = argparse.ArgumentParser(description='LISA chat') - parser.add_argument('--version', default='xinlai/LISA-13B-llama2-v0') - parser.add_argument('--vis_save_path', default='./vis_output', type=str) - parser.add_argument('--precision', default='bf16', type=str, choices=['fp32', 'bf16', 'fp16'], help="precision for inference") - parser.add_argument('--image-size', default=1024, type=int, help='image size') - parser.add_argument('--model-max-length', default=512, type=int) - parser.add_argument('--lora-r', default=-1, type=int) - parser.add_argument('--vision-tower', default='openai/clip-vit-large-patch14', type=str) - parser.add_argument('--local-rank', default=0, type=int, help='node rank') - parser.add_argument('--load_in_8bit', action='store_true', default=False) - parser.add_argument('--load_in_4bit', action='store_true', default=False) - return parser.parse_args(args) - - -def preprocess(x, - pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1), + parser = argparse.ArgumentParser(description="LISA chat") + parser.add_argument("--version", default="xinlai/LISA-13B-llama2-v0") + parser.add_argument("--vis_save_path", default="./vis_output", type=str) + parser.add_argument( + "--precision", + default="bf16", + type=str, + choices=["fp32", "bf16", "fp16"], + help="precision for inference", + ) + parser.add_argument("--image-size", default=1024, type=int, help="image size") + parser.add_argument("--model-max-length", default=512, type=int) + parser.add_argument("--lora-r", default=-1, type=int) + parser.add_argument( + "--vision-tower", default="openai/clip-vit-large-patch14", type=str + ) + parser.add_argument("--local-rank", default=0, type=int, help="node rank") + parser.add_argument("--load_in_8bit", action="store_true", default=False) + parser.add_argument("--load_in_4bit", action="store_true", default=False) + return parser.parse_args(args) + + +def preprocess( + x, + pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1), pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1), - img_size=1024 - ) -> torch.Tensor: + img_size=1024, +) -> torch.Tensor: """Normalize pixel values and pad to a square input.""" # Normalize colors x = (x - pixel_mean) / pixel_std @@ -45,125 +55,185 @@ def preprocess(x, def main(args): - args = parse_args(args) - os.makedirs(args.vis_save_path, exist_ok=True) - - # Create model - tokenizer = transformers.AutoTokenizer.from_pretrained( - args.version, - cache_dir=None, - model_max_length=args.model_max_length, - padding_side="right", - use_fast=False, - ) - tokenizer.pad_token = tokenizer.unk_token - num_added_tokens = tokenizer.add_tokens('[SEG]') - ret_token_idx = tokenizer('[SEG]', add_special_tokens=False).input_ids - args.seg_token_idx = ret_token_idx[0] - - model = LISA( - args.local_rank, - args.seg_token_idx, - tokenizer, - args.version, - args.lora_r, - args.precision, - load_in_8bit=args.load_in_8bit, - load_in_4bit=args.load_in_4bit, - ) - - weight = {} - visual_model_weight = torch.load(os.path.join(args.version, "pytorch_model-visual_model.bin")) - text_hidden_fcs_weight = torch.load(os.path.join(args.version, "pytorch_model-text_hidden_fcs.bin")) - weight.update(visual_model_weight) - weight.update(text_hidden_fcs_weight) - missing_keys, unexpected_keys = model.load_state_dict(weight, strict=False) - - if args.precision == 'bf16': - model = model.bfloat16().cuda() - elif args.precision == 'fp16': - import deepspeed - model_engine = deepspeed.init_inference(model=model, - dtype=torch.half, - replace_with_kernel_inject=True, - replace_method="auto", + args = parse_args(args) + os.makedirs(args.vis_save_path, exist_ok=True) + + # Create model + tokenizer = transformers.AutoTokenizer.from_pretrained( + args.version, + cache_dir=None, + model_max_length=args.model_max_length, + padding_side="right", + use_fast=False, ) - model = model_engine.module - else: - model = model.float().cuda() - - DEFAULT_IMAGE_TOKEN = "" - DEFAULT_IMAGE_PATCH_TOKEN = "" - DEFAULT_IM_START_TOKEN = "" - DEFAULT_IM_END_TOKEN = "" - image_token_len = 256 - - clip_image_processor = CLIPImageProcessor.from_pretrained(args.vision_tower) - transform = ResizeLongestSide(args.image_size) - - while True: - - conv = get_default_conv_template("vicuna").copy() - conv.messages = [] - - prompt = input("Please input your prompt: ") - prompt = DEFAULT_IMAGE_TOKEN + " " + prompt - replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len - replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN - prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) - - conv.append_message(conv.roles[0], prompt) - conv.append_message(conv.roles[1], "") - prompt = conv.get_prompt() - - image_path = input("Please input the image path: ") - if not os.path.exists(image_path): - print("File not found in {}".format(image_path)) - continue - - image = cv2.imread(image_path) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - original_size_list = [image.shape[:2]] - if args.precision == 'bf16': - images_clip = clip_image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0].unsqueeze(0).cuda().bfloat16() - elif args.precision == 'fp16': - images_clip = clip_image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0].unsqueeze(0).cuda().half() - else: - images_clip = clip_image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0].unsqueeze(0).cuda().float() - images = transform.apply_image(image) - resize_list = [images.shape[:2]] - if args.precision == 'bf16': - images = preprocess(torch.from_numpy(images).permute(2,0,1).contiguous()).unsqueeze(0).cuda().bfloat16() - elif args.precision == 'fp16': - images = preprocess(torch.from_numpy(images).permute(2,0,1).contiguous()).unsqueeze(0).cuda().half() + tokenizer.pad_token = tokenizer.unk_token + num_added_tokens = tokenizer.add_tokens("[SEG]") + ret_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids + args.seg_token_idx = ret_token_idx[0] + + model = LISA( + args.local_rank, + args.seg_token_idx, + tokenizer, + args.version, + args.lora_r, + args.precision, + load_in_8bit=args.load_in_8bit, + load_in_4bit=args.load_in_4bit, + ) + + weight = {} + visual_model_weight = torch.load( + os.path.join(args.version, "pytorch_model-visual_model.bin") + ) + text_hidden_fcs_weight = torch.load( + os.path.join(args.version, "pytorch_model-text_hidden_fcs.bin") + ) + weight.update(visual_model_weight) + weight.update(text_hidden_fcs_weight) + missing_keys, unexpected_keys = model.load_state_dict(weight, strict=False) + + if args.precision == "bf16": + model = model.bfloat16().cuda() + elif args.precision == "fp16": + import deepspeed + + model_engine = deepspeed.init_inference( + model=model, + dtype=torch.half, + replace_with_kernel_inject=True, + replace_method="auto", + ) + model = model_engine.module else: - images = preprocess(torch.from_numpy(images).permute(2,0,1).contiguous()).unsqueeze(0).cuda().float() - - input_ids = tokenizer(prompt).input_ids - input_ids = torch.LongTensor(input_ids).unsqueeze(0).cuda() - output_ids, pred_masks = model.evaluate(images_clip, images, input_ids, resize_list, original_size_list, max_new_tokens=512, tokenizer=tokenizer) - text_output = tokenizer.decode(output_ids[0], skip_special_tokens=False) - text_output = text_output.replace(DEFAULT_IMAGE_PATCH_TOKEN, "").replace("\n", "").replace(" ", "") - - print("text_output: ", text_output) - for i, pred_mask in enumerate(pred_masks): - - if pred_mask.shape[0] == 0: - continue - - pred_mask = pred_mask.detach().cpu().numpy()[0] - pred_mask = (pred_mask > 0) - - save_path = "{}/{}_mask_{}.jpg".format(args.vis_save_path, image_path.split("/")[-1].split(".")[0], i) - cv2.imwrite(save_path, pred_mask * 100) - print("{} has been saved.".format(save_path)) - - save_path = "{}/{}_masked_img_{}.jpg".format(args.vis_save_path, image_path.split("/")[-1].split(".")[0], i) - save_img = image.copy() - save_img[pred_mask] = (image * 0.5 + pred_mask[:,:,None].astype(np.uint8) * np.array([255,0,0]) * 0.5)[pred_mask] - save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR) - cv2.imwrite(save_path, save_img) - print("{} has been saved.".format(save_path)) - -if __name__ == '__main__': - main(sys.argv[1:]) + model = model.float().cuda() + + DEFAULT_IMAGE_TOKEN = "" + DEFAULT_IMAGE_PATCH_TOKEN = "" + DEFAULT_IM_START_TOKEN = "" + DEFAULT_IM_END_TOKEN = "" + image_token_len = 256 + + clip_image_processor = CLIPImageProcessor.from_pretrained(args.vision_tower) + transform = ResizeLongestSide(args.image_size) + + while True: + conv = get_default_conv_template("vicuna").copy() + conv.messages = [] + + prompt = input("Please input your prompt: ") + prompt = DEFAULT_IMAGE_TOKEN + " " + prompt + replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) + + conv.append_message(conv.roles[0], prompt) + conv.append_message(conv.roles[1], "") + prompt = conv.get_prompt() + + image_path = input("Please input the image path: ") + if not os.path.exists(image_path): + print("File not found in {}".format(image_path)) + continue + + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + original_size_list = [image.shape[:2]] + if args.precision == "bf16": + images_clip = ( + clip_image_processor.preprocess(image, return_tensors="pt")[ + "pixel_values" + ][0] + .unsqueeze(0) + .cuda() + .bfloat16() + ) + elif args.precision == "fp16": + images_clip = ( + clip_image_processor.preprocess(image, return_tensors="pt")[ + "pixel_values" + ][0] + .unsqueeze(0) + .cuda() + .half() + ) + else: + images_clip = ( + clip_image_processor.preprocess(image, return_tensors="pt")[ + "pixel_values" + ][0] + .unsqueeze(0) + .cuda() + .float() + ) + images = transform.apply_image(image) + resize_list = [images.shape[:2]] + if args.precision == "bf16": + images = ( + preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous()) + .unsqueeze(0) + .cuda() + .bfloat16() + ) + elif args.precision == "fp16": + images = ( + preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous()) + .unsqueeze(0) + .cuda() + .half() + ) + else: + images = ( + preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous()) + .unsqueeze(0) + .cuda() + .float() + ) + + input_ids = tokenizer(prompt).input_ids + input_ids = torch.LongTensor(input_ids).unsqueeze(0).cuda() + output_ids, pred_masks = model.evaluate( + images_clip, + images, + input_ids, + resize_list, + original_size_list, + max_new_tokens=512, + tokenizer=tokenizer, + ) + text_output = tokenizer.decode(output_ids[0], skip_special_tokens=False) + text_output = ( + text_output.replace(DEFAULT_IMAGE_PATCH_TOKEN, "") + .replace("\n", "") + .replace(" ", "") + ) + + print("text_output: ", text_output) + for i, pred_mask in enumerate(pred_masks): + if pred_mask.shape[0] == 0: + continue + + pred_mask = pred_mask.detach().cpu().numpy()[0] + pred_mask = pred_mask > 0 + + save_path = "{}/{}_mask_{}.jpg".format( + args.vis_save_path, image_path.split("/")[-1].split(".")[0], i + ) + cv2.imwrite(save_path, pred_mask * 100) + print("{} has been saved.".format(save_path)) + + save_path = "{}/{}_masked_img_{}.jpg".format( + args.vis_save_path, image_path.split("/")[-1].split(".")[0], i + ) + save_img = image.copy() + save_img[pred_mask] = ( + image * 0.5 + + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5 + )[pred_mask] + save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR) + cv2.imwrite(save_path, save_img) + print("{} has been saved.".format(save_path)) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/model/LISA.py b/model/LISA.py index 66697951da211645c4a6417e2dd01472420bbe20..f3f6c01e68a89979558b67e6b16327a71e6e1813 100755 --- a/model/LISA.py +++ b/model/LISA.py @@ -1,213 +1,490 @@ -from typing import Callable, List, Optional, Tuple, Union -import json -import glob -import math -import numpy as np -import os +from typing import List + import torch import torch.nn as nn import torch.nn.functional as F -import transformers - -from transformers import LlamaForCausalLM, CLIPVisionModel, BitsAndBytesConfig -from peft import ( - LoraConfig, - get_peft_model, - get_peft_model_state_dict, - prepare_model_for_int8_training, - set_peft_model_state_dict, -) -from .llava.model.llava import LlavaLlamaForCausalLM -from .segment_anything import build_sam_vit_l, build_sam_vit_h +from peft import (LoraConfig, get_peft_model) +from transformers import BitsAndBytesConfig, CLIPVisionModel -DEFAULT_IMAGE_TOKEN = "" -DEFAULT_IMAGE_PATCH_TOKEN = "" -DEFAULT_IM_START_TOKEN = "" -DEFAULT_IM_END_TOKEN = "" +from .llava.model.llava import LlavaLlamaForCausalLM +from .segment_anything import build_sam_vit_h +from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, + DEFAULT_IMAGE_PATCH_TOKEN) -def find_all_linear_names(model): - cls = torch.nn.Linear - lora_module_names = set() - for name, module in model.named_modules(): - if isinstance(module, cls): - names = name.split('.') - lora_module_names.add(names[0] if len(names) == 1 else names[-1]) +def dice_loss( + inputs: torch.Tensor, + targets: torch.Tensor, + num_masks: float, + scale=1000, # 100000.0, + eps=1e-6, +): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1, 2) + targets = targets.flatten(1, 2) + numerator = 2 * (inputs / scale * targets).sum(-1) + denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1) + loss = 1 - (numerator + eps) / (denominator + eps) + loss = loss.sum() / (num_masks + 1e-8) + return loss - if 'lm_head' in lora_module_names: # needed for 16-bit - lora_module_names.remove('lm_head') - if 'mm_projector' in lora_module_names: - lora_module_names.remove('mm_projector') +def sigmoid_ce_loss( + inputs: torch.Tensor, + targets: torch.Tensor, + num_masks: float, +): + """ + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + Returns: + Loss tensor + """ + loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8) + return loss - return sorted(list(lora_module_names)) class LISA(nn.Module): - def __init__(self, - local_rank, - seg_token_idx, - tokenizer, - llm_version, - lora_r, - precision, - load_in_4bit=False, - load_in_8bit=False, - lora_target_modules=['q_proj', 'v_proj'], - lora_alpha=16, - lora_dropout=0.05, - vision_tower='openai/clip-vit-large-patch14', - mm_vision_select_layer=-2, - freeze_lm=True, - train_mask_decoder=True, - out_dim=256, - ): - - super().__init__() - self.tokenizer = tokenizer - self.image_token = tokenizer.cls_token_id - self.precision = precision - - # LLaVA - tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) - num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) - if precision == "bf16": - self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, torch_dtype=torch.bfloat16, cache_dir=None, low_cpu_mem_usage=True) - elif precision == "fp16": - if load_in_4bit: - self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, load_in_4bit=True, cache_dir=None, low_cpu_mem_usage=True, device_map='auto', - quantization_config=BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=torch.float16, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type='nf4' - ) - ) - elif load_in_8bit: - self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, load_in_8bit=True, cache_dir=None, low_cpu_mem_usage=True, device_map='auto') - else: - self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, torch_dtype=torch.half, cache_dir=None, low_cpu_mem_usage=True) - else: - self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, torch_dtype=torch.float32, cache_dir=None, low_cpu_mem_usage=True) - - self.lm.enable_input_require_grads() - self.lm.gradient_checkpointing_enable() - self.lm.config.use_cache = False - model_vision_dict = self.lm.get_model().initialize_vision_modules(vision_tower=vision_tower, mm_vision_select_layer=mm_vision_select_layer, precision=precision) - vision_config = model_vision_dict['vision_config'] - vision_tower = self.lm.get_model().vision_tower[0] - self.lm.model.config.eos_token_id = tokenizer.eos_token_id - self.lm.model.config.bos_token_id = tokenizer.bos_token_id - self.lm.model.config.pad_token_id = tokenizer.pad_token_id - - if vision_tower.device.type == 'meta': - if precision == 'bf16': - vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).cuda(local_rank) - elif precision == 'fp16': - vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.half, low_cpu_mem_usage=True).cuda(local_rank) - else: - vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float32, low_cpu_mem_usage=True).cuda(local_rank) - self.lm.get_model().vision_tower[0] = vision_tower - else: + def __init__( + self, + local_rank, + seg_token_idx, + tokenizer, + llm_version, + lora_r, + precision, + load_in_4bit=False, + load_in_8bit=False, + lora_target_modules=["q_proj", "v_proj"], + lora_alpha=16, + lora_dropout=0.05, + vision_tower="openai/clip-vit-large-patch14", + mm_vision_select_layer=-2, + freeze_lm=True, + train_mask_decoder=True, + out_dim=256, + ce_loss_weight=1.0, + dice_loss_weight=0.5, + bce_loss_weight=2.0, + vision_pretrained=None, + ): + super().__init__() + self.local_rank = local_rank + self.tokenizer = tokenizer + self.image_token = tokenizer.cls_token_id + self.precision = precision + self.ce_loss_weight = ce_loss_weight + self.dice_loss_weight = dice_loss_weight + self.bce_loss_weight = bce_loss_weight + # LLaVA + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + num_new_tokens = tokenizer.add_tokens( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True + ) if precision == "bf16": - vision_tower.to(device='cuda', dtype=torch.bfloat16) + self.lm = LlavaLlamaForCausalLM.from_pretrained( + llm_version, + torch_dtype=torch.bfloat16, + cache_dir=None, + low_cpu_mem_usage=True, + ) elif precision == "fp16": - vision_tower.to(device='cuda', dtype=torch.half) + if load_in_4bit: + self.lm = LlavaLlamaForCausalLM.from_pretrained( + llm_version, + load_in_4bit=True, + cache_dir=None, + low_cpu_mem_usage=True, + device_map="auto", + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + ), + ) + elif load_in_8bit: + self.lm = LlavaLlamaForCausalLM.from_pretrained( + llm_version, + load_in_8bit=True, + cache_dir=None, + low_cpu_mem_usage=True, + device_map="auto", + ) + else: + self.lm = LlavaLlamaForCausalLM.from_pretrained( + llm_version, + torch_dtype=torch.half, + cache_dir=None, + low_cpu_mem_usage=True, + ) else: - vision_tower.to(device='cuda', dtype=torch.float32) - - self.lm.config.tune_mm_mlp_adapter = False - self.lm.config.freeze_mm_mlp_adapter = False - self.lm.config.mm_use_im_start_end = True - vision_config.use_im_start_end = True - self.lm.config.sep_image_conv_front = False - - self.lm.initialize_vision_tokenizer(mm_use_im_start_end=True, tokenizer=tokenizer, num_new_tokens=num_new_tokens, device=local_rank, tune_mm_mlp_adapter=False) - if freeze_lm: - for n, param in self.lm.named_parameters(): - param.requires_grad = False - - self.llm_version = llm_version - - self.seg_token_idx = seg_token_idx - self.lm.resize_token_embeddings(len(tokenizer)) - - for n, p in self.lm.named_parameters(): - if any([x in n for x in ['lm_head', 'embed_tokens']]) and p.shape[0] == len(tokenizer): - p.requires_grad = True - - # SAM - self.visual_model = build_sam_vit_h(None) - for param in self.visual_model.parameters(): - param.requires_grad = False - if train_mask_decoder: - self.visual_model.mask_decoder.train() - for param in self.visual_model.mask_decoder.parameters(): - param.requires_grad = True - - # Projection layer - in_dim = self.lm.config.hidden_size - text_fc = [nn.Linear(in_dim, in_dim), nn.ReLU(inplace=True), nn.Linear(in_dim, out_dim), nn.Dropout(0.0)] - self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)]) - - def get_visual_embs(self, pixel_values: torch.FloatTensor): - image_embeddings = self.visual_model.image_encoder(pixel_values) - return image_embeddings - - def evaluate(self, images_clip, images, input_ids, resize_list, original_size_list, max_new_tokens=32, tokenizer=None): - - with torch.no_grad(): - outputs = self.lm.generate(images=images_clip, input_ids=input_ids, max_new_tokens=max_new_tokens, num_beams=1, output_hidden_states=True, return_dict_in_generate=True) - output_hidden_states = outputs.hidden_states[-1] - output_ids = outputs.sequences - - seg_token_mask = (output_ids[:, 1:] == self.seg_token_idx) - - last_embedding = None - last_output_logit = None - hidden_states = [] - - assert len(self.text_hidden_fcs) == 1 - hidden_states.append(self.text_hidden_fcs[0](output_hidden_states)) - - last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1) - pred_embeddings = last_hidden_state[seg_token_mask] - - seg_token_counts = seg_token_mask.int().sum(-1) #[bs, ] - seg_token_offset = seg_token_counts.cumsum(-1) - seg_token_offset = torch.cat([torch.zeros(1).long().cuda(), seg_token_offset], dim=0) - - pred_embeddings_ = [] - for i in range(len(seg_token_offset)-1): - start_i, end_i = seg_token_offset[i], seg_token_offset[i+1] - pred_embeddings_.append(pred_embeddings[start_i: end_i]) - pred_embeddings = pred_embeddings_ - - image_embeddings = self.get_visual_embs(images) - - multimask_output = False - pred_masks = [] - for i in range(len(pred_embeddings)): - sparse_embeddings, dense_embeddings = self.visual_model.prompt_encoder( - points=None, - boxes=None, - masks=None, - text_embeds=pred_embeddings[i].unsqueeze(1), + self.lm = LlavaLlamaForCausalLM.from_pretrained( + llm_version, + torch_dtype=torch.float32, + cache_dir=None, + low_cpu_mem_usage=True, + ) + + self.lm.enable_input_require_grads() + self.lm.gradient_checkpointing_enable() + self.lm.config.use_cache = False + model_vision_dict = self.lm.get_model().initialize_vision_modules( + vision_tower=vision_tower, + mm_vision_select_layer=mm_vision_select_layer, + precision=precision, ) + vision_config = model_vision_dict["vision_config"] + vision_tower = self.lm.get_model().vision_tower[0] + self.lm.model.config.eos_token_id = tokenizer.eos_token_id + self.lm.model.config.bos_token_id = tokenizer.bos_token_id + self.lm.model.config.pad_token_id = tokenizer.pad_token_id - sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype) - low_res_masks, iou_predictions = self.visual_model.mask_decoder( - image_embeddings=image_embeddings[i].unsqueeze(0), - image_pe=self.visual_model.prompt_encoder.get_dense_pe(), - sparse_prompt_embeddings=sparse_embeddings, - dense_prompt_embeddings=dense_embeddings, - multimask_output=multimask_output, + if vision_tower.device.type == "meta": + if precision == "bf16": + vision_tower = CLIPVisionModel.from_pretrained( + vision_tower.config._name_or_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + ).cuda(local_rank) + elif precision == "fp16": + vision_tower = CLIPVisionModel.from_pretrained( + vision_tower.config._name_or_path, + torch_dtype=torch.half, + low_cpu_mem_usage=True, + ).cuda(local_rank) + else: + vision_tower = CLIPVisionModel.from_pretrained( + vision_tower.config._name_or_path, + torch_dtype=torch.float32, + low_cpu_mem_usage=True, + ).cuda(local_rank) + self.lm.get_model().vision_tower[0] = vision_tower + else: + if precision == "bf16": + vision_tower.to(device="cuda", dtype=torch.bfloat16) + elif precision == "fp16": + vision_tower.to(device="cuda", dtype=torch.half) + else: + vision_tower.to(device="cuda", dtype=torch.float32) + + self.lm.config.tune_mm_mlp_adapter = False + self.lm.config.freeze_mm_mlp_adapter = False + self.lm.config.mm_use_im_start_end = True + vision_config.use_im_start_end = True + self.lm.config.sep_image_conv_front = False + + self.lm.initialize_vision_tokenizer( + mm_use_im_start_end=True, + tokenizer=tokenizer, + num_new_tokens=num_new_tokens, + device=local_rank, + tune_mm_mlp_adapter=False, ) + if freeze_lm: + for n, param in self.lm.named_parameters(): + param.requires_grad = False + + # LoRA + if lora_r > 0: + config = LoraConfig( + r=lora_r, + lora_alpha=lora_alpha, + target_modules=lora_target_modules, + lora_dropout=lora_dropout, + bias="none", + task_type="CAUSAL_LM", + ) + self.lm = get_peft_model(self.lm, config) + self.lm.print_trainable_parameters() + + self.llm_version = llm_version + + self.seg_token_idx = seg_token_idx + self.lm.resize_token_embeddings(len(tokenizer)) + + for n, p in self.lm.named_parameters(): + if any([x in n for x in ["lm_head", "embed_tokens"]]) and p.shape[0] == len(tokenizer): + p.requires_grad = True - pred_mask = self.visual_model.postprocess_masks( - low_res_masks, - input_size=resize_list[i], - original_size=original_size_list[i], + # SAM + self.visual_model = build_sam_vit_h(vision_pretrained) + for param in self.visual_model.parameters(): + param.requires_grad = False + if train_mask_decoder: + self.visual_model.mask_decoder.train() + for param in self.visual_model.mask_decoder.parameters(): + param.requires_grad = True + + # Projection layer + in_dim = self.lm.config.hidden_size + text_fc = [ + nn.Linear(in_dim, in_dim), + nn.ReLU(inplace=True), + nn.Linear(in_dim, out_dim), + nn.Dropout(0.0), + ] + self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)]) + + def get_visual_embs(self, pixel_values: torch.FloatTensor): + with torch.no_grad(): + image_embeddings = self.visual_model.image_encoder(pixel_values) + return image_embeddings + + def forward( + self, + images: torch.FloatTensor, + images_clip: torch.FloatTensor, + input_ids: torch.LongTensor, + labels: torch.LongTensor, + attention_masks: torch.LongTensor, + offset: torch.LongTensor, + masks_list: List[torch.FloatTensor], + label_list: List[torch.Tensor], + resize_list: List[tuple], + inference: bool = False, + **kwargs, + ): + image_embeddings = self.get_visual_embs(images) + batch_size = image_embeddings.shape[0] + assert batch_size == len(offset) - 1 + + seg_token_mask = input_ids[:, 1:] == self.seg_token_idx + seg_token_mask = torch.cat( + [ + seg_token_mask, + torch.zeros((seg_token_mask.shape[0], 1)).bool().cuda(self.local_rank), + ], + dim=1, ) - pred_masks.append(pred_mask[:, 0]) - - return output_ids, pred_masks + + if inference: + n_batch = 1 + length = input_ids.shape[0] + assert images_clip.shape[0] == 1 + images_clip_extend = images_clip.expand(length, -1, -1, -1).contiguous() + + output_hidden_states = [] + for i in range(n_batch): + start_i, end_i = i * length, min((i + 1) * length, input_ids.shape[0]) + output_i = self.lm( + images=images_clip_extend[: end_i - start_i], + attention_mask=attention_masks[start_i:end_i], + input_ids=input_ids[start_i:end_i], + output_hidden_states=True, + ) + output_hidden_states.append(output_i.hidden_states) + torch.cuda.empty_cache() + + output_hidden_states_list = [] + output_hidden_states_level = torch.cat(output_hidden_states, dim=0) + output_hidden_states_list.append(output_hidden_states_level) + output_hidden_states = output_hidden_states_list + output = None + + else: + images_clip_list = [] + for i in range(len(offset) - 1): + start_i, end_i = offset[i], offset[i + 1] + images_clip_i = ( + images_clip[i] + .unsqueeze(0) + .expand(end_i - start_i, -1, -1, -1) + .contiguous() + ) + images_clip_list.append(images_clip_i) + images_clip = torch.cat(images_clip_list, dim=0) + + output = self.lm( + images=images_clip, + attention_mask=attention_masks, + input_ids=input_ids, + labels=labels, + output_hidden_states=True, + ) + output_hidden_states = output.hidden_states + + hidden_states = [] + + assert len(self.text_hidden_fcs) == 1 + hidden_states.append(self.text_hidden_fcs[0](output_hidden_states[-1])) + + last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1) + + pred_embeddings = last_hidden_state[seg_token_mask] + seg_token_counts = seg_token_mask.int().sum(-1) # [bs, ] + + seg_token_offset = seg_token_counts.cumsum(-1) + seg_token_offset = torch.cat( + [torch.zeros(1).long().cuda(), seg_token_offset], dim=0 + ) + + seg_token_offset = seg_token_offset[offset] + + pred_embeddings_ = [] + for i in range(len(seg_token_offset) - 1): + start_i, end_i = seg_token_offset[i], seg_token_offset[i + 1] + pred_embeddings_.append(pred_embeddings[start_i:end_i]) + pred_embeddings = pred_embeddings_ + + multimask_output = False + pred_masks = [] + for i in range(len(pred_embeddings)): + sparse_embeddings, dense_embeddings = self.visual_model.prompt_encoder( + points=None, + boxes=None, + masks=None, + text_embeds=pred_embeddings[i].unsqueeze(1), + ) + sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype) + low_res_masks, iou_predictions = self.visual_model.mask_decoder( + image_embeddings=image_embeddings[i].unsqueeze(0), + image_pe=self.visual_model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + pred_mask = self.visual_model.postprocess_masks( + low_res_masks, + input_size=resize_list[i], + original_size=label_list[i].shape, + ) + pred_masks.append(pred_mask[:, 0]) + + model_output = output + gt_masks = masks_list + + if inference: + return { + "pred_masks": pred_masks, + "gt_masks": gt_masks, + } + + output = model_output.logits + + ce_loss = model_output.loss + ce_loss = ce_loss * self.ce_loss_weight + loss = ce_loss + mask_bce_loss = 0 + mask_dice_loss = 0 + num_masks = 0 + for batch_idx in range(len(pred_masks)): + gt_mask = gt_masks[batch_idx] + pred_mask = pred_masks[batch_idx] + + assert ( + gt_mask.shape[0] == pred_mask.shape[0] + ), "gt_mask.shape: {}, pred_mask.shape: {}".format( + gt_mask.shape, pred_mask.shape + ) + mask_bce_loss += ( + sigmoid_ce_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0]) + * gt_mask.shape[0] + ) + mask_dice_loss += ( + dice_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0]) + * gt_mask.shape[0] + ) + num_masks += gt_mask.shape[0] + + mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8) + mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8) + mask_loss = mask_bce_loss + mask_dice_loss + + loss += mask_loss + + return { + "loss": loss, + "ce_loss": ce_loss, + "mask_bce_loss": mask_bce_loss, + "mask_dice_loss": mask_dice_loss, + "mask_loss": mask_loss, + } + + def evaluate( + self, + images_clip, + images, + input_ids, + resize_list, + original_size_list, + max_new_tokens=32, + tokenizer=None, + ): + with torch.no_grad(): + outputs = self.lm.generate( + images=images_clip, + input_ids=input_ids, + max_new_tokens=max_new_tokens, + num_beams=1, + output_hidden_states=True, + return_dict_in_generate=True, + ) + output_hidden_states = outputs.hidden_states[-1] + output_ids = outputs.sequences + + seg_token_mask = output_ids[:, 1:] == self.seg_token_idx + + hidden_states = [] + + assert len(self.text_hidden_fcs) == 1 + hidden_states.append(self.text_hidden_fcs[0](output_hidden_states)) + + last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1) + pred_embeddings = last_hidden_state[seg_token_mask] + + seg_token_counts = seg_token_mask.int().sum(-1) # [bs, ] + seg_token_offset = seg_token_counts.cumsum(-1) + seg_token_offset = torch.cat( + [torch.zeros(1).long().cuda(), seg_token_offset], dim=0 + ) + + pred_embeddings_ = [] + for i in range(len(seg_token_offset) - 1): + start_i, end_i = seg_token_offset[i], seg_token_offset[i + 1] + pred_embeddings_.append(pred_embeddings[start_i:end_i]) + pred_embeddings = pred_embeddings_ + + image_embeddings = self.get_visual_embs(images) + + multimask_output = False + pred_masks = [] + for i in range(len(pred_embeddings)): + sparse_embeddings, dense_embeddings = self.visual_model.prompt_encoder( + points=None, + boxes=None, + masks=None, + text_embeds=pred_embeddings[i].unsqueeze(1), + ) + + sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype) + low_res_masks, iou_predictions = self.visual_model.mask_decoder( + image_embeddings=image_embeddings[i].unsqueeze(0), + image_pe=self.visual_model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + pred_mask = self.visual_model.postprocess_masks( + low_res_masks, + input_size=resize_list[i], + original_size=original_size_list[i], + ) + pred_masks.append(pred_mask[:, 0]) + + return output_ids, pred_masks diff --git a/model/llava/conversation.py b/model/llava/conversation.py index 6ff9a66065505d5241693058b649e9f8125735b7..d29387b13b358be670de39378329889092714693 100644 --- a/model/llava/conversation.py +++ b/model/llava/conversation.py @@ -1,10 +1,11 @@ import dataclasses -from enum import auto, Enum +from enum import Enum, auto from typing import List, Tuple class SeparatorStyle(Enum): """Different separator style.""" + SINGLE = auto() TWO = auto() MPT = auto() @@ -13,6 +14,7 @@ class SeparatorStyle(Enum): @dataclasses.dataclass class Conversation: """A class that keeps all conversation history.""" + system: str roles: List[str] messages: List[List[str]] @@ -64,33 +66,43 @@ class Conversation: def get_images(self, return_pil=False): images = [] - for i, (role, msg) in enumerate(self.messages[self.offset:]): + for i, (role, msg) in enumerate(self.messages[self.offset :]): if i % 2 == 0: if type(msg) is tuple: import base64 from io import BytesIO + from PIL import Image + msg, image, image_process_mode = msg if image_process_mode == "Pad": + def expand2square(pil_img, background_color=(122, 116, 104)): width, height = pil_img.size if width == height: return pil_img elif width > height: - result = Image.new(pil_img.mode, (width, width), background_color) + result = Image.new( + pil_img.mode, (width, width), background_color + ) result.paste(pil_img, (0, (width - height) // 2)) return result else: - result = Image.new(pil_img.mode, (height, height), background_color) + result = Image.new( + pil_img.mode, (height, height), background_color + ) result.paste(pil_img, ((height - width) // 2, 0)) return result + image = expand2square(image) elif image_process_mode == "Crop": pass elif image_process_mode == "Resize": image = image.resize((224, 224)) else: - raise ValueError(f"Invalid image_process_mode: {image_process_mode}") + raise ValueError( + f"Invalid image_process_mode: {image_process_mode}" + ) max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 800, 400 @@ -113,11 +125,12 @@ class Conversation: def to_gradio_chatbot(self): ret = [] - for i, (role, msg) in enumerate(self.messages[self.offset:]): + for i, (role, msg) in enumerate(self.messages[self.offset :]): if i % 2 == 0: if type(msg) is tuple: import base64 from io import BytesIO + msg, image, image_process_mode = msg max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw @@ -135,7 +148,7 @@ class Conversation: image.save(buffered, format="JPEG") img_b64_str = base64.b64encode(buffered.getvalue()).decode() img_str = f'user upload image' - msg = msg.replace('', img_str) + msg = msg.replace("", img_str) ret.append([msg, None]) else: ret[-1][-1] = msg @@ -149,14 +162,17 @@ class Conversation: offset=self.offset, sep_style=self.sep_style, sep=self.sep, - sep2=self.sep2) + sep2=self.sep2, + ) def dict(self): if len(self.get_images()) > 0: return { "system": self.system, "roles": self.roles, - "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], + "messages": [ + [x, y[0] if type(y) is tuple else y] for x, y in self.messages + ], "offset": self.offset, "sep": self.sep, "sep2": self.sep2, @@ -173,11 +189,12 @@ class Conversation: conv_v1 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions.", + "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=( ("Human", "Give three tips for staying healthy."), - ("Assistant", + ( + "Assistant", "Sure, here are three tips for staying healthy:\n" "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. " "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, " @@ -191,7 +208,8 @@ conv_v1 = Conversation( "3. Get enough sleep: Getting enough quality sleep is essential for your physical " "and mental health. Adults should aim for seven to nine hours of sleep per night. " "Establish a regular sleep schedule and try to create a relaxing bedtime routine to " - "help improve the quality of your sleep.") + "help improve the quality of your sleep.", + ), ), offset=2, sep_style=SeparatorStyle.SINGLE, @@ -200,11 +218,15 @@ conv_v1 = Conversation( conv_v1_2 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions.", + "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=( - ("Human", "What are the key differences between renewable and non-renewable energy sources?"), - ("Assistant", + ( + "Human", + "What are the key differences between renewable and non-renewable energy sources?", + ), + ( + "Assistant", "Renewable energy sources are those that can be replenished naturally in a relatively " "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " "Non-renewable energy sources, on the other hand, are finite and will eventually be " @@ -222,7 +244,8 @@ conv_v1_2 = Conversation( "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " "situations and needs, while non-renewable sources are more rigid and inflexible.\n" "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " - "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") + "non-renewable sources are not, and their depletion can lead to economic and social instability.\n", + ), ), offset=2, sep_style=SeparatorStyle.SINGLE, @@ -280,12 +303,12 @@ conv_bair_v1 = Conversation( simple_conv = Conversation( system="You are LLaVA, a large language model trained by UW Madison WAIV Lab, based on LLaMA architecture." - "You are designed to assist human with a variety of tasks using natural language." - "Follow the instructions carefully.", + "You are designed to assist human with a variety of tasks using natural language." + "Follow the instructions carefully.", roles=("Human", "Assistant"), messages=( ("Human", "Hi!"), - ("Assistant", "Hi there! How can I help you today?\n") + ("Assistant", "Hi there! How can I help you today?\n"), ), offset=2, sep_style=SeparatorStyle.SINGLE, @@ -294,12 +317,12 @@ simple_conv = Conversation( simple_conv_multimodal = Conversation( system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab." - "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." - "Follow the instructions carefully and explain your answers in detail.", + "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." + "Follow the instructions carefully and explain your answers in detail.", roles=("Human", "Assistant"), messages=( ("Human", "Hi!"), - ("Assistant", "Hi there! How can I help you today?\n") + ("Assistant", "Hi there! How can I help you today?\n"), ), offset=2, sep_style=SeparatorStyle.SINGLE, @@ -321,12 +344,12 @@ simple_conv_mpt_multimodal = Conversation( simple_conv_legacy = Conversation( system="You are LLaVA, a large language model trained by UW Madison WAIV Lab." - "You are designed to assist human with a variety of tasks using natural language." - "Follow the instructions carefully.", + "You are designed to assist human with a variety of tasks using natural language." + "Follow the instructions carefully.", roles=("Human", "Assistant"), messages=( ("Human", "Hi!\n\n### Response:"), - ("Assistant", "Hi there! How can I help you today?\n") + ("Assistant", "Hi there! How can I help you today?\n"), ), offset=2, sep_style=SeparatorStyle.SINGLE, @@ -335,8 +358,8 @@ simple_conv_legacy = Conversation( conv_llava_v1 = Conversation( system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab." - "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." - "Follow the instructions carefully and explain your answers in detail.", + "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." + "Follow the instructions carefully and explain your answers in detail.", roles=("USER", "ASSISTANT"), version="v1", messages=(), @@ -354,7 +377,6 @@ conv_templates = { "multimodal": simple_conv_multimodal, "mpt_multimodal": simple_conv_mpt_multimodal, "llava_v1": conv_llava_v1, - # fastchat "v1": conv_v1_2, "bair_v1": conv_bair_v1, diff --git a/model/llava/eval/eval_gpt_review.py b/model/llava/eval/eval_gpt_review.py index 1552d2de0609fcc5044ae300cf8e5ddf457a7581..d55b5bec1aab652a4e03065bfbddd2d30e438d14 100644 --- a/model/llava/eval/eval_gpt_review.py +++ b/model/llava/eval/eval_gpt_review.py @@ -1,25 +1,29 @@ import argparse import json import os +import time import openai -import tqdm import ray -import time +import tqdm + @ray.remote(num_cpus=4) def get_eval(content: str, max_tokens: int): while True: try: response = openai.ChatCompletion.create( - model='gpt-4', - messages=[{ - 'role': 'system', - 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' - }, { - 'role': 'user', - 'content': content, - }], + model="gpt-4", + messages=[ + { + "role": "system", + "content": "You are a helpful and precise assistant for checking the quality of the answer.", + }, + { + "role": "user", + "content": content, + }, + ], temperature=0.2, # TODO: figure out which temperature is best for evaluation max_tokens=max_tokens, ) @@ -30,34 +34,39 @@ def get_eval(content: str, max_tokens: int): print(e) time.sleep(1) - print('success!') - return response['choices'][0]['message']['content'] + print("success!") + return response["choices"][0]["message"]["content"] def parse_score(review): try: - score_pair = review.split('\n')[0] - score_pair = score_pair.replace(',', ' ') - sp = score_pair.split(' ') + score_pair = review.split("\n")[0] + score_pair = score_pair.replace(",", " ") + sp = score_pair.split(" ") if len(sp) == 2: return [float(sp[0]), float(sp[1])] else: - print('error', review) + print("error", review) return [-1, -1] except Exception as e: print(e) - print('error', review) + print("error", review) return [-1, -1] -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') - parser.add_argument('-q', '--question') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="ChatGPT-based QA evaluation.") + parser.add_argument("-q", "--question") # parser.add_argument('-a', '--answer') - parser.add_argument('-a', '--answer-list', nargs='+', default=[]) - parser.add_argument('-r', '--rule') - parser.add_argument('-o', '--output') - parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') + parser.add_argument("-a", "--answer-list", nargs="+", default=[]) + parser.add_argument("-r", "--rule") + parser.add_argument("-o", "--output") + parser.add_argument( + "--max-tokens", + type=int, + default=1024, + help="maximum number of tokens produced in the output", + ) args = parser.parse_args() ray.init() @@ -65,9 +74,9 @@ if __name__ == '__main__': f_q = open(os.path.expanduser(args.question)) f_ans1 = open(os.path.expanduser(args.answer_list[0])) f_ans2 = open(os.path.expanduser(args.answer_list[1])) - rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) + rule_dict = json.load(open(os.path.expanduser(args.rule), "r")) - review_file = open(f'{args.output}', 'w') + review_file = open(f"{args.output}", "w") js_list = [] handles = [] @@ -80,23 +89,28 @@ if __name__ == '__main__': ans1 = json.loads(ans1_js) ans2 = json.loads(ans2_js) - category = json.loads(ques_js)['category'] + category = json.loads(ques_js)["category"] if category in rule_dict: rule = rule_dict[category] else: - rule = rule_dict['default'] - prompt = rule['prompt'] - role = rule['role'] - content = (f'[Question]\n{ques["text"]}\n\n' - f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' - f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' - f'[System]\n{prompt}\n\n') - js_list.append({ - 'id': idx+1, - 'question_id': ques['question_id'], - 'answer1_id': ans1['answer_id'], - 'answer2_id': ans2['answer_id'], - 'category': category}) + rule = rule_dict["default"] + prompt = rule["prompt"] + role = rule["role"] + content = ( + f'[Question]\n{ques["text"]}\n\n' + f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' + f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' + f"[System]\n{prompt}\n\n" + ) + js_list.append( + { + "id": idx + 1, + "question_id": ques["question_id"], + "answer1_id": ans1["answer_id"], + "answer2_id": ans2["answer_id"], + "category": category, + } + ) idx += 1 handles.append(get_eval.remote(content, args.max_tokens)) # To avoid the rate limit set by OpenAI @@ -105,7 +119,7 @@ if __name__ == '__main__': reviews = ray.get(handles) for idx, review in enumerate(reviews): scores = parse_score(review) - js_list[idx]['content'] = review - js_list[idx]['tuple'] = scores - review_file.write(json.dumps(js_list[idx]) + '\n') + js_list[idx]["content"] = review + js_list[idx]["tuple"] = scores + review_file.write(json.dumps(js_list[idx]) + "\n") review_file.close() diff --git a/model/llava/eval/eval_gpt_review_visual.py b/model/llava/eval/eval_gpt_review_visual.py index 58699fd4fce2587039c860cb69c433fe544d7159..db0c21862a5573465ec326cd46426bff0ad53e92 100644 --- a/model/llava/eval/eval_gpt_review_visual.py +++ b/model/llava/eval/eval_gpt_review_visual.py @@ -1,25 +1,29 @@ import argparse import json import os +import time import openai -import tqdm import ray -import time +import tqdm + @ray.remote(num_cpus=4) def get_eval(content: str, max_tokens: int): while True: try: response = openai.ChatCompletion.create( - model='gpt-4', - messages=[{ - 'role': 'system', - 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' - }, { - 'role': 'user', - 'content': content, - }], + model="gpt-4", + messages=[ + { + "role": "system", + "content": "You are a helpful and precise assistant for checking the quality of the answer.", + }, + { + "role": "user", + "content": content, + }, + ], temperature=0.2, # TODO: figure out which temperature is best for evaluation max_tokens=max_tokens, ) @@ -30,34 +34,39 @@ def get_eval(content: str, max_tokens: int): print(e) time.sleep(1) - print('success!') - return response['choices'][0]['message']['content'] + print("success!") + return response["choices"][0]["message"]["content"] def parse_score(review): try: - score_pair = review.split('\n')[0] - score_pair = score_pair.replace(',', ' ') - sp = score_pair.split(' ') + score_pair = review.split("\n")[0] + score_pair = score_pair.replace(",", " ") + sp = score_pair.split(" ") if len(sp) == 2: return [float(sp[0]), float(sp[1])] else: - print('error', review) + print("error", review) return [-1, -1] except Exception as e: print(e) - print('error', review) + print("error", review) return [-1, -1] -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') - parser.add_argument('-q', '--question') - parser.add_argument('-c', '--context') - parser.add_argument('-a', '--answer-list', nargs='+', default=[]) - parser.add_argument('-r', '--rule') - parser.add_argument('-o', '--output') - parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="ChatGPT-based QA evaluation.") + parser.add_argument("-q", "--question") + parser.add_argument("-c", "--context") + parser.add_argument("-a", "--answer-list", nargs="+", default=[]) + parser.add_argument("-r", "--rule") + parser.add_argument("-o", "--output") + parser.add_argument( + "--max-tokens", + type=int, + default=1024, + help="maximum number of tokens produced in the output", + ) args = parser.parse_args() ray.init() @@ -65,12 +74,12 @@ if __name__ == '__main__': f_q = open(os.path.expanduser(args.question)) f_ans1 = open(os.path.expanduser(args.answer_list[0])) f_ans2 = open(os.path.expanduser(args.answer_list[1])) - rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) + rule_dict = json.load(open(os.path.expanduser(args.rule), "r")) - review_file = open(f'{args.output}', 'w') + review_file = open(f"{args.output}", "w") context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] - image_to_context = {context['image']: context for context in context_list} + image_to_context = {context["image"]: context for context in context_list} js_list = [] handles = [] @@ -80,28 +89,38 @@ if __name__ == '__main__': ans1 = json.loads(ans1_js) ans2 = json.loads(ans2_js) - inst = image_to_context[ques['image']] - cap_str = '\n'.join(inst['captions']) - box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']]) + inst = image_to_context[ques["image"]] + cap_str = "\n".join(inst["captions"]) + box_str = "\n".join( + [ + f'{instance["category"]}: {instance["bbox"]}' + for instance in inst["instances"] + ] + ) - category = json.loads(ques_js)['category'] + category = json.loads(ques_js)["category"] if category in rule_dict: rule = rule_dict[category] else: assert False, f"Visual QA category not found in rule file: {category}." - prompt = rule['prompt'] - role = rule['role'] - content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n' - f'[Question]\n{ques["text"]}\n\n' - f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' - f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' - f'[System]\n{prompt}\n\n') - js_list.append({ - 'id': idx+1, - 'question_id': ques['question_id'], - 'answer1_id': ans1.get('answer_id', ans1['question_id']), - 'answer2_id': ans2.get('answer_id', ans2['answer_id']), - 'category': category}) + prompt = rule["prompt"] + role = rule["role"] + content = ( + f"[Context]\n{cap_str}\n\n{box_str}\n\n" + f'[Question]\n{ques["text"]}\n\n' + f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' + f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' + f"[System]\n{prompt}\n\n" + ) + js_list.append( + { + "id": idx + 1, + "question_id": ques["question_id"], + "answer1_id": ans1.get("answer_id", ans1["question_id"]), + "answer2_id": ans2.get("answer_id", ans2["answer_id"]), + "category": category, + } + ) idx += 1 handles.append(get_eval.remote(content, args.max_tokens)) # To avoid the rate limit set by OpenAI @@ -110,7 +129,7 @@ if __name__ == '__main__': reviews = ray.get(handles) for idx, review in enumerate(reviews): scores = parse_score(review) - js_list[idx]['content'] = review - js_list[idx]['tuple'] = scores - review_file.write(json.dumps(js_list[idx]) + '\n') + js_list[idx]["content"] = review + js_list[idx]["tuple"] = scores + review_file.write(json.dumps(js_list[idx]) + "\n") review_file.close() diff --git a/model/llava/eval/eval_science_qa.py b/model/llava/eval/eval_science_qa.py index e1b3ce52fd6d922f247cc0c48409e88d5af3f204..ccadd319f666e732375654656815f737197c6dd0 100644 --- a/model/llava/eval/eval_science_qa.py +++ b/model/llava/eval/eval_science_qa.py @@ -1,26 +1,26 @@ import argparse import json import os -import re import random +import re def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--base-dir', type=str) - parser.add_argument('--result-file', type=str) - parser.add_argument('--output-file', type=str) - parser.add_argument('--output-result', type=str) - parser.add_argument('--split', type=str, default='test') - parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) + parser.add_argument("--base-dir", type=str) + parser.add_argument("--result-file", type=str) + parser.add_argument("--output-file", type=str) + parser.add_argument("--output-result", type=str) + parser.add_argument("--split", type=str, default="test") + parser.add_argument("--options", type=list, default=["A", "B", "C", "D", "E"]) return parser.parse_args() def convert_caps(results): fakecaps = [] for result in results: - image_id = result['question_id'] - caption = result['text'] + image_id = result["question_id"] + caption = result["text"] fakecaps.append({"image_id": int(image_id), "caption": caption}) return fakecaps @@ -29,7 +29,7 @@ def get_pred_idx(prediction, choices, options): """ Get the index (e.g. 2) from the prediction (e.g. 'C') """ - if prediction in options[:len(choices)]: + if prediction in options[: len(choices)]: return options.index(prediction) else: return random.choice(range(len(choices))) @@ -39,61 +39,65 @@ if __name__ == "__main__": args = get_args() base_dir = args.base_dir - split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] + split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[ + args.split + ] problems = json.load(open(os.path.join(base_dir, "problems.json"))) predictions = [json.loads(line) for line in open(args.result_file)] - predictions = {pred['question_id']: pred for pred in predictions} + predictions = {pred["question_id"]: pred for pred in predictions} split_problems = {idx: problems[idx] for idx in split_indices} - results = {'correct': [], 'incorrect': []} + results = {"correct": [], "incorrect": []} sqa_results = {} - sqa_results['acc'] = None - sqa_results['correct'] = None - sqa_results['count'] = None - sqa_results['results'] = {} - sqa_results['outputs'] = {} + sqa_results["acc"] = None + sqa_results["correct"] = None + sqa_results["count"] = None + sqa_results["results"] = {} + sqa_results["outputs"] = {} for prob_id, prob in split_problems.items(): if prob_id not in predictions: continue pred = predictions[prob_id] - pred_text = pred['text'] + pred_text = pred["text"] - pattern = re.compile(r'The answer is ([A-Z]).') + pattern = re.compile(r"The answer is ([A-Z]).") res = pattern.findall(pred_text) if len(res) == 1: answer = res[0] # 'A', 'B', ... else: answer = "FAILED" - pred_idx = get_pred_idx(answer, prob['choices'], args.options) + pred_idx = get_pred_idx(answer, prob["choices"], args.options) analysis = { - 'question_id': prob_id, - 'parsed_ans': answer, - 'ground_truth': args.options[prob['answer']], - 'question': pred['prompt'], - 'pred': pred_text, - 'is_multimodal': '' in pred['prompt'], + "question_id": prob_id, + "parsed_ans": answer, + "ground_truth": args.options[prob["answer"]], + "question": pred["prompt"], + "pred": pred_text, + "is_multimodal": "" in pred["prompt"], } - sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options) - sqa_results['outputs'][prob_id] = pred_text + sqa_results["results"][prob_id] = get_pred_idx( + answer, prob["choices"], args.options + ) + sqa_results["outputs"][prob_id] = pred_text - if pred_idx == prob['answer']: - results['correct'].append(analysis) + if pred_idx == prob["answer"]: + results["correct"].append(analysis) else: - results['incorrect'].append(analysis) + results["incorrect"].append(analysis) - correct = len(results['correct']) - total = len(results['correct']) + len(results['incorrect']) - print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') + correct = len(results["correct"]) + total = len(results["correct"]) + len(results["incorrect"]) + print(f"Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%") - sqa_results['acc'] = correct / total * 100 - sqa_results['correct'] = correct - sqa_results['count'] = total + sqa_results["acc"] = correct / total * 100 + sqa_results["correct"] = correct + sqa_results["count"] = total - with open(args.output_file, 'w') as f: + with open(args.output_file, "w") as f: json.dump(results, f, indent=2) - with open(args.output_result, 'w') as f: + with open(args.output_result, "w") as f: json.dump(sqa_results, f, indent=2) diff --git a/model/llava/eval/eval_science_qa_gpt4.py b/model/llava/eval/eval_science_qa_gpt4.py index c2ff17c915481fb556aba6ec816a9e08f519c515..662285994c79681298c4ede01e938ba10181bc73 100644 --- a/model/llava/eval/eval_science_qa_gpt4.py +++ b/model/llava/eval/eval_science_qa_gpt4.py @@ -1,26 +1,26 @@ import argparse import json import os -import re import random +import re from collections import defaultdict def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--base-dir', type=str) - parser.add_argument('--gpt4-result', type=str) - parser.add_argument('--our-result', type=str) - parser.add_argument('--split', type=str, default='test') - parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) + parser.add_argument("--base-dir", type=str) + parser.add_argument("--gpt4-result", type=str) + parser.add_argument("--our-result", type=str) + parser.add_argument("--split", type=str, default="test") + parser.add_argument("--options", type=list, default=["A", "B", "C", "D", "E"]) return parser.parse_args() def convert_caps(results): fakecaps = [] for result in results: - image_id = result['question_id'] - caption = result['text'] + image_id = result["question_id"] + caption = result["text"] fakecaps.append({"image_id": int(image_id), "caption": caption}) return fakecaps @@ -29,7 +29,7 @@ def get_pred_idx(prediction, choices, options): """ Get the index (e.g. 2) from the prediction (e.g. 'C') """ - if prediction in options[:len(choices)]: + if prediction in options[: len(choices)]: return options.index(prediction) else: return random.choice(range(len(choices))) @@ -39,13 +39,15 @@ if __name__ == "__main__": args = get_args() base_dir = args.base_dir - split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] + split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[ + args.split + ] problems = json.load(open(os.path.join(base_dir, "problems.json"))) our_predictions = [json.loads(line) for line in open(args.our_result)] - our_predictions = {pred['question_id']: pred for pred in our_predictions} + our_predictions = {pred["question_id"]: pred for pred in our_predictions} split_problems = {idx: problems[idx] for idx in split_indices} - gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] + gpt4_predictions = json.load(open(args.gpt4_result))["outputs"] results = defaultdict(lambda: 0) @@ -54,10 +56,10 @@ if __name__ == "__main__": continue if prob_id not in gpt4_predictions: continue - our_pred = our_predictions[prob_id]['text'] + our_pred = our_predictions[prob_id]["text"] gpt4_pred = gpt4_predictions[prob_id] - pattern = re.compile(r'The answer is ([A-Z]).') + pattern = re.compile(r"The answer is ([A-Z]).") our_res = pattern.findall(our_pred) if len(our_res) == 1: our_answer = our_res[0] # 'A', 'B', ... @@ -69,11 +71,11 @@ if __name__ == "__main__": else: gpt4_answer = "FAILED" - our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) - gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) + our_pred_idx = get_pred_idx(our_answer, prob["choices"], args.options) + gpt4_pred_idx = get_pred_idx(gpt4_answer, prob["choices"], args.options) - if gpt4_answer == 'FAILED': - results['gpt4_failed'] += 1 + if gpt4_answer == "FAILED": + results["gpt4_failed"] += 1 # continue gpt4_pred_idx = our_pred_idx # if our_pred_idx != prob['answer']: @@ -87,18 +89,20 @@ if __name__ == "__main__": pass # gpt4_pred_idx = our_pred_idx - if gpt4_pred_idx == prob['answer']: - results['correct'] += 1 + if gpt4_pred_idx == prob["answer"]: + results["correct"] += 1 else: - results['incorrect'] += 1 - - - if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: - results['correct_upperbound'] += 1 - - correct = results['correct'] - total = results['correct'] + results['incorrect'] - print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') - print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') - print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') - + results["incorrect"] += 1 + + if gpt4_pred_idx == prob["answer"] or our_pred_idx == prob["answer"]: + results["correct_upperbound"] += 1 + + correct = results["correct"] + total = results["correct"] + results["incorrect"] + print(f"Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%") + print( + f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%' + ) + print( + f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%' + ) diff --git a/model/llava/eval/eval_science_qa_gpt4_requery.py b/model/llava/eval/eval_science_qa_gpt4_requery.py index 698546e995d365d1ccc2c25a87e6c5cd681e6eb6..f74fea2d68dc63af13b2efee8c887091dc4ead04 100644 --- a/model/llava/eval/eval_science_qa_gpt4_requery.py +++ b/model/llava/eval/eval_science_qa_gpt4_requery.py @@ -1,28 +1,28 @@ import argparse import json import os -import re import random +import re from collections import defaultdict def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--base-dir', type=str) - parser.add_argument('--gpt4-result', type=str) - parser.add_argument('--requery-result', type=str) - parser.add_argument('--our-result', type=str) - parser.add_argument('--output-result', type=str) - parser.add_argument('--split', type=str, default='test') - parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) + parser.add_argument("--base-dir", type=str) + parser.add_argument("--gpt4-result", type=str) + parser.add_argument("--requery-result", type=str) + parser.add_argument("--our-result", type=str) + parser.add_argument("--output-result", type=str) + parser.add_argument("--split", type=str, default="test") + parser.add_argument("--options", type=list, default=["A", "B", "C", "D", "E"]) return parser.parse_args() def convert_caps(results): fakecaps = [] for result in results: - image_id = result['question_id'] - caption = result['text'] + image_id = result["question_id"] + caption = result["text"] fakecaps.append({"image_id": int(image_id), "caption": caption}) return fakecaps @@ -31,7 +31,7 @@ def get_pred_idx(prediction, choices, options): """ Get the index (e.g. 2) from the prediction (e.g. 'C') """ - if prediction in options[:len(choices)]: + if prediction in options[: len(choices)]: return options.index(prediction) else: return random.choice(range(len(choices))) @@ -41,40 +41,42 @@ if __name__ == "__main__": args = get_args() base_dir = args.base_dir - split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] + split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[ + args.split + ] problems = json.load(open(os.path.join(base_dir, "problems.json"))) our_predictions = [json.loads(line) for line in open(args.our_result)] - our_predictions = {pred['question_id']: pred for pred in our_predictions} + our_predictions = {pred["question_id"]: pred for pred in our_predictions} split_problems = {idx: problems[idx] for idx in split_indices} requery_predictions = [json.loads(line) for line in open(args.requery_result)] - requery_predictions = {pred['question_id']: pred for pred in requery_predictions} + requery_predictions = {pred["question_id"]: pred for pred in requery_predictions} - gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] + gpt4_predictions = json.load(open(args.gpt4_result))["outputs"] results = defaultdict(lambda: 0) sqa_results = {} - sqa_results['acc'] = None - sqa_results['correct'] = None - sqa_results['count'] = None - sqa_results['results'] = {} - sqa_results['outputs'] = {} + sqa_results["acc"] = None + sqa_results["correct"] = None + sqa_results["count"] = None + sqa_results["results"] = {} + sqa_results["outputs"] = {} for prob_id, prob in split_problems.items(): if prob_id not in our_predictions: assert False if prob_id not in gpt4_predictions: assert False - our_pred = our_predictions[prob_id]['text'] + our_pred = our_predictions[prob_id]["text"] gpt4_pred = gpt4_predictions[prob_id] if prob_id not in requery_predictions: - results['missing_requery'] += 1 + results["missing_requery"] += 1 requery_pred = "MISSING" else: - requery_pred = requery_predictions[prob_id]['text'] + requery_pred = requery_predictions[prob_id]["text"] - pattern = re.compile(r'The answer is ([A-Z]).') + pattern = re.compile(r"The answer is ([A-Z]).") our_res = pattern.findall(our_pred) if len(our_res) == 1: our_answer = our_res[0] # 'A', 'B', ... @@ -93,57 +95,70 @@ if __name__ == "__main__": else: gpt4_answer = "FAILED" - our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) - gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) - requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options) - - results['total'] += 1 - - if gpt4_answer == 'FAILED': - results['gpt4_failed'] += 1 - if gpt4_pred_idx == prob['answer']: - results['gpt4_correct'] += 1 - if our_pred_idx == prob['answer']: - results['gpt4_ourvisual_correct'] += 1 - elif gpt4_pred_idx == prob['answer']: - results['gpt4_correct'] += 1 - results['gpt4_ourvisual_correct'] += 1 - - if our_pred_idx == prob['answer']: - results['our_correct'] += 1 - - if requery_answer == 'FAILED': - sqa_results['results'][prob_id] = our_pred_idx - if our_pred_idx == prob['answer']: - results['requery_correct'] += 1 + our_pred_idx = get_pred_idx(our_answer, prob["choices"], args.options) + gpt4_pred_idx = get_pred_idx(gpt4_answer, prob["choices"], args.options) + requery_pred_idx = get_pred_idx(requery_answer, prob["choices"], args.options) + + results["total"] += 1 + + if gpt4_answer == "FAILED": + results["gpt4_failed"] += 1 + if gpt4_pred_idx == prob["answer"]: + results["gpt4_correct"] += 1 + if our_pred_idx == prob["answer"]: + results["gpt4_ourvisual_correct"] += 1 + elif gpt4_pred_idx == prob["answer"]: + results["gpt4_correct"] += 1 + results["gpt4_ourvisual_correct"] += 1 + + if our_pred_idx == prob["answer"]: + results["our_correct"] += 1 + + if requery_answer == "FAILED": + sqa_results["results"][prob_id] = our_pred_idx + if our_pred_idx == prob["answer"]: + results["requery_correct"] += 1 else: - sqa_results['results'][prob_id] = requery_pred_idx - if requery_pred_idx == prob['answer']: - results['requery_correct'] += 1 + sqa_results["results"][prob_id] = requery_pred_idx + if requery_pred_idx == prob["answer"]: + results["requery_correct"] += 1 else: - print(f""" + print( + f""" Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']} Our ({our_answer}): {our_pred} GPT-4 ({gpt4_answer}): {gpt4_pred} Requery ({requery_answer}): {requery_pred} print("=====================================") -""") - - if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: - results['correct_upperbound'] += 1 - - total = results['total'] - print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%') - print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%') - print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') - print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%') - print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%') - print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') - - sqa_results['acc'] = results["requery_correct"] / total * 100 - sqa_results['correct'] = results["requery_correct"] - sqa_results['count'] = total - - with open(args.output_result, 'w') as f: +""" + ) + + if gpt4_pred_idx == prob["answer"] or our_pred_idx == prob["answer"]: + results["correct_upperbound"] += 1 + + total = results["total"] + print( + f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%' + ) + print( + f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%' + ) + print( + f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%' + ) + print( + f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%' + ) + print( + f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%' + ) + print( + f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%' + ) + + sqa_results["acc"] = results["requery_correct"] / total * 100 + sqa_results["correct"] = results["requery_correct"] + sqa_results["count"] = total + + with open(args.output_result, "w") as f: json.dump(sqa_results, f, indent=2) - diff --git a/model/llava/eval/generate_webpage_data_from_table.py b/model/llava/eval/generate_webpage_data_from_table.py index 92602258ccd953a1d7137056aaf15c8de8166e21..d24e266ecbaf69a44460e08df9f081dc50808db5 100644 --- a/model/llava/eval/generate_webpage_data_from_table.py +++ b/model/llava/eval/generate_webpage_data_from_table.py @@ -4,10 +4,10 @@ import os import re # models = ['llama', 'alpaca', 'gpt35', 'bard'] -models = ['vicuna'] +models = ["vicuna"] -def read_jsonl(path: str, key: str=None): +def read_jsonl(path: str, key: str = None): data = [] with open(os.path.expanduser(path)) as f: for line in f: @@ -23,21 +23,27 @@ def read_jsonl(path: str, key: str=None): def trim_hanging_lines(s: str, n: int) -> str: s = s.strip() for _ in range(n): - s = s.split('\n', 1)[1].strip() + s = s.split("\n", 1)[1].strip() return s -if __name__ == '__main__': - questions = read_jsonl('table/question.jsonl', key='question_id') +if __name__ == "__main__": + questions = read_jsonl("table/question.jsonl", key="question_id") # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id') # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id') # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id') # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id') - vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id') - ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id') + vicuna_answers = read_jsonl( + "table/answer/answer_vicuna-13b.jsonl", key="question_id" + ) + ours_answers = read_jsonl( + "table/results/llama-13b-hf-alpaca.jsonl", key="question_id" + ) - review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id') + review_vicuna = read_jsonl( + "table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl", key="question_id" + ) # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id') # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id') # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id') @@ -46,26 +52,26 @@ if __name__ == '__main__': records = [] for qid in questions.keys(): r = { - 'id': qid, - 'category': questions[qid]['category'], - 'question': questions[qid]['text'], - 'answers': { + "id": qid, + "category": questions[qid]["category"], + "question": questions[qid]["text"], + "answers": { # 'alpaca': alpaca_answers[qid]['text'], # 'llama': llama_answers[qid]['text'], # 'bard': bard_answers[qid]['text'], # 'gpt35': gpt35_answers[qid]['text'], - 'vicuna': vicuna_answers[qid]['text'], - 'ours': ours_answers[qid]['text'], + "vicuna": vicuna_answers[qid]["text"], + "ours": ours_answers[qid]["text"], }, - 'evaluations': { + "evaluations": { # 'alpaca': review_alpaca[qid]['text'], # 'llama': review_llama[qid]['text'], # 'bard': review_bard[qid]['text'], - 'vicuna': review_vicuna[qid]['content'], + "vicuna": review_vicuna[qid]["content"], # 'gpt35': review_gpt35[qid]['text'], }, - 'scores': { - 'vicuna': review_vicuna[qid]['tuple'], + "scores": { + "vicuna": review_vicuna[qid]["tuple"], # 'alpaca': review_alpaca[qid]['score'], # 'llama': review_llama[qid]['score'], # 'bard': review_bard[qid]['score'], @@ -75,37 +81,39 @@ if __name__ == '__main__': # cleanup data cleaned_evals = {} - for k, v in r['evaluations'].items(): + for k, v in r["evaluations"].items(): v = v.strip() - lines = v.split('\n') + lines = v.split("\n") # trim the first line if it's a pair of numbers - if re.match(r'\d+[, ]+\d+', lines[0]): + if re.match(r"\d+[, ]+\d+", lines[0]): lines = lines[1:] - v = '\n'.join(lines) - cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**') + v = "\n".join(lines) + cleaned_evals[k] = v.replace("Assistant 1", "**Assistant 1**").replace( + "Assistant 2", "**Assistant 2**" + ) - r['evaluations'] = cleaned_evals + r["evaluations"] = cleaned_evals records.append(r) # Reorder the records, this is optional for r in records: - if r['id'] <= 20: - r['id'] += 60 + if r["id"] <= 20: + r["id"] += 60 else: - r['id'] -= 20 + r["id"] -= 20 for r in records: - if r['id'] <= 50: - r['id'] += 10 - elif 50 < r['id'] <= 60: - r['id'] -= 50 + if r["id"] <= 50: + r["id"] += 10 + elif 50 < r["id"] <= 60: + r["id"] -= 50 for r in records: - if r['id'] == 7: - r['id'] = 1 - elif r['id'] < 7: - r['id'] += 1 + if r["id"] == 7: + r["id"] = 1 + elif r["id"] < 7: + r["id"] += 1 - records.sort(key=lambda x: x['id']) + records.sort(key=lambda x: x["id"]) # Write to file - with open('webpage/data.json', 'w') as f: - json.dump({'questions': records, 'models': models}, f, indent=2) + with open("webpage/data.json", "w") as f: + json.dump({"questions": records, "models": models}, f, indent=2) diff --git a/model/llava/eval/model_qa.py b/model/llava/eval/model_qa.py index 1dab4f2e9e742fa855da3d516c9020fa6a5d5042..a1c2cee07fb7ab3aa182d0e69612f2a2ebb5c761 100644 --- a/model/llava/eval/model_qa.py +++ b/model/llava/eval/model_qa.py @@ -1,13 +1,13 @@ import argparse -from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria -import torch -import os import json -from tqdm import tqdm -import shortuuid +import os +import shortuuid +import torch from llava.conversation import default_conversation from llava.utils import disable_torch_init +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria # new stopping implementation @@ -18,11 +18,15 @@ class KeywordsStoppingCriteria(StoppingCriteria): self.start_len = None self.input_ids = input_ids - def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + def __call__( + self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs + ) -> bool: if self.start_len is None: self.start_len = self.input_ids.shape[1] else: - outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] + outputs = self.tokenizer.batch_decode( + output_ids[:, self.start_len :], skip_special_tokens=True + )[0] for keyword in self.keywords: if keyword in outputs: return True @@ -35,9 +39,9 @@ def eval_model(model_name, questions_file, answers_file): disable_torch_init() model_name = os.path.expanduser(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) - model = AutoModelForCausalLM.from_pretrained(model_name, - torch_dtype=torch.float16).cuda() - + model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=torch.float16 + ).cuda() ques_file = open(os.path.expanduser(questions_file), "r") ans_file = open(os.path.expanduser(answers_file), "w") @@ -56,7 +60,8 @@ def eval_model(model_name, questions_file, answers_file): do_sample=True, temperature=0.7, max_new_tokens=1024, - stopping_criteria=[stopping_criteria]) + stopping_criteria=[stopping_criteria], + ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] try: index = outputs.index(conv.sep, len(prompt)) @@ -64,16 +69,24 @@ def eval_model(model_name, questions_file, answers_file): outputs += conv.sep index = outputs.index(conv.sep, len(prompt)) - outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip() + outputs = outputs[len(prompt) + len(conv.roles[1]) + 2 : index].strip() ans_id = shortuuid.uuid() - ans_file.write(json.dumps({"question_id": idx, - "text": outputs, - "answer_id": ans_id, - "model_id": model_name, - "metadata": {}}) + "\n") + ans_file.write( + json.dumps( + { + "question_id": idx, + "text": outputs, + "answer_id": ans_id, + "model_id": model_name, + "metadata": {}, + } + ) + + "\n" + ) ans_file.flush() ans_file.close() + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-name", type=str, default="facebook/opt-350m") diff --git a/model/llava/eval/model_vqa.py b/model/llava/eval/model_vqa.py index 880e433c5def6842adbbb878368867b088447701..d82fb1487dbc803b318c9f8b9662c0cab1f455b2 100644 --- a/model/llava/eval/model_vqa.py +++ b/model/llava/eval/model_vqa.py @@ -1,25 +1,25 @@ import argparse -from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig -import torch -import os import json -from tqdm import tqdm -import shortuuid +import math +import os +import random +import shortuuid +import torch from llava import LlavaLlamaForCausalLM from llava.conversation import conv_templates from llava.utils import disable_torch_init -from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria - from PIL import Image -import random -import math +from tqdm import tqdm +from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, + CLIPImageProcessor, CLIPVisionModel, + StoppingCriteria) def split_list(lst, n): """Split a list into n (roughly) equal-sized chunks""" chunk_size = math.ceil(len(lst) / n) # integer division - return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] + return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] def get_chunk(lst, n, k): @@ -37,12 +37,14 @@ def patch_config(config): patch_dict = { "use_mm_proj": True, "mm_vision_tower": "openai/clip-vit-large-patch14", - "mm_hidden_size": 1024 + "mm_hidden_size": 1024, } cfg = AutoConfig.from_pretrained(config) if not hasattr(cfg, "mm_vision_tower"): - print(f'`mm_vision_tower` not found in `{config}`, applying patch and save to disk.') + print( + f"`mm_vision_tower` not found in `{config}`, applying patch and save to disk." + ) for k, v in patch_dict.items(): setattr(cfg, k, v) cfg.save_pretrained(config) @@ -55,50 +57,84 @@ def eval_model(args): tokenizer = AutoTokenizer.from_pretrained(model_name) if args.mm_projector is None: patch_config(model_name) - model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).cuda() - image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16) + model = LlavaLlamaForCausalLM.from_pretrained( + model_name, torch_dtype=torch.float16 + ).cuda() + image_processor = CLIPImageProcessor.from_pretrained( + model.config.mm_vision_tower, torch_dtype=torch.float16 + ) mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: - tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + tokenizer.add_tokens( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True + ) vision_tower = model.model.vision_tower[0] - vision_tower.to(device='cuda', dtype=torch.float16) + vision_tower.to(device="cuda", dtype=torch.float16) vision_config = vision_tower.config - vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] + vision_config.im_patch_token = tokenizer.convert_tokens_to_ids( + [DEFAULT_IMAGE_PATCH_TOKEN] + )[0] vision_config.use_im_start_end = mm_use_im_start_end if mm_use_im_start_end: - vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) + ( + vision_config.im_start_token, + vision_config.im_end_token, + ) = tokenizer.convert_tokens_to_ids( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN] + ) image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2 else: # in case of using a pretrained model with only a MLP projector weights - model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).cuda() + model = LlavaLlamaForCausalLM.from_pretrained( + model_name, torch_dtype=torch.float16 + ).cuda() - vision_tower = CLIPVisionModel.from_pretrained(args.vision_tower, torch_dtype=torch.float16).cuda() - image_processor = CLIPImageProcessor.from_pretrained(args.vision_tower, torch_dtype=torch.float16) + vision_tower = CLIPVisionModel.from_pretrained( + args.vision_tower, torch_dtype=torch.float16 + ).cuda() + image_processor = CLIPImageProcessor.from_pretrained( + args.vision_tower, torch_dtype=torch.float16 + ) mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: - tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + tokenizer.add_tokens( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True + ) vision_config = vision_tower.config - vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] + vision_config.im_patch_token = tokenizer.convert_tokens_to_ids( + [DEFAULT_IMAGE_PATCH_TOKEN] + )[0] vision_config.use_im_start_end = mm_use_im_start_end if mm_use_im_start_end: - vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) + ( + vision_config.im_start_token, + vision_config.im_end_token, + ) = tokenizer.convert_tokens_to_ids( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN] + ) image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2 - mm_projector = torch.nn.Linear(vision_config.hidden_size, model.config.hidden_size) - mm_projector_weights = torch.load(args.mm_projector, map_location='cpu') - mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()}) + mm_projector = torch.nn.Linear( + vision_config.hidden_size, model.config.hidden_size + ) + mm_projector_weights = torch.load(args.mm_projector, map_location="cpu") + mm_projector.load_state_dict( + {k.split(".")[-1]: v for k, v in mm_projector_weights.items()} + ) model.model.mm_projector = mm_projector.cuda().half() model.model.vision_tower = [vision_tower] - questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] + questions = [ + json.loads(q) for q in open(os.path.expanduser(args.question_file), "r") + ] questions = get_chunk(questions, args.num_chunks, args.chunk_idx) answers_file = os.path.expanduser(args.answers_file) os.makedirs(os.path.dirname(answers_file), exist_ok=True) @@ -109,12 +145,18 @@ def eval_model(args): qs = line["text"] cur_prompt = qs if mm_use_im_start_end: - qs = qs + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN + qs = ( + qs + + "\n" + + DEFAULT_IM_START_TOKEN + + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + + DEFAULT_IM_END_TOKEN + ) else: - qs = qs + '\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + qs = qs + "\n" + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len - if args.conv_mode == 'simple_legacy': - qs += '\n\n### Response:' + if args.conv_mode == "simple_legacy": + qs += "\n\n### Response:" # conv = default_conversation.copy() conv = conv_templates[args.conv_mode].copy() conv.append_message(conv.roles[0], qs) @@ -123,7 +165,9 @@ def eval_model(args): image = Image.open(os.path.join(args.image_folder, image_file)) # image.save(os.path.join(save_image_folder, image_file)) - image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + image_tensor = image_processor.preprocess(image, return_tensors="pt")[ + "pixel_values" + ][0] input_ids = torch.as_tensor(inputs.input_ids).cuda() @@ -135,17 +179,21 @@ def eval_model(args): self.start_len = None self.input_ids = input_ids - def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + def __call__( + self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs + ) -> bool: if self.start_len is None: self.start_len = self.input_ids.shape[1] else: - outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] + outputs = self.tokenizer.batch_decode( + output_ids[:, self.start_len :], skip_special_tokens=True + )[0] for keyword in self.keywords: if keyword in outputs: return True return False - keywords = ['###'] + keywords = ["###"] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) with torch.inference_mode(): @@ -155,21 +203,28 @@ def eval_model(args): do_sample=True, temperature=0.7, max_new_tokens=1024, - stopping_criteria=[stopping_criteria]) + stopping_criteria=[stopping_criteria], + ) input_token_len = input_ids.shape[1] - n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() + n_diff_input_output = ( + (input_ids != output_ids[:, :input_token_len]).sum().item() + ) if n_diff_input_output > 0: - print(f'[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids') - outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] - - if args.conv_mode == 'simple_legacy' or args.conv_mode == 'simple': + print( + f"[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids" + ) + outputs = tokenizer.batch_decode( + output_ids[:, input_token_len:], skip_special_tokens=True + )[0] + + if args.conv_mode == "simple_legacy" or args.conv_mode == "simple": while True: cur_len = len(outputs) outputs = outputs.strip() - for pattern in ['###', 'Assistant:', 'Response:']: + for pattern in ["###", "Assistant:", "Response:"]: if outputs.startswith(pattern): - outputs = outputs[len(pattern):].strip() + outputs = outputs[len(pattern) :].strip() if len(outputs) == cur_len: break @@ -182,15 +237,23 @@ def eval_model(args): outputs = outputs[:index].strip() ans_id = shortuuid.uuid() - ans_file.write(json.dumps({"question_id": idx, - "prompt": cur_prompt, - "text": outputs, - "answer_id": ans_id, - "model_id": model_name, - "metadata": {}}) + "\n") + ans_file.write( + json.dumps( + { + "question_id": idx, + "prompt": cur_prompt, + "text": outputs, + "answer_id": ans_id, + "model_id": model_name, + "metadata": {}, + } + ) + + "\n" + ) ans_file.flush() ans_file.close() + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-name", type=str, default="facebook/opt-350m") diff --git a/model/llava/eval/model_vqa_science.py b/model/llava/eval/model_vqa_science.py index 52d60521af6fe681cb92f1e70a9c107ac8ccaec9..6623e72ca60aff8fcda6eeb885f5568f2d17a07c 100644 --- a/model/llava/eval/model_vqa_science.py +++ b/model/llava/eval/model_vqa_science.py @@ -1,25 +1,25 @@ import argparse -from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig -import torch -import os import json -from tqdm import tqdm -import shortuuid +import math +import os +import random +import shortuuid +import torch from llava import LlavaLlamaForCausalLM from llava.conversation import conv_templates from llava.utils import disable_torch_init -from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria - from PIL import Image -import random -import math +from tqdm import tqdm +from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, + CLIPImageProcessor, CLIPVisionModel, + StoppingCriteria) def split_list(lst, n): """Split a list into n (roughly) equal-sized chunks""" chunk_size = math.ceil(len(lst) / n) # integer division - return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] + return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] def get_chunk(lst, n, k): @@ -33,8 +33,6 @@ DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" - - detail_describe_instructions = [ "Describe the following image in detail.", "Provide a detailed description of the given image.", @@ -70,19 +68,21 @@ concise_describe_instructions = [ prompt_pool = detail_describe_instructions + concise_describe_instructions -prompt_pool = [ "Describe the following image in detail."] +prompt_pool = ["Describe the following image in detail."] def patch_config(config): patch_dict = { "use_mm_proj": True, "mm_vision_tower": "openai/clip-vit-large-patch14", - "mm_hidden_size": 1024 + "mm_hidden_size": 1024, } cfg = AutoConfig.from_pretrained(config) if not hasattr(cfg, "mm_vision_tower"): - print(f'`mm_vision_tower` not found in `{config}`, applying patch and save to disk.') + print( + f"`mm_vision_tower` not found in `{config}`, applying patch and save to disk." + ) for k, v in patch_dict.items(): setattr(cfg, k, v) cfg.save_pretrained(config) @@ -96,11 +96,15 @@ class KeywordsStoppingCriteria(StoppingCriteria): self.start_len = None self.input_ids = input_ids - def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + def __call__( + self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs + ) -> bool: if self.start_len is None: self.start_len = self.input_ids.shape[1] else: - outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] + outputs = self.tokenizer.batch_decode( + output_ids[:, self.start_len :], skip_special_tokens=True + )[0] for keyword in self.keywords: if keyword in outputs: return True @@ -114,45 +118,77 @@ def eval_model(args): tokenizer = AutoTokenizer.from_pretrained(model_name) if args.mm_projector is None: patch_config(model_name) - model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, use_cache=True).cuda() - image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16) + model = LlavaLlamaForCausalLM.from_pretrained( + model_name, torch_dtype=torch.float16, use_cache=True + ).cuda() + image_processor = CLIPImageProcessor.from_pretrained( + model.config.mm_vision_tower, torch_dtype=torch.float16 + ) mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: - tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + tokenizer.add_tokens( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True + ) vision_tower = model.model.vision_tower[0] - vision_tower.to(device='cuda', dtype=torch.float16) + vision_tower.to(device="cuda", dtype=torch.float16) vision_config = vision_tower.config - vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] + vision_config.im_patch_token = tokenizer.convert_tokens_to_ids( + [DEFAULT_IMAGE_PATCH_TOKEN] + )[0] vision_config.use_im_start_end = mm_use_im_start_end if mm_use_im_start_end: - vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) + ( + vision_config.im_start_token, + vision_config.im_end_token, + ) = tokenizer.convert_tokens_to_ids( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN] + ) image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2 else: # in case of using a pretrained model with only a MLP projector weights - model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, use_cache=True).cuda() + model = LlavaLlamaForCausalLM.from_pretrained( + model_name, torch_dtype=torch.float16, use_cache=True + ).cuda() mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: - tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + tokenizer.add_tokens( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True + ) - vision_tower = CLIPVisionModel.from_pretrained(args.vision_tower, torch_dtype=torch.float16).cuda() - image_processor = CLIPImageProcessor.from_pretrained(args.vision_tower, torch_dtype=torch.float16) + vision_tower = CLIPVisionModel.from_pretrained( + args.vision_tower, torch_dtype=torch.float16 + ).cuda() + image_processor = CLIPImageProcessor.from_pretrained( + args.vision_tower, torch_dtype=torch.float16 + ) vision_config = vision_tower.config - vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] + vision_config.im_patch_token = tokenizer.convert_tokens_to_ids( + [DEFAULT_IMAGE_PATCH_TOKEN] + )[0] vision_config.use_im_start_end = mm_use_im_start_end if mm_use_im_start_end: - vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) + ( + vision_config.im_start_token, + vision_config.im_end_token, + ) = tokenizer.convert_tokens_to_ids( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN] + ) image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2 - mm_projector = torch.nn.Linear(vision_config.hidden_size, model.config.hidden_size) - mm_projector_weights = torch.load(args.mm_projector, map_location='cpu') - mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()}) + mm_projector = torch.nn.Linear( + vision_config.hidden_size, model.config.hidden_size + ) + mm_projector_weights = torch.load(args.mm_projector, map_location="cpu") + mm_projector.load_state_dict( + {k.split(".")[-1]: v for k, v in mm_projector_weights.items()} + ) model.model.mm_projector = mm_projector.cuda().half() model.model.vision_tower = [vision_tower] @@ -163,33 +199,43 @@ def eval_model(args): os.makedirs(os.path.dirname(answers_file), exist_ok=True) os.makedirs(os.path.join(os.path.dirname(answers_file), "images"), exist_ok=True) ans_file = open(answers_file, "w") - save_image_folder = os.path.join(os.path.dirname(os.path.expanduser(args.answers_file)), "images") + save_image_folder = os.path.join( + os.path.dirname(os.path.expanduser(args.answers_file)), "images" + ) for i, line in enumerate(tqdm(questions)): idx = line["id"] - question = line['conversations'][0] + question = line["conversations"][0] gt_ans = line["conversations"][1] - - qs = question['value'] - qs = qs.replace('', '').strip() + qs = question["value"] + + qs = qs.replace("", "").strip() cur_prompt = qs - if 'image' in line: + if "image" in line: image_file = line["image"] image = Image.open(os.path.join(args.image_folder, image_file)) - image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + image_tensor = image_processor.preprocess(image, return_tensors="pt")[ + "pixel_values" + ][0] images = image_tensor.unsqueeze(0).half().cuda() - if getattr(model.config, 'mm_use_im_start_end', False): - qs = qs + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN + if getattr(model.config, "mm_use_im_start_end", False): + qs = ( + qs + + "\n" + + DEFAULT_IM_START_TOKEN + + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + + DEFAULT_IM_END_TOKEN + ) else: - qs = qs + '\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len - cur_prompt = cur_prompt + '\n' + '' + qs = qs + "\n" + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + cur_prompt = cur_prompt + "\n" + "" else: images = None - if args.conv_mode == 'simple_legacy': - qs += '\n\n### Response:' - assert gt_ans['from'] == 'gpt' + if args.conv_mode == "simple_legacy": + qs += "\n\n### Response:" + assert gt_ans["from"] == "gpt" # conv = default_conversation.copy() conv = conv_templates[args.conv_mode].copy() conv.append_message(conv.roles[0], qs) @@ -198,7 +244,7 @@ def eval_model(args): input_ids = torch.as_tensor(inputs.input_ids).cuda() - keywords = ['###'] + keywords = ["###"] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) with torch.inference_mode(): @@ -208,22 +254,29 @@ def eval_model(args): do_sample=True, temperature=0.7, max_new_tokens=1024, - stopping_criteria=[stopping_criteria]) + stopping_criteria=[stopping_criteria], + ) # TODO: new implementation input_token_len = input_ids.shape[1] - n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() + n_diff_input_output = ( + (input_ids != output_ids[:, :input_token_len]).sum().item() + ) if n_diff_input_output > 0: - print(f'[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids') - outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] - - if args.conv_mode == 'simple_legacy': + print( + f"[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids" + ) + outputs = tokenizer.batch_decode( + output_ids[:, input_token_len:], skip_special_tokens=True + )[0] + + if args.conv_mode == "simple_legacy": while True: cur_len = len(outputs) outputs = outputs.strip() - for pattern in ['###', 'Assistant:', 'Response:']: + for pattern in ["###", "Assistant:", "Response:"]: if outputs.startswith(pattern): - outputs = outputs[len(pattern):].strip() + outputs = outputs[len(pattern) :].strip() if len(outputs) == cur_len: break @@ -238,11 +291,11 @@ def eval_model(args): # prompt for answer if args.answer_prompter: outputs_reasoning = outputs - inputs = tokenizer([prompt + outputs_reasoning + ' ###\nANSWER:']) + inputs = tokenizer([prompt + outputs_reasoning + " ###\nANSWER:"]) input_ids = torch.as_tensor(inputs.input_ids).cuda() - keywords = ['###'] + keywords = ["###"] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) with torch.inference_mode(): @@ -252,13 +305,20 @@ def eval_model(args): do_sample=True, temperature=0.7, max_new_tokens=64, - stopping_criteria=[stopping_criteria]) + stopping_criteria=[stopping_criteria], + ) input_token_len = input_ids.shape[1] - n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() + n_diff_input_output = ( + (input_ids != output_ids[:, :input_token_len]).sum().item() + ) if n_diff_input_output > 0: - print(f'[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids') - outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] + print( + f"[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids" + ) + outputs = tokenizer.batch_decode( + output_ids[:, input_token_len:], skip_special_tokens=True + )[0] try: index = outputs.index(conv.sep) @@ -267,7 +327,7 @@ def eval_model(args): index = outputs.index(conv.sep) outputs = outputs[:index].strip() - outputs = outputs_reasoning + '\n The answer is ' + outputs + outputs = outputs_reasoning + "\n The answer is " + outputs # new implementation ends @@ -281,17 +341,24 @@ def eval_model(args): # outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip() - ans_id = shortuuid.uuid() - ans_file.write(json.dumps({"question_id": idx, - "prompt": cur_prompt, - "text": outputs, - "answer_id": ans_id, - "model_id": model_name, - "metadata": {}}) + "\n") + ans_file.write( + json.dumps( + { + "question_id": idx, + "prompt": cur_prompt, + "text": outputs, + "answer_id": ans_id, + "model_id": model_name, + "metadata": {}, + } + ) + + "\n" + ) ans_file.flush() ans_file.close() + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-name", type=str, default="facebook/opt-350m") diff --git a/model/llava/eval/qa_baseline_gpt35.py b/model/llava/eval/qa_baseline_gpt35.py index babab6e12b4bb8cfa74a7edfa5e56cd1b3e2bf6c..3c93d8db598b3e5830226da4dd21181744c40a66 100644 --- a/model/llava/eval/qa_baseline_gpt35.py +++ b/model/llava/eval/qa_baseline_gpt35.py @@ -1,51 +1,57 @@ """Generate answers with GPT-3.5""" # Note: you need to be using OpenAI Python v0.27.0 for the code below to work import argparse +import concurrent.futures import json import os import time -import concurrent.futures import openai -import tqdm import shortuuid +import tqdm + +MODEL = "gpt-3.5-turbo" +MODEL_ID = "gpt-3.5-turbo:20230327" -MODEL = 'gpt-3.5-turbo' -MODEL_ID = 'gpt-3.5-turbo:20230327' def get_answer(question_id: int, question: str, max_tokens: int): ans = { - 'answer_id': shortuuid.uuid(), - 'question_id': question_id, - 'model_id': MODEL_ID, + "answer_id": shortuuid.uuid(), + "question_id": question_id, + "model_id": MODEL_ID, } for _ in range(3): try: response = openai.ChatCompletion.create( model=MODEL, - messages=[{ - 'role': 'system', - 'content': 'You are a helpful assistant.' - }, { - 'role': 'user', - 'content': question, - }], + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": question, + }, + ], max_tokens=max_tokens, ) - ans['text'] = response['choices'][0]['message']['content'] + ans["text"] = response["choices"][0]["message"]["content"] return ans except Exception as e: - print('[ERROR]', e) - ans['text'] = '#ERROR#' + print("[ERROR]", e) + ans["text"] = "#ERROR#" time.sleep(1) return ans -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='ChatGPT answer generation.') - parser.add_argument('-q', '--question') - parser.add_argument('-o', '--output') - parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="ChatGPT answer generation.") + parser.add_argument("-q", "--question") + parser.add_argument("-o", "--output") + parser.add_argument( + "--max-tokens", + type=int, + default=1024, + help="maximum number of tokens produced in the output", + ) args = parser.parse_args() questions_dict = {} @@ -54,7 +60,7 @@ if __name__ == '__main__': if not line: continue q = json.loads(line) - questions_dict[q['question_id']] = q['text'] + questions_dict[q["question_id"]] = q["text"] answers = [] @@ -64,11 +70,13 @@ if __name__ == '__main__': future = executor.submit(get_answer, qid, question, args.max_tokens) futures.append(future) - for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)): + for future in tqdm.tqdm( + concurrent.futures.as_completed(futures), total=len(futures) + ): answers.append(future.result()) - answers.sort(key=lambda x: x['question_id']) + answers.sort(key=lambda x: x["question_id"]) - with open(os.path.expanduser(args.output), 'w') as f: + with open(os.path.expanduser(args.output), "w") as f: table = [json.dumps(ans) for ans in answers] - f.write('\n'.join(table)) + f.write("\n".join(table)) diff --git a/model/llava/eval/run_llava.py b/model/llava/eval/run_llava.py index 67ff876b47e0e32595ec0c4e31d52306c3f92e10..a9d09255445d0bb1543ed9777d6b22647181e059 100644 --- a/model/llava/eval/run_llava.py +++ b/model/llava/eval/run_llava.py @@ -1,20 +1,17 @@ import argparse -from transformers import AutoTokenizer, AutoModelForCausalLM -import torch import os -from llava.conversation import conv_templates, SeparatorStyle -from llava.utils import disable_torch_init -from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria -from llava.model import * -from llava.model.utils import KeywordsStoppingCriteria - -from PIL import Image +from io import BytesIO -import os import requests +import torch +from llava.conversation import SeparatorStyle, conv_templates +from llava.model import * +from llava.model.utils import KeywordsStoppingCriteria +from llava.utils import disable_torch_init from PIL import Image -from io import BytesIO - +from transformers import (AutoModelForCausalLM, AutoTokenizer, + CLIPImageProcessor, CLIPVisionModel, + StoppingCriteria) DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" @@ -23,11 +20,11 @@ DEFAULT_IM_END_TOKEN = "" def load_image(image_file): - if image_file.startswith('http') or image_file.startswith('https'): + if image_file.startswith("http") or image_file.startswith("https"): response = requests.get(image_file) - image = Image.open(BytesIO(response.content)).convert('RGB') + image = Image.open(BytesIO(response.content)).convert("RGB") else: - image = Image.open(image_file).convert('RGB') + image = Image.open(image_file).convert("RGB") return image @@ -38,35 +35,63 @@ def eval_model(args): tokenizer = AutoTokenizer.from_pretrained(model_name) if "mpt" in model_name.lower(): - model = LlavaMPTForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda() + model = LlavaMPTForCausalLM.from_pretrained( + model_name, + low_cpu_mem_usage=True, + torch_dtype=torch.float16, + use_cache=True, + ).cuda() else: # model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda() - model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')#.cuda() - image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16) + model = LlavaLlamaForCausalLM.from_pretrained( + model_name, torch_dtype=torch.float16, device_map="auto" + ) # .cuda() + image_processor = CLIPImageProcessor.from_pretrained( + model.config.mm_vision_tower, torch_dtype=torch.float16 + ) mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: - tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + tokenizer.add_tokens( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True + ) vision_tower = model.get_model().vision_tower[0] - if vision_tower.device.type == 'meta': - vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True).cuda() + if vision_tower.device.type == "meta": + vision_tower = CLIPVisionModel.from_pretrained( + vision_tower.config._name_or_path, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + ).cuda() model.get_model().vision_tower[0] = vision_tower else: - vision_tower.to(device='cuda', dtype=torch.float16) + vision_tower.to(device="cuda", dtype=torch.float16) vision_config = vision_tower.config - vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] + vision_config.im_patch_token = tokenizer.convert_tokens_to_ids( + [DEFAULT_IMAGE_PATCH_TOKEN] + )[0] vision_config.use_im_start_end = mm_use_im_start_end if mm_use_im_start_end: - vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) + ( + vision_config.im_start_token, + vision_config.im_end_token, + ) = tokenizer.convert_tokens_to_ids( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN] + ) image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2 qs = args.query if mm_use_im_start_end: - qs = qs + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN + qs = ( + qs + + "\n" + + DEFAULT_IM_START_TOKEN + + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + + DEFAULT_IM_END_TOKEN + ) else: - qs = qs + '\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + qs = qs + "\n" + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len if "v1" in model_name.lower(): conv_mode = "llava_v1" @@ -76,7 +101,11 @@ def eval_model(args): conv_mode = "multimodal" if args.conv_mode is not None and conv_mode != args.conv_mode: - print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) + print( + "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( + conv_mode, args.conv_mode, args.conv_mode + ) + ) else: args.conv_mode = conv_mode @@ -87,7 +116,9 @@ def eval_model(args): inputs = tokenizer([prompt]) image = load_image(args.image_file) - image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + image_tensor = image_processor.preprocess(image, return_tensors="pt")[ + "pixel_values" + ][0] input_ids = torch.as_tensor(inputs.input_ids).cuda() @@ -96,20 +127,34 @@ def eval_model(args): stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) with torch.inference_mode(): - output_ids = model.generate(input_ids, images=image_tensor.unsqueeze(0).half().cuda(), do_sample=True, temperature=0.2, max_new_tokens=1024, stopping_criteria=[stopping_criteria]) + output_ids = model.generate( + input_ids, + images=image_tensor.unsqueeze(0).half().cuda(), + do_sample=True, + temperature=0.2, + max_new_tokens=1024, + stopping_criteria=[stopping_criteria], + ) input_token_len = input_ids.shape[1] n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() if n_diff_input_output > 0: - print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') - outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] + print( + f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids" + ) + outputs = tokenizer.batch_decode( + output_ids[:, input_token_len:], skip_special_tokens=True + )[0] outputs = outputs.strip() if outputs.endswith(stop_str): - outputs = outputs[:-len(stop_str)] + outputs = outputs[: -len(stop_str)] outputs = outputs.strip() print(outputs) - import pdb; pdb.set_trace() + import pdb + + pdb.set_trace() + if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/model/llava/eval/run_llava_batch.py b/model/llava/eval/run_llava_batch.py index 3d186fed0bf698379ae8a0a616318a44d4d36cd5..2ae1f0d7601636a2b434e3c4be7ec93fe2dc33fb 100644 --- a/model/llava/eval/run_llava_batch.py +++ b/model/llava/eval/run_llava_batch.py @@ -1,24 +1,21 @@ import argparse -from transformers import AutoTokenizer, AutoModelForCausalLM -import torch -import os -from llava.conversation import conv_templates, SeparatorStyle -from llava.utils import disable_torch_init -from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria -from llava.model import * -from llava.model.utils import KeywordsStoppingCriteria - -from PIL import Image - +import glob +import json import os -import requests -from PIL import Image from io import BytesIO -import glob import numpy as np -import json +import requests +import torch import tqdm +from llava.conversation import SeparatorStyle, conv_templates +from llava.model import * +from llava.model.utils import KeywordsStoppingCriteria +from llava.utils import disable_torch_init +from PIL import Image +from transformers import (AutoModelForCausalLM, AutoTokenizer, + CLIPImageProcessor, CLIPVisionModel, + StoppingCriteria) DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" @@ -27,42 +24,167 @@ DEFAULT_IM_END_TOKEN = "" def load_image(image_file): - if image_file.startswith('http') or image_file.startswith('https'): + if image_file.startswith("http") or image_file.startswith("https"): response = requests.get(image_file) - image = Image.open(BytesIO(response.content)).convert('RGB') + image = Image.open(BytesIO(response.content)).convert("RGB") else: - image = Image.open(image_file).convert('RGB') + image = Image.open(image_file).convert("RGB") return image -classes = ['wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', - 'bed', 'windowpane', 'grass', 'cabinet', 'sidewalk', - 'person', 'earth', 'door', 'table', 'mountain', 'plant', - 'curtain', 'chair', 'car', 'water', 'painting', 'sofa', - 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair', - 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', - 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', - 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', - 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', - 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', - 'screen door', 'stairway', 'river', 'bridge', 'bookcase', - 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', - 'bench', 'countertop', 'stove', 'palm', 'kitchen island', - 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine', - 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', - 'chandelier', 'awning', 'streetlight', 'booth', - 'television receiver', 'airplane', 'dirt track', 'apparel', - 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', - 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain', - 'conveyer belt', 'canopy', 'washer', 'plaything', - 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', - 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', - 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', - 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket', - 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', - 'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate', - 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', - 'clock', 'flag'] +classes = [ + "wall", + "building", + "sky", + "floor", + "tree", + "ceiling", + "road", + "bed", + "windowpane", + "grass", + "cabinet", + "sidewalk", + "person", + "earth", + "door", + "table", + "mountain", + "plant", + "curtain", + "chair", + "car", + "water", + "painting", + "sofa", + "shelf", + "house", + "sea", + "mirror", + "rug", + "field", + "armchair", + "seat", + "fence", + "desk", + "rock", + "wardrobe", + "lamp", + "bathtub", + "railing", + "cushion", + "base", + "box", + "column", + "signboard", + "chest of drawers", + "counter", + "sand", + "sink", + "skyscraper", + "fireplace", + "refrigerator", + "grandstand", + "path", + "stairs", + "runway", + "case", + "pool table", + "pillow", + "screen door", + "stairway", + "river", + "bridge", + "bookcase", + "blind", + "coffee table", + "toilet", + "flower", + "book", + "hill", + "bench", + "countertop", + "stove", + "palm", + "kitchen island", + "computer", + "swivel chair", + "boat", + "bar", + "arcade machine", + "hovel", + "bus", + "towel", + "light", + "truck", + "tower", + "chandelier", + "awning", + "streetlight", + "booth", + "television receiver", + "airplane", + "dirt track", + "apparel", + "pole", + "land", + "bannister", + "escalator", + "ottoman", + "bottle", + "buffet", + "poster", + "stage", + "van", + "ship", + "fountain", + "conveyer belt", + "canopy", + "washer", + "plaything", + "swimming pool", + "stool", + "barrel", + "basket", + "waterfall", + "tent", + "bag", + "minibike", + "cradle", + "oven", + "ball", + "food", + "step", + "tank", + "trade name", + "microwave", + "pot", + "animal", + "bicycle", + "lake", + "dishwasher", + "screen", + "blanket", + "sculpture", + "hood", + "sconce", + "vase", + "traffic light", + "tray", + "ashcan", + "fan", + "pier", + "crt screen", + "plate", + "monitor", + "bulletin board", + "shower", + "radiator", + "glass", + "clock", + "flag", +] + def eval_model(args): # Model @@ -71,35 +193,58 @@ def eval_model(args): tokenizer = AutoTokenizer.from_pretrained(model_name) if "mpt" in model_name.lower(): - model = LlavaMPTForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda() + model = LlavaMPTForCausalLM.from_pretrained( + model_name, + low_cpu_mem_usage=True, + torch_dtype=torch.float16, + use_cache=True, + ).cuda() else: # model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda() - model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')#.cuda() - image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16) + model = LlavaLlamaForCausalLM.from_pretrained( + model_name, torch_dtype=torch.float16, device_map="auto" + ) # .cuda() + image_processor = CLIPImageProcessor.from_pretrained( + model.config.mm_vision_tower, torch_dtype=torch.float16 + ) mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: - tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + tokenizer.add_tokens( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True + ) vision_tower = model.get_model().vision_tower[0] - if vision_tower.device.type == 'meta': - vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True).cuda() + if vision_tower.device.type == "meta": + vision_tower = CLIPVisionModel.from_pretrained( + vision_tower.config._name_or_path, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + ).cuda() model.get_model().vision_tower[0] = vision_tower else: - vision_tower.to(device='cuda', dtype=torch.float16) + vision_tower.to(device="cuda", dtype=torch.float16) vision_config = vision_tower.config - vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] + vision_config.im_patch_token = tokenizer.convert_tokens_to_ids( + [DEFAULT_IMAGE_PATCH_TOKEN] + )[0] vision_config.use_im_start_end = mm_use_im_start_end if mm_use_im_start_end: - vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) + ( + vision_config.im_start_token, + vision_config.im_end_token, + ) = tokenizer.convert_tokens_to_ids( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN] + ) image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2 # paths for all images - images = sorted(glob.glob("/mnt/proj74/xinlai/dataset/ade20k/images/training/*.jpg")) + images = sorted( + glob.glob("/mnt/proj74/xinlai/dataset/ade20k/images/training/*.jpg") + ) results = [] for i, image_file in enumerate(tqdm.tqdm(images)): - # if i == 2: # break @@ -109,7 +254,9 @@ def eval_model(args): print("i: {}, len(images): {}".format(i, len(images))) image = load_image(image_file) - image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + image_tensor = image_processor.preprocess(image, return_tensors="pt")[ + "pixel_values" + ][0] image_tensor = image_tensor.unsqueeze(0).half().cuda() label_file = image_file.replace("images", "annotations").replace(".jpg", ".png") @@ -126,9 +273,15 @@ def eval_model(args): # qs = args.query if mm_use_im_start_end: - qs = qs + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN + qs = ( + qs + + "\n" + + DEFAULT_IM_START_TOKEN + + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + + DEFAULT_IM_END_TOKEN + ) else: - qs = qs + '\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + qs = qs + "\n" + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len if "v1" in model_name.lower(): conv_mode = "llava_v1" @@ -138,7 +291,11 @@ def eval_model(args): conv_mode = "multimodal" if args.conv_mode is not None and conv_mode != args.conv_mode: - print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) + print( + "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( + conv_mode, args.conv_mode, args.conv_mode + ) + ) else: args.conv_mode = conv_mode @@ -164,27 +321,41 @@ def eval_model(args): images=image_tensor, do_sample=True, temperature=0.2, - max_new_tokens=512, #1024, - stopping_criteria=[stopping_criteria]) + max_new_tokens=512, # 1024, + stopping_criteria=[stopping_criteria], + ) input_token_len = input_ids.shape[1] - n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() + n_diff_input_output = ( + (input_ids != output_ids[:, :input_token_len]).sum().item() + ) if n_diff_input_output > 0: - print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') - outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] + print( + f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids" + ) + outputs = tokenizer.batch_decode( + output_ids[:, input_token_len:], skip_special_tokens=True + )[0] outputs = outputs.strip() if outputs.endswith(stop_str): - outputs = outputs[:-len(stop_str)] + outputs = outputs[: -len(stop_str)] outputs = outputs.strip() print("qs: {}, output: {}, image_file: {}".format(qs, outputs, image_file)) - results.append({'image_id': image_file.split("/")[-1], 'input': input_conv, 'output': outputs}) + results.append( + { + "image_id": image_file.split("/")[-1], + "input": input_conv, + "output": outputs, + } + ) with open("/mnt/proj74/xinlai/LLM/LLaVA/ade20k_conversations.json", "w+") as f: json.dump(results, f) - # print(outputs) + # print(outputs) + if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/model/llava/eval/run_llava_batch_v2.py b/model/llava/eval/run_llava_batch_v2.py index 01d1d60d4f22b96b60c369cc8f4aa6b969ec3844..0700e30cd72c7d00b245ee099a089ffd598a5351 100644 --- a/model/llava/eval/run_llava_batch_v2.py +++ b/model/llava/eval/run_llava_batch_v2.py @@ -1,24 +1,21 @@ import argparse -from transformers import AutoTokenizer, AutoModelForCausalLM -import torch -import os -from llava.conversation import conv_templates, SeparatorStyle -from llava.utils import disable_torch_init -from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria -from llava.model import * -from llava.model.utils import KeywordsStoppingCriteria - -from PIL import Image - +import glob +import json import os -import requests -from PIL import Image from io import BytesIO -import glob import numpy as np -import json +import requests +import torch import tqdm +from llava.conversation import SeparatorStyle, conv_templates +from llava.model import * +from llava.model.utils import KeywordsStoppingCriteria +from llava.utils import disable_torch_init +from PIL import Image +from transformers import (AutoModelForCausalLM, AutoTokenizer, + CLIPImageProcessor, CLIPVisionModel, + StoppingCriteria) DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" @@ -27,42 +24,167 @@ DEFAULT_IM_END_TOKEN = "" def load_image(image_file): - if image_file.startswith('http') or image_file.startswith('https'): + if image_file.startswith("http") or image_file.startswith("https"): response = requests.get(image_file) - image = Image.open(BytesIO(response.content)).convert('RGB') + image = Image.open(BytesIO(response.content)).convert("RGB") else: - image = Image.open(image_file).convert('RGB') + image = Image.open(image_file).convert("RGB") return image -classes = ['wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', - 'bed', 'windowpane', 'grass', 'cabinet', 'sidewalk', - 'person', 'earth', 'door', 'table', 'mountain', 'plant', - 'curtain', 'chair', 'car', 'water', 'painting', 'sofa', - 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair', - 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', - 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', - 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', - 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', - 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', - 'screen door', 'stairway', 'river', 'bridge', 'bookcase', - 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', - 'bench', 'countertop', 'stove', 'palm', 'kitchen island', - 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine', - 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', - 'chandelier', 'awning', 'streetlight', 'booth', - 'television receiver', 'airplane', 'dirt track', 'apparel', - 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', - 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain', - 'conveyer belt', 'canopy', 'washer', 'plaything', - 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', - 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', - 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', - 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket', - 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', - 'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate', - 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', - 'clock', 'flag'] +classes = [ + "wall", + "building", + "sky", + "floor", + "tree", + "ceiling", + "road", + "bed", + "windowpane", + "grass", + "cabinet", + "sidewalk", + "person", + "earth", + "door", + "table", + "mountain", + "plant", + "curtain", + "chair", + "car", + "water", + "painting", + "sofa", + "shelf", + "house", + "sea", + "mirror", + "rug", + "field", + "armchair", + "seat", + "fence", + "desk", + "rock", + "wardrobe", + "lamp", + "bathtub", + "railing", + "cushion", + "base", + "box", + "column", + "signboard", + "chest of drawers", + "counter", + "sand", + "sink", + "skyscraper", + "fireplace", + "refrigerator", + "grandstand", + "path", + "stairs", + "runway", + "case", + "pool table", + "pillow", + "screen door", + "stairway", + "river", + "bridge", + "bookcase", + "blind", + "coffee table", + "toilet", + "flower", + "book", + "hill", + "bench", + "countertop", + "stove", + "palm", + "kitchen island", + "computer", + "swivel chair", + "boat", + "bar", + "arcade machine", + "hovel", + "bus", + "towel", + "light", + "truck", + "tower", + "chandelier", + "awning", + "streetlight", + "booth", + "television receiver", + "airplane", + "dirt track", + "apparel", + "pole", + "land", + "bannister", + "escalator", + "ottoman", + "bottle", + "buffet", + "poster", + "stage", + "van", + "ship", + "fountain", + "conveyer belt", + "canopy", + "washer", + "plaything", + "swimming pool", + "stool", + "barrel", + "basket", + "waterfall", + "tent", + "bag", + "minibike", + "cradle", + "oven", + "ball", + "food", + "step", + "tank", + "trade name", + "microwave", + "pot", + "animal", + "bicycle", + "lake", + "dishwasher", + "screen", + "blanket", + "sculpture", + "hood", + "sconce", + "vase", + "traffic light", + "tray", + "ashcan", + "fan", + "pier", + "crt screen", + "plate", + "monitor", + "bulletin board", + "shower", + "radiator", + "glass", + "clock", + "flag", +] + def eval_model(args): # Model @@ -71,35 +193,58 @@ def eval_model(args): tokenizer = AutoTokenizer.from_pretrained(model_name) if "mpt" in model_name.lower(): - model = LlavaMPTForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda() + model = LlavaMPTForCausalLM.from_pretrained( + model_name, + low_cpu_mem_usage=True, + torch_dtype=torch.float16, + use_cache=True, + ).cuda() else: # model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda() - model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')#.cuda() - image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16) + model = LlavaLlamaForCausalLM.from_pretrained( + model_name, torch_dtype=torch.float16, device_map="auto" + ) # .cuda() + image_processor = CLIPImageProcessor.from_pretrained( + model.config.mm_vision_tower, torch_dtype=torch.float16 + ) mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: - tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + tokenizer.add_tokens( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True + ) vision_tower = model.get_model().vision_tower[0] - if vision_tower.device.type == 'meta': - vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True).cuda() + if vision_tower.device.type == "meta": + vision_tower = CLIPVisionModel.from_pretrained( + vision_tower.config._name_or_path, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + ).cuda() model.get_model().vision_tower[0] = vision_tower # else: - # vision_tower.to(device='cuda', dtype=torch.float16) + # vision_tower.to(device='cuda', dtype=torch.float16) vision_config = vision_tower.config - vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] + vision_config.im_patch_token = tokenizer.convert_tokens_to_ids( + [DEFAULT_IMAGE_PATCH_TOKEN] + )[0] vision_config.use_im_start_end = mm_use_im_start_end if mm_use_im_start_end: - vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) + ( + vision_config.im_start_token, + vision_config.im_end_token, + ) = tokenizer.convert_tokens_to_ids( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN] + ) image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2 # paths for all images - images = sorted(glob.glob("/mnt/proj74/xinlai/dataset/ade20k/images/training/*.jpg")) + images = sorted( + glob.glob("/mnt/proj74/xinlai/dataset/ade20k/images/training/*.jpg") + ) results = [] for i, image_file in enumerate(tqdm.tqdm(images)): - # if i == 2: # break @@ -115,7 +260,9 @@ def eval_model(args): label_unique = np.unique(label) image = load_image(image_file) - image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + image_tensor = image_processor.preprocess(image, return_tensors="pt")[ + "pixel_values" + ][0] image_tensor = image_tensor.unsqueeze(0).half().cuda() for label in label_unique: @@ -128,9 +275,15 @@ def eval_model(args): # qs = args.query if mm_use_im_start_end: - qs = qs + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN + qs = ( + qs + + "\n" + + DEFAULT_IM_START_TOKEN + + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + + DEFAULT_IM_END_TOKEN + ) else: - qs = qs + '\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + qs = qs + "\n" + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len if "v1" in model_name.lower(): conv_mode = "llava_v1" @@ -140,7 +293,11 @@ def eval_model(args): conv_mode = "multimodal" if args.conv_mode is not None and conv_mode != args.conv_mode: - print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) + print( + "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( + conv_mode, args.conv_mode, args.conv_mode + ) + ) else: args.conv_mode = conv_mode @@ -173,32 +330,46 @@ def eval_model(args): images=image_tensor, # do_sample=True, # temperature=0.2, - max_new_tokens=512, #1024, - stopping_criteria=[stopping_criteria]) + max_new_tokens=512, # 1024, + stopping_criteria=[stopping_criteria], + ) input_token_len = input_ids.shape[1] - n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() + n_diff_input_output = ( + (input_ids != output_ids[:, :input_token_len]).sum().item() + ) if n_diff_input_output > 0: - print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') + print( + f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids" + ) outputs_list = [] for output_id in output_ids: - outputs = tokenizer.batch_decode(output_id[:, input_token_len:], skip_special_tokens=True)[0] + outputs = tokenizer.batch_decode( + output_id[:, input_token_len:], skip_special_tokens=True + )[0] outputs = outputs.strip() if outputs.endswith(stop_str): - outputs = outputs[:-len(stop_str)] + outputs = outputs[: -len(stop_str)] outputs = outputs.strip() outputs_list.append(outputs) for qs, outputs in zip(prompt_list, outputs_list): print("qs: {}, output: {}, image_file: {}".format(qs, outputs, image_file)) - results.append({'image_id': image_file.split("/")[-1], 'input': prompt_list, 'output': outputs_list}) + results.append( + { + "image_id": image_file.split("/")[-1], + "input": prompt_list, + "output": outputs_list, + } + ) with open("/mnt/proj74/xinlai/LLM/LLaVA/ade20k_conversations.json", "w+") as f: json.dump(results, f) - # print(outputs) + # print(outputs) + if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/model/llava/eval/run_llava_batch_v3.py b/model/llava/eval/run_llava_batch_v3.py index 2db02cecc796622d644d56ad344e197d182232e8..c89d007f1ac369f1380d7b0cd291f535b430d3f4 100644 --- a/model/llava/eval/run_llava_batch_v3.py +++ b/model/llava/eval/run_llava_batch_v3.py @@ -1,24 +1,21 @@ import argparse -from transformers import AutoTokenizer, AutoModelForCausalLM -import torch -import os -from llava.conversation import conv_templates, SeparatorStyle -from llava.utils import disable_torch_init -from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria -from llava.model import * -from llava.model.utils import KeywordsStoppingCriteria - -from PIL import Image - +import glob +import json import os -import requests -from PIL import Image from io import BytesIO -import glob import numpy as np -import json +import requests +import torch import tqdm +from llava.conversation import SeparatorStyle, conv_templates +from llava.model import * +from llava.model.utils import KeywordsStoppingCriteria +from llava.utils import disable_torch_init +from PIL import Image +from transformers import (AutoModelForCausalLM, AutoTokenizer, + CLIPImageProcessor, CLIPVisionModel, + StoppingCriteria) DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" @@ -27,42 +24,167 @@ DEFAULT_IM_END_TOKEN = "" def load_image(image_file): - if image_file.startswith('http') or image_file.startswith('https'): + if image_file.startswith("http") or image_file.startswith("https"): response = requests.get(image_file) - image = Image.open(BytesIO(response.content)).convert('RGB') + image = Image.open(BytesIO(response.content)).convert("RGB") else: - image = Image.open(image_file).convert('RGB') + image = Image.open(image_file).convert("RGB") return image -classes = ['wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', - 'bed', 'windowpane', 'grass', 'cabinet', 'sidewalk', - 'person', 'earth', 'door', 'table', 'mountain', 'plant', - 'curtain', 'chair', 'car', 'water', 'painting', 'sofa', - 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair', - 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', - 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', - 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', - 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', - 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', - 'screen door', 'stairway', 'river', 'bridge', 'bookcase', - 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', - 'bench', 'countertop', 'stove', 'palm', 'kitchen island', - 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine', - 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', - 'chandelier', 'awning', 'streetlight', 'booth', - 'television receiver', 'airplane', 'dirt track', 'apparel', - 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', - 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain', - 'conveyer belt', 'canopy', 'washer', 'plaything', - 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', - 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', - 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', - 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket', - 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', - 'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate', - 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', - 'clock', 'flag'] +classes = [ + "wall", + "building", + "sky", + "floor", + "tree", + "ceiling", + "road", + "bed", + "windowpane", + "grass", + "cabinet", + "sidewalk", + "person", + "earth", + "door", + "table", + "mountain", + "plant", + "curtain", + "chair", + "car", + "water", + "painting", + "sofa", + "shelf", + "house", + "sea", + "mirror", + "rug", + "field", + "armchair", + "seat", + "fence", + "desk", + "rock", + "wardrobe", + "lamp", + "bathtub", + "railing", + "cushion", + "base", + "box", + "column", + "signboard", + "chest of drawers", + "counter", + "sand", + "sink", + "skyscraper", + "fireplace", + "refrigerator", + "grandstand", + "path", + "stairs", + "runway", + "case", + "pool table", + "pillow", + "screen door", + "stairway", + "river", + "bridge", + "bookcase", + "blind", + "coffee table", + "toilet", + "flower", + "book", + "hill", + "bench", + "countertop", + "stove", + "palm", + "kitchen island", + "computer", + "swivel chair", + "boat", + "bar", + "arcade machine", + "hovel", + "bus", + "towel", + "light", + "truck", + "tower", + "chandelier", + "awning", + "streetlight", + "booth", + "television receiver", + "airplane", + "dirt track", + "apparel", + "pole", + "land", + "bannister", + "escalator", + "ottoman", + "bottle", + "buffet", + "poster", + "stage", + "van", + "ship", + "fountain", + "conveyer belt", + "canopy", + "washer", + "plaything", + "swimming pool", + "stool", + "barrel", + "basket", + "waterfall", + "tent", + "bag", + "minibike", + "cradle", + "oven", + "ball", + "food", + "step", + "tank", + "trade name", + "microwave", + "pot", + "animal", + "bicycle", + "lake", + "dishwasher", + "screen", + "blanket", + "sculpture", + "hood", + "sconce", + "vase", + "traffic light", + "tray", + "ashcan", + "fan", + "pier", + "crt screen", + "plate", + "monitor", + "bulletin board", + "shower", + "radiator", + "glass", + "clock", + "flag", +] + def eval_model(args): # Model @@ -71,38 +193,61 @@ def eval_model(args): tokenizer = AutoTokenizer.from_pretrained(model_name) if "mpt" in model_name.lower(): - model = LlavaMPTForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda() + model = LlavaMPTForCausalLM.from_pretrained( + model_name, + low_cpu_mem_usage=True, + torch_dtype=torch.float16, + use_cache=True, + ).cuda() else: # model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda() - model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')#.cuda() - image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16) + model = LlavaLlamaForCausalLM.from_pretrained( + model_name, torch_dtype=torch.float16, device_map="auto" + ) # .cuda() + image_processor = CLIPImageProcessor.from_pretrained( + model.config.mm_vision_tower, torch_dtype=torch.float16 + ) mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: - tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + tokenizer.add_tokens( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True + ) vision_tower = model.get_model().vision_tower[0] - if vision_tower.device.type == 'meta': - vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True).cuda() + if vision_tower.device.type == "meta": + vision_tower = CLIPVisionModel.from_pretrained( + vision_tower.config._name_or_path, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + ).cuda() model.get_model().vision_tower[0] = vision_tower else: - vision_tower.to(device='cuda', dtype=torch.float16) + vision_tower.to(device="cuda", dtype=torch.float16) vision_config = vision_tower.config - vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] + vision_config.im_patch_token = tokenizer.convert_tokens_to_ids( + [DEFAULT_IMAGE_PATCH_TOKEN] + )[0] vision_config.use_im_start_end = mm_use_im_start_end if mm_use_im_start_end: - vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) + ( + vision_config.im_start_token, + vision_config.im_end_token, + ) = tokenizer.convert_tokens_to_ids( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN] + ) image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2 # paths for all images - images = sorted(glob.glob("/mnt/proj74/xinlai/dataset/ade20k/images/training/*.jpg")) + images = sorted( + glob.glob("/mnt/proj74/xinlai/dataset/ade20k/images/training/*.jpg") + ) start, end = args.range.split(",") start, end = int(start), int(end) images = images[start:end] results = [] for i, image_file in enumerate(tqdm.tqdm(images)): - # if i == 2: # break @@ -112,7 +257,9 @@ def eval_model(args): print("i: {}, len(images): {}".format(i, len(images))) image = load_image(image_file) - image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + image_tensor = image_processor.preprocess(image, return_tensors="pt")[ + "pixel_values" + ][0] image_tensor = image_tensor.unsqueeze(0).half().cuda() prompt_list = [] @@ -133,9 +280,15 @@ def eval_model(args): # qs = args.query if mm_use_im_start_end: - qs = qs + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN + qs = ( + qs + + "\n" + + DEFAULT_IM_START_TOKEN + + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + + DEFAULT_IM_END_TOKEN + ) else: - qs = qs + '\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + qs = qs + "\n" + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len if "v1" in model_name.lower(): conv_mode = "llava_v1" @@ -145,7 +298,11 @@ def eval_model(args): conv_mode = "multimodal" if args.conv_mode is not None and conv_mode != args.conv_mode: - print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) + print( + "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( + conv_mode, args.conv_mode, args.conv_mode + ) + ) else: args.conv_mode = conv_mode @@ -171,17 +328,24 @@ def eval_model(args): images=image_tensor, do_sample=True, temperature=0.2, - max_new_tokens=512, #1024, - stopping_criteria=[stopping_criteria]) + max_new_tokens=512, # 1024, + stopping_criteria=[stopping_criteria], + ) input_token_len = input_ids.shape[1] - n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() + n_diff_input_output = ( + (input_ids != output_ids[:, :input_token_len]).sum().item() + ) if n_diff_input_output > 0: - print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') - outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] + print( + f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids" + ) + outputs = tokenizer.batch_decode( + output_ids[:, input_token_len:], skip_special_tokens=True + )[0] outputs = outputs.strip() if outputs.endswith(stop_str): - outputs = outputs[:-len(stop_str)] + outputs = outputs[: -len(stop_str)] outputs = outputs.strip() # print("qs: {}, output: {}, image_file: {}".format(qs, outputs, image_file)) @@ -190,13 +354,23 @@ def eval_model(args): # results.append({'image_id': image_id, 'input': input_conv, 'output': outputs}) output_list.append(outputs) image_id = image_file.split("/")[-1].split(".")[0] - with open("/mnt/proj74/xinlai/LLM/LLaVA/generated/{}.json".format(image_id), "w+") as f: - json.dump({'image_id': image_id, 'input_list': prompt_list, 'output_list': output_list}, f) + with open( + "/mnt/proj74/xinlai/LLM/LLaVA/generated/{}.json".format(image_id), "w+" + ) as f: + json.dump( + { + "image_id": image_id, + "input_list": prompt_list, + "output_list": output_list, + }, + f, + ) # with open("/mnt/proj74/xinlai/LLM/LLaVA/ade20k_conversations.json", "w+") as f: # json.dump(results, f) - # print(outputs) + # print(outputs) + if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/model/llava/eval/summarize_gpt_review.py b/model/llava/eval/summarize_gpt_review.py index f871f7f820868a1271b2466832768bcda10c38cc..b60cb1baed0e1ea4850dd768cc7d1ae088eed62b 100644 --- a/model/llava/eval/summarize_gpt_review.py +++ b/model/llava/eval/summarize_gpt_review.py @@ -4,23 +4,25 @@ from collections import defaultdict import numpy as np - -if __name__ == '__main__': +if __name__ == "__main__": base_dir = "vqa/reviews/coco2014_val80" - review_files = [x for x in os.listdir(base_dir) if x.endswith('.jsonl') and x.startswith('gpt4_text')] + review_files = [ + x + for x in os.listdir(base_dir) + if x.endswith(".jsonl") and x.startswith("gpt4_text") + ] for review_file in sorted(review_files): - config = review_file.replace('gpt4_text_', '').replace('.jsonl', '') + config = review_file.replace("gpt4_text_", "").replace(".jsonl", "") scores = defaultdict(list) - print(f'GPT-4 vs. {config}') + print(f"GPT-4 vs. {config}") with open(os.path.join(base_dir, review_file)) as f: for review_str in f: review = json.loads(review_str) - scores[review['category']].append(review['tuple']) - scores['all'].append(review['tuple']) + scores[review["category"]].append(review["tuple"]) + scores["all"].append(review["tuple"]) for k, v in scores.items(): stats = np.asarray(v).mean(0).tolist() stats = [round(x, 3) for x in stats] - print(k, stats, round(stats[1]/stats[0]*100, 1)) - print('=================================') - + print(k, stats, round(stats[1] / stats[0] * 100, 1)) + print("=================================") diff --git a/model/llava/model/__init__.py b/model/llava/model/__init__.py index ceb04c4ccfbc6b405412b9d514d8e66d93e06913..153b079602a72de7793bb7f809ffbb5acdf846ee 100644 --- a/model/llava/model/__init__.py +++ b/model/llava/model/__init__.py @@ -1,2 +1,2 @@ -from .llava import LlavaLlamaForCausalLM, LlavaConfig -from .llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig +from .llava import LlavaConfig, LlavaLlamaForCausalLM +from .llava_mpt import LlavaMPTConfig, LlavaMPTForCausalLM diff --git a/model/llava/model/apply_delta.py b/model/llava/model/apply_delta.py index 666dd9691bde7d54ddf2871e311d6f621e29f099..2f73809262b001a1f16ca3302cd75ab30893486a 100644 --- a/model/llava/model/apply_delta.py +++ b/model/llava/model/apply_delta.py @@ -5,32 +5,40 @@ python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~ import argparse import torch -from tqdm import tqdm -from transformers import AutoTokenizer, AutoModelForCausalLM from llava import LlavaLlamaForCausalLM +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer def apply_delta(base_model_path, target_model_path, delta_path): print("Loading base model") base = AutoModelForCausalLM.from_pretrained( - base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) print("Loading delta") - delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + delta = LlavaLlamaForCausalLM.from_pretrained( + delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) print("Applying delta") for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): if name not in base.state_dict(): - assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' + assert name in [ + "model.mm_projector.weight", + "model.mm_projector.bias", + ], f"{name} not in base model" continue if param.data.shape == base.state_dict()[name].shape: param.data += base.state_dict()[name] else: - assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ - f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' + assert name in [ + "model.embed_tokens.weight", + "lm_head.weight", + ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" bparam = base.state_dict()[name] - param.data[:bparam.shape[0], :bparam.shape[1]] += bparam + param.data[: bparam.shape[0], : bparam.shape[1]] += bparam print("Saving target model") delta.save_pretrained(target_model_path) diff --git a/model/llava/model/consolidate.py b/model/llava/model/consolidate.py index ee93cf6981e849c7ea28dfc0201e9b5368ae6803..e49d9796e069343c51465a25335f266d654e03db 100644 --- a/model/llava/model/consolidate.py +++ b/model/llava/model/consolidate.py @@ -5,15 +5,17 @@ python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_ import argparse import torch -from transformers import AutoTokenizer, AutoModelForCausalLM from llava.model import * from llava.model.utils import auto_upgrade +from transformers import AutoModelForCausalLM, AutoTokenizer def consolidate_ckpt(src_path, dst_path): print("Loading model") auto_upgrade(src_path) - src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + src_model = AutoModelForCausalLM.from_pretrained( + src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) src_tokenizer = AutoTokenizer.from_pretrained(src_path) src_model.save_pretrained(dst_path) src_tokenizer.save_pretrained(dst_path) diff --git a/model/llava/model/llava.py b/model/llava/model/llava.py index 5fe2e2dec333a99fa1bab58c4657a0b45799c5ec..665b284e3e59cb396b14ad7cb5f21b611a54223d 100644 --- a/model/llava/model/llava.py +++ b/model/llava/model/llava.py @@ -19,13 +19,11 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss - -from transformers import AutoConfig, AutoModelForCausalLM, \ - LlamaConfig, LlamaModel, LlamaForCausalLM, \ - CLIPVisionModel, CLIPImageProcessor - -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast - +from transformers import (AutoConfig, AutoModelForCausalLM, CLIPImageProcessor, + CLIPVisionModel, LlamaConfig, LlamaForCausalLM, + LlamaModel) +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" @@ -45,25 +43,33 @@ class LlavaLlamaModel(LlamaModel): if hasattr(config, "mm_vision_tower"): # HACK: for FSDP - self.vision_tower = [CLIPVisionModel.from_pretrained(config.mm_vision_tower)] + self.vision_tower = [ + CLIPVisionModel.from_pretrained(config.mm_vision_tower) + ] if hasattr(config, "use_mm_proj"): self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size) - def initialize_vision_modules(self, vision_tower, mm_vision_select_layer, - pretrain_mm_mlp_adapter=None, tune_mm_mlp_adapter=False, precision='bf16'): + def initialize_vision_modules( + self, + vision_tower, + mm_vision_select_layer, + pretrain_mm_mlp_adapter=None, + tune_mm_mlp_adapter=False, + precision="bf16", + ): self.config.mm_vision_tower = vision_tower image_processor = CLIPImageProcessor.from_pretrained(vision_tower) - if not hasattr(self, 'vision_tower'): + if not hasattr(self, "vision_tower"): vision_tower = CLIPVisionModel.from_pretrained(vision_tower) else: vision_tower = self.vision_tower[0] vision_tower.requires_grad_(False) - if precision == 'bf16': + if precision == "bf16": vision_tower = vision_tower.to(torch.bfloat16) - elif precision == 'fp16': + elif precision == "fp16": vision_tower = vision_tower.to(torch.half) else: vision_tower = vision_tower.to(torch.float32) @@ -77,17 +83,23 @@ class LlavaLlamaModel(LlamaModel): self.config.mm_hidden_size = vision_config.hidden_size self.config.mm_vision_select_layer = mm_vision_select_layer - if not hasattr(self, 'mm_projector'): - self.mm_projector = nn.Linear(vision_config.hidden_size, self.config.hidden_size) + if not hasattr(self, "mm_projector"): + self.mm_projector = nn.Linear( + vision_config.hidden_size, self.config.hidden_size + ) if pretrain_mm_mlp_adapter is not None: - mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') - self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()}) + mm_projector_weights = torch.load( + pretrain_mm_mlp_adapter, map_location="cpu" + ) + self.mm_projector.load_state_dict( + {k.split(".")[-1]: v for k, v in mm_projector_weights.items()} + ) return dict( image_processor=image_processor, image_token_len=num_patches, - vision_config=vision_config + vision_config=vision_config, ) def forward( @@ -102,15 +114,18 @@ class LlavaLlamaModel(LlamaModel): images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - # HACK: replace back original embeddings for LLaVA pretraining - orig_embeds_params = getattr(self, 'orig_embeds_params', None) - + orig_embeds_params = getattr(self, "orig_embeds_params", None) + if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - vision_tower = getattr(self, 'vision_tower', None) - if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None: + vision_tower = getattr(self, "vision_tower", None) + if ( + vision_tower is not None + and (input_ids.shape[1] != 1 or self.training) + and images is not None + ): # TODO: this is a modified multimodal LLM -- Haotian Liu vision_tower = vision_tower[0] # HACK: for FSDP with torch.no_grad(): @@ -118,26 +133,41 @@ class LlavaLlamaModel(LlamaModel): # variable length images image_features = [] for image in images: - image_forward_out = vision_tower(image.unsqueeze(0), output_hidden_states=True) - select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1) - select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer] + image_forward_out = vision_tower( + image.unsqueeze(0), output_hidden_states=True + ) + select_hidden_state_layer = getattr( + self.config, "mm_vision_select_layer", -1 + ) + select_hidden_state = image_forward_out.hidden_states[ + select_hidden_state_layer + ] image_feature = select_hidden_state[:, 1:] image_feature = image_feature.contiguous() image_features.append(image_feature) torch.cuda.empty_cache() else: image_forward_outs = vision_tower(images, output_hidden_states=True) - select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1) - select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer] + select_hidden_state_layer = getattr( + self.config, "mm_vision_select_layer", -1 + ) + select_hidden_state = image_forward_outs.hidden_states[ + select_hidden_state_layer + ] image_features = select_hidden_state[:, 1:] image_features = image_features.contiguous() torch.cuda.empty_cache() if type(images) is list: - image_features = [self.mm_projector(image_feature)[0] for image_feature in image_features] + image_features = [ + self.mm_projector(image_feature)[0] + for image_feature in image_features + ] else: image_features = self.mm_projector(image_features) - dummy_image_features = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype) + dummy_image_features = torch.zeros( + 256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) dummy_image_features = self.mm_projector(dummy_image_features) new_input_embeds = [] @@ -145,48 +175,128 @@ class LlavaLlamaModel(LlamaModel): for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds): if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0: # multimodal LLM, but the current sample is not multimodal - cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum() + cur_input_embeds = ( + cur_input_embeds + (0.0 * dummy_image_features).sum() + ) new_input_embeds.append(cur_input_embeds) cur_image_idx += 1 continue if vision_tower.config.use_im_start_end: cur_image_features = image_features[cur_image_idx] num_patches = cur_image_features.shape[0] - if (cur_input_ids == vision_tower.config.im_start_token).sum() != (cur_input_ids == vision_tower.config.im_end_token).sum(): - raise ValueError("The number of image start tokens and image end tokens should be the same.") - image_start_tokens = torch.where(cur_input_ids == vision_tower.config.im_start_token)[0] + if (cur_input_ids == vision_tower.config.im_start_token).sum() != ( + cur_input_ids == vision_tower.config.im_end_token + ).sum(): + raise ValueError( + "The number of image start tokens and image end tokens should be the same." + ) + image_start_tokens = torch.where( + cur_input_ids == vision_tower.config.im_start_token + )[0] for image_start_token_pos in image_start_tokens: - cur_image_features = image_features[cur_image_idx].to(device=cur_input_embeds.device) + cur_image_features = image_features[cur_image_idx].to( + device=cur_input_embeds.device + ) num_patches = cur_image_features.shape[0] - if cur_input_ids[image_start_token_pos + num_patches + 1] != vision_tower.config.im_end_token: - raise ValueError("The image end token should follow the image start token.") + if ( + cur_input_ids[image_start_token_pos + num_patches + 1] + != vision_tower.config.im_end_token + ): + raise ValueError( + "The image end token should follow the image start token." + ) if orig_embeds_params is not None: - cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0) + cur_new_input_embeds = torch.cat( + ( + cur_input_embeds[:image_start_token_pos].detach(), + cur_input_embeds[ + image_start_token_pos : image_start_token_pos + + 1 + ], + cur_image_features, + cur_input_embeds[ + image_start_token_pos + + num_patches + + 1 : image_start_token_pos + + num_patches + + 2 + ], + cur_input_embeds[ + image_start_token_pos + num_patches + 2 : + ].detach(), + ), + dim=0, + ) else: - cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0) + cur_new_input_embeds = torch.cat( + ( + cur_input_embeds[: image_start_token_pos + 1], + cur_image_features, + cur_input_embeds[ + image_start_token_pos + num_patches + 1 : + ], + ), + dim=0, + ) cur_image_idx += 1 new_input_embeds.append(cur_new_input_embeds) else: cur_image_features = image_features[cur_image_idx] num_patches = cur_image_features.shape[0] - if (cur_input_ids == vision_tower.config.im_patch_token).sum() != num_patches: - raise ValueError("The number of image patch tokens should be the same as the number of image patches.") - masked_indices = torch.where(cur_input_ids == vision_tower.config.im_patch_token)[0] + if ( + cur_input_ids == vision_tower.config.im_patch_token + ).sum() != num_patches: + raise ValueError( + "The number of image patch tokens should be the same as the number of image patches." + ) + masked_indices = torch.where( + cur_input_ids == vision_tower.config.im_patch_token + )[0] mask_index_start = masked_indices[0] - if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any(): - raise ValueError("The image patch tokens should be consecutive.") + if ( + masked_indices + != torch.arange( + mask_index_start, + mask_index_start + num_patches, + device=masked_indices.device, + dtype=masked_indices.dtype, + ) + ).any(): + raise ValueError( + "The image patch tokens should be consecutive." + ) if orig_embeds_params is not None: - cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(), cur_image_features, cur_input_embeds[mask_index_start+num_patches:].detach()), dim=0) + cur_new_input_embeds = torch.cat( + ( + cur_input_embeds[:mask_index_start].detach(), + cur_image_features, + cur_input_embeds[ + mask_index_start + num_patches : + ].detach(), + ), + dim=0, + ) else: - cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0) + cur_new_input_embeds = torch.cat( + ( + cur_input_embeds[:mask_index_start], + cur_image_features, + cur_input_embeds[mask_index_start + num_patches :], + ), + dim=0, + ) new_input_embeds.append(cur_new_input_embeds) cur_image_idx += 1 inputs_embeds = torch.stack(new_input_embeds, dim=0) return super(LlavaLlamaModel, self).forward( - input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, - inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict + input_ids=None, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, ) @@ -218,11 +328,19 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM): images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( @@ -234,7 +352,7 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - images=images + images=images, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) @@ -269,7 +387,12 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM): ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, ): if past_key_values: input_ids = input_ids[:, -1:] @@ -291,16 +414,28 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM): ) return model_inputs - def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, num_new_tokens, device, - tune_mm_mlp_adapter=False, pretrain_mm_mlp_adapter=None): + def initialize_vision_tokenizer( + self, + mm_use_im_start_end, + tokenizer, + num_new_tokens, + device, + tune_mm_mlp_adapter=False, + pretrain_mm_mlp_adapter=None, + ): vision_config = self.get_model().vision_tower[0].config vision_config.use_im_start_end = mm_use_im_start_end if mm_use_im_start_end: # num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) - + # self.resize_token_embeddings(len(tokenizer)) - vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) + ( + vision_config.im_start_token, + vision_config.im_end_token, + ) = tokenizer.convert_tokens_to_ids( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN] + ) # if num_new_tokens > 0: # input_embeddings = self.get_input_embeddings().weight.data @@ -315,24 +450,35 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM): # output_embeddings[-num_new_tokens:] = output_embeddings_avg if tune_mm_mlp_adapter: - self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)] + self.get_model().orig_embeds_params = [ + self.get_input_embeddings().weight.data.clone().to(device=device) + ] for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = False if pretrain_mm_mlp_adapter: - mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') - embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] + mm_projector_weights = torch.load( + pretrain_mm_mlp_adapter, map_location="cpu" + ) + embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"] assert num_new_tokens == 2 if input_embeddings.shape == embed_tokens_weight.shape: - input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] + input_embeddings[-num_new_tokens:] = embed_tokens_weight[ + -num_new_tokens: + ] elif embed_tokens_weight.shape[0] == num_new_tokens: input_embeddings[-num_new_tokens:] = embed_tokens_weight else: - raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") + raise ValueError( + f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}." + ) + + vision_config.im_patch_token = tokenizer.convert_tokens_to_ids( + [DEFAULT_IMAGE_PATCH_TOKEN] + )[0] - vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] AutoConfig.register("llava", LlavaConfig) AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) diff --git a/model/llava/model/llava_mpt.py b/model/llava/model/llava_mpt.py index 4b2f982108f03451934d2d1efdf4e459bca71958..4a638b68933f5930036d6f4b5fa46684e2e5cc5c 100644 --- a/model/llava/model/llava_mpt.py +++ b/model/llava/model/llava_mpt.py @@ -13,24 +13,21 @@ # limitations under the License. -from typing import List, Optional, Tuple, Union +import math import warnings +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss - -import math - -from transformers import AutoConfig, AutoModelForCausalLM, \ - CLIPVisionModel, CLIPImageProcessor - -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers import (AutoConfig, AutoModelForCausalLM, CLIPImageProcessor, + CLIPVisionModel) +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel - DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" @@ -49,19 +46,26 @@ class LlavaMPTModel(MPTModel): if hasattr(config, "mm_vision_tower"): # HACK: for FSDP - self.vision_tower = [CLIPVisionModel.from_pretrained(config.mm_vision_tower)] + self.vision_tower = [ + CLIPVisionModel.from_pretrained(config.mm_vision_tower) + ] # self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower) if hasattr(config, "use_mm_proj"): self.mm_projector = nn.Linear(config.mm_hidden_size, config.d_model) - def initialize_vision_modules(self, vision_tower, mm_vision_select_layer, - pretrain_mm_mlp_adapter=None, tune_mm_mlp_adapter=False): + def initialize_vision_modules( + self, + vision_tower, + mm_vision_select_layer, + pretrain_mm_mlp_adapter=None, + tune_mm_mlp_adapter=False, + ): self.config.mm_vision_tower = vision_tower image_processor = CLIPImageProcessor.from_pretrained(vision_tower) - if not hasattr(self, 'vision_tower'): + if not hasattr(self, "vision_tower"): vision_tower = CLIPVisionModel.from_pretrained(vision_tower) else: vision_tower = self.vision_tower[0] @@ -76,23 +80,44 @@ class LlavaMPTModel(MPTModel): self.config.mm_hidden_size = vision_config.hidden_size self.config.mm_vision_select_layer = mm_vision_select_layer - if not hasattr(self, 'mm_projector'): - self.mm_projector = nn.Linear(vision_config.hidden_size, self.config.d_model) + if not hasattr(self, "mm_projector"): + self.mm_projector = nn.Linear( + vision_config.hidden_size, self.config.d_model + ) if pretrain_mm_mlp_adapter is not None: - mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') - self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items() if 'mm_projector' in k}) + mm_projector_weights = torch.load( + pretrain_mm_mlp_adapter, map_location="cpu" + ) + self.mm_projector.load_state_dict( + { + k.split(".")[-1]: v + for k, v in mm_projector_weights.items() + if "mm_projector" in k + } + ) return dict( image_processor=image_processor, image_token_len=num_patches, - vision_config=vision_config + vision_config=vision_config, ) - def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None): - + def forward( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.ByteTensor] = None, + prefix_mask: Optional[torch.ByteTensor] = None, + sequence_id: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + use_cache: Optional[bool] = None, + images=None, + ): # HACK: replace back original embeddings for LLaVA pretraining - orig_embeds_params = getattr(self, 'orig_embeds_params', None) + orig_embeds_params = getattr(self, "orig_embeds_params", None) # if orig_embeds_params is not None: # orig_embeds_params = orig_embeds_params[0] # with torch.no_grad(): @@ -100,8 +125,12 @@ class LlavaMPTModel(MPTModel): inputs_embeds = self.wte(input_ids) - vision_tower = getattr(self, 'vision_tower', None) - if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None: + vision_tower = getattr(self, "vision_tower", None) + if ( + vision_tower is not None + and (input_ids.shape[1] != 1 or self.training) + and images is not None + ): # TODO: this is a modified multimodal LLM -- Haotian Liu vision_tower = vision_tower[0] # HACK: for FSDP with torch.no_grad(): @@ -109,21 +138,36 @@ class LlavaMPTModel(MPTModel): # variable length images image_features = [] for image in images: - image_forward_out = vision_tower(image.unsqueeze(0), output_hidden_states=True) - select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1) - select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer] + image_forward_out = vision_tower( + image.unsqueeze(0), output_hidden_states=True + ) + select_hidden_state_layer = getattr( + self.config, "mm_vision_select_layer", -1 + ) + select_hidden_state = image_forward_out.hidden_states[ + select_hidden_state_layer + ] image_feature = select_hidden_state[:, 1:] image_features.append(image_feature) else: image_forward_outs = vision_tower(images, output_hidden_states=True) - select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1) - select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer] + select_hidden_state_layer = getattr( + self.config, "mm_vision_select_layer", -1 + ) + select_hidden_state = image_forward_outs.hidden_states[ + select_hidden_state_layer + ] image_features = select_hidden_state[:, 1:] if type(images) is list: - image_features = [self.mm_projector(image_feature)[0] for image_feature in image_features] + image_features = [ + self.mm_projector(image_feature)[0] + for image_feature in image_features + ] else: image_features = self.mm_projector(image_features) - dummy_image_features = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype) + dummy_image_features = torch.zeros( + 256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) dummy_image_features = self.mm_projector(dummy_image_features) new_input_embeds = [] @@ -131,43 +175,130 @@ class LlavaMPTModel(MPTModel): for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds): if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0: # multimodal LLM, but the current sample is not multimodal - cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum() + cur_input_embeds = ( + cur_input_embeds + (0.0 * dummy_image_features).sum() + ) new_input_embeds.append(cur_input_embeds) continue if vision_tower.config.use_im_start_end: cur_image_features = image_features[cur_image_idx] num_patches = cur_image_features.shape[0] - if (cur_input_ids == vision_tower.config.im_start_token).sum() != (cur_input_ids == vision_tower.config.im_end_token).sum(): - raise ValueError("The number of image start tokens and image end tokens should be the same.") - image_start_tokens = torch.where(cur_input_ids == vision_tower.config.im_start_token)[0] + if (cur_input_ids == vision_tower.config.im_start_token).sum() != ( + cur_input_ids == vision_tower.config.im_end_token + ).sum(): + raise ValueError( + "The number of image start tokens and image end tokens should be the same." + ) + image_start_tokens = torch.where( + cur_input_ids == vision_tower.config.im_start_token + )[0] for image_start_token_pos in image_start_tokens: - cur_image_features = image_features[cur_image_idx].to(device=cur_input_embeds.device) + cur_image_features = image_features[cur_image_idx].to( + device=cur_input_embeds.device + ) num_patches = cur_image_features.shape[0] - if cur_input_ids[image_start_token_pos + num_patches + 1] != vision_tower.config.im_end_token: - raise ValueError("The image end token should follow the image start token.") + if ( + cur_input_ids[image_start_token_pos + num_patches + 1] + != vision_tower.config.im_end_token + ): + raise ValueError( + "The image end token should follow the image start token." + ) if orig_embeds_params is not None: - cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0) + cur_new_input_embeds = torch.cat( + ( + cur_input_embeds[:image_start_token_pos].detach(), + cur_input_embeds[ + image_start_token_pos : image_start_token_pos + + 1 + ], + cur_image_features, + cur_input_embeds[ + image_start_token_pos + + num_patches + + 1 : image_start_token_pos + + num_patches + + 2 + ], + cur_input_embeds[ + image_start_token_pos + num_patches + 2 : + ].detach(), + ), + dim=0, + ) else: - cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0) + cur_new_input_embeds = torch.cat( + ( + cur_input_embeds[: image_start_token_pos + 1], + cur_image_features, + cur_input_embeds[ + image_start_token_pos + num_patches + 1 : + ], + ), + dim=0, + ) cur_image_idx += 1 new_input_embeds.append(cur_new_input_embeds) else: cur_image_features = image_features[cur_image_idx] num_patches = cur_image_features.shape[0] - if (cur_input_ids == vision_tower.config.im_patch_token).sum() != num_patches: - raise ValueError("The number of image patch tokens should be the same as the number of image patches.") - masked_indices = torch.where(cur_input_ids == vision_tower.config.im_patch_token)[0] + if ( + cur_input_ids == vision_tower.config.im_patch_token + ).sum() != num_patches: + raise ValueError( + "The number of image patch tokens should be the same as the number of image patches." + ) + masked_indices = torch.where( + cur_input_ids == vision_tower.config.im_patch_token + )[0] mask_index_start = masked_indices[0] - if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any(): - raise ValueError("The image patch tokens should be consecutive.") + if ( + masked_indices + != torch.arange( + mask_index_start, + mask_index_start + num_patches, + device=masked_indices.device, + dtype=masked_indices.dtype, + ) + ).any(): + raise ValueError( + "The image patch tokens should be consecutive." + ) if orig_embeds_params is not None: - cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(), cur_image_features, cur_input_embeds[mask_index_start+num_patches:].detach()), dim=0) + cur_new_input_embeds = torch.cat( + ( + cur_input_embeds[:mask_index_start].detach(), + cur_image_features, + cur_input_embeds[ + mask_index_start + num_patches : + ].detach(), + ), + dim=0, + ) else: - cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0) + cur_new_input_embeds = torch.cat( + ( + cur_input_embeds[:mask_index_start], + cur_image_features, + cur_input_embeds[mask_index_start + num_patches :], + ), + dim=0, + ) new_input_embeds.append(cur_new_input_embeds) inputs_embeds = torch.stack(new_input_embeds, dim=0) - return super(LlavaMPTModel, self).forward(input_ids=None, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, tok_emb=inputs_embeds) + return super(LlavaMPTModel, self).forward( + input_ids=None, + past_key_values=past_key_values, + attention_mask=attention_mask, + prefix_mask=prefix_mask, + sequence_id=sequence_id, + return_dict=return_dict, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + tok_emb=inputs_embeds, + ) class LlavaMPTForCausalLM(MPTForCausalLM): @@ -178,16 +309,18 @@ class LlavaMPTForCausalLM(MPTForCausalLM): super(MPTForCausalLM, self).__init__(config) if not config.tie_word_embeddings: - raise ValueError('MPTForCausalLM only supports tied word embeddings') + raise ValueError("MPTForCausalLM only supports tied word embeddings") self.transformer = LlavaMPTModel(config) self.logit_scale = None if config.logit_scale is not None: logit_scale = config.logit_scale if isinstance(logit_scale, str): - if logit_scale == 'inv_sqrt_d_model': + if logit_scale == "inv_sqrt_d_model": logit_scale = 1 / math.sqrt(config.d_model) else: - raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") + raise ValueError( + f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." + ) self.logit_scale = logit_scale def get_model(self): @@ -197,28 +330,67 @@ class LlavaMPTForCausalLM(MPTForCausalLM): if isinstance(module, LlavaMPTModel): module.gradient_checkpointing = value - def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None): - return_dict = return_dict if return_dict is not None else self.config.return_dict + def forward( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.ByteTensor] = None, + prefix_mask: Optional[torch.ByteTensor] = None, + sequence_id: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + use_cache: Optional[bool] = None, + images=None, + ): + return_dict = ( + return_dict if return_dict is not None else self.config.return_dict + ) use_cache = use_cache if use_cache is not None else self.config.use_cache - outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, images=images) + outputs = self.transformer( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + prefix_mask=prefix_mask, + sequence_id=sequence_id, + return_dict=return_dict, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + images=images, + ) logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight) if self.logit_scale is not None: if self.logit_scale == 0: - warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.') + warnings.warn( + f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs." + ) logits *= self.logit_scale loss = None if labels is not None: labels = torch.roll(labels, shifts=-1) labels[:, -1] = -100 - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)) - return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1) + ) + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + ) - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs + ): if inputs_embeds is not None: - raise NotImplementedError('inputs_embeds is not implemented for MPT yet') - attention_mask = kwargs['attention_mask'].bool() + raise NotImplementedError("inputs_embeds is not implemented for MPT yet") + attention_mask = kwargs["attention_mask"].bool() if attention_mask[:, -1].sum() != attention_mask.shape[0]: - raise NotImplementedError('MPT does not support generation with right padding.') + raise NotImplementedError( + "MPT does not support generation with right padding." + ) if self.transformer.attn_uses_sequence_id and self.training: sequence_id = torch.zeros_like(input_ids[:1]) else: @@ -227,55 +399,91 @@ class LlavaMPTForCausalLM(MPTForCausalLM): input_ids = input_ids[:, -1].unsqueeze(-1) if self.transformer.prefix_lm: prefix_mask = torch.ones_like(attention_mask) - if kwargs.get('use_cache') == False: - raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.') + if kwargs.get("use_cache") == False: + raise NotImplementedError( + "MPT with prefix_lm=True does not support use_cache=False." + ) else: prefix_mask = None - return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), "images": kwargs.get("images", None)} - - def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device, - tune_mm_mlp_adapter=False, pretrain_mm_mlp_adapter=None): + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "prefix_mask": prefix_mask, + "sequence_id": sequence_id, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache", True), + "images": kwargs.get("images", None), + } + + def initialize_vision_tokenizer( + self, + mm_use_im_start_end, + tokenizer, + device, + tune_mm_mlp_adapter=False, + pretrain_mm_mlp_adapter=None, + ): vision_config = self.get_model().vision_tower[0].config vision_config.use_im_start_end = mm_use_im_start_end tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) self.resize_token_embeddings(len(tokenizer)) if mm_use_im_start_end: - num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + num_new_tokens = tokenizer.add_tokens( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True + ) self.resize_token_embeddings(len(tokenizer)) - vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) + ( + vision_config.im_start_token, + vision_config.im_end_token, + ) = tokenizer.convert_tokens_to_ids( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN] + ) if num_new_tokens > 0: input_embeddings = self.get_input_embeddings().weight.data output_embeddings = self.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( - dim=0, keepdim=True) + dim=0, keepdim=True + ) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( - dim=0, keepdim=True) + dim=0, keepdim=True + ) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg if tune_mm_mlp_adapter: - self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)] + self.get_model().orig_embeds_params = [ + self.get_input_embeddings().weight.data.clone().to(device=device) + ] for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = False if pretrain_mm_mlp_adapter: - mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') - embed_tokens_weight = mm_projector_weights['transformer.wte.weight'] + mm_projector_weights = torch.load( + pretrain_mm_mlp_adapter, map_location="cpu" + ) + embed_tokens_weight = mm_projector_weights["transformer.wte.weight"] assert num_new_tokens == 2 if input_embeddings.shape == embed_tokens_weight.shape: - input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] + input_embeddings[-num_new_tokens:] = embed_tokens_weight[ + -num_new_tokens: + ] elif embed_tokens_weight.shape[0] == num_new_tokens: input_embeddings[-num_new_tokens:] = embed_tokens_weight else: - raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") + raise ValueError( + f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}." + ) + + vision_config.im_patch_token = tokenizer.convert_tokens_to_ids( + [DEFAULT_IMAGE_PATCH_TOKEN] + )[0] - vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] AutoConfig.register("llava_mpt", LlavaMPTConfig) AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM) diff --git a/model/llava/model/make_delta.py b/model/llava/model/make_delta.py index 4ae55d59c2c8bab80299272314a41bbeb959d8ed..26d73d2474e2da7f62955c6685c8812d6d94f6ad 100644 --- a/model/llava/model/make_delta.py +++ b/model/llava/model/make_delta.py @@ -5,31 +5,40 @@ python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/mod import argparse import torch -from tqdm import tqdm -from transformers import AutoTokenizer, AutoModelForCausalLM from llava.model.utils import auto_upgrade +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): print("Loading base model") base = AutoModelForCausalLM.from_pretrained( - base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) print("Loading target model") auto_upgrade(target_model_path) - target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + target = AutoModelForCausalLM.from_pretrained( + target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) print("Calculating delta") for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): if name not in base.state_dict(): - assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' + assert name in [ + "model.mm_projector.weight", + "model.mm_projector.bias", + ], f"{name} not in base model" continue if param.data.shape == base.state_dict()[name].shape: param.data -= base.state_dict()[name] else: - assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' + assert name in [ + "model.embed_tokens.weight", + "lm_head.weight", + ], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}" bparam = base.state_dict()[name] - param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam + param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam print("Saving delta") if hub_repo_id: @@ -49,4 +58,6 @@ if __name__ == "__main__": parser.add_argument("--hub-repo-id", type=str, default=None) args = parser.parse_args() - make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) + make_delta( + args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id + ) diff --git a/model/llava/model/mpt/adapt_tokenizer.py b/model/llava/model/mpt/adapt_tokenizer.py index e640c157e8f5581953c518df0611a423225ef598..b6c2acaca8bd5bab095bad9f45208f7961297057 100644 --- a/model/llava/model/mpt/adapt_tokenizer.py +++ b/model/llava/model/mpt/adapt_tokenizer.py @@ -1,8 +1,12 @@ from typing import Union -from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast + +from transformers import (AutoTokenizer, PreTrainedTokenizer, + PreTrainedTokenizerFast) + Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] NUM_SENTINEL_TOKENS: int = 100 + def adapt_tokenizer_for_denoising(tokenizer: Tokenizer): """Adds sentinel tokens and padding token (if missing). @@ -12,16 +16,17 @@ def adapt_tokenizer_for_denoising(tokenizer: Tokenizer): All added tokens are added as special tokens. No tokens are added if sentinel tokens and padding token already exist. """ - sentinels_to_add = [f'' for i in range(NUM_SENTINEL_TOKENS)] + sentinels_to_add = [f"" for i in range(NUM_SENTINEL_TOKENS)] tokenizer.add_tokens(sentinels_to_add, special_tokens=True) if tokenizer.pad_token is None: - tokenizer.add_tokens('', special_tokens=True) - tokenizer.pad_token = '' + tokenizer.add_tokens("", special_tokens=True) + tokenizer.pad_token = "" assert tokenizer.pad_token_id is not None - sentinels = ''.join([f'' for i in range(NUM_SENTINEL_TOKENS)]) + sentinels = "".join([f"" for i in range(NUM_SENTINEL_TOKENS)]) _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids tokenizer.sentinel_token_ids = _sentinel_token_ids + class AutoTokenizerForMOD(AutoTokenizer): """AutoTokenizer + Adaptation for MOD. @@ -38,4 +43,4 @@ class AutoTokenizerForMOD(AutoTokenizer): """See `AutoTokenizer.from_pretrained` docstring.""" tokenizer = super().from_pretrained(*args, **kwargs) adapt_tokenizer_for_denoising(tokenizer) - return tokenizer \ No newline at end of file + return tokenizer diff --git a/model/llava/model/mpt/attention.py b/model/llava/model/mpt/attention.py index 2ca1069cd14ca055d918fa623d7da5efb4c5fd89..b4dad928098484ef7c287b5d7da7d95d5ff5ffee 100644 --- a/model/llava/model/mpt/attention.py +++ b/model/llava/model/mpt/attention.py @@ -2,24 +2,45 @@ import math import warnings from typing import Optional + import torch import torch.nn as nn from einops import rearrange from torch import nn + from .norm import LPLayerNorm -def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool): + +def _reset_is_causal( + num_query_tokens: int, num_key_tokens: int, original_is_causal: bool +): if original_is_causal and num_query_tokens != num_key_tokens: if num_query_tokens != 1: - raise NotImplementedError('MPT does not support query and key with different number of tokens, unless number of query tokens is 1.') + raise NotImplementedError( + "MPT does not support query and key with different number of tokens, unless number of query tokens is 1." + ) else: return False return original_is_causal -def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False): - q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads) - k = rearrange(key, 'b s (h d) -> b h d s', h=1 if multiquery else n_heads) - v = rearrange(value, 'b s (h d) -> b h s d', h=1 if multiquery else n_heads) + +def scaled_multihead_dot_product_attention( + query, + key, + value, + n_heads, + softmax_scale=None, + attn_bias=None, + key_padding_mask=None, + is_causal=False, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, +): + q = rearrange(query, "b s (h d) -> b h s d", h=n_heads) + k = rearrange(key, "b s (h d) -> b h d s", h=1 if multiquery else n_heads) + v = rearrange(value, "b s (h d) -> b h s d", h=1 if multiquery else n_heads) min_val = torch.finfo(q.dtype).min (b, _, s_q, d) = q.shape s_k = k.size(-1) @@ -27,13 +48,27 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_s softmax_scale = 1 / math.sqrt(d) attn_weight = q.matmul(k) * softmax_scale if attn_bias is not None: - if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q): - raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.') + if ( + attn_bias.size(-1) != 1 + and attn_bias.size(-1) != s_k + or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q) + ): + raise RuntimeError( + f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}." + ) attn_weight = attn_weight + attn_bias if key_padding_mask is not None: if attn_bias is not None: - warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.') - attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val) + warnings.warn( + "Propogating key_padding_mask to the attention module " + + "and applying it within the attention module can cause " + + "unneccessary computation/memory usage. Consider integrating " + + "into attn_bias once and passing that to each attention " + + "module instead." + ) + attn_weight = attn_weight.masked_fill( + ~key_padding_mask.view((b, 1, 1, s_k)), min_val + ) if is_causal: s = max(s_q, s_k) causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16) @@ -44,74 +79,146 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_s attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val) attn_weight = torch.softmax(attn_weight, dim=-1) if dropout_p: - attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True) + attn_weight = torch.nn.functional.dropout( + attn_weight, p=dropout_p, training=training, inplace=True + ) out = attn_weight.matmul(v) - out = rearrange(out, 'b h s d -> b s (h d)') + out = rearrange(out, "b h s d -> b s (h d)") if needs_weights: return (out, attn_weight) return (out, None) + def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]): for tensor in tensors: if tensor.dtype not in valid_dtypes: - raise TypeError(f'tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.') + raise TypeError( + f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}." + ) if not tensor.is_cuda: - raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).') + raise TypeError( + f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})." + ) + -def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False): +def flash_attn_fn( + query, + key, + value, + n_heads, + softmax_scale=None, + attn_bias=None, + key_padding_mask=None, + is_causal=False, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, +): try: from flash_attn import bert_padding, flash_attn_interface except: - raise RuntimeError('Please install flash-attn==1.0.3.post0') + raise RuntimeError("Please install flash-attn==1.0.3.post0") check_valid_inputs(query, key, value) if attn_bias is not None: - raise NotImplementedError(f'attn_bias not implemented for flash attn.') + raise NotImplementedError(f"attn_bias not implemented for flash attn.") (batch_size, seqlen) = query.shape[:2] if key_padding_mask is None: key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) - query_padding_mask = key_padding_mask[:, -query.size(1):] - (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask) - query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads) - (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask) - key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads) + query_padding_mask = key_padding_mask[:, -query.size(1) :] + (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input( + query, query_padding_mask + ) + query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads) + (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input( + key, key_padding_mask + ) + key_unpad = rearrange( + key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads + ) (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask) - value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads) + value_unpad = rearrange( + value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads + ) if multiquery: key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1)) - value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1)) + value_unpad = value_unpad.expand( + value_unpad.size(0), n_heads, value_unpad.size(-1) + ) dropout_p = dropout_p if training else 0.0 reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) - output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights) - output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen) + output_unpad = flash_attn_interface.flash_attn_unpadded_func( + query_unpad, + key_unpad, + value_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale=softmax_scale, + causal=reset_is_causal, + return_attn_probs=needs_weights, + ) + output = bert_padding.pad_input( + rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen + ) return (output, None) -def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False): + +def triton_flash_attn_fn( + query, + key, + value, + n_heads, + softmax_scale=None, + attn_bias=None, + key_padding_mask=None, + is_causal=False, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, +): try: from flash_attn import flash_attn_triton except: - raise RuntimeError('Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202') + raise RuntimeError( + "Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202" + ) check_valid_inputs(query, key, value) if dropout_p: - raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.') + raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.") if needs_weights: - raise NotImplementedError(f'attn_impl: triton cannot return attn weights.') + raise NotImplementedError(f"attn_impl: triton cannot return attn weights.") if key_padding_mask is not None: - warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.') + warnings.warn( + "Propagating key_padding_mask to the attention module " + + "and applying it within the attention module can cause " + + "unnecessary computation/memory usage. Consider integrating " + + "into attn_bias once and passing that to each attention " + + "module instead." + ) (b_size, s_k) = key_padding_mask.shape[:2] if attn_bias is None: attn_bias = query.new_zeros(b_size, 1, 1, s_k) - attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min) - query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads) - key = rearrange(key, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads) - value = rearrange(value, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads) + attn_bias = attn_bias.masked_fill( + ~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min + ) + query = rearrange(query, "b s (h d) -> b s h d", h=n_heads) + key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads) + value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads) if multiquery: key = key.expand(*key.shape[:2], n_heads, key.size(-1)) value = value.expand(*value.shape[:2], n_heads, value.size(-1)) reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) - attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale) + attn_output = flash_attn_triton.flash_attn_func( + query, key, value, attn_bias, reset_is_causal, softmax_scale + ) output = attn_output.view(*attn_output.shape[:2], -1) return (output, None) + class MultiheadAttention(nn.Module): """Multi-head self attention. @@ -119,7 +226,18 @@ class MultiheadAttention(nn.Module): additive bias. """ - def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None): + def __init__( + self, + d_model: int, + n_heads: int, + attn_impl: str = "triton", + clip_qkv: Optional[float] = None, + qk_ln: bool = False, + softmax_scale: Optional[float] = None, + attn_pdrop: float = 0.0, + low_precision_layernorm: bool = False, + device: Optional[str] = None, + ): super().__init__() self.attn_impl = attn_impl self.clip_qkv = clip_qkv @@ -137,21 +255,38 @@ class MultiheadAttention(nn.Module): layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm self.q_ln = layernorm_class(self.d_model, device=device) self.k_ln = layernorm_class(self.d_model, device=device) - if self.attn_impl == 'flash': + if self.attn_impl == "flash": self.attn_fn = flash_attn_fn - elif self.attn_impl == 'triton': + elif self.attn_impl == "triton": self.attn_fn = triton_flash_attn_fn - warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.') - elif self.attn_impl == 'torch': + warnings.warn( + "While `attn_impl: triton` can be faster than `attn_impl: flash` " + + "it uses more memory. When training larger models this can trigger " + + "alloc retries which hurts performance. If encountered, we recommend " + + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`." + ) + elif self.attn_impl == "torch": self.attn_fn = scaled_multihead_dot_product_attention if torch.cuda.is_available(): - warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.') + warnings.warn( + "Using `attn_impl: torch`. If your model does not use `alibi` or " + + "`prefix_lm` we recommend using `attn_impl: flash` otherwise " + + "we recommend using `attn_impl: triton`." + ) else: - raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') + raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") self.out_proj = nn.Linear(self.d_model, self.d_model, device=device) self.out_proj._is_residual = True - def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False): + def forward( + self, + x, + past_key_value=None, + attn_bias=None, + attention_mask=None, + is_causal=True, + needs_weights=False, + ): qkv = self.Wqkv(x) if self.clip_qkv: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) @@ -167,10 +302,23 @@ class MultiheadAttention(nn.Module): value = torch.cat([past_key_value[1], value], dim=1) past_key_value = (key, value) if attn_bias is not None: - attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):] - (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights) + attn_bias = attn_bias[:, :, -query.size(1) :, -key.size(1) :] + (context, attn_weights) = self.attn_fn( + query, + key, + value, + self.n_heads, + softmax_scale=self.softmax_scale, + attn_bias=attn_bias, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + dropout_p=self.attn_dropout_p, + training=self.training, + needs_weights=needs_weights, + ) return (self.out_proj(context), attn_weights, past_key_value) + class MultiQueryAttention(nn.Module): """Multi-Query self attention. @@ -178,7 +326,18 @@ class MultiQueryAttention(nn.Module): additive bias. """ - def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None): + def __init__( + self, + d_model: int, + n_heads: int, + attn_impl: str = "triton", + clip_qkv: Optional[float] = None, + qk_ln: bool = False, + softmax_scale: Optional[float] = None, + attn_pdrop: float = 0.0, + low_precision_layernorm: bool = False, + device: Optional[str] = None, + ): super().__init__() self.attn_impl = attn_impl self.clip_qkv = clip_qkv @@ -197,25 +356,44 @@ class MultiQueryAttention(nn.Module): layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm self.q_ln = layernorm_class(d_model, device=device) self.k_ln = layernorm_class(self.head_dim, device=device) - if self.attn_impl == 'flash': + if self.attn_impl == "flash": self.attn_fn = flash_attn_fn - elif self.attn_impl == 'triton': + elif self.attn_impl == "triton": self.attn_fn = triton_flash_attn_fn - warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.') - elif self.attn_impl == 'torch': + warnings.warn( + "While `attn_impl: triton` can be faster than `attn_impl: flash` " + + "it uses more memory. When training larger models this can trigger " + + "alloc retries which hurts performance. If encountered, we recommend " + + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`." + ) + elif self.attn_impl == "torch": self.attn_fn = scaled_multihead_dot_product_attention if torch.cuda.is_available(): - warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.') + warnings.warn( + "Using `attn_impl: torch`. If your model does not use `alibi` or " + + "`prefix_lm` we recommend using `attn_impl: flash` otherwise " + + "we recommend using `attn_impl: triton`." + ) else: - raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') + raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") self.out_proj = nn.Linear(self.d_model, self.d_model, device=device) self.out_proj._is_residual = True - def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False): + def forward( + self, + x, + past_key_value=None, + attn_bias=None, + attention_mask=None, + is_causal=True, + needs_weights=False, + ): qkv = self.Wqkv(x) if self.clip_qkv: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) - (query, key, value) = qkv.split([self.d_model, self.head_dim, self.head_dim], dim=2) + (query, key, value) = qkv.split( + [self.d_model, self.head_dim, self.head_dim], dim=2 + ) key_padding_mask = attention_mask if self.qk_ln: dtype = query.dtype @@ -227,14 +405,30 @@ class MultiQueryAttention(nn.Module): value = torch.cat([past_key_value[1], value], dim=1) past_key_value = (key, value) if attn_bias is not None: - attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):] - (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True) + attn_bias = attn_bias[:, :, -query.size(1) :, -key.size(1) :] + (context, attn_weights) = self.attn_fn( + query, + key, + value, + self.n_heads, + softmax_scale=self.softmax_scale, + attn_bias=attn_bias, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + dropout_p=self.attn_dropout_p, + training=self.training, + needs_weights=needs_weights, + multiquery=True, + ) return (self.out_proj(context), attn_weights, past_key_value) -def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id): - if attn_impl == 'flash': + +def attn_bias_shape( + attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id +): + if attn_impl == "flash": return None - elif attn_impl in ['torch', 'triton']: + elif attn_impl in ["torch", "triton"]: if alibi: if (prefix_lm or not causal) or use_sequence_id: return (1, n_heads, seq_len, seq_len) @@ -243,18 +437,31 @@ def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_s return (1, 1, seq_len, seq_len) return None else: - raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') + raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") -def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8): - if attn_impl == 'flash': + +def build_attn_bias( + attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8 +): + if attn_impl == "flash": return None - elif attn_impl in ['torch', 'triton']: + elif attn_impl in ["torch", "triton"]: if alibi: (device, dtype) = (attn_bias.device, attn_bias.dtype) - attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype)) + attn_bias = attn_bias.add( + build_alibi_bias( + n_heads, + seq_len, + full=not causal, + alibi_bias_max=alibi_bias_max, + device=device, + dtype=dtype, + ) + ) return attn_bias else: - raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') + raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") + def gen_slopes(n_heads, alibi_bias_max=8, device=None): _n_heads = 2 ** math.ceil(math.log2(n_heads)) @@ -265,12 +472,24 @@ def gen_slopes(n_heads, alibi_bias_max=8, device=None): slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads] return slopes.view(1, n_heads, 1, 1) -def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None): - alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len) + +def build_alibi_bias( + n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None +): + alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view( + 1, 1, 1, seq_len + ) if full: - alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1) + alibi_bias = alibi_bias - torch.arange( + 1 - seq_len, 1, dtype=torch.int32, device=device + ).view(1, 1, seq_len, 1) alibi_bias = alibi_bias.abs().mul(-1) slopes = gen_slopes(n_heads, alibi_bias_max, device=device) alibi_bias = alibi_bias * slopes return alibi_bias.to(dtype=dtype) -ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention} \ No newline at end of file + + +ATTN_CLASS_REGISTRY = { + "multihead_attention": MultiheadAttention, + "multiquery_attention": MultiQueryAttention, +} diff --git a/model/llava/model/mpt/blocks.py b/model/llava/model/mpt/blocks.py index 04493aa4c03ef1b14ec539c9af8e9c38e8befc8b..1511a225455aaf0a5134cf6d275993e7de57b0e1 100644 --- a/model/llava/model/mpt/blocks.py +++ b/model/llava/model/mpt/blocks.py @@ -1,41 +1,90 @@ """GPT Blocks used for the GPT Model.""" from typing import Dict, Optional, Tuple + import torch import torch.nn as nn + from .attention import ATTN_CLASS_REGISTRY from .norm import NORM_CLASS_REGISTRY -class MPTMLP(nn.Module): - def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None): +class MPTMLP(nn.Module): + def __init__( + self, d_model: int, expansion_ratio: int, device: Optional[str] = None + ): super().__init__() self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) - self.act = nn.GELU(approximate='none') + self.act = nn.GELU(approximate="none") self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) self.down_proj._is_residual = True def forward(self, x): return self.down_proj(self.act(self.up_proj(x))) -class MPTBlock(nn.Module): - def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs): +class MPTBlock(nn.Module): + def __init__( + self, + d_model: int, + n_heads: int, + expansion_ratio: int, + attn_config: Dict = { + "attn_type": "multihead_attention", + "attn_pdrop": 0.0, + "attn_impl": "triton", + "qk_ln": False, + "clip_qkv": None, + "softmax_scale": None, + "prefix_lm": False, + "attn_uses_sequence_id": False, + "alibi": False, + "alibi_bias_max": 8, + }, + resid_pdrop: float = 0.0, + norm_type: str = "low_precision_layernorm", + device: Optional[str] = None, + **kwargs + ): del kwargs super().__init__() norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] - attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] + attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]] self.norm_1 = norm_class(d_model, device=device) - self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device) + self.attn = attn_class( + attn_impl=attn_config["attn_impl"], + clip_qkv=attn_config["clip_qkv"], + qk_ln=attn_config["qk_ln"], + softmax_scale=attn_config["softmax_scale"], + attn_pdrop=attn_config["attn_pdrop"], + d_model=d_model, + n_heads=n_heads, + device=device, + ) self.norm_2 = norm_class(d_model, device=device) - self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device) + self.ffn = MPTMLP( + d_model=d_model, expansion_ratio=expansion_ratio, device=device + ) self.resid_attn_dropout = nn.Dropout(resid_pdrop) self.resid_ffn_dropout = nn.Dropout(resid_pdrop) - def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: + def forward( + self, + x: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attn_bias: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.ByteTensor] = None, + is_causal: bool = True, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: a = self.norm_1(x) - (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal) + (b, _, past_key_value) = self.attn( + a, + past_key_value=past_key_value, + attn_bias=attn_bias, + attention_mask=attention_mask, + is_causal=is_causal, + ) x = x + self.resid_attn_dropout(b) m = self.norm_2(x) n = self.ffn(m) x = x + self.resid_ffn_dropout(n) - return (x, past_key_value) \ No newline at end of file + return (x, past_key_value) diff --git a/model/llava/model/mpt/configuration_mpt.py b/model/llava/model/mpt/configuration_mpt.py index 35d1269cd4b599799d6df7953a8d0c30b33d1e65..f5b96e2a41b16a372b5050769a8c897816ada529 100644 --- a/model/llava/model/mpt/configuration_mpt.py +++ b/model/llava/model/mpt/configuration_mpt.py @@ -1,13 +1,52 @@ """A HuggingFace-style model configuration.""" from typing import Dict, Optional, Union + from transformers import PretrainedConfig -attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8} -init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu'} + +attn_config_defaults: Dict = { + "attn_type": "multihead_attention", + "attn_pdrop": 0.0, + "attn_impl": "triton", + "qk_ln": False, + "clip_qkv": None, + "softmax_scale": None, + "prefix_lm": False, + "attn_uses_sequence_id": False, + "alibi": False, + "alibi_bias_max": 8, +} +init_config_defaults: Dict = { + "name": "kaiming_normal_", + "fan_mode": "fan_in", + "init_nonlinearity": "relu", +} + class MPTConfig(PretrainedConfig): - model_type = 'mpt' + model_type = "mpt" - def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, verbose: int=0, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, **kwargs): + def __init__( + self, + d_model: int = 2048, + n_heads: int = 16, + n_layers: int = 24, + expansion_ratio: int = 4, + max_seq_len: int = 2048, + vocab_size: int = 50368, + resid_pdrop: float = 0.0, + emb_pdrop: float = 0.0, + learned_pos_emb: bool = True, + attn_config: Dict = attn_config_defaults, + init_device: str = "cpu", + logit_scale: Optional[Union[float, str]] = None, + no_bias: bool = False, + verbose: int = 0, + embedding_fraction: float = 1.0, + norm_type: str = "low_precision_layernorm", + use_cache: bool = False, + init_config: Dict = init_config_defaults, + **kwargs, + ): """The MPT configuration class. Args: @@ -80,39 +119,76 @@ class MPTConfig(PretrainedConfig): self.norm_type = norm_type self.use_cache = use_cache self.init_config = init_config - if 'name' in kwargs: - del kwargs['name'] - if 'loss_fn' in kwargs: - del kwargs['loss_fn'] + if "name" in kwargs: + del kwargs["name"] + if "loss_fn" in kwargs: + del kwargs["loss_fn"] super().__init__(**kwargs) self._validate_config() def _set_config_defaults(self, config, config_defaults): - for (k, v) in config_defaults.items(): + for k, v in config_defaults.items(): if k not in config: config[k] = v return config def _validate_config(self): - self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults) - self.init_config = self._set_config_defaults(self.init_config, init_config_defaults) + self.attn_config = self._set_config_defaults( + self.attn_config, attn_config_defaults + ) + self.init_config = self._set_config_defaults( + self.init_config, init_config_defaults + ) if self.d_model % self.n_heads != 0: - raise ValueError('d_model must be divisible by n_heads') - if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])): - raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1") - if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']: + raise ValueError("d_model must be divisible by n_heads") + if any( + ( + prob < 0 or prob > 1 + for prob in [ + self.attn_config["attn_pdrop"], + self.resid_pdrop, + self.emb_pdrop, + ] + ) + ): + raise ValueError( + "self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1" + ) + if self.attn_config["attn_impl"] not in ["torch", "flash", "triton"]: raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}") - if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: - raise NotImplementedError('prefix_lm only implemented with torch and triton attention.') - if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: - raise NotImplementedError('alibi only implemented with torch and triton attention.') - if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: - raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.') + if self.attn_config["prefix_lm"] and self.attn_config["attn_impl"] not in [ + "torch", + "triton", + ]: + raise NotImplementedError( + "prefix_lm only implemented with torch and triton attention." + ) + if self.attn_config["alibi"] and self.attn_config["attn_impl"] not in [ + "torch", + "triton", + ]: + raise NotImplementedError( + "alibi only implemented with torch and triton attention." + ) + if self.attn_config["attn_uses_sequence_id"] and self.attn_config[ + "attn_impl" + ] not in ["torch", "triton"]: + raise NotImplementedError( + "attn_uses_sequence_id only implemented with torch and triton attention." + ) if self.embedding_fraction > 1 or self.embedding_fraction <= 0: - raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!') - if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model': - raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") - if self.init_config.get('name', None) is None: - raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.") - if not self.learned_pos_emb and (not self.attn_config['alibi']): - raise ValueError(f'Positional information must be provided to the model using either learned_pos_emb or alibi.') \ No newline at end of file + raise ValueError( + "model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!" + ) + if isinstance(self.logit_scale, str) and self.logit_scale != "inv_sqrt_d_model": + raise ValueError( + f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." + ) + if self.init_config.get("name", None) is None: + raise ValueError( + f"self.init_config={self.init_config!r} 'name' needs to be set." + ) + if not self.learned_pos_emb and (not self.attn_config["alibi"]): + raise ValueError( + f"Positional information must be provided to the model using either learned_pos_emb or alibi." + ) diff --git a/model/llava/model/mpt/hf_prefixlm_converter.py b/model/llava/model/mpt/hf_prefixlm_converter.py index 8c1a6487202a6400a7116a6bd68b493892ef0d14..427d3878185431f3e657d1a93c5db5a55f04300f 100644 --- a/model/llava/model/mpt/hf_prefixlm_converter.py +++ b/model/llava/model/mpt/hf_prefixlm_converter.py @@ -10,21 +10,37 @@ import math import warnings from types import MethodType from typing import Any, Dict, List, Optional, Tuple, Union + import torch -from transformers.models.bloom.modeling_bloom import BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel, CausalLMOutputWithCrossAttentions, CrossEntropyLoss -from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom -from transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom +from transformers.models.bloom.modeling_bloom import ( + BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel, + CausalLMOutputWithCrossAttentions, CrossEntropyLoss) +from transformers.models.bloom.modeling_bloom import \ + _expand_mask as _expand_mask_bloom +from transformers.models.bloom.modeling_bloom import \ + _make_causal_mask as _make_causal_mask_bloom from transformers.models.bloom.modeling_bloom import logging from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM from transformers.models.gptj.modeling_gptj import GPTJForCausalLM from transformers.models.opt.modeling_opt import OPTForCausalLM -from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt -from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt +from transformers.models.opt.modeling_opt import \ + _expand_mask as _expand_mask_opt +from transformers.models.opt.modeling_opt import \ + _make_causal_mask as _make_causal_mask_opt + logger = logging.get_logger(__name__) -_SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM) -CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM] +_SUPPORTED_GPT_MODELS = ( + GPT2LMHeadModel, + GPTJForCausalLM, + GPTNeoForCausalLM, + GPTNeoXForCausalLM, +) +CAUSAL_GPT_TYPES = Union[ + GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM +] + def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES: """Converts a GPT-style Causal LM to a Prefix LM. @@ -37,10 +53,12 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T See `convert_hf_causal_lm_to_prefix_lm` for more details. """ - if hasattr(model, '_prefix_lm_converted'): + if hasattr(model, "_prefix_lm_converted"): return model assert isinstance(model, _SUPPORTED_GPT_MODELS) - assert model.config.add_cross_attention == False, 'Only supports GPT-style decoder-only models' + assert ( + model.config.add_cross_attention == False + ), "Only supports GPT-style decoder-only models" def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]: """Helper that gets a list of the model's attention modules. @@ -56,7 +74,7 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T blocks = model.transformer.h for block in blocks: if isinstance(model, GPTNeoForCausalLM): - if block.attn.attention_type != 'global': + if block.attn.attention_type != "global": continue attn_module = block.attn.attention elif isinstance(model, GPTNeoXForCausalLM): @@ -65,17 +83,58 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T attn_module = block.attn attn_modules.append(attn_module) return attn_modules - setattr(model, '_original_forward', getattr(model, 'forward')) - setattr(model, '_original_generate', getattr(model, 'generate')) - def forward(self: CAUSAL_GPT_TYPES, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]]=None, attention_mask: Optional[torch.FloatTensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, token_type_ids: Optional[torch.LongTensor]=None, position_ids: Optional[torch.LongTensor]=None, head_mask: Optional[torch.FloatTensor]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None): + setattr(model, "_original_forward", getattr(model, "forward")) + setattr(model, "_original_generate", getattr(model, "generate")) + + def forward( + self: CAUSAL_GPT_TYPES, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + bidirectional_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): """Wraps original forward to enable PrefixLM attention.""" def call_og_forward(): if isinstance(self, GPTNeoXForCausalLM): - return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict) + return self._original_forward( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) else: - return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict) + return self._original_forward( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + if bidirectional_mask is None: return call_og_forward() assert isinstance(bidirectional_mask, torch.Tensor) @@ -83,14 +142,23 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T (b, s) = bidirectional_mask.shape max_length = attn_modules[0].bias.shape[-1] if s > max_length: - raise ValueError(f'bidirectional_mask sequence length (={s}) exceeds the ' + f'max length allowed by the model ({max_length}).') + raise ValueError( + f"bidirectional_mask sequence length (={s}) exceeds the " + + f"max length allowed by the model ({max_length})." + ) assert s <= max_length if s < max_length: - pad = torch.zeros((int(b), int(max_length - s)), dtype=bidirectional_mask.dtype, device=bidirectional_mask.device) + pad = torch.zeros( + (int(b), int(max_length - s)), + dtype=bidirectional_mask.dtype, + device=bidirectional_mask.device, + ) bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1) bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1) for attn_module in attn_modules: - attn_module.bias.data = torch.logical_or(attn_module.bias.data, bidirectional) + attn_module.bias.data = torch.logical_or( + attn_module.bias.data, bidirectional + ) output = call_og_forward() for attn_module in attn_modules: attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None] @@ -105,11 +173,13 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T for attn_module in attn_modules: attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None] return output - setattr(model, 'forward', MethodType(forward, model)) - setattr(model, 'generate', MethodType(generate, model)) - setattr(model, '_prefix_lm_converted', True) + + setattr(model, "forward", MethodType(forward, model)) + setattr(model, "generate", MethodType(generate, model)) + setattr(model, "_prefix_lm_converted", True) return model + def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM: """Converts a BLOOM Causal LM to a Prefix LM. @@ -118,62 +188,137 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa See `convert_hf_causal_lm_to_prefix_lm` for more details. """ - if hasattr(model, '_prefix_lm_converted'): + if hasattr(model, "_prefix_lm_converted"): return model assert isinstance(model, BloomForCausalLM) - assert model.config.add_cross_attention == False, 'Only supports BLOOM decoder-only models' - - def _prepare_attn_mask(self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int) -> torch.BoolTensor: + assert ( + model.config.add_cross_attention == False + ), "Only supports BLOOM decoder-only models" + + def _prepare_attn_mask( + self: BloomModel, + attention_mask: torch.Tensor, + bidirectional_mask: Optional[torch.Tensor], + input_shape: Tuple[int, int], + past_key_values_length: int, + ) -> torch.BoolTensor: combined_attention_mask = None device = attention_mask.device (_, src_length) = input_shape if src_length > 1: - combined_attention_mask = _make_causal_mask_bloom(input_shape, device=device, past_key_values_length=past_key_values_length) + combined_attention_mask = _make_causal_mask_bloom( + input_shape, + device=device, + past_key_values_length=past_key_values_length, + ) if bidirectional_mask is not None: assert attention_mask.shape == bidirectional_mask.shape - expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length) - combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask) + expanded_bidirectional_mask = _expand_mask_bloom( + bidirectional_mask, tgt_length=src_length + ) + combined_attention_mask = torch.logical_and( + combined_attention_mask, expanded_bidirectional_mask + ) expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length) - combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask | combined_attention_mask + ) return combined_attention_mask - def _build_alibi_tensor(self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + def _build_alibi_tensor( + self: BloomModel, + batch_size: int, + query_length: int, + key_length: int, + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: num_heads = self.config.n_head closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) - base = torch.tensor(2 ** (-2 ** (-(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32) - powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32) + base = torch.tensor( + 2 ** (-(2 ** (-(math.log2(closest_power_of_2) - 3)))), + device=device, + dtype=torch.float32, + ) + powers = torch.arange( + 1, 1 + closest_power_of_2, device=device, dtype=torch.int32 + ) slopes = torch.pow(base, powers) if closest_power_of_2 != num_heads: - extra_base = torch.tensor(2 ** (-2 ** (-(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32) - num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) - extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32) + extra_base = torch.tensor( + 2 ** (-(2 ** (-(math.log2(2 * closest_power_of_2) - 3)))), + device=device, + dtype=torch.float32, + ) + num_remaining_heads = min( + closest_power_of_2, num_heads - closest_power_of_2 + ) + extra_powers = torch.arange( + 1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32 + ) slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1) ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1) diffs = qa - ka + key_length - query_length diffs = -diffs.abs() - alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(1, 1, query_length, key_length) - alibi = alibi.expand(batch_size, -1, -1, -1).reshape(-1, query_length, key_length) + alibi = slopes.view(1, num_heads, 1, 1) * diffs.view( + 1, 1, query_length, key_length + ) + alibi = alibi.expand(batch_size, -1, -1, -1).reshape( + -1, query_length, key_length + ) return alibi.to(dtype) + KeyValueT = Tuple[torch.Tensor, torch.Tensor] - def forward(self: BloomModel, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.LongTensor]=None, inputs_embeds: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: - if deprecated_arguments.pop('position_ids', False) is not False: - warnings.warn('`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. ' + 'You can safely ignore passing `position_ids`.', FutureWarning) + def forward( + self: BloomModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[KeyValueT, ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + bidirectional_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + if deprecated_arguments.pop("position_ids", False) is not False: + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. " + + "You can safely ignore passing `position_ids`.", + FutureWarning, + ) if len(deprecated_arguments) > 0: - raise ValueError(f'Got unexpected arguments: {deprecated_arguments}') - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) if input_ids is not None and inputs_embeds is not None: - raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time') + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) elif input_ids is not None: (batch_size, seq_length) = input_ids.shape elif inputs_embeds is not None: (batch_size, seq_length, _) = inputs_embeds.shape else: - raise ValueError('You have to specify either input_ids or inputs_embeds') + raise ValueError("You have to specify either input_ids or inputs_embeds") if past_key_values is None: past_key_values = tuple([None] * len(self.h)) head_mask = self.get_head_mask(head_mask, self.config.n_layer) @@ -190,28 +335,62 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa past_key_values_length = tmp.shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + attention_mask = torch.ones( + (batch_size, seq_length_with_past), device=hidden_states.device + ) else: attention_mask = attention_mask.to(hidden_states.device) - alibi = self._build_alibi_tensor(batch_size=batch_size, query_length=seq_length, key_length=seq_length_with_past, dtype=hidden_states.dtype, device=hidden_states.device) - causal_mask = self._prepare_attn_mask(attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length) - for (i, (block, layer_past)) in enumerate(zip(self.h, past_key_values)): + alibi = self._build_alibi_tensor( + batch_size=batch_size, + query_length=seq_length, + key_length=seq_length_with_past, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + causal_mask = self._prepare_attn_mask( + attention_mask, + bidirectional_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: hst = (hidden_states,) all_hidden_states = all_hidden_states + hst if self.gradient_checkpointing and self.training: if use_cache: - logger.warning('`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...') + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) + return module( + *inputs, + use_cache=use_cache, + output_attentions=output_attentions, + ) + return custom_forward - outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(block), hidden_states, alibi, causal_mask, head_mask[i]) + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + alibi, + causal_mask, + head_mask[i], + ) else: - outputs = block(hidden_states, layer_past=layer_past, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi) + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + ) hidden_states = outputs[0] if use_cache is True: presents = presents + (outputs[1],) @@ -223,21 +402,77 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa hst = (hidden_states,) all_hidden_states = all_hidden_states + hst if not return_dict: - return tuple((v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)) - return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions) - setattr(model.transformer, '_prepare_attn_mask', MethodType(_prepare_attn_mask, model.transformer)) - setattr(model.transformer, '_build_alibi_tensor', MethodType(_build_alibi_tensor, model.transformer)) - setattr(model.transformer, 'forward', MethodType(forward, model.transformer)) + return tuple( + ( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + setattr( + model.transformer, + "_prepare_attn_mask", + MethodType(_prepare_attn_mask, model.transformer), + ) + setattr( + model.transformer, + "_build_alibi_tensor", + MethodType(_build_alibi_tensor, model.transformer), + ) + setattr(model.transformer, "forward", MethodType(forward, model.transformer)) KeyValueT = Tuple[torch.Tensor, torch.Tensor] - def forward(self: BloomForCausalLM, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.Tensor]=None, inputs_embeds: Optional[torch.Tensor]=None, labels: Optional[torch.Tensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + def forward( + self: BloomForCausalLM, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[KeyValueT, ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + bidirectional_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: """Replacement forward method for BloomCausalLM.""" - if deprecated_arguments.pop('position_ids', False) is not False: - warnings.warn('`position_ids` have no functionality in BLOOM and will be removed ' + 'in v5.0.0. You can safely ignore passing `position_ids`.', FutureWarning) + if deprecated_arguments.pop("position_ids", False) is not False: + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed " + + "in v5.0.0. You can safely ignore passing `position_ids`.", + FutureWarning, + ) if len(deprecated_arguments) > 0: - raise ValueError(f'Got unexpected arguments: {deprecated_arguments}') - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask, bidirectional_mask=bidirectional_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict) + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + bidirectional_mask=bidirectional_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) loss = None @@ -246,13 +481,28 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa shift_labels = labels[..., 1:].contiguous() (batch_size, seq_length, vocab_size) = shift_logits.shape loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)) + loss = loss_fct( + shift_logits.view(batch_size * seq_length, vocab_size), + shift_labels.view(batch_size * seq_length), + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return (loss,) + output if loss is not None else output - return CausalLMOutputWithCrossAttentions(loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions) - - def prepare_inputs_for_generation(self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, **kwargs) -> dict: + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def prepare_inputs_for_generation( + self: BloomForCausalLM, + input_ids: torch.LongTensor, + past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> dict: if past: input_ids = input_ids[:, -1].unsqueeze(-1) bidirectional_mask = None @@ -260,12 +510,24 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa past = self._convert_to_bloom_cache(past) else: bidirectional_mask = torch.ones_like(input_ids) - return {'input_ids': input_ids, 'past_key_values': past, 'use_cache': True, 'attention_mask': attention_mask, 'bidirectional_mask': bidirectional_mask} - setattr(model, 'forward', MethodType(forward, model)) - setattr(model, 'prepare_inputs_for_generation', MethodType(prepare_inputs_for_generation, model)) - setattr(model, '_prefix_lm_converted', True) + return { + "input_ids": input_ids, + "past_key_values": past, + "use_cache": True, + "attention_mask": attention_mask, + "bidirectional_mask": bidirectional_mask, + } + + setattr(model, "forward", MethodType(forward, model)) + setattr( + model, + "prepare_inputs_for_generation", + MethodType(prepare_inputs_for_generation, model), + ) + setattr(model, "_prefix_lm_converted", True) return model + def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM: """Converts an OPT Causal LM to a Prefix LM. @@ -274,36 +536,89 @@ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM See `convert_hf_causal_lm_to_prefix_lm` for more details. """ - if hasattr(model, '_prefix_lm_converted'): + if hasattr(model, "_prefix_lm_converted"): return model assert isinstance(model, OPTForCausalLM) - assert model.config.add_cross_attention == False, 'Only supports OPT decoder-only models' - setattr(model, '_original_forward', getattr(model, 'forward')) - setattr(model, '_original_generate', getattr(model, 'generate')) + assert ( + model.config.add_cross_attention == False + ), "Only supports OPT decoder-only models" + setattr(model, "_original_forward", getattr(model, "forward")) + setattr(model, "_original_generate", getattr(model, "generate")) model.model.decoder.bidirectional_mask = None - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length + ): combined_attention_mask = None if input_shape[-1] > 1: - if self.bidirectional_mask == 'g': + if self.bidirectional_mask == "g": (bsz, src_length) = input_shape - combined_attention_mask = torch.zeros((bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device) + combined_attention_mask = torch.zeros( + (bsz, 1, src_length, src_length + past_key_values_length), + dtype=inputs_embeds.dtype, + device=inputs_embeds.device, + ) else: - combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(inputs_embeds.device) + combined_attention_mask = _make_causal_mask_opt( + input_shape, + inputs_embeds.dtype, + past_key_values_length=past_key_values_length, + ).to(inputs_embeds.device) if self.bidirectional_mask is not None: assert attention_mask.shape == self.bidirectional_mask.shape - expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device) - combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask) + expanded_bidirectional_mask = _expand_mask_opt( + self.bidirectional_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ).to(inputs_embeds.device) + combined_attention_mask = torch.maximum( + expanded_bidirectional_mask, combined_attention_mask + ) if attention_mask is not None: - expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device) - combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + expanded_attn_mask = _expand_mask_opt( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) return combined_attention_mask - setattr(model.model.decoder, '_prepare_decoder_attention_mask', MethodType(_prepare_decoder_attention_mask, model.model.decoder)) - - def forward(self: OPTForCausalLM, input_ids: Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.ByteTensor]=None, head_mask: Optional[torch.Tensor]=None, past_key_values: Optional[List[torch.FloatTensor]]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None): + setattr( + model.model.decoder, + "_prepare_decoder_attention_mask", + MethodType(_prepare_decoder_attention_mask, model.model.decoder), + ) + + def forward( + self: OPTForCausalLM, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + bidirectional_mask: Optional[torch.ByteTensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): def call_og_forward(): - return self._original_forward(input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict) + return self._original_forward( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + if bidirectional_mask is None: return call_og_forward() self.model.decoder.bidirectional_mask = bidirectional_mask @@ -317,7 +632,7 @@ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]): """Wraps original generate to enable PrefixLM-style attention.""" - self.model.decoder.bidirectional_mask = 'g' + self.model.decoder.bidirectional_mask = "g" try: output = self._original_generate(*args, **kwargs) except: @@ -325,12 +640,23 @@ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM raise self.model.decoder.bidirectional_mask = None return output - setattr(model, 'forward', MethodType(forward, model)) - setattr(model, 'generate', MethodType(generate, model)) - setattr(model, '_prefix_lm_converted', True) + + setattr(model, "forward", MethodType(forward, model)) + setattr(model, "generate", MethodType(generate, model)) + setattr(model, "_prefix_lm_converted", True) return model + + _SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM) -CAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM] +CAUSAL_LM_TYPES = Union[ + GPT2LMHeadModel, + GPTJForCausalLM, + GPTNeoForCausalLM, + GPTNeoXForCausalLM, + BloomForCausalLM, + OPTForCausalLM, +] + def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES: """Converts a HuggingFace Causal LM to a Prefix LM. @@ -396,7 +722,12 @@ def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES elif isinstance(model, OPTForCausalLM): return _convert_opt_causal_lm_to_prefix_lm(model) else: - raise TypeError(f'Cannot convert model to Prefix LM. ' + f'Model does not belong to set of supported HF models:' + f'\n{_SUPPORTED_HF_MODELS}') + raise TypeError( + f"Cannot convert model to Prefix LM. " + + f"Model does not belong to set of supported HF models:" + + f"\n{_SUPPORTED_HF_MODELS}" + ) + def add_bidirectional_mask_if_missing(batch: Dict[str, Any]): """Attempts to add bidirectional_mask to batch if missing. @@ -404,12 +735,16 @@ def add_bidirectional_mask_if_missing(batch: Dict[str, Any]): Raises: KeyError if bidirectional_mask is missing and can't be inferred """ - if 'bidirectional_mask' not in batch: - if batch.get('mode', None) == 'icl_task': - batch['bidirectional_mask'] = batch['attention_mask'].clone() - for (i, continuation_indices) in enumerate(batch['continuation_indices']): - batch['bidirectional_mask'][i, continuation_indices] = 0 - elif 'labels' in batch and 'attention_mask' in batch: - batch['bidirectional_mask'] = torch.logical_and(torch.eq(batch['attention_mask'], 1), torch.eq(batch['labels'], -100)).type_as(batch['attention_mask']) + if "bidirectional_mask" not in batch: + if batch.get("mode", None) == "icl_task": + batch["bidirectional_mask"] = batch["attention_mask"].clone() + for i, continuation_indices in enumerate(batch["continuation_indices"]): + batch["bidirectional_mask"][i, continuation_indices] = 0 + elif "labels" in batch and "attention_mask" in batch: + batch["bidirectional_mask"] = torch.logical_and( + torch.eq(batch["attention_mask"], 1), torch.eq(batch["labels"], -100) + ).type_as(batch["attention_mask"]) else: - raise KeyError('No bidirectional_mask in batch and not sure how to construct one.') \ No newline at end of file + raise KeyError( + "No bidirectional_mask in batch and not sure how to construct one." + ) diff --git a/model/llava/model/mpt/meta_init_context.py b/model/llava/model/mpt/meta_init_context.py index 6cba6fff0fe21fe222c7ab38eae44a9784c0be9c..208ab255cedb65e5c444b1c5fa5abf72cbdb1512 100644 --- a/model/llava/model/mpt/meta_init_context.py +++ b/model/llava/model/mpt/meta_init_context.py @@ -1,9 +1,11 @@ from contextlib import contextmanager + import torch import torch.nn as nn + @contextmanager -def init_empty_weights(include_buffers: bool=False): +def init_empty_weights(include_buffers: bool = False): """Meta initialization context manager. A context manager under which models are initialized with all parameters @@ -30,11 +32,12 @@ def init_empty_weights(include_buffers: bool=False): """ - with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f: + with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f: yield f + @contextmanager -def init_on_device(device: torch.device, include_buffers: bool=False): +def init_on_device(device: torch.device, include_buffers: bool = False): """Device initialization context manager. A context manager under which models are initialized with all parameters @@ -62,33 +65,47 @@ def init_on_device(device: torch.device, include_buffers: bool=False): if param is not None: param_cls = type(module._parameters[name]) kwargs = module._parameters[name].__dict__ - module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) + module._parameters[name] = param_cls( + module._parameters[name].to(device), **kwargs + ) def register_empty_buffer(module, name, buffer): old_register_buffer(module, name, buffer) if buffer is not None: module._buffers[name] = module._buffers[name].to(device) + if include_buffers: - tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']} + tensor_constructors_to_patch = { + torch_function_name: getattr(torch, torch_function_name) + for torch_function_name in ["empty", "zeros", "ones", "full"] + } else: tensor_constructors_to_patch = {} def patch_tensor_constructor(fn): - def wrapper(*args, **kwargs): - kwargs['device'] = device + kwargs["device"] = device return fn(*args, **kwargs) + return wrapper + try: nn.Module.register_parameter = register_empty_parameter if include_buffers: nn.Module.register_buffer = register_empty_buffer for torch_function_name in tensor_constructors_to_patch.keys(): - setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) + setattr( + torch, + torch_function_name, + patch_tensor_constructor(getattr(torch, torch_function_name)), + ) yield finally: nn.Module.register_parameter = old_register_parameter if include_buffers: nn.Module.register_buffer = old_register_buffer - for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items(): - setattr(torch, torch_function_name, old_torch_function) \ No newline at end of file + for ( + torch_function_name, + old_torch_function, + ) in tensor_constructors_to_patch.items(): + setattr(torch, torch_function_name, old_torch_function) diff --git a/model/llava/model/mpt/modeling_mpt.py b/model/llava/model/mpt/modeling_mpt.py index 5c3144a9872b7cf8df3bcab58e2f12ecc292d5c0..070c151e292f0a360bc468113602fcab1f8e594a 100644 --- a/model/llava/model/mpt/modeling_mpt.py +++ b/model/llava/model/mpt/modeling_mpt.py @@ -5,68 +5,95 @@ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py import math import warnings from typing import List, Optional, Tuple, Union + import torch import torch.nn as nn import torch.nn.functional as F -from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers import (PreTrainedModel, PreTrainedTokenizer, + PreTrainedTokenizerFast) +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) + +from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising from .attention import attn_bias_shape, build_attn_bias from .blocks import MPTBlock -from .norm import NORM_CLASS_REGISTRY from .configuration_mpt import MPTConfig -from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising -from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm +from .hf_prefixlm_converter import (add_bidirectional_mask_if_missing, + convert_hf_causal_lm_to_prefix_lm) from .meta_init_context import init_empty_weights +from .norm import NORM_CLASS_REGISTRY from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_ + Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] from transformers.utils import logging + logger = logging.get_logger(__name__) + class MPTPreTrainedModel(PreTrainedModel): config_class = MPTConfig - base_model_prefix = 'model' + base_model_prefix = "model" -class MPTModel(MPTPreTrainedModel): +class MPTModel(MPTPreTrainedModel): def __init__(self, config: MPTConfig): config._validate_config() super().__init__(config) - self.attn_impl = config.attn_config['attn_impl'] - self.prefix_lm = config.attn_config['prefix_lm'] - self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id'] - self.alibi = config.attn_config['alibi'] - self.alibi_bias_max = config.attn_config['alibi_bias_max'] + self.attn_impl = config.attn_config["attn_impl"] + self.prefix_lm = config.attn_config["prefix_lm"] + self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"] + self.alibi = config.attn_config["alibi"] + self.alibi_bias_max = config.attn_config["alibi_bias_max"] if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys(): - norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys()) - raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).') + norm_options = " | ".join(NORM_CLASS_REGISTRY.keys()) + raise NotImplementedError( + f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})." + ) norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()] self.embedding_fraction = config.embedding_fraction - self.wte = nn.Embedding(config.vocab_size, config.d_model, device=config.init_device) + self.wte = nn.Embedding( + config.vocab_size, config.d_model, device=config.init_device + ) if not self.alibi: - self.wpe = nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device) + self.wpe = nn.Embedding( + config.max_seq_len, config.d_model, device=config.init_device + ) self.emb_drop = nn.Dropout(config.emb_pdrop) - self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)]) + self.blocks = nn.ModuleList( + [ + MPTBlock(device=config.init_device, **config.to_dict()) + for _ in range(config.n_layers) + ] + ) self.norm_f = norm_class(config.d_model, device=config.init_device) - if config.init_device != 'meta': + if config.init_device != "meta": self.apply(self.param_init_fn) self.is_causal = not self.prefix_lm self._attn_bias_initialized = False self.attn_bias = None - self.attn_bias_shape = attn_bias_shape(self.attn_impl, config.n_heads, config.max_seq_len, self.alibi, prefix_lm=self.prefix_lm, causal=self.is_causal, use_sequence_id=self.attn_uses_sequence_id) + self.attn_bias_shape = attn_bias_shape( + self.attn_impl, + config.n_heads, + config.max_seq_len, + self.alibi, + prefix_lm=self.prefix_lm, + causal=self.is_causal, + use_sequence_id=self.attn_uses_sequence_id, + ) if config.no_bias: for module in self.modules(): - if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter): + if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter): if config.verbose: - warnings.warn(f'Removing bias ({module.bias}) from {module}.') - module.register_parameter('bias', None) + warnings.warn(f"Removing bias ({module.bias}) from {module}.") + module.register_parameter("bias", None) if config.verbose and config.verbose > 2: print(self) - if 'verbose' not in self.config.init_config: - self.config.init_config['verbose'] = self.config.verbose - if self.config.init_config['verbose'] > 1: - init_fn_name = self.config.init_config['name'] - warnings.warn(f'Using {init_fn_name} initialization.') + if "verbose" not in self.config.init_config: + self.config.init_config["verbose"] = self.config.verbose + if self.config.init_config["verbose"] > 1: + init_fn_name = self.config.init_config["name"] + warnings.warn(f"Using {init_fn_name} initialization.") self.gradient_checkpointing = False def get_input_embeddings(self): @@ -76,13 +103,30 @@ class MPTModel(MPTPreTrainedModel): self.wte = value @torch.no_grad() - def _attn_bias(self, device, dtype, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None): + def _attn_bias( + self, + device, + dtype, + attention_mask: Optional[torch.ByteTensor] = None, + prefix_mask: Optional[torch.ByteTensor] = None, + sequence_id: Optional[torch.LongTensor] = None, + ): if not self._attn_bias_initialized: if self.attn_bias_shape: - self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype) - self.attn_bias = build_attn_bias(self.attn_impl, self.attn_bias, self.config.n_heads, self.config.max_seq_len, causal=self.is_causal, alibi=self.alibi, alibi_bias_max=self.alibi_bias_max) + self.attn_bias = torch.zeros( + self.attn_bias_shape, device=device, dtype=dtype + ) + self.attn_bias = build_attn_bias( + self.attn_impl, + self.attn_bias, + self.config.n_heads, + self.config.max_seq_len, + causal=self.is_causal, + alibi=self.alibi, + alibi_bias_max=self.alibi_bias_max, + ) self._attn_bias_initialized = True - if self.attn_impl == 'flash': + if self.attn_impl == "flash": return (self.attn_bias, attention_mask) if self.attn_bias is not None: self.attn_bias = self.attn_bias.to(dtype=dtype, device=device) @@ -101,38 +145,71 @@ class MPTModel(MPTPreTrainedModel): else: attn_bias = attn_bias[:, :, :, -s_k:] if prefix_mask is not None and attention_mask.shape != prefix_mask.shape: - raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.') + raise ValueError( + f"attention_mask shape={attention_mask.shape} " + + f"and prefix_mask shape={prefix_mask.shape} are not equal." + ) min_val = torch.finfo(attn_bias.dtype).min - attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val) + attn_bias = attn_bias.masked_fill( + ~attention_mask.view(-1, 1, 1, s_k), min_val + ) return (attn_bias, None) def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor): (s_k, s_q) = attn_bias.shape[-2:] if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len: - raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ' + f'but are {s_k} and {s_q}.') + raise ValueError( + "attn_bias does not match the expected shape. " + + f"The last two dimensions should both be {self.config.max_length} " + + f"but are {s_k} and {s_q}." + ) seq_len = prefix_mask.shape[-1] if seq_len > self.config.max_seq_len: - raise ValueError(f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}') + raise ValueError( + f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}" + ) attn_bias = attn_bias[..., :seq_len, :seq_len] - causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)).view(1, 1, seq_len, seq_len) + causal = torch.tril( + torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device) + ).view(1, 1, seq_len, seq_len) prefix = prefix_mask.view(-1, 1, 1, seq_len) cannot_attend = ~torch.logical_or(causal, prefix.bool()) min_val = torch.finfo(attn_bias.dtype).min attn_bias = attn_bias.masked_fill(cannot_attend, min_val) return attn_bias - def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor): + def _apply_sequence_id( + self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor + ): seq_len = sequence_id.shape[-1] if seq_len > self.config.max_seq_len: - raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}') + raise ValueError( + f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}" + ) attn_bias = attn_bias[..., :seq_len, :seq_len] - cannot_attend = torch.logical_not(torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))).unsqueeze(1) + cannot_attend = torch.logical_not( + torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len)) + ).unsqueeze(1) min_val = torch.finfo(attn_bias.dtype).min attn_bias = attn_bias.masked_fill(cannot_attend, min_val) return attn_bias - def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, tok_emb: Optional[torch.FloatTensor]=None): - return_dict = return_dict if return_dict is not None else self.config.return_dict + def forward( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.ByteTensor] = None, + prefix_mask: Optional[torch.ByteTensor] = None, + sequence_id: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + use_cache: Optional[bool] = None, + tok_emb: Optional[torch.FloatTensor] = None, + ): + return_dict = ( + return_dict if return_dict is not None else self.config.return_dict + ) use_cache = use_cache if use_cache is not None else self.config.use_cache if self.gradient_checkpointing and self.training: @@ -146,21 +223,41 @@ class MPTModel(MPTPreTrainedModel): if prefix_mask is not None: prefix_mask = prefix_mask.bool() if not return_dict: - raise NotImplementedError('return_dict False is not implemented yet for MPT') + raise NotImplementedError( + "return_dict False is not implemented yet for MPT" + ) if output_attentions: - raise NotImplementedError('output_attentions is not implemented yet for MPT') - if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training: - raise NotImplementedError('MPT does not support training with left padding.') + raise NotImplementedError( + "output_attentions is not implemented yet for MPT" + ) + if ( + attention_mask is not None + and attention_mask[:, 0].sum() != attention_mask.shape[0] + and self.training + ): + raise NotImplementedError( + "MPT does not support training with left padding." + ) if self.prefix_lm and prefix_mask is None: - raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.') + raise ValueError( + "prefix_mask is a required argument when MPT is configured with prefix_lm=True." + ) if self.training: if self.attn_uses_sequence_id and sequence_id is None: - raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.') + raise ValueError( + "sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True " + + "and the model is in train mode." + ) elif self.attn_uses_sequence_id is False and sequence_id is not None: - warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.') + warnings.warn( + "MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. " + + "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True." + ) if input_ids is not None: S = input_ids.size(1) - assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' + assert ( + S <= self.config.max_seq_len + ), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}" tok_emb = self.wte(input_ids) else: assert tok_emb is not None @@ -171,45 +268,85 @@ class MPTModel(MPTPreTrainedModel): past_position = 0 if past_key_values is not None: if len(past_key_values) != self.config.n_layers: - raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).') + raise ValueError( + f"past_key_values must provide a past_key_value for each attention " + + f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})." + ) past_position = past_key_values[0][0].size(1) if S + past_position > self.config.max_seq_len: - raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.') - pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0) + raise ValueError( + f"Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}." + ) + pos = torch.arange( + past_position, + S + past_position, + dtype=torch.long, + device=input_ids.device, + ).unsqueeze(0) if attention_mask is not None: - pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0) + pos = torch.clamp( + pos + - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[ + :, past_position: + ], + min=0, + ) pos_emb = self.wpe(pos) x = tok_emb + pos_emb if self.embedding_fraction == 1: x = self.emb_drop(x) else: - x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction) + x_shrunk = x * self.embedding_fraction + x.detach() * ( + 1 - self.embedding_fraction + ) assert isinstance(self.emb_drop, nn.Module) x = self.emb_drop(x_shrunk) - (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id) + (attn_bias, attention_mask) = self._attn_bias( + device=x.device, + dtype=x.dtype, + attention_mask=attention_mask, + prefix_mask=prefix_mask, + sequence_id=sequence_id, + ) if use_cache and past_key_values is None: past_key_values = [() for _ in range(self.config.n_layers)] all_hidden_states = () if output_hidden_states else None - for (b_idx, block) in enumerate(self.blocks): + for b_idx, block in enumerate(self.blocks): if output_hidden_states: assert all_hidden_states is not None all_hidden_states = all_hidden_states + (x,) - past_key_value = past_key_values[b_idx] if past_key_values is not None else None + past_key_value = ( + past_key_values[b_idx] if past_key_values is not None else None + ) if self.gradient_checkpointing and self.training: (x, past_key_value) = torch.utils.checkpoint.checkpoint( - block, - x, past_key_value, attn_bias, attention_mask, self.is_causal + block, x, past_key_value, attn_bias, attention_mask, self.is_causal ) else: - (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal) + (x, past_key_value) = block( + x, + past_key_value=past_key_value, + attn_bias=attn_bias, + attention_mask=attention_mask, + is_causal=self.is_causal, + ) if past_key_values is not None: past_key_values[b_idx] = past_key_value x = self.norm_f(x) - return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=x, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + ) def param_init_fn(self, module): - init_fn_name = self.config.init_config['name'] - MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config) + init_fn_name = self.config.init_config["name"] + MODEL_INIT_REGISTRY[init_fn_name]( + module=module, + n_layers=self.config.n_layers, + d_model=self.config.d_model, + **self.config.init_config, + ) def fsdp_wrap_fn(self, module): return isinstance(module, MPTBlock) @@ -217,21 +354,23 @@ class MPTModel(MPTPreTrainedModel): def activation_checkpointing_fn(self, module): return isinstance(module, MPTBlock) -class MPTForCausalLM(MPTPreTrainedModel): +class MPTForCausalLM(MPTPreTrainedModel): def __init__(self, config: MPTConfig): super().__init__(config) if not config.tie_word_embeddings: - raise ValueError('MPTForCausalLM only supports tied word embeddings') + raise ValueError("MPTForCausalLM only supports tied word embeddings") self.transformer = MPTModel(config) self.logit_scale = None if config.logit_scale is not None: logit_scale = config.logit_scale if isinstance(logit_scale, str): - if logit_scale == 'inv_sqrt_d_model': + if logit_scale == "inv_sqrt_d_model": logit_scale = 1 / math.sqrt(config.d_model) else: - raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") + raise ValueError( + f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." + ) self.logit_scale = logit_scale def get_input_embeddings(self): @@ -252,25 +391,63 @@ class MPTForCausalLM(MPTPreTrainedModel): def get_decoder(self): return self.transformer - def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None): - return_dict = return_dict if return_dict is not None else self.config.return_dict + def forward( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.ByteTensor] = None, + prefix_mask: Optional[torch.ByteTensor] = None, + sequence_id: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + use_cache: Optional[bool] = None, + ): + return_dict = ( + return_dict if return_dict is not None else self.config.return_dict + ) use_cache = use_cache if use_cache is not None else self.config.use_cache - outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache) + outputs = self.transformer( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + prefix_mask=prefix_mask, + sequence_id=sequence_id, + return_dict=return_dict, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + ) logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight) if self.logit_scale is not None: if self.logit_scale == 0: - warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.') + warnings.warn( + f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs." + ) logits *= self.logit_scale loss = None if labels is not None: labels = torch.roll(labels, shifts=-1) labels[:, -1] = -100 - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)) - return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1) + ) + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + ) def param_init_fn(self, module): - init_fn_name = self.config.init_config['name'] - MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config) + init_fn_name = self.config.init_config["name"] + MODEL_INIT_REGISTRY[init_fn_name]( + module=module, + n_layers=self.config.n_layers, + d_model=self.config.d_model, + **self.config.init_config, + ) def fsdp_wrap_fn(self, module): return isinstance(module, MPTBlock) @@ -278,12 +455,16 @@ class MPTForCausalLM(MPTPreTrainedModel): def activation_checkpointing_fn(self, module): return isinstance(module, MPTBlock) - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs + ): if inputs_embeds is not None: - raise NotImplementedError('inputs_embeds is not implemented for MPT yet') - attention_mask = kwargs['attention_mask'].bool() + raise NotImplementedError("inputs_embeds is not implemented for MPT yet") + attention_mask = kwargs["attention_mask"].bool() if attention_mask[:, -1].sum() != attention_mask.shape[0]: - raise NotImplementedError('MPT does not support generation with right padding.') + raise NotImplementedError( + "MPT does not support generation with right padding." + ) if self.transformer.attn_uses_sequence_id and self.training: sequence_id = torch.zeros_like(input_ids[:1]) else: @@ -292,11 +473,20 @@ class MPTForCausalLM(MPTPreTrainedModel): input_ids = input_ids[:, -1].unsqueeze(-1) if self.transformer.prefix_lm: prefix_mask = torch.ones_like(attention_mask) - if kwargs.get('use_cache') == False: - raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.') + if kwargs.get("use_cache") == False: + raise NotImplementedError( + "MPT with prefix_lm=True does not support use_cache=False." + ) else: prefix_mask = None - return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True)} + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "prefix_mask": prefix_mask, + "sequence_id": sequence_id, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache", True), + } @staticmethod def _reorder_cache(past_key_values, beam_idx): @@ -307,5 +497,9 @@ class MPTForCausalLM(MPTPreTrainedModel): """ reordered_past = [] for layer_past in past_key_values: - reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))] - return reordered_past \ No newline at end of file + reordered_past += [ + tuple( + (past_state.index_select(0, beam_idx) for past_state in layer_past) + ) + ] + return reordered_past diff --git a/model/llava/model/mpt/norm.py b/model/llava/model/mpt/norm.py index bec4a4ca3304c2188312387743a49b75015542be..42fa6d9c84a3c3cf8190a86dc5ca86b7412763b7 100644 --- a/model/llava/model/mpt/norm.py +++ b/model/llava/model/mpt/norm.py @@ -1,28 +1,55 @@ import torch + def _cast_if_autocast_enabled(tensor): if torch.is_autocast_enabled(): - if tensor.device.type == 'cuda': + if tensor.device.type == "cuda": dtype = torch.get_autocast_gpu_dtype() - elif tensor.device.type == 'cpu': + elif tensor.device.type == "cpu": dtype = torch.get_autocast_cpu_dtype() else: raise NotImplementedError() return tensor.to(dtype=dtype) return tensor -class LPLayerNorm(torch.nn.LayerNorm): - def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None): - super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) +class LPLayerNorm(torch.nn.LayerNorm): + def __init__( + self, + normalized_shape, + eps=1e-05, + elementwise_affine=True, + device=None, + dtype=None, + ): + super().__init__( + normalized_shape=normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + device=device, + dtype=dtype, + ) def forward(self, x): module_device = x.device downcast_x = _cast_if_autocast_enabled(x) - downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight - downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias + downcast_weight = ( + _cast_if_autocast_enabled(self.weight) + if self.weight is not None + else self.weight + ) + downcast_bias = ( + _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias + ) with torch.autocast(enabled=False, device_type=module_device.type): - return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) + return torch.nn.functional.layer_norm( + downcast_x, + self.normalized_shape, + downcast_weight, + downcast_bias, + self.eps, + ) + def rms_norm(x, weight=None, eps=1e-05): output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) @@ -30,27 +57,50 @@ def rms_norm(x, weight=None, eps=1e-05): return output * weight return output -class RMSNorm(torch.nn.Module): - def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): +class RMSNorm(torch.nn.Module): + def __init__( + self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None + ): super().__init__() self.eps = eps if weight: - self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device)) + self.weight = torch.nn.Parameter( + torch.ones(normalized_shape, dtype=dtype, device=device) + ) else: - self.register_parameter('weight', None) + self.register_parameter("weight", None) def forward(self, x): return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) -class LPRMSNorm(RMSNorm): - def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): - super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device) +class LPRMSNorm(RMSNorm): + def __init__( + self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None + ): + super().__init__( + normalized_shape=normalized_shape, + eps=eps, + weight=weight, + dtype=dtype, + device=device, + ) def forward(self, x): downcast_x = _cast_if_autocast_enabled(x) - downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight + downcast_weight = ( + _cast_if_autocast_enabled(self.weight) + if self.weight is not None + else self.weight + ) with torch.autocast(enabled=False, device_type=x.device.type): return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) -NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm} \ No newline at end of file + + +NORM_CLASS_REGISTRY = { + "layernorm": torch.nn.LayerNorm, + "low_precision_layernorm": LPLayerNorm, + "rmsnorm": RMSNorm, + "low_precision_rmsnorm": LPRMSNorm, +} diff --git a/model/llava/model/mpt/param_init_fns.py b/model/llava/model/mpt/param_init_fns.py index 418b83ca2363288046f4b48b1d706c5607341fb5..5c1d17a22a62e4411a537e2d7c0c96422e4a4174 100644 --- a/model/llava/model/mpt/param_init_fns.py +++ b/model/llava/model/mpt/param_init_fns.py @@ -3,101 +3,139 @@ import warnings from collections.abc import Sequence from functools import partial from typing import Optional, Tuple, Union + import torch from torch import nn + from .norm import NORM_CLASS_REGISTRY -def torch_default_param_init_fn_(module: nn.Module, verbose: int=0, **kwargs): + +def torch_default_param_init_fn_(module: nn.Module, verbose: int = 0, **kwargs): del kwargs if verbose > 1: warnings.warn(f"Initializing network using module's reset_parameters attribute") - if hasattr(module, 'reset_parameters'): + if hasattr(module, "reset_parameters"): module.reset_parameters() + def fused_init_helper_(module: nn.Module, init_fn_): - _fused = getattr(module, '_fused', None) + _fused = getattr(module, "_fused", None) if _fused is None: - raise RuntimeError(f'Internal logic error') + raise RuntimeError(f"Internal logic error") (dim, splits) = _fused splits = (0, *splits, module.weight.size(dim)) - for (s, e) in zip(splits[:-1], splits[1:]): + for s, e in zip(splits[:-1], splits[1:]): slice_indices = [slice(None)] * module.weight.ndim slice_indices[dim] = slice(s, e) init_fn_(module.weight[slice_indices]) -def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): + +def generic_param_init_fn_( + module: nn.Module, + init_fn_, + n_layers: int, + d_model: Optional[int] = None, + init_div_is_residual: Union[int, float, str, bool] = True, + emb_init_std: Optional[float] = None, + emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, + verbose: int = 0, + **kwargs, +): del kwargs if verbose > 1: - warnings.warn(f'If model has bias parameters they are initialized to 0.') + warnings.warn(f"If model has bias parameters they are initialized to 0.") init_div_is_residual = init_div_is_residual if init_div_is_residual is False: div_is_residual = 1.0 elif init_div_is_residual is True: div_is_residual = math.sqrt(2 * n_layers) - elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int): + elif isinstance(init_div_is_residual, float) or isinstance( + init_div_is_residual, int + ): div_is_residual = init_div_is_residual elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric(): div_is_residual = float(init_div_is_residual) else: div_is_residual = 1.0 - raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}') + raise ValueError( + f"Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}" + ) if init_div_is_residual is not False: if verbose > 1: - warnings.warn(f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' + f'Set `init_div_is_residual: false` in init config to disable this.') + warnings.warn( + f"Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. " + + f"Set `init_div_is_residual: false` in init config to disable this." + ) if isinstance(module, nn.Linear): - if hasattr(module, '_fused'): + if hasattr(module, "_fused"): fused_init_helper_(module, init_fn_) else: init_fn_(module.weight) if module.bias is not None: torch.nn.init.zeros_(module.bias) - if init_div_is_residual is not False and getattr(module, '_is_residual', False): + if init_div_is_residual is not False and getattr(module, "_is_residual", False): with torch.no_grad(): module.weight.div_(div_is_residual) elif isinstance(module, nn.Embedding): if emb_init_std is not None: std = emb_init_std if std == 0: - warnings.warn(f'Embedding layer initialized to 0.') + warnings.warn(f"Embedding layer initialized to 0.") emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std) if verbose > 1: - warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.') + warnings.warn( + f"Embedding layer initialized using normal distribution with mean=0 and std={std!r}." + ) elif emb_init_uniform_lim is not None: lim = emb_init_uniform_lim if isinstance(lim, Sequence): if len(lim) > 2: - raise ValueError(f'Uniform init requires a min and a max limit. User input: {lim}.') + raise ValueError( + f"Uniform init requires a min and a max limit. User input: {lim}." + ) if lim[0] == lim[1]: - warnings.warn(f'Embedding layer initialized to {lim[0]}.') + warnings.warn(f"Embedding layer initialized to {lim[0]}.") else: if lim == 0: - warnings.warn(f'Embedding layer initialized to 0.') + warnings.warn(f"Embedding layer initialized to 0.") lim = [-lim, lim] (a, b) = lim emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b) if verbose > 1: - warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.') + warnings.warn( + f"Embedding layer initialized using uniform distribution in range {lim}." + ) else: emb_init_fn_ = init_fn_ emb_init_fn_(module.weight) elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))): if verbose > 1: - warnings.warn(f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.') - if hasattr(module, 'weight') and module.weight is not None: + warnings.warn( + f"Norm weights are set to 1. If norm layer has a bias it is initialized to 0." + ) + if hasattr(module, "weight") and module.weight is not None: torch.nn.init.ones_(module.weight) - if hasattr(module, 'bias') and module.bias is not None: + if hasattr(module, "bias") and module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.MultiheadAttention): if module._qkv_same_embed_dim: assert module.in_proj_weight is not None - assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None) + assert ( + module.q_proj_weight is None + and module.k_proj_weight is None + and (module.v_proj_weight is None) + ) assert d_model is not None _d = d_model splits = (0, _d, 2 * _d, 3 * _d) - for (s, e) in zip(splits[:-1], splits[1:]): + for s, e in zip(splits[:-1], splits[1:]): init_fn_(module.in_proj_weight[s:e]) else: - assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None) + assert ( + module.q_proj_weight is not None + and module.k_proj_weight is not None + and (module.v_proj_weight is not None) + ) assert module.in_proj_weight is None init_fn_(module.q_proj_weight) init_fn_(module.k_proj_weight) @@ -109,37 +147,112 @@ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: if module.bias_v is not None: torch.nn.init.zeros_(module.bias_v) init_fn_(module.out_proj.weight) - if init_div_is_residual is not False and getattr(module.out_proj, '_is_residual', False): + if init_div_is_residual is not False and getattr( + module.out_proj, "_is_residual", False + ): with torch.no_grad(): module.out_proj.weight.div_(div_is_residual) if module.out_proj.bias is not None: torch.nn.init.zeros_(module.out_proj.bias) else: for _ in module.parameters(recurse=False): - raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.') + raise NotImplementedError( + f"{module.__class__.__name__} parameters are not initialized by param_init_fn." + ) + def _normal_init_(std, mean=0.0): return partial(torch.nn.init.normal_, mean=mean, std=std) -def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): + +def _normal_param_init_fn_( + module: nn.Module, + std: float, + n_layers: int, + d_model: Optional[int] = None, + init_div_is_residual: Union[int, float, str, bool] = True, + emb_init_std: Optional[float] = None, + emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, + verbose: int = 0, + **kwargs, +): del kwargs init_fn_ = _normal_init_(std=std) if verbose > 1: - warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}') - generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) + warnings.warn(f"Using torch.nn.init.normal_ init fn mean=0.0, std={std}") + generic_param_init_fn_( + module=module, + init_fn_=init_fn_, + d_model=d_model, + n_layers=n_layers, + init_div_is_residual=init_div_is_residual, + emb_init_std=emb_init_std, + emb_init_uniform_lim=emb_init_uniform_lim, + verbose=verbose, + ) + -def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): +def baseline_param_init_fn_( + module: nn.Module, + init_std: float, + n_layers: int, + d_model: Optional[int] = None, + init_div_is_residual: Union[int, float, str, bool] = True, + emb_init_std: Optional[float] = None, + emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, + verbose: int = 0, + **kwargs, +): del kwargs if init_std is None: - raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.") - _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) + raise ValueError( + "You must set model.init_config['init_std'] to a float value to use the default initialization scheme." + ) + _normal_param_init_fn_( + module=module, + std=init_std, + d_model=d_model, + n_layers=n_layers, + init_div_is_residual=init_div_is_residual, + emb_init_std=emb_init_std, + emb_init_uniform_lim=emb_init_uniform_lim, + verbose=verbose, + ) -def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): + +def small_param_init_fn_( + module: nn.Module, + n_layers: int, + d_model: int, + init_div_is_residual: Union[int, float, str, bool] = True, + emb_init_std: Optional[float] = None, + emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, + verbose: int = 0, + **kwargs, +): del kwargs std = math.sqrt(2 / (5 * d_model)) - _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) + _normal_param_init_fn_( + module=module, + std=std, + d_model=d_model, + n_layers=n_layers, + init_div_is_residual=init_div_is_residual, + emb_init_std=emb_init_std, + emb_init_uniform_lim=emb_init_uniform_lim, + verbose=verbose, + ) + -def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): +def neox_param_init_fn_( + module: nn.Module, + n_layers: int, + d_model: int, + emb_init_std: Optional[float] = None, + emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, + verbose: int = 0, + **kwargs, +): """From section 2.3.1 of GPT-NeoX-20B: An Open-Source AutoregressiveLanguage Model — Black et. al. (2022) @@ -149,33 +262,158 @@ def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init del kwargs residual_div = n_layers / math.sqrt(10) if verbose > 1: - warnings.warn(f'setting init_div_is_residual to {residual_div}') - small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) + warnings.warn(f"setting init_div_is_residual to {residual_div}") + small_param_init_fn_( + module=module, + d_model=d_model, + n_layers=n_layers, + init_div_is_residual=residual_div, + emb_init_std=emb_init_std, + emb_init_uniform_lim=emb_init_uniform_lim, + verbose=verbose, + ) -def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs): + +def kaiming_uniform_param_init_fn_( + module: nn.Module, + n_layers: int, + d_model: Optional[int] = None, + init_div_is_residual: Union[int, float, str, bool] = True, + emb_init_std: Optional[float] = None, + emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, + init_gain: float = 0, + fan_mode: str = "fan_in", + init_nonlinearity: str = "leaky_relu", + verbose: int = 0, + **kwargs, +): del kwargs if verbose > 1: - warnings.warn(f'Using nn.init.kaiming_uniform_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}') - kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity) - generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) + warnings.warn( + f"Using nn.init.kaiming_uniform_ init fn with parameters: " + + f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}" + ) + kaiming_uniform_ = partial( + nn.init.kaiming_uniform_, + a=init_gain, + mode=fan_mode, + nonlinearity=init_nonlinearity, + ) + generic_param_init_fn_( + module=module, + init_fn_=kaiming_uniform_, + d_model=d_model, + n_layers=n_layers, + init_div_is_residual=init_div_is_residual, + emb_init_std=emb_init_std, + emb_init_uniform_lim=emb_init_uniform_lim, + verbose=verbose, + ) + -def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs): +def kaiming_normal_param_init_fn_( + module: nn.Module, + n_layers: int, + d_model: Optional[int] = None, + init_div_is_residual: Union[int, float, str, bool] = True, + emb_init_std: Optional[float] = None, + emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, + init_gain: float = 0, + fan_mode: str = "fan_in", + init_nonlinearity: str = "leaky_relu", + verbose: int = 0, + **kwargs, +): del kwargs if verbose > 1: - warnings.warn(f'Using nn.init.kaiming_normal_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}') - kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity) - generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) + warnings.warn( + f"Using nn.init.kaiming_normal_ init fn with parameters: " + + f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}" + ) + kaiming_normal_ = partial( + torch.nn.init.kaiming_normal_, + a=init_gain, + mode=fan_mode, + nonlinearity=init_nonlinearity, + ) + generic_param_init_fn_( + module=module, + init_fn_=kaiming_normal_, + d_model=d_model, + n_layers=n_layers, + init_div_is_residual=init_div_is_residual, + emb_init_std=emb_init_std, + emb_init_uniform_lim=emb_init_uniform_lim, + verbose=verbose, + ) -def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs): + +def xavier_uniform_param_init_fn_( + module: nn.Module, + n_layers: int, + d_model: Optional[int] = None, + init_div_is_residual: Union[int, float, str, bool] = True, + emb_init_std: Optional[float] = None, + emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, + init_gain: float = 0, + verbose: int = 0, + **kwargs, +): del kwargs xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain) if verbose > 1: - warnings.warn(f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' + f'gain={init_gain}') - generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) + warnings.warn( + f"Using torch.nn.init.xavier_uniform_ init fn with parameters: " + + f"gain={init_gain}" + ) + generic_param_init_fn_( + module=module, + init_fn_=xavier_uniform_, + d_model=d_model, + n_layers=n_layers, + init_div_is_residual=init_div_is_residual, + emb_init_std=emb_init_std, + emb_init_uniform_lim=emb_init_uniform_lim, + verbose=verbose, + ) + -def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs): +def xavier_normal_param_init_fn_( + module: nn.Module, + n_layers: int, + d_model: Optional[int] = None, + init_div_is_residual: Union[int, float, str, bool] = True, + emb_init_std: Optional[float] = None, + emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, + init_gain: float = 0, + verbose: int = 0, + **kwargs, +): xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain) if verbose > 1: - warnings.warn(f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' + f'gain={init_gain}') - generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) -MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_} \ No newline at end of file + warnings.warn( + f"Using torch.nn.init.xavier_normal_ init fn with parameters: " + + f"gain={init_gain}" + ) + generic_param_init_fn_( + module=module, + init_fn_=xavier_normal_, + d_model=d_model, + n_layers=n_layers, + init_div_is_residual=init_div_is_residual, + emb_init_std=emb_init_std, + emb_init_uniform_lim=emb_init_uniform_lim, + verbose=verbose, + ) + + +MODEL_INIT_REGISTRY = { + "default_": torch_default_param_init_fn_, + "baseline_": baseline_param_init_fn_, + "kaiming_uniform_": kaiming_uniform_param_init_fn_, + "kaiming_normal_": kaiming_normal_param_init_fn_, + "neox_init_": neox_param_init_fn_, + "small_init_": small_param_init_fn_, + "xavier_uniform_": xavier_uniform_param_init_fn_, + "xavier_normal_": xavier_normal_param_init_fn_, +} diff --git a/model/llava/model/utils.py b/model/llava/model/utils.py index b732e869aa3ad9d08a5909d6ddf6c7631bb87805..976f0190bdf130627a289b0bbc765f75b0e7a669 100644 --- a/model/llava/model/utils.py +++ b/model/llava/model/utils.py @@ -5,16 +5,20 @@ from transformers import AutoConfig, StoppingCriteria def auto_upgrade(config): cfg = AutoConfig.from_pretrained(config) - if 'llava' in config and 'llava' not in cfg.model_type: - assert cfg.model_type == 'llama' - print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") - print("You must upgrade the checkpoint to the new code base (this can be done automatically).") + if "llava" in config and "llava" not in cfg.model_type: + assert cfg.model_type == "llama" + print( + "You are using newer LLaVA code base, while the checkpoint of v0 is from older code base." + ) + print( + "You must upgrade the checkpoint to the new code base (this can be done automatically)." + ) confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") if confirm.lower() in ["y", "yes"]: print("Upgrading checkpoint...") assert len(cfg.architectures) == 1 setattr(cfg.__class__, "model_type", "llava") - cfg.architectures[0] = 'LlavaLlamaForCausalLM' + cfg.architectures[0] = "LlavaLlamaForCausalLM" cfg.save_pretrained(config) print("Checkpoint upgraded.") else: @@ -22,24 +26,31 @@ def auto_upgrade(config): exit(1) - class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords] - self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1] + self.keyword_ids = [ + keyword_id[0] + for keyword_id in self.keyword_ids + if type(keyword_id) is list and len(keyword_id) == 1 + ] self.tokenizer = tokenizer self.start_len = None self.input_ids = input_ids - def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + def __call__( + self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs + ) -> bool: if self.start_len is None: self.start_len = self.input_ids.shape[1] else: for keyword_id in self.keyword_ids: if output_ids[0, -1] == keyword_id: return True - outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] + outputs = self.tokenizer.batch_decode( + output_ids[:, self.start_len :], skip_special_tokens=True + )[0] for keyword in self.keywords: if keyword in outputs: return True @@ -50,7 +61,7 @@ class KeywordsStoppingCriteria(StoppingCriteria): # # if output_ids[0, -1] == keyword_id: # # return True - + # print("output_ids.shape: {}, self.start_len: {}".format(output_ids.shape, self.start_len)) # print("output_ids[:, self.start_len:]: ", output_ids[:, self.start_len:]) diff --git a/model/llava/serve/cli.py b/model/llava/serve/cli.py index a385727b5cc7ad7c013c01d704297ec1af5d5686..044be2c8c162f1f6e8a3acbf25022778a17f47c4 100644 --- a/model/llava/serve/cli.py +++ b/model/llava/serve/cli.py @@ -6,14 +6,14 @@ import argparse import time import torch -from transformers import AutoTokenizer, AutoModelForCausalLM - -from llava.conversation import conv_templates, SeparatorStyle +from llava.conversation import SeparatorStyle, conv_templates +from transformers import AutoModelForCausalLM, AutoTokenizer @torch.inference_mode() -def generate_stream(tokenizer, model, params, device, - context_len=2048, stream_interval=2): +def generate_stream( + tokenizer, model, params, device, context_len=2048, stream_interval=2 +): """Adapted from fastchat/serve/model_worker.py::generate_stream""" prompt = params["prompt"] @@ -30,17 +30,19 @@ def generate_stream(tokenizer, model, params, device, for i in range(max_new_tokens): if i == 0: - out = model( - torch.as_tensor([input_ids], device=device), use_cache=True) + out = model(torch.as_tensor([input_ids], device=device), use_cache=True) logits = out.logits past_key_values = out.past_key_values else: attention_mask = torch.ones( - 1, past_key_values[0][0].shape[-2] + 1, device=device) - out = model(input_ids=torch.as_tensor([[token]], device=device), - use_cache=True, - attention_mask=attention_mask, - past_key_values=past_key_values) + 1, past_key_values[0][0].shape[-2] + 1, device=device + ) + out = model( + input_ids=torch.as_tensor([[token]], device=device), + use_cache=True, + attention_mask=attention_mask, + past_key_values=past_key_values, + ) logits = out.logits past_key_values = out.past_key_values @@ -84,18 +86,21 @@ def main(args): else: num_gpus = int(num_gpus) if num_gpus != 1: - kwargs.update({ - "device_map": "auto", - "max_memory": {i: "13GiB" for i in range(num_gpus)}, - }) + kwargs.update( + { + "device_map": "auto", + "max_memory": {i: "13GiB" for i in range(num_gpus)}, + } + ) elif args.device == "cpu": kwargs = {} else: raise ValueError(f"Invalid device: {args.device}") tokenizer = AutoTokenizer.from_pretrained(model_name) - model = AutoModelForCausalLM.from_pretrained(model_name, - low_cpu_mem_usage=True, **kwargs) + model = AutoModelForCausalLM.from_pretrained( + model_name, low_cpu_mem_usage=True, **kwargs + ) if args.device == "cuda" and num_gpus == 1: model.cuda() @@ -126,11 +131,11 @@ def main(args): print(f"{conv.roles[1]}: ", end="", flush=True) pre = 0 for outputs in generate_stream(tokenizer, model, params, args.device): - outputs = outputs[len(prompt) + 1:].strip() + outputs = outputs[len(prompt) + 1 :].strip() outputs = outputs.split(" ") now = len(outputs) if now - 1 > pre: - print(" ".join(outputs[pre:now-1]), end=" ", flush=True) + print(" ".join(outputs[pre : now - 1]), end=" ", flush=True) pre = now - 1 print(" ".join(outputs[pre:]), flush=True) diff --git a/model/llava/serve/controller.py b/model/llava/serve/controller.py index b61fca6ea9fe8aa37acd143784a3d76e90a58b9f..1cb67c27a657d71420f6dddb4e40e40e7645488e 100644 --- a/model/llava/serve/controller.py +++ b/model/llava/serve/controller.py @@ -5,23 +5,21 @@ It sends worker addresses to clients. import argparse import asyncio import dataclasses -from enum import Enum, auto import json import logging +import threading import time +from enum import Enum, auto from typing import List, Union -import threading -from fastapi import FastAPI, Request -from fastapi.responses import StreamingResponse import numpy as np import requests import uvicorn - +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION from llava.utils import build_logger, server_error_msg - logger = build_logger("controller", "controller.log") @@ -61,13 +59,15 @@ class Controller: self.dispatch_method = DispatchMethod.from_str(dispatch_method) self.heart_beat_thread = threading.Thread( - target=heart_beat_controller, args=(self,)) + target=heart_beat_controller, args=(self,) + ) self.heart_beat_thread.start() logger.info("Init controller") - def register_worker(self, worker_name: str, check_heart_beat: bool, - worker_status: dict): + def register_worker( + self, worker_name: str, check_heart_beat: bool, worker_status: dict + ): if worker_name not in self.worker_info: logger.info(f"Register a new worker: {worker_name}") else: @@ -79,8 +79,12 @@ class Controller: return False self.worker_info[worker_name] = WorkerInfo( - worker_status["model_names"], worker_status["speed"], worker_status["queue_length"], - check_heart_beat, time.time()) + worker_status["model_names"], + worker_status["speed"], + worker_status["queue_length"], + check_heart_beat, + time.time(), + ) logger.info(f"Register done: {worker_name}, {worker_status}") return True @@ -131,15 +135,13 @@ class Controller: return "" worker_speeds = worker_speeds / norm if True: # Directly return address - pt = np.random.choice(np.arange(len(worker_names)), - p=worker_speeds) + pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) worker_name = worker_names[pt] return worker_name # Check status before returning while True: - pt = np.random.choice(np.arange(len(worker_names)), - p=worker_speeds) + pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) worker_name = worker_names[pt] if self.get_worker_status(worker_name): @@ -165,7 +167,9 @@ class Controller: min_index = np.argmin(worker_qlen) w_name = worker_names[min_index] self.worker_info[w_name].queue_length += 1 - logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}") + logger.info( + f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}" + ) return w_name else: raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") @@ -201,8 +205,12 @@ class Controller: yield json.dumps(ret).encode() + b"\0" try: - response = requests.post(worker_addr + "/worker_generate_stream", - json=params, stream=True, timeout=5) + response = requests.post( + worker_addr + "/worker_generate_stream", + json=params, + stream=True, + timeout=5, + ) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: yield chunk + b"\0" @@ -214,7 +222,6 @@ class Controller: } yield json.dumps(ret).encode() + b"\0" - # Let the controller act as a worker to achieve hierarchical # management. This can be used to connect isolated sub networks. def worker_api_get_status(self): @@ -243,8 +250,8 @@ app = FastAPI() async def register_worker(request: Request): data = await request.json() controller.register_worker( - data["worker_name"], data["check_heart_beat"], - data.get("worker_status", None)) + data["worker_name"], data["check_heart_beat"], data.get("worker_status", None) + ) @app.post("/refresh_all_workers") @@ -268,8 +275,7 @@ async def get_worker_address(request: Request): @app.post("/receive_heart_beat") async def receive_heart_beat(request: Request): data = await request.json() - exist = controller.receive_heart_beat( - data["worker_name"], data["queue_length"]) + exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"]) return {"exist": exist} @@ -289,8 +295,12 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=21001) - parser.add_argument("--dispatch-method", type=str, choices=[ - "lottery", "shortest_queue"], default="shortest_queue") + parser.add_argument( + "--dispatch-method", + type=str, + choices=["lottery", "shortest_queue"], + default="shortest_queue", + ) args = parser.parse_args() logger.info(f"args: {args}") diff --git a/model/llava/serve/gradio_css.py b/model/llava/serve/gradio_css.py index 55454130b3becdf786c1545a2f79028068389e7c..71d79b4a4b5a7ad84b8822d99e1740e77bc1f7a8 100644 --- a/model/llava/serve/gradio_css.py +++ b/model/llava/serve/gradio_css.py @@ -1,5 +1,4 @@ -code_highlight_css = ( -""" +code_highlight_css = """ #chatbot .hll { background-color: #ffffcc } #chatbot .c { color: #408080; font-style: italic } #chatbot .err { border: 1px solid #FF0000 } @@ -68,6 +67,5 @@ code_highlight_css = ( #chatbot .vi { color: #19177C } #chatbot .vm { color: #19177C } #chatbot .il { color: #666666 } -""") -#.highlight { background: #f8f8f8; } - +""" +# .highlight { background: #f8f8f8; } diff --git a/model/llava/serve/gradio_patch.py b/model/llava/serve/gradio_patch.py index 07e5909e2d6b10fc75178daa54f45c01dcbb42cb..cb3b4838fe14c9df20b7f22eef03617e1e4b088a 100644 --- a/model/llava/serve/gradio_patch.py +++ b/model/llava/serve/gradio_patch.py @@ -50,7 +50,7 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable): warnings.warn( "The 'color_map' parameter has been deprecated.", ) - #self.md = utils.get_markdown_parser() + # self.md = utils.get_markdown_parser() self.md = Markdown(extras=["fenced-code-blocks", "tables", "break-on-newline"]) self.select: EventListenerMethod """ @@ -113,7 +113,7 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable): ): # This happens for previously processed messages return chat_message elif isinstance(chat_message, str): - #return self.md.render(chat_message) + # return self.md.render(chat_message) return str(self.md.convert(chat_message)) else: raise ValueError(f"Invalid message for Chatbot component: {chat_message}") @@ -142,9 +142,10 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable): ), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}" processed_messages.append( ( - #self._process_chat_messages(message_pair[0]), - '
' +
-                    message_pair[0] + "
", + # self._process_chat_messages(message_pair[0]), + '
'
+                    + message_pair[0]
+                    + "
", self._process_chat_messages(message_pair[1]), ) ) @@ -164,5 +165,3 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable): **kwargs, ) return self - - diff --git a/model/llava/serve/gradio_web_server.py b/model/llava/serve/gradio_web_server.py index c6407730e2956ea0ea65dc7b11873f7b5bef126c..976972e690e317e173f57d43f4556c81dc94e4a5 100644 --- a/model/llava/serve/gradio_web_server.py +++ b/model/llava/serve/gradio_web_server.py @@ -1,22 +1,20 @@ import argparse -from collections import defaultdict import datetime +import hashlib import json import os import time +from collections import defaultdict import gradio as gr import requests - -from llava.conversation import (default_conversation, conv_templates, - SeparatorStyle) from llava.constants import LOGDIR -from llava.utils import (build_logger, server_error_msg, - violates_moderation, moderation_msg) -from llava.serve.gradio_patch import Chatbot as grChatbot +from llava.conversation import (SeparatorStyle, conv_templates, + default_conversation) from llava.serve.gradio_css import code_highlight_css -import hashlib - +from llava.serve.gradio_patch import Chatbot as grChatbot +from llava.utils import (build_logger, moderation_msg, server_error_msg, + violates_moderation) logger = build_logger("gradio_web_server", "gradio_web_server.log") @@ -65,31 +63,33 @@ def load_demo(url_params, request: gr.Request): if "model" in url_params: model = url_params["model"] if model in models: - dropdown_update = gr.Dropdown.update( - value=model, visible=True) + dropdown_update = gr.Dropdown.update(value=model, visible=True) state = default_conversation.copy() - return (state, - dropdown_update, - gr.Chatbot.update(visible=True), - gr.Textbox.update(visible=True), - gr.Button.update(visible=True), - gr.Row.update(visible=True), - gr.Accordion.update(visible=True)) + return ( + state, + dropdown_update, + gr.Chatbot.update(visible=True), + gr.Textbox.update(visible=True), + gr.Button.update(visible=True), + gr.Row.update(visible=True), + gr.Accordion.update(visible=True), + ) def load_demo_refresh_model_list(request: gr.Request): logger.info(f"load_demo. ip: {request.client.host}") models = get_model_list() state = default_conversation.copy() - return (state, gr.Dropdown.update( - choices=models, - value=models[0] if len(models) > 0 else ""), - gr.Chatbot.update(visible=True), - gr.Textbox.update(visible=True), - gr.Button.update(visible=True), - gr.Row.update(visible=True), - gr.Accordion.update(visible=True)) + return ( + state, + gr.Dropdown.update(choices=models, value=models[0] if len(models) > 0 else ""), + gr.Chatbot.update(visible=True), + gr.Textbox.update(visible=True), + gr.Button.update(visible=True), + gr.Row.update(visible=True), + gr.Accordion.update(visible=True), + ) def vote_last_response(state, vote_type, model_selector, request: gr.Request): @@ -148,13 +148,14 @@ def add_text(state, text, image, image_process_mode, request: gr.Request): if flagged: state.skip_next = True return (state, state.to_gradio_chatbot(), moderation_msg, None) + ( - no_change_btn,) * 5 + no_change_btn, + ) * 5 text = text[:1536] # Hard cut-off if image is not None: text = text[:1200] # Hard cut-off for images - if '' not in text: - text = text + '\n' + if "" not in text: + text = text + "\n" text = (text, image, image_process_mode) state = default_conversation.copy() state.append_message(state.roles[0], text) @@ -195,9 +196,9 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req template_name = "multimodal" elif "mpt" in model_name: template_name = "mpt_text" - elif "koala" in model_name: # Hardcode the condition + elif "koala" in model_name: # Hardcode the condition template_name = "bair_v1" - elif "v1" in model_name: # vicuna v1_1/v1_2 + elif "v1" in model_name: # vicuna v1_1/v1_2 template_name = "vicuna_v1_1" else: template_name = "v1" @@ -208,15 +209,24 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req # Query worker address controller_url = args.controller_url - ret = requests.post(controller_url + "/get_worker_address", - json={"model": model_name}) + ret = requests.post( + controller_url + "/get_worker_address", json={"model": model_name} + ) worker_addr = ret.json()["address"] logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") # No available worker if worker_addr == "": state.messages[-1][-1] = server_error_msg - yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + yield ( + state, + state.to_gradio_chatbot(), + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) return # Construct prompt @@ -226,7 +236,9 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images] for image, hash in zip(all_images, all_image_hash): t = datetime.datetime.now() - filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg") + filename = os.path.join( + LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg" + ) if not os.path.isfile(filename): os.makedirs(os.path.dirname(filename), exist_ok=True) image.save(filename) @@ -237,37 +249,56 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req "prompt": prompt, "temperature": float(temperature), "max_new_tokens": min(int(max_new_tokens), 1536), - "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2, - "images": f'List of {len(state.get_images())} images: {all_image_hash}', + "stop": state.sep + if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] + else state.sep2, + "images": f"List of {len(state.get_images())} images: {all_image_hash}", } logger.info(f"==== request ====\n{pload}") - pload['images'] = state.get_images() + pload["images"] = state.get_images() state.messages[-1][-1] = "▌" yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 try: # Stream output - response = requests.post(worker_addr + "/worker_generate_stream", - headers=headers, json=pload, stream=True, timeout=10) + response = requests.post( + worker_addr + "/worker_generate_stream", + headers=headers, + json=pload, + stream=True, + timeout=10, + ) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode()) if data["error_code"] == 0: - output = data["text"][len(prompt):].strip() + output = data["text"][len(prompt) :].strip() output = post_process_code(output) state.messages[-1][-1] = output + "▌" yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 else: output = data["text"] + f" (error_code: {data['error_code']})" state.messages[-1][-1] = output - yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) return time.sleep(0.03) except requests.exceptions.RequestException as e: state.messages[-1][-1] = server_error_msg - yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + yield (state, state.to_gradio_chatbot()) + ( + disable_btn, + disable_btn, + disable_btn, + enable_btn, + enable_btn, + ) return state.messages[-1][-1] = state.messages[-1][-1][:-1] @@ -289,27 +320,30 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req } fout.write(json.dumps(data) + "\n") -title_markdown = (""" + +title_markdown = """ # 🌋 LLaVA: Large Language and Vision Assistant [[Project Page]](https://llava-vl.github.io) [[Paper]](https://arxiv.org/abs/2304.08485) [[Code]](https://github.com/haotian-liu/LLaVA) [[Model]](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v0) -""") +""" -tos_markdown = (""" +tos_markdown = """ ### Terms of use By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator. For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality. -""") +""" -learn_more_markdown = (""" +learn_more_markdown = """ ### License The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. -""") +""" -css = code_highlight_css + """ +css = ( + code_highlight_css + + """ pre { white-space: pre-wrap; /* Since CSS 2.1 */ white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ @@ -318,11 +352,13 @@ pre { word-wrap: break-word; /* Internet Explorer 5.5+ */ } """ +) def build_demo(embed_mode): - textbox = gr.Textbox(show_label=False, - placeholder="Enter text and press ENTER", visible=False).style(container=False) + textbox = gr.Textbox( + show_label=False, placeholder="Enter text and press ENTER", visible=False + ).style(container=False) with gr.Blocks(title="LLaVA", theme=gr.themes.Base(), css=css) as demo: state = gr.State() @@ -336,26 +372,55 @@ def build_demo(embed_mode): choices=models, value=models[0] if len(models) > 0 else "", interactive=True, - show_label=False).style(container=False) + show_label=False, + ).style(container=False) imagebox = gr.Image(type="pil") image_process_mode = gr.Radio( ["Crop", "Resize", "Pad"], value="Crop", - label="Preprocess for non-square image") + label="Preprocess for non-square image", + ) cur_dir = os.path.dirname(os.path.abspath(__file__)) - gr.Examples(examples=[ - [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"], - [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"], - ], inputs=[imagebox, textbox]) - - with gr.Accordion("Parameters", open=False, visible=False) as parameter_row: - temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",) - max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",) + gr.Examples( + examples=[ + [ + f"{cur_dir}/examples/extreme_ironing.jpg", + "What is unusual about this image?", + ], + [ + f"{cur_dir}/examples/waterview.jpg", + "What are the things I should be cautious about when I visit here?", + ], + ], + inputs=[imagebox, textbox], + ) + + with gr.Accordion( + "Parameters", open=False, visible=False + ) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.2, + step=0.1, + interactive=True, + label="Temperature", + ) + max_output_tokens = gr.Slider( + minimum=0, + maximum=1024, + value=512, + step=64, + interactive=True, + label="Max output tokens", + ) with gr.Column(scale=6): - chatbot = grChatbot(elem_id="chatbot", label="LLaVA Chatbot", visible=False).style(height=550) + chatbot = grChatbot( + elem_id="chatbot", label="LLaVA Chatbot", visible=False + ).style(height=550) with gr.Row(): with gr.Column(scale=8): textbox.render() @@ -365,7 +430,7 @@ def build_demo(embed_mode): upvote_btn = gr.Button(value="👍 Upvote", interactive=False) downvote_btn = gr.Button(value="👎 Downvote", interactive=False) flag_btn = gr.Button(value="⚠️ Flag", interactive=False) - #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False) + # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False) regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) @@ -376,32 +441,82 @@ def build_demo(embed_mode): # Register listeners btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] - upvote_btn.click(upvote_last_response, - [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn]) - downvote_btn.click(downvote_last_response, - [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn]) - flag_btn.click(flag_last_response, - [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn]) - regenerate_btn.click(regenerate, [state, image_process_mode], - [state, chatbot, textbox, imagebox] + btn_list).then( - http_bot, [state, model_selector, temperature, max_output_tokens], - [state, chatbot] + btn_list) - clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list) - - textbox.submit(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list - ).then(http_bot, [state, model_selector, temperature, max_output_tokens], - [state, chatbot] + btn_list) - submit_btn.click(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list - ).then(http_bot, [state, model_selector, temperature, max_output_tokens], - [state, chatbot] + btn_list) + upvote_btn.click( + upvote_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + downvote_btn.click( + downvote_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + flag_btn.click( + flag_last_response, + [state, model_selector], + [textbox, upvote_btn, downvote_btn, flag_btn], + ) + regenerate_btn.click( + regenerate, + [state, image_process_mode], + [state, chatbot, textbox, imagebox] + btn_list, + ).then( + http_bot, + [state, model_selector, temperature, max_output_tokens], + [state, chatbot] + btn_list, + ) + clear_btn.click( + clear_history, None, [state, chatbot, textbox, imagebox] + btn_list + ) + + textbox.submit( + add_text, + [state, textbox, imagebox, image_process_mode], + [state, chatbot, textbox, imagebox] + btn_list, + ).then( + http_bot, + [state, model_selector, temperature, max_output_tokens], + [state, chatbot] + btn_list, + ) + submit_btn.click( + add_text, + [state, textbox, imagebox, image_process_mode], + [state, chatbot, textbox, imagebox] + btn_list, + ).then( + http_bot, + [state, model_selector, temperature, max_output_tokens], + [state, chatbot] + btn_list, + ) if args.model_list_mode == "once": - demo.load(load_demo, [url_params], [state, model_selector, - chatbot, textbox, submit_btn, button_row, parameter_row], - _js=get_window_url_params) + demo.load( + load_demo, + [url_params], + [ + state, + model_selector, + chatbot, + textbox, + submit_btn, + button_row, + parameter_row, + ], + _js=get_window_url_params, + ) elif args.model_list_mode == "reload": - demo.load(load_demo_refresh_model_list, None, [state, model_selector, - chatbot, textbox, submit_btn, button_row, parameter_row]) + demo.load( + load_demo_refresh_model_list, + None, + [ + state, + model_selector, + chatbot, + textbox, + submit_btn, + button_row, + parameter_row, + ], + ) else: raise ValueError(f"Unknown model list mode: {args.model_list_mode}") @@ -414,8 +529,9 @@ if __name__ == "__main__": parser.add_argument("--port", type=int) parser.add_argument("--controller-url", type=str, default="http://localhost:21001") parser.add_argument("--concurrency-count", type=int, default=8) - parser.add_argument("--model-list-mode", type=str, default="once", - choices=["once", "reload"]) + parser.add_argument( + "--model-list-mode", type=str, default="once", choices=["once", "reload"] + ) parser.add_argument("--share", action="store_true") parser.add_argument("--moderate", action="store_true") parser.add_argument("--embed", action="store_true") @@ -426,6 +542,6 @@ if __name__ == "__main__": logger.info(args) demo = build_demo(args.embed) - demo.queue(concurrency_count=args.concurrency_count, status_update_rate=10, - api_open=False).launch( - server_name=args.host, server_port=args.port, share=args.share) + demo.queue( + concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False + ).launch(server_name=args.host, server_port=args.port, share=args.share) diff --git a/model/llava/serve/model_worker.py b/model/llava/serve/model_worker.py index a4ef900d42a823b5ae7cb41becf1a73b56c6565c..0095bb61b056e69e26d243534796385707f743b0 100644 --- a/model/llava/serve/model_worker.py +++ b/model/llava/serve/model_worker.py @@ -4,25 +4,23 @@ A model worker executes the model. import argparse import asyncio import dataclasses -import logging import json -import time -from typing import List, Union +import logging import threading +import time import uuid +from functools import partial +from typing import List, Union -from fastapi import FastAPI, Request, BackgroundTasks -from fastapi.responses import StreamingResponse import requests -from transformers import AutoTokenizer, AutoModelForCausalLM import torch import uvicorn -from functools import partial - +from fastapi import BackgroundTasks, FastAPI, Request +from fastapi.responses import StreamingResponse from llava.constants import WORKER_HEART_BEAT_INTERVAL -from llava.utils import (build_logger, server_error_msg, - pretty_print_semaphore) from llava.model import * +from llava.utils import build_logger, pretty_print_semaphore, server_error_msg +from transformers import AutoModelForCausalLM, AutoTokenizer GB = 1 << 30 @@ -40,7 +38,6 @@ DEFAULT_IM_END_TOKEN = "" def heart_beat_worker(controller): - while True: time.sleep(WORKER_HEART_BEAT_INTERVAL) controller.send_heart_beat() @@ -56,38 +53,66 @@ def load_model(model_path, model_name, num_gpus): } tokenizer = AutoTokenizer.from_pretrained(model_path) - if 'llava' in model_name.lower(): - if 'mpt' in model_name.lower(): - model = LlavaMPTForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs) + if "llava" in model_name.lower(): + if "mpt" in model_name.lower(): + model = LlavaMPTForCausalLM.from_pretrained( + model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs + ) else: - model = LlavaLlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs) - elif 'mpt' in model_name.lower(): - model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) + model = LlavaLlamaForCausalLM.from_pretrained( + model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs + ) + elif "mpt" in model_name.lower(): + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + trust_remote_code=True, + **kwargs, + ) else: - model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs) + model = AutoModelForCausalLM.from_pretrained( + model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs + ) image_processor = None - if 'llava' in model_name.lower(): + if "llava" in model_name.lower(): from transformers import CLIPImageProcessor, CLIPVisionModel - image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16) + + image_processor = CLIPImageProcessor.from_pretrained( + model.config.mm_vision_tower, torch_dtype=torch.float16 + ) mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) if mm_use_im_start_end: - tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + tokenizer.add_tokens( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True + ) vision_tower = model.get_model().vision_tower[0] - if vision_tower.device.type == 'meta': - vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True).cuda() + if vision_tower.device.type == "meta": + vision_tower = CLIPVisionModel.from_pretrained( + vision_tower.config._name_or_path, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + ).cuda() model.get_model().vision_tower[0] = vision_tower else: - vision_tower.to(device='cuda', dtype=torch.float16) + vision_tower.to(device="cuda", dtype=torch.float16) vision_config = vision_tower.config - vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] + vision_config.im_patch_token = tokenizer.convert_tokens_to_ids( + [DEFAULT_IMAGE_PATCH_TOKEN] + )[0] vision_config.use_im_start_end = mm_use_im_start_end if mm_use_im_start_end: - vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) + ( + vision_config.im_start_token, + vision_config.im_end_token, + ) = tokenizer.convert_tokens_to_ids( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN] + ) if num_gpus == 1: model.cuda() @@ -101,11 +126,17 @@ def load_model(model_path, model_name, num_gpus): class ModelWorker: - def __init__(self, controller_addr, worker_addr, - worker_id, no_register, - model_path, model_name, - keep_aspect_ratio, - num_gpus): + def __init__( + self, + controller_addr, + worker_addr, + worker_id, + no_register, + model_path, + model_name, + keep_aspect_ratio, + num_gpus, + ): self.controller_addr = controller_addr self.worker_addr = worker_addr self.worker_id = worker_id @@ -113,7 +144,7 @@ class ModelWorker: model_path = model_path[:-1] if model_name is None: model_paths = model_path.split("/") - if model_paths[-1].startswith('checkpoint-'): + if model_paths[-1].startswith("checkpoint-"): self.model_name = model_paths[-2] + "_" + model_paths[-1] else: self.model_name = model_paths[-1] @@ -123,13 +154,15 @@ class ModelWorker: logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...") self.keep_aspect_ratio = keep_aspect_ratio self.tokenizer, self.model, self.image_processor, self.context_len = load_model( - model_path, self.model_name, num_gpus) - self.is_multimodal = 'llava' in model_path.lower() + model_path, self.model_name, num_gpus + ) + self.is_multimodal = "llava" in model_path.lower() if not no_register: self.register_to_controller() self.heart_beat_thread = threading.Thread( - target=heart_beat_worker, args=(self,)) + target=heart_beat_worker, args=(self,) + ) self.heart_beat_thread.start() def register_to_controller(self): @@ -139,23 +172,30 @@ class ModelWorker: data = { "worker_name": self.worker_addr, "check_heart_beat": True, - "worker_status": self.get_status() + "worker_status": self.get_status(), } r = requests.post(url, json=data) assert r.status_code == 200 def send_heart_beat(self): - logger.info(f"Send heart beat. Models: {[self.model_name]}. " - f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " - f"global_counter: {global_counter}") + logger.info( + f"Send heart beat. Models: {[self.model_name]}. " + f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " + f"global_counter: {global_counter}" + ) url = self.controller_addr + "/receive_heart_beat" while True: try: - ret = requests.post(url, json={ - "worker_name": self.worker_addr, - "queue_length": self.get_queue_length()}, timeout=5) + ret = requests.post( + url, + json={ + "worker_name": self.worker_addr, + "queue_length": self.get_queue_length(), + }, + timeout=5, + ) exist = ret.json()["exist"] break except requests.exceptions.RequestException as e: @@ -169,8 +209,15 @@ class ModelWorker: if model_semaphore is None: return 0 else: - return args.limit_model_concurrency - model_semaphore._value + (len( - model_semaphore._waiters) if model_semaphore._waiters is not None else 0) + return ( + args.limit_model_concurrency + - model_semaphore._value + + ( + len(model_semaphore._waiters) + if model_semaphore._waiters is not None + else 0 + ) + ) def get_status(self): return { @@ -181,20 +228,30 @@ class ModelWorker: @torch.inference_mode() def generate_stream(self, params): - tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor + tokenizer, model, image_processor = ( + self.tokenizer, + self.model, + self.image_processor, + ) prompt = params["prompt"] ori_prompt = prompt images = params.get("images", None) if images is not None and len(images) > 0 and self.is_multimodal: - from PIL import Image - from io import BytesIO import base64 + from io import BytesIO + + from PIL import Image + assert type(images) is list if len(images) > 0: # assert len(images) == 1, "Only support one image for now" - images = [Image.open(BytesIO(base64.b64decode(image))) for image in images] - assert len(images) == prompt.count(DEFAULT_IMAGE_TOKEN), "Number of images does not match number of tokens in prompt" + images = [ + Image.open(BytesIO(base64.b64decode(image))) for image in images + ] + assert len(images) == prompt.count( + DEFAULT_IMAGE_TOKEN + ), "Number of images does not match number of tokens in prompt" if self.keep_aspect_ratio: new_images = [] @@ -203,21 +260,40 @@ class ModelWorker: aspect_ratio = max_hw / min_hw max_len, min_len = 448, 224 shortest_edge = int(min(max_len / aspect_ratio, min_len)) - image = image_processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": shortest_edge})['pixel_values'][0] - new_images.append(image.to(self.model.device, dtype=torch.float16)) + image = image_processor.preprocess( + image, + return_tensors="pt", + do_center_crop=False, + size={"shortest_edge": shortest_edge}, + )["pixel_values"][0] + new_images.append( + image.to(self.model.device, dtype=torch.float16) + ) # replace the image token with the image patch token in the prompt (each occurrence) - cur_token_len = (image.shape[1]//14) * (image.shape[2]//14) + cur_token_len = (image.shape[1] // 14) * (image.shape[2] // 14) replace_token = DEFAULT_IMAGE_PATCH_TOKEN * cur_token_len - if getattr(self.model.config, 'mm_use_im_start_end', False): - replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + if getattr(self.model.config, "mm_use_im_start_end", False): + replace_token = ( + DEFAULT_IM_START_TOKEN + + replace_token + + DEFAULT_IM_END_TOKEN + ) prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token, 1) images = new_images else: - images = image_processor(images, return_tensors='pt')['pixel_values'] + images = image_processor(images, return_tensors="pt")[ + "pixel_values" + ] images = images.to(self.model.device, dtype=torch.float16) - replace_token = DEFAULT_IMAGE_PATCH_TOKEN * 256 # HACK: 256 is the max image token length hacked - if getattr(self.model.config, 'mm_use_im_start_end', False): - replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + replace_token = ( + DEFAULT_IMAGE_PATCH_TOKEN * 256 + ) # HACK: 256 is the max image token length hacked + if getattr(self.model.config, "mm_use_im_start_end", False): + replace_token = ( + DEFAULT_IM_START_TOKEN + + replace_token + + DEFAULT_IM_END_TOKEN + ) prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) else: images = None @@ -249,18 +325,20 @@ class ModelWorker: for i in range(max_new_tokens): if i == 0: out = model( - torch.as_tensor([input_ids]).cuda(), - use_cache=True, - **image_args) + torch.as_tensor([input_ids]).cuda(), use_cache=True, **image_args + ) logits = out.logits past_key_values = out.past_key_values else: attention_mask = torch.ones( - 1, past_key_values[0][0].shape[-2] + 1, device="cuda") - out = model(input_ids=torch.as_tensor([[token]], device="cuda"), - use_cache=True, - attention_mask=attention_mask, - past_key_values=past_key_values) + 1, past_key_values[0][0].shape[-2] + 1, device="cuda" + ) + out = model( + input_ids=torch.as_tensor([[token]], device="cuda"), + use_cache=True, + attention_mask=attention_mask, + past_key_values=past_key_values, + ) logits = out.logits past_key_values = out.past_key_values @@ -342,7 +420,9 @@ async def generate_stream(request: Request): worker.send_heart_beat() generator = worker.generate_stream_gate(params) background_tasks = BackgroundTasks() - background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) + background_tasks.add_task( + partial(release_model_semaphore, fn=worker.send_heart_beat) + ) return StreamingResponse(generator, background=background_tasks) @@ -355,13 +435,17 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=21002) - parser.add_argument("--worker-address", type=str, - default="http://localhost:21002") - parser.add_argument("--controller-address", type=str, - default="http://localhost:21001") + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) parser.add_argument("--model-path", type=str, default="facebook/opt-350m") parser.add_argument("--model-name", type=str) - parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.") + parser.add_argument( + "--multi-modal", + action="store_true", + help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.", + ) parser.add_argument("--keep-aspect-ratio", action="store_true") parser.add_argument("--num-gpus", type=int, default=1) parser.add_argument("--limit-model-concurrency", type=int, default=5) @@ -371,14 +455,18 @@ if __name__ == "__main__": logger.info(f"args: {args}") if args.multi_modal: - logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.") - - worker = ModelWorker(args.controller_address, - args.worker_address, - worker_id, - args.no_register, - args.model_path, - args.model_name, - args.keep_aspect_ratio, - args.num_gpus) + logger.warning( + "Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path." + ) + + worker = ModelWorker( + args.controller_address, + args.worker_address, + worker_id, + args.no_register, + args.model_path, + args.model_name, + args.keep_aspect_ratio, + args.num_gpus, + ) uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/model/llava/serve/test_message.py b/model/llava/serve/test_message.py index 6b090faed0e630b03b2294545050f1f4f5032cad..3d5c2576d943624db791332bda8427ef6f70778e 100644 --- a/model/llava/serve/test_message.py +++ b/model/llava/serve/test_message.py @@ -2,7 +2,6 @@ import argparse import json import requests - from llava.conversation import default_conversation @@ -17,8 +16,9 @@ def main(): models.sort() print(f"Models: {models}") - ret = requests.post(controller_addr + "/get_worker_address", - json={"model": args.model_name}) + ret = requests.post( + controller_addr + "/get_worker_address", json={"model": args.model_name} + ) worker_addr = ret.json()["address"] print(f"worker_addr: {worker_addr}") @@ -37,11 +37,17 @@ def main(): "temperature": 0.7, "stop": conv.sep, } - response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, - json=pload, stream=True) + response = requests.post( + worker_addr + "/worker_generate_stream", + headers=headers, + json=pload, + stream=True, + ) print(prompt.replace(conv.sep, "\n"), end="") - for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): + for chunk in response.iter_lines( + chunk_size=8192, decode_unicode=False, delimiter=b"\0" + ): if chunk: data = json.loads(chunk.decode("utf-8")) output = data["text"].split(conv.sep)[-1] @@ -51,12 +57,15 @@ def main(): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--controller-address", type=str, default="http://localhost:21001") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) parser.add_argument("--worker-address", type=str) parser.add_argument("--model-name", type=str, default="facebook/opt-350m") parser.add_argument("--max-new-tokens", type=int, default=32) - parser.add_argument("--message", type=str, default= - "Tell me a story with more than 1000 words.") + parser.add_argument( + "--message", type=str, default="Tell me a story with more than 1000 words." + ) args = parser.parse_args() main() diff --git a/model/llava/train/llama_flash_attn_monkey_patch.py b/model/llava/train/llama_flash_attn_monkey_patch.py index 89f9c3b56fce9b6c8c8be334772686a15c9454d4..66f1f7ab8a2b286f44327d7759ca6b082c4a9d9a 100644 --- a/model/llava/train/llama_flash_attn_monkey_patch.py +++ b/model/llava/train/llama_flash_attn_monkey_patch.py @@ -2,15 +2,13 @@ from typing import List, Optional, Tuple import torch -from torch import nn - import transformers -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb - from einops import rearrange - +from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func -from flash_attn.bert_padding import unpad_input, pad_input +from torch import nn +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb + def forward( self, @@ -19,20 +17,28 @@ def forward( attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, use_cache: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel - + attention_mask: [bsz, q_len] """ bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states).view( - bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view( - bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view( - bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) # [bsz, q_len, nh, hd] # [bsz, nh, q_len, hd] @@ -42,11 +48,9 @@ def forward( offset = past_key_value[0].shape[-2] kv_seq_len += offset cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, - key_states, - cos, - sin, - offset=offset) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, offset=offset + ) # [bsz, nh, t, hd] assert not output_attentions, "output_attentions is not supported" assert not use_cache, "use_cache is not supported" @@ -56,47 +60,55 @@ def forward( # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py # transform the data into the format required by flash attention - qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd] - qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] + qkv = torch.stack( + [query_states, key_states, value_states], dim=2 + ) # [bsz, nh, 3, q_len, hd] + qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] # We have disabled _prepare_decoder_attention_mask in LlamaModel # the attention_mask should be the same as the key_padding_mask key_padding_mask = attention_mask - if key_padding_mask is None: - qkv = rearrange(qkv, 'b s ... -> (b s) ...') + qkv = rearrange(qkv, "b s ... -> (b s) ...") max_s = q_len - cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, - device=qkv.device) + cu_q_lens = torch.arange( + 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device + ) output = flash_attn_unpadded_qkvpacked_func( - qkv, cu_q_lens, max_s, 0.0, - softmax_scale=None, causal=True + qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True ) - output = rearrange(output, '(b s) ... -> b s ...', b=bsz) + output = rearrange(output, "(b s) ... -> b s ...", b=bsz) else: nheads = qkv.shape[-2] - x = rearrange(qkv, 'b s three h d -> b s (three h d)') + x = rearrange(qkv, "b s three h d -> b s (three h d)") x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) - x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) + x_unpad = rearrange( + x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads + ) output_unpad = flash_attn_unpadded_qkvpacked_func( - x_unpad, cu_q_lens, max_s, 0.0, - softmax_scale=None, causal=True + x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True + ) + output = rearrange( + pad_input( + rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len + ), + "b s (h d) -> b s h d", + h=nheads, ) - output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), - indices, bsz, q_len), - 'b s (h d) -> b s h d', h=nheads) - return self.o_proj(rearrange(output, - 'b s h d -> b s (h d)')), None, None + return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None # Disable the transformation of the attention mask in LlamaModel as the flash attention # requires the attention mask to be the same as the key_padding_mask -def _prepare_decoder_attention_mask(self, attention_mask, input_shape, - inputs_embeds, past_key_values_length): +def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length +): # [bsz, seq_len] return attention_mask def replace_llama_attn_with_flash_attn(): - transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( + _prepare_decoder_attention_mask + ) transformers.models.llama.modeling_llama.LlamaAttention.forward = forward diff --git a/model/llava/train/llava_trainer.py b/model/llava/train/llava_trainer.py index 2824f25e92d8893103ad8f32848b749167c5630d..864ad016a65c83b22756be67a43bf6a24a5e50cd 100644 --- a/model/llava/train/llava_trainer.py +++ b/model/llava/train/llava_trainer.py @@ -1,9 +1,9 @@ import os +from typing import Dict, Optional, Sequence + import torch import torch.nn as nn - from transformers import Trainer -from typing import Dict, Optional, Sequence def unwrap_model(model: nn.Module) -> nn.Module: @@ -21,9 +21,8 @@ def unwrap_model(model: nn.Module) -> nn.Module: class LLaVATrainer(Trainer): - def _save(self, output_dir: Optional[str] = None, state_dict=None): - if getattr(self.args, 'tune_mm_mlp_adapter', False): + if getattr(self.args, "tune_mm_mlp_adapter", False): # Save the model _state_dict = state_dict if _state_dict is None: @@ -32,18 +31,23 @@ class LLaVATrainer(Trainer): _state_dict = model_to_save.state_dict() weight_to_save = {} - keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in'] + keys_to_match = ["mm_projector", "embed_tokens", "embed_in"] for k, v in _state_dict.items(): if any(key_match in k for key_match in keys_to_match): weight_to_save[k] = v - current_folder = output_dir.split('/')[-1] + current_folder = output_dir.split("/")[-1] parent_folder = os.path.dirname(output_dir) - if current_folder.startswith('checkpoint-'): + if current_folder.startswith("checkpoint-"): mm_projector_folder = os.path.join(parent_folder, "mm_projector") os.makedirs(mm_projector_folder, exist_ok=True) - torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) + torch.save( + weight_to_save, + os.path.join(mm_projector_folder, f"{current_folder}.bin"), + ) else: - torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) + torch.save( + weight_to_save, os.path.join(output_dir, f"mm_projector.bin") + ) super(LLaVATrainer, self)._save(output_dir, state_dict) diff --git a/model/llava/train/train.py b/model/llava/train/train.py index 49f7a0d5e33c7c082aa7e9857344a29a203dce13..f76872b3acd023cb5e69e32c239bd0ae5503443d 100644 --- a/model/llava/train/train.py +++ b/model/llava/train/train.py @@ -14,25 +14,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import copy -from dataclasses import dataclass, field import json import logging +import os import pathlib +from dataclasses import dataclass, field from typing import Dict, Optional, Sequence import torch - +import torch.nn as nn import transformers -from torch.utils.data import Dataset -from llava.train.llava_trainer import LLaVATrainer - from llava import conversation as conversation_lib from llava.model import * - +from llava.train.llava_trainer import LLaVATrainer from PIL import Image -import torch.nn as nn +from torch.utils.data import Dataset # TODO: import and use code from ../data/dataset.py @@ -54,21 +51,24 @@ class ModelArguments: freeze_backbone: bool = field(default=False) tune_mm_mlp_adapter: bool = field(default=False) vision_tower: Optional[str] = field(default=None) - mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer + mm_vision_select_layer: Optional[int] = field( + default=-1 + ) # default to the last layer pretrain_mm_mlp_adapter: Optional[str] = field(default=None) mm_use_im_start_end: bool = field(default=False) @dataclass class DataArguments: - data_path: str = field(default=None, - metadata={"help": "Path to the training data."}) + data_path: str = field( + default=None, metadata={"help": "Path to the training data."} + ) lazy_preprocess: bool = False is_multimodal: bool = False sep_image_conv_front: bool = False image_token_len: int = 0 image_folder: Optional[str] = field(default=None) - image_aspect_ratio: str = 'square' + image_aspect_ratio: str = "square" @dataclass @@ -81,21 +81,16 @@ class TrainingArguments(transformers.TrainingArguments): model_max_length: int = field( default=512, metadata={ - "help": - "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." }, ) -def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, - output_dir: str): +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): """Collects the state dict and dump to disk.""" state_dict = trainer.model.state_dict() if trainer.args.should_save: - cpu_state_dict = { - key: value.cpu() - for key, value in state_dict.items() - } + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} del state_dict trainer._save(output_dir, state_dict=cpu_state_dict) # noqa @@ -117,16 +112,19 @@ def smart_tokenizer_and_embedding_resize( output_embeddings = model.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( - dim=0, keepdim=True) + dim=0, keepdim=True + ) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( - dim=0, keepdim=True) + dim=0, keepdim=True + ) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg -def _tokenize_fn(strings: Sequence[str], - tokenizer: transformers.PreTrainedTokenizer) -> Dict: +def _tokenize_fn( + strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer +) -> Dict: """Tokenize a list of strings.""" tokenized_list = [ tokenizer( @@ -135,11 +133,10 @@ def _tokenize_fn(strings: Sequence[str], padding="longest", max_length=tokenizer.model_max_length, truncation=True, - ) for text in strings - ] - input_ids = labels = [ - tokenized.input_ids[0] for tokenized in tokenized_list + ) + for text in strings ] + input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] input_ids_lens = labels_lens = [ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list @@ -159,7 +156,7 @@ def _mask_targets(target, tokenized_lens, speakers): target[:cur_idx] = IGNORE_INDEX for tokenized_len, speaker in zip(tokenized_lens, speakers): if speaker == "human": - target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX + target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX cur_idx += tokenized_len @@ -175,9 +172,10 @@ def _add_speaker_and_signal(header, source, get_conversation=True): elif from_str.lower() == "gpt": from_str = conversation_lib.default_conversation.roles[1] else: - from_str = 'unknown' - sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + - sentence["value"] + END_SIGNAL) + from_str = "unknown" + sentence["value"] = ( + BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL + ) if get_conversation: conversation += sentence["value"] conversation += BEGIN_SIGNAL @@ -189,22 +187,34 @@ def preprocess_multimodal( multimodal_cfg: dict, cur_token_len: int, ) -> Dict: - is_multimodal = multimodal_cfg['is_multimodal'] + is_multimodal = multimodal_cfg["is_multimodal"] # image_token_len = multimodal_cfg['image_token_len'] image_token_len = cur_token_len if not is_multimodal: return sources for source in sources: - if multimodal_cfg['sep_image_conv_front']: - assert DEFAULT_IMAGE_TOKEN in source[0]['value'] - source[0]['value'] = source[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() - source[0]['value'] = DEFAULT_IMAGE_TOKEN + conversation_lib.default_conversation.sep + conversation_lib.default_conversation.roles[0] + ": " + source[0]['value'] + if multimodal_cfg["sep_image_conv_front"]: + assert DEFAULT_IMAGE_TOKEN in source[0]["value"] + source[0]["value"] = ( + source[0]["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip() + ) + source[0]["value"] = ( + DEFAULT_IMAGE_TOKEN + + conversation_lib.default_conversation.sep + + conversation_lib.default_conversation.roles[0] + + ": " + + source[0]["value"] + ) for sentence in source: replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len - if multimodal_cfg['use_im_start_end']: - replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN - sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) + if multimodal_cfg["use_im_start_end"]: + replace_token = ( + DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + ) + sentence["value"] = sentence["value"].replace( + DEFAULT_IMAGE_TOKEN, replace_token + ) return sources @@ -279,6 +289,7 @@ def preprocess_v1( labels=targets, ) + def preprocess_mpt( sources, tokenizer: transformers.PreTrainedTokenizer, @@ -317,9 +328,11 @@ def preprocess_mpt( total_len = int(target.ne(tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep) - re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt + re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt for conv_idx in range(3, len(rounds), 2): - re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt + re_rounds.append( + conv.sep.join(rounds[conv_idx : conv_idx + 2]) + ) # user + gpt cur_len = 0 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(re_rounds): @@ -330,7 +343,9 @@ def preprocess_mpt( if len(parts) != 2: break parts[0] += sep - round_len = len(tokenizer(rou).input_ids) + len(tokenizer(conv.sep).input_ids) + round_len = len(tokenizer(rou).input_ids) + len( + tokenizer(conv.sep).input_ids + ) instruction_len = len(tokenizer(parts[0]).input_ids) target[cur_len : cur_len + instruction_len] = IGNORE_INDEX @@ -377,8 +392,9 @@ def preprocess( input_ids = conversations_tokenized["input_ids"] targets = copy.deepcopy(input_ids) for target, source in zip(targets, sources): - tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], - tokenizer)["input_ids_lens"] + tokenized_lens = _tokenize_fn( + [header] + [s["value"] for s in source], tokenizer + )["input_ids_lens"] speakers = [sentence["from"] for sentence in source] _mask_targets(target, tokenized_lens, speakers) @@ -388,8 +404,7 @@ def preprocess( class SupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" - def __init__(self, data_path: str, - tokenizer: transformers.PreTrainedTokenizer): + def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer): super(SupervisedDataset, self).__init__() logging.warning("Loading data...") list_data_dict = json.load(open(data_path, "r")) @@ -411,9 +426,12 @@ class SupervisedDataset(Dataset): class LazySupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" - def __init__(self, data_path: str, - tokenizer: transformers.PreTrainedTokenizer, - multimodal_cfg: dict): + def __init__( + self, + data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + multimodal_cfg: dict, + ): super(LazySupervisedDataset, self).__init__() logging.warning("Loading data...") list_data_dict = json.load(open(data_path, "r")) @@ -431,54 +449,74 @@ class LazySupervisedDataset(Dataset): if isinstance(i, int): sources = [sources] assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME - if 'image' in sources[0]: - image_file = self.list_data_dict[i]['image'] - image_folder = self.multimodal_cfg['image_folder'] - processor = self.multimodal_cfg['image_processor'] - image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') - if self.multimodal_cfg['image_aspect_ratio'] == 'keep': + if "image" in sources[0]: + image_file = self.list_data_dict[i]["image"] + image_folder = self.multimodal_cfg["image_folder"] + processor = self.multimodal_cfg["image_processor"] + image = Image.open(os.path.join(image_folder, image_file)).convert("RGB") + if self.multimodal_cfg["image_aspect_ratio"] == "keep": max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 448, 224 shortest_edge = int(min(max_len / aspect_ratio, min_len)) - image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": shortest_edge})['pixel_values'][0] - elif self.multimodal_cfg['image_aspect_ratio'] == 'pad': + image = processor.preprocess( + image, + return_tensors="pt", + do_center_crop=False, + size={"shortest_edge": shortest_edge}, + )["pixel_values"][0] + elif self.multimodal_cfg["image_aspect_ratio"] == "pad": + def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: - result = Image.new(pil_img.mode, (width, width), background_color) + result = Image.new( + pil_img.mode, (width, width), background_color + ) result.paste(pil_img, (0, (width - height) // 2)) return result else: - result = Image.new(pil_img.mode, (height, height), background_color) + result = Image.new( + pil_img.mode, (height, height), background_color + ) result.paste(pil_img, ((height - width) // 2, 0)) return result - image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) - image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + + image = expand2square( + image, tuple(int(x * 255) for x in processor.image_mean) + ) + image = processor.preprocess(image, return_tensors="pt")[ + "pixel_values" + ][0] else: - image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] - cur_token_len = (image.shape[1]//14) * (image.shape[2]//14) # FIXME: 14 is hardcoded patch size + image = processor.preprocess(image, return_tensors="pt")[ + "pixel_values" + ][0] + cur_token_len = (image.shape[1] // 14) * ( + image.shape[2] // 14 + ) # FIXME: 14 is hardcoded patch size sources = preprocess_multimodal( copy.deepcopy([e["conversations"] for e in sources]), - self.multimodal_cfg, cur_token_len) + self.multimodal_cfg, + cur_token_len, + ) else: sources = copy.deepcopy([e["conversations"] for e in sources]) - data_dict = preprocess( - sources, - self.tokenizer) + data_dict = preprocess(sources, self.tokenizer) if isinstance(i, int): - data_dict = dict(input_ids=data_dict["input_ids"][0], - labels=data_dict["labels"][0]) + data_dict = dict( + input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0] + ) # image exist in the data - if 'image' in self.list_data_dict[i]: - data_dict['image'] = image - elif self.multimodal_cfg['is_multimodal']: + if "image" in self.list_data_dict[i]: + data_dict["image"] = image + elif self.multimodal_cfg["is_multimodal"]: # image does not exist in the data, but the model is multimodal - crop_size = self.multimodal_cfg['image_processor'].crop_size - data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) + crop_size = self.multimodal_cfg["image_processor"].crop_size + data_dict["image"] = torch.zeros(3, crop_size["height"], crop_size["width"]) return data_dict @@ -489,59 +527,65 @@ class DataCollatorForSupervisedDataset(object): tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: - input_ids, labels = tuple([instance[key] for instance in instances] - for key in ("input_ids", "labels")) + input_ids, labels = tuple( + [instance[key] for instance in instances] for key in ("input_ids", "labels") + ) input_ids = torch.nn.utils.rnn.pad_sequence( - input_ids, - batch_first=True, - padding_value=self.tokenizer.pad_token_id) - labels = torch.nn.utils.rnn.pad_sequence(labels, - batch_first=True, - padding_value=IGNORE_INDEX) + input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id + ) + labels = torch.nn.utils.rnn.pad_sequence( + labels, batch_first=True, padding_value=IGNORE_INDEX + ) batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) - if 'image' in instances[0]: - images = [instance['image'] for instance in instances] + if "image" in instances[0]: + images = [instance["image"] for instance in instances] if all(x is not None and x.shape == images[0].shape for x in images): - batch['images'] = torch.stack(images) + batch["images"] = torch.stack(images) else: - batch['images'] = images + batch["images"] = images return batch -def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, - data_args) -> Dict: +def make_supervised_data_module( + tokenizer: transformers.PreTrainedTokenizer, data_args +) -> Dict: """Make dataset and collator for supervised fine-tuning.""" - dataset_cls = (LazySupervisedDataset - if data_args.lazy_preprocess else SupervisedDataset) - train_dataset = dataset_cls(tokenizer=tokenizer, - data_path=data_args.data_path, - multimodal_cfg=dict( - is_multimodal=data_args.is_multimodal, - sep_image_conv_front=data_args.sep_image_conv_front, - image_token_len=data_args.image_token_len, - image_folder=data_args.image_folder, - image_aspect_ratio=data_args.image_aspect_ratio, - use_im_start_end=getattr(data_args, 'mm_use_im_start_end', False), - image_processor=getattr(data_args, 'image_processor', None))) + dataset_cls = ( + LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset + ) + train_dataset = dataset_cls( + tokenizer=tokenizer, + data_path=data_args.data_path, + multimodal_cfg=dict( + is_multimodal=data_args.is_multimodal, + sep_image_conv_front=data_args.sep_image_conv_front, + image_token_len=data_args.image_token_len, + image_folder=data_args.image_folder, + image_aspect_ratio=data_args.image_aspect_ratio, + use_im_start_end=getattr(data_args, "mm_use_im_start_end", False), + image_processor=getattr(data_args, "image_processor", None), + ), + ) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) - return dict(train_dataset=train_dataset, - eval_dataset=None, - data_collator=data_collator) + return dict( + train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator + ) def train(): parser = transformers.HfArgumentParser( - (ModelArguments, DataArguments, TrainingArguments)) + (ModelArguments, DataArguments, TrainingArguments) + ) model_args, data_args, training_args = parser.parse_args_into_dataclasses() if model_args.vision_tower is not None: - if 'mpt' in model_args.model_name_or_path: + if "mpt" in model_args.model_name_or_path: model = LlavaMPTForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, @@ -561,12 +605,12 @@ def train(): if model_args.freeze_backbone: model.model.requires_grad_(False) - if 'mpt' in model_args.model_name_or_path: + if "mpt" in model_args.model_name_or_path: tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, - padding_side="right" + padding_side="right", ) else: tokenizer = transformers.AutoTokenizer.from_pretrained( @@ -585,23 +629,29 @@ def train(): model=model, ) if "llama" in model_args.model_name_or_path: - tokenizer.add_special_tokens({ - "eos_token": DEFAULT_EOS_TOKEN, - "bos_token": DEFAULT_BOS_TOKEN, - "unk_token": DEFAULT_UNK_TOKEN, - }) + tokenizer.add_special_tokens( + { + "eos_token": DEFAULT_EOS_TOKEN, + "bos_token": DEFAULT_BOS_TOKEN, + "unk_token": DEFAULT_UNK_TOKEN, + } + ) else: tokenizer.pad_token = tokenizer.unk_token if "mpt" in model_args.model_name_or_path: - conversation_lib.default_conversation = conversation_lib.conv_templates["mpt"] + conversation_lib.default_conversation = conversation_lib.conv_templates[ + "mpt" + ] else: - conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1_1"] + conversation_lib.default_conversation = conversation_lib.conv_templates[ + "vicuna_v1_1" + ] if model_args.vision_tower is not None: model_vision_dict = model.get_model().initialize_vision_modules( vision_tower=model_args.vision_tower, mm_vision_select_layer=model_args.mm_vision_select_layer, - pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter + pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter, ) dtype = torch.float32 if training_args.fp16: @@ -609,13 +659,15 @@ def train(): if training_args.bf16: dtype = torch.bfloat16 model.get_model().vision_tower[0].to(dtype=dtype, device=training_args.device) - vision_config = model_vision_dict['vision_config'] + vision_config = model_vision_dict["vision_config"] - data_args.image_token_len = model_vision_dict['image_token_len'] - data_args.image_processor = model_vision_dict['image_processor'] + data_args.image_token_len = model_vision_dict["image_token_len"] + data_args.image_processor = model_vision_dict["image_processor"] data_args.is_multimodal = True - model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter + model.config.tune_mm_mlp_adapter = ( + training_args.tune_mm_mlp_adapter + ) = model_args.tune_mm_mlp_adapter if model_args.tune_mm_mlp_adapter: model.requires_grad_(False) for p in model.get_model().mm_projector.parameters(): @@ -626,45 +678,66 @@ def train(): for p in model.get_model().mm_projector.parameters(): p.requires_grad = False - model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end - vision_config.use_im_start_end = training_args.use_im_start_end = model_args.mm_use_im_start_end + model.config.mm_use_im_start_end = ( + data_args.mm_use_im_start_end + ) = model_args.mm_use_im_start_end + vision_config.use_im_start_end = ( + training_args.use_im_start_end + ) = model_args.mm_use_im_start_end model.config.sep_image_conv_front = data_args.sep_image_conv_front - model.initialize_vision_tokenizer(mm_use_im_start_end=model_args.mm_use_im_start_end, tokenizer=tokenizer, device=training_args.device, - tune_mm_mlp_adapter=model_args.tune_mm_mlp_adapter, pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter) + model.initialize_vision_tokenizer( + mm_use_im_start_end=model_args.mm_use_im_start_end, + tokenizer=tokenizer, + device=training_args.device, + tune_mm_mlp_adapter=model_args.tune_mm_mlp_adapter, + pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter, + ) params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad] if len(params_no_grad) > 0: if training_args.fsdp is not None and len(training_args.fsdp) > 0: if len(params_no_grad) < 10: - print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad)) + print( + "[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}".format( + len(params_no_grad), params_no_grad + ) + ) else: - print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10]))) - print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.") - print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining") + print( + "[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)".format( + len(params_no_grad), ", ".join(params_no_grad[:10]) + ) + ) + print( + "[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental." + ) + print( + "[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining" + ) + + from torch.distributed.fsdp.fully_sharded_data_parallel import \ + FullyShardedDataParallel as FSDP - from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP def patch_FSDP_use_orig_params(func): def wrap_func(*args, **kwargs): - use_orig_params = kwargs.pop('use_orig_params', True) + use_orig_params = kwargs.pop("use_orig_params", True) return func(*args, **kwargs, use_orig_params=use_orig_params) + return wrap_func FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__) - data_module = make_supervised_data_module(tokenizer=tokenizer, - data_args=data_args) - trainer = LLaVATrainer(model=model, - tokenizer=tokenizer, - args=training_args, - **data_module) + data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) + trainer = LLaVATrainer( + model=model, tokenizer=tokenizer, args=training_args, **data_module + ) if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): trainer.train(resume_from_checkpoint=True) else: trainer.train() trainer.save_state() - safe_save_model_for_hf_trainer(trainer=trainer, - output_dir=training_args.output_dir) + safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) if __name__ == "__main__": diff --git a/model/llava/train/train_mem.py b/model/llava/train/train_mem.py index 2487d317855b27d5b07a755ee0389667e4964f02..f3940cf7fea248d055a9cb333a08ebca0f782885 100644 --- a/model/llava/train/train_mem.py +++ b/model/llava/train/train_mem.py @@ -3,7 +3,8 @@ # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. # Need to call this before importing transformers. -from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn +from llava.train.llama_flash_attn_monkey_patch import \ + replace_llama_attn_with_flash_attn replace_llama_attn_with_flash_attn() diff --git a/model/llava/utils.py b/model/llava/utils.py index 8f7163c0ba1d9a81d81a950bce61e0f0db06066e..0a2d5fd533ded77352f5548a0ed027b700365ea4 100644 --- a/model/llava/utils.py +++ b/model/llava/utils.py @@ -5,11 +5,14 @@ import os import sys import requests - from llava.constants import LOGDIR -server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" -moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." +server_error_msg = ( + "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" +) +moderation_msg = ( + "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." +) handler = None @@ -47,7 +50,8 @@ def build_logger(logger_name, logger_filename): os.makedirs(LOGDIR, exist_ok=True) filename = os.path.join(LOGDIR, logger_filename) handler = logging.handlers.TimedRotatingFileHandler( - filename, when='D', utc=True) + filename, when="D", utc=True + ) handler.setFormatter(formatter) for name, item in logging.root.manager.loggerDict.items(): @@ -61,33 +65,34 @@ class StreamToLogger(object): """ Fake file-like stream object that redirects writes to a logger instance. """ + def __init__(self, logger, log_level=logging.INFO): self.terminal = sys.stdout self.logger = logger self.log_level = log_level - self.linebuf = '' + self.linebuf = "" def __getattr__(self, attr): return getattr(self.terminal, attr) def write(self, buf): temp_linebuf = self.linebuf + buf - self.linebuf = '' + self.linebuf = "" for line in temp_linebuf.splitlines(True): # From the io.TextIOWrapper docs: # On output, if newline is None, any '\n' characters written # are translated to the system default line separator. # By default sys.stdout.write() expects '\n' newlines and then # translates them so this is still cross platform. - if line[-1] == '\n': + if line[-1] == "\n": self.logger.log(self.log_level, line.rstrip()) else: self.linebuf += line def flush(self): - if self.linebuf != '': + if self.linebuf != "": self.logger.log(self.log_level, self.linebuf.rstrip()) - self.linebuf = '' + self.linebuf = "" def disable_torch_init(): @@ -95,6 +100,7 @@ def disable_torch_init(): Disable the redundant torch default initialization to accelerate model creation. """ import torch + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) @@ -104,8 +110,10 @@ def violates_moderation(text): Check whether the text violates OpenAI moderation API. """ url = "https://api.openai.com/v1/moderations" - headers = {"Content-Type": "application/json", - "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"], + } text = text.replace("\n", "") data = "{" + '"input": ' + f'"{text}"' + "}" data = data.encode("utf-8") diff --git a/model/segment_anything/__init__.py b/model/segment_anything/__init__.py index 34383d83f5e76bc801f31b20e5651e383be348b6..e66218b2edd8754f1546ad1dca8b604ce891c365 100755 --- a/model/segment_anything/__init__.py +++ b/model/segment_anything/__init__.py @@ -4,12 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from .build_sam import ( - build_sam, - build_sam_vit_h, - build_sam_vit_l, - build_sam_vit_b, - sam_model_registry, -) -from .predictor import SamPredictor from .automatic_mask_generator import SamAutomaticMaskGenerator +from .build_sam import (build_sam, build_sam_vit_b, build_sam_vit_h, + build_sam_vit_l, sam_model_registry) +from .predictor import SamPredictor diff --git a/model/segment_anything/automatic_mask_generator.py b/model/segment_anything/automatic_mask_generator.py index d5a8c969207f119feff7087f94e044403acdff00..aa4bc4f0324cf7f91ded55a0993b51deeec41537 100755 --- a/model/segment_anything/automatic_mask_generator.py +++ b/model/segment_anything/automatic_mask_generator.py @@ -4,32 +4,21 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from typing import Any, Dict, List, Optional, Tuple + import numpy as np import torch from torchvision.ops.boxes import batched_nms, box_area # type: ignore -from typing import Any, Dict, List, Optional, Tuple - from .modeling import Sam from .predictor import SamPredictor -from .utils.amg import ( - MaskData, - area_from_rle, - batch_iterator, - batched_mask_to_box, - box_xyxy_to_xywh, - build_all_layer_point_grids, - calculate_stability_score, - coco_encode_rle, - generate_crop_boxes, - is_box_near_crop_edge, - mask_to_rle_pytorch, - remove_small_regions, - rle_to_mask, - uncrop_boxes_xyxy, - uncrop_masks, - uncrop_points, -) +from .utils.amg import (MaskData, area_from_rle, batch_iterator, + batched_mask_to_box, box_xyxy_to_xywh, + build_all_layer_point_grids, calculate_stability_score, + coco_encode_rle, generate_crop_boxes, + is_box_near_crop_edge, mask_to_rle_pytorch, + remove_small_regions, rle_to_mask, uncrop_boxes_xyxy, + uncrop_masks, uncrop_points) class SamAutomaticMaskGenerator: @@ -115,7 +104,8 @@ class SamAutomaticMaskGenerator: "coco_rle", ], f"Unknown output_mode {output_mode}." if output_mode == "coco_rle": - from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + from pycocotools import \ + mask as mask_utils # type: ignore # noqa: F401 if min_mask_region_area > 0: import cv2 # type: ignore # noqa: F401 @@ -172,7 +162,9 @@ class SamAutomaticMaskGenerator: # Encode masks if self.output_mode == "coco_rle": - mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] + mask_data["segmentations"] = [ + coco_encode_rle(rle) for rle in mask_data["rles"] + ] elif self.output_mode == "binary_mask": mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] else: @@ -242,7 +234,9 @@ class SamAutomaticMaskGenerator: # Generate masks for this crop in batches data = MaskData() for (points,) in batch_iterator(self.points_per_batch, points_for_image): - batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) + batch_data = self._process_batch( + points, cropped_im_size, crop_box, orig_size + ) data.cat(batch_data) del batch_data self.predictor.reset_image() @@ -275,7 +269,9 @@ class SamAutomaticMaskGenerator: # Run model on this batch transformed_points = self.predictor.transform.apply_coords(points, im_size) in_points = torch.as_tensor(transformed_points, device=self.predictor.device) - in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) + in_labels = torch.ones( + in_points.shape[0], dtype=torch.int, device=in_points.device + ) masks, iou_preds, _ = self.predictor.predict_torch( in_points[:, None, :], in_labels[:, None], @@ -298,7 +294,9 @@ class SamAutomaticMaskGenerator: # Calculate stability score data["stability_score"] = calculate_stability_score( - data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset + data["masks"], + self.predictor.model.mask_threshold, + self.stability_score_offset, ) if self.stability_score_thresh > 0.0: keep_mask = data["stability_score"] >= self.stability_score_thresh @@ -309,7 +307,9 @@ class SamAutomaticMaskGenerator: data["boxes"] = batched_mask_to_box(data["masks"]) # Filter boxes that touch crop boundaries - keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) + keep_mask = ~is_box_near_crop_edge( + data["boxes"], crop_box, [0, 0, orig_w, orig_h] + ) if not torch.all(keep_mask): data.filter(keep_mask) diff --git a/model/segment_anything/build_sam.py b/model/segment_anything/build_sam.py index 2f85cebcc30a0c410453ee257c1f1b7091872b0a..788d25ad5a6fd32c112201301b320f5884d6e8e8 100755 --- a/model/segment_anything/build_sam.py +++ b/model/segment_anything/build_sam.py @@ -4,11 +4,12 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import torch - from functools import partial -from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer +import torch + +from .modeling import (ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, + TwoWayTransformer) def build_sam_vit_h(checkpoint=None): diff --git a/model/segment_anything/modeling/__init__.py b/model/segment_anything/modeling/__init__.py index 38e906243d898d7fc071c0fe218338c5cace3ea1..088af386e5b45d14e99d11dec132821ddba5df39 100755 --- a/model/segment_anything/modeling/__init__.py +++ b/model/segment_anything/modeling/__init__.py @@ -4,8 +4,8 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from .sam import Sam from .image_encoder import ImageEncoderViT from .mask_decoder import MaskDecoder from .prompt_encoder import PromptEncoder +from .sam import Sam from .transformer import TwoWayTransformer diff --git a/model/segment_anything/modeling/common.py b/model/segment_anything/modeling/common.py index 2bf15236a3eb24d8526073bc4fa2b274cccb3f96..e8727816d4861a2d0c7c367879951d1d4fa791fb 100755 --- a/model/segment_anything/modeling/common.py +++ b/model/segment_anything/modeling/common.py @@ -4,11 +4,11 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from typing import Type + import torch import torch.nn as nn -from typing import Type - class MLPBlock(nn.Module): def __init__( diff --git a/model/segment_anything/modeling/image_encoder.py b/model/segment_anything/modeling/image_encoder.py index b34026cc61c09550549a6f3e6d932a1e19e308c6..b472a3d6b7a609134afe18d7f8740e0c01a56842 100755 --- a/model/segment_anything/modeling/image_encoder.py +++ b/model/segment_anything/modeling/image_encoder.py @@ -4,12 +4,12 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from typing import Optional, Tuple, Type + import torch import torch.nn as nn import torch.nn.functional as F -from typing import Optional, Tuple, Type - from .common import LayerNorm2d, MLPBlock @@ -68,7 +68,9 @@ class ImageEncoderViT(nn.Module): if use_abs_pos: # Initialize absolute positional embedding with pretrain image size. self.pos_embed = nn.Parameter( - torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + torch.zeros( + 1, img_size // patch_size, img_size // patch_size, embed_dim + ) ) self.blocks = nn.ModuleList() @@ -106,7 +108,6 @@ class ImageEncoderViT(nn.Module): ) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.patch_embed(x) if self.pos_embed is not None: x = x + self.pos_embed @@ -115,8 +116,8 @@ class ImageEncoderViT(nn.Module): x = blk(x) dtype = x.dtype - if dtype == torch.float16: # prevent overflow - with torch.autocast(device_type='cuda', dtype=torch.float32): + if dtype == torch.float16: # prevent overflow + with torch.autocast(device_type="cuda", dtype=torch.float32): x = self.neck(x.permute(0, 3, 1, 2)) x = x.to(dtype) else: @@ -167,7 +168,9 @@ class Block(nn.Module): ) self.norm2 = norm_layer(dim) - self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + self.mlp = MLPBlock( + embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer + ) self.window_size = window_size @@ -232,23 +235,34 @@ class Attention(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: B, H, W, _ = x.shape # qkv with shape (3, B, nHead, H * W, C) - qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + qkv = ( + self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + ) # q, k, v with shape (B * nHead, H * W, C) q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) attn = (q * self.scale) @ k.transpose(-2, -1) if self.use_rel_pos: - attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + attn = add_decomposed_rel_pos( + attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W) + ) attn = attn.softmax(dim=-1) - x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = ( + (attn @ v) + .view(B, self.num_heads, H, W, -1) + .permute(0, 2, 3, 1, 4) + .reshape(B, H, W, -1) + ) x = self.proj(x) return x -def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: +def window_partition( + x: torch.Tensor, window_size: int +) -> Tuple[torch.Tensor, Tuple[int, int]]: """ Partition into non-overlapping windows with padding if needed. Args: @@ -268,12 +282,17 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T Hp, Wp = H + pad_h, W + pad_w x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) return windows, (Hp, Wp) def window_unpartition( - windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] + windows: torch.Tensor, + window_size: int, + pad_hw: Tuple[int, int], + hw: Tuple[int, int], ) -> torch.Tensor: """ Window unpartition into original sequences and removing padding. @@ -289,7 +308,9 @@ def window_unpartition( Hp, Wp = pad_hw H, W = hw B = windows.shape[0] // (Hp * Wp // window_size // window_size) - x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = windows.view( + B, Hp // window_size, Wp // window_size, window_size, window_size, -1 + ) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) if Hp > H or Wp > W: @@ -363,7 +384,9 @@ def add_decomposed_rel_pos( rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) attn = ( - attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + attn.view(B, q_h, q_w, k_h, k_w) + + rel_h[:, :, :, :, None] + + rel_w[:, :, :, None, :] ).view(B, q_h * q_w, k_h * k_w) return attn diff --git a/model/segment_anything/modeling/mask_decoder.py b/model/segment_anything/modeling/mask_decoder.py index f7c0f1be2ce3dbee6b0f32656a51fe9c48c353e3..105bc9206e0fc2b1ceef69a31f4a16ae07e37a94 100755 --- a/model/segment_anything/modeling/mask_decoder.py +++ b/model/segment_anything/modeling/mask_decoder.py @@ -4,12 +4,12 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from typing import List, Tuple, Type + import torch from torch import nn from torch.nn import functional as F -from typing import List, Tuple, Type - from .common import LayerNorm2d @@ -51,10 +51,14 @@ class MaskDecoder(nn.Module): self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) self.output_upscaling = nn.Sequential( - nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + nn.ConvTranspose2d( + transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 + ), LayerNorm2d(transformer_dim // 4), activation(), - nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + nn.ConvTranspose2d( + transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 + ), activation(), ) self.output_hypernetworks_mlps = nn.ModuleList( @@ -118,9 +122,13 @@ class MaskDecoder(nn.Module): ) -> Tuple[torch.Tensor, torch.Tensor]: """Predicts masks. See 'forward' for more details.""" # Concatenate output tokens - output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) - output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) - + output_tokens = torch.cat( + [self.iou_token.weight, self.mask_tokens.weight], dim=0 + ) + output_tokens = output_tokens.unsqueeze(0).expand( + sparse_prompt_embeddings.size(0), -1, -1 + ) + # sparse_prompt_embeddings = sparse_prompt_embeddings.half() tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) @@ -143,10 +151,14 @@ class MaskDecoder(nn.Module): upscaled_embedding = self.output_upscaling(src) hyper_in_list: List[torch.Tensor] = [] for i in range(self.num_mask_tokens): - hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in_list.append( + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) + ) hyper_in = torch.stack(hyper_in_list, dim=1) b, c, h, w = upscaled_embedding.shape - masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, self.num_mask_tokens, h, w) + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view( + b, self.num_mask_tokens, h, w + ) # Generate mask quality predictions iou_pred = self.iou_prediction_head(iou_token_out) diff --git a/model/segment_anything/modeling/prompt_encoder.py b/model/segment_anything/modeling/prompt_encoder.py index c08726b353e94e6b324759655bea4ab11238628d..16bc3a45e75f154453ed0724c70ce8daa0324c81 100755 --- a/model/segment_anything/modeling/prompt_encoder.py +++ b/model/segment_anything/modeling/prompt_encoder.py @@ -4,12 +4,12 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from typing import Any, Optional, Tuple, Type + import numpy as np import torch from torch import nn -from typing import Any, Optional, Tuple, Type - from .common import LayerNorm2d @@ -43,11 +43,16 @@ class PromptEncoder(nn.Module): self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners - point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + point_embeddings = [ + nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) + ] self.point_embeddings = nn.ModuleList(point_embeddings) self.not_a_point_embed = nn.Embedding(1, embed_dim) - self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_input_size = ( + 4 * image_embedding_size[0], + 4 * image_embedding_size[1], + ) self.mask_downscaling = nn.Sequential( nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), LayerNorm2d(mask_in_chans // 4), @@ -83,7 +88,9 @@ class PromptEncoder(nn.Module): padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) points = torch.cat([points, padding_point], dim=1) labels = torch.cat([labels, padding_label], dim=1) - point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding = self.pe_layer.forward_with_coords( + points, self.input_image_size + ) point_embedding[labels == -1] = 0.0 point_embedding[labels == -1] += self.not_a_point_embed.weight point_embedding[labels == 0] += self.point_embeddings[0].weight @@ -94,7 +101,9 @@ class PromptEncoder(nn.Module): """Embeds box prompts.""" boxes = boxes + 0.5 # Shift to center of pixel coords = boxes.reshape(-1, 2, 2) - corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding = self.pe_layer.forward_with_coords( + coords, self.input_image_size + ) corner_embedding[:, 0, :] += self.point_embeddings[2].weight corner_embedding[:, 1, :] += self.point_embeddings[3].weight return corner_embedding @@ -153,7 +162,9 @@ class PromptEncoder(nn.Module): Bx(embed_dim)x(embed_H)x(embed_W) """ bs = self._get_batch_size(points, boxes, masks, text_embeds) - sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + sparse_embeddings = torch.empty( + (bs, 0, self.embed_dim), device=self._get_device() + ) if points is not None: coords, labels = points point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) @@ -206,7 +217,9 @@ class PositionEmbeddingRandom(nn.Module): """Generate positional encoding for a grid of the specified size.""" h, w = size device: Any = self.positional_encoding_gaussian_matrix.device - grid = torch.ones((h, w), device=device, dtype=self.positional_encoding_gaussian_matrix.dtype) + grid = torch.ones( + (h, w), device=device, dtype=self.positional_encoding_gaussian_matrix.dtype + ) y_embed = grid.cumsum(dim=0) - 0.5 x_embed = grid.cumsum(dim=1) - 0.5 y_embed = y_embed / h diff --git a/model/segment_anything/modeling/sam.py b/model/segment_anything/modeling/sam.py index c857fd5aaad4e696b56aa2fa7c1b23ddf0ca569d..f1d82cac3cc1deea45171fd9360dfd7fa25e457a 100755 --- a/model/segment_anything/modeling/sam.py +++ b/model/segment_anything/modeling/sam.py @@ -4,12 +4,12 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from typing import Any, Dict, List, Tuple + import torch from torch import nn from torch.nn import functional as F -from typing import Any, Dict, List, Tuple - from .image_encoder import ImageEncoderViT from .mask_decoder import MaskDecoder from .prompt_encoder import PromptEncoder @@ -43,7 +43,9 @@ class Sam(nn.Module): self.image_encoder = image_encoder self.prompt_encoder = prompt_encoder self.mask_decoder = mask_decoder - self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer( + "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False + ) self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) @property @@ -94,7 +96,9 @@ class Sam(nn.Module): shape BxCxHxW, where H=W=256. Can be passed as mask input to subsequent iterations of prediction. """ - input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + input_images = torch.stack( + [self.preprocess(x["image"]) for x in batched_input], dim=0 + ) image_embeddings = self.image_encoder(input_images) outputs = [] @@ -162,7 +166,9 @@ class Sam(nn.Module): ) # masks = masks.to(dtype) masks = masks[..., : input_size[0], : input_size[1]] - masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + masks = F.interpolate( + masks, original_size, mode="bilinear", align_corners=False + ) return masks def preprocess(self, x: torch.Tensor) -> torch.Tensor: diff --git a/model/segment_anything/modeling/transformer.py b/model/segment_anything/modeling/transformer.py index 28fafea52288603fea275f3a100790471825c34a..8c511e4ff35cc91132b09edd788c96f9a5768161 100755 --- a/model/segment_anything/modeling/transformer.py +++ b/model/segment_anything/modeling/transformer.py @@ -4,12 +4,12 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import torch -from torch import Tensor, nn - import math from typing import Tuple, Type +import torch +from torch import Tensor, nn + from .common import MLPBlock @@ -198,7 +198,9 @@ class Attention(nn.Module): self.embedding_dim = embedding_dim self.internal_dim = embedding_dim // downsample_rate self.num_heads = num_heads - assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + assert ( + self.internal_dim % num_heads == 0 + ), "num_heads must divide embedding_dim." self.q_proj = nn.Linear(embedding_dim, self.internal_dim) self.k_proj = nn.Linear(embedding_dim, self.internal_dim) diff --git a/model/segment_anything/predictor.py b/model/segment_anything/predictor.py index a3820fb7de8647e5d6adf229debc498b33caad62..bf52d81c2ef2e81b87e574fc935e88749ae3ebf6 100755 --- a/model/segment_anything/predictor.py +++ b/model/segment_anything/predictor.py @@ -4,13 +4,12 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from typing import Optional, Tuple + import numpy as np import torch from .modeling import Sam - -from typing import Optional, Tuple - from .utils.transforms import ResizeLongestSide @@ -55,7 +54,9 @@ class SamPredictor: # Transform the image to the form expected by the model input_image = self.transform.apply_image(image) input_image_torch = torch.as_tensor(input_image, device=self.device) - input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[ + None, :, :, : + ] self.set_torch_image(input_image_torch, image.shape[:2]) @@ -131,7 +132,9 @@ class SamPredictor: a subsequent iteration as mask input. """ if not self.is_image_set: - raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) # Transform input prompts coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None @@ -140,15 +143,21 @@ class SamPredictor: point_labels is not None ), "point_labels must be supplied if point_coords is supplied." point_coords = self.transform.apply_coords(point_coords, self.original_size) - coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) - labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + coords_torch = torch.as_tensor( + point_coords, dtype=torch.float, device=self.device + ) + labels_torch = torch.as_tensor( + point_labels, dtype=torch.int, device=self.device + ) coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] if box is not None: box = self.transform.apply_boxes(box, self.original_size) box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) box_torch = box_torch[None, :] if mask_input is not None: - mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) + mask_input_torch = torch.as_tensor( + mask_input, dtype=torch.float, device=self.device + ) mask_input_torch = mask_input_torch[None, :, :, :] masks, iou_predictions, low_res_masks = self.predict_torch( @@ -211,7 +220,9 @@ class SamPredictor: a subsequent iteration as mask input. """ if not self.is_image_set: - raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) if point_coords is not None: points = (point_coords, point_labels) @@ -235,7 +246,9 @@ class SamPredictor: ) # Upscale the masks to the original image resolution - masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) + masks = self.model.postprocess_masks( + low_res_masks, self.input_size, self.original_size + ) if not return_logits: masks = masks > self.model.mask_threshold @@ -252,7 +265,9 @@ class SamPredictor: raise RuntimeError( "An image must be set with .set_image(...) to generate an embedding." ) - assert self.features is not None, "Features must exist if an image has been set." + assert ( + self.features is not None + ), "Features must exist if an image has been set." return self.features @property diff --git a/model/segment_anything/utils/amg.py b/model/segment_anything/utils/amg.py index be064071ef399fea96c673ad173689656c23534a..5c3bc5d789049076a2404b1b2477110cebc32fb2 100755 --- a/model/segment_anything/utils/amg.py +++ b/model/segment_anything/utils/amg.py @@ -4,14 +4,14 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import numpy as np -import torch - import math from copy import deepcopy from itertools import product from typing import Any, Dict, Generator, ItemsView, List, Tuple +import numpy as np +import torch + class MaskData: """ diff --git a/model/segment_anything/utils/onnx.py b/model/segment_anything/utils/onnx.py index 3196bdf4b782e6eeb3da4ad66ef3c7b1741535fe..3521208f620aeef707707037d027c0156d940cdf 100755 --- a/model/segment_anything/utils/onnx.py +++ b/model/segment_anything/utils/onnx.py @@ -4,12 +4,12 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from typing import Tuple + import torch import torch.nn as nn from torch.nn import functional as F -from typing import Tuple - from ..modeling import Sam from .amg import calculate_stability_score @@ -48,32 +48,43 @@ class SamOnnxModel(nn.Module): transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) return transformed_size - def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: + def _embed_points( + self, point_coords: torch.Tensor, point_labels: torch.Tensor + ) -> torch.Tensor: point_coords = point_coords + 0.5 point_coords = point_coords / self.img_size point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) point_embedding = point_embedding * (point_labels != -1) - point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( - point_labels == -1 + point_embedding = ( + point_embedding + + self.model.prompt_encoder.not_a_point_embed.weight * (point_labels == -1) ) for i in range(self.model.prompt_encoder.num_point_embeddings): - point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ - i - ].weight * (point_labels == i) + point_embedding = ( + point_embedding + + self.model.prompt_encoder.point_embeddings[i].weight + * (point_labels == i) + ) return point_embedding - def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: - mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) + def _embed_masks( + self, input_mask: torch.Tensor, has_mask_input: torch.Tensor + ) -> torch.Tensor: + mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling( + input_mask + ) mask_embedding = mask_embedding + ( 1 - has_mask_input ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) return mask_embedding - def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: + def mask_postprocessing( + self, masks: torch.Tensor, orig_im_size: torch.Tensor + ) -> torch.Tensor: masks = F.interpolate( masks, size=(self.img_size, self.img_size), @@ -81,7 +92,9 @@ class SamOnnxModel(nn.Module): align_corners=False, ) - prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) + prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to( + torch.int64 + ) masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore orig_im_size = orig_im_size.to(torch.int64) diff --git a/model/segment_anything/utils/transforms.py b/model/segment_anything/utils/transforms.py index 97a682a28ed0fb1481a27a6134d44a98d41d78f3..4232d84252ea4983b194b2ebe8796741d252ef87 100755 --- a/model/segment_anything/utils/transforms.py +++ b/model/segment_anything/utils/transforms.py @@ -4,13 +4,14 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from copy import deepcopy +from typing import Tuple + import numpy as np import torch from torch.nn import functional as F -from torchvision.transforms.functional import resize, to_pil_image # type: ignore - -from copy import deepcopy -from typing import Tuple +from torchvision.transforms.functional import resize # type: ignore +from torchvision.transforms.functional import to_pil_image class ResizeLongestSide: @@ -27,10 +28,14 @@ class ResizeLongestSide: """ Expects a numpy array with shape HxWxC in uint8 format. """ - target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + target_size = self.get_preprocess_shape( + image.shape[0], image.shape[1], self.target_length + ) return np.array(resize(to_pil_image(image), target_size)) - def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + def apply_coords( + self, coords: np.ndarray, original_size: Tuple[int, ...] + ) -> np.ndarray: """ Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format. @@ -44,7 +49,9 @@ class ResizeLongestSide: coords[..., 1] = coords[..., 1] * (new_h / old_h) return coords - def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + def apply_boxes( + self, boxes: np.ndarray, original_size: Tuple[int, ...] + ) -> np.ndarray: """ Expects a numpy array shape Bx4. Requires the original image size in (H, W) format. @@ -59,7 +66,9 @@ class ResizeLongestSide: the transformation expected by the model. """ # Expects an image in BCHW format. May not exactly match apply_image. - target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + target_size = self.get_preprocess_shape( + image.shape[0], image.shape[1], self.target_length + ) return F.interpolate( image, target_size, mode="bilinear", align_corners=False, antialias=True ) @@ -91,7 +100,9 @@ class ResizeLongestSide: return boxes.reshape(-1, 4) @staticmethod - def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: + def get_preprocess_shape( + oldh: int, oldw: int, long_side_length: int + ) -> Tuple[int, int]: """ Compute the output size given input size and target long side length. """ diff --git a/train_ds.py b/train_ds.py new file mode 100755 index 0000000000000000000000000000000000000000..ea98d250e09667c11c2b0d1296765258c8c20df7 --- /dev/null +++ b/train_ds.py @@ -0,0 +1,455 @@ +import argparse +import os +import shutil +import sys +import time +from functools import partial + +import deepspeed +import numpy as np +import torch +import tqdm +import transformers +from torch.utils.tensorboard import SummaryWriter + +from model.LISA import LISA +from utils.dataset import HybridDataset, ValDataset, collate_fn +from utils.utils import (AverageMeter, ProgressMeter, Summary, dict_to_cuda, + intersectionAndUnionGPU) + + +def parse_args(args): + parser = argparse.ArgumentParser(description="LISA Model Training") + parser.add_argument("--local_rank", default=0, type=int, help="node rank") + parser.add_argument( + "--version", default="liuhaotian/llava-llama-2-13b-chat-lightning-preview" + ) + parser.add_argument("--vis_save_path", default="./vis_output", type=str) + parser.add_argument( + "--precision", + default="bf16", + type=str, + choices=["fp32", "bf16", "fp16"], + help="precision for inference", + ) + parser.add_argument("--image_size", default=1024, type=int, help="image size") + parser.add_argument("--model_max_length", default=512, type=int) + parser.add_argument("--lora_r", default=8, type=int) + parser.add_argument( + "--vision-tower", default="openai/clip-vit-large-patch14", type=str + ) + parser.add_argument("--load_in_8bit", action="store_true", default=False) + parser.add_argument("--load_in_4bit", action="store_true", default=False) + + parser.add_argument( + "--dataset", default="sem_seg||refer_seg||vqa||reason_seg", type=str + ) + parser.add_argument( + "--sem_seg_data", + default="ade20k||cocostuff||pascal_part||paco_lvis||mapillary", + type=str, + ) + parser.add_argument( + "--refer_seg_data", default="refclef||refcoco||refcoco+||refcocog", type=str + ) + parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str) + parser.add_argument("--reason_seg_data", default="ReasonSeg|train", type=str) + parser.add_argument( + "--val_dataset", default="ReasonSeg|val", type=str + ) + parser.add_argument("--dataset_dir", default="./dataset", type=str) + parser.add_argument("--log_base_dir", default="./runs", type=str) + parser.add_argument("--exp_name", default="lisa", type=str) + parser.add_argument("--epochs", default=20, type=int) + parser.add_argument("--steps_per_epoch", default=500, type=int) + parser.add_argument( + "--batch_size", default=2, type=int, help="batch size per device per step" + ) + parser.add_argument( + "--grad_accumulation_steps", + default=10, + type=int, + help="batch size per device per step", + ) + parser.add_argument("--val_batch_size", default=1, type=int) + parser.add_argument("--workers", default=4, type=int) + parser.add_argument("--lr", default=0.0003, type=float) + parser.add_argument("--ce_loss_weight", default=1.0, type=float) + parser.add_argument("--dice_loss_weight", default=0.5, type=float) + parser.add_argument("--bce_loss_weight", default=2.0, type=float) + parser.add_argument("--lora_alpha", default=16, type=int) + parser.add_argument("--lora_dropout", default=0.05, type=float) + parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str) + parser.add_argument("--explanatory", default=0.1, type=float) + parser.add_argument("--beta1", default=0.9, type=float) + parser.add_argument("--beta2", default=0.95, type=float) + parser.add_argument("--num_classes_per_sample", default=3, type=int) + parser.add_argument("--exclude_val", action="store_true", default=False) + parser.add_argument("--no_eval", action="store_true", default=False) + parser.add_argument("--eval_only", action="store_true", default=False) + parser.add_argument("--vision_pretrained", default="PATH TO SAM ViT-H Pre-trained Wegiht", type=str) + parser.add_argument("--weight", default="", type=str) + parser.add_argument("--print_freq", default=1, type=int) + parser.add_argument("--start_epoch", default=0, type=int) + return parser.parse_args(args) + + +def main(args): + args = parse_args(args) + args.log_dir = os.path.join(args.log_base_dir, args.exp_name) + if args.local_rank == 0: + os.makedirs(args.log_dir, exist_ok=True) + writer = SummaryWriter(args.log_dir) + else: + writer = None + + # Create model + tokenizer = transformers.AutoTokenizer.from_pretrained( + args.version, + cache_dir=None, + model_max_length=args.model_max_length, + padding_side="right", + use_fast=False, + ) + tokenizer.pad_token = tokenizer.unk_token + num_added_tokens = tokenizer.add_tokens("[SEG]") + ret_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids + args.seg_token_idx = ret_token_idx[0] + + model = LISA( + args.local_rank, + args.seg_token_idx, + tokenizer, + args.version, + args.lora_r, + args.precision, + vision_tower=args.vision_tower, + load_in_8bit=args.load_in_8bit, + load_in_4bit=args.load_in_4bit, + ce_loss_weight=args.ce_loss_weight, + dice_loss_weight=args.dice_loss_weight, + bce_loss_weight=args.bce_loss_weight, + vision_pretrained=args.vision_pretrained, + ) + + if args.weight: + state_dict = torch.load(args.weight, map_location='cpu') + model.load_state_dict(state_dict, strict=True) + + world_size = torch.cuda.device_count() + args.distributed = world_size > 1 + train_dataset = HybridDataset( + args.dataset_dir, + tokenizer, + args.vision_tower, + samples_per_epoch=args.batch_size * args.grad_accumulation_steps * args.steps_per_epoch * world_size, + precision=args.precision, + image_size=args.image_size, + num_classes_per_sample=args.num_classes_per_sample, + exclude_val=args.exclude_val, + dataset=args.dataset, + sem_seg_data=args.sem_seg_data, + refer_seg_data=args.refer_seg_data, + vqa_data=args.vqa_data, + reason_seg_data=args.reason_seg_data, + explanatory=args.explanatory, + ) + + if args.no_eval == False: + val_dataset = ValDataset( + args.dataset_dir, + tokenizer, + args.vision_tower, + args.val_dataset, + args.image_size, + ) + print(f"Training with {len(train_dataset)} examples and validating with {len(val_dataset)} examples.") + else: + val_dataset = None + print(f"Training with {len(train_dataset)} examples.") + + ds_config = { + "train_micro_batch_size_per_gpu": args.batch_size, + "gradient_accumulation_steps": args.grad_accumulation_steps, + "optimizer": { + "type": "AdamW", + "params": { + "lr": args.lr, + "weight_decay": 0.0, + "betas": (args.beta1, args.beta2), + }, + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "total_num_steps": args.epochs * args.steps_per_epoch, + "warmup_min_lr": 0, + "warmup_max_lr": args.lr, + "warmup_num_steps": 100, + "warmup_type": "linear", + }, + }, + "fp16": { + "enabled": args.precision == "fp16", + }, + "bf16": { + "enabled": args.precision == "bf16", + }, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": 2, + "contiguous_gradients": True, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 5e8, + "allgather_bucket_size": 5e8, + }, + } + model_engine, optimizer, train_loader, scheduler = deepspeed.initialize( + model=model, + model_parameters=model.parameters(), + training_data=train_dataset, + collate_fn=partial(collate_fn, tokenizer=tokenizer), + config=ds_config, + ) + + if val_dataset is not None: + assert args.val_batch_size == 1 + val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=False) + val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=args.val_batch_size, + shuffle=False, + num_workers=args.workers, + pin_memory=False, + sampler=val_sampler, + collate_fn=partial(collate_fn, tokenizer=tokenizer), + ) + + train_iter = iter(train_loader) + best_score, cur_ciou = 0.0, 0.0 + + if args.eval_only: + giou, ciou = validate( + val_loader, model_engine, 0, writer, args + ) + exit() + + for epoch in range(args.start_epoch, args.epochs): + + # train for one epoch + train_iter = train( + train_loader, + model_engine, + epoch, + scheduler, + writer, + train_iter, + args, + ) + + if args.no_eval == False: + giou, ciou = validate( + val_loader, model_engine, epoch, writer, args + ) + is_best = giou > best_score + best_score = max(giou, best_score) + cur_ciou = ciou if is_best else cur_ciou + + if args.no_eval or is_best: + save_dir = os.path.join(args.log_dir, "ckpt_model") + if args.local_rank == 0: + torch.save( + {"epoch": epoch}, + os.path.join( + args.log_dir, + "meta_log_giou{:.3f}_ciou{:.3f}.pth".format( + best_score, cur_ciou + ), + ), + ) + if os.path.exists(save_dir): + shutil.rmtree(save_dir) + torch.distributed.barrier() + model_engine.save_checkpoint(save_dir) + + +def train( + train_loader, + model, + epoch, + scheduler, + writer, + train_iter, + args, +): + """Main training loop.""" + batch_time = AverageMeter("Time", ":6.3f") + data_time = AverageMeter("Data", ":6.3f") + losses = AverageMeter("Loss", ":.4f") + ce_losses = AverageMeter("CeLoss", ":.4f") + mask_bce_losses = AverageMeter("MaskBCELoss", ":.4f") + mask_dice_losses = AverageMeter("MaskDICELoss", ":.4f") + mask_losses = AverageMeter("MaskLoss", ":.4f") + + progress = ProgressMeter( + args.steps_per_epoch, + [ + batch_time, + losses, + ce_losses, + mask_losses, + mask_bce_losses, + mask_dice_losses, + ], + prefix="Epoch: [{}]".format(epoch), + ) + + # switch to train mode + model.train() + end = time.time() + for global_step in range(args.steps_per_epoch): + for i in range(args.grad_accumulation_steps): + try: + input_dict = next(train_iter) + except: + train_iter = iter(train_loader) + input_dict = next(train_iter) + + data_time.update(time.time() - end) + input_dict = dict_to_cuda(input_dict) + + if args.precision == "fp16": + input_dict["images"] = input_dict["images"].half() + input_dict["images_clip"] = input_dict["images_clip"].half() + elif args.precision == "bf16": + input_dict["images"] = input_dict["images"].bfloat16() + input_dict["images_clip"] = input_dict["images_clip"].bfloat16() + else: + input_dict["images"] = input_dict["images"].float() + input_dict["images_clip"] = input_dict["images_clip"].float() + + output_dict = model(**input_dict) + + loss = output_dict["loss"] + ce_loss = output_dict["ce_loss"] + mask_bce_loss = output_dict["mask_bce_loss"] + mask_dice_loss = output_dict["mask_dice_loss"] + mask_loss = output_dict["mask_loss"] + + losses.update(loss.item(), input_dict["images"].size(0)) + ce_losses.update(ce_loss.item(), input_dict["images"].size(0)) + mask_bce_losses.update(mask_bce_loss.item(), input_dict["images"].size(0)) + mask_dice_losses.update(mask_dice_loss.item(), input_dict["images"].size(0)) + mask_losses.update(mask_loss.item(), input_dict["images"].size(0)) + model.backward(loss) + model.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if global_step % args.print_freq == 0: + if args.distributed: + batch_time.all_reduce() + data_time.all_reduce() + + losses.all_reduce() + ce_losses.all_reduce() + mask_bce_losses.all_reduce() + mask_dice_losses.all_reduce() + mask_losses.all_reduce() + + if args.local_rank == 0: + progress.display(global_step + 1) + writer.add_scalar("train/loss", losses.avg, global_step) + writer.add_scalar("train/ce_loss", ce_losses.avg, global_step) + writer.add_scalar( + "train/mask_bce_loss", mask_bce_losses.avg, global_step + ) + writer.add_scalar( + "train/mask_dice_loss", mask_dice_losses.avg, global_step + ) + writer.add_scalar("train/mask_loss", mask_losses.avg, global_step) + writer.add_scalar( + "metrics/total_secs_per_batch", batch_time.avg, global_step + ) + writer.add_scalar( + "metrics/data_secs_per_batch", data_time.avg, global_step + ) + + batch_time.reset() + data_time.reset() + losses.reset() + ce_losses.reset() + mask_bce_losses.reset() + mask_dice_losses.reset() + mask_losses.reset() + + if global_step != 0: + curr_lr = scheduler.get_last_lr() + if args.local_rank == 0: + writer.add_scalar("train/lr", curr_lr[0], global_step) + + return train_iter + + +def validate(val_loader, model_engine, epoch, writer, args): + intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM) + union_meter = AverageMeter("Union", ":6.3f", Summary.SUM) + acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM) + + model_engine.eval() + + for input_dict in tqdm.tqdm(val_loader): + input_dict = dict_to_cuda(input_dict) + if args.precision == "fp16": + input_dict["images"] = input_dict["images"].half() + input_dict["images_clip"] = input_dict["images_clip"].half() + elif args.precision == "bf16": + input_dict["images"] = input_dict["images"].bfloat16() + input_dict["images_clip"] = input_dict["images_clip"].bfloat16() + else: + input_dict["images"] = input_dict["images"].float() + input_dict["images_clip"] = input_dict["images_clip"].float() + + output_dict = model_engine(**input_dict) + + pred_masks = output_dict["pred_masks"] + masks_list = output_dict["gt_masks"][0].int() + output_list = (pred_masks[0] > 0).int() + assert len(pred_masks) == 1 + + intersection, union, acc_iou = 0.0, 0.0, 0.0 + for mask_i, output_i in zip(masks_list, output_list): + intersection_i, union_i, _ = intersectionAndUnionGPU( + output_i.contiguous().clone(), mask_i.contiguous(), 2, ignore_index=255 + ) + intersection += intersection_i + union += union_i + acc_iou += intersection_i / (union_i + 1e-5) + acc_iou[union_i == 0] += 1.0 # no-object target + intersection, union = intersection.cpu().numpy(), union.cpu().numpy() + acc_iou = acc_iou.cpu().numpy() / masks_list.shape[0] + intersection_meter.update(intersection), union_meter.update( + union + ), acc_iou_meter.update(acc_iou, n=masks_list.shape[0]) + + intersection_meter.all_reduce() + union_meter.all_reduce() + acc_iou_meter.all_reduce() + + iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) + ciou = iou_class[1] + giou = acc_iou_meter.avg[1] + + if args.local_rank == 0: + writer.add_scalar("val/giou", giou, epoch) + writer.add_scalar("val/giou", ciou, epoch) + print("giou: {:.4f}, ciou: {:.4f}".format(giou, ciou)) + + return giou, ciou + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/utils/ade20k_classes.json b/utils/ade20k_classes.json new file mode 100644 index 0000000000000000000000000000000000000000..1f96e616bc3fd2f8c0ec4caea975d77c680f44bb --- /dev/null +++ b/utils/ade20k_classes.json @@ -0,0 +1,30 @@ +[ + "wall", "building", "sky", "floor", "tree", "ceiling", "road", + "bed", "windowpane", "grass", "cabinet", "sidewalk", + "person", "earth", "door", "table", "mountain", "plant", + "curtain", "chair", "car", "water", "painting", "sofa", + "shelf", "house", "sea", "mirror", "rug", "field", "armchair", + "seat", "fence", "desk", "rock", "wardrobe", "lamp", + "bathtub", "railing", "cushion", "base", "box", "column", + "signboard", "chest of drawers", "counter", "sand", "sink", + "skyscraper", "fireplace", "refrigerator", "grandstand", + "path", "stairs", "runway", "case", "pool table", "pillow", + "screen door", "stairway", "river", "bridge", "bookcase", + "blind", "coffee table", "toilet", "flower", "book", "hill", + "bench", "countertop", "stove", "palm", "kitchen island", + "computer", "swivel chair", "boat", "bar", "arcade machine", + "hovel", "bus", "towel", "light", "truck", "tower", + "chandelier", "awning", "streetlight", "booth", + "television receiver", "airplane", "dirt track", "apparel", + "pole", "land", "bannister", "escalator", "ottoman", "bottle", + "buffet", "poster", "stage", "van", "ship", "fountain", + "conveyer belt", "canopy", "washer", "plaything", + "swimming pool", "stool", "barrel", "basket", "waterfall", + "tent", "bag", "minibike", "cradle", "oven", "ball", "food", + "step", "tank", "trade name", "microwave", "pot", "animal", + "bicycle", "lake", "dishwasher", "screen", "blanket", + "sculpture", "hood", "sconce", "vase", "traffic light", + "tray", "ashcan", "fan", "pier", "crt screen", "plate", + "monitor", "bulletin board", "shower", "radiator", "glass", + "clock", "flag" +] \ No newline at end of file diff --git a/utils/cocostuff_classes.txt b/utils/cocostuff_classes.txt new file mode 100755 index 0000000000000000000000000000000000000000..1d5a692b83ac8eead2bfffa805e1115cef737bae --- /dev/null +++ b/utils/cocostuff_classes.txt @@ -0,0 +1,183 @@ +0: unlabeled +1: person +2: bicycle +3: car +4: motorcycle +5: airplane +6: bus +7: train +8: truck +9: boat +10: traffic light +11: fire hydrant +12: street sign +13: stop sign +14: parking meter +15: bench +16: bird +17: cat +18: dog +19: horse +20: sheep +21: cow +22: elephant +23: bear +24: zebra +25: giraffe +26: hat +27: backpack +28: umbrella +29: shoe +30: eye glasses +31: handbag +32: tie +33: suitcase +34: frisbee +35: skis +36: snowboard +37: sports ball +38: kite +39: baseball bat +40: baseball glove +41: skateboard +42: surfboard +43: tennis racket +44: bottle +45: plate +46: wine glass +47: cup +48: fork +49: knife +50: spoon +51: bowl +52: banana +53: apple +54: sandwich +55: orange +56: broccoli +57: carrot +58: hot dog +59: pizza +60: donut +61: cake +62: chair +63: couch +64: potted plant +65: bed +66: mirror +67: dining table +68: window +69: desk +70: toilet +71: door +72: tv +73: laptop +74: mouse +75: remote +76: keyboard +77: cell phone +78: microwave +79: oven +80: toaster +81: sink +82: refrigerator +83: blender +84: book +85: clock +86: vase +87: scissors +88: teddy bear +89: hair drier +90: toothbrush +91: hair brush +92: banner +93: blanket +94: branch +95: bridge +96: building-other +97: bush +98: cabinet +99: cage +100: cardboard +101: carpet +102: ceiling-other +103: ceiling-tile +104: cloth +105: clothes +106: clouds +107: counter +108: cupboard +109: curtain +110: desk-stuff +111: dirt +112: door-stuff +113: fence +114: floor-marble +115: floor-other +116: floor-stone +117: floor-tile +118: floor-wood +119: flower +120: fog +121: food-other +122: fruit +123: furniture-other +124: grass +125: gravel +126: ground-other +127: hill +128: house +129: leaves +130: light +131: mat +132: metal +133: mirror-stuff +134: moss +135: mountain +136: mud +137: napkin +138: net +139: paper +140: pavement +141: pillow +142: plant-other +143: plastic +144: platform +145: playingfield +146: railing +147: railroad +148: river +149: road +150: rock +151: roof +152: rug +153: salad +154: sand +155: sea +156: shelf +157: sky +158: skyscraper +159: snow +160: solid-other +161: stairs +162: stone +163: straw +164: structural-other +165: table +166: tent +167: textile-other +168: towel +169: tree +170: vegetable +171: wall-brick +172: wall-concrete +173: wall-other +174: wall-panel +175: wall-stone +176: wall-tile +177: wall-wood +178: water-other +179: waterdrops +180: window-blind +181: window-other +182: wood diff --git a/utils/conversation.py b/utils/conversation.py index 0cf11c4096391485b332e2006fd88aa80e6b783e..65ea31ff2e1ba6f93c5942d096162576284fff61 100644 --- a/utils/conversation.py +++ b/utils/conversation.py @@ -3,8 +3,8 @@ Conversation prompt templates. """ import dataclasses -from enum import auto, Enum -from typing import List, Tuple, Any +from enum import Enum, auto +from typing import Any, List class SeparatorStyle(Enum): diff --git a/utils/data_proc_demo.py b/utils/data_proc_demo.py deleted file mode 100644 index 88eb9b6696c7e65e8384846182c1129083f239db..0000000000000000000000000000000000000000 --- a/utils/data_proc_demo.py +++ /dev/null @@ -1,83 +0,0 @@ -import os -import numpy as np -import json -import cv2 -import glob - -def get_mask_from_json(json_path, img): - try: - with open(json_path, 'r') as r: - anno = json.loads(r.read()) - except: - with open(json_path, 'r', encoding="cp1252") as r: - anno = json.loads(r.read()) - - inform = anno['shapes'] - comments = anno['text'] - is_sentence = anno['is_sentence'] - - height, width = img.shape[:2] - - ### sort polies by area - area_list = [] - valid_poly_list = [] - for i in inform: - label_id = i['label'] - points = i['points'] - if 'flag' == label_id.lower(): ## meaningless deprecated annotations - continue - - tmp_mask = np.zeros((height, width), dtype=np.uint8) - cv2.polylines(tmp_mask, np.array([points], dtype=np.int32), True, 1, 1) - cv2.fillPoly(tmp_mask, np.array([points], dtype=np.int32), 1) - tmp_area = tmp_mask.sum() - - area_list.append(tmp_area) - valid_poly_list.append(i) - - ### ground-truth mask - sort_index = np.argsort(area_list)[::-1].astype(np.int32) - sort_index = list(sort_index) - sort_inform = [] - for s_idx in sort_index: - sort_inform.append(valid_poly_list[s_idx]) - - mask = np.zeros((height, width), dtype=np.uint8) - for i in sort_inform: - label_id = i['label'] - points = i['points'] - - if 'ignore' in label_id.lower(): - label_value = 255 # ignored during evaluation - else: - label_value = 1 # target - - cv2.polylines(mask, np.array([points], dtype=np.int32), True, label_value, 1) - cv2.fillPoly(mask, np.array([points], dtype=np.int32), label_value) - - return mask, comments, is_sentence - - -if __name__ == '__main__': - data_dir = './train' - vis_dir = './vis' - - if not os.path.exists(vis_dir): - os.makedirs(vis_dir) - - json_path_list = sorted(glob.glob(data_dir + '/*.json')) - for json_path in json_path_list: - img_path = json_path.replace('.json', '.jpg') - img = cv2.imread(img_path)[:,:,::-1] - - # In generated mask, value 1 denotes valid target region, and value 255 stands for region ignored during evaluaiton. - mask, comments, is_sentence = get_mask_from_json(json_path, img) - - ## visualization. Green for target, and red for ignore. - valid_mask = (mask == 1).astype(np.float32)[:,:,None] - ignore_mask = (mask == 255).astype(np.float32)[:,:,None] - vis_img = img * (1 - valid_mask) * (1 - ignore_mask) + ((np.array([0,255,0]) * 0.6 + img * 0.4) * valid_mask + (np.array([255,0,0]) * 0.6 + img * 0.4) * ignore_mask) - vis_img = np.concatenate([img, vis_img], 1) - vis_path = os.path.join(vis_dir, json_path.split('/')[-1].replace('.json', '.jpg')) - cv2.imwrite(vis_path, vis_img[:,:,::-1]) - print('Visualization has been saved to: ', vis_path) \ No newline at end of file diff --git a/utils/data_processing.py b/utils/data_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..d47a80f0111019c97ccb2ce198f37495ee037471 --- /dev/null +++ b/utils/data_processing.py @@ -0,0 +1,90 @@ +import glob +import json +import os + +import cv2 +import numpy as np + + +def get_mask_from_json(json_path, img): + try: + with open(json_path, "r") as r: + anno = json.loads(r.read()) + except: + with open(json_path, "r", encoding="cp1252") as r: + anno = json.loads(r.read()) + + inform = anno["shapes"] + comments = anno["text"] + is_sentence = anno["is_sentence"] + + height, width = img.shape[:2] + + ### sort polies by area + area_list = [] + valid_poly_list = [] + for i in inform: + label_id = i["label"] + points = i["points"] + if "flag" == label_id.lower(): ## meaningless deprecated annotations + continue + + tmp_mask = np.zeros((height, width), dtype=np.uint8) + cv2.polylines(tmp_mask, np.array([points], dtype=np.int32), True, 1, 1) + cv2.fillPoly(tmp_mask, np.array([points], dtype=np.int32), 1) + tmp_area = tmp_mask.sum() + + area_list.append(tmp_area) + valid_poly_list.append(i) + + ### ground-truth mask + sort_index = np.argsort(area_list)[::-1].astype(np.int32) + sort_index = list(sort_index) + sort_inform = [] + for s_idx in sort_index: + sort_inform.append(valid_poly_list[s_idx]) + + mask = np.zeros((height, width), dtype=np.uint8) + for i in sort_inform: + label_id = i["label"] + points = i["points"] + + if "ignore" in label_id.lower(): + label_value = 255 # ignored during evaluation + else: + label_value = 1 # target + + cv2.polylines(mask, np.array([points], dtype=np.int32), True, label_value, 1) + cv2.fillPoly(mask, np.array([points], dtype=np.int32), label_value) + + return mask, comments, is_sentence + + +if __name__ == "__main__": + data_dir = "./train" + vis_dir = "./vis" + + if not os.path.exists(vis_dir): + os.makedirs(vis_dir) + + json_path_list = sorted(glob.glob(data_dir + "/*.json")) + for json_path in json_path_list: + img_path = json_path.replace(".json", ".jpg") + img = cv2.imread(img_path)[:, :, ::-1] + + # In generated mask, value 1 denotes valid target region, and value 255 stands for region ignored during evaluaiton. + mask, comments, is_sentence = get_mask_from_json(json_path, img) + + ## visualization. Green for target, and red for ignore. + valid_mask = (mask == 1).astype(np.float32)[:, :, None] + ignore_mask = (mask == 255).astype(np.float32)[:, :, None] + vis_img = img * (1 - valid_mask) * (1 - ignore_mask) + ( + (np.array([0, 255, 0]) * 0.6 + img * 0.4) * valid_mask + + (np.array([255, 0, 0]) * 0.6 + img * 0.4) * ignore_mask + ) + vis_img = np.concatenate([img, vis_img], 1) + vis_path = os.path.join( + vis_dir, json_path.split("/")[-1].replace(".json", ".jpg") + ) + cv2.imwrite(vis_path, vis_img[:, :, ::-1]) + print("Visualization has been saved to: ", vis_path) diff --git a/utils/dataset.py b/utils/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3499046ff89c54bf1dcffa7f1f8c2b564a1d0ffe --- /dev/null +++ b/utils/dataset.py @@ -0,0 +1,450 @@ +import glob +import os +import random + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from pycocotools import mask +from transformers import CLIPImageProcessor + +from model.segment_anything.utils.transforms import ResizeLongestSide + +from .conversation import get_default_conv_template +from .data_processing import get_mask_from_json +from .reason_seg_dataset import ReasonSegDataset +from .refer import REFER +from .refer_seg_dataset import ReferSegDataset +from .sem_seg_dataset import SemSegDataset +from .utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, + DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IMAGE_TOKEN) +from .vqa_dataset import VQADataset + + +def collate_fn(batch, tokenizer=None): + image_path_list = [] + images_list = [] + images_clip_list = [] + conversation_list = [] + masks_list = [] + label_list = [] + resize_list = [] + questions_list = [] + sampled_classes_list = [] + offset_list = [0] + cnt = 0 + inferences = [] + for ( + image_path, + images, + images_clip, + conversations, + masks, + label, + resize, + questions, + sampled_classes, + inference, + ) in batch: + image_path_list.append(image_path) + images_list.append(images) + images_clip_list.append(images_clip) + conversation_list.extend(conversations) + label_list.append(label) + masks_list.append(masks.float()) + resize_list.append(resize) + questions_list.append(questions) + sampled_classes_list.append(sampled_classes) + cnt += len(conversations) + offset_list.append(cnt) + inferences.append(inference) + + tokenize_data = tokenizer( + conversation_list, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ) + + input_ids = tokenize_data.input_ids + attention_masks = tokenize_data.attention_mask + + IGNORE_TOKEN_ID = -100 + conv = get_default_conv_template("vicuna").copy() + targets = input_ids.clone() + sep = conv.sep + conv.roles[1] + ": " + for conversation, target in zip(conversation_list, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_TOKEN_ID + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + # if len(parts) != 2: + # break + assert len(parts) == 2, (len(parts), rou) + parts[0] += sep + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID + + cur_len += round_len + target[cur_len:] = IGNORE_TOKEN_ID + + if False: + # if True: + z = target.clone() + z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) + # rank0_print(tokenizer.decode(z)) + print( + "conversation: ", + conversation, + "tokenizer.decode(z): ", + tokenizer.decode(z), + ) + + if cur_len < tokenizer.model_max_length: + assert cur_len == total_len + + return { + "image_paths": image_path_list, + "images": torch.stack(images_list, dim=0), + "images_clip": torch.stack(images_clip_list, dim=0), + "input_ids": input_ids, + "labels": targets, + "attention_masks": attention_masks, + "masks_list": masks_list, + "label_list": label_list, + "resize_list": resize_list, + "offset": torch.LongTensor(offset_list), + "questions_list": questions_list, + "sampled_classes_list": sampled_classes_list, + "inference": inferences[0], + "conversation_list": conversation_list, + } + + +class HybridDataset(torch.utils.data.Dataset): + pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1) + pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1) + img_size = 1024 + ignore_label = 255 + + def __init__( + self, + base_image_dir, + tokenizer, + vision_tower, + samples_per_epoch=500 * 8 * 2 * 10, + precision: str = "fp32", + image_size: int = 224, + num_classes_per_sample: int = 3, + exclude_val=False, + dataset="sem_seg||refer_seg||vqa||reason_seg", + sem_seg_data="ade20k||cocostuff||partimagenet||pascal_part||paco_lvis||mapillary", + refer_seg_data="refclef||refcoco||refcoco+||refcocog", + vqa_data="llava_instruct_150k", + reason_seg_data="ReasonSeg|train", + explanatory=0.1, + ): + self.exclude_val = exclude_val + self.dataset = dataset + self.samples_per_epoch = samples_per_epoch + self.explanatory = explanatory + self.num_classes_per_sample = num_classes_per_sample + + self.base_image_dir = base_image_dir + self.image_size = image_size + self.tokenizer = tokenizer + self.precision = precision + + self.datasets = dataset.split("||") + + self.all_datasets = [] + for dataset in self.datasets: + if dataset == "sem_seg": + self.all_datasets.append( + SemSegDataset( + base_image_dir, + tokenizer, + vision_tower, + samples_per_epoch, + precision, + image_size, + num_classes_per_sample, + exclude_val, + sem_seg_data, + ) + ) + elif dataset == "refer_seg": + self.all_datasets.append( + ReferSegDataset( + base_image_dir, + tokenizer, + vision_tower, + samples_per_epoch, + precision, + image_size, + num_classes_per_sample, + exclude_val, + refer_seg_data, + ) + ) + elif dataset == "vqa": + self.all_datasets.append( + VQADataset( + base_image_dir, + tokenizer, + vision_tower, + samples_per_epoch, + precision, + image_size, + num_classes_per_sample, + exclude_val, + vqa_data, + ) + ) + elif dataset == "reason_seg": + self.all_datasets.append( + ReasonSegDataset( + base_image_dir, + tokenizer, + vision_tower, + samples_per_epoch, + precision, + image_size, + num_classes_per_sample, + exclude_val, + reason_seg_data, + explanatory, + ) + ) + + def __len__(self): + return self.samples_per_epoch + + def __getitem__(self, idx): + ind = (random.randint(0, 2023) * (idx + 1)) % len( + self.datasets + ) # random.randint(0, len(self.datasets)-1) + data = self.all_datasets[ind] + inference = False + return *data[0], inference + + +class ValDataset(torch.utils.data.Dataset): + pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1) + pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1) + img_size = 1024 + ignore_label = 255 + + def __init__( + self, + base_image_dir, + tokenizer, + vision_tower, + val_dataset, + image_size=1024, + ): + self.base_image_dir = base_image_dir + splits = val_dataset.split("|") + if len(splits) == 2: + ds, split = splits + images = glob.glob( + os.path.join(self.base_image_dir, "reason_seg", ds, split, "*.jpg") + ) + self.images = images + self.data_type = 'reason_seg' + elif len(splits) == 3: + ds, splitBy, split = splits + refer_api = REFER(self.base_image_dir, ds, splitBy) + ref_ids_val = refer_api.getRefIds(split=split) + images_ids_val = refer_api.getImgIds(ref_ids=ref_ids_val) + refs_val = refer_api.loadRefs(ref_ids=ref_ids_val) + refer_seg_ds = {} + refer_seg_ds["images"] = [] + loaded_images = refer_api.loadImgs(image_ids=images_ids_val) + for item in loaded_images: + item = item.copy() + if ds == "refclef": + item["file_name"] = os.path.join( + base_image_dir, "images/saiapr_tc-12", item["file_name"] + ) + elif ds in ["refcoco", "refcoco+", "refcocog", "grefcoco"]: + item["file_name"] = os.path.join( + base_image_dir, + "images/mscoco/images/train2014", + item["file_name"], + ) + refer_seg_ds["images"].append(item) + refer_seg_ds["annotations"] = refer_api.Anns # anns_val + + img2refs = {} + for ref in refs_val: + image_id = ref["image_id"] + img2refs[image_id] = img2refs.get(image_id, []) + [ + ref, + ] + refer_seg_ds["img2refs"] = img2refs + self.refer_seg_ds = refer_seg_ds + self.data_type = 'refer_seg' + + self.ds = ds + self.image_size = image_size + self.tokenizer = tokenizer + self.transform = ResizeLongestSide(image_size) + self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower) + + def __len__(self): + if self.data_type == 'refer_seg': + return len(self.refer_seg_ds["images"]) + else: + return len(self.images) + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.img_size - h + padw = self.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x + + def __getitem__(self, idx): + if self.data_type == 'refer_seg': + refer_seg_ds = self.refer_seg_ds + images = refer_seg_ds["images"] + annotations = refer_seg_ds["annotations"] + img2refs = refer_seg_ds["img2refs"] + + image = images[idx] + image_path = image["file_name"] + image_id = image["id"] + + refs = img2refs[image_id] + if len(refs) == 0: + raise ValueError("image {} has no refs".format(image_id)) + + sents = [] + ann_ids = [] + for ref in refs: + for sent in ref["sentences"]: + sents.append(sent["sent"].strip().lower()) + ann_ids.append(ref["ann_id"]) + + sampled_sents = sents + sampled_ann_ids = ann_ids + img = cv2.imread(image_path) + images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + is_sentence = False + else: + image_path = self.images[idx] + img = cv2.imread(image_path) + images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + json_path = image_path.replace(".jpg", ".json") + mask_json, sampled_sents, is_sentence = get_mask_from_json(json_path, img) + sampled_sents = [sampled_sents[0]] + + conversations = [] + conv = get_default_conv_template("vicuna").copy() + i = 0 + while i < len(sampled_sents): + conv.messages = [] + text = sampled_sents[i].strip() + if is_sentence: + conv.append_message( + conv.roles[0], + DEFAULT_IMAGE_TOKEN + + " {} Please output segmentation mask.".format(text), + ) + conv.append_message(conv.roles[1], "[SEG].") + else: + conv.append_message( + conv.roles[0], + DEFAULT_IMAGE_TOKEN + + " What is {} in this image? Please output segmentation mask.".format( + text + ), + ) + conv.append_message(conv.roles[1], "[SEG].") + conversations.append(conv.get_prompt()) + i += 1 + + # replace token + image_token_len = 256 + for i in range(len(conversations)): + replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + replace_token = ( + DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + ) + conversations[i] = conversations[i].replace( + DEFAULT_IMAGE_TOKEN, replace_token + ) + + # preprocess images for clip + images_clip = self.clip_image_processor.preprocess(images, return_tensors="pt")[ + "pixel_values" + ][0] + image_token_len = (images_clip.shape[1] // 14) * ( + images_clip.shape[2] // 14 + ) # FIXME: 14 is hardcoded patch size + + # preprocess images for sam + images = self.transform.apply_image(images) + + resize = images.shape[:2] + + images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous()) + + if self.data_type == 'refer_seg': + masks = [] + for i, ann_id in enumerate(sampled_ann_ids): + ann = annotations[ann_id] + if len(ann["segmentation"]) == 0 and sampled_sents[i] != "": + m = np.zeros((image["height"], image["width"], 1)) + else: + if type(ann["segmentation"][0]) == list: # polygon + rle = mask.frPyObjects( + ann["segmentation"], image["height"], image["width"] + ) + else: + rle = ann["segmentation"] + for i in range(len(rle)): + if not isinstance(rle[i]["counts"], bytes): + rle[i]["counts"] = rle[i]["counts"].encode() + m = mask.decode(rle) + m = np.sum( + m, axis=2 + ) # sometimes there are multiple binary map (corresponding to multiple segs) + m = m.astype(np.uint8) # convert to np.uint8 + masks.append(m) + else: + masks = [mask_json] + + masks = np.stack(masks, axis=0) + masks = torch.from_numpy(masks) + labels = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label + inference = True + + return ( + image_path, + images, + images_clip, + conversations, + masks, + labels, + resize, + None, + None, + inference, + ) diff --git a/utils/reason_seg_dataset.py b/utils/reason_seg_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6b911c7f36ed3a81fa6995cb2d9c2d5d846f50a8 --- /dev/null +++ b/utils/reason_seg_dataset.py @@ -0,0 +1,247 @@ +import glob +import json +import os +import random + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor + +from model.segment_anything.utils.transforms import ResizeLongestSide + +from .conversation import get_default_conv_template +from .data_processing import get_mask_from_json +from .utils import (ANSWER_LIST, DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, + DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IMAGE_TOKEN, + EXPLANATORY_QUESTION_LIST, LONG_QUESTION_LIST, + SHORT_QUESTION_LIST) + + +class ReasonSegDataset(torch.utils.data.Dataset): + pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1) + pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1) + img_size = 1024 + ignore_label = 255 + + def __init__( + self, + base_image_dir, + tokenizer, + vision_tower, + samples_per_epoch=500 * 8 * 2 * 10, + precision: str = "fp32", + image_size: int = 224, + num_classes_per_sample: int = 3, + exclude_val=False, + reason_seg_data="ReasonSeg|train", + explanatory=0.1, + ): + self.exclude_val = exclude_val + self.reason_seg_data = reason_seg_data + self.samples_per_epoch = samples_per_epoch + self.explanatory = explanatory + self.num_classes_per_sample = num_classes_per_sample + + self.base_image_dir = base_image_dir + self.image_size = image_size + self.tokenizer = tokenizer + self.precision = precision + self.transform = ResizeLongestSide(image_size) + self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower) + + self.short_question_list = SHORT_QUESTION_LIST + self.long_question_list = LONG_QUESTION_LIST + self.answer_list = ANSWER_LIST + + if explanatory != -1: + self.explanatory_question_list = EXPLANATORY_QUESTION_LIST + + if explanatory != -1: + self.img_to_why = {} + for sub_data in [ + "20230711_2000_0_processed_masked_finished_masked.json", + "20230711_2000_0_processed_masked_partial_masked.json", + ]: + with open( + os.path.join(base_image_dir, "reason_seg", "explanatory", sub_data) + ) as f: + items = json.load(f) + for item in items: + img_name = item["image_path"].split("/")[-1] + self.img_to_why[img_name] = { + "query": item["query"], + "outputs": item["outputs"], + } + + reason_seg_data, splits = reason_seg_data.split("|") + splits = splits.split("_") + images = [] + for split in splits: + images_split = glob.glob( + os.path.join( + base_image_dir, "reason_seg", reason_seg_data, split, "*.jpg" + ) + ) + images.extend(images_split) + jsons = [path.replace(".jpg", ".json") for path in images] + self.reason_seg_data = (images, jsons) + + def __len__(self): + return self.samples_per_epoch + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.img_size - h + padw = self.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x + + def __getitem__(self, idx): + images, jsons = self.reason_seg_data + idx = random.randint(0, len(images) - 1) + image_path = images[idx] + json_path = jsons[idx] + + img = cv2.imread(image_path) + images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + ori_size = images.shape[:2] + # preprocess images for clip + images_clip = self.clip_image_processor.preprocess(images, return_tensors="pt")[ + "pixel_values" + ][0] + image_token_len = (images_clip.shape[1] // 14) * ( + images_clip.shape[2] // 14 + ) # FIXME: 14 is hardcoded patch size + images = self.transform.apply_image(images) # preprocess images for sam + resize = images.shape[:2] + + mask, sents, is_sentence = get_mask_from_json(json_path, img) + if len(sents) >= self.num_classes_per_sample: + sampled_inds = np.random.choice( + list(range(len(sents))), size=self.num_classes_per_sample, replace=False + ) + else: + sampled_inds = list(range(len(sents))) + sampled_sents = np.vectorize(sents.__getitem__)(sampled_inds).tolist() + sampled_masks = [ + (mask == 1).astype(np.float32) for _ in range(len(sampled_inds)) + ] + + image_name = image_path.split("/")[-1] + if ( + self.explanatory != -1 and image_name in self.img_to_why + ): # ds in ['20230711_2000_0_processed_masked_partial_masked', '20230711_2000_0_processed_masked_finished_masked', 'trainval_rephrased_20230730_checked_final_masked', 'rephrased_20230730_checked_final_masked']: + if random.random() < self.explanatory: + choice = 2 + else: + choice = random.randint(0, 1) + + questions = [] + answers = [] + class_ids = [] + for text in sampled_sents: + if is_sentence: + question_template = random.choice(self.long_question_list) + questions.append(question_template.format(sent=text)) + else: + question_template = random.choice(self.short_question_list) + questions.append(question_template.format(class_name=text.lower())) + + img_name = image_path.split("/")[-1] + if self.explanatory != -1 and img_name in self.img_to_why: + # choice = random.randint(0, 2) + if choice == 0: # [SEG] token + answers.append(random.choice(self.answer_list)) + elif choice == 1: # [SEG] token + text answer + image_name = image_path.split("/")[-1] + answer = self.img_to_why[image_name]["outputs"] + answer = random.choice(self.answer_list) + " {}".format(answer) + questions[-1] = ( + DEFAULT_IMAGE_TOKEN + + " " + + text + + " {}".format(random.choice(self.explanatory_question_list)) + ) + answers.append(answer) + elif choice == 2: # vanilla text answer + image_name = image_path.split("/")[-1] + answer = self.img_to_why[image_name]["outputs"] + questions[-1] = DEFAULT_IMAGE_TOKEN + " " + text + answers.append(answer) + else: + raise ValueError("Not implemented yet.") + else: + answers.append(random.choice(self.answer_list)) + + conversations = [] + conv = get_default_conv_template("vicuna").copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + i = 0 + while i < len(questions): + conv.messages = [] + conv.append_message(conv.roles[0], questions[i]) + conv.append_message(conv.roles[1], answers[i]) + conversations.append(conv.get_prompt()) + i += 1 + + # ============================== + # replace token + for i in range(len(conversations)): + replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + replace_token = ( + DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + ) + conversations[i] = conversations[i].replace( + DEFAULT_IMAGE_TOKEN, replace_token + ) + # ============================== + + images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous()) + + image_name = image_path.split("/")[-1] + if self.explanatory != -1 and image_name in self.img_to_why and choice == 2: + # print("e1") + + masks = torch.rand(0, *ori_size) + label = torch.ones(ori_size) * self.ignore_label + else: + # print("e2") + + masks = np.stack(sampled_masks, axis=0) + masks = torch.from_numpy(masks) + label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label + + # print("reason_seg: {}".format(conversations)) + + # # debug + # if masks.shape[0] != 0: + # save_dir = "./debug/{}".format(image_path.split("/")[-1].split(".")[0]) + # os.makedirs(save_dir, exist_ok=True) + # print("masks.shape: ", masks.shape) + # for i in range(masks.shape[0]): + # cv2.imwrite("{}/mask_{}.jpg".format(save_dir, i), masks[i].numpy().astype(np.uint8)*100) + # assert len(conversations) == masks.shape[0] + # with open("{}/conversations.txt".format(save_dir), "w+") as f: + # for i in range(len(conversations)): + # f.write("{}. ".format(i) + conversations[i] + "\n") + # shutil.copy(image_path, save_dir) + + return ( + image_path, + images, + images_clip, + conversations, + masks, + label, + resize, + questions, + sampled_sents, + ) diff --git a/utils/refer.py b/utils/refer.py new file mode 100644 index 0000000000000000000000000000000000000000..3b4cea716e40e73d0b5aa118143eb076392f5eb1 --- /dev/null +++ b/utils/refer.py @@ -0,0 +1,391 @@ +__author__ = "licheng" + +""" +This interface provides access to four datasets: +1) refclef +2) refcoco +3) refcoco+ +4) refcocog +split by unc and google + +The following API functions are defined: +REFER - REFER api class +getRefIds - get ref ids that satisfy given filter conditions. +getAnnIds - get ann ids that satisfy given filter conditions. +getImgIds - get image ids that satisfy given filter conditions. +getCatIds - get category ids that satisfy given filter conditions. +loadRefs - load refs with the specified ref ids. +loadAnns - load anns with the specified ann ids. +loadImgs - load images with the specified image ids. +loadCats - load category names with the specified category ids. +getRefBox - get ref's bounding box [x, y, w, h] given the ref_id +showRef - show image, segmentation or box of the referred object with the ref +getMask - get mask and area of the referred object given ref +showMask - show mask of the referred object given ref +""" + +import itertools +import json +import os.path as osp +import pickle +import sys +import time +from pprint import pprint + +import matplotlib.pyplot as plt +import numpy as np +import skimage.io as io +from matplotlib.collections import PatchCollection +from matplotlib.patches import Polygon, Rectangle +from pycocotools import mask + + +class REFER: + def __init__(self, data_root, dataset="refcoco", splitBy="unc"): + # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog + # also provide dataset name and splitBy information + # e.g., dataset = 'refcoco', splitBy = 'unc' + print("loading dataset %s into memory..." % dataset) + self.ROOT_DIR = osp.abspath(osp.dirname(__file__)) + self.DATA_DIR = osp.join(data_root, dataset) + if dataset in ["refcoco", "refcoco+", "refcocog"]: + self.IMAGE_DIR = osp.join(data_root, "images/mscoco/images/train2014") + elif dataset == "refclef": + self.IMAGE_DIR = osp.join(data_root, "images/saiapr_tc-12") + else: + print("No refer dataset is called [%s]" % dataset) + sys.exit() + + self.dataset = dataset + + # load refs from data/dataset/refs(dataset).json + tic = time.time() + + ref_file = osp.join(self.DATA_DIR, "refs(" + splitBy + ").p") + print("ref_file: ", ref_file) + self.data = {} + self.data["dataset"] = dataset + self.data["refs"] = pickle.load(open(ref_file, "rb")) + + # load annotations from data/dataset/instances.json + instances_file = osp.join(self.DATA_DIR, "instances.json") + instances = json.load(open(instances_file, "rb")) + self.data["images"] = instances["images"] + self.data["annotations"] = instances["annotations"] + self.data["categories"] = instances["categories"] + + # create index + self.createIndex() + print("DONE (t=%.2fs)" % (time.time() - tic)) + + def createIndex(self): + # create sets of mapping + # 1) Refs: {ref_id: ref} + # 2) Anns: {ann_id: ann} + # 3) Imgs: {image_id: image} + # 4) Cats: {category_id: category_name} + # 5) Sents: {sent_id: sent} + # 6) imgToRefs: {image_id: refs} + # 7) imgToAnns: {image_id: anns} + # 8) refToAnn: {ref_id: ann} + # 9) annToRef: {ann_id: ref} + # 10) catToRefs: {category_id: refs} + # 11) sentToRef: {sent_id: ref} + # 12) sentToTokens: {sent_id: tokens} + print("creating index...") + # fetch info from instances + Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {} + for ann in self.data["annotations"]: + Anns[ann["id"]] = ann + imgToAnns[ann["image_id"]] = imgToAnns.get(ann["image_id"], []) + [ann] + for img in self.data["images"]: + Imgs[img["id"]] = img + for cat in self.data["categories"]: + Cats[cat["id"]] = cat["name"] + + # fetch info from refs + Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {} + Sents, sentToRef, sentToTokens = {}, {}, {} + for ref in self.data["refs"]: + # ids + ref_id = ref["ref_id"] + ann_id = ref["ann_id"] + category_id = ref["category_id"] + image_id = ref["image_id"] + + # add mapping related to ref + Refs[ref_id] = ref + imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref] + catToRefs[category_id] = catToRefs.get(category_id, []) + [ref] + refToAnn[ref_id] = Anns[ann_id] + annToRef[ann_id] = ref + + # add mapping of sent + for sent in ref["sentences"]: + Sents[sent["sent_id"]] = sent + sentToRef[sent["sent_id"]] = ref + sentToTokens[sent["sent_id"]] = sent["tokens"] + + # create class members + self.Refs = Refs + self.Anns = Anns + self.Imgs = Imgs + self.Cats = Cats + self.Sents = Sents + self.imgToRefs = imgToRefs + self.imgToAnns = imgToAnns + self.refToAnn = refToAnn + self.annToRef = annToRef + self.catToRefs = catToRefs + self.sentToRef = sentToRef + self.sentToTokens = sentToTokens + print("index created.") + + def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=""): + image_ids = image_ids if type(image_ids) == list else [image_ids] + cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] + ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] + + if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0: + refs = self.data["refs"] + else: + if not len(image_ids) == 0: + refs = [self.imgToRefs[image_id] for image_id in image_ids] + else: + refs = self.data["refs"] + if not len(cat_ids) == 0: + refs = [ref for ref in refs if ref["category_id"] in cat_ids] + if not len(ref_ids) == 0: + refs = [ref for ref in refs if ref["ref_id"] in ref_ids] + if not len(split) == 0: + if split in ["testA", "testB", "testC"]: + refs = [ + ref for ref in refs if split[-1] in ref["split"] + ] # we also consider testAB, testBC, ... + elif split in ["testAB", "testBC", "testAC"]: + refs = [ + ref for ref in refs if ref["split"] == split + ] # rarely used I guess... + elif split == "test": + refs = [ref for ref in refs if "test" in ref["split"]] + elif split == "train" or split == "val": + refs = [ref for ref in refs if ref["split"] == split] + else: + print("No such split [%s]" % split) + sys.exit() + ref_ids = [ref["ref_id"] for ref in refs] + return ref_ids + + def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]): + image_ids = image_ids if type(image_ids) == list else [image_ids] + cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] + ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] + + if len(image_ids) == len(cat_ids) == len(ref_ids) == 0: + ann_ids = [ann["id"] for ann in self.data["annotations"]] + else: + if not len(image_ids) == 0: + lists = [ + self.imgToAnns[image_id] + for image_id in image_ids + if image_id in self.imgToAnns + ] # list of [anns] + anns = list(itertools.chain.from_iterable(lists)) + else: + anns = self.data["annotations"] + if not len(cat_ids) == 0: + anns = [ann for ann in anns if ann["category_id"] in cat_ids] + ann_ids = [ann["id"] for ann in anns] + if not len(ref_ids) == 0: + ids = set(ann_ids).intersection( + set([self.Refs[ref_id]["ann_id"] for ref_id in ref_ids]) + ) + return ann_ids + + def getImgIds(self, ref_ids=[]): + ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] + + if not len(ref_ids) == 0: + image_ids = list(set([self.Refs[ref_id]["image_id"] for ref_id in ref_ids])) + else: + image_ids = self.Imgs.keys() + return image_ids + + def getCatIds(self): + return self.Cats.keys() + + def loadRefs(self, ref_ids=[]): + if type(ref_ids) == list: + return [self.Refs[ref_id] for ref_id in ref_ids] + elif type(ref_ids) == int: + return [self.Refs[ref_ids]] + + def loadAnns(self, ann_ids=[]): + if type(ann_ids) == list: + return [self.Anns[ann_id] for ann_id in ann_ids] + elif type(ann_ids) == int or type(ann_ids) == unicode: + return [self.Anns[ann_ids]] + + def loadImgs(self, image_ids=[]): + if type(image_ids) == list: + return [self.Imgs[image_id] for image_id in image_ids] + elif type(image_ids) == int: + return [self.Imgs[image_ids]] + + def loadCats(self, cat_ids=[]): + if type(cat_ids) == list: + return [self.Cats[cat_id] for cat_id in cat_ids] + elif type(cat_ids) == int: + return [self.Cats[cat_ids]] + + def getRefBox(self, ref_id): + ref = self.Refs[ref_id] + ann = self.refToAnn[ref_id] + return ann["bbox"] # [x, y, w, h] + + def showRef(self, ref, seg_box="seg"): + ax = plt.gca() + # show image + image = self.Imgs[ref["image_id"]] + I = io.imread(osp.join(self.IMAGE_DIR, image["file_name"])) + ax.imshow(I) + # show refer expression + for sid, sent in enumerate(ref["sentences"]): + print("%s. %s" % (sid + 1, sent["sent"])) + # show segmentations + if seg_box == "seg": + ann_id = ref["ann_id"] + ann = self.Anns[ann_id] + polygons = [] + color = [] + c = "none" + if type(ann["segmentation"][0]) == list: + # polygon used for refcoco* + for seg in ann["segmentation"]: + poly = np.array(seg).reshape((len(seg) / 2, 2)) + polygons.append(Polygon(poly, True, alpha=0.4)) + color.append(c) + p = PatchCollection( + polygons, + facecolors=color, + edgecolors=(1, 1, 0, 0), + linewidths=3, + alpha=1, + ) + ax.add_collection(p) # thick yellow polygon + p = PatchCollection( + polygons, + facecolors=color, + edgecolors=(1, 0, 0, 0), + linewidths=1, + alpha=1, + ) + ax.add_collection(p) # thin red polygon + else: + # mask used for refclef + rle = ann["segmentation"] + m = mask.decode(rle) + img = np.ones((m.shape[0], m.shape[1], 3)) + color_mask = np.array([2.0, 166.0, 101.0]) / 255 + for i in range(3): + img[:, :, i] = color_mask[i] + ax.imshow(np.dstack((img, m * 0.5))) + # show bounding-box + elif seg_box == "box": + ann_id = ref["ann_id"] + ann = self.Anns[ann_id] + bbox = self.getRefBox(ref["ref_id"]) + box_plot = Rectangle( + (bbox[0], bbox[1]), + bbox[2], + bbox[3], + fill=False, + edgecolor="green", + linewidth=3, + ) + ax.add_patch(box_plot) + + def getMask(self, ref): + # return mask, area and mask-center + ann = self.refToAnn[ref["ref_id"]] + image = self.Imgs[ref["image_id"]] + if type(ann["segmentation"][0]) == list: # polygon + rle = mask.frPyObjects(ann["segmentation"], image["height"], image["width"]) + else: + rle = ann["segmentation"] + m = mask.decode(rle) + m = np.sum( + m, axis=2 + ) # sometimes there are multiple binary map (corresponding to multiple segs) + m = m.astype(np.uint8) # convert to np.uint8 + # compute area + area = sum(mask.area(rle)) # should be close to ann['area'] + return {"mask": m, "area": area} + # # position + # position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style) + # position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style) -> y (c style) + # # mass position (if there were multiple regions, we use the largest one.) + # label_m = label(m, connectivity=m.ndim) + # regions = regionprops(label_m) + # if len(regions) > 0: + # largest_id = np.argmax(np.array([props.filled_area for props in regions])) + # largest_props = regions[largest_id] + # mass_y, mass_x = largest_props.centroid + # else: + # mass_x, mass_y = position_x, position_y + # # if centroid is not in mask, we find the closest point to it from mask + # if m[mass_y, mass_x] != 1: + # print('Finding closes mask point ...') + # kernel = np.ones((10, 10),np.uint8) + # me = cv2.erode(m, kernel, iterations = 1) + # points = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist()) # row, col style + # points = np.array(points) + # dist = np.sum((points - (mass_y, mass_x))**2, axis=1) + # id = np.argsort(dist)[0] + # mass_y, mass_x = points[id] + # # return + # return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y} + # # show image and mask + # I = io.imread(osp.join(self.IMAGE_DIR, image['file_name'])) + # plt.figure() + # plt.imshow(I) + # ax = plt.gca() + # img = np.ones( (m.shape[0], m.shape[1], 3) ) + # color_mask = np.array([2.0,166.0,101.0])/255 + # for i in range(3): + # img[:,:,i] = color_mask[i] + # ax.imshow(np.dstack( (img, m*0.5) )) + # plt.show() + + def showMask(self, ref): + M = self.getMask(ref) + msk = M["mask"] + ax = plt.gca() + ax.imshow(msk) + + +if __name__ == "__main__": + refer = REFER(dataset="refcocog", splitBy="google") + ref_ids = refer.getRefIds() + print(len(ref_ids)) + + print(len(refer.Imgs)) + print(len(refer.imgToRefs)) + + ref_ids = refer.getRefIds(split="train") + print("There are %s training referred objects." % len(ref_ids)) + + for ref_id in ref_ids: + ref = refer.loadRefs(ref_id)[0] + if len(ref["sentences"]) < 2: + continue + + pprint(ref) + print("The label is %s." % refer.Cats[ref["category_id"]]) + plt.figure() + refer.showRef(ref, seg_box="box") + plt.show() + + # plt.figure() + # refer.showMask(ref) + # plt.show() diff --git a/utils/refer_seg_dataset.py b/utils/refer_seg_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9c54a07b8ab670b22c168f8e13439a4d9ec4aa0b --- /dev/null +++ b/utils/refer_seg_dataset.py @@ -0,0 +1,272 @@ +import os +import random + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from pycocotools import mask +from transformers import CLIPImageProcessor + +from model.segment_anything.utils.transforms import ResizeLongestSide + +from .conversation import get_default_conv_template +from .refer import REFER +from .utils import (ANSWER_LIST, DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, + DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IMAGE_TOKEN, + SHORT_QUESTION_LIST) + + +class ReferSegDataset(torch.utils.data.Dataset): + pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1) + pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1) + img_size = 1024 + ignore_label = 255 + + def __init__( + self, + base_image_dir, + tokenizer, + vision_tower, + samples_per_epoch=500 * 8 * 2 * 10, + precision: str = "fp32", + image_size: int = 224, + num_classes_per_sample: int = 3, + exclude_val=False, + refer_seg_data="refclef||refcoco||refcoco+||refcocog", + ): + self.exclude_val = exclude_val + self.samples_per_epoch = samples_per_epoch + self.num_classes_per_sample = num_classes_per_sample + + self.base_image_dir = base_image_dir + self.image_size = image_size + self.tokenizer = tokenizer + self.precision = precision + self.transform = ResizeLongestSide(image_size) + self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower) + + self.short_question_list = SHORT_QUESTION_LIST + self.answer_list = ANSWER_LIST + + DATA_DIR = os.path.join(base_image_dir, "refer_seg") + self.refer_seg_ds_list = refer_seg_data.split( + "||" + ) # ['refclef', 'refcoco', 'refcoco+', 'refcocog', ''] + self.refer_seg_data = {} + for ds in self.refer_seg_ds_list: + if ds == "refcocog": + splitBy = "umd" + else: + splitBy = "unc" + refer_api = REFER(DATA_DIR, ds, splitBy) + ref_ids_train = refer_api.getRefIds(split="train") + images_ids_train = refer_api.getImgIds(ref_ids=ref_ids_train) + refs_train = refer_api.loadRefs(ref_ids=ref_ids_train) + ref_file = os.path.join(DATA_DIR, ds, "refs(" + splitBy + ").p") + + refer_seg_ds = {} + refer_seg_ds["images"] = [] + loaded_images = refer_api.loadImgs(image_ids=images_ids_train) + + for item in loaded_images: + item = item.copy() + if ds == "refclef": + item["file_name"] = os.path.join( + DATA_DIR, "images/saiapr_tc-12", item["file_name"] + ) + else: + item["file_name"] = os.path.join( + DATA_DIR, "images/mscoco/images/train2014", item["file_name"] + ) + refer_seg_ds["images"].append(item) + refer_seg_ds["annotations"] = refer_api.Anns # anns_train + + print( + "dataset {} (refs {}) (train split) has {} images and {} annotations (before excluding: {} images)".format( + ds, + splitBy, + len(refer_seg_ds["images"]), + len(refer_seg_ds["annotations"]), + len(loaded_images), + ) + ) + + img2refs = {} + for ref in refs_train: + image_id = ref["image_id"] + img2refs[image_id] = img2refs.get(image_id, []) + [ + ref, + ] + refer_seg_ds["img2refs"] = img2refs + self.refer_seg_data[ds] = refer_seg_ds + + def __len__(self): + return self.samples_per_epoch + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.img_size - h + padw = self.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x + + def __getitem__(self, idx): + ds = random.randint(0, len(self.refer_seg_ds_list) - 1) + ds = self.refer_seg_ds_list[ds] + refer_seg_ds = self.refer_seg_data[ds] + images = refer_seg_ds["images"] + annotations = refer_seg_ds["annotations"] + img2refs = refer_seg_ds["img2refs"] + idx = random.randint(0, len(images) - 1) + image = images[idx] + image_path = image["file_name"] + image_id = image["id"] + refs = img2refs[image_id] + if len(refs) == 0: + return self.__getitem__(0) + + sents = [] + ann_ids = [] + for ref in refs: + for sent in ref["sentences"]: + text = sent["sent"] + sents.append(text) + ann_ids.append(ref["ann_id"]) + if len(sents) >= self.num_classes_per_sample: + sampled_inds = np.random.choice( + list(range(len(sents))), size=self.num_classes_per_sample, replace=False + ) + else: + sampled_inds = list(range(len(sents))) + sampled_sents = np.vectorize(sents.__getitem__)(sampled_inds).tolist() + sampled_ann_ids = np.vectorize(ann_ids.__getitem__)(sampled_inds).tolist() + sampled_classes = sampled_sents + img = cv2.imread(image_path) + images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + ori_size = images.shape[:2] + + # preprocess images for clip + images_clip = self.clip_image_processor.preprocess(images, return_tensors="pt")[ + "pixel_values" + ][0] + image_token_len = (images_clip.shape[1] // 14) * ( + images_clip.shape[2] // 14 + ) # FIXME: 14 is hardcoded patch size + images = self.transform.apply_image(images) # preprocess images for sam + resize = images.shape[:2] + + questions = [] + answers = [] + class_ids = [] + for text in sampled_classes: + text = text.strip() + assert len(text.split("||")) == 1 + question_template = random.choice(self.short_question_list) + questions.append(question_template.format(class_name=text.lower())) + answers.append(random.choice(self.answer_list)) + + conversations = [] + conv = get_default_conv_template("vicuna").copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + i = 0 + while i < len(questions): + conv.messages = [] + conv.append_message(conv.roles[0], questions[i]) + conv.append_message(conv.roles[1], answers[i]) + conversations.append(conv.get_prompt()) + i += 1 + + # ============================== + # replace token + for i in range(len(conversations)): + replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + replace_token = ( + DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + ) + conversations[i] = conversations[i].replace( + DEFAULT_IMAGE_TOKEN, replace_token + ) + # ============================== + + images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous()) + + masks = [] + for ann_id in sampled_ann_ids: + ann = annotations[ann_id] + + if len(ann["segmentation"]) == 0: + m = np.zeros((image["height"], image["width"])).astype(np.uint8) + masks.append(m) + continue + + if type(ann["segmentation"][0]) == list: # polygon + rle = mask.frPyObjects( + ann["segmentation"], image["height"], image["width"] + ) + else: + rle = ann["segmentation"] + for i in range(len(rle)): + if not isinstance(rle[i]["counts"], bytes): + rle[i]["counts"] = rle[i]["counts"].encode() + m = mask.decode(rle) + m = np.sum( + m, axis=2 + ) # sometimes there are multiple binary map (corresponding to multiple segs) + m = m.astype(np.uint8) # convert to np.uint8 + masks.append(m) + + masks = np.stack(masks, axis=0) + + # debug + # print("masks.shape: ", masks.shape) + # for i in range(masks.shape[0]): + # cv2.imwrite("debug/{}_mask_{}.png".format(image_path.split("refer_seg/images")[-1].replace("/", "-").split(".")[0], sampled_sents[i]), masks[i]*100) + + # debug + # if ds.endswith("masked"): + # save_dir = "./debug/{}".format(image_path.split("/")[-1].split(".")[0]) + # os.makedirs(save_dir, exist_ok=True) + # print("masks.shape: ", masks.shape) + # for i in range(masks.shape[0]): + # cv2.imwrite("{}/mask_{}.jpg".format(save_dir, i), masks[i]*100) + # assert len(conversations) == masks.shape[0] + # with open("{}/conversations.txt".format(save_dir), "w+") as f: + # for i in range(len(conversations)): + # f.write("{}. ".format(i) + conversations[i] + "\n") + # shutil.copy(image_path, save_dir) + + masks = torch.from_numpy(masks) + label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label + + # print("refer_seg: {}".format(conversations)) + + # # debug + # save_dir = "./debug/{}".format(image_path.split("/")[-1].split(".")[0]) + # os.makedirs(save_dir, exist_ok=True) + # print("masks.shape: ", masks.shape) + # for i in range(masks.shape[0]): + # cv2.imwrite("{}/mask_{}_{}.jpg".format(save_dir, i, sampled_classes[i]), masks[i].numpy().astype(np.uint8)*100) + # assert len(conversations) == masks.shape[0] + # with open("{}/conversations.txt".format(save_dir), "w+") as f: + # for i in range(len(conversations)): + # f.write("{}. ".format(i) + conversations[i] + "\n") + # shutil.copy(image_path, save_dir) + + return ( + image_path, + images, + images_clip, + conversations, + masks, + label, + resize, + questions, + sampled_classes, + ) diff --git a/utils/sem_seg_dataset.py b/utils/sem_seg_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d83aa1b389a6ed3ad5e19b88aad06438a7e61164 --- /dev/null +++ b/utils/sem_seg_dataset.py @@ -0,0 +1,359 @@ +import glob +import json +import os +import random + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from pycocotools.coco import COCO +from transformers import CLIPImageProcessor + +from model.segment_anything.utils.transforms import ResizeLongestSide + +from .conversation import get_default_conv_template +from .utils import (ANSWER_LIST, DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, + DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IMAGE_TOKEN, + SHORT_QUESTION_LIST) + +def init_mapillary(base_image_dir): + mapillary_data_root = os.path.join(base_image_dir, "mapillary") + with open(os.path.join(mapillary_data_root, "config_v2.0.json")) as f: + mapillary_classes = json.load(f)["labels"] + mapillary_classes = [x["readable"].lower() for x in mapillary_classes] + mapillary_classes = np.array(mapillary_classes) + mapillary_labels = sorted( + glob.glob( + os.path.join(mapillary_data_root, "training", "v2.0", "labels", "*.png") + ) + ) + mapillary_images = [ + x.replace(".png", ".jpg").replace("v2.0/labels", "images") + for x in mapillary_labels + ] + print("mapillary: ", len(mapillary_images)) + return mapillary_classes, mapillary_images, mapillary_labels + + +def init_ade20k(base_image_dir): + with open("utils/ade20k_classes.json", "r") as f: + ade20k_classes = json.load(f) + ade20k_classes = np.array(ade20k_classes) + image_ids = sorted( + os.listdir(os.path.join(base_image_dir, "ade20k/images", "training")) + ) + ade20k_image_ids = [] + for x in image_ids: + if x.endswith(".jpg"): + ade20k_image_ids.append(x[:-4]) + ade20k_images = [] + for image_id in ade20k_image_ids: # self.descriptions: + ade20k_images.append( + os.path.join( + base_image_dir, + "ade20k", + "images", + "training", + "{}.jpg".format(image_id), + ) + ) + ade20k_labels = [ + x.replace(".jpg", ".png").replace("images", "annotations") + for x in ade20k_images + ] + print("ade20k: ", len(ade20k_images)) + return ade20k_classes, ade20k_images, ade20k_labels + + +def init_cocostuff(base_image_dir): + cocostuff_classes = [] + with open("utils/cocostuff_classes.txt") as f: + for line in f.readlines()[1:]: + cocostuff_classes.append(line.strip().split(": ")[-1]) + cocostuff_classes = np.array(cocostuff_classes) + cocostuff_images = [] + cocostuff_image_dir = glob.glob( + os.path.join(base_image_dir, "cocostuff", "train2017", "*.jpg") + ) + for image_id in cocostuff_image_dir: + cocostuff_images.append(image_id) + cocostuff_labels = [ + x.replace(".jpg", ".png").replace("images", "annotations") + for x in cocostuff_images + ] + print("cocostuff: ", len(cocostuff_images)) + return cocostuff_classes, cocostuff_images, cocostuff_labels + + +def init_paco_lvis(base_image_dir): + coco_api_paco_lvis = COCO( + os.path.join( + base_image_dir, "vlpart", "paco", "annotations", "paco_lvis_v1_train.json" + ) + ) + all_classes = coco_api_paco_lvis.loadCats(coco_api_paco_lvis.getCatIds()) + class_map_paco_lvis = {} + for cat in all_classes: + cat_split = cat["name"].strip().split(":") + if len(cat_split) == 1: + name = cat_split[0].split("_(")[0] + else: + assert len(cat_split) == 2 + obj, part = cat_split + obj = obj.split("_(")[0] + part = part.split("_(")[0] + # if random.random() < 0.5: + # name = obj + " " + part + # else: + # name = "the {} of the {}".format(part, obj) + name = (obj, part) + class_map_paco_lvis[cat["id"]] = name + img_ids = coco_api_paco_lvis.getImgIds() + print("paco_lvis: ", len(img_ids)) + return class_map_paco_lvis, img_ids, coco_api_paco_lvis + + +def init_pascal_part(base_image_dir): + coco_api_pascal_part = COCO( + os.path.join(base_image_dir, "vlpart", "pascal_part", "train.json") + ) + all_classes = coco_api_pascal_part.loadCats(coco_api_pascal_part.getCatIds()) + class_map_pascal_part = {} + for cat in all_classes: + cat_main, cat_part = cat["name"].strip().split(":") + name = (cat_main, cat_part) + class_map_pascal_part[cat["id"]] = name + img_ids = coco_api_pascal_part.getImgIds() + print("pascal_part: ", len(img_ids)) + return class_map_pascal_part, img_ids, coco_api_pascal_part + + +class SemSegDataset(torch.utils.data.Dataset): + pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1) + pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1) + img_size = 1024 + ignore_label = 255 + + def __init__( + self, + base_image_dir, + tokenizer, + vision_tower, + samples_per_epoch=500 * 8 * 2 * 10, + precision: str = "fp32", + image_size: int = 224, + num_classes_per_sample: int = 3, + exclude_val=False, + sem_seg_data="ade20k||cocostuff||partimagenet||pascal_part||paco_lvis||mapillary", + ): + self.exclude_val = exclude_val + self.samples_per_epoch = samples_per_epoch + self.num_classes_per_sample = num_classes_per_sample + + self.base_image_dir = base_image_dir + self.image_size = image_size + self.tokenizer = tokenizer + self.precision = precision + self.transform = ResizeLongestSide(image_size) + self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower) + + self.short_question_list = SHORT_QUESTION_LIST + self.answer_list = ANSWER_LIST + + self.data2list = {} + self.data2classes = {} + + self.sem_seg_datas = sem_seg_data.split("||") + for ds in self.sem_seg_datas: + classes, images, labels = eval("init_{}".format(ds))(base_image_dir) + self.data2list[ds] = (images, labels) + self.data2classes[ds] = classes + + if "cocostuff" in self.sem_seg_datas: + self.cocostuff_class2index = { + c: i for i, c in enumerate(self.data2classes["cocostuff"]) + } + + def __len__(self): + return self.samples_per_epoch + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.img_size - h + padw = self.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x + + def __getitem__(self, idx): + ds = random.randint(0, len(self.sem_seg_datas) - 1) + ds = self.sem_seg_datas[ds] + + if ds in ["paco_lvis", "pascal_part"]: + class_map = self.data2classes[ds] + img_ids, coco_api = self.data2list[ds] + idx = random.randint(0, len(img_ids) - 1) + img_id = img_ids[idx] + image = coco_api.loadImgs([img_id])[0] + file_name = image["file_name"] + if ds == "pascal_part": + file_name = os.path.join( + "VOCdevkit", "VOC2010", "JPEGImages", file_name + ) + image_path = os.path.join(self.base_image_dir, "vlpart", ds, file_name) + elif ds == "paco_lvis": + image_path = os.path.join(self.base_image_dir, "coco", file_name) + img = cv2.imread(image_path) + images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # preprocess images for clip + images_clip = self.clip_image_processor.preprocess( + images, return_tensors="pt" + )["pixel_values"][0] + image_token_len = (images_clip.shape[1] // 14) * ( + images_clip.shape[2] // 14 + ) # FIXME: 14 is hardcoded patch size + + images = self.transform.apply_image(images) # preprocess images for sam + resize = images.shape[:2] + annIds = coco_api.getAnnIds(imgIds=image["id"]) + anns = coco_api.loadAnns(annIds) + if len(anns) == 0: + return self.__getitem__(0) + if len(anns) >= self.num_classes_per_sample: + sampled_anns = np.random.choice( + anns, size=self.num_classes_per_sample, replace=False + ).tolist() + else: + sampled_anns = anns + sampled_classes = [] + for ann in sampled_anns: + sampled_cls = class_map[ann["category_id"]] + if isinstance(sampled_cls, tuple): + obj, part = sampled_cls + if random.random() < 0.5: + name = obj + " " + part + else: + name = "the {} of the {}".format(part, obj) + else: + name = sampled_cls + sampled_classes.append(name) + + elif ds in ["ade20k", "cocostuff", "mapillary"]: + images, labels = self.data2list[ds] + idx = random.randint(0, len(images) - 1) + image_path = images[idx] + label_path = labels[idx] + label = Image.open(label_path) + label = np.array(label) + if ds == "ade20k": + label[label == 0] = 255 + label -= 1 + label[label == 254] = 255 + elif ds == "cocostuff": + for c, i in self.cocostuff_class2index.items(): + if "-" in c: + label[label == i] = 255 + img = cv2.imread(image_path) + images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + # preprocess images for clip + images_clip = self.clip_image_processor.preprocess( + images, return_tensors="pt" + )["pixel_values"][0] + image_token_len = (images_clip.shape[1] // 14) * ( + images_clip.shape[2] // 14 + ) # FIXME: 14 is hardcoded patch size + images = self.transform.apply_image(images) # preprocess images for sam + resize = images.shape[:2] + unique_label = np.unique(label).tolist() + if 255 in unique_label: + unique_label.remove(255) + if len(unique_label) == 0: + return self.__getitem__(0) + + classes = [self.data2classes[ds][class_id] for class_id in unique_label] + if len(classes) >= self.num_classes_per_sample: + sampled_classes = np.random.choice( + classes, size=self.num_classes_per_sample, replace=False + ).tolist() + else: + sampled_classes = classes + + questions = [] + answers = [] + class_ids = [] + for sampled_cls in sampled_classes: + text = sampled_cls + + assert len(text.split("||")) == 1 + question_template = random.choice(self.short_question_list) + questions.append(question_template.format(class_name=text.lower())) + + answers.append(random.choice(self.answer_list)) + + if ds in ["paco_lvis", "pascal_part"]: + continue + + class_id = self.data2classes[ds].tolist().index(sampled_cls) + class_ids.append(class_id) + + conversations = [] + conv = get_default_conv_template("vicuna").copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + i = 0 + while i < len(questions): + conv.messages = [] + conv.append_message(conv.roles[0], questions[i]) + conv.append_message(conv.roles[1], answers[i]) + conversations.append(conv.get_prompt()) + i += 1 + + # replace token + for i in range(len(conversations)): + replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + replace_token = ( + DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + ) + conversations[i] = conversations[i].replace( + DEFAULT_IMAGE_TOKEN, replace_token + ) + + images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous()) + + if ds in ["paco_lvis", "pascal_part"]: + masks = [] + for ann in sampled_anns: + try: + masks.append(coco_api.annToMask(ann)) + except Exception as e: + print(e) + return self.__getitem__(0) + + masks = np.stack(masks, axis=0) + masks = torch.from_numpy(masks) + label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label + + else: + label = torch.from_numpy(label).long() + masks = [] + for class_id in class_ids: + masks.append(label == class_id) + masks = torch.stack(masks, dim=0) + return ( + image_path, + images, + images_clip, + conversations, + masks, + label, + resize, + questions, + sampled_classes, + ) diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..275aa832daeb5136ef10d1774cdfaa9bd1ae5bae --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,156 @@ +from enum import Enum + +import numpy as np +import torch +import torch.distributed as dist + +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_IMAGE_PATCH_TOKEN = "" +DEFAULT_IM_START_TOKEN = "" +DEFAULT_IM_END_TOKEN = "" + +SHORT_QUESTION_LIST = [ + DEFAULT_IMAGE_TOKEN + " " + "Can you segment the {class_name} in this image?", + DEFAULT_IMAGE_TOKEN + " " + "Please segment the {class_name} in this image.", + DEFAULT_IMAGE_TOKEN + " " + "What is {class_name} in this image? Please respond with segmentation mask.", + DEFAULT_IMAGE_TOKEN + " " + "What is {class_name} in this image? Please output segmentation mask.", +] + +LONG_QUESTION_LIST = [ + DEFAULT_IMAGE_TOKEN + " " + "{sent} Please respond with segmentation mask.", + DEFAULT_IMAGE_TOKEN + " " + "{sent} Please output segmentation mask.", +] + +EXPLANATORY_QUESTION_LIST = [ + "Please output segmentation mask and explain why.", + "Please output segmentation mask and explain the reason.", + "Please output segmentation mask and give some explaination.", +] + +ANSWER_LIST = [ + "It is [SEG].", + "Sure, [SEG].", + "Sure, it is [SEG].", + "Sure, the segmentation result is [SEG].", + "[SEG].", +] + + +class Summary(Enum): + NONE = 0 + AVERAGE = 1 + SUM = 2 + COUNT = 3 + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE): + self.name = name + self.fmt = fmt + self.summary_type = summary_type + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def all_reduce(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(self.sum, np.ndarray): + total = torch.tensor( + self.sum.tolist() + + [ + self.count, + ], + dtype=torch.float32, + device=device, + ) + else: + total = torch.tensor( + [self.sum, self.count], dtype=torch.float32, device=device + ) + + dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False) + if total.shape[0] > 2: + self.sum, self.count = total[:-1].cpu().numpy(), total[-1].cpu().item() + else: + self.sum, self.count = total.tolist() + self.avg = self.sum / (self.count + 1e-5) + + def __str__(self): + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + def summary(self): + fmtstr = "" + if self.summary_type is Summary.NONE: + fmtstr = "" + elif self.summary_type is Summary.AVERAGE: + fmtstr = "{name} {avg:.3f}" + elif self.summary_type is Summary.SUM: + fmtstr = "{name} {sum:.3f}" + elif self.summary_type is Summary.COUNT: + fmtstr = "{name} {count:.3f}" + else: + raise ValueError("invalid summary type %r" % self.summary_type) + + return fmtstr.format(**self.__dict__) + + +def intersectionAndUnionGPU(output, target, K, ignore_index=255): + # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. + assert output.dim() in [1, 2, 3] + assert output.shape == target.shape + output = output.view(-1) + target = target.view(-1) + output[target == ignore_index] = ignore_index + intersection = output[output == target] + area_intersection = torch.histc(intersection, bins=K, min=0, max=K - 1) + area_output = torch.histc(output, bins=K, min=0, max=K - 1) + area_target = torch.histc(target, bins=K, min=0, max=K - 1) + area_union = area_output + area_target - area_intersection + return area_intersection, area_union, area_target + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print("\t".join(entries)) + + def display_summary(self): + entries = [" *"] + entries += [meter.summary() for meter in self.meters] + print(" ".join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = "{:" + str(num_digits) + "d}" + return "[" + fmt + "/" + fmt.format(num_batches) + "]" + + +def dict_to_cuda(input_dict): + for k, v in input_dict.items(): + if isinstance(input_dict[k], torch.Tensor): + input_dict[k] = v.cuda(non_blocking=True) + elif ( + isinstance(input_dict[k], list) + and len(input_dict[k]) > 0 + and isinstance(input_dict[k][0], torch.Tensor) + ): + input_dict[k] = [ele.cuda(non_blocking=True) for ele in v] + return input_dict diff --git a/utils/vqa_dataset.py b/utils/vqa_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9aeb322a9f4fdbf5bd6659d133fcd000a737d48d --- /dev/null +++ b/utils/vqa_dataset.py @@ -0,0 +1,126 @@ +import json +import os +import random + +import cv2 +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor + +from model.segment_anything.utils.transforms import ResizeLongestSide + +from .conversation import get_default_conv_template +from .utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, + DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IMAGE_TOKEN) + +class VQADataset(torch.utils.data.Dataset): + pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1) + pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1) + img_size = 1024 + ignore_label = 255 + + def __init__( + self, + base_image_dir, + tokenizer, + vision_tower, + samples_per_epoch=500 * 8 * 2 * 10, + precision: str = "fp32", + image_size: int = 224, + num_classes_per_sample: int = 3, + exclude_val=False, + vqa_data="llava_instruct_150k", + ): + self.exclude_val = exclude_val + self.samples_per_epoch = samples_per_epoch + self.num_classes_per_sample = num_classes_per_sample + + self.base_image_dir = base_image_dir + self.image_size = image_size + self.tokenizer = tokenizer + self.precision = precision + self.transform = ResizeLongestSide(image_size) + self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower) + + DATA_DIR = os.path.join(base_image_dir, "llava_dataset") + self.vqa_image_root = os.path.join(base_image_dir, "coco/train2017") + with open(os.path.join(DATA_DIR, "{}.json".format(vqa_data))) as f: + vqa_data = json.load(f) + self.vqa_data = vqa_data + + print("vqa_data: ", len(self.vqa_data)) + + def __len__(self): + return self.samples_per_epoch + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.img_size - h + padw = self.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x + + def __getitem__(self, idx): + idx = random.randint(0, len(self.vqa_data) - 1) + item = self.vqa_data[idx] + image_path = os.path.join(self.vqa_image_root, item["image"]) + img = cv2.imread(image_path) + images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + ori_size = images.shape[:2] + images_clip = self.clip_image_processor.preprocess(images, return_tensors="pt")["pixel_values"][0] # preprocess images for clip + image_token_len = (images_clip.shape[1] // 14) * ( + images_clip.shape[2] // 14 + ) # FIXME: 14 is hardcoded patch size + + images = self.transform.apply_image(images) # preprocess images for sam + resize = images.shape[:2] + source = item["conversations"] + conv = get_default_conv_template( + "vicuna" + ).copy() # conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + conversations = [] + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + questions = conversations + sampled_classes = conversations + + # replace token + for i in range(len(conversations)): + replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + replace_token = ( + DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + ) + conversations[i] = conversations[i].replace( + DEFAULT_IMAGE_TOKEN, replace_token + ) + + images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous()) + + masks = torch.rand(0, *ori_size) + label = torch.ones(ori_size) * self.ignore_label + + return ( + image_path, + images, + images_clip, + conversations, + masks, + label, + resize, + questions, + sampled_classes, + )