File size: 4,577 Bytes
cb3a670 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
from medrax.llava.model import LlavaMistralForCausalLM
from medrax.llava.constants import (
DEFAULT_IMAGE_PATCH_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
)
def load_pretrained_model(
model_path,
model_base,
model_name,
load_in_8bit=False,
load_in_4bit=True,
device="cuda",
cache_dir: str = "/model-weights",
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
):
kwargs = {}
if device != "cuda":
kwargs["device_map"] = {"": device}
# else:
# kwargs["device_map"] = "auto"
if load_in_8bit:
kwargs["load_in_8bit"] = True
elif load_in_4bit:
# kwargs["load_in_4bit"] = True
kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch_dtype,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
# else:
# kwargs["torch_dtype"] = torch_dtype
if "llava" in model_name.lower():
# Load LLaVA model
if "mistral" in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=cache_dir)
model = LlavaMistralForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=low_cpu_mem_usage,
use_flash_attention_2=False,
cache_dir=cache_dir,
torch_dtype=torch_dtype,
**kwargs,
)
else:
# Load language model
if model_base is not None:
# PEFT model
from peft import PeftModel
tokenizer = AutoTokenizer.from_pretrained(
model_base, use_fast=False, cache_dir=cache_dir
)
model = AutoModelForCausalLM.from_pretrained(
model_base,
low_cpu_mem_usage=True,
cache_dir=cache_dir,
torch_dtype=torch_dtype,
**kwargs,
)
print(f"Loading LoRA weights from {model_path}")
model = PeftModel.from_pretrained(model, model_path)
print("Merging weights")
model = model.merge_and_unload()
print("Convert to FP16...")
model.to(torch_dtype)
else:
use_fast = False
if "mpt" in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=True, cache_dir=cache_dir
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
trust_remote_code=True,
cache_dir=cache_dir,
torch_dtype=torch_dtype,
**kwargs,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=False, cache_dir=cache_dir
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
cache_dir=cache_dir,
torch_dtype=torch_dtype,
**kwargs,
)
image_processor = None
if "llava" in model_name.lower(): # or 'mistral' in model_name.lower():
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
)
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model()
vision_tower.to(device=device, dtype=torch_dtype)
model.model.mm_projector.to(device=device, dtype=torch_dtype)
if not (load_in_4bit or load_in_8bit):
model.to(device=device, dtype=torch_dtype)
image_processor = vision_tower.image_processor
if hasattr(model.config, "max_sequence_length"):
context_len = model.config.max_sequence_length
else:
context_len = 2048
return tokenizer, model, image_processor, context_len
|