Spaces:
Sleeping
Sleeping
import gradio as gr | |
import json | |
import collections | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# Define model name - use a very small model | |
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Chat-tuned 1.1B parameter model | |
print(f"Loading model {MODEL_NAME}...") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
low_cpu_mem_usage=True, # CPU-friendly settings | |
device_map="cpu" # Force CPU usage | |
) | |
print("Model loaded successfully!") | |
def analyze_patterns(player_history, opponent_history): | |
"""Perform basic pattern analysis on the game history""" | |
analysis = {} | |
# Count frequencies of each move | |
move_counts = collections.Counter(opponent_history) | |
total_moves = len(opponent_history) | |
analysis["move_frequencies"] = { | |
"Rock": f"{move_counts.get('Rock', 0)}/{total_moves} ({move_counts.get('Rock', 0)*100/total_moves:.1f}%)", | |
"Paper": f"{move_counts.get('Paper', 0)}/{total_moves} ({move_counts.get('Paper', 0)*100/total_moves:.1f}%)", | |
"Scissors": f"{move_counts.get('Scissors', 0)}/{total_moves} ({move_counts.get('Scissors', 0)*100/total_moves:.1f}%)" | |
} | |
# Check response patterns (what opponent plays after player's moves) | |
response_patterns = { | |
"After_Rock": collections.Counter(), | |
"After_Paper": collections.Counter(), | |
"After_Scissors": collections.Counter() | |
} | |
for i in range(len(player_history) - 1): | |
player_move = player_history[i] | |
opponent_next = opponent_history[i + 1] | |
response_patterns[f"After_{player_move}"][opponent_next] += 1 | |
analysis["response_patterns"] = {} | |
for pattern, counter in response_patterns.items(): | |
if sum(counter.values()) > 0: | |
most_common = counter.most_common(1)[0] | |
analysis["response_patterns"][pattern] = f"{most_common[0]} ({most_common[1]}/{sum(counter.values())})" | |
# Check for repeating sequences | |
last_moves = opponent_history[-3:] | |
repeated_sequences = [] | |
# Look for this sequence in history | |
for i in range(len(opponent_history) - 3): | |
if opponent_history[i:i+3] == last_moves: | |
if i+3 < len(opponent_history): | |
repeated_sequences.append(opponent_history[i+3]) | |
if repeated_sequences: | |
counter = collections.Counter(repeated_sequences) | |
most_common = counter.most_common(1)[0] | |
analysis["sequence_prediction"] = f"After sequence {' → '.join(last_moves)}, opponent most often plays {most_common[0]} ({most_common[1]}/{len(repeated_sequences)} times)" | |
return analysis | |
def format_rps_game_prompt(game_data): | |
"""Format Rock-Paper-Scissors game data into a prompt for the LLM with pattern analysis""" | |
try: | |
# Parse the JSON game data | |
if isinstance(game_data, str): | |
game_data = json.loads(game_data) | |
# Extract key game information | |
player_history = game_data.get("player_history", []) | |
opponent_history = game_data.get("opponent_history", []) | |
rounds_played = len(player_history) | |
player_score = game_data.get("player_score", 0) | |
opponent_score = game_data.get("opponent_score", 0) | |
draws = game_data.get("draws", 0) | |
# Perform pattern analysis | |
pattern_analysis = analyze_patterns(player_history, opponent_history) | |
# Format analysis for prompt | |
move_frequencies = pattern_analysis.get("move_frequencies", {}) | |
response_patterns = pattern_analysis.get("response_patterns", {}) | |
sequence_prediction = pattern_analysis.get("sequence_prediction", "No clear sequence pattern detected") | |
# Create a more specific prompt with the analysis | |
prompt = f"""You are an expert Rock-Paper-Scissors strategy advisor helping a player win. | |
Game State: | |
- Rounds played: {rounds_played} | |
- Player score: {player_score} | |
- Opponent score: {opponent_score} | |
- Draws: {draws} | |
- Player's last 5 moves: {', '.join(player_history[-5:])} | |
- Opponent's last 5 moves: {', '.join(opponent_history[-5:])} | |
Pattern Analysis: | |
- Opponent's move frequencies: | |
* Rock: {move_frequencies.get('Rock', 'N/A')} | |
* Paper: {move_frequencies.get('Paper', 'N/A')} | |
* Scissors: {move_frequencies.get('Scissors', 'N/A')} | |
- Opponent's response patterns: | |
* After player's Rock: {response_patterns.get('After_Rock', 'No clear pattern')} | |
* After player's Paper: {response_patterns.get('After_Paper', 'No clear pattern')} | |
* After player's Scissors: {response_patterns.get('After_Scissors', 'No clear pattern')} | |
- Sequence analysis: {sequence_prediction} | |
Based on this pattern analysis, what should the player choose next (Rock, Paper, or Scissors)? | |
Explain your reasoning step-by-step, then end with: "Recommendation: [Rock/Paper/Scissors]" | |
""" | |
return prompt | |
except Exception as e: | |
return f"Error formatting prompt: {str(e)}\n\nPlease provide game data in a valid JSON format." | |
def generate_advice(game_data): | |
"""Generate advice based on game data using the LLM""" | |
try: | |
# Format the prompt | |
prompt = format_rps_game_prompt(game_data) | |
# Generate response from LLM (with CPU-only settings) | |
inputs = tokenizer(prompt, return_tensors="pt") | |
# Set max_length to avoid excessive generation | |
outputs = model.generate( | |
inputs["input_ids"], | |
max_new_tokens=150, # Allow more tokens for a more detailed response | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9 | |
) | |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract just the model's response (remove the prompt) | |
if full_response.startswith(prompt): | |
response = full_response[len(prompt):].strip() | |
else: | |
response = full_response | |
# If model response is too short, add fallback advice | |
if len(response) < 30: | |
pattern_analysis = analyze_patterns( | |
json.loads(game_data)["player_history"] if isinstance(game_data, str) else game_data["player_history"], | |
json.loads(game_data)["opponent_history"] if isinstance(game_data, str) else game_data["opponent_history"] | |
) | |
# Simple fallback strategy based on pattern analysis | |
move_freqs = pattern_analysis.get("move_frequencies", {}) | |
max_move = max(["Rock", "Paper", "Scissors"], | |
key=lambda m: float(move_freqs.get(m, "0/0 (0%)").split("(")[1].split("%")[0])) | |
# Choose counter to opponent's most frequent move | |
if max_move == "Rock": | |
suggestion = "Paper" | |
elif max_move == "Paper": | |
suggestion = "Scissors" | |
else: | |
suggestion = "Rock" | |
response += f"\n\nBased on opponent's move frequencies, they play {max_move} most often. \nRecommendation: {suggestion}" | |
return response | |
except Exception as e: | |
return f"Error generating advice: {str(e)}" | |
# Sample game data for the example | |
sample_game_data = { | |
"player_history": ["Rock", "Paper", "Scissors", "Rock", "Paper"], | |
"opponent_history": ["Scissors", "Rock", "Paper", "Scissors", "Rock"], | |
"player_score": 3, | |
"opponent_score": 2, | |
"draws": 0 | |
} | |
# Create Gradio interface | |
with gr.Blocks(title="Rock-Paper-Scissors AI Assistant") as demo: | |
gr.Markdown("# Rock-Paper-Scissors AI Assistant") | |
gr.Markdown("Enter your game data to get advice on your next move.") | |
with gr.Row(): | |
with gr.Column(): | |
game_data_input = gr.Textbox( | |
label="Game State (JSON)", | |
placeholder=json.dumps(sample_game_data, indent=2), | |
lines=10 | |
) | |
submit_btn = gr.Button("Get Advice") | |
with gr.Column(): | |
output = gr.Textbox(label="AI Advice", lines=10) | |
submit_btn.click(generate_advice, inputs=[game_data_input], outputs=[output]) | |
# Launch the app | |
demo.launch() |