[DataPipe] Add RandomSplitter (without buffer)#724
[DataPipe] Add RandomSplitter (without buffer)#724NivekT wants to merge 18 commits intogh/NivekT/86/basefrom
Conversation
[ghstack-poisoned]
This PR adds RandomSplitter without a buffer. The upside is that this uses less memory (good for memory-bound cases) but the downside are 1) only one group can be iterated through at a time and 2) it skips over all the groups that do not match the target (which is potentially wasteful). TODO: * Decide if we like this or the buffer version better. Or we can add both. * Determines if the API related to randomness needs further extension (we might need to add set_seed) * More tests. See #712 for related discussion. See #723 for the version with buffer. [ghstack-poisoned]
This PR adds RandomSplitter without a buffer. The upside is that this uses less memory (good for memory-bound cases) but the downside are 1) only one group can be iterated through at a time and 2) it skips over all the groups that do not match the target (which is potentially wasteful). TODO: * Decide if we like this or the buffer version better. Or we can add both. * Determines if the API related to randomness needs further extension (we might need to add set_seed) * More tests. Implementation note: * I decided against reusing `_ChildDataPipe` since its features are overly complicated for this use case. See #712 for related discussion. See #723 for the version with buffer. [ghstack-poisoned]
|
Offline: Discussion:
|
This PR adds RandomSplitter without a buffer. The upside is that this uses less memory (good for memory-bound cases) but the downside are 1) only one group can be iterated through at a time and 2) it skips over all the groups that do not match the target (which is potentially wasteful). TODO: * Decide if we like this or the buffer version better. Or we can add both. * Determines if the API related to randomness needs further extension (we might need to add set_seed) * More tests. Implementation note: * I decided against reusing `_ChildDataPipe` since its features are overly complicated for this use case. See #712 for related discussion. See #723 for the version with buffer. [ghstack-poisoned]
NivekT
left a comment
There was a problem hiding this comment.
Updated this PR based on our discussion.
Note: I decided against having an option to change seed automatically after each iteration, because there are situations where the first iteration is for test and the second iteration is for valid. Changing seed will be confusing and causes inconsistency. Users should use set_seed instead to update the seed when necessary.
| "dataframe": "torcharrow.DataFrame", | ||
| "end_caching": "IterDataPipe", | ||
| "unzip": "List[IterDataPipe]", | ||
| "random_split": "Union[IterDataPipe, List[IterDataPipe]]", |
There was a problem hiding this comment.
Is this return type fine? It will either be a single IterDataPipe or a List[...] depending if target is specified.
This PR adds RandomSplitter without a buffer. The upside is that this uses less memory (good for memory-bound cases) but the downside are 1) only one group can be iterated through at a time and 2) it skips over all the groups that do not match the target (which is potentially wasteful). Implementation note: * I decided against reusing `_ChildDataPipe` since its features are overly complicated for this use case. * I also decided against having an option to change seed automatically after each iteration, because there are situations where the first iteration is for `test` and the second iteration is for `valid`. Changing seed will be confusing and causes inconsistency. See #712 for related discussion. See #723 for the version with buffer. [ghstack-poisoned]
|
@NivekT has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
This PR adds RandomSplitter without a buffer. The upside is that this uses less memory (good for memory-bound cases) but the downside are 1) only one group can be iterated through at a time and 2) it skips over all the groups that do not match the target (which is potentially wasteful). Implementation note: * I decided against reusing `_ChildDataPipe` since its features are overly complicated for this use case. * I also decided against having an option to change seed automatically after each iteration, because there are situations where the first iteration is for `test` and the second iteration is for `valid`. Changing seed will be confusing and causes inconsistency. See #712 for related discussion. See #723 for the version with buffer. Differential Revision: [D38675266](https://our.internmc.facebook.com/intern/diff/D38675266) [ghstack-poisoned]
|
@NivekT has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
This PR adds RandomSplitter without a buffer. The upside is that this uses less memory (good for memory-bound cases) but the downside are 1) only one group can be iterated through at a time and 2) it skips over all the groups that do not match the target (which is potentially wasteful). Implementation note: * I decided against reusing `_ChildDataPipe` since its features are overly complicated for this use case. * I also decided against having an option to change seed automatically after each iteration, because there are situations where the first iteration is for `test` and the second iteration is for `valid`. Changing seed will be confusing and causes inconsistency. See #712 for related discussion. See #723 for the version with buffer. Differential Revision: [D38675266](https://our.internmc.facebook.com/intern/diff/D38675266) [ghstack-poisoned]
|
@NivekT has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
|
Can we derive |
This PR adds RandomSplitter without a buffer. The upside is that this uses less memory (good for memory-bound cases) but the downside are 1) only one group can be iterated through at a time and 2) it skips over all the groups that do not match the target (which is potentially wasteful). Implementation note: * I decided against reusing `_ChildDataPipe` since its features are overly complicated for this use case. * I also decided against having an option to change seed automatically after each iteration, because there are situations where the first iteration is for `test` and the second iteration is for `valid`. Changing seed will be confusing and causes inconsistency. See #712 for related discussion. See #723 for the version with buffer. Differential Revision: [D38675266](https://our.internmc.facebook.com/intern/diff/D38675266) [ghstack-poisoned]
This PR adds RandomSplitter without a buffer. The upside is that this uses less memory (good for memory-bound cases) but the downside are 1) only one group can be iterated through at a time and 2) it skips over all the groups that do not match the target (which is potentially wasteful). Implementation note: * I decided against reusing `_ChildDataPipe` since its features are overly complicated for this use case. * I also decided against having an option to change seed automatically after each iteration, because there are situations where the first iteration is for `test` and the second iteration is for `valid`. Changing seed will be confusing and causes inconsistency. See #712 for related discussion. See #723 for the version with buffer. Differential Revision: [D38675266](https://our.internmc.facebook.com/intern/diff/D38675266) [ghstack-poisoned]
NivekT
left a comment
There was a problem hiding this comment.
I think this should be mostly consistent with Shuffler with the exception that the seed doesn't change per epoch.
ejguan
left a comment
There was a problem hiding this comment.
Overall, LGTM with a few comments
| @staticmethod | ||
| def normalize_weights(weights, total_length: int): | ||
| total_weight = sum(weights.values()) | ||
| return {k: round(float(w) * total_length / total_weight) for k, w in weights.items()} |
There was a problem hiding this comment.
And, using round might lead to the sum of weights not equal to total_length.
There was a problem hiding this comment.
You are right. This also means we have to renormalize the weights after one of the keys run out. I have updated my implementation. Please verify.
There was a problem hiding this comment.
I am unsure. If one key runs out, why do we need to renormalize? If one weight goes to 0, the corresponding key should never be returned by random.choices IMHO.
There was a problem hiding this comment.
Let's say we have total_length=10 and weights=[3.33, 3.33, 3.33].
We draw 4 '0's in a row, the remaining weights will be weights=[-0.67, 3.33, 3.33].
We can set the negative value to 0 and get weights=[0, 3.33, 3.33]. Now the remaining number of draws is 6 but the weights sum up to 6.66.
It should not impact the result/output, but if we choose to normalize, it will bring the sum back to 6 to match the remaining number of draws. I can take it out since it is technically unnecessary computation.
There was a problem hiding this comment.
I think I will take it out, because it will greatly simplify the other functions. For example, normalize_weights will no longer have to consider the list case.
There was a problem hiding this comment.
Then, how about we always convert weights to a list of integer. For example, we do floor over all the elements except the last one. And, the rest of elements send to the last key. Then, we don't need to normalize and the weight is kept as integer then no negative value will happen.
There was a problem hiding this comment.
Two things:
- I was wrong about not needing to normalize after a weight hitting negative. Actually it is necessary.
Example: Splitting 10 elements to 9 DPs: [1.11, 1.11, 1.11, 1.11, 1.11, 1.11, 1.11, 1.11, 1.11]
Without normalization, you may end up with 2 or more DPs with more than one element, while the ideal case is 8
DPs with 1 element and 1 DP with 2. - I considered your suggestion to
floorall but the last one. It changes the weights from its true distributino.
Example: 10 elements split into two DPs with weights [0.89, 0.11], most of the time you want 9 element in the first
and 1 element second DP.
Because of these complications, I am going to add a note, telling users that if they want certainty, please provide integer weights that sum up to total_length.
There was a problem hiding this comment.
IIUC, the current implementation is similar to doing ceiling to all keys until the last requested. That's why we have to do re-normalization when each time one key is depleted. It's fine but I feel like we can do it in advance without doing normalization multiple times.
This PR adds RandomSplitter without a buffer. The upside is that this uses less memory (good for memory-bound cases) but the downside are 1) only one group can be iterated through at a time and 2) it skips over all the groups that do not match the target (which is potentially wasteful). Implementation note: * I decided against reusing `_ChildDataPipe` since its features are overly complicated for this use case. * I also decided against having an option to change seed automatically after each iteration, because there are situations where the first iteration is for `test` and the second iteration is for `valid`. Changing seed will be confusing and causes inconsistency. See #712 for related discussion. See #723 for the version with buffer. Differential Revision: [D38675266](https://our.internmc.facebook.com/intern/diff/D38675266) [ghstack-poisoned]
This PR adds RandomSplitter without a buffer. The upside is that this uses less memory (good for memory-bound cases) but the downside are 1) only one group can be iterated through at a time and 2) it skips over all the groups that do not match the target (which is potentially wasteful). Implementation note: * I decided against reusing `_ChildDataPipe` since its features are overly complicated for this use case. * I also decided against having an option to change seed automatically after each iteration, because there are situations where the first iteration is for `test` and the second iteration is for `valid`. Changing seed will be confusing and causes inconsistency. See #712 for related discussion. See #723 for the version with buffer. Differential Revision: [D38675266](https://our.internmc.facebook.com/intern/diff/D38675266) [ghstack-poisoned]
|
@NivekT has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
This PR adds RandomSplitter without a buffer. The upside is that this uses less memory (good for memory-bound cases) but the downside are 1) only one group can be iterated through at a time and 2) it skips over all the groups that do not match the target (which is potentially wasteful). Implementation note: * I decided against reusing `_ChildDataPipe` since its features are overly complicated for this use case. * I also decided against having an option to change seed automatically after each iteration, because there are situations where the first iteration is for `test` and the second iteration is for `valid`. Changing seed will be confusing and causes inconsistency. See #712 for related discussion. See #723 for the version with buffer. Differential Revision: [D38675266](https://our.internmc.facebook.com/intern/diff/D38675266) [ghstack-poisoned]
|
Thanks for the helpful comments. It is simpler than before now! |
|
@NivekT has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
This PR adds RandomSplitter without a buffer. The upside is that this uses less memory (good for memory-bound cases) but the downside are 1) only one group can be iterated through at a time and 2) it skips over all the groups that do not match the target (which is potentially wasteful). Implementation note: * I decided against reusing `_ChildDataPipe` since its features are overly complicated for this use case. * I also decided against having an option to change seed automatically after each iteration, because there are situations where the first iteration is for `test` and the second iteration is for `valid`. Changing seed will be confusing and causes inconsistency. See #712 for related discussion. See #723 for the version with buffer. Differential Revision: [D38675266](https://our.internmc.facebook.com/intern/diff/D38675266) [ghstack-poisoned]
|
|
||
| def get_length(self, target: T) -> int: | ||
| if not self._lengths: | ||
| if all(w.is_integer() for w in self.norm_weights) and sum(self.norm_weights) == self.total_length: |
There was a problem hiding this comment.
We can potentially assert this during __init__, such that if the normalized weights aren't integer or cannot sum up to total_length, we will raise exception.
|
@NivekT has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
This PR adds RandomSplitter without a buffer. The upside is that this uses less memory (good for memory-bound cases) but the downside are 1) only one group can be iterated through at a time and 2) it skips over all the groups that do not match the target (which is potentially wasteful). Implementation note: * I decided against reusing `_ChildDataPipe` since its features are overly complicated for this use case. * I also decided against having an option to change seed automatically after each iteration, because there are situations where the first iteration is for `test` and the second iteration is for `valid`. Changing seed will be confusing and causes inconsistency. See #712 for related discussion. See #723 for the version with buffer. Differential Revision: [D38675266](https://our.internmc.facebook.com/intern/diff/D38675266) [ghstack-poisoned]
|
@NivekT has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Stack from ghstack:
This PR adds RandomSplitter without a buffer. The upside is that this uses less memory (good for memory-bound cases) but the downside are 1) only one group can be iterated through at a time and 2) it skips over all the groups that do not match the target (which is potentially wasteful).
Implementation note:
_ChildDataPipesince its features are overly complicated for this use case.testand the second iteration is forvalid. Changing seed will be confusing and causes inconsistency.See #712 for related discussion.
See #723 for the version with buffer.
Differential Revision: D38675266