{ "cells": [ { "cell_type": "markdown", "id": "80d94024-43cc-488d-aa1d-7d29fef337d8", "metadata": {}, "source": [ "# Integrating samples form DLPFC dataset\n", "Human dorsolateral prefrontal cortex (DLPFC) dataset is an SRT dataset, which contains three sets of slices, with each set contains four slices that exhibit vertical adjacent structure and came from one donor. \n", "In this case, we demonstrate that INSTINCT has the ability for integrating SRT samples." ] }, { "cell_type": "code", "execution_count": null, "id": "b19195ba-7da5-4978-a34c-979c2dc483c5", "metadata": {}, "outputs": [], "source": [ "import os\n", "import csv\n", "import torch\n", "import numpy as np\n", "import pandas as pd\n", "import anndata as ad\n", "import scanpy as sc\n", "\n", "from sklearn.decomposition import PCA\n", "from sklearn.mixture import GaussianMixture\n", "\n", "import INSTINCT\n", "\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", "from sklearn.metrics.cluster import adjusted_rand_score\n", "from sklearn.metrics.cluster import normalized_mutual_info_score\n", "from sklearn.metrics.cluster import fowlkes_mallows_score\n", "from sklearn.metrics.cluster import homogeneity_score\n", "from sklearn.metrics.cluster import adjusted_mutual_info_score\n", "from sklearn.metrics.cluster import completeness_score\n", "import sklearn\n", "import sklearn.neighbors\n", "import networkx as nx\n", "import scib" ] }, { "cell_type": "code", "execution_count": null, "id": "c1f1dfd7-21a0-45d3-ad70-69d9f2f166c3", "metadata": {}, "outputs": [], "source": [ "def match_cluster_labels(true_labels, est_labels):\n", " true_labels_arr = np.array(list(true_labels))\n", " est_labels_arr = np.array(list(est_labels))\n", "\n", " org_cat = list(np.sort(list(pd.unique(true_labels))))\n", " est_cat = list(np.sort(list(pd.unique(est_labels))))\n", "\n", " B = nx.Graph()\n", " B.add_nodes_from([i + 1 for i in range(len(org_cat))], bipartite=0)\n", " B.add_nodes_from([-j - 1 for j in range(len(est_cat))], bipartite=1)\n", "\n", " for i in range(len(org_cat)):\n", " for j in range(len(est_cat)):\n", " weight = np.sum((true_labels_arr == org_cat[i]) * (est_labels_arr == est_cat[j]))\n", " B.add_edge(i + 1, -j - 1, weight=-weight)\n", "\n", " match = nx.algorithms.bipartite.matching.minimum_weight_full_matching(B)\n", "\n", " if len(org_cat) >= len(est_cat):\n", " return np.array([match[-est_cat.index(c) - 1] - 1 for c in est_labels_arr])\n", " else:\n", " unmatched = [c for c in est_cat if not (-est_cat.index(c) - 1) in match.keys()]\n", " l = []\n", " for c in est_labels_arr:\n", " if (-est_cat.index(c) - 1) in match:\n", " l.append(match[-est_cat.index(c) - 1] - 1)\n", " else:\n", " l.append(len(org_cat) + unmatched.index(c))\n", " return np.array(l)\n", "\n", "\n", "def cluster_metrics(target, pred):\n", " target = np.array(target)\n", " pred = np.array(pred)\n", " \n", " ari = adjusted_rand_score(target, pred)\n", " ami = adjusted_mutual_info_score(target, pred)\n", " nmi = normalized_mutual_info_score(target, pred)\n", " fmi = fowlkes_mallows_score(target, pred)\n", " comp = completeness_score(target, pred)\n", " homo = homogeneity_score(target, pred)\n", " print('ARI: %.3f, AMI: %.3f, NMI: %.3f, FMI: %.3f, Comp: %.3f, Homo: %.3f' % (ari, ami, nmi, fmi, comp, homo))\n", " \n", " return ari, ami, nmi, fmi, comp, homo\n", "\n", "\n", "def mean_average_precision(x: np.ndarray, y: np.ndarray, k: int=30, **kwargs) -> float:\n", " r\"\"\"\n", " Mean average precision\n", " Parameters\n", " ----------\n", " x\n", " Coordinates\n", " y\n", " Cell_type/Layer labels\n", " k\n", " k neighbors\n", " **kwargs\n", " Additional keyword arguments are passed to\n", " :class:`sklearn.neighbors.NearestNeighbors`\n", " Returns\n", " -------\n", " map\n", " Mean average precision\n", " \"\"\"\n", " \n", " def _average_precision(match: np.ndarray) -> float:\n", " if np.any(match):\n", " cummean = np.cumsum(match) / (np.arange(match.size) + 1)\n", " return cummean[match].mean().item()\n", " return 0.0\n", " \n", " y = np.array(y)\n", " knn = sklearn.neighbors.NearestNeighbors(n_neighbors=min(y.shape[0], k + 1), **kwargs).fit(x)\n", " nni = knn.kneighbors(x, return_distance=False)\n", " match = np.equal(y[nni[:, 1:]], np.expand_dims(y, 1))\n", " \n", " return np.apply_along_axis(_average_precision, 1, match).mean().item()\n", "\n", "\n", "def rep_metrics(adata, origin_concat, use_rep, label_key, batch_key, k_map=30):\n", " if label_key not in adata.obs or batch_key not in adata.obs or use_rep not in adata.obsm:\n", " print(\"KeyError\")\n", " return None\n", " \n", " adata.obs[label_key] = adata.obs[label_key].astype(str).astype(\"category\")\n", " adata.obs[batch_key] = adata.obs[batch_key].astype(str).astype(\"category\")\n", " origin_concat.X = origin_concat.X.astype(float)\n", " sc.pp.neighbors(adata, use_rep=use_rep)\n", "\n", " MAP = mean_average_precision(adata.obsm[use_rep].copy(), adata.obs[label_key], k=k_map)\n", " cell_type_ASW = scib.me.silhouette(adata, label_key=label_key, embed=use_rep)\n", " # g_iLISI = scib.me.ilisi_graph(adata, batch_key=batch_key, type_=\"embed\", use_rep=use_rep)\n", " batch_ASW = scib.me.silhouette_batch(adata, batch_key=batch_key, label_key=label_key, embed=use_rep, verbose=False)\n", " batch_PCR = scib.me.pcr_comparison(origin_concat, adata, covariate=batch_key, embed=use_rep)\n", " kBET = scib.me.kBET(adata, batch_key=batch_key, label_key=label_key, type_='embed', embed=use_rep)\n", " g_conn = scib.me.graph_connectivity(adata, label_key=label_key)\n", " print('mAP: %.3f, Cell type ASW: %.3f, Batch ASW: %.3f, Batch PCR: %.3f, kBET: %.3f, Graph connectivity: %.3f' %\n", " (MAP, cell_type_ASW, batch_ASW, batch_PCR, kBET, g_conn))\n", " \n", " return MAP, cell_type_ASW, batch_ASW, batch_PCR, kBET, g_conn" ] }, { "cell_type": "markdown", "id": "a9f23ac4-1a29-40f2-9f28-e34a9cc3c7d8", "metadata": {}, "source": [ "### Run model\n", "For preprocessing SRT data, we use INSTINCT.preprocess_SRT()" ] }, { "cell_type": "code", "execution_count": null, "id": "320ddcbe-60b5-45bb-97a1-60342e2f1a6c", "metadata": {}, "outputs": [], "source": [ "# DLPFC\n", "data_dir = '../../data/STdata/10xVisium/DLPFC_Maynard2021/'\n", "sample_group_list = [['151507', '151508', '151509', '151510'],\n", " ['151669', '151670', '151671', '151672'],\n", " ['151673', '151674', '151675', '151676']]\n", "n_cluster_list = [7, 5, 7]\n", "\n", "save_dir = '../../results/DLPFC_Maynard2021/'\n", "if not os.path.exists(save_dir):\n", " os.makedirs(save_dir)\n", "\n", "for idx in range(len(sample_group_list)):\n", "\n", " # load data\n", " slice_name_list = sample_group_list[idx]\n", " slice_index_list = list(range(len(slice_name_list)))\n", "\n", " rna_list = []\n", " for sample in slice_name_list:\n", " adata = sc.read_visium(path=data_dir + f'{sample}/', count_file=sample + '_filtered_feature_bc_matrix.h5')\n", " adata.var_names_make_unique()\n", "\n", " # read the annotation\n", " Ann_df = pd.read_csv(data_dir + f'{sample}/meta_data.csv', sep=',', index_col=0)\n", "\n", " if not all(Ann_df.index.isin(adata.obs_names)):\n", " raise ValueError(\"Some rows in the annotation file are not present in the adata.obs_names\")\n", "\n", " adata.obs['image_row'] = Ann_df.loc[adata.obs_names, 'imagerow']\n", " adata.obs['image_col'] = Ann_df.loc[adata.obs_names, 'imagecol']\n", " adata.obs['Manual_Annotation'] = Ann_df.loc[adata.obs_names, 'ManualAnnotation']\n", "\n", " adata.obs_names = [x + '_' + sample for x in adata.obs_names]\n", " rna_list.append(adata)\n", " # print(adata.shape)\n", "\n", " # concatenation\n", " adata_concat = ad.concat(rna_list, label=\"slice_name\", keys=slice_name_list)\n", " # adata_concat.obs_names_make_unique()\n", "\n", " # preprocess SRT data\n", " print('Start preprocessing')\n", " rna_list, adata_concat = INSTINCT.preprocess_SRT(rna_list, adata_concat, n_top_genes=5000)\n", " print(adata_concat.shape)\n", " print('Done!')\n", "\n", " origin_concat = ad.concat(rna_list, label=\"slice_name\", keys=slice_index_list)\n", "\n", " print(f'Applying PCA to reduce the feature dimension to 100 ...')\n", " pca = PCA(n_components=100, random_state=1234)\n", " input_matrix = pca.fit_transform(adata_concat.X.toarray())\n", " np.save(save_dir + f'input_matrix_group{idx}.npy', input_matrix)\n", " print('Done !')\n", "\n", " input_matrix = np.load(save_dir + f'input_matrix_group{idx}.npy')\n", " adata_concat.obsm['X_pca'] = input_matrix\n", "\n", " # calculate the spatial graph\n", " INSTINCT.create_neighbor_graph(rna_list, adata_concat)\n", "\n", " spots_count = [0]\n", " n = 0\n", " for sample in rna_list:\n", " num = sample.shape[0]\n", " n += num\n", " spots_count.append(n)\n", "\n", " INSTINCT_model = INSTINCT.INSTINCT_Model(rna_list, adata_concat, device=device)\n", "\n", " INSTINCT_model.train(report_loss=True, report_interval=100)\n", "\n", " INSTINCT_model.eval(rna_list)\n", "\n", " result = ad.concat(rna_list, label=\"slice_name\", keys=slice_index_list)\n", "\n", " with open(save_dir + f'INSTINCT_embed_group{idx}.csv', 'w', newline='') as file:\n", " writer = csv.writer(file)\n", " writer.writerows(result.obsm['INSTINCT_latent'])\n", "\n", " with open(save_dir + f'INSTINCT_noise_embed_group{idx}.csv', 'w', newline='') as file:\n", " writer = csv.writer(file)\n", " writer.writerows(result.obsm['INSTINCT_latent_noise'])\n", "\n", " gm = GaussianMixture(n_components=n_cluster_list[idx], covariance_type='tied', random_state=1234)\n", " y = gm.fit_predict(result.obsm['INSTINCT_latent'], y=None)\n", " result.obs[\"gm_clusters\"] = pd.Series(y, index=result.obs.index, dtype='category')\n", " result.obs['matched_clusters'] = pd.Series(match_cluster_labels(result.obs['Manual_Annotation'],\n", " result.obs[\"gm_clusters\"]),\n", " index=result.obs.index, dtype='category')\n", "\n", " ari, ami, nmi, fmi, comp, homo = cluster_metrics(result.obs['Manual_Annotation'],\n", " result.obs['matched_clusters'].tolist())\n", " map, c_asw, b_asw, b_pcr, kbet, g_conn = rep_metrics(result, origin_concat, use_rep='INSTINCT_latent',\n", " label_key='Manual_Annotation', batch_key='slice_name')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 5 }