generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 29
Expand file tree
/
Copy pathrnn.py
More file actions
57 lines (50 loc) · 2 KB
/
rnn.py
File metadata and controls
57 lines (50 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# AUTOGENERATED! DO NOT EDIT! File to edit: 09_rnn.ipynb (unless otherwise specified).
__all__ = ['generate_data', 'encode', 'decode']
# Cell
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, TimeDistributed, Dense, RepeatVector
#export
def generate_data(training_size=10):
X = []
y = []
duplicates = set()
p_bar = tqdm(total=training_size)
while len(X) < training_size:
a = int(''.join(np.random.choice(list('0123456789')) for i in range(np.random.randint(1, DIGITS + 1))))
b = int(''.join(np.random.choice(list('0123456789')) for i in range(np.random.randint(1, DIGITS + 1))))
pair = tuple(sorted((a, b)))
if pair in duplicates:
continue
duplicates.add(pair)
pair_str = '{}+{}'.format(a,b)
pair_str = ' ' * (MAXLEN - len(pair_str)) + pair_str
ans = str(a + b)
ans = ' ' * ((DIGITS + 1) - len(ans)) + ans
X.append(pair_str)
y.append(ans)
p_bar.update(1)
return X,y
#export
def encode(questions, answers, alphabet):
char_to_index = dict((c, i) for i, c in enumerate(alphabet))
x = np.zeros((len(questions), MAXLEN, len(alphabet)))
y = np.zeros((len(questions), DIGITS + 1, len(alphabet)))
for q_counter, pair in enumerate(questions):
encoded_pair = np.zeros((MAXLEN, len(alphabet)))
for i, c in enumerate(pair):
encoded_pair[i, char_to_index[c]] = 1
x[q_counter] = encoded_pair
for a_counter, ans in enumerate(answers):
encoded_ans = np.zeros((DIGITS + 1, len(alphabet)))
for i, c in enumerate(ans):
encoded_ans[i, char_to_index[c]] = 1
y[a_counter] = encoded_ans
return x, y
#export
def decode(seq, alphabet, calc_argmax=True):
index_to_char = dict((i, c) for i, c in enumerate(alphabet))
if calc_argmax:
seq = np.argmax(seq, axis=-1)
return ''.join(index_to_char[c] for c in seq)