Validate channel count in rgb_to_grayscale and RandomColorDegeneration#22685
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements channel and rank validation for image operations across JAX, NumPy, and TensorFlow backends, specifically targeting rgb_to_grayscale and related preprocessing layers. Feedback was provided to ensure the TensorFlow backend correctly handles dynamic shapes where the channel dimension may be None during graph construction. Additionally, it is recommended to enhance error messages in the RandomColorDegeneration layer to include the full input shape for better debugging context, aligning with Keras API design guidelines.
| if images.shape[channels_axis] not in (1, 3): | ||
| raise ValueError( | ||
| "Invalid channel size: expected 3 (RGB) or 1 (Grayscale). " | ||
| f"Received input with shape: images.shape={images.shape}" | ||
| ) | ||
| if images.shape[channels_axis] == 1: | ||
| return tf.identity(images) |
There was a problem hiding this comment.
The channel validation logic will raise a ValueError if the channel dimension is dynamic (i.e., None). In TensorFlow, images.shape[channels_axis] can be None during graph construction or when using tf.function with unknown shapes. The validation should only be performed if the channel count is statically known.
| if images.shape[channels_axis] not in (1, 3): | |
| raise ValueError( | |
| "Invalid channel size: expected 3 (RGB) or 1 (Grayscale). " | |
| f"Received input with shape: images.shape={images.shape}" | |
| ) | |
| if images.shape[channels_axis] == 1: | |
| return tf.identity(images) | |
| channels = images.shape[channels_axis] | |
| if channels is not None and channels not in (1, 3): | |
| raise ValueError( | |
| "Invalid channel size: expected 3 (RGB) or 1 (Grayscale). " | |
| f"Received input with shape: images.shape={images.shape}" | |
| ) | |
| if channels == 1: | |
| return tf.identity(images) |
| if len(input_shape) not in (3, 4): | ||
| raise ValueError( | ||
| "Invalid images rank: expected rank 3 (single image) " | ||
| "or rank 4 (batch of images). " | ||
| f"Received: input_shape={input_shape}" | ||
| ) | ||
| channels_axis = -1 if self.data_format == "channels_last" else -3 | ||
| channels = input_shape[channels_axis] | ||
| if channels is not None and channels != 3: | ||
| raise ValueError( | ||
| "Input images must have 3 channels, but received images with " | ||
| f"{channels} channels." | ||
| ) |
There was a problem hiding this comment.
The error messages in compute_output_shape should be more informative and consistent with other image operations in Keras. Specifically, including the full input shape in the channel validation error provides better context for debugging. Also, using images.shape instead of input_shape in the rank error message matches the convention used in RGBToGrayscale.
| if len(input_shape) not in (3, 4): | |
| raise ValueError( | |
| "Invalid images rank: expected rank 3 (single image) " | |
| "or rank 4 (batch of images). " | |
| f"Received: input_shape={input_shape}" | |
| ) | |
| channels_axis = -1 if self.data_format == "channels_last" else -3 | |
| channels = input_shape[channels_axis] | |
| if channels is not None and channels != 3: | |
| raise ValueError( | |
| "Input images must have 3 channels, but received images with " | |
| f"{channels} channels." | |
| ) | |
| if len(input_shape) not in (3, 4): | |
| raise ValueError( | |
| "Invalid images rank: expected rank 3 (single image) " | |
| "or rank 4 (batch of images). " | |
| f"Received: images.shape={input_shape}" | |
| ) | |
| channels_axis = -1 if self.data_format == "channels_last" else -3 | |
| channels = input_shape[channels_axis] | |
| if channels is not None and channels != 3: | |
| raise ValueError( | |
| "Input images must have 3 channels, but received images with " | |
| f"{channels} channels. " | |
| f"Received input with shape: images.shape={input_shape}" | |
| ) |
References
- Error messages should be contextual, informative, and actionable. Providing the full shape helps the user understand why the validation failed. (link)
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #22685 +/- ##
==========================================
- Coverage 82.95% 82.95% -0.01%
==========================================
Files 596 596
Lines 69252 69272 +20
Branches 10814 10822 +8
==========================================
+ Hits 57451 57465 +14
- Misses 8973 8976 +3
- Partials 2828 2831 +3
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
9c6f06e to
7b5c21b
Compare
|
Done |
hertschuh
left a comment
There was a problem hiding this comment.
Thanks for the fixes!
Description
Fixes #22536.
keras.ops.image.rgb_to_grayscaleandkeras.layers.RandomColorDegenerationpreviously failed inconsistently when the channel axis size did not match the expected RGB count (3 for RGB or 1 for grayscale):Matrix size-incompatibleerror (TF/JAX/NumPy backends).This PR adds explicit channel validation:
RGBToGrayscale.compute_output_specnow validates the channel axis size (matchesRGBToHSVbehavior).rgb_to_grayscalein TF, JAX, and NumPy backends now validates the channel axis size and treats single-channel input as a no-op (matches existing Torch and OpenVINO behavior).RandomColorDegeneration.compute_output_shapenow validates the channel axis size so the symbolic path raises a clear error before runtime.Contributor Agreement