Skip to content

Commit 61b3b4f

Browse files
authored
Add strict parameter to map() builtin (#7405)
* Add strict parameter to map() builtin * Refactor map IterNext to match zip style
1 parent f26752c commit 61b3b4f

File tree

2 files changed

+67
-19
lines changed

2 files changed

+67
-19
lines changed

Lib/test/test_builtin.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,7 +1365,6 @@ def test_map_pickle(self):
13651365

13661366
# strict map tests based on strict zip tests
13671367

1368-
@unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Unexpected keyword argument strict
13691368
def test_map_pickle_strict(self):
13701369
a = (1, 2, 3)
13711370
b = (4, 5, 6)
@@ -1374,7 +1373,6 @@ def test_map_pickle_strict(self):
13741373
m1 = map(pack, a, b, strict=True)
13751374
self.check_iter_pickle(m1, t, proto)
13761375

1377-
@unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Unexpected keyword argument strict
13781376
def test_map_pickle_strict_fail(self):
13791377
a = (1, 2, 3)
13801378
b = (4, 5, 6, 7)
@@ -1385,7 +1383,6 @@ def test_map_pickle_strict_fail(self):
13851383
self.assertEqual(self.iter_error(m1, ValueError), t)
13861384
self.assertEqual(self.iter_error(m2, ValueError), t)
13871385

1388-
@unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Unexpected keyword argument strict
13891386
def test_map_strict(self):
13901387
self.assertEqual(tuple(map(pack, (1, 2, 3), 'abc', strict=True)),
13911388
((1, 'a'), (2, 'b'), (3, 'c')))
@@ -1412,7 +1409,6 @@ def test_map_strict(self):
14121409
self.assertRaises(ValueError, tuple,
14131410
map(pack, 'a', t2, t3, strict=True))
14141411

1415-
@unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Unexpected keyword argument strict
14161412
def test_map_strict_iterators(self):
14171413
x = iter(range(5))
14181414
y = [0]
@@ -1422,7 +1418,6 @@ def test_map_strict_iterators(self):
14221418
self.assertEqual(next(x), 2)
14231419
self.assertEqual(next(z), 1)
14241420

1425-
@unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Unexpected keyword argument strict
14261421
def test_map_strict_error_handling(self):
14271422

14281423
class Error(Exception):
@@ -1456,7 +1451,6 @@ def __next__(self):
14561451
l8 = self.iter_error(map(pack, Iter(3), "AB", strict=True), ValueError)
14571452
self.assertEqual(l8, [(2, "A"), (1, "B")])
14581453

1459-
@unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Unexpected keyword argument strict
14601454
def test_map_strict_error_handling_stopiteration(self):
14611455

14621456
class Iter:

crates/vm/src/builtins/map.rs

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
1-
use super::{PyType, PyTypeRef};
1+
use super::PyType;
22
use crate::{
3-
Context, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine,
3+
AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine,
44
builtins::PyTupleRef,
55
class::PyClassImpl,
6-
function::PosArgs,
6+
function::{ArgIntoBool, OptionalArg, PosArgs},
77
protocol::{PyIter, PyIterReturn},
8-
raise_if_stop,
98
types::{Constructor, IterNext, Iterable, SelfIter},
109
};
10+
use rustpython_common::atomic::{self, PyAtomic, Radium};
1111

1212
#[pyclass(module = false, name = "map", traverse)]
1313
#[derive(Debug)]
1414
pub struct PyMap {
1515
mapper: PyObjectRef,
1616
iterators: Vec<PyIter>,
17+
#[pytraverse(skip)]
18+
strict: PyAtomic<bool>,
1719
}
1820

1921
impl PyPayload for PyMap {
@@ -23,16 +25,27 @@ impl PyPayload for PyMap {
2325
}
2426
}
2527

28+
#[derive(FromArgs)]
29+
pub struct PyMapNewArgs {
30+
#[pyarg(named, optional)]
31+
strict: OptionalArg<bool>,
32+
}
33+
2634
impl Constructor for PyMap {
27-
type Args = (PyObjectRef, PosArgs<PyIter>);
35+
type Args = (PyObjectRef, PosArgs<PyIter>, PyMapNewArgs);
2836

2937
fn py_new(
3038
_cls: &Py<PyType>,
31-
(mapper, iterators): Self::Args,
39+
(mapper, iterators, args): Self::Args,
3240
_vm: &VirtualMachine,
3341
) -> PyResult<Self> {
3442
let iterators = iterators.into_vec();
35-
Ok(Self { mapper, iterators })
43+
let strict = Radium::new(args.strict.unwrap_or(false));
44+
Ok(Self {
45+
mapper,
46+
iterators,
47+
strict,
48+
})
3649
}
3750
}
3851

@@ -48,10 +61,24 @@ impl PyMap {
4861
}
4962

5063
#[pymethod]
51-
fn __reduce__(&self, vm: &VirtualMachine) -> (PyTypeRef, PyTupleRef) {
52-
let mut vec = vec![self.mapper.clone()];
53-
vec.extend(self.iterators.iter().map(|o| o.clone().into()));
54-
(vm.ctx.types.map_type.to_owned(), vm.new_tuple(vec))
64+
fn __reduce__(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
65+
let cls = zelf.class().to_owned();
66+
let mut vec = vec![zelf.mapper.clone()];
67+
vec.extend(zelf.iterators.iter().map(|o| o.clone().into()));
68+
let tuple_args = vm.ctx.new_tuple(vec);
69+
Ok(if zelf.strict.load(atomic::Ordering::Acquire) {
70+
vm.new_tuple((cls, tuple_args, true))
71+
} else {
72+
vm.new_tuple((cls, tuple_args))
73+
})
74+
}
75+
76+
#[pymethod]
77+
fn __setstate__(zelf: PyRef<Self>, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
78+
if let Ok(obj) = ArgIntoBool::try_from_object(vm, state) {
79+
zelf.strict.store(obj.into(), atomic::Ordering::Release);
80+
}
81+
Ok(())
5582
}
5683
}
5784

@@ -60,8 +87,35 @@ impl SelfIter for PyMap {}
6087
impl IterNext for PyMap {
6188
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
6289
let mut next_objs = Vec::new();
63-
for iterator in &zelf.iterators {
64-
let item = raise_if_stop!(iterator.next(vm)?);
90+
for (idx, iterator) in zelf.iterators.iter().enumerate() {
91+
let item = match iterator.next(vm)? {
92+
PyIterReturn::Return(obj) => obj,
93+
PyIterReturn::StopIteration(v) => {
94+
if zelf.strict.load(atomic::Ordering::Acquire) {
95+
if idx > 0 {
96+
let plural = if idx == 1 { " " } else { "s 1-" };
97+
return Err(vm.new_value_error(format!(
98+
"map() argument {} is shorter than argument{}{}",
99+
idx + 1,
100+
plural,
101+
idx,
102+
)));
103+
}
104+
for (idx, iterator) in zelf.iterators[1..].iter().enumerate() {
105+
if let PyIterReturn::Return(_) = iterator.next(vm)? {
106+
let plural = if idx == 0 { " " } else { "s 1-" };
107+
return Err(vm.new_value_error(format!(
108+
"map() argument {} is longer than argument{}{}",
109+
idx + 2,
110+
plural,
111+
idx + 1,
112+
)));
113+
}
114+
}
115+
}
116+
return Ok(PyIterReturn::StopIteration(v));
117+
}
118+
};
65119
next_objs.push(item);
66120
}
67121

0 commit comments

Comments
 (0)