Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import plotly.express as px | |
import plotly.subplots as sp | |
from datasets import load_dataset | |
import umap | |
# Load dataset function | |
def load_dataset_from_hub(dataset_name, split="test"): | |
try: | |
return load_dataset(dataset_name, split=split), None | |
except Exception as e: | |
return None, str(e) | |
# Create visualization function | |
def create_visualization(split, color_col, log): | |
# Load the dataset | |
dataset, error = load_dataset_from_hub("Smith42/galaxies_with_embeddings", split) | |
if error: | |
return None, f"Error loading dataset: {error}" | |
try: | |
embedding_cols = ["p16k00_pca", "p16k01_pca", "p16k10_pca"] | |
# Extract embeddings and color values | |
embeddings = dataset.select_columns(embedding_cols) | |
colors = np.array(dataset[color_col], dtype=float) | |
if log: colors = np.log(colors) | |
fig = sp.make_subplots(cols=3, subplot_titles=["k = 0%", "k = 1%", "k = 10%"]) | |
ii = 0 | |
for col in range(1, 4): | |
embedding_col = embedding_cols[ii] | |
emb_ar = np.array(embeddings[embedding_col]) | |
df = pd.DataFrame({ | |
'x': emb_ar[:, 0], | |
'y': emb_ar[:, 1], | |
'color': colors | |
}).dropna() | |
scatter = px.scatter(df, x='x', y='y', color='color') | |
fig.add_trace(scatter.data[0], row=1, col=col) | |
ii = ii + 1 | |
return fig, None | |
except Exception as e: | |
return None, f"Error creating viz: {str(e)}" | |
property_groups = { | |
"Basic Identifiers": [ | |
"dr8_id", "ra", "dec", "brickid", "objid", "file_name", "iauname" | |
], | |
"Galaxy Morphology": [ | |
"smooth-or-featured_smooth_fraction", "smooth-or-featured_featured-or-disk_fraction", | |
"smooth-or-featured_artifact_fraction", "disk-edge-on_yes_fraction", "disk-edge-on_no_fraction", | |
"has-spiral-arms_yes_fraction", "has-spiral-arms_no_fraction", | |
"bar_strong_fraction", "bar_weak_fraction", "bar_no_fraction", | |
"bulge-size_dominant_fraction", "bulge-size_large_fraction", "bulge-size_moderate_fraction", | |
"bulge-size_small_fraction", "bulge-size_none_fraction", | |
"how-rounded_round_fraction", "how-rounded_in-between_fraction", "how-rounded_cigar-shaped_fraction", | |
"edge-on-bulge_boxy_fraction", "edge-on-bulge_none_fraction", "edge-on-bulge_rounded_fraction", | |
"spiral-winding_tight_fraction", "spiral-winding_medium_fraction", "spiral-winding_loose_fraction", | |
"spiral-arm-count_1_fraction", "spiral-arm-count_2_fraction", "spiral-arm-count_3_fraction", | |
"spiral-arm-count_4_fraction", "spiral-arm-count_more-than-4_fraction", "spiral-arm-count_cant-tell_fraction", | |
"merging_none_fraction", "merging_minor-disturbance_fraction", "merging_major-disturbance_fraction", | |
"merging_merger_fraction" | |
], | |
"Physical Size Parameters": [ | |
"est_petro_th50", "est_petro_th50_kpc", "petro_theta", "petro_th50", "petro_th90", | |
"petro_phi50", "petro_phi90", "petro_ba50", "petro_ba90", | |
"elpetro_ba", "elpetro_phi", "elpetro_flux_r", "elpetro_theta_r" | |
], | |
"Photometric Properties": [ | |
"mag_r_desi", "mag_g_desi", "mag_z_desi", | |
"mag_f", "mag_n", "mag_u", "mag_g", "mag_r", "mag_i", "mag_z", | |
"u_minus_r", "sersic_n", "sersic_ba", "sersic_phi", | |
"elpetro_absmag_f", "elpetro_absmag_n", "elpetro_absmag_u", | |
"elpetro_absmag_g", "elpetro_absmag_r", "elpetro_absmag_i", "elpetro_absmag_z", | |
"sersic_nmgy_f", "sersic_nmgy_n", "sersic_nmgy_u", "sersic_nmgy_g", | |
"sersic_nmgy_r", "sersic_nmgy_i", "sersic_nmgy_z" | |
], | |
"Mass and Redshift": [ | |
"elpetro_mass", "elpetro_mass_log", "redshift", "redshift_nsa", | |
"redshift_ossy", "photo_z", "photo_zerr", "spec_z" | |
], | |
"Star Formation Properties": [ | |
"fibre_sfr_avg", "fibre_sfr_entropy", "fibre_sfr_median", "fibre_sfr_mode", | |
"fibre_sfr_p16", "fibre_sfr_p2p5", "fibre_sfr_p84", "fibre_sfr_p97p5", | |
"fibre_ssfr_avg", "fibre_ssfr_entropy", "fibre_ssfr_median", "fibre_ssfr_mode", | |
"fibre_ssfr_p16", "fibre_ssfr_p2p5", "fibre_ssfr_p84", "fibre_ssfr_p97p5", | |
"total_ssfr_avg", "total_ssfr_entropy", "total_ssfr_flag", "total_ssfr_median", | |
"total_ssfr_mode", "total_ssfr_p16", "total_ssfr_p2p5", "total_ssfr_p84", | |
"total_ssfr_p97p5", "total_sfr_avg", "total_sfr_entropy", "total_sfr_flag", | |
"total_sfr_median", "total_sfr_mode", "total_sfr_p16", "total_sfr_p2p5", | |
"total_sfr_p84", "total_sfr_p97p5" | |
], | |
"AGN Properties": [ | |
"log_l_oiii", "fwhm", "e_fwhm", "equiv_width", "log_l_ha", | |
"log_m_bh", "upper_e_log_m_bh", "lower_e_log_m_bh", "log_bolometric_l" | |
], | |
"HI Properties": [ | |
"W50", "sigW", "W20", "HIflux", "sigflux", "SNR", "RMS", | |
"Dist", "sigDist", "logMH", "siglogMH" | |
], | |
"PhotoZ Catalog": [ | |
"photoz_id", "ra_photoz", "dec_photoz", "mag_abs_g_photoz", "mag_abs_r_photoz", | |
"mag_abs_z_photoz", "mass_inf_photoz", "mass_med_photoz", "mass_sup_photoz", | |
"sfr_inf_photoz", "sfr_sup_photoz", "ssfr_inf_photoz", "ssfr_med_photoz", | |
"ssfr_sup_photoz", "sky_separation_arcsec_from_photoz" | |
] | |
} | |
# Define the Gradio interface | |
with gr.Blocks(title="Galaxy embeddings") as demo: | |
gr.Markdown("# Sparse galaxy embeddings") | |
with gr.Row(): | |
split_input = gr.Dropdown( | |
label="Split", | |
value="test", | |
choices=["test", "validation"] | |
) | |
group_dropdown = gr.Dropdown( | |
label="Property category", | |
choices=list(property_groups.keys()), | |
value=list(property_groups.keys())[0] | |
) | |
color_col = gr.Dropdown( | |
label="Property", | |
choices=property_groups[list(property_groups.keys())[0]] | |
) | |
log = gr.Checkbox( | |
label="Take log?", | |
value=False | |
) | |
visualize_btn = gr.Button("Let's go!") | |
error_output = gr.Textbox(label="Errors", visible=False) | |
def update_properties(group): | |
return gr.update(choices=property_groups[group], value=property_groups[group][0]) | |
group_dropdown.change( | |
fn=update_properties, | |
inputs=[group_dropdown], | |
outputs=[color_col] | |
) | |
with gr.Row(): | |
plot_output = gr.Plot(label="Visualization") | |
visualize_btn.click( | |
fn=create_visualization, | |
inputs=[split_input, color_col, log], | |
outputs=[plot_output, error_output] | |
) | |
demo.launch() | |