Skip to content

[data] make random_sample() reproducible#51401

Merged
richardliaw merged 1 commit intoray-project:masterfrom
wingkitlee0:klee/random-sample-fixed-seed
Apr 10, 2025
Merged

[data] make random_sample() reproducible#51401
richardliaw merged 1 commit intoray-project:masterfrom
wingkitlee0:klee/random-sample-fixed-seed

Conversation

@wingkitlee0
Copy link
Copy Markdown
Contributor

@wingkitlee0 wingkitlee0 commented Mar 15, 2025

Why are these changes needed?

Problem

Current random_sample() does not work with fixed seed. #40406

Previous attempts (changing the global seed or passing the same seed/state to workers) also do not work.

Solution [Updated after PR review]

In order to use random generators in parallel, we need to be careful about the seed/state that passes into map_batches. numpy describes a few methods and one of them is to use a sequence of seeds https://numpy.org/doc/2.2/reference/random/parallel.html#sequence-of-integer-seeds. In Ray Data, we can construct a random_sample() UDF that has access to a "block id" via TaskContext (that is thread-local) and use [block_id, seed] to initialize a RNG. As the Ray task may be reused for different blocks, the RNG is saved into TaskContext.kwargs.

Proposed fix [Updated after PR review]

We add set/get_current() methods to TaskContext which allow the UDF to get a local copy. It has access to the task_idx and previously initialized RNG. This removes the need of extra arguments in the original proposal.

After fix

In [9]: ds = ray.data.range(1000)

In [10]: ds.random_sample(0.05, seed=1234).take_batch()
Out[10]:
{'id': array([ 27,  54,  72, 111, 136, 144, 147, 168, 200, 224, 225, 245, 247,
        248, 307, 312, 313, 340, 347, 375])}

In [11]: ds.random_sample(0.05, seed=1234).take_batch()
Out[11]:
{'id': array([ 27,  54,  72, 111, 136, 144, 147, 168, 200, 224, 225, 245, 247,
        248, 307, 312, 313, 340, 347, 375])}

Related issue number

This issue has been raised a few times:
Closes #40406 #48497

Other implementations did not solve the root cause:
#46088
#49443

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

@wingkitlee0 wingkitlee0 changed the title Make random_sample() reproducible [data] make random_sample() reproducible Mar 15, 2025
@wingkitlee0 wingkitlee0 force-pushed the klee/random-sample-fixed-seed branch 8 times, most recently from 9190c43 to 09d7814 Compare March 16, 2025 21:06
@wingkitlee0 wingkitlee0 marked this pull request as ready for review March 16, 2025 23:35
@wingkitlee0 wingkitlee0 requested a review from a team as a code owner March 16, 2025 23:35
@wingkitlee0
Copy link
Copy Markdown
Contributor Author

@alexeykudinkin can you help review? thanks

@alexeykudinkin
Copy link
Copy Markdown
Contributor

Will be addressed in #46088

@wingkitlee0
Copy link
Copy Markdown
Contributor Author

Will be addressed in #46088

@alexeykudinkin I am aware of that PR and do not think the solution is correct. See my example here: #49443 (comment)

@jcotant1 jcotant1 added the data Ray Data-related issues label Mar 24, 2025
@raulchen
Copy link
Copy Markdown
Contributor

hey @wingkitlee0 , @alexeykudinkin was referring to a different PR https://github.com/ray-project/ray/pull/46088/files
That PR looks good to me as well. And it's simpler.
If you still have concerns, please comment on that PR.

@raulchen
Copy link
Copy Markdown
Contributor

Actually, after reading more on both PRs, I think this PR is better, because it can avoid the small-batch issue mentioned in #46088 cc @alexeykudinkin

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to avoid exposing this flag.
Instead, we can add a get_current() method in TaskContext

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and we don't need passing the batch_idx either.
We can just seed the random generator once per task.
this can be done by either 1) use a class-based UDF or 2) save some state in TaskContext.kwargs. (1) would be better.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. use a class-based UDF

class-based UDF requires concurrency to be set. Any way to get around that?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use a class-based UDF with compute=TaskPoolComputeStrategy

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, ray/data/_internal/util.py's get_compute_strategy seems to discourage the use of CallableClass + TaskPoolStrategy (get_compute_strategy is called by _map_batches_without_batch_size_validation). I can find ways to bypass that, but I am curious if that would work...

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default, we choose ActorPoolStrategy if the UDF is a class.
The motivation is to simplify the usage and avoid users having to specify the strategy.
But in theory, CallableClass + TaskPoolStrategy should also be feasible.
You can probably move get_compute_strategy out of _map_batches_without_batch_size_validation.
Not sure if there are other issues. if you find you have to tweak too many things, TaskContext.kwargs is also fine.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can probably just check against hard-coded expected results.
otherwise, we'll need to also test get_expected_mask_indices is deterministic.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice test

@alexeykudinkin alexeykudinkin added the go add ONLY when ready to merge, run all tests label Mar 26, 2025
@alexeykudinkin
Copy link
Copy Markdown
Contributor

@wingkitlee0 yes, you're right. I overlooked the fact that we'll be generating identical sequences per block which isn't ideal.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's avoid this param

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.
However, it will now need to call _map_batches_without_batch_size_validation which does not have default args. So there will be a bunch of hardcoded default values in random_sample.. Any advice?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need batch_idx?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My original thought was to keep this a pure and stateless function.

@richardliaw richardliaw added the @external-author-action-required Alternate tag for PRs where the author doesn't have labeling permission. label Mar 28, 2025
@wingkitlee0 wingkitlee0 force-pushed the klee/random-sample-fixed-seed branch from 09d7814 to a8b8d33 Compare March 29, 2025 22:19
@richardliaw
Copy link
Copy Markdown
Contributor

@wingkitlee0 - could you fix tests? and ping when ready for another review?

@wingkitlee0 wingkitlee0 force-pushed the klee/random-sample-fixed-seed branch 3 times, most recently from 3382a4e to 18736ba Compare April 4, 2025 01:35
@wingkitlee0 wingkitlee0 force-pushed the klee/random-sample-fixed-seed branch from 18736ba to 0c1dd08 Compare April 4, 2025 01:41
@hainesmichaelc hainesmichaelc added the community-contribution Contributed by the community label Apr 4, 2025
@wingkitlee0 wingkitlee0 force-pushed the klee/random-sample-fixed-seed branch from 0c1dd08 to e3d6d91 Compare April 4, 2025 03:49
@wingkitlee0
Copy link
Copy Markdown
Contributor Author

It's ready for re-review!

  • No change to the public APIs.
  • Used TaskContext.kwargs at the end. I checked TaskPoolStrategy etc: It can't use class-based UDFs now because it misses the step to call the constructor.
  • Updated AbstractUDFMap to use kwargs (the use of positional args + defaults makes it hard to track down the previous pipeline failures: the object was instantiated successfully with mismatched args but the error wasn't raised many steps later)
  • Simplified the unit tests

else:
rng = np.random.default_rng(
[ctx.kwargs.get("batch_idx", 0), ctx.task_idx, seed]
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can get rid of the include_task_ctx flag as well.

  1. we can add a thread-local variable in TaskContext to allow accessing the current TaskContext. Basically call TaskContext.set_current/reset_currernt in _map_task.
  2. I don't think we need batch_idx here. We just need to create the rng object once per task and store it in TaskContext.kwargs.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great suggestion. It will reduce code changes. I will probably need to update the tests a little bit.

Copy link
Copy Markdown
Contributor

@raulchen raulchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with one last comment

kwargs: Dict[str, Any] = field(default_factory=dict)

@classmethod
def get_current(cls, create_if_not_exists=True, **kwargs) -> "TaskContext":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, I don't think create_if_not_exists and kwargs are needed for this PR.
Let's remove it.

@wingkitlee0 wingkitlee0 force-pushed the klee/random-sample-fixed-seed branch from fe368ee to 35d4280 Compare April 9, 2025 00:07
- using a TaskContext to access task_idx

Signed-off-by: Kit Lee <7000003+wingkitlee0@users.noreply.github.com>
@wingkitlee0 wingkitlee0 force-pushed the klee/random-sample-fixed-seed branch from 35d4280 to 8f3b7b4 Compare April 9, 2025 01:27
@richardliaw richardliaw merged commit fa03256 into ray-project:master Apr 10, 2025
5 checks passed
han-steve pushed a commit to han-steve/ray that referenced this pull request Apr 11, 2025
<!-- Thank you for your contribution! Please review
https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before
opening a pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed?

**Problem**

Current `random_sample()` does not work with fixed seed.
ray-project#40406

Previous attempts (changing the global seed or passing the same
seed/state to workers) also do not work.

**Solution [Updated after PR review]**

In order to use random generators in parallel, we need to be careful
about the seed/state that passes into `map_batches`. `numpy` describes a
few methods and one of them is to use a sequence of seeds
https://numpy.org/doc/2.2/reference/random/parallel.html#sequence-of-integer-seeds.
In Ray Data, we can construct a `random_sample()` UDF that has access to
a "block id" via `TaskContext` (that is thread-local) and use
`[block_id, seed]` to initialize a RNG. As the Ray task may be reused
for different blocks, the RNG is saved into `TaskContext.kwargs`.

**Proposed fix  [Updated after PR review]**

We add `set/get_current()` methods to `TaskContext` which allow the UDF
to get a local copy. It has access to the `task_idx` and previously
initialized RNG. This removes the need of extra arguments in the
original proposal.

**After fix**
```python
In [9]: ds = ray.data.range(1000)

In [10]: ds.random_sample(0.05, seed=1234).take_batch()
Out[10]:
{'id': array([ 27,  54,  72, 111, 136, 144, 147, 168, 200, 224, 225, 245, 247,
        248, 307, 312, 313, 340, 347, 375])}

In [11]: ds.random_sample(0.05, seed=1234).take_batch()
Out[11]:
{'id': array([ 27,  54,  72, 111, 136, 144, 147, 168, 200, 224, 225, 245, 247,
        248, 307, 312, 313, 340, 347, 375])}
```

## Related issue number

This issue has been raised a few times:
Closes ray-project#40406 ray-project#48497

Other implementations did not solve the root cause:
ray-project#46088
ray-project#49443

<!-- For example: "Closes ray-project#1234" -->

## Checks

- [x] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [x] I've run `scripts/format.sh` to lint the changes in this PR.
- [ ] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [ ] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [ ] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [x] Unit tests
   - [ ] Release tests
   - [ ] This PR is not tested :(

Signed-off-by: Kit Lee <7000003+wingkitlee0@users.noreply.github.com>
Signed-off-by: Steve Han <stevehan2001@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-backlog community-contribution Contributed by the community data Ray Data-related issues @external-author-action-required Alternate tag for PRs where the author doesn't have labeling permission. go add ONLY when ready to merge, run all tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Ray Data] Dataset.random_sample() does not return deterministic results even when seed is set

6 participants