File size: 7,128 Bytes
9c197fc
 
76afc42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c197fc
76afc42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c197fc
 
76afc42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c197fc
76afc42
 
 
 
 
 
 
9c197fc
76afc42
 
 
9c197fc
76afc42
 
9c197fc
76afc42
 
 
9c197fc
76afc42
 
 
 
 
 
 
 
 
 
 
 
 
9c197fc
76afc42
 
 
 
 
 
9c197fc
 
76afc42
 
9c197fc
76afc42
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig # Example for LLM
# from diffusers import StableDiffusionPipeline # Example for Diffusion
from peft import PeftModel
import accelerate # Often needed for device_map='auto'
import os
import time # For basic timing/feedback

# --- Global Placeholder (Alternative: Use gr.State for cleaner state management) ---
# We will use gr.State in the Blocks interface, which is generally preferred.
# loaded_model = None
# loaded_tokenizer = None

# --- Model Loading Function ---
def load_models(base_model_id, lora_model_id, progress=gr.Progress(track_tqdm=True)):
    """Loads the base model and applies the LoRA adapter."""
    global loaded_model, loaded_tokenizer # If not using gr.State
    model = None
    tokenizer = None
    status = "Starting model loading..."
    progress(0, desc=status)
    print(status)

    if not base_model_id or not lora_model_id:
        return None, None, "Error: Base Model ID and LoRA Model ID cannot be empty."

    try:
        # --- Load Base Model Tokenizer (for LLMs) ---
        status = f"Loading tokenizer for {base_model_id}..."
        progress(0.1, desc=status)
        print(status)
        tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
        if tokenizer.pad_token is None:
            print("Setting pad_token to eos_token")
            tokenizer.pad_token = tokenizer.eos_token

        # --- Load Base Model ---
        # Add quantization or other configs if needed
        status = f"Loading base model: {base_model_id}..."
        progress(0.3, desc=status)
        print(status)
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_id,
            torch_dtype=torch.bfloat16, # Or float16
            device_map="auto",
            trust_remote_code=True
        )
        progress(0.7, desc="Base model loaded.")
        print("Base model loaded.")

        # --- Load LoRA Adapter ---
        status = f"Loading LoRA adapter: {lora_model_id}..."
        progress(0.8, desc=status)
        print(status)
        model = PeftModel.from_pretrained(
            base_model,
            lora_model_id,
        )
        progress(0.95, desc="LoRA adapter applied.")
        print("PEFT LoRA model loaded.")

        model.eval() # Set model to evaluation mode
        status = "Models loaded successfully!"
        progress(1.0, desc=status)
        print(status)
        # Return the loaded model and tokenizer to be stored in gr.State
        return model, tokenizer, status

    except Exception as e:
        error_msg = f"Error loading models: {str(e)}"
        print(error_msg)
        # Ensure we return None for model/tokenizer on error
        return None, None, error_msg

# --- Inference Function ---
def generate_text(
    state_model, state_tokenizer, # Receive model/tokenizer from gr.State
    prompt, max_new_tokens, temperature,
    progress=gr.Progress(track_tqdm=True)
):
    """Generates text using the loaded model."""
    if state_model is None or state_tokenizer is None:
        return "Error: Models not loaded. Please load models first."

    status = "Tokenizing prompt..."
    progress(0.1, desc=status)
    print(status)
    try:
        inputs = state_tokenizer(prompt, return_tensors="pt").to(state_model.device)

        status = "Generating text..."
        progress(0.3, desc=status)
        print(status)
        with torch.no_grad():
            outputs = state_model.generate(
                **inputs,
                max_new_tokens=int(max_new_tokens), # Ensure it's int
                temperature=float(temperature),   # Ensure it's float
                pad_token_id=state_tokenizer.pad_token_id
                # Add other parameters like top_k, top_p if desired
            )

        status = "Decoding output..."
        progress(0.9, desc=status)
        print(status)
        result = state_tokenizer.decode(outputs[0], skip_special_tokens=True)
        progress(1.0, desc="Generation complete.")
        print("Generation complete.")
        return result

    except Exception as e:
        error_msg = f"Error during generation: {str(e)}"
        print(error_msg)
        return error_msg


# --- Gradio Interface Definition ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    # Using gr.State to hold the loaded model and tokenizer objects
    # This state persists within the user's session
    model_state = gr.State(None)
    tokenizer_state = gr.State(None)

    gr.Markdown("# 🎛️ Dynamic LoRA Model Loader & Generator (Gradio)")
    gr.Markdown(
        "Enter the Hugging Face IDs for the base model and your LoRA adapter repository. "
        "Then, load the models and generate text."
        "\n**Note:** Ensure your LoRA file is named appropriately (e.g., `adapter_model.safetensors` or specify filename if loader supports it) and your Space has adequate hardware (GPU recommended)."
    )

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("## Configuration")
            base_model_input = gr.Textbox(
                label="Base Model ID (Hugging Face)",
                placeholder="e.g., meta-llama/Meta-Llama-3-8B",
                value="meta-llama/Meta-Llama-3-8B" # Example default
            )
            lora_model_input = gr.Textbox(
                label="LoRA Model ID (Hugging Face - where lora.safetensors is)",
                placeholder="e.g., YourUsername/YourLoraRepo"
            )
            load_button = gr.Button("Load Models", variant="primary")
            status_output = gr.Textbox(label="Loading Status", interactive=False)

        with gr.Column(scale=2):
            gr.Markdown("## Inference")
            prompt_input = gr.Textbox(label="Enter Prompt:", lines=5, placeholder="Once upon a time...")
            with gr.Row():
                max_tokens_slider = gr.Slider(label="Max New Tokens", minimum=10, maximum=1024, value=200, step=10)
                temp_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, value=0.7, step=0.05)
            generate_button = gr.Button("Generate Text", variant="primary")
            generated_output = gr.Textbox(label="Generated Output", lines=10, interactive=False)

    # --- Connect Actions ---
    load_button.click(
        fn=load_models,
        inputs=[base_model_input, lora_model_input],
        # Outputs: model state, tokenizer state, status message textbox
        outputs=[model_state, tokenizer_state, status_output],
        show_progress="full" # Show progress bar
    )

    generate_button.click(
        fn=generate_text,
        # Inputs: model state, tokenizer state, prompt, sliders
        inputs=[model_state, tokenizer_state, prompt_input, max_tokens_slider, temp_slider],
        outputs=[generated_output], # Output: generated text textbox
        show_progress="full" # Show progress bar
    )

# --- Launch the Gradio App ---
# HF Spaces automatically runs this when deploying app.py
if __name__ == "__main__":
    demo.launch() # Use share=True for a public link if running locally