File size: 4,057 Bytes
00f809d 9905491 00f809d 9905491 00f809d 9905491 00f809d 9905491 00f809d 9905491 00f809d 9905491 00f809d 9905491 00f809d 9905491 00f809d 9905491 00f809d 9905491 00f809d 9905491 00f809d 9905491 00f809d 9905491 00f809d 9905491 00f809d 9905491 00f809d 9905491 00f809d 9905491 00f809d 9905491 00f809d 9905491 00f809d 9905491 00f809d 9905491 00f809d 9905491 7ffa4f6 9905491 7ffa4f6 9905491 7ffa4f6 9905491 00f809d 9905491 ae64ae9 9905491 00f809d 9905491 ae64ae9 9905491 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
__all__ = [
"ORGAN",
"IMAGE_SIZE",
"MODEL_NAME",
"THRESHOLD",
"CODES",
"learn",
"title",
"description",
"examples",
"interpretation",
"demo",
"x_getter",
"y_getter",
"splitter",
"make3D",
"predict",
"infer",
"remove_small_segs",
"to_oberlay_image",
]
import numpy as np
import pandas as pd
import skimage
from fastai.vision.all import *
import segmentation_models_pytorch as smp
import gradio as gr
ORGAN = "kidney"
IMAGE_SIZE = 512
MODEL_NAME = "unetpp_b4_th60_d9414.pkl"
THRESHOLD = float(MODEL_NAME.split("_")[2][2:]) / 100.0
CODES = ["Background", "FTU"] # FTU = functional tissue unit
def x_getter(r):
return r["fnames"]
def y_getter(r):
rle = r["rle"]
shape = (int(r["img_height"]), int(r["img_width"]))
return rle_decode(rle, shape).T
def splitter(model):
enc_params = L(model.encoder.parameters())
dec_params = L(model.decoder.parameters())
sg_params = L(model.segmentation_head.parameters())
untrained_params = L([*dec_params, *sg_params])
return L([enc_params, untrained_params])
learn = load_learner(MODEL_NAME)
def make3D(t: np.array) -> np.array:
t = np.expand_dims(t, axis=2)
t = np.concatenate((t, t, t), axis=2)
return t
def predict(fn, cutoff_area=200):
data = infer(fn)
data = remove_small_segs(data, cutoff_area=cutoff_area)
return to_oberlay_image(data), data["df"]
def infer(fn):
img = PILImage.create(fn)
tf_img, _, _, preds = learn.predict(img, with_input=True)
mask = (F.softmax(preds.float(), dim=0) > THRESHOLD).int()[1]
mask = np.array(mask, dtype=np.uint8)
resized_image = Image.fromarray(
tf_img.numpy().transpose(1, 2, 0).astype(np.uint8)
).resize(img.shape)
resized_image = np.array(resized_image)
return {
"tf_image": tf_img.numpy().transpose(1, 2, 0).astype(np.uint8),
"tf_mask": mask,
}
def remove_small_segs(data, cutoff_area=250):
labeled_mask = skimage.measure.label(data["tf_mask"])
props = skimage.measure.regionprops(labeled_mask)
df = {"Glomerulus": [], "Area (in px)": []}
for i, prop in enumerate(props):
if prop.area < cutoff_area:
labeled_mask[labeled_mask == i + 1] = 0
continue
df["Glomerulus"].append(len(df["Glomerulus"]) + 1)
df["Area (in px)"].append(prop.area)
labeled_mask[labeled_mask > 0] = 1
data["tf_mask"] = labeled_mask.astype(np.uint8)
data["df"] = pd.DataFrame(df)
return data
def to_oberlay_image(data):
img, msk = data["tf_image"], data["tf_mask"]
msk_im = np.zeros_like(img)
# rgb code: 255, 80, 80
msk_im[:, :, 0] = 255
msk_im[:, :, 1] = 80
msk_im[:, :, 2] = 80
img = Image.fromarray(img).convert("RGBA")
msk_im = Image.fromarray(msk_im).convert("RGBA")
msk = Image.fromarray((msk * 255 * 0.5).astype(np.uint8))
img.paste(
msk_im,
(0, 0),
msk,
)
return img
title = "Glomerulus Segmentation"
description = """
A web app that segments glomeruli in histological kidney slices!
The model deployed here is a [UNet++](https://arxiv.org/abs/1807.10165) with an [efficientnet-b4](https://arxiv.org/abs/1905.11946) encoder from the [segmentation_models_pytorch](https://github.com/qubvel/segmentation_models.pytorch) library.
The provided example images are random subset of kidney slices from the [Human Protein Atlas](https://www.proteinatlas.org/). These have been collected separately from model training and have neither been part of the training, validation nor test set.
Here is my corresponding [blog post](https://fhatje.github.io/posts/glomseg/train_model.html).
"""
examples = [str(p) for p in get_image_files("example_images")]
interpretation = "default"
demo = gr.Interface(
fn=predict,
inputs=gr.components.Image(width=IMAGE_SIZE, height=IMAGE_SIZE),
outputs=[gr.components.Image(), gr.components.DataFrame()],
title=title,
description=description,
examples=examples,
)
demo.launch()
|