Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
12a6a7a
1
Parent(s):
6b12b83
add all
Browse files- app.py +10 -9
- zero123plus/pipeline.py +1 -1
app.py
CHANGED
@@ -58,16 +58,17 @@ def load_models():
|
|
58 |
|
59 |
# Load custom UNet
|
60 |
print('Loading custom UNet...')
|
61 |
-
unet_path = "best_21.ckpt"
|
62 |
-
state_dict = torch.load(unet_path, map_location='cpu')
|
63 |
|
64 |
-
# Process the state dict to match the model keys
|
65 |
-
if 'state_dict' in state_dict:
|
66 |
-
|
67 |
-
|
68 |
-
else:
|
69 |
-
|
70 |
-
|
|
|
71 |
pipeline = pipeline.to(device).to(torch_dtype=torch.float16)
|
72 |
|
73 |
# Load reconstruction model
|
|
|
58 |
|
59 |
# Load custom UNet
|
60 |
print('Loading custom UNet...')
|
61 |
+
# unet_path = "best_21.ckpt"
|
62 |
+
# state_dict = torch.load(unet_path, map_location='cpu')
|
63 |
|
64 |
+
# # Process the state dict to match the model keys
|
65 |
+
# if 'state_dict' in state_dict:
|
66 |
+
# new_state_dict = {key.replace('unet.unet.', ''): value for key, value in state_dict['state_dict'].items()}
|
67 |
+
# pipeline.unet.load_state_dict(new_state_dict, strict=False)
|
68 |
+
# else:
|
69 |
+
# pipeline.unet.load_state_dict(state_dict, strict=False)
|
70 |
+
# pipeline.unet.push_to_hub("YiftachEde/Sharp-It")
|
71 |
+
pipeline.unet = pipeline.unet.from_pretrained("YiftachEde/Sharp-It").to(torch.float16)
|
72 |
pipeline = pipeline.to(device).to(torch_dtype=torch.float16)
|
73 |
|
74 |
# Load reconstruction model
|
zero123plus/pipeline.py
CHANGED
@@ -974,7 +974,7 @@ class Zero123PlusPipeline(diffusers.StableDiffusionPipeline):
|
|
974 |
latent_model_input, t
|
975 |
)
|
976 |
latent_model_input = torch.cat([latent_model_input, cond_latent], dim=1)
|
977 |
-
|
978 |
# predict the noise residual
|
979 |
noise_pred = self.unet(
|
980 |
latent_model_input,
|
|
|
974 |
latent_model_input, t
|
975 |
)
|
976 |
latent_model_input = torch.cat([latent_model_input, cond_latent], dim=1)
|
977 |
+
# latent_model_input = latent_model_input.half()
|
978 |
# predict the noise residual
|
979 |
noise_pred = self.unet(
|
980 |
latent_model_input,
|