Spaces:
Sleeping
Sleeping
File size: 2,650 Bytes
80ab65e be4ce2d 80ab65e 37b1d1c 80ab65e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import streamlit as st
from PIL import Image
import numpy as np
from data.rg_masks import get_transforms
from models import tiramisu
from torchvision.transforms.functional import to_pil_image
import torch
from astropy.io import fits
def load_fits(path):
array = fits.getdata(path).astype(np.float32)
array = np.expand_dims(array, 2)
return array
def load_image(path):
image = Image.open(path)
array = np.array(image)
array = np.expand_dims(array[:,:,0], 2)
return array
def load_weights(model, fpath, device="cuda"):
print("loading weights '{}'".format(fpath))
weights = torch.load(fpath, map_location=torch.device(device))
model.load_state_dict(weights['state_dict'])
# Function to apply color overlay to the input image based on the segmentation mask
def apply_color_overlay(input_image, segmentation_mask, alpha=0.5):
r = (segmentation_mask == 1).float()
g = (segmentation_mask == 2).float()
b = (segmentation_mask == 3).float()
overlay = torch.cat([r, g, b], dim=0)
overlay = to_pil_image(overlay)
output = Image.blend(input_image, overlay, alpha=alpha)
return output
# Streamlit app
def main():
st.title("Tiramisu for semantic segmentation of radio astronomy images")
st.write("Upload an image and see the segmentation result!")
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "fits"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = tiramisu.FCDenseNet67(n_classes=4).to(device)
load_weights(model, "weights/real.pth", device)
model.eval()
st.markdown(
"""
Category Legend:
- :blue[Extended]
- :green[Compact]
- :red[Spurious]
"""
)
if uploaded_image is not None:
# Load the uploaded image
if uploaded_image.name.endswith(".fits"):
input_array = load_fits(uploaded_image)
else:
input_array = load_image(uploaded_image)
input_array = input_array.transpose(2,0,1)
transforms = get_transforms(input_array.shape[1])
image = transforms(input_array)
image = image.to(device)
with torch.no_grad():
output = model(image)
preds = output.argmax(1)
pil_image = to_pil_image(image[0])
# Apply color overlay to the input image
segmented_image = apply_color_overlay(pil_image, preds)
# Display the input image and the segmented output
st.image([pil_image, segmented_image], caption=["Input Image", "Segmented Output"], width=300)
if __name__ == "__main__":
main() |