diff --git a/ror/README.md b/ror/README.md new file mode 100644 index 000000000..25ef2d401 --- /dev/null +++ b/ror/README.md @@ -0,0 +1,2 @@ +# ROR Affiliations Plugin +TODO diff --git a/ror/indico_ror/__init__.py b/ror/indico_ror/__init__.py new file mode 100644 index 000000000..6e0392fc7 --- /dev/null +++ b/ror/indico_ror/__init__.py @@ -0,0 +1,17 @@ +# This file is part of the Indico plugins. +# Copyright (C) 2002 - 2026 CERN +# +# The Indico plugins are free software; you can redistribute +# them and/or modify them under the terms of the MIT License; +# see the LICENSE file for more details. + +from indico.core import signals +from indico.util.i18n import make_bound_gettext + + +_ = make_bound_gettext('ror') + + +@signals.core.import_tasks.connect +def _import_tasks(sender, **kwargs): + import indico_ror.task # noqa: F401 diff --git a/ror/indico_ror/matching.py b/ror/indico_ror/matching.py new file mode 100644 index 000000000..34c806aba --- /dev/null +++ b/ror/indico_ror/matching.py @@ -0,0 +1,131 @@ +# This file is part of the Indico plugins. +# Copyright (C) 2002 - 2026 CERN +# +# The Indico plugins are free software; you can redistribute +# them and/or modify them under the terms of the MIT License; +# see the LICENSE file for more details. + +import ollama +from langchain_ollama import OllamaEmbeddings +from sqlalchemy import delete, literal, select, union_all + +from indico.core.db import db +from indico.modules.affiliations.search import AffiliationSearchMatch, AffiliationSearchProvider + +from indico_ror.models.affiliation_vs_document import AffiliationVectorStoreDocument + + +def is_model_pulled(model_name: str) -> bool: + models = ollama.list() + return any(model.model.split(':')[0] == model_name for model in models.models) + + +def ensure_model(model_name: str): + """Pull a model if it's not already available.""" + if not is_model_pulled(model_name): + # ollama.pull downloads from some remote archive; we should instead pre-fetch whichever + # models we actually need + ollama.pull(model_name) + + +class PSQLVectorStoreAffiliationSearchProvider(AffiliationSearchProvider): + def __init__( + self, model: str = 'jina/jina-embeddings-v2-small-en', batch_size: int = 512, threshold: float = 0.3 + ) -> None: + ensure_model(model) + self.model = model + self.embeddings = OllamaEmbeddings( + model=model, + ) + self.batch_size = batch_size + self.threshold = threshold + + @staticmethod + def cosine_distance_to_score(distance: float) -> float: + return 1 - distance + + def init(self, texts: list[str], affiliation_ids: list[int]) -> list[str]: + return self.add(texts, affiliation_ids) + + def add(self, texts: list[str], affiliation_ids: list[int]) -> None: + if len(texts) == 0: + return + for i in range(0, len(texts), self.batch_size): + embeddings = self.embeddings.embed_documents(texts[i:i+self.batch_size]) + for j in range(len(embeddings)): + db.session.add(AffiliationVectorStoreDocument( + content=texts[i+j], + embedding=embeddings[j], + affiliation_id=affiliation_ids[i+j], + )) + db.session.flush() + db.session.flush() + + def update(self, texts: list[str], affiliation_ids: list[int], changed_affiliations: list[int]) -> None: + if len(texts) == 0: + return + self.delete(changed_affiliations) + return self.add(texts, affiliation_ids) + + def delete(self, affiliation_ids: list[int]) -> None: + if len(affiliation_ids) == 0: + return + db.session.execute( + delete(AffiliationVectorStoreDocument) + .where(AffiliationVectorStoreDocument.affiliation_id.in_(affiliation_ids)) + ) + + def match_embeddings( + self, embeddings: list[list[float]], k: int = 1 + ) -> list[list[AffiliationSearchMatch]]: + subqueries = [] + for i, embedding in enumerate(embeddings): + distance = AffiliationVectorStoreDocument.embedding.cosine_distance(embedding).label('distance') + subqueries.append( + select( + literal(i).label('embedding_index'), + AffiliationVectorStoreDocument.id.label('id'), + distance + ).where(distance < self.threshold) + .order_by(distance) + .limit(k) + ) + combined = union_all(*subqueries).subquery() + + results = db.session.execute( + select( + combined.c.embedding_index, + AffiliationVectorStoreDocument, + combined.c.distance + ) + .join(AffiliationVectorStoreDocument, AffiliationVectorStoreDocument.id == combined.c.id) + .order_by(combined.c.embedding_index, combined.c.distance) + ).all() + + grouped: dict[int, list[AffiliationSearchMatch]] = {i: [] for i in range(len(embeddings))} + for embedding_index, doc, dist in results: + grouped[embedding_index].append(AffiliationSearchMatch( + score=self.cosine_distance_to_score(dist), text=doc.content, affiliation_id=doc.affiliation_id + )) + + return [grouped[i] for i in sorted(grouped.keys())] + + def match_embedding(self, embedding: list[float], k: int = 1) -> list[AffiliationSearchMatch]: + distance = AffiliationVectorStoreDocument.embedding.cosine_distance(embedding).label('distance') + return [ + AffiliationSearchMatch( + score=self.cosine_distance_to_score(dist), text=doc.content, affiliation_id=doc.affiliation_id + ) + for doc, dist in db.session.execute( + select(AffiliationVectorStoreDocument, distance) + .where(distance < self.threshold) + .order_by(distance) + .limit(k) + ).all() + ] + + def match_many(self, texts: list[str], k: int = 1) -> list[list[AffiliationSearchMatch]]: + return self.match_embeddings([self.embeddings.embed_query(text) for text in texts], k) + + def match(self, text: str, k: int = 1) -> list[AffiliationSearchMatch]: + return self.match_embedding(self.embeddings.embed_query(text), k) diff --git a/ror/indico_ror/migrations/.no-header b/ror/indico_ror/migrations/.no-header new file mode 100644 index 000000000..e69de29bb diff --git a/ror/indico_ror/migrations/20260313_1243_f95cb312d2bb_add_affiliations_vector_store.py b/ror/indico_ror/migrations/20260313_1243_f95cb312d2bb_add_affiliations_vector_store.py new file mode 100644 index 000000000..29920d7b5 --- /dev/null +++ b/ror/indico_ror/migrations/20260313_1243_f95cb312d2bb_add_affiliations_vector_store.py @@ -0,0 +1,43 @@ +"""Add affiliations vector store. + +Revision ID: f95cb312d2bb +Revises: +Create Date: 2026-03-13 12:43:32.863241 +""" + +import sqlalchemy as sa +from alembic import context, op +from pgvector.sqlalchemy import Vector +from sqlalchemy.sql.ddl import CreateSchema, DropSchema + + +# revision identifiers, used by Alembic. +revision = 'f95cb312d2bb' +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade(): + if not context.is_offline_mode(): + connection = context.get_bind() + # Make sure the vector extension is enabled + result = connection.execute("SELECT oid FROM pg_extension where extname = 'vector'") + if result.rowcount == 0: + raise Exception('The pg_extension "vector" must be enabled to run this update') + + op.execute(CreateSchema('plugin_ror')) + op.create_table('affiliation_documents', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('embedding', Vector(dim=512), nullable=False), + sa.Column('affiliation_id', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['affiliation_id'], ['indico.affiliations.id'], ondelete='cascade'), + sa.PrimaryKeyConstraint('id'), + schema='plugin_ror' + ) + + +def downgrade(): + op.drop_table('affiliation_documents', schema='plugin_ror') + op.execute(DropSchema('plugin_ror')) diff --git a/ror/indico_ror/models/affiliation_vs_document.py b/ror/indico_ror/models/affiliation_vs_document.py new file mode 100644 index 000000000..4a50819a8 --- /dev/null +++ b/ror/indico_ror/models/affiliation_vs_document.py @@ -0,0 +1,27 @@ +# This file is part of the Indico plugins. +# Copyright (C) 2002 - 2026 CERN +# +# The Indico plugins are free software; you can redistribute +# them and/or modify them under the terms of the MIT License; +# see the LICENSE file for more details. + +from pgvector.sqlalchemy import Vector + +from indico.core.db import db + + +class AffiliationVectorStoreDocument(db.Model): + __tablename__ = 'affiliation_documents' + __table_args__ = {'schema': 'plugin_ror'} + + id = db.Column(db.Integer, primary_key=True) + content = db.Column(db.Text, nullable=False) + embedding = db.Column(Vector(512), nullable=False) + affiliation_id = db.Column(db.Integer, db.ForeignKey('indico.affiliations.id', ondelete='cascade'), nullable=True) + affiliation = db.relationship( + 'Affiliation', + backref=db.backref('ror_vector_store_documents', cascade='all, delete-orphan', lazy='dynamic'), + ) + + def __repr__(self): + return f'' diff --git a/ror/indico_ror/plugin.py b/ror/indico_ror/plugin.py new file mode 100644 index 000000000..c68d081e1 --- /dev/null +++ b/ror/indico_ror/plugin.py @@ -0,0 +1,419 @@ +# This file is part of the Indico plugins. +# Copyright (C) 2002 - 2026 CERN +# +# The Indico plugins are free software; you can redistribute +# them and/or modify them under the terms of the MIT License; +# see the LICENSE file for more details. + +import csv +import io +import pathlib +import time +import zipfile +from collections import defaultdict + +import click +import requests +import sqlalchemy + +import indico +from indico.cli.core import cli_group +from indico.core import signals +from indico.core.db import db +from indico.core.plugins import IndicoPlugin +from indico.modules.events.abstracts.models.persons import AbstractPersonLink +from indico.modules.events.contributions.models.persons import ContributionPersonLink, SubContributionPersonLink +from indico.modules.events.models.persons import EventPerson, EventPersonLink +from indico.modules.events.sessions.models.persons import SessionBlockPersonLink +from indico.modules.users.models.affiliations import Affiliation +from indico.modules.users.models.users import User +from indico.util.console import verbose_iterator + +from indico_ror.matching import PSQLVectorStoreAffiliationSearchProvider + + +AFFILIATION_BACKREF_CLASSES = [ + AbstractPersonLink, + ContributionPersonLink, + EventPersonLink, + EventPerson, + SessionBlockPersonLink, + SubContributionPersonLink, + User, +] + +CSV_MATCHES_HEADER = ('Affiliation Text', 'Match Text', 'Match ID', 'Confidence') + + +def fetch_ror(): + headers = { + 'User-Agent': f'Indico/{indico.__version__}' + } + + base_url = 'https://zenodo.org/api/records/6347574' + click.echo(f"fetching records from '{base_url}'...") + + versions = requests.get( + base_url, + headers=headers, + allow_redirects=True + ) + + if versions.status_code != 200: + return None + + json = versions.json() + + filename = json['files'][0]['key'] + file_url = json['files'][0]['links']['self'] + + click.echo(f"fetching '{filename}' from '{file_url}'...") + + return requests.get( + file_url, + headers=headers, + allow_redirects=True + ) + + +def parse_csv(contents: str): + csv_dict = defaultdict(list) + header, *rows = list(csv.reader(contents.splitlines())) + for row in rows: + for key, value in zip(header, row, strict=True): + csv_dict[key].append(value) + return csv_dict + + +def extract_csv_rows(csv_dict, names): + return { + key: value + for key, value in csv_dict.items() if key in names + } + + +def iterate_csv_rows(csv_dict): + header = csv_dict.keys() + size = len(next(iter(csv_dict.values()))) + for row_n in range(size): + yield { + key: csv_dict[key][row_n] + for key in header + } + + +def get_ror_csv() -> str: + ror_zip = fetch_ror() + click.echo('extracting contents from zip...') + in_memory_zip = io.BytesIO() + in_memory_zip.write(ror_zip.content) + with zipfile.ZipFile(in_memory_zip, 'r') as zf: + for name in zf.namelist(): + if name.endswith('.csv'): + click.echo('done') + return zf.read(name).decode() + + click.echo('failed to extract contents from zip (no CSV file included)') + return + + +def sanitize_csv_name(name: str) -> str: + name = name.strip() + split = name.split(':') + if len(split) == 1: + return name + return split[-1].strip() + + +def parse_csv_names(names: str) -> set[str]: + names_list = names.split(';') + return { + sanitize_csv_name(name) for name in names_list + } + + +def parse_csv_name_column(csv_dict, column): + for row, names in enumerate(csv_dict[column]): + csv_dict[column][row] = parse_csv_names(names) + + +def parse_csv_id_column(csv_dict, column): + for row, ror_id in enumerate(csv_dict[column]): + csv_dict[column][row] = ror_id.split('/')[-1] + + +def update_ror_affiliations(ror_affiliations) -> tuple: + start = time.perf_counter() + + click.echo('updating affiliations...') + affiliations_by_ror_id = { + aff.meta['ror_id']: aff + for aff in Affiliation.query.filter(Affiliation.meta['ror_id'].isnot(None), Affiliation.is_deleted.is_(False)) + } + + updated = {} + # For each affiliation in the ROR records, try to match it with an existing one in the database. + # If a match is found, updated the old entry as needed; otherwise, add a new entry. + for ror_id, aff_data in verbose_iterator(ror_affiliations.items(), len(ror_affiliations), print_total_time=True): + if (aff := affiliations_by_ror_id.get(ror_id)) is not None: + assert aff.meta['ror_id'] == ror_id + del affiliations_by_ror_id[ror_id] + + changed = False + for field, data in aff_data.items(): + if field == 'alt_names': + old = set(aff.alt_names) + new = set(data) + if old != new: + aff.alt_names = sorted(new) + changed = True + elif getattr(aff, field) != data: + setattr(aff, field, data) + changed = True + if changed: + updated[aff.id] = aff + else: + db.session.add(Affiliation(**aff_data, meta={'ror_id': ror_id})) + + # The remaining affilitations in the database that were not matched are not + # present in the new ROR records, so they must be deleted + for remaining_affiliation in affiliations_by_ror_id.values(): + db.session.delete(remaining_affiliation) + + new = list(db.session.new) + deleted = {aff.id: aff for aff in db.session.deleted if isinstance(aff, Affiliation)} + + click.echo('flushing database session...') + db.session.flush() + + # Create the dict here since before flush() we don't have IDs + added = {aff.id: aff for aff in new if isinstance(aff, Affiliation)} + + elapsed = time.perf_counter() - start + click.echo( + f'updated {len(updated)}, added {len(added)}, and deleted {len(deleted)} affiliations in {elapsed:.2f} seconds' + ) + + return added, updated, deleted + + +def make_vectorstore_data(affiliations: dict[int, Affiliation]) -> tuple[list[str], list[int]]: + names, metadatas = [], [] + for affiliation_id, affiliation in affiliations.items(): + names.append(affiliation.name) + metadatas.append(affiliation_id) + for name in affiliation.alt_names: + names.append(name) + metadatas.append(affiliation_id) + return names, metadatas + + +def do_ror_sync(csv_dict: dict, reset: bool, dry_run: bool, batch_size: int) -> None: + filtered = extract_csv_rows( + csv_dict, + { + 'id', # id + 'names.types.ror_display', # name + 'locations.geonames_details.country_code', # country code + 'locations.geonames_details.name', # city + 'names.types.acronym', # alt name + 'names.types.alias', # alt name + 'names.types.label', # alt name + } + ) + + parse_csv_id_column(csv_dict, 'id') + parse_csv_name_column(csv_dict, 'names.types.acronym') + parse_csv_name_column(csv_dict, 'names.types.alias') + parse_csv_name_column(csv_dict, 'names.types.label') + + rows = len(next(iter(csv_dict.values()))) + + affiliations = {} + for row in verbose_iterator(iterate_csv_rows(filtered), rows, print_total_time=True): + alts = (row['names.types.acronym'] | row['names.types.alias'] | row['names.types.label']) - {''} + affiliations[row['id']] = { + 'name': row['names.types.ror_display'], + 'city': row['locations.geonames_details.name'], + 'country_code': row['locations.geonames_details.country_code'], + 'alt_names': alts + } + + if reset: + click.echo('deleting previous ROR affiliations from the database...') + db.session.execute( + sqlalchemy.delete(Affiliation) + .where(Affiliation.meta['ror_id'].isnot(None), Affiliation.is_deleted.is_(False)) + .execution_options(synchronize_session='fetch') + ) + + click.echo('updating vector store...') + added, updated, deleted = update_ror_affiliations(affiliations) + added_texts, added_ids = make_vectorstore_data(added) + updated_texts, updated_ids = make_vectorstore_data(updated) + __, deleted_ids = make_vectorstore_data(deleted) + + search_engine = PSQLVectorStoreAffiliationSearchProvider(batch_size=batch_size) + search_engine.delete(deleted_ids) + search_engine.update(updated_texts, updated_ids, updated_ids) + search_engine.add(added_texts, added_ids) + + if not dry_run: + click.echo('committing database changes...') + db.session.commit() + + +class RORPlugin(IndicoPlugin): + """ROR affiliations. + + Provides access to ROR affiliation information. + """ + + configurable = False + + def init(self): + super().init() + self.connect(signals.affiliations.get_affiliation_search_providers, self.get_search_providers) + self.connect(signals.plugin.cli, self._extend_indico_cli) + + def get_search_providers(self, sender, **kwargs): + from indico_ror.matching import PSQLVectorStoreAffiliationSearchProvider + return PSQLVectorStoreAffiliationSearchProvider + + def _extend_indico_cli(self, sender, **kwargs): + @cli_group() + def ror(): + """Manage ROR storage.""" + + @ror.command() + @click.option( + '--output', type=click.Path(dir_okay=False, path_type=pathlib.Path), + default='ror.csv', help='The output file name.' + ) + def download(output: str) -> None: + """Download affiliation metadata from ROR registry.""" + csv_contents = get_ror_csv() + pathlib.Path(output).write_text(csv_contents) + + @ror.command() + @click.option( + '--csv', type=click.Path(dir_okay=False, path_type=pathlib.Path), help='Path to a pre-downloaded CSV.' + ) + @click.option( + '--reset', is_flag=True, help='Delete all previously existing affiliations from ROR before starting.' + ) + @click.option('--dry-run', is_flag=True, help="Don't persist any changes to the database.") + @click.option( + '--batch-size', default=512, type=click.INT, help='Change the batch size when calculating embeddings.' + ) + def sync(csv: pathlib.Path | None, reset: bool, dry_run: bool, batch_size: int) -> None: + """Update the affiliations in the database from ROR registry.""" + if csv is None: + csv_contents = get_ror_csv() + else: + csv_contents = csv.read_text() + + click.echo('parsing ROR affiliations CSV...') + csv_dict = parse_csv(csv_contents) + do_ror_sync(csv_dict, reset, dry_run, batch_size) + click.echo('done') + + @ror.command() + @click.argument( + 'output', type=click.Path(dir_okay=False, path_type=pathlib.Path), + ) + def match(output: pathlib.Path) -> None: + """Match "free-text" affiliations with affiliations stored in the database.""" + search_engine = PSQLVectorStoreAffiliationSearchProvider() + + click.echo('loading affiliations...') + affiliations: set[str] = set() + for cls in AFFILIATION_BACKREF_CLASSES: + affiliations = affiliations.union( + str(cwa.affiliation) + for cwa in cls.query.filter(cls.affiliation.is_not(None), cls.affiliation != '').all() # noqa: PLC1901 + ) + + def process_affiliation( + affiliation: str + ): + results = search_engine.match(affiliation, 1) + if len(results) == 0: + return None + return (affiliation, results[0]) + + click.echo('matching...') + results = [process_affiliation(affiliation) for affiliation in affiliations] + + click.echo('saving results...') + with output.open('w') as csvf: + writer = csv.writer(csvf) + writer.writerow(CSV_MATCHES_HEADER) + writer.writerows( + (result[0], result[1].text, result[1].affiliation_id, result[1].score) + for result in results if result is not None + ) + + click.echo('done') + + @ror.command() + @click.argument( + 'matches', type=click.Path(dir_okay=False, exists=True, path_type=pathlib.Path) + ) + @click.option('--force', '-f', is_flag=True, help='By-pass the header check.') + @click.option('--keep-original', '-k', is_flag=True, help='Keep original affiliation text after updating.') + @click.option('--dry-run', is_flag=True, help="Don't persist any changes to the database.") + def apply(matches: pathlib.Path, force: bool, dry_run: bool, keep_original: bool) -> None: + """Apply a set of previously found matches.""" + click.echo('reading matches from file...') + with matches.open() as csvf: + reader = csv.reader(csvf) + headers = tuple(next(reader)) + + # TODO: actually fetch the right columns to work with differently formatted CSVs + if not force and headers != CSV_MATCHES_HEADER: + click.secho('error: CSV header mismatch; re-run with --force to proceed', fg='red') + return + + try: + correspondence = { + text: int(match_id) for text, _, match_id, *_ in reader + } + except ValueError: + click.secho(f'error: invalid match ID encountered while parsing {str(matches)!r}') + return + + click.echo('retrieving matches from the database...') + affiliation_texts = list(correspondence.keys()) + objects_to_update = [ + object_with_affiliation for sublist in + [ + cls.query.filter(cls.affiliation.in_(affiliation_texts)) + for cls in AFFILIATION_BACKREF_CLASSES + ] + for object_with_affiliation in sublist + ] + + click.echo('updating...') + updates = 0 + for object_to_update in objects_to_update: + match_result = correspondence.get(object_to_update.affiliation) + if match_result is None: + click.secho( + f"warning: couldn't get match for entry {object_to_update} " + f"with affiliation {object_to_update.affiliation!r}", + fg='yellow' + ) + continue + + updates += 1 + object_to_update.affiliation_id = match_result + if not keep_original: + object_to_update.affiliation = None + + if not dry_run: + db.session.commit() + + click.echo(f'updated {updates}/{len(objects_to_update)} entries') + + return ror diff --git a/ror/indico_ror/task.py b/ror/indico_ror/task.py new file mode 100644 index 000000000..4c299678c --- /dev/null +++ b/ror/indico_ror/task.py @@ -0,0 +1,19 @@ +# This file is part of the Indico plugins. +# Copyright (C) 2002 - 2026 CERN +# +# The Indico plugins are free software; you can redistribute +# them and/or modify them under the terms of the MIT License; +# see the LICENSE file for more details. + +from celery.schedules import crontab + +from indico.core.celery import celery + +from indico_ror.plugin import do_ror_sync, get_ror_csv, parse_csv + + +@celery.periodic_task(name='sync_ror', run_every=crontab(minute=0)) +def sync_ror(): + csv = get_ror_csv() + csv_dict = parse_csv(csv) + do_ror_sync(csv_dict, False, False, 512) diff --git a/ror/pyproject.toml b/ror/pyproject.toml new file mode 100644 index 000000000..11437dea1 --- /dev/null +++ b/ror/pyproject.toml @@ -0,0 +1,42 @@ +[project] +name = 'indico-plugin-ror' +description = 'ROR Affiliations plugin for Indico' +readme = 'README.md' +version = '3.3.13-dev' +license = 'MIT' +authors = [{ name = 'Indico Team', email = 'indico-team@cern.ch' }] +classifiers = [ + 'Environment :: Plugins', + 'Environment :: Web Environment', + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python :: 3.12', +] +requires-python = '>=3.12.2, <3.13' +dependencies = ['indico>=3.3.13.dev0', 'langchain_core', 'langchain_ollama', 'pgvector'] + +[project.urls] +GitHub = 'https://github.com/indico/indico-plugins' + +[project.entry-points.'indico.plugins'] +ror = 'indico_ror.plugin:RORPlugin' + +[build-system] +requires = ['hatchling==1.28.0'] +build-backend = 'hatchling.build' + +[tool.hatch.build] +packages = ['indico_ror'] +exclude = [ + '*.no-header', + '.keep', + # exclude original client sources (they are all included in source maps anyway) + 'indico_*/client/', + # no need for tests outside development + 'test_snapshots/', + 'tests/', + '*_test.py', +] + +[tool.hatch.build.targets.sdist.hooks.custom] +path = '../hatch_build.py' +dependencies = ['babel==2.18.0'] diff --git a/ror/pytest.ini b/ror/pytest.ini new file mode 100644 index 000000000..57c2f9c1b --- /dev/null +++ b/ror/pytest.ini @@ -0,0 +1,9 @@ +[pytest] +; more verbose summary (include skip/fail/error/warning) +addopts = -rsfEw +; only check for tests in suffixed files +python_files = *_test.py +; we need the ror plugin to be loaded +indico_plugins = ror +; use psql's vector extension for vector storage +indico_pg_extensions = ["vector"] \ No newline at end of file