Spaces:
Sleeping
Sleeping
Update models/DiT.py
Browse filesAdded PyTorchModelHubMixin to allow from_pretrained method
- models/DiT.py +3 -1
models/DiT.py
CHANGED
@@ -3,6 +3,7 @@ import torch.nn as nn
|
|
3 |
import torch.nn.functional as F
|
4 |
import math
|
5 |
from timm.models.vision_transformer import PatchEmbed
|
|
|
6 |
|
7 |
class TimestepEmbedder(nn.Module):
|
8 |
"""Module to create timestep's embedding."""
|
@@ -65,7 +66,8 @@ class DiTBlock(nn.Module):
|
|
65 |
x = x * (1+gamma_2.unsqueeze(1)) + beta_2.unsqueeze(1)
|
66 |
return x
|
67 |
|
68 |
-
class DiT(nn.Module
|
|
|
69 |
def __init__(self,
|
70 |
num_blocks=10,
|
71 |
hidden_size=640,
|
|
|
3 |
import torch.nn.functional as F
|
4 |
import math
|
5 |
from timm.models.vision_transformer import PatchEmbed
|
6 |
+
from huggingface_hub import PyTorchModelHubMixin
|
7 |
|
8 |
class TimestepEmbedder(nn.Module):
|
9 |
"""Module to create timestep's embedding."""
|
|
|
66 |
x = x * (1+gamma_2.unsqueeze(1)) + beta_2.unsqueeze(1)
|
67 |
return x
|
68 |
|
69 |
+
class DiT(nn.Module,
|
70 |
+
PyTorchModelHubMixin):
|
71 |
def __init__(self,
|
72 |
num_blocks=10,
|
73 |
hidden_size=640,
|