Skip to content

Validate input channels match kernel in depthwise/separable conv#22687

Merged
hertschuh merged 6 commits intokeras-team:masterfrom
rstar327:fix-separable-conv-channels-first
Apr 24, 2026
Merged

Validate input channels match kernel in depthwise/separable conv#22687
hertschuh merged 6 commits intokeras-team:masterfrom
rstar327:fix-separable-conv-channels-first

Conversation

@rstar327
Copy link
Copy Markdown
Contributor

Description

Fixes #22516.

keras.ops.separable_conv and keras.ops.depthwise_conv previously gave confusing error messages when the input's channel dimension didn't match the kernel's expected input channels (e.g., passing channels-last shaped data with data_format="channels_first"):

  • Eager: TF backend error: input depth must be evenly divisible by filter depth: 5 vs 3
  • Symbolic: Kernel shape must have the same length as input

Neither message clearly identifies the root cause: a mismatch between the input's channel count and the kernel's expected input channels.

This PR adds explicit channel validation in both the symbolic path (compute_output_spec) and the eager path (before dispatching to the backend), producing a clear error:

ValueError: The number of input channels must match the kernel's input channels.
Received: input channels=5, kernel input channels=3, data_format='channels_first'.

Changes:

  • Added _check_input_channels_match_kernel helper in keras/src/ops/nn.py
  • Added channel validation in DepthwiseConv.compute_output_spec (symbolic path)
  • Added channel validation in depthwise_conv() and separable_conv() (eager path)

Contributor Agreement

  • I am a human, and not a bot.
  • I will be responsible for responding to review comments in a timely manner.
  • I will work with the maintainers to push this PR forward until submission.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a _check_input_channels_match_kernel helper function to validate input channels for depthwise and separable convolutions. The review feedback suggests enhancing this helper to support symbolic tensors with None dimensions, which would allow for its reuse in DepthwiseConv.compute_output_spec and eliminate redundant validation logic.

Comment thread keras/src/ops/nn.py Outdated
Comment on lines +1527 to +1538
def _check_input_channels_match_kernel(inputs, kernel, data_format):
input_channels = (
inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1]
)
kernel_input_channels = kernel.shape[-2]
if input_channels != kernel_input_channels:
raise ValueError(
"The number of input channels must match the kernel's "
f"input channels. Received: input channels="
f"{input_channels}, kernel input channels="
f"{kernel_input_channels}, data_format='{data_format}'."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve reusability and reduce code duplication, you can modify this helper function to also handle symbolic tensors, where shape dimensions can be None. This would allow you to use it in DepthwiseConv.compute_output_spec as well.

By adding is not None checks, this function becomes safe for both eager and symbolic execution paths.

def _check_input_channels_match_kernel(inputs, kernel, data_format):
    input_channels = (
        inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1]
    )
    kernel_input_channels = kernel.shape[-2]
    if (
        input_channels is not None
        and kernel_input_channels is not None
        and input_channels != kernel_input_channels
    ):
        raise ValueError(
            "The number of input channels must match the kernel's "
            f"input channels. Received: input channels="
            f"{input_channels}, kernel input channels="
            f"{kernel_input_channels}, data_format='{data_format}'."
        )

Comment thread keras/src/ops/nn.py Outdated
Comment on lines +1568 to +1585
input_channels = (
inputs.shape[-1]
if self.data_format == "channels_last"
else inputs.shape[1]
)
kernel_input_channels = kernel.shape[-2]
if (
input_channels is not None
and kernel_input_channels is not None
and input_channels != kernel_input_channels
):
raise ValueError(
"The number of input channels must match the kernel's "
f"input channels. Received: input channels="
f"{input_channels}, kernel input channels="
f"{kernel_input_channels}, data_format="
f"'{self.data_format}'."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Following the suggested change to _check_input_channels_match_kernel, you can replace this entire validation block with a single call to the updated helper function. This simplifies the code and removes duplication.

        _check_input_channels_match_kernel(inputs, kernel, self.data_format)

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 17, 2026

Codecov Report

❌ Patch coverage is 94.73684% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 83.03%. Comparing base (e94cb07) to head (2fae3a7).
⚠️ Report is 24 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/common/backend_utils.py 60.00% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #22687      +/-   ##
==========================================
+ Coverage   82.95%   83.03%   +0.08%     
==========================================
  Files         596      596              
  Lines       69252    69858     +606     
  Branches    10814    10884      +70     
==========================================
+ Hits        57451    58010     +559     
- Misses       8973     8999      +26     
- Partials     2828     2849      +21     
Flag Coverage Δ
keras 82.86% <94.73%> (+0.08%) ⬆️
keras-jax 58.35% <42.10%> (-0.37%) ⬇️
keras-numpy 54.23% <44.73%> (-0.34%) ⬇️
keras-openvino 59.77% <23.68%> (+0.34%) ⬆️
keras-tensorflow 59.90% <34.21%> (-0.38%) ⬇️
keras-torch 58.69% <44.73%> (-0.37%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@keerthanakadiri keerthanakadiri added the stat:awaiting keras-eng Awaiting response from Keras engineer label Apr 17, 2026
Copy link
Copy Markdown
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking into this!

Comment thread keras/src/ops/nn.py Outdated
return DepthwiseConv(
strides, padding, data_format, dilation_rate
).symbolic_call(inputs, kernel)
_check_input_channels_match_kernel(inputs, kernel, data_format)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason why we don't do any validation involving the inputs in the backend-independent code is that there are scenarios where this won't work:

  • inputs can be a nested Python array and won't have a shape
  • inputs can be a tracer / placeholder tensor and have dynamic dimensions

I would check which backend don't handle this correctly, then add this check for the backends that need it after the convert_to_tensor(inputs).

@hertschuh hertschuh added stat:awaiting response from contributor and removed stat:awaiting keras-eng Awaiting response from Keras engineer labels Apr 23, 2026
@rstar327
Copy link
Copy Markdown
Contributor Author

Thanks for the review! I've pushed a refactor addressing the feedback:

  • Removed the check from keras/src/ops/nn.py (both the standalone helper and the DepthwiseConv.compute_output_spec check), so the backend-independent path no longer touches inputs.shape.
  • Added a single shared helper check_depthwise_conv_input_channels in keras/src/backend/common/backend_utils.py.
  • Called it from each of the four backends (TF / JAX / Torch / NumPy) after convert_to_tensor(inputs), so we only run it against a concrete tensor. The JAX / Torch / NumPy separable_conv already delegate to their own depthwise_conv, so the check fires once per call.

Verified all four backends now raise the same clear ValueError on the #22516 repro, and the existing nn_test.py + separable_conv_test.py + depthwise_conv_test.py suites pass on all four backends. Let me know if you'd like any adjustments.

Comment on lines +321 to +327
if (
input_channels is None
or kernel_input_channels is None
or input_channels == kernel_input_channels
):
return
raise ValueError(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Write it like this:

if (
    isinstance(input_channels, int)
    and isinstance(kernel_input_channels, int)
    and input_channels != kernel_input_channels
):
    raise ValueError(...)

The reason is that dynamic dimensions can come in different forms, not just None, during tracing.

Comment on lines +853 to 861
inputs = convert_to_tensor(inputs)
kernel = convert_to_tensor(kernel)
check_depthwise_conv_input_channels(inputs, kernel, data_format)
num_spatial_dims = len(inputs.shape) - 2
if num_spatial_dims > 2:
raise ValueError(
"`inputs` rank must be 3 (1D conv) or 4 (2D conv). Received: "
f"{inputs.ndim}."
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move the check_depthwise_conv_input_channels after the check on num_spatial_dims.

The reason is that if the number of dimensions is incorrect, the math in check_depthwise_conv_input_channels will be incorrect and meaningless and may raise a confusing error.

Comment on lines +689 to +691
inputs = convert_to_tensor(inputs)
kernel = convert_to_tensor(kernel)
check_depthwise_conv_input_channels(inputs, kernel, data_format)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why you didn't do separable_conv?

Comment on lines +934 to 943
inputs = convert_to_tensor(inputs)
depthwise_kernel = convert_to_tensor(depthwise_kernel)
pointwise_kernel = convert_to_tensor(pointwise_kernel)
check_depthwise_conv_input_channels(inputs, depthwise_kernel, data_format)
num_spatial_dims = len(inputs.shape) - 2
if num_spatial_dims > 2:
raise ValueError(
"`num_spatial_dims` must be 1 or 2. Received: "
f"num_spatial_dims={num_spatial_dims}."
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move the check_depthwise_conv_input_channels after the check on num_spatial_dims.

The reason is that if the number of dimensions is incorrect, the math in check_depthwise_conv_input_channels will be incorrect and meaningless and may raise a confusing error.

Comment on lines +789 to +791
inputs = convert_to_tensor(inputs)
kernel = convert_to_tensor(kernel)
check_depthwise_conv_input_channels(inputs, kernel, data_format)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why you didn't do separable_conv?

Comment on lines +662 to +665
data_format = backend.standardize_data_format(data_format)
inputs = convert_to_tensor(inputs)
kernel = convert_to_tensor(kernel)
check_depthwise_conv_input_channels(inputs, kernel, data_format)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why you didn't do separable_conv?

@rstar327
Copy link
Copy Markdown
Contributor Author

Thanks for the review! Pushed 2fae3a72d addressing all four points:

  1. check_depthwise_conv_input_channels now uses isinstance(..., int) instead of an explicit None check, so dynamic/tracing dims in other forms fall through to the no-op path instead of accidentally passing the validation.
  2. TF depthwise_conv / separable_conv: moved the channel check after the num_spatial_dims > 2 guard so we don't run wrong-axis math on a misranked input.
  3. JAX / Torch / NumPy separable_conv: added the check explicitly. My original reasoning was that these three backends' separable_conv delegates to their own depthwise_conv (which already has the check), but you're right that making it explicit is cleaner and safer against future refactors.

Re-verified the #22516 repro raises the clear error on all four backends for both depthwise_conv and separable_conv, and the existing nn_test.py depthwise/separable tests (24 each) pass on TF / JAX / Torch / NumPy.

@hertschuh
Copy link
Copy Markdown
Collaborator

  1. JAX / Torch / NumPy separable_conv: added the check explicitly. My original reasoning was that these three backends' separable_conv delegates to their own depthwise_conv (which already has the check), but you're right that making it explicit is cleaner and safer against future refactors.

Oh, I did not realize that.

Copy link
Copy Markdown
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution!

@google-ml-butler google-ml-butler Bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Apr 24, 2026
@hertschuh hertschuh merged commit 634e895 into keras-team:master Apr 24, 2026
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kokoro:force-run ready to pull Ready to be merged into the codebase size:S

Projects

None yet

Development

Successfully merging this pull request may close these issues.

keras.ops.separable_conv mishandles data_format shape validation and kernel shape inference

5 participants