Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
import numpy as np | |
from data.rg_masks import get_transforms | |
from models import tiramisu | |
from torchvision.transforms.functional import to_pil_image | |
import torch | |
from astropy.io import fits | |
def load_fits(path): | |
array = fits.getdata(path).astype(np.float32) | |
array = np.expand_dims(array, 2) | |
return array | |
def load_image(path): | |
image = Image.open(path) | |
array = np.array(image) | |
array = np.expand_dims(array[:,:,0], 2) | |
return array | |
def load_weights(model, fpath, device="cuda"): | |
print("loading weights '{}'".format(fpath)) | |
weights = torch.load(fpath, map_location=torch.device(device)) | |
model.load_state_dict(weights['state_dict']) | |
# Function to apply color overlay to the input image based on the segmentation mask | |
def apply_color_overlay(input_image, segmentation_mask, alpha=0.5): | |
r = (segmentation_mask == 1).float() | |
g = (segmentation_mask == 2).float() | |
b = (segmentation_mask == 3).float() | |
overlay = torch.cat([r, g, b], dim=0) | |
overlay = to_pil_image(overlay) | |
output = Image.blend(input_image, overlay, alpha=alpha) | |
return output | |
# Streamlit app | |
def main(): | |
st.title("Tiramisu for semantic segmentation of radio astronomy images") | |
st.write("Upload an image and see the segmentation result!") | |
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "fits"]) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = tiramisu.FCDenseNet67(n_classes=4).to(device) | |
load_weights(model, "weights/real.pth", device) | |
model.eval() | |
st.markdown( | |
""" | |
Category Legend: | |
- :blue[Extended] | |
- :green[Compact] | |
- :red[Spurious] | |
""" | |
) | |
if uploaded_image is not None: | |
# Load the uploaded image | |
if uploaded_image.name.endswith(".fits"): | |
input_array = load_fits(uploaded_image) | |
else: | |
input_array = load_image(uploaded_image) | |
input_array = input_array.transpose(2,0,1) | |
transforms = get_transforms(input_array.shape[1]) | |
image = transforms(input_array) | |
image = image.to(device) | |
with torch.no_grad(): | |
output = model(image) | |
preds = output.argmax(1) | |
pil_image = to_pil_image(image[0]) | |
# Apply color overlay to the input image | |
segmented_image = apply_color_overlay(pil_image, preds) | |
# Display the input image and the segmented output | |
st.image([pil_image, segmented_image], caption=["Input Image", "Segmented Output"], width=300) | |
if __name__ == "__main__": | |
main() |