Gradio-Lora / app.py
ramimu's picture
Update app.py
76afc42 verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig # Example for LLM
# from diffusers import StableDiffusionPipeline # Example for Diffusion
from peft import PeftModel
import accelerate # Often needed for device_map='auto'
import os
import time # For basic timing/feedback
# --- Global Placeholder (Alternative: Use gr.State for cleaner state management) ---
# We will use gr.State in the Blocks interface, which is generally preferred.
# loaded_model = None
# loaded_tokenizer = None
# --- Model Loading Function ---
def load_models(base_model_id, lora_model_id, progress=gr.Progress(track_tqdm=True)):
"""Loads the base model and applies the LoRA adapter."""
global loaded_model, loaded_tokenizer # If not using gr.State
model = None
tokenizer = None
status = "Starting model loading..."
progress(0, desc=status)
print(status)
if not base_model_id or not lora_model_id:
return None, None, "Error: Base Model ID and LoRA Model ID cannot be empty."
try:
# --- Load Base Model Tokenizer (for LLMs) ---
status = f"Loading tokenizer for {base_model_id}..."
progress(0.1, desc=status)
print(status)
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
if tokenizer.pad_token is None:
print("Setting pad_token to eos_token")
tokenizer.pad_token = tokenizer.eos_token
# --- Load Base Model ---
# Add quantization or other configs if needed
status = f"Loading base model: {base_model_id}..."
progress(0.3, desc=status)
print(status)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=torch.bfloat16, # Or float16
device_map="auto",
trust_remote_code=True
)
progress(0.7, desc="Base model loaded.")
print("Base model loaded.")
# --- Load LoRA Adapter ---
status = f"Loading LoRA adapter: {lora_model_id}..."
progress(0.8, desc=status)
print(status)
model = PeftModel.from_pretrained(
base_model,
lora_model_id,
)
progress(0.95, desc="LoRA adapter applied.")
print("PEFT LoRA model loaded.")
model.eval() # Set model to evaluation mode
status = "Models loaded successfully!"
progress(1.0, desc=status)
print(status)
# Return the loaded model and tokenizer to be stored in gr.State
return model, tokenizer, status
except Exception as e:
error_msg = f"Error loading models: {str(e)}"
print(error_msg)
# Ensure we return None for model/tokenizer on error
return None, None, error_msg
# --- Inference Function ---
def generate_text(
state_model, state_tokenizer, # Receive model/tokenizer from gr.State
prompt, max_new_tokens, temperature,
progress=gr.Progress(track_tqdm=True)
):
"""Generates text using the loaded model."""
if state_model is None or state_tokenizer is None:
return "Error: Models not loaded. Please load models first."
status = "Tokenizing prompt..."
progress(0.1, desc=status)
print(status)
try:
inputs = state_tokenizer(prompt, return_tensors="pt").to(state_model.device)
status = "Generating text..."
progress(0.3, desc=status)
print(status)
with torch.no_grad():
outputs = state_model.generate(
**inputs,
max_new_tokens=int(max_new_tokens), # Ensure it's int
temperature=float(temperature), # Ensure it's float
pad_token_id=state_tokenizer.pad_token_id
# Add other parameters like top_k, top_p if desired
)
status = "Decoding output..."
progress(0.9, desc=status)
print(status)
result = state_tokenizer.decode(outputs[0], skip_special_tokens=True)
progress(1.0, desc="Generation complete.")
print("Generation complete.")
return result
except Exception as e:
error_msg = f"Error during generation: {str(e)}"
print(error_msg)
return error_msg
# --- Gradio Interface Definition ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
# Using gr.State to hold the loaded model and tokenizer objects
# This state persists within the user's session
model_state = gr.State(None)
tokenizer_state = gr.State(None)
gr.Markdown("# 🎛️ Dynamic LoRA Model Loader & Generator (Gradio)")
gr.Markdown(
"Enter the Hugging Face IDs for the base model and your LoRA adapter repository. "
"Then, load the models and generate text."
"\n**Note:** Ensure your LoRA file is named appropriately (e.g., `adapter_model.safetensors` or specify filename if loader supports it) and your Space has adequate hardware (GPU recommended)."
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## Configuration")
base_model_input = gr.Textbox(
label="Base Model ID (Hugging Face)",
placeholder="e.g., meta-llama/Meta-Llama-3-8B",
value="meta-llama/Meta-Llama-3-8B" # Example default
)
lora_model_input = gr.Textbox(
label="LoRA Model ID (Hugging Face - where lora.safetensors is)",
placeholder="e.g., YourUsername/YourLoraRepo"
)
load_button = gr.Button("Load Models", variant="primary")
status_output = gr.Textbox(label="Loading Status", interactive=False)
with gr.Column(scale=2):
gr.Markdown("## Inference")
prompt_input = gr.Textbox(label="Enter Prompt:", lines=5, placeholder="Once upon a time...")
with gr.Row():
max_tokens_slider = gr.Slider(label="Max New Tokens", minimum=10, maximum=1024, value=200, step=10)
temp_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, value=0.7, step=0.05)
generate_button = gr.Button("Generate Text", variant="primary")
generated_output = gr.Textbox(label="Generated Output", lines=10, interactive=False)
# --- Connect Actions ---
load_button.click(
fn=load_models,
inputs=[base_model_input, lora_model_input],
# Outputs: model state, tokenizer state, status message textbox
outputs=[model_state, tokenizer_state, status_output],
show_progress="full" # Show progress bar
)
generate_button.click(
fn=generate_text,
# Inputs: model state, tokenizer state, prompt, sliders
inputs=[model_state, tokenizer_state, prompt_input, max_tokens_slider, temp_slider],
outputs=[generated_output], # Output: generated text textbox
show_progress="full" # Show progress bar
)
# --- Launch the Gradio App ---
# HF Spaces automatically runs this when deploying app.py
if __name__ == "__main__":
demo.launch() # Use share=True for a public link if running locally