diff --git a/message_ix/tools/graph_viz_tool/dynamic_graph.py b/message_ix/tools/graph_viz_tool/dynamic_graph.py new file mode 100644 index 000000000..f56c17753 --- /dev/null +++ b/message_ix/tools/graph_viz_tool/dynamic_graph.py @@ -0,0 +1,106 @@ +from typing import List + +import pandas as pd +import pyvis as pv +from message_ix_models.util import load_package_data + +pyvis_hierarchy_config = { + "enabled": True, + "direction": "LR", + "blockShifting": True, + "edgeMinimization": False, + "parentCentralization": False, + "levelSeparation": 150, + "treeSpacing": 200, +} + +T_D_TECS_COLOR = "#f7cf26" + + +def load_technology_annotations(): + tec_codes = load_package_data("technology") + tec_descriptions = {k: v.get("description") for k, v in tec_codes.items()} + return tec_descriptions + + +class DynamicGraph: + def __init__(self, commodities: List[str]): + self.graph = pv.network.Network( + layout="reingold_tilford", + select_menu=True, + directed=True, + heading="commodity flow graph", + ) + self.graph.options.physics.enabled = False + self.graph.options.layout.hierarchical = pyvis_hierarchy_config + self.tec_annotations = load_technology_annotations() + self.commodities = commodities + + def add_graph_elements( + self, levels: pd.DataFrame, tecs_in: pd.DataFrame, tecs_out: pd.DataFrame + ): + level_location = {k: 2 * v for v, k in enumerate(levels.level)} + for row in levels.iterrows(): + row = row[1] + self.graph.add_node( + row.level, row.level, shape="box", level=level_location.get(row.level) + ) + + for row in tecs_in.iterrows(): + row = row[1] + self.graph.add_node( + row.technology, + row.technology, + level=level_location.get(row.level_in) + 1, + title=self.tec_annotations.get(row.technology), + ) + self.graph.add_edge(row.level_in, row.technology, title=row.comm_in) + + for row in tecs_out.iterrows(): + row = row[1] + self.graph.add_node( + row.technology, + row.technology, + level=level_location.get(row.level_out) - 1, + title=self.tec_annotations.get(row.technology), + ) + self.graph.add_edge(row.technology, row.level_out, title=row.comm_out) + + def highlight_transmission_tecs(self, tecs: pd.DataFrame): + # highlight transmission technologies with different color + transmission_tecs = tecs[ + (tecs["comm_in"] == tecs["comm_out"]) + & (tecs["comm_out"].isin(self.commodities)) + ].technology.to_list() + for node in self.graph.nodes: + if node.get("label") in transmission_tecs: + node["color"] = T_D_TECS_COLOR + + +if __name__ == "__main__": + import ixmp + import message_ix + import yaml + from graphviz_res import GraphData + + mp = ixmp.Platform("") + scen = message_ix.Scenario(mp, model="", scenario="") + + # Load the YAML configuration file + with open("graph_config.yaml", "r") as file: + config = yaml.safe_load(file) + + scenario_data = GraphData(scen, config["graph"]) + + commodities = ["coal"] + graph = DynamicGraph(commodities) + levels = scenario_data.levels[scenario_data.levels["commodity"].isin(commodities)] + tecs_in = scenario_data.technologies[ + (scenario_data.technologies["comm_in"].isin(commodities)) + ] + tecs_out = scenario_data.technologies[ + (scenario_data.technologies["comm_out"].isin(commodities)) + ] + graph.add_graph_elements(levels, tecs_in, tecs_out) + graph.highlight_transmission_tecs(scenario_data.technologies) + graph.graph.save_graph("dynamic_commodity_graph_rendering.html") diff --git a/message_ix/tools/graph_viz_tool/graph_config.yaml b/message_ix/tools/graph_viz_tool/graph_config.yaml new file mode 100644 index 000000000..2bd9082e4 --- /dev/null +++ b/message_ix/tools/graph_viz_tool/graph_config.yaml @@ -0,0 +1,20 @@ +# config.yaml + +graph: + node: + shape: "rect" + style: "filled" + fontsize: "20" + height: "0.005" + width: "0.0001" + colors: + default: "lightblue" + level: "lightgrey" + phantom: "invisible" + levels: + primary: 0 + secondary: 15 + final: 35 + useful: 60 + technologies: + [] \ No newline at end of file diff --git a/message_ix/tools/graph_viz_tool/graphviz_multi_comm.py b/message_ix/tools/graph_viz_tool/graphviz_multi_comm.py new file mode 100644 index 000000000..d7e12648b --- /dev/null +++ b/message_ix/tools/graph_viz_tool/graphviz_multi_comm.py @@ -0,0 +1,299 @@ +import os +import warnings + +import graphviz +import pandas as pd + +warnings.filterwarnings("ignore", category=RuntimeWarning) + + +# -------------------------- +# Load or extract I/O (caching logic preserved) +# -------------------------- +def load_or_extract_io(scen, model, scenario, node, year, commodities): + in_file = f"{model}_{scenario}_inputs.csv" + out_file = f"{model}_{scenario}_outputs.csv" + dem_file = f"{model}_{scenario}_demand.csv" + + if ( + os.path.exists(in_file) + and os.path.exists(out_file) + and os.path.exists(dem_file) + ): + print(f"Reading cached data: {in_file}, {out_file}, {dem_file}") + df_in = pd.read_csv(in_file) + df_out = pd.read_csv(out_file) + df_dem = pd.read_csv(dem_file) + return df_in, df_out, df_dem + + print("Extracting from scenario…") + df_in = scen.par("input", {"node_loc": node, "year_act": year}).copy() + df_out = scen.par("output", {"node_loc": node, "year_act": year}).copy() + df_dem = scen.par("demand", {"node": node, "year": year}).copy() + + # Normalize column names + df_in = df_in.rename(columns={"commodity": "comm_in", "level": "level_in"}) + df_out = df_out.rename(columns={"commodity": "comm_out", "level": "level_out"}) + df_dem = df_dem.rename(columns={"commodity": "comm", "level": "level"}) + + # Filter to commodities of interest + if commodities: + df_in = df_in[df_in["comm_in"].isin(commodities)].copy() + df_out = df_out[df_out["comm_out"].isin(commodities)].copy() + df_dem = df_dem[df_dem["comm"].isin(commodities)].copy() + + # Save to CSV + df_in.to_csv(in_file, index=False) + df_out.to_csv(out_file, index=False) + df_dem.to_csv(dem_file, index=False) + print(f"Saved to {in_file}, {out_file}, {dem_file}") + + return df_in, df_out, df_dem + + +# -------------------------- +# Small helper: enforce comm_* and level_* names (safe-guard) +# -------------------------- +def ensure_comm_level_cols(df_in, df_out, df_dem=None): + if "comm_in" not in df_in.columns: + if "commodity" in df_in.columns: + df_in = df_in.rename(columns={"commodity": "comm_in"}) + else: + df_in["comm_in"] = "" + if "level_in" not in df_in.columns: + if "level" in df_in.columns: + df_in = df_in.rename(columns={"level": "level_in"}) + else: + df_in["level_in"] = "" + + if "comm_out" not in df_out.columns: + if "commodity" in df_out.columns: + df_out = df_out.rename(columns={"commodity": "comm_out"}) + else: + df_out["comm_out"] = "" + if "level_out" not in df_out.columns: + if "level" in df_out.columns: + df_out = df_out.rename(columns={"level": "level_out"}) + else: + df_out["level_out"] = "" + + if df_dem is not None: + if "comm" not in df_dem.columns and "commodity" in df_dem.columns: + df_dem = df_dem.rename(columns={"commodity": "comm"}) + if "level" not in df_dem.columns: + df_dem["level"] = "" + + return df_in, df_out, df_dem + + +# -------------------------- +# Update Graphviz plotting +# -------------------------- +def plot_flows_graphviz(df_in, df_out, df_dem, model, scenario, commodities): + dot = graphviz.Digraph(comment=f"{model} {scenario} flows", format="png") + dot.attr(rankdir="LR", splines="ortho") + + # Set of demand commodity.level + demand_comms = ( + {f"{r.comm}.{r.level}" for r in df_dem.itertuples()} + if not df_dem.empty + else set() + ) + + # --- Build sets of techs connected to main commodities --- + techs_in = df_in[df_in["comm_in"].isin(commodities)]["technology"].unique() + techs_out = df_out[df_out["comm_out"].isin(commodities)]["technology"].unique() + techs = set(techs_in).union(techs_out) + + # --- Main commodities actually connected --- + main_in = df_in[df_in["technology"].isin(techs)][["comm_in", "level_in"]] + main_out = df_out[df_out["technology"].isin(techs)][["comm_out", "level_out"]] + + main_comms = { + f"{r.comm_in}.{r.level_in}" + for r in main_in.itertuples() + if r.comm_in in commodities + } | { + f"{r.comm_out}.{r.level_out}" + for r in main_out.itertuples() + if r.comm_out in commodities + } + + # --- Extra commodities connected to those techs --- + extra_in = { + f"{r.comm_in}.{r.level_in}" + for r in main_in.itertuples() + if r.comm_in not in commodities + } + extra_out = { + f"{r.comm_out}.{r.level_out}" + for r in main_out.itertuples() + if r.comm_out not in commodities + } + extra_comms = extra_in | extra_out + + # --- Add commodity nodes --- + for c in main_comms: + if c in demand_comms: + dot.node( + c, + label=c, + shape="ellipse", + style="filled", + fillcolor="violet", + color="black", + fontcolor="black", + ) + else: + dot.node( + c, + label=c, + shape="ellipse", + style="filled", + fillcolor="lightgray", + color="black", + fontcolor="black", + ) + + for c in extra_comms: + dot.node(c, label=c, shape="ellipse", color="red", fontcolor="red") + + # --- Technology nodes --- + for t in techs: + dot.node(t, shape="box", style="rounded,filled", fillcolor="lightblue") + + # --- Deduplicate edges by (src, dst, type) --- + edges_seen = {} + for _, row in df_in[df_in["technology"].isin(techs)].iterrows(): + src = f"{row.comm_in}.{row.level_in}" + key = (src, row.technology, "in") + year_vtg = row.get("year_vtg", row.year_act) + edges_seen.setdefault(key, []).append(year_vtg) + + for _, row in df_out[df_out["technology"].isin(techs)].iterrows(): + dst = f"{row.comm_out}.{row.level_out}" + key = (row.technology, dst, "out") + year_vtg = row.get("year_vtg", row.year_act) + edges_seen.setdefault(key, []).append(year_vtg) + + # --- Draw edges --- + for (src, dst, typ), years in edges_seen.items(): + latest = max(years) + multiple = len(set(years)) > 1 + style = "dashed" if multiple else "solid" + color = "red" if (src in extra_comms or dst in extra_comms) else "black" + dot.edge(src, dst, style=style, color=color) + + # --- Legend --- + with dot.subgraph(name="cluster_legend") as c: + c.attr(label="Legend", fontsize="10", rankdir="LR") + c.node("solid_arrow", label="single vintage", shape="plaintext") + c.edge("solid_arrow", "dashed_arrow", style="solid") + c.node("dashed_arrow", label="multiple vintages", shape="plaintext") + c.edge("solid_arrow", "dashed_arrow", style="dashed") + + # Region/year info + c.node("region_year", label=f"Region: {NODE}, Year: {YEAR}", shape="plaintext") + + # Boxes meaning + c.node( + "tech_box", + label="Technology", + shape="box", + style="rounded,filled", + fillcolor="lightblue", + ) + c.node( + "comm_grey", + label="Commodity", + shape="ellipse", + style="filled", + fillcolor="lightgray", + ) + c.node( + "comm_violet", + label="Demand commodity", + shape="ellipse", + style="filled", + fillcolor="violet", + ) + # c.node( + # "comm_red", + # label="Extra linked commodity", + # shape="ellipse", + # color="red", + # fontcolor="red", + # ) + + # --- Save --- + png_name = f"{model}_{scenario}_ascii_flows" + svg_name = f"{model}_{scenario}_ascii_flows" + + dot.format = "png" + dot.render(filename=png_name, cleanup=True) + dot.format = "svg" + dot.render(filename=svg_name, cleanup=True) + + print(f"Saved Graphviz PNG: {png_name}.png, SVG: {svg_name}.svg") + + +# -------------------------- +# Main +# -------------------------- +if __name__ == "__main__": + """ + Note: need to delete the csv files if you are re-running with different + commodities + """ + # ==== User-editable configuration ==== + PLATFORM_NAME = "" # only used if CSV cache missing + MODEL = "" # e.g. "MESSAGEix-Nexus" + SCENARIO = "" + NODE = "R12_CHN" + YEAR = 2050 + # Define the commodities you want (only these will be kept) + COMMODITIES = [ + "electr", + "coal", + ] + # ===================================== + + in_file = f"{MODEL}_{SCENARIO}_inputs.csv" + out_file = f"{MODEL}_{SCENARIO}_outputs.csv" + dem_file = f"{MODEL}_{SCENARIO}_demand.csv" + + # If cached CSVs exist, load without connecting to ixmp + if ( + os.path.exists(in_file) + and os.path.exists(out_file) + and os.path.exists(dem_file) + ): + print("Loading cached CSVs…") + df_in = pd.read_csv(in_file) + df_out = pd.read_csv(out_file) + df_dem = pd.read_csv(dem_file) + # Normalize column names (old CSVs may be different) + df_in, df_out, df_dem = ensure_comm_level_cols(df_in, df_out, df_dem) + # Apply commodity filter again (defensive) + if COMMODITIES: + df_in = df_in[df_in["comm_in"].isin(COMMODITIES)].copy() + df_out = df_out[df_out["comm_out"].isin(COMMODITIES)].copy() + else: + # Need to connect and extract + import ixmp + + import message_ix + + mp = ixmp.Platform(PLATFORM_NAME) + scen = message_ix.Scenario(mp, model=MODEL, scenario=SCENARIO) + df_in, df_out, df_dem = load_or_extract_io( + scen, MODEL, SCENARIO, NODE, YEAR, COMMODITIES + ) + # Explicit cleanup + mp.close_db() + + # Final safety normalization (ensure col names exist) + df_in, df_out, df_dem = ensure_comm_level_cols(df_in, df_out, df_dem) + + # Plots + plot_flows_graphviz(df_in, df_out, df_dem, MODEL, SCENARIO, COMMODITIES) diff --git a/message_ix/tools/graph_viz_tool/graphviz_res.py b/message_ix/tools/graph_viz_tool/graphviz_res.py new file mode 100644 index 000000000..11fe4c99c --- /dev/null +++ b/message_ix/tools/graph_viz_tool/graphviz_res.py @@ -0,0 +1,340 @@ +# this is a python implementation of the VBA RES tool developed by a former IIASA +# collaborator from Iran (name?). +# The original tool (including VBA source code) can be downloaded here: +# https://github.com/user-attachments/files/18109634/Fw_.VBA.Macros.to.Visualize.RES.diagram.zip + +import random + +import pandas as pd +from graphviz import Digraph + +import message_ix + + +def random_color_hex() -> str: + return "#{:06x}".format(random.randint(0, 0xFFFFFF)) + + +class GraphData: + def __init__(self, scenario: message_ix.Scenario, config): + inp, out = self.get_scenario_data(scenario, config) + self.levels = self.gen_level_data(out, config) + self.technologies = self.gen_technology_data(inp, out, config) + + def get_scenario_data(self, scenario, config): + filters = {"level": config.get("levels").keys()} + if len(config.get("technologies")): + filters["technology"] = config.get("technologies").keys() + out = scenario.par("output", filters=filters) + inp = scenario.par("input", filters=filters) + + # filter out flows with 0 values + out = out[out["value"] != 0] + inp = inp[inp["value"] != 0] + return inp, out + + def gen_technology_data( + self, inp: pd.DataFrame, out: pd.DataFrame, config + ) -> pd.DataFrame: + """Create technology dataframe required to generate RES diagram. + + Takes unique combinations of level, commodity and technology columns from + input and output parameters of a scenario. Joins them for output and input and + generates a y-coordinate for each combination. Y coordinates are assigned by + sorting the technologies by input commodity. + + Parameters + ---------- + inp + Input parameter data from a message_ix scenario + out + Output parameter data from a message_ix scenario + + Returns + ------- + pd.DataFrame + """ + + tec_out = out[["technology", "commodity", "level"]].drop_duplicates() + tec_out = tec_out.rename( + columns={"commodity": "comm_out", "level": "level_out"} + ) + tec_inp = inp[["technology", "commodity", "level"]].drop_duplicates() + tec_inp = tec_inp.rename(columns={"commodity": "comm_in", "level": "level_in"}) + + tecs = tec_inp.set_index("technology").join(tec_out.set_index("technology")) + tecs = tecs[ + (tecs["level_in"].isin(config.get("levels").keys())) + & (tecs["level_out"].isin(config.get("levels").keys())) + ] + tecs = tecs.reset_index() + tecs_new = pd.DataFrame() + for level in tecs["level_in"].unique(): + prims = ( + tecs[tecs["level_in"] == level].copy(deep=True).sort_values(["comm_in"]) + ) + prims["y_coord"] = prims.reset_index().index.to_list() + tecs_new = pd.concat([tecs_new, pd.DataFrame(prims)]) + return tecs_new + + def gen_level_data(self, out: pd.DataFrame, config) -> pd.DataFrame: + """Create commodity level dataframe required to generate RES diagram. + + Takes unique combinations of level-commodity columns. + X-coordinates are assigned by grouping the commodities by level. + + Parameters + ---------- + out + Output parameter data from a message_ix scenario + + Returns + ------- + pd.DataFrame + """ + + levels = out[["commodity", "level"]].drop_duplicates().set_index("level") + levels_new = pd.DataFrame() + for level in levels.index.get_level_values(0): + prims = levels.loc[[level]].copy(deep=True) + prims["rank"] = prims.rank() + levels_new = pd.concat([levels_new, pd.DataFrame(prims)]) + levels_new = levels_new.reset_index().drop_duplicates() + levels_new["x_coord"] = levels_new.apply( + lambda x: x["rank"] + config.get("levels").get(x.level, 0), axis=1 + ) + return levels_new + + +class GraphBuilder: + def __init__(self): + self.graph = Digraph(engine="neato") + self.color_dict = {} + + def add_level_nodes(self, levels: pd.DataFrame) -> None: + """Add the commodity levels to the RES stored in the "levels" DataFrame. + + Each commodity level pair needs to be mapped to an x-coordinate. + + Parameters + ---------- + dot + Graph to add node to + levels + DataFrame that contains [commodity, level and x_coord] as columns + color_map + Dictionary that contains a color mapping for each commodity in levels + """ + for row in levels.iterrows(): + row = row[1] + x = row["x_coord"] + id = row["commodity"] + "-" + row["level"] + label = row["commodity"] + y = 55 + self.graph.node( + f"{id}", + label="", + xlabel=label, + labelloc="t", + shape="rect", + style="filled", + height=f"{y * 2}", + width="0.0001", + pos=f"{x},{y}!", + color=self.color_dict[label], + ) + + def add_phantom_level_node( + self, id: str, x: int, y: int, hide: bool = True + ) -> None: + """Add a "level" phantom node to the RES at the given coordinates. + + Levels are represented as thin vertical lines in the RES. + + Parameters + ---------- + id + Identifier of new level node + x + X-coordinate of node + y + Y-coordinate of node + hide + phantom nodes are hidden by default, can be set to False for debugging + """ + + label = "" if hide else id + style = "invisible" if hide else "solid" + self.graph.node( + id, + label=label, + pos=f"{x},{y}!", + shape="rect", + height="0.005", + width="0.0001", + style=style, + ) + + def add_tec_node( + self, name: str, x: int, y: int, height: float, log: bool = False + ) -> None: + """Add a technology node to the RES at the specified coordinates. + + Technologies are represented as rectangles in the RES with given height. + + Parameters + ---------- + name + Identifier of new node + x + X-coordinate of node + y + Y-coordinate of node + height + height of node shape + log + If True print out name and y value for debugging + """ + + if log: + print(name, y) + self.graph.node( + name, + label=name, + fontsize="20", + shape="rect", + style="filled", + color="lightblue", + pos=f"{x},{y}!", + height=f"{height}", + ) + + def add_phantom_tec_node( + self, id: str, x: int, y: int, label: str, hide: bool = True + ) -> None: + """Add a "technology" phantom node to the RES at the given coordinates. + + Parameters + ---------- + dot + Graph to add node to + name + Identifier of new node + x + X-coordinate of node + y + Y-coordinate of node + label + (hidden) label of node + hide + If False do not hide phantom nodes to debug placement issues + """ + style = "invisible" if hide else "solid" + self.graph.node( + id, label=label, fontsize="20", pos=f"{x},{y}!", shape="rect", style=style + ) + + def add_edge(self, origin: str, destination: str, color: str) -> None: + self.graph.edge(origin, destination, color=color) + + def build_graph(self, graph_data: GraphData) -> None: + levels = graph_data.levels + tecs = graph_data.technologies + self.color_dict = {k: random_color_hex() for k in levels["commodity"].unique()} + self.add_level_nodes(levels) + + phantom_node_in_prev = None + phantom_node_out_prev = None + for tec in tecs.technology.unique(): + if tec == "furnace_biomass_refining": + continue + rows = tecs[tecs["technology"] == tec] + y = rows["y_coord"].min() * 1.5 + name = tec + tec_x_coord = ( + levels[levels["level"] == rows["level_in"].values[0]].x_coord.max() + + 2.5 + ) + height = rows.index.size - 1 + self.add_tec_node( + name, tec_x_coord, y + height * 0.25, 0.3 + height * 0.35, log=False + ) + + for i, row in enumerate(rows.iterrows()): + y_in = y + i / 4 + y_out = y + i / 4 - 0.1 + row = row[1] + phantom_in_x_coord = levels[ + (levels["level"] == row["level_in"]) + & (levels["commodity"] == row["comm_in"]) + ].x_coord.values[0] + phantom_node_id_in = f"{row['level_in']}_{name}_{row['comm_in']}" + if phantom_node_id_in != phantom_node_in_prev: + self.add_phantom_level_node( + phantom_node_id_in, phantom_in_x_coord, y_in, hide=True + ) + self.add_phantom_tec_node( + phantom_node_id_in + "_tec", tec_x_coord, y_in, name, hide=True + ) + self.add_edge( + phantom_node_id_in, + phantom_node_id_in + "_tec", + self.color_dict[row["comm_in"]], + ) + phantom_node_in_prev = phantom_node_id_in + + phantom_out_x_coord = levels[ + (levels["level"] == row["level_out"]) + & (levels["commodity"] == row["comm_out"]) + ].x_coord.values[0] + phantom_node_id_out = f"{row['level_out']}_{name}_{row['comm_out']}" + if phantom_node_id_out != phantom_node_out_prev: + self.add_phantom_level_node( + phantom_node_id_out, phantom_out_x_coord, y_out, hide=True + ) + self.add_phantom_tec_node( + phantom_node_id_out + "_tec", + tec_x_coord, + y_out, + name, + hide=True, + ) + self.add_edge( + phantom_node_id_out + "_tec", + phantom_node_id_out, + self.color_dict[row["comm_out"]], + ) + phantom_node_out_prev = phantom_node_id_out + + def render(self, filename: str) -> None: + self.graph.render(filename, format="svg", cleanup=False) + + def calculate_level_line_height(self): + # the vertical lines that represent the commodities at each level + # need to be long enough to fit all the technology nodes. At the moment, + # the height is hardcoded and manually adjusted. + raise NotImplementedError + + def calculate_node_height(self): + # if technology nodes should be different height based on the amount of connected edges + # then node height needs to be calculated + raise NotImplementedError + + +if __name__ == "__main__": + import ixmp + import message_ix + import yaml + + # Load the YAML configuration file + with open("graph_config.yaml", "r") as file: + config = yaml.safe_load(file) + + mp = ixmp.Platform("") + scen = message_ix.Scenario(mp, model="", scenario="") + + scenario_data = GraphData(scen, config["graph"]) + + graph_builder = GraphBuilder() + graph_builder.build_graph(scenario_data) + graph_builder.render("res_graphviz_rendering")