Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
637 changes: 417 additions & 220 deletions docs/notebooks/10x_xenium_focus.ipynb

Large diffs are not rendered by default.

1,051 changes: 824 additions & 227 deletions docs/notebooks/10x_xenium_focus_simplified.ipynb

Large diffs are not rendered by default.

574 changes: 144 additions & 430 deletions docs/notebooks/clustering_stability.ipynb

Large diffs are not rendered by default.

166 changes: 84 additions & 82 deletions docs/notebooks/plotting.ipynb

Large diffs are not rendered by default.

503 changes: 326 additions & 177 deletions docs/notebooks/supervised.ipynb

Large diffs are not rendered by default.

21 changes: 9 additions & 12 deletions src/segtraq/SegTraQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from anndata import AnnData

from . import bl, cs, pl, ps, rs, sp, vl
from .constants import SEGTRAQ_CELL_ID_KEY
from .utils import _filter_control_and_low_quality_transcripts, _get_genes, validate_spatialdata
from .utils import filter_cells as _filter_cells
from .utils import markers_from_reference as _markers_from_reference
Expand Down Expand Up @@ -127,7 +128,7 @@ def __init__(
Cell ID key for `sdata.shapes[shapes_key]`. Must match either the shapes index name
or a column name (which will be set as the index if needed).
If `None`, the index is assumed to contain cell IDs and
renamed to "segtraq_cell_id".
renamed to "segtraq_id".

nucleus_shapes_key : str or None, optional, default="nucleus_boundaries"
Key in `sdata.shapes` for nucleus boundary polygons, if available.
Expand Down Expand Up @@ -182,7 +183,7 @@ def __init__(
# optionally filter out low-quality and control transcripts that would otherwise skew metrics
if filter_low_quality_transcripts:
resolved_kwargs = {**DEFAULT_FILTER_KWARGS, **(filter_kwargs or {})}
# by default, this modifies the input SpatialData object in-place and
# by default, this modifies the input SpatialData object in-place and
# issues a warning about this
sdata_new = _filter_control_and_low_quality_transcripts(
sdata,
Expand Down Expand Up @@ -226,8 +227,8 @@ def __init__(
self.points_gene_key = points_gene_key

self.shapes_key = shapes_key
self.shapes_cell_id_key = shapes_cell_id_key
self.nucleus_shapes_key = nucleus_shapes_key
self.shapes_cell_id_key = shapes_cell_id_key if shapes_cell_id_key is not None else SEGTRAQ_CELL_ID_KEY
self.nucleus_shapes_key = nucleus_shapes_key if nucleus_shapes_key is not None else SEGTRAQ_CELL_ID_KEY

self.bl = _BLFacade(self)
self.rs = _RSFacade(self)
Expand Down Expand Up @@ -1803,7 +1804,6 @@ def purity(
frac_cells_subset: float = 0.63,
key_prefix: str = "leiden_subset",
use_hvg: bool = False,
representation: str | None = None,
inplace: bool = True,
leiden_kwargs: dict | None = None,
) -> float:
Expand All @@ -1814,7 +1814,6 @@ def purity(
tables_key=self._p.tables_key,
key_prefix=key_prefix,
use_hvg=use_hvg,
representation=representation,
inplace=inplace,
leiden_kwargs=leiden_kwargs,
)
Expand All @@ -1827,7 +1826,6 @@ def adjusted_rand_index(
frac_cells_subset: float = 0.63,
key_prefix: str = "leiden_subset",
use_hvg: bool = False,
representation: str | None = None,
inplace: bool = True,
leiden_kwargs: dict | None = None,
) -> float:
Expand All @@ -1838,7 +1836,6 @@ def adjusted_rand_index(
key_prefix=key_prefix,
tables_key=self._p.tables_key,
use_hvg=use_hvg,
representation=representation,
inplace=inplace,
leiden_kwargs=leiden_kwargs,
)
Expand Down Expand Up @@ -1982,22 +1979,22 @@ class _PLFacade:
def __init__(self, parent: "SegTraQ") -> None:
self._p = parent

def transcript_distribution_across_space(self, smoothing: int = 10):
def transcript_distribution_across_space(self, filter_size: int = 21):
return pl.transcript_distribution_across_space(
sdata=self._p.sdata,
axes=(self._p.points_x_key, self._p.points_y_key),
smoothing=smoothing,
filter_size=filter_size,
points_key=self._p.points_key,
)

transcript_distribution_across_space.__doc__ = pl.transcript_distribution_across_space.__doc__

def feature_distribution_across_space(self, features: list[str], smoothing: int = 10):
def feature_distribution_across_space(self, features: list[str], filter_size: int = 21):
return pl.feature_distribution_across_space(
sdata=self._p.sdata,
features=features,
axes=(self._p.tables_centroid_x_key, self._p.tables_centroid_y_key),
smoothing=smoothing,
filter_size=filter_size,
tables_key=self._p.tables_key,
)

Expand Down
18 changes: 18 additions & 0 deletions src/segtraq/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# This file defines conventions for where SegTraQ stores its results in the AnnData object.
# This is to avoid conflicts with other tools and to make it clear which data is generated by SegTraQ.

# ── Layers ────────────────────────────────────────────────────────────────────
NORM_LOG_LAYER = "norm_log_segtraq"

# ── Obsm ──────────────────────────────────────────────────────────────────────
PCA_KEY = "X_pca_segtraq"

# ── Uns ───────────────────────────────────────────────────────────────────────
NEIGHBORS_KEY = "neighbors_segtraq"

# ── Obsp ──────────────────────────────────────────────────────────────────────
CONNECTIVITIES_KEY = "connectivities_segtraq"
DISTANCES_KEY = "distances_segtraq"

# ── Cell ID ───────────────────────────────────────────────────────────────────
SEGTRAQ_CELL_ID_KEY = "segtraq_id"
80 changes: 40 additions & 40 deletions src/segtraq/cs/clustering_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import spatialdata as sd
from sklearn.metrics import silhouette_score as _silhouette_score

from ..constants import CONNECTIVITIES_KEY, NEIGHBORS_KEY, PCA_KEY
from ..utils import _get_pca_and_neighbors
from .utils import (
_cluster_connectedness,
ari_mean,
Expand All @@ -22,6 +24,9 @@ def cluster_connectedness(
random_state: int = 42,
cell_type_key: str | None = None,
use_hvg: bool = False,
n_neighbors: int = 15,
n_pcs: int = 50,
target_sum: float | None = None,
inplace: bool = True,
leiden_kwargs: dict | None = None,
) -> float:
Expand Down Expand Up @@ -49,6 +54,13 @@ def cluster_connectedness(
If provided, compute cluster connectedness for this clustering only.
use_hvg: bool, optional
Whether to use highly variable genes (HVGs) for PCA. By default False.
n_neighbors: int, optional
Number of neighbors to use for computing the connectivity matrix. Default is 15.
n_pcs: int, optional
Number of principal components to compute for PCA. Default is 50.
target_sum: float | None, optional
Target sum for normalization in `scanpy.pp.normalize_total()` before PCA.
Default is None.
inplace : bool, optional
Whether to store the computed cluster connectedness in sdata.uns, by default True.
leiden_kwargs : dict, optional
Expand All @@ -74,25 +86,21 @@ def cluster_connectedness(
labels = adata.obs[cell_type_key].values
valid_labels = labels[~pd.isna(labels)]
if len(pd.unique(valid_labels)) > 1:
if "connectivities" not in adata.obsp:
raise ValueError(
"Connectivities not found in adata.obsp['connectivities']. "
"Please compute neighbors first by running sc.pp.neighbors(adata)."
)
if CONNECTIVITIES_KEY not in adata.obsp:
adata = _get_pca_and_neighbors(adata, n_neighbors=n_neighbors, n_pcs=n_pcs, target_sum=target_sum)
sdata.tables[tables_key] = adata
distance_val = _cluster_connectedness(
adata.obsp["connectivities"],
adata.obsp[CONNECTIVITIES_KEY],
labels,
use_weights=use_weights,
)
return float(distance_val)
else:
raise ValueError(f"cell_type_key '{cell_type_key}' must contain more than one cluster")

if "neighbors" not in adata.uns:
raise ValueError(
f"Neighbors not found in adata. Please use scanpy to compute neighbors:\n"
f"adata=st_obj.sdata.tables['{tables_key}']; sc.pp.neighbors(adata)."
)
if NEIGHBORS_KEY not in adata.uns:
adata = _get_pca_and_neighbors(adata, n_neighbors=n_neighbors, n_pcs=n_pcs, target_sum=target_sum)
sdata.tables[tables_key] = adata

for res in resolution:
key_added, _, _ = run_leiden_clustering_on_random_subset(
Expand All @@ -112,7 +120,7 @@ def cluster_connectedness(

if len(pd.unique(valid_labels)) > 1:
# Slice connectivity matrix to valid cells only — both rows AND columns
connectivity_subset = adata.obsp["connectivities"][np.ix_(valid_mask, valid_mask)]
connectivity_subset = adata.obsp[CONNECTIVITIES_KEY][np.ix_(valid_mask, valid_mask)]

distance_val = _cluster_connectedness(
connectivity_subset,
Expand All @@ -137,6 +145,9 @@ def silhouette_score(
random_state: int = 42,
cell_type_key: str | None = None,
use_hvg: bool = False,
n_neighbors: int = 15,
n_pcs: int = 50,
target_sum: float | None = None,
inplace: bool = True,
leiden_kwargs: dict | None = None,
) -> float:
Expand All @@ -162,6 +173,13 @@ def silhouette_score(
If provided, compute the silhouette score for provided labels.
use_hvg: bool, optional
Whether to use highly variable genes (HVGs) for PCA. By default False.
n_neighbors: int, optional
Number of neighbors to use for computing the connectivity matrix. Default is 15.
n_pcs: int, optional
Number of principal components to compute for PCA. Default is 50.
target_sum: float | None, optional
Target sum for normalization in `scanpy.pp.normalize_total()` before PCA.
Default is None.
inplace : bool, optional
Whether to store the computed silhouette score in sdata.uns, by default True.
leiden_kwargs : dict, optional
Expand All @@ -188,12 +206,13 @@ def silhouette_score(

labels_nn = adata.obs[cell_type_key].dropna()
if labels_nn.nunique() > 1: # Ensure more than one cluster exists
if "X_pca" not in adata.obsm:
raise ValueError("PCA coordinates not found in adata.obsm['X_pca']. Please run PCA first.")
if PCA_KEY not in adata.obsm:
adata = _get_pca_and_neighbors(adata, n_neighbors=n_neighbors, n_pcs=n_pcs, target_sum=target_sum)
sdata.tables[tables_key] = adata
# remove NaN labels
adata_subset = adata[~pd.isna(adata.obs[cell_type_key]), :]
labels = adata_subset.obs[cell_type_key].values
silhouette_avg = _silhouette_score(adata_subset.obsm["X_pca"], labels, metric=metric)
silhouette_avg = _silhouette_score(adata_subset.obsm[PCA_KEY], labels, metric=metric)
best_silhouette_score = float(silhouette_avg)
key = "silhouette_score_labels"

Expand All @@ -207,11 +226,9 @@ def silhouette_score(
else:
# ensure that we already have neighbors computed
# this way we avoid recomputing neighbors multiple times (for the different resolutions)
if "neighbors" not in adata.uns:
raise ValueError(
f"Neighbors not found in adata. Please use scanpy to compute neighbors:\n"
f"adata=st_obj.sdata.tables['{tables_key}']; sc.pp.neighbors(adata)."
)
if NEIGHBORS_KEY not in adata.uns:
adata = _get_pca_and_neighbors(adata, n_neighbors=n_neighbors, n_pcs=n_pcs, target_sum=target_sum)
sdata.tables[tables_key] = adata

key = "silhouette_score"
for res in resolution:
Expand All @@ -230,12 +247,9 @@ def silhouette_score(

if len(pd.unique(labels)) > 1: # Ensure more than one cluster exists
if pca is None:
raise ValueError(
"PCA coordinates are required for silhouette score calculation, "
"but no PCA embedding was found. "
"Please compute PCA with `sc.pp.pca(adata)` and then recompute neighbors with "
"`sc.pp.neighbors(adata)`."
)
adata = _get_pca_and_neighbors(adata, n_neighbors=n_neighbors, n_pcs=n_pcs, target_sum=target_sum)
sdata.tables[tables_key] = adata
pca = adata.obsm[PCA_KEY]

silhouette_avg = _silhouette_score(pca, labels, metric=metric)

Expand All @@ -255,7 +269,6 @@ def purity(
tables_key: str = "table",
key_prefix: str = "leiden_subset",
use_hvg: bool = False,
representation: str | None = None,
inplace: bool = True,
leiden_kwargs: dict | None = None,
) -> float:
Expand All @@ -275,11 +288,6 @@ def purity(
The prefix for the keys under which the clustering results are stored, by default "leiden_subset".
use_hvg: bool, optional
Whether to use highly variable genes (HVGs) for PCA. By default False.
representation : str | None, optional
Key in `adata.obsm` specifying the feature representation used to compute
the k-nearest neighbor graph before clustering. This is passed to
`scanpy.pp.neighbors(..., use_rep=representation)`.
If `None`, a PCA ('X_pca') embedding is computed internally.
inplace : bool, optional
Whether to store the computed purity in sdata.uns, by default True.
leiden_kwargs : dict, optional
Expand All @@ -303,7 +311,6 @@ def purity(
key_prefix=key_prefix,
use_hvg=use_hvg,
random_state=random_state,
representation=representation,
leiden_kwargs=leiden_kwargs,
)
cluster_keys.append(key_added)
Expand All @@ -324,7 +331,6 @@ def adjusted_rand_index(
tables_key: str = "table",
key_prefix: str = "leiden_subset",
use_hvg: bool = False,
representation: str | None = None,
inplace: bool = True,
leiden_kwargs: dict | None = None,
) -> float:
Expand All @@ -345,11 +351,6 @@ def adjusted_rand_index(
The prefix for the keys under which the clustering results are stored, by default "leiden_subset".
use_hvg: bool, optional
Whether to use highly variable genes (HVGs) for PCA. By default False.
representation : str | None, optional
Key in `adata.obsm` specifying the feature representation used to compute
the k-nearest neighbor graph before clustering. This is passed to
`scanpy.pp.neighbors(..., use_rep=representation)`.
If `None`, a PCA ('X_pca') embedding is computed internally.
inplace : bool, optional
Whether to store the computed ARI in sdata.uns, by default True.
leiden_kwargs : dict, optional
Expand All @@ -374,7 +375,6 @@ def adjusted_rand_index(
key_prefix=key_prefix,
use_hvg=use_hvg,
random_state=random_state,
representation=representation,
leiden_kwargs=leiden_kwargs,
)
cluster_keys.append(key_added)
Expand Down
Loading
Loading