Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run",
"revision": 2
"revision": 3
}
290 changes: 187 additions & 103 deletions sdks/python/apache_beam/yaml/yaml_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
#

"""This module defines the basic MapToFields operation."""

import datetime
import itertools
import re
from collections import abc
from collections.abc import Callable
from collections.abc import Collection
from collections.abc import Iterable
Comment thread
derrickaw marked this conversation as resolved.
Comment thread
derrickaw marked this conversation as resolved.
Comment thread
derrickaw marked this conversation as resolved.
Comment thread
derrickaw marked this conversation as resolved.
Expand All @@ -29,6 +30,8 @@
from typing import TypeVar
from typing import Union

import jinja2

import apache_beam as beam
from apache_beam.io.filesystems import FileSystems
from apache_beam.portability.api import schema_pb2
Expand All @@ -53,20 +56,40 @@
from apache_beam.yaml.yaml_errors import maybe_with_exception_handling_transform_fn
from apache_beam.yaml.yaml_provider import dicts_to_rows
Comment thread
derrickaw marked this conversation as resolved.
Comment thread
derrickaw marked this conversation as resolved.
Comment thread
derrickaw marked this conversation as resolved.

# Import js2py package if it exists
try:
import js2py
from js2py.base import JsObjectWrapper
from py_mini_racer import MiniRacer
except ImportError:
js2py = None
JsObjectWrapper = object
MiniRacer = None

_JS_DATE_ISO_REGEX = re.compile(
r'^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z$')
_JS_IDENTIFIER_PATTERN = re.compile(r'^[a-zA-Z_$][a-zA-Z0-9_$]*$')

_str_expression_fields = {
'AssignTimestamps': 'timestamp',
'Filter': 'keep',
'Partition': 'by',
}

JS_EXPR_TEMPLATE = jinja2.Template(
"""
var {{ func_id }} = (__row__) => {
{% for field in valid_fields %}
const {{ field }} = __row__['{{ field }}'];
{% endfor %}
return ({{ expr }});
};
""")

JS_AGGREGATOR_TEMPLATE = jinja2.Template(
"""
var __aggregate_fn__ = (__row__) => ({
{% for name, func_name in field_funcs.items() %}
"{{ name }}": {{ func_name }}(__row__){% if not loop.last %},{% endif %}
{% endfor %}
});
""")


def normalize_mapping(spec):
"""
Expand Down Expand Up @@ -178,112 +201,162 @@ def _check_mapping_arguments(
raise ValueError(f'{transform_name} cannot specify "name" without "path"')


# js2py's JsObjectWrapper object has a self-referencing __dict__ property
# that cannot be pickled without implementing the __getstate__ and
# __setstate__ methods.
class _CustomJsObjectWrapper(JsObjectWrapper):
def __init__(self, js_obj):
super().__init__(js_obj.__dict__['_obj'])

def __getstate__(self):
return self.__dict__.copy()

def __setstate__(self, state):
self.__dict__.update(state)


# TODO(yaml) Improve type inferencing for JS UDF's
def py_value_to_js_dict(py_value):
if ((isinstance(py_value, tuple) and hasattr(py_value, '_asdict')) or
isinstance(py_value, beam.Row)):
py_value = py_value._asdict()
if isinstance(py_value, dict):
return {key: py_value_to_js_dict(value) for key, value in py_value.items()}
elif not isinstance(py_value, str) and isinstance(py_value, abc.Iterable):
elif not isinstance(py_value, str) and isinstance(py_value, Iterable):
return [py_value_to_js_dict(value) for value in list(py_value)]
else:
return py_value


# TODO(yaml) Consider adding optional language version parameter to support
# ECMAScript 5 and 6
def _expand_javascript_mapping_func(
original_fields, expression=None, callable=None, path=None, name=None):

# Check for installed js2py package
if js2py is None:
raise ValueError(
"Javascript mapping functions are not supported on"
" Python 3.12 or later.")

# import remaining js2py objects
from js2py import base
from js2py.constructors import jsdate
from js2py.internals import simplex

js_array_type = (
base.PyJsArray,
base.PyJsArrayBuffer,
base.PyJsInt8Array,
base.PyJsUint8Array,
base.PyJsUint8ClampedArray,
base.PyJsInt16Array,
base.PyJsUint16Array,
base.PyJsInt32Array,
base.PyJsUint32Array,
base.PyJsFloat32Array,
base.PyJsFloat64Array)

def _js_object_to_py_object(obj):
if isinstance(obj, (base.PyJsNumber, base.PyJsString, base.PyJsBoolean)):
return base.to_python(obj)
elif isinstance(obj, js_array_type):
return [_js_object_to_py_object(value) for value in obj.to_list()]
elif isinstance(obj, jsdate.PyJsDate):
return obj.to_utc_dt()
elif isinstance(obj, (base.PyJsNull, base.PyJsUndefined)):
return None
elif isinstance(obj, base.PyJsError):
raise RuntimeError(obj['message'])
elif isinstance(obj, base.PyJsObject):
return {
key: _js_object_to_py_object(value['value'])
for (key, value) in obj.own.items()
}
elif isinstance(obj, base.JsObjectWrapper):
return _js_object_to_py_object(obj._obj)

def js_to_py(obj):
"""Converts mini-racer mapped objects to standard Python types.

This is needed because ctx.eval returns objects that implement Mapping
and Iterable but are not picklable (like JSMappedObjectImpl and JSArrayImpl),
which would fail when Beam tries to serialize rows containing them.
We also preserve datetime objects which are correctly produced by ctx.eval
for JS Date objects.
"""
if isinstance(obj, datetime.datetime):
return obj
elif isinstance(obj, Mapping):
return {k: js_to_py(v) for k, v in obj.items()}
elif not isinstance(obj, (str, bytes)) and isinstance(obj, Iterable):
Comment thread
derrickaw marked this conversation as resolved.
Comment thread
derrickaw marked this conversation as resolved.
return [js_to_py(v) for v in obj]
elif isinstance(obj, str):
if _JS_DATE_ISO_REGEX.match(obj):
try:
return datetime.datetime.fromisoformat(obj[:-1] + '+00:00')
except ValueError:
return obj
return obj
else:
return obj

if expression:
source = '\n'.join(['function(__row__) {'] + [
f' {name} = __row__.{name}'
for name in original_fields if name in expression
] + [' return (' + expression + ')'] + ['}'])
js_func = _CustomJsObjectWrapper(js2py.eval_js(source))

elif callable:
js_func = _CustomJsObjectWrapper(js2py.eval_js(callable))
class JsFilterDoFn(beam.DoFn):
def __init__(self, udf_code, function_name):
self.udf_code = udf_code
self.function_name = function_name
self.ctx = None

def setup(self):
self.ctx = MiniRacer()
self.ctx.eval(self.udf_code)

def process(self, element):
row_as_dict = py_value_to_js_dict(element)
result = self.ctx.call(self.function_name, row_as_dict)
result = js_to_py(result)
if result:
yield element


class JsMapToFieldsDoFn(beam.DoFn):
def __init__(self, fields, original_fields, input_schema):
self.ctx = None
self.field_funcs = {}
self.passthrough_fields = []

script = []
for i, (name, expr) in enumerate(fields.items()):
if isinstance(expr, str) and expr in input_schema:
self.passthrough_fields.append((name, expr))
continue

if isinstance(expr, str):
expr = {'expression': expr}

# We use numeric indexing (func_{i}) instead of reusing the output field
# name to prevent syntax errors if output names contain spaces or hyphens.
# We also use bracket notation for robustness against input field names
# that aren't compliant dot-access identifiers.
if 'expression' in expr:
e = expr['expression']
valid_fields = [
n for n in original_fields
if n in e and _JS_IDENTIFIER_PATTERN.match(n)
]
code = JS_EXPR_TEMPLATE.render(
func_id=f"func_{i}", valid_fields=valid_fields, expr=e)
script.append(code.strip())
self.field_funcs[name] = f"func_{i}"
elif 'callable' in expr:
code = f"var func_{i} = {expr['callable']}"
script.append(code)
self.field_funcs[name] = f"func_{i}"
elif 'path' in expr and 'name' in expr:
path = expr['path']
func_name = expr['name']
udf_code = FileSystems.open(path).read().decode()
Comment thread
derrickaw marked this conversation as resolved.
script.append(udf_code)
self.field_funcs[name] = func_name
Comment thread
derrickaw marked this conversation as resolved.

if self.field_funcs:
code = JS_AGGREGATOR_TEMPLATE.render(field_funcs=self.field_funcs)
script.append(code.strip())

self.script = "\n".join(script) if script else None

def setup(self):
self.ctx = MiniRacer()
if self.script:
self.ctx.eval(self.script)

def process(self, element):
row_as_dict = py_value_to_js_dict(element)
result_dict = {}

# Handle passthrough fields
for name, src in self.passthrough_fields:
result_dict[name] = row_as_dict.get(src)

# Handle JS fields via single aggregate call
if self.field_funcs:
res = self.ctx.call("__aggregate_fn__", row_as_dict)
result_dict.update(js_to_py(res))

yield dicts_to_rows(result_dict)


def _get_javascript_udf_code(
original_fields,
function_name="func",
expression=None,
callable=None,
path=None,
name=None):

if MiniRacer is None:
raise ValueError(
"JavaScript mapping functions require the 'py-mini-racer' package to be installed."
)

else:
udf_code = None
if path:
if not path.endswith('.js'):
raise ValueError(f'File "{path}" is not a valid .js file.')
udf_code = FileSystems.open(path).read().decode()
js = js2py.EvalJs()
js.eval(udf_code)
js_func = _CustomJsObjectWrapper(getattr(js, name))

def js_wrapper(row):
row_as_dict = py_value_to_js_dict(row)
try:
js_result = js_func(row_as_dict)
except simplex.JsException as exn:
raise RuntimeError(
f"Error evaluating javascript expression: "
f"{exn.mes['message']}") from exn
return dicts_to_rows(_js_object_to_py_object(js_result))

return js_wrapper
return udf_code, name
elif expression:
# We use bracket notation for robustness against field names that
# aren't compliant dot-access identifiers.
udf_code = f"var {function_name} = (__row__) => {{ " + " ".join([
f"const {n} = __row__['{n}'];"
for n in original_fields if n in expression
]) + f" return ({expression}); }}"
return udf_code, function_name
elif callable:
udf_code = f"var {function_name} = {callable}"
return udf_code, function_name
else:
raise ValueError("Must specify expression, callable, or path.")


def _expand_python_mapping_func(
Expand Down Expand Up @@ -390,14 +463,10 @@ def _as_callable(original_fields, expr, transform_name, language, input_schema):
explicit_type = expr.pop('output_type', None)
_check_mapping_arguments(transform_name, **expr)

if language == "javascript":
func = _expand_javascript_mapping_func(original_fields, **expr)
elif language in ("python", "generic", None):
if language in ("python", "generic", None):
func = _expand_python_mapping_func(original_fields, **expr)
else:
raise ValueError(
f'Unknown language for mapping transform: {language}. '
'Supported languages are "javascript" and "python."')
raise ValueError(f'Language {language} not supported in this context.')

if explicit_type:
if isinstance(explicit_type, str):
Expand Down Expand Up @@ -636,8 +705,17 @@ def _PyJsFilter(
error_handling: Whether and where to output records that throw errors when
the above expressions are evaluated.
""" # pylint: disable=line-too-long
keep_fn = _as_callable_for_pcoll(pcoll, keep, "keep", language or 'generic')
return pcoll | beam.Filter(keep_fn)
if language == 'javascript':
if isinstance(keep, str):
keep = {'expression': keep}
udf_code, function_name = _get_javascript_udf_code(
[f.name for f in schema_from_element_type(pcoll.element_type).fields],
Comment thread
derrickaw marked this conversation as resolved.
Comment thread
derrickaw marked this conversation as resolved.
**keep
)
return pcoll | beam.ParDo(JsFilterDoFn(udf_code, function_name))
Comment thread
derrickaw marked this conversation as resolved.
else:
keep_fn = _as_callable_for_pcoll(pcoll, keep, "keep", language or 'generic')
return pcoll | beam.Filter(keep_fn)
Comment thread
derrickaw marked this conversation as resolved.


def is_expr(v):
Expand Down Expand Up @@ -709,10 +787,16 @@ def _PyJsMapToFields(
""" # pylint: disable=line-too-long
input_schema, fields = normalize_fields(
pcoll, fields, drop or (), append, language=language or 'generic')
original_fields = list(input_schema.keys())

if language == 'javascript':
options.YamlOptions.check_enabled(pcoll.pipeline, 'javascript')

original_fields = list(input_schema.keys())
if MiniRacer is None:
raise ValueError(
"JavaScript mapping functions require the 'py-mini-racer' package to be installed."
)
return pcoll | beam.ParDo(
JsMapToFieldsDoFn(fields, original_fields, input_schema))
Comment thread
derrickaw marked this conversation as resolved.
Comment thread
derrickaw marked this conversation as resolved.

return pcoll | beam.Select(
**{
Expand Down
Loading
Loading