jw2yang commited on
Commit
0ac095e
·
1 Parent(s): b84c65f
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -8,6 +8,9 @@ from transformers import AutoModelForCausalLM, AutoProcessor
8
  import re
9
  import random
10
 
 
 
 
11
  pygame.mixer.quit() # Disable sound
12
 
13
  # Constants
@@ -30,8 +33,9 @@ STATIC = (0, 0)
30
  ACTIONS = ["up", "down", "left", "right", "static"]
31
 
32
  # Load AI Model
 
33
  magma_model_id = "microsoft/Magma-8B"
34
- magam_model = AutoModelForCausalLM.from_pretrained(magma_model_id, trust_remote_code=True)
35
  magma_processor = AutoProcessor.from_pretrained(magma_model_id, trust_remote_code=True)
36
  magam_model.to("cuda")
37
 
@@ -137,7 +141,7 @@ def play_game():
137
  inputs = magma_processor(images=[pil_img], texts=prompt, return_tensors="pt")
138
  inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
139
  inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
140
- inputs = inputs.to("cuda")
141
  generation_args = {
142
  "max_new_tokens": 10,
143
  "temperature": 0,
 
8
  import re
9
  import random
10
 
11
+ # add a command for installing flash-attn
12
+ os.system('pip install flash-attn --no-build-isolation')
13
+
14
  pygame.mixer.quit() # Disable sound
15
 
16
  # Constants
 
33
  ACTIONS = ["up", "down", "left", "right", "static"]
34
 
35
  # Load AI Model
36
+ dtype = torch.bfloat16
37
  magma_model_id = "microsoft/Magma-8B"
38
+ magam_model = AutoModelForCausalLM.from_pretrained(magma_model_id, trust_remote_code=True, torch_dtype=dtype)
39
  magma_processor = AutoProcessor.from_pretrained(magma_model_id, trust_remote_code=True)
40
  magam_model.to("cuda")
41
 
 
141
  inputs = magma_processor(images=[pil_img], texts=prompt, return_tensors="pt")
142
  inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
143
  inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
144
+ inputs = inputs.to("cuda").to(dtype)
145
  generation_args = {
146
  "max_new_tokens": 10,
147
  "temperature": 0,