Skip to content

Commit cdabef1

Browse files
authored
Changing Decoder trait to be more composable. (#938)
* Changing `Decoder` trait to be more composable. Fix #872 * Fixing Python side. * Fixing test. * Updating cleanup signature, removing turbofish.
1 parent 1f1f86d commit cdabef1

11 files changed

Lines changed: 147 additions & 80 deletions

File tree

bindings/node/lib/bindings/decoders.test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ describe("wordPieceDecoder", () => {
1212
it("can decode arrays of strings", () => {
1313
expect(
1414
wordPieceDecoder().decode(["Hel", "##lo", "there", "my", "fr", "##iend"])
15-
).toEqual("Hello there my friend");
15+
).toEqual(["Hel", "lo", " there", " my", " fr", "iend"]);
1616
});
1717
});
1818

@@ -39,6 +39,6 @@ describe("ctcDecoder", () => {
3939
it("encodes correctly", () => {
4040
expect(
4141
ctcDecoder().decode(["<pad>", "h", "h", "e", "e", "l", "l", "<pad>", "l", "l", "o"])
42-
).toEqual("hello");
42+
).toEqual(["h", "e", "l", "l", "o"]);
4343
});
4444
});

bindings/node/native/src/decoders.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ pub struct Decoder {
1414
}
1515

1616
impl tk::Decoder for Decoder {
17-
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
17+
fn decode(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
1818
self.decoder
1919
.as_ref()
2020
.ok_or("Uninitialized Decoder")?
@@ -41,7 +41,13 @@ declare_types! {
4141
.decode(tokens)
4242
.map_err(|e| Error(format!("{}", e)))?;
4343

44-
Ok(cx.string(output).upcast())
44+
let decoded = JsArray::new(&mut cx, output.len() as u32);
45+
for (i, token) in output.into_iter().enumerate() {
46+
let js_token = cx.string(token);
47+
decoded.set(&mut cx, i as u32, js_token)?;
48+
}
49+
50+
Ok(decoded.upcast())
4551
}
4652
}
4753
}

bindings/python/src/decoders.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ impl PyDecoder {
5151
}
5252

5353
impl Decoder for PyDecoder {
54-
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
54+
fn decode(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
5555
self.decoder.decode(tokens)
5656
}
5757
}
@@ -98,7 +98,7 @@ impl PyDecoder {
9898
/// Returns:
9999
/// :obj:`str`: The decoded string
100100
#[text_signature = "(self, tokens)"]
101-
fn decode(&self, tokens: Vec<String>) -> PyResult<String> {
101+
fn decode(&self, tokens: Vec<String>) -> PyResult<Vec<String>> {
102102
ToPyResult(self.decoder.decode(tokens)).into()
103103
}
104104
}
@@ -337,12 +337,12 @@ impl CustomDecoder {
337337
}
338338

339339
impl Decoder for CustomDecoder {
340-
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
340+
fn decode(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
341341
Python::with_gil(|py| {
342342
let decoded = self
343343
.inner
344344
.call_method(py, "decode", (tokens,), None)?
345-
.extract::<String>(py)?;
345+
.extract(py)?;
346346
Ok(decoded)
347347
})
348348
}
@@ -396,7 +396,7 @@ where
396396
}
397397

398398
impl Decoder for PyDecoderWrapper {
399-
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
399+
fn decode(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
400400
match self {
401401
PyDecoderWrapper::Wrapped(inner) => inner.read().unwrap().decode(tokens),
402402
PyDecoderWrapper::Custom(inner) => inner.read().unwrap().decode(tokens),

bindings/python/tests/bindings/test_decoders.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def test_instantiate(self):
1414

1515
def test_decoding(self):
1616
decoder = ByteLevel()
17-
assert decoder.decode(["My", "Ġname", "Ġis", "ĠJohn"]) == "My name is John"
17+
assert decoder.decode(["My", "Ġname", "Ġis", "ĠJohn"]) == ["My name is John"]
1818

1919
def test_manual_reload(self):
2020
byte_level = ByteLevel()
@@ -34,11 +34,25 @@ def test_instantiate(self):
3434

3535
def test_decoding(self):
3636
decoder = WordPiece()
37-
assert decoder.decode(["My", "na", "##me", "is", "Jo", "##hn"]) == "My name is John"
38-
assert decoder.decode(["I", "'m", "Jo", "##hn"]) == "I'm John"
37+
assert decoder.decode(["My", "na", "##me", "is", "Jo", "##hn"]) == [
38+
"My",
39+
" na",
40+
"me",
41+
" is",
42+
" Jo",
43+
"hn",
44+
]
45+
assert decoder.decode(["I", "'m", "Jo", "##hn"]) == ["I", "'m", " Jo", "hn"]
3946
decoder = WordPiece(prefix="__", cleanup=False)
40-
assert decoder.decode(["My", "na", "__me", "is", "Jo", "__hn"]) == "My name is John"
41-
assert decoder.decode(["I", "'m", "Jo", "__hn"]) == "I 'm John"
47+
assert decoder.decode(["My", "na", "__me", "is", "Jo", "__hn"]) == [
48+
"My",
49+
" na",
50+
"me",
51+
" is",
52+
" Jo",
53+
"hn",
54+
]
55+
assert decoder.decode(["I", "'m", "Jo", "__hn"]) == ["I", " 'm", " Jo", "hn"]
4256

4357
def test_can_modify(self):
4458
decoder = WordPiece(prefix="$$", cleanup=False)
@@ -66,9 +80,9 @@ def test_instantiate(self):
6680

6781
def test_decoding(self):
6882
decoder = Metaspace()
69-
assert decoder.decode(["▁My", "▁name", "▁is", "▁John"]) == "My name is John"
83+
assert decoder.decode(["▁My", "▁name", "▁is", "▁John"]) == ["My", " name", " is", " John"]
7084
decoder = Metaspace(replacement="-", add_prefix_space=False)
71-
assert decoder.decode(["-My", "-name", "-is", "-John"]) == " My name is John"
85+
assert decoder.decode(["-My", "-name", "-is", "-John"]) == [" My", " name", " is", " John"]
7286

7387
def test_can_modify(self):
7488
decoder = Metaspace(replacement="*", add_prefix_space=False)
@@ -93,12 +107,23 @@ def test_instantiate(self):
93107

94108
def test_decoding(self):
95109
decoder = BPEDecoder()
96-
assert (
97-
decoder.decode(["My</w>", "na", "me</w>", "is</w>", "Jo", "hn</w>"])
98-
== "My name is John"
99-
)
110+
assert decoder.decode(["My</w>", "na", "me</w>", "is</w>", "Jo", "hn</w>"]) == [
111+
"My ",
112+
"na",
113+
"me ",
114+
"is ",
115+
"Jo",
116+
"hn",
117+
]
100118
decoder = BPEDecoder(suffix="_")
101-
assert decoder.decode(["My_", "na", "me_", "is_", "Jo", "hn_"]) == "My name is John"
119+
assert decoder.decode(["My_", "na", "me_", "is_", "Jo", "hn_"]) == [
120+
"My ",
121+
"na",
122+
"me ",
123+
"is ",
124+
"Jo",
125+
"hn",
126+
]
102127

103128
def test_can_modify(self):
104129
decoder = BPEDecoder(suffix="123")
@@ -120,19 +145,13 @@ def test_instantiate(self):
120145

121146
def test_decoding(self):
122147
decoder = CTC()
123-
assert (
124-
decoder.decode(
125-
["<pad>", "<pad>", "h", "e", "e", "l", "l", "<pad>", "l", "o", "o", "o", "<pad>"]
126-
)
127-
== "hello"
128-
)
148+
assert decoder.decode(
149+
["<pad>", "<pad>", "h", "e", "e", "l", "l", "<pad>", "l", "o", "o", "o", "<pad>"]
150+
) == ["h", "e", "l", "l", "o"]
129151
decoder = CTC(pad_token="[PAD]")
130-
assert (
131-
decoder.decode(
132-
["[PAD]", "[PAD]", "h", "e", "e", "l", "l", "[PAD]", "l", "o", "o", "o", "[PAD]"]
133-
)
134-
== "hello"
135-
)
152+
assert decoder.decode(
153+
["[PAD]", "[PAD]", "h", "e", "e", "l", "l", "[PAD]", "l", "o", "o", "o", "[PAD]"]
154+
) == ["h", "e", "l", "l", "o"]
136155

137156
def test_can_modify(self):
138157
decoder = CTC(pad_token="[PAD]")

tokenizers/src/decoders/bpe.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,15 @@ impl Default for BPEDecoder {
2424
}
2525

2626
impl Decoder for BPEDecoder {
27-
fn decode(&self, tokens: Vec<String>) -> Result<String> {
28-
Ok(tokens.join("").replace(&self.suffix, " ").trim().to_owned())
27+
fn decode(&self, tokens: Vec<String>) -> Result<Vec<String>> {
28+
let n = tokens.len() - 1;
29+
Ok(tokens
30+
.into_iter()
31+
.enumerate()
32+
.map(|(i, token)| {
33+
let replacement = if i == n { "" } else { " " };
34+
token.replace(&self.suffix, replacement)
35+
})
36+
.collect())
2937
}
3038
}

tokenizers/src/decoders/ctc.rs

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,23 @@ impl Default for CTC {
4242
}
4343

4444
impl Decoder for CTC {
45-
fn decode(&self, tokens: Vec<String>) -> Result<String> {
46-
let mut output = tokens
45+
fn decode(&self, tokens: Vec<String>) -> Result<Vec<String>> {
46+
Ok(tokens
4747
.into_iter()
4848
.dedup()
49-
.join("")
50-
.replace(&self.pad_token, "");
51-
if self.cleanup {
52-
output = wordpiece::cleanup(output).replace(&self.word_delimiter_token, " ");
53-
}
54-
Ok(output)
49+
.filter_map(|token| {
50+
let mut replaced = token.replace(&self.pad_token, "");
51+
if self.cleanup {
52+
replaced =
53+
wordpiece::cleanup(&replaced).replace(&self.word_delimiter_token, " ");
54+
}
55+
if replaced.is_empty() {
56+
None
57+
} else {
58+
Some(replaced)
59+
}
60+
})
61+
.collect())
5562
}
5663
}
5764

@@ -67,7 +74,7 @@ mod tests {
6774
.collect();
6875
assert_eq!(
6976
ctc_decoder.decode(id_to_string_result).unwrap(),
70-
"hello".to_string()
77+
vec!["h", "e", "l", "l", "o"]
7178
);
7279
}
7380
#[test]
@@ -79,7 +86,7 @@ mod tests {
7986
.collect();
8087
assert_eq!(
8188
ctc_decoder.decode(id_to_string_result).unwrap(),
82-
"hello world".to_string()
89+
vec!["h", "e", "l", "l", "o", " ", "w", "o", "r", "l", "d"]
8390
);
8491
}
8592
#[test]
@@ -88,7 +95,11 @@ mod tests {
8895
let id_to_string_result = "<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> A | | <pad> M <pad> <pad> <pad> <pad> A <pad> <pad> N <pad> <pad> <pad> | | | <pad> <pad> <pad> <pad> S <pad> <pad> <pad> A I <pad> D D | | T T <pad> O <pad> | | T H E E | | | <pad> U U <pad> N N <pad> I <pad> <pad> V <pad> <pad> <pad> E R R <pad> <pad> <pad> S E E | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> S S <pad> <pad> <pad> <pad> I <pad> R R <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> I <pad> <pad> <pad> | <pad> <pad> <pad> E X <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> I <pad> S <pad> <pad> T <pad> <pad> | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>".split(' ').map(|s| s.to_string()).collect();
8996
assert_eq!(
9097
ctc_decoder.decode(id_to_string_result).unwrap(),
91-
"A MAN SAID TO THE UNIVERSE SIR I EXIST ".to_string()
98+
vec![
99+
"A", " ", "M", "A", "N", " ", "S", "A", "I", "D", " ", "T", "O", " ", "T", "H",
100+
"E", " ", "U", "N", "I", "V", "E", "R", "S", "E", " ", "S", "I", "R", " ", "I",
101+
" ", "E", "X", "I", "S", "T", " "
102+
]
92103
);
93104
}
94105
#[test]
@@ -97,7 +108,13 @@ mod tests {
97108
let id_to_string_result = "<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> H <pad> I <pad> S S | | <pad> <pad> <pad> I N <pad> <pad> S <pad> T T <pad> <pad> A N C C T <pad> | | | | | <pad> <pad> <pad> <pad> P <pad> <pad> <pad> <pad> A <pad> <pad> N N N <pad> <pad> I <pad> C <pad> <pad> | | <pad> W <pad> <pad> A S <pad> | | <pad> <pad> <pad> F <pad> <pad> O L <pad> <pad> L L O O W E E D | | <pad> B <pad> <pad> <pad> Y <pad> | | | A | | <pad> S S S <pad> M M <pad> <pad> <pad> A L L <pad> <pad> <pad> <pad> L <pad> | | | <pad> <pad> <pad> <pad> S H H <pad> <pad> <pad> <pad> A R R <pad> <pad> P <pad> <pad> | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> B <pad> <pad> L L <pad> <pad> <pad> <pad> <pad> O W W <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> H <pad> <pad> <pad> <pad> <pad> <pad> <pad> I G H H | | <pad> <pad> O N <pad> | | H <pad> I S S | | <pad> <pad> C H H <pad> <pad> <pad> E <pad> S S <pad> T T <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>".split(' ').map(|s| s.to_string()).collect();
98109
assert_eq!(
99110
ctc_decoder.decode(id_to_string_result).unwrap(),
100-
"HIS INSTANCT PANIC WAS FOLLOWED BY A SMALL SHARP BLOW HIGH ON HIS CHEST ".to_string()
111+
vec![
112+
"H", "I", "S", " ", "I", "N", "S", "T", "A", "N", "C", "T", " ", "P", "A", "N",
113+
"I", "C", " ", "W", "A", "S", " ", "F", "O", "L", "L", "O", "W", "E", "D", " ",
114+
"B", "Y", " ", "A", " ", "S", "M", "A", "L", "L", " ", "S", "H", "A", "R", "P",
115+
" ", "B", "L", "O", "W", " ", "H", "I", "G", "H", " ", "O", "N", " ", "H", "I",
116+
"S", " ", "C", "H", "E", "S", "T", " "
117+
]
101118
);
102119
}
103120
}

tokenizers/src/decoders/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ pub enum DecoderWrapper {
2626
}
2727

2828
impl Decoder for DecoderWrapper {
29-
fn decode(&self, tokens: Vec<String>) -> Result<String> {
29+
fn decode(&self, tokens: Vec<String>) -> Result<Vec<String>> {
3030
match self {
3131
Self::BPE(bpe) => bpe.decode(tokens),
3232
Self::ByteLevel(bl) => bl.decode(tokens),

tokenizers/src/decoders/wordpiece.rs

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ impl Default for WordPiece {
2828
}
2929
}
3030
}
31-
pub fn cleanup(dirty_input: String) -> String {
31+
pub fn cleanup(dirty_input: &str) -> String {
3232
dirty_input
3333
.replace(" .", ".")
3434
.replace(" ?", "?")
@@ -44,12 +44,21 @@ pub fn cleanup(dirty_input: String) -> String {
4444
}
4545

4646
impl Decoder for WordPiece {
47-
fn decode(&self, tokens: Vec<String>) -> Result<String> {
48-
let mut output = tokens.join(" ").replace(&format!(" {}", self.prefix), "");
49-
if self.cleanup {
50-
output = cleanup(output);
51-
}
52-
53-
Ok(output)
47+
fn decode(&self, mut tokens: Vec<String>) -> Result<Vec<String>> {
48+
tokens
49+
.iter_mut()
50+
.enumerate()
51+
.map(|(i, token)| {
52+
if token.starts_with(&self.prefix) {
53+
*token = token.replacen(&self.prefix, "", 1);
54+
} else if i != 0 {
55+
*token = format!(" {}", token);
56+
}
57+
if self.cleanup {
58+
*token = cleanup(token);
59+
}
60+
Ok(token.to_string())
61+
})
62+
.collect::<Result<_>>()
5463
}
5564
}

tokenizers/src/pre_tokenizers/byte_level.rs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,11 @@ impl PreTokenizer for ByteLevel {
124124

125125
/// As a `Decoder`, `ByteLevel` is in charge of converting any byte-level characters to their
126126
/// unicode counterpart, before merging everything back into a single String.
127+
/// This decoder will consume the tokens and merge them in one step to alleviate
128+
/// the fact that single token decoded might be a byte not representable as
129+
/// as String.
127130
impl Decoder for ByteLevel {
128-
fn decode(&self, tokens: Vec<String>) -> Result<String> {
131+
fn decode(&self, tokens: Vec<String>) -> Result<Vec<String>> {
129132
let toks = tokens
130133
.into_iter()
131134
.flat_map(|t| {
@@ -138,8 +141,8 @@ impl Decoder for ByteLevel {
138141
})
139142
.unwrap_or_else(|| t.as_bytes().to_vec())
140143
})
141-
.collect::<Vec<_>>();
142-
Ok(String::from_utf8_lossy(&toks).into_owned())
144+
.collect::<Vec<u8>>();
145+
Ok(vec![String::from_utf8_lossy(&toks).to_string()])
143146
}
144147
}
145148

@@ -248,7 +251,6 @@ mod tests {
248251
fn decoding() {
249252
let bytelevel = ByteLevel::default().add_prefix_space(false);
250253
assert_eq!(
251-
"Hello my friend, how is your day going?",
252254
bytelevel
253255
.decode(
254256
vec![
@@ -259,7 +261,8 @@ mod tests {
259261
.map(|s| s.into())
260262
.collect::<Vec<String>>()
261263
)
262-
.unwrap()
264+
.unwrap(),
265+
vec!["Hello my friend, how is your day going?"]
263266
);
264267
}
265268

@@ -311,7 +314,7 @@ mod tests {
311314
.iter()
312315
.flat_map(|(s, _, _)| s.split("").map(|t| t.into()))
313316
.collect::<Vec<_>>();
314-
assert_eq!(sample, bytelevel.decode(separated_tokens).unwrap());
317+
assert_eq!(sample, bytelevel.decode(separated_tokens).unwrap().join(""));
315318
}
316319
}
317320

@@ -507,7 +510,7 @@ mod tests {
507510
"[PA D]".into()
508511
])
509512
.unwrap(),
510-
"Hello there dear friend! [PA D]"
513+
vec!["Hello there dear friend! [PA D]"]
511514
);
512515
}
513516
}

0 commit comments

Comments
 (0)