jw2yang commited on
Commit
9c781c3
·
1 Parent(s): 6c8c423
Files changed (2) hide show
  1. app.py +199 -0
  2. requirements.txt +37 -0
app.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pygame
2
+ import numpy as np
3
+ import gradio as gr
4
+ import time
5
+ import torch
6
+ from PIL import Image
7
+ from transformers import AutoModelForCausalLM, AutoProcessor
8
+ import re
9
+ import random
10
+
11
+ pygame.mixer.quit() # Disable sound
12
+
13
+ # Constants
14
+ WIDTH, HEIGHT = 800, 800
15
+ GRID_SIZE = 80
16
+ WHITE = (255, 255, 255)
17
+ GREEN = (34, 139, 34) # Forest green - more like an apple
18
+ RED = (200, 50, 50)
19
+ BLACK = (0, 0, 0)
20
+ GRAY = (128, 128, 128)
21
+ YELLOW = (218, 165, 32) # Golden yellow color
22
+
23
+ # Directions
24
+ UP = (0, -1)
25
+ DOWN = (0, 1)
26
+ LEFT = (-1, 0)
27
+ RIGHT = (1, 0)
28
+ STATIC = (0, 0)
29
+
30
+ ACTIONS = ["up", "down", "left", "right", "static"]
31
+
32
+ # Load AI Model
33
+ magma_model_id = "microsoft/Magma-8B"
34
+ magam_model = AutoModelForCausalLM.from_pretrained(magma_model_id, trust_remote_code=True)
35
+ magma_processor = AutoProcessor.from_pretrained(magma_model_id, trust_remote_code=True)
36
+ magam_model.to("cuda")
37
+
38
+ # Load magma image
39
+ magma_img = pygame.image.load("./assets/images/magma_game.png")
40
+ magma_img = pygame.transform.scale(magma_img, (GRID_SIZE, GRID_SIZE))
41
+
42
+ class MagmaFindGPU:
43
+ def __init__(self):
44
+ self.reset()
45
+
46
+ def reset(self):
47
+ self.snake = [(5, 5)]
48
+ self.direction = RIGHT
49
+ self.score = 0
50
+ self.game_over = False
51
+ self.place_target()
52
+
53
+ def place_target(self):
54
+ while True:
55
+ target_x = np.random.randint(1, WIDTH // GRID_SIZE - 1)
56
+ target_y = np.random.randint(1, HEIGHT // GRID_SIZE - 1)
57
+ if (target_x, target_y) not in self.snake:
58
+ self.target = (target_x, target_y)
59
+ break
60
+
61
+ def step(self, action):
62
+ if action == "up":
63
+ self.direction = UP
64
+ elif action == "down":
65
+ self.direction = DOWN
66
+ elif action == "left":
67
+ self.direction = LEFT
68
+ elif action == "right":
69
+ self.direction = RIGHT
70
+ elif action == "static":
71
+ self.direction = STATIC
72
+
73
+ if self.game_over:
74
+ return self.render(), self.score
75
+
76
+ new_head = (self.snake[0][0] + self.direction[0], self.snake[0][1] + self.direction[1])
77
+
78
+ if new_head[0] < 0 or new_head[1] < 0 or new_head[0] >= WIDTH // GRID_SIZE or new_head[1] >= HEIGHT // GRID_SIZE:
79
+ self.game_over = True
80
+ return self.render(), self.score
81
+
82
+ self.snake = [new_head] # Keep only the head (single block snake)
83
+
84
+ # Check if the target is covered by four surrounding squares
85
+ head_x, head_y = self.snake[0]
86
+ neighbors = set([(head_x, head_y - 1), (head_x, head_y + 1), (head_x - 1, head_y), (head_x + 1, head_y)])
87
+
88
+ if neighbors.issuperset(set([self.target])):
89
+ self.score += 1
90
+ self.place_target()
91
+
92
+ return self.render(), self.score
93
+
94
+ def render(self):
95
+ pygame.init()
96
+ surface = pygame.Surface((WIDTH, HEIGHT))
97
+ surface.fill(BLACK)
98
+
99
+ head_x, head_y = self.snake[0]
100
+ surface.blit(magma_img, (head_x * GRID_SIZE, head_y * GRID_SIZE))
101
+
102
+ # pygame.draw.rect(surface, RED, (self.snake[0][0] * GRID_SIZE, self.snake[0][1] * GRID_SIZE, GRID_SIZE, GRID_SIZE))
103
+ pygame.draw.rect(surface, GREEN, (self.target[0] * GRID_SIZE, self.target[1] * GRID_SIZE, GRID_SIZE, GRID_SIZE))
104
+
105
+ # Draw four surrounding squares with labels
106
+ head_x, head_y = self.snake[0]
107
+ neighbors = [(head_x, head_y - 1), (head_x, head_y + 1), (head_x - 1, head_y), (head_x + 1, head_y)]
108
+ labels = ["1", "2", "3", "4"]
109
+ font = pygame.font.Font(None, 48)
110
+
111
+ # clone surface
112
+ surface_nomark = surface.copy()
113
+ for i, (nx, ny) in enumerate(neighbors):
114
+ if 0 <= nx < WIDTH // GRID_SIZE and 0 <= ny < HEIGHT // GRID_SIZE:
115
+ pygame.draw.rect(surface, RED, (nx * GRID_SIZE, ny * GRID_SIZE, GRID_SIZE, GRID_SIZE), GRID_SIZE)
116
+ # pygame.draw.rect(surface_nomark, RED, (nx * GRID_SIZE, ny * GRID_SIZE, GRID_SIZE, GRID_SIZE), GRID_SIZE)
117
+
118
+ text = font.render(labels[i], True, WHITE)
119
+ text_rect = text.get_rect(center=(nx * GRID_SIZE + GRID_SIZE // 2, ny * GRID_SIZE + GRID_SIZE // 2))
120
+ surface.blit(text, text_rect)
121
+
122
+ return np.array(pygame.surfarray.array3d(surface_nomark)).swapaxes(0, 1), np.array(pygame.surfarray.array3d(surface)).swapaxes(0, 1)
123
+
124
+ def get_state(self):
125
+ return self.render()
126
+
127
+ game = MagmaFindGPU()
128
+
129
+ def play_game():
130
+ state, state_som = game.get_state()
131
+ pil_img = Image.fromarray(state_som)
132
+ convs = [
133
+ {"role": "system", "content": "You are an agent that can see, talk, and act."},
134
+ {"role": "user", "content": "<image_start><image><image_end>\nWhich mark is closer to green block? Answer with a single number."},
135
+ ]
136
+ prompt = magma_processor.tokenizer.apply_chat_template(convs, tokenize=False, add_generation_prompt=True)
137
+ inputs = magma_processor(images=[pil_img], texts=prompt, return_tensors="pt")
138
+ inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
139
+ inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
140
+ inputs = inputs.to("cuda")
141
+ generation_args = {
142
+ "max_new_tokens": 10,
143
+ "temperature": 0,
144
+ "do_sample": False,
145
+ "use_cache": True,
146
+ "num_beams": 1,
147
+ }
148
+ with torch.inference_mode():
149
+ generate_ids = magam_model.generate(**inputs, **generation_args)
150
+ generate_ids = generate_ids[:, inputs["input_ids"].shape[-1] :]
151
+ action = magma_processor.decode(generate_ids[0], skip_special_tokens=True).strip()
152
+ # extract mark id fro action use re
153
+ match = re.search(r'\d+', action)
154
+ if match:
155
+ action = match.group(0)
156
+ if action.isdigit() and 1 <= int(action) <= 4:
157
+ # epsilon sampling
158
+ if random.random() < 0.1:
159
+ action = random.choice(ACTIONS[:-1])
160
+ else:
161
+ action = ACTIONS[int(action) - 1]
162
+ else:
163
+ # random choose one from the pool
164
+ action = random.choice(ACTIONS[:-1])
165
+ else:
166
+ action = random.choice(ACTIONS[:-1])
167
+
168
+ img, score = game.step(action)
169
+ img = img[0]
170
+ return img, f"Score: {score}"
171
+
172
+ def reset_game():
173
+ game.reset()
174
+ return game.render()[0], "Score: 0"
175
+
176
+ MARKDOWN = """
177
+ <div align="center">
178
+ <h2>Magma: A Foundation Model for Multimodal AI Agents</h2>
179
+
180
+ Game: Magma finds the apple by moving up, down, left and right.
181
+
182
+ \[[arXiv Paper](https://www.arxiv.org/pdf/2502.13130)\] &nbsp; \[[Project Page](https://microsoft.github.io/Magma/)\] &nbsp; \[[Github Repo](https://github.com/microsoft/Magma)\] &nbsp; \[[Hugging Face Model](https://huggingface.co/microsoft/Magma-8B)\] &nbsp;
183
+
184
+ This demo is powered by [Gradio](https://gradio.app/).
185
+ </div>
186
+ """
187
+
188
+ with gr.Blocks() as interface:
189
+ gr.Markdown(MARKDOWN)
190
+ with gr.Row():
191
+ image_output = gr.Image(label="Game Screen")
192
+ score_output = gr.Text(label="Score")
193
+ with gr.Row():
194
+ start_btn = gr.Button("Start/Reset Game")
195
+
196
+ interface.load(fn=play_game, every=1, inputs=[], outputs=[image_output, score_output])
197
+ start_btn.click(fn=reset_game, inputs=[], outputs=[image_output, score_output])
198
+
199
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.3.1
2
+ torchvision==0.18.1
3
+ pytorch-lightning>=1.0.8
4
+ transformers @ git+https://github.com/jwyang/transformers.git@dev/jwyang-v4.44.1
5
+ tokenizers>=0.15.0
6
+ sentencepiece==0.1.99
7
+ shortuuid
8
+ accelerate==0.34.2
9
+ peft==0.4.0
10
+ bitsandbytes==0.44.1
11
+ pydantic>=2.0
12
+ markdown2[all]
13
+ numpy
14
+ scikit-learn==1.5.0
15
+ gradio==4.44.1
16
+ gradio_client
17
+ spaces
18
+ requests
19
+ httpx
20
+ uvicorn
21
+ fastapi
22
+ einops==0.6.1
23
+ einops-exts==0.0.4
24
+ timm==0.9.12
25
+ tensorflow==2.15.0
26
+ tensorflow_datasets==4.9.3
27
+ tensorflow_graphics==2021.12.3
28
+ draccus
29
+ pyav
30
+ numba
31
+ loguru
32
+ sacrebleu
33
+ evaluate
34
+ sqlitedict
35
+ open_clip_torch
36
+ supervision==0.18.0
37
+ ultralytics==8.3.78