rui3000 commited on
Commit
9c848ec
·
verified ·
1 Parent(s): af80097

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -14
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import gradio as gr
2
  import json
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
  # Define model name - use a very small model
6
- MODEL_NAME = "EleutherAI/pythia-70m" # Extremely small model, no quantization needed
7
 
8
  print(f"Loading model {MODEL_NAME}...")
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
@@ -14,8 +15,57 @@ model = AutoModelForCausalLM.from_pretrained(
14
  )
15
  print("Model loaded successfully!")
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def format_rps_game_prompt(game_data):
18
- """Format Rock-Paper-Scissors game data into a prompt for the LLM"""
19
  try:
20
  # Parse the JSON game data
21
  if isinstance(game_data, str):
@@ -27,19 +77,42 @@ def format_rps_game_prompt(game_data):
27
  rounds_played = len(player_history)
28
  player_score = game_data.get("player_score", 0)
29
  opponent_score = game_data.get("opponent_score", 0)
 
30
 
31
- # Create the prompt
32
- prompt = f"""You are an assistant helping a player win at Rock-Paper-Scissors.
 
 
 
 
 
 
 
 
33
 
34
  Game State:
35
  - Rounds played: {rounds_played}
36
  - Player score: {player_score}
37
  - Opponent score: {opponent_score}
38
- - Player move history: {', '.join(player_history)}
39
- - Opponent move history: {', '.join(opponent_history)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- Based on the opponent's pattern of moves, what should the player choose next (Rock, Paper, or Scissors)?
42
- Explain your reasoning, then provide a clear recommendation.
43
  """
44
  return prompt
45
  except Exception as e:
@@ -57,17 +130,41 @@ def generate_advice(game_data):
57
  # Set max_length to avoid excessive generation
58
  outputs = model.generate(
59
  inputs["input_ids"],
60
- max_new_tokens=100, # Limit token generation
61
  do_sample=True,
62
  temperature=0.7,
63
  top_p=0.9
64
  )
65
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
 
67
- # Remove the prompt from the response
68
- if response.startswith(prompt):
69
- response = response[len(prompt):].strip()
 
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  return response
72
  except Exception as e:
73
  return f"Error generating advice: {str(e)}"
@@ -77,7 +174,8 @@ sample_game_data = {
77
  "player_history": ["Rock", "Paper", "Scissors", "Rock", "Paper"],
78
  "opponent_history": ["Scissors", "Rock", "Paper", "Scissors", "Rock"],
79
  "player_score": 3,
80
- "opponent_score": 2
 
81
  }
82
 
83
  # Create Gradio interface
 
1
  import gradio as gr
2
  import json
3
+ import collections
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
  # Define model name - use a very small model
7
+ MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Chat-tuned 1.1B parameter model
8
 
9
  print(f"Loading model {MODEL_NAME}...")
10
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
15
  )
16
  print("Model loaded successfully!")
17
 
18
+ def analyze_patterns(player_history, opponent_history):
19
+ """Perform basic pattern analysis on the game history"""
20
+ analysis = {}
21
+
22
+ # Count frequencies of each move
23
+ move_counts = collections.Counter(opponent_history)
24
+ total_moves = len(opponent_history)
25
+
26
+ analysis["move_frequencies"] = {
27
+ "Rock": f"{move_counts.get('Rock', 0)}/{total_moves} ({move_counts.get('Rock', 0)*100/total_moves:.1f}%)",
28
+ "Paper": f"{move_counts.get('Paper', 0)}/{total_moves} ({move_counts.get('Paper', 0)*100/total_moves:.1f}%)",
29
+ "Scissors": f"{move_counts.get('Scissors', 0)}/{total_moves} ({move_counts.get('Scissors', 0)*100/total_moves:.1f}%)"
30
+ }
31
+
32
+ # Check response patterns (what opponent plays after player's moves)
33
+ response_patterns = {
34
+ "After_Rock": collections.Counter(),
35
+ "After_Paper": collections.Counter(),
36
+ "After_Scissors": collections.Counter()
37
+ }
38
+
39
+ for i in range(len(player_history) - 1):
40
+ player_move = player_history[i]
41
+ opponent_next = opponent_history[i + 1]
42
+ response_patterns[f"After_{player_move}"][opponent_next] += 1
43
+
44
+ analysis["response_patterns"] = {}
45
+ for pattern, counter in response_patterns.items():
46
+ if sum(counter.values()) > 0:
47
+ most_common = counter.most_common(1)[0]
48
+ analysis["response_patterns"][pattern] = f"{most_common[0]} ({most_common[1]}/{sum(counter.values())})"
49
+
50
+ # Check for repeating sequences
51
+ last_moves = opponent_history[-3:]
52
+ repeated_sequences = []
53
+
54
+ # Look for this sequence in history
55
+ for i in range(len(opponent_history) - 3):
56
+ if opponent_history[i:i+3] == last_moves:
57
+ if i+3 < len(opponent_history):
58
+ repeated_sequences.append(opponent_history[i+3])
59
+
60
+ if repeated_sequences:
61
+ counter = collections.Counter(repeated_sequences)
62
+ most_common = counter.most_common(1)[0]
63
+ analysis["sequence_prediction"] = f"After sequence {' → '.join(last_moves)}, opponent most often plays {most_common[0]} ({most_common[1]}/{len(repeated_sequences)} times)"
64
+
65
+ return analysis
66
+
67
  def format_rps_game_prompt(game_data):
68
+ """Format Rock-Paper-Scissors game data into a prompt for the LLM with pattern analysis"""
69
  try:
70
  # Parse the JSON game data
71
  if isinstance(game_data, str):
 
77
  rounds_played = len(player_history)
78
  player_score = game_data.get("player_score", 0)
79
  opponent_score = game_data.get("opponent_score", 0)
80
+ draws = game_data.get("draws", 0)
81
 
82
+ # Perform pattern analysis
83
+ pattern_analysis = analyze_patterns(player_history, opponent_history)
84
+
85
+ # Format analysis for prompt
86
+ move_frequencies = pattern_analysis.get("move_frequencies", {})
87
+ response_patterns = pattern_analysis.get("response_patterns", {})
88
+ sequence_prediction = pattern_analysis.get("sequence_prediction", "No clear sequence pattern detected")
89
+
90
+ # Create a more specific prompt with the analysis
91
+ prompt = f"""You are an expert Rock-Paper-Scissors strategy advisor helping a player win.
92
 
93
  Game State:
94
  - Rounds played: {rounds_played}
95
  - Player score: {player_score}
96
  - Opponent score: {opponent_score}
97
+ - Draws: {draws}
98
+ - Player's last 5 moves: {', '.join(player_history[-5:])}
99
+ - Opponent's last 5 moves: {', '.join(opponent_history[-5:])}
100
+
101
+ Pattern Analysis:
102
+ - Opponent's move frequencies:
103
+ * Rock: {move_frequencies.get('Rock', 'N/A')}
104
+ * Paper: {move_frequencies.get('Paper', 'N/A')}
105
+ * Scissors: {move_frequencies.get('Scissors', 'N/A')}
106
+
107
+ - Opponent's response patterns:
108
+ * After player's Rock: {response_patterns.get('After_Rock', 'No clear pattern')}
109
+ * After player's Paper: {response_patterns.get('After_Paper', 'No clear pattern')}
110
+ * After player's Scissors: {response_patterns.get('After_Scissors', 'No clear pattern')}
111
+
112
+ - Sequence analysis: {sequence_prediction}
113
 
114
+ Based on this pattern analysis, what should the player choose next (Rock, Paper, or Scissors)?
115
+ Explain your reasoning step-by-step, then end with: "Recommendation: [Rock/Paper/Scissors]"
116
  """
117
  return prompt
118
  except Exception as e:
 
130
  # Set max_length to avoid excessive generation
131
  outputs = model.generate(
132
  inputs["input_ids"],
133
+ max_new_tokens=150, # Allow more tokens for a more detailed response
134
  do_sample=True,
135
  temperature=0.7,
136
  top_p=0.9
137
  )
138
+ full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
139
 
140
+ # Extract just the model's response (remove the prompt)
141
+ if full_response.startswith(prompt):
142
+ response = full_response[len(prompt):].strip()
143
+ else:
144
+ response = full_response
145
 
146
+ # If model response is too short, add fallback advice
147
+ if len(response) < 30:
148
+ pattern_analysis = analyze_patterns(
149
+ json.loads(game_data)["player_history"] if isinstance(game_data, str) else game_data["player_history"],
150
+ json.loads(game_data)["opponent_history"] if isinstance(game_data, str) else game_data["opponent_history"]
151
+ )
152
+
153
+ # Simple fallback strategy based on pattern analysis
154
+ move_freqs = pattern_analysis.get("move_frequencies", {})
155
+ max_move = max(["Rock", "Paper", "Scissors"],
156
+ key=lambda m: float(move_freqs.get(m, "0/0 (0%)").split("(")[1].split("%")[0]))
157
+
158
+ # Choose counter to opponent's most frequent move
159
+ if max_move == "Rock":
160
+ suggestion = "Paper"
161
+ elif max_move == "Paper":
162
+ suggestion = "Scissors"
163
+ else:
164
+ suggestion = "Rock"
165
+
166
+ response += f"\n\nBased on opponent's move frequencies, they play {max_move} most often. \nRecommendation: {suggestion}"
167
+
168
  return response
169
  except Exception as e:
170
  return f"Error generating advice: {str(e)}"
 
174
  "player_history": ["Rock", "Paper", "Scissors", "Rock", "Paper"],
175
  "opponent_history": ["Scissors", "Rock", "Paper", "Scissors", "Rock"],
176
  "player_score": 3,
177
+ "opponent_score": 2,
178
+ "draws": 0
179
  }
180
 
181
  # Create Gradio interface