-
Notifications
You must be signed in to change notification settings - Fork 1
PCA transform before cosine calculation #146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,10 +2,11 @@ | |
| import numpy as np | ||
| import pandas as pd | ||
| import spatialdata as sd | ||
| from scipy.linalg import norm | ||
| from shapely.ops import unary_union | ||
| from sklearn.metrics.pairwise import cosine_similarity | ||
| from sklearn.decomposition import PCA | ||
|
|
||
| from ..utils import _ensure_index, _is_background, estimate_theta_simple, merge_into_obs, pearson_residuals | ||
| from ..utils import _ensure_index, _is_background, merge_into_obs | ||
| from .utils import _correct_z_drift | ||
|
|
||
|
|
||
|
|
@@ -127,6 +128,21 @@ def vertical_signal_integrity_per_cell( | |
| return out | ||
|
|
||
|
|
||
| def _cell_by_gene_from_transcripts(tx_df: pd.DataFrame, cell_key: str, gene_key: str): | ||
| return tx_df.groupby([cell_key, gene_key], observed=True).size().unstack(fill_value=0) | ||
|
|
||
|
|
||
| def _normalize(x: pd.DataFrame, normalization: str | None, scale: float = 1e4) -> pd.DataFrame: | ||
| match normalization: | ||
| case "log": | ||
| x_sum = x.sum(axis=1).replace(0, np.nan) | ||
| return np.log1p(x.div(x_sum, axis=0) * scale).fillna(0.0) | ||
| case "raw" | None: | ||
| return x | ||
| case _: | ||
| raise ValueError("Invalid `normalization`") | ||
|
|
||
|
|
||
| def similarity_top_bottom( | ||
| sdata, | ||
| tables_key: str = "table", | ||
|
|
@@ -146,6 +162,7 @@ def similarity_top_bottom( | |
| scale: float = 1e4, | ||
| min_genes: int = 5, | ||
| min_transcripts: int = 10, | ||
| n_pcs: int | None = 30, | ||
| inplace: bool = True, | ||
| ): | ||
| """ | ||
|
|
@@ -207,6 +224,8 @@ def similarity_top_bottom( | |
| Minimum number of genes with nonzero counts in (bottom OR top) required to score a cell. | ||
| min_transcripts : int, default=10 | ||
| Minimum number of transcripts required in EACH part (bottom and top) to score a cell. | ||
| n_pcs : int | None, default=30 | ||
| Number of components to use for PCA. If None, no PCA will be performed. | ||
| inplace : bool, default=True | ||
| Whether to add the results to `sdata.tables[tables_key].obs`. | ||
|
|
||
|
|
@@ -221,20 +240,19 @@ def similarity_top_bottom( | |
| # subset points and drop rows with missing cell identifiers, genes or coordinates | ||
| pts = sdata.points[points_key] | ||
| cols = [points_cell_id_key, points_gene_key, points_x_key, points_y_key, points_z_key] | ||
|
|
||
| tx = pts[cols] | ||
| tx = tx.dropna(subset=cols) | ||
| tx = pts[cols].dropna() | ||
|
|
||
| # remove background transcripts | ||
| is_bg = _is_background(tx[points_cell_id_key], points_background_id) | ||
| tx = tx[~is_bg] | ||
|
|
||
| # ensure genes match table var_names from the anndata object | ||
| valid_features = pd.Index(sdata.tables[tables_key].var_names) | ||
| valid_features = sdata.tables[tables_key].var_names | ||
| tx = tx[tx[points_gene_key].isin(valid_features)] | ||
|
|
||
| # cast into pandas Dataframe if Dask Array | ||
| tx = tx.compute() if hasattr(tx, "compute") else tx | ||
| assert isinstance(tx, pd.DataFrame) | ||
| tx = tx.reset_index(drop=True) | ||
|
|
||
| # Optionally correct z-drift before defining top/bottom subsets | ||
|
|
@@ -251,98 +269,56 @@ def similarity_top_bottom( | |
| tx["_z_for_split"] = tx[points_z_key].to_numpy(dtype=float) | ||
|
|
||
| # compute per-cell quantile cutoffs | ||
| z = tx["_z_for_split"] | ||
| tx["_z_bottom"] = tx.groupby(points_cell_id_key, observed=True)["_z_for_split"].transform(lambda s: s.quantile(q)) | ||
| tx["_z_top"] = tx.groupby(points_cell_id_key, observed=True)["_z_for_split"].transform( | ||
| lambda s: s.quantile(1.0 - q) | ||
| ) | ||
| z_by_cell = tx.groupby(points_cell_id_key, observed=True)["_z_for_split"] | ||
| tx["_z_bottom"] = z_by_cell.transform(lambda s: s.quantile(q)) | ||
| tx["_z_top"] = z_by_cell.transform(lambda s: s.quantile(1.0 - q)) | ||
|
|
||
| tx["_is_bottom"] = z <= tx["_z_bottom"] | ||
| tx["_is_top"] = z >= tx["_z_top"] | ||
| tx["_is_bottom"] = tx["_z_for_split"] <= tx["_z_bottom"] | ||
| tx["_is_top"] = tx["_z_for_split"] >= tx["_z_top"] | ||
|
|
||
| # counts per part | ||
| counts_bottom = ( | ||
| tx[tx["_is_bottom"]].groupby([points_cell_id_key, points_gene_key], observed=True).size().unstack(fill_value=0) | ||
| ) | ||
| counts_top = ( | ||
| tx[tx["_is_top"]].groupby([points_cell_id_key, points_gene_key], observed=True).size().unstack(fill_value=0) | ||
| ) | ||
| counts_bottom = _cell_by_gene_from_transcripts(tx[tx["_is_bottom"]], points_cell_id_key, points_gene_key) | ||
| counts_top = _cell_by_gene_from_transcripts(tx[tx["_is_top"]], points_cell_id_key, points_gene_key) | ||
|
|
||
| # align top and bottom cells/genes | ||
| common_cells = counts_bottom.index.intersection(counts_top.index) | ||
| common_genes = counts_bottom.columns.intersection(counts_top.columns) | ||
| all_genes = tx[points_gene_key].unique() | ||
|
|
||
| # counts of the common cells per bottom/top | ||
| counts_bottom_raw = counts_bottom.loc[common_cells, common_genes] | ||
| counts_top_raw = counts_top.loc[common_cells, common_genes] | ||
| counts_bottom_raw = counts_bottom.loc[common_cells].reindex(columns=all_genes, fill_value=0) | ||
| counts_top_raw = counts_top.loc[common_cells].reindex(columns=all_genes, fill_value=0) | ||
|
|
||
| # total number of transcripts per bottom/top | ||
| n_tx_bottom = counts_bottom_raw.sum(axis=1) | ||
| n_tx_top = counts_top_raw.sum(axis=1) | ||
|
|
||
| if normalization == "pearson": | ||
| # aggregate the counts to total counts | ||
| X = np.vstack([counts_bottom_raw.to_numpy(), counts_top_raw.to_numpy()]) | ||
| # estimate the overdispersion parameter from the counts per region according to the | ||
| # variance of a negative binomial; var = mu + mu^2 / theta - solve for theta | ||
| theta = estimate_theta_simple(X) | ||
| # normalize the total counts data with analytical pearson residuals | ||
| R = pearson_residuals(X, theta=theta, clip=None) | ||
| # take them apart again | ||
| n = counts_bottom_raw.shape[0] | ||
| bottom_norm = R[:n, :] | ||
| top_norm = R[n:, :] | ||
| elif normalization == "log": | ||
| # within-cell normalization using (bottom + top) | ||
| total_counts = (counts_bottom_raw + counts_top_raw).sum(axis=1).replace(0, np.nan) | ||
| bottom_norm = counts_bottom_raw.div(total_counts, axis=0) * scale | ||
| top_norm = counts_top_raw.div(total_counts, axis=0) * scale | ||
| bottom_norm = np.log1p(bottom_norm).fillna(0.0) | ||
| top_norm = np.log1p(top_norm).fillna(0.0) | ||
| elif normalization == "raw": | ||
| bottom_norm = counts_bottom_raw | ||
| top_norm = counts_top_raw | ||
| else: | ||
| bottom_norm = counts_bottom_raw | ||
| top_norm = counts_top_raw | ||
|
|
||
| # cast into dataframe | ||
| bottom_norm = pd.DataFrame(bottom_norm, columns=counts_bottom_raw.columns, index=counts_bottom_raw.index) | ||
| top_norm = pd.DataFrame(top_norm, columns=counts_top_raw.columns, index=counts_top_raw.index) | ||
|
|
||
| rows = [] | ||
| for cid in common_cells: | ||
| # get the raw counts for the common cell ID | ||
| x_raw = counts_bottom_raw.loc[cid].to_numpy(dtype=float) | ||
| y_raw = counts_top_raw.loc[cid].to_numpy(dtype=float) | ||
|
|
||
| # genes nonzero in at least one part | ||
| mask = (x_raw != 0) | (y_raw != 0) | ||
| n_genes_kept = int(mask.sum()) | ||
|
|
||
| # thresholds | ||
| if n_tx_bottom.loc[cid] < min_transcripts or n_tx_top.loc[cid] < min_transcripts or n_genes_kept < min_genes: | ||
| sim = np.nan | ||
| else: | ||
| # extract only normalized expression per matching cells for | ||
| # non zero count genes | ||
| x = bottom_norm.loc[cid].to_numpy(dtype=float)[mask] | ||
| y = top_norm.loc[cid].to_numpy(dtype=float)[mask] | ||
|
|
||
| if np.all(x == 0) or np.all(y == 0): | ||
| sim = np.nan | ||
| else: | ||
| sim = cosine_similarity(x.reshape(1, -1), y.reshape(1, -1))[0, 0] | ||
|
|
||
| rows.append((cid, sim)) | ||
|
|
||
| out = pd.DataFrame( | ||
| rows, | ||
| columns=[ | ||
| tables_cell_id_key, | ||
| "cosine_sim_top_bottom_z", | ||
| ], | ||
| ) | ||
| # filter cells by number of genes and min counts in top/bottom | ||
| n_genes = ((counts_bottom_raw != 0) | (counts_top_raw != 0)).sum(axis=1) | ||
| drop_cell = (n_tx_bottom < min_transcripts) | (n_tx_top < min_transcripts) | (n_genes < min_genes) | ||
|
|
||
| counts_bottom_raw = counts_bottom_raw.loc[~drop_cell] | ||
| counts_top_raw = counts_top_raw.loc[~drop_cell] | ||
| common_cells = counts_bottom_raw.index | ||
|
|
||
| top_norm = _normalize(counts_top_raw, normalization, scale) | ||
| bottom_norm = _normalize(counts_bottom_raw, normalization, scale) | ||
|
|
||
| # transform normalized counts into PCA space that is fit based on whole cells | ||
| if n_pcs is not None: | ||
| # TODO: probably tables["table"].X can be reused or similar | ||
| counts_cell = _cell_by_gene_from_transcripts(tx, points_cell_id_key, points_gene_key) | ||
| counts_cell = counts_cell.reindex(columns=all_genes, fill_value=0) | ||
| cell_norm = _normalize(counts_cell, normalization, scale) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we normalise the entire dataset but then compute PCA on the two subsets individually?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually, since you fit the PCA on the full dataset it probably also makes sense to normalise and scale on the entire dataset. My worry is just a bit, that they might actually be very different in say sparsity and then we get kind of an "averaged out" result
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I agree. For now I just kept it as it was before but probably should be changed. |
||
| pca = PCA(n_components=n_pcs, random_state=seed).fit(cell_norm) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, should we scale before the PCA? smth like StandardScaler could actually replace the log norm/scale from above?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could be an idea, but then you would also need to scale the top/bottom using the same transformation? Also consider that at some point it might make sense to switch to sparse arrays and then I always find StandardScaler a bit awkward because you can't correct the mean |
||
| pca.set_output(transform="pandas") | ||
|
|
||
| bottom_norm = pca.transform(bottom_norm) | ||
| top_norm = pca.transform(top_norm) | ||
| assert isinstance(bottom_norm, pd.DataFrame) | ||
| assert isinstance(top_norm, pd.DataFrame) | ||
|
|
||
| cosine_similarity = np.sum(top_norm * bottom_norm, axis=1) / (norm(top_norm, axis=1) * norm(bottom_norm, axis=1)) | ||
| out = cosine_similarity.to_frame("cosine_sim_top_bottom_z").rename_axis(tables_cell_id_key) | ||
|
|
||
| if inplace: | ||
| merge_into_obs( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
potentially, it would be nice to go via percent variance explained as a threshold rather than the number of PCs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, makes it more complicated, also you technically then don't know how many PCs you need to calculate