Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse filestest on .from_pretrained
app.py
CHANGED
@@ -34,6 +34,7 @@ def load_dit_model(dit_size):
|
|
34 |
# Configure model based on size
|
35 |
if dit_size == "S":
|
36 |
model = DiT(num_blocks=8, hidden_size=384, num_heads=6)
|
|
|
37 |
elif dit_size == "B":
|
38 |
model = DiT(num_blocks=12, hidden_size=640, num_heads=10)
|
39 |
elif dit_size == "L":
|
@@ -42,8 +43,8 @@ def load_dit_model(dit_size):
|
|
42 |
raise ValueError(f"Invalid DiT size: {dit_size}")
|
43 |
|
44 |
# Load checkpoint
|
45 |
-
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
46 |
-
model.load_state_dict(checkpoint["model_state_dict"])
|
47 |
|
48 |
return model
|
49 |
|
|
|
34 |
# Configure model based on size
|
35 |
if dit_size == "S":
|
36 |
model = DiT(num_blocks=8, hidden_size=384, num_heads=6)
|
37 |
+
model.from_pretrained("kaupane/DiT-Wikiart-Small")
|
38 |
elif dit_size == "B":
|
39 |
model = DiT(num_blocks=12, hidden_size=640, num_heads=10)
|
40 |
elif dit_size == "L":
|
|
|
43 |
raise ValueError(f"Invalid DiT size: {dit_size}")
|
44 |
|
45 |
# Load checkpoint
|
46 |
+
#checkpoint = torch.load(ckpt_path, map_location="cpu")
|
47 |
+
#model.load_state_dict(checkpoint["model_state_dict"])
|
48 |
|
49 |
return model
|
50 |
|