diff --git a/opacus/tests/module_validator_test.py b/opacus/tests/module_validator_test.py index 4e0d9b4ce..26cd73336 100644 --- a/opacus/tests/module_validator_test.py +++ b/opacus/tests/module_validator_test.py @@ -159,3 +159,46 @@ def test_fix_bn_with_args(self) -> None: with self.assertRaises(ValueError): ModuleValidator.fix(m, replace_bn_with_in=True, num_groups=4) + + def test_validate_error_includes_module_name(self): + """Verify that validation errors include the module path and type.""" + model = nn.Sequential( + OrderedDict( + [ + ("fc", nn.Linear(4, 8)), + ("bn", nn.BatchNorm1d(8)), + ] + ) + ) + errors = ModuleValidator.validate(model) + self.assertGreater(len(errors), 0) + # The error message should contain the module name "bn" and type "BatchNorm1d" + error_message = str(errors[0]) + self.assertIn("bn", error_message) + self.assertIn("BatchNorm1d", error_message) + + def test_validate_error_includes_nested_module_name(self): + """Verify that validation errors include the full nested module path.""" + model = nn.Sequential( + OrderedDict( + [ + ( + "block", + nn.Sequential( + OrderedDict( + [ + ("linear", nn.Linear(4, 8)), + ("norm", nn.BatchNorm1d(8)), + ] + ) + ), + ), + ] + ) + ) + errors = ModuleValidator.validate(model) + self.assertGreater(len(errors), 0) + # The error message should contain the full path "block.norm" + error_message = str(errors[0]) + self.assertIn("block.norm", error_message) + self.assertIn("BatchNorm1d", error_message) diff --git a/opacus/validators/module_validator.py b/opacus/validators/module_validator.py index b3d5fb58a..c08bcfc7d 100644 --- a/opacus/validators/module_validator.py +++ b/opacus/validators/module_validator.py @@ -59,11 +59,17 @@ def validate( IllegalModuleConfigurationError("Model needs to be in training mode") ) # 2. perform module specific validations for trainable modules. - # TODO: use module name here - it's useful part of error message - for _, sub_module in trainable_modules(module): + for module_name, sub_module in trainable_modules(module): if type(sub_module) in ModuleValidator.VALIDATORS: sub_module_validator = ModuleValidator.VALIDATORS[type(sub_module)] - errors.extend(sub_module_validator(sub_module)) + sub_errors = sub_module_validator(sub_module) + for err in sub_errors: + # Prepend module name to error message for easier debugging + err.args = ( + f"{module_name} ({type(sub_module).__name__}): {err.args[0]}", + *err.args[1:], + ) + errors.extend(sub_errors) # raise/return as needed if strict and len(errors) > 0: raise UnsupportedModuleError(errors)