import gradio as gr import numpy as np import pandas as pd import plotly.express as px import plotly.subplots as sp from datasets import load_dataset import umap # Load dataset function def load_dataset_from_hub(dataset_name, split="test"): try: return load_dataset(dataset_name, split=split), None except Exception as e: return None, str(e) # Create visualization function def create_visualization(split, color_col, log): # Load the dataset dataset, error = load_dataset_from_hub("Smith42/galaxies_with_embeddings", split) if error: return None, f"Error loading dataset: {error}" try: embedding_cols = ["p16k00_pca", "p16k01_pca", "p16k10_pca"] # Extract embeddings and color values embeddings = dataset.select_columns(embedding_cols) colors = np.array(dataset[color_col], dtype=float) if log: colors = np.log(colors) fig = sp.make_subplots(cols=3, subplot_titles=["k = 0%", "k = 1%", "k = 10%"]) ii = 0 for col in range(1, 4): embedding_col = embedding_cols[ii] emb_ar = np.array(embeddings[embedding_col]) df = pd.DataFrame({ 'x': emb_ar[:, 0], 'y': emb_ar[:, 1], 'color': colors }).dropna() scatter = px.scatter(df, x='x', y='y', color='color') fig.add_trace(scatter.data[0], row=1, col=col) ii = ii + 1 return fig, None except Exception as e: return None, f"Error creating viz: {str(e)}" property_groups = { "Basic Identifiers": [ "dr8_id", "ra", "dec", "brickid", "objid", "file_name", "iauname" ], "Galaxy Morphology": [ "smooth-or-featured_smooth_fraction", "smooth-or-featured_featured-or-disk_fraction", "smooth-or-featured_artifact_fraction", "disk-edge-on_yes_fraction", "disk-edge-on_no_fraction", "has-spiral-arms_yes_fraction", "has-spiral-arms_no_fraction", "bar_strong_fraction", "bar_weak_fraction", "bar_no_fraction", "bulge-size_dominant_fraction", "bulge-size_large_fraction", "bulge-size_moderate_fraction", "bulge-size_small_fraction", "bulge-size_none_fraction", "how-rounded_round_fraction", "how-rounded_in-between_fraction", "how-rounded_cigar-shaped_fraction", "edge-on-bulge_boxy_fraction", "edge-on-bulge_none_fraction", "edge-on-bulge_rounded_fraction", "spiral-winding_tight_fraction", "spiral-winding_medium_fraction", "spiral-winding_loose_fraction", "spiral-arm-count_1_fraction", "spiral-arm-count_2_fraction", "spiral-arm-count_3_fraction", "spiral-arm-count_4_fraction", "spiral-arm-count_more-than-4_fraction", "spiral-arm-count_cant-tell_fraction", "merging_none_fraction", "merging_minor-disturbance_fraction", "merging_major-disturbance_fraction", "merging_merger_fraction" ], "Physical Size Parameters": [ "est_petro_th50", "est_petro_th50_kpc", "petro_theta", "petro_th50", "petro_th90", "petro_phi50", "petro_phi90", "petro_ba50", "petro_ba90", "elpetro_ba", "elpetro_phi", "elpetro_flux_r", "elpetro_theta_r" ], "Photometric Properties": [ "mag_r_desi", "mag_g_desi", "mag_z_desi", "mag_f", "mag_n", "mag_u", "mag_g", "mag_r", "mag_i", "mag_z", "u_minus_r", "sersic_n", "sersic_ba", "sersic_phi", "elpetro_absmag_f", "elpetro_absmag_n", "elpetro_absmag_u", "elpetro_absmag_g", "elpetro_absmag_r", "elpetro_absmag_i", "elpetro_absmag_z", "sersic_nmgy_f", "sersic_nmgy_n", "sersic_nmgy_u", "sersic_nmgy_g", "sersic_nmgy_r", "sersic_nmgy_i", "sersic_nmgy_z" ], "Mass and Redshift": [ "elpetro_mass", "elpetro_mass_log", "redshift", "redshift_nsa", "redshift_ossy", "photo_z", "photo_zerr", "spec_z" ], "Star Formation Properties": [ "fibre_sfr_avg", "fibre_sfr_entropy", "fibre_sfr_median", "fibre_sfr_mode", "fibre_sfr_p16", "fibre_sfr_p2p5", "fibre_sfr_p84", "fibre_sfr_p97p5", "fibre_ssfr_avg", "fibre_ssfr_entropy", "fibre_ssfr_median", "fibre_ssfr_mode", "fibre_ssfr_p16", "fibre_ssfr_p2p5", "fibre_ssfr_p84", "fibre_ssfr_p97p5", "total_ssfr_avg", "total_ssfr_entropy", "total_ssfr_flag", "total_ssfr_median", "total_ssfr_mode", "total_ssfr_p16", "total_ssfr_p2p5", "total_ssfr_p84", "total_ssfr_p97p5", "total_sfr_avg", "total_sfr_entropy", "total_sfr_flag", "total_sfr_median", "total_sfr_mode", "total_sfr_p16", "total_sfr_p2p5", "total_sfr_p84", "total_sfr_p97p5" ], "AGN Properties": [ "log_l_oiii", "fwhm", "e_fwhm", "equiv_width", "log_l_ha", "log_m_bh", "upper_e_log_m_bh", "lower_e_log_m_bh", "log_bolometric_l" ], "HI Properties": [ "W50", "sigW", "W20", "HIflux", "sigflux", "SNR", "RMS", "Dist", "sigDist", "logMH", "siglogMH" ], "PhotoZ Catalog": [ "photoz_id", "ra_photoz", "dec_photoz", "mag_abs_g_photoz", "mag_abs_r_photoz", "mag_abs_z_photoz", "mass_inf_photoz", "mass_med_photoz", "mass_sup_photoz", "sfr_inf_photoz", "sfr_sup_photoz", "ssfr_inf_photoz", "ssfr_med_photoz", "ssfr_sup_photoz", "sky_separation_arcsec_from_photoz" ] } # Define the Gradio interface with gr.Blocks(title="Galaxy embeddings") as demo: gr.Markdown("# Sparse galaxy embeddings") with gr.Row(): split_input = gr.Dropdown( label="Split", value="test", choices=["test", "validation"] ) group_dropdown = gr.Dropdown( label="Property category", choices=list(property_groups.keys()), value=list(property_groups.keys())[0] ) color_col = gr.Dropdown( label="Property", choices=property_groups[list(property_groups.keys())[0]] ) log = gr.Checkbox( label="Take log?", value=False ) visualize_btn = gr.Button("Let's go!") error_output = gr.Textbox(label="Errors", visible=False) def update_properties(group): return gr.update(choices=property_groups[group], value=property_groups[group][0]) group_dropdown.change( fn=update_properties, inputs=[group_dropdown], outputs=[color_col] ) with gr.Row(): plot_output = gr.Plot(label="Visualization") visualize_btn.click( fn=create_visualization, inputs=[split_input, color_col, log], outputs=[plot_output, error_output] ) demo.launch()