Tzktz's picture
Upload 7664 files
6fc683c verified
import torch.nn as nn
from .fairseq_dropout import FairseqDropout
class Lora(nn.Module):
def __init__(self, u_dim, r, v_dim, shared_u=None, shared_v=None, dropout=0.0):
super().__init__()
self.u_dim = u_dim
self.d_dim = r
self.v_dim = v_dim
self.dropout_module = FairseqDropout(dropout, module_name=self.__class__.__name__)
if shared_u is not None:
assert shared_u.weight.size() == (r, u_dim)
self.linear_u = shared_u
else:
self.linear_u = nn.Linear(u_dim, r)
if shared_v is not None:
assert shared_v.weight.size() == (v_dim, r)
self.linear_v = shared_v
else:
self.linear_v = nn.Linear(r, v_dim)
def forward(self, x):
x = self.linear_u(x)
x = self.dropout_module(x)
x = self.linear_v(x)
x = self.dropout_module(x)
return x
@classmethod
def ratio_r(cls, layerid): # only consider 6 layers: 2 1.5 1 0.5 0.5 0.5
if layerid < 4:
return 2 - 0.5 * layerid
else:
return 0.5