Daniel Kantor
commited on
Commit
·
0cfafdf
1
Parent(s):
87921bd
fixes in model size logic
Browse files
backend/app/utils/model_validation.py
CHANGED
@@ -88,28 +88,34 @@ class ModelValidator:
|
|
88 |
async def get_model_size(
|
89 |
self, model_info: Any, precision: str, base_model: str, revision: str
|
90 |
) -> Tuple[Optional[float], Optional[str]]:
|
91 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
use file metadata (summing the sizes of weight files) and estimate the parameter count.
|
96 |
"""
|
97 |
try:
|
98 |
logger.info(
|
99 |
LogFormatter.info(f"Checking model size for {model_info.modelId}")
|
100 |
)
|
101 |
|
102 |
-
# Check if model is adapter
|
103 |
is_adapter = any(
|
104 |
-
s.rfilename == "adapter_config.json"
|
105 |
for s in model_info.siblings
|
106 |
-
if hasattr(s, "rfilename")
|
107 |
)
|
108 |
|
109 |
-
model_size = None # will hold total parameter count
|
110 |
|
111 |
if is_adapter and base_model:
|
112 |
-
# For adapters, we need both adapter and base model
|
113 |
adapter_meta = await self.get_safetensors_metadata(
|
114 |
model_info.id, is_adapter=True, revision=revision
|
115 |
)
|
@@ -129,23 +135,37 @@ class ModelValidator:
|
|
129 |
model_size = sum(meta.parameter_count.values())
|
130 |
|
131 |
if model_size is not None:
|
132 |
-
# Adjust
|
133 |
factor = (
|
134 |
8 if (precision == "GPTQ" or "gptq" in model_info.id.lower()) else 1
|
135 |
)
|
136 |
-
# Convert parameter count to billions
|
137 |
model_size = round((model_size / 1e9) * factor, 3)
|
138 |
logger.info(
|
139 |
-
LogFormatter.success(
|
|
|
|
|
140 |
)
|
141 |
return model_size, None
|
142 |
|
143 |
-
# Fallback: use file metadata
|
144 |
logger.info(
|
145 |
"Safetensors metadata not available. Falling back to file metadata to estimate model size."
|
146 |
)
|
147 |
weight_file_extensions = [".bin", ".safetensors"]
|
148 |
fallback_size_bytes = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
for sibling in model_info.siblings:
|
150 |
if hasattr(sibling, "rfilename") and sibling.size is not None:
|
151 |
if any(
|
@@ -155,8 +175,8 @@ class ModelValidator:
|
|
155 |
fallback_size_bytes += sibling.size
|
156 |
|
157 |
if fallback_size_bytes > 0:
|
158 |
-
#
|
159 |
-
#
|
160 |
factor = (
|
161 |
8 if (precision == "GPTQ" or "gptq" in model_info.id.lower()) else 1
|
162 |
)
|
|
|
88 |
async def get_model_size(
|
89 |
self, model_info: Any, precision: str, base_model: str, revision: str
|
90 |
) -> Tuple[Optional[float], Optional[str]]:
|
91 |
+
"""
|
92 |
+
Get model size in billions of parameters.
|
93 |
+
|
94 |
+
First, try to use safetensors metadata (which includes a parameter count).
|
95 |
+
If that isn’t available, then as a fallback, use file metadata from the repository
|
96 |
+
to sum the sizes of weight files.
|
97 |
+
|
98 |
+
For the fallback, we assume (for example) that for float16 storage each parameter takes ~2 bytes.
|
99 |
+
For GPTQ models (detected via the precision argument or model ID), we adjust by a factor (e.g. 8).
|
100 |
|
101 |
+
Returns:
|
102 |
+
Tuple of (model_size_in_billions, error_message). If successful, error_message is None.
|
|
|
103 |
"""
|
104 |
try:
|
105 |
logger.info(
|
106 |
LogFormatter.info(f"Checking model size for {model_info.modelId}")
|
107 |
)
|
108 |
|
109 |
+
# Check if model is an adapter by looking for an adapter config file.
|
110 |
is_adapter = any(
|
111 |
+
hasattr(s, "rfilename") and s.rfilename == "adapter_config.json"
|
112 |
for s in model_info.siblings
|
|
|
113 |
)
|
114 |
|
115 |
+
model_size = None # This will hold the total parameter count if available.
|
116 |
|
117 |
if is_adapter and base_model:
|
118 |
+
# For adapters, we need to get both the adapter and base model metadata.
|
119 |
adapter_meta = await self.get_safetensors_metadata(
|
120 |
model_info.id, is_adapter=True, revision=revision
|
121 |
)
|
|
|
135 |
model_size = sum(meta.parameter_count.values())
|
136 |
|
137 |
if model_size is not None:
|
138 |
+
# Adjust for GPTQ models if necessary.
|
139 |
factor = (
|
140 |
8 if (precision == "GPTQ" or "gptq" in model_info.id.lower()) else 1
|
141 |
)
|
|
|
142 |
model_size = round((model_size / 1e9) * factor, 3)
|
143 |
logger.info(
|
144 |
+
LogFormatter.success(
|
145 |
+
f"Model size: {model_size}B parameters (from safetensors metadata)"
|
146 |
+
)
|
147 |
)
|
148 |
return model_size, None
|
149 |
|
150 |
+
# Fallback: use file metadata from the repository.
|
151 |
logger.info(
|
152 |
"Safetensors metadata not available. Falling back to file metadata to estimate model size."
|
153 |
)
|
154 |
weight_file_extensions = [".bin", ".safetensors"]
|
155 |
fallback_size_bytes = 0
|
156 |
+
|
157 |
+
# If model_info does not contain file metadata, re-fetch with files_metadata=True.
|
158 |
+
if not model_info.siblings or all(
|
159 |
+
getattr(s, "size", None) is None for s in model_info.siblings
|
160 |
+
):
|
161 |
+
logger.info(
|
162 |
+
"Re-fetching model info with file metadata for fallback estimation."
|
163 |
+
)
|
164 |
+
model_info = await asyncio.to_thread(
|
165 |
+
self.api.model_info, model_info.id, files_metadata=True
|
166 |
+
)
|
167 |
+
|
168 |
+
# Sum up the sizes of files that appear to be weight files.
|
169 |
for sibling in model_info.siblings:
|
170 |
if hasattr(sibling, "rfilename") and sibling.size is not None:
|
171 |
if any(
|
|
|
175 |
fallback_size_bytes += sibling.size
|
176 |
|
177 |
if fallback_size_bytes > 0:
|
178 |
+
# Estimate parameter count based on file size.
|
179 |
+
# For float16 weights we assume ~2 bytes per parameter.
|
180 |
factor = (
|
181 |
8 if (precision == "GPTQ" or "gptq" in model_info.id.lower()) else 1
|
182 |
)
|