Skip to content

Commit e1a7a29

Browse files
committed
Add strict parameter to map() builtin
1 parent 3f20619 commit e1a7a29

File tree

2 files changed

+82
-20
lines changed

2 files changed

+82
-20
lines changed

Lib/test/test_builtin.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,7 +1365,6 @@ def test_map_pickle(self):
13651365

13661366
# strict map tests based on strict zip tests
13671367

1368-
@unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Unexpected keyword argument strict
13691368
def test_map_pickle_strict(self):
13701369
a = (1, 2, 3)
13711370
b = (4, 5, 6)
@@ -1374,7 +1373,6 @@ def test_map_pickle_strict(self):
13741373
m1 = map(pack, a, b, strict=True)
13751374
self.check_iter_pickle(m1, t, proto)
13761375

1377-
@unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Unexpected keyword argument strict
13781376
def test_map_pickle_strict_fail(self):
13791377
a = (1, 2, 3)
13801378
b = (4, 5, 6, 7)
@@ -1385,7 +1383,6 @@ def test_map_pickle_strict_fail(self):
13851383
self.assertEqual(self.iter_error(m1, ValueError), t)
13861384
self.assertEqual(self.iter_error(m2, ValueError), t)
13871385

1388-
@unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Unexpected keyword argument strict
13891386
def test_map_strict(self):
13901387
self.assertEqual(tuple(map(pack, (1, 2, 3), 'abc', strict=True)),
13911388
((1, 'a'), (2, 'b'), (3, 'c')))
@@ -1412,7 +1409,6 @@ def test_map_strict(self):
14121409
self.assertRaises(ValueError, tuple,
14131410
map(pack, 'a', t2, t3, strict=True))
14141411

1415-
@unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Unexpected keyword argument strict
14161412
def test_map_strict_iterators(self):
14171413
x = iter(range(5))
14181414
y = [0]
@@ -1422,7 +1418,6 @@ def test_map_strict_iterators(self):
14221418
self.assertEqual(next(x), 2)
14231419
self.assertEqual(next(z), 1)
14241420

1425-
@unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Unexpected keyword argument strict
14261421
def test_map_strict_error_handling(self):
14271422

14281423
class Error(Exception):
@@ -1456,7 +1451,6 @@ def __next__(self):
14561451
l8 = self.iter_error(map(pack, Iter(3), "AB", strict=True), ValueError)
14571452
self.assertEqual(l8, [(2, "A"), (1, "B")])
14581453

1459-
@unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Unexpected keyword argument strict
14601454
def test_map_strict_error_handling_stopiteration(self):
14611455

14621456
class Iter:

crates/vm/src/builtins/map.rs

Lines changed: 82 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
1-
use super::{PyType, PyTypeRef};
1+
use super::PyType;
22
use crate::{
3-
Context, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine,
3+
AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine,
44
builtins::PyTupleRef,
55
class::PyClassImpl,
6-
function::PosArgs,
6+
function::{ArgIntoBool, OptionalArg, PosArgs},
77
protocol::{PyIter, PyIterReturn},
8-
raise_if_stop,
98
types::{Constructor, IterNext, Iterable, SelfIter},
109
};
10+
use rustpython_common::atomic::{self, PyAtomic, Radium};
1111

1212
#[pyclass(module = false, name = "map", traverse)]
1313
#[derive(Debug)]
1414
pub struct PyMap {
1515
mapper: PyObjectRef,
1616
iterators: Vec<PyIter>,
17+
#[pytraverse(skip)]
18+
strict: PyAtomic<bool>,
1719
}
1820

1921
impl PyPayload for PyMap {
@@ -23,16 +25,27 @@ impl PyPayload for PyMap {
2325
}
2426
}
2527

28+
#[derive(FromArgs)]
29+
pub struct PyMapNewArgs {
30+
#[pyarg(named, optional)]
31+
strict: OptionalArg<bool>,
32+
}
33+
2634
impl Constructor for PyMap {
27-
type Args = (PyObjectRef, PosArgs<PyIter>);
35+
type Args = (PyObjectRef, PosArgs<PyIter>, PyMapNewArgs);
2836

2937
fn py_new(
3038
_cls: &Py<PyType>,
31-
(mapper, iterators): Self::Args,
39+
(mapper, iterators, args): Self::Args,
3240
_vm: &VirtualMachine,
3341
) -> PyResult<Self> {
3442
let iterators = iterators.into_vec();
35-
Ok(Self { mapper, iterators })
43+
let strict = Radium::new(args.strict.unwrap_or(false));
44+
Ok(Self {
45+
mapper,
46+
iterators,
47+
strict,
48+
})
3649
}
3750
}
3851

@@ -48,21 +61,76 @@ impl PyMap {
4861
}
4962

5063
#[pymethod]
51-
fn __reduce__(&self, vm: &VirtualMachine) -> (PyTypeRef, PyTupleRef) {
52-
let mut vec = vec![self.mapper.clone()];
53-
vec.extend(self.iterators.iter().map(|o| o.clone().into()));
54-
(vm.ctx.types.map_type.to_owned(), vm.new_tuple(vec))
64+
fn __reduce__(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
65+
let cls = zelf.class().to_owned();
66+
let mut vec = vec![zelf.mapper.clone()];
67+
vec.extend(zelf.iterators.iter().map(|o| o.clone().into()));
68+
let tuple_args = vm.ctx.new_tuple(vec);
69+
Ok(if zelf.strict.load(atomic::Ordering::Acquire) {
70+
vm.new_tuple((cls, tuple_args, true))
71+
} else {
72+
vm.new_tuple((cls, tuple_args))
73+
})
74+
}
75+
76+
#[pymethod]
77+
fn __setstate__(zelf: PyRef<Self>, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
78+
if let Ok(obj) = ArgIntoBool::try_from_object(vm, state) {
79+
zelf.strict.store(obj.into(), atomic::Ordering::Release);
80+
}
81+
Ok(())
5582
}
5683
}
5784

5885
impl SelfIter for PyMap {}
5986

6087
impl IterNext for PyMap {
6188
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
89+
let strict = zelf.strict.load(atomic::Ordering::Acquire);
6290
let mut next_objs = Vec::new();
63-
for iterator in &zelf.iterators {
64-
let item = raise_if_stop!(iterator.next(vm)?);
65-
next_objs.push(item);
91+
let mut stopped_at: Option<usize> = None;
92+
93+
for (idx, iterator) in zelf.iterators.iter().enumerate() {
94+
match iterator.next(vm)? {
95+
PyIterReturn::Return(obj) => {
96+
if let Some(stopped_idx) = stopped_at {
97+
if strict {
98+
let plural = if stopped_idx == 0 { " " } else { "s 1-" };
99+
return Err(vm.new_value_error(format!(
100+
"map() argument {} is longer than argument{}{}",
101+
idx + 1,
102+
plural,
103+
stopped_idx + 1,
104+
)));
105+
}
106+
return Ok(PyIterReturn::StopIteration(None));
107+
}
108+
next_objs.push(obj);
109+
}
110+
PyIterReturn::StopIteration(v) => {
111+
if stopped_at.is_some() {
112+
continue;
113+
}
114+
if strict && idx > 0 {
115+
let plural = if idx == 1 { " " } else { "s 1-" };
116+
return Err(vm.new_value_error(format!(
117+
"map() argument {} is shorter than argument{}{}",
118+
idx + 1,
119+
plural,
120+
idx,
121+
)));
122+
}
123+
if strict {
124+
stopped_at = Some(idx);
125+
} else {
126+
return Ok(PyIterReturn::StopIteration(v));
127+
}
128+
}
129+
}
130+
}
131+
132+
if stopped_at.is_some() {
133+
return Ok(PyIterReturn::StopIteration(None));
66134
}
67135

68136
// the mapper itself can raise StopIteration which does stop the map iteration

0 commit comments

Comments
 (0)