Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig # Example for LLM | |
# from diffusers import StableDiffusionPipeline # Example for Diffusion | |
from peft import PeftModel | |
import accelerate # Often needed for device_map='auto' | |
import os | |
import time # For basic timing/feedback | |
# --- Global Placeholder (Alternative: Use gr.State for cleaner state management) --- | |
# We will use gr.State in the Blocks interface, which is generally preferred. | |
# loaded_model = None | |
# loaded_tokenizer = None | |
# --- Model Loading Function --- | |
def load_models(base_model_id, lora_model_id, progress=gr.Progress(track_tqdm=True)): | |
"""Loads the base model and applies the LoRA adapter.""" | |
global loaded_model, loaded_tokenizer # If not using gr.State | |
model = None | |
tokenizer = None | |
status = "Starting model loading..." | |
progress(0, desc=status) | |
print(status) | |
if not base_model_id or not lora_model_id: | |
return None, None, "Error: Base Model ID and LoRA Model ID cannot be empty." | |
try: | |
# --- Load Base Model Tokenizer (for LLMs) --- | |
status = f"Loading tokenizer for {base_model_id}..." | |
progress(0.1, desc=status) | |
print(status) | |
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True) | |
if tokenizer.pad_token is None: | |
print("Setting pad_token to eos_token") | |
tokenizer.pad_token = tokenizer.eos_token | |
# --- Load Base Model --- | |
# Add quantization or other configs if needed | |
status = f"Loading base model: {base_model_id}..." | |
progress(0.3, desc=status) | |
print(status) | |
base_model = AutoModelForCausalLM.from_pretrained( | |
base_model_id, | |
torch_dtype=torch.bfloat16, # Or float16 | |
device_map="auto", | |
trust_remote_code=True | |
) | |
progress(0.7, desc="Base model loaded.") | |
print("Base model loaded.") | |
# --- Load LoRA Adapter --- | |
status = f"Loading LoRA adapter: {lora_model_id}..." | |
progress(0.8, desc=status) | |
print(status) | |
model = PeftModel.from_pretrained( | |
base_model, | |
lora_model_id, | |
) | |
progress(0.95, desc="LoRA adapter applied.") | |
print("PEFT LoRA model loaded.") | |
model.eval() # Set model to evaluation mode | |
status = "Models loaded successfully!" | |
progress(1.0, desc=status) | |
print(status) | |
# Return the loaded model and tokenizer to be stored in gr.State | |
return model, tokenizer, status | |
except Exception as e: | |
error_msg = f"Error loading models: {str(e)}" | |
print(error_msg) | |
# Ensure we return None for model/tokenizer on error | |
return None, None, error_msg | |
# --- Inference Function --- | |
def generate_text( | |
state_model, state_tokenizer, # Receive model/tokenizer from gr.State | |
prompt, max_new_tokens, temperature, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
"""Generates text using the loaded model.""" | |
if state_model is None or state_tokenizer is None: | |
return "Error: Models not loaded. Please load models first." | |
status = "Tokenizing prompt..." | |
progress(0.1, desc=status) | |
print(status) | |
try: | |
inputs = state_tokenizer(prompt, return_tensors="pt").to(state_model.device) | |
status = "Generating text..." | |
progress(0.3, desc=status) | |
print(status) | |
with torch.no_grad(): | |
outputs = state_model.generate( | |
**inputs, | |
max_new_tokens=int(max_new_tokens), # Ensure it's int | |
temperature=float(temperature), # Ensure it's float | |
pad_token_id=state_tokenizer.pad_token_id | |
# Add other parameters like top_k, top_p if desired | |
) | |
status = "Decoding output..." | |
progress(0.9, desc=status) | |
print(status) | |
result = state_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
progress(1.0, desc="Generation complete.") | |
print("Generation complete.") | |
return result | |
except Exception as e: | |
error_msg = f"Error during generation: {str(e)}" | |
print(error_msg) | |
return error_msg | |
# --- Gradio Interface Definition --- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
# Using gr.State to hold the loaded model and tokenizer objects | |
# This state persists within the user's session | |
model_state = gr.State(None) | |
tokenizer_state = gr.State(None) | |
gr.Markdown("# 🎛️ Dynamic LoRA Model Loader & Generator (Gradio)") | |
gr.Markdown( | |
"Enter the Hugging Face IDs for the base model and your LoRA adapter repository. " | |
"Then, load the models and generate text." | |
"\n**Note:** Ensure your LoRA file is named appropriately (e.g., `adapter_model.safetensors` or specify filename if loader supports it) and your Space has adequate hardware (GPU recommended)." | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("## Configuration") | |
base_model_input = gr.Textbox( | |
label="Base Model ID (Hugging Face)", | |
placeholder="e.g., meta-llama/Meta-Llama-3-8B", | |
value="meta-llama/Meta-Llama-3-8B" # Example default | |
) | |
lora_model_input = gr.Textbox( | |
label="LoRA Model ID (Hugging Face - where lora.safetensors is)", | |
placeholder="e.g., YourUsername/YourLoraRepo" | |
) | |
load_button = gr.Button("Load Models", variant="primary") | |
status_output = gr.Textbox(label="Loading Status", interactive=False) | |
with gr.Column(scale=2): | |
gr.Markdown("## Inference") | |
prompt_input = gr.Textbox(label="Enter Prompt:", lines=5, placeholder="Once upon a time...") | |
with gr.Row(): | |
max_tokens_slider = gr.Slider(label="Max New Tokens", minimum=10, maximum=1024, value=200, step=10) | |
temp_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, value=0.7, step=0.05) | |
generate_button = gr.Button("Generate Text", variant="primary") | |
generated_output = gr.Textbox(label="Generated Output", lines=10, interactive=False) | |
# --- Connect Actions --- | |
load_button.click( | |
fn=load_models, | |
inputs=[base_model_input, lora_model_input], | |
# Outputs: model state, tokenizer state, status message textbox | |
outputs=[model_state, tokenizer_state, status_output], | |
show_progress="full" # Show progress bar | |
) | |
generate_button.click( | |
fn=generate_text, | |
# Inputs: model state, tokenizer state, prompt, sliders | |
inputs=[model_state, tokenizer_state, prompt_input, max_tokens_slider, temp_slider], | |
outputs=[generated_output], # Output: generated text textbox | |
show_progress="full" # Show progress bar | |
) | |
# --- Launch the Gradio App --- | |
# HF Spaces automatically runs this when deploying app.py | |
if __name__ == "__main__": | |
demo.launch() # Use share=True for a public link if running locally |