Spatial domain identification and UMAP visualization
[ ]:
import numpy as np
import anndata as ad
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
import networkx as nx
from umap.umap_ import UMAP
from sklearn.mixture import GaussianMixture
from matplotlib.lines import Line2D
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
import warnings
warnings.filterwarnings("ignore")
[ ]:
def match_cluster_labels(true_labels, est_labels):
true_labels_arr = np.array(list(true_labels))
est_labels_arr = np.array(list(est_labels))
org_cat = list(np.sort(list(pd.unique(true_labels))))
est_cat = list(np.sort(list(pd.unique(est_labels))))
B = nx.Graph()
B.add_nodes_from([i + 1 for i in range(len(org_cat))], bipartite=0)
B.add_nodes_from([-j - 1 for j in range(len(est_cat))], bipartite=1)
for i in range(len(org_cat)):
for j in range(len(est_cat)):
weight = np.sum((true_labels_arr == org_cat[i]) * (est_labels_arr == est_cat[j]))
B.add_edge(i + 1, -j - 1, weight=-weight)
match = nx.algorithms.bipartite.matching.minimum_weight_full_matching(B)
if len(org_cat) >= len(est_cat):
return np.array([match[-est_cat.index(c) - 1] - 1 for c in est_labels_arr])
else:
unmatched = [c for c in est_cat if not (-est_cat.index(c) - 1) in match.keys()]
l = []
for c in est_labels_arr:
if (-est_cat.index(c) - 1) in match:
l.append(match[-est_cat.index(c) - 1] - 1)
else:
l.append(len(org_cat) + unmatched.index(c))
return np.array(l)
def plot_DLPFC(rna_list, adata_concat, ground_truth_key, matched_clusters_key, model, group_idx, cluster_to_color_map,
matched_to_color_map, cluster_orders, slice_name_list, cls_list, sp_embedding,
save_root=None, frame_color=None, file_format='pdf', save=False, plot=False):
samples = ['A', 'B', 'C']
fig, axs = plt.subplots(2, 4, figsize=(15, 7))
fig.suptitle(f'{model} Clustering Results (Sample {samples[group_idx]})', fontsize=16)
for i in range(len(rna_list)):
real_colors = list(rna_list[i].obs[ground_truth_key].astype('str').map(cluster_to_color_map))
axs[0, i].scatter(rna_list[i].obsm['spatial'][:, 0], rna_list[i].obsm['spatial'][:, 1], linewidth=0.5, s=30,
marker=".", color=real_colors, alpha=0.9)
axs[0, i].set_title(f'{slice_name_list[i]} (Ground Truth)', size=12)
axs[0, i].invert_yaxis()
axs[0, i].axis('off')
cluster_colors = list(rna_list[i].obs[matched_clusters_key].map(matched_to_color_map))
axs[1, i].scatter(rna_list[i].obsm['spatial'][:, 0], rna_list[i].obsm['spatial'][:, 1], linewidth=0.5, s=30,
marker=".", color=cluster_colors, alpha=0.9)
axs[1, i].set_title(f'{slice_name_list[i]} (Cluster Results)', size=12)
axs[1, i].invert_yaxis()
axs[1, i].axis('off')
legend_handles_1 = [
Line2D([0], [0], marker='o', color='w', markersize=8, markerfacecolor=cluster_to_color_map[cluster],
label=cluster) for cluster in cls_list
]
axs[0, 3].legend(
handles=legend_handles_1,
fontsize=8, title='Spot-types', title_fontsize=10, bbox_to_anchor=(1, 1.15))
legend_handles_2 = [
Line2D([0], [0], marker='o', color='w', markersize=8, markerfacecolor=matched_to_color_map[order],
label=f'{i}') for i, order in enumerate(cluster_orders)
]
axs[1, 3].legend(
handles=legend_handles_2,
fontsize=8, title='Clusters', title_fontsize=10, bbox_to_anchor=(1, 1.1))
plt.gcf().subplots_adjust(left=0.05, top=None, bottom=None, right=0.85)
if save:
save_path = save_root + f'/{model}_group{group_idx}_clustering_results.{file_format}'
plt.savefig(save_path, dpi=500)
n_spots = adata_concat.shape[0]
size = 10000 / n_spots
order = np.arange(n_spots)
colors_for_slices = [[0.2298057, 0.29871797, 0.75368315],
[0.70567316, 0.01555616, 0.15023281],
[0.2298057, 0.70567316, 0.15023281],
[0.5830223, 0.59200322, 0.12993134]]
slice_cmap = {slice_name_list[i]: colors_for_slices[i] for i in range(len(slice_name_list))}
colors = list(adata_concat.obs['slice_name'].astype('str').map(slice_cmap))
plt.figure(figsize=(5, 5))
if frame_color:
plt.rc('axes', edgecolor=frame_color, linewidth=2)
plt.scatter(sp_embedding[order, 0], sp_embedding[order, 1], s=size, c=colors)
plt.tick_params(axis='both', bottom=False, top=False, left=False, right=False,
labelleft=False, labelbottom=False, grid_alpha=0)
plt.title(f'Slices ({model}/Sample {samples[group_idx]})', fontsize=14)
if save:
save_path = save_root + f"/{model}_group{group_idx}_slices_umap.{file_format}"
plt.savefig(save_path)
colors = list(adata_concat.obs[ground_truth_key].astype('str').map(cluster_to_color_map))
plt.figure(figsize=(5, 5))
if frame_color:
plt.rc('axes', edgecolor=frame_color, linewidth=2)
plt.scatter(sp_embedding[order, 0], sp_embedding[order, 1], s=size, c=colors)
plt.tick_params(axis='both', bottom=False, top=False, left=False, right=False,
labelleft=False, labelbottom=False, grid_alpha=0)
plt.title(f'Annotated Spot-types ({model}/Sample {samples[group_idx]})', fontsize=14)
if save:
save_path = save_root + f"/{model}_group{group_idx}_annotated_clusters_umap.{file_format}"
plt.savefig(save_path)
colors = list(adata_concat.obs[matched_clusters_key].map(matched_to_color_map))
plt.figure(figsize=(5, 5))
if frame_color:
plt.rc('axes', edgecolor=frame_color, linewidth=2)
plt.scatter(sp_embedding[order, 0], sp_embedding[order, 1], s=size, c=colors)
plt.tick_params(axis='both', bottom=False, top=False, left=False, right=False,
labelleft=False, labelbottom=False, grid_alpha=0)
plt.title(f'Identified Clusters ({model}/Sample {samples[group_idx]})', fontsize=14)
if save:
save_path = save_root + f"/{model}_group{group_idx}_identified_clusters_umap.{file_format}"
plt.savefig(save_path)
if plot:
plt.show()
[ ]:
save_dir = '../../results/DLPFC_Maynard2021/'
save = True
# DLPFC
data_dir = '../../data/STdata/10xVisium/DLPFC_Maynard2021/'
sample_group_list = [['151507', '151508', '151509', '151510'],
['151669', '151670', '151671', '151672'],
['151673', '151674', '151675', '151676']]
cls_list = ['Layer1', 'Layer2', 'Layer3', 'Layer4', 'Layer5', 'Layer6', 'WM']
num_clusters_list = [7, 5, 7]
samples = ['A', 'B', 'C']
file_format = 'pdf'
layer_to_color_map = {'Layer{0}'.format(i+1): sns.color_palette()[i] for i in range(6)}
layer_to_color_map['WM'] = sns.color_palette()[6]
matched_to_color_map = {i+1: sns.color_palette()[i] for i in range(7)}
[ ]:
reducer = UMAP(n_neighbors=30, n_components=2, metric="correlation", n_epochs=None, learning_rate=1.0,
min_dist=0.3, spread=1.0, set_op_mix_ratio=1.0, local_connectivity=1, repulsion_strength=1,
negative_sample_rate=5, a=None, b=None, random_state=1234, metric_kwds=None,
angular_rp_forest=False, verbose=False)
[ ]:
for idx in range(len(sample_group_list)):
slice_name_list = sample_group_list[idx]
slice_index_list = list(range(len(slice_name_list)))
rna_list = []
for sample in slice_name_list:
adata = sc.read_visium(path=data_dir + f'{sample}/',
count_file=sample + '_filtered_feature_bc_matrix.h5')
adata.var_names_make_unique()
# read the annotation
Ann_df = pd.read_csv(data_dir + f'{sample}/meta_data.csv', sep=',', index_col=0)
if not all(Ann_df.index.isin(adata.obs_names)):
raise ValueError("Some rows in the annotation file are not present in the adata.obs_names")
adata.obs['image_row'] = Ann_df.loc[adata.obs_names, 'imagerow']
adata.obs['image_col'] = Ann_df.loc[adata.obs_names, 'imagecol']
adata.obs['Manual_Annotation'] = Ann_df.loc[adata.obs_names, 'ManualAnnotation']
adata.obs_names = [x + '_' + sample for x in adata.obs_names]
rna_list.append(adata)
# concatenation
adata_concat = ad.concat(rna_list, label="slice_name", keys=slice_name_list)
# plot clustering results
embed = pd.read_csv(save_dir + f'/INSTINCT_embed_group{idx}.csv', header=None).values
adata_concat.obsm['latent'] = embed
gm = GaussianMixture(n_components=num_clusters_list[idx], covariance_type='tied', random_state=1234)
y = gm.fit_predict(adata_concat.obsm['latent'], y=None)
adata_concat.obs["gm_clusters"] = pd.Series(y, index=adata_concat.obs.index, dtype='category')
adata_concat = adata_concat[~adata_concat.obs['Manual_Annotation'].isna(), :]
spots_count = [0]
n = 0
for k in range(len(rna_list)):
rna_list[k] = rna_list[k][~rna_list[k].obs['Manual_Annotation'].isna(), :]
num = rna_list[k].shape[0]
n += num
spots_count.append(n)
if idx != 1:
adata_concat.obs['matched_clusters'] = list(pd.Series(1 + match_cluster_labels(
adata_concat.obs['Manual_Annotation'], adata_concat.obs["gm_clusters"]),
index=adata_concat.obs.index, dtype='category'))
else:
adata_concat.obs['matched_clusters'] = list(pd.Series(3 + match_cluster_labels(
adata_concat.obs['Manual_Annotation'], adata_concat.obs["gm_clusters"]),
index=adata_concat.obs.index, dtype='category'))
my_clusters = np.sort(list(set(adata_concat.obs['matched_clusters'])))
for i in range(len(rna_list)):
rna_list[i].obs['matched_clusters'] = list(adata_concat.obs['matched_clusters'][spots_count[i]:spots_count[i+1]])
sp_embedding = reducer.fit_transform(adata_concat.obsm['latent'])
plot_DLPFC(rna_list, adata_concat, 'Manual_Annotation', 'matched_clusters', 'INSTINCT', idx, layer_to_color_map,
matched_to_color_map, my_clusters, slice_name_list, cls_list, sp_embedding,
save_root=save_dir, frame_color='darkviolet', file_format=file_format,
save=save, plot=True)