Embedding.from_config does not validate that input_dim and `output_d#22716
Embedding.from_config does not validate that input_dim and `output_d#22716cantenesse wants to merge 4 commits intokeras-team:masterfrom
Embedding.from_config does not validate that input_dim and `output_d#22716Conversation
…lidate that `input_
…lidate that `input_
…put_dim` (and `outp
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
There was a problem hiding this comment.
Code Review
This pull request introduces explicit type validation for input_dim and output_dim in the Embedding layer, ensuring they are Python integers during both initialization and deserialization via from_config. Corresponding unit tests have been added to verify that non-integer values trigger a ValueError. The review feedback points out that because booleans are a subclass of int in Python, the current isinstance checks will silently accept True or False as valid dimensions. Suggestions were provided to explicitly exclude booleans from the validation logic and to include boolean edge cases in the test suite.
| if not isinstance(input_dim, int): | ||
| raise ValueError( | ||
| "`input_dim` must be a Python int. " | ||
| f"Received: input_dim={input_dim!r} (of type {type(input_dim)})" | ||
| ) | ||
| if not isinstance(output_dim, int): | ||
| raise ValueError( | ||
| "`output_dim` must be a Python int. " | ||
| f"Received: output_dim={output_dim!r} (of type {type(output_dim)})" | ||
| ) |
There was a problem hiding this comment.
In Python, bool is a subclass of int, so isinstance(value, int) returns True for booleans. This allows True or False to be silently accepted as valid dimensions (e.g., input_dim=True becomes 1), which is almost certainly a user error. To ensure robust validation of "Python int" dimensions, booleans should be explicitly excluded.
| if not isinstance(input_dim, int): | |
| raise ValueError( | |
| "`input_dim` must be a Python int. " | |
| f"Received: input_dim={input_dim!r} (of type {type(input_dim)})" | |
| ) | |
| if not isinstance(output_dim, int): | |
| raise ValueError( | |
| "`output_dim` must be a Python int. " | |
| f"Received: output_dim={output_dim!r} (of type {type(output_dim)})" | |
| ) | |
| if not isinstance(input_dim, int) or isinstance(input_dim, bool): | |
| raise ValueError( | |
| "`input_dim` must be a Python int. " | |
| f"Received: input_dim={input_dim!r} (of type {type(input_dim)})" | |
| ) | |
| if not isinstance(output_dim, int) or isinstance(output_dim, bool): | |
| raise ValueError( | |
| "`output_dim` must be a Python int. " | |
| f"Received: output_dim={output_dim!r} (of type {type(output_dim)})" | |
| ) |
| def test_input_dim_output_dim_must_be_int(self): | ||
| with self.assertRaisesRegex(ValueError, "input_dim"): | ||
| layers.Embedding(input_dim=3.7, output_dim=2) | ||
| with self.assertRaisesRegex(ValueError, "output_dim"): | ||
| layers.Embedding(input_dim=3, output_dim=2.0) | ||
| # Valid ints should not raise | ||
| layers.Embedding(input_dim=3, output_dim=2) |
There was a problem hiding this comment.
It is recommended to add test cases for boolean values to ensure they are correctly rejected, as they are technically subclasses of int in Python but invalid as layer dimensions.
| def test_input_dim_output_dim_must_be_int(self): | |
| with self.assertRaisesRegex(ValueError, "input_dim"): | |
| layers.Embedding(input_dim=3.7, output_dim=2) | |
| with self.assertRaisesRegex(ValueError, "output_dim"): | |
| layers.Embedding(input_dim=3, output_dim=2.0) | |
| # Valid ints should not raise | |
| layers.Embedding(input_dim=3, output_dim=2) | |
| def test_input_dim_output_dim_must_be_int(self): | |
| for val in [3.7, True, False]: | |
| with self.assertRaisesRegex(ValueError, "input_dim"): | |
| layers.Embedding(input_dim=val, output_dim=2) | |
| for val in [2.0, True, False]: | |
| with self.assertRaisesRegex(ValueError, "output_dim"): | |
| layers.Embedding(input_dim=3, output_dim=val) | |
| # Valid ints should not raise | |
| layers.Embedding(input_dim=3, output_dim=2) |
|
Please accept the CLA and Contributor Agreement |
|
Fixed by #22718 |
Problem
Embedding.from_configdoes not validate thatinput_dimandoutput_dimare Python ints before constructing the layer, allowing float values (e.g. 3.7) to be silently accepted as valid configuration, potentially causing silent truncation or unexpected behavior when the layer is built.Approach
Add explicit type validation in
Embedding.from_config(keras/src/layers/core/embedding.py) to check thatinput_dimandoutput_dimin the config dict are Pythonintinstances before callingsuper().from_config(config). Raise a descriptiveValueErrormatching the style of the existing validation in__init__if either value is not anint. Add corresponding test cases inembedding_test.pythat callEmbedding.from_configwith float values and assert aValueErroris raised.Review comments
WARN
keras/src/layers/core/embedding.py:348: Ifinput_dimoroutput_dimis absent from the config dict,config.get(dim_name)returnsNone, causing the validator to raise "input_dimmust be a Python int. Received: input_dim=None (of type <class 'NoneType'>)". A missing required key is a different error condition from a wrong type, and the current message is misleading. Consider checkingif dim_name not in configfirst and raising a KeyError or a more descriptive ValueError, or at minimum noting the missing-key case in the message.NIT
keras/src/layers/core/embedding.py:348: Python'sboolis a subclass ofint, soisinstance(True, int)returnsTrue.input_dim=Trueoroutput_dim=Falsewould silently pass validation as 1 and 0 respectively. This is consistent with the__init__validation added in earlier sessions, but if stricter rejection is desired, addand not isinstance(value, bool)to each check.NIT
keras/src/layers/core/embedding_test.py:882:test_from_config_accepts_valid_int_dimspasses a bare minimal config{"input_dim": 4, "output_dim": 3}missing many keys thatfrom_configmay expect (e.g.quantization_configdeserialization path). Ifserialization_lib.deserialize_keras_object(None)or subsequent layer construction raises for other reasons, this test fails for reasons unrelated to type validation. Consider deriving the config fromEmbedding(4, 3).get_config()to test the real round-trip path.Generated by swe session #4