Integrating samples form DLPFC dataset
Human dorsolateral prefrontal cortex (DLPFC) dataset is an SRT dataset, which contains three sets of slices, with each set contains four slices that exhibit vertical adjacent structure and came from one donor.
In this case, we demonstrate that INSTINCT has the ability for integrating SRT samples.
[ ]:
import os
import csv
import torch
import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc
from sklearn.decomposition import PCA
from sklearn.mixture import GaussianMixture
import INSTINCT
import warnings
warnings.filterwarnings("ignore")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.metrics.cluster import normalized_mutual_info_score
from sklearn.metrics.cluster import fowlkes_mallows_score
from sklearn.metrics.cluster import homogeneity_score
from sklearn.metrics.cluster import adjusted_mutual_info_score
from sklearn.metrics.cluster import completeness_score
import sklearn
import sklearn.neighbors
import networkx as nx
import scib
[ ]:
def match_cluster_labels(true_labels, est_labels):
true_labels_arr = np.array(list(true_labels))
est_labels_arr = np.array(list(est_labels))
org_cat = list(np.sort(list(pd.unique(true_labels))))
est_cat = list(np.sort(list(pd.unique(est_labels))))
B = nx.Graph()
B.add_nodes_from([i + 1 for i in range(len(org_cat))], bipartite=0)
B.add_nodes_from([-j - 1 for j in range(len(est_cat))], bipartite=1)
for i in range(len(org_cat)):
for j in range(len(est_cat)):
weight = np.sum((true_labels_arr == org_cat[i]) * (est_labels_arr == est_cat[j]))
B.add_edge(i + 1, -j - 1, weight=-weight)
match = nx.algorithms.bipartite.matching.minimum_weight_full_matching(B)
if len(org_cat) >= len(est_cat):
return np.array([match[-est_cat.index(c) - 1] - 1 for c in est_labels_arr])
else:
unmatched = [c for c in est_cat if not (-est_cat.index(c) - 1) in match.keys()]
l = []
for c in est_labels_arr:
if (-est_cat.index(c) - 1) in match:
l.append(match[-est_cat.index(c) - 1] - 1)
else:
l.append(len(org_cat) + unmatched.index(c))
return np.array(l)
def cluster_metrics(target, pred):
target = np.array(target)
pred = np.array(pred)
ari = adjusted_rand_score(target, pred)
ami = adjusted_mutual_info_score(target, pred)
nmi = normalized_mutual_info_score(target, pred)
fmi = fowlkes_mallows_score(target, pred)
comp = completeness_score(target, pred)
homo = homogeneity_score(target, pred)
print('ARI: %.3f, AMI: %.3f, NMI: %.3f, FMI: %.3f, Comp: %.3f, Homo: %.3f' % (ari, ami, nmi, fmi, comp, homo))
return ari, ami, nmi, fmi, comp, homo
def mean_average_precision(x: np.ndarray, y: np.ndarray, k: int=30, **kwargs) -> float:
r"""
Mean average precision
Parameters
----------
x
Coordinates
y
Cell_type/Layer labels
k
k neighbors
**kwargs
Additional keyword arguments are passed to
:class:`sklearn.neighbors.NearestNeighbors`
Returns
-------
map
Mean average precision
"""
def _average_precision(match: np.ndarray) -> float:
if np.any(match):
cummean = np.cumsum(match) / (np.arange(match.size) + 1)
return cummean[match].mean().item()
return 0.0
y = np.array(y)
knn = sklearn.neighbors.NearestNeighbors(n_neighbors=min(y.shape[0], k + 1), **kwargs).fit(x)
nni = knn.kneighbors(x, return_distance=False)
match = np.equal(y[nni[:, 1:]], np.expand_dims(y, 1))
return np.apply_along_axis(_average_precision, 1, match).mean().item()
def rep_metrics(adata, origin_concat, use_rep, label_key, batch_key, k_map=30):
if label_key not in adata.obs or batch_key not in adata.obs or use_rep not in adata.obsm:
print("KeyError")
return None
adata.obs[label_key] = adata.obs[label_key].astype(str).astype("category")
adata.obs[batch_key] = adata.obs[batch_key].astype(str).astype("category")
origin_concat.X = origin_concat.X.astype(float)
sc.pp.neighbors(adata, use_rep=use_rep)
MAP = mean_average_precision(adata.obsm[use_rep].copy(), adata.obs[label_key], k=k_map)
cell_type_ASW = scib.me.silhouette(adata, label_key=label_key, embed=use_rep)
# g_iLISI = scib.me.ilisi_graph(adata, batch_key=batch_key, type_="embed", use_rep=use_rep)
batch_ASW = scib.me.silhouette_batch(adata, batch_key=batch_key, label_key=label_key, embed=use_rep, verbose=False)
batch_PCR = scib.me.pcr_comparison(origin_concat, adata, covariate=batch_key, embed=use_rep)
kBET = scib.me.kBET(adata, batch_key=batch_key, label_key=label_key, type_='embed', embed=use_rep)
g_conn = scib.me.graph_connectivity(adata, label_key=label_key)
print('mAP: %.3f, Cell type ASW: %.3f, Batch ASW: %.3f, Batch PCR: %.3f, kBET: %.3f, Graph connectivity: %.3f' %
(MAP, cell_type_ASW, batch_ASW, batch_PCR, kBET, g_conn))
return MAP, cell_type_ASW, batch_ASW, batch_PCR, kBET, g_conn
Run model
For preprocessing SRT data, we use INSTINCT.preprocess_SRT()
[ ]:
# DLPFC
data_dir = '../../data/STdata/10xVisium/DLPFC_Maynard2021/'
sample_group_list = [['151507', '151508', '151509', '151510'],
['151669', '151670', '151671', '151672'],
['151673', '151674', '151675', '151676']]
n_cluster_list = [7, 5, 7]
save_dir = '../../results/DLPFC_Maynard2021/'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for idx in range(len(sample_group_list)):
# load data
slice_name_list = sample_group_list[idx]
slice_index_list = list(range(len(slice_name_list)))
rna_list = []
for sample in slice_name_list:
adata = sc.read_visium(path=data_dir + f'{sample}/', count_file=sample + '_filtered_feature_bc_matrix.h5')
adata.var_names_make_unique()
# read the annotation
Ann_df = pd.read_csv(data_dir + f'{sample}/meta_data.csv', sep=',', index_col=0)
if not all(Ann_df.index.isin(adata.obs_names)):
raise ValueError("Some rows in the annotation file are not present in the adata.obs_names")
adata.obs['image_row'] = Ann_df.loc[adata.obs_names, 'imagerow']
adata.obs['image_col'] = Ann_df.loc[adata.obs_names, 'imagecol']
adata.obs['Manual_Annotation'] = Ann_df.loc[adata.obs_names, 'ManualAnnotation']
adata.obs_names = [x + '_' + sample for x in adata.obs_names]
rna_list.append(adata)
# print(adata.shape)
# concatenation
adata_concat = ad.concat(rna_list, label="slice_name", keys=slice_name_list)
# adata_concat.obs_names_make_unique()
# preprocess SRT data
print('Start preprocessing')
rna_list, adata_concat = INSTINCT.preprocess_SRT(rna_list, adata_concat, n_top_genes=5000)
print(adata_concat.shape)
print('Done!')
origin_concat = ad.concat(rna_list, label="slice_name", keys=slice_index_list)
print(f'Applying PCA to reduce the feature dimension to 100 ...')
pca = PCA(n_components=100, random_state=1234)
input_matrix = pca.fit_transform(adata_concat.X.toarray())
np.save(save_dir + f'input_matrix_group{idx}.npy', input_matrix)
print('Done !')
input_matrix = np.load(save_dir + f'input_matrix_group{idx}.npy')
adata_concat.obsm['X_pca'] = input_matrix
# calculate the spatial graph
INSTINCT.create_neighbor_graph(rna_list, adata_concat)
spots_count = [0]
n = 0
for sample in rna_list:
num = sample.shape[0]
n += num
spots_count.append(n)
INSTINCT_model = INSTINCT.INSTINCT_Model(rna_list, adata_concat, device=device)
INSTINCT_model.train(report_loss=True, report_interval=100)
INSTINCT_model.eval(rna_list)
result = ad.concat(rna_list, label="slice_name", keys=slice_index_list)
with open(save_dir + f'INSTINCT_embed_group{idx}.csv', 'w', newline='') as file:
writer = csv.writer(file)
writer.writerows(result.obsm['INSTINCT_latent'])
with open(save_dir + f'INSTINCT_noise_embed_group{idx}.csv', 'w', newline='') as file:
writer = csv.writer(file)
writer.writerows(result.obsm['INSTINCT_latent_noise'])
gm = GaussianMixture(n_components=n_cluster_list[idx], covariance_type='tied', random_state=1234)
y = gm.fit_predict(result.obsm['INSTINCT_latent'], y=None)
result.obs["gm_clusters"] = pd.Series(y, index=result.obs.index, dtype='category')
result.obs['matched_clusters'] = pd.Series(match_cluster_labels(result.obs['Manual_Annotation'],
result.obs["gm_clusters"]),
index=result.obs.index, dtype='category')
ari, ami, nmi, fmi, comp, homo = cluster_metrics(result.obs['Manual_Annotation'],
result.obs['matched_clusters'].tolist())
map, c_asw, b_asw, b_pcr, kbet, g_conn = rep_metrics(result, origin_concat, use_rep='INSTINCT_latent',
label_key='Manual_Annotation', batch_key='slice_name')