From 941e570ce6095aaafe63adeaa6806b93958c9643 Mon Sep 17 00:00:00 2001 From: Harrison Date: Fri, 5 Jun 2026 11:29:57 +0100 Subject: [PATCH 1/2] refactor: Use pymetkit's PARAMDB --- pyproject.toml | 1 + src/anemoi/utils/grib.py | 222 +----- src/anemoi/utils/settings_schema/paramdb.py | 9 +- .../settings_schema/settings.defaults.toml | 15 +- tests/parameter_metadata_test.yaml | 83 ++ tests/test_grib.py | 713 ++++-------------- 6 files changed, 248 insertions(+), 795 deletions(-) create mode 100644 tests/parameter_metadata_test.yaml diff --git a/pyproject.toml b/pyproject.toml index 1392f0b..17ce05f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "numpy", "pydantic>=2.9", "pydantic-settings>=2.14.1", + "pymetkit", "python-dateutil", "pyyaml", "rich", diff --git a/src/anemoi/utils/grib.py b/src/anemoi/utils/grib.py index 4c3d430..cd6d4f1 100644 --- a/src/anemoi/utils/grib.py +++ b/src/anemoi/utils/grib.py @@ -1,4 +1,4 @@ -# (C) Copyright 2024 Anemoi contributors. +# (C) Copyright 2024- Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. @@ -10,212 +10,26 @@ """Utilities for working with GRIB parameters. +Aliases Pymetkit's ParamDB to provide a consistent interface for GRIB parameter lookups across Anemoi, with configuration controlled by Anemoi settings. +Use `PARAMDB` for direct access. + See https://codes.ecmwf.int/grib/param-db/ for more information. """ -import json -import logging -import os -import re -import warnings -from functools import cache +from datetime import timedelta -import requests +from pymetkit import ParamDB -from .caching import cached from .settings import AnemoiSettings -LOG = logging.getLogger(__name__) - SETTINGS = AnemoiSettings() """Anemoi settings, loaded on module import.""" - -@cached(collection="grib", expires=SETTINGS.paramdb.cache_length * 24 * 3600) -def _units() -> dict[str, str]: - """Fetch and cache GRIB parameter units. - - Returns - ------- - dict - A dictionary mapping unit ids to their names. - """ - r = requests.get("https://codes.ecmwf.int/parameter-database/api/v1/unit/") - r.raise_for_status() - units = r.json() - return {str(u["id"]): u["name"] for u in units} - - -@cache -def _get_local_db(local_db: str) -> list[dict[str, str | int | list[str]]]: - """Open the local GRIB parameter database. - - Parameters - ---------- - local_db : str - Path to the local cache file. - - Returns - ------- - list[dict[str, str | int | list[str]]] - A list of dictionaries containing parameter details. - - Raises - ------ - FileNotFoundError - If the local db file is not found. - """ - - if not os.path.exists(local_db): - raise FileNotFoundError(f"Local cache file {local_db} not found.") - - return json.load(open(local_db, "r")) - - -@cache -def _local_search_param(name: str) -> list[dict[str, str | int | list[str]]]: - """Search for a GRIB parameter by name in the local cache. - - This is used to avoid making API calls when the local cache is available. - - Parameters - ---------- - name : str - Parameter name to search for. - - Returns - ------- - list - A list of dictionaries containing parameter details. - - Raises - ------ - KeyError - If no parameter is found. - """ - local_cache = SETTINGS.paramdb.local_cache - assert local_cache is not None, "Local cache is not configured." - - local_param_db = _get_local_db(local_cache) - - matched_params = [param for param in local_param_db if param["shortname"] == name] - if matched_params: - return matched_params - - raise KeyError(f"{name} not found in local cache.") - - -@cached(collection="grib", expires=SETTINGS.paramdb.cache_length * 24 * 3600) -def _online_search_param(name: str, **filters) -> list[dict[str, str | int | list[str]]]: - """Search for a GRIB parameter by name using the online API. - - Parameters - ---------- - name : str - Parameter name to search for. - filters : Any - Additional filters to disambiguate parameters with the same shortname (e.g. origin, encoding, table, discipline, category). - - Returns - ------- - list - A list of dictionaries containing parameter details. - """ - r = requests.get( - f"https://codes.ecmwf.int/parameter-database/api/v1/param/?search=^{name}$®ex=true{''.join(f'&{k}={v}' for k, v in filters.items())}" - ) - r.raise_for_status() - return r.json() - - -def _search_param(name: str, **filters) -> dict[str, str | int | list[str]]: - """Search for a GRIB parameter by name. - - Parameters - ---------- - name : str - Parameter name to search for. - filters : Any - Additional filters to disambiguate parameters with the same shortname (e.g. origin, encoding, table, discipline, category). - - Returns - ------- - dict - A dictionary containing parameter details. - - Raises - ------ - KeyError - If no parameter is found. - """ - if "origin" in filters and isinstance(filters["origin"], str): - filters["origin"] = origin(filters["origin"])["id"] - - name = re.escape(name) - - if SETTINGS.paramdb.local_cache is not None: - if filters: - warnings.warn("Filters are ignored when using local cache.") - results = _local_search_param(name) - else: - results = _online_search_param(name, **filters) - - if len(results) == 0: - raise KeyError(f"{name} not found in parameter database.") - - if len(results) > 1: - names = [f"{r.get('id')} ({r.get('name')})" for r in results] - dissemination = [r for r in results if "dissemination" in r.get("access_ids", [])] # type: ignore[reportOperatorIssue] - if len(dissemination) == 1: - return dissemination[0] - - warnings.warn(f"{name} is ambiguous: {', '.join(names)}.") - if "origin" not in filters and SETTINGS.paramdb.local_cache is None: - warnings.warn(f"Applying origin='{SETTINGS.paramdb.default_origin}' in an attempt to disambiguate {name}.") - try: - filtered_param = _search_param(name, **{**filters, "origin": SETTINGS.paramdb.default_origin}) - warnings.warn( - f"Disambiguated {name} to id: {filtered_param['id']} ({filtered_param.get('name', 'unknown')})." - ) - return filtered_param - except KeyError: - pass - - warnings.warn(f"Failed to disambiguate {name}'. Returning the first match: {names[0]}.") - results = sorted(results, key=lambda x: x["id"]) - - return results[0] - - -@cached(collection="grib", expires=SETTINGS.paramdb.cache_length * 24 * 3600) -def origin(name: str) -> dict[str, str | int]: - """Search for an id of an origin by name. - - Parameters - ---------- - name : str - Origin name to search for. - - Returns - ------- - dict - A dictionary containing origin details. - - Raises - ------ - KeyError - If no origin is found. - """ - name = re.escape(name) - r = requests.get("https://codes.ecmwf.int/parameter-database/api/v1/origin/") - r.raise_for_status() - results = r.json() - - for result in results: - if result["abbreviation"] == name: - return result - - raise KeyError(f"{name} not found in origin database.") +PARAMDB = ParamDB( + mode=SETTINGS.paramdb.mode, + cache_ttl=timedelta(days=SETTINGS.paramdb.cache_length), + yaml_path=SETTINGS.paramdb.local_data, +) def shortname_to_paramid(shortname: str, **filters) -> int: @@ -226,7 +40,7 @@ def shortname_to_paramid(shortname: str, **filters) -> int: shortname : str Parameter shortname. filters : Any - Additional filters to disambiguate parameters with the same shortname (e.g. origin, encoding, table, discipline, category). + Additional filters to disambiguate parameters with the same shortname (e.g. origin, access, table). Returns ------- @@ -236,7 +50,7 @@ def shortname_to_paramid(shortname: str, **filters) -> int: >>> shortname_to_paramid("2t") 167 """ - return _search_param(shortname, **filters)["id"] # type: ignore[reportReturnType] + return PARAMDB.shortname_to_param_id(shortname, **filters) def paramid_to_shortname(paramid: int, **filters) -> str: @@ -247,7 +61,7 @@ def paramid_to_shortname(paramid: int, **filters) -> str: paramid : int Parameter id. filters : Any - Additional filters to disambiguate parameters with the same shortname (e.g. origin, encoding, table, discipline, category). + Additional filters to disambiguate parameters with the same shortname (e.g. origin, access, table, discipline, category). Returns ------- @@ -257,7 +71,7 @@ def paramid_to_shortname(paramid: int, **filters) -> str: >>> paramid_to_shortname(167) '2t' """ - return _search_param(str(paramid), **filters)["shortname"] # type: ignore[reportReturnType] + return PARAMDB.param_id_to_shortname(paramid, **filters) # type: ignore[reportReturnType] def units(param: int | str) -> str: @@ -273,12 +87,10 @@ def units(param: int | str) -> str: str Parameter unit. - >>> unit(167) + >>> units(167) 'K' """ - - unit_id = str(_search_param(str(param))["unit_id"]) - return _units()[unit_id] + return PARAMDB.get_units(param) def must_be_positive(param: int | str) -> bool: diff --git a/src/anemoi/utils/settings_schema/paramdb.py b/src/anemoi/utils/settings_schema/paramdb.py index 3fa0c80..46a2839 100644 --- a/src/anemoi/utils/settings_schema/paramdb.py +++ b/src/anemoi/utils/settings_schema/paramdb.py @@ -7,6 +7,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from typing import Literal from typing import Optional from pydantic import FilePath @@ -22,11 +23,11 @@ class ParamDBConfig(AnemoiBaseSettingsSchema): and cache lifetime). """ - default_origin: str = "ecmf" - """Default origin to use when disambiguating parameters with the same shortname.""" + mode: Literal["online", "offline"] = "offline" + """Mode for GRIB parameter lookups. 'online' uses the ECMWF API, 'offline' uses a local cache file.""" cache_length: int = 30 """Cache length in days for GRIB parameter lookups.""" - local_cache: Optional[FilePath] = None - """Path to a local JSON cache file for GRIB parameters. If set, used instead of the online API.""" + local_data: Optional[FilePath] = None + """Path to a local YAML cache file for GRIB parameters. If set, used instead of the online API.""" diff --git a/src/anemoi/utils/settings_schema/settings.defaults.toml b/src/anemoi/utils/settings_schema/settings.defaults.toml index 6d63f4d..4cd0f83 100644 --- a/src/anemoi/utils/settings_schema/settings.defaults.toml +++ b/src/anemoi/utils/settings_schema/settings.defaults.toml @@ -111,17 +111,16 @@ ignore_naming_conventions = false # ============================================================================= [paramdb] -# Default origin to use when disambiguating parameters with the same shortname. -# Common values: "ecmf", "destine". -default_origin = "ecmf" +# Mode for GRIB parameter lookups. 'online' uses the ECMWF API, 'offline' uses a local cache file. +mode = "offline" -# Cache length in days for GRIB parameter lookups from the online API. +# Cache length in days for GRIB parameter lookups. cache_length = 30 -# Path to a local JSON cache file for GRIB parameters. -# If set, this will be used instead of making online API calls. -# Leave empty / unset to use the online API. -# local_cache = "/path/to/parameters.json" +# Path to a local YAML cache file for GRIB parameters. +# If set, used instead of the offline metkit cache. +# Leave empty / unset to use the online / offline ECMWF API. +# local_data = "/path/to/parameters.yaml" # ============================================================================= diff --git a/tests/parameter_metadata_test.yaml b/tests/parameter_metadata_test.yaml new file mode 100644 index 0000000..277636f --- /dev/null +++ b/tests/parameter_metadata_test.yaml @@ -0,0 +1,83 @@ +- id: 167 + shortname: 2t + longname: 2 metre temperature + units: K + origin_ids: [98] + access_ids: [dissemination] + +- id: 228 + shortname: tp + longname: Total precipitation + units: m + origin_ids: [98] + access_ids: [dissemination] + +- id: 130 + shortname: t + longname: Temperature + units: K + origin_ids: [98] + access_ids: [dissemination] + +- id: 500014 + shortname: t + longname: Temperature + units: K + origin_ids: [7] + access_ids: [] + +- id: 134 + shortname: sp + longname: Surface pressure + units: Pa + origin_ids: [98] + access_ids: [dissemination] + +- id: 151 + shortname: msl + longname: Mean sea level pressure + units: Pa + origin_ids: [98] + access_ids: [dissemination] + +- id: 165 + shortname: 10u + longname: 10 metre U wind component + units: m s**-1 + origin_ids: [98] + access_ids: [dissemination] + +- id: 166 + shortname: 10v + longname: 10 metre V wind component + units: m s**-1 + origin_ids: [98] + access_ids: [dissemination] + +- id: 31 + shortname: ci + longname: Sea ice area fraction + units: (0 - 1) + origin_ids: [98] + access_ids: [dissemination] + +- id: 260048 + shortname: sf + longname: Snowfall + units: m of water equivalent + origin_ids: [98] + access_ids: [dissemination] + +- id: 59 + shortname: cape + longname: Convective available potential energy + units: J kg**-1 + origin_ids: [98] + access_ids: [dissemination] + +- id: 75 + shortname: crwc + longname: Specific rain water content + units: kg kg**-1 + origin_ids: [98] + access_ids: [dissemination] diff --git a/tests/test_grib.py b/tests/test_grib.py index d047418..47763de 100644 --- a/tests/test_grib.py +++ b/tests/test_grib.py @@ -7,630 +7,187 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -"""Tests for grib.py — GRIB parameter lookup utilities. +"""Tests for anemoi.utils.grib -- API consistency. -Covers the filtering/disambiguation logic added in the advanced-shortname-filtering -feature, including origin-based filtering and default origin fallback. +Each test is parametrised over two ParamDB backends: + +- **local**: our own test YAML fixture -- always available. +- **bundled**: pymetkit's shipped parameter_metadata.yaml. + Currently xfail because the YAML is not packaged with the dev install. + Will start passing (and enforcing) once pymetkit ships it. """ from __future__ import annotations -import json -from contextlib import contextmanager +import sys from pathlib import Path -from unittest.mock import MagicMock from unittest.mock import patch import pytest +from pymetkit import ParamDB # --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _mock_response(json_data, status_code=200): - """Create a mock requests.Response.""" - resp = MagicMock() - resp.status_code = status_code - resp.json.return_value = json_data - resp.raise_for_status.return_value = None - return resp - - -# Sample API payloads ------------------------------------------------------- - -PARAM_2T_ECMF = { - "id": 167, - "shortname": "2t", - "name": "2 metre temperature", - "unit_id": 1, - "access_ids": ["dissemination"], -} - -PARAM_2T_OTHER = { - "id": 500167, - "shortname": "2t", - "name": "2 metre temperature (other centre)", - "unit_id": 1, - "access_ids": [], -} - -PARAM_AMBIGUOUS_A = { - "id": 200, - "shortname": "xx", - "name": "Ambiguous param A", - "unit_id": 1, - "access_ids": [], -} - -PARAM_AMBIGUOUS_B = { - "id": 300, - "shortname": "xx", - "name": "Ambiguous param B", - "unit_id": 1, - "access_ids": [], -} - -ORIGIN_ECMF = { - "id": 98, - "abbreviation": "ecmf", - "name": "European Centre for Medium-Range Weather Forecasts", -} -ORIGIN_DESTINE = { - "id": -60, - "abbreviation": "destine", - "name": "Destination Earth Local parameter definitions", -} - -ALL_ORIGINS = [ORIGIN_ECMF, ORIGIN_DESTINE] - - -# --------------------------------------------------------------------------- -# Fixtures — patch caching so every call goes straight through +# Fixtures # --------------------------------------------------------------------------- +FIXTURE_YAML = Path(__file__).parent / "parameter_metadata_test.yaml" -@pytest.fixture(autouse=True) -def _bypass_cache(monkeypatch): - """Replace the @cached decorator with a no-op so tests hit the mocked API.""" - import anemoi.utils.caching as caching_mod - - def _noop_cached(*args, **kwargs): - def decorator(func): - return func +_bundled_xfail = pytest.param( + "bundled", + marks=pytest.mark.xfail( + raises=FileNotFoundError, + reason="pymetkit bundled YAML not installed in this environment", + strict=False, + ), +) - return decorator - monkeypatch.setattr(caching_mod, "cached", _noop_cached) +@pytest.fixture(params=["local", _bundled_xfail]) +def grib(request, monkeypatch): + """Import anemoi.utils.grib and swap PARAMDB to the requested backend. - # Re-import the module so the patched decorator is applied - import importlib + For 'local', the module is imported with settings pointing at our test + YAML. For 'bundled', PARAMDB is replaced with a ParamDB using pymetkit's + default YAML lookup (which may not exist). + """ + from anemoi.utils.settings import AnemoiSettings - import anemoi.utils.grib as grib_mod - - importlib.reload(grib_mod) - yield - # Reload again to restore original state for other test modules - importlib.reload(grib_mod) + yaml_path = FIXTURE_YAML.resolve() + settings = AnemoiSettings() + settings.paramdb.local_data = yaml_path + # Ensure a clean import of the grib module with patched settings + cached = sys.modules.pop("anemoi.utils.grib", None) + try: + with patch("anemoi.utils.settings.AnemoiSettings", return_value=settings): + import anemoi.utils.grib as grib_mod -def _grib(): - """Import the reloaded grib module.""" - import anemoi.utils.grib as grib_mod + if request.param == "bundled": + bundled = ParamDB(mode="offline") + monkeypatch.setattr(grib_mod, "PARAMDB", bundled) - return grib_mod + yield grib_mod + finally: + sys.modules.pop("anemoi.utils.grib", None) + if cached is not None: + sys.modules["anemoi.utils.grib"] = cached # --------------------------------------------------------------------------- -# Basic lookup tests +# shortname_to_paramid # --------------------------------------------------------------------------- class TestShortNameToParamId: - """Tests for shortname_to_paramid.""" - - @patch("anemoi.utils.grib.requests.get") - def test_single_result(self, mock_get): - """A unique match returns the parameter id directly.""" - mock_get.return_value = _mock_response([PARAM_2T_ECMF]) - - grib = _grib() - assert grib.shortname_to_paramid("2t") == 167 - - @patch("anemoi.utils.grib.requests.get") - def test_no_result_raises_key_error(self, mock_get): - """No matches should raise KeyError.""" - mock_get.return_value = _mock_response([]) - - grib = _grib() + @pytest.mark.parametrize( + "shortname,expected_id", + [ + ("2t", 167), + ("tp", 228), + ("sp", 134), + ("msl", 151), + ("ci", 31), + ], + ) + def test_known_params(self, grib, shortname, expected_id): + result = grib.shortname_to_paramid(shortname) + assert result == expected_id + assert isinstance(result, int) + + def test_unknown_raises_key_error(self, grib): with pytest.raises(KeyError): - grib.shortname_to_paramid("nonexistent") - - -class TestParamIdToShortName: - """Tests for paramid_to_shortname.""" - - @patch("anemoi.utils.grib.requests.get") - def test_single_result(self, mock_get): - """A unique match returns the shortname.""" - mock_get.return_value = _mock_response([PARAM_2T_ECMF]) - - grib = _grib() - assert grib.paramid_to_shortname(167) == "2t" - - -# --------------------------------------------------------------------------- -# Disambiguation tests -# --------------------------------------------------------------------------- - - -class TestDisambiguationByDissemination: - """When multiple results exist, a single 'dissemination' entry wins.""" - - @patch("anemoi.utils.grib.requests.get") - def test_dissemination_filter_selects_correct_entry(self, mock_get): - """If exactly one result has 'dissemination' access, it is returned.""" - # The correct entry (PARAM_2T_ECMF) is second in the response list - mock_get.return_value = _mock_response([PARAM_2T_OTHER, PARAM_2T_ECMF]) - - grib = _grib() - result = grib.shortname_to_paramid("2t") - assert result == 167, ( - "Should pick the entry with 'dissemination' access (id=167), " "not the first entry (id=500167)" - ) - - -class TestDisambiguationByDefaultOrigin: - """When dissemination doesn't resolve ambiguity, default_origin is applied.""" - - @patch("anemoi.utils.grib.requests.get") - def test_default_origin_applied_when_ambiguous(self, mock_get): - """Two results, neither with dissemination → falls back to default_origin filter. - - The second API call (with origin filter) returns the correct single match. - """ - - # First call: ambiguous, two results without dissemination - # Second call (with origin): returns only the correct one - # Third call: origin lookup - def side_effect(url, **kwargs): - if "origin/" in url: - return _mock_response(ALL_ORIGINS) - if "origin=" in url: - # Filtered call returns the correct entry only - return _mock_response([PARAM_AMBIGUOUS_A]) - # Unfiltered call returns ambiguous results - return _mock_response([PARAM_AMBIGUOUS_B, PARAM_AMBIGUOUS_A]) - - mock_get.side_effect = side_effect - - grib = _grib() - with pytest.warns(UserWarning): - result = grib.shortname_to_paramid("xx") - # The default_origin filter resolves to PARAM_AMBIGUOUS_A (id=200) - assert result == 200 - - @patch("anemoi.utils.grib.requests.get") - def test_default_origin_fallback_to_sorted_first(self, mock_get): - """If default_origin also fails (KeyError), returns sorted first result.""" - - def side_effect(url, **kwargs): - if "origin/" in url: - return _mock_response(ALL_ORIGINS) - if "origin=" in url: - # Origin filter yields nothing - return _mock_response([]) - # Unfiltered: ambiguous - return _mock_response([PARAM_AMBIGUOUS_B, PARAM_AMBIGUOUS_A]) - - mock_get.side_effect = side_effect - - grib = _grib() - # Should fall back to sorted-first → id 200 (PARAM_AMBIGUOUS_A) - with pytest.warns(UserWarning): - result = grib.shortname_to_paramid("xx") - assert result == 200, "Should return the entry with the lowest id after sorting" + grib.shortname_to_paramid("nonexistent_xyz") # --------------------------------------------------------------------------- -# Explicit filter tests +# paramid_to_shortname # --------------------------------------------------------------------------- -class TestExplicitOriginFilter: - """Passing origin= explicitly should change the returned parameter.""" - - @patch("anemoi.utils.grib.requests.get") - def test_origin_filter_changes_returned_key(self, mock_get): - """Filtering by origin should return the correct entry even when it is - the second item in the unfiltered response. - - This verifies that filtering by origin (and by extension the default - origin) correctly selects a different entry than a naive 'first result' - approach would. - """ - - def side_effect(url, **kwargs): - if "origin/" in url: - return _mock_response(ALL_ORIGINS) - if "origin=98" in url: - # When filtered by ecmf origin (id=98), only the ECMF entry - return _mock_response([PARAM_2T_ECMF]) - if "origin=-60" in url: - # When filtered by destine origin (id=-60), only the other entry - return _mock_response([PARAM_2T_OTHER]) - # Unfiltered - return _mock_response([PARAM_2T_OTHER, PARAM_2T_ECMF]) - - mock_get.side_effect = side_effect - - grib = _grib() - - # Without filter the correct entry (167) is the SECOND in the response. - # The dissemination filter would pick it, but with an explicit origin - # the API is called with origin= directly and returns only that entry. - result_ecmf = grib.shortname_to_paramid("2t") - assert ( - result_ecmf == 167 - ), "Filtering by nothing should return the ECMF entry with id 167, as it is set by default" - - result_ecmf = grib.shortname_to_paramid("2t", origin="ecmf") - assert result_ecmf == 167 - - result_destine = grib.shortname_to_paramid("2t", origin="destine") - assert result_destine == 500167 - - @patch("anemoi.utils.grib.requests.get") - def test_origin_filter_not_reapplied_when_already_provided(self, mock_get): - """If the caller already passes origin=, the default_origin fallback - should NOT be applied on top of it. - """ - - def side_effect(url, **kwargs): - if "origin/" in url: - return _mock_response(ALL_ORIGINS) - # Even with origin filter, still ambiguous (edge case) - return _mock_response([PARAM_AMBIGUOUS_B, PARAM_AMBIGUOUS_A]) - - mock_get.side_effect = side_effect - - grib = _grib() - with pytest.warns(UserWarning): - result = grib.shortname_to_paramid("xx", origin="destine") - # Should NOT recurse with default_origin because origin was already - # provided. Falls through to sorted first → id 200. - assert result == 200 - - @patch("anemoi.utils.grib.requests.get") - def test_paramid_to_shortname_with_origin_filter(self, mock_get): - """paramid_to_shortname also forwards filters correctly.""" - - def side_effect(url, **kwargs): - if "origin/" in url: - return _mock_response(ALL_ORIGINS) - if "origin=" in url: - return _mock_response([PARAM_2T_ECMF]) - return _mock_response([PARAM_2T_OTHER, PARAM_2T_ECMF]) - - mock_get.side_effect = side_effect - - grib = _grib() - result = grib.paramid_to_shortname(167, origin="ecmf") - assert result == "2t" - - -# --------------------------------------------------------------------------- -# _search_origin tests -# --------------------------------------------------------------------------- - - -class TestSearchOrigin: - """Tests for the _search_origin helper.""" - - @patch("anemoi.utils.grib.requests.get") - def test_known_origin(self, mock_get): - mock_get.return_value = _mock_response(ALL_ORIGINS) - - grib = _grib() - result = grib.origin("ecmf") - assert result["id"] == 98 - assert result["abbreviation"] == "ecmf" - - @patch("anemoi.utils.grib.requests.get") - def test_unknown_origin_raises_key_error(self, mock_get): - mock_get.return_value = _mock_response(ALL_ORIGINS) - - grib = _grib() +class TestParamIdToShortName: + @pytest.mark.parametrize( + "paramid,expected_name", + [ + (167, "2t"), + (228, "tp"), + (134, "sp"), + (151, "msl"), + (31, "ci"), + ], + ) + def test_known_params(self, grib, paramid, expected_name): + result = grib.paramid_to_shortname(paramid) + assert result == expected_name + assert isinstance(result, str) + + def test_unknown_raises_key_error(self, grib): with pytest.raises(KeyError): - grib.origin("zzzz") + grib.paramid_to_shortname(9999999) # --------------------------------------------------------------------------- -# units / must_be_positive tests +# Roundtrip consistency # --------------------------------------------------------------------------- -class TestUnits: - """Tests for the units() and must_be_positive() helpers.""" - - @patch("anemoi.utils.grib.requests.get") - def test_units_returns_correct_string(self, mock_get): - unit_data = [{"id": 1, "name": "K"}, {"id": 2, "name": "m"}] - - def side_effect(url, **kwargs): - if "unit/" in url: - return _mock_response(unit_data) - return _mock_response([PARAM_2T_ECMF]) - - mock_get.side_effect = side_effect - - grib = _grib() - assert grib.units("2t") == "K" - - @patch("anemoi.utils.grib.requests.get") - def test_must_be_positive(self, mock_get): - unit_data = [{"id": 1, "name": "K"}, {"id": 2, "name": "m"}] - param_tp = {**PARAM_2T_ECMF, "id": 228, "shortname": "tp", "unit_id": 2} - - def side_effect(url, **kwargs): - if "unit/" in url: - return _mock_response(unit_data) - return _mock_response([param_tp]) - - mock_get.side_effect = side_effect - - grib = _grib() - assert grib.must_be_positive("tp") is True +class TestRoundtripConsistency: + @pytest.mark.parametrize("shortname", ["2t", "tp", "sp", "msl", "10u", "10v"]) + def test_shortname_roundtrip(self, grib, shortname): + paramid = grib.shortname_to_paramid(shortname) + assert grib.paramid_to_shortname(paramid) == shortname - @patch("anemoi.utils.grib.requests.get") - def test_must_be_positive_false(self, mock_get): - unit_data = [{"id": 1, "name": "K"}] - - def side_effect(url, **kwargs): - if "unit/" in url: - return _mock_response(unit_data) - return _mock_response([PARAM_2T_ECMF]) - - mock_get.side_effect = side_effect - - grib = _grib() - assert grib.must_be_positive("2t") is False + @pytest.mark.parametrize("paramid", [167, 228, 134, 151, 165, 166]) + def test_paramid_roundtrip(self, grib, paramid): + shortname = grib.paramid_to_shortname(paramid) + assert grib.shortname_to_paramid(shortname) == paramid # --------------------------------------------------------------------------- -# URL construction tests +# units # --------------------------------------------------------------------------- -class TestUrlConstruction: - """Verify that filter kwargs are appended to the API URL.""" - - @patch("anemoi.utils.grib.requests.get") - def test_filters_appended_to_url(self, mock_get): - mock_get.return_value = _mock_response([PARAM_2T_ECMF]) - - grib = _grib() - grib.shortname_to_paramid("2t", encoding=1, table=128) - - call_url = mock_get.call_args[0][0] - assert "encoding=1" in call_url - assert "table=128" in call_url - - @patch("anemoi.utils.grib.requests.get") - def test_origin_string_resolved_to_id_in_url(self, mock_get): - """When origin is a string, it should be resolved to a numeric id - via _search_origin before being appended to the URL. - """ - - def side_effect(url, **kwargs): - if "origin/" in url: - return _mock_response(ALL_ORIGINS) - return _mock_response([PARAM_2T_ECMF]) - - mock_get.side_effect = side_effect - - grib = _grib() - grib.shortname_to_paramid("2t", origin="ecmf") - - # Find the param search call (not the origin lookup) - param_calls = [c for c in mock_get.call_args_list if "param/" in str(c)] - assert len(param_calls) >= 1 - param_url = param_calls[0][0][0] - assert "origin=98" in param_url, f"Expected origin to be resolved to numeric id 98, got URL: {param_url}" - - -# --------------------------------------------------------------------------- -# Local cache tests -# --------------------------------------------------------------------------- - -# Minimal mock data modelled after parameters.json structure -LOCAL_CACHE_DATA = [ - { - "id": 1, - "name": "Stream function", - "shortname": "strf", - "unit_id": 1, - "encoding_ids": ["grib1", "grib2"], - "access_ids": ["dissemination"], - "published": True, - "pending": False, - "retired": False, - }, - { - "id": 2, - "name": "Velocity potential", - "shortname": "vp", - "unit_id": 1, - "encoding_ids": ["grib1", "grib2"], - "access_ids": ["dissemination"], - "published": True, - "pending": False, - "retired": False, - }, - { - "id": 10, - "name": "Wind speed", - "shortname": "ws", - "unit_id": 5, - "encoding_ids": ["grib1", "grib2"], - "access_ids": ["dissemination"], - "published": True, - "pending": False, - "retired": False, - }, - { - "id": 31, - "name": "Sea ice area fraction", - "shortname": "ci", - "unit_id": 3, - "encoding_ids": ["grib1", "grib2"], - "access_ids": ["dissemination"], - "published": True, - "pending": False, - "retired": False, - }, - { - "id": 34, - "name": "Sea surface temperature", - "shortname": "sst", - "unit_id": 2, - "encoding_ids": ["grib1", "grib2"], - "access_ids": ["dissemination"], - "published": True, - "pending": False, - "retired": False, - }, - { - "id": 54, - "name": "Pressure", - "shortname": "pres", - "unit_id": 16, - "encoding_ids": ["grib1", "grib2"], - "access_ids": ["dissemination"], - "published": True, - "pending": False, - "retired": False, - }, - { - "id": 59, - "name": "Convective available potential energy", - "shortname": "cape", - "unit_id": 17, - "encoding_ids": ["grib1", "grib2"], - "access_ids": [], - "published": True, - "pending": False, - "retired": False, - }, - { - "id": 228059, - "name": "Convective available potential energy", - "shortname": "cape", - "unit_id": 17, - "encoding_ids": ["grib1", "grib2"], - "access_ids": [], - "published": True, - "pending": False, - "retired": False, - }, -] - - -@pytest.fixture() -def local_cache_file(tmp_path): - """Write LOCAL_CACHE_DATA to a temporary JSON file and return its path.""" - cache_file = tmp_path / "parameters.json" - cache_file.write_text(json.dumps(LOCAL_CACHE_DATA)) - return str(cache_file) - - -@contextmanager -def _patch_local_cache(grib_mod, cache_path): - """Temporarily set SETTINGS.paramdb.local_cache on the grib module and clear functools caches.""" - old_value = grib_mod.SETTINGS.paramdb.local_cache - grib_mod.SETTINGS.paramdb.local_cache = Path(cache_path) if cache_path else None - # Clear functools.cache so results are not stale across tests - grib_mod._local_search_param.cache_clear() - grib_mod._get_local_db.cache_clear() - try: - yield - finally: - grib_mod.SETTINGS.paramdb.local_cache = old_value - grib_mod._local_search_param.cache_clear() - grib_mod._get_local_db.cache_clear() - - -class TestLocalCacheSearch: - """Tests for _local_search_param and the SETTINGS.paramdb.local_cache routing in _search_param.""" - - def test_local_search_param_returns_single_match(self, local_cache_file): - """_local_search_param returns a one-element list for a known shortname.""" - grib = _grib() - with _patch_local_cache(grib, local_cache_file): - results = grib._local_search_param("sst") - assert len(results) == 1 - assert results[0]["shortname"] == "sst" - assert results[0]["id"] == 34 - - def test_local_search_param_raises_on_missing(self, local_cache_file): - """_local_search_param raises KeyError for an unknown shortname.""" - grib = _grib() - with _patch_local_cache(grib, local_cache_file): - with pytest.raises(KeyError, match="not found in local cache"): - grib._local_search_param("nonexistent_param_xyz") - - @patch("anemoi.utils.grib.requests.get") - def test_search_param_local_cache_no_network(self, mock_get, local_cache_file): - """Verify requests.get is never called when local_cache is configured.""" - grib = _grib() - with _patch_local_cache(grib, local_cache_file): - result = grib._search_param("ws") - mock_get.assert_not_called() - assert result["shortname"] == "ws" - assert result["id"] == 10 - - @patch("anemoi.utils.grib.requests.get") - def test_shortname_to_paramid_via_local_cache(self, mock_get, local_cache_file): - """End-to-end: shortname_to_paramid works with the local cache.""" - grib = _grib() - with _patch_local_cache(grib, local_cache_file): - assert grib.shortname_to_paramid("ci") == 31 - assert grib.shortname_to_paramid("sst") == 34 - assert grib.shortname_to_paramid("pres") == 54 - mock_get.assert_not_called() - - @patch("anemoi.utils.grib.requests.get") - def test_paramid_to_shortname_via_local_cache(self, mock_get, local_cache_file): - """_search_param finds entries via local cache for reverse lookups.""" - grib = _grib() - with _patch_local_cache(grib, local_cache_file): - result = grib._search_param("sst") - assert result["shortname"] == "sst" - mock_get.assert_not_called() - - @patch("anemoi.utils.grib.warnings.warn") - def test_local_cache_filters_ignored_with_warning(self, mock_warning, local_cache_file): - """When local_cache is set, passing filters emits a warning.""" - grib = _grib() - with _patch_local_cache(grib, local_cache_file): - result = grib._search_param("sst", origin=98) - mock_warning.assert_called() - warning_msg = mock_warning.call_args[0][0] - assert "ignored" in warning_msg.lower() or "Filters" in warning_msg - assert result["id"] == 34 - - def test_local_search_multiple_params(self, local_cache_file): - """Verify several known parameters from the mock cache are found.""" - grib = _grib() - with _patch_local_cache(grib, local_cache_file): - expected = { - "strf": 1, - "vp": 2, - "ws": 10, - } - for shortname, expected_id in expected.items(): - results = grib._local_search_param(shortname) - assert len(results) == 1 - assert ( - results[0]["id"] == expected_id - ), f"Expected {shortname} -> id={expected_id}, got {results[0]['id']}" +class TestUnits: + @pytest.mark.parametrize( + "param,expected_unit", + [ + ("2t", "K"), + ("tp", "m"), + ("sp", "Pa"), + ("10u", "m s**-1"), + (167, "K"), + (228, "m"), + ], + ) + def test_known_units(self, grib, param, expected_unit): + result = grib.units(param) + assert result == expected_unit + assert isinstance(result, str) + + +# --------------------------------------------------------------------------- +# must_be_positive +# --------------------------------------------------------------------------- + + +class TestMustBePositive: + @pytest.mark.parametrize( + "param,expected", + [ + ("tp", True), # units: m + ("crwc", True), # units: kg kg**-1 + ("sf", True), # units: m of water equivalent + ("2t", False), # units: K + ("sp", False), # units: Pa + ("10u", False), # units: m s**-1 + ("cape", False), # units: J kg**-1 + ], + ) + def test_positive_classification(self, grib, param, expected): + result = grib.must_be_positive(param) + assert result is expected + assert isinstance(result, bool) + + def test_accepts_paramid(self, grib): + assert grib.must_be_positive(228) is True # tp, units: m + assert grib.must_be_positive(167) is False # 2t, units: K From 858400078c5a1e13971cfb0842261e93c0669574 Mon Sep 17 00:00:00 2001 From: Harrison Date: Fri, 5 Jun 2026 12:04:39 +0100 Subject: [PATCH 2/2] fix: test parameterisation --- tests/test_grib.py | 57 ++++++++++++++++------------------------------ 1 file changed, 20 insertions(+), 37 deletions(-) diff --git a/tests/test_grib.py b/tests/test_grib.py index 47763de..d92a878 100644 --- a/tests/test_grib.py +++ b/tests/test_grib.py @@ -19,9 +19,7 @@ from __future__ import annotations -import sys from pathlib import Path -from unittest.mock import patch import pytest from pymetkit import ParamDB @@ -32,45 +30,30 @@ FIXTURE_YAML = Path(__file__).parent / "parameter_metadata_test.yaml" -_bundled_xfail = pytest.param( - "bundled", - marks=pytest.mark.xfail( - raises=FileNotFoundError, - reason="pymetkit bundled YAML not installed in this environment", - strict=False, - ), -) +@pytest.fixture(scope="session") +def _local_paramdb(): + return ParamDB(mode="offline", yaml_path=FIXTURE_YAML.resolve()) -@pytest.fixture(params=["local", _bundled_xfail]) -def grib(request, monkeypatch): - """Import anemoi.utils.grib and swap PARAMDB to the requested backend. - For 'local', the module is imported with settings pointing at our test - YAML. For 'bundled', PARAMDB is replaced with a ParamDB using pymetkit's - default YAML lookup (which may not exist). +@pytest.fixture(scope="session") +def _offline_paramdb(): + return ParamDB(mode="offline") + + +@pytest.fixture(params=["local", "bundled"]) +def grib(request, monkeypatch, _local_paramdb, _offline_paramdb): + """Swap anemoi.utils.grib.PARAMDB to the requested backend. + + 'local' reuses a session-scoped ParamDB pointing at our test YAML. + 'bundled' constructs one from pymetkit's default YAML lookup (which may + not exist, in which case the test xfails on FileNotFoundError). """ - from anemoi.utils.settings import AnemoiSettings - - yaml_path = FIXTURE_YAML.resolve() - settings = AnemoiSettings() - settings.paramdb.local_data = yaml_path - - # Ensure a clean import of the grib module with patched settings - cached = sys.modules.pop("anemoi.utils.grib", None) - try: - with patch("anemoi.utils.settings.AnemoiSettings", return_value=settings): - import anemoi.utils.grib as grib_mod - - if request.param == "bundled": - bundled = ParamDB(mode="offline") - monkeypatch.setattr(grib_mod, "PARAMDB", bundled) - - yield grib_mod - finally: - sys.modules.pop("anemoi.utils.grib", None) - if cached is not None: - sys.modules["anemoi.utils.grib"] = cached + import anemoi.utils.grib as grib_mod + + db = _local_paramdb if request.param == "local" else _offline_paramdb + monkeypatch.setattr(grib_mod, "PARAMDB", db) + return grib_mod # ---------------------------------------------------------------------------