finite-element-method / utils /llama_utils.py
mostafa-sh's picture
add 3B model
fd97c8c
import os
import torch
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, PreTrainedTokenizerFast
from peft import PeftModel
import re
import streamlit as st
# Set the cache directory to persistent storage
os.environ["HF_HOME"] = "/data/.cache/huggingface"
#-----------------------------------------
# Quantization Config
#-----------------------------------------
def get_bnb_config():
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_storage=torch.float16
)
#-----------------------------------------
# Base Model Loader
#-----------------------------------------
@st.cache_resource
def load_base_model(base_model_path: str):
"""
Loads a base LLM model with 4-bit quantization and tokenizer.
Args:
base_model_path (str): HF model path
Returns:
model (AutoModelForCausalLM)
tokenizer (PreTrainedTokenizerFast)
"""
bnb_config = get_bnb_config()
tokenizer = PreTrainedTokenizerFast.from_pretrained(base_model_path, return_tensors="pt")
model = AutoModelForCausalLM.from_pretrained(
base_model_path,
quantization_config=bnb_config,
trust_remote_code=True,
attn_implementation="eager",
torch_dtype=torch.float16
)
return model, tokenizer
#-----------------------------------------
# Fine-Tuned Model Loader
#-----------------------------------------
@st.cache_resource
def load_fine_tuned_model(adapter_path: str, base_model_path: str):
"""
Loads the fine-tuned model by applying LoRA adapter to a base model.
Args:
adapter_path (str): Local or HF adapter path
base_model_path (str): Base LLM model path
Returns:
fine_tuned_model (PeftModel)
tokenizer (PreTrainedTokenizerFast)
"""
bnb_config = get_bnb_config()
tokenizer = PreTrainedTokenizerFast.from_pretrained(base_model_path, return_tensors="pt")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
quantization_config=bnb_config,
trust_remote_code=True,
attn_implementation="eager",
torch_dtype=torch.float16
)
fine_tuned_model = PeftModel.from_pretrained(
base_model,
adapter_path,
device_map="auto"
)
return fine_tuned_model, tokenizer
#-----------------------------------------
# Inference Function
#-----------------------------------------
@torch.no_grad()
def generate_response(
model: AutoModelForCausalLM,
tokenizer: PreTrainedTokenizerFast,
messages: list,
tokenizer_max_length: int = 500,
do_sample: bool = False,
temperature: float = 0.1,
top_k: int = 50,
top_p: float = 0.95,
num_beams: int = 1,
max_new_tokens: int = 700
) -> str:
"""
Runs inference on an LLM model.
Args:
model (AutoModelForCausalLM)
tokenizer (PreTrainedTokenizerFast)
messages (list): List of dicts containing 'role' and 'content'
Returns:
str: Model response
"""
# Ensure pad token exists
tokenizer.pad_token = "<|reserved_special_token_5|>"
# Create chat prompt
input_text = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False
)
# Tokenize input
inputs = tokenizer(
input_text,
max_length=tokenizer_max_length,
truncation=True,
return_tensors="pt"
).to(model.device)
# Store the number of input tokens for reference
input_token_length = inputs.input_ids.shape[1]
generation_params = {
"do_sample": do_sample,
"temperature": temperature if do_sample else None,
"top_k": top_k if do_sample else None,
"top_p": top_p if do_sample else None,
"num_beams": num_beams if not do_sample else 1,
"max_new_tokens": max_new_tokens
}
output = model.generate(**inputs, **generation_params)
# Extract only the newly generated tokens
new_tokens = output[0][input_token_length:]
# Decode only the new tokens
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
if 'assistant' in response:
response = response.split('assistant')[1].strip()
# In case there's still any assistant prefix, clean it up
if response.startswith("assistant") or response.startswith("<assistant>"):
response = re.sub(r"^assistant[:\s]*|^<assistant>[\s]*", "", response, flags=re.IGNORECASE)
response = re.sub(r'^\s*(?:answer\s*)+:?\s*', '', response, flags=re.IGNORECASE)
return response