{ "cells": [ { "cell_type": "markdown", "id": "bd9a810d-9f68-4158-81c6-2a894e1c4150", "metadata": {}, "source": [ "# Integrating all six slices from spatial ATAC ME dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "ab5357cf-e2a1-4b1c-bacf-0858f91b2115", "metadata": {}, "outputs": [], "source": [ "import os\n", "import csv\n", "import torch\n", "import numpy as np\n", "import pandas as pd\n", "import anndata as ad\n", "\n", "from sklearn.decomposition import PCA\n", "from sklearn.mixture import GaussianMixture\n", "\n", "import INSTINCT\n", "\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", "from sklearn.metrics.cluster import adjusted_rand_score\n", "from sklearn.metrics.cluster import normalized_mutual_info_score\n", "from sklearn.metrics.cluster import fowlkes_mallows_score\n", "from sklearn.metrics.cluster import homogeneity_score\n", "from sklearn.metrics.cluster import adjusted_mutual_info_score\n", "from sklearn.metrics.cluster import completeness_score\n", "import sklearn\n", "import sklearn.neighbors\n", "import networkx as nx\n", "\n", "import scib\n", "import scanpy as sc" ] }, { "cell_type": "markdown", "id": "816e3764-8d05-46d8-b282-7a77ce28c5d5", "metadata": {}, "source": [ "### Load raw data\n", "The peaks have already been merged by the original study, so their is no need to merge them again." ] }, { "cell_type": "code", "execution_count": null, "id": "0bdcd8d2-305a-42ce-a46c-957c7a0d9a11", "metadata": {}, "outputs": [], "source": [ "# mouse embryo\n", "data_dir = '../../data/spCASdata/MouseEmbryo_Llorens-Bobadilla2023/spATAC/'\n", "save_dir = '../../results/MouseEmbryo_Llorens-Bobadilla2023/all/'\n", "\n", "slice_name_list = ['E12_5-S1', 'E12_5-S2', 'E13_5-S1', 'E13_5-S2', 'E15_5-S1', 'E15_5-S2']\n", "slice_index_list = list(range(len(slice_name_list)))\n", "\n", "if not os.path.exists(save_dir):\n", " os.makedirs(save_dir)\n", "\n", "# load dataset\n", "cas_list = [ad.read_h5ad(data_dir + sample + '.h5ad') for sample in slice_name_list]\n", "for i in range(len(cas_list)):\n", " cas_list[i].obs_names = [x + '_' + slice_name_list[i] for x in cas_list[i].obs_names]\n", "\n", "# concatenation\n", "adata_concat = ad.concat(cas_list, label=\"slice_name\", keys=slice_name_list)" ] }, { "cell_type": "markdown", "id": "57e14ef2-b872-47df-9e1f-5accaccef9fd", "metadata": {}, "source": [ "### Data preprocessing\n", "Since the data matirces are fragment count matrices already, we set use_fragment_count=False" ] }, { "cell_type": "code", "execution_count": null, "id": "965c8bb5-a85e-4e91-adfb-af4d187335ac", "metadata": {}, "outputs": [], "source": [ "# preprocess CAS data\n", "# peaks are already merged and fragment counts are stored in the data matrices\n", "print('Start preprocessing')\n", "INSTINCT.preprocess_CAS(cas_list, adata_concat, use_fragment_count=False, min_cells_rate=0.003)\n", "print('Done!')\n", "print(adata_concat)" ] }, { "cell_type": "code", "execution_count": null, "id": "6544f7a3-1ebb-44cd-9ebf-50b325911b27", "metadata": {}, "outputs": [], "source": [ "adata_concat.write_h5ad(save_dir + f\"preprocessed_concat.h5ad\")\n", "for i in range(len(slice_name_list)):\n", " cas_list[i].write_h5ad(save_dir + f\"filtered_{slice_name_list[i]}.h5ad\")\n", "\n", "cas_list = [ad.read_h5ad(save_dir + f\"filtered_{sample}.h5ad\") for sample in slice_name_list]\n", "origin_concat = ad.concat(cas_list, label=\"slice_idx\", keys=slice_index_list)\n", "adata_concat = ad.read_h5ad(save_dir + f\"preprocessed_concat.h5ad\")" ] }, { "cell_type": "markdown", "id": "946446bd-fc53-4baa-b808-342c880e000d", "metadata": {}, "source": [ "### Perform PCA" ] }, { "cell_type": "code", "execution_count": null, "id": "867b0213-0ff6-420e-b7e9-7fe380eec8fd", "metadata": {}, "outputs": [], "source": [ "print(f'Applying PCA to reduce the feature dimension to 100 ...')\n", "pca = PCA(n_components=100, random_state=1234)\n", "input_matrix = pca.fit_transform(adata_concat.X.toarray())\n", "np.save(save_dir + f'input_matrix.npy', input_matrix)\n", "print('Done !')\n", "\n", "input_matrix = np.load(save_dir + 'input_matrix.npy')\n", "adata_concat.obsm['X_pca'] = input_matrix" ] }, { "cell_type": "markdown", "id": "bb81959c-5e51-429e-bd18-fd446775ed39", "metadata": {}, "source": [ "### Create neighbor graph" ] }, { "cell_type": "code", "execution_count": null, "id": "e5bd596a-746a-4749-9297-3eeeb8408e73", "metadata": {}, "outputs": [], "source": [ "# calculate the spatial graph\n", "INSTINCT.create_neighbor_graph(cas_list, adata_concat)" ] }, { "cell_type": "markdown", "id": "b77481d7-dcc1-4c7b-aa68-0859e772aa29", "metadata": {}, "source": [ "### Run model" ] }, { "cell_type": "code", "execution_count": null, "id": "ae0dee9c-56a5-4c0b-a4de-53b2291bc542", "metadata": {}, "outputs": [], "source": [ "INSTINCT_model = INSTINCT.INSTINCT_Model(cas_list, adata_concat, device=device)\n", "\n", "INSTINCT_model.train(report_loss=True, report_interval=100)\n", "\n", "INSTINCT_model.eval(cas_list)" ] }, { "cell_type": "code", "execution_count": null, "id": "7fd8d79d-1b58-4e5f-8ba5-cf7f13eddc56", "metadata": {}, "outputs": [], "source": [ "result = ad.concat(cas_list, label=\"slice_idx\", keys=slice_index_list)\n", "\n", "with open(save_dir + 'INSTINCT_embed.csv', 'w', newline='') as file:\n", " writer = csv.writer(file)\n", " writer.writerows(result.obsm['INSTINCT_latent'])\n", "\n", "with open(save_dir + 'INSTINCT_noise_embed.csv', 'w', newline='') as file:\n", " writer = csv.writer(file)\n", " writer.writerows(result.obsm['INSTINCT_latent_noise'])" ] }, { "cell_type": "markdown", "id": "e1407e21-c454-4988-a7c8-94f18fa728de", "metadata": {}, "source": [ "### Clustering " ] }, { "cell_type": "code", "execution_count": null, "id": "0dc54071-8c49-4607-91e1-6aacd81713f0", "metadata": {}, "outputs": [], "source": [ "def match_cluster_labels(true_labels, est_labels):\n", " true_labels_arr = np.array(list(true_labels))\n", " est_labels_arr = np.array(list(est_labels))\n", "\n", " org_cat = list(np.sort(list(pd.unique(true_labels))))\n", " est_cat = list(np.sort(list(pd.unique(est_labels))))\n", "\n", " B = nx.Graph()\n", " B.add_nodes_from([i + 1 for i in range(len(org_cat))], bipartite=0)\n", " B.add_nodes_from([-j - 1 for j in range(len(est_cat))], bipartite=1)\n", "\n", " for i in range(len(org_cat)):\n", " for j in range(len(est_cat)):\n", " weight = np.sum((true_labels_arr == org_cat[i]) * (est_labels_arr == est_cat[j]))\n", " B.add_edge(i + 1, -j - 1, weight=-weight)\n", "\n", " match = nx.algorithms.bipartite.matching.minimum_weight_full_matching(B)\n", "\n", " if len(org_cat) >= len(est_cat):\n", " return np.array([match[-est_cat.index(c) - 1] - 1 for c in est_labels_arr])\n", " else:\n", " unmatched = [c for c in est_cat if not (-est_cat.index(c) - 1) in match.keys()]\n", " l = []\n", " for c in est_labels_arr:\n", " if (-est_cat.index(c) - 1) in match:\n", " l.append(match[-est_cat.index(c) - 1] - 1)\n", " else:\n", " l.append(len(org_cat) + unmatched.index(c))\n", " return np.array(l)" ] }, { "cell_type": "code", "execution_count": null, "id": "9fb198bb-5715-48e3-954b-26cd5dfea59a", "metadata": {}, "outputs": [], "source": [ "gm = GaussianMixture(n_components=11, covariance_type='tied', random_state=1234)\n", "y = gm.fit_predict(result.obsm['INSTINCT_latent'], y=None)\n", "result.obs[\"gm_clusters\"] = pd.Series(y, index=result.obs.index, dtype='category')\n", "result.obs['matched_clusters'] = pd.Series(match_cluster_labels(result.obs['clusters'],\n", " result.obs[\"gm_clusters\"]),\n", " index=result.obs.index, dtype='category')" ] }, { "cell_type": "markdown", "id": "04a23b01-d020-430f-a91e-91594e2908a9", "metadata": {}, "source": [ "### Evaluation" ] }, { "cell_type": "code", "execution_count": null, "id": "d0c76cae-9b31-492a-95e1-82dd9dfab2f1", "metadata": {}, "outputs": [], "source": [ "def cluster_metrics(target, pred):\n", " target = np.array(target)\n", " pred = np.array(pred)\n", " \n", " ari = adjusted_rand_score(target, pred)\n", " ami = adjusted_mutual_info_score(target, pred)\n", " nmi = normalized_mutual_info_score(target, pred)\n", " fmi = fowlkes_mallows_score(target, pred)\n", " comp = completeness_score(target, pred)\n", " homo = homogeneity_score(target, pred)\n", " print('ARI: %.3f, AMI: %.3f, NMI: %.3f, FMI: %.3f, Comp: %.3f, Homo: %.3f' % (ari, ami, nmi, fmi, comp, homo))\n", " \n", " return ari, ami, nmi, fmi, comp, homo\n", "\n", "\n", "def mean_average_precision(x: np.ndarray, y: np.ndarray, k: int=30, **kwargs) -> float:\n", " r\"\"\"\n", " Mean average precision\n", " Parameters\n", " ----------\n", " x\n", " Coordinates\n", " y\n", " Cell_type/Layer labels\n", " k\n", " k neighbors\n", " **kwargs\n", " Additional keyword arguments are passed to\n", " :class:`sklearn.neighbors.NearestNeighbors`\n", " Returns\n", " -------\n", " map\n", " Mean average precision\n", " \"\"\"\n", " \n", " def _average_precision(match: np.ndarray) -> float:\n", " if np.any(match):\n", " cummean = np.cumsum(match) / (np.arange(match.size) + 1)\n", " return cummean[match].mean().item()\n", " return 0.0\n", " \n", " y = np.array(y)\n", " knn = sklearn.neighbors.NearestNeighbors(n_neighbors=min(y.shape[0], k + 1), **kwargs).fit(x)\n", " nni = knn.kneighbors(x, return_distance=False)\n", " match = np.equal(y[nni[:, 1:]], np.expand_dims(y, 1))\n", " \n", " return np.apply_along_axis(_average_precision, 1, match).mean().item()\n", "\n", "\n", "def rep_metrics(adata, origin_concat, use_rep, label_key, batch_key, k_map=30):\n", " if label_key not in adata.obs or batch_key not in adata.obs or use_rep not in adata.obsm:\n", " print(\"KeyError\")\n", " return None\n", " \n", " adata.obs[label_key] = adata.obs[label_key].astype(str).astype(\"category\")\n", " adata.obs[batch_key] = adata.obs[batch_key].astype(str).astype(\"category\")\n", " origin_concat.X = origin_concat.X.astype(float)\n", " sc.pp.neighbors(adata, use_rep=use_rep)\n", "\n", " MAP = mean_average_precision(adata.obsm[use_rep].copy(), adata.obs[label_key], k=k_map)\n", " cell_type_ASW = scib.me.silhouette(adata, label_key=label_key, embed=use_rep)\n", " # g_iLISI = scib.me.ilisi_graph(adata, batch_key=batch_key, type_=\"embed\", use_rep=use_rep)\n", " batch_ASW = scib.me.silhouette_batch(adata, batch_key=batch_key, label_key=label_key, embed=use_rep, verbose=False)\n", " batch_PCR = scib.me.pcr_comparison(origin_concat, adata, covariate=batch_key, embed=use_rep)\n", " kBET = scib.me.kBET(adata, batch_key=batch_key, label_key=label_key, type_='embed', embed=use_rep)\n", " g_conn = scib.me.graph_connectivity(adata, label_key=label_key)\n", " print('mAP: %.3f, Cell type ASW: %.3f, Batch ASW: %.3f, Batch PCR: %.3f, kBET: %.3f, Graph connectivity: %.3f' %\n", " (MAP, cell_type_ASW, batch_ASW, batch_PCR, kBET, g_conn))\n", " \n", " return MAP, cell_type_ASW, batch_ASW, batch_PCR, kBET, g_conn" ] }, { "cell_type": "code", "execution_count": null, "id": "b829fb3b-836b-4a88-af15-996efb348f46", "metadata": {}, "outputs": [], "source": [ "ari, ami, nmi, fmi, comp, homo = cluster_metrics(result.obs['clusters'],\n", " result.obs['matched_clusters'].tolist())\n", "map, c_asw, b_asw, b_pcr, kbet, g_conn = rep_metrics(result, origin_concat, use_rep='INSTINCT_latent',\n", " label_key='clusters', batch_key='slice_idx')" ] }, { "cell_type": "markdown", "id": "03dc07ca-e410-422c-9b3f-6add8bbf83b2", "metadata": {}, "source": [ "### Spatial domain identification and UMAP visualization" ] }, { "cell_type": "code", "execution_count": null, "id": "1368d09e-9b59-4532-8ecf-572d5598247e", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import anndata as ad\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import matplotlib.patches as mpatches\n", "from umap.umap_ import UMAP\n", "from sklearn.mixture import GaussianMixture\n", "\n", "from matplotlib.lines import Line2D\n", "import matplotlib as mpl\n", "mpl.rcParams['pdf.fonttype'] = 42\n", "mpl.rcParams['ps.fonttype'] = 42" ] }, { "cell_type": "code", "execution_count": null, "id": "c20610d7-6f6d-428b-b7b7-c1ab9f9928d6", "metadata": {}, "outputs": [], "source": [ "save_dir = '../../results/MouseEmbryo_Llorens-Bobadilla2023/all/'\n", "save = True\n", "\n", "cluster_list = ['Forebrain', 'Midbrain', 'Hindbrain', 'Periventricular', 'Meningeal_PNS_1', 'Meningeal_PNS_2',\n", " 'Internal', 'Facial_bone', 'Muscle_heart', 'Limb', 'Liver']\n", "\n", "label_list = ['Forebrain', 'Midbrain', 'Hindbrain', 'Periventricular', 'Meningeal/PNS_1', 'Meningeal/PNS_2',\n", " 'Internal', 'Facial/bone', 'Muscle/heart', 'Limb', 'Liver']\n", "\n", "color_list = ['royalblue', 'dodgerblue', 'deepskyblue', 'forestgreen', 'yellowgreen', 'y',\n", " 'grey', 'crimson', 'deeppink', 'orchid', 'orange']\n", "\n", "order_list = [1, 8, 2, 10, 6, 7, 3, 0, 9, 4, 5]\n", "\n", "cluster_to_color_map = {cluster: color for cluster, color in zip(cluster_list, color_list)}\n", "order_to_cluster_map = {order: cluster for order, cluster in zip(order_list, cluster_list)}\n", "\n", "reducer = UMAP(n_neighbors=30, n_components=2, metric=\"correlation\", n_epochs=None, learning_rate=1.0,\n", " min_dist=0.3, spread=1.0, set_op_mix_ratio=1.0, local_connectivity=1, repulsion_strength=1,\n", " negative_sample_rate=5, a=None, b=None, random_state=1234, metric_kwds=None,\n", " angular_rp_forest=False, verbose=False)\n", "\n", "slice_name_list = ['E12_5-S1', 'E12_5-S2', 'E13_5-S1', 'E13_5-S2', 'E15_5-S1', 'E15_5-S2']\n", "cas_list = [ad.read_h5ad(save_dir + f\"filtered_{sample}.h5ad\") for sample in slice_name_list]\n", "adata_concat = ad.concat(cas_list, label='slice_name', keys=slice_name_list)\n", "\n", "spots_count = [0]\n", "n = 0\n", "for sample in cas_list:\n", " num = sample.shape[0]\n", " n += num\n", " spots_count.append(n)\n", "\n", "embed = pd.read_csv(save_dir + f'INSTINCT_embed.csv', header=None).values\n", "adata_concat.obsm['latent'] = embed\n", "\n", "gm = GaussianMixture(n_components=len(cluster_list), covariance_type='tied', random_state=1234)\n", "y = gm.fit_predict(adata_concat.obsm['latent'], y=None)\n", "adata_concat.obs[\"gm_clusters\"] = pd.Series(y, index=adata_concat.obs.index, dtype='category')\n", "adata_concat.obs['matched_clusters'] = pd.Series(match_cluster_labels(\n", " adata_concat.obs['clusters'], adata_concat.obs[\"gm_clusters\"]),\n", " index=adata_concat.obs.index, dtype='category')\n", "# adata_concat.obs['matched_clusters'] = list(adata_concat.obs['matched_clusters'].map(order_to_cluster_map))\n", "my_clusters = np.sort(list(set(adata_concat.obs['matched_clusters'])))\n", "matched_colors = [cluster_to_color_map[order_to_cluster_map[order]] for order in my_clusters]\n", "matched_to_color_map = {matched: color for matched, color in zip(my_clusters, matched_colors)}\n", "\n", "for i in range(len(cas_list)):\n", " cas_list[i].obs['matched_clusters'] = adata_concat.obs['matched_clusters'][spots_count[i]:spots_count[i+1]]\n", "\n", "sp_embedding = reducer.fit_transform(adata_concat.obsm['latent'])" ] }, { "cell_type": "code", "execution_count": null, "id": "0c1e56f7-3124-46c4-a3c5-99e89cdf5278", "metadata": {}, "outputs": [], "source": [ "def plot_mouseembryo_6(cas_list, adata_concat, ground_truth_key, matched_clusters_key, model,\n", " cluster_to_color_map, matched_to_color_map, cluster_orders,\n", " slice_name_list, sp_embedding,\n", " save_root=None, frame_color=None, save=False, plot=False):\n", "\n", " fig, axs = plt.subplots(2, 3, figsize=(10, 6))\n", " fig.suptitle(f'{model} Clustering Results', fontsize=16)\n", " for i in range(len(cas_list)):\n", " if slice_name_list[i] == 'E12_5-S1' or slice_name_list[i] == 'E12_5-S2':\n", " size = 20\n", " else:\n", " size = 15\n", " if slice_name_list[i] == 'E15_5-S1':\n", " axs[int(i % 2), int(i / 2)].invert_xaxis()\n", " axs[int(i % 2), int(i / 2)].invert_yaxis()\n", " cluster_colors = list(cas_list[i].obs[matched_clusters_key].map(matched_to_color_map))\n", " axs[int(i % 2), int(i / 2)].scatter(cas_list[i].obsm['spatial'][:, 1], cas_list[i].obsm['spatial'][:, 0],\n", " linewidth=0.5, s=size, marker=\".\", color=cluster_colors, alpha=0.9)\n", " axs[int(i % 2), int(i / 2)].set_title(f'{slice_name_list[i]} (Cluster Results)', size=12)\n", " axs[int(i % 2), int(i / 2)].axis('off')\n", "\n", " legend_handles = [\n", " Line2D([0], [0], marker='o', color='w', markersize=8, markerfacecolor=matched_to_color_map[order],\n", " label=f'{i}') for i, order in enumerate(cluster_orders)\n", " ]\n", " axs[0, 2].legend(\n", " handles=legend_handles,\n", " fontsize=8, title='Clusters', title_fontsize=10, bbox_to_anchor=(1, 1))\n", " plt.gcf().subplots_adjust(left=0.05, top=None, bottom=None, right=0.85)\n", " if save:\n", " save_path = save_root + f'/{model}_clustering_results.pdf'\n", " plt.savefig(save_path)\n", "\n", " n_spots = adata_concat.shape[0]\n", " size = 10000 / n_spots\n", " order = np.arange(n_spots)\n", " colors_for_slices = ['deeppink', 'hotpink', 'darkgoldenrod', 'goldenrod', 'c', 'cyan']\n", " slice_cmap = {slice_name_list[i]: colors_for_slices[i] for i in range(len(slice_name_list))}\n", " colors = list(adata_concat.obs['slice_name'].astype('str').map(slice_cmap))\n", " plt.figure(figsize=(5, 5))\n", " if frame_color:\n", " plt.rc('axes', edgecolor=frame_color, linewidth=2)\n", " plt.scatter(sp_embedding[order, 0], sp_embedding[order, 1], s=size, c=colors)\n", " plt.tick_params(axis='both', bottom=False, top=False, left=False, right=False,\n", " labelleft=False, labelbottom=False, grid_alpha=0)\n", " legend_handles = [\n", " Line2D([0], [0], marker='o', color='w', markersize=8, markerfacecolor=slice_cmap[slice_name_list[i]],\n", " label=slice_name_list[i])\n", " for i in range(len(slice_name_list))\n", " ]\n", " plt.legend(handles=legend_handles, fontsize=8, title='Slices', title_fontsize=10,\n", " loc='upper left')\n", " plt.title(f'Slices ({model})', fontsize=16)\n", " if save:\n", " save_path = save_root + f\"/{model}_slices_umap.pdf\"\n", " plt.savefig(save_path)\n", "\n", " colors = list(adata_concat.obs[ground_truth_key].astype('str').map(cluster_to_color_map))\n", " plt.figure(figsize=(5, 5))\n", " if frame_color:\n", " plt.rc('axes', edgecolor=frame_color, linewidth=2)\n", " plt.scatter(sp_embedding[order, 0], sp_embedding[order, 1], s=size, c=colors)\n", " plt.tick_params(axis='both', bottom=False, top=False, left=False, right=False,\n", " labelleft=False, labelbottom=False, grid_alpha=0)\n", " plt.title(f'Annotated Spot-types ({model})', fontsize=16)\n", " if save:\n", " save_path = save_root + f\"/{model}_annotated_clusters_umap.pdf\"\n", " plt.savefig(save_path)\n", "\n", " colors = list(adata_concat.obs[matched_clusters_key].map(matched_to_color_map))\n", " plt.figure(figsize=(5, 5))\n", " if frame_color:\n", " plt.rc('axes', edgecolor=frame_color, linewidth=2)\n", " plt.scatter(sp_embedding[order, 0], sp_embedding[order, 1], s=size, c=colors)\n", " plt.tick_params(axis='both', bottom=False, top=False, left=False, right=False,\n", " labelleft=False, labelbottom=False, grid_alpha=0)\n", " plt.title(f'Identified Clusters ({model})', fontsize=16)\n", " if save:\n", " save_path = save_root + f\"/{model}_identified_clusters_umap.pdf\"\n", " plt.savefig(save_path)\n", "\n", " if plot:\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "aab925c2-8897-4b48-a268-b529d008f561", "metadata": {}, "outputs": [], "source": [ "plot_mouseembryo_6(cas_list, adata_concat, 'clusters', 'matched_clusters', 'INSTINCT', cluster_to_color_map,\n", " matched_to_color_map, my_clusters, slice_name_list, sp_embedding,\n", " save_root=save_dir, frame_color='darkviolet', save=save, plot=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 5 }