Daniel Kantor commited on
Commit
87921bd
·
1 Parent(s): afcd31f

add fallback logic for getting the model size

Browse files
Files changed (1) hide show
  1. backend/app/utils/model_validation.py +53 -18
backend/app/utils/model_validation.py CHANGED
@@ -88,7 +88,12 @@ class ModelValidator:
88
  async def get_model_size(
89
  self, model_info: Any, precision: str, base_model: str, revision: str
90
  ) -> Tuple[Optional[float], Optional[str]]:
91
- """Get model size in billions of parameters"""
 
 
 
 
 
92
  try:
93
  logger.info(
94
  LogFormatter.info(f"Checking model size for {model_info.modelId}")
@@ -101,43 +106,73 @@ class ModelValidator:
101
  if hasattr(s, "rfilename")
102
  )
103
 
104
- # Try to get size from safetensors first
105
- model_size = None
106
 
107
  if is_adapter and base_model:
108
- # For adapters, we need both adapter and base model sizes
109
  adapter_meta = await self.get_safetensors_metadata(
110
  model_info.id, is_adapter=True, revision=revision
111
  )
112
  base_meta = await self.get_safetensors_metadata(
113
  base_model, revision="main"
114
  )
115
-
116
  if adapter_meta and base_meta:
117
  adapter_size = sum(adapter_meta.parameter_count.values())
118
  base_size = sum(base_meta.parameter_count.values())
119
  model_size = adapter_size + base_size
120
  else:
121
- # For regular models, just get the model size
122
  meta = await self.get_safetensors_metadata(
123
  model_info.id, revision=revision
124
  )
125
  if meta:
126
- model_size = sum(meta.parameter_count.values()) # total params
127
 
128
- if model_size is None:
129
- # If model size could not be determined, return an error
130
- return None, "Model size could not be determined"
 
 
 
 
 
 
 
 
131
 
132
- # Adjust size for GPTQ models
133
- size_factor = (
134
- 8 if (precision == "GPTQ" or "gptq" in model_info.id.lower()) else 1
135
  )
136
- model_size = model_size / 1e9 # Convert to billions, assuming float16
137
- model_size = round(size_factor * model_size, 3)
138
-
139
- logger.info(LogFormatter.success(f"Model size: {model_size}B parameters"))
140
- return model_size, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  except Exception as e:
143
  logger.error(LogFormatter.error(f"Error while determining model size: {e}"))
 
88
  async def get_model_size(
89
  self, model_info: Any, precision: str, base_model: str, revision: str
90
  ) -> Tuple[Optional[float], Optional[str]]:
91
+ """Get model size in billions of parameters.
92
+
93
+ First, try to use safetensors metadata (which provides parameter counts).
94
+ If that isn’t available (i.e. for non-safetensors models), then as a fallback,
95
+ use file metadata (summing the sizes of weight files) and estimate the parameter count.
96
+ """
97
  try:
98
  logger.info(
99
  LogFormatter.info(f"Checking model size for {model_info.modelId}")
 
106
  if hasattr(s, "rfilename")
107
  )
108
 
109
+ model_size = None # will hold total parameter count (as a number)
 
110
 
111
  if is_adapter and base_model:
112
+ # For adapters, we need both adapter and base model sizes from safetensors metadata.
113
  adapter_meta = await self.get_safetensors_metadata(
114
  model_info.id, is_adapter=True, revision=revision
115
  )
116
  base_meta = await self.get_safetensors_metadata(
117
  base_model, revision="main"
118
  )
 
119
  if adapter_meta and base_meta:
120
  adapter_size = sum(adapter_meta.parameter_count.values())
121
  base_size = sum(base_meta.parameter_count.values())
122
  model_size = adapter_size + base_size
123
  else:
124
+ # For regular models, try to get the model size from safetensors metadata.
125
  meta = await self.get_safetensors_metadata(
126
  model_info.id, revision=revision
127
  )
128
  if meta:
129
+ model_size = sum(meta.parameter_count.values())
130
 
131
+ if model_size is not None:
132
+ # Adjust size for GPTQ models if needed
133
+ factor = (
134
+ 8 if (precision == "GPTQ" or "gptq" in model_info.id.lower()) else 1
135
+ )
136
+ # Convert parameter count to billions
137
+ model_size = round((model_size / 1e9) * factor, 3)
138
+ logger.info(
139
+ LogFormatter.success(f"Model size: {model_size}B parameters")
140
+ )
141
+ return model_size, None
142
 
143
+ # Fallback: use file metadata (siblings) to estimate model size
144
+ logger.info(
145
+ "Safetensors metadata not available. Falling back to file metadata to estimate model size."
146
  )
147
+ weight_file_extensions = [".bin", ".safetensors"]
148
+ fallback_size_bytes = 0
149
+ for sibling in model_info.siblings:
150
+ if hasattr(sibling, "rfilename") and sibling.size is not None:
151
+ if any(
152
+ sibling.rfilename.endswith(ext)
153
+ for ext in weight_file_extensions
154
+ ):
155
+ fallback_size_bytes += sibling.size
156
+
157
+ if fallback_size_bytes > 0:
158
+ # Assume float16 storage where each parameter takes ~2 bytes.
159
+ # Then estimate parameter count and adjust for GPTQ if needed.
160
+ factor = (
161
+ 8 if (precision == "GPTQ" or "gptq" in model_info.id.lower()) else 1
162
+ )
163
+ estimated_param_count = (fallback_size_bytes / 2) * factor
164
+ model_size = round(estimated_param_count / 1e9, 3) # in billions
165
+ logger.info(
166
+ LogFormatter.success(
167
+ f"Fallback model size: {model_size}B parameters"
168
+ )
169
+ )
170
+ return model_size, None
171
+ else:
172
+ return (
173
+ None,
174
+ "Model size could not be determined using file metadata fallback",
175
+ )
176
 
177
  except Exception as e:
178
  logger.error(LogFormatter.error(f"Error while determining model size: {e}"))