Skip to content

Validate channel count in rgb_to_grayscale and RandomColorDegeneration#22685

Merged
hertschuh merged 3 commits intokeras-team:masterfrom
rstar327:validate-rgb-to-grayscale-channels
Apr 20, 2026
Merged

Validate channel count in rgb_to_grayscale and RandomColorDegeneration#22685
hertschuh merged 3 commits intokeras-team:masterfrom
rstar327:validate-rgb-to-grayscale-channels

Conversation

@rstar327
Copy link
Copy Markdown
Contributor

Description

Fixes #22536.

keras.ops.image.rgb_to_grayscale and keras.layers.RandomColorDegeneration previously failed inconsistently when the channel axis size did not match the expected RGB count (3 for RGB or 1 for grayscale):

  • Eager execution crashed with a confusing Matrix size-incompatible error (TF/JAX/NumPy backends).
  • Symbolic execution silently succeeded, even though the configuration was invalid.

This PR adds explicit channel validation:

  • RGBToGrayscale.compute_output_spec now validates the channel axis size (matches RGBToHSV behavior).
  • rgb_to_grayscale in TF, JAX, and NumPy backends now validates the channel axis size and treats single-channel input as a no-op (matches existing Torch and OpenVINO behavior).
  • RandomColorDegeneration.compute_output_shape now validates the channel axis size so the symbolic path raises a clear error before runtime.
  • Added a unit test covering the invalid channel case.

Contributor Agreement

  • I am a human, and not a bot.
  • I will be responsible for responding to review comments in a timely manner.
  • I will work with the maintainers to push this PR forward until submission.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements channel and rank validation for image operations across JAX, NumPy, and TensorFlow backends, specifically targeting rgb_to_grayscale and related preprocessing layers. Feedback was provided to ensure the TensorFlow backend correctly handles dynamic shapes where the channel dimension may be None during graph construction. Additionally, it is recommended to enhance error messages in the RandomColorDegeneration layer to include the full input shape for better debugging context, aligning with Keras API design guidelines.

Comment on lines +61 to +67
if images.shape[channels_axis] not in (1, 3):
raise ValueError(
"Invalid channel size: expected 3 (RGB) or 1 (Grayscale). "
f"Received input with shape: images.shape={images.shape}"
)
if images.shape[channels_axis] == 1:
return tf.identity(images)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The channel validation logic will raise a ValueError if the channel dimension is dynamic (i.e., None). In TensorFlow, images.shape[channels_axis] can be None during graph construction or when using tf.function with unknown shapes. The validation should only be performed if the channel count is statically known.

Suggested change
if images.shape[channels_axis] not in (1, 3):
raise ValueError(
"Invalid channel size: expected 3 (RGB) or 1 (Grayscale). "
f"Received input with shape: images.shape={images.shape}"
)
if images.shape[channels_axis] == 1:
return tf.identity(images)
channels = images.shape[channels_axis]
if channels is not None and channels not in (1, 3):
raise ValueError(
"Invalid channel size: expected 3 (RGB) or 1 (Grayscale). "
f"Received input with shape: images.shape={images.shape}"
)
if channels == 1:
return tf.identity(images)

Comment on lines +142 to +154
if len(input_shape) not in (3, 4):
raise ValueError(
"Invalid images rank: expected rank 3 (single image) "
"or rank 4 (batch of images). "
f"Received: input_shape={input_shape}"
)
channels_axis = -1 if self.data_format == "channels_last" else -3
channels = input_shape[channels_axis]
if channels is not None and channels != 3:
raise ValueError(
"Input images must have 3 channels, but received images with "
f"{channels} channels."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error messages in compute_output_shape should be more informative and consistent with other image operations in Keras. Specifically, including the full input shape in the channel validation error provides better context for debugging. Also, using images.shape instead of input_shape in the rank error message matches the convention used in RGBToGrayscale.

Suggested change
if len(input_shape) not in (3, 4):
raise ValueError(
"Invalid images rank: expected rank 3 (single image) "
"or rank 4 (batch of images). "
f"Received: input_shape={input_shape}"
)
channels_axis = -1 if self.data_format == "channels_last" else -3
channels = input_shape[channels_axis]
if channels is not None and channels != 3:
raise ValueError(
"Input images must have 3 channels, but received images with "
f"{channels} channels."
)
if len(input_shape) not in (3, 4):
raise ValueError(
"Invalid images rank: expected rank 3 (single image) "
"or rank 4 (batch of images). "
f"Received: images.shape={input_shape}"
)
channels_axis = -1 if self.data_format == "channels_last" else -3
channels = input_shape[channels_axis]
if channels is not None and channels != 3:
raise ValueError(
"Input images must have 3 channels, but received images with "
f"{channels} channels. "
f"Received input with shape: images.shape={input_shape}"
)
References
  1. Error messages should be contextual, informative, and actionable. Providing the full shape helps the user understand why the validation failed. (link)

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 16, 2026

Codecov Report

❌ Patch coverage is 56.52174% with 10 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.95%. Comparing base (e94cb07) to head (1d55990).

Files with missing lines Patch % Lines
...g/image_preprocessing/random_color_degeneration.py 33.33% 2 Missing and 2 partials ⚠️
keras/src/backend/jax/image.py 50.00% 1 Missing and 1 partial ⚠️
keras/src/backend/numpy/image.py 50.00% 1 Missing and 1 partial ⚠️
keras/src/backend/tensorflow/image.py 50.00% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #22685      +/-   ##
==========================================
- Coverage   82.95%   82.95%   -0.01%     
==========================================
  Files         596      596              
  Lines       69252    69272      +20     
  Branches    10814    10822       +8     
==========================================
+ Hits        57451    57465      +14     
- Misses       8973     8976       +3     
- Partials     2828     2831       +3     
Flag Coverage Δ
keras 82.77% <56.52%> (-0.01%) ⬇️
keras-jax 58.71% <39.13%> (-0.01%) ⬇️
keras-numpy 54.55% <39.13%> (-0.01%) ⬇️
keras-openvino 59.42% <30.43%> (-0.01%) ⬇️
keras-tensorflow 60.27% <39.13%> (-0.01%) ⬇️
keras-torch 59.04% <30.43%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@rstar327
Copy link
Copy Markdown
Contributor Author

Done

Copy link
Copy Markdown
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fixes!

@google-ml-butler google-ml-butler Bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Apr 20, 2026
@hertschuh hertschuh merged commit 66426ea into keras-team:master Apr 20, 2026
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kokoro:force-run ready to pull Ready to be merged into the codebase size:M

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RandomColorDegeneration fails on eager tensor inputs with matrix size-incompatible error

5 participants