Spaces:
Paused
Paused
from functools import partial | |
from optimum.quanto.tensor import QTensor | |
import torch | |
def hacked_state_dict(self, *args, **kwargs): | |
orig_state_dict = self.orig_state_dict(*args, **kwargs) | |
new_state_dict = {} | |
for key, value in orig_state_dict.items(): | |
if key.endswith("._scale"): | |
continue | |
if key.endswith(".input_scale"): | |
continue | |
if key.endswith(".output_scale"): | |
continue | |
if key.endswith("._data"): | |
key = key[:-6] | |
scale = orig_state_dict[key + "._scale"] | |
# scale is the original dtype | |
dtype = scale.dtype | |
scale = scale.float() | |
value = value.float() | |
dequantized = value * scale | |
# handle input and output scaling if they exist | |
input_scale = orig_state_dict.get(key + ".input_scale") | |
if input_scale is not None: | |
# make sure the tensor is 1.0 | |
if input_scale.item() != 1.0: | |
raise ValueError("Input scale is not 1.0, cannot dequantize") | |
output_scale = orig_state_dict.get(key + ".output_scale") | |
if output_scale is not None: | |
# make sure the tensor is 1.0 | |
if output_scale.item() != 1.0: | |
raise ValueError("Output scale is not 1.0, cannot dequantize") | |
new_state_dict[key] = dequantized.to('cpu', dtype=dtype) | |
else: | |
new_state_dict[key] = value | |
return new_state_dict | |
# hacks the state dict so we can dequantize before saving | |
def patch_dequantization_on_save(model): | |
model.orig_state_dict = model.state_dict | |
model.state_dict = partial(hacked_state_dict, model) | |
def dequantize_parameter(module: torch.nn.Module, param_name: str) -> bool: | |
""" | |
Convert a quantized parameter back to a regular Parameter with floating point values. | |
Args: | |
module: The module containing the parameter to unquantize | |
param_name: Name of the parameter to unquantize (e.g., 'weight', 'bias') | |
Returns: | |
bool: True if parameter was unquantized, False if it was already unquantized | |
""" | |
# Check if the parameter exists | |
if not hasattr(module, param_name): | |
raise AttributeError(f"Module has no parameter named '{param_name}'") | |
param = getattr(module, param_name) | |
# If it's not a parameter or not quantized, nothing to do | |
if not isinstance(param, torch.nn.Parameter): | |
raise TypeError(f"'{param_name}' is not a Parameter") | |
if not isinstance(param, QTensor): | |
return False | |
# Convert to float tensor while preserving device and requires_grad | |
with torch.no_grad(): | |
float_tensor = param.float() | |
new_param = torch.nn.Parameter( | |
float_tensor, | |
requires_grad=param.requires_grad | |
) | |
# Replace the parameter | |
setattr(module, param_name, new_param) | |
return True |