RishabA commited on
Commit
9835792
·
verified ·
1 Parent(s): b2cf8f4

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +1 -150
model.py CHANGED
@@ -1672,153 +1672,4 @@ class UNet(nn.Module):
1672
  out
1673
  ) # (batch_size, self.conv_out_channels, h, w) -> (batch_size, image_channels, h, w)
1674
 
1675
- return out # (batch_size, image_channels, h, w)
1676
-
1677
-
1678
- def sample_ddpm_inference(
1679
- unet,
1680
- vae,
1681
- text_prompt,
1682
- mask_image_pil=None,
1683
- guidance_scale=1.0,
1684
- device=torch.device("cpu"),
1685
- ):
1686
- """
1687
- Given a text prompt and (optionally) an image condition (as a PIL image),
1688
- sample from the diffusion model and return a generated image (PIL image).
1689
- """
1690
- # Create noise scheduler
1691
- scheduler = LinearNoiseScheduler(
1692
- num_timesteps=diffusion_params["num_timesteps"],
1693
- beta_start=diffusion_params["beta_start"],
1694
- beta_end=diffusion_params["beta_end"],
1695
- )
1696
- # Get conditioning config from ldm_params
1697
- condition_config = ldm_params.get("condition_config", None)
1698
- condition_types = (
1699
- condition_config.get("condition_types", [])
1700
- if condition_config is not None
1701
- else []
1702
- )
1703
-
1704
- # Load text tokenizer/model for conditioning
1705
- text_model_type = condition_config["text_condition_config"]["text_embed_model"]
1706
- text_tokenizer, text_model = get_tokenizer_and_model(text_model_type, device=device)
1707
-
1708
- # Get empty text representation for classifier-free guidance
1709
- empty_text_embed = get_text_representation([""], text_tokenizer, text_model, device)
1710
-
1711
- # Get text representation of the input prompt
1712
- text_prompt_embed = get_text_representation(
1713
- [text_prompt], text_tokenizer, text_model, device
1714
- )
1715
-
1716
- # Prepare image conditioning:
1717
- # If the user uploaded a mask image (should be a PIL image), convert it; otherwise, use zeros.
1718
- if "image" in condition_types:
1719
- if mask_image_pil is not None:
1720
- mask_transform = transforms.Compose(
1721
- [
1722
- transforms.Resize(
1723
- (
1724
- ldm_params["condition_config"]["image_condition_config"][
1725
- "image_condition_h"
1726
- ],
1727
- ldm_params["condition_config"]["image_condition_config"][
1728
- "image_condition_w"
1729
- ],
1730
- )
1731
- ),
1732
- transforms.ToTensor(),
1733
- ]
1734
- )
1735
- mask_tensor = (
1736
- mask_transform(mask_image_pil).unsqueeze(0).to(device)
1737
- ) # (1, channels, H, W)
1738
- else:
1739
- # Create a zero mask with the required number of channels (e.g. 18)
1740
- ic = ldm_params["condition_config"]["image_condition_config"][
1741
- "image_condition_input_channels"
1742
- ]
1743
- H = ldm_params["condition_config"]["image_condition_config"][
1744
- "image_condition_h"
1745
- ]
1746
- W = ldm_params["condition_config"]["image_condition_config"][
1747
- "image_condition_w"
1748
- ]
1749
- mask_tensor = torch.zeros((1, ic, H, W), device=device)
1750
- else:
1751
- mask_tensor = None
1752
-
1753
- # Build conditioning dictionaries for classifier-free guidance:
1754
- # For unconditional, we use empty text and zero mask.
1755
- uncond_input = {}
1756
- cond_input = {}
1757
- if "text" in condition_types:
1758
- uncond_input["text"] = empty_text_embed
1759
- cond_input["text"] = text_prompt_embed
1760
- if "image" in condition_types:
1761
- # Use zeros for unconditioning, and the provided mask for conditioning.
1762
- uncond_input["image"] = torch.zeros_like(mask_tensor)
1763
- cond_input["image"] = mask_tensor
1764
-
1765
- # Load the diffusion UNet (and assume it has been pretrained and saved)
1766
- # unet = UNet(
1767
- # image_channels=autoencoder_params["z_channels"], model_config=ldm_params
1768
- # ).to(device)
1769
- # ldm_checkpoint_path = os.path.join(
1770
- # train_params["task_name"], train_params["ldm_ckpt_name"]
1771
- # )
1772
- # if os.path.exists(ldm_checkpoint_path):
1773
- # checkpoint = torch.load(ldm_checkpoint_path, map_location=device)
1774
- # unet.load_state_dict(checkpoint["model_state_dict"])
1775
- # unet.eval()
1776
-
1777
- # Load VQVAE (assume pretrained and saved)
1778
- # vae = VQVAE(
1779
- # image_channels=dataset_params["image_channels"], model_config=autoencoder_params
1780
- # ).to(device)
1781
- # vae_checkpoint_path = os.path.join(
1782
- # train_params["task_name"], train_params["vqvae_autoencoder_ckpt_name"]
1783
- # )
1784
- # if os.path.exists(vae_checkpoint_path):
1785
- # checkpoint = torch.load(vae_checkpoint_path, map_location=device)
1786
- # vae.load_state_dict(checkpoint["model_state_dict"])
1787
- # vae.eval()
1788
-
1789
- # Determine latent shape from VQVAE: (batch, z_channels, H_lat, W_lat)
1790
- # For example, if image_size is 256 and there are 3 downsamplings, H_lat = 256 // 8 = 32.
1791
- latent_size = dataset_params["image_size"] // (
1792
- 2 ** sum(autoencoder_params["down_sample"])
1793
- )
1794
- batch = train_params["num_samples"]
1795
- z_channels = autoencoder_params["z_channels"]
1796
-
1797
- # Sample initial latent noise
1798
- xt = torch.randn((batch, z_channels, latent_size, latent_size), device=device)
1799
-
1800
- # Sampling loop (reverse diffusion)
1801
- T = diffusion_params["num_timesteps"]
1802
- for i in reversed(range(T)):
1803
- t = torch.full((batch,), i, dtype=torch.long, device=device)
1804
- # Get conditional noise prediction
1805
- noise_pred_cond = unet(xt, t, cond_input)
1806
- if guidance_scale > 1:
1807
- noise_pred_uncond = unet(xt, t, uncond_input)
1808
- noise_pred = noise_pred_uncond + guidance_scale * (
1809
- noise_pred_cond - noise_pred_uncond
1810
- )
1811
- else:
1812
- noise_pred = noise_pred_cond
1813
- xt, _ = scheduler.sample_prev_timestep(xt, noise_pred, t)
1814
-
1815
- with torch.no_grad():
1816
- generated = vae.decode(xt)
1817
-
1818
- generated = torch.clamp(generated, -1, 1)
1819
- generated = (generated + 1) / 2 # scale to [0,1]
1820
- grid = make_grid(generated, nrow=1)
1821
- pil_img = transforms.ToPILImage()(grid.cpu())
1822
-
1823
- if i % 10 == 0:
1824
- yield pil_img
 
1672
  out
1673
  ) # (batch_size, self.conv_out_channels, h, w) -> (batch_size, image_channels, h, w)
1674
 
1675
+ return out # (batch_size, image_channels, h, w)