Update handler.py
Browse files- handler.py +2 -2
handler.py
CHANGED
@@ -12,10 +12,10 @@ class EndpointHandler():
|
|
12 |
path,
|
13 |
torch_dtype=torch.bfloat16,
|
14 |
device_map="cuda:0", # or "mps" if on Apple Silicon
|
15 |
-
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
|
16 |
).eval()
|
17 |
self.processor = ColQwen2Processor.from_pretrained(path)
|
18 |
-
self.model = torch.compile(self.model)
|
19 |
print(f"Model and processor loaded {'with' if is_flash_attn_2_available() else 'without'} FA2")
|
20 |
|
21 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
12 |
path,
|
13 |
torch_dtype=torch.bfloat16,
|
14 |
device_map="cuda:0", # or "mps" if on Apple Silicon
|
15 |
+
attn_implementation="flash_attention_2" # if is_flash_attn_2_available() else None,
|
16 |
).eval()
|
17 |
self.processor = ColQwen2Processor.from_pretrained(path)
|
18 |
+
# self.model = torch.compile(self.model)
|
19 |
print(f"Model and processor loaded {'with' if is_flash_attn_2_available() else 'without'} FA2")
|
20 |
|
21 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|