Skip to content

Fix TextVectorization custom callables receiving TF tensors on non-TF backends#22682

Closed
divyashreepathihalli wants to merge 6 commits intomasterfrom
fix-text-vectorization-custom-callable-tf-tensor-7273620509560818728
Closed

Fix TextVectorization custom callables receiving TF tensors on non-TF backends#22682
divyashreepathihalli wants to merge 6 commits intomasterfrom
fix-text-vectorization-custom-callable-tf-tensor-7273620509560818728

Conversation

@divyashreepathihalli
Copy link
Copy Markdown
Collaborator

@divyashreepathihalli divyashreepathihalli commented Apr 15, 2026

When using TextVectorization with a custom standardize or split callable and a non-TensorFlow backend (like JAX), the callable received a tf.EagerTensor. This caused issues because users expected either a plain Python string or at least something that doesn't require TensorFlow knowledge to handle.

I have modified TextVectorization._preprocess to convert the inputs to NumPy arrays before passing them to the custom callables if the active Keras backend is not TensorFlow. After the callable returns, the result is converted back to a TensorFlow string tensor so it can be processed by the rest of the layer (which relies on tf.strings ops).

Summary of changes:

  • Modified keras/src/layers/preprocessing/text_vectorization.py.
  • Added conversion to NumPy before custom standardize and split callables.
  • Added conversion back to TF tensor after the callables.
  • This removes the forced exposure to tf.EagerTensor for JAX/Torch/Numpy backend users.

Fixes #22626


PR created automatically by Jules for task 7273620509560818728 started by @divyashreepathihalli

Contributor Agreement

Please check all boxes below before submitting your PR for review:

  • I am a human, and not a bot.
  • I will be responsible for responding to review comments in a timely manner.
  • I will work with the maintainers to push this PR forward until submission.

…d of TF tensors on non-TF backends.

When using a non-TensorFlow backend (e.g., JAX), TextVectorization's custom `standardize` and `split` callables were receiving `tf.EagerTensor` objects because the layer uses TensorFlow internally for string manipulation. This created an unexpected dependency on TensorFlow for the user-provided functions.

This change ensures that:
1. Inputs are converted to NumPy arrays before being passed to custom callables when the backend is not TensorFlow.
2. Results from the callables are converted back to TensorFlow string tensors to continue with the layer's internal processing.

This makes the layer more backend-agnostic and avoids AttributeError when users try to use Python string methods on what they expect to be strings or numpy-like objects.

Co-authored-by: divyashreepathihalli <78194266+divyashreepathihalli@users.noreply.github.com>
@google-labs-jules
Copy link
Copy Markdown

👋 Jules, reporting for duty! I'm here to lend a hand with this pull request.

When you start a review, I'll add a 👀 emoji to each comment to let you know I've read it. I'll focus on feedback directed at me and will do my best to stay out of conversations between you and other bots or reviewers to keep the noise down.

I'll push a commit with your requested changes shortly after. Please note there might be a delay between these steps, but rest assured I'm on the job!

For more direct control, you can switch me to Reactive Mode. When this mode is on, I will only act on comments where you specifically mention me with @jules. You can find this option in the Pull Request section of your global Jules UI settings. You can always switch back!

New to Jules? Learn more at jules.google/docs.


For security, I will only act on instructions from the user who triggered this task.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds logic to the TextVectorization layer to support custom standardization and splitting callables when using non-TensorFlow backends by performing intermediate NumPy conversions. Review feedback highlights a potential failure when tf_utils.ensure_tensor encounters ragged data from custom splitters, suggests refactoring duplicated backend-check logic into a helper function, and requests the removal of temporary test files from the repository root.

inputs = backend.convert_to_numpy(inputs)
inputs = self._split(inputs)
if backend.backend() != "tensorflow":
inputs = tf_utils.ensure_tensor(inputs, dtype=tf.string)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using tf_utils.ensure_tensor here is problematic because it relies on tf.convert_to_tensor, which cannot handle non-rectangular (ragged) Python sequences. Custom split callables almost always return ragged data (e.g., a list of lists of strings where inner lists have different lengths). This will cause a ValueError: Can't convert non-rectangular Python sequence to Tensor when running with non-TensorFlow backends. You should use a conversion method that supports raggedness, such as tf.ragged.constant, if the output is a nested list.

Comment on lines +542 to +546
if backend.backend() != "tensorflow":
inputs = backend.convert_to_numpy(inputs)
inputs = self._standardize(inputs)
if backend.backend() != "tensorflow":
inputs = tf_utils.ensure_tensor(inputs, dtype=tf.string)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check if backend.backend() != "tensorflow": and the associated conversion logic are duplicated for both the standardize and split callables. This repetition makes the code harder to maintain and slightly less efficient. Consider determining the backend once at the start of _preprocess or using a small helper function to handle the conditional NumPy conversion.

Comment thread repro.py Outdated
Comment on lines +1 to +29
import os
import sys

backend = sys.argv[1] if len(sys.argv) > 1 else "jax"
os.environ["KERAS_BACKEND"] = backend

import re
import string
import keras
import numpy as np

strip_chars = string.punctuation
def my_standardize(input_string):
print(f"Backend: {os.environ['KERAS_BACKEND']}")
print(f"Type of input_string: {type(input_string)}")
try:
input_string = input_string.lower()
except AttributeError:
# If it's a tensor, we might need to handle it differently
# But the user expects .lower() to work
raise
return re.sub(f"[{re.escape(strip_chars)}]", "", input_string)

layer = keras.layers.TextVectorization(standardize=my_standardize)
try:
layer.adapt(["Hello, world."])
print("Adapt successful")
except Exception as e:
print(f"Caught exception: {e}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This reproduction script appears to be a temporary file and should not be committed to the repository root. Please remove it and move any necessary test logic to the appropriate test file in keras/src/layers/preprocessing/.

Comment thread test_convert.py Outdated
Comment on lines +1 to +12
import os
os.environ["KERAS_BACKEND"] = "jax"
import numpy as np
import tensorflow as tf
from keras.src import backend
from keras.src.utils import backend_utils

t = tf.constant(["Hello", "World"])
print(f"TF tensor: {t}")
converted = backend_utils.convert_tf_tensor(t)
print(f"Converted type: {type(converted)}")
print(f"Converted value: {converted}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test script appears to be a temporary file and should not be committed to the repository root. Please remove it.

google-labs-jules Bot and others added 3 commits April 15, 2026 23:21
…-TF backends.

TextVectorization uses TensorFlow internally even for non-TF backends.
Previously, this caused custom callables to receive `tf.EagerTensor`
objects instead of the expected backend-native or agnostic types.

This change converts inputs to NumPy arrays before passing them to
custom `standardize` or `split` callables when the backend is not
TensorFlow. The results are then converted back to TensorFlow string
tensors for internal processing.

This makes the layer more consistent with the "TensorFlow independence"
expected when using other backends like JAX or PyTorch.

Note: The CI failure regarding "contributor agreement terms" is a
process issue that cannot be resolved via code changes by the agent.

Co-authored-by: divyashreepathihalli <78194266+divyashreepathihalli@users.noreply.github.com>
…ings on non-TF backends.

TextVectorization uses TensorFlow internally even for non-TF backends.
Previously, this caused custom callables to receive `tf.EagerTensor`
objects instead of the expected backend-native or agnostic types.

This change:
1. Converts inputs to NumPy arrays before passing them to custom
   `standardize` or `split` callables when the backend is not TensorFlow.
2. If the input is a scalar, it is converted to a plain Python string
   (decoded using the layer's encoding) to match user expectations.
3. Converts results back to TensorFlow string tensors for internal processing.

This makes the layer more consistent with the "TensorFlow independence"
expected when using other backends like JAX or PyTorch.

Note: The CI failure regarding "contributor agreement terms" is a
process issue related to the bot-nature of this contribution and
manual intervention may be required to check the boxes in the PR.

Co-authored-by: divyashreepathihalli <78194266+divyashreepathihalli@users.noreply.github.com>
…on-TF backends.

TextVectorization uses TensorFlow internally for string manipulation even
when a different Keras backend (like JAX) is active. This caused custom
`standardize` and `split` callables to receive `tf.EagerTensor` objects,
breaking user code that expects plain Python strings (e.g., using `.lower()`).

This change:
1. Detects when a non-TF backend is used in eager mode.
2. Converts the input TF tensor to NumPy.
3. Decodes bytes to strings.
4. Uses `np.vectorize` to apply the user's callable element-wise to the
   batch, passing actual Python strings.
5. Converts the results back to a TensorFlow string tensor for continued
   internal processing.
6. Gracefully handles cases where the callable might return a TF tensor.

This provides a backend-agnostic experience for users while maintaining
the performance benefits of the layer's internal TF-based implementation.

Verified with JAX and TensorFlow backends.

Co-authored-by: divyashreepathihalli <78194266+divyashreepathihalli@users.noreply.github.com>
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 16, 2026

Codecov Report

❌ Patch coverage is 58.97436% with 16 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.93%. Comparing base (a44f5be) to head (ba739ba).
⚠️ Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
...ras/src/layers/preprocessing/text_vectorization.py 58.97% 7 Missing and 9 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #22682      +/-   ##
==========================================
- Coverage   82.95%   82.93%   -0.02%     
==========================================
  Files         596      596              
  Lines       69200    69237      +37     
  Branches    10806    10817      +11     
==========================================
+ Hits        57402    57423      +21     
- Misses       8969     8976       +7     
- Partials     2829     2838       +9     
Flag Coverage Δ
keras 82.75% <53.84%> (-0.02%) ⬇️
keras-jax 58.75% <48.71%> (-0.01%) ⬇️
keras-numpy 54.59% <48.71%> (-0.01%) ⬇️
keras-openvino 59.38% <48.71%> (-0.01%) ⬇️
keras-tensorflow 60.28% <5.12%> (-0.04%) ⬇️
keras-torch 59.09% <48.71%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

google-labs-jules Bot and others added 2 commits April 16, 2026 01:18
… backends.

TextVectorization uses TensorFlow internally for string manipulation even
when a different Keras backend (like JAX) is active. This caused custom
`standardize` and `split` callables to receive `tf.EagerTensor` objects,
breaking user code that expects plain Python strings (e.g., using `.lower()`).

This change:
1. Detects when a non-TF backend is used in eager mode.
2. Converts the input TF tensor to NumPy strings/bytes.
3. Decodes bytes to actual Python strings.
4. Applies the user's callable to the batch (using `np.vectorize` for
   standardization or a list comprehension for splitting to handle ragged
   results).
5. Converts the results back to a TensorFlow string tensor (handling
   rectangular or ragged results as appropriate).
6. Handles cases where the callable might return a TF tensor.

This provides a backend-agnostic experience for users while maintaining
the functionality of the layer's internal TF-based implementation.

Verified with JAX and TensorFlow backends.
Fixes # (issue reported by user)

Co-authored-by: divyashreepathihalli <78194266+divyashreepathihalli@users.noreply.github.com>
… backends.

TextVectorization uses TensorFlow internally for string manipulation even
when a different Keras backend (like JAX or OpenVino) is active. This caused
custom `standardize` and `split` callables to receive `tf.EagerTensor`
objects instead of the expected plain Python strings.

This change:
1. Detects when a non-TF backend is used in eager mode.
2. Converts the input TF tensor to NumPy strings and then actual Python strings.
3. Applies the user's callable to the batch:
   - For `standardize`: uses `np.vectorize` for element-wise application.
   - For `split`: uses list comprehension to handle potentially ragged results.
4. Converts the results back to TensorFlow string tensors (using
   `tf.ragged.constant` for splits if necessary) to continue with the
   layer's internal processing.
5. Gracefully handles cases where the callable might return a TF tensor
   (preserving compatibility with existing Keras tests).

This ensures a consistent, backend-agnostic experience for users while
leveraging the layer's optimized internal implementation.

Verified with JAX and TensorFlow backends. Passes pre-commit checks.

Co-authored-by: divyashreepathihalli <78194266+divyashreepathihalli@users.noreply.github.com>
@keerthanakadiri keerthanakadiri added the stat:awaiting keras-eng Awaiting response from Keras engineer label Apr 17, 2026
@hertschuh hertschuh closed this Apr 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

size:M stat:awaiting keras-eng Awaiting response from Keras engineer

Projects

None yet

Development

Successfully merging this pull request may close these issues.

TextVectorization custom standardize callable receives EagerTensor instead of a Python string

5 participants