File size: 4,057 Bytes
00f809d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9905491
 
 
 
 
 
 
 
 
 
 
 
00f809d
 
 
9905491
00f809d
 
 
 
 
9905491
 
 
00f809d
 
 
9905491
 
 
 
 
 
00f809d
9905491
 
00f809d
9905491
 
00f809d
9905491
 
00f809d
9905491
 
 
 
 
00f809d
9905491
 
00f809d
 
9905491
00f809d
 
 
9905491
 
 
00f809d
9905491
 
00f809d
9905491
 
 
00f809d
9905491
00f809d
 
9905491
00f809d
9905491
00f809d
9905491
 
 
 
00f809d
9905491
 
 
 
00f809d
 
 
9905491
 
00f809d
9905491
00f809d
 
 
 
 
9905491
 
00f809d
9905491
 
7ffa4f6
9905491
7ffa4f6
9905491
7ffa4f6
9905491
00f809d
9905491
ae64ae9
9905491
00f809d
9905491
 
 
ae64ae9
9905491
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
__all__ = [
    "ORGAN",
    "IMAGE_SIZE",
    "MODEL_NAME",
    "THRESHOLD",
    "CODES",
    "learn",
    "title",
    "description",
    "examples",
    "interpretation",
    "demo",
    "x_getter",
    "y_getter",
    "splitter",
    "make3D",
    "predict",
    "infer",
    "remove_small_segs",
    "to_oberlay_image",
]

import numpy as np
import pandas as pd
import skimage
from fastai.vision.all import *
import segmentation_models_pytorch as smp

import gradio as gr

ORGAN = "kidney"
IMAGE_SIZE = 512
MODEL_NAME = "unetpp_b4_th60_d9414.pkl"
THRESHOLD = float(MODEL_NAME.split("_")[2][2:]) / 100.0
CODES = ["Background", "FTU"]  # FTU = functional tissue unit


def x_getter(r):
    return r["fnames"]


def y_getter(r):
    rle = r["rle"]
    shape = (int(r["img_height"]), int(r["img_width"]))
    return rle_decode(rle, shape).T


def splitter(model):
    enc_params = L(model.encoder.parameters())
    dec_params = L(model.decoder.parameters())
    sg_params = L(model.segmentation_head.parameters())
    untrained_params = L([*dec_params, *sg_params])
    return L([enc_params, untrained_params])


learn = load_learner(MODEL_NAME)


def make3D(t: np.array) -> np.array:
    t = np.expand_dims(t, axis=2)
    t = np.concatenate((t, t, t), axis=2)
    return t


def predict(fn, cutoff_area=200):
    data = infer(fn)
    data = remove_small_segs(data, cutoff_area=cutoff_area)
    return to_oberlay_image(data), data["df"]


def infer(fn):
    img = PILImage.create(fn)
    tf_img, _, _, preds = learn.predict(img, with_input=True)
    mask = (F.softmax(preds.float(), dim=0) > THRESHOLD).int()[1]
    mask = np.array(mask, dtype=np.uint8)
    resized_image = Image.fromarray(
        tf_img.numpy().transpose(1, 2, 0).astype(np.uint8)
    ).resize(img.shape)
    resized_image = np.array(resized_image)
    return {
        "tf_image": tf_img.numpy().transpose(1, 2, 0).astype(np.uint8),
        "tf_mask": mask,
    }


def remove_small_segs(data, cutoff_area=250):
    labeled_mask = skimage.measure.label(data["tf_mask"])
    props = skimage.measure.regionprops(labeled_mask)
    df = {"Glomerulus": [], "Area (in px)": []}
    for i, prop in enumerate(props):
        if prop.area < cutoff_area:
            labeled_mask[labeled_mask == i + 1] = 0
            continue
        df["Glomerulus"].append(len(df["Glomerulus"]) + 1)
        df["Area (in px)"].append(prop.area)
    labeled_mask[labeled_mask > 0] = 1
    data["tf_mask"] = labeled_mask.astype(np.uint8)
    data["df"] = pd.DataFrame(df)
    return data


def to_oberlay_image(data):
    img, msk = data["tf_image"], data["tf_mask"]
    msk_im = np.zeros_like(img)
    # rgb code: 255, 80, 80
    msk_im[:, :, 0] = 255
    msk_im[:, :, 1] = 80
    msk_im[:, :, 2] = 80
    img = Image.fromarray(img).convert("RGBA")
    msk_im = Image.fromarray(msk_im).convert("RGBA")
    msk = Image.fromarray((msk * 255 * 0.5).astype(np.uint8))

    img.paste(
        msk_im,
        (0, 0),
        msk,
    )
    return img


title = "Glomerulus Segmentation"
description = """
A web app that segments glomeruli in histological kidney slices!

The model deployed here is a [UNet++](https://arxiv.org/abs/1807.10165) with an [efficientnet-b4](https://arxiv.org/abs/1905.11946) encoder from the [segmentation_models_pytorch](https://github.com/qubvel/segmentation_models.pytorch) library.

The provided example images are random subset of kidney slices from the [Human Protein Atlas](https://www.proteinatlas.org/). These have been collected separately from model training and have neither been part of the training, validation nor test set.

Here is my corresponding [blog post](https://fhatje.github.io/posts/glomseg/train_model.html).
"""

examples = [str(p) for p in get_image_files("example_images")]
interpretation = "default"

demo = gr.Interface(
    fn=predict,
    inputs=gr.components.Image(width=IMAGE_SIZE, height=IMAGE_SIZE),
    outputs=[gr.components.Image(), gr.components.DataFrame()],
    title=title,
    description=description,
    examples=examples,
)

demo.launch()