phi2-grpo / app.py
padmanabhbosamia's picture
Cosmetic Changes
2e73311 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from rich.console import Console
import time
# Initialize rich console for better logging
console = Console()
# Load the model and tokenizer with the same configuration as training
console.print("[bold green]Loading model and tokenizer...[/bold green]")
# Load model with memory optimizations
model = AutoModelForCausalLM.from_pretrained(
"./fine-tuned-model",
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16, # Use float16 for memory efficiency
low_cpu_mem_usage=True, # Add this for better memory handling
)
tokenizer = AutoTokenizer.from_pretrained("./fine-tuned-model")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
# Load base model for before/after comparison
console.print("[bold green]Loading base model for comparison...[/bold green]")
base_model = AutoModelForCausalLM.from_pretrained(
"microsoft/phi-2",
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16,
low_cpu_mem_usage=True, # Add this for better memory handling
)
def generate_response(
prompt,
max_length=128, # Match training max_length
temperature=0.7,
top_p=0.9,
num_generations=2, # Match training num_generations
repetition_penalty=1.1,
do_sample=True,
show_comparison=True, # New parameter for comparison toggle
):
try:
# Get the device of the model
device = next(model.parameters()).device
# Tokenize the input
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
# Move inputs to the same device as the model
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate response from fine-tuned model
with torch.no_grad(): # Disable gradient computation
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
num_return_sequences=num_generations,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Decode and return the responses
responses = []
for output in outputs:
response = tokenizer.decode(output, skip_special_tokens=True)
responses.append(response)
fine_tuned_response = "\n\n---\n\n".join(responses)
if show_comparison:
# Generate response from base model
with torch.no_grad():
base_outputs = base_model.generate(
**inputs,
max_new_tokens=max_length,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
num_return_sequences=1, # Only one for comparison
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
base_response = tokenizer.decode(base_outputs[0], skip_special_tokens=True)
return f"""
### Before Fine-tuning (Base Model)
{base_response}
### After Fine-tuning
{fine_tuned_response}
"""
else:
return fine_tuned_response
except Exception as e:
console.print(f"[bold red]Error during generation: {str(e)}[/bold red]")
return f"Error: {str(e)}"
# Create custom CSS for better UI
custom_css = """
.gradio-container {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
.container {
max-width: 800px;
margin: auto;
padding: 20px;
}
.title {
text-align: center;
color: #2c3e50;
margin-bottom: 20px;
}
.description {
color: #34495e;
line-height: 1.6;
margin-bottom: 20px;
}
.comparison {
background-color: #f8f9fa;
padding: 15px;
border-radius: 8px;
margin: 10px 0;
}
.prompt-box {
background-color: #ffffff;
border: 2px solid #3498db;
border-radius: 8px;
padding: 15px;
margin-bottom: 20px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.prompt-box label {
font-size: 1.1em;
font-weight: bold;
color: #2c3e50;
margin-bottom: 10px;
display: block;
}
.prompt-box textarea {
width: 100%;
min-height: 100px;
padding: 10px;
border: 1px solid #bdc3c7;
border-radius: 4px;
font-size: 1em;
line-height: 1.5;
}
.output-box {
background-color: #ffffff;
border: 2px solid #2ecc71;
border-radius: 8px;
padding: 20px;
margin-top: 20px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.output-box label {
font-size: 1.1em;
font-weight: bold;
color: #2c3e50;
margin-bottom: 15px;
display: block;
}
.output-box .markdown {
background-color: #f8f9fa;
padding: 15px;
border-radius: 6px;
border: 1px solid #e9ecef;
}
.output-box h3 {
color: #2c3e50;
border-bottom: 2px solid #3498db;
padding-bottom: 8px;
margin-top: 20px;
}
.output-box p {
line-height: 1.6;
color: #34495e;
margin: 10px 0;
}
.loading {
display: flex;
align-items: center;
justify-content: center;
padding: 20px;
background-color: #f8f9fa;
border-radius: 8px;
margin: 10px 0;
}
.loading-spinner {
width: 40px;
height: 40px;
border: 4px solid #f3f3f3;
border-top: 4px solid #3498db;
border-radius: 50%;
animation: spin 1s linear infinite;
margin-right: 15px;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.loading-text {
color: #2c3e50;
font-size: 1.1em;
font-weight: 500;
}
"""
# Create the Gradio interface with enhanced UI
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Phi-2 Fine-tuned with GRPO and qLoRA
This model has been fine-tuned using GRPO (Generative Reward-Penalized Optimization) and compressed using qLoRA.
Try it out with different prompts and generation parameters!
""",
elem_classes="title"
)
with gr.Row():
with gr.Column(scale=2):
with gr.Column(elem_classes="prompt-box"):
prompt = gr.Textbox(
label="Enter Your Prompt Here",
placeholder="Type your prompt here... (e.g., 'What is machine learning?' or 'Write a story about a robot learning to paint')",
lines=5,
show_label=True,
)
with gr.Row():
with gr.Column():
max_length = gr.Slider(
minimum=32,
maximum=256,
value=128,
step=32,
label="Max Length",
info="Maximum number of tokens to generate"
)
temperature = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperature",
info="Higher values make output more random, lower values more deterministic"
)
with gr.Column():
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.1,
label="Top-p",
info="Nucleus sampling parameter"
)
num_generations = gr.Slider(
minimum=1,
maximum=4,
value=2,
step=1,
label="Number of Generations",
info="Number of different responses to generate"
)
with gr.Row():
with gr.Column():
repetition_penalty = gr.Slider(
minimum=1.0,
maximum=2.0,
value=1.1,
step=0.1,
label="Repetition Penalty",
info="Higher values prevent repetition"
)
with gr.Column():
do_sample = gr.Checkbox(
value=True,
label="Enable Sampling",
info="Enable/disable sampling for deterministic output"
)
show_comparison = gr.Checkbox(
value=True,
label="Show Before/After Comparison",
info="Toggle to show responses from both base and fine-tuned models"
)
generate_btn = gr.Button("Generate", variant="primary", size="large")
with gr.Column(scale=3):
with gr.Column(elem_classes="output-box"):
output = gr.Markdown(
label="Generated Response(s)",
show_label=True,
value="Your generated responses will appear here...", # Add default value
)
loading_status = gr.Markdown(
value="",
show_label=False,
elem_classes="loading"
)
gr.Markdown(
"""
### Example Prompts
Try these example prompts to test the model:
1. **Technical Questions**:
- "What is machine learning?"
- "What is deep learning?"
- "What is the difference between supervised and unsupervised learning?"
2. **Creative Writing**:
- "Write a short story about a robot learning to paint."
- "Write a story about a time-traveling smartphone."
- "Write a fairy tale about a computer learning to dream."
- "Create a story about an AI becoming an artist."
3. **Technical Explanations**:
- "How does neural network training work?"
- "Explain quantum computing in simple terms."
- "What is transfer learning?"
4. **Creative Tasks**:
- "Write a poem about artificial intelligence."
- "Write a poem about the future of technology."
- "Create a story about a robot learning to dream."
""",
elem_classes="description"
)
def generate_with_status(*args):
# Show loading status
loading_status.value = """
<div class="loading">
<div class="loading-spinner"></div>
<div class="loading-text">Generating responses... Please wait...</div>
</div>
"""
# Generate response
result = generate_response(*args)
# Clear loading status
loading_status.value = ""
return result
# Connect the interface
generate_btn.click(
fn=generate_with_status,
inputs=[
prompt,
max_length,
temperature,
top_p,
num_generations,
repetition_penalty,
do_sample,
show_comparison
],
outputs=output
)
if __name__ == "__main__":
console.print("[bold green]Starting Gradio interface...[/bold green]")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True # Enable sharing for HuggingFace Spaces
)