|  | import gradio as gr | 
					
						
						|  | import pandas as pd | 
					
						
						|  | import numpy as np | 
					
						
						|  | import umap | 
					
						
						|  | import json | 
					
						
						|  | import matplotlib.pyplot as plt | 
					
						
						|  | import os | 
					
						
						|  | import scanpy as sc | 
					
						
						|  | import subprocess | 
					
						
						|  | import sys | 
					
						
						|  | from io import BytesIO | 
					
						
						|  | from sklearn.linear_model import LogisticRegression | 
					
						
						|  | from huggingface_hub import hf_hub_download | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load_and_predict_with_classifier(x, model_path, output_path, save): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with open(model_path, 'r') as f: | 
					
						
						|  | model_params = json.load(f) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model_loaded = LogisticRegression(multi_class='multinomial', solver='lbfgs', max_iter=1000) | 
					
						
						|  | model_loaded.coef_ = np.array(model_params["coef"]) | 
					
						
						|  | model_loaded.intercept_ = np.array(model_params["intercept"]) | 
					
						
						|  | model_loaded.classes_ = np.array(model_params["classes"]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | y_pred = model_loaded.predict(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if save: | 
					
						
						|  | df = pd.DataFrame(y_pred, columns=["predicted_cell_type"]) | 
					
						
						|  | df.to_csv(output_path, index=False, header=False) | 
					
						
						|  |  | 
					
						
						|  | return y_pred | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def plot_umap(adata): | 
					
						
						|  |  | 
					
						
						|  | labels = pd.Categorical(adata.obs["cell_type"]) | 
					
						
						|  |  | 
					
						
						|  | reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42) | 
					
						
						|  | embedding = reducer.fit_transform(adata.obsm["X_uce"]) | 
					
						
						|  |  | 
					
						
						|  | plt.figure(figsize=(10, 8)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=labels.codes, cmap='Set1', s=50, alpha=0.6) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | handles = [] | 
					
						
						|  | for i, cell_type in enumerate(labels.categories): | 
					
						
						|  | handles.append(plt.Line2D([0], [0], marker='o', color='w', label=cell_type, | 
					
						
						|  | markerfacecolor=plt.cm.Set1(i / len(labels.categories)), markersize=10)) | 
					
						
						|  |  | 
					
						
						|  | plt.legend(handles=handles, title='Cell Type') | 
					
						
						|  | plt.title('UMAP projection of the data') | 
					
						
						|  | plt.xlabel('UMAP1') | 
					
						
						|  | plt.ylabel('UMAP2') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | buf = BytesIO() | 
					
						
						|  | plt.savefig(buf, format='png') | 
					
						
						|  | buf.seek(0) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | img = plt.imread(buf, format='png') | 
					
						
						|  |  | 
					
						
						|  | return img | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def toggle_file_input(default_dataset): | 
					
						
						|  | if default_dataset != "None": | 
					
						
						|  | return gr.update(interactive=False) | 
					
						
						|  | else: | 
					
						
						|  | return gr.update(interactive=True) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def main(input_file_path, species, default_dataset): | 
					
						
						|  |  | 
					
						
						|  | BASE_PATH = '/home/user/app/UCE/' | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | os.system('git clone https://github.com/minwoosun/UCE.git') | 
					
						
						|  | os.chdir(BASE_PATH) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sys.path.append(BASE_PATH) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | default_dataset_1_path = hf_hub_download(repo_id="minwoosun/uce-misc", filename="100_pbmcs_proc_subset.h5ad") | 
					
						
						|  | default_dataset_2_path = hf_hub_download(repo_id="minwoosun/uce-misc", filename="1k_pbmcs_proc_subset.h5ad") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if default_dataset == "PBMC 100 cells": | 
					
						
						|  | input_file_path = default_dataset_1_path | 
					
						
						|  | elif default_dataset == "PBMC 1000 cells": | 
					
						
						|  | input_file_path = default_dataset_2_path | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from evaluate import AnndataProcessor | 
					
						
						|  | from accelerate import Accelerator | 
					
						
						|  |  | 
					
						
						|  | model_loc = 'minwoosun/uce-100m' | 
					
						
						|  |  | 
					
						
						|  | print(input_file_path) | 
					
						
						|  | print(BASE_PATH) | 
					
						
						|  | print(model_loc) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | command = [ | 
					
						
						|  | 'python', | 
					
						
						|  | BASE_PATH + 'eval_single_anndata.py', | 
					
						
						|  | '--adata_path', input_file_path, | 
					
						
						|  | '--dir', BASE_PATH, | 
					
						
						|  | '--model_loc', model_loc | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print("Running command:", command) | 
					
						
						|  |  | 
					
						
						|  | print("---> RUNNING UCE") | 
					
						
						|  | result = subprocess.run(command, capture_output=True, text=True, check=True) | 
					
						
						|  | print(result.stdout) | 
					
						
						|  | print(result.stderr) | 
					
						
						|  | print("---> FINSIH UCE") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | file_name_with_ext = os.path.basename(input_file_path) | 
					
						
						|  | file_name = os.path.splitext(file_name_with_ext)[0] | 
					
						
						|  | pred_file = BASE_PATH + f"{file_name}_predictions.csv" | 
					
						
						|  | model_path = hf_hub_download(repo_id="minwoosun/uce-misc", filename="tabula_sapiens_v1_logistic_regression_model_weights.json") | 
					
						
						|  |  | 
					
						
						|  | file_name_with_ext = os.path.basename(input_file_path) | 
					
						
						|  | file_name = os.path.splitext(file_name_with_ext)[0] | 
					
						
						|  | output_file = BASE_PATH + f"{file_name}_uce_adata.h5ad" | 
					
						
						|  | adata = sc.read_h5ad(output_file) | 
					
						
						|  | x = adata.obsm['X_uce'] | 
					
						
						|  |  | 
					
						
						|  | y_pred = load_and_predict_with_classifier(x, model_path, pred_file, save=True) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | img = plot_umap(adata) | 
					
						
						|  |  | 
					
						
						|  | return img, output_file, pred_file | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  |  | 
					
						
						|  | with gr.Blocks() as demo: | 
					
						
						|  | gr.Markdown( | 
					
						
						|  | ''' | 
					
						
						|  | <div style="text-align:center; margin-bottom:20px;"> | 
					
						
						|  | <span style="font-size:3em; font-weight:bold;">UCE 100M Demo 🦠</span> | 
					
						
						|  | </div> | 
					
						
						|  | <div style="text-align:center; margin-bottom:10px;"> | 
					
						
						|  | <span style="font-size:1.5em; font-weight:bold;">Universal Cell Embeddings: Zero-Shot Cell-Type Classification in Action!</span> | 
					
						
						|  | </div> | 
					
						
						|  | <div style="text-align:center; margin-bottom:20px;"> | 
					
						
						|  | <a href="https://github.com/minwoosun/UCE"> | 
					
						
						|  | <img src="https://badges.aleen42.com/src/github.svg" alt="GitHub" style="display:inline-block; margin-right:10px;"> | 
					
						
						|  | </a> | 
					
						
						|  | <a href="https://www.biorxiv.org/content/10.1101/2023.11.28.568918v1"> | 
					
						
						|  | <img src="https://img.shields.io/badge/bioRxiv-2023.11.28.568918-green?style=plastic" alt="Paper" style="display:inline-block; margin-right:10px;"> | 
					
						
						|  | </a> | 
					
						
						|  | <a href="https://colab.research.google.com/drive/1opud0BVWr76IM8UnGgTomVggui_xC4p0?usp=sharing"> | 
					
						
						|  | <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" style="display:inline-block; margin-right:10px;"> | 
					
						
						|  | </a> | 
					
						
						|  | </div> | 
					
						
						|  | <div style="text-align:left; margin-bottom:20px;"> | 
					
						
						|  | Upload a `.h5ad` single cell gene expression file and select the species (Human/Mouse). | 
					
						
						|  | The demo will generate UMAP projections of the embeddings and allow you to download the embeddings for further analysis. | 
					
						
						|  | </div> | 
					
						
						|  | <div style="margin-bottom:20px;"> | 
					
						
						|  | <ol style="list-style:none; padding-left:0;"> | 
					
						
						|  | <li>1. Upload your `.h5ad` file or select one of the default datasets (subset of 10x pbmc data)</li> | 
					
						
						|  | <li>2. Select the species</li> | 
					
						
						|  | <li>3. Click "Run" to view the UMAP scatter plot</li> | 
					
						
						|  | <li>4. Download the UCE embeddings and predicted cell-types</li> | 
					
						
						|  | </ol> | 
					
						
						|  | </div> | 
					
						
						|  | <div style="text-align:left; line-height:1.8;"> | 
					
						
						|  | Please consider citing the following paper if you use this tool in your research: | 
					
						
						|  | </div> | 
					
						
						|  | <div style="text-align:left; line-height:1.8;"> | 
					
						
						|  | Rosen, Y., Roohani, Y., Agarwal, A., Samotorčan, L., Tabula Sapiens Consortium, Quake, S. R., & Leskovec, J. Universal Cell Embeddings: A Foundation Model for Cell Biology. bioRxiv. https://doi.org/10.1101/2023.11.28.568918 | 
					
						
						|  | </div> | 
					
						
						|  | ''' | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | file_input = gr.File(label="Upload a .h5ad single cell gene expression file or select a default dataset below") | 
					
						
						|  |  | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | species_input = gr.Dropdown(choices=["human", "mouse"], label="Select species") | 
					
						
						|  | default_dataset_input = gr.Dropdown(choices=["None", "PBMC 100 cells", "PBMC 1000 cells"], label="Select default dataset") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | default_dataset_input.change( | 
					
						
						|  | toggle_file_input, | 
					
						
						|  | inputs=[default_dataset_input], | 
					
						
						|  | outputs=[file_input] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | run_button = gr.Button("Run", elem_classes="run-button") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | image_output = gr.Image(type="numpy", label="UMAP_of_UCE_Embeddings") | 
					
						
						|  | file_output = gr.File(label="Download embeddings") | 
					
						
						|  | pred_output = gr.File(label="Download predictions") | 
					
						
						|  |  | 
					
						
						|  | print(image_output) | 
					
						
						|  | print(file_output) | 
					
						
						|  | print(pred_output) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | run_button.click( | 
					
						
						|  | fn=main, | 
					
						
						|  | inputs=[file_input, species_input, default_dataset_input], | 
					
						
						|  | outputs=[image_output, file_output, pred_output] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | demo.launch() | 
					
						
						|  |  |