kcarnold commited on
Commit
819c4b4
·
1 Parent(s): 697e79c

Prepare to move up to Gemma3

Browse files
Files changed (1) hide show
  1. custom_llm.py +11 -2
custom_llm.py CHANGED
@@ -32,8 +32,16 @@ async def models_lifespan(app: FastAPI):
32
  #model_name = 'google/gemma-1.1-7b-it'
33
  #model_name = 'google/gemma-1.1-2b-it'
34
  model_name = 'google/gemma-2-9b-it'
 
 
35
 
36
- dtype = torch.bfloat16 if USE_GPU else torch.float16
 
 
 
 
 
 
37
 
38
  ml_models["llm"] = llm = {
39
  'tokenizer': AutoTokenizer.from_pretrained(model_name),
@@ -41,7 +49,8 @@ async def models_lifespan(app: FastAPI):
41
  model_name,
42
  device_map="auto" if USE_GPU else "cpu",
43
  torch_dtype=dtype,
44
- attn_implementation='eager'
 
45
  )
46
  }
47
  print("Loaded llm with device map:")
 
32
  #model_name = 'google/gemma-1.1-7b-it'
33
  #model_name = 'google/gemma-1.1-2b-it'
34
  model_name = 'google/gemma-2-9b-it'
35
+ #model_name = 'google/gemma-3-12b-it'
36
+ #model_name = 'google/gemma-3-4b-it'
37
 
38
+ if USE_GPU:
39
+ dtype = torch.bfloat16
40
+ from transformers import TorchAoConfig
41
+ quantization_config = None#TorchAoConfig("int4_weight_only", group_size=128)
42
+ else:
43
+ dtype = torch.float16
44
+ quantization_config = None
45
 
46
  ml_models["llm"] = llm = {
47
  'tokenizer': AutoTokenizer.from_pretrained(model_name),
 
49
  model_name,
50
  device_map="auto" if USE_GPU else "cpu",
51
  torch_dtype=dtype,
52
+ attn_implementation='eager',
53
+ quantization_config=quantization_config,
54
  )
55
  }
56
  print("Loaded llm with device map:")