|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
import torch |
|
from medrax.llava.model import LlavaMistralForCausalLM |
|
from medrax.llava.constants import ( |
|
DEFAULT_IMAGE_PATCH_TOKEN, |
|
DEFAULT_IM_START_TOKEN, |
|
DEFAULT_IM_END_TOKEN, |
|
) |
|
|
|
|
|
def load_pretrained_model( |
|
model_path, |
|
model_base, |
|
model_name, |
|
load_in_8bit=False, |
|
load_in_4bit=True, |
|
device="cuda", |
|
cache_dir: str = "/model-weights", |
|
low_cpu_mem_usage=True, |
|
torch_dtype=torch.bfloat16, |
|
): |
|
|
|
kwargs = {} |
|
|
|
if device != "cuda": |
|
kwargs["device_map"] = {"": device} |
|
|
|
|
|
|
|
if load_in_8bit: |
|
kwargs["load_in_8bit"] = True |
|
elif load_in_4bit: |
|
|
|
kwargs["quantization_config"] = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch_dtype, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
) |
|
|
|
|
|
|
|
if "llava" in model_name.lower(): |
|
|
|
if "mistral" in model_name.lower(): |
|
tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=cache_dir) |
|
model = LlavaMistralForCausalLM.from_pretrained( |
|
model_path, |
|
low_cpu_mem_usage=low_cpu_mem_usage, |
|
use_flash_attention_2=False, |
|
cache_dir=cache_dir, |
|
torch_dtype=torch_dtype, |
|
**kwargs, |
|
) |
|
|
|
else: |
|
|
|
if model_base is not None: |
|
|
|
from peft import PeftModel |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_base, use_fast=False, cache_dir=cache_dir |
|
) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_base, |
|
low_cpu_mem_usage=True, |
|
cache_dir=cache_dir, |
|
torch_dtype=torch_dtype, |
|
**kwargs, |
|
) |
|
print(f"Loading LoRA weights from {model_path}") |
|
model = PeftModel.from_pretrained(model, model_path) |
|
print("Merging weights") |
|
model = model.merge_and_unload() |
|
print("Convert to FP16...") |
|
model.to(torch_dtype) |
|
else: |
|
use_fast = False |
|
if "mpt" in model_name.lower(): |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_path, use_fast=True, cache_dir=cache_dir |
|
) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
low_cpu_mem_usage=True, |
|
trust_remote_code=True, |
|
cache_dir=cache_dir, |
|
torch_dtype=torch_dtype, |
|
**kwargs, |
|
) |
|
else: |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_path, use_fast=False, cache_dir=cache_dir |
|
) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
low_cpu_mem_usage=True, |
|
cache_dir=cache_dir, |
|
torch_dtype=torch_dtype, |
|
**kwargs, |
|
) |
|
|
|
image_processor = None |
|
|
|
if "llava" in model_name.lower(): |
|
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) |
|
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) |
|
if mm_use_im_patch_token: |
|
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) |
|
if mm_use_im_start_end: |
|
tokenizer.add_tokens( |
|
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True |
|
) |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
vision_tower = model.get_vision_tower() |
|
if not vision_tower.is_loaded: |
|
vision_tower.load_model() |
|
|
|
vision_tower.to(device=device, dtype=torch_dtype) |
|
model.model.mm_projector.to(device=device, dtype=torch_dtype) |
|
|
|
if not (load_in_4bit or load_in_8bit): |
|
model.to(device=device, dtype=torch_dtype) |
|
|
|
image_processor = vision_tower.image_processor |
|
|
|
if hasattr(model.config, "max_sequence_length"): |
|
context_len = model.config.max_sequence_length |
|
else: |
|
context_len = 2048 |
|
|
|
return tokenizer, model, image_processor, context_len |
|
|