Spaces:
Running
on
Zero
Running
on
Zero
Update model.py
Browse files
model.py
CHANGED
@@ -1672,153 +1672,4 @@ class UNet(nn.Module):
|
|
1672 |
out
|
1673 |
) # (batch_size, self.conv_out_channels, h, w) -> (batch_size, image_channels, h, w)
|
1674 |
|
1675 |
-
return out # (batch_size, image_channels, h, w)
|
1676 |
-
|
1677 |
-
|
1678 |
-
def sample_ddpm_inference(
|
1679 |
-
unet,
|
1680 |
-
vae,
|
1681 |
-
text_prompt,
|
1682 |
-
mask_image_pil=None,
|
1683 |
-
guidance_scale=1.0,
|
1684 |
-
device=torch.device("cpu"),
|
1685 |
-
):
|
1686 |
-
"""
|
1687 |
-
Given a text prompt and (optionally) an image condition (as a PIL image),
|
1688 |
-
sample from the diffusion model and return a generated image (PIL image).
|
1689 |
-
"""
|
1690 |
-
# Create noise scheduler
|
1691 |
-
scheduler = LinearNoiseScheduler(
|
1692 |
-
num_timesteps=diffusion_params["num_timesteps"],
|
1693 |
-
beta_start=diffusion_params["beta_start"],
|
1694 |
-
beta_end=diffusion_params["beta_end"],
|
1695 |
-
)
|
1696 |
-
# Get conditioning config from ldm_params
|
1697 |
-
condition_config = ldm_params.get("condition_config", None)
|
1698 |
-
condition_types = (
|
1699 |
-
condition_config.get("condition_types", [])
|
1700 |
-
if condition_config is not None
|
1701 |
-
else []
|
1702 |
-
)
|
1703 |
-
|
1704 |
-
# Load text tokenizer/model for conditioning
|
1705 |
-
text_model_type = condition_config["text_condition_config"]["text_embed_model"]
|
1706 |
-
text_tokenizer, text_model = get_tokenizer_and_model(text_model_type, device=device)
|
1707 |
-
|
1708 |
-
# Get empty text representation for classifier-free guidance
|
1709 |
-
empty_text_embed = get_text_representation([""], text_tokenizer, text_model, device)
|
1710 |
-
|
1711 |
-
# Get text representation of the input prompt
|
1712 |
-
text_prompt_embed = get_text_representation(
|
1713 |
-
[text_prompt], text_tokenizer, text_model, device
|
1714 |
-
)
|
1715 |
-
|
1716 |
-
# Prepare image conditioning:
|
1717 |
-
# If the user uploaded a mask image (should be a PIL image), convert it; otherwise, use zeros.
|
1718 |
-
if "image" in condition_types:
|
1719 |
-
if mask_image_pil is not None:
|
1720 |
-
mask_transform = transforms.Compose(
|
1721 |
-
[
|
1722 |
-
transforms.Resize(
|
1723 |
-
(
|
1724 |
-
ldm_params["condition_config"]["image_condition_config"][
|
1725 |
-
"image_condition_h"
|
1726 |
-
],
|
1727 |
-
ldm_params["condition_config"]["image_condition_config"][
|
1728 |
-
"image_condition_w"
|
1729 |
-
],
|
1730 |
-
)
|
1731 |
-
),
|
1732 |
-
transforms.ToTensor(),
|
1733 |
-
]
|
1734 |
-
)
|
1735 |
-
mask_tensor = (
|
1736 |
-
mask_transform(mask_image_pil).unsqueeze(0).to(device)
|
1737 |
-
) # (1, channels, H, W)
|
1738 |
-
else:
|
1739 |
-
# Create a zero mask with the required number of channels (e.g. 18)
|
1740 |
-
ic = ldm_params["condition_config"]["image_condition_config"][
|
1741 |
-
"image_condition_input_channels"
|
1742 |
-
]
|
1743 |
-
H = ldm_params["condition_config"]["image_condition_config"][
|
1744 |
-
"image_condition_h"
|
1745 |
-
]
|
1746 |
-
W = ldm_params["condition_config"]["image_condition_config"][
|
1747 |
-
"image_condition_w"
|
1748 |
-
]
|
1749 |
-
mask_tensor = torch.zeros((1, ic, H, W), device=device)
|
1750 |
-
else:
|
1751 |
-
mask_tensor = None
|
1752 |
-
|
1753 |
-
# Build conditioning dictionaries for classifier-free guidance:
|
1754 |
-
# For unconditional, we use empty text and zero mask.
|
1755 |
-
uncond_input = {}
|
1756 |
-
cond_input = {}
|
1757 |
-
if "text" in condition_types:
|
1758 |
-
uncond_input["text"] = empty_text_embed
|
1759 |
-
cond_input["text"] = text_prompt_embed
|
1760 |
-
if "image" in condition_types:
|
1761 |
-
# Use zeros for unconditioning, and the provided mask for conditioning.
|
1762 |
-
uncond_input["image"] = torch.zeros_like(mask_tensor)
|
1763 |
-
cond_input["image"] = mask_tensor
|
1764 |
-
|
1765 |
-
# Load the diffusion UNet (and assume it has been pretrained and saved)
|
1766 |
-
# unet = UNet(
|
1767 |
-
# image_channels=autoencoder_params["z_channels"], model_config=ldm_params
|
1768 |
-
# ).to(device)
|
1769 |
-
# ldm_checkpoint_path = os.path.join(
|
1770 |
-
# train_params["task_name"], train_params["ldm_ckpt_name"]
|
1771 |
-
# )
|
1772 |
-
# if os.path.exists(ldm_checkpoint_path):
|
1773 |
-
# checkpoint = torch.load(ldm_checkpoint_path, map_location=device)
|
1774 |
-
# unet.load_state_dict(checkpoint["model_state_dict"])
|
1775 |
-
# unet.eval()
|
1776 |
-
|
1777 |
-
# Load VQVAE (assume pretrained and saved)
|
1778 |
-
# vae = VQVAE(
|
1779 |
-
# image_channels=dataset_params["image_channels"], model_config=autoencoder_params
|
1780 |
-
# ).to(device)
|
1781 |
-
# vae_checkpoint_path = os.path.join(
|
1782 |
-
# train_params["task_name"], train_params["vqvae_autoencoder_ckpt_name"]
|
1783 |
-
# )
|
1784 |
-
# if os.path.exists(vae_checkpoint_path):
|
1785 |
-
# checkpoint = torch.load(vae_checkpoint_path, map_location=device)
|
1786 |
-
# vae.load_state_dict(checkpoint["model_state_dict"])
|
1787 |
-
# vae.eval()
|
1788 |
-
|
1789 |
-
# Determine latent shape from VQVAE: (batch, z_channels, H_lat, W_lat)
|
1790 |
-
# For example, if image_size is 256 and there are 3 downsamplings, H_lat = 256 // 8 = 32.
|
1791 |
-
latent_size = dataset_params["image_size"] // (
|
1792 |
-
2 ** sum(autoencoder_params["down_sample"])
|
1793 |
-
)
|
1794 |
-
batch = train_params["num_samples"]
|
1795 |
-
z_channels = autoencoder_params["z_channels"]
|
1796 |
-
|
1797 |
-
# Sample initial latent noise
|
1798 |
-
xt = torch.randn((batch, z_channels, latent_size, latent_size), device=device)
|
1799 |
-
|
1800 |
-
# Sampling loop (reverse diffusion)
|
1801 |
-
T = diffusion_params["num_timesteps"]
|
1802 |
-
for i in reversed(range(T)):
|
1803 |
-
t = torch.full((batch,), i, dtype=torch.long, device=device)
|
1804 |
-
# Get conditional noise prediction
|
1805 |
-
noise_pred_cond = unet(xt, t, cond_input)
|
1806 |
-
if guidance_scale > 1:
|
1807 |
-
noise_pred_uncond = unet(xt, t, uncond_input)
|
1808 |
-
noise_pred = noise_pred_uncond + guidance_scale * (
|
1809 |
-
noise_pred_cond - noise_pred_uncond
|
1810 |
-
)
|
1811 |
-
else:
|
1812 |
-
noise_pred = noise_pred_cond
|
1813 |
-
xt, _ = scheduler.sample_prev_timestep(xt, noise_pred, t)
|
1814 |
-
|
1815 |
-
with torch.no_grad():
|
1816 |
-
generated = vae.decode(xt)
|
1817 |
-
|
1818 |
-
generated = torch.clamp(generated, -1, 1)
|
1819 |
-
generated = (generated + 1) / 2 # scale to [0,1]
|
1820 |
-
grid = make_grid(generated, nrow=1)
|
1821 |
-
pil_img = transforms.ToPILImage()(grid.cpu())
|
1822 |
-
|
1823 |
-
if i % 10 == 0:
|
1824 |
-
yield pil_img
|
|
|
1672 |
out
|
1673 |
) # (batch_size, self.conv_out_channels, h, w) -> (batch_size, image_channels, h, w)
|
1674 |
|
1675 |
+
return out # (batch_size, image_channels, h, w)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|