[data] make random_sample() reproducible#51401
[data] make random_sample() reproducible#51401richardliaw merged 1 commit intoray-project:masterfrom
Conversation
9190c43 to
09d7814
Compare
|
@alexeykudinkin can you help review? thanks |
|
Will be addressed in #46088 |
@alexeykudinkin I am aware of that PR and do not think the solution is correct. See my example here: #49443 (comment) |
|
hey @wingkitlee0 , @alexeykudinkin was referring to a different PR https://github.com/ray-project/ray/pull/46088/files |
|
Actually, after reading more on both PRs, I think this PR is better, because it can avoid the small-batch issue mentioned in #46088 cc @alexeykudinkin |
python/ray/data/dataset.py
Outdated
There was a problem hiding this comment.
I'd like to avoid exposing this flag.
Instead, we can add a get_current() method in TaskContext
There was a problem hiding this comment.
and we don't need passing the batch_idx either.
We can just seed the random generator once per task.
this can be done by either 1) use a class-based UDF or 2) save some state in TaskContext.kwargs. (1) would be better.
There was a problem hiding this comment.
- use a class-based UDF
class-based UDF requires concurrency to be set. Any way to get around that?
There was a problem hiding this comment.
You can use a class-based UDF with compute=TaskPoolComputeStrategy
There was a problem hiding this comment.
Interesting, ray/data/_internal/util.py's get_compute_strategy seems to discourage the use of CallableClass + TaskPoolStrategy (get_compute_strategy is called by _map_batches_without_batch_size_validation). I can find ways to bypass that, but I am curious if that would work...
There was a problem hiding this comment.
By default, we choose ActorPoolStrategy if the UDF is a class.
The motivation is to simplify the usage and avoid users having to specify the strategy.
But in theory, CallableClass + TaskPoolStrategy should also be feasible.
You can probably move get_compute_strategy out of _map_batches_without_batch_size_validation.
Not sure if there are other issues. if you find you have to tweak too many things, TaskContext.kwargs is also fine.
python/ray/data/tests/test_map.py
Outdated
There was a problem hiding this comment.
we can probably just check against hard-coded expected results.
otherwise, we'll need to also test get_expected_mask_indices is deterministic.
python/ray/data/tests/test_map.py
Outdated
|
@wingkitlee0 yes, you're right. I overlooked the fact that we'll be generating identical sequences per block which isn't ideal. |
There was a problem hiding this comment.
Let's avoid this param
There was a problem hiding this comment.
Sure.
However, it will now need to call _map_batches_without_batch_size_validation which does not have default args. So there will be a bunch of hardcoded default values in random_sample.. Any advice?
python/ray/data/dataset.py
Outdated
There was a problem hiding this comment.
Why do we need batch_idx?
There was a problem hiding this comment.
My original thought was to keep this a pure and stateless function.
09d7814 to
a8b8d33
Compare
|
@wingkitlee0 - could you fix tests? and ping when ready for another review? |
3382a4e to
18736ba
Compare
18736ba to
0c1dd08
Compare
0c1dd08 to
e3d6d91
Compare
|
It's ready for re-review!
|
python/ray/data/dataset.py
Outdated
| else: | ||
| rng = np.random.default_rng( | ||
| [ctx.kwargs.get("batch_idx", 0), ctx.task_idx, seed] | ||
| ) |
There was a problem hiding this comment.
I think we can get rid of the include_task_ctx flag as well.
- we can add a thread-local variable in
TaskContextto allow accessing the current TaskContext. Basically callTaskContext.set_current/reset_currerntin_map_task. - I don't think we need batch_idx here. We just need to create the rng object once per task and store it in
TaskContext.kwargs.
There was a problem hiding this comment.
Great suggestion. It will reduce code changes. I will probably need to update the tests a little bit.
raulchen
left a comment
There was a problem hiding this comment.
LGTM with one last comment
| kwargs: Dict[str, Any] = field(default_factory=dict) | ||
|
|
||
| @classmethod | ||
| def get_current(cls, create_if_not_exists=True, **kwargs) -> "TaskContext": |
There was a problem hiding this comment.
nit, I don't think create_if_not_exists and kwargs are needed for this PR.
Let's remove it.
fe368ee to
35d4280
Compare
- using a TaskContext to access task_idx Signed-off-by: Kit Lee <7000003+wingkitlee0@users.noreply.github.com>
35d4280 to
8f3b7b4
Compare
<!-- Thank you for your contribution! Please review https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? **Problem** Current `random_sample()` does not work with fixed seed. ray-project#40406 Previous attempts (changing the global seed or passing the same seed/state to workers) also do not work. **Solution [Updated after PR review]** In order to use random generators in parallel, we need to be careful about the seed/state that passes into `map_batches`. `numpy` describes a few methods and one of them is to use a sequence of seeds https://numpy.org/doc/2.2/reference/random/parallel.html#sequence-of-integer-seeds. In Ray Data, we can construct a `random_sample()` UDF that has access to a "block id" via `TaskContext` (that is thread-local) and use `[block_id, seed]` to initialize a RNG. As the Ray task may be reused for different blocks, the RNG is saved into `TaskContext.kwargs`. **Proposed fix [Updated after PR review]** We add `set/get_current()` methods to `TaskContext` which allow the UDF to get a local copy. It has access to the `task_idx` and previously initialized RNG. This removes the need of extra arguments in the original proposal. **After fix** ```python In [9]: ds = ray.data.range(1000) In [10]: ds.random_sample(0.05, seed=1234).take_batch() Out[10]: {'id': array([ 27, 54, 72, 111, 136, 144, 147, 168, 200, 224, 225, 245, 247, 248, 307, 312, 313, 340, 347, 375])} In [11]: ds.random_sample(0.05, seed=1234).take_batch() Out[11]: {'id': array([ 27, 54, 72, 111, 136, 144, 147, 168, 200, 224, 225, 245, 247, 248, 307, 312, 313, 340, 347, 375])} ``` ## Related issue number This issue has been raised a few times: Closes ray-project#40406 ray-project#48497 Other implementations did not solve the root cause: ray-project#46088 ray-project#49443 <!-- For example: "Closes ray-project#1234" --> ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run `scripts/format.sh` to lint the changes in this PR. - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [x] Unit tests - [ ] Release tests - [ ] This PR is not tested :( Signed-off-by: Kit Lee <7000003+wingkitlee0@users.noreply.github.com> Signed-off-by: Steve Han <stevehan2001@gmail.com>
Why are these changes needed?
Problem
Current
random_sample()does not work with fixed seed. #40406Previous attempts (changing the global seed or passing the same seed/state to workers) also do not work.
Solution [Updated after PR review]
In order to use random generators in parallel, we need to be careful about the seed/state that passes into
map_batches.numpydescribes a few methods and one of them is to use a sequence of seeds https://numpy.org/doc/2.2/reference/random/parallel.html#sequence-of-integer-seeds. In Ray Data, we can construct arandom_sample()UDF that has access to a "block id" viaTaskContext(that is thread-local) and use[block_id, seed]to initialize a RNG. As the Ray task may be reused for different blocks, the RNG is saved intoTaskContext.kwargs.Proposed fix [Updated after PR review]
We add
set/get_current()methods toTaskContextwhich allow the UDF to get a local copy. It has access to thetask_idxand previously initialized RNG. This removes the need of extra arguments in the original proposal.After fix
Related issue number
This issue has been raised a few times:
Closes #40406 #48497
Other implementations did not solve the root cause:
#46088
#49443
Checks
git commit -s) in this PR.scripts/format.shto lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/under thecorresponding
.rstfile.