autodecoder
Browse files
app.py
CHANGED
@@ -135,21 +135,21 @@ clip = FrozenCLIPEmbedder()
|
|
135 |
clip.eval()
|
136 |
clip.to(device)
|
137 |
|
138 |
-
#
|
139 |
-
|
140 |
-
|
141 |
|
142 |
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
|
148 |
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
|
154 |
|
155 |
@spaces.GPU #[uncomment to use ZeroGPU]
|
|
|
135 |
clip.eval()
|
136 |
clip.to(device)
|
137 |
|
138 |
+
# Load autoencoder.
|
139 |
+
autoencoder = libs.autoencoder.get_model(**config_1.autoencoder)
|
140 |
+
autoencoder.to(device)
|
141 |
|
142 |
|
143 |
+
@torch.cuda.amp.autocast()
|
144 |
+
def encode(_batch: torch.Tensor) -> torch.Tensor:
|
145 |
+
"""Encode a batch of images using the autoencoder."""
|
146 |
+
return autoencoder.encode(_batch)
|
147 |
|
148 |
|
149 |
+
@torch.cuda.amp.autocast()
|
150 |
+
def decode(_batch: torch.Tensor) -> torch.Tensor:
|
151 |
+
"""Decode a batch of latent vectors using the autoencoder."""
|
152 |
+
return autoencoder.decode(_batch)
|
153 |
|
154 |
|
155 |
@spaces.GPU #[uncomment to use ZeroGPU]
|