kaupane commited on
Commit
a7df152
·
verified ·
1 Parent(s): 2e49ba4

Update models/DiT.py

Browse files

Added PyTorchModelHubMixin to allow from_pretrained method

Files changed (1) hide show
  1. models/DiT.py +3 -1
models/DiT.py CHANGED
@@ -3,6 +3,7 @@ import torch.nn as nn
3
  import torch.nn.functional as F
4
  import math
5
  from timm.models.vision_transformer import PatchEmbed
 
6
 
7
  class TimestepEmbedder(nn.Module):
8
  """Module to create timestep's embedding."""
@@ -65,7 +66,8 @@ class DiTBlock(nn.Module):
65
  x = x * (1+gamma_2.unsqueeze(1)) + beta_2.unsqueeze(1)
66
  return x
67
 
68
- class DiT(nn.Module):
 
69
  def __init__(self,
70
  num_blocks=10,
71
  hidden_size=640,
 
3
  import torch.nn.functional as F
4
  import math
5
  from timm.models.vision_transformer import PatchEmbed
6
+ from huggingface_hub import PyTorchModelHubMixin
7
 
8
  class TimestepEmbedder(nn.Module):
9
  """Module to create timestep's embedding."""
 
66
  x = x * (1+gamma_2.unsqueeze(1)) + beta_2.unsqueeze(1)
67
  return x
68
 
69
+ class DiT(nn.Module,
70
+ PyTorchModelHubMixin):
71
  def __init__(self,
72
  num_blocks=10,
73
  hidden_size=640,