uce_demo / app.py
minwoosun's picture
Modularize code
4481b1f verified
raw
history blame
6.71 kB
import gradio as gr
import pandas as pd
import numpy as np
import umap
import json
import matplotlib.pyplot as plt
import os
import scanpy as sc
import subprocess
import sys
from io import BytesIO
from sklearn.linear_model import LogisticRegression
from huggingface_hub import hf_hub_download
def load_model_params(model_path):
"""Load model parameters from a JSON file."""
with open(model_path, 'r') as f:
model_params = json.load(f)
return model_params
def reconstruct_classifier(model_params):
"""Reconstruct the logistic regression model from parameters."""
model = LogisticRegression(multi_class='multinomial', solver='lbfgs', max_iter=1000)
model.coef_ = np.array(model_params["coef"])
model.intercept_ = np.array(model_params["intercept"])
model.classes_ = np.array(model_params["classes"])
return model
def save_predictions(y_pred, output_path):
"""Save predictions to a CSV file."""
df = pd.DataFrame(y_pred, columns=["predicted_cell_type"])
df.to_csv(output_path, index=False, header=False)
def load_and_predict_with_classifier(x, model_path, output_path, save=False):
"""Load model, predict, and optionally save predictions."""
model_params = load_model_params(model_path)
model = reconstruct_classifier(model_params)
y_pred = model.predict(x)
if save:
save_predictions(y_pred, output_path)
return y_pred
def plot_umap(adata):
"""Generate a UMAP plot from the provided AnnData object."""
labels = pd.Categorical(adata.obs["cell_type"])
reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)
embedding = reducer.fit_transform(adata.obsm["X_uce"])
plt.figure(figsize=(10, 8))
scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=labels.codes, cmap='Set1', s=50, alpha=0.6)
handles = [
plt.Line2D([0], [0], marker='o', color='w', label=cell_type,
markerfacecolor=plt.cm.Set1(i / len(labels.categories)), markersize=10)
for i, cell_type in enumerate(labels.categories)
]
plt.legend(handles=handles, title='Cell Type')
plt.title('UMAP projection of the data')
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
buf = BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
img = plt.imread(buf, format='png')
return img
def toggle_file_input(default_dataset):
"""Toggle file input based on dataset selection."""
if default_dataset != "None":
return gr.update(interactive=False)
else:
return gr.update(interactive=True)
def run_uce_model(input_file_path, model_dir, model_loc):
"""Run UCE model on the provided AnnData file."""
command = [
sys.executable,
os.path.join(model_dir, 'eval_single_anndata.py'),
'--adata_path', input_file_path,
'--dir', model_dir,
'--model_loc', model_loc
]
subprocess.run(command, check=True)
def main(input_file_path, species, default_dataset):
"""Main function to execute the demo logic."""
# Clone the UCE repository and set paths
repo_url = 'https://github.com/minwoosun/UCE.git'
repo_dir = '/home/user/app/UCE'
if not os.path.exists(repo_dir):
subprocess.run(['git', 'clone', repo_url], check=True)
sys.path.append(repo_dir)
# Handle default datasets
default_dataset_paths = {
"PBMC 100 cells": hf_hub_download(repo_id="minwoosun/uce-misc", filename="100_pbmcs_proc_subset.h5ad"),
"PBMC 1000 cells": hf_hub_download(repo_id="minwoosun/uce-misc", filename="1k_pbmcs_proc_subset.h5ad"),
}
if default_dataset in default_dataset_paths:
input_file_path = default_dataset_paths[default_dataset]
# Run UCE model
run_uce_model(input_file_path, repo_dir, 'minwoosun/uce-100m')
# Load UCE embeddings and perform classification
adata = sc.read_h5ad(os.path.join(repo_dir, f"{os.path.splitext(os.path.basename(input_file_path))[0]}_uce_adata.h5ad"))
x = adata.obsm['X_uce']
model_path = hf_hub_download(repo_id="minwoosun/uce-misc", filename="tabula_sapiens_v1_logistic_regression_model_weights.json")
pred_file = os.path.join(repo_dir, f"{os.path.splitext(os.path.basename(input_file_path))[0]}_predictions.csv")
y_pred = load_and_predict_with_classifier(x, model_path, pred_file, save=True)
# Generate UMAP plot
img = plot_umap(adata)
return img, os.path.join(repo_dir, f"{os.path.splitext(os.path.basename(input_file_path))[0]}_uce_adata.h5ad"), pred_file
# Gradio UI
def create_demo():
"""Create and launch the Gradio demo."""
with gr.Blocks() as demo:
gr.Markdown("""
<div style="text-align:center; margin-bottom:20px;">
<h1>UCE 100M Demo 🦠</h1>
<h2>Universal Cell Embeddings: Zero-Shot Cell-Type Classification in Action!</h2>
<div style="margin-top:10px;">
<a href="https://github.com/minwoosun/UCE"><img src="https://badges.aleen42.com/src/github.svg" alt="GitHub"></a>
<a href="https://www.biorxiv.org/content/10.1101/2023.11.28.568918v1"><img src="https://img.shields.io/badge/bioRxiv-2023.11.28.568918-green?style=plastic" alt="Paper"></a>
<a href="https://colab.research.google.com/drive/1opud0BVWr76IM8UnGgTomVggui_xC4p0?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
</div>
<p>Upload a `.h5ad` single cell gene expression file or select the species to generate UMAP projections and download the embeddings.</p>
</div>
""")
# Inputs
file_input = gr.File(label="Upload a .h5ad single cell gene expression file or select a default dataset below")
species_input = gr.Dropdown(choices=["human", "mouse"], label="Select species")
default_dataset_input = gr.Dropdown(choices=["None", "PBMC 100 cells", "PBMC 1000 cells"], label="Select default dataset")
default_dataset_input.change(toggle_file_input, inputs=[default_dataset_input], outputs=[file_input])
# Outputs
run_button = gr.Button("Run")
with gr.Row():
image_output = gr.Image(type="numpy", label="UMAP of UCE Embeddings")
file_output = gr.File(label="Download embeddings")
pred_output = gr.File(label="Download predictions")
# Run the function on button click
run_button.click(fn=main, inputs=[file_input, species_input, default_dataset_input], outputs=[image_output, file_output, pred_output])
demo.launch()
if __name__ == "__main__":
create_demo()