Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion conll03_nel_eval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .prepare import Prepare
from .evaluate import Evaluate
from .analyze import Analyze
from .significance import Significance
from .significance import Significance, Confidence
from .formats import Unstitch, Stitch, Tagme
from .fetch_map import FetchMapping
from .filter import FilterMentions
Expand All @@ -18,6 +18,7 @@
Evaluate,
Analyze,
Significance,
Confidence,
Prepare,
FilterMentions,
Unstitch,
Expand Down
119 changes: 108 additions & 11 deletions conll03_nel_eval/significance.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def count_permutation_trials(per_doc1, per_doc2, base_diff, n_trials):
return dict(zip(metrics, better))


def _bootstrap_trial(per_doc1, per_doc2):
def _paired_bootstrap_trial(per_doc1, per_doc2):
indices = [random.randint(0, len(per_doc1) - 1)
for i in xrange(len(per_doc1))]
pseudo1 = sum((per_doc1[i] for i in indices), Matrix())
Expand All @@ -102,12 +102,21 @@ def count_bootstrap_trials(per_doc1, per_doc2, base_diff, n_trials):
signs = [base >= 0 for base in bases]
same_sign = [0] * len(metrics)
for _ in xrange(n_trials):
result = _bootstrap_trial(per_doc1, per_doc2)
result = _paired_bootstrap_trial(per_doc1, per_doc2)
for i, metric in enumerate(metrics):
same_sign[i] += signs[i] == (result[metric] >= 0)
same_sign[i] += signs[i] == (result[metric] < 0)
return dict(zip(metrics, same_sign))


def _job_shares(n_jobs, trials):
if n_jobs == -1:
n_jobs = cpu_count()
shares = [trials // n_jobs] * n_jobs
for i in range(trials - sum(shares)):
shares[i] += 1
return shares


class Significance(object):
"""Test for pairwise significance between systems"""

Expand Down Expand Up @@ -160,15 +169,8 @@ def significance(self, (per_doc1, overall1), (per_doc2, overall2)):
randomized_diffs = functools.partial(self.METHODS[self.method],
per_doc1, per_doc2,
base_diff)
n_jobs = self.n_jobs
if n_jobs == -1:
n_jobs = cpu_count()
shares = [self.trials // n_jobs] * n_jobs
for i in range(self.trials - sum(shares)):
shares[i] += 1

results = Parallel(n_jobs=self.n_jobs)(delayed(randomized_diffs)(share)
for share in shares)
for share in _job_shares(self.n_jobs, self.trials))
all_counts = []
for result in results:
metrics, counts = zip(*result.iteritems())
Expand Down Expand Up @@ -197,3 +199,98 @@ def add_arguments(cls, p):
choices=LMATCH_SETS.keys())
p.set_defaults(cls=cls)
return p


def bootstrap_trials(per_doc, n_trials, metrics):
"""Bootstrap results over a single system output"""
history = defaultdict(list)
for _ in xrange(n_trials):
indices = [random.randint(0, len(per_doc) - 1)
for i in xrange(len(per_doc))]
result = sum((per_doc[i] for i in indices), Matrix()).results
for metric in metrics:
history[metric].append(result[metric])
return dict(history)


def _percentile(ordered, p):
# As per http://www.itl.nist.gov/div898/handbook/prc/section2/prc252.htm
k, d = divmod(p / 100 * (len(ordered) + 1), 1)
# k is integer, d decimal part
k = int(k)
if 0 < k < len(ordered):
lo, hi = ordered[k - 1:k + 1]
return lo + d * (hi - lo)
elif k == 0:
return ordered[0]
else:
return ordered[-1]


class Confidence(object):
"""Calculate percentile bootstrap confidence intervals for a system
"""
def __init__(self, system, gold, trials=10000, percentiles=(90, 95, 99),
n_jobs=1, metrics=['precision', 'recall', 'fscore'],
lmatches=DEFAULT_LMATCH_SET):
# Check whether import worked, generate a more useful error.
if Parallel is None:
raise ImportError('Package: "joblib" not available, please install to run significance tests.')
self.system = system
self.gold = gold
self.trials = trials
self.n_jobs = n_jobs
self.lmatches = LMATCH_SETS[lmatches]
self.metrics = metrics
self.percentiles = percentiles

def intervals(self, per_doc):
results = Parallel(n_jobs=self.n_jobs)(delayed(bootstrap_trials)(per_doc, share, self.metrics)
for share in _job_shares(self.n_jobs, self.trials))
history = defaultdict(list)
for res in results:
for metric in self.metrics:
history[metric].extend(res[metric])

ret = {}
for metric, values in history.items():
values.sort()
ret[metric] = [(_percentile(values, (100 - p) / 2),
_percentile(values, 100 - (100 - p) / 2))
for p in self.percentiles]
return ret

def calculate_all(self):
gold = list(Reader(open(self.gold)))
system = list(Reader(open(self.system)))
doc_pairs = list(Evaluate.iter_pairs(system, gold))
counts = {}
for match, per_doc, overall in Evaluate.count_all(doc_pairs, self.lmatches):
counts[match] = (per_doc, overall)
results = [{'match': match,
'overall': {k: v for k, v in overall.results.items() if k in self.metrics},
'intervals': self.intervals(per_doc)}
for match, (per_doc, overall) in sorted(counts.iteritems(),
key=lambda (k, v): self.lmatches.index(k))]
return results

def __call__(self):
return json_format(self.calculate_all(), self.metrics)

@classmethod
def add_arguments(cls, p):
p.add_argument('system', metavar='FILE')
p.add_argument('-g', '--gold')
p.add_argument('-n', '--trials', default=10000, type=int)
p.add_argument('-j', '--n_jobs', default=1, type=int,
help='Number of parallel processes, use -1 for all CPUs')
p.add_argument('--metrics', default='precision recall fscore'.split(),
type=lambda x: x.split(','),
help='Test significance for which metrics (default: precision,recall,fscore)')
p.add_argument('--percentiles', default=(90, 95, 99),
type=lambda x: map(float, x.split(',')),
help='Output confidence intervals at these percentiles (default: 90,95,99)')
p.add_argument('-l', '--lmatches', default=DEFAULT_LMATCH_SET,
choices=LMATCH_SETS.keys())
p.set_defaults(cls=cls)
return p
2 changes: 1 addition & 1 deletion conll03_nel_eval/tac.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __call__(self):
@classmethod
def add_arguments(cls, p):
p.add_argument('system', metavar='FILE', help='link annotations')
p.add_argument('-q', '--queries', help='mention annotations')
p.add_argument('-q', '--queries', required=True, help='mention annotations')
p.add_argument('-m', '--mapping', help='mapping for titles')
p.set_defaults(cls=cls)
return p
Expand Down