From 4cbea2e845eaf15d4b08b3daffe6a258ee3d2907 Mon Sep 17 00:00:00 2001 From: Vignesh Raghunathan Date: Tue, 10 Feb 2026 09:34:44 +0100 Subject: [PATCH 01/10] Fix dtype for pandas --- message_ix_models/model/water/data/water_for_ppl.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/message_ix_models/model/water/data/water_for_ppl.py b/message_ix_models/model/water/data/water_for_ppl.py index 77d907b147..338c77d28a 100644 --- a/message_ix_models/model/water/data/water_for_ppl.py +++ b/message_ix_models/model/water/data/water_for_ppl.py @@ -121,9 +121,11 @@ def _compute_cooling_rates(input_cool: pd.DataFrame) -> pd.DataFrame: # Cooling fraction: heat to be rejected input_cool["cooling_fraction"] = input_cool.apply( - lambda r: r["value"] - 1 - if "hpl" in str(r.get("parent_tech", "")) - else r["value"] * (1 - flue_loss) - 1, + lambda r: ( + r["value"] - 1 + if "hpl" in str(r.get("parent_tech", "")) + else r["value"] * (1 - flue_loss) - 1 + ), axis=1, ) From d2bcef3be90b44cbdebebfb6f89a171eb758bee6 Mon Sep 17 00:00:00 2001 From: Vignesh Raghunathan Date: Wed, 27 Aug 2025 21:05:34 +0200 Subject: [PATCH 02/10] Add basin filtering Use basin filtering everywhere More filtering More filtering Fix accidental removal Add basin reduction options Clean up filtering code Add min 1 basin region pair req Add reg_to_basin return tech consolidate return logic Fix failing tests Tests failure due to missing valid_basins list. Adding list to context to fix issue. --- message_ix_models/model/water/build.py | 8 +- message_ix_models/model/water/cli.py | 33 +++++++- message_ix_models/model/water/data/demands.py | 37 +++++++-- .../model/water/data/infrastructure.py | 13 +++ .../model/water/data/irrigation.py | 4 + .../model/water/data/water_supply.py | 15 ++++ message_ix_models/model/water/utils.py | 82 +++++++++++++++++++ .../tests/model/water/conftest.py | 33 ++++++++ .../tests/model/water/data/test_irrigation.py | 4 + .../model/water/data/test_water_for_ppl.py | 12 +-- 10 files changed, 224 insertions(+), 17 deletions(-) diff --git a/message_ix_models/model/water/build.py b/message_ix_models/model/water/build.py index 811d8152ca..795b65a397 100644 --- a/message_ix_models/model/water/build.py +++ b/message_ix_models/model/water/build.py @@ -10,7 +10,7 @@ from message_ix_models.model.structure import get_codes from message_ix_models.util import broadcast, package_data_path -from .utils import read_config +from .utils import filter_basins_by_region, read_config log = logging.getLogger(__name__) @@ -556,6 +556,10 @@ def map_basin(context: Context) -> Mapping[str, ScenarioInfo]: PATH = package_data_path("water", "delineation", FILE) df = pd.read_csv(PATH) + + # Apply basin filter to reduce number of basins per region + df = filter_basins_by_region(df, context) + # Assigning proper nomenclature df["node"] = "B" + df["BCU_name"].astype(str) df["mode"] = "M" + df["BCU_name"].astype(str) @@ -577,6 +581,8 @@ def map_basin(context: Context) -> Mapping[str, ScenarioInfo]: results["map_node"] = nodes context.all_nodes = df["node"] + # Store the filtered basin names for use in other functions + context.valid_basins = set(df["BCU_name"].astype(str)) for set_name, config in results.items(): # Sets to add diff --git a/message_ix_models/model/water/cli.py b/message_ix_models/model/water/cli.py index a485bbd390..0861de37d3 100644 --- a/message_ix_models/model/water/cli.py +++ b/message_ix_models/model/water/cli.py @@ -123,13 +123,44 @@ def water_ini(context: "Context", regions, time): is_flag=True, help="Defines whether the model solves with macro", ) +@click.option( + "--reduced-basin/--no-reduced-basin", + default=False, + help="Enable reduced basin filtering", +) +@click.option( + "--filter-list", + multiple=True, + help="Specific basins to include (can be used multiple times)", +) +@click.option( + "--num-basins", + type=int, + help="Number of basins per region to keep when reduced-basin is enabled", +) @common_params("regions") @scenario_param("--ssp") -def nexus_cli(context: "Context", regions, rcps, sdgs, rels, macro=False): +def nexus_cli( + context: "Context", + regions, + rcps, + sdgs, + rels, + macro=False, + reduced_basin=False, + filter_list=None, + num_basins=None, +): """ Add basin structure connected to the energy sector and water balance linking different water demands to supply. """ + # Set basin filtering configuration on context + context.reduced_basin = reduced_basin + if filter_list: + context.filter_list = list(filter_list) + if num_basins is not None: + context.num_basins = num_basins nexus(context, regions, rcps, sdgs, rels, macro) diff --git a/message_ix_models/model/water/data/demands.py b/message_ix_models/model/water/data/demands.py index 25f30caaa2..2469272559 100644 --- a/message_ix_models/model/water/data/demands.py +++ b/message_ix_models/model/water/data/demands.py @@ -216,6 +216,9 @@ def add_sectoral_demands(context: "Context") -> dict[str, pd.DataFrame]: df_dmds.sort_values(["year", "node", "variable", "value"], inplace=True) df_dmds["time"] = "year" + + # Filter to only include basins that exist after basin filtering + df_dmds = df_dmds[df_dmds["node"].isin(context.valid_basins)] # Write final interpolated values as csv # df2_f.to_csv('final_interpolated_values.csv') @@ -238,6 +241,9 @@ def add_sectoral_demands(context: "Context") -> dict[str, pd.DataFrame]: ) df_m = df_m[["year", "pid", "variable", "value", "month"]] df_m.columns = pd.Index(["year", "node", "variable", "value", "time"]) + + # Filter monthly data to only include valid basins + df_m = df_m[df_m["node"].isin(context.valid_basins)] # remove yearly parts from df_dms df_dmds = df_dmds[ @@ -767,13 +773,11 @@ def read_water_availability(context: "Context") -> Sequence[pd.DataFrame]: "water", "delineation", f"basins_by_region_simpl_{context.regions}.csv" ) df_x = pd.read_csv(PATH) + + # Filter to only include valid basins + df_x = df_x[df_x["BCU_name"].isin(context.valid_basins)] if "year" in context.time: - # path for reading basin delineation file - PATH = package_data_path( - "water", "delineation", f"basins_by_region_simpl_{context.regions}.csv" - ) - df_x = pd.read_csv(PATH) # Adding freshwater supply constraints # Reading data, the data is spatially and temprally aggregated from GHMs path1 = package_data_path( @@ -784,6 +788,14 @@ def read_water_availability(context: "Context") -> Sequence[pd.DataFrame]: # Read rcp 2.6 data df_sw = pd.read_csv(path1) df_sw.drop(["Unnamed: 0"], axis=1, inplace=True) + + # Filter columns to only include valid basins + # The columns are years, so we need to filter rows based on the original basin order + # First, get the indices of valid basins from the original full list + full_basin_df = pd.read_csv(PATH) # Read full basin list again + valid_indices = full_basin_df[full_basin_df["BCU_name"].isin(context.valid_basins)].index + df_sw = df_sw.iloc[valid_indices] # Keep only rows for valid basins + df_sw.reset_index(drop=True, inplace=True) df_sw.index = df_x["BCU_name"].index df_sw = df_sw.stack().reset_index() @@ -809,6 +821,11 @@ def read_water_availability(context: "Context") -> Sequence[pd.DataFrame]: # Read groundwater data df_gw = pd.read_csv(path1) df_gw.drop(["Unnamed: 0"], axis=1, inplace=True) + + # Filter to only include valid basins (same as df_sw) + df_gw = df_gw.iloc[valid_indices] # Use same valid_indices from above + df_gw.reset_index(drop=True, inplace=True) + df_gw.index = df_x["BCU_name"].index df_gw = df_gw.stack().reset_index() df_gw.columns = pd.Index(["Region", "years", "value"]) @@ -832,6 +849,12 @@ def read_water_availability(context: "Context") -> Sequence[pd.DataFrame]: ) df_sw = pd.read_csv(path1) df_sw.drop(["Unnamed: 0"], axis=1, inplace=True) + + # Filter to only include valid basins + full_basin_df = pd.read_csv(PATH) # Read full basin list again + valid_indices = full_basin_df[full_basin_df["BCU_name"].isin(context.valid_basins)].index + df_sw = df_sw.iloc[valid_indices] + df_sw.reset_index(drop=True, inplace=True) df_sw.index = df_x["BCU_name"].index df_sw = df_sw.stack().reset_index() @@ -855,6 +878,10 @@ def read_water_availability(context: "Context") -> Sequence[pd.DataFrame]: ) df_gw = pd.read_csv(path1) df_gw.drop(["Unnamed: 0"], axis=1, inplace=True) + + # Filter to only include valid basins (same as df_sw) + df_gw = df_gw.iloc[valid_indices] # Use same valid_indices from above + df_gw.reset_index(drop=True, inplace=True) df_gw.index = df_x["BCU_name"].index df_gw = df_gw.stack().reset_index() diff --git a/message_ix_models/model/water/data/infrastructure.py b/message_ix_models/model/water/data/infrastructure.py index 10a929bf1b..5953d27c25 100644 --- a/message_ix_models/model/water/data/infrastructure.py +++ b/message_ix_models/model/water/data/infrastructure.py @@ -291,6 +291,10 @@ def add_infrastructure_techs(context: "Context") -> dict[str, pd.DataFrame]: PATH = package_data_path("water", "delineation", FILE2) df_node = pd.read_csv(PATH) + + # Filter to only valid basins (already filtered in map_basin) + df_node = df_node[df_node["BCU_name"].isin(context.valid_basins)] + # Assigning proper nomenclature df_node["node"] = "B" + df_node["BCU_name"].astype(str) df_node["mode"] = "M" + df_node["BCU_name"].astype(str) @@ -807,6 +811,10 @@ def add_desalination(context: "Context") -> dict[str, pd.DataFrame]: PATH = package_data_path("water", "delineation", FILE2) df_node = pd.read_csv(PATH) + + # Filter to only valid basins (already filtered in map_basin) + df_node = df_node[df_node["BCU_name"].isin(context.valid_basins)] + # Assigning proper nomenclature df_node["node"] = "B" + df_node["BCU_name"].astype(str) df_node["mode"] = "M" + df_node["BCU_name"].astype(str) @@ -815,6 +823,11 @@ def add_desalination(context: "Context") -> dict[str, pd.DataFrame]: if context.type_reg == "country" else f"{context.regions}_" + df_node["REGION"].astype(str) ) + + # Filter historical and projected data to only include basins that exist after filtering + df_hist = df_hist[df_hist["BCU_name"].isin(context.valid_basins)] + df_proj = df_proj[df_proj["BCU_name"].isin(context.valid_basins)] + # output dataframe linking to desal tech types out_df = ( make_df( diff --git a/message_ix_models/model/water/data/irrigation.py b/message_ix_models/model/water/data/irrigation.py index efabe6e2c4..e6046a86a7 100644 --- a/message_ix_models/model/water/data/irrigation.py +++ b/message_ix_models/model/water/data/irrigation.py @@ -32,6 +32,10 @@ def add_irr_structure(context: "Context") -> dict[str, pd.DataFrame]: FILE2 = f"basins_by_region_simpl_{context.regions}.csv" PATH = package_data_path("water", "delineation", FILE2) df_node = pd.read_csv(PATH) + + # Filter to only include valid basins + df_node = df_node[df_node["BCU_name"].isin(context.valid_basins)] + # Assigning proper nomenclature df_node["node"] = "B" + df_node["BCU_name"].astype(str) df_node["mode"] = "M" + df_node["BCU_name"].astype(str) diff --git a/message_ix_models/model/water/data/water_supply.py b/message_ix_models/model/water/data/water_supply.py index fb79e9ab30..d27c22cf05 100644 --- a/message_ix_models/model/water/data/water_supply.py +++ b/message_ix_models/model/water/data/water_supply.py @@ -11,6 +11,7 @@ KM3_TO_MCM, USD_KM3_TO_USD_MCM, GWa_KM3_TO_GWa_MCM, + filter_basins_by_region, get_vintage_and_active_years, ) from message_ix_models.util import ( @@ -42,6 +43,9 @@ def map_basin_region_wat(context: "Context") -> pd.DataFrame: "water", "delineation", f"basins_by_region_simpl_{context.regions}.csv" ) df_x = pd.read_csv(PATH) + + # Filter to only include valid basins + df_x = df_x[df_x["BCU_name"].isin(context.valid_basins)] # Adding freshwater supply constraints # Reading data, the data is spatially and temprally aggregated from GHMs path1 = package_data_path( @@ -94,6 +98,9 @@ def map_basin_region_wat(context: "Context") -> pd.DataFrame: "water", "delineation", f"basins_by_region_simpl_{context.regions}.csv" ) df_x = pd.read_csv(PATH) + + # Filter to only include valid basins + df_x = df_x[df_x["BCU_name"].isin(context.valid_basins)] # Reading data, the data is spatially and temporally aggregated from GHMs df_sw["BCU_name"] = df_x["BCU_name"] @@ -165,6 +172,10 @@ def add_water_supply(context: "Context") -> dict[str, pd.DataFrame]: PATH = package_data_path("water", "delineation", FILE) df_node = pd.read_csv(PATH) + + # Apply basin filter to reduce number of basins per region + df_node = filter_basins_by_region(df_node, context) + # Assigning proper nomenclature df_node["node"] = "B" + df_node["BCU_name"].astype(str) df_node["mode"] = "M" + df_node["BCU_name"].astype(str) @@ -191,6 +202,10 @@ def add_water_supply(context: "Context") -> dict[str, pd.DataFrame]: FILE2 = f"historical_new_cap_gw_sw_km3_year_{context.regions}.csv" PATH2 = package_data_path("water", "availability", FILE2) df_hist = pd.read_csv(PATH2) + + # Filter to only include valid basins + df_hist = df_hist[df_hist["BCU_name"].isin(context.valid_basins)] + df_hist["BCU_name"] = "B" + df_hist["BCU_name"].astype(str) if context.nexus_set == "cooling": diff --git a/message_ix_models/model/water/utils.py b/message_ix_models/model/water/utils.py index 1641b24c7d..3bc3e3fcd0 100644 --- a/message_ix_models/model/water/utils.py +++ b/message_ix_models/model/water/utils.py @@ -78,6 +78,88 @@ def read_config(context: Context | None = None): return context +def filter_basins_by_region( + df_basins: pd.DataFrame, + context: Optional[Context] = None, + n_per_region: int = 3, +) -> pd.DataFrame: + """Filter basins based on context configuration. + + Parameters + ---------- + df_basins : pd.DataFrame + DataFrame with basin data including 'REGION' and 'BCU_name' columns + context : Context, optional + Context object that may contain basin filtering configuration + n_per_region : int, default 3 + Default number of basins to keep per region (used as fallback) + + Returns + ------- + pd.DataFrame + Filtered DataFrame based on configuration + """ + if not context: + context = Context.get_instance(-1) + + # Check if reduced basin filtering is enabled + reduced_basin = getattr(context, 'reduced_basin', False) + + if not reduced_basin: + # No filtering, return original dataframe + log.info("Basin filtering disabled, returning all basins") + return df_basins + + # Basin filtering is enabled + filter_list = getattr(context, 'filter_list', None) + num_basins = getattr(context, 'num_basins', None) + + if filter_list: + # Filter to specific basin list + filtered = df_basins[df_basins['BCU_name'].isin(filter_list)] + + # Check if we have at least 1 basin per R12 region + all_regions = set(df_basins['REGION'].unique()) + filtered_regions = set(filtered['REGION'].unique()) + missing_regions = all_regions - filtered_regions + + if missing_regions: + log.info(f"Adding one basin per missing region: {missing_regions}") + # Add one basin from each missing region + for region in missing_regions: + region_basins = df_basins[df_basins['REGION'] == region] + # Add the first basin from this region + filtered = pd.concat( + [filtered, region_basins.head(1)], ignore_index=True + ) + + log.info( + f"Filtered basins from {len(df_basins)} to {len(filtered)} " + f"using custom filter list: {filter_list} (with 1 basin per missing region)" + ) + + return filtered.reset_index(drop=True) + + elif num_basins is not None: + # Use specified number of basins per region + n_per_region = num_basins + # else: use function default n_per_region + + # Group by region and take first n rows from each group + if 'REGION' not in df_basins.columns: + log.info("REGION column not found, cannot filter by region") + return df_basins + + filtered = df_basins.groupby('REGION', group_keys=False).apply( + lambda x: x.head(n_per_region) + ).reset_index(drop=True) + + log.info(f"Filtered basins from {len(df_basins)} to {len(filtered)} " + f"(keeping first {n_per_region} per region)") + + return filtered + + @lru_cache() def map_add_on(rtype=Code): """Map addon & type_addon in ``sets.yaml``.""" diff --git a/message_ix_models/tests/model/water/conftest.py b/message_ix_models/tests/model/water/conftest.py index b784053590..4a47f28388 100644 --- a/message_ix_models/tests/model/water/conftest.py +++ b/message_ix_models/tests/model/water/conftest.py @@ -6,6 +6,36 @@ from message_ix_models import ScenarioInfo from message_ix_models.model.structure import get_codes +from message_ix_models.util import package_data_path + + +def setup_valid_basins(context, regions="R12"): + """Set up valid_basins attribute for test contexts. + + This helper function ensures that test contexts have the valid_basins + attribute that is normally set by the map_basin() function during + model building. This is required for basin filtering functionality. + + Parameters + ---------- + context : Context + Test context object that needs valid_basins attribute + regions : str, default "R12" + Region code for basin delineation file + """ + from message_ix_models.model.water.utils import filter_basins_by_region + + basin_file = f"basins_by_region_simpl_{regions}.csv" + basin_path = package_data_path("water", "delineation", basin_file) + df_basins = pd.read_csv(basin_path) + + # Apply basin filtering if enabled + df_filtered = filter_basins_by_region(df_basins, context) + + # Set valid_basins as set of basin names + context.valid_basins = set(df_filtered["BCU_name"].astype(str)) + + return context @pytest.fixture @@ -36,6 +66,9 @@ def water_context(test_context, request): nodes = list(map(str, nodes[nodes.index("World")].child)) test_context.map_ISO_c = {test_context.regions: nodes[0]} + # Set up valid_basins for basin filtering + setup_valid_basins(test_context, regions=test_context.regions) + return test_context diff --git a/message_ix_models/tests/model/water/data/test_irrigation.py b/message_ix_models/tests/model/water/data/test_irrigation.py index 81d29f48f5..ab1ee98cfd 100644 --- a/message_ix_models/tests/model/water/data/test_irrigation.py +++ b/message_ix_models/tests/model/water/data/test_irrigation.py @@ -3,6 +3,7 @@ from message_ix_models import ScenarioInfo from message_ix_models.model.structure import get_codes from message_ix_models.model.water.data.irrigation import add_irr_structure +from message_ix_models.tests.model.water.conftest import setup_valid_basins def test_add_irr_structure(test_context): @@ -30,6 +31,9 @@ def test_add_irr_structure(test_context): # FIXME same as above test_context["water build info"] = ScenarioInfo(s) + # Set up valid_basins for basin filtering + setup_valid_basins(test_context, regions=test_context.regions) + # Call the function to be tested result = add_irr_structure(test_context) diff --git a/message_ix_models/tests/model/water/data/test_water_for_ppl.py b/message_ix_models/tests/model/water/data/test_water_for_ppl.py index 74b4ad1c27..c57d4fd97e 100644 --- a/message_ix_models/tests/model/water/data/test_water_for_ppl.py +++ b/message_ix_models/tests/model/water/data/test_water_for_ppl.py @@ -7,7 +7,7 @@ cool_tech, non_cooling_tec, ) -from message_ix_models.util import package_data_path +from message_ix_models.tests.model.water.conftest import setup_valid_basins def _get_test_node(regions: str) -> str: @@ -21,14 +21,6 @@ def _get_test_node(regions: str) -> str: return regions -def _setup_valid_basins(context) -> None: - """Set up valid_basins from basin delineation file.""" - basin_file = f"basins_by_region_simpl_{context.regions}.csv" - basin_path = package_data_path("water", "delineation", basin_file) - df_basins = pd.read_csv(basin_path) - context.valid_basins = set(df_basins["BCU_name"].astype(str)) - - @pytest.mark.usefixtures("ssp_user_data") @pytest.mark.parametrize( "water_context", @@ -139,7 +131,7 @@ def test_cool_tec(request, water_context, assert_message_params): water_context["water build info"] = ScenarioInfo(scenario_obj=s) # Set up valid_basins for water_for_ppl functions - _setup_valid_basins(water_context) + setup_valid_basins(water_context, regions=water_context.regions) result = cool_tech(context=water_context) From 3dab99449341a92d48daf1cc4880650787dc7315 Mon Sep 17 00:00:00 2001 From: Vignesh Raghunathan Date: Tue, 20 Jan 2026 14:04:38 +0100 Subject: [PATCH 03/10] Fix basin filtering for cooling mode --- message_ix_models/model/water/data/water_supply.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/message_ix_models/model/water/data/water_supply.py b/message_ix_models/model/water/data/water_supply.py index d27c22cf05..27bfa79ed3 100644 --- a/message_ix_models/model/water/data/water_supply.py +++ b/message_ix_models/model/water/data/water_supply.py @@ -202,11 +202,11 @@ def add_water_supply(context: "Context") -> dict[str, pd.DataFrame]: FILE2 = f"historical_new_cap_gw_sw_km3_year_{context.regions}.csv" PATH2 = package_data_path("water", "availability", FILE2) df_hist = pd.read_csv(PATH2) - - # Filter to only include valid basins - df_hist = df_hist[df_hist["BCU_name"].isin(context.valid_basins)] - - df_hist["BCU_name"] = "B" + df_hist["BCU_name"].astype(str) + + # Filter to only include valid basins (nexus mode only) + if context.nexus_set == "nexus": + df_hist = df_hist[df_hist["BCU_name"].isin(context.valid_basins)] + df_hist["BCU_name"] = "B" + df_hist["BCU_name"].astype(str) if context.nexus_set == "cooling": # Add output df for surfacewater supply for regions From 57550bca058b3d2ebe9575f05bc0d470aeb2accd Mon Sep 17 00:00:00 2001 From: Vignesh Raghunathan Date: Tue, 10 Feb 2026 17:31:20 +0100 Subject: [PATCH 04/10] Add demand/supply stress-based basin selection --- doc/whatsnew.rst | 5 + message_ix_models/model/water/cli.py | 8 + .../model/water/data/water_supply.py | 31 +-- message_ix_models/model/water/utils.py | 193 ++++++++++++++++-- .../tests/model/water/test_utils.py | 87 ++++++++ 5 files changed, 294 insertions(+), 30 deletions(-) diff --git a/doc/whatsnew.rst b/doc/whatsnew.rst index 40f645738d..4ee9d0876c 100644 --- a/doc/whatsnew.rst +++ b/doc/whatsnew.rst @@ -8,6 +8,11 @@ Next release released 2026-01-21 (:pull:`470`). - New module :mod:`tools.bilateralize ` to change scenarios to a bilateral representation of trade (:pull:`438`). + +- Add reduced basin filtering for water module with ``--reduced-basin`` and + demand/supply stress-based selection via ``--basin-selection stress`` + (:issue:`414`). + - Fix water module parameter bugs and refactor cooling (:pull:`405`): infrastructure M1/Mf mode fixes, regional average shares for cooling allocation, water supply level hierarchy corrections, and test suite improvements. diff --git a/message_ix_models/model/water/cli.py b/message_ix_models/model/water/cli.py index 0861de37d3..d85b3446fb 100644 --- a/message_ix_models/model/water/cli.py +++ b/message_ix_models/model/water/cli.py @@ -138,6 +138,12 @@ def water_ini(context: "Context", regions, time): type=int, help="Number of basins per region to keep when reduced-basin is enabled", ) +@click.option( + "--basin-selection", + type=click.Choice(["first_k", "stress"]), + default="first_k", + help="Basin selection: first_k (CSV order) or stress (demand/supply span)", +) @common_params("regions") @scenario_param("--ssp") def nexus_cli( @@ -150,6 +156,7 @@ def nexus_cli( reduced_basin=False, filter_list=None, num_basins=None, + basin_selection="first_k", ): """ Add basin structure connected to the energy sector and @@ -161,6 +168,7 @@ def nexus_cli( context.filter_list = list(filter_list) if num_basins is not None: context.num_basins = num_basins + context.basin_selection = basin_selection nexus(context, regions, rcps, sdgs, rels, macro) diff --git a/message_ix_models/model/water/data/water_supply.py b/message_ix_models/model/water/data/water_supply.py index 27bfa79ed3..a4d070a043 100644 --- a/message_ix_models/model/water/data/water_supply.py +++ b/message_ix_models/model/water/data/water_supply.py @@ -42,10 +42,13 @@ def map_basin_region_wat(context: "Context") -> pd.DataFrame: PATH = package_data_path( "water", "delineation", f"basins_by_region_simpl_{context.regions}.csv" ) - df_x = pd.read_csv(PATH) - - # Filter to only include valid basins - df_x = df_x[df_x["BCU_name"].isin(context.valid_basins)] + df_x_full = pd.read_csv(PATH) + + # Get positional indices of valid basins from the unfiltered list + valid_mask = df_x_full["BCU_name"].isin(context.valid_basins) + valid_indices = df_x_full[valid_mask].index + df_x = df_x_full[valid_mask].reset_index(drop=True) + # Adding freshwater supply constraints # Reading data, the data is spatially and temprally aggregated from GHMs path1 = package_data_path( @@ -57,7 +60,8 @@ def map_basin_region_wat(context: "Context") -> pd.DataFrame: df_sw = pd.read_csv(path1) df_sw.drop(["Unnamed: 0"], axis=1, inplace=True) - # Reading data, the data is spatially and temporally aggregated from GHMs + # Filter df_sw to matching positional rows, then reset both indices + df_sw = df_sw.iloc[valid_indices].reset_index(drop=True) df_sw["BCU_name"] = df_x["BCU_name"] df_sw["MSGREG"] = ( context.map_ISO_c[context.regions] @@ -97,12 +101,15 @@ def map_basin_region_wat(context: "Context") -> pd.DataFrame: PATH = package_data_path( "water", "delineation", f"basins_by_region_simpl_{context.regions}.csv" ) - df_x = pd.read_csv(PATH) - - # Filter to only include valid basins - df_x = df_x[df_x["BCU_name"].isin(context.valid_basins)] + df_x_full = pd.read_csv(PATH) - # Reading data, the data is spatially and temporally aggregated from GHMs + # Get positional indices of valid basins from the unfiltered list + valid_mask = df_x_full["BCU_name"].isin(context.valid_basins) + valid_indices = df_x_full[valid_mask].index + df_x = df_x_full[valid_mask].reset_index(drop=True) + + # Filter df_sw to matching positional rows, then reset both indices + df_sw = df_sw.iloc[valid_indices].reset_index(drop=True) df_sw["BCU_name"] = df_x["BCU_name"] df_sw["MSGREG"] = ( @@ -172,10 +179,10 @@ def add_water_supply(context: "Context") -> dict[str, pd.DataFrame]: PATH = package_data_path("water", "delineation", FILE) df_node = pd.read_csv(PATH) - + # Apply basin filter to reduce number of basins per region df_node = filter_basins_by_region(df_node, context) - + # Assigning proper nomenclature df_node["node"] = "B" + df_node["BCU_name"].astype(str) df_node["mode"] = "M" + df_node["BCU_name"].astype(str) diff --git a/message_ix_models/model/water/utils.py b/message_ix_models/model/water/utils.py index 3bc3e3fcd0..8e4141df58 100644 --- a/message_ix_models/model/water/utils.py +++ b/message_ix_models/model/water/utils.py @@ -12,7 +12,7 @@ from message_ix_models import Context from message_ix_models.model.structure import get_codes -from message_ix_models.util import load_package_data +from message_ix_models.util import load_package_data, package_data_path log = logging.getLogger(__name__) @@ -80,7 +80,7 @@ def read_config(context: Context | None = None): def filter_basins_by_region( df_basins: pd.DataFrame, - context: Optional[Context] = None, + context: Context | None = None, n_per_region: int = 3, ) -> pd.DataFrame: """Filter basins based on context configuration. @@ -103,7 +103,7 @@ def filter_basins_by_region( context = Context.get_instance(-1) # Check if reduced basin filtering is enabled - reduced_basin = getattr(context, 'reduced_basin', False) + reduced_basin = getattr(context, "reduced_basin", False) if not reduced_basin: # No filtering, return original dataframe @@ -111,23 +111,23 @@ def filter_basins_by_region( return df_basins # Basin filtering is enabled - filter_list = getattr(context, 'filter_list', None) - num_basins = getattr(context, 'num_basins', None) + filter_list = getattr(context, "filter_list", None) + num_basins = getattr(context, "num_basins", None) if filter_list: # Filter to specific basin list - filtered = df_basins[df_basins['BCU_name'].isin(filter_list)] + filtered = df_basins[df_basins["BCU_name"].isin(filter_list)] # Check if we have at least 1 basin per R12 region - all_regions = set(df_basins['REGION'].unique()) - filtered_regions = set(filtered['REGION'].unique()) + all_regions = set(df_basins["REGION"].unique()) + filtered_regions = set(filtered["REGION"].unique()) missing_regions = all_regions - filtered_regions if missing_regions: log.info(f"Adding one basin per missing region: {missing_regions}") # Add one basin from each missing region for region in missing_regions: - region_basins = df_basins[df_basins['REGION'] == region] + region_basins = df_basins[df_basins["REGION"] == region] # Add the first basin from this region filtered = pd.concat( [filtered, region_basins.head(1)], ignore_index=True @@ -140,26 +140,183 @@ def filter_basins_by_region( return filtered.reset_index(drop=True) - elif num_basins is not None: - # Use specified number of basins per region + # Check for stress-based selection mode + basin_selection = getattr(context, "basin_selection", "first_k") + + if basin_selection == "stress": + n = num_basins if num_basins is not None else n_per_region + ssp = getattr(context, "ssp", "SSP2") + stress_df = compute_basin_demand_ratio(context.regions, ssp=ssp) + selected = _select_by_stress(stress_df, n_per_region=n) + + filtered = df_basins[df_basins["BCU_name"].isin(selected)].reset_index( + drop=True + ) + log.info( + f"Stress-based selection: {len(df_basins)} -> {len(filtered)} basins " + f"(n_per_region={n})" + ) + return filtered + + if num_basins is not None: n_per_region = num_basins - # else: use function default n_per_region # Group by region and take first n rows from each group - if 'REGION' not in df_basins.columns: + if "REGION" not in df_basins.columns: log.info("REGION column not found, cannot filter by region") return df_basins - filtered = df_basins.groupby('REGION', group_keys=False).apply( - lambda x: x.head(n_per_region) - ).reset_index(drop=True) + filtered = ( + df_basins.groupby("REGION", group_keys=False) + .apply(lambda x: x.head(n_per_region)) + .reset_index(drop=True) + ) - log.info(f"Filtered basins from {len(df_basins)} to {len(filtered)} " - f"(keeping first {n_per_region} per region)") + log.info( + f"Filtered basins from {len(df_basins)} to {len(filtered)} " + f"(keeping first {n_per_region} per region)" + ) return filtered +def compute_basin_demand_ratio( + regions: str = "R12", + ssp: str = "SSP2", + demand_year: int = 2050, +) -> pd.DataFrame: + """Compute basin-level demand/supply ratio from pre-build CSV data. + + Demand = urban + rural + manufacturing withdrawals (MCM/year). + Supply = (surface water + groundwater recharge) mean across years (km3 -> MCM). + + Parameters + ---------- + regions : str + Region codelist (e.g. "R12"). + ssp : str + SSP scenario for demand file naming. + demand_year : int + Year to use for demand values (later years show higher stress). + + Returns + ------- + pd.DataFrame + Columns: BCU_name, REGION, supply_mcm, demand_mcm, demand_ratio. + """ + ssp_label = ssp.lower().replace("ssp", "ssp") # SSP2 -> ssp2 + + basins = pd.read_csv( + package_data_path( + "water", "delineation", f"basins_by_region_simpl_{regions}.csv" + ) + ) + + # Supply: surface + groundwater, mean across year columns, km3 -> MCM + qtot = pd.read_csv( + package_data_path( + "water", "availability", f"qtot_5y_no_climate_low_{regions}.csv" + ) + ).drop(columns=["Unnamed: 0"], errors="ignore") + qr = pd.read_csv( + package_data_path( + "water", "availability", f"qr_5y_no_climate_low_{regions}.csv" + ) + ).drop(columns=["Unnamed: 0"], errors="ignore") + supply_mcm = (qtot.mean(axis=1) + qr.mean(axis=1)) * KM3_TO_MCM + + # Demand: urban + rural + manufacturing withdrawals at demand_year + demand_path = package_data_path("water", "demands", "harmonized", regions) + demand_files = [ + f"{ssp_label}_regional_urban_withdrawal2_baseline.csv", + f"{ssp_label}_regional_rural_withdrawal_baseline.csv", + f"{ssp_label}_regional_manufacturing_withdrawal_baseline.csv", + ] + + total_demand = pd.Series(0.0, index=basins["BCU_name"].astype(str)) + for fname in demand_files: + df = pd.read_csv(demand_path / fname) + row = df[df.iloc[:, 0] == demand_year] + if row.empty: + log.warning(f"Year {demand_year} not found in {fname}") + continue + vals = row.iloc[0, 1:].astype(float) + # Align by basin name + for bcu in total_demand.index: + if bcu in vals.index: + total_demand[bcu] += vals[bcu] + + result = pd.DataFrame( + { + "BCU_name": basins["BCU_name"], + "REGION": basins["REGION"], + "supply_mcm": supply_mcm.values, + } + ) + result["demand_mcm"] = result["BCU_name"].astype(str).map(total_demand).fillna(0.0) + safe_supply = result["supply_mcm"].replace(0, float("inf")) + result["demand_ratio"] = result["demand_mcm"] / safe_supply + + return result + + +def _diversity_select(group_sorted: pd.DataFrame, n_per_region: int) -> set[str]: + """Select basins spanning a range via evenly spaced quantile positions. + + Parameters + ---------- + group_sorted : pd.DataFrame + Single-region subset, pre-sorted by the target metric. + n_per_region : int + Target number of basins. + + Returns + ------- + set[str] + Selected BCU_name values. + """ + n = len(group_sorted) + if n <= n_per_region: + return set(group_sorted["BCU_name"]) + if n_per_region == 1: + return {group_sorted.iloc[n // 2]["BCU_name"]} + if n_per_region == 2: + return { + group_sorted.iloc[0]["BCU_name"], + group_sorted.iloc[-1]["BCU_name"], + } + positions = [i / (n_per_region - 1) for i in range(n_per_region)] + indices = {int(round(p * (n - 1))) for p in positions} + return {group_sorted.iloc[i]["BCU_name"] for i in indices} + + +def _select_by_stress( + stress_df: pd.DataFrame, + n_per_region: int = 3, +) -> set[str]: + """Select basins spanning the demand/supply ratio range per region. + + Ensures the reduced model includes basins across the stress spectrum: + low-stress (demand << supply) through high-stress (demand ~ supply). + + Parameters + ---------- + stress_df : pd.DataFrame + Output of compute_basin_demand_ratio(). + n_per_region : int + Target number of basins per region. + """ + selected: set[str] = set() + + for region, group in stress_df.groupby("REGION"): + group_sorted = group.sort_values("demand_ratio").reset_index(drop=True) + basins = _diversity_select(group_sorted, n_per_region) + selected.update(basins) + log.info(f"{region}: {len(basins)} basins selected") + + return selected + + @lru_cache() def map_add_on(rtype=Code): """Map addon & type_addon in ``sets.yaml``.""" diff --git a/message_ix_models/tests/model/water/test_utils.py b/message_ix_models/tests/model/water/test_utils.py index b56f41af0f..330535c8d8 100644 --- a/message_ix_models/tests/model/water/test_utils.py +++ b/message_ix_models/tests/model/water/test_utils.py @@ -2,6 +2,9 @@ import pytest from message_ix_models.model.water.utils import ( + _select_by_stress, + compute_basin_demand_ratio, + filter_basins_by_region, get_vintage_and_active_years, read_config, ) @@ -70,3 +73,87 @@ def test_get_vintage_and_active_years( result = get_vintage_and_active_years(mock_scenario_info, technical_lifetime) expected = pd.DataFrame(expected_data) pd.testing.assert_frame_equal(result, expected) + + +# --- Tests for stress-based basin selection --- + + +class TestComputeBasinDemandRatio: + """Tests for compute_basin_demand_ratio().""" + + def test_r12_shape_and_columns(self): + result = compute_basin_demand_ratio("R12") + assert len(result) == 217 + expected_cols = { + "BCU_name", + "REGION", + "supply_mcm", + "demand_mcm", + "demand_ratio", + } + assert expected_cols == set(result.columns) + + def test_no_nan_in_ratio(self): + result = compute_basin_demand_ratio("R12") + assert not result["demand_ratio"].isna().any() + + def test_all_regions_present(self): + result = compute_basin_demand_ratio("R12") + assert len(result["REGION"].unique()) == 12 + + def test_high_stress_basins_exist(self): + """At least some basins should have demand/supply > 10%.""" + result = compute_basin_demand_ratio("R12", demand_year=2050) + high_stress = result[result["demand_ratio"] > 0.10] + assert len(high_stress) > 0, "No high-stress basins found" + + +class TestSelectByStress: + """Tests for _select_by_stress().""" + + @pytest.fixture + def stress_df(self): + return compute_basin_demand_ratio("R12") + + @pytest.mark.parametrize("n_per_region", [1, 2, 3, 5]) + def test_all_regions_covered(self, stress_df, n_per_region): + selected = _select_by_stress(stress_df, n_per_region=n_per_region) + assert len(selected) > 0 + selected_df = stress_df[stress_df["BCU_name"].isin(selected)] + assert len(selected_df["REGION"].unique()) == 12 + + def test_min_max_included_for_n2(self, stress_df): + """For n=2, lowest and highest stress basins per region should be selected.""" + selected = _select_by_stress(stress_df, n_per_region=2) + for _, group in stress_df.groupby("REGION"): + sorted_g = group.sort_values("demand_ratio") + assert sorted_g.iloc[0]["BCU_name"] in selected + assert sorted_g.iloc[-1]["BCU_name"] in selected + + def test_n_per_region_respected(self, stress_df): + selected = _select_by_stress(stress_df, n_per_region=2) + selected_df = stress_df[stress_df["BCU_name"].isin(selected)] + for _, group in selected_df.groupby("REGION"): + assert len(group) <= 2 + + +class TestFilterBasinsByRegionStress: + """Test stress mode integration in filter_basins_by_region().""" + + def test_stress_mode_returns_valid_output(self, test_context): + from message_ix_models.util import package_data_path + + df_basins = pd.read_csv( + package_data_path("water", "delineation", "basins_by_region_simpl_R12.csv") + ) + test_context.reduced_basin = True + test_context.basin_selection = "stress" + test_context.regions = "R12" + test_context.ssp = "SSP2" + + filtered = filter_basins_by_region(df_basins, test_context, n_per_region=2) + + assert len(filtered) > 0 + assert len(filtered) < len(df_basins) + assert len(filtered["REGION"].unique()) == 12 + assert not filtered["BCU_name"].isna().any() From 9587edda1f3429509ab7eee60e313129ebcad853 Mon Sep 17 00:00:00 2001 From: Vignesh Raghunathan Date: Tue, 10 Feb 2026 18:16:08 +0100 Subject: [PATCH 05/10] Fix lint --- message_ix_models/model/water/data/demands.py | 34 ++++++++++--------- .../model/water/data/infrastructure.py | 14 ++++---- .../model/water/data/irrigation.py | 4 +-- 3 files changed, 27 insertions(+), 25 deletions(-) diff --git a/message_ix_models/model/water/data/demands.py b/message_ix_models/model/water/data/demands.py index 2469272559..fbacbff725 100644 --- a/message_ix_models/model/water/data/demands.py +++ b/message_ix_models/model/water/data/demands.py @@ -216,7 +216,7 @@ def add_sectoral_demands(context: "Context") -> dict[str, pd.DataFrame]: df_dmds.sort_values(["year", "node", "variable", "value"], inplace=True) df_dmds["time"] = "year" - + # Filter to only include basins that exist after basin filtering df_dmds = df_dmds[df_dmds["node"].isin(context.valid_basins)] @@ -241,7 +241,7 @@ def add_sectoral_demands(context: "Context") -> dict[str, pd.DataFrame]: ) df_m = df_m[["year", "pid", "variable", "value", "month"]] df_m.columns = pd.Index(["year", "node", "variable", "value", "time"]) - + # Filter monthly data to only include valid basins df_m = df_m[df_m["node"].isin(context.valid_basins)] @@ -773,7 +773,7 @@ def read_water_availability(context: "Context") -> Sequence[pd.DataFrame]: "water", "delineation", f"basins_by_region_simpl_{context.regions}.csv" ) df_x = pd.read_csv(PATH) - + # Filter to only include valid basins df_x = df_x[df_x["BCU_name"].isin(context.valid_basins)] @@ -788,12 +788,12 @@ def read_water_availability(context: "Context") -> Sequence[pd.DataFrame]: # Read rcp 2.6 data df_sw = pd.read_csv(path1) df_sw.drop(["Unnamed: 0"], axis=1, inplace=True) - - # Filter columns to only include valid basins - # The columns are years, so we need to filter rows based on the original basin order - # First, get the indices of valid basins from the original full list - full_basin_df = pd.read_csv(PATH) # Read full basin list again - valid_indices = full_basin_df[full_basin_df["BCU_name"].isin(context.valid_basins)].index + + # Filter rows to valid basins using index positions from full list + full_basin_df = pd.read_csv(PATH) + valid_indices = full_basin_df[ + full_basin_df["BCU_name"].isin(context.valid_basins) + ].index df_sw = df_sw.iloc[valid_indices] # Keep only rows for valid basins df_sw.reset_index(drop=True, inplace=True) @@ -821,11 +821,11 @@ def read_water_availability(context: "Context") -> Sequence[pd.DataFrame]: # Read groundwater data df_gw = pd.read_csv(path1) df_gw.drop(["Unnamed: 0"], axis=1, inplace=True) - + # Filter to only include valid basins (same as df_sw) df_gw = df_gw.iloc[valid_indices] # Use same valid_indices from above df_gw.reset_index(drop=True, inplace=True) - + df_gw.index = df_x["BCU_name"].index df_gw = df_gw.stack().reset_index() df_gw.columns = pd.Index(["Region", "years", "value"]) @@ -849,10 +849,12 @@ def read_water_availability(context: "Context") -> Sequence[pd.DataFrame]: ) df_sw = pd.read_csv(path1) df_sw.drop(["Unnamed: 0"], axis=1, inplace=True) - - # Filter to only include valid basins - full_basin_df = pd.read_csv(PATH) # Read full basin list again - valid_indices = full_basin_df[full_basin_df["BCU_name"].isin(context.valid_basins)].index + + # Filter rows to valid basins + full_basin_df = pd.read_csv(PATH) + valid_indices = full_basin_df[ + full_basin_df["BCU_name"].isin(context.valid_basins) + ].index df_sw = df_sw.iloc[valid_indices] df_sw.reset_index(drop=True, inplace=True) @@ -878,7 +880,7 @@ def read_water_availability(context: "Context") -> Sequence[pd.DataFrame]: ) df_gw = pd.read_csv(path1) df_gw.drop(["Unnamed: 0"], axis=1, inplace=True) - + # Filter to only include valid basins (same as df_sw) df_gw = df_gw.iloc[valid_indices] # Use same valid_indices from above df_gw.reset_index(drop=True, inplace=True) diff --git a/message_ix_models/model/water/data/infrastructure.py b/message_ix_models/model/water/data/infrastructure.py index 5953d27c25..169e398eed 100644 --- a/message_ix_models/model/water/data/infrastructure.py +++ b/message_ix_models/model/water/data/infrastructure.py @@ -291,10 +291,10 @@ def add_infrastructure_techs(context: "Context") -> dict[str, pd.DataFrame]: PATH = package_data_path("water", "delineation", FILE2) df_node = pd.read_csv(PATH) - + # Filter to only valid basins (already filtered in map_basin) df_node = df_node[df_node["BCU_name"].isin(context.valid_basins)] - + # Assigning proper nomenclature df_node["node"] = "B" + df_node["BCU_name"].astype(str) df_node["mode"] = "M" + df_node["BCU_name"].astype(str) @@ -811,10 +811,10 @@ def add_desalination(context: "Context") -> dict[str, pd.DataFrame]: PATH = package_data_path("water", "delineation", FILE2) df_node = pd.read_csv(PATH) - + # Filter to only valid basins (already filtered in map_basin) df_node = df_node[df_node["BCU_name"].isin(context.valid_basins)] - + # Assigning proper nomenclature df_node["node"] = "B" + df_node["BCU_name"].astype(str) df_node["mode"] = "M" + df_node["BCU_name"].astype(str) @@ -823,11 +823,11 @@ def add_desalination(context: "Context") -> dict[str, pd.DataFrame]: if context.type_reg == "country" else f"{context.regions}_" + df_node["REGION"].astype(str) ) - - # Filter historical and projected data to only include basins that exist after filtering + + # Filter to basins that exist after filtering df_hist = df_hist[df_hist["BCU_name"].isin(context.valid_basins)] df_proj = df_proj[df_proj["BCU_name"].isin(context.valid_basins)] - + # output dataframe linking to desal tech types out_df = ( make_df( diff --git a/message_ix_models/model/water/data/irrigation.py b/message_ix_models/model/water/data/irrigation.py index e6046a86a7..5626aa2537 100644 --- a/message_ix_models/model/water/data/irrigation.py +++ b/message_ix_models/model/water/data/irrigation.py @@ -32,10 +32,10 @@ def add_irr_structure(context: "Context") -> dict[str, pd.DataFrame]: FILE2 = f"basins_by_region_simpl_{context.regions}.csv" PATH = package_data_path("water", "delineation", FILE2) df_node = pd.read_csv(PATH) - + # Filter to only include valid basins df_node = df_node[df_node["BCU_name"].isin(context.valid_basins)] - + # Assigning proper nomenclature df_node["node"] = "B" + df_node["BCU_name"].astype(str) df_node["mode"] = "M" + df_node["BCU_name"].astype(str) From a859132405d316f0440b8d24faba8a0e0118d1e4 Mon Sep 17 00:00:00 2001 From: Vignesh Raghunathan Date: Tue, 17 Feb 2026 17:47:49 +0100 Subject: [PATCH 06/10] Make filter_list additive on automatic basin selection filter_list now augments the automatic selection (first_k or stress) rather than replacing it. Docstring updated to document the two-step design. Test parametrized over both selection modes. --- message_ix_models/model/water/utils.py | 106 ++++++++---------- .../tests/model/water/test_utils.py | 37 +++++- 2 files changed, 84 insertions(+), 59 deletions(-) diff --git a/message_ix_models/model/water/utils.py b/message_ix_models/model/water/utils.py index 8e4141df58..9101c199e8 100644 --- a/message_ix_models/model/water/utils.py +++ b/message_ix_models/model/water/utils.py @@ -85,19 +85,34 @@ def filter_basins_by_region( ) -> pd.DataFrame: """Filter basins based on context configuration. + Selection is two-step: + + 1. **Automatic selection** — either ``"first_k"`` (head *n* per region) or + ``"stress"`` (diversity-sampled across the demand/supply ratio spectrum), + controlled by ``context.basin_selection``. + 2. **filter_list augmentation** — if ``context.filter_list`` is set, those + basins are *added* to the automatic set (union, not replacement). + Parameters ---------- df_basins : pd.DataFrame - DataFrame with basin data including 'REGION' and 'BCU_name' columns + DataFrame with basin data including 'REGION' and 'BCU_name' columns. context : Context, optional - Context object that may contain basin filtering configuration + Context object that may contain: + + - ``reduced_basin`` (bool): enable filtering (default False). + - ``basin_selection`` (str): ``"first_k"`` or ``"stress"`` + (default ``"first_k"``). + - ``num_basins`` (int): override *n_per_region*. + - ``filter_list`` (list[str]): additional BCU_name values to include + on top of the automatic selection. n_per_region : int, default 3 - Default number of basins to keep per region (used as fallback) + Default number of basins to keep per region (used as fallback). Returns ------- pd.DataFrame - Filtered DataFrame based on configuration + Filtered DataFrame based on configuration. """ if not context: context = Context.get_instance(-1) @@ -110,74 +125,51 @@ def filter_basins_by_region( log.info("Basin filtering disabled, returning all basins") return df_basins - # Basin filtering is enabled + # Basin filtering is enabled — run automatic selection, then augment with + # filter_list if provided. filter_list = getattr(context, "filter_list", None) num_basins = getattr(context, "num_basins", None) - - if filter_list: - # Filter to specific basin list - filtered = df_basins[df_basins["BCU_name"].isin(filter_list)] - - # Check if we have at least 1 basin per R12 region - all_regions = set(df_basins["REGION"].unique()) - filtered_regions = set(filtered["REGION"].unique()) - missing_regions = all_regions - filtered_regions - - if missing_regions: - log.info(f"Adding one basin per missing region: {missing_regions}") - # Add one basin from each missing region - for region in missing_regions: - region_basins = df_basins[df_basins["REGION"] == region] - # Add the first basin from this region - filtered = pd.concat( - [filtered, region_basins.head(1)], ignore_index=True - ) - - log.info( - f"Filtered basins from {len(df_basins)} to {len(filtered)} " - f"using custom filter list: {filter_list} (with 1 basin per missing region)" - ) - - return filtered.reset_index(drop=True) - - # Check for stress-based selection mode basin_selection = getattr(context, "basin_selection", "first_k") + # Step 1: automatic selection (stress or first_k) + if "REGION" not in df_basins.columns: + log.info("REGION column not found, cannot filter by region") + return df_basins + if basin_selection == "stress": n = num_basins if num_basins is not None else n_per_region ssp = getattr(context, "ssp", "SSP2") stress_df = compute_basin_demand_ratio(context.regions, ssp=ssp) selected = _select_by_stress(stress_df, n_per_region=n) - - filtered = df_basins[df_basins["BCU_name"].isin(selected)].reset_index( - drop=True - ) + filtered = df_basins[df_basins["BCU_name"].isin(selected)] log.info( f"Stress-based selection: {len(df_basins)} -> {len(filtered)} basins " f"(n_per_region={n})" ) - return filtered - - if num_basins is not None: - n_per_region = num_basins - - # Group by region and take first n rows from each group - if "REGION" not in df_basins.columns: - log.info("REGION column not found, cannot filter by region") - return df_basins - - filtered = ( - df_basins.groupby("REGION", group_keys=False) - .apply(lambda x: x.head(n_per_region)) - .reset_index(drop=True) - ) + else: + if num_basins is not None: + n_per_region = num_basins + filtered = df_basins.groupby("REGION", group_keys=False).apply( + lambda x: x.head(n_per_region) + ) + log.info( + f"first_k selection: {len(df_basins)} -> {len(filtered)} basins " + f"(n_per_region={n_per_region})" + ) - log.info( - f"Filtered basins from {len(df_basins)} to {len(filtered)} " - f"(keeping first {n_per_region} per region)" - ) + # Step 2: augment with filter_list (additive — union with automatic selection) + if filter_list: + extra = df_basins[ + df_basins["BCU_name"].isin(filter_list) + & ~df_basins["BCU_name"].isin(filtered["BCU_name"]) + ] + if len(extra): + log.info( + f"filter_list adds {len(extra)} basins on top of automatic selection" + ) + filtered = pd.concat([filtered, extra], ignore_index=True) - return filtered + return filtered.reset_index(drop=True) def compute_basin_demand_ratio( diff --git a/message_ix_models/tests/model/water/test_utils.py b/message_ix_models/tests/model/water/test_utils.py index 330535c8d8..af7c6288bb 100644 --- a/message_ix_models/tests/model/water/test_utils.py +++ b/message_ix_models/tests/model/water/test_utils.py @@ -140,12 +140,15 @@ def test_n_per_region_respected(self, stress_df): class TestFilterBasinsByRegionStress: """Test stress mode integration in filter_basins_by_region().""" - def test_stress_mode_returns_valid_output(self, test_context): + @pytest.fixture + def df_basins(self): from message_ix_models.util import package_data_path - df_basins = pd.read_csv( + return pd.read_csv( package_data_path("water", "delineation", "basins_by_region_simpl_R12.csv") ) + + def test_stress_mode_returns_valid_output(self, test_context, df_basins): test_context.reduced_basin = True test_context.basin_selection = "stress" test_context.regions = "R12" @@ -157,3 +160,33 @@ def test_stress_mode_returns_valid_output(self, test_context): assert len(filtered) < len(df_basins) assert len(filtered["REGION"].unique()) == 12 assert not filtered["BCU_name"].isna().any() + + @pytest.mark.parametrize("basin_selection", ["first_k", "stress"]) + def test_filter_list_additive_on_automatic( + self, test_context, df_basins, basin_selection + ): + """filter_list augments automatic selection rather than replacing it.""" + test_context.reduced_basin = True + test_context.basin_selection = basin_selection + test_context.regions = "R12" + test_context.ssp = "SSP2" + + # Run automatic selection alone (1 per region = 12 basins) + auto_only = filter_basins_by_region(df_basins, test_context, n_per_region=1) + auto_basins = set(auto_only["BCU_name"]) + + # Pick a basin NOT in the automatic set to add via filter_list + all_basins = set(df_basins["BCU_name"]) + extra_basins = list(all_basins - auto_basins)[:3] + assert len(extra_basins) > 0, "Need basins outside automatic set" + + test_context.filter_list = extra_basins + combined = filter_basins_by_region(df_basins, test_context, n_per_region=1) + combined_basins = set(combined["BCU_name"]) + + # Automatic basins are still present + assert auto_basins <= combined_basins + # Extra basins were added + assert set(extra_basins) <= combined_basins + # Total is the union + assert len(combined) == len(auto_basins) + len(extra_basins) From 8438738bfdbe4920a8352675be15c3d7f23e36d9 Mon Sep 17 00:00:00 2001 From: Vignesh Raghunathan Date: Thu, 26 Feb 2026 16:03:11 +0100 Subject: [PATCH 07/10] Fix pandas 3.0 compat in water data functions - utils.py: groupby().apply() drops groupby column in pandas 3.0; replace with cumcount() boolean mask - demands.py: cast StringDtype columns to object before xr.Dataset() - water_supply.py: filter delineation and e-flow CSVs to valid_basins in add_e_flow (mirrors read_water_availability pattern) --- message_ix_models/model/water/data/demands.py | 2 ++ message_ix_models/model/water/data/water_supply.py | 13 +++++++++---- message_ix_models/model/water/utils.py | 13 ++++++++++--- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/message_ix_models/model/water/data/demands.py b/message_ix_models/model/water/data/demands.py index fbacbff725..2bf9a1c140 100644 --- a/message_ix_models/model/water/data/demands.py +++ b/message_ix_models/model/water/data/demands.py @@ -201,6 +201,8 @@ def add_sectoral_demands(context: "Context") -> dict[str, pd.DataFrame]: for key, df in d.items(): df.rename(columns={"Unnamed: 0": "year"}, inplace=True) df.set_index("year", inplace=True) + # Cast column index from StringDtype to object for xarray compatibility + df.columns = df.columns.astype(object) dfs[key] = df # convert the dictionary of dataframes to xarray diff --git a/message_ix_models/model/water/data/water_supply.py b/message_ix_models/model/water/data/water_supply.py index a4d070a043..21e82ec0ad 100644 --- a/message_ix_models/model/water/data/water_supply.py +++ b/message_ix_models/model/water/data/water_supply.py @@ -1002,11 +1002,16 @@ def add_e_flow(context: "Context") -> dict[str, pd.DataFrame]: # Reading data, the data is spatially and temprally aggregated from GHMs df_sw, df_gw = read_water_availability(context) - # reading sample for assiging basins + # reading sample for assigning basins PATH = package_data_path( "water", "delineation", f"basins_by_region_simpl_{context.regions}.csv" ) - df_x = pd.read_csv(PATH) + df_x_full = pd.read_csv(PATH) + # Index positions of valid basins in the full CSV (for positional CSV filtering) + valid_indices = df_x_full[df_x_full["BCU_name"].isin(context.valid_basins)].index + df_x = df_x_full[df_x_full["BCU_name"].isin(context.valid_basins)].reset_index( + drop=True + ) dmd_df = make_df( "demand", @@ -1030,6 +1035,7 @@ def add_e_flow(context: "Context") -> dict[str, pd.DataFrame]: ) df_env = pd.read_csv(path1) df_env.drop(["Unnamed: 0"], axis=1, inplace=True) + df_env = df_env.iloc[valid_indices].reset_index(drop=True) df_env.index = df_x["BCU_name"].index df_env = df_env.stack().reset_index() df_env.columns = pd.Index(["Region", "years", "value"]) @@ -1052,8 +1058,7 @@ def add_e_flow(context: "Context") -> dict[str, pd.DataFrame]: ) df_env = pd.read_csv(path1) df_env.drop(["Unnamed: 0"], axis=1, inplace=True) - # new_cols = pd.to_datetime(df_env.columns, format="%Y/%m/%d") - # df_env.columns = new_cols + df_env = df_env.iloc[valid_indices].reset_index(drop=True) df_env.index = df_x["BCU_name"].index df_env = df_env.stack().reset_index() df_env.columns = pd.Index(["Region", "years", "value"]) diff --git a/message_ix_models/model/water/utils.py b/message_ix_models/model/water/utils.py index 9101c199e8..624b316bc0 100644 --- a/message_ix_models/model/water/utils.py +++ b/message_ix_models/model/water/utils.py @@ -131,6 +131,14 @@ def filter_basins_by_region( num_basins = getattr(context, "num_basins", None) basin_selection = getattr(context, "basin_selection", "first_k") + if num_basins is None: + log.info(f"num_basins not set, using default n_per_region={n_per_region}") + elif num_basins < 3: + log.warning( + f"num_basins={num_basins} is below 3; results may not capture " + f"sufficient basin diversity per region" + ) + # Step 1: automatic selection (stress or first_k) if "REGION" not in df_basins.columns: log.info("REGION column not found, cannot filter by region") @@ -149,9 +157,8 @@ def filter_basins_by_region( else: if num_basins is not None: n_per_region = num_basins - filtered = df_basins.groupby("REGION", group_keys=False).apply( - lambda x: x.head(n_per_region) - ) + mask = df_basins.groupby("REGION").cumcount() < n_per_region + filtered = df_basins[mask] log.info( f"first_k selection: {len(df_basins)} -> {len(filtered)} basins " f"(n_per_region={n_per_region})" From d220eb1b53e377ae7da7856aef709315b5159173 Mon Sep 17 00:00:00 2001 From: Vignesh Raghunathan Date: Thu, 26 Feb 2026 16:03:26 +0100 Subject: [PATCH 08/10] Refactor water tests: add water_params factory, reduced-basin coverage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Centralize region→type_reg mapping in water_params() factory. Replace all inline param dicts. Add reduced_basin=True variants across data tests. Refactor test_irrigation to use water_context fixture. --- .../tests/model/water/conftest.py | 30 ++++- .../tests/model/water/data/test_demands.py | 66 +++------ .../model/water/data/test_infrastructure.py | 27 ++-- .../tests/model/water/data/test_irrigation.py | 58 +++----- .../model/water/data/test_water_supply.py | 49 ++----- .../tests/model/water/test_utils.py | 125 +++++++----------- 6 files changed, 133 insertions(+), 222 deletions(-) diff --git a/message_ix_models/tests/model/water/conftest.py b/message_ix_models/tests/model/water/conftest.py index 4a47f28388..cb1d5253c7 100644 --- a/message_ix_models/tests/model/water/conftest.py +++ b/message_ix_models/tests/model/water/conftest.py @@ -8,6 +8,25 @@ from message_ix_models.model.structure import get_codes from message_ix_models.util import package_data_path +REGION_CONFIG = { + "R11": {"regions": "R11", "type_reg": "global"}, + "R12": {"regions": "R12", "type_reg": "global"}, + "ZMB": {"regions": "ZMB", "type_reg": "country"}, +} + + +def water_params(region, *, reduced_basin=False, **extra): + """Build a water_context param dict for the given region. + + Encodes the region→type_reg mapping once. When *reduced_basin* is True, + adds ``first_k`` selection with ``num_basins=2`` (24 basins total). + """ + base = {**REGION_CONFIG[region]} + if reduced_basin: + base.update(reduced_basin=True, basin_selection="first_k", num_basins=2) + base.update(extra) + return base + def setup_valid_basins(context, regions="R12"): """Set up valid_basins attribute for test contexts. @@ -56,7 +75,16 @@ def water_context(test_context, request): test_context.nexus_set = params.get("nexus_set", "nexus") # Optional attributes - for attr in ["RCP", "REL", "SDG", "ssp"]: + for attr in [ + "RCP", + "REL", + "SDG", + "ssp", + "reduced_basin", + "basin_selection", + "num_basins", + "filter_list", + ]: if attr in params: setattr(test_context, attr, params[attr]) diff --git a/message_ix_models/tests/model/water/data/test_demands.py b/message_ix_models/tests/model/water/data/test_demands.py index fe03839e1b..27e796ea67 100644 --- a/message_ix_models/tests/model/water/data/test_demands.py +++ b/message_ix_models/tests/model/water/data/test_demands.py @@ -6,18 +6,16 @@ add_sectoral_demands, add_water_availability, ) +from message_ix_models.tests.model.water.conftest import water_params @pytest.mark.parametrize( "water_context", [ - # Global R11 - {"regions": "R11", "type_reg": "global", "SDG": "baseline", "time": "year"}, - # Global R12 - {"regions": "R12", "type_reg": "global", "SDG": "baseline", "time": "year"}, - # Country ZMB - {"regions": "ZMB", "type_reg": "country", "SDG": "baseline", "time": "year"}, - # SDG="SDG" excluded: requires policy data files + water_params("R11", SDG="baseline", time="year"), + water_params("R12", SDG="baseline", time="year"), + water_params("ZMB", SDG="baseline", time="year"), + water_params("R12", reduced_basin=True, SDG="baseline", time="year"), ], indirect=True, ) @@ -55,44 +53,12 @@ def test_add_sectoral_demands(water_context, water_scenario, assert_message_para @pytest.mark.parametrize( "water_context", [ - # Global R11 (no monthly data) - { - "regions": "R11", - "type_reg": "global", - "RCP": "2p6", - "REL": "low", - "time": "year", - }, - # Global R12 (monthly exists for REL=low) - { - "regions": "R12", - "type_reg": "global", - "RCP": "2p6", - "REL": "low", - "time": "year", - }, - { - "regions": "R12", - "type_reg": "global", - "RCP": "2p6", - "REL": "low", - "time": "month", - }, - # Country ZMB (monthly exists for REL=low) - { - "regions": "ZMB", - "type_reg": "country", - "RCP": "2p6", - "REL": "low", - "time": "year", - }, - { - "regions": "ZMB", - "type_reg": "country", - "RCP": "2p6", - "REL": "low", - "time": "month", - }, + water_params("R11", RCP="2p6", REL="low", time="year"), + water_params("R12", RCP="2p6", REL="low", time="year"), + water_params("R12", RCP="2p6", REL="low", time="month"), + water_params("ZMB", RCP="2p6", REL="low", time="year"), + water_params("ZMB", RCP="2p6", REL="low", time="month"), + water_params("R12", reduced_basin=True, RCP="2p6", REL="low", time="year"), ], indirect=True, ) @@ -122,12 +88,10 @@ def test_add_water_availability(water_context, assert_message_params): @pytest.mark.parametrize( "water_context", [ - # Global R11 - {"regions": "R11", "type_reg": "global"}, - # Global R12 - {"regions": "R12", "type_reg": "global"}, - # Country ZMB - {"regions": "ZMB", "type_reg": "country"}, + water_params("R11"), + water_params("R12"), + water_params("ZMB"), + water_params("R12", reduced_basin=True), ], indirect=True, ) diff --git a/message_ix_models/tests/model/water/data/test_infrastructure.py b/message_ix_models/tests/model/water/data/test_infrastructure.py index 29847d7293..691bf86124 100644 --- a/message_ix_models/tests/model/water/data/test_infrastructure.py +++ b/message_ix_models/tests/model/water/data/test_infrastructure.py @@ -4,20 +4,19 @@ add_desalination, add_infrastructure_techs, ) +from message_ix_models.tests.model.water.conftest import water_params @pytest.mark.parametrize( "water_context", [ - # Global R11 - {"regions": "R11", "type_reg": "global", "SDG": "baseline"}, - {"regions": "R11", "type_reg": "global", "SDG": "not_baseline"}, - # Global R12 - {"regions": "R12", "type_reg": "global", "SDG": "baseline"}, - {"regions": "R12", "type_reg": "global", "SDG": "not_baseline"}, - # Country ZMB - {"regions": "ZMB", "type_reg": "country", "SDG": "baseline"}, - {"regions": "ZMB", "type_reg": "country", "SDG": "not_baseline"}, + water_params("R11", SDG="baseline"), + water_params("R11", SDG="not_baseline"), + water_params("R12", SDG="baseline"), + water_params("R12", SDG="not_baseline"), + water_params("ZMB", SDG="baseline"), + water_params("ZMB", SDG="not_baseline"), + water_params("R12", reduced_basin=True, SDG="baseline"), ], indirect=True, ) @@ -39,12 +38,10 @@ def test_add_infrastructure_techs( @pytest.mark.parametrize( "water_context", [ - # Global R11 (has 6p0) - {"regions": "R11", "type_reg": "global", "RCP": "6p0"}, - # Global R12 (no 6p0, use 7p0) - {"regions": "R12", "type_reg": "global", "RCP": "7p0"}, - # Country ZMB (no 6p0, use 7p0) - {"regions": "ZMB", "type_reg": "country", "RCP": "7p0"}, + water_params("R11", RCP="6p0"), + water_params("R12", RCP="7p0"), + water_params("ZMB", RCP="7p0"), + water_params("R12", reduced_basin=True, RCP="7p0"), ], indirect=True, ) diff --git a/message_ix_models/tests/model/water/data/test_irrigation.py b/message_ix_models/tests/model/water/data/test_irrigation.py index ab1ee98cfd..3df77860a0 100644 --- a/message_ix_models/tests/model/water/data/test_irrigation.py +++ b/message_ix_models/tests/model/water/data/test_irrigation.py @@ -1,46 +1,26 @@ -from message_ix import Scenario +import pytest -from message_ix_models import ScenarioInfo -from message_ix_models.model.structure import get_codes from message_ix_models.model.water.data.irrigation import add_irr_structure -from message_ix_models.tests.model.water.conftest import setup_valid_basins +from message_ix_models.tests.model.water.conftest import water_params + + +@pytest.mark.parametrize( + "water_context", + [ + water_params("ZMB"), + water_params("ZMB", reduced_basin=True), + ], + indirect=True, +) +def test_add_irr_structure( + water_context, water_scenario, assert_input_output_structure +): + """Test add_irr_structure with country model configurations.""" + result = add_irr_structure(water_context) - -def test_add_irr_structure(test_context): - # FIXME You probably want this to be part of a common setup rather than writing - # something like this for every test - test_context.type_reg = "country" - test_context.regions = "ZMB" - nodes = get_codes(f"node/{test_context.regions}") - nodes = list(map(str, nodes[nodes.index("World")].child)) - test_context.map_ISO_c = {test_context.regions: nodes[0]} - - mp = test_context.get_platform() - scenario_info = { - "mp": mp, - "model": "test water model", - "scenario": "test water scenario", - "version": "new", - } - s = Scenario(**scenario_info) - s.add_horizon(year=[2020, 2030, 2040]) - s.add_set("technology", ["tech1", "tech2"]) - s.add_set("node", ["loc1", "loc2"]) - s.add_set("year", [2020, 2030, 2040]) - - # FIXME same as above - test_context["water build info"] = ScenarioInfo(s) - - # Set up valid_basins for basin filtering - setup_valid_basins(test_context, regions=test_context.regions) - - # Call the function to be tested - result = add_irr_structure(test_context) - - # Assert the results assert isinstance(result, dict) - assert "input" in result - assert "output" in result + assert_input_output_structure(result) + assert all( col in result["input"].columns for col in [ diff --git a/message_ix_models/tests/model/water/data/test_water_supply.py b/message_ix_models/tests/model/water/data/test_water_supply.py index 1aae1299b0..c61bf476fb 100644 --- a/message_ix_models/tests/model/water/data/test_water_supply.py +++ b/message_ix_models/tests/model/water/data/test_water_supply.py @@ -7,17 +7,16 @@ add_water_supply, map_basin_region_wat, ) +from message_ix_models.tests.model.water.conftest import water_params @pytest.mark.parametrize( "water_context", [ - # Global R11 - {"regions": "R11", "type_reg": "global", "RCP": "2p6", "REL": "med"}, - # Global R12 - {"regions": "R12", "type_reg": "global", "RCP": "2p6", "REL": "med"}, - # Country ZMB - {"regions": "ZMB", "type_reg": "country", "RCP": "2p6", "REL": "med"}, + water_params("R11", RCP="2p6", REL="med"), + water_params("R12", RCP="2p6", REL="med"), + water_params("ZMB", RCP="2p6", REL="med"), + water_params("R12", reduced_basin=True, RCP="2p6", REL="med"), ], indirect=True, ) @@ -39,12 +38,10 @@ def test_map_basin_region_wat(water_context): @pytest.mark.parametrize( "water_context", [ - # Global R11 - {"regions": "R11", "type_reg": "global", "RCP": "2p6", "REL": "med"}, - # Global R12 - {"regions": "R12", "type_reg": "global", "RCP": "2p6", "REL": "med"}, - # Country ZMB - {"regions": "ZMB", "type_reg": "country", "RCP": "2p6", "REL": "med"}, + water_params("R11", RCP="2p6", REL="med"), + water_params("R12", RCP="2p6", REL="med"), + water_params("ZMB", RCP="2p6", REL="med"), + water_params("R12", reduced_basin=True, RCP="2p6", REL="med"), ], indirect=True, ) @@ -64,30 +61,10 @@ def test_add_water_supply(water_context, water_scenario, assert_message_params): @pytest.mark.parametrize( "water_context", [ - # Global R11 - { - "regions": "R11", - "type_reg": "global", - "RCP": "2p6", - "REL": "med", - "SDG": True, - }, - # Global R12 - { - "regions": "R12", - "type_reg": "global", - "RCP": "2p6", - "REL": "med", - "SDG": True, - }, - # Country ZMB - { - "regions": "ZMB", - "type_reg": "country", - "RCP": "2p6", - "REL": "med", - "SDG": True, - }, + water_params("R11", RCP="2p6", REL="med", SDG=True), + water_params("R12", RCP="2p6", REL="med", SDG=True), + water_params("ZMB", RCP="2p6", REL="med", SDG=True), + water_params("R12", reduced_basin=True, RCP="2p6", REL="med", SDG=True), ], indirect=True, ) diff --git a/message_ix_models/tests/model/water/test_utils.py b/message_ix_models/tests/model/water/test_utils.py index af7c6288bb..0bd9e9c881 100644 --- a/message_ix_models/tests/model/water/test_utils.py +++ b/message_ix_models/tests/model/water/test_utils.py @@ -78,34 +78,19 @@ def test_get_vintage_and_active_years( # --- Tests for stress-based basin selection --- -class TestComputeBasinDemandRatio: - """Tests for compute_basin_demand_ratio().""" - - def test_r12_shape_and_columns(self): - result = compute_basin_demand_ratio("R12") - assert len(result) == 217 - expected_cols = { - "BCU_name", - "REGION", - "supply_mcm", - "demand_mcm", - "demand_ratio", - } - assert expected_cols == set(result.columns) - - def test_no_nan_in_ratio(self): - result = compute_basin_demand_ratio("R12") - assert not result["demand_ratio"].isna().any() - - def test_all_regions_present(self): - result = compute_basin_demand_ratio("R12") - assert len(result["REGION"].unique()) == 12 - - def test_high_stress_basins_exist(self): - """At least some basins should have demand/supply > 10%.""" - result = compute_basin_demand_ratio("R12", demand_year=2050) - high_stress = result[result["demand_ratio"] > 0.10] - assert len(high_stress) > 0, "No high-stress basins found" +def test_compute_basin_demand_ratio(): + """Shape, columns, completeness, stress signal for basin demand ratio.""" + result = compute_basin_demand_ratio("R12") + assert len(result) == 217 + assert {"BCU_name", "REGION", "supply_mcm", "demand_mcm", "demand_ratio"} == set( + result.columns + ) + assert not result["demand_ratio"].isna().any() + assert len(result["REGION"].unique()) == 12 + + # Separate call — different args + result_2050 = compute_basin_demand_ratio("R12", demand_year=2050) + assert (result_2050["demand_ratio"] > 0.10).any(), "No high-stress basins found" class TestSelectByStress: @@ -137,56 +122,36 @@ def test_n_per_region_respected(self, stress_df): assert len(group) <= 2 -class TestFilterBasinsByRegionStress: - """Test stress mode integration in filter_basins_by_region().""" +@pytest.mark.parametrize("basin_selection", ["first_k", "stress"]) +def test_filter_list_additive_on_automatic(test_context, basin_selection): + """filter_list augments automatic selection rather than replacing it.""" + from message_ix_models.util import package_data_path - @pytest.fixture - def df_basins(self): - from message_ix_models.util import package_data_path - - return pd.read_csv( - package_data_path("water", "delineation", "basins_by_region_simpl_R12.csv") - ) - - def test_stress_mode_returns_valid_output(self, test_context, df_basins): - test_context.reduced_basin = True - test_context.basin_selection = "stress" - test_context.regions = "R12" - test_context.ssp = "SSP2" - - filtered = filter_basins_by_region(df_basins, test_context, n_per_region=2) - - assert len(filtered) > 0 - assert len(filtered) < len(df_basins) - assert len(filtered["REGION"].unique()) == 12 - assert not filtered["BCU_name"].isna().any() - - @pytest.mark.parametrize("basin_selection", ["first_k", "stress"]) - def test_filter_list_additive_on_automatic( - self, test_context, df_basins, basin_selection - ): - """filter_list augments automatic selection rather than replacing it.""" - test_context.reduced_basin = True - test_context.basin_selection = basin_selection - test_context.regions = "R12" - test_context.ssp = "SSP2" - - # Run automatic selection alone (1 per region = 12 basins) - auto_only = filter_basins_by_region(df_basins, test_context, n_per_region=1) - auto_basins = set(auto_only["BCU_name"]) - - # Pick a basin NOT in the automatic set to add via filter_list - all_basins = set(df_basins["BCU_name"]) - extra_basins = list(all_basins - auto_basins)[:3] - assert len(extra_basins) > 0, "Need basins outside automatic set" - - test_context.filter_list = extra_basins - combined = filter_basins_by_region(df_basins, test_context, n_per_region=1) - combined_basins = set(combined["BCU_name"]) - - # Automatic basins are still present - assert auto_basins <= combined_basins - # Extra basins were added - assert set(extra_basins) <= combined_basins - # Total is the union - assert len(combined) == len(auto_basins) + len(extra_basins) + df_basins = pd.read_csv( + package_data_path("water", "delineation", "basins_by_region_simpl_R12.csv") + ) + + test_context.reduced_basin = True + test_context.basin_selection = basin_selection + test_context.regions = "R12" + test_context.ssp = "SSP2" + + # Run automatic selection alone (1 per region = 12 basins) + auto_only = filter_basins_by_region(df_basins, test_context, n_per_region=1) + auto_basins = set(auto_only["BCU_name"]) + + # Pick a basin NOT in the automatic set to add via filter_list + all_basins = set(df_basins["BCU_name"]) + extra_basins = list(all_basins - auto_basins)[:3] + assert len(extra_basins) > 0, "Need basins outside automatic set" + + test_context.filter_list = extra_basins + combined = filter_basins_by_region(df_basins, test_context, n_per_region=1) + combined_basins = set(combined["BCU_name"]) + + # Automatic basins are still present + assert auto_basins <= combined_basins + # Extra basins were added + assert set(extra_basins) <= combined_basins + # Total is the union + assert len(combined) == len(auto_basins) + len(extra_basins) From 3eaf70fa4cb68004bb982b21b515a3efcdf8c8f3 Mon Sep 17 00:00:00 2001 From: Vignesh Raghunathan Date: Thu, 26 Feb 2026 16:03:35 +0100 Subject: [PATCH 09/10] Update docs and whatsnew for basin filtering --- doc/water/index.rst | 34 ++++++++++++++++++++++++++++------ doc/whatsnew.rst | 2 +- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/doc/water/index.rst b/doc/water/index.rst index 44fa18a1ad..f8a4a8e1b9 100644 --- a/doc/water/index.rst +++ b/doc/water/index.rst @@ -15,9 +15,12 @@ This work extends the water sector linkage described by Parkinson et al. (2019) CLI usage ========= -Use the :doc:`CLI ` command ``mix-data water`` to invoke the commands defined in :mod:`.water.cli`. Example: -``mix-models --url=ixmp://ixmp_dev/ENGAGE_SSP2_v4.1.7/baseline_clone_test water cooling`` -model and scenario specifications can be either set manually in ``cli.py`` or specified in the ``--url`` option +Use the :doc:`CLI ` command ``mix-models water-ix`` to invoke the commands defined in :mod:`.water.cli`. +Model and scenario specifications can be set via the ``--url`` option or in ``cli.py``. + +Example:: + + mix-models --url=ixmp://ixmp_dev/ENGAGE_SSP2_v4.1.7/baseline_clone_test water-ix nexus .. code:: @@ -37,18 +40,37 @@ model and scenario specifications can be either set manually in ``cli.py`` or sp nexus Add basin structure connected to the energy sector and water... report function to run the water report_full from cli to the scenario... +.. code:: + + Usage: mix-models water-ix nexus [OPTIONS] + + Options: + --rcps [no_climate|6p0|2p6|7p0] Climate scenario (default: no_climate). + --rels [low|med|high] Hydrological data reliability (default: low). + --sdgs TEXT Water SDG measures (default: baseline). + --macro Solve with MESSAGE-MACRO. + --reduced-basin / --no-reduced-basin + Enable basin filtering (default: off). + --basin-selection [first_k|stress] Automatic selection method (default: first_k). + first_k: head n basins per region in CSV order. + stress: sample across demand/supply ratio spectrum. + --num-basins INTEGER Basins per region (default: 3). + --filter-list TEXT Extra basins to add to the automatic selection + (repeatable). Final set is the union of automatic + selection and filter-list entries. + Country vs Global implementation -------------------------------- The :mod:`message_ix_models.model.water` is designed to being able to add water components to either a global R11 (or R12) model or any country model designed with `the MESSAGEix single country `_ model prototype. -For any of the region configuration a shapefile is needed to run the pre-processing part, while, once the data is prepared, only a .csv file similar to those in `message_ix_models.data.water.delineation` is needed. +For any of the region configuration a shapefile is needed to run the pre-processing part, while, once the data is prepared, only a .csv file similar to those in :file:`message_ix_models/data/water/delineation/` is needed. To work with a country model please ensure that: 1. country model and scenario are specified either in ``--url`` or in the ``cli.py`` script 2. the option ``--regions`` is used with the ISO3 code of the country (e.g. for Israel ``--regions=ISR``) -3. Following the Israel example add a 'country'.yaml file in `message_ix_models.data.node` for the specific country -4. Following the Israel example add the country ISO3 code in the 'regions' options in `message_ix_models.utils.click` +3. Following the Israel example add a 'country'.yaml file in :file:`message_ix_models/data/node/` for the specific country +4. Following the Israel example add the country ISO3 code in the 'regions' options in :mod:`message_ix_models.util.click` Annual vs sub-annual implementation ----------------------------------- diff --git a/doc/whatsnew.rst b/doc/whatsnew.rst index 4ee9d0876c..179d3ab75b 100644 --- a/doc/whatsnew.rst +++ b/doc/whatsnew.rst @@ -11,7 +11,7 @@ Next release - Add reduced basin filtering for water module with ``--reduced-basin`` and demand/supply stress-based selection via ``--basin-selection stress`` - (:issue:`414`). + (:pull:`432`, :issue:`414`). - Fix water module parameter bugs and refactor cooling (:pull:`405`): infrastructure M1/Mf mode fixes, regional average shares for cooling allocation, From df6c5dc7e14f1684244460a496c6fe31992ad0b9 Mon Sep 17 00:00:00 2001 From: Vignesh Raghunathan Date: Thu, 26 Feb 2026 16:24:08 +0100 Subject: [PATCH 10/10] Add FIXME for indirect region derivation in water_for_ppl --- message_ix_models/model/water/data/water_for_ppl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/message_ix_models/model/water/data/water_for_ppl.py b/message_ix_models/model/water/data/water_for_ppl.py index 338c77d28a..a63438c9df 100644 --- a/message_ix_models/model/water/data/water_for_ppl.py +++ b/message_ix_models/model/water/data/water_for_ppl.py @@ -54,7 +54,7 @@ def _load_scenario_and_cooling_data( "water", "delineation", f"basins_by_region_simpl_{context.regions}.csv" ) - # Load basin delineation + # FIXME Derive node_region from scenario/codelist rather than basin CSV df_node = pd.read_csv(basin_path) df_node["node"] = "B" + df_node["BCU_name"].astype(str) df_node["mode"] = "M" + df_node["BCU_name"].astype(str)