Spaces:
Configuration error
Configuration error
update
Browse files
app.py
CHANGED
@@ -8,6 +8,9 @@ from transformers import AutoModelForCausalLM, AutoProcessor
|
|
8 |
import re
|
9 |
import random
|
10 |
|
|
|
|
|
|
|
11 |
pygame.mixer.quit() # Disable sound
|
12 |
|
13 |
# Constants
|
@@ -30,8 +33,9 @@ STATIC = (0, 0)
|
|
30 |
ACTIONS = ["up", "down", "left", "right", "static"]
|
31 |
|
32 |
# Load AI Model
|
|
|
33 |
magma_model_id = "microsoft/Magma-8B"
|
34 |
-
magam_model = AutoModelForCausalLM.from_pretrained(magma_model_id, trust_remote_code=True)
|
35 |
magma_processor = AutoProcessor.from_pretrained(magma_model_id, trust_remote_code=True)
|
36 |
magam_model.to("cuda")
|
37 |
|
@@ -137,7 +141,7 @@ def play_game():
|
|
137 |
inputs = magma_processor(images=[pil_img], texts=prompt, return_tensors="pt")
|
138 |
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
|
139 |
inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
|
140 |
-
inputs = inputs.to("cuda")
|
141 |
generation_args = {
|
142 |
"max_new_tokens": 10,
|
143 |
"temperature": 0,
|
|
|
8 |
import re
|
9 |
import random
|
10 |
|
11 |
+
# add a command for installing flash-attn
|
12 |
+
os.system('pip install flash-attn --no-build-isolation')
|
13 |
+
|
14 |
pygame.mixer.quit() # Disable sound
|
15 |
|
16 |
# Constants
|
|
|
33 |
ACTIONS = ["up", "down", "left", "right", "static"]
|
34 |
|
35 |
# Load AI Model
|
36 |
+
dtype = torch.bfloat16
|
37 |
magma_model_id = "microsoft/Magma-8B"
|
38 |
+
magam_model = AutoModelForCausalLM.from_pretrained(magma_model_id, trust_remote_code=True, torch_dtype=dtype)
|
39 |
magma_processor = AutoProcessor.from_pretrained(magma_model_id, trust_remote_code=True)
|
40 |
magam_model.to("cuda")
|
41 |
|
|
|
141 |
inputs = magma_processor(images=[pil_img], texts=prompt, return_tensors="pt")
|
142 |
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
|
143 |
inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
|
144 |
+
inputs = inputs.to("cuda").to(dtype)
|
145 |
generation_args = {
|
146 |
"max_new_tokens": 10,
|
147 |
"temperature": 0,
|