Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
class LoRALayer(torch.nn.Module): | |
def __init__(self, in_dim, out_dim, rank, alpha): | |
super().__init__() | |
std_dev = 1 / torch.sqrt(torch.tensor(rank).float()) | |
self.A = torch.nn.Parameter(torch.randn(in_dim, rank) * std_dev) | |
self.B = torch.nn.Parameter(torch.zeros(rank, out_dim)) | |
self.alpha = alpha | |
def forward(self, x): | |
x = self.alpha * (x @ self.A @ self.B) | |
return x | |
class LinearWithLoRA(torch.nn.Module): | |
def __init__(self, linear, rank, alpha, | |
weak_lora_alpha=0.1, number_of_lora=1): | |
super().__init__() | |
self.linear = linear | |
self.lora = nn.ModuleList([LoRALayer( | |
linear.in_features, linear.out_features, rank, alpha | |
) for _ in range(number_of_lora)]) | |
self.use_lora = True | |
self.lora_idx = 0 | |
def forward(self, x): | |
if self.use_lora: | |
return self.linear(x) + self.lora[self.lora_idx](x) | |
else: | |
return self.linear(x) | |
def replace_linear_with_lora(module, rank=64, alpha=1., tag=0, weak_lora_alpha=0.1, number_of_lora=1): | |
for name, child in module.named_children(): | |
if isinstance(child, nn.Linear): | |
setattr(module, name, LinearWithLoRA(child, rank, alpha, weak_lora_alpha=weak_lora_alpha, number_of_lora=number_of_lora)) | |
else: | |
replace_linear_with_lora(child, rank, alpha, tag, weak_lora_alpha=weak_lora_alpha, number_of_lora=number_of_lora) | |
def lora_false(model, lora_idx=0): | |
for name, module in model.named_modules(): | |
if isinstance(module, LinearWithLoRA): | |
module.use_lora = False | |
module.lora_idx = lora_idx | |
def lora_true(model, lora_idx=0): | |
for name, module in model.named_modules(): | |
if isinstance(module, LinearWithLoRA): | |
module.use_lora = True | |
module.lora_idx = lora_idx | |
for i, lora in enumerate(module.lora): | |
if i != lora_idx: | |
lora.A.requires_grad = False | |
lora.B.requires_grad = False | |
if lora.A.grad is not None: | |
del lora.A.grad | |
if lora.B.grad is not None: | |
del lora.B.grad | |
else: | |
lora.A.requires_grad = True | |
lora.B.requires_grad = True | |