1- use super :: { PyType , PyTypeRef } ;
1+ use super :: PyType ;
22use crate :: {
3- Context , Py , PyObjectRef , PyPayload , PyResult , VirtualMachine ,
3+ AsObject , Context , Py , PyObjectRef , PyPayload , PyRef , PyResult , TryFromObject , VirtualMachine ,
44 builtins:: PyTupleRef ,
55 class:: PyClassImpl ,
6- function:: PosArgs ,
6+ function:: { ArgIntoBool , OptionalArg , PosArgs } ,
77 protocol:: { PyIter , PyIterReturn } ,
8- raise_if_stop,
98 types:: { Constructor , IterNext , Iterable , SelfIter } ,
109} ;
10+ use rustpython_common:: atomic:: { self , PyAtomic , Radium } ;
1111
1212#[ pyclass( module = false , name = "map" , traverse) ]
1313#[ derive( Debug ) ]
1414pub struct PyMap {
1515 mapper : PyObjectRef ,
1616 iterators : Vec < PyIter > ,
17+ #[ pytraverse( skip) ]
18+ strict : PyAtomic < bool > ,
1719}
1820
1921impl PyPayload for PyMap {
@@ -23,16 +25,27 @@ impl PyPayload for PyMap {
2325 }
2426}
2527
28+ #[ derive( FromArgs ) ]
29+ pub struct PyMapNewArgs {
30+ #[ pyarg( named, optional) ]
31+ strict : OptionalArg < bool > ,
32+ }
33+
2634impl Constructor for PyMap {
27- type Args = ( PyObjectRef , PosArgs < PyIter > ) ;
35+ type Args = ( PyObjectRef , PosArgs < PyIter > , PyMapNewArgs ) ;
2836
2937 fn py_new (
3038 _cls : & Py < PyType > ,
31- ( mapper, iterators) : Self :: Args ,
39+ ( mapper, iterators, args ) : Self :: Args ,
3240 _vm : & VirtualMachine ,
3341 ) -> PyResult < Self > {
3442 let iterators = iterators. into_vec ( ) ;
35- Ok ( Self { mapper, iterators } )
43+ let strict = Radium :: new ( args. strict . unwrap_or ( false ) ) ;
44+ Ok ( Self {
45+ mapper,
46+ iterators,
47+ strict,
48+ } )
3649 }
3750}
3851
@@ -48,21 +61,76 @@ impl PyMap {
4861 }
4962
5063 #[ pymethod]
51- fn __reduce__ ( & self , vm : & VirtualMachine ) -> ( PyTypeRef , PyTupleRef ) {
52- let mut vec = vec ! [ self . mapper. clone( ) ] ;
53- vec. extend ( self . iterators . iter ( ) . map ( |o| o. clone ( ) . into ( ) ) ) ;
54- ( vm. ctx . types . map_type . to_owned ( ) , vm. new_tuple ( vec) )
64+ fn __reduce__ ( zelf : PyRef < Self > , vm : & VirtualMachine ) -> PyResult < PyTupleRef > {
65+ let cls = zelf. class ( ) . to_owned ( ) ;
66+ let mut vec = vec ! [ zelf. mapper. clone( ) ] ;
67+ vec. extend ( zelf. iterators . iter ( ) . map ( |o| o. clone ( ) . into ( ) ) ) ;
68+ let tuple_args = vm. ctx . new_tuple ( vec) ;
69+ Ok ( if zelf. strict . load ( atomic:: Ordering :: Acquire ) {
70+ vm. new_tuple ( ( cls, tuple_args, true ) )
71+ } else {
72+ vm. new_tuple ( ( cls, tuple_args) )
73+ } )
74+ }
75+
76+ #[ pymethod]
77+ fn __setstate__ ( zelf : PyRef < Self > , state : PyObjectRef , vm : & VirtualMachine ) -> PyResult < ( ) > {
78+ if let Ok ( obj) = ArgIntoBool :: try_from_object ( vm, state) {
79+ zelf. strict . store ( obj. into ( ) , atomic:: Ordering :: Release ) ;
80+ }
81+ Ok ( ( ) )
5582 }
5683}
5784
5885impl SelfIter for PyMap { }
5986
6087impl IterNext for PyMap {
6188 fn next ( zelf : & Py < Self > , vm : & VirtualMachine ) -> PyResult < PyIterReturn > {
89+ let strict = zelf. strict . load ( atomic:: Ordering :: Acquire ) ;
6290 let mut next_objs = Vec :: new ( ) ;
63- for iterator in & zelf. iterators {
64- let item = raise_if_stop ! ( iterator. next( vm) ?) ;
65- next_objs. push ( item) ;
91+ let mut stopped_at: Option < usize > = None ;
92+
93+ for ( idx, iterator) in zelf. iterators . iter ( ) . enumerate ( ) {
94+ match iterator. next ( vm) ? {
95+ PyIterReturn :: Return ( obj) => {
96+ if let Some ( stopped_idx) = stopped_at {
97+ if strict {
98+ let plural = if stopped_idx == 0 { " " } else { "s 1-" } ;
99+ return Err ( vm. new_value_error ( format ! (
100+ "map() argument {} is longer than argument{}{}" ,
101+ idx + 1 ,
102+ plural,
103+ stopped_idx + 1 ,
104+ ) ) ) ;
105+ }
106+ return Ok ( PyIterReturn :: StopIteration ( None ) ) ;
107+ }
108+ next_objs. push ( obj) ;
109+ }
110+ PyIterReturn :: StopIteration ( v) => {
111+ if stopped_at. is_some ( ) {
112+ continue ;
113+ }
114+ if strict && idx > 0 {
115+ let plural = if idx == 1 { " " } else { "s 1-" } ;
116+ return Err ( vm. new_value_error ( format ! (
117+ "map() argument {} is shorter than argument{}{}" ,
118+ idx + 1 ,
119+ plural,
120+ idx,
121+ ) ) ) ;
122+ }
123+ if strict {
124+ stopped_at = Some ( idx) ;
125+ } else {
126+ return Ok ( PyIterReturn :: StopIteration ( v) ) ;
127+ }
128+ }
129+ }
130+ }
131+
132+ if stopped_at. is_some ( ) {
133+ return Ok ( PyIterReturn :: StopIteration ( None ) ) ;
66134 }
67135
68136 // the mapper itself can raise StopIteration which does stop the map iteration
0 commit comments