Skip to content

torch.topk Memory Leakage #3959

@todpole3

Description

@todpole3

#I implemented a beam search and has run into a memory leakage problem.

def top_k_action(log_action_dist, action_space):
        """
        Get top k action for batches
            - k = beam_size if the beam size is smaller than or equal to the beam action space size
            - k = beam_action_space_size otherwise
        :param log_action_dist: [batch_size*beam_size, action_space_size]
        :param action_space (r_space, e_space):
            r_space: [batch_size*beam_size, action_space_size]
            e_space: [batch_size*beam_size, action_space_size]
        :return:
            (next_r, next_e), action_prob, action_offset: [batch_size*k]
        """
        action_space_size = action_space[0].size()[1]
        # => [batch_size, k'*action_space_size]]
        log_action_dist = log_action_dist.view(batch_size, -1)
        beam_action_space_size = log_action_dist.size()[1]
        assert(beam_action_space_size % action_space_size == 0)
        last_k = beam_action_space_size / action_space_size
        k = beam_size if beam_size < beam_action_space_size else beam_action_space_size
        # [batch_size, k]
        action_prob, action_ind = torch.topk(log_action_dist, k)
        # DEBUGGING -- numpy topk operator
        # log_action_dist_cpu = log_action_dist.data.cpu().numpy()
        # action_prob = src.ops.var_cuda(torch.FloatTensor(np.sort(log_action_dist_cpu, axis=1)[:, -k:]))
        # action_ind = src.ops.int_var_cuda(torch.LongTensor(np.argsort(log_action_dist_cpu, axis=1)[:, -k:]))
        # compute parent offset
        # [batch_size, 1]
        action_batch_offset = (Variable(torch.arange(batch_size), volatile=True) * last_k)\
            .unsqueeze(1).type(torch.LongTensor).cuda()
        # [batch_size, k]
        action_beam_offset = action_ind / action_space_size
        # [batch_size, k] => [batch_size*k]
        action_offset = (action_batch_offset + action_beam_offset).view(-1)
        # compute action indices
        # [batch_size*k', action_space_size] => [batch_size*k, action_space_size]
        r_space = action_space[0][action_offset]
        e_space = action_space[1][action_offset]
        # [batch_size*k, 1]
        action_space_ind = (action_ind % action_space_size).view(-1, 1).detach()
        next_r = src.ops.batch_lookup(r_space, action_space_ind)
        next_e = src.ops.batch_lookup(e_space, action_space_ind)
        # [batch_size, k] => [batch_size*k] 
        action_prob = action_prob.view(-1)
        return (next_r, next_e), action_prob, action_offset

When I run beam search, the memory continuously increases as it searches over more examples and eventually it hits the following out of memory error during the foward pass:

THCudaCheck FAIL file=/tmp/pip-a7h6fthz-build/aten/src/THC/generic/THCStorage.cu line=58 error=2 : out of memory

Traceback (most recent call last):
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/Projects/dtt/src/experiments.py", line 159, in <module>
    main()
  File "/home/Projects/dtt/src/experiments.py", line 152, in main
    inference(pn, kg, rlf)
  File "/home/Projects/dtt/src/experiments.py", line 48, in inference
    rlf.inference(pn, kg, test_data, with_label=False, verbose=True)
  File "/home/Projects/dtt/src/rl.py", line 157, in inference
    pn, e1, r, e2, kg, num_steps=self.num_rollout_steps, beam_size=beam_size)
  File "/home/Projects/dtt/src/search.py", line 80, in beam_search
    action_dist, _ = pn.policy_nn(e, q, e_s, None, kg, use_dynamic_batching=False)
  File "/home/Projects/dtt/src/policy_network.py", line 151, in policy_nn
    A = self.get_action_embedding(KG, (r_space, e_space))
  File "/home/Projects/dtt/src/policy_network.py", line 277, in get_action_embedding
    action_embedding = torch.cat([relation_embedding, entity_embedding], dim=-1)
RuntimeError: cuda runtime error (2) : out of memory at /tmp/pip-a7h6fthz-build/aten/src/THC/generic/THCStorage.cu:58

However, the program runs error free if I commented out the line using the topk operator and replace the results of that line using random tensors.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions