|
import gradio as gr |
|
import pandas as pd |
|
import numpy as np |
|
import umap |
|
import json |
|
import matplotlib.pyplot as plt |
|
import os |
|
import scanpy as sc |
|
import subprocess |
|
import sys |
|
from io import BytesIO |
|
from sklearn.linear_model import LogisticRegression |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
def load_model_params(model_path): |
|
"""Load model parameters from a JSON file.""" |
|
with open(model_path, 'r') as f: |
|
model_params = json.load(f) |
|
return model_params |
|
|
|
def reconstruct_classifier(model_params): |
|
"""Reconstruct the logistic regression model from parameters.""" |
|
model = LogisticRegression(multi_class='multinomial', solver='lbfgs', max_iter=1000) |
|
model.coef_ = np.array(model_params["coef"]) |
|
model.intercept_ = np.array(model_params["intercept"]) |
|
model.classes_ = np.array(model_params["classes"]) |
|
return model |
|
|
|
def save_predictions(y_pred, output_path): |
|
"""Save predictions to a CSV file.""" |
|
df = pd.DataFrame(y_pred, columns=["predicted_cell_type"]) |
|
df.to_csv(output_path, index=False, header=False) |
|
|
|
def load_and_predict_with_classifier(x, model_path, output_path, save=False): |
|
"""Load model, predict, and optionally save predictions.""" |
|
model_params = load_model_params(model_path) |
|
model = reconstruct_classifier(model_params) |
|
y_pred = model.predict(x) |
|
if save: |
|
save_predictions(y_pred, output_path) |
|
return y_pred |
|
|
|
def plot_umap(adata): |
|
"""Generate a UMAP plot from the provided AnnData object.""" |
|
labels = pd.Categorical(adata.obs["cell_type"]) |
|
reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42) |
|
embedding = reducer.fit_transform(adata.obsm["X_uce"]) |
|
|
|
plt.figure(figsize=(10, 8)) |
|
scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=labels.codes, cmap='Set1', s=50, alpha=0.6) |
|
|
|
handles = [ |
|
plt.Line2D([0], [0], marker='o', color='w', label=cell_type, |
|
markerfacecolor=plt.cm.Set1(i / len(labels.categories)), markersize=10) |
|
for i, cell_type in enumerate(labels.categories) |
|
] |
|
plt.legend(handles=handles, title='Cell Type') |
|
plt.title('UMAP projection of the data') |
|
plt.xlabel('UMAP1') |
|
plt.ylabel('UMAP2') |
|
|
|
buf = BytesIO() |
|
plt.savefig(buf, format='png') |
|
buf.seek(0) |
|
img = plt.imread(buf, format='png') |
|
return img |
|
|
|
def toggle_file_input(default_dataset): |
|
"""Toggle file input based on dataset selection.""" |
|
if default_dataset != "None": |
|
return gr.update(interactive=False) |
|
else: |
|
return gr.update(interactive=True) |
|
|
|
def run_uce_model(input_file_path, model_dir, model_loc): |
|
"""Run UCE model on the provided AnnData file.""" |
|
command = [ |
|
sys.executable, |
|
os.path.join(model_dir, 'eval_single_anndata.py'), |
|
'--adata_path', input_file_path, |
|
'--dir', model_dir, |
|
'--model_loc', model_loc |
|
] |
|
subprocess.run(command, check=True) |
|
|
|
def main(input_file_path, species, default_dataset): |
|
"""Main function to execute the demo logic.""" |
|
|
|
|
|
repo_url = 'https://github.com/minwoosun/UCE.git' |
|
repo_dir = '/home/user/app/UCE' |
|
if not os.path.exists(repo_dir): |
|
subprocess.run(['git', 'clone', repo_url], check=True) |
|
|
|
sys.path.append(repo_dir) |
|
|
|
|
|
default_dataset_paths = { |
|
"PBMC 100 cells": hf_hub_download(repo_id="minwoosun/uce-misc", filename="100_pbmcs_proc_subset.h5ad"), |
|
"PBMC 1000 cells": hf_hub_download(repo_id="minwoosun/uce-misc", filename="1k_pbmcs_proc_subset.h5ad"), |
|
} |
|
|
|
if default_dataset in default_dataset_paths: |
|
input_file_path = default_dataset_paths[default_dataset] |
|
|
|
|
|
run_uce_model(input_file_path, repo_dir, 'minwoosun/uce-100m') |
|
|
|
|
|
adata = sc.read_h5ad(os.path.join(repo_dir, f"{os.path.splitext(os.path.basename(input_file_path))[0]}_uce_adata.h5ad")) |
|
x = adata.obsm['X_uce'] |
|
|
|
model_path = hf_hub_download(repo_id="minwoosun/uce-misc", filename="tabula_sapiens_v1_logistic_regression_model_weights.json") |
|
pred_file = os.path.join(repo_dir, f"{os.path.splitext(os.path.basename(input_file_path))[0]}_predictions.csv") |
|
y_pred = load_and_predict_with_classifier(x, model_path, pred_file, save=True) |
|
|
|
|
|
img = plot_umap(adata) |
|
|
|
return img, os.path.join(repo_dir, f"{os.path.splitext(os.path.basename(input_file_path))[0]}_uce_adata.h5ad"), pred_file |
|
|
|
|
|
|
|
def create_demo(): |
|
"""Create and launch the Gradio demo.""" |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(""" |
|
<div style="text-align:center; margin-bottom:20px;"> |
|
<h1>UCE 100M Demo 🦠</h1> |
|
<h2>Universal Cell Embeddings: Zero-Shot Cell-Type Classification in Action!</h2> |
|
<div style="margin-top:10px;"> |
|
<a href="https://github.com/minwoosun/UCE"><img src="https://badges.aleen42.com/src/github.svg" alt="GitHub"></a> |
|
<a href="https://www.biorxiv.org/content/10.1101/2023.11.28.568918v1"><img src="https://img.shields.io/badge/bioRxiv-2023.11.28.568918-green?style=plastic" alt="Paper"></a> |
|
<a href="https://colab.research.google.com/drive/1opud0BVWr76IM8UnGgTomVggui_xC4p0?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a> |
|
</div> |
|
<p>Upload a `.h5ad` single cell gene expression file or select the species to generate UMAP projections and download the embeddings.</p> |
|
</div> |
|
""") |
|
|
|
|
|
file_input = gr.File(label="Upload a .h5ad single cell gene expression file or select a default dataset below") |
|
species_input = gr.Dropdown(choices=["human", "mouse"], label="Select species") |
|
default_dataset_input = gr.Dropdown(choices=["None", "PBMC 100 cells", "PBMC 1000 cells"], label="Select default dataset") |
|
|
|
default_dataset_input.change(toggle_file_input, inputs=[default_dataset_input], outputs=[file_input]) |
|
|
|
|
|
run_button = gr.Button("Run") |
|
with gr.Row(): |
|
image_output = gr.Image(type="numpy", label="UMAP of UCE Embeddings") |
|
file_output = gr.File(label="Download embeddings") |
|
pred_output = gr.File(label="Download predictions") |
|
|
|
|
|
run_button.click(fn=main, inputs=[file_input, species_input, default_dataset_input], outputs=[image_output, file_output, pred_output]) |
|
|
|
demo.launch() |
|
|
|
if __name__ == "__main__": |
|
create_demo() |
|
|