finite-element-method / utils /llama_utils.py
mostafa-sh's picture
model output answer extraction
2b27faa
raw
history blame
4.67 kB
import os
import torch
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, PreTrainedTokenizerFast
from peft import PeftModel
import re
import streamlit as st
# Set the cache directory to persistent storage
os.environ["HF_HOME"] = "/data/.cache/huggingface"
#-----------------------------------------
# Quantization Config
#-----------------------------------------
def get_bnb_config():
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_storage=torch.float16
)
#-----------------------------------------
# Base Model Loader
#-----------------------------------------
@st.cache_resource
def load_base_model(base_model_path: str):
"""
Loads a base LLM model with 4-bit quantization and tokenizer.
Args:
base_model_path (str): HF model path
Returns:
model (AutoModelForCausalLM)
tokenizer (PreTrainedTokenizerFast)
"""
bnb_config = get_bnb_config()
tokenizer = PreTrainedTokenizerFast.from_pretrained(base_model_path, return_tensors="pt")
model = AutoModelForCausalLM.from_pretrained(
base_model_path,
quantization_config=bnb_config,
trust_remote_code=True,
attn_implementation="eager",
torch_dtype=torch.float16
)
return model, tokenizer
#-----------------------------------------
# Fine-Tuned Model Loader
#-----------------------------------------
@st.cache_resource
def load_fine_tuned_model(adapter_path: str, base_model_path: str):
"""
Loads the fine-tuned model by applying LoRA adapter to a base model.
Args:
adapter_path (str): Local or HF adapter path
base_model_path (str): Base LLM model path
Returns:
fine_tuned_model (PeftModel)
tokenizer (PreTrainedTokenizerFast)
"""
bnb_config = get_bnb_config()
tokenizer = PreTrainedTokenizerFast.from_pretrained(base_model_path, return_tensors="pt")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
quantization_config=bnb_config,
trust_remote_code=True,
attn_implementation="eager",
torch_dtype=torch.float16
)
fine_tuned_model = PeftModel.from_pretrained(
base_model,
adapter_path,
device_map="auto"
)
return fine_tuned_model, tokenizer
#-----------------------------------------
# Inference Function
#-----------------------------------------
@torch.no_grad()
def generate_response(
model: AutoModelForCausalLM,
tokenizer: PreTrainedTokenizerFast,
messages: list,
do_sample: bool = False,
temperature: float = 0.7,
top_k: int = 50,
top_p: float = 0.95,
num_beams: int = 1,
max_new_tokens: int = 500
) -> str:
"""
Runs inference on an LLM model.
Args:
model (AutoModelForCausalLM)
tokenizer (PreTrainedTokenizerFast)
messages (list): List of dicts containing 'role' and 'content'
Returns:
str: Model response
"""
# Ensure pad token exists
tokenizer.pad_token = "<|reserved_special_token_5|>"
# Create chat prompt
input_text = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False
)
# Tokenize input
inputs = tokenizer(
input_text,
max_length=500,
truncation=True,
return_tensors="pt"
).to(model.device)
# Store the number of input tokens for reference
input_token_length = inputs.input_ids.shape[1]
generation_params = {
"do_sample": do_sample,
"temperature": temperature if do_sample else None,
"top_k": top_k if do_sample else None,
"top_p": top_p if do_sample else None,
"num_beams": num_beams if not do_sample else 1,
"max_new_tokens": max_new_tokens
}
output = model.generate(**inputs, **generation_params)
# Extract only the newly generated tokens
new_tokens = output[0][input_token_length:]
# Decode only the new tokens
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
if 'assistant' in response:
response = response.split('assistant')[1].strip()
# In case there's still any assistant prefix, clean it up
if response.startswith("assistant") or response.startswith("<assistant>"):
response = re.sub(r"^assistant[:\s]*|^<assistant>[\s]*", "", response, flags=re.IGNORECASE)
response = re.sub(r'^\s*(?:answer\s*)+:?\s*', '', response, flags=re.IGNORECASE)
return response