Skip to content

Commit 92cef32

Browse files
authored
Num failed inferences (#23830)
* created test_increment_num_failed_inferences and test_num_failed_inferences_no_failures * added assertRaises to test_increment_num_failed_inferences * added num_failed_inferences to _MetricsCollector * changed error handling and update() implementation * updated metric name in tests * removed unnecessary else blocking * removed unnecessary inference_args from test_increment_failed_batches_counter() * changed final test_increment_failed_batches_counter assertion * clarified error handling and updated failed_batches_counter initialization * decreased examples array length to 1 to ensure repeatability * troubleshooting tests * trying to get test_increment_failed_batches_counter to fail as expected * corrected assertion details * simplified assertRaises and added reminder comment to assertEqual counter * lint test * lint test passed, resetting pre-commit-config.yaml * fixed lingering linting issues * shortened comment line to comply with linting * formatter worked its magic
1 parent 2341f61 commit 92cef32

3 files changed

Lines changed: 134 additions & 41 deletions

File tree

sdks/python/apache_beam/examples/snippets/transforms/aggregation/groupby_test.py

Lines changed: 95 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def test_groupby_expr(self):
8282
with beam.Pipeline() as p:
8383
grouped = (
8484
p
85-
| beam.Create(['strawberry', 'raspberry', 'blueberry', 'blackberry', 'banana'])
85+
| beam.Create(
86+
['strawberry', 'raspberry', 'blueberry', 'blackberry', 'banana'])
8687
| beam.GroupBy(lambda s: s[0]))
8788
# [END groupby_expr]
8889

@@ -101,8 +102,10 @@ def test_groupby_two_exprs(self):
101102
with beam.Pipeline() as p:
102103
grouped = (
103104
p
104-
| beam.Create(['strawberry', 'raspberry', 'blueberry', 'blackberry', 'banana'])
105-
| beam.GroupBy(letter=lambda s: s[0], is_berry=lambda s: 'berry' in s))
105+
| beam.Create(
106+
['strawberry', 'raspberry', 'blueberry', 'blackberry', 'banana'])
107+
| beam.GroupBy(
108+
letter=lambda s: s[0], is_berry=lambda s: 'berry' in s))
106109
# [END groupby_two_exprs]
107110

108111
expected = [
@@ -123,18 +126,44 @@ def test_group_by_attr(self):
123126

124127
expected = [
125128
#[START groupby_attr_result]
126-
('pie',
127-
[
128-
beam.Row(recipe='pie', fruit='strawberry', quantity=3, unit_price=1.50),
129-
beam.Row(recipe='pie', fruit='raspberry', quantity=1, unit_price=3.50),
130-
beam.Row(recipe='pie', fruit='blackberry', quantity=1, unit_price=4.00),
131-
beam.Row(recipe='pie', fruit='blueberry', quantity=1, unit_price=2.00),
132-
]),
133-
('muffin',
134-
[
135-
beam.Row(recipe='muffin', fruit='blueberry', quantity=2, unit_price=2.00),
136-
beam.Row(recipe='muffin', fruit='banana', quantity=3, unit_price=1.00),
137-
]),
129+
(
130+
'pie',
131+
[
132+
beam.Row(
133+
recipe='pie',
134+
fruit='strawberry',
135+
quantity=3,
136+
unit_price=1.50),
137+
beam.Row(
138+
recipe='pie',
139+
fruit='raspberry',
140+
quantity=1,
141+
unit_price=3.50),
142+
beam.Row(
143+
recipe='pie',
144+
fruit='blackberry',
145+
quantity=1,
146+
unit_price=4.00),
147+
beam.Row(
148+
recipe='pie',
149+
fruit='blueberry',
150+
quantity=1,
151+
unit_price=2.00),
152+
]),
153+
(
154+
'muffin',
155+
[
156+
beam.Row(
157+
recipe='muffin',
158+
fruit='blueberry',
159+
quantity=2,
160+
unit_price=2.00),
161+
beam.Row(
162+
recipe='muffin',
163+
fruit='banana',
164+
quantity=3,
165+
unit_price=1.00),
166+
]),
138167
#[END groupby_attr_result]
139168
]
140169
assert_that(grouped | beam.MapTuple(normalize_kv), equal_to(expected))
@@ -149,21 +178,48 @@ def test_group_by_attr_expr(self):
149178

150179
expected = [
151180
#[START groupby_attr_expr_result]
152-
(NamedTuple(recipe='pie', is_berry=True),
153-
[
154-
beam.Row(recipe='pie', fruit='strawberry', quantity=3, unit_price=1.50),
155-
beam.Row(recipe='pie', fruit='raspberry', quantity=1, unit_price=3.50),
156-
beam.Row(recipe='pie', fruit='blackberry', quantity=1, unit_price=4.00),
157-
beam.Row(recipe='pie', fruit='blueberry', quantity=1, unit_price=2.00),
158-
]),
159-
(NamedTuple(recipe='muffin', is_berry=True),
160-
[
161-
beam.Row(recipe='muffin', fruit='blueberry', quantity=2, unit_price=2.00),
162-
]),
163-
(NamedTuple(recipe='muffin', is_berry=False),
164-
[
165-
beam.Row(recipe='muffin', fruit='banana', quantity=3, unit_price=1.00),
166-
]),
181+
(
182+
NamedTuple(recipe='pie', is_berry=True),
183+
[
184+
beam.Row(
185+
recipe='pie',
186+
fruit='strawberry',
187+
quantity=3,
188+
unit_price=1.50),
189+
beam.Row(
190+
recipe='pie',
191+
fruit='raspberry',
192+
quantity=1,
193+
unit_price=3.50),
194+
beam.Row(
195+
recipe='pie',
196+
fruit='blackberry',
197+
quantity=1,
198+
unit_price=4.00),
199+
beam.Row(
200+
recipe='pie',
201+
fruit='blueberry',
202+
quantity=1,
203+
unit_price=2.00),
204+
]),
205+
(
206+
NamedTuple(recipe='muffin', is_berry=True),
207+
[
208+
beam.Row(
209+
recipe='muffin',
210+
fruit='blueberry',
211+
quantity=2,
212+
unit_price=2.00),
213+
]),
214+
(
215+
NamedTuple(recipe='muffin', is_berry=False),
216+
[
217+
beam.Row(
218+
recipe='muffin',
219+
fruit='banana',
220+
quantity=3,
221+
unit_price=1.00),
222+
]),
167223
#[END groupby_attr_expr_result]
168224
]
169225
assert_that(grouped | beam.MapTuple(normalize_kv), equal_to(expected))
@@ -174,8 +230,8 @@ def test_simple_aggregate(self):
174230
grouped = (
175231
p
176232
| beam.Create(GROCERY_LIST)
177-
| beam.GroupBy('fruit')
178-
.aggregate_field('quantity', sum, 'total_quantity'))
233+
| beam.GroupBy('fruit').aggregate_field(
234+
'quantity', sum, 'total_quantity'))
179235
# [END simple_aggregate]
180236

181237
expected = [
@@ -195,9 +251,9 @@ def test_expr_aggregate(self):
195251
grouped = (
196252
p
197253
| beam.Create(GROCERY_LIST)
198-
| beam.GroupBy('recipe')
199-
.aggregate_field('quantity', sum, 'total_quantity')
200-
.aggregate_field(lambda x: x.quantity * x.unit_price, sum, 'price'))
254+
| beam.GroupBy('recipe').aggregate_field(
255+
'quantity', sum, 'total_quantity').aggregate_field(
256+
lambda x: x.quantity * x.unit_price, sum, 'price'))
201257
# [END expr_aggregate]
202258

203259
expected = [
@@ -214,10 +270,10 @@ def test_global_aggregate(self):
214270
grouped = (
215271
p
216272
| beam.Create(GROCERY_LIST)
217-
| beam.GroupBy()
218-
.aggregate_field('unit_price', min, 'min_price')
219-
.aggregate_field('unit_price', MeanCombineFn(), 'mean_price')
220-
.aggregate_field('unit_price', max, 'max_price'))
273+
| beam.GroupBy().aggregate_field(
274+
'unit_price', min, 'min_price').aggregate_field(
275+
'unit_price', MeanCombineFn(), 'mean_price').aggregate_field(
276+
'unit_price', max, 'max_price'))
221277
# [END global_aggregate]
222278

223279
expected = [

sdks/python/apache_beam/ml/inference/base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,8 @@ def __init__(self, namespace: str):
337337
# Metrics
338338
self._inference_counter = beam.metrics.Metrics.counter(
339339
namespace, 'num_inferences')
340+
self.failed_batches_counter = beam.metrics.Metrics.counter(
341+
namespace, 'failed_batches_counter')
340342
self._inference_request_batch_size = beam.metrics.Metrics.distribution(
341343
namespace, 'inference_request_batch_size')
342344
self._inference_request_batch_byte_size = (
@@ -426,8 +428,12 @@ def setup(self):
426428

427429
def process(self, batch, inference_args):
428430
start_time = _to_microseconds(self._clock.time_ns())
429-
result_generator = self._model_handler.run_inference(
430-
batch, self._model, inference_args)
431+
try:
432+
result_generator = self._model_handler.run_inference(
433+
batch, self._model, inference_args)
434+
except BaseException as e:
435+
self._metrics_collector.failed_batches_counter.inc()
436+
raise e
431437
predictions = list(result_generator)
432438

433439
end_time = _to_microseconds(self._clock.time_ns())

sdks/python/apache_beam/ml/inference/base_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,37 @@ def test_unexpected_inference_args_passed(self):
175175
FakeModelHandlerFailsOnInferenceArgs(),
176176
inference_args=inference_args)
177177

178+
def test_increment_failed_batches_counter(self):
179+
with self.assertRaises(ValueError):
180+
with TestPipeline() as pipeline:
181+
examples = [7]
182+
pcoll = pipeline | 'start' >> beam.Create(examples)
183+
_ = pcoll | base.RunInference(FakeModelHandlerExpectedInferenceArgs())
184+
run_result = pipeline.run()
185+
run_result.wait_until_finish()
186+
187+
metric_results = (
188+
run_result.metrics().query(
189+
MetricsFilter().with_name('failed_batches_counter')))
190+
num_failed_batches_counter = metric_results['counters'][0]
191+
self.assertEqual(num_failed_batches_counter.committed, 3)
192+
# !!!: The above will need to be updated if retry behavior changes
193+
194+
def test_failed_batches_counter_no_failures(self):
195+
pipeline = TestPipeline()
196+
examples = [7]
197+
pcoll = pipeline | 'start' >> beam.Create(examples)
198+
inference_args = {'key': True}
199+
_ = pcoll | base.RunInference(
200+
FakeModelHandlerExpectedInferenceArgs(), inference_args=inference_args)
201+
run_result = pipeline.run()
202+
run_result.wait_until_finish()
203+
204+
metric_results = (
205+
run_result.metrics().query(
206+
MetricsFilter().with_name('failed_batches_counter')))
207+
self.assertEqual(len(metric_results['counters']), 0)
208+
178209
def test_counted_metrics(self):
179210
pipeline = TestPipeline()
180211
examples = [1, 5, 3, 10]

0 commit comments

Comments
 (0)