diff --git a/doc/water/index.rst b/doc/water/index.rst index f8a4a8e1b9..a5ad94f021 100644 --- a/doc/water/index.rst +++ b/doc/water/index.rst @@ -96,6 +96,9 @@ Build and run .. automodule:: message_ix_models.model.water.build :members: +.. automodule:: message_ix_models.model.water.config + :members: + Data preparation ---------------- diff --git a/doc/whatsnew.rst b/doc/whatsnew.rst index ec052a8413..c42d62c50a 100644 --- a/doc/whatsnew.rst +++ b/doc/whatsnew.rst @@ -6,6 +6,7 @@ Next release - Add IAMC code list :class:`~.iamc.structure.CL_SCENARIO_DIAGNOSTIC` (:pull:`501`). - New module :ref:`tools-newclimate` (:pull:`499`). +- Add :class:`.model.water.Config` to collect water module settings (:pull:`509`). - Add :doc:`/api/model-bmt` (:pull:`433`). - Add diff --git a/message_ix_models/model/water/__init__.py b/message_ix_models/model/water/__init__.py index 46583ecf73..9bdedaf45f 100644 --- a/message_ix_models/model/water/__init__.py +++ b/message_ix_models/model/water/__init__.py @@ -1,4 +1,5 @@ +from .config import Config from .data import demands, water_supply from .utils import read_config -__all__ = ["demands", "read_config", "water_supply"] +__all__ = ["Config", "demands", "read_config", "water_supply"] diff --git a/message_ix_models/model/water/build.py b/message_ix_models/model/water/build.py index 795b65a397..ebb2989b05 100644 --- a/message_ix_models/model/water/build.py +++ b/message_ix_models/model/water/build.py @@ -10,6 +10,7 @@ from message_ix_models.model.structure import get_codes from message_ix_models.util import broadcast, package_data_path +from .config import Config from .utils import filter_basins_by_region, read_config log = logging.getLogger(__name__) @@ -184,9 +185,10 @@ def cat_tec_cooling_calib( - 'tec': Name of the cooling technology. - regions_df: A list of unique region nodes from the scenario. """ + cfg = Config.from_context(context) FILE1 = ( "cooltech_cost_and_shares_" - + (f"ssp_msg_{context.regions}" if context.type_reg == "global" else "country") + + (f"ssp_msg_{context.regions}" if cfg.type_reg == "global" else "country") + ".csv" ) path1 = package_data_path("water", "ppl_cooling_tech", FILE1) @@ -288,7 +290,9 @@ def get_spec(context: Context) -> Mapping[str, ScenarioInfo]: # Elements to add add.set[set_name].extend(config.get("add", [])) - if context.nexus_set == "nexus": + cfg = Config.from_context(context) + + if cfg.nexus_set == "nexus": # Merge technology.yaml with set.yaml context["water set"]["nexus"]["technology"]["add"] = context[ "water technology" @@ -311,7 +315,7 @@ def get_spec(context: Context) -> Mapping[str, ScenarioInfo]: # Share commodity for groundwater results = {} - df_node = context.all_nodes + df_node = cfg.all_nodes n = len(df_node.values) d = { @@ -550,6 +554,7 @@ def map_basin(context: Context) -> Mapping[str, ScenarioInfo]: # define an empty dictionary results = {} + cfg = Config.from_context(context) # read csv file for basin names and region mapping # reading basin_delineation FILE = f"basins_by_region_simpl_{context.regions}.csv" @@ -564,8 +569,8 @@ def map_basin(context: Context) -> Mapping[str, ScenarioInfo]: df["node"] = "B" + df["BCU_name"].astype(str) df["mode"] = "M" + df["BCU_name"].astype(str) df["region"] = ( - context.map_ISO_c[context.regions] - if context.type_reg == "country" + cfg.map_ISO_c[context.regions] + if cfg.type_reg == "country" else f"{context.regions}_" + df["REGION"].astype(str) ) @@ -580,9 +585,9 @@ def map_basin(context: Context) -> Mapping[str, ScenarioInfo]: results["map_node"] = nodes - context.all_nodes = df["node"] + cfg.all_nodes = df["node"] # Store the filtered basin names for use in other functions - context.valid_basins = set(df["BCU_name"].astype(str)) + cfg.valid_basins = set(df["BCU_name"].astype(str)) for set_name, config in results.items(): # Sets to add @@ -604,7 +609,7 @@ def main(context: Context, scenario, **options): log.info("Set up MESSAGEix-Nexus") - if context.nexus_set == "nexus": + if Config.from_context(context).nexus_set == "nexus": # Add water balance spec = map_basin(context) diff --git a/message_ix_models/model/water/cli.py b/message_ix_models/model/water/cli.py index d85b3446fb..b0d6709d3f 100644 --- a/message_ix_models/model/water/cli.py +++ b/message_ix_models/model/water/cli.py @@ -8,6 +8,8 @@ from message_ix_models.model.structure import get_codes from message_ix_models.util.click import common_params, scenario_param +from .config import Config + if TYPE_CHECKING: from message_ix import Scenario log = logging.getLogger(__name__) @@ -39,6 +41,8 @@ def water_ini(context: "Context", regions, time): from .utils import read_config + config = Config.from_context(context) + # Ensure water model configuration is loaded read_config(context) if not context.scenario_info: @@ -61,9 +65,9 @@ def water_ini(context: "Context", regions, time): regions = "R12" # add an attribute to distinguish country models if regions in ["R11", "R12", "R14", "R32", "RCP"]: - context.type_reg = "global" + config.type_reg = "global" else: - context.type_reg = "country" + config.type_reg = "country" context.regions = regions # create a mapping ISO code : @@ -71,10 +75,10 @@ def water_ini(context: "Context", regions, time): # only needed for 1-country models n_codes = get_codes(f"node/{context.regions}") nodes = list(map(str, n_codes[n_codes.index(Code(id="World"))].child)) - if context.type_reg == "country": + if config.type_reg == "country": map_ISO_c = {context.regions: nodes[0]} - context.map_ISO_c = map_ISO_c - log.info(f"mapping {context.map_ISO_c[context.regions]}") + config.map_ISO_c = map_ISO_c + log.info(f"mapping {config.map_ISO_c[context.regions]}") # deinfe the timestep if not time: @@ -82,13 +86,13 @@ def water_ini(context: "Context", regions, time): time = sc_ref.set("time") sub_time = list(time[time != "year"]) if len(sub_time) == 0: - context.time = ["year"] + config.time = ["year"] else: - context.time = sub_time + config.time = sub_time else: - context.time = [time] - log.info(f"Using the following time-step for the water module: {context.time}") + config.time = [time] + log.info(f"Using the following time-step for the water module: {config.time}") # setting the time information in context @@ -163,12 +167,11 @@ def nexus_cli( water balance linking different water demands to supply. """ # Set basin filtering configuration on context - context.reduced_basin = reduced_basin - if filter_list: - context.filter_list = list(filter_list) - if num_basins is not None: - context.num_basins = num_basins - context.basin_selection = basin_selection + config = Config.from_context(context) + config.reduced_basin = reduced_basin + config.filter_list = list(filter_list or []) + config.num_basins = num_basins + config.basin_selection = basin_selection nexus(context, regions, rcps, sdgs, rels, macro) @@ -193,17 +196,18 @@ def nexus(context: "Context", regions, rcps, sdgs, rels, macro=False): Specifies the reliability of hydrological data ['low','mid','high'] """ # add input information to the class context - context.nexus_set = "nexus" + config = Config.from_context(context) + config.nexus_set = "nexus" if not context.regions: context.regions = regions - context.RCP = rcps - context.SDG = sdgs - context.REL = rels + config.RCP = rcps + config.SDG = sdgs + config.REL = rels log.info( - f"SSP assumption is {context.ssp}. SDG is {context.SDG}. " - f"RCP is {context.RCP}. REL is {context.REL}." + f"SSP assumption is {context.ssp}. SDG is {config.SDG}. " + f"RCP is {config.RCP}. REL is {config.REL}." ) from .build import main as build @@ -298,12 +302,13 @@ def cooling( Specifies the climate scenario used ['no_climate','6p0','2p6'] """ - context.nexus_set = "cooling" - context.RCP = rcps - context.REL = rels + config = Config.from_context(context) + config.nexus_set = "cooling" + config.RCP = rcps + config.REL = rels log.info( - f"SSP assumption is {context.ssp}. RCP is {context.RCP}. REL is {context.REL}." + f"SSP assumption is {context.ssp}. RCP is {config.RCP}. REL is {config.REL}." ) from .build import main as build diff --git a/message_ix_models/model/water/config.py b/message_ix_models/model/water/config.py new file mode 100644 index 0000000000..251bd7e089 --- /dev/null +++ b/message_ix_models/model/water/config.py @@ -0,0 +1,65 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Literal + +from message_ix_models.util.config import ConfigHelper + +if TYPE_CHECKING: + from message_ix_models import Context + + +@dataclass +class Config(ConfigHelper): + """Settings for :mod:`message_ix_models.model.water`.""" + + #: Water build variant. + nexus_set: Literal["cooling", "nexus"] = "nexus" + + #: Climate scenario used for water availability and cooling data. + RCP: str = "no_climate" + + #: Water SDG policy setting. + SDG: str = "baseline" + + #: Hydrological-data reliability setting. + REL: str = "low" + + #: Time slices handled by the water module. + time: list[str] = field(default_factory=lambda: ["year"]) + + #: Whether :attr:`regions` names a global node codelist or a single country. + type_reg: Literal["country", "global"] = "global" + + #: Single-country model mapping from region code to ISO code. + map_ISO_c: dict[str, str] = field(default_factory=dict) + + #: Enable basin filtering. + reduced_basin: bool = False + + #: Basins to add to the automatic reduced-basin selection. + filter_list: list[str] = field(default_factory=list) + + #: Number of basins per region to keep when :attr:`reduced_basin` is enabled. + num_basins: int | None = None + + #: Automatic reduced-basin selection method. + basin_selection: Literal["first_k", "stress"] = "first_k" + + #: Basin names retained after optional filtering. + valid_basins: set[str] = field(default_factory=set) + + #: Water-module basin node labels, populated during structural mapping. + all_nodes: Any = None + + @classmethod + def from_context(cls, context: "Context") -> "Config": + """Return the shared water configuration for `context`. + + Create `context.water` if missing, or convert a mapping to `Config`. + Repeated calls return the same `Config` instance. + """ + if "water" not in context: + context["water"] = cls() + elif isinstance(context["water"], dict): + context["water"] = cls.from_dict(context["water"]) + + return context["water"] diff --git a/message_ix_models/model/water/data/__init__.py b/message_ix_models/model/water/data/__init__.py index d118c174a5..9463aad062 100644 --- a/message_ix_models/model/water/data/__init__.py +++ b/message_ix_models/model/water/data/__init__.py @@ -9,6 +9,7 @@ import pandas as pd from message_ix_models import ScenarioInfo +from message_ix_models.model.water.config import Config from message_ix_models.util import add_par_data from .demands import add_irrigation_demand, add_sectoral_demands, add_water_availability @@ -57,12 +58,13 @@ def add_data(scenario, context: "Context", dry_run=False): info = ScenarioInfo(scenario) context["water build info"] = info + cfg = Config.from_context(context) data_funcs: list[DataFunc] = ( [add_water_supply, cool_tech, non_cooling_tec] - if context.nexus_set == "cooling" + if cfg.nexus_set == "cooling" else DATA_FUNCTIONS - if context.type_reg == "global" + if cfg.type_reg == "global" else DATA_FUNCTIONS_COUNTRY ) diff --git a/message_ix_models/model/water/data/demands.py b/message_ix_models/model/water/data/demands.py index 2bf9a1c140..6d4c42a065 100644 --- a/message_ix_models/model/water/data/demands.py +++ b/message_ix_models/model/water/data/demands.py @@ -9,6 +9,7 @@ import xarray as xr from message_ix import make_df +from message_ix_models.model.water.config import Config from message_ix_models.model.water.utils import KM3_TO_MCM from message_ix_models.util import broadcast, package_data_path @@ -178,12 +179,13 @@ def add_sectoral_demands(context: "Context") -> dict[str, pd.DataFrame]: results = {} # Reference to the water configuration + cfg = Config.from_context(context) info = context["water build info"] year_vtgs = tuple(range(2010, info.Y[0], 5)) # defines path to read in demand data region = f"{context.regions}" - sub_time = context.time + sub_time = cfg.time path = package_data_path("water", "demands", "harmonized", region, ".") # make sure all of the csvs have format, otherwise it might not work list_of_csvs = list(path.glob("ssp2_regional_*.csv")) @@ -220,14 +222,14 @@ def add_sectoral_demands(context: "Context") -> dict[str, pd.DataFrame]: df_dmds["time"] = "year" # Filter to only include basins that exist after basin filtering - df_dmds = df_dmds[df_dmds["node"].isin(context.valid_basins)] + df_dmds = df_dmds[df_dmds["node"].isin(cfg.valid_basins)] # Write final interpolated values as csv # df2_f.to_csv('final_interpolated_values.csv') # if we are using sub-annual timesteps we replace the rural and municipal # withdrawals and return flows with monthly data and also add industrial - if "year" not in context.time: + if "year" not in cfg.time: PATH = package_data_path( "water", "demands", "harmonized", region, "ssp2_m_water_demands.csv" ) @@ -245,7 +247,7 @@ def add_sectoral_demands(context: "Context") -> dict[str, pd.DataFrame]: df_m.columns = pd.Index(["year", "node", "variable", "value", "time"]) # Filter monthly data to only include valid basins - df_m = df_m[df_m["node"].isin(context.valid_basins)] + df_m = df_m[df_m["node"].isin(cfg.valid_basins)] # remove yearly parts from df_dms df_dmds = df_dmds[ @@ -307,9 +309,9 @@ def add_sectoral_demands(context: "Context") -> dict[str, pd.DataFrame]: ] ) - if context.SDG != "baseline": + if cfg.SDG != "baseline": # only if SDG exactly equal to SDG, otherwise other policies are possible - pol_scen = context.SDG + pol_scen = cfg.SDG if pol_scen == "SDG": # reading basin mapping to countries FILE2 = f"basins_country_{context.regions}.csv" @@ -769,6 +771,7 @@ def read_water_availability(context: "Context") -> Sequence[pd.DataFrame]: """ # Reference to the water configuration + cfg = Config.from_context(context) info = context["water build info"] # reading sample for assiging basins PATH = package_data_path( @@ -777,15 +780,15 @@ def read_water_availability(context: "Context") -> Sequence[pd.DataFrame]: df_x = pd.read_csv(PATH) # Filter to only include valid basins - df_x = df_x[df_x["BCU_name"].isin(context.valid_basins)] + df_x = df_x[df_x["BCU_name"].isin(cfg.valid_basins)] - if "year" in context.time: + if "year" in cfg.time: # Adding freshwater supply constraints # Reading data, the data is spatially and temprally aggregated from GHMs path1 = package_data_path( "water", "availability", - f"qtot_5y_{context.RCP}_{context.REL}_{context.regions}.csv", + f"qtot_5y_{cfg.RCP}_{cfg.REL}_{context.regions}.csv", ) # Read rcp 2.6 data df_sw = pd.read_csv(path1) @@ -794,7 +797,7 @@ def read_water_availability(context: "Context") -> Sequence[pd.DataFrame]: # Filter rows to valid basins using index positions from full list full_basin_df = pd.read_csv(PATH) valid_indices = full_basin_df[ - full_basin_df["BCU_name"].isin(context.valid_basins) + full_basin_df["BCU_name"].isin(cfg.valid_basins) ].index df_sw = df_sw.iloc[valid_indices] # Keep only rows for valid basins df_sw.reset_index(drop=True, inplace=True) @@ -817,7 +820,7 @@ def read_water_availability(context: "Context") -> Sequence[pd.DataFrame]: path1 = package_data_path( "water", "availability", - f"qr_5y_{context.RCP}_{context.REL}_{context.regions}.csv", + f"qr_5y_{cfg.RCP}_{cfg.REL}_{context.regions}.csv", ) # Read groundwater data @@ -847,7 +850,7 @@ def read_water_availability(context: "Context") -> Sequence[pd.DataFrame]: path1 = package_data_path( "water", "availability", - f"qtot_5y_m_{context.RCP}_{context.REL}_{context.regions}.csv", + f"qtot_5y_m_{cfg.RCP}_{cfg.REL}_{context.regions}.csv", ) df_sw = pd.read_csv(path1) df_sw.drop(["Unnamed: 0"], axis=1, inplace=True) @@ -855,7 +858,7 @@ def read_water_availability(context: "Context") -> Sequence[pd.DataFrame]: # Filter rows to valid basins full_basin_df = pd.read_csv(PATH) valid_indices = full_basin_df[ - full_basin_df["BCU_name"].isin(context.valid_basins) + full_basin_df["BCU_name"].isin(cfg.valid_basins) ].index df_sw = df_sw.iloc[valid_indices] df_sw.reset_index(drop=True, inplace=True) @@ -878,7 +881,7 @@ def read_water_availability(context: "Context") -> Sequence[pd.DataFrame]: path1 = package_data_path( "water", "availability", - f"qr_5y_m_{context.RCP}_{context.REL}_{context.regions}.csv", + f"qr_5y_m_{cfg.RCP}_{cfg.REL}_{context.regions}.csv", ) df_gw = pd.read_csv(path1) df_gw.drop(["Unnamed: 0"], axis=1, inplace=True) diff --git a/message_ix_models/model/water/data/infrastructure.py b/message_ix_models/model/water/data/infrastructure.py index 169e398eed..d8051348ec 100644 --- a/message_ix_models/model/water/data/infrastructure.py +++ b/message_ix_models/model/water/data/infrastructure.py @@ -9,6 +9,7 @@ from message_ix import make_df from message_ix_models import Context, ScenarioInfo +from message_ix_models.model.water.config import Config from message_ix_models.model.water.utils import ( ANNUAL_CAPACITY_FACTOR, KM3_TO_MCM, @@ -273,11 +274,12 @@ def add_infrastructure_techs(context: "Context") -> dict[str, pd.DataFrame]: ``context["water build info"]``, plus the additional year 2010. """ # Reference to the water configuration + cfg = Config.from_context(context) info = context["water build info"] # define an empty dictionary results = {} - sub_time = pd.Series(context.time) + sub_time = pd.Series(cfg.time) # load the scenario from context scen = context.get_scenario() @@ -293,14 +295,14 @@ def add_infrastructure_techs(context: "Context") -> dict[str, pd.DataFrame]: df_node = pd.read_csv(PATH) # Filter to only valid basins (already filtered in map_basin) - df_node = df_node[df_node["BCU_name"].isin(context.valid_basins)] + df_node = df_node[df_node["BCU_name"].isin(cfg.valid_basins)] # Assigning proper nomenclature df_node["node"] = "B" + df_node["BCU_name"].astype(str) df_node["mode"] = "M" + df_node["BCU_name"].astype(str) df_node["region"] = ( - context.map_ISO_c[context.regions] - if context.type_reg == "country" + cfg.map_ISO_c[context.regions] + if cfg.type_reg == "country" else f"{context.regions}_" + df_node["REGION"].astype(str) ) @@ -322,7 +324,7 @@ def add_infrastructure_techs(context: "Context") -> dict[str, pd.DataFrame]: df_elec = df[df["incmd"] == "electr"].reset_index() inp_df = start_creating_input_dataframe( - sdg=context.SDG, + sdg=cfg.SDG, df_node=df_node, df_non_elec=df_non_elec, df_dist=df_dist, @@ -383,9 +385,7 @@ def add_infrastructure_techs(context: "Context") -> dict[str, pd.DataFrame]: ) # Process distribution outputs using helper function - dist_out = _make_dist_output( - df_out_dist, scenario_info, df_node, sub_time, context.SDG - ) + dist_out = _make_dist_output(df_out_dist, scenario_info, df_node, sub_time, cfg.SDG) out_df = pd.concat([out_df, dist_out]) if not dist_out.empty else out_df results["output"] = out_df @@ -496,7 +496,7 @@ def add_infrastructure_techs(context: "Context") -> dict[str, pd.DataFrame]: df_var = df_inv[~df_inv["tec"].isin(techs)] df_var_dist = df_inv[df_inv["tec"].isin(techs)] - if context.SDG != "baseline": + if cfg.SDG != "baseline": for index, rows in df_var.iterrows(): # Check if this is a dummy technology use_same_year = is_dummy_technology(rows) @@ -646,6 +646,7 @@ def prepare_input_dataframe( df_elec: pd.DataFrame, ) -> defaultdict[Any, list]: result_dc = defaultdict(list) + cfg = Config.from_context(context) # Unit 1 KWh/m^3 = 10^3 GWh/Km^3 = 1 GWh/MCM, # Parkinson et al. # which is the only explanation as to how the model solved. @@ -654,7 +655,7 @@ def prepare_input_dataframe( # Check if this is a dummy technology (for distribution techs) use_same_year = is_dummy_technology(rows) - if context.SDG != "baseline": + if cfg.SDG != "baseline": inp = make_df( "input", technology=rows["tec"], @@ -774,7 +775,8 @@ def add_desalination(context: "Context") -> dict[str, pd.DataFrame]: """ # define an empty dictionary results = {} - sub_time = pd.Series(context.time) + cfg = Config.from_context(context) + sub_time = pd.Series(cfg.time) # Reference to the water configuration info = context["water build info"] @@ -801,7 +803,7 @@ def add_desalination(context: "Context") -> dict[str, pd.DataFrame]: df_desal = pd.read_csv(path) df_hist = pd.read_csv(path2) df_proj = pd.read_csv(path3) - df_proj = df_proj[df_proj["rcp"] == f"{context.RCP}"] + df_proj = df_proj[df_proj["rcp"] == f"{cfg.RCP}"] df_proj = df_proj[~(df_proj["year"] == 2065) & ~(df_proj["year"] == 2075)] df_proj.reset_index(inplace=True, drop=True) df_proj = df_proj[df_proj["year"].isin(info.Y)] @@ -813,20 +815,20 @@ def add_desalination(context: "Context") -> dict[str, pd.DataFrame]: df_node = pd.read_csv(PATH) # Filter to only valid basins (already filtered in map_basin) - df_node = df_node[df_node["BCU_name"].isin(context.valid_basins)] + df_node = df_node[df_node["BCU_name"].isin(cfg.valid_basins)] # Assigning proper nomenclature df_node["node"] = "B" + df_node["BCU_name"].astype(str) df_node["mode"] = "M" + df_node["BCU_name"].astype(str) df_node["region"] = ( - context.map_ISO_c[context.regions] - if context.type_reg == "country" + cfg.map_ISO_c[context.regions] + if cfg.type_reg == "country" else f"{context.regions}_" + df_node["REGION"].astype(str) ) # Filter to basins that exist after filtering - df_hist = df_hist[df_hist["BCU_name"].isin(context.valid_basins)] - df_proj = df_proj[df_proj["BCU_name"].isin(context.valid_basins)] + df_hist = df_hist[df_hist["BCU_name"].isin(cfg.valid_basins)] + df_proj = df_proj[df_proj["BCU_name"].isin(cfg.valid_basins)] # output dataframe linking to desal tech types out_df = ( diff --git a/message_ix_models/model/water/data/irrigation.py b/message_ix_models/model/water/data/irrigation.py index 5626aa2537..48c2f69a18 100644 --- a/message_ix_models/model/water/data/irrigation.py +++ b/message_ix_models/model/water/data/irrigation.py @@ -4,6 +4,7 @@ from message_ix import make_df from message_ix_models import Context +from message_ix_models.model.water.config import Config from message_ix_models.util import broadcast, package_data_path @@ -27,6 +28,7 @@ def add_irr_structure(context: "Context") -> dict[str, pd.DataFrame]: # define an empty dictionary results = {} + cfg = Config.from_context(context) # reading basin_delineation FILE2 = f"basins_by_region_simpl_{context.regions}.csv" @@ -34,14 +36,14 @@ def add_irr_structure(context: "Context") -> dict[str, pd.DataFrame]: df_node = pd.read_csv(PATH) # Filter to only include valid basins - df_node = df_node[df_node["BCU_name"].isin(context.valid_basins)] + df_node = df_node[df_node["BCU_name"].isin(cfg.valid_basins)] # Assigning proper nomenclature df_node["node"] = "B" + df_node["BCU_name"].astype(str) df_node["mode"] = "M" + df_node["BCU_name"].astype(str) df_node["region"] = ( - context.map_ISO_c[context.regions] - if context.type_reg == "country" + cfg.map_ISO_c[context.regions] + if cfg.type_reg == "country" else f"{context.regions}_" + df_node["REGION"].astype(str) ) diff --git a/message_ix_models/model/water/data/water_for_ppl.py b/message_ix_models/model/water/data/water_for_ppl.py index a63438c9df..8cb84c9014 100644 --- a/message_ix_models/model/water/data/water_for_ppl.py +++ b/message_ix_models/model/water/data/water_for_ppl.py @@ -9,6 +9,7 @@ from message_ix import make_df from message_ix_models import Context +from message_ix_models.model.water.config import Config from message_ix_models.model.water.data.water_supply import map_basin_region_wat from message_ix_models.model.water.utils import get_vintage_and_active_years from message_ix_models.util import ( @@ -43,11 +44,12 @@ def _load_scenario_and_cooling_data( scen : Scenario object cost_share_df : DataFrame with cost and share data """ + cfg = Config.from_context(context) # File paths tech_perf_path = package_data_path( "water", "ppl_cooling_tech", "tech_water_performance_ssp_msg.csv" ) - suffix = "ssp_msg_" + context.regions if context.type_reg == "global" else "country" + suffix = "ssp_msg_" + context.regions if cfg.type_reg == "global" else "country" cost_share_file = f"cooltech_cost_and_shares_{suffix}.csv" cost_share_path = package_data_path("water", "ppl_cooling_tech", cost_share_file) basin_path = package_data_path( @@ -59,8 +61,8 @@ def _load_scenario_and_cooling_data( df_node["node"] = "B" + df_node["BCU_name"].astype(str) df_node["mode"] = "M" + df_node["BCU_name"].astype(str) df_node["region"] = ( - context.map_ISO_c[context.regions] - if context.type_reg == "country" + cfg.map_ISO_c[context.regions] + if cfg.type_reg == "country" else f"{context.regions}_" + df_node["REGION"].astype(str) ) node_region = df_node["region"].unique() @@ -152,7 +154,8 @@ def _compute_cooling_rates(input_cool: pd.DataFrame) -> pd.DataFrame: def _make_input_params(input_cool: pd.DataFrame, context: "Context") -> pd.DataFrame: """Generate input parameter DataFrame.""" - commodity = "surfacewater" if context.nexus_set == "nexus" else "freshwater" + cfg = Config.from_context(context) + commodity = "surfacewater" if cfg.nexus_set == "nexus" else "freshwater" # Electricity input for parasitic demand electr = input_cool[input_cool["parasitic_electricity_demand_fraction"] > 0].copy() @@ -325,18 +328,19 @@ def _make_technical_lifetime( def _make_capacity_factor(inp: pd.DataFrame, context: "Context") -> pd.DataFrame: """Generate capacity_factor parameter with optional climate impacts.""" + cfg = Config.from_context(context) cap_fact = make_matched_dfs(inp, capacity_factor=1)["capacity_factor"] cap_fact["unit"] = "-" # capacity_factor is dimensionless cap_fact = cap_fact.drop_duplicates().reset_index(drop=True) - if context.RCP == "no_climate": + if cfg.RCP == "no_climate": return cap_fact # Apply climate impacts on freshwater cooling impact_path = package_data_path( "water", "ppl_cooling_tech", - f"power_plant_cooling_impact_MESSAGE_{context.regions}_{context.RCP}.csv", + f"power_plant_cooling_impact_MESSAGE_{context.regions}_{cfg.RCP}.csv", ) df_impact = pd.read_csv(impact_path) @@ -569,7 +573,7 @@ def _add_nexus_params( results: dict, context: "Context", node_region: list, info ) -> None: """Add basin-region distribution for nexus mode.""" - if context.nexus_set != "nexus": + if Config.from_context(context).nexus_set != "nexus": return df_sw = map_basin_region_wat(context) @@ -797,7 +801,11 @@ def non_cooling_tec(context: "Context", scenario=None) -> dict[str, pd.DataFrame merged = output_data.merge(non_cool, on="technology", how="right").dropna() - commodity = "surfacewater" if context.nexus_set == "nexus" else "freshwater" + commodity = ( + "surfacewater" + if Config.from_context(context).nexus_set == "nexus" + else "freshwater" + ) inp = make_df( "input", diff --git a/message_ix_models/model/water/data/water_supply.py b/message_ix_models/model/water/data/water_supply.py index 21e82ec0ad..707c084e71 100644 --- a/message_ix_models/model/water/data/water_supply.py +++ b/message_ix_models/model/water/data/water_supply.py @@ -5,6 +5,7 @@ from message_ix import make_df from message_ix_models import Context, ScenarioInfo +from message_ix_models.model.water.config import Config from message_ix_models.model.water.data.demands import read_water_availability from message_ix_models.model.water.utils import ( ANNUAL_CAPACITY_FACTOR, @@ -36,16 +37,17 @@ def map_basin_region_wat(context: "Context") -> pd.DataFrame: ------- data : pandas.DataFrame """ + cfg = Config.from_context(context) info = context["water build info"] - if "year" in context.time: + if "year" in cfg.time: PATH = package_data_path( "water", "delineation", f"basins_by_region_simpl_{context.regions}.csv" ) df_x_full = pd.read_csv(PATH) # Get positional indices of valid basins from the unfiltered list - valid_mask = df_x_full["BCU_name"].isin(context.valid_basins) + valid_mask = df_x_full["BCU_name"].isin(cfg.valid_basins) valid_indices = df_x_full[valid_mask].index df_x = df_x_full[valid_mask].reset_index(drop=True) @@ -54,7 +56,7 @@ def map_basin_region_wat(context: "Context") -> pd.DataFrame: path1 = package_data_path( "water", "availability", - f"qtot_5y_{context.RCP}_{context.REL}_{context.regions}.csv", + f"qtot_5y_{cfg.RCP}_{cfg.REL}_{context.regions}.csv", ) df_sw = pd.read_csv(path1) @@ -64,8 +66,8 @@ def map_basin_region_wat(context: "Context") -> pd.DataFrame: df_sw = df_sw.iloc[valid_indices].reset_index(drop=True) df_sw["BCU_name"] = df_x["BCU_name"] df_sw["MSGREG"] = ( - context.map_ISO_c[context.regions] - if context.type_reg == "country" + cfg.map_ISO_c[context.regions] + if cfg.type_reg == "country" else f"{context.regions}_" + df_sw["BCU_name"].str.split("|").str[-1] ) @@ -93,7 +95,7 @@ def map_basin_region_wat(context: "Context") -> pd.DataFrame: path3 = package_data_path( "water", "availability", - f"qtot_5y_m_{context.RCP}_{context.REL}_{context.regions}.csv", + f"qtot_5y_m_{cfg.RCP}_{cfg.REL}_{context.regions}.csv", ) df_sw = pd.read_csv(path3) @@ -104,7 +106,7 @@ def map_basin_region_wat(context: "Context") -> pd.DataFrame: df_x_full = pd.read_csv(PATH) # Get positional indices of valid basins from the unfiltered list - valid_mask = df_x_full["BCU_name"].isin(context.valid_basins) + valid_mask = df_x_full["BCU_name"].isin(cfg.valid_basins) valid_indices = df_x_full[valid_mask].index df_x = df_x_full[valid_mask].reset_index(drop=True) @@ -113,8 +115,8 @@ def map_basin_region_wat(context: "Context") -> pd.DataFrame: df_sw["BCU_name"] = df_x["BCU_name"] df_sw["MSGREG"] = ( - context.map_ISO_c[context.regions] - if context.type_reg == "country" + cfg.map_ISO_c[context.regions] + if cfg.type_reg == "country" else f"{context.regions}_" + df_sw["BCU_name"].str.split("|").str[-1] ) @@ -161,6 +163,7 @@ def add_water_supply(context: "Context") -> dict[str, pd.DataFrame]: results = {} # Reference to the water configuration + cfg = Config.from_context(context) info = context["water build info"] # load the scenario from context scen = context.get_scenario() @@ -168,7 +171,7 @@ def add_water_supply(context: "Context") -> dict[str, pd.DataFrame]: fut_year = info.Y year_wat = (*range(2010, info.Y[0] + 1, 5), *info.Y) last_vtg_yr = info.Y[0] - 5 - sub_time = context.time + sub_time = cfg.time scen_info = ScenarioInfo(scen) print(" future year = ", fut_year) @@ -187,8 +190,8 @@ def add_water_supply(context: "Context") -> dict[str, pd.DataFrame]: df_node["node"] = "B" + df_node["BCU_name"].astype(str) df_node["mode"] = "M" + df_node["BCU_name"].astype(str) df_node["region"] = ( - context.map_ISO_c[context.regions] - if context.type_reg == "country" + cfg.map_ISO_c[context.regions] + if cfg.type_reg == "country" else f"{context.regions}_" + df_node["REGION"].astype(str) ) @@ -200,8 +203,8 @@ def add_water_supply(context: "Context") -> dict[str, pd.DataFrame]: PATH1 = package_data_path("water", "availability", FILE1) df_gwt = pd.read_csv(PATH1) df_gwt["region"] = ( - context.map_ISO_c[context.regions] - if context.type_reg == "country" + cfg.map_ISO_c[context.regions] + if cfg.type_reg == "country" else f"{context.regions}_" + df_gwt["REGION"].astype(str) ) @@ -211,11 +214,11 @@ def add_water_supply(context: "Context") -> dict[str, pd.DataFrame]: df_hist = pd.read_csv(PATH2) # Filter to only include valid basins (nexus mode only) - if context.nexus_set == "nexus": - df_hist = df_hist[df_hist["BCU_name"].isin(context.valid_basins)] + if cfg.nexus_set == "nexus": + df_hist = df_hist[df_hist["BCU_name"].isin(cfg.valid_basins)] df_hist["BCU_name"] = "B" + df_hist["BCU_name"].astype(str) - if context.nexus_set == "cooling": + if cfg.nexus_set == "cooling": # Add output df for surfacewater supply for regions output_df = ( make_df( @@ -324,7 +327,7 @@ def add_water_supply(context: "Context") -> dict[str, pd.DataFrame]: results["input"] = pd.concat([inp_sw, inp_gw], ignore_index=True) - elif context.nexus_set == "nexus": + elif cfg.nexus_set == "nexus": # input data frame for slack technology balancing equality with demands inp = ( make_df( @@ -559,7 +562,7 @@ def add_water_supply(context: "Context") -> dict[str, pd.DataFrame]: ] ) - if context.type_reg == "global": + if cfg.type_reg == "global": inp.loc[ (inp["technology"].str.contains("extract_gw_fossil")) & (inp["year_act"] == 2020) @@ -995,6 +998,7 @@ def add_e_flow(context: "Context") -> dict[str, pd.DataFrame]: """ # define an empty dictionary results = {} + cfg = Config.from_context(context) info = context["water build info"] @@ -1008,10 +1012,9 @@ def add_e_flow(context: "Context") -> dict[str, pd.DataFrame]: ) df_x_full = pd.read_csv(PATH) # Index positions of valid basins in the full CSV (for positional CSV filtering) - valid_indices = df_x_full[df_x_full["BCU_name"].isin(context.valid_basins)].index - df_x = df_x_full[df_x_full["BCU_name"].isin(context.valid_basins)].reset_index( - drop=True - ) + valid_mask = df_x_full["BCU_name"].isin(cfg.valid_basins) + valid_indices = df_x_full[valid_mask].index + df_x = df_x_full[valid_mask].reset_index(drop=True) dmd_df = make_df( "demand", @@ -1026,12 +1029,12 @@ def add_e_flow(context: "Context") -> dict[str, pd.DataFrame]: dmd_df = dmd_df[dmd_df["year"] >= 2025].reset_index(drop=True) dmd_df["value"] = dmd_df["value"].apply(lambda x: x if x >= 0 else 0) - if "year" in context.time: + if "year" in cfg.time: # Reading data, the data is spatially and temporally aggregated from GHMs path1 = package_data_path( "water", "availability", - f"e-flow_{context.RCP}_{context.regions}.csv", + f"e-flow_{cfg.RCP}_{context.regions}.csv", ) df_env = pd.read_csv(path1) df_env.drop(["Unnamed: 0"], axis=1, inplace=True) @@ -1054,7 +1057,7 @@ def add_e_flow(context: "Context") -> dict[str, pd.DataFrame]: path1 = package_data_path( "water", "availability", - f"e-flow_5y_m_{context.RCP}_{context.regions}.csv", + f"e-flow_5y_m_{cfg.RCP}_{context.regions}.csv", ) df_env = pd.read_csv(path1) df_env.drop(["Unnamed: 0"], axis=1, inplace=True) @@ -1074,7 +1077,7 @@ def add_e_flow(context: "Context") -> dict[str, pd.DataFrame]: df_env = df_env[df_env["year"].isin(info.Y)] # Return a processed dataframe for env flow calculations - if context.SDG != "baseline": + if cfg.SDG != "baseline": # dataframe to put constraints on env flows eflow_df = make_df( "bound_activity_lo", diff --git a/message_ix_models/model/water/utils.py b/message_ix_models/model/water/utils.py index 624b316bc0..6eb49a6b37 100644 --- a/message_ix_models/model/water/utils.py +++ b/message_ix_models/model/water/utils.py @@ -14,6 +14,8 @@ from message_ix_models.model.structure import get_codes from message_ix_models.util import load_package_data, package_data_path +from .config import Config + log = logging.getLogger(__name__) if TYPE_CHECKING: @@ -117,19 +119,18 @@ def filter_basins_by_region( if not context: context = Context.get_instance(-1) - # Check if reduced basin filtering is enabled - reduced_basin = getattr(context, "reduced_basin", False) + cfg = Config.from_context(context) - if not reduced_basin: + if not cfg.reduced_basin: # No filtering, return original dataframe log.info("Basin filtering disabled, returning all basins") return df_basins # Basin filtering is enabled — run automatic selection, then augment with # filter_list if provided. - filter_list = getattr(context, "filter_list", None) - num_basins = getattr(context, "num_basins", None) - basin_selection = getattr(context, "basin_selection", "first_k") + filter_list = cfg.filter_list + num_basins = cfg.num_basins + basin_selection = cfg.basin_selection if num_basins is None: log.info(f"num_basins not set, using default n_per_region={n_per_region}") diff --git a/message_ix_models/tests/model/water/conftest.py b/message_ix_models/tests/model/water/conftest.py index cb1d5253c7..5f16f783b2 100644 --- a/message_ix_models/tests/model/water/conftest.py +++ b/message_ix_models/tests/model/water/conftest.py @@ -6,6 +6,7 @@ from message_ix_models import ScenarioInfo from message_ix_models.model.structure import get_codes +from message_ix_models.model.water.config import Config from message_ix_models.util import package_data_path REGION_CONFIG = { @@ -52,7 +53,7 @@ def setup_valid_basins(context, regions="R12"): df_filtered = filter_basins_by_region(df_basins, context) # Set valid_basins as set of basin names - context.valid_basins = set(df_filtered["BCU_name"].astype(str)) + Config.from_context(context).valid_basins = set(df_filtered["BCU_name"].astype(str)) return context @@ -70,29 +71,31 @@ def water_context(test_context, request): # Apply defaults test_context.regions = params.get("regions", "R11") - test_context.type_reg = params.get("type_reg", "global") - test_context.time = params.get("time", "year") - test_context.nexus_set = params.get("nexus_set", "nexus") + cfg = Config.from_context(test_context) + cfg.type_reg = params.get("type_reg", "global") + cfg.time = params.get("time", "year") + cfg.nexus_set = params.get("nexus_set", "nexus") # Optional attributes for attr in [ "RCP", "REL", "SDG", - "ssp", "reduced_basin", "basin_selection", "num_basins", "filter_list", ]: if attr in params: - setattr(test_context, attr, params[attr]) + setattr(cfg, attr, params[attr]) + if "ssp" in params: + test_context.ssp = params["ssp"] # Node mapping for country models - if test_context.type_reg == "country": + if cfg.type_reg == "country": nodes = get_codes(f"node/{test_context.regions}") nodes = list(map(str, nodes[nodes.index("World")].child)) - test_context.map_ISO_c = {test_context.regions: nodes[0]} + cfg.map_ISO_c = {test_context.regions: nodes[0]} # Set up valid_basins for basin filtering setup_valid_basins(test_context, regions=test_context.regions) diff --git a/message_ix_models/tests/model/water/test_build.py b/message_ix_models/tests/model/water/test_build.py index 92e8fcafc4..6204f904d4 100644 --- a/message_ix_models/tests/model/water/test_build.py +++ b/message_ix_models/tests/model/water/test_build.py @@ -5,6 +5,7 @@ from message_ix_models.model.structure import get_codes from message_ix_models.model.water.build import cat_tec_cooling, get_spec, map_basin from message_ix_models.model.water.build import main as build +from message_ix_models.model.water.config import Config @pytest.mark.xfail(reason="Temporary, for #106") @@ -16,15 +17,16 @@ def test_build(request, test_context): # TODO If all water functions require these keys, set this up in a central location # or via default value # Ensure test_context has all necessary keys for build() - test_context.nexus_set = "nexus" - test_context.type_reg = "global" - test_context.time = "year" + cfg = Config.from_context(test_context) + cfg.nexus_set = "nexus" + cfg.type_reg = "global" + cfg.time = ["year"] nodes = get_codes(f"node/{test_context.regions}") nodes = list(map(str, nodes[nodes.index("World")].child)) map_ISO_c = {test_context.regions: nodes[0]} - test_context.map_ISO_c = map_ISO_c - test_context.RCP = "6p0" - test_context.REL = "med" + cfg.map_ISO_c = map_ISO_c + cfg.RCP = "6p0" + cfg.REL = "med" test_context["water build info"] = ScenarioInfo(scenario_obj=scenario) # Code runs on the bare RES @@ -80,9 +82,10 @@ def parametrize_for_cat_tec(request, context): @pytest.mark.parametrize("nexus_set", ["nexus", "cooling"]) def test_get_spec(request, test_context, nexus_set): # Ensure test_context has all necessary keys for get_spec() - test_context.nexus_set = nexus_set + cfg = Config.from_context(test_context) + cfg.nexus_set = nexus_set test_context.model.regions = "R12" - test_context.type_reg = "global" + cfg.type_reg = "global" test_context = parametrize_for_cat_tec(request, test_context) diff --git a/message_ix_models/tests/model/water/test_config.py b/message_ix_models/tests/model/water/test_config.py new file mode 100644 index 0000000000..50285f101b --- /dev/null +++ b/message_ix_models/tests/model/water/test_config.py @@ -0,0 +1,12 @@ +from message_ix_models.model.water.config import Config + + +def test_config_from_context(test_context) -> None: + test_context.water = {"RCP": "2p6", "SDG": "SDG", "REL": "med"} + + cfg = Config.from_context(test_context) + + assert isinstance(cfg, Config) + assert cfg.RCP == "2p6" + assert cfg.SDG == "SDG" + assert cfg.REL == "med" diff --git a/message_ix_models/tests/model/water/test_report.py b/message_ix_models/tests/model/water/test_report.py index 3bfb7d001d..3e0794416e 100644 --- a/message_ix_models/tests/model/water/test_report.py +++ b/message_ix_models/tests/model/water/test_report.py @@ -8,6 +8,7 @@ from message_ix_models import ScenarioInfo from message_ix_models.model.structure import get_codes +from message_ix_models.model.water.config import Config from message_ix_models.model.water.report import ( ScenarioMetadata, aggregate_totals, @@ -23,8 +24,9 @@ def test_report_full(test_context: Any, request: pytest.FixtureRequest) -> None: # FIXME You probably want this to be part of a common setup rather than writing # something like this for every test - test_context.time = "year" - test_context.type_reg = "global" + cfg = Config.from_context(test_context) + cfg.time = ["year"] + cfg.type_reg = "global" test_context.regions = "R12" codes = get_codes(f"node/{test_context.regions}") world_code = [n for n in codes if str(n) == "World"][0] diff --git a/message_ix_models/tests/model/water/test_utils.py b/message_ix_models/tests/model/water/test_utils.py index 0bd9e9c881..366eb3e954 100644 --- a/message_ix_models/tests/model/water/test_utils.py +++ b/message_ix_models/tests/model/water/test_utils.py @@ -125,14 +125,16 @@ def test_n_per_region_respected(self, stress_df): @pytest.mark.parametrize("basin_selection", ["first_k", "stress"]) def test_filter_list_additive_on_automatic(test_context, basin_selection): """filter_list augments automatic selection rather than replacing it.""" + from message_ix_models.model.water.config import Config from message_ix_models.util import package_data_path df_basins = pd.read_csv( package_data_path("water", "delineation", "basins_by_region_simpl_R12.csv") ) - test_context.reduced_basin = True - test_context.basin_selection = basin_selection + cfg = Config.from_context(test_context) + cfg.reduced_basin = True + cfg.basin_selection = basin_selection test_context.regions = "R12" test_context.ssp = "SSP2" @@ -145,7 +147,7 @@ def test_filter_list_additive_on_automatic(test_context, basin_selection): extra_basins = list(all_basins - auto_basins)[:3] assert len(extra_basins) > 0, "Need basins outside automatic set" - test_context.filter_list = extra_basins + cfg.filter_list = extra_basins combined = filter_basins_by_region(df_basins, test_context, n_per_region=1) combined_basins = set(combined["BCU_name"]) diff --git a/message_ix_models/tests/util/test_click.py b/message_ix_models/tests/util/test_click.py index c638217b73..372951abd0 100644 --- a/message_ix_models/tests/util/test_click.py +++ b/message_ix_models/tests/util/test_click.py @@ -36,10 +36,10 @@ def func(ctx, rep_out_path): assert result.output.startswith(f"{expected}\n") -def test_regions(mix_models_cli): - """--regions=… used on both group and a command within the group. +def test_regions(mix_models_cli) -> None: + """--regions=… redundantly declared on a group and its subcommand. - If the option is not provided to the inner command, the value given to the outer + If the option is not provided to the subcommand, the value given to the parent group should persist. """ @@ -64,6 +64,26 @@ def inner(context, regions): assert "ZMB" == result.output.strip() +def test_scenario_param_preserves_outer_value(mix_models_cli) -> None: + """An unset subcommand --ssp does not overwrite the parent command's value.""" + + @click.group() + @scenario_param("--ssp", default="SSP2") + def outer(): + pass + + @outer.command() + @scenario_param("--ssp") + @click.pass_obj + def inner(context): + print(context.ssp) + + with temporary_command(cli_test_group, outer): + result = mix_models_cli.assert_exit_0(["_test", "outer", "--ssp=SSP1", "inner"]) + + assert "SSP1" == result.output.strip() + + @pytest.mark.parametrize( "args, command, expected", [ diff --git a/message_ix_models/util/click.py b/message_ix_models/util/click.py index a1bee15695..41a1a88fac 100644 --- a/message_ix_models/util/click.py +++ b/message_ix_models/util/click.py @@ -211,11 +211,17 @@ def store_context(context: click.Context | Context, param, value): Use this for parameters that are not used directly in a @click.command() function, but need to be carried by the Context for later use. """ - setattr( - context.obj if isinstance(context, click.Context) else context, - param.name, - value, - ) + target = context.obj if isinstance(context, click.Context) else context + + # In nested commands, an omitted inner option should not clear a value already set + # by an outer command. + if value is None and ( + (isinstance(target, Context) and param.name in target) + or (not isinstance(target, Context) and hasattr(target, param.name)) + ): + return value + + setattr(target, param.name, value) return value