Spaces:
Running
on
L4
Running
on
L4
import os | |
import torch | |
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, PreTrainedTokenizerFast | |
from peft import PeftModel | |
import re | |
import streamlit as st | |
# Set the cache directory to persistent storage | |
os.environ["HF_HOME"] = "/data/.cache/huggingface" | |
#----------------------------------------- | |
# Quantization Config | |
#----------------------------------------- | |
def get_bnb_config(): | |
return BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_storage=torch.float16 | |
) | |
#----------------------------------------- | |
# Base Model Loader | |
#----------------------------------------- | |
def load_base_model(base_model_path: str): | |
""" | |
Loads a base LLM model with 4-bit quantization and tokenizer. | |
Args: | |
base_model_path (str): HF model path | |
Returns: | |
model (AutoModelForCausalLM) | |
tokenizer (PreTrainedTokenizerFast) | |
""" | |
bnb_config = get_bnb_config() | |
tokenizer = PreTrainedTokenizerFast.from_pretrained(base_model_path, return_tensors="pt") | |
model = AutoModelForCausalLM.from_pretrained( | |
base_model_path, | |
quantization_config=bnb_config, | |
trust_remote_code=True, | |
attn_implementation="eager", | |
torch_dtype=torch.float16 | |
) | |
return model, tokenizer | |
#----------------------------------------- | |
# Fine-Tuned Model Loader | |
#----------------------------------------- | |
def load_fine_tuned_model(adapter_path: str, base_model_path: str): | |
""" | |
Loads the fine-tuned model by applying LoRA adapter to a base model. | |
Args: | |
adapter_path (str): Local or HF adapter path | |
base_model_path (str): Base LLM model path | |
Returns: | |
fine_tuned_model (PeftModel) | |
tokenizer (PreTrainedTokenizerFast) | |
""" | |
bnb_config = get_bnb_config() | |
tokenizer = PreTrainedTokenizerFast.from_pretrained(base_model_path, return_tensors="pt") | |
base_model = AutoModelForCausalLM.from_pretrained( | |
base_model_path, | |
quantization_config=bnb_config, | |
trust_remote_code=True, | |
attn_implementation="eager", | |
torch_dtype=torch.float16 | |
) | |
fine_tuned_model = PeftModel.from_pretrained( | |
base_model, | |
adapter_path, | |
device_map="auto" | |
) | |
return fine_tuned_model, tokenizer | |
#----------------------------------------- | |
# Inference Function | |
#----------------------------------------- | |
def generate_response( | |
model: AutoModelForCausalLM, | |
tokenizer: PreTrainedTokenizerFast, | |
messages: list, | |
tokenizer_max_length: int = 500, | |
do_sample: bool = False, | |
temperature: float = 0.1, | |
top_k: int = 50, | |
top_p: float = 0.95, | |
num_beams: int = 1, | |
max_new_tokens: int = 700 | |
) -> str: | |
""" | |
Runs inference on an LLM model. | |
Args: | |
model (AutoModelForCausalLM) | |
tokenizer (PreTrainedTokenizerFast) | |
messages (list): List of dicts containing 'role' and 'content' | |
Returns: | |
str: Model response | |
""" | |
# Ensure pad token exists | |
tokenizer.pad_token = "<|reserved_special_token_5|>" | |
# Create chat prompt | |
input_text = tokenizer.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=False | |
) | |
# Tokenize input | |
inputs = tokenizer( | |
input_text, | |
max_length=tokenizer_max_length, | |
truncation=True, | |
return_tensors="pt" | |
).to(model.device) | |
# Store the number of input tokens for reference | |
input_token_length = inputs.input_ids.shape[1] | |
generation_params = { | |
"do_sample": do_sample, | |
"temperature": temperature if do_sample else None, | |
"top_k": top_k if do_sample else None, | |
"top_p": top_p if do_sample else None, | |
"num_beams": num_beams if not do_sample else 1, | |
"max_new_tokens": max_new_tokens | |
} | |
output = model.generate(**inputs, **generation_params) | |
# Extract only the newly generated tokens | |
new_tokens = output[0][input_token_length:] | |
# Decode only the new tokens | |
response = tokenizer.decode(new_tokens, skip_special_tokens=True) | |
if 'assistant' in response: | |
response = response.split('assistant')[1].strip() | |
# In case there's still any assistant prefix, clean it up | |
if response.startswith("assistant") or response.startswith("<assistant>"): | |
response = re.sub(r"^assistant[:\s]*|^<assistant>[\s]*", "", response, flags=re.IGNORECASE) | |
response = re.sub(r'^\s*(?:answer\s*)+:?\s*', '', response, flags=re.IGNORECASE) | |
return response |