Skip to content

Concatenate directly into shared memory when constructing batches#1323

Merged
soumith merged 1 commit intopytorch:masterfrom
colesbury:cat
Apr 22, 2017
Merged

Concatenate directly into shared memory when constructing batches#1323
soumith merged 1 commit intopytorch:masterfrom
colesbury:cat

Conversation

@colesbury
Copy link
Copy Markdown
Member

This saves an extra memory copy, which speeds up data loading a bit
(5-10% with accimage).

As part of this change:

  • torch.cat accepts keyword argument out
  • sepcifiying out=None is treated like not specifying out

This saves an extra memory copy, which speeds up data loading a bit
(5-10% with accimage).

As part of this change:

 * torch.cat accepts keyword argument out
 * sepcifiying out=None is treated like not specifying out
Copy link
Copy Markdown
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍

if _use_shared_memory:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])

This comment was marked as off-topic.

@soumith soumith merged commit 24d92b5 into pytorch:master Apr 22, 2017
@chenyuntc
Copy link
Copy Markdown
Contributor

seems to cause error

import torch as t
a=t.Tensor(2,4)
b=t.Tensor(2,4)
c=t.Tensor(2,8)
t.cat(a,b,out=c)
TypeError                                 Traceback (most recent call last)
<ipython-input-1-ce3d61609aec> in <module>()
      3 b=t.Tensor(2,4)
      4 c=t.Tensor(2,8)
----> 5 t.cat(a,b,out=c)

TypeError: cat received an invalid combination of arguments - got (torch.FloatTensor, torch.FloatTensor, out=torch.FloatTensor), but expected one of:
 * (sequence[torch.FloatTensor] seq)
 * (sequence[torch.FloatTensor] seq, int dim)

also see https://discuss.pytorch.org/t/cat-got-an-unexpected-keyword-argument-out/2151

@colesbury
Copy link
Copy Markdown
Member Author

colesbury commented Apr 22, 2017 via email

bunelr added a commit to bunelr/pytorch that referenced this pull request May 13, 2017
soumith pushed a commit that referenced this pull request May 14, 2017
Jiaming-Liu pushed a commit to Jiaming-Liu/pytorch that referenced this pull request May 18, 2017
@boeddeker
Copy link
Copy Markdown
Contributor

@colesbury Is there a reason why numpy arrays in default_collate do not use shared memory?
I would suggest changing

return torch.stack([torch.from_numpy(b) for b in batch], 0)

to

return default_collate([torch.from_numpy(b) for b in batch])

@colesbury
Copy link
Copy Markdown
Member Author

@boeddeker that seems fine. Can you send a PR with the change?

@boeddeker
Copy link
Copy Markdown
Contributor

ok, I opened a PR

facebook-github-bot pushed a commit that referenced this pull request Dec 30, 2018
… numpy (#14534)

Summary:
Since #1323 tensors are shared with shared memory, but this feature is not active for numpy.
This PR fix this.
Pull Request resolved: #14534

Differential Revision: D13561649

Pulled By: soumith

fbshipit-source-id: b6bc9e99fb91e8b675c2ef131fba9fa11c1647c0
eqy pushed a commit to eqy/pytorch that referenced this pull request Jan 20, 2022
scotts added a commit to scotts/pytorch that referenced this pull request Mar 31, 2026
8b42d4c..185fe9c includes the following commits:

185fe9c Expose occupany limiting factors (pytorch#1330)
0c8ede0 remove the rocprofiler early exit hack (pytorch#1329)
4826a43 Remove duplicate test ignore (pytorch#1328)
37fada9 Ensure that async doesn't loop while sync is active (pytorch#1327)
628e1d0 Add host_name to OSS Kineto trace metadata via gethostname() (pytorch#1323)
9d7373b Revert D97166802 (pytorch#1326)
3a61657 Fix Lingering INT32 Overflow (pytorch#1324)
50a0085 Re-enabled some hardcoded tests (pytorch#1321)
e19dd92 Expose occupany limiting factors (pytorch#1322)

Authored with Claude.
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
… numpy (pytorch#14534)

Summary:
Since pytorch#1323 tensors are shared with shared memory, but this feature is not active for numpy.
This PR fix this.
Pull Request resolved: pytorch#14534

Differential Revision: D13561649

Pulled By: soumith

fbshipit-source-id: b6bc9e99fb91e8b675c2ef131fba9fa11c1647c0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants