Fix TextVectorization custom callables receiving TF tensors on non-TF backends#22682
Conversation
…d of TF tensors on non-TF backends. When using a non-TensorFlow backend (e.g., JAX), TextVectorization's custom `standardize` and `split` callables were receiving `tf.EagerTensor` objects because the layer uses TensorFlow internally for string manipulation. This created an unexpected dependency on TensorFlow for the user-provided functions. This change ensures that: 1. Inputs are converted to NumPy arrays before being passed to custom callables when the backend is not TensorFlow. 2. Results from the callables are converted back to TensorFlow string tensors to continue with the layer's internal processing. This makes the layer more backend-agnostic and avoids AttributeError when users try to use Python string methods on what they expect to be strings or numpy-like objects. Co-authored-by: divyashreepathihalli <78194266+divyashreepathihalli@users.noreply.github.com>
|
👋 Jules, reporting for duty! I'm here to lend a hand with this pull request. When you start a review, I'll add a 👀 emoji to each comment to let you know I've read it. I'll focus on feedback directed at me and will do my best to stay out of conversations between you and other bots or reviewers to keep the noise down. I'll push a commit with your requested changes shortly after. Please note there might be a delay between these steps, but rest assured I'm on the job! For more direct control, you can switch me to Reactive Mode. When this mode is on, I will only act on comments where you specifically mention me with New to Jules? Learn more at jules.google/docs. For security, I will only act on instructions from the user who triggered this task. |
There was a problem hiding this comment.
Code Review
This pull request adds logic to the TextVectorization layer to support custom standardization and splitting callables when using non-TensorFlow backends by performing intermediate NumPy conversions. Review feedback highlights a potential failure when tf_utils.ensure_tensor encounters ragged data from custom splitters, suggests refactoring duplicated backend-check logic into a helper function, and requests the removal of temporary test files from the repository root.
| inputs = backend.convert_to_numpy(inputs) | ||
| inputs = self._split(inputs) | ||
| if backend.backend() != "tensorflow": | ||
| inputs = tf_utils.ensure_tensor(inputs, dtype=tf.string) |
There was a problem hiding this comment.
Using tf_utils.ensure_tensor here is problematic because it relies on tf.convert_to_tensor, which cannot handle non-rectangular (ragged) Python sequences. Custom split callables almost always return ragged data (e.g., a list of lists of strings where inner lists have different lengths). This will cause a ValueError: Can't convert non-rectangular Python sequence to Tensor when running with non-TensorFlow backends. You should use a conversion method that supports raggedness, such as tf.ragged.constant, if the output is a nested list.
| if backend.backend() != "tensorflow": | ||
| inputs = backend.convert_to_numpy(inputs) | ||
| inputs = self._standardize(inputs) | ||
| if backend.backend() != "tensorflow": | ||
| inputs = tf_utils.ensure_tensor(inputs, dtype=tf.string) |
There was a problem hiding this comment.
The check if backend.backend() != "tensorflow": and the associated conversion logic are duplicated for both the standardize and split callables. This repetition makes the code harder to maintain and slightly less efficient. Consider determining the backend once at the start of _preprocess or using a small helper function to handle the conditional NumPy conversion.
| import os | ||
| import sys | ||
|
|
||
| backend = sys.argv[1] if len(sys.argv) > 1 else "jax" | ||
| os.environ["KERAS_BACKEND"] = backend | ||
|
|
||
| import re | ||
| import string | ||
| import keras | ||
| import numpy as np | ||
|
|
||
| strip_chars = string.punctuation | ||
| def my_standardize(input_string): | ||
| print(f"Backend: {os.environ['KERAS_BACKEND']}") | ||
| print(f"Type of input_string: {type(input_string)}") | ||
| try: | ||
| input_string = input_string.lower() | ||
| except AttributeError: | ||
| # If it's a tensor, we might need to handle it differently | ||
| # But the user expects .lower() to work | ||
| raise | ||
| return re.sub(f"[{re.escape(strip_chars)}]", "", input_string) | ||
|
|
||
| layer = keras.layers.TextVectorization(standardize=my_standardize) | ||
| try: | ||
| layer.adapt(["Hello, world."]) | ||
| print("Adapt successful") | ||
| except Exception as e: | ||
| print(f"Caught exception: {e}") |
| import os | ||
| os.environ["KERAS_BACKEND"] = "jax" | ||
| import numpy as np | ||
| import tensorflow as tf | ||
| from keras.src import backend | ||
| from keras.src.utils import backend_utils | ||
|
|
||
| t = tf.constant(["Hello", "World"]) | ||
| print(f"TF tensor: {t}") | ||
| converted = backend_utils.convert_tf_tensor(t) | ||
| print(f"Converted type: {type(converted)}") | ||
| print(f"Converted value: {converted}") |
…-TF backends. TextVectorization uses TensorFlow internally even for non-TF backends. Previously, this caused custom callables to receive `tf.EagerTensor` objects instead of the expected backend-native or agnostic types. This change converts inputs to NumPy arrays before passing them to custom `standardize` or `split` callables when the backend is not TensorFlow. The results are then converted back to TensorFlow string tensors for internal processing. This makes the layer more consistent with the "TensorFlow independence" expected when using other backends like JAX or PyTorch. Note: The CI failure regarding "contributor agreement terms" is a process issue that cannot be resolved via code changes by the agent. Co-authored-by: divyashreepathihalli <78194266+divyashreepathihalli@users.noreply.github.com>
…ings on non-TF backends. TextVectorization uses TensorFlow internally even for non-TF backends. Previously, this caused custom callables to receive `tf.EagerTensor` objects instead of the expected backend-native or agnostic types. This change: 1. Converts inputs to NumPy arrays before passing them to custom `standardize` or `split` callables when the backend is not TensorFlow. 2. If the input is a scalar, it is converted to a plain Python string (decoded using the layer's encoding) to match user expectations. 3. Converts results back to TensorFlow string tensors for internal processing. This makes the layer more consistent with the "TensorFlow independence" expected when using other backends like JAX or PyTorch. Note: The CI failure regarding "contributor agreement terms" is a process issue related to the bot-nature of this contribution and manual intervention may be required to check the boxes in the PR. Co-authored-by: divyashreepathihalli <78194266+divyashreepathihalli@users.noreply.github.com>
…on-TF backends. TextVectorization uses TensorFlow internally for string manipulation even when a different Keras backend (like JAX) is active. This caused custom `standardize` and `split` callables to receive `tf.EagerTensor` objects, breaking user code that expects plain Python strings (e.g., using `.lower()`). This change: 1. Detects when a non-TF backend is used in eager mode. 2. Converts the input TF tensor to NumPy. 3. Decodes bytes to strings. 4. Uses `np.vectorize` to apply the user's callable element-wise to the batch, passing actual Python strings. 5. Converts the results back to a TensorFlow string tensor for continued internal processing. 6. Gracefully handles cases where the callable might return a TF tensor. This provides a backend-agnostic experience for users while maintaining the performance benefits of the layer's internal TF-based implementation. Verified with JAX and TensorFlow backends. Co-authored-by: divyashreepathihalli <78194266+divyashreepathihalli@users.noreply.github.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #22682 +/- ##
==========================================
- Coverage 82.95% 82.93% -0.02%
==========================================
Files 596 596
Lines 69200 69237 +37
Branches 10806 10817 +11
==========================================
+ Hits 57402 57423 +21
- Misses 8969 8976 +7
- Partials 2829 2838 +9
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
… backends. TextVectorization uses TensorFlow internally for string manipulation even when a different Keras backend (like JAX) is active. This caused custom `standardize` and `split` callables to receive `tf.EagerTensor` objects, breaking user code that expects plain Python strings (e.g., using `.lower()`). This change: 1. Detects when a non-TF backend is used in eager mode. 2. Converts the input TF tensor to NumPy strings/bytes. 3. Decodes bytes to actual Python strings. 4. Applies the user's callable to the batch (using `np.vectorize` for standardization or a list comprehension for splitting to handle ragged results). 5. Converts the results back to a TensorFlow string tensor (handling rectangular or ragged results as appropriate). 6. Handles cases where the callable might return a TF tensor. This provides a backend-agnostic experience for users while maintaining the functionality of the layer's internal TF-based implementation. Verified with JAX and TensorFlow backends. Fixes # (issue reported by user) Co-authored-by: divyashreepathihalli <78194266+divyashreepathihalli@users.noreply.github.com>
… backends. TextVectorization uses TensorFlow internally for string manipulation even when a different Keras backend (like JAX or OpenVino) is active. This caused custom `standardize` and `split` callables to receive `tf.EagerTensor` objects instead of the expected plain Python strings. This change: 1. Detects when a non-TF backend is used in eager mode. 2. Converts the input TF tensor to NumPy strings and then actual Python strings. 3. Applies the user's callable to the batch: - For `standardize`: uses `np.vectorize` for element-wise application. - For `split`: uses list comprehension to handle potentially ragged results. 4. Converts the results back to TensorFlow string tensors (using `tf.ragged.constant` for splits if necessary) to continue with the layer's internal processing. 5. Gracefully handles cases where the callable might return a TF tensor (preserving compatibility with existing Keras tests). This ensures a consistent, backend-agnostic experience for users while leveraging the layer's optimized internal implementation. Verified with JAX and TensorFlow backends. Passes pre-commit checks. Co-authored-by: divyashreepathihalli <78194266+divyashreepathihalli@users.noreply.github.com>
When using TextVectorization with a custom standardize or split callable and a non-TensorFlow backend (like JAX), the callable received a
tf.EagerTensor. This caused issues because users expected either a plain Python string or at least something that doesn't require TensorFlow knowledge to handle.I have modified
TextVectorization._preprocessto convert the inputs to NumPy arrays before passing them to the custom callables if the active Keras backend is not TensorFlow. After the callable returns, the result is converted back to a TensorFlow string tensor so it can be processed by the rest of the layer (which relies ontf.stringsops).Summary of changes:
keras/src/layers/preprocessing/text_vectorization.py.standardizeandsplitcallables.tf.EagerTensorfor JAX/Torch/Numpy backend users.Fixes #22626
PR created automatically by Jules for task 7273620509560818728 started by @divyashreepathihalli
Contributor Agreement
Please check all boxes below before submitting your PR for review: