Daniel Kantor commited on
Commit
0cfafdf
·
1 Parent(s): 87921bd

fixes in model size logic

Browse files
Files changed (1) hide show
  1. backend/app/utils/model_validation.py +35 -15
backend/app/utils/model_validation.py CHANGED
@@ -88,28 +88,34 @@ class ModelValidator:
88
  async def get_model_size(
89
  self, model_info: Any, precision: str, base_model: str, revision: str
90
  ) -> Tuple[Optional[float], Optional[str]]:
91
- """Get model size in billions of parameters.
 
 
 
 
 
 
 
 
92
 
93
- First, try to use safetensors metadata (which provides parameter counts).
94
- If that isn’t available (i.e. for non-safetensors models), then as a fallback,
95
- use file metadata (summing the sizes of weight files) and estimate the parameter count.
96
  """
97
  try:
98
  logger.info(
99
  LogFormatter.info(f"Checking model size for {model_info.modelId}")
100
  )
101
 
102
- # Check if model is adapter
103
  is_adapter = any(
104
- s.rfilename == "adapter_config.json"
105
  for s in model_info.siblings
106
- if hasattr(s, "rfilename")
107
  )
108
 
109
- model_size = None # will hold total parameter count (as a number)
110
 
111
  if is_adapter and base_model:
112
- # For adapters, we need both adapter and base model sizes from safetensors metadata.
113
  adapter_meta = await self.get_safetensors_metadata(
114
  model_info.id, is_adapter=True, revision=revision
115
  )
@@ -129,23 +135,37 @@ class ModelValidator:
129
  model_size = sum(meta.parameter_count.values())
130
 
131
  if model_size is not None:
132
- # Adjust size for GPTQ models if needed
133
  factor = (
134
  8 if (precision == "GPTQ" or "gptq" in model_info.id.lower()) else 1
135
  )
136
- # Convert parameter count to billions
137
  model_size = round((model_size / 1e9) * factor, 3)
138
  logger.info(
139
- LogFormatter.success(f"Model size: {model_size}B parameters")
 
 
140
  )
141
  return model_size, None
142
 
143
- # Fallback: use file metadata (siblings) to estimate model size
144
  logger.info(
145
  "Safetensors metadata not available. Falling back to file metadata to estimate model size."
146
  )
147
  weight_file_extensions = [".bin", ".safetensors"]
148
  fallback_size_bytes = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  for sibling in model_info.siblings:
150
  if hasattr(sibling, "rfilename") and sibling.size is not None:
151
  if any(
@@ -155,8 +175,8 @@ class ModelValidator:
155
  fallback_size_bytes += sibling.size
156
 
157
  if fallback_size_bytes > 0:
158
- # Assume float16 storage where each parameter takes ~2 bytes.
159
- # Then estimate parameter count and adjust for GPTQ if needed.
160
  factor = (
161
  8 if (precision == "GPTQ" or "gptq" in model_info.id.lower()) else 1
162
  )
 
88
  async def get_model_size(
89
  self, model_info: Any, precision: str, base_model: str, revision: str
90
  ) -> Tuple[Optional[float], Optional[str]]:
91
+ """
92
+ Get model size in billions of parameters.
93
+
94
+ First, try to use safetensors metadata (which includes a parameter count).
95
+ If that isn’t available, then as a fallback, use file metadata from the repository
96
+ to sum the sizes of weight files.
97
+
98
+ For the fallback, we assume (for example) that for float16 storage each parameter takes ~2 bytes.
99
+ For GPTQ models (detected via the precision argument or model ID), we adjust by a factor (e.g. 8).
100
 
101
+ Returns:
102
+ Tuple of (model_size_in_billions, error_message). If successful, error_message is None.
 
103
  """
104
  try:
105
  logger.info(
106
  LogFormatter.info(f"Checking model size for {model_info.modelId}")
107
  )
108
 
109
+ # Check if model is an adapter by looking for an adapter config file.
110
  is_adapter = any(
111
+ hasattr(s, "rfilename") and s.rfilename == "adapter_config.json"
112
  for s in model_info.siblings
 
113
  )
114
 
115
+ model_size = None # This will hold the total parameter count if available.
116
 
117
  if is_adapter and base_model:
118
+ # For adapters, we need to get both the adapter and base model metadata.
119
  adapter_meta = await self.get_safetensors_metadata(
120
  model_info.id, is_adapter=True, revision=revision
121
  )
 
135
  model_size = sum(meta.parameter_count.values())
136
 
137
  if model_size is not None:
138
+ # Adjust for GPTQ models if necessary.
139
  factor = (
140
  8 if (precision == "GPTQ" or "gptq" in model_info.id.lower()) else 1
141
  )
 
142
  model_size = round((model_size / 1e9) * factor, 3)
143
  logger.info(
144
+ LogFormatter.success(
145
+ f"Model size: {model_size}B parameters (from safetensors metadata)"
146
+ )
147
  )
148
  return model_size, None
149
 
150
+ # Fallback: use file metadata from the repository.
151
  logger.info(
152
  "Safetensors metadata not available. Falling back to file metadata to estimate model size."
153
  )
154
  weight_file_extensions = [".bin", ".safetensors"]
155
  fallback_size_bytes = 0
156
+
157
+ # If model_info does not contain file metadata, re-fetch with files_metadata=True.
158
+ if not model_info.siblings or all(
159
+ getattr(s, "size", None) is None for s in model_info.siblings
160
+ ):
161
+ logger.info(
162
+ "Re-fetching model info with file metadata for fallback estimation."
163
+ )
164
+ model_info = await asyncio.to_thread(
165
+ self.api.model_info, model_info.id, files_metadata=True
166
+ )
167
+
168
+ # Sum up the sizes of files that appear to be weight files.
169
  for sibling in model_info.siblings:
170
  if hasattr(sibling, "rfilename") and sibling.size is not None:
171
  if any(
 
175
  fallback_size_bytes += sibling.size
176
 
177
  if fallback_size_bytes > 0:
178
+ # Estimate parameter count based on file size.
179
+ # For float16 weights we assume ~2 bytes per parameter.
180
  factor = (
181
  8 if (precision == "GPTQ" or "gptq" in model_info.id.lower()) else 1
182
  )