import os from transformers import AutoModelForCausalLM model_name = os.getenv('MODEL_NAME') model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", torch_dtype="bfloat16", ) print(model_name, sum(p.numel() for p in model.parameters()), model.num_parameters())