uce_demo / app.py
minwoosun's picture
Incorporate classifier
f81db88 verified
raw
history blame
9 kB
import gradio as gr
import pandas as pd
import numpy as np
import umap
import json
import matplotlib.pyplot as plt
import os
# import tempfile
import scanpy as sc
# import argparse
import subprocess
import sys
from evaluate import AnndataProcessor
from accelerate import Accelerator
from io import BytesIO
from sklearn.linear_model import LogisticRegression
from huggingface_hub import hf_hub_download
def load_and_predict_with_classifier(x, model_path, output_path):
# Load the model parameters from the JSON file
with open(model_path, 'r') as f:
model_params = json.load(f)
# Reconstruct the logistic regression model
model_loaded = LogisticRegression(multi_class='multinomial', solver='lbfgs', max_iter=1000)
model_loaded.coef_ = np.array(model_params["coef"])
model_loaded.intercept_ = np.array(model_params["intercept"])
model_loaded.classes_ = np.array(model_params["classes"])
# output predictions
y_pred = model_loaded.predict(x)
# Convert the array to a Pandas DataFrame
if save:
df = pd.DataFrame(y_pred, columns=["predicted_cell_type"])
df.to_csv(output_path, index=False, header=False)
return y_pred
def main(input_file_path, species):
# Get the current working directory
current_working_directory = os.getcwd()
# Print the current working directory
print("Current Working Directory:", current_working_directory)
# clone and cd into UCE repo
os.system('git clone https://github.com/minwoosun/UCE.git')
os.chdir('/home/user/app/UCE')
# Get the current working directory
current_working_directory = os.getcwd()
# Print the current working directory
print("Current Working Directory:", current_working_directory)
# Specify the path to the directory you want to add
new_directory = "/home/user/app/UCE"
# Add the directory to the Python path
sys.path.append(new_directory)
##############
# UCE #
##############
# # python eval_single_anndata.py --adata_path "./data/10k_pbmcs_proc.h5ad" --dir "./" --model_loc "minwoosun/uce-100m"
# script_name = "/home/user/app/UCE/eval_single_anndata.py"
# args = ["--adata_path", input_file_path, "--dir", "/home/user/app/UCE/", "--model_loc", "minwoosun/uce-100m"]
# command = ["python", script_name] + args
dir_path = '/home/user/app/UCE/'
model_loc = 'minwoosun/uce-100m'
print(input_file_path)
print(dir_path)
print(model_loc)
# Verify adata_path is not None
if input_file_path is None or not os.path.exists(input_file_path):
raise ValueError(f"Invalid adata_path: {input_file_path}. Please check if the file exists.")
# Construct the command
command = [
'python',
'/home/user/app/UCE/eval_single_anndata.py',
'--adata_path', input_file_path,
'--dir', dir_path,
'--model_loc', model_loc
]
# Print the command for debugging
print("Running command:", command)
print("---> RUNNING UCE")
result = subprocess.run(command, capture_output=True, text=True, check=True)
print(result.stdout)
print(result.stderr)
print("---> FINSIH UCE")
################################
# Cell-type classification #
################################
# Set output file path
file_name_with_ext = os.path.basename(input_file_path)
file_name = os.path.splitext(file_name_with_ext)[0]
pred_file = "/home/user/app/UCE/" + f"{file_name}_predictions.csv"
model_path = hf_hub_download(repo_id="minwoosun/uce-misc", filename="tabula_sapiens_v1_logistic_regression_model_weights.json")
file_name_with_ext = os.path.basename(input_file_path)
file_name = os.path.splitext(file_name_with_ext)[0]
output_file = "/home/user/app/UCE/" + f"{file_name}_uce_adata.h5ad"
adata = sc.read_h5ad(output_file)
x = adata.obsm['X_uce']
y_pred = load_and_predict_with_classifier(x, model_path, pred_file, save=True)
##############
# UMAP #
##############
UMAP = True
if (UMAP):
# # Set output file path
# file_name_with_ext = os.path.basename(input_file_path)
# file_name = os.path.splitext(file_name_with_ext)[0]
# output_file = "/home/user/app/UCE/" + f"{file_name}_uce_adata.h5ad"
# adata = sc.read_h5ad(output_file)
labels = pd.Categorical(adata.obs["cell_type"])
reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)
embedding = reducer.fit_transform(adata.obsm["X_uce"])
plt.figure(figsize=(10, 8))
# Create the scatter plot
scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=labels.codes, cmap='Set1', s=50, alpha=0.6)
# Create a legend
handles = []
for i, cell_type in enumerate(labels.categories):
handles.append(plt.Line2D([0], [0], marker='o', color='w', label=cell_type,
markerfacecolor=plt.cm.Set1(i / len(labels.categories)), markersize=10))
plt.legend(handles=handles, title='Cell Type')
plt.title('UMAP projection of the data')
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
# Save plot to a BytesIO object
buf = BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
# Read the image from BytesIO object
img = plt.imread(buf, format='png')
else:
img = None
print("no image")
return img, output_file, pred_file
if __name__ == "__main__":
css = """
body {background-color: white; color: black;}
.gradio-container {background-color: white; color: black;}
.gr-file, .gr-image {background-color: #f0f0f0; color: black; border-color: black;}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(
'''
<div style="text-align:center; margin-bottom:20px;">
<span style="font-size:3em; font-weight:bold; color: black;">UCE 100M Demo 🦠</span>
</div>
<div style="text-align:center; margin-bottom:10px;">
<span style="font-size:1.5em; font-weight:bold; color: black;">Universal Cell Embeddings: Explore Single Cell Data</span>
</div>
<div style="text-align:center; margin-bottom:20px;">
<a href="https://github.com/minwoosun/UCE">
<img src="https://badges.aleen42.com/src/github.svg" alt="GitHub" style="display:inline-block; margin-right:10px;">
</a>
<a href="https://www.biorxiv.org/content/10.1101/2023.11.28.568918v1">
<img src="https://img.shields.io/badge/bioRxiv-2023.11.28.568918-green?style=plastic" alt="Paper" style="display:inline-block; margin-right:10px;">
</a>
</div>
<div style="text-align:left; margin-bottom:20px; color: black;">
Upload a `.h5ad` single cell gene expression file and select the species (Human/Mouse).
The demo will generate UMAP projections of the embeddings and allow you to download the embeddings for further analysis.
</div>
<div style="margin-bottom:20px; color: black;">
<ol style="list-style:none; padding-left:0;">
<li>1. Upload your `.h5ad` file</li>
<li>2. Select the species</li>
<li>3. Click "Run" to view the UMAP scatter plot</li>
<li>4. Download the UMAP coordinates</li>
</ol>
</div>
<div style="text-align:left; line-height:1.8; color: black;">
Please consider citing the following paper if you use this tool in your research:
</div>
<div style="text-align:left; line-height:1.8; color: black;">
Rosen, Y., Roohani, Y., Agarwal, A., Samotorčan, L., Tabula Sapiens Consortium, Quake, S. R., & Leskovec, J. Universal Cell Embeddings: A Foundation Model for Cell Biology. bioRxiv. https://doi.org/10.1101/2023.11.28.568918
</div>
'''
)
# Define Gradio inputs and outputs
file_input = gr.File(label="Upload a .h5ad single cell gene expression file")
species_input = gr.Dropdown(choices=["human", "mouse"], label="Select species")
run_button = gr.Button("Run")
# Arrange UMAP plot and file output side by side
with gr.Row():
image_output = gr.Image(type="numpy", label="UMAP of UCE Embeddings")
file_output = gr.File(label="Download embeddings")
pred_output = gr.File(label="Download predictions")
# Add the components and link to the function
run_button.click(
fn=main,
inputs=[file_input, species_input],
outputs=[image_output, file_output, pred_output]
)
demo.launch()