kaupane commited on
Commit
2e49ba4
·
verified ·
1 Parent(s): 9ad5c03

Update app.py

Browse files

test on .from_pretrained

Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -34,6 +34,7 @@ def load_dit_model(dit_size):
34
  # Configure model based on size
35
  if dit_size == "S":
36
  model = DiT(num_blocks=8, hidden_size=384, num_heads=6)
 
37
  elif dit_size == "B":
38
  model = DiT(num_blocks=12, hidden_size=640, num_heads=10)
39
  elif dit_size == "L":
@@ -42,8 +43,8 @@ def load_dit_model(dit_size):
42
  raise ValueError(f"Invalid DiT size: {dit_size}")
43
 
44
  # Load checkpoint
45
- checkpoint = torch.load(ckpt_path, map_location="cpu")
46
- model.load_state_dict(checkpoint["model_state_dict"])
47
 
48
  return model
49
 
 
34
  # Configure model based on size
35
  if dit_size == "S":
36
  model = DiT(num_blocks=8, hidden_size=384, num_heads=6)
37
+ model.from_pretrained("kaupane/DiT-Wikiart-Small")
38
  elif dit_size == "B":
39
  model = DiT(num_blocks=12, hidden_size=640, num_heads=10)
40
  elif dit_size == "L":
 
43
  raise ValueError(f"Invalid DiT size: {dit_size}")
44
 
45
  # Load checkpoint
46
+ #checkpoint = torch.load(ckpt_path, map_location="cpu")
47
+ #model.load_state_dict(checkpoint["model_state_dict"])
48
 
49
  return model
50