File size: 3,951 Bytes
b849b51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, PreTrainedTokenizerFast
from peft import PeftModel

#-----------------------------------------
# Quantization Config
#-----------------------------------------
def get_bnb_config():
    return BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_storage=torch.float16
    )

#-----------------------------------------
# Base Model Loader
#-----------------------------------------
def load_base_model(base_model_path: str):
    """
    Loads a base LLM model with 4-bit quantization and tokenizer.
    
    Args:
        base_model_path (str): HF model path
    
    Returns:
        model (AutoModelForCausalLM)
        tokenizer (PreTrainedTokenizerFast)
    """
    bnb_config = get_bnb_config()

    tokenizer = PreTrainedTokenizerFast.from_pretrained(base_model_path, return_tensors="pt")

    model = AutoModelForCausalLM.from_pretrained(
        base_model_path,
        quantization_config=bnb_config,
        trust_remote_code=True,
        attn_implementation="eager",
        torch_dtype=torch.float16
    )

    return model, tokenizer

#-----------------------------------------
# Fine-Tuned Model Loader
#-----------------------------------------
def load_fine_tuned_model(adapter_path: str, base_model_path: str):
    """
    Loads the fine-tuned model by applying LoRA adapter to a base model.
    
    Args:
        adapter_path (str): Local or HF adapter path
        base_model_path (str): Base LLM model path
    
    Returns:
        fine_tuned_model (PeftModel)
        tokenizer (PreTrainedTokenizerFast)
    """
    bnb_config = get_bnb_config()

    tokenizer = PreTrainedTokenizerFast.from_pretrained(base_model_path, return_tensors="pt")

    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_path,
        quantization_config=bnb_config,
        trust_remote_code=True,
        attn_implementation="eager",
        torch_dtype=torch.float16
    )

    fine_tuned_model = PeftModel.from_pretrained(
        base_model,
        adapter_path,
        device_map="auto"
    )

    return fine_tuned_model, tokenizer

#-----------------------------------------
# Inference Function
#-----------------------------------------
@torch.no_grad()
def generate_response(
    model: AutoModelForCausalLM,
    tokenizer: PreTrainedTokenizerFast,
    messages: list,
    do_sample: bool = False,
    temperature: float = 0.7,
    top_k: int = 50,
    top_p: float = 0.95,
    num_beams: int = 1,
    max_new_tokens: int = 500
) -> str:
    """
    Runs inference on an LLM model.

    Args:
        model (AutoModelForCausalLM)
        tokenizer (PreTrainedTokenizerFast)
        messages (list): List of dicts containing 'role' and 'content'
    
    Returns:
        str: Model response
    """
    # Ensure pad token exists
    tokenizer.pad_token = "<|reserved_special_token_5|>"

    # Create chat prompt
    input_text = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=False
    )

    # Tokenize input
    inputs = tokenizer(
        input_text,
        max_length=500,
        truncation=True,
        return_tensors="pt"
    ).to(model.device)

    generation_params = {
        "do_sample": do_sample,
        "temperature": temperature if do_sample else None,
        "top_k": top_k if do_sample else None,
        "top_p": top_p if do_sample else None,
        "num_beams": num_beams if not do_sample else 1,
        "max_new_tokens": max_new_tokens
    }

    output = model.generate(**inputs, **generation_params)

    # Decode and clean up response
    response = tokenizer.decode(output[0], skip_special_tokens=True)

    if 'assistant' in response:
        response = response.split('assistant')[1].strip()

    return response