#I implemented a beam search and has run into a memory leakage problem.
def top_k_action(log_action_dist, action_space):
"""
Get top k action for batches
- k = beam_size if the beam size is smaller than or equal to the beam action space size
- k = beam_action_space_size otherwise
:param log_action_dist: [batch_size*beam_size, action_space_size]
:param action_space (r_space, e_space):
r_space: [batch_size*beam_size, action_space_size]
e_space: [batch_size*beam_size, action_space_size]
:return:
(next_r, next_e), action_prob, action_offset: [batch_size*k]
"""
action_space_size = action_space[0].size()[1]
# => [batch_size, k'*action_space_size]]
log_action_dist = log_action_dist.view(batch_size, -1)
beam_action_space_size = log_action_dist.size()[1]
assert(beam_action_space_size % action_space_size == 0)
last_k = beam_action_space_size / action_space_size
k = beam_size if beam_size < beam_action_space_size else beam_action_space_size
# [batch_size, k]
action_prob, action_ind = torch.topk(log_action_dist, k)
# DEBUGGING -- numpy topk operator
# log_action_dist_cpu = log_action_dist.data.cpu().numpy()
# action_prob = src.ops.var_cuda(torch.FloatTensor(np.sort(log_action_dist_cpu, axis=1)[:, -k:]))
# action_ind = src.ops.int_var_cuda(torch.LongTensor(np.argsort(log_action_dist_cpu, axis=1)[:, -k:]))
# compute parent offset
# [batch_size, 1]
action_batch_offset = (Variable(torch.arange(batch_size), volatile=True) * last_k)\
.unsqueeze(1).type(torch.LongTensor).cuda()
# [batch_size, k]
action_beam_offset = action_ind / action_space_size
# [batch_size, k] => [batch_size*k]
action_offset = (action_batch_offset + action_beam_offset).view(-1)
# compute action indices
# [batch_size*k', action_space_size] => [batch_size*k, action_space_size]
r_space = action_space[0][action_offset]
e_space = action_space[1][action_offset]
# [batch_size*k, 1]
action_space_ind = (action_ind % action_space_size).view(-1, 1).detach()
next_r = src.ops.batch_lookup(r_space, action_space_ind)
next_e = src.ops.batch_lookup(e_space, action_space_ind)
# [batch_size, k] => [batch_size*k]
action_prob = action_prob.view(-1)
return (next_r, next_e), action_prob, action_offset
When I run beam search, the memory continuously increases as it searches over more examples and eventually it hits the following out of memory error during the foward pass:
THCudaCheck FAIL file=/tmp/pip-a7h6fthz-build/aten/src/THC/generic/THCStorage.cu line=58 error=2 : out of memory
Traceback (most recent call last):
File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/Projects/dtt/src/experiments.py", line 159, in <module>
main()
File "/home/Projects/dtt/src/experiments.py", line 152, in main
inference(pn, kg, rlf)
File "/home/Projects/dtt/src/experiments.py", line 48, in inference
rlf.inference(pn, kg, test_data, with_label=False, verbose=True)
File "/home/Projects/dtt/src/rl.py", line 157, in inference
pn, e1, r, e2, kg, num_steps=self.num_rollout_steps, beam_size=beam_size)
File "/home/Projects/dtt/src/search.py", line 80, in beam_search
action_dist, _ = pn.policy_nn(e, q, e_s, None, kg, use_dynamic_batching=False)
File "/home/Projects/dtt/src/policy_network.py", line 151, in policy_nn
A = self.get_action_embedding(KG, (r_space, e_space))
File "/home/Projects/dtt/src/policy_network.py", line 277, in get_action_embedding
action_embedding = torch.cat([relation_embedding, entity_embedding], dim=-1)
RuntimeError: cuda runtime error (2) : out of memory at /tmp/pip-a7h6fthz-build/aten/src/THC/generic/THCStorage.cu:58
However, the program runs error free if I commented out the line using the topk operator and replace the results of that line using random tensors.
#I implemented a beam search and has run into a memory leakage problem.
When I run beam search, the memory continuously increases as it searches over more examples and eventually it hits the following out of memory error during the foward pass:
However, the program runs error free if I commented out the line using the
topkoperator and replace the results of that line using random tensors.