File size: 4,577 Bytes
cb3a670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
from medrax.llava.model import LlavaMistralForCausalLM
from medrax.llava.constants import (
    DEFAULT_IMAGE_PATCH_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
)


def load_pretrained_model(
    model_path,
    model_base,
    model_name,
    load_in_8bit=False,
    load_in_4bit=True,
    device="cuda",
    cache_dir: str = "/model-weights",
    low_cpu_mem_usage=True,
    torch_dtype=torch.bfloat16,
):

    kwargs = {}

    if device != "cuda":
        kwargs["device_map"] = {"": device}
    # else:
    #     kwargs["device_map"] = "auto"

    if load_in_8bit:
        kwargs["load_in_8bit"] = True
    elif load_in_4bit:
        # kwargs["load_in_4bit"] = True
        kwargs["quantization_config"] = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch_dtype,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        )
    # else:
    # kwargs["torch_dtype"] = torch_dtype

    if "llava" in model_name.lower():
        # Load LLaVA model
        if "mistral" in model_name.lower():
            tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=cache_dir)
            model = LlavaMistralForCausalLM.from_pretrained(
                model_path,
                low_cpu_mem_usage=low_cpu_mem_usage,
                use_flash_attention_2=False,
                cache_dir=cache_dir,
                torch_dtype=torch_dtype,
                **kwargs,
            )

    else:
        # Load language model
        if model_base is not None:
            # PEFT model
            from peft import PeftModel

            tokenizer = AutoTokenizer.from_pretrained(
                model_base, use_fast=False, cache_dir=cache_dir
            )
            model = AutoModelForCausalLM.from_pretrained(
                model_base,
                low_cpu_mem_usage=True,
                cache_dir=cache_dir,
                torch_dtype=torch_dtype,
                **kwargs,
            )
            print(f"Loading LoRA weights from {model_path}")
            model = PeftModel.from_pretrained(model, model_path)
            print("Merging weights")
            model = model.merge_and_unload()
            print("Convert to FP16...")
            model.to(torch_dtype)
        else:
            use_fast = False
            if "mpt" in model_name.lower():
                tokenizer = AutoTokenizer.from_pretrained(
                    model_path, use_fast=True, cache_dir=cache_dir
                )
                model = AutoModelForCausalLM.from_pretrained(
                    model_path,
                    low_cpu_mem_usage=True,
                    trust_remote_code=True,
                    cache_dir=cache_dir,
                    torch_dtype=torch_dtype,
                    **kwargs,
                )
            else:
                tokenizer = AutoTokenizer.from_pretrained(
                    model_path, use_fast=False, cache_dir=cache_dir
                )
                model = AutoModelForCausalLM.from_pretrained(
                    model_path,
                    low_cpu_mem_usage=True,
                    cache_dir=cache_dir,
                    torch_dtype=torch_dtype,
                    **kwargs,
                )

    image_processor = None

    if "llava" in model_name.lower():  # or 'mistral' in model_name.lower():
        mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
        mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
        if mm_use_im_patch_token:
            tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
        if mm_use_im_start_end:
            tokenizer.add_tokens(
                [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
            )
        model.resize_token_embeddings(len(tokenizer))

        vision_tower = model.get_vision_tower()
        if not vision_tower.is_loaded:
            vision_tower.load_model()

        vision_tower.to(device=device, dtype=torch_dtype)
        model.model.mm_projector.to(device=device, dtype=torch_dtype)

        if not (load_in_4bit or load_in_8bit):
            model.to(device=device, dtype=torch_dtype)

        image_processor = vision_tower.image_processor

    if hasattr(model.config, "max_sequence_length"):
        context_len = model.config.max_sequence_length
    else:
        context_len = 2048

    return tokenizer, model, image_processor, context_len