Skip to content

Commit cf0cf3b

Browse files
authored
Add an annotation to expose transforms to yaml. (#28208)
We should add this to all transforms that are simply parameterized.
1 parent 141e3e6 commit cf0cf3b

2 files changed

Lines changed: 82 additions & 0 deletions

File tree

sdks/python/apache_beam/transforms/ptransform.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@ class and wrapper class that allows lambda functions to be used as
3838

3939
import copy
4040
import itertools
41+
import json
4142
import logging
4243
import operator
4344
import os
4445
import sys
4546
import threading
47+
import warnings
4648
from functools import reduce
4749
from functools import wraps
4850
from typing import TYPE_CHECKING
@@ -83,6 +85,7 @@ class and wrapper class that allows lambda functions to be used as
8385
from apache_beam.typehints.trivial_inference import instance_to_type
8486
from apache_beam.typehints.typehints import validate_composite_type_param
8587
from apache_beam.utils import proto_utils
88+
from apache_beam.utils import python_callable
8689

8790
if TYPE_CHECKING:
8891
from apache_beam import coders
@@ -95,6 +98,7 @@ class and wrapper class that allows lambda functions to be used as
9598
'PTransform',
9699
'ptransform_fn',
97100
'label_from_callable',
101+
'annotate_yaml',
98102
]
99103

100104
_LOGGER = logging.getLogger(__name__)
@@ -1096,3 +1100,51 @@ def __ror__(self, pvalueish, _unused=None):
10961100

10971101
def expand(self, pvalue):
10981102
raise RuntimeError("Should never be expanded directly.")
1103+
1104+
1105+
# Defined here to avoid circular import issues for Beam library transforms.
1106+
def annotate_yaml(constructor):
1107+
"""Causes instances of this transform to be annotated with their yaml syntax.
1108+
1109+
Should only be used for transforms that are fully defined by their constructor
1110+
arguments.
1111+
"""
1112+
@wraps(constructor)
1113+
def wrapper(*args, **kwargs):
1114+
transform = constructor(*args, **kwargs)
1115+
1116+
fully_qualified_name = (
1117+
f'{constructor.__module__}.{constructor.__qualname__}')
1118+
try:
1119+
imported_constructor = (
1120+
python_callable.PythonCallableWithSource.
1121+
load_from_fully_qualified_name(fully_qualified_name))
1122+
if imported_constructor != wrapper:
1123+
raise ImportError('Different object.')
1124+
except ImportError:
1125+
warnings.warn(f'Cannot import {constructor} as {fully_qualified_name}.')
1126+
return transform
1127+
1128+
try:
1129+
config = json.dumps({
1130+
'constructor': fully_qualified_name,
1131+
'args': args,
1132+
'kwargs': kwargs,
1133+
})
1134+
except TypeError as exn:
1135+
warnings.warn(
1136+
f'Cannot serialize arguments for {constructor} as json: {exn}')
1137+
return transform
1138+
1139+
original_annotations = transform.annotations
1140+
transform.annotations = lambda: {
1141+
**original_annotations(),
1142+
# These override whatever may have been provided earlier.
1143+
# The outermost call is expected to be the most specific.
1144+
'yaml_provider': 'python',
1145+
'yaml_type': 'PyTransform',
1146+
'yaml_args': config,
1147+
}
1148+
return transform
1149+
1150+
return wrapper

sdks/python/apache_beam/yaml/yaml_transform_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,23 @@ def test_name_is_ambiguous(self):
250250
output: AnotherFilter
251251
''')
252252

253+
def test_annotations(self):
254+
t = LinearTransform(5, b=100)
255+
annotations = t.annotations()
256+
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
257+
pickle_library='cloudpickle')) as p:
258+
result = p | YamlTransform(
259+
'''
260+
type: chain
261+
transforms:
262+
- type: Create
263+
config:
264+
elements: [0, 1, 2, 3]
265+
- type: %r
266+
config: %s
267+
''' % (annotations['yaml_type'], annotations['yaml_args']))
268+
assert_that(result, equal_to([100, 105, 110, 115]))
269+
253270

254271
class CreateTimestamped(beam.PTransform):
255272
def __init__(self, elements):
@@ -631,6 +648,19 @@ def test_prefers_same_provider_class(self):
631648
label='StartWith3')
632649

633650

651+
@beam.transforms.ptransform.annotate_yaml
652+
class LinearTransform(beam.PTransform):
653+
"""A transform used for testing annotate_yaml."""
654+
def __init__(self, a, b):
655+
self._a = a
656+
self._b = b
657+
658+
def expand(self, pcoll):
659+
a = self._a
660+
b = self._b
661+
return pcoll | beam.Map(lambda x: a * x + b)
662+
663+
634664
if __name__ == '__main__':
635665
logging.getLogger().setLevel(logging.INFO)
636666
unittest.main()

0 commit comments

Comments
 (0)