diff --git a/cnvlib/commands.py b/cnvlib/commands.py index 440c0630..4ba2b75c 100644 --- a/cnvlib/commands.py +++ b/cnvlib/commands.py @@ -2798,6 +2798,7 @@ def _cmd_import_rna(args: argparse.Namespace) -> None: args.do_gc, args.do_txlen, args.max_log2, + min_sample_fraction=args.min_sample_fraction, ) logging.info("Writing output files") if args.output: @@ -2850,6 +2851,16 @@ def _cmd_import_rna(args: argparse.Namespace) -> None: help="""Maximum log2 ratio in output. Observed values above this limit will be replaced with this value. [Default: %(default)s]""", ) +P_import_rna.add_argument( + "--min-sample-fraction", + metavar="FLOAT", + default=0.5, + type=float, + help="""Keep a gene only if it is expressed (read count >= 1) in at least this + fraction of samples. Lower this for single-cell or sparse cohorts, where + most genes are expressed in fewer than half of cells; values below ~0.2 + admit genes whose log2 ratios are likely noise. [Default: %(default)s]""", +) P_import_rna.add_argument( "-n", "--normal", diff --git a/cnvlib/import_rna.py b/cnvlib/import_rna.py index 57920d34..1c1d3a62 100644 --- a/cnvlib/import_rna.py +++ b/cnvlib/import_rna.py @@ -19,6 +19,7 @@ def do_import_rna( do_txlen=True, max_log2=3, diploid_parx_genome=None, + min_sample_fraction=0.5, ): """Convert a cohort of per-gene read counts to CNVkit .cnr format. @@ -46,6 +47,11 @@ def do_import_rna( diploid_parx_genome : str, optional Reference genome name for pseudo-autosomal region handling (e.g., 'hg19', 'hg38'). + min_sample_fraction : float, optional + Minimum fraction of samples in which a gene must be expressed + (count >= 1) for it to be retained. Default 0.5 preserves the legacy + filter. Lower it for single-cell or sparse cohorts. See + :func:`cnvlib.rna.filter_probes`. Returns ------- @@ -68,8 +74,10 @@ def do_import_rna( sample_counts = aggregate_gene_counts(gene_count_fnames) tx_lengths = None else: - raise RuntimeError("Unrecognized input format name: {in_format!r}") - sample_counts = rna.filter_probes(sample_counts) + raise RuntimeError(f"Unrecognized input format name: {in_format!r}") + sample_counts = rna.filter_probes( + sample_counts, min_sample_fraction=min_sample_fraction + ) logging.info( "Loading gene metadata%s", diff --git a/cnvlib/rna.py b/cnvlib/rna.py index b42a8532..9f954bf3 100644 --- a/cnvlib/rna.py +++ b/cnvlib/rna.py @@ -49,7 +49,7 @@ def selector(string): return selector -def filter_probes(sample_counts): +def filter_probes(sample_counts, min_sample_fraction=0.5): """Filter probes to only include high-quality, transcribed genes. The human genome has ~25,000 protein coding genes, yet the RSEM output @@ -58,14 +58,41 @@ def filter_probes(sample_counts): mapped genes in contigs that have not been linked to the 24 chromosomes (e.g. HLA region). Others correspond to pseudo-genes and non-coding genes. For the purposes of copy number inference, these rows are best removed. + + A gene is retained when it has a detectable transcript (count >= 1) in at + least ``min_sample_fraction`` of the samples. This is expressed as a + quantile threshold rather than a literal sample count: the gene's + ``(1 - min_sample_fraction)`` quantile of per-sample counts must be >= 1. + + Parameters + ---------- + sample_counts : pandas.DataFrame + Per-gene (rows) read counts across samples (columns). + min_sample_fraction : float, optional + Minimum fraction of samples in which a gene must be expressed for it to + be retained, in [0, 1]. The default 0.5 reproduces the historical + ``median(counts) >= 1`` rule bit-exact (the median is the 0.5 quantile). + Lower values are more permissive, which is appropriate for single-cell + or otherwise sparse cohorts where most genes are expressed in fewer than + half of cells. """ - gene_medians = sample_counts.median(axis=1) - # Make sure the gene has detectable transcript in at least half of samples - is_mostly_transcribed = gene_medians >= 1.0 + if not 0.0 <= min_sample_fraction <= 1.0: + raise ValueError( + f"min_sample_fraction must be in [0, 1], got {min_sample_fraction!r}" + ) + if sample_counts.empty: + # Nothing to filter; DataFrame.quantile(axis=1) raises on an empty frame + # whereas the legacy median(axis=1) returned an empty result. + return sample_counts + # The (1 - f) quantile >= 1 means at least fraction f of samples express the + # gene. At f=0.5 this is the median, preserving the legacy behavior exactly. + gene_quantiles = sample_counts.quantile(1.0 - min_sample_fraction, axis=1) + is_mostly_transcribed = gene_quantiles >= 1.0 logging.info( - "Dropping %d / %d rarely expressed genes from input samples", + "Dropping %d / %d genes expressed in fewer than %g%% of input samples", (~is_mostly_transcribed).sum(), len(is_mostly_transcribed), + 100 * min_sample_fraction, ) return sample_counts[is_mostly_transcribed] diff --git a/doc/rna.rst b/doc/rna.rst index be2070d8..1981e99f 100644 --- a/doc/rna.rst +++ b/doc/rna.rst @@ -51,6 +51,27 @@ Input file sources: You can also create the equivalent on your own from the output of another RNA quantification tool like Salmon or Kallisto. +Gene filtering +~~~~~~~~~~~~~~ + +Before normalization, ``import-rna`` discards genes that are rarely expressed +across the cohort: a gene is kept only if it has a detectable transcript (read +count of at least 1) in a minimum fraction of the input samples. By default this +fraction is 0.5, i.e. a gene must be expressed in at least half of the samples, +which reproduces the behavior of earlier CNVkit versions. + +The ``--min-sample-fraction`` option lowers this threshold for single-cell or +otherwise sparse cohorts, where most genes are legitimately expressed in fewer +than half of the cells and the default would discard informative genes:: + + cnvkit.py import-rna *.txt --gene-resource data/ensembl-gene-info.hg38.tsv \ + --min-sample-fraction 0.2 --output-dir out/ + +Lowering the threshold retains more genes in the output ``.cnr`` files, which can +shift downstream segmentation; this is a deliberate trade-off for sparse data. +Values below roughly 0.2 admit genes whose expression is detected in only a small +minority of samples, so their log2 ratios are likely dominated by noise. + Segmentation ------------ diff --git a/test/test_rna.py b/test/test_rna.py index c658c171..b63b5a0e 100644 --- a/test/test_rna.py +++ b/test/test_rna.py @@ -1,6 +1,8 @@ #!/usr/bin/env python """Unit tests for RNA import functionality (cnvlib.rna).""" +import ast +import inspect import logging import os import tempfile @@ -10,7 +12,7 @@ import numpy as np import pandas as pd -from cnvlib import import_rna, rna +from cnvlib import commands, import_rna, rna logging.basicConfig(level=logging.ERROR, format="%(message)s") @@ -271,6 +273,151 @@ def test_safe_log2_zero_handling(self): self.assertTrue(result[3] > result[2]) +class FilterProbesTests(unittest.TestCase): + """``filter_probes`` keeps a gene when enough samples express it. + + The filter is a quantile threshold, not a literal sample count: a gene is + retained when the ``(1 - min_sample_fraction)`` quantile of its per-sample + counts is >= 1. At the default ``min_sample_fraction=0.5`` this is exactly + the legacy ``median(counts) >= 1`` rule, preserved bit-exact (#448). + """ + + @staticmethod + def _counts_expressed_in(n_expressed, n_samples=10, level=100): + """A 3-gene matrix; the middle gene is expressed in exactly N samples. + + The flanking genes are expressed everywhere / nowhere so the frame is + never empty and the assertions isolate the middle gene's fate. + """ + rows = { + "gene_all": [level] * n_samples, # always kept + "gene_mid": [level] * n_expressed + [0] * (n_samples - n_expressed), + "gene_none": [0] * n_samples, # always dropped + } + return pd.DataFrame.from_dict( + rows, orient="index", columns=[f"s{i}" for i in range(n_samples)] + ) + + def test_default_matches_legacy_median_rule(self): + """Default fraction reproduces the historical ``median >= 1`` filter bit-exact.""" + rng = np.random.default_rng(0) + counts = pd.DataFrame( + rng.poisson(2, size=(500, 7)), + index=[f"g{i}" for i in range(500)], + ) + legacy = counts[counts.median(axis=1) >= 1.0] + new = rna.filter_probes(counts) + self.assertTrue(new.equals(legacy)) + + def test_retained_iff_fraction_meets_threshold(self): + """Gene kept iff (N expressed / M samples) >= min_sample_fraction.""" + m = 10 + for n in range(m + 1): + counts = self._counts_expressed_in(n, n_samples=m) + for f in (0.1, 0.2, 0.3, 0.5, 0.7, 0.9, 1.0): + kept = ( + "gene_mid" in rna.filter_probes(counts, min_sample_fraction=f).index + ) + expected = (n / m) >= f + self.assertEqual( + kept, + expected, + f"N={n}/{m} (fraction {n / m}) with min_sample_fraction={f}: " + f"expected kept={expected}, got {kept}", + ) + + def test_lower_fraction_is_more_permissive(self): + """Lowering the threshold never drops a gene the stricter run kept.""" + counts = self._counts_expressed_in(3, n_samples=10) + strict = set(rna.filter_probes(counts, min_sample_fraction=0.5).index) + loose = set(rna.filter_probes(counts, min_sample_fraction=0.2).index) + self.assertTrue(strict.issubset(loose)) + # The single-cell-style low threshold rescues the sparsely-expressed gene. + self.assertNotIn("gene_mid", strict) + self.assertIn("gene_mid", loose) + + def test_invalid_fraction_raises(self): + counts = self._counts_expressed_in(5) + for bad in (-0.1, 1.5, 2.0): + with self.assertRaises(ValueError): + rna.filter_probes(counts, min_sample_fraction=bad) + + def test_empty_input_returns_empty(self): + """Empty frame passes through unchanged (quantile(axis=1) raises on empty). + + Regression guard: the legacy ``median(axis=1)`` path returned an empty + result on an empty frame, whereas ``quantile(axis=1)`` raises + ``ValueError: no types given``. + """ + empty = pd.DataFrame() + self.assertEqual(rna.filter_probes(empty).shape, (0, 0)) + self.assertEqual( + rna.filter_probes(empty, min_sample_fraction=0.2).shape, (0, 0) + ) + + def test_single_sample_uses_that_sample(self): + """With one sample, the quantile collapses to that sample's count.""" + counts = pd.DataFrame({"s0": [0, 5, 1]}, index=["a", "b", "c"]) + self.assertEqual(list(rna.filter_probes(counts).index), ["b", "c"]) + + +class FilterProbesPlumbingTests(unittest.TestCase): + """``min_sample_fraction`` is threaded CLI -> do_import_rna -> filter_probes. + + AST-level guards (cheap, no fixtures) so a dropped kwarg fails at collection + rather than silently reverting single-cell cohorts to the 0.5 default. + """ + + @staticmethod + def _calls_to(source_obj, attr, value_id): + tree = ast.parse(inspect.getsource(source_obj)) + return [ + node + for node in ast.walk(tree) + if isinstance(node, ast.Call) + and isinstance(node.func, ast.Attribute) + and node.func.attr == attr + and isinstance(node.func.value, ast.Name) + and node.func.value.id == value_id + ] + + def test_do_import_rna_passes_fraction_to_filter_probes(self): + calls = self._calls_to(import_rna.do_import_rna, "filter_probes", "rna") + self.assertGreater(len(calls), 0) + for call in calls: + self.assertIn( + "min_sample_fraction", + {kw.arg for kw in call.keywords}, + "do_import_rna must forward min_sample_fraction to rna.filter_probes", + ) + + def test_cmd_import_rna_passes_fraction_to_do_import_rna(self): + calls = self._calls_to(commands._cmd_import_rna, "do_import_rna", "import_rna") + self.assertGreater(len(calls), 0) + for call in calls: + self.assertIn( + "min_sample_fraction", + {kw.arg for kw in call.keywords}, + "_cmd_import_rna must forward args.min_sample_fraction to " + "do_import_rna", + ) + + def test_signatures_default_to_one_half(self): + """Default preserved at both layers so existing behavior is unchanged.""" + self.assertEqual( + inspect.signature(rna.filter_probes) + .parameters["min_sample_fraction"] + .default, + 0.5, + ) + self.assertEqual( + inspect.signature(import_rna.do_import_rna) + .parameters["min_sample_fraction"] + .default, + 0.5, + ) + + class ImportRnaIntegrationTests(unittest.TestCase): """End-to-end `import-rna` from count files + gene resource to .cnr.""" @@ -301,8 +448,10 @@ def test_do_import_rna_counts_to_cnr(self): self.assertEqual(cnrs[0].sample_id, "rna-sample-A") def test_do_import_rna_unknown_format_raises(self): - with self.assertRaises(RuntimeError): + with self.assertRaises(RuntimeError) as cm: import_rna.do_import_rna(self.COUNT_FILES[:1], "bogus", self.GENE_RESOURCE) + # The offending format name is interpolated into the message (not literal). + self.assertIn("bogus", str(cm.exception)) class NormalizeReadDepthsNormalAnchorTests(unittest.TestCase):