Spaces:
Running
on
L4
Running
on
L4
File size: 4,724 Bytes
6f5e4d4 b849b51 2b27faa 737a09d b849b51 6f5e4d4 b849b51 737a09d b849b51 737a09d b849b51 fd97c8c b849b51 fd97c8c b849b51 fd97c8c b849b51 fd97c8c b849b51 2b27faa b849b51 2b27faa b849b51 2b27faa fd97c8c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import os
import torch
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, PreTrainedTokenizerFast
from peft import PeftModel
import re
import streamlit as st
# Set the cache directory to persistent storage
os.environ["HF_HOME"] = "/data/.cache/huggingface"
#-----------------------------------------
# Quantization Config
#-----------------------------------------
def get_bnb_config():
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_storage=torch.float16
)
#-----------------------------------------
# Base Model Loader
#-----------------------------------------
@st.cache_resource
def load_base_model(base_model_path: str):
"""
Loads a base LLM model with 4-bit quantization and tokenizer.
Args:
base_model_path (str): HF model path
Returns:
model (AutoModelForCausalLM)
tokenizer (PreTrainedTokenizerFast)
"""
bnb_config = get_bnb_config()
tokenizer = PreTrainedTokenizerFast.from_pretrained(base_model_path, return_tensors="pt")
model = AutoModelForCausalLM.from_pretrained(
base_model_path,
quantization_config=bnb_config,
trust_remote_code=True,
attn_implementation="eager",
torch_dtype=torch.float16
)
return model, tokenizer
#-----------------------------------------
# Fine-Tuned Model Loader
#-----------------------------------------
@st.cache_resource
def load_fine_tuned_model(adapter_path: str, base_model_path: str):
"""
Loads the fine-tuned model by applying LoRA adapter to a base model.
Args:
adapter_path (str): Local or HF adapter path
base_model_path (str): Base LLM model path
Returns:
fine_tuned_model (PeftModel)
tokenizer (PreTrainedTokenizerFast)
"""
bnb_config = get_bnb_config()
tokenizer = PreTrainedTokenizerFast.from_pretrained(base_model_path, return_tensors="pt")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
quantization_config=bnb_config,
trust_remote_code=True,
attn_implementation="eager",
torch_dtype=torch.float16
)
fine_tuned_model = PeftModel.from_pretrained(
base_model,
adapter_path,
device_map="auto"
)
return fine_tuned_model, tokenizer
#-----------------------------------------
# Inference Function
#-----------------------------------------
@torch.no_grad()
def generate_response(
model: AutoModelForCausalLM,
tokenizer: PreTrainedTokenizerFast,
messages: list,
tokenizer_max_length: int = 500,
do_sample: bool = False,
temperature: float = 0.1,
top_k: int = 50,
top_p: float = 0.95,
num_beams: int = 1,
max_new_tokens: int = 700
) -> str:
"""
Runs inference on an LLM model.
Args:
model (AutoModelForCausalLM)
tokenizer (PreTrainedTokenizerFast)
messages (list): List of dicts containing 'role' and 'content'
Returns:
str: Model response
"""
# Ensure pad token exists
tokenizer.pad_token = "<|reserved_special_token_5|>"
# Create chat prompt
input_text = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False
)
# Tokenize input
inputs = tokenizer(
input_text,
max_length=tokenizer_max_length,
truncation=True,
return_tensors="pt"
).to(model.device)
# Store the number of input tokens for reference
input_token_length = inputs.input_ids.shape[1]
generation_params = {
"do_sample": do_sample,
"temperature": temperature if do_sample else None,
"top_k": top_k if do_sample else None,
"top_p": top_p if do_sample else None,
"num_beams": num_beams if not do_sample else 1,
"max_new_tokens": max_new_tokens
}
output = model.generate(**inputs, **generation_params)
# Extract only the newly generated tokens
new_tokens = output[0][input_token_length:]
# Decode only the new tokens
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
if 'assistant' in response:
response = response.split('assistant')[1].strip()
# In case there's still any assistant prefix, clean it up
if response.startswith("assistant") or response.startswith("<assistant>"):
response = re.sub(r"^assistant[:\s]*|^<assistant>[\s]*", "", response, flags=re.IGNORECASE)
response = re.sub(r'^\s*(?:answer\s*)+:?\s*', '', response, flags=re.IGNORECASE)
return response |