File size: 6,706 Bytes
691f8f6
 
f81db88
691f8f6
f81db88
691f8f6
 
 
 
d54d01e
691f8f6
f81db88
691f8f6
 
 
4481b1f
 
f81db88
 
4481b1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f81db88
4481b1f
f81db88
 
31a75e1
4481b1f
31a75e1
 
 
 
 
 
 
4481b1f
 
 
 
 
31a75e1
 
 
 
4481b1f
31a75e1
 
 
 
 
 
8686715
4481b1f
8686715
4481b1f
8686715
4481b1f
691f8f6
4481b1f
 
6deb948
4481b1f
 
6deb948
4481b1f
6deb948
 
4481b1f
6deb948
4481b1f
 
691f8f6
4481b1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f81db88
4481b1f
 
f81db88
4481b1f
 
f81db88
 
4481b1f
 
f81db88
4481b1f
 
31a75e1
691f8f6
4481b1f
691f8f6
4481b1f
3b420fc
4481b1f
 
 
dcbe609
4481b1f
5cdd2ae
4481b1f
 
 
 
 
 
 
 
5cdd2ae
4481b1f
 
 
e2972d3
4481b1f
 
a92a26c
4481b1f
 
 
 
a92a26c
4481b1f
a92a26c
f81db88
a14ec14
4481b1f
 
efca0ea
5cdd2ae
a92a26c
4481b1f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import gradio as gr
import pandas as pd
import numpy as np
import umap
import json
import matplotlib.pyplot as plt
import os
import scanpy as sc
import subprocess
import sys
from io import BytesIO
from sklearn.linear_model import LogisticRegression
from huggingface_hub import hf_hub_download


def load_model_params(model_path):
    """Load model parameters from a JSON file."""
    with open(model_path, 'r') as f:
        model_params = json.load(f)
    return model_params

def reconstruct_classifier(model_params):
    """Reconstruct the logistic regression model from parameters."""
    model = LogisticRegression(multi_class='multinomial', solver='lbfgs', max_iter=1000)
    model.coef_ = np.array(model_params["coef"])
    model.intercept_ = np.array(model_params["intercept"])
    model.classes_ = np.array(model_params["classes"])
    return model

def save_predictions(y_pred, output_path):
    """Save predictions to a CSV file."""
    df = pd.DataFrame(y_pred, columns=["predicted_cell_type"])
    df.to_csv(output_path, index=False, header=False)

def load_and_predict_with_classifier(x, model_path, output_path, save=False):
    """Load model, predict, and optionally save predictions."""
    model_params = load_model_params(model_path)
    model = reconstruct_classifier(model_params)
    y_pred = model.predict(x)
    if save:
        save_predictions(y_pred, output_path)
    return y_pred

def plot_umap(adata):
    """Generate a UMAP plot from the provided AnnData object."""
    labels = pd.Categorical(adata.obs["cell_type"])
    reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)
    embedding = reducer.fit_transform(adata.obsm["X_uce"])

    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=labels.codes, cmap='Set1', s=50, alpha=0.6)

    handles = [
        plt.Line2D([0], [0], marker='o', color='w', label=cell_type,
                   markerfacecolor=plt.cm.Set1(i / len(labels.categories)), markersize=10)
        for i, cell_type in enumerate(labels.categories)
    ]
    plt.legend(handles=handles, title='Cell Type')
    plt.title('UMAP projection of the data')
    plt.xlabel('UMAP1')
    plt.ylabel('UMAP2')

    buf = BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    img = plt.imread(buf, format='png')
    return img

def toggle_file_input(default_dataset):
    """Toggle file input based on dataset selection."""
    if default_dataset != "None":
        return gr.update(interactive=False)
    else:
        return gr.update(interactive=True)

def run_uce_model(input_file_path, model_dir, model_loc):
    """Run UCE model on the provided AnnData file."""
    command = [
        sys.executable, 
        os.path.join(model_dir, 'eval_single_anndata.py'), 
        '--adata_path', input_file_path, 
        '--dir', model_dir, 
        '--model_loc', model_loc
    ]
    subprocess.run(command, check=True)

def main(input_file_path, species, default_dataset):
    """Main function to execute the demo logic."""

    # Clone the UCE repository and set paths
    repo_url = 'https://github.com/minwoosun/UCE.git'
    repo_dir = '/home/user/app/UCE'
    if not os.path.exists(repo_dir):
        subprocess.run(['git', 'clone', repo_url], check=True)
    
    sys.path.append(repo_dir)
    
    # Handle default datasets
    default_dataset_paths = {
        "PBMC 100 cells": hf_hub_download(repo_id="minwoosun/uce-misc", filename="100_pbmcs_proc_subset.h5ad"),
        "PBMC 1000 cells": hf_hub_download(repo_id="minwoosun/uce-misc", filename="1k_pbmcs_proc_subset.h5ad"),
    }
    
    if default_dataset in default_dataset_paths:
        input_file_path = default_dataset_paths[default_dataset]

    # Run UCE model
    run_uce_model(input_file_path, repo_dir, 'minwoosun/uce-100m')

    # Load UCE embeddings and perform classification
    adata = sc.read_h5ad(os.path.join(repo_dir, f"{os.path.splitext(os.path.basename(input_file_path))[0]}_uce_adata.h5ad"))
    x = adata.obsm['X_uce']

    model_path = hf_hub_download(repo_id="minwoosun/uce-misc", filename="tabula_sapiens_v1_logistic_regression_model_weights.json")
    pred_file = os.path.join(repo_dir, f"{os.path.splitext(os.path.basename(input_file_path))[0]}_predictions.csv")
    y_pred = load_and_predict_with_classifier(x, model_path, pred_file, save=True)

    # Generate UMAP plot
    img = plot_umap(adata)

    return img, os.path.join(repo_dir, f"{os.path.splitext(os.path.basename(input_file_path))[0]}_uce_adata.h5ad"), pred_file

# Gradio UI

def create_demo():
    """Create and launch the Gradio demo."""
    
    with gr.Blocks() as demo:
        gr.Markdown("""
            <div style="text-align:center; margin-bottom:20px;">
                <h1>UCE 100M Demo 🦠</h1>
                <h2>Universal Cell Embeddings: Zero-Shot Cell-Type Classification in Action!</h2>
                <div style="margin-top:10px;">
                    <a href="https://github.com/minwoosun/UCE"><img src="https://badges.aleen42.com/src/github.svg" alt="GitHub"></a>
                    <a href="https://www.biorxiv.org/content/10.1101/2023.11.28.568918v1"><img src="https://img.shields.io/badge/bioRxiv-2023.11.28.568918-green?style=plastic" alt="Paper"></a>
                    <a href="https://colab.research.google.com/drive/1opud0BVWr76IM8UnGgTomVggui_xC4p0?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
                </div>
                <p>Upload a `.h5ad` single cell gene expression file or select the species to generate UMAP projections and download the embeddings.</p>
            </div>
        """)
        
        # Inputs
        file_input = gr.File(label="Upload a .h5ad single cell gene expression file or select a default dataset below")
        species_input = gr.Dropdown(choices=["human", "mouse"], label="Select species")
        default_dataset_input = gr.Dropdown(choices=["None", "PBMC 100 cells", "PBMC 1000 cells"], label="Select default dataset")
        
        default_dataset_input.change(toggle_file_input, inputs=[default_dataset_input], outputs=[file_input])
        
        # Outputs
        run_button = gr.Button("Run")
        with gr.Row():
            image_output = gr.Image(type="numpy", label="UMAP of UCE Embeddings")
            file_output = gr.File(label="Download embeddings")
            pred_output = gr.File(label="Download predictions")
    
        # Run the function on button click
        run_button.click(fn=main, inputs=[file_input, species_input, default_dataset_input], outputs=[image_output, file_output, pred_output])

    demo.launch()

if __name__ == "__main__":
    create_demo()