diff --git a/conll03_nel_eval/__main__.py b/conll03_nel_eval/__main__.py index b3af1fb..557a824 100755 --- a/conll03_nel_eval/__main__.py +++ b/conll03_nel_eval/__main__.py @@ -7,7 +7,7 @@ from .prepare import Prepare from .evaluate import Evaluate from .analyze import Analyze -from .significance import Significance +from .significance import Significance, Confidence from .formats import Unstitch, Stitch, Tagme from .fetch_map import FetchMapping from .filter import FilterMentions @@ -18,6 +18,7 @@ Evaluate, Analyze, Significance, + Confidence, Prepare, FilterMentions, Unstitch, diff --git a/conll03_nel_eval/significance.py b/conll03_nel_eval/significance.py index 5e1ee76..d42be4c 100644 --- a/conll03_nel_eval/significance.py +++ b/conll03_nel_eval/significance.py @@ -88,7 +88,7 @@ def count_permutation_trials(per_doc1, per_doc2, base_diff, n_trials): return dict(zip(metrics, better)) -def _bootstrap_trial(per_doc1, per_doc2): +def _paired_bootstrap_trial(per_doc1, per_doc2): indices = [random.randint(0, len(per_doc1) - 1) for i in xrange(len(per_doc1))] pseudo1 = sum((per_doc1[i] for i in indices), Matrix()) @@ -102,12 +102,21 @@ def count_bootstrap_trials(per_doc1, per_doc2, base_diff, n_trials): signs = [base >= 0 for base in bases] same_sign = [0] * len(metrics) for _ in xrange(n_trials): - result = _bootstrap_trial(per_doc1, per_doc2) + result = _paired_bootstrap_trial(per_doc1, per_doc2) for i, metric in enumerate(metrics): - same_sign[i] += signs[i] == (result[metric] >= 0) + same_sign[i] += signs[i] == (result[metric] < 0) return dict(zip(metrics, same_sign)) +def _job_shares(n_jobs, trials): + if n_jobs == -1: + n_jobs = cpu_count() + shares = [trials // n_jobs] * n_jobs + for i in range(trials - sum(shares)): + shares[i] += 1 + return shares + + class Significance(object): """Test for pairwise significance between systems""" @@ -160,15 +169,8 @@ def significance(self, (per_doc1, overall1), (per_doc2, overall2)): randomized_diffs = functools.partial(self.METHODS[self.method], per_doc1, per_doc2, base_diff) - n_jobs = self.n_jobs - if n_jobs == -1: - n_jobs = cpu_count() - shares = [self.trials // n_jobs] * n_jobs - for i in range(self.trials - sum(shares)): - shares[i] += 1 - results = Parallel(n_jobs=self.n_jobs)(delayed(randomized_diffs)(share) - for share in shares) + for share in _job_shares(self.n_jobs, self.trials)) all_counts = [] for result in results: metrics, counts = zip(*result.iteritems()) @@ -197,3 +199,98 @@ def add_arguments(cls, p): choices=LMATCH_SETS.keys()) p.set_defaults(cls=cls) return p + + +def bootstrap_trials(per_doc, n_trials, metrics): + """Bootstrap results over a single system output""" + history = defaultdict(list) + for _ in xrange(n_trials): + indices = [random.randint(0, len(per_doc) - 1) + for i in xrange(len(per_doc))] + result = sum((per_doc[i] for i in indices), Matrix()).results + for metric in metrics: + history[metric].append(result[metric]) + return dict(history) + + +def _percentile(ordered, p): + # As per http://www.itl.nist.gov/div898/handbook/prc/section2/prc252.htm + k, d = divmod(p / 100 * (len(ordered) + 1), 1) + # k is integer, d decimal part + k = int(k) + if 0 < k < len(ordered): + lo, hi = ordered[k - 1:k + 1] + return lo + d * (hi - lo) + elif k == 0: + return ordered[0] + else: + return ordered[-1] + + +class Confidence(object): + """Calculate percentile bootstrap confidence intervals for a system + """ + def __init__(self, system, gold, trials=10000, percentiles=(90, 95, 99), + n_jobs=1, metrics=['precision', 'recall', 'fscore'], + lmatches=DEFAULT_LMATCH_SET): + # Check whether import worked, generate a more useful error. + if Parallel is None: + raise ImportError('Package: "joblib" not available, please install to run significance tests.') + self.system = system + self.gold = gold + self.trials = trials + self.n_jobs = n_jobs + self.lmatches = LMATCH_SETS[lmatches] + self.metrics = metrics + self.percentiles = percentiles + + def intervals(self, per_doc): + results = Parallel(n_jobs=self.n_jobs)(delayed(bootstrap_trials)(per_doc, share, self.metrics) + for share in _job_shares(self.n_jobs, self.trials)) + history = defaultdict(list) + for res in results: + for metric in self.metrics: + history[metric].extend(res[metric]) + + ret = {} + for metric, values in history.items(): + values.sort() + ret[metric] = [(_percentile(values, (100 - p) / 2), + _percentile(values, 100 - (100 - p) / 2)) + for p in self.percentiles] + return ret + + def calculate_all(self): + gold = list(Reader(open(self.gold))) + system = list(Reader(open(self.system))) + doc_pairs = list(Evaluate.iter_pairs(system, gold)) + counts = {} + for match, per_doc, overall in Evaluate.count_all(doc_pairs, self.lmatches): + counts[match] = (per_doc, overall) + results = [{'match': match, + 'overall': {k: v for k, v in overall.results.items() if k in self.metrics}, + 'intervals': self.intervals(per_doc)} + for match, (per_doc, overall) in sorted(counts.iteritems(), + key=lambda (k, v): self.lmatches.index(k))] + return results + + def __call__(self): + return json_format(self.calculate_all(), self.metrics) + + @classmethod + def add_arguments(cls, p): + p.add_argument('system', metavar='FILE') + p.add_argument('-g', '--gold') + p.add_argument('-n', '--trials', default=10000, type=int) + p.add_argument('-j', '--n_jobs', default=1, type=int, + help='Number of parallel processes, use -1 for all CPUs') + p.add_argument('--metrics', default='precision recall fscore'.split(), + type=lambda x: x.split(','), + help='Test significance for which metrics (default: precision,recall,fscore)') + p.add_argument('--percentiles', default=(90, 95, 99), + type=lambda x: map(float, x.split(',')), + help='Output confidence intervals at these percentiles (default: 90,95,99)') + p.add_argument('-l', '--lmatches', default=DEFAULT_LMATCH_SET, + choices=LMATCH_SETS.keys()) + p.set_defaults(cls=cls) + return p diff --git a/conll03_nel_eval/tac.py b/conll03_nel_eval/tac.py index 9557a4c..7284193 100644 --- a/conll03_nel_eval/tac.py +++ b/conll03_nel_eval/tac.py @@ -28,7 +28,7 @@ def __call__(self): @classmethod def add_arguments(cls, p): p.add_argument('system', metavar='FILE', help='link annotations') - p.add_argument('-q', '--queries', help='mention annotations') + p.add_argument('-q', '--queries', required=True, help='mention annotations') p.add_argument('-m', '--mapping', help='mapping for titles') p.set_defaults(cls=cls) return p