rui3000 commited on
Commit
23a7862
·
verified ·
1 Parent(s): 0f898b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -111
app.py CHANGED
@@ -2,154 +2,152 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
- # Load model and tokenizer
6
- model_name = "Qwen/Qwen2-0.5B"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
8
  model = AutoModelForCausalLM.from_pretrained(
9
- model_name,
10
- device_map="auto",
11
  torch_dtype=torch.float16,
 
12
  trust_remote_code=True
13
  )
14
 
15
- def generate_response(prompt, max_length=300, temperature=0.7):
 
 
 
16
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
17
 
 
18
  with torch.no_grad():
19
  outputs = model.generate(
20
  **inputs,
21
  max_new_tokens=max_length,
22
  do_sample=True,
23
  temperature=temperature,
24
- top_p=0.9,
25
  )
26
 
 
27
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
- # Remove the prompt from the response
29
- if response.startswith(prompt):
 
30
  response = response[len(prompt):]
31
 
32
  return response.strip()
33
 
34
- # Create different test templates
35
- test_templates = {
36
- "Basic Game History": """
37
- Game history: {game_history}
38
- Player score: {player_score}
39
- AI score: {ai_score}
40
- Last move: {last_move}
41
-
42
- Based on this information, analyze the game and recommend a next move.
43
- """,
44
 
45
- "With Pre-calculated Statistics": """
46
- Game history: {game_history}
47
- Player's move frequencies: Rock ({rock_freq}%), Paper ({paper_freq}%), Scissors ({scissors_freq}%)
48
- Player's patterns:
49
- - After playing Rock, chooses Paper: {rock_to_paper}%
50
- - After playing Paper, chooses Scissors: {paper_to_scissors}%
51
- - After playing Scissors, chooses Rock: {scissors_to_rock}%
52
-
53
- What should be the AI's next move?
54
- """,
55
 
56
- "Simplified Decision": """
57
- Recent moves: {recent_moves}
58
- Based on this pattern, the player is likely to play {likely_next} next.
59
- To counter {likely_next}, the AI should play:
60
- """
61
- }
62
-
63
- def create_sample_data(template_key):
64
- """Create sample data for the selected template"""
65
- if template_key == "Basic Game History":
66
- return {
67
- "game_history": "R,P,S,R,P,S,S,R,P,R",
68
- "player_score": "5",
69
- "ai_score": "3",
70
- "last_move": "P"
71
- }
72
- elif template_key == "With Pre-calculated Statistics":
73
- return {
74
- "game_history": "R,P,S,R,P,S,S,R,P,R",
75
- "rock_freq": "40",
76
- "paper_freq": "30",
77
- "scissors_freq": "30",
78
- "rock_to_paper": "75",
79
- "paper_to_scissors": "67",
80
- "scissors_to_rock": "50"
81
- }
82
- elif template_key == "Simplified Decision":
83
- return {
84
- "recent_moves": "R,P,S,R,P",
85
- "likely_next": "S"
86
- }
87
- return {}
88
-
89
- def format_prompt(template_key, **kwargs):
90
- """Format the selected template with provided values"""
91
- template = test_templates[template_key]
92
- return template.format(**kwargs)
93
-
94
- def update_template_inputs(template_name):
95
- """Update the input fields based on the selected template"""
96
- sample_data = create_sample_data(template_name)
97
- inputs = []
98
 
99
- for key, value in sample_data.items():
100
- inputs.append(gr.Textbox(value=value, label=key))
101
-
102
- return inputs
103
-
104
- def test_model(template_name, *args):
105
- """Test the model with the provided template and inputs"""
106
- sample_data = create_sample_data(template_name)
107
- data = dict(zip(sample_data.keys(), args))
108
-
109
- prompt = format_prompt(template_name, **data)
110
- response = generate_response(prompt)
111
 
112
- return prompt, response
113
 
114
- # Define the interface
115
  with gr.Blocks() as demo:
116
- gr.Markdown("# Qwen2 0.5B Testing for Rock-Paper-Scissors Game Analysis")
 
117
 
118
  with gr.Row():
119
  with gr.Column():
120
- template_dropdown = gr.Dropdown(
121
- choices=list(test_templates.keys()),
122
- value="Basic Game History",
123
- label="Select Template"
124
  )
125
 
126
- input_container = gr.Column()
127
- sample_data = create_sample_data("Basic Game History")
128
- input_fields = [gr.Textbox(value=v, label=k) for k, v in sample_data.items()]
 
 
 
 
 
 
 
 
129
 
130
- for field in input_fields:
131
- input_container.append(field)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
- test_button = gr.Button("Test Model")
134
 
135
  with gr.Column():
136
- prompt_output = gr.Textbox(label="Formatted Prompt")
137
- response_output = gr.Textbox(label="Model Response")
138
-
139
- def update_inputs(template_name):
140
- sample_data = create_sample_data(template_name)
141
- return [gr.Textbox(value=v, label=k) for k, v in sample_data.items()]
 
 
142
 
143
- template_dropdown.change(
144
- fn=update_inputs,
145
- inputs=template_dropdown,
146
- outputs=input_container
 
 
 
 
 
 
 
147
  )
148
 
149
- test_button.click(
150
- fn=test_model,
151
- inputs=[template_dropdown] + input_fields,
152
- outputs=[prompt_output, response_output]
153
- )
 
 
 
 
154
 
 
155
  demo.launch()
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
+ # Load the Qwen2 0.5B model
6
+ model_id = "Qwen/Qwen2-0.5B"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
8
  model = AutoModelForCausalLM.from_pretrained(
9
+ model_id,
 
10
  torch_dtype=torch.float16,
11
+ device_map="auto",
12
  trust_remote_code=True
13
  )
14
 
15
+ def generate_response(prompt, max_length=512, temperature=0.7, top_p=0.9):
16
+ """Generate a response from the Qwen2 model based on the input prompt."""
17
+
18
+ # Tokenize the input prompt
19
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
20
 
21
+ # Generate response
22
  with torch.no_grad():
23
  outputs = model.generate(
24
  **inputs,
25
  max_new_tokens=max_length,
26
  do_sample=True,
27
  temperature=temperature,
28
+ top_p=top_p,
29
  )
30
 
31
+ # Decode the response
32
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
33
+
34
+ # Extract only the model's response (remove the input prompt)
35
+ if prompt in response:
36
  response = response[len(prompt):]
37
 
38
  return response.strip()
39
 
40
+ def process_input(
41
+ raw_prompt,
42
+ game_stats_template,
43
+ template_type,
44
+ max_length,
45
+ temperature,
46
+ top_p
47
+ ):
48
+ """Process the input and template to create the final prompt for the model."""
 
49
 
50
+ final_prompt = ""
 
 
 
 
 
 
 
 
 
51
 
52
+ if template_type == "Raw Prompt Only":
53
+ final_prompt = raw_prompt
54
+ elif template_type == "Template + Prompt":
55
+ final_prompt = f"{game_stats_template}\n\n{raw_prompt}"
56
+ elif template_type == "Custom Format":
57
+ final_prompt = f"{game_stats_template}\n\nBased on the game statistics above, {raw_prompt}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ # Generate response from the model
60
+ response = generate_response(
61
+ final_prompt,
62
+ max_length=max_length,
63
+ temperature=temperature,
64
+ top_p=top_p
65
+ )
 
 
 
 
 
66
 
67
+ return final_prompt, response
68
 
69
+ # Create the Gradio interface
70
  with gr.Blocks() as demo:
71
+ gr.Markdown("# Qwen2 0.5B Game Analysis Tester")
72
+ gr.Markdown("Use this interface to test how the Qwen2 0.5B model responds to different prompts about your game statistics.")
73
 
74
  with gr.Row():
75
  with gr.Column():
76
+ template_type = gr.Radio(
77
+ ["Raw Prompt Only", "Template + Prompt", "Custom Format"],
78
+ label="Prompt Template Type",
79
+ value="Template + Prompt"
80
  )
81
 
82
+ game_stats_template = gr.Textbox(
83
+ label="Game Statistics Template",
84
+ placeholder="Enter your game statistics here (scores, round history, etc.)",
85
+ lines=10
86
+ )
87
+
88
+ raw_prompt = gr.Textbox(
89
+ label="Prompt",
90
+ placeholder="What do you want the model to analyze or respond to?",
91
+ lines=3
92
+ )
93
 
94
+ with gr.Row():
95
+ max_length = gr.Slider(
96
+ minimum=50,
97
+ maximum=1024,
98
+ value=256,
99
+ step=1,
100
+ label="Max Response Length"
101
+ )
102
+ temperature = gr.Slider(
103
+ minimum=0.1,
104
+ maximum=1.5,
105
+ value=0.7,
106
+ step=0.1,
107
+ label="Temperature"
108
+ )
109
+ top_p = gr.Slider(
110
+ minimum=0.1,
111
+ maximum=1.0,
112
+ value=0.9,
113
+ step=0.1,
114
+ label="Top P"
115
+ )
116
 
117
+ submit_btn = gr.Button("Generate Response")
118
 
119
  with gr.Column():
120
+ final_prompt_display = gr.Textbox(
121
+ label="Final Prompt Sent to Model",
122
+ lines=10
123
+ )
124
+ response_display = gr.Textbox(
125
+ label="Model Response",
126
+ lines=15
127
+ )
128
 
129
+ submit_btn.click(
130
+ process_input,
131
+ inputs=[
132
+ raw_prompt,
133
+ game_stats_template,
134
+ template_type,
135
+ max_length,
136
+ temperature,
137
+ top_p
138
+ ],
139
+ outputs=[final_prompt_display, response_display]
140
  )
141
 
142
+ gr.Markdown("""
143
+ ## Tips for Testing
144
+
145
+ 1. Start with simple prompts to gauge the model's basic understanding
146
+ 2. Gradually increase complexity to find the model's limitations
147
+ 3. Try different prompt formats to see which works best
148
+ 4. Experiment with temperature and top_p to find optimal settings
149
+ 5. Document which prompts work well as candidates for fine-tuning
150
+ """)
151
 
152
+ # Launch the demo
153
  demo.launch()