File size: 2,400 Bytes
65ccd88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torch.nn as nn

class LoRALayer(torch.nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.A = torch.nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha

    def forward(self, x):
        x = self.alpha * (x @ self.A @ self.B)
        return x

class LinearWithLoRA(torch.nn.Module):
    def __init__(self, linear, rank, alpha, 
                 weak_lora_alpha=0.1, number_of_lora=1):
        super().__init__()
        self.linear = linear
        self.lora = nn.ModuleList([LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        ) for _ in range(number_of_lora)])
        self.use_lora = True
        self.lora_idx = 0
        
    def forward(self, x):
        if self.use_lora:
            return self.linear(x) + self.lora[self.lora_idx](x)
        else:
            return self.linear(x)

def replace_linear_with_lora(module, rank=64, alpha=1., tag=0, weak_lora_alpha=0.1, number_of_lora=1):
    for name, child in module.named_children():
        if isinstance(child, nn.Linear):
            setattr(module, name, LinearWithLoRA(child, rank, alpha, weak_lora_alpha=weak_lora_alpha, number_of_lora=number_of_lora))
        else:
            replace_linear_with_lora(child, rank, alpha, tag, weak_lora_alpha=weak_lora_alpha, number_of_lora=number_of_lora)


def lora_false(model, lora_idx=0):
    for name, module in model.named_modules():
        if isinstance(module, LinearWithLoRA):
            module.use_lora = False
            module.lora_idx = lora_idx

def lora_true(model, lora_idx=0):
    for name, module in model.named_modules():
        if isinstance(module, LinearWithLoRA):
            module.use_lora = True
            module.lora_idx = lora_idx
            for i, lora in enumerate(module.lora):
                if i != lora_idx:
                    lora.A.requires_grad = False
                    lora.B.requires_grad = False
                    if lora.A.grad is not None:
                        del lora.A.grad
                    if lora.B.grad is not None:
                        del lora.B.grad
                else:
                    lora.A.requires_grad = True
                    lora.B.requires_grad = True