leaderboard / src /submit.py
Alvaro Romo
Fixed logging messages and refactor code. Added log in private dataset
c2f297a
raw
history blame contribute delete
916 Bytes
from dataclasses import dataclass
from transformers import AutoConfig
@dataclass
class ModelSizeChecker:
model: str
precision: str
model_size_in_b: float
def get_precision_factor(self):
if self.precision in ["float16", "bfloat16"]:
return 1
elif self.precision == "8bit":
return 2
elif self.precision == "4bit":
return 4
elif self.precision == "GPTQ":
config = AutoConfig.from_pretrained(self.model)
num_bits = int(config.quantization_config["bits"])
bits_to_precision_factor = {2: 8, 3: 6, 4: 4, 8: 2}
return bits_to_precision_factor.get(num_bits, 1)
else:
raise Exception(f"Unknown precision {self.precision}.")
def can_evaluate(self):
precision_factor = self.get_precision_factor()
return self.model_size_in_b <= 140 * precision_factor