@@ -22,94 +22,133 @@ use tk::tokenizer::{
2222
2323#[ pyclass( dict, module = "tokenizers" ) ]
2424pub struct AddedToken {
25- pub token : tk:: tokenizer:: AddedToken ,
25+ pub content : String ,
26+ pub is_special_token : bool ,
27+ pub single_word : Option < bool > ,
28+ pub lstrip : Option < bool > ,
29+ pub rstrip : Option < bool > ,
30+ pub normalized : Option < bool > ,
2631}
32+ impl AddedToken {
33+ pub fn from < S : Into < String > > ( content : S , is_special_token : Option < bool > ) -> Self {
34+ Self {
35+ content : content. into ( ) ,
36+ is_special_token : is_special_token. unwrap_or ( false ) ,
37+ single_word : None ,
38+ lstrip : None ,
39+ rstrip : None ,
40+ normalized : None ,
41+ }
42+ }
43+
44+ pub fn get_token ( & self ) -> tk:: tokenizer:: AddedToken {
45+ let mut token = tk:: AddedToken :: from ( & self . content , self . is_special_token ) ;
46+
47+ if let Some ( sw) = self . single_word {
48+ token = token. single_word ( sw) ;
49+ }
50+ if let Some ( ls) = self . lstrip {
51+ token = token. lstrip ( ls) ;
52+ }
53+ if let Some ( rs) = self . rstrip {
54+ token = token. rstrip ( rs) ;
55+ }
56+ if let Some ( n) = self . normalized {
57+ token = token. normalized ( n) ;
58+ }
59+
60+ token
61+ }
62+
63+ pub fn as_pydict < ' py > ( & self , py : Python < ' py > ) -> PyResult < & ' py PyDict > {
64+ let dict = PyDict :: new ( py) ;
65+ let token = self . get_token ( ) ;
66+
67+ dict. set_item ( "content" , token. content ) ?;
68+ dict. set_item ( "single_word" , token. single_word ) ?;
69+ dict. set_item ( "lstrip" , token. lstrip ) ?;
70+ dict. set_item ( "rstrip" , token. rstrip ) ?;
71+ dict. set_item ( "normalized" , token. normalized ) ?;
72+
73+ Ok ( dict)
74+ }
75+ }
76+
2777#[ pymethods]
2878impl AddedToken {
2979 #[ new]
3080 #[ args( kwargs = "**" ) ]
31- fn new ( content : & str , is_special_token : bool , kwargs : Option < & PyDict > ) -> PyResult < Self > {
32- let mut token = tk :: tokenizer :: AddedToken :: from ( content, is_special_token ) ;
81+ fn new ( content : Option < & str > , kwargs : Option < & PyDict > ) -> PyResult < Self > {
82+ let mut token = AddedToken :: from ( content. unwrap_or ( "" ) , None ) ;
3383
3484 if let Some ( kwargs) = kwargs {
3585 for ( key, value) in kwargs {
3686 let key: & str = key. extract ( ) ?;
3787 match key {
38- "single_word" => token = token . single_word ( value. extract ( ) ?) ,
39- "lstrip" => token = token . lstrip ( value. extract ( ) ?) ,
40- "rstrip" => token = token . rstrip ( value. extract ( ) ?) ,
41- "normalized" => token = token . normalized ( value. extract ( ) ?) ,
88+ "single_word" => token. single_word = Some ( value. extract ( ) ?) ,
89+ "lstrip" => token. lstrip = Some ( value. extract ( ) ?) ,
90+ "rstrip" => token. rstrip = Some ( value. extract ( ) ?) ,
91+ "normalized" => token. normalized = Some ( value. extract ( ) ?) ,
4292 _ => println ! ( "Ignored unknown kwarg option {}" , key) ,
4393 }
4494 }
4595 }
4696
47- Ok ( AddedToken { token } )
97+ Ok ( token)
4898 }
4999
50- fn __getstate__ ( & self , py : Python ) -> PyResult < PyObject > {
51- let data = serde_json:: to_string ( & self . token ) . map_err ( |e| {
52- exceptions:: Exception :: py_err ( format ! (
53- "Error while attempting to pickle AddedToken: {}" ,
54- e. to_string( )
55- ) )
56- } ) ?;
57- Ok ( PyBytes :: new ( py, data. as_bytes ( ) ) . to_object ( py) )
100+ fn __getstate__ < ' py > ( & self , py : Python < ' py > ) -> PyResult < & ' py PyDict > {
101+ self . as_pydict ( py)
58102 }
59103
60104 fn __setstate__ ( & mut self , py : Python , state : PyObject ) -> PyResult < ( ) > {
61- match state. extract :: < & PyBytes > ( py) {
62- Ok ( s) => {
63- self . token = serde_json:: from_slice ( s. as_bytes ( ) ) . map_err ( |e| {
64- exceptions:: Exception :: py_err ( format ! (
65- "Error while attempting to unpickle AddedToken: {}" ,
66- e. to_string( )
67- ) )
68- } ) ?;
105+ match state. extract :: < & PyDict > ( py) {
106+ Ok ( state) => {
107+ for ( key, value) in state {
108+ let key: & str = key. extract ( ) ?;
109+ match key {
110+ "single_word" => self . single_word = Some ( value. extract ( ) ?) ,
111+ "lstrip" => self . lstrip = Some ( value. extract ( ) ?) ,
112+ "rstrip" => self . rstrip = Some ( value. extract ( ) ?) ,
113+ "normalized" => self . normalized = Some ( value. extract ( ) ?) ,
114+ _ => { }
115+ }
116+ }
69117 Ok ( ( ) )
70118 }
71119 Err ( e) => Err ( e) ,
72120 }
73121 }
74122
75- fn __getnewargs__ < ' p > ( & self , py : Python < ' p > ) -> PyResult < & ' p PyTuple > {
76- // We don't really care about the values of `content` & `is_special_token` here because
77- // they will get overriden by `__setstate__`
78- let content: PyObject = "" . into_py ( py) ;
79- let is_special_token: PyObject = false . into_py ( py) ;
80- let args = PyTuple :: new ( py, vec ! [ content, is_special_token] ) ;
81- Ok ( args)
82- }
83-
84123 #[ getter]
85124 fn get_content ( & self ) -> & str {
86- & self . token . content
125+ & self . content
87126 }
88127
89128 #[ getter]
90129 fn get_rstrip ( & self ) -> bool {
91- self . token . rstrip
130+ self . get_token ( ) . rstrip
92131 }
93132
94133 #[ getter]
95134 fn get_lstrip ( & self ) -> bool {
96- self . token . lstrip
135+ self . get_token ( ) . lstrip
97136 }
98137
99138 #[ getter]
100139 fn get_single_word ( & self ) -> bool {
101- self . token . single_word
140+ self . get_token ( ) . single_word
102141 }
103142
104143 #[ getter]
105144 fn get_normalized ( & self ) -> bool {
106- self . token . normalized
145+ self . get_token ( ) . normalized
107146 }
108147}
109148#[ pyproto]
110149impl PyObjectProtocol for AddedToken {
111150 fn __str__ ( & ' p self ) -> PyResult < & ' p str > {
112- Ok ( & self . token . content )
151+ Ok ( & self . content )
113152 }
114153
115154 fn __repr__ ( & self ) -> PyResult < String > {
@@ -118,13 +157,14 @@ impl PyObjectProtocol for AddedToken {
118157 false => "False" ,
119158 } ;
120159
160+ let token = self . get_token ( ) ;
121161 Ok ( format ! (
122162 "AddedToken(\" {}\" , rstrip={}, lstrip={}, single_word={}, normalized={})" ,
123- self . token . content,
124- bool_to_python( self . token. rstrip) ,
125- bool_to_python( self . token. lstrip) ,
126- bool_to_python( self . token. single_word) ,
127- bool_to_python( self . token. normalized)
163+ self . content,
164+ bool_to_python( token. rstrip) ,
165+ bool_to_python( token. lstrip) ,
166+ bool_to_python( token. single_word) ,
167+ bool_to_python( token. normalized)
128168 ) )
129169 }
130170}
@@ -583,9 +623,10 @@ impl Tokenizer {
583623 . into_iter ( )
584624 . map ( |token| {
585625 if let Ok ( content) = token. extract :: < String > ( ) {
586- Ok ( tk:: tokenizer:: AddedToken :: from ( content, false ) )
587- } else if let Ok ( token) = token. extract :: < PyRef < AddedToken > > ( ) {
588- Ok ( token. token . clone ( ) )
626+ Ok ( AddedToken :: from ( content, Some ( false ) ) . get_token ( ) )
627+ } else if let Ok ( mut token) = token. extract :: < PyRefMut < AddedToken > > ( ) {
628+ token. is_special_token = false ;
629+ Ok ( token. get_token ( ) )
589630 } else {
590631 Err ( exceptions:: Exception :: py_err (
591632 "Input must be a List[Union[str, AddedToken]]" ,
@@ -603,8 +644,9 @@ impl Tokenizer {
603644 . map ( |token| {
604645 if let Ok ( content) = token. extract :: < String > ( ) {
605646 Ok ( tk:: tokenizer:: AddedToken :: from ( content, true ) )
606- } else if let Ok ( token) = token. extract :: < PyRef < AddedToken > > ( ) {
607- Ok ( token. token . clone ( ) )
647+ } else if let Ok ( mut token) = token. extract :: < PyRefMut < AddedToken > > ( ) {
648+ token. is_special_token = true ;
649+ Ok ( token. get_token ( ) )
608650 } else {
609651 Err ( exceptions:: Exception :: py_err (
610652 "Input must be a List[Union[str, AddedToken]]" ,
0 commit comments