GalaxyAnnotator / app.py
pawipa's picture
restore configs.
97e707e
import gradio as gr
from astropy.io import fits
import matplotlib.pyplot as plt
import numpy as np
import io
from PIL import Image
import astropy.units as u
from astropy.wcs import WCS
from astropy.coordinates import SkyCoord
from astropy import coordinates as coord
from astropy.wcs.utils import skycoord_to_pixel
from astroquery.simbad import Simbad
import pandas as pd
import matplotlib.patches as patches
# Increase the limit (set to a value larger than the pixel count of your image)
Image.MAX_IMAGE_PIXELS = None
plt.style.use('dark_background')
# Initialize globals
global_dataframe = pd.DataFrame()
global_data = None
global_header = None
def show_csv(file):
"""
Displays the uploaded CSV file as a table.
"""
global global_dataframe
try:
# Read the CSV file into a pandas DataFrame
df = pd.read_csv(file.name, index_col=0)
global_dataframe = df # Store the dataframe globally for filtering
# Extract unique types from the "type" column
if "TYPE" in df.columns:
unique_types = df["TYPE"].unique().tolist()
return df, gr.CheckboxGroup(label="Select Catalogue", choices=unique_types, value=unique_types, interactive=True)
else:
return "Error: CSV does not contain a 'type' column.", None
except Exception as e:
return f"Error: {str(e)}", None
# Define a function to be called when the button is clicked
def query_update_table():
"""
Displays the uploaded CSV file as a table.
"""
global global_dataframe, global_header, global_data
try:
# Read the CSV file into a pandas DataFrame
#df = pd.read_csv('dataframe.csv', index_col=0)
Simbad.TIMEOUT = 120
# Define the specific coordinates
wcs = WCS(global_header).dropaxis(2)
center_ra = global_header['CRVAL1']
center_dec = global_header['CRVAL2']
target_coord = SkyCoord(ra=center_ra, dec=center_dec, unit=(u.deg, u.deg), frame='icrs')
print(center_ra, center_dec)
# define the search radius
radius_deg = max([abs(global_header['CDELT1']),abs(global_header['CDELT2'])])*max([global_header['NAXIS1'],global_header['NAXIS2']])
radius_deg *= 1
# Set up the query criteria
if target_coord.dec.deg > 0:
custom_query = f"region(CIRCLE, {target_coord.ra.deg} +{target_coord.dec.deg}, {radius_deg}d)"
else:
custom_query = f"region(CIRCLE, {target_coord.ra.deg} {target_coord.dec.deg}, {radius_deg}d)"
print(f'Query={custom_query}')
result_table = Simbad.query_criteria(custom_query, otype='galaxy')
print("received feedback from simbad!!!")
print(result_table)
df = result_table.to_pandas().set_index('main_id')
print(df.columns)
df['Pixel_Position'] = [skycoord_to_pixel(SkyCoord(v[0],v[1], unit=(u.deg, u.deg), frame='icrs'), wcs) for v in df[['ra','dec']].values]
print(df['Pixel_Position'])
df['px'] = df['Pixel_Position'].apply(lambda x: int(x[0]))
df['py'] = df['Pixel_Position'].apply(lambda x: int(x[1]))
mask = (df.px>0)&(df.px< global_data.shape[1])&(df.py>0)&(df.py<global_data.shape[0])
print(df)
df = df[mask]
df = df.reset_index()
df['TYPE'] = df['main_id'].apply(lambda x: x.split(' ')[0].split('+')[0])
df = df.sort_values(by=['px', 'py'], ascending=[True, True]).reset_index(drop=True)
print(df)
#df = df.iloc[:200]
global_dataframe = df # Store the dataframe globally for filtering
# Extract unique types from the "type" column
if "TYPE" in df.columns:
unique_types = df["TYPE"].unique().tolist()
return df, gr.CheckboxGroup(label="Select Catalogue", choices=unique_types, value=unique_types, interactive=True)
else:
return "Error: CSV does not contain a 'type' column.", None
except Exception as e:
return f"Error: {str(e)}", None
def load_fits_image(file, type_checkboxes, title, axis_options, num_rows, patch_size, fontsize, alpha, linewidth, scale, patch_color, sort_method):
"""
Displays the data from the uploaded FITS file.
"""
global global_header, global_data
# Open the FITS file
hdu = fits.open(file)
data = hdu[0].data # Access the primary HDU data
data = np.swapaxes(np.swapaxes(data,0,2),0,1)#.astype(np.float)
#data = (data*255).astype(np.uint8) # Access the primary HDU data
global_data = data
# get fits header
header = hdu[0].header
global_header = header
#selected_types, title, selected_axis_options, num_rows, patch_size, patch_color, sort_method
return update_images_and_tables(type_checkboxes, title, axis_options, num_rows, patch_size, fontsize, alpha, linewidth, scale, patch_color, sort_method)
def update_images_and_tables(selected_types, title, selected_axis_options, num_rows, patch_size, fontsize, alpha, linewidth, scale, patch_color, sort_method):
global global_dataframe, global_header, global_data
if selected_types and not global_dataframe.empty:
# Filter the dataframe based on the selected types
filtered_df = global_dataframe[global_dataframe["TYPE"].isin(selected_types)]
mask = (filtered_df.px-patch_size//2 > 0)&(filtered_df.px+patch_size//2 < global_data.shape[1])&(filtered_df.py-patch_size//2 > 0)&(filtered_df.py+patch_size//2 < global_data.shape[0])
filtered_df = filtered_df[mask]
else:
filtered_df = None
if not filtered_df is None:
# Sort the dataframe based on the sorting method
if sort_method == "by Catalogue":
filtered_df = filtered_df.sort_values(by=['px', 'py'], ascending=[True, True])
filtered_df = filtered_df.sort_values(by='TYPE', ascending=True).reset_index(drop=True)
elif sort_method == "by x":
filtered_df = filtered_df.sort_values(by=['px', 'py'], ascending=[True, True]).reset_index(drop=True)
elif sort_method == "by y":
filtered_df = filtered_df.sort_values(by=['py', 'px'], ascending=[True, True]).reset_index(drop=True)
try:
wcs = WCS(global_header).dropaxis(2)
ratio = global_data.shape[0]/global_data.shape[1]
# Plot WCS
fig = plt.figure(figsize=(ratio*scale,scale))
ax = fig.add_subplot(projection=wcs, label='overlays')
ax.imshow(global_data, origin='lower')
#if not filtered_df is None:
# filtered_df.plot.scatter(x='px', y='py', ax=ax, s=15, c=patch_color)
if "with Grid" in selected_axis_options:
ax.coords.grid(True, color='white', ls='-', alpha=.5)
if "with Axis Annotation" in selected_axis_options:
ax.coords[0].set_axislabel('Right Ascension (J2000)', fontsize=fontsize+2)
ax.coords[1].set_axislabel('Declination (J2000)', fontsize=fontsize+2)
else:
ax.axis('off')
plt.title(title, fontsize=fontsize+4)
if not filtered_df is None:
all_patches = []
for i,row in filtered_df.iterrows():
rect = patches.Rectangle((row.px-patch_size//2, row.py-patch_size//2), patch_size, patch_size, alpha=alpha, linewidth=linewidth, edgecolor=patch_color, facecolor='none')
ax.add_patch(rect)
ax.text(row.px,row.py+patch_size//2,str(i+1),
ha='center',va='bottom',color=patch_color,fontsize=fontsize)
patch = global_data[row.py-patch_size//2:row.py+patch_size//2,row.px-patch_size//2:row.px+patch_size//2]
all_patches.append(patch)
plt.tight_layout()
# Convert the plot to an image
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=.1, dpi=200)
plt.close(fig)
buf.seek(0)
# Convert buffer to PIL Image
image = Image.open(buf)
if not filtered_df is None:
m = num_rows
n = int(np.ceil(len(filtered_df)/m))
second_scale=max([1,scale//3])
fig, axarr = plt.subplots(n,m,figsize=(m*second_scale,n*second_scale))
for i, row in filtered_df.iterrows():
ax = axarr[i//m,i%m]
ax.imshow(all_patches[i][::-1])
ax.set_title(row.main_id, fontsize=fontsize-2)
ax.set_xticks([])
ax.set_yticks([])
ax.text(2,2,str(i+1)[:30],ha='left',va='top',fontsize=fontsize+6)
for i in np.arange(len(all_patches),m*n):
ax = axarr[i//m,i%m]
ax.axis('off')
plt.tight_layout()
# Convert the plot to an image
second_buf = io.BytesIO()
plt.savefig(second_buf, format='png', bbox_inches='tight', pad_inches=.1, dpi=200)
plt.close(fig)
second_buf.seek(0)
# Convert buffer to PIL Image
patches_image = Image.open(second_buf)
return filtered_df, image, patches_image
else:
return filtered_df, image, None
except Exception as e:
return f"Error: {str(e)}"
# Gradio interface
with gr.Blocks(css=".btn-green {background-color: green; color: white;}") as gui:
gr.Markdown("# What's in my image?")
# Options Area
with gr.Row() as options_gui:
num_rows = gr.Number(label="Number of Rows", value=16, minimum=2, precision=0, interactive=True)
title = gr.Textbox(label="Image Title", value="Custom Title", interactive=True)
patch_size = gr.Slider(label="Patch Size", minimum=16, maximum=128, step=8, value=32,
interactive=True)
fontsize = gr.Slider(label="Fontsize", minimum=6, maximum=26, step=1, value=10,
interactive=True)
alpha = gr.Slider(label="Alpha", minimum=0., maximum=1., step=.1, value=1.,
interactive=True)
linewidth = gr.Slider(label="Linewidth", minimum=1, maximum=4, step=1, value=1,
interactive=True)
scale = gr.Slider(label="Scale", minimum=1, maximum=20, step=1, value=10,
interactive=True)
patch_color = gr.ColorPicker(label="Patch Color", value="#FFFFFF", interactive=True)
sort_method = gr.Dropdown(label="Sorting Method", choices=["by Catalogue", "by x", "by y"], value="by Catalogue", interactive=True)
axis_options = gr.CheckboxGroup(
label="Select options",
choices=["with Grid", "with Axis Annotation"],
value=["with Grid", "with Axis Annotation"], # Preselected values
interactive=True # Makes it interactive
)
gr.Markdown("Upload a plate solved `.fits` file (32 bit) to display its content.")
file_input = gr.File(label="Upload .fits File", type="filepath")
#file_input_csv = gr.File(label="Upload .csv File")
greet_button = gr.Button("Query Simbad for Galaxies") # Create the button
fits_image = gr.Image(label="Input Image", type="pil")
type_checkboxes = gr.CheckboxGroup(label="Select Catalogue")
patches_image = gr.Image(label="Patches Image", type="pil")
csv_table = gr.DataFrame(label="CSV Table")
track_options = [type_checkboxes, title, axis_options, num_rows, patch_size, fontsize, alpha, linewidth, scale, patch_color, sort_method]
file_input.change(load_fits_image,
inputs=[file_input] + track_options,
outputs=[csv_table,fits_image,patches_image])
for option_i in track_options:
option_i.change(update_images_and_tables,
inputs=track_options,
outputs=[csv_table,fits_image,patches_image])
# Display CSV table
#file_input_csv.change(show_csv,
# inputs=file_input_csv,
# outputs=[csv_table, type_checkboxes])
greet_button.click(query_update_table, inputs=None, outputs=[csv_table, type_checkboxes])
# Update the selected checkboxes change
type_checkboxes.change(update_images_and_tables,
inputs=track_options,
outputs=[csv_table,fits_image,patches_image])
gui.launch(debug=True)