File size: 3,052 Bytes
1c72248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88


from functools import partial
from optimum.quanto.tensor import QTensor
import torch


def hacked_state_dict(self, *args, **kwargs):
    orig_state_dict = self.orig_state_dict(*args, **kwargs)
    new_state_dict = {}
    for key, value in orig_state_dict.items():
        if key.endswith("._scale"):
            continue
        if key.endswith(".input_scale"):
            continue
        if key.endswith(".output_scale"):
            continue
        if key.endswith("._data"):
            key = key[:-6]
            scale = orig_state_dict[key + "._scale"]
            # scale is the original dtype
            dtype = scale.dtype
            scale = scale.float()
            value = value.float()
            dequantized = value * scale
            
            # handle input and output scaling if they exist
            input_scale = orig_state_dict.get(key + ".input_scale")
            
            if input_scale is not None:
                # make sure the tensor is 1.0
                if input_scale.item() != 1.0:
                    raise ValueError("Input scale is not 1.0, cannot dequantize")
                
            output_scale = orig_state_dict.get(key + ".output_scale")
            
            if output_scale is not None:
                # make sure the tensor is 1.0
                if output_scale.item() != 1.0:
                    raise ValueError("Output scale is not 1.0, cannot dequantize")
            
            new_state_dict[key] = dequantized.to('cpu', dtype=dtype)
        else:
            new_state_dict[key] = value
    return new_state_dict

# hacks the state dict so we can dequantize before saving
def patch_dequantization_on_save(model):
    model.orig_state_dict = model.state_dict
    model.state_dict = partial(hacked_state_dict, model)
  
  
def dequantize_parameter(module: torch.nn.Module, param_name: str) -> bool:
    """
    Convert a quantized parameter back to a regular Parameter with floating point values.
    
    Args:
        module: The module containing the parameter to unquantize
        param_name: Name of the parameter to unquantize (e.g., 'weight', 'bias')
    
    Returns:
        bool: True if parameter was unquantized, False if it was already unquantized
    """
    
    # Check if the parameter exists
    if not hasattr(module, param_name):
        raise AttributeError(f"Module has no parameter named '{param_name}'")
    
    param = getattr(module, param_name)
    
    # If it's not a parameter or not quantized, nothing to do
    if not isinstance(param, torch.nn.Parameter):
        raise TypeError(f"'{param_name}' is not a Parameter")
    if not isinstance(param, QTensor):
        return False
        
    # Convert to float tensor while preserving device and requires_grad
    with torch.no_grad():
        float_tensor = param.float()
        new_param = torch.nn.Parameter(
            float_tensor,
            requires_grad=param.requires_grad
        )
    
    # Replace the parameter
    setattr(module, param_name, new_param)
    
    return True