diff --git a/src/segtraq/SegTraQ.py b/src/segtraq/SegTraQ.py index a1275eed..a255487a 100644 --- a/src/segtraq/SegTraQ.py +++ b/src/segtraq/SegTraQ.py @@ -1645,6 +1645,7 @@ def similarity_top_bottom( normalization: str | None = "log", min_genes: int = 5, min_transcripts: int = 10, + n_pcs: int | None = 30, inplace: bool = True, ): return vl.similarity_top_bottom( @@ -1666,6 +1667,7 @@ def similarity_top_bottom( scale=scale, min_genes=min_genes, min_transcripts=min_transcripts, + n_pcs=n_pcs, inplace=inplace, ) diff --git a/src/segtraq/vl/volume.py b/src/segtraq/vl/volume.py index 4d5951fc..95da7d47 100644 --- a/src/segtraq/vl/volume.py +++ b/src/segtraq/vl/volume.py @@ -2,10 +2,11 @@ import numpy as np import pandas as pd import spatialdata as sd +from scipy.linalg import norm from shapely.ops import unary_union -from sklearn.metrics.pairwise import cosine_similarity +from sklearn.decomposition import PCA -from ..utils import _ensure_index, _is_background, estimate_theta_simple, merge_into_obs, pearson_residuals +from ..utils import _ensure_index, _is_background, merge_into_obs from .utils import _correct_z_drift @@ -127,6 +128,21 @@ def vertical_signal_integrity_per_cell( return out +def _cell_by_gene_from_transcripts(tx_df: pd.DataFrame, cell_key: str, gene_key: str): + return tx_df.groupby([cell_key, gene_key], observed=True).size().unstack(fill_value=0) + + +def _normalize(x: pd.DataFrame, normalization: str | None, scale: float = 1e4) -> pd.DataFrame: + match normalization: + case "log": + x_sum = x.sum(axis=1).replace(0, np.nan) + return np.log1p(x.div(x_sum, axis=0) * scale).fillna(0.0) + case "raw" | None: + return x + case _: + raise ValueError("Invalid `normalization`") + + def similarity_top_bottom( sdata, tables_key: str = "table", @@ -146,6 +162,7 @@ def similarity_top_bottom( scale: float = 1e4, min_genes: int = 5, min_transcripts: int = 10, + n_pcs: int | None = 30, inplace: bool = True, ): """ @@ -207,6 +224,8 @@ def similarity_top_bottom( Minimum number of genes with nonzero counts in (bottom OR top) required to score a cell. min_transcripts : int, default=10 Minimum number of transcripts required in EACH part (bottom and top) to score a cell. + n_pcs : int | None, default=30 + Number of components to use for PCA. If None, no PCA will be performed. inplace : bool, default=True Whether to add the results to `sdata.tables[tables_key].obs`. @@ -221,20 +240,19 @@ def similarity_top_bottom( # subset points and drop rows with missing cell identifiers, genes or coordinates pts = sdata.points[points_key] cols = [points_cell_id_key, points_gene_key, points_x_key, points_y_key, points_z_key] - - tx = pts[cols] - tx = tx.dropna(subset=cols) + tx = pts[cols].dropna() # remove background transcripts is_bg = _is_background(tx[points_cell_id_key], points_background_id) tx = tx[~is_bg] # ensure genes match table var_names from the anndata object - valid_features = pd.Index(sdata.tables[tables_key].var_names) + valid_features = sdata.tables[tables_key].var_names tx = tx[tx[points_gene_key].isin(valid_features)] # cast into pandas Dataframe if Dask Array tx = tx.compute() if hasattr(tx, "compute") else tx + assert isinstance(tx, pd.DataFrame) tx = tx.reset_index(drop=True) # Optionally correct z-drift before defining top/bottom subsets @@ -251,98 +269,56 @@ def similarity_top_bottom( tx["_z_for_split"] = tx[points_z_key].to_numpy(dtype=float) # compute per-cell quantile cutoffs - z = tx["_z_for_split"] - tx["_z_bottom"] = tx.groupby(points_cell_id_key, observed=True)["_z_for_split"].transform(lambda s: s.quantile(q)) - tx["_z_top"] = tx.groupby(points_cell_id_key, observed=True)["_z_for_split"].transform( - lambda s: s.quantile(1.0 - q) - ) + z_by_cell = tx.groupby(points_cell_id_key, observed=True)["_z_for_split"] + tx["_z_bottom"] = z_by_cell.transform(lambda s: s.quantile(q)) + tx["_z_top"] = z_by_cell.transform(lambda s: s.quantile(1.0 - q)) - tx["_is_bottom"] = z <= tx["_z_bottom"] - tx["_is_top"] = z >= tx["_z_top"] + tx["_is_bottom"] = tx["_z_for_split"] <= tx["_z_bottom"] + tx["_is_top"] = tx["_z_for_split"] >= tx["_z_top"] # counts per part - counts_bottom = ( - tx[tx["_is_bottom"]].groupby([points_cell_id_key, points_gene_key], observed=True).size().unstack(fill_value=0) - ) - counts_top = ( - tx[tx["_is_top"]].groupby([points_cell_id_key, points_gene_key], observed=True).size().unstack(fill_value=0) - ) + counts_bottom = _cell_by_gene_from_transcripts(tx[tx["_is_bottom"]], points_cell_id_key, points_gene_key) + counts_top = _cell_by_gene_from_transcripts(tx[tx["_is_top"]], points_cell_id_key, points_gene_key) # align top and bottom cells/genes common_cells = counts_bottom.index.intersection(counts_top.index) - common_genes = counts_bottom.columns.intersection(counts_top.columns) + all_genes = tx[points_gene_key].unique() # counts of the common cells per bottom/top - counts_bottom_raw = counts_bottom.loc[common_cells, common_genes] - counts_top_raw = counts_top.loc[common_cells, common_genes] + counts_bottom_raw = counts_bottom.loc[common_cells].reindex(columns=all_genes, fill_value=0) + counts_top_raw = counts_top.loc[common_cells].reindex(columns=all_genes, fill_value=0) # total number of transcripts per bottom/top n_tx_bottom = counts_bottom_raw.sum(axis=1) n_tx_top = counts_top_raw.sum(axis=1) - if normalization == "pearson": - # aggregate the counts to total counts - X = np.vstack([counts_bottom_raw.to_numpy(), counts_top_raw.to_numpy()]) - # estimate the overdispersion parameter from the counts per region according to the - # variance of a negative binomial; var = mu + mu^2 / theta - solve for theta - theta = estimate_theta_simple(X) - # normalize the total counts data with analytical pearson residuals - R = pearson_residuals(X, theta=theta, clip=None) - # take them apart again - n = counts_bottom_raw.shape[0] - bottom_norm = R[:n, :] - top_norm = R[n:, :] - elif normalization == "log": - # within-cell normalization using (bottom + top) - total_counts = (counts_bottom_raw + counts_top_raw).sum(axis=1).replace(0, np.nan) - bottom_norm = counts_bottom_raw.div(total_counts, axis=0) * scale - top_norm = counts_top_raw.div(total_counts, axis=0) * scale - bottom_norm = np.log1p(bottom_norm).fillna(0.0) - top_norm = np.log1p(top_norm).fillna(0.0) - elif normalization == "raw": - bottom_norm = counts_bottom_raw - top_norm = counts_top_raw - else: - bottom_norm = counts_bottom_raw - top_norm = counts_top_raw - - # cast into dataframe - bottom_norm = pd.DataFrame(bottom_norm, columns=counts_bottom_raw.columns, index=counts_bottom_raw.index) - top_norm = pd.DataFrame(top_norm, columns=counts_top_raw.columns, index=counts_top_raw.index) - - rows = [] - for cid in common_cells: - # get the raw counts for the common cell ID - x_raw = counts_bottom_raw.loc[cid].to_numpy(dtype=float) - y_raw = counts_top_raw.loc[cid].to_numpy(dtype=float) - - # genes nonzero in at least one part - mask = (x_raw != 0) | (y_raw != 0) - n_genes_kept = int(mask.sum()) - - # thresholds - if n_tx_bottom.loc[cid] < min_transcripts or n_tx_top.loc[cid] < min_transcripts or n_genes_kept < min_genes: - sim = np.nan - else: - # extract only normalized expression per matching cells for - # non zero count genes - x = bottom_norm.loc[cid].to_numpy(dtype=float)[mask] - y = top_norm.loc[cid].to_numpy(dtype=float)[mask] - - if np.all(x == 0) or np.all(y == 0): - sim = np.nan - else: - sim = cosine_similarity(x.reshape(1, -1), y.reshape(1, -1))[0, 0] - - rows.append((cid, sim)) - - out = pd.DataFrame( - rows, - columns=[ - tables_cell_id_key, - "cosine_sim_top_bottom_z", - ], - ) + # filter cells by number of genes and min counts in top/bottom + n_genes = ((counts_bottom_raw != 0) | (counts_top_raw != 0)).sum(axis=1) + drop_cell = (n_tx_bottom < min_transcripts) | (n_tx_top < min_transcripts) | (n_genes < min_genes) + + counts_bottom_raw = counts_bottom_raw.loc[~drop_cell] + counts_top_raw = counts_top_raw.loc[~drop_cell] + common_cells = counts_bottom_raw.index + + top_norm = _normalize(counts_top_raw, normalization, scale) + bottom_norm = _normalize(counts_bottom_raw, normalization, scale) + + # transform normalized counts into PCA space that is fit based on whole cells + if n_pcs is not None: + # TODO: probably tables["table"].X can be reused or similar + counts_cell = _cell_by_gene_from_transcripts(tx, points_cell_id_key, points_gene_key) + counts_cell = counts_cell.reindex(columns=all_genes, fill_value=0) + cell_norm = _normalize(counts_cell, normalization, scale) + pca = PCA(n_components=n_pcs, random_state=seed).fit(cell_norm) + pca.set_output(transform="pandas") + + bottom_norm = pca.transform(bottom_norm) + top_norm = pca.transform(top_norm) + assert isinstance(bottom_norm, pd.DataFrame) + assert isinstance(top_norm, pd.DataFrame) + + cosine_similarity = np.sum(top_norm * bottom_norm, axis=1) / (norm(top_norm, axis=1) * norm(bottom_norm, axis=1)) + out = cosine_similarity.to_frame("cosine_sim_top_bottom_z").rename_axis(tables_cell_id_key) if inplace: merge_into_obs(