Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,696 Bytes
95a02fd 77269e5 52f9197 be66f33 52f9197 be66f33 d71875f be66f33 95a02fd be66f33 f22f03c be66f33 52f9197 95a02fd a202591 ffaef20 a202591 95a02fd ffaef20 95a02fd 6e2a03e 95a02fd 77269e5 6e2a03e 77269e5 95a02fd 77269e5 f22f03c 95a02fd f44174e f22f03c 77269e5 52f9197 cd21582 52f9197 7387897 52f9197 cd21582 7387897 77269e5 7387897 be66f33 7387897 52f9197 be66f33 7387897 be66f33 cd21582 7387897 52f9197 f1e86bd cd21582 52f9197 cd21582 d71875f bf573cf d71875f cd21582 d71875f cd21582 d71875f cd21582 7387897 77269e5 f1e86bd bf573cf e71c8e0 95a02fd d71875f bf573cf d71875f cd21582 d71875f ffaef20 77269e5 ffaef20 a202591 f1e86bd cd21582 77269e5 cd21582 d71875f 7387897 bf573cf d71875f e71c8e0 bf573cf be66f33 3c00c76 e71c8e0 3c00c76 4b4a98b 3c00c76 be66f33 3c00c76 be66f33 3c00c76 52f9197 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import json
import os
from functools import cache
from pickle import load
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image as Image
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from msma import (
ScoreFlow,
build_model_from_config,
build_model_from_pickle,
config_presets,
)
@cache
def load_model(modeldir, preset="edm2-img64-s-fid", device="cpu"):
modeldir = f"{modeldir}/{preset}"
with open(f"{modeldir}/config.json", "rb") as f:
model_params = json.load(f)
scorenet = build_model_from_pickle(preset=preset)
model = ScoreFlow(scorenet, **model_params['PatchFlow'])
model.flow.load_state_dict(torch.load(f"{modeldir}/flow.pt"))
print("Loaded:", model_params)
return model.to(device)
@cache
def load_model_from_hub(preset, device):
cache_dir = "/tmp/"
if 'DNNLIB_CACHE_DIR' in os.environ:
cache_dir = os.environ["DNNLIB_CACHE_DIR"]
for fname in ['config.json', 'gmm.pkl', 'refscores.npz', 'model.safetensors' ]:
cached_fname = hf_hub_download(
repo_id="ahsanMah/localizing-edm",
subfolder=preset,
filename=fname,
cache_dir=cache_dir,
)
modeldir = os.path.dirname(cached_fname)
print("HF Cache Dir:", modeldir)
with open(f"{modeldir}/config.json", "rb") as f:
model_params = json.load(f)
print("Loaded:", model_params)
hf_checkpoint = f"{modeldir}/model.safetensors"
model = build_model_from_config(model_params)
model.load_state_dict(load_file(hf_checkpoint), strict=True)
model = model.eval().requires_grad_(False)
model.to(device)
return model, modeldir
@cache
def load_reference_scores(model_dir):
with np.load(f"{model_dir}/refscores.npz", "rb") as f:
ref_nll = f["arr_0"]
return ref_nll
def compute_gmm_likelihood(x_score, model_dir):
with open(f"{model_dir}/gmm.pkl", "rb") as f:
clf = load(f)
nll = -clf.score(x_score)
ref_nll = load_reference_scores(model_dir)
percentile = (ref_nll < nll).mean() * 100
return nll, percentile, ref_nll
def plot_against_reference(nll, ref_nll):
fig, ax = plt.subplots()
ax.hist(ref_nll, label="Reference Scores", bins=25)
ax.axvline(nll, label="Image Score", c="red", ls="--")
plt.legend()
fig.tight_layout()
return fig
def plot_heatmap(img: Image, heatmap: np.array):
fig, ax = plt.subplots()
cmap = plt.get_cmap("gist_heat")
h = -heatmap[0, 0].copy()
qmin, qmax = np.quantile(h, 0.8), np.quantile(h, 0.999)
h = np.clip(h, a_min=qmin, a_max=qmax)
h = (h - h.min()) / (h.max() - h.min())
h = cmap(h, bytes=True)[:, :, :3]
h = Image.fromarray(h).resize(img.size, resample=Image.Resampling.BILINEAR)
im = Image.blend(img, h, alpha=0.6)
return im
@torch.no_grad
def run_inference(model, img):
img = torch.nn.functional.interpolate(img, size=64, mode="bilinear")
score_norms = model.scorenet(img)
score_norms = score_norms.square().sum(dim=(2, 3, 4)) ** 0.5
img_likelihood = model(img).cpu().numpy()
score_norms = score_norms.cpu().numpy()
return img_likelihood, score_norms
def localize_anomalies(input_img, preset="edm2-img64-s-fid", load_from_hub=False):
orig_size = input_img.size
device = "cuda" if torch.cuda.is_available() else "cpu"
# img = center_crop_imagenet(64, img)
input_img = input_img.resize(size=(64, 64), resample=Image.Resampling.LANCZOS)
with torch.inference_mode():
img = np.array(input_img)
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
img = img.float().to(device)
if load_from_hub:
model, modeldir = load_model_from_hub(preset=preset, device=device)
else:
model = load_model(modeldir="models", preset=preset, device=device)
modeldir = f"models/{preset}"
img_likelihood, score_norms = run_inference(model, img)
nll, pct, ref_nll = compute_gmm_likelihood(
score_norms, model_dir=modeldir
)
outstr = f"Anomaly score: {nll:.3f} / {pct:.2f} percentile"
histplot = plot_against_reference(nll, ref_nll)
heatmapplot = plot_heatmap(input_img, img_likelihood)
heatmapplot = heatmapplot.resize(orig_size)
return outstr, heatmapplot, histplot
def build_demo(inference_fn):
demo = gr.Interface(
fn=inference_fn,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Dropdown(
choices=config_presets.keys(),
label="Score Model Preset",
info="The preset of the underlying score estimator. These are the EDM2 diffusion models from Karras et.al.",
),
gr.Checkbox(
label="HuggingFace Hub",
value=True,
info="Load a pretrained model from HuggingFace. Uncheck to use a model from `models` directory.",
),
],
outputs=[
gr.Text(
label="Estimated global outlier scores - Percentiles with respect to Imagenette Scores"
),
gr.Image(label="Anomaly Heatmap", min_width=160),
gr.Plot(label="Comparing to Imagenette"),
],
examples=[
["samples/duckelephant.jpeg", "edm2-img64-s-fid", True],
["samples/sharkhorse.jpeg", "edm2-img64-s-fid", True],
["samples/goldfish.jpeg", "edm2-img64-s-fid", True],
],
cache_examples=False,
)
return demo
demo = build_demo(localize_anomalies)
if __name__ == "__main__":
demo.launch()
|