Spaces:
Running
Running
Prepare to move up to Gemma3
Browse files- custom_llm.py +11 -2
custom_llm.py
CHANGED
@@ -32,8 +32,16 @@ async def models_lifespan(app: FastAPI):
|
|
32 |
#model_name = 'google/gemma-1.1-7b-it'
|
33 |
#model_name = 'google/gemma-1.1-2b-it'
|
34 |
model_name = 'google/gemma-2-9b-it'
|
|
|
|
|
35 |
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
ml_models["llm"] = llm = {
|
39 |
'tokenizer': AutoTokenizer.from_pretrained(model_name),
|
@@ -41,7 +49,8 @@ async def models_lifespan(app: FastAPI):
|
|
41 |
model_name,
|
42 |
device_map="auto" if USE_GPU else "cpu",
|
43 |
torch_dtype=dtype,
|
44 |
-
attn_implementation='eager'
|
|
|
45 |
)
|
46 |
}
|
47 |
print("Loaded llm with device map:")
|
|
|
32 |
#model_name = 'google/gemma-1.1-7b-it'
|
33 |
#model_name = 'google/gemma-1.1-2b-it'
|
34 |
model_name = 'google/gemma-2-9b-it'
|
35 |
+
#model_name = 'google/gemma-3-12b-it'
|
36 |
+
#model_name = 'google/gemma-3-4b-it'
|
37 |
|
38 |
+
if USE_GPU:
|
39 |
+
dtype = torch.bfloat16
|
40 |
+
from transformers import TorchAoConfig
|
41 |
+
quantization_config = None#TorchAoConfig("int4_weight_only", group_size=128)
|
42 |
+
else:
|
43 |
+
dtype = torch.float16
|
44 |
+
quantization_config = None
|
45 |
|
46 |
ml_models["llm"] = llm = {
|
47 |
'tokenizer': AutoTokenizer.from_pretrained(model_name),
|
|
|
49 |
model_name,
|
50 |
device_map="auto" if USE_GPU else "cpu",
|
51 |
torch_dtype=dtype,
|
52 |
+
attn_implementation='eager',
|
53 |
+
quantization_config=quantization_config,
|
54 |
)
|
55 |
}
|
56 |
print("Loaded llm with device map:")
|