YiftachEde commited on
Commit
12a6a7a
·
1 Parent(s): 6b12b83
Files changed (2) hide show
  1. app.py +10 -9
  2. zero123plus/pipeline.py +1 -1
app.py CHANGED
@@ -58,16 +58,17 @@ def load_models():
58
 
59
  # Load custom UNet
60
  print('Loading custom UNet...')
61
- unet_path = "best_21.ckpt"
62
- state_dict = torch.load(unet_path, map_location='cpu')
63
 
64
- # Process the state dict to match the model keys
65
- if 'state_dict' in state_dict:
66
- new_state_dict = {key.replace('unet.unet.', ''): value for key, value in state_dict['state_dict'].items()}
67
- pipeline.unet.load_state_dict(new_state_dict, strict=False)
68
- else:
69
- pipeline.unet.load_state_dict(state_dict, strict=False)
70
-
 
71
  pipeline = pipeline.to(device).to(torch_dtype=torch.float16)
72
 
73
  # Load reconstruction model
 
58
 
59
  # Load custom UNet
60
  print('Loading custom UNet...')
61
+ # unet_path = "best_21.ckpt"
62
+ # state_dict = torch.load(unet_path, map_location='cpu')
63
 
64
+ # # Process the state dict to match the model keys
65
+ # if 'state_dict' in state_dict:
66
+ # new_state_dict = {key.replace('unet.unet.', ''): value for key, value in state_dict['state_dict'].items()}
67
+ # pipeline.unet.load_state_dict(new_state_dict, strict=False)
68
+ # else:
69
+ # pipeline.unet.load_state_dict(state_dict, strict=False)
70
+ # pipeline.unet.push_to_hub("YiftachEde/Sharp-It")
71
+ pipeline.unet = pipeline.unet.from_pretrained("YiftachEde/Sharp-It").to(torch.float16)
72
  pipeline = pipeline.to(device).to(torch_dtype=torch.float16)
73
 
74
  # Load reconstruction model
zero123plus/pipeline.py CHANGED
@@ -974,7 +974,7 @@ class Zero123PlusPipeline(diffusers.StableDiffusionPipeline):
974
  latent_model_input, t
975
  )
976
  latent_model_input = torch.cat([latent_model_input, cond_latent], dim=1)
977
-
978
  # predict the noise residual
979
  noise_pred = self.unet(
980
  latent_model_input,
 
974
  latent_model_input, t
975
  )
976
  latent_model_input = torch.cat([latent_model_input, cond_latent], dim=1)
977
+ # latent_model_input = latent_model_input.half()
978
  # predict the noise residual
979
  noise_pred = self.unet(
980
  latent_model_input,