amildravid4292 commited on
Commit
bf73861
·
verified ·
1 Parent(s): d76e471

Update lora_w2w.py

Browse files
Files changed (1) hide show
  1. lora_w2w.py +2 -1
lora_w2w.py CHANGED
@@ -83,7 +83,7 @@ class LoRAModule(nn.Module):
83
  alpha = alpha.detach().numpy()
84
  alpha = lora_dim if alpha is None or alpha == 0 else alpha
85
  self.scale = alpha / self.lora_dim
86
- #self.scale = self.scale.bfloat16()
87
 
88
 
89
  self.multiplier = multiplier
@@ -96,6 +96,7 @@ class LoRAModule(nn.Module):
96
 
97
  def forward(self, x):
98
  print(self.org_forward(x).dtype)
 
99
 
100
  return self.org_forward(x) +\
101
  (x@(([email protected])*self.std1+self.mean1).T)@((([email protected])*self.std2+self.mean2))*self.multiplier*self.scale
 
83
  alpha = alpha.detach().numpy()
84
  alpha = lora_dim if alpha is None or alpha == 0 else alpha
85
  self.scale = alpha / self.lora_dim
86
+ self.scale = self.scale.bfloat16()
87
 
88
 
89
  self.multiplier = multiplier
 
96
 
97
  def forward(self, x):
98
  print(self.org_forward(x).dtype)
99
+ print((x@(([email protected])*self.std1+self.mean1).T).dtype)
100
 
101
  return self.org_forward(x) +\
102
  (x@(([email protected])*self.std1+self.mean1).T)@((([email protected])*self.std2+self.mean2))*self.multiplier*self.scale