File size: 6,706 Bytes
691f8f6 f81db88 691f8f6 f81db88 691f8f6 d54d01e 691f8f6 f81db88 691f8f6 4481b1f f81db88 4481b1f f81db88 4481b1f f81db88 31a75e1 4481b1f 31a75e1 4481b1f 31a75e1 4481b1f 31a75e1 8686715 4481b1f 8686715 4481b1f 8686715 4481b1f 691f8f6 4481b1f 6deb948 4481b1f 6deb948 4481b1f 6deb948 4481b1f 6deb948 4481b1f 691f8f6 4481b1f f81db88 4481b1f f81db88 4481b1f f81db88 4481b1f f81db88 4481b1f 31a75e1 691f8f6 4481b1f 691f8f6 4481b1f 3b420fc 4481b1f dcbe609 4481b1f 5cdd2ae 4481b1f 5cdd2ae 4481b1f e2972d3 4481b1f a92a26c 4481b1f a92a26c 4481b1f a92a26c f81db88 a14ec14 4481b1f efca0ea 5cdd2ae a92a26c 4481b1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import gradio as gr
import pandas as pd
import numpy as np
import umap
import json
import matplotlib.pyplot as plt
import os
import scanpy as sc
import subprocess
import sys
from io import BytesIO
from sklearn.linear_model import LogisticRegression
from huggingface_hub import hf_hub_download
def load_model_params(model_path):
"""Load model parameters from a JSON file."""
with open(model_path, 'r') as f:
model_params = json.load(f)
return model_params
def reconstruct_classifier(model_params):
"""Reconstruct the logistic regression model from parameters."""
model = LogisticRegression(multi_class='multinomial', solver='lbfgs', max_iter=1000)
model.coef_ = np.array(model_params["coef"])
model.intercept_ = np.array(model_params["intercept"])
model.classes_ = np.array(model_params["classes"])
return model
def save_predictions(y_pred, output_path):
"""Save predictions to a CSV file."""
df = pd.DataFrame(y_pred, columns=["predicted_cell_type"])
df.to_csv(output_path, index=False, header=False)
def load_and_predict_with_classifier(x, model_path, output_path, save=False):
"""Load model, predict, and optionally save predictions."""
model_params = load_model_params(model_path)
model = reconstruct_classifier(model_params)
y_pred = model.predict(x)
if save:
save_predictions(y_pred, output_path)
return y_pred
def plot_umap(adata):
"""Generate a UMAP plot from the provided AnnData object."""
labels = pd.Categorical(adata.obs["cell_type"])
reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)
embedding = reducer.fit_transform(adata.obsm["X_uce"])
plt.figure(figsize=(10, 8))
scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=labels.codes, cmap='Set1', s=50, alpha=0.6)
handles = [
plt.Line2D([0], [0], marker='o', color='w', label=cell_type,
markerfacecolor=plt.cm.Set1(i / len(labels.categories)), markersize=10)
for i, cell_type in enumerate(labels.categories)
]
plt.legend(handles=handles, title='Cell Type')
plt.title('UMAP projection of the data')
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
buf = BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
img = plt.imread(buf, format='png')
return img
def toggle_file_input(default_dataset):
"""Toggle file input based on dataset selection."""
if default_dataset != "None":
return gr.update(interactive=False)
else:
return gr.update(interactive=True)
def run_uce_model(input_file_path, model_dir, model_loc):
"""Run UCE model on the provided AnnData file."""
command = [
sys.executable,
os.path.join(model_dir, 'eval_single_anndata.py'),
'--adata_path', input_file_path,
'--dir', model_dir,
'--model_loc', model_loc
]
subprocess.run(command, check=True)
def main(input_file_path, species, default_dataset):
"""Main function to execute the demo logic."""
# Clone the UCE repository and set paths
repo_url = 'https://github.com/minwoosun/UCE.git'
repo_dir = '/home/user/app/UCE'
if not os.path.exists(repo_dir):
subprocess.run(['git', 'clone', repo_url], check=True)
sys.path.append(repo_dir)
# Handle default datasets
default_dataset_paths = {
"PBMC 100 cells": hf_hub_download(repo_id="minwoosun/uce-misc", filename="100_pbmcs_proc_subset.h5ad"),
"PBMC 1000 cells": hf_hub_download(repo_id="minwoosun/uce-misc", filename="1k_pbmcs_proc_subset.h5ad"),
}
if default_dataset in default_dataset_paths:
input_file_path = default_dataset_paths[default_dataset]
# Run UCE model
run_uce_model(input_file_path, repo_dir, 'minwoosun/uce-100m')
# Load UCE embeddings and perform classification
adata = sc.read_h5ad(os.path.join(repo_dir, f"{os.path.splitext(os.path.basename(input_file_path))[0]}_uce_adata.h5ad"))
x = adata.obsm['X_uce']
model_path = hf_hub_download(repo_id="minwoosun/uce-misc", filename="tabula_sapiens_v1_logistic_regression_model_weights.json")
pred_file = os.path.join(repo_dir, f"{os.path.splitext(os.path.basename(input_file_path))[0]}_predictions.csv")
y_pred = load_and_predict_with_classifier(x, model_path, pred_file, save=True)
# Generate UMAP plot
img = plot_umap(adata)
return img, os.path.join(repo_dir, f"{os.path.splitext(os.path.basename(input_file_path))[0]}_uce_adata.h5ad"), pred_file
# Gradio UI
def create_demo():
"""Create and launch the Gradio demo."""
with gr.Blocks() as demo:
gr.Markdown("""
<div style="text-align:center; margin-bottom:20px;">
<h1>UCE 100M Demo 🦠</h1>
<h2>Universal Cell Embeddings: Zero-Shot Cell-Type Classification in Action!</h2>
<div style="margin-top:10px;">
<a href="https://github.com/minwoosun/UCE"><img src="https://badges.aleen42.com/src/github.svg" alt="GitHub"></a>
<a href="https://www.biorxiv.org/content/10.1101/2023.11.28.568918v1"><img src="https://img.shields.io/badge/bioRxiv-2023.11.28.568918-green?style=plastic" alt="Paper"></a>
<a href="https://colab.research.google.com/drive/1opud0BVWr76IM8UnGgTomVggui_xC4p0?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
</div>
<p>Upload a `.h5ad` single cell gene expression file or select the species to generate UMAP projections and download the embeddings.</p>
</div>
""")
# Inputs
file_input = gr.File(label="Upload a .h5ad single cell gene expression file or select a default dataset below")
species_input = gr.Dropdown(choices=["human", "mouse"], label="Select species")
default_dataset_input = gr.Dropdown(choices=["None", "PBMC 100 cells", "PBMC 1000 cells"], label="Select default dataset")
default_dataset_input.change(toggle_file_input, inputs=[default_dataset_input], outputs=[file_input])
# Outputs
run_button = gr.Button("Run")
with gr.Row():
image_output = gr.Image(type="numpy", label="UMAP of UCE Embeddings")
file_output = gr.File(label="Download embeddings")
pred_output = gr.File(label="Download predictions")
# Run the function on button click
run_button.click(fn=main, inputs=[file_input, species_input, default_dataset_input], outputs=[image_output, file_output, pred_output])
demo.launch()
if __name__ == "__main__":
create_demo()
|