kaupane commited on
Commit
23f7c2b
·
verified ·
1 Parent(s): e8b2f0e

Update models/DiT.py

Browse files

Adjusted TimestepEmbedder.forward to fix tensor dtype inconsistency

Files changed (1) hide show
  1. models/DiT.py +5 -1
models/DiT.py CHANGED
@@ -22,7 +22,11 @@ class TimestepEmbedder(nn.Module):
22
  ).to(device=t.device)
23
  args = torch.einsum('i,j->ij', t, freqs.to(t.device))
24
  freqs = torch.cat([torch.cos(args),torch.sin(args)],dim=-1)
25
- return self.mlp(freqs)
 
 
 
 
26
 
27
  class ViTAttn(nn.Module):
28
  def __init__(self,hidden_size,num_heads):
 
22
  ).to(device=t.device)
23
  args = torch.einsum('i,j->ij', t, freqs.to(t.device))
24
  freqs = torch.cat([torch.cos(args),torch.sin(args)],dim=-1)
25
+
26
+ mlp_dtype = next(self.mlp.parameters()).dtype
27
+ freqs_casted = freqs.to(mlp_dtype)
28
+
29
+ return self.mlp(freqs_casted)
30
 
31
  class ViTAttn(nn.Module):
32
  def __init__(self,hidden_size,num_heads):