Update handler.py
Browse files- handler.py +1 -1
handler.py
CHANGED
@@ -12,7 +12,7 @@ class EndpointHandler():
|
|
12 |
path,
|
13 |
torch_dtype=torch.bfloat16,
|
14 |
device_map="cuda:0", # or "mps" if on Apple Silicon
|
15 |
-
attn_implementation="flash_attention_2"
|
16 |
).eval()
|
17 |
self.processor = ColQwen2Processor.from_pretrained(path)
|
18 |
# self.model = torch.compile(self.model)
|
|
|
12 |
path,
|
13 |
torch_dtype=torch.bfloat16,
|
14 |
device_map="cuda:0", # or "mps" if on Apple Silicon
|
15 |
+
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
|
16 |
).eval()
|
17 |
self.processor = ColQwen2Processor.from_pretrained(path)
|
18 |
# self.model = torch.compile(self.model)
|