File size: 2,650 Bytes
80ab65e
 
 
be4ce2d
80ab65e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37b1d1c
80ab65e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import streamlit as st
from PIL import Image
import numpy as np
from data.rg_masks import get_transforms
from models import tiramisu
from torchvision.transforms.functional import to_pil_image
import torch
from astropy.io import fits


def load_fits(path):
    array = fits.getdata(path).astype(np.float32)
    array = np.expand_dims(array, 2)
    return array

def load_image(path):
    image = Image.open(path)
    array = np.array(image)
    array = np.expand_dims(array[:,:,0], 2)

    return array

def load_weights(model, fpath, device="cuda"):
    print("loading weights '{}'".format(fpath))
    weights = torch.load(fpath, map_location=torch.device(device))
    model.load_state_dict(weights['state_dict'])


# Function to apply color overlay to the input image based on the segmentation mask
def apply_color_overlay(input_image, segmentation_mask, alpha=0.5):
    r = (segmentation_mask == 1).float()
    g = (segmentation_mask == 2).float()
    b = (segmentation_mask == 3).float()
    overlay = torch.cat([r, g, b], dim=0)
    overlay = to_pil_image(overlay)
    output = Image.blend(input_image, overlay, alpha=alpha)
    return output

# Streamlit app
def main():
    st.title("Tiramisu for semantic segmentation of radio astronomy images")
    st.write("Upload an image and see the segmentation result!")

    uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "fits"])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = tiramisu.FCDenseNet67(n_classes=4).to(device)
    load_weights(model, "weights/real.pth", device)
    model.eval()

    st.markdown(
        """
        Category Legend:
        - :blue[Extended]
        - :green[Compact]
        - :red[Spurious]
        """
        )
    if uploaded_image is not None:
        # Load the uploaded image
        if uploaded_image.name.endswith(".fits"):
            input_array = load_fits(uploaded_image)
        else:
            input_array = load_image(uploaded_image)

        input_array = input_array.transpose(2,0,1)
        transforms = get_transforms(input_array.shape[1])
        image = transforms(input_array)
        image = image.to(device)

        with torch.no_grad():
            output = model(image)
        preds = output.argmax(1)

        pil_image = to_pil_image(image[0])
        # Apply color overlay to the input image
        segmented_image = apply_color_overlay(pil_image, preds)

        # Display the input image and the segmented output
        st.image([pil_image, segmented_image], caption=["Input Image", "Segmented Output"], width=300)

if __name__ == "__main__":
    main()