Daniel Kantor
commited on
Commit
·
87921bd
1
Parent(s):
afcd31f
add fallback logic for getting the model size
Browse files
backend/app/utils/model_validation.py
CHANGED
@@ -88,7 +88,12 @@ class ModelValidator:
|
|
88 |
async def get_model_size(
|
89 |
self, model_info: Any, precision: str, base_model: str, revision: str
|
90 |
) -> Tuple[Optional[float], Optional[str]]:
|
91 |
-
"""Get model size in billions of parameters
|
|
|
|
|
|
|
|
|
|
|
92 |
try:
|
93 |
logger.info(
|
94 |
LogFormatter.info(f"Checking model size for {model_info.modelId}")
|
@@ -101,43 +106,73 @@ class ModelValidator:
|
|
101 |
if hasattr(s, "rfilename")
|
102 |
)
|
103 |
|
104 |
-
#
|
105 |
-
model_size = None
|
106 |
|
107 |
if is_adapter and base_model:
|
108 |
-
# For adapters, we need both adapter and base model sizes
|
109 |
adapter_meta = await self.get_safetensors_metadata(
|
110 |
model_info.id, is_adapter=True, revision=revision
|
111 |
)
|
112 |
base_meta = await self.get_safetensors_metadata(
|
113 |
base_model, revision="main"
|
114 |
)
|
115 |
-
|
116 |
if adapter_meta and base_meta:
|
117 |
adapter_size = sum(adapter_meta.parameter_count.values())
|
118 |
base_size = sum(base_meta.parameter_count.values())
|
119 |
model_size = adapter_size + base_size
|
120 |
else:
|
121 |
-
# For regular models,
|
122 |
meta = await self.get_safetensors_metadata(
|
123 |
model_info.id, revision=revision
|
124 |
)
|
125 |
if meta:
|
126 |
-
model_size = sum(meta.parameter_count.values())
|
127 |
|
128 |
-
if model_size is None:
|
129 |
-
#
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
-
#
|
133 |
-
|
134 |
-
|
135 |
)
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
except Exception as e:
|
143 |
logger.error(LogFormatter.error(f"Error while determining model size: {e}"))
|
|
|
88 |
async def get_model_size(
|
89 |
self, model_info: Any, precision: str, base_model: str, revision: str
|
90 |
) -> Tuple[Optional[float], Optional[str]]:
|
91 |
+
"""Get model size in billions of parameters.
|
92 |
+
|
93 |
+
First, try to use safetensors metadata (which provides parameter counts).
|
94 |
+
If that isn’t available (i.e. for non-safetensors models), then as a fallback,
|
95 |
+
use file metadata (summing the sizes of weight files) and estimate the parameter count.
|
96 |
+
"""
|
97 |
try:
|
98 |
logger.info(
|
99 |
LogFormatter.info(f"Checking model size for {model_info.modelId}")
|
|
|
106 |
if hasattr(s, "rfilename")
|
107 |
)
|
108 |
|
109 |
+
model_size = None # will hold total parameter count (as a number)
|
|
|
110 |
|
111 |
if is_adapter and base_model:
|
112 |
+
# For adapters, we need both adapter and base model sizes from safetensors metadata.
|
113 |
adapter_meta = await self.get_safetensors_metadata(
|
114 |
model_info.id, is_adapter=True, revision=revision
|
115 |
)
|
116 |
base_meta = await self.get_safetensors_metadata(
|
117 |
base_model, revision="main"
|
118 |
)
|
|
|
119 |
if adapter_meta and base_meta:
|
120 |
adapter_size = sum(adapter_meta.parameter_count.values())
|
121 |
base_size = sum(base_meta.parameter_count.values())
|
122 |
model_size = adapter_size + base_size
|
123 |
else:
|
124 |
+
# For regular models, try to get the model size from safetensors metadata.
|
125 |
meta = await self.get_safetensors_metadata(
|
126 |
model_info.id, revision=revision
|
127 |
)
|
128 |
if meta:
|
129 |
+
model_size = sum(meta.parameter_count.values())
|
130 |
|
131 |
+
if model_size is not None:
|
132 |
+
# Adjust size for GPTQ models if needed
|
133 |
+
factor = (
|
134 |
+
8 if (precision == "GPTQ" or "gptq" in model_info.id.lower()) else 1
|
135 |
+
)
|
136 |
+
# Convert parameter count to billions
|
137 |
+
model_size = round((model_size / 1e9) * factor, 3)
|
138 |
+
logger.info(
|
139 |
+
LogFormatter.success(f"Model size: {model_size}B parameters")
|
140 |
+
)
|
141 |
+
return model_size, None
|
142 |
|
143 |
+
# Fallback: use file metadata (siblings) to estimate model size
|
144 |
+
logger.info(
|
145 |
+
"Safetensors metadata not available. Falling back to file metadata to estimate model size."
|
146 |
)
|
147 |
+
weight_file_extensions = [".bin", ".safetensors"]
|
148 |
+
fallback_size_bytes = 0
|
149 |
+
for sibling in model_info.siblings:
|
150 |
+
if hasattr(sibling, "rfilename") and sibling.size is not None:
|
151 |
+
if any(
|
152 |
+
sibling.rfilename.endswith(ext)
|
153 |
+
for ext in weight_file_extensions
|
154 |
+
):
|
155 |
+
fallback_size_bytes += sibling.size
|
156 |
+
|
157 |
+
if fallback_size_bytes > 0:
|
158 |
+
# Assume float16 storage where each parameter takes ~2 bytes.
|
159 |
+
# Then estimate parameter count and adjust for GPTQ if needed.
|
160 |
+
factor = (
|
161 |
+
8 if (precision == "GPTQ" or "gptq" in model_info.id.lower()) else 1
|
162 |
+
)
|
163 |
+
estimated_param_count = (fallback_size_bytes / 2) * factor
|
164 |
+
model_size = round(estimated_param_count / 1e9, 3) # in billions
|
165 |
+
logger.info(
|
166 |
+
LogFormatter.success(
|
167 |
+
f"Fallback model size: {model_size}B parameters"
|
168 |
+
)
|
169 |
+
)
|
170 |
+
return model_size, None
|
171 |
+
else:
|
172 |
+
return (
|
173 |
+
None,
|
174 |
+
"Model size could not be determined using file metadata fallback",
|
175 |
+
)
|
176 |
|
177 |
except Exception as e:
|
178 |
logger.error(LogFormatter.error(f"Error while determining model size: {e}"))
|