{ "cells": [ { "cell_type": "markdown", "id": "1566fa5a-215b-4c57-9f45-ce9cd0382b00", "metadata": {}, "source": [ "# Spatial domain identification and UMAP visualization" ] }, { "cell_type": "code", "execution_count": null, "id": "3f2566dd-983d-493b-befa-2ee306252fcc", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import anndata as ad\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import scanpy as sc\n", "import networkx as nx\n", "from umap.umap_ import UMAP\n", "from sklearn.mixture import GaussianMixture\n", "\n", "from matplotlib.lines import Line2D\n", "import matplotlib as mpl\n", "mpl.rcParams['pdf.fonttype'] = 42\n", "mpl.rcParams['ps.fonttype'] = 42\n", "\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "code", "execution_count": null, "id": "48c454ef-85ac-4f4e-a681-e5d846932cd1", "metadata": {}, "outputs": [], "source": [ "def match_cluster_labels(true_labels, est_labels):\n", " true_labels_arr = np.array(list(true_labels))\n", " est_labels_arr = np.array(list(est_labels))\n", "\n", " org_cat = list(np.sort(list(pd.unique(true_labels))))\n", " est_cat = list(np.sort(list(pd.unique(est_labels))))\n", "\n", " B = nx.Graph()\n", " B.add_nodes_from([i + 1 for i in range(len(org_cat))], bipartite=0)\n", " B.add_nodes_from([-j - 1 for j in range(len(est_cat))], bipartite=1)\n", "\n", " for i in range(len(org_cat)):\n", " for j in range(len(est_cat)):\n", " weight = np.sum((true_labels_arr == org_cat[i]) * (est_labels_arr == est_cat[j]))\n", " B.add_edge(i + 1, -j - 1, weight=-weight)\n", "\n", " match = nx.algorithms.bipartite.matching.minimum_weight_full_matching(B)\n", "\n", " if len(org_cat) >= len(est_cat):\n", " return np.array([match[-est_cat.index(c) - 1] - 1 for c in est_labels_arr])\n", " else:\n", " unmatched = [c for c in est_cat if not (-est_cat.index(c) - 1) in match.keys()]\n", " l = []\n", " for c in est_labels_arr:\n", " if (-est_cat.index(c) - 1) in match:\n", " l.append(match[-est_cat.index(c) - 1] - 1)\n", " else:\n", " l.append(len(org_cat) + unmatched.index(c))\n", " return np.array(l)\n", "\n", "\n", "def plot_DLPFC(rna_list, adata_concat, ground_truth_key, matched_clusters_key, model, group_idx, cluster_to_color_map,\n", " matched_to_color_map, cluster_orders, slice_name_list, cls_list, sp_embedding,\n", " save_root=None, frame_color=None, file_format='pdf', save=False, plot=False):\n", "\n", " samples = ['A', 'B', 'C']\n", "\n", " fig, axs = plt.subplots(2, 4, figsize=(15, 7))\n", " fig.suptitle(f'{model} Clustering Results (Sample {samples[group_idx]})', fontsize=16)\n", " for i in range(len(rna_list)):\n", " real_colors = list(rna_list[i].obs[ground_truth_key].astype('str').map(cluster_to_color_map))\n", " axs[0, i].scatter(rna_list[i].obsm['spatial'][:, 0], rna_list[i].obsm['spatial'][:, 1], linewidth=0.5, s=30,\n", " marker=\".\", color=real_colors, alpha=0.9)\n", " axs[0, i].set_title(f'{slice_name_list[i]} (Ground Truth)', size=12)\n", " axs[0, i].invert_yaxis()\n", " axs[0, i].axis('off')\n", "\n", " cluster_colors = list(rna_list[i].obs[matched_clusters_key].map(matched_to_color_map))\n", " axs[1, i].scatter(rna_list[i].obsm['spatial'][:, 0], rna_list[i].obsm['spatial'][:, 1], linewidth=0.5, s=30,\n", " marker=\".\", color=cluster_colors, alpha=0.9)\n", " axs[1, i].set_title(f'{slice_name_list[i]} (Cluster Results)', size=12)\n", " axs[1, i].invert_yaxis()\n", " axs[1, i].axis('off')\n", "\n", " legend_handles_1 = [\n", " Line2D([0], [0], marker='o', color='w', markersize=8, markerfacecolor=cluster_to_color_map[cluster],\n", " label=cluster) for cluster in cls_list\n", " ]\n", " axs[0, 3].legend(\n", " handles=legend_handles_1,\n", " fontsize=8, title='Spot-types', title_fontsize=10, bbox_to_anchor=(1, 1.15))\n", " legend_handles_2 = [\n", " Line2D([0], [0], marker='o', color='w', markersize=8, markerfacecolor=matched_to_color_map[order],\n", " label=f'{i}') for i, order in enumerate(cluster_orders)\n", " ]\n", " axs[1, 3].legend(\n", " handles=legend_handles_2,\n", " fontsize=8, title='Clusters', title_fontsize=10, bbox_to_anchor=(1, 1.1))\n", " plt.gcf().subplots_adjust(left=0.05, top=None, bottom=None, right=0.85)\n", " if save:\n", " save_path = save_root + f'/{model}_group{group_idx}_clustering_results.{file_format}'\n", " plt.savefig(save_path, dpi=500)\n", "\n", " n_spots = adata_concat.shape[0]\n", " size = 10000 / n_spots\n", " order = np.arange(n_spots)\n", " colors_for_slices = [[0.2298057, 0.29871797, 0.75368315],\n", " [0.70567316, 0.01555616, 0.15023281],\n", " [0.2298057, 0.70567316, 0.15023281],\n", " [0.5830223, 0.59200322, 0.12993134]]\n", " slice_cmap = {slice_name_list[i]: colors_for_slices[i] for i in range(len(slice_name_list))}\n", " colors = list(adata_concat.obs['slice_name'].astype('str').map(slice_cmap))\n", " plt.figure(figsize=(5, 5))\n", " if frame_color:\n", " plt.rc('axes', edgecolor=frame_color, linewidth=2)\n", " plt.scatter(sp_embedding[order, 0], sp_embedding[order, 1], s=size, c=colors)\n", " plt.tick_params(axis='both', bottom=False, top=False, left=False, right=False,\n", " labelleft=False, labelbottom=False, grid_alpha=0)\n", " plt.title(f'Slices ({model}/Sample {samples[group_idx]})', fontsize=14)\n", " if save:\n", " save_path = save_root + f\"/{model}_group{group_idx}_slices_umap.{file_format}\"\n", " plt.savefig(save_path)\n", "\n", " colors = list(adata_concat.obs[ground_truth_key].astype('str').map(cluster_to_color_map))\n", " plt.figure(figsize=(5, 5))\n", " if frame_color:\n", " plt.rc('axes', edgecolor=frame_color, linewidth=2)\n", " plt.scatter(sp_embedding[order, 0], sp_embedding[order, 1], s=size, c=colors)\n", " plt.tick_params(axis='both', bottom=False, top=False, left=False, right=False,\n", " labelleft=False, labelbottom=False, grid_alpha=0)\n", " plt.title(f'Annotated Spot-types ({model}/Sample {samples[group_idx]})', fontsize=14)\n", " if save:\n", " save_path = save_root + f\"/{model}_group{group_idx}_annotated_clusters_umap.{file_format}\"\n", " plt.savefig(save_path)\n", "\n", " colors = list(adata_concat.obs[matched_clusters_key].map(matched_to_color_map))\n", " plt.figure(figsize=(5, 5))\n", " if frame_color:\n", " plt.rc('axes', edgecolor=frame_color, linewidth=2)\n", " plt.scatter(sp_embedding[order, 0], sp_embedding[order, 1], s=size, c=colors)\n", " plt.tick_params(axis='both', bottom=False, top=False, left=False, right=False,\n", " labelleft=False, labelbottom=False, grid_alpha=0)\n", " plt.title(f'Identified Clusters ({model}/Sample {samples[group_idx]})', fontsize=14)\n", " if save:\n", " save_path = save_root + f\"/{model}_group{group_idx}_identified_clusters_umap.{file_format}\"\n", " plt.savefig(save_path)\n", "\n", " if plot:\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "cbd4f423-65f9-447a-a806-74aedb6ac581", "metadata": {}, "outputs": [], "source": [ "save_dir = '../../results/DLPFC_Maynard2021/'\n", "save = True\n", "\n", "# DLPFC\n", "data_dir = '../../data/STdata/10xVisium/DLPFC_Maynard2021/'\n", "sample_group_list = [['151507', '151508', '151509', '151510'],\n", " ['151669', '151670', '151671', '151672'],\n", " ['151673', '151674', '151675', '151676']]\n", "cls_list = ['Layer1', 'Layer2', 'Layer3', 'Layer4', 'Layer5', 'Layer6', 'WM']\n", "num_clusters_list = [7, 5, 7]\n", "samples = ['A', 'B', 'C']\n", "\n", "file_format = 'pdf'\n", "\n", "layer_to_color_map = {'Layer{0}'.format(i+1): sns.color_palette()[i] for i in range(6)}\n", "layer_to_color_map['WM'] = sns.color_palette()[6]\n", "matched_to_color_map = {i+1: sns.color_palette()[i] for i in range(7)}" ] }, { "cell_type": "code", "execution_count": null, "id": "456ff028-b6b5-49dd-8f30-f06c40b87621", "metadata": {}, "outputs": [], "source": [ "reducer = UMAP(n_neighbors=30, n_components=2, metric=\"correlation\", n_epochs=None, learning_rate=1.0,\n", " min_dist=0.3, spread=1.0, set_op_mix_ratio=1.0, local_connectivity=1, repulsion_strength=1,\n", " negative_sample_rate=5, a=None, b=None, random_state=1234, metric_kwds=None,\n", " angular_rp_forest=False, verbose=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "a49e116a-ffbb-4296-aa0f-353ecf4aee62", "metadata": {}, "outputs": [], "source": [ "for idx in range(len(sample_group_list)):\n", "\n", " slice_name_list = sample_group_list[idx]\n", " slice_index_list = list(range(len(slice_name_list)))\n", "\n", " rna_list = []\n", " for sample in slice_name_list:\n", " adata = sc.read_visium(path=data_dir + f'{sample}/',\n", " count_file=sample + '_filtered_feature_bc_matrix.h5')\n", " adata.var_names_make_unique()\n", "\n", " # read the annotation\n", " Ann_df = pd.read_csv(data_dir + f'{sample}/meta_data.csv', sep=',', index_col=0)\n", "\n", " if not all(Ann_df.index.isin(adata.obs_names)):\n", " raise ValueError(\"Some rows in the annotation file are not present in the adata.obs_names\")\n", "\n", " adata.obs['image_row'] = Ann_df.loc[adata.obs_names, 'imagerow']\n", " adata.obs['image_col'] = Ann_df.loc[adata.obs_names, 'imagecol']\n", " adata.obs['Manual_Annotation'] = Ann_df.loc[adata.obs_names, 'ManualAnnotation']\n", "\n", " adata.obs_names = [x + '_' + sample for x in adata.obs_names]\n", " rna_list.append(adata)\n", "\n", " # concatenation\n", " adata_concat = ad.concat(rna_list, label=\"slice_name\", keys=slice_name_list)\n", "\n", " # plot clustering results\n", " embed = pd.read_csv(save_dir + f'/INSTINCT_embed_group{idx}.csv', header=None).values\n", " adata_concat.obsm['latent'] = embed\n", " gm = GaussianMixture(n_components=num_clusters_list[idx], covariance_type='tied', random_state=1234)\n", " y = gm.fit_predict(adata_concat.obsm['latent'], y=None)\n", " adata_concat.obs[\"gm_clusters\"] = pd.Series(y, index=adata_concat.obs.index, dtype='category')\n", "\n", " adata_concat = adata_concat[~adata_concat.obs['Manual_Annotation'].isna(), :]\n", " spots_count = [0]\n", " n = 0\n", " for k in range(len(rna_list)):\n", " rna_list[k] = rna_list[k][~rna_list[k].obs['Manual_Annotation'].isna(), :]\n", " num = rna_list[k].shape[0]\n", " n += num\n", " spots_count.append(n)\n", "\n", " if idx != 1:\n", " adata_concat.obs['matched_clusters'] = list(pd.Series(1 + match_cluster_labels(\n", " adata_concat.obs['Manual_Annotation'], adata_concat.obs[\"gm_clusters\"]),\n", " index=adata_concat.obs.index, dtype='category'))\n", " else:\n", " adata_concat.obs['matched_clusters'] = list(pd.Series(3 + match_cluster_labels(\n", " adata_concat.obs['Manual_Annotation'], adata_concat.obs[\"gm_clusters\"]),\n", " index=adata_concat.obs.index, dtype='category'))\n", " my_clusters = np.sort(list(set(adata_concat.obs['matched_clusters'])))\n", "\n", " for i in range(len(rna_list)):\n", " rna_list[i].obs['matched_clusters'] = list(adata_concat.obs['matched_clusters'][spots_count[i]:spots_count[i+1]])\n", "\n", " sp_embedding = reducer.fit_transform(adata_concat.obsm['latent'])\n", "\n", " plot_DLPFC(rna_list, adata_concat, 'Manual_Annotation', 'matched_clusters', 'INSTINCT', idx, layer_to_color_map,\n", " matched_to_color_map, my_clusters, slice_name_list, cls_list, sp_embedding,\n", " save_root=save_dir, frame_color='darkviolet', file_format=file_format,\n", " save=save, plot=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 5 }