Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# Load the Qwen2 0.5B model | |
model_id = "Qwen/Qwen2-0.5B" | |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
def generate_response(prompt, max_length=512, temperature=0.7, top_p=0.9): | |
"""Generate a response from the Qwen2 model based on the input prompt.""" | |
# Tokenize the input prompt | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
# Generate response | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_length, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
) | |
# Decode the response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract only the model's response (remove the input prompt) | |
if prompt in response: | |
response = response[len(prompt):] | |
return response.strip() | |
def process_input( | |
raw_prompt, | |
game_stats_template, | |
template_type, | |
max_length, | |
temperature, | |
top_p | |
): | |
"""Process the input and template to create the final prompt for the model.""" | |
final_prompt = "" | |
if template_type == "Raw Prompt Only": | |
final_prompt = raw_prompt | |
elif template_type == "Template + Prompt": | |
final_prompt = f"{game_stats_template}\n\n{raw_prompt}" | |
elif template_type == "Custom Format": | |
final_prompt = f"{game_stats_template}\n\nBased on the game statistics above, {raw_prompt}" | |
# Generate response from the model | |
response = generate_response( | |
final_prompt, | |
max_length=max_length, | |
temperature=temperature, | |
top_p=top_p | |
) | |
return final_prompt, response | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Qwen2 0.5B Game Analysis Tester") | |
gr.Markdown("Use this interface to test how the Qwen2 0.5B model responds to different prompts about your game statistics.") | |
with gr.Row(): | |
with gr.Column(): | |
template_type = gr.Radio( | |
["Raw Prompt Only", "Template + Prompt", "Custom Format"], | |
label="Prompt Template Type", | |
value="Template + Prompt" | |
) | |
game_stats_template = gr.Textbox( | |
label="Game Statistics Template", | |
placeholder="Enter your game statistics here (scores, round history, etc.)", | |
lines=10 | |
) | |
raw_prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="What do you want the model to analyze or respond to?", | |
lines=3 | |
) | |
with gr.Row(): | |
max_length = gr.Slider( | |
minimum=50, | |
maximum=1024, | |
value=256, | |
step=1, | |
label="Max Response Length" | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=1.5, | |
value=0.7, | |
step=0.1, | |
label="Temperature" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.9, | |
step=0.1, | |
label="Top P" | |
) | |
submit_btn = gr.Button("Generate Response") | |
with gr.Column(): | |
final_prompt_display = gr.Textbox( | |
label="Final Prompt Sent to Model", | |
lines=10 | |
) | |
response_display = gr.Textbox( | |
label="Model Response", | |
lines=15 | |
) | |
submit_btn.click( | |
process_input, | |
inputs=[ | |
raw_prompt, | |
game_stats_template, | |
template_type, | |
max_length, | |
temperature, | |
top_p | |
], | |
outputs=[final_prompt_display, response_display] | |
) | |
gr.Markdown(""" | |
## Tips for Testing | |
1. Start with simple prompts to gauge the model's basic understanding | |
2. Gradually increase complexity to find the model's limitations | |
3. Try different prompt formats to see which works best | |
4. Experiment with temperature and top_p to find optimal settings | |
5. Document which prompts work well as candidates for fine-tuning | |
""") | |
# Launch the demo | |
demo.launch() |