diff --git a/message_ix_models/tests/test_report.py b/message_ix_models/tests/test_report.py index ba8fe16327..e00ed7e2a9 100644 --- a/message_ix_models/tests/test_report.py +++ b/message_ix_models/tests/test_report.py @@ -2,7 +2,7 @@ import re from importlib.metadata import version -from typing import List +from typing import List, Optional import numpy as np import pandas as pd @@ -315,10 +315,37 @@ def test_prepare_reporter(test_context): assert 14299 <= len(rep.graph) - N +# Filters for comparison +PE0 = r"Primary Energy\|(Coal|Gas|Hydro|Nuclear|Solar|Wind)" +PE1 = r"Primary Energy\|(Coal|Gas|Solar|Wind)" +E = ( + r"Emissions\|CO2\|Energy\|Demand\|Transportation\|Road Rail and Domestic " + "Shipping" +) + +IGNORE = [ + # Other 'variable' codes are missing from `obs` + re.compile(f"variable='(?!{PE0}).*': no right data"), + # 'variable' codes with further parts are missing from `obs` + re.compile(f"variable='{PE0}.*': no right data"), + # For `pe1` (NB: not Hydro or Solar) units and most values differ + re.compile(f"variable='{PE1}.*': units mismatch .*EJ/yr.*'', nan"), + re.compile(r"variable='Primary Energy|Coal': 220 of 240 values with \|diff"), + re.compile(r"variable='Primary Energy|Gas': 234 of 240 values with \|diff"), + re.compile(r"variable='Primary Energy|Solar': 191 of 240 values with \|diff"), + re.compile(r"variable='Primary Energy|Wind': 179 of 240 values with \|diff"), + # For `e` units and most values differ + re.compile(f"variable='{E}': units mismatch: .*Mt CO2/yr.*Mt / a"), + re.compile(rf"variable='{E}': 20 missing right entries"), + re.compile(rf"variable='{E}': 220 of 240 values with \|diff"), +] + + @to_simulate.minimum_version def test_compare(test_context): """Compare the output of genno-based and legacy reporting.""" - key = "pe test" + key = "all::iamc" + # key = "pe test" # Obtain the output from reporting `key` on `snapshot_id` snapshot_id: int = 1 @@ -340,24 +367,8 @@ def test_compare(test_context): engine="pyarrow", ) - # Filters for comparison - pe0 = r"Primary Energy\|(Coal|Gas|Hydro|Nuclear|Solar|Wind)" - pe1 = r"Primary Energy\|(Coal|Gas|Solar|Wind)" - ignore = [ - # Other 'variable' codes are missing from `obs` - re.compile(f"variable='(?!{pe0}).*': no right data"), - # 'variable' codes with further parts are missing from `obs` - re.compile(f"variable='{pe0}.*': no right data"), - # For `pe1` (NB: not Hydro or Solar) units and most values differ - re.compile(f"variable='{pe1}.*': units mismatch .*EJ/yr.*'', nan"), - re.compile(r"variable='Primary Energy|Coal': 220 of 240 values with \|diff"), - re.compile(r"variable='Primary Energy|Gas': 234 of 240 values with \|diff"), - re.compile(r"variable='Primary Energy|Solar': 191 of 240 values with \|diff"), - re.compile(r"variable='Primary Energy|Wind': 179 of 240 values with \|diff"), - ] - # Perform the comparison, ignoring some messages - if messages := compare_iamc(exp, obs, ignore=ignore): + if messages := compare_iamc(exp, obs, ignore=IGNORE): # Other messages that were not explicitly ignored → some error print("\n".join(messages)) assert False @@ -369,8 +380,8 @@ def compare_iamc( """Compare IAMC-structured data in `left` and `right`; return a list of messages.""" result = [] - def record(message: str) -> None: - if any(p.match(message) for p in ignore): + def record(message: str, condition: Optional[bool] = True) -> None: + if not condition or any(p.match(message) for p in ignore): return result.append(message) @@ -388,16 +399,29 @@ def checks(df: pd.DataFrame): "value_rel = value_diff / value_left" ) + na_left = tmp.isna()[["unit_left", "value_left"]] + if na_left.any(axis=None): + record(f"{prefix} {na_left.sum(axis=0).max()} missing left entries") + tmp = tmp[~na_left.any(axis=1)] + na_right = tmp.isna()[["unit_right", "value_right"]] + if na_right.any(axis=None): + record(f"{prefix} {na_right.sum(axis=0).max()} missing right entries") + tmp = tmp[~na_right.any(axis=1)] + units_left = set(tmp.unit_left.unique()) units_right = set(tmp.unit_right.unique()) - if units_left != units_right: - record(f"{prefix} units mismatch: {units_left} != {units_right}") + record( + condition=units_left != units_right, + message=f"{prefix} units mismatch: {units_left} != {units_right}", + ) N0 = len(df) mask1 = tmp.query("abs(value_diff) > @atol") - if len(mask1): - record(f"{prefix} {len(mask1)} of {N0} values with |diff| > {atol}") + record( + condition=len(mask1), + message=f"{prefix} {len(mask1)} of {N0} values with |diff| > {atol}", + ) for (model, scenario), group_0 in left.merge( right,