Smith42's picture
init
55dabfb
import gradio as gr
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.subplots as sp
from datasets import load_dataset
import umap
# Load dataset function
def load_dataset_from_hub(dataset_name, split="test"):
try:
return load_dataset(dataset_name, split=split), None
except Exception as e:
return None, str(e)
# Create visualization function
def create_visualization(split, color_col, log):
# Load the dataset
dataset, error = load_dataset_from_hub("Smith42/galaxies_with_embeddings", split)
if error:
return None, f"Error loading dataset: {error}"
try:
embedding_cols = ["p16k00_pca", "p16k01_pca", "p16k10_pca"]
# Extract embeddings and color values
embeddings = dataset.select_columns(embedding_cols)
colors = np.array(dataset[color_col], dtype=float)
if log: colors = np.log(colors)
fig = sp.make_subplots(cols=3, subplot_titles=["k = 0%", "k = 1%", "k = 10%"])
ii = 0
for col in range(1, 4):
embedding_col = embedding_cols[ii]
emb_ar = np.array(embeddings[embedding_col])
df = pd.DataFrame({
'x': emb_ar[:, 0],
'y': emb_ar[:, 1],
'color': colors
}).dropna()
scatter = px.scatter(df, x='x', y='y', color='color')
fig.add_trace(scatter.data[0], row=1, col=col)
ii = ii + 1
return fig, None
except Exception as e:
return None, f"Error creating viz: {str(e)}"
property_groups = {
"Basic Identifiers": [
"dr8_id", "ra", "dec", "brickid", "objid", "file_name", "iauname"
],
"Galaxy Morphology": [
"smooth-or-featured_smooth_fraction", "smooth-or-featured_featured-or-disk_fraction",
"smooth-or-featured_artifact_fraction", "disk-edge-on_yes_fraction", "disk-edge-on_no_fraction",
"has-spiral-arms_yes_fraction", "has-spiral-arms_no_fraction",
"bar_strong_fraction", "bar_weak_fraction", "bar_no_fraction",
"bulge-size_dominant_fraction", "bulge-size_large_fraction", "bulge-size_moderate_fraction",
"bulge-size_small_fraction", "bulge-size_none_fraction",
"how-rounded_round_fraction", "how-rounded_in-between_fraction", "how-rounded_cigar-shaped_fraction",
"edge-on-bulge_boxy_fraction", "edge-on-bulge_none_fraction", "edge-on-bulge_rounded_fraction",
"spiral-winding_tight_fraction", "spiral-winding_medium_fraction", "spiral-winding_loose_fraction",
"spiral-arm-count_1_fraction", "spiral-arm-count_2_fraction", "spiral-arm-count_3_fraction",
"spiral-arm-count_4_fraction", "spiral-arm-count_more-than-4_fraction", "spiral-arm-count_cant-tell_fraction",
"merging_none_fraction", "merging_minor-disturbance_fraction", "merging_major-disturbance_fraction",
"merging_merger_fraction"
],
"Physical Size Parameters": [
"est_petro_th50", "est_petro_th50_kpc", "petro_theta", "petro_th50", "petro_th90",
"petro_phi50", "petro_phi90", "petro_ba50", "petro_ba90",
"elpetro_ba", "elpetro_phi", "elpetro_flux_r", "elpetro_theta_r"
],
"Photometric Properties": [
"mag_r_desi", "mag_g_desi", "mag_z_desi",
"mag_f", "mag_n", "mag_u", "mag_g", "mag_r", "mag_i", "mag_z",
"u_minus_r", "sersic_n", "sersic_ba", "sersic_phi",
"elpetro_absmag_f", "elpetro_absmag_n", "elpetro_absmag_u",
"elpetro_absmag_g", "elpetro_absmag_r", "elpetro_absmag_i", "elpetro_absmag_z",
"sersic_nmgy_f", "sersic_nmgy_n", "sersic_nmgy_u", "sersic_nmgy_g",
"sersic_nmgy_r", "sersic_nmgy_i", "sersic_nmgy_z"
],
"Mass and Redshift": [
"elpetro_mass", "elpetro_mass_log", "redshift", "redshift_nsa",
"redshift_ossy", "photo_z", "photo_zerr", "spec_z"
],
"Star Formation Properties": [
"fibre_sfr_avg", "fibre_sfr_entropy", "fibre_sfr_median", "fibre_sfr_mode",
"fibre_sfr_p16", "fibre_sfr_p2p5", "fibre_sfr_p84", "fibre_sfr_p97p5",
"fibre_ssfr_avg", "fibre_ssfr_entropy", "fibre_ssfr_median", "fibre_ssfr_mode",
"fibre_ssfr_p16", "fibre_ssfr_p2p5", "fibre_ssfr_p84", "fibre_ssfr_p97p5",
"total_ssfr_avg", "total_ssfr_entropy", "total_ssfr_flag", "total_ssfr_median",
"total_ssfr_mode", "total_ssfr_p16", "total_ssfr_p2p5", "total_ssfr_p84",
"total_ssfr_p97p5", "total_sfr_avg", "total_sfr_entropy", "total_sfr_flag",
"total_sfr_median", "total_sfr_mode", "total_sfr_p16", "total_sfr_p2p5",
"total_sfr_p84", "total_sfr_p97p5"
],
"AGN Properties": [
"log_l_oiii", "fwhm", "e_fwhm", "equiv_width", "log_l_ha",
"log_m_bh", "upper_e_log_m_bh", "lower_e_log_m_bh", "log_bolometric_l"
],
"HI Properties": [
"W50", "sigW", "W20", "HIflux", "sigflux", "SNR", "RMS",
"Dist", "sigDist", "logMH", "siglogMH"
],
"PhotoZ Catalog": [
"photoz_id", "ra_photoz", "dec_photoz", "mag_abs_g_photoz", "mag_abs_r_photoz",
"mag_abs_z_photoz", "mass_inf_photoz", "mass_med_photoz", "mass_sup_photoz",
"sfr_inf_photoz", "sfr_sup_photoz", "ssfr_inf_photoz", "ssfr_med_photoz",
"ssfr_sup_photoz", "sky_separation_arcsec_from_photoz"
]
}
# Define the Gradio interface
with gr.Blocks(title="Galaxy embeddings") as demo:
gr.Markdown("# Sparse galaxy embeddings")
with gr.Row():
split_input = gr.Dropdown(
label="Split",
value="test",
choices=["test", "validation"]
)
group_dropdown = gr.Dropdown(
label="Property category",
choices=list(property_groups.keys()),
value=list(property_groups.keys())[0]
)
color_col = gr.Dropdown(
label="Property",
choices=property_groups[list(property_groups.keys())[0]]
)
log = gr.Checkbox(
label="Take log?",
value=False
)
visualize_btn = gr.Button("Let's go!")
error_output = gr.Textbox(label="Errors", visible=False)
def update_properties(group):
return gr.update(choices=property_groups[group], value=property_groups[group][0])
group_dropdown.change(
fn=update_properties,
inputs=[group_dropdown],
outputs=[color_col]
)
with gr.Row():
plot_output = gr.Plot(label="Visualization")
visualize_btn.click(
fn=create_visualization,
inputs=[split_input, color_col, log],
outputs=[plot_output, error_output]
)
demo.launch()