import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig # Example for LLM # from diffusers import StableDiffusionPipeline # Example for Diffusion from peft import PeftModel import accelerate # Often needed for device_map='auto' import os import time # For basic timing/feedback # --- Global Placeholder (Alternative: Use gr.State for cleaner state management) --- # We will use gr.State in the Blocks interface, which is generally preferred. # loaded_model = None # loaded_tokenizer = None # --- Model Loading Function --- def load_models(base_model_id, lora_model_id, progress=gr.Progress(track_tqdm=True)): """Loads the base model and applies the LoRA adapter.""" global loaded_model, loaded_tokenizer # If not using gr.State model = None tokenizer = None status = "Starting model loading..." progress(0, desc=status) print(status) if not base_model_id or not lora_model_id: return None, None, "Error: Base Model ID and LoRA Model ID cannot be empty." try: # --- Load Base Model Tokenizer (for LLMs) --- status = f"Loading tokenizer for {base_model_id}..." progress(0.1, desc=status) print(status) tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True) if tokenizer.pad_token is None: print("Setting pad_token to eos_token") tokenizer.pad_token = tokenizer.eos_token # --- Load Base Model --- # Add quantization or other configs if needed status = f"Loading base model: {base_model_id}..." progress(0.3, desc=status) print(status) base_model = AutoModelForCausalLM.from_pretrained( base_model_id, torch_dtype=torch.bfloat16, # Or float16 device_map="auto", trust_remote_code=True ) progress(0.7, desc="Base model loaded.") print("Base model loaded.") # --- Load LoRA Adapter --- status = f"Loading LoRA adapter: {lora_model_id}..." progress(0.8, desc=status) print(status) model = PeftModel.from_pretrained( base_model, lora_model_id, ) progress(0.95, desc="LoRA adapter applied.") print("PEFT LoRA model loaded.") model.eval() # Set model to evaluation mode status = "Models loaded successfully!" progress(1.0, desc=status) print(status) # Return the loaded model and tokenizer to be stored in gr.State return model, tokenizer, status except Exception as e: error_msg = f"Error loading models: {str(e)}" print(error_msg) # Ensure we return None for model/tokenizer on error return None, None, error_msg # --- Inference Function --- def generate_text( state_model, state_tokenizer, # Receive model/tokenizer from gr.State prompt, max_new_tokens, temperature, progress=gr.Progress(track_tqdm=True) ): """Generates text using the loaded model.""" if state_model is None or state_tokenizer is None: return "Error: Models not loaded. Please load models first." status = "Tokenizing prompt..." progress(0.1, desc=status) print(status) try: inputs = state_tokenizer(prompt, return_tensors="pt").to(state_model.device) status = "Generating text..." progress(0.3, desc=status) print(status) with torch.no_grad(): outputs = state_model.generate( **inputs, max_new_tokens=int(max_new_tokens), # Ensure it's int temperature=float(temperature), # Ensure it's float pad_token_id=state_tokenizer.pad_token_id # Add other parameters like top_k, top_p if desired ) status = "Decoding output..." progress(0.9, desc=status) print(status) result = state_tokenizer.decode(outputs[0], skip_special_tokens=True) progress(1.0, desc="Generation complete.") print("Generation complete.") return result except Exception as e: error_msg = f"Error during generation: {str(e)}" print(error_msg) return error_msg # --- Gradio Interface Definition --- with gr.Blocks(theme=gr.themes.Soft()) as demo: # Using gr.State to hold the loaded model and tokenizer objects # This state persists within the user's session model_state = gr.State(None) tokenizer_state = gr.State(None) gr.Markdown("# 🎛️ Dynamic LoRA Model Loader & Generator (Gradio)") gr.Markdown( "Enter the Hugging Face IDs for the base model and your LoRA adapter repository. " "Then, load the models and generate text." "\n**Note:** Ensure your LoRA file is named appropriately (e.g., `adapter_model.safetensors` or specify filename if loader supports it) and your Space has adequate hardware (GPU recommended)." ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("## Configuration") base_model_input = gr.Textbox( label="Base Model ID (Hugging Face)", placeholder="e.g., meta-llama/Meta-Llama-3-8B", value="meta-llama/Meta-Llama-3-8B" # Example default ) lora_model_input = gr.Textbox( label="LoRA Model ID (Hugging Face - where lora.safetensors is)", placeholder="e.g., YourUsername/YourLoraRepo" ) load_button = gr.Button("Load Models", variant="primary") status_output = gr.Textbox(label="Loading Status", interactive=False) with gr.Column(scale=2): gr.Markdown("## Inference") prompt_input = gr.Textbox(label="Enter Prompt:", lines=5, placeholder="Once upon a time...") with gr.Row(): max_tokens_slider = gr.Slider(label="Max New Tokens", minimum=10, maximum=1024, value=200, step=10) temp_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, value=0.7, step=0.05) generate_button = gr.Button("Generate Text", variant="primary") generated_output = gr.Textbox(label="Generated Output", lines=10, interactive=False) # --- Connect Actions --- load_button.click( fn=load_models, inputs=[base_model_input, lora_model_input], # Outputs: model state, tokenizer state, status message textbox outputs=[model_state, tokenizer_state, status_output], show_progress="full" # Show progress bar ) generate_button.click( fn=generate_text, # Inputs: model state, tokenizer state, prompt, sliders inputs=[model_state, tokenizer_state, prompt_input, max_tokens_slider, temp_slider], outputs=[generated_output], # Output: generated text textbox show_progress="full" # Show progress bar ) # --- Launch the Gradio App --- # HF Spaces automatically runs this when deploying app.py if __name__ == "__main__": demo.launch() # Use share=True for a public link if running locally