Spaces:
Runtime error
Runtime error
File size: 7,128 Bytes
9c197fc 76afc42 9c197fc 76afc42 9c197fc 76afc42 9c197fc 76afc42 9c197fc 76afc42 9c197fc 76afc42 9c197fc 76afc42 9c197fc 76afc42 9c197fc 76afc42 9c197fc 76afc42 9c197fc 76afc42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig # Example for LLM
# from diffusers import StableDiffusionPipeline # Example for Diffusion
from peft import PeftModel
import accelerate # Often needed for device_map='auto'
import os
import time # For basic timing/feedback
# --- Global Placeholder (Alternative: Use gr.State for cleaner state management) ---
# We will use gr.State in the Blocks interface, which is generally preferred.
# loaded_model = None
# loaded_tokenizer = None
# --- Model Loading Function ---
def load_models(base_model_id, lora_model_id, progress=gr.Progress(track_tqdm=True)):
"""Loads the base model and applies the LoRA adapter."""
global loaded_model, loaded_tokenizer # If not using gr.State
model = None
tokenizer = None
status = "Starting model loading..."
progress(0, desc=status)
print(status)
if not base_model_id or not lora_model_id:
return None, None, "Error: Base Model ID and LoRA Model ID cannot be empty."
try:
# --- Load Base Model Tokenizer (for LLMs) ---
status = f"Loading tokenizer for {base_model_id}..."
progress(0.1, desc=status)
print(status)
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
if tokenizer.pad_token is None:
print("Setting pad_token to eos_token")
tokenizer.pad_token = tokenizer.eos_token
# --- Load Base Model ---
# Add quantization or other configs if needed
status = f"Loading base model: {base_model_id}..."
progress(0.3, desc=status)
print(status)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=torch.bfloat16, # Or float16
device_map="auto",
trust_remote_code=True
)
progress(0.7, desc="Base model loaded.")
print("Base model loaded.")
# --- Load LoRA Adapter ---
status = f"Loading LoRA adapter: {lora_model_id}..."
progress(0.8, desc=status)
print(status)
model = PeftModel.from_pretrained(
base_model,
lora_model_id,
)
progress(0.95, desc="LoRA adapter applied.")
print("PEFT LoRA model loaded.")
model.eval() # Set model to evaluation mode
status = "Models loaded successfully!"
progress(1.0, desc=status)
print(status)
# Return the loaded model and tokenizer to be stored in gr.State
return model, tokenizer, status
except Exception as e:
error_msg = f"Error loading models: {str(e)}"
print(error_msg)
# Ensure we return None for model/tokenizer on error
return None, None, error_msg
# --- Inference Function ---
def generate_text(
state_model, state_tokenizer, # Receive model/tokenizer from gr.State
prompt, max_new_tokens, temperature,
progress=gr.Progress(track_tqdm=True)
):
"""Generates text using the loaded model."""
if state_model is None or state_tokenizer is None:
return "Error: Models not loaded. Please load models first."
status = "Tokenizing prompt..."
progress(0.1, desc=status)
print(status)
try:
inputs = state_tokenizer(prompt, return_tensors="pt").to(state_model.device)
status = "Generating text..."
progress(0.3, desc=status)
print(status)
with torch.no_grad():
outputs = state_model.generate(
**inputs,
max_new_tokens=int(max_new_tokens), # Ensure it's int
temperature=float(temperature), # Ensure it's float
pad_token_id=state_tokenizer.pad_token_id
# Add other parameters like top_k, top_p if desired
)
status = "Decoding output..."
progress(0.9, desc=status)
print(status)
result = state_tokenizer.decode(outputs[0], skip_special_tokens=True)
progress(1.0, desc="Generation complete.")
print("Generation complete.")
return result
except Exception as e:
error_msg = f"Error during generation: {str(e)}"
print(error_msg)
return error_msg
# --- Gradio Interface Definition ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
# Using gr.State to hold the loaded model and tokenizer objects
# This state persists within the user's session
model_state = gr.State(None)
tokenizer_state = gr.State(None)
gr.Markdown("# 🎛️ Dynamic LoRA Model Loader & Generator (Gradio)")
gr.Markdown(
"Enter the Hugging Face IDs for the base model and your LoRA adapter repository. "
"Then, load the models and generate text."
"\n**Note:** Ensure your LoRA file is named appropriately (e.g., `adapter_model.safetensors` or specify filename if loader supports it) and your Space has adequate hardware (GPU recommended)."
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## Configuration")
base_model_input = gr.Textbox(
label="Base Model ID (Hugging Face)",
placeholder="e.g., meta-llama/Meta-Llama-3-8B",
value="meta-llama/Meta-Llama-3-8B" # Example default
)
lora_model_input = gr.Textbox(
label="LoRA Model ID (Hugging Face - where lora.safetensors is)",
placeholder="e.g., YourUsername/YourLoraRepo"
)
load_button = gr.Button("Load Models", variant="primary")
status_output = gr.Textbox(label="Loading Status", interactive=False)
with gr.Column(scale=2):
gr.Markdown("## Inference")
prompt_input = gr.Textbox(label="Enter Prompt:", lines=5, placeholder="Once upon a time...")
with gr.Row():
max_tokens_slider = gr.Slider(label="Max New Tokens", minimum=10, maximum=1024, value=200, step=10)
temp_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, value=0.7, step=0.05)
generate_button = gr.Button("Generate Text", variant="primary")
generated_output = gr.Textbox(label="Generated Output", lines=10, interactive=False)
# --- Connect Actions ---
load_button.click(
fn=load_models,
inputs=[base_model_input, lora_model_input],
# Outputs: model state, tokenizer state, status message textbox
outputs=[model_state, tokenizer_state, status_output],
show_progress="full" # Show progress bar
)
generate_button.click(
fn=generate_text,
# Inputs: model state, tokenizer state, prompt, sliders
inputs=[model_state, tokenizer_state, prompt_input, max_tokens_slider, temp_slider],
outputs=[generated_output], # Output: generated text textbox
show_progress="full" # Show progress bar
)
# --- Launch the Gradio App ---
# HF Spaces automatically runs this when deploying app.py
if __name__ == "__main__":
demo.launch() # Use share=True for a public link if running locally |