Skip to content

Commit cf66672

Browse files
tetrominocopybara-github
authored andcommitted
Support key callback in Starlark min/max builtins
This is required by the language spec, but was not implemented in Bazel. See https://github.com/bazelbuild/starlark/blob/master/spec.md#max Fixes #15022 Also take the opportunity to adjust sorted's signature for `key` to match. RELNOTES: Starlark `min` and `max` buitins now allow a `key` callback, similarly to `sorted`. PiperOrigin-RevId: 623547043 Change-Id: I71d44aa715793f9f2260f9b20b876694154ff352
1 parent 999762d commit cf66672

3 files changed

Lines changed: 293 additions & 67 deletions

File tree

src/main/java/net/starlark/java/eval/MethodLibrary.java

Lines changed: 144 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,17 @@
1414

1515
package net.starlark.java.eval;
1616

17+
import static com.google.common.collect.Streams.stream;
18+
import static java.util.Comparator.comparing;
19+
1720
import com.google.common.base.Ascii;
21+
import com.google.common.base.Throwables;
1822
import com.google.common.collect.Ordering;
1923
import java.util.Arrays;
2024
import java.util.Comparator;
2125
import java.util.Iterator;
2226
import java.util.NoSuchElementException;
27+
import java.util.Optional;
2328
import net.starlark.java.annot.Param;
2429
import net.starlark.java.annot.ParamType;
2530
import net.starlark.java.annot.StarlarkBuiltin;
@@ -31,46 +36,151 @@ class MethodLibrary {
3136
@StarlarkMethod(
3237
name = "min",
3338
doc =
34-
"Returns the smallest one of all given arguments. "
35-
+ "If only one argument is provided, it must be a non-empty iterable. "
36-
+ "It is an error if elements are not comparable (for example int with string), "
37-
+ "or if no arguments are given. "
38-
+ "<pre class=\"language-python\">min(2, 5, 4) == 2\n"
39-
+ "min([5, 6, 3]) == 3</pre>",
40-
extraPositionals = @Param(name = "args", doc = "The elements to be checked."))
41-
public Object min(Sequence<?> args) throws EvalException {
42-
return findExtreme(args, Starlark.ORDERING.reverse());
39+
"Returns the smallest one of all given arguments. If only one positional argument is"
40+
+ " provided, it must be a non-empty iterable. It is an error if elements are not"
41+
+ " comparable (for example int with string), or if no arguments are given."
42+
+ "<pre class=\"language-python\">\n" //
43+
+ "min(2, 5, 4) == 2\n"
44+
+ "min([5, 6, 3]) == 3\n"
45+
+ "min(\"six\", \"three\", \"four\", key = len) == \"six\" # the shortest\n"
46+
+ "min([2, -2, -1, 1], key = abs) == -1 # the first encountered with minimal key"
47+
+ " value\n"
48+
+ "</pre>",
49+
extraPositionals = @Param(name = "args", doc = "The elements to be checked."),
50+
parameters = {
51+
@Param(
52+
name = "key",
53+
named = true,
54+
positional = false,
55+
allowedTypes = {
56+
@ParamType(type = StarlarkCallable.class),
57+
@ParamType(type = NoneType.class),
58+
},
59+
doc = "An optional function applied to each element before comparison.",
60+
defaultValue = "None")
61+
},
62+
useStarlarkThread = true)
63+
public Object min(Object key, Sequence<?> args, StarlarkThread thread)
64+
throws EvalException, InterruptedException {
65+
return findExtreme(
66+
args,
67+
Starlark.toJavaOptional(key, StarlarkCallable.class),
68+
Starlark.ORDERING.reverse(),
69+
thread);
4370
}
4471

4572
@StarlarkMethod(
4673
name = "max",
4774
doc =
48-
"Returns the largest one of all given arguments. "
49-
+ "If only one argument is provided, it must be a non-empty iterable."
50-
+ "It is an error if elements are not comparable (for example int with string), "
51-
+ "or if no arguments are given. "
52-
+ "<pre class=\"language-python\">max(2, 5, 4) == 5\n"
53-
+ "max([5, 6, 3]) == 6</pre>",
54-
extraPositionals = @Param(name = "args", doc = "The elements to be checked."))
55-
public Object max(Sequence<?> args) throws EvalException {
56-
return findExtreme(args, Starlark.ORDERING);
75+
"Returns the largest one of all given arguments. If only one positional argument is"
76+
+ " provided, it must be a non-empty iterable.It is an error if elements are not"
77+
+ " comparable (for example int with string), or if no arguments are given."
78+
+ "<pre class=\"language-python\">\n" //
79+
+ "max(2, 5, 4) == 5\n"
80+
+ "max([5, 6, 3]) == 6\n"
81+
+ "max(\"two\", \"three\", \"four\", key = len) ==\"three\" # the longest\n"
82+
+ "max([1, -1, -2, 2], key = abs) == -2 # the first encountered with maximal key"
83+
+ " value\n"
84+
+ "</pre>",
85+
extraPositionals = @Param(name = "args", doc = "The elements to be checked."),
86+
parameters = {
87+
@Param(
88+
name = "key",
89+
named = true,
90+
positional = false,
91+
allowedTypes = {
92+
@ParamType(type = StarlarkCallable.class),
93+
@ParamType(type = NoneType.class),
94+
},
95+
doc = "An optional function applied to each element before comparison.",
96+
defaultValue = "None")
97+
},
98+
useStarlarkThread = true)
99+
public Object max(Object key, Sequence<?> args, StarlarkThread thread)
100+
throws EvalException, InterruptedException {
101+
return findExtreme(
102+
args, Starlark.toJavaOptional(key, StarlarkCallable.class), Starlark.ORDERING, thread);
57103
}
58104

59105
/** Returns the maximum element from this list, as determined by maxOrdering. */
60-
private static Object findExtreme(Sequence<?> args, Ordering<Object> maxOrdering)
61-
throws EvalException {
106+
private static Object findExtreme(
107+
Sequence<?> args,
108+
Optional<StarlarkCallable> keyFn,
109+
Ordering<Object> maxOrdering,
110+
StarlarkThread thread)
111+
throws EvalException, InterruptedException {
62112
// Args can either be a list of items to compare, or a singleton list whose element is an
63113
// iterable of items to compare. In either case, there must be at least one item to compare.
64114
Iterable<?> items = (args.size() == 1) ? Starlark.toIterable(args.get(0)) : args;
65115
try {
66-
return maxOrdering.max(items);
116+
if (keyFn.isPresent()) {
117+
try {
118+
return stream(items)
119+
.map(value -> ValueWithComparisonKey.make(value, keyFn.get(), thread))
120+
.max(comparing(ValueWithComparisonKey::getComparisonKey, maxOrdering))
121+
.get()
122+
.getValue();
123+
} catch (ValueWithComparisonKey.KeyCallException ex) {
124+
Throwables.throwIfInstanceOf(ex.getCause(), EvalException.class);
125+
Throwables.throwIfInstanceOf(ex.getCause(), InterruptedException.class);
126+
throw new AssertionError("Got invalid ValueWithComparisonKey.KeyCallException", ex);
127+
}
128+
} else {
129+
return maxOrdering.max(items);
130+
}
67131
} catch (ClassCastException ex) {
68132
throw new EvalException(ex.getMessage()); // e.g. unsupported comparison: int <=> string
69133
} catch (NoSuchElementException ex) {
70134
throw new EvalException("expected at least one item", ex);
71135
}
72136
}
73137

138+
/**
139+
* Original value decorated with its comparison key; storing the comparison key alongside the
140+
* value ensures that we call the comparison key computation function only once per original value
141+
* (which is important in case the function has side effects).
142+
*/
143+
private static final class ValueWithComparisonKey {
144+
private final Object value;
145+
private final Object comparisonKey;
146+
147+
private ValueWithComparisonKey(Object value, Object comparisonKey) {
148+
this.value = value;
149+
this.comparisonKey = comparisonKey;
150+
}
151+
152+
/**
153+
* @throws KeyCallException wrapping the exception thrown by the underlying {@link
154+
* Starlark#fastcall} call if it threw.
155+
*/
156+
static ValueWithComparisonKey make(
157+
Object value, StarlarkCallable keyFn, StarlarkThread thread) {
158+
Object[] positional = {value};
159+
Object[] named = {};
160+
try {
161+
return new ValueWithComparisonKey(
162+
value, Starlark.fastcall(thread, keyFn, positional, named));
163+
} catch (EvalException | InterruptedException ex) {
164+
throw new KeyCallException(ex);
165+
}
166+
}
167+
168+
Object getValue() {
169+
return value;
170+
}
171+
172+
Object getComparisonKey() {
173+
return comparisonKey;
174+
}
175+
176+
/** An unchecked exception wrapping an exception thrown by {@link Starlark#fastcall}. */
177+
private static final class KeyCallException extends RuntimeException {
178+
KeyCallException(Exception cause) {
179+
super(cause);
180+
}
181+
}
182+
}
183+
74184
@StarlarkMethod(
75185
name = "abs",
76186
doc =
@@ -140,16 +250,24 @@ private static boolean hasElementWithBooleanValue(Object seq, boolean value)
140250
+ " using x < y. The elements are sorted into ascending order, unless the reverse"
141251
+ " argument is True, in which case the order is descending.\n"
142252
+ " Sorting is stable: elements that compare equal retain their original relative"
143-
+ " order.\n"
144-
+ "<pre class=\"language-python\">sorted([3, 5, 4]) == [3, 4, 5]</pre>",
253+
+ " order.\n" //
254+
+ "<pre class=\"language-python\">\n" //
255+
+ "sorted([3, 5, 4]) == [3, 4, 5]\n" //
256+
+ "sorted([3, 5, 4], reverse = True) == [5, 4, 3]\n" //
257+
+ "sorted([\"two\", \"three\", \"four\"], key = len) == [\"two\", \"four\","
258+
+ " \"three\"] # sort by length\n" //
259+
+ "</pre>",
145260
parameters = {
146261
@Param(name = "iterable", doc = "The iterable sequence to sort."),
147262
@Param(
148263
name = "key",
149-
doc = "An optional function applied to each element before comparison.",
150264
named = true,
151-
defaultValue = "None",
152-
positional = false),
265+
allowedTypes = {
266+
@ParamType(type = StarlarkCallable.class),
267+
@ParamType(type = NoneType.class),
268+
},
269+
doc = "An optional function applied to each element before comparison.",
270+
defaultValue = "None"),
153271
@Param(
154272
name = "reverse",
155273
doc = "Return results in descending order.",
@@ -177,9 +295,6 @@ public StarlarkList<?> sorted(
177295
// The user provided a key function.
178296
// We must call it exactly once per element, in order,
179297
// so use the decorate/sort/undecorate pattern.
180-
if (!(key instanceof StarlarkCallable)) {
181-
throw Starlark.errorf("for key, got %s, want callable", Starlark.type(key));
182-
}
183298
StarlarkCallable keyfn = (StarlarkCallable) key;
184299

185300
// decorate

src/test/java/net/starlark/java/eval/testdata/min_max.star

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ assert_eq(min([1, 2], [3]), [1, 2])
88
assert_eq(min([1, 5], [1, 6], [2, 4], [0, 6]), [0, 6])
99
assert_eq(min([-1]), -1)
1010
assert_eq(min([5, 2, 3]), 2)
11-
assert_eq(min({1: 2, -1: 3}), -1)
12-
assert_eq(min({2: None}), 2)
11+
assert_eq(min({1: 2, -1: 3}), -1) # a single dict argument is treated as its sequence of keys
12+
assert_eq(min({2: None}), 2) # a single dict argument is treated as its sequence of keys
1313
assert_eq(min(-1, 2), -1)
1414
assert_eq(min(5, 2, 3), 2)
1515
assert_eq(min(1, 1, 1, 1, 1, 1), 1)
@@ -21,15 +21,62 @@ assert_fails(lambda: min([]), "expected at least one item")
2121
assert_fails(lambda: min(1, "2", True), "unsupported comparison: int <=> string")
2222
assert_fails(lambda: min([1, "2", True]), "unsupported comparison: int <=> string")
2323

24+
# min with key
25+
assert_eq(min("aBcDeFXyZ".elems(), key = lambda s: s.upper()), "a")
26+
assert_eq(min("test", "xyz", key = len), "xyz")
27+
assert_eq(min([4, 5], [1], key = lambda x: x), [1])
28+
assert_eq(min([1, 2], [3], key = lambda x: x), [1, 2])
29+
assert_eq(min([1, 5], [1, 6], [2, 4], [0, 6], key = lambda x: x), [0, 6])
30+
assert_eq(min([1, 5], [1, 6], [2, 4], [0, 6], key = lambda x: x[1]), [2, 4])
31+
assert_eq(min([-1], key = lambda x: x), -1)
32+
assert_eq(min([5, 2, 3], key = lambda x: x), 2)
33+
assert_eq(min({1: 2, -1: 3}, key = lambda x: x), -1) # a single dict argument is treated as its sequence of keys
34+
assert_eq(min({2: None}, key = lambda x: x), 2) # a single dict argument is treated as its sequence of keys
35+
assert_eq(min(-1, 2, key = lambda x: x), -1)
36+
assert_eq(min(5, 2, 3, key = lambda x: x), 2)
37+
assert_eq(min(1, 1, 1, 1, 1, 1, key = lambda x: -x), 1)
38+
assert_eq(min([1, 1, 1, 1, 1, 1], key = lambda x: -x), 1)
39+
assert_fails(lambda: min(1, key = lambda x: x), "type 'int' is not iterable")
40+
assert_fails(lambda: min(key = lambda x: x), "expected at least one item")
41+
assert_fails(lambda: min([], key = lambda x: x), "expected at least one item")
42+
assert_fails(lambda: min([1], ["2"], [True], key = lambda x: x[0]), "unsupported comparison: (int <=> string|string <=> int)")
43+
assert_fails(lambda: min([[1], ["2"], [True]], key = lambda x: x[0]), "unsupported comparison: (int <=> string|string <=> int)")
44+
45+
# verify min with key chooses first value with minimal key
46+
assert_eq(min(1, -1, -2, 2, key = abs), 1)
47+
assert_eq(min([1, -1, -2, 2], key = abs), 1)
48+
49+
# min with failing key
50+
assert_fails(lambda: min(0, 1, 2, 3, 4, key = lambda x: "foo".elems()[x]), "index out of range \\(index is 3, but sequence has 3 elements\\)")
51+
assert_fails(lambda: min([0, 1, 2, 3, 4], key = lambda x: "foo".elems()[x]), "index out of range \\(index is 3, but sequence has 3 elements\\)")
52+
53+
# min with non-callable key
54+
assert_fails(lambda: min(1, 2, 3, key = "hello"), "parameter 'key' got value of type 'string', want 'callable or NoneType'")
55+
assert_fails(lambda: min([1, 2, 3], key = "hello"), "parameter 'key' got value of type 'string', want 'callable or NoneType'")
56+
57+
# verify min with key invokes key callback exactly once per item
58+
def make_counting_identity():
59+
call_count = {}
60+
61+
def counting_identity(x):
62+
call_count[x] = call_count.get(x, 0) + 1
63+
return x
64+
65+
return counting_identity, call_count
66+
67+
min_counting_identity, min_call_count = make_counting_identity()
68+
assert_eq(min("min".elems(), key = min_counting_identity), "i")
69+
assert_eq(min_call_count, {"m": 1, "i": 1, "n": 1})
70+
2471
# max
2572
assert_eq(max("abcdefxyz".elems()), "z")
2673
assert_eq(max("test", "xyz"), "xyz")
2774
assert_eq(max("test", "xyz"), "xyz")
2875
assert_eq(max([1, 2], [5]), [5])
2976
assert_eq(max([-1]), -1)
3077
assert_eq(max([5, 2, 3]), 5)
31-
assert_eq(max({1: 2, -1: 3}), 1)
32-
assert_eq(max({2: None}), 2)
78+
assert_eq(max({1: 2, -1: 3}), 1) # a single dict argument is treated as its sequence of keys
79+
assert_eq(max({2: None}), 2) # a single dict argument is treated as its sequence of keys
3380
assert_eq(max(-1, 2), 2)
3481
assert_eq(max(5, 2, 3), 5)
3582
assert_eq(max(1, 1, 1, 1, 1, 1), 1)
@@ -40,3 +87,38 @@ assert_fails(lambda: max(), "expected at least one item")
4087
assert_fails(lambda: max([]), "expected at least one item")
4188
assert_fails(lambda: max(1, "2", True), "unsupported comparison: int <=> string")
4289
assert_fails(lambda: max([1, "2", True]), "unsupported comparison: int <=> string")
90+
91+
# max with key
92+
assert_eq(max("aBcDeFXyZ".elems(), key = lambda s: s.lower()), "Z")
93+
assert_eq(max("test", "xyz", key = len), "test")
94+
assert_eq(max([1, 2], [5], key = lambda x: x), [5])
95+
assert_eq(max([-1], key = lambda x: x), -1)
96+
assert_eq(max([5, 2, 3], key = lambda x: x), 5)
97+
assert_eq(max({1: 2, -1: 3}, key = lambda x: x), 1) # a single dict argument is treated as its sequence of keys
98+
assert_eq(max({2: None}, key = lambda x: x), 2) # a single dict argument is treated as its sequence of keys
99+
assert_eq(max(-1, 2, key = lambda x: x), 2)
100+
assert_eq(max(5, 2, 3, key = lambda x: x), 5)
101+
assert_eq(max(1, 1, 1, 1, 1, 1, key = lambda x: -x), 1)
102+
assert_eq(max([1, 1, 1, 1, 1, 1], key = lambda x: -x), 1)
103+
assert_fails(lambda: max(1, key = lambda x: x), "type 'int' is not iterable")
104+
assert_fails(lambda: max(key = lambda x: x), "expected at least one item")
105+
assert_fails(lambda: max([], key = lambda x: x), "expected at least one item")
106+
assert_fails(lambda: max([1], ["2"], [True], key = lambda x: x[0]), "unsupported comparison: (int <=> string|string <=> int)")
107+
assert_fails(lambda: max([[1], ["2"], [True]], key = lambda x: x[0]), "unsupported comparison: (int <=> string|string <=> int)")
108+
109+
# verify max with key chooses first value with minimal key
110+
assert_eq(max(1, -1, -2, 2, key = abs), -2)
111+
assert_eq(max([1, -1, -2, 2], key = abs), -2)
112+
113+
# max with failing key
114+
assert_fails(lambda: max(0, 1, 2, 3, 4, key = lambda i: "xyz".elems()[i]), "index out of range \\(index is 3, but sequence has 3 elements\\)")
115+
assert_fails(lambda: max([0, 1, 2, 3, 4], key = lambda i: "xyz".elems()[i]), "index out of range \\(index is 3, but sequence has 3 elements\\)")
116+
117+
# max with non-callable key
118+
assert_fails(lambda: max(1, 2, 3, key = "hello"), "parameter 'key' got value of type 'string', want 'callable or NoneType'")
119+
assert_fails(lambda: max([1, 2, 3], key = "hello"), "parameter 'key' got value of type 'string', want 'callable or NoneType'")
120+
121+
# verify max with key invokes key callback exactly once per item
122+
max_counting_identity, max_call_count = make_counting_identity()
123+
assert_eq(max("max".elems(), key = max_counting_identity), "x")
124+
assert_eq(max_call_count, {"m": 1, "a": 1, "x": 1})

0 commit comments

Comments
 (0)