amildravid4292 commited on
Commit
bc19114
·
verified ·
1 Parent(s): bf73861

Update lora_w2w.py

Browse files
Files changed (1) hide show
  1. lora_w2w.py +4 -5
lora_w2w.py CHANGED
@@ -82,8 +82,8 @@ class LoRAModule(nn.Module):
82
  if type(alpha) == torch.Tensor:
83
  alpha = alpha.detach().numpy()
84
  alpha = lora_dim if alpha is None or alpha == 0 else alpha
85
- self.scale = alpha / self.lora_dim
86
- self.scale = self.scale.bfloat16()
87
 
88
 
89
  self.multiplier = multiplier
@@ -95,11 +95,10 @@ class LoRAModule(nn.Module):
95
  del self.org_module
96
 
97
  def forward(self, x):
98
- print(self.org_forward(x).dtype)
99
- print((x@(([email protected])*self.std1+self.mean1).T).dtype)
100
 
101
  return self.org_forward(x) +\
102
- (x@(([email protected])*self.std1+self.mean1).T)@((([email protected])*self.std2+self.mean2))*self.multiplier*self.scale
103
 
104
 
105
 
 
82
  if type(alpha) == torch.Tensor:
83
  alpha = alpha.detach().numpy()
84
  alpha = lora_dim if alpha is None or alpha == 0 else alpha
85
+ # self.scale = alpha / self.lora_dim
86
+ # self.scale = self.scale.bfloat16()
87
 
88
 
89
  self.multiplier = multiplier
 
95
  del self.org_module
96
 
97
  def forward(self, x):
98
+
 
99
 
100
  return self.org_forward(x) +\
101
+ (x@(([email protected])*self.std1+self.mean1).T)@((([email protected])*self.std2+self.mean2))#*self.multiplier*self.scale
102
 
103
 
104