File size: 4,724 Bytes
6f5e4d4
b849b51
 
 
2b27faa
737a09d
b849b51
6f5e4d4
 
 
b849b51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
737a09d
b849b51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
737a09d
b849b51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd97c8c
b849b51
fd97c8c
b849b51
 
 
fd97c8c
b849b51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd97c8c
b849b51
 
 
 
2b27faa
 
 
b849b51
 
 
 
 
 
 
 
 
 
 
2b27faa
 
 
 
 
b849b51
 
 
 
2b27faa
 
 
 
 
 
fd97c8c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
import torch
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, PreTrainedTokenizerFast
from peft import PeftModel
import re
import streamlit as st

# Set the cache directory to persistent storage
os.environ["HF_HOME"] = "/data/.cache/huggingface"

#-----------------------------------------
# Quantization Config
#-----------------------------------------
def get_bnb_config():
    return BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_storage=torch.float16
    )

#-----------------------------------------
# Base Model Loader
#-----------------------------------------
@st.cache_resource
def load_base_model(base_model_path: str):
    """
    Loads a base LLM model with 4-bit quantization and tokenizer.
    
    Args:
        base_model_path (str): HF model path
    
    Returns:
        model (AutoModelForCausalLM)
        tokenizer (PreTrainedTokenizerFast)
    """
    bnb_config = get_bnb_config()

    tokenizer = PreTrainedTokenizerFast.from_pretrained(base_model_path, return_tensors="pt")

    model = AutoModelForCausalLM.from_pretrained(
        base_model_path,
        quantization_config=bnb_config,
        trust_remote_code=True,
        attn_implementation="eager",
        torch_dtype=torch.float16
    )

    return model, tokenizer

#-----------------------------------------
# Fine-Tuned Model Loader
#-----------------------------------------
@st.cache_resource
def load_fine_tuned_model(adapter_path: str, base_model_path: str):
    """
    Loads the fine-tuned model by applying LoRA adapter to a base model.
    
    Args:
        adapter_path (str): Local or HF adapter path
        base_model_path (str): Base LLM model path
    
    Returns:
        fine_tuned_model (PeftModel)
        tokenizer (PreTrainedTokenizerFast)
    """
    bnb_config = get_bnb_config()

    tokenizer = PreTrainedTokenizerFast.from_pretrained(base_model_path, return_tensors="pt")

    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_path,
        quantization_config=bnb_config,
        trust_remote_code=True,
        attn_implementation="eager",
        torch_dtype=torch.float16
    )

    fine_tuned_model = PeftModel.from_pretrained(
        base_model,
        adapter_path,
        device_map="auto"
    )

    return fine_tuned_model, tokenizer

#-----------------------------------------
# Inference Function
#-----------------------------------------
@torch.no_grad()
def generate_response(
    model: AutoModelForCausalLM,
    tokenizer: PreTrainedTokenizerFast,
    messages: list,
    tokenizer_max_length: int = 500,
    do_sample: bool = False,
    temperature: float = 0.1,
    top_k: int = 50,
    top_p: float = 0.95,
    num_beams: int = 1,
    max_new_tokens: int = 700
) -> str:
    """
    Runs inference on an LLM model.
    Args:
        model (AutoModelForCausalLM)
        tokenizer (PreTrainedTokenizerFast)
        messages (list): List of dicts containing 'role' and 'content'
    
    Returns:
        str: Model response
    """
    # Ensure pad token exists
    tokenizer.pad_token = "<|reserved_special_token_5|>"

    # Create chat prompt
    input_text = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=False
    )

    # Tokenize input
    inputs = tokenizer(
        input_text,
        max_length=tokenizer_max_length,
        truncation=True,
        return_tensors="pt"
    ).to(model.device)

    # Store the number of input tokens for reference
    input_token_length = inputs.input_ids.shape[1]

    generation_params = {
        "do_sample": do_sample,
        "temperature": temperature if do_sample else None,
        "top_k": top_k if do_sample else None,
        "top_p": top_p if do_sample else None,
        "num_beams": num_beams if not do_sample else 1,
        "max_new_tokens": max_new_tokens
    }

    output = model.generate(**inputs, **generation_params)

    # Extract only the newly generated tokens
    new_tokens = output[0][input_token_length:]
     
    # Decode only the new tokens
    response = tokenizer.decode(new_tokens, skip_special_tokens=True)

    if 'assistant' in response:
        response = response.split('assistant')[1].strip()

    # In case there's still any assistant prefix, clean it up
    if response.startswith("assistant") or response.startswith("<assistant>"):
        response = re.sub(r"^assistant[:\s]*|^<assistant>[\s]*", "", response, flags=re.IGNORECASE)
    
    response = re.sub(r'^\s*(?:answer\s*)+:?\s*', '', response, flags=re.IGNORECASE)
    
    return response