Spaces:
Build error
Build error
File size: 2,400 Bytes
65ccd88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import torch
import torch.nn as nn
class LoRALayer(torch.nn.Module):
def __init__(self, in_dim, out_dim, rank, alpha):
super().__init__()
std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
self.A = torch.nn.Parameter(torch.randn(in_dim, rank) * std_dev)
self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
self.alpha = alpha
def forward(self, x):
x = self.alpha * (x @ self.A @ self.B)
return x
class LinearWithLoRA(torch.nn.Module):
def __init__(self, linear, rank, alpha,
weak_lora_alpha=0.1, number_of_lora=1):
super().__init__()
self.linear = linear
self.lora = nn.ModuleList([LoRALayer(
linear.in_features, linear.out_features, rank, alpha
) for _ in range(number_of_lora)])
self.use_lora = True
self.lora_idx = 0
def forward(self, x):
if self.use_lora:
return self.linear(x) + self.lora[self.lora_idx](x)
else:
return self.linear(x)
def replace_linear_with_lora(module, rank=64, alpha=1., tag=0, weak_lora_alpha=0.1, number_of_lora=1):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
setattr(module, name, LinearWithLoRA(child, rank, alpha, weak_lora_alpha=weak_lora_alpha, number_of_lora=number_of_lora))
else:
replace_linear_with_lora(child, rank, alpha, tag, weak_lora_alpha=weak_lora_alpha, number_of_lora=number_of_lora)
def lora_false(model, lora_idx=0):
for name, module in model.named_modules():
if isinstance(module, LinearWithLoRA):
module.use_lora = False
module.lora_idx = lora_idx
def lora_true(model, lora_idx=0):
for name, module in model.named_modules():
if isinstance(module, LinearWithLoRA):
module.use_lora = True
module.lora_idx = lora_idx
for i, lora in enumerate(module.lora):
if i != lora_idx:
lora.A.requires_grad = False
lora.B.requires_grad = False
if lora.A.grad is not None:
del lora.A.grad
if lora.B.grad is not None:
del lora.B.grad
else:
lora.A.requires_grad = True
lora.B.requires_grad = True
|