import os import torch from transformers import BitsAndBytesConfig, AutoModelForCausalLM, PreTrainedTokenizerFast from peft import PeftModel import re import streamlit as st # Set the cache directory to persistent storage os.environ["HF_HOME"] = "/data/.cache/huggingface" #----------------------------------------- # Quantization Config #----------------------------------------- def get_bnb_config(): return BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_storage=torch.float16 ) #----------------------------------------- # Base Model Loader #----------------------------------------- @st.cache_resource def load_base_model(base_model_path: str): """ Loads a base LLM model with 4-bit quantization and tokenizer. Args: base_model_path (str): HF model path Returns: model (AutoModelForCausalLM) tokenizer (PreTrainedTokenizerFast) """ bnb_config = get_bnb_config() tokenizer = PreTrainedTokenizerFast.from_pretrained(base_model_path, return_tensors="pt") model = AutoModelForCausalLM.from_pretrained( base_model_path, quantization_config=bnb_config, trust_remote_code=True, attn_implementation="eager", torch_dtype=torch.float16 ) return model, tokenizer #----------------------------------------- # Fine-Tuned Model Loader #----------------------------------------- @st.cache_resource def load_fine_tuned_model(adapter_path: str, base_model_path: str): """ Loads the fine-tuned model by applying LoRA adapter to a base model. Args: adapter_path (str): Local or HF adapter path base_model_path (str): Base LLM model path Returns: fine_tuned_model (PeftModel) tokenizer (PreTrainedTokenizerFast) """ bnb_config = get_bnb_config() tokenizer = PreTrainedTokenizerFast.from_pretrained(base_model_path, return_tensors="pt") base_model = AutoModelForCausalLM.from_pretrained( base_model_path, quantization_config=bnb_config, trust_remote_code=True, attn_implementation="eager", torch_dtype=torch.float16 ) fine_tuned_model = PeftModel.from_pretrained( base_model, adapter_path, device_map="auto" ) return fine_tuned_model, tokenizer #----------------------------------------- # Inference Function #----------------------------------------- @torch.no_grad() def generate_response( model: AutoModelForCausalLM, tokenizer: PreTrainedTokenizerFast, messages: list, tokenizer_max_length: int = 500, do_sample: bool = False, temperature: float = 0.1, top_k: int = 50, top_p: float = 0.95, num_beams: int = 1, max_new_tokens: int = 700 ) -> str: """ Runs inference on an LLM model. Args: model (AutoModelForCausalLM) tokenizer (PreTrainedTokenizerFast) messages (list): List of dicts containing 'role' and 'content' Returns: str: Model response """ # Ensure pad token exists tokenizer.pad_token = "<|reserved_special_token_5|>" # Create chat prompt input_text = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=False ) # Tokenize input inputs = tokenizer( input_text, max_length=tokenizer_max_length, truncation=True, return_tensors="pt" ).to(model.device) # Store the number of input tokens for reference input_token_length = inputs.input_ids.shape[1] generation_params = { "do_sample": do_sample, "temperature": temperature if do_sample else None, "top_k": top_k if do_sample else None, "top_p": top_p if do_sample else None, "num_beams": num_beams if not do_sample else 1, "max_new_tokens": max_new_tokens } output = model.generate(**inputs, **generation_params) # Extract only the newly generated tokens new_tokens = output[0][input_token_length:] # Decode only the new tokens response = tokenizer.decode(new_tokens, skip_special_tokens=True) if 'assistant' in response: response = response.split('assistant')[1].strip() # In case there's still any assistant prefix, clean it up if response.startswith("assistant") or response.startswith(""): response = re.sub(r"^assistant[:\s]*|^[\s]*", "", response, flags=re.IGNORECASE) response = re.sub(r'^\s*(?:answer\s*)+:?\s*', '', response, flags=re.IGNORECASE) return response