mostafa-sh's picture
add local model
b849b51
raw
history blame
3.95 kB
import torch
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, PreTrainedTokenizerFast
from peft import PeftModel
#-----------------------------------------
# Quantization Config
#-----------------------------------------
def get_bnb_config():
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_storage=torch.float16
)
#-----------------------------------------
# Base Model Loader
#-----------------------------------------
def load_base_model(base_model_path: str):
"""
Loads a base LLM model with 4-bit quantization and tokenizer.
Args:
base_model_path (str): HF model path
Returns:
model (AutoModelForCausalLM)
tokenizer (PreTrainedTokenizerFast)
"""
bnb_config = get_bnb_config()
tokenizer = PreTrainedTokenizerFast.from_pretrained(base_model_path, return_tensors="pt")
model = AutoModelForCausalLM.from_pretrained(
base_model_path,
quantization_config=bnb_config,
trust_remote_code=True,
attn_implementation="eager",
torch_dtype=torch.float16
)
return model, tokenizer
#-----------------------------------------
# Fine-Tuned Model Loader
#-----------------------------------------
def load_fine_tuned_model(adapter_path: str, base_model_path: str):
"""
Loads the fine-tuned model by applying LoRA adapter to a base model.
Args:
adapter_path (str): Local or HF adapter path
base_model_path (str): Base LLM model path
Returns:
fine_tuned_model (PeftModel)
tokenizer (PreTrainedTokenizerFast)
"""
bnb_config = get_bnb_config()
tokenizer = PreTrainedTokenizerFast.from_pretrained(base_model_path, return_tensors="pt")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
quantization_config=bnb_config,
trust_remote_code=True,
attn_implementation="eager",
torch_dtype=torch.float16
)
fine_tuned_model = PeftModel.from_pretrained(
base_model,
adapter_path,
device_map="auto"
)
return fine_tuned_model, tokenizer
#-----------------------------------------
# Inference Function
#-----------------------------------------
@torch.no_grad()
def generate_response(
model: AutoModelForCausalLM,
tokenizer: PreTrainedTokenizerFast,
messages: list,
do_sample: bool = False,
temperature: float = 0.7,
top_k: int = 50,
top_p: float = 0.95,
num_beams: int = 1,
max_new_tokens: int = 500
) -> str:
"""
Runs inference on an LLM model.
Args:
model (AutoModelForCausalLM)
tokenizer (PreTrainedTokenizerFast)
messages (list): List of dicts containing 'role' and 'content'
Returns:
str: Model response
"""
# Ensure pad token exists
tokenizer.pad_token = "<|reserved_special_token_5|>"
# Create chat prompt
input_text = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False
)
# Tokenize input
inputs = tokenizer(
input_text,
max_length=500,
truncation=True,
return_tensors="pt"
).to(model.device)
generation_params = {
"do_sample": do_sample,
"temperature": temperature if do_sample else None,
"top_k": top_k if do_sample else None,
"top_p": top_p if do_sample else None,
"num_beams": num_beams if not do_sample else 1,
"max_new_tokens": max_new_tokens
}
output = model.generate(**inputs, **generation_params)
# Decode and clean up response
response = tokenizer.decode(output[0], skip_special_tokens=True)
if 'assistant' in response:
response = response.split('assistant')[1].strip()
return response