Skip to content

[Data] Make the seed take effect in Dataset.random_sample()#46088

Closed
liuxsh9 wants to merge 2 commits intoray-project:masterfrom
liuxsh9:random_sample
Closed

[Data] Make the seed take effect in Dataset.random_sample()#46088
liuxsh9 wants to merge 2 commits intoray-project:masterfrom
liuxsh9:random_sample

Conversation

@liuxsh9
Copy link
Copy Markdown
Contributor

@liuxsh9 liuxsh9 commented Jun 17, 2024

Why are these changes needed?

A quick fix for the #40406. The random_sample is called within the map_batches, externally set seed did not take effect in the context. After:

ds = ray.data.range(100000)
ds.random_sample(0.2, seed=1234).count()
Out[3]: 20775
ds.random_sample(0.2, seed=1234).count()
Out[4]: 20775
ds.random_sample(0.2, seed=1234).count()
Out[5]: 20775

However, after the fix, we have observed that in some cases, when the seed set, the return size has a significant gap from the expected fraction * total_rows, for example:

ds = ray.data.range(1000)
ds.random_sample(0.5, seed=1111).count()
Out[7]: 1000
ds.random_sample(0.5, seed=1234).count()
Out[8]: 500
ds.random_sample(0.5, seed=5).count()
Out[9]: 0

The issue seems to be that the map_batches groups the samples into very small batches (e.g., 2 samples per batch) and performs random sampling within each batch. With a fixed seed, the samples selected within each batch are deterministic, leading to sampling ratios of 100%, 50%, 0%.

We would like to understand the community's suggestions - after setting the seed, should we eliminate the map_batches operation to achieve more stable and reproducible results? Or should we consider setting the default batch_size more explicitly and stably, to ensure the actual effective batch size is more predictable and consistent.

Related issue number

#40406

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

liuxsh9 added 2 commits June 17, 2024 19:28
Signed-off-by: Xiaoshuang Liu <liuxiaoshuang4@huawei.com>
Signed-off-by: Xiaoshuang Liu <liuxiaoshuang4@huawei.com>
@liuxsh9
Copy link
Copy Markdown
Contributor Author

liuxsh9 commented Jun 17, 2024

Perhaps you are also interested in this one @Bye-legumes @nemo9cby, please feel free to provide some suggestions!

@anyscalesam anyscalesam added triage Needs triage (eg: priority, bug/not-bug, and owning component) data Ray Data-related issues labels Aug 6, 2024
@alexeykudinkin
Copy link
Copy Markdown
Contributor

@liuxsh9 can you please add a test for it?

@alexeykudinkin alexeykudinkin added the go add ONLY when ready to merge, run all tests label Mar 24, 2025
@raulchen
Copy link
Copy Markdown
Contributor

hey @liuxsh9 , I think #51401 can avoid the issue you mentioned.

@pcmoritz pcmoritz requested a review from a team as a code owner March 26, 2025 22:34
@richardliaw
Copy link
Copy Markdown
Contributor

Hi, we'll be taking #51401 -- feel free to work with @wingkitlee0 on this. thanks a bunch for the contribution!

richardliaw pushed a commit that referenced this pull request Apr 10, 2025
<!-- Thank you for your contribution! Please review
https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before
opening a pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed?

**Problem**

Current `random_sample()` does not work with fixed seed.
#40406

Previous attempts (changing the global seed or passing the same
seed/state to workers) also do not work.

**Solution [Updated after PR review]**

In order to use random generators in parallel, we need to be careful
about the seed/state that passes into `map_batches`. `numpy` describes a
few methods and one of them is to use a sequence of seeds
https://numpy.org/doc/2.2/reference/random/parallel.html#sequence-of-integer-seeds.
In Ray Data, we can construct a `random_sample()` UDF that has access to
a "block id" via `TaskContext` (that is thread-local) and use
`[block_id, seed]` to initialize a RNG. As the Ray task may be reused
for different blocks, the RNG is saved into `TaskContext.kwargs`.

**Proposed fix  [Updated after PR review]**

We add `set/get_current()` methods to `TaskContext` which allow the UDF
to get a local copy. It has access to the `task_idx` and previously
initialized RNG. This removes the need of extra arguments in the
original proposal.

**After fix**
```python
In [9]: ds = ray.data.range(1000)

In [10]: ds.random_sample(0.05, seed=1234).take_batch()
Out[10]:
{'id': array([ 27,  54,  72, 111, 136, 144, 147, 168, 200, 224, 225, 245, 247,
        248, 307, 312, 313, 340, 347, 375])}

In [11]: ds.random_sample(0.05, seed=1234).take_batch()
Out[11]:
{'id': array([ 27,  54,  72, 111, 136, 144, 147, 168, 200, 224, 225, 245, 247,
        248, 307, 312, 313, 340, 347, 375])}
```

## Related issue number

This issue has been raised a few times:
Closes #40406 #48497 

Other implementations did not solve the root cause:
#46088
#49443

<!-- For example: "Closes #1234" -->

## Checks

- [x] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [x] I've run `scripts/format.sh` to lint the changes in this PR.
- [ ] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [ ] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [ ] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [x] Unit tests
   - [ ] Release tests
   - [ ] This PR is not tested :(

Signed-off-by: Kit Lee <7000003+wingkitlee0@users.noreply.github.com>
han-steve pushed a commit to han-steve/ray that referenced this pull request Apr 11, 2025
<!-- Thank you for your contribution! Please review
https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before
opening a pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed?

**Problem**

Current `random_sample()` does not work with fixed seed.
ray-project#40406

Previous attempts (changing the global seed or passing the same
seed/state to workers) also do not work.

**Solution [Updated after PR review]**

In order to use random generators in parallel, we need to be careful
about the seed/state that passes into `map_batches`. `numpy` describes a
few methods and one of them is to use a sequence of seeds
https://numpy.org/doc/2.2/reference/random/parallel.html#sequence-of-integer-seeds.
In Ray Data, we can construct a `random_sample()` UDF that has access to
a "block id" via `TaskContext` (that is thread-local) and use
`[block_id, seed]` to initialize a RNG. As the Ray task may be reused
for different blocks, the RNG is saved into `TaskContext.kwargs`.

**Proposed fix  [Updated after PR review]**

We add `set/get_current()` methods to `TaskContext` which allow the UDF
to get a local copy. It has access to the `task_idx` and previously
initialized RNG. This removes the need of extra arguments in the
original proposal.

**After fix**
```python
In [9]: ds = ray.data.range(1000)

In [10]: ds.random_sample(0.05, seed=1234).take_batch()
Out[10]:
{'id': array([ 27,  54,  72, 111, 136, 144, 147, 168, 200, 224, 225, 245, 247,
        248, 307, 312, 313, 340, 347, 375])}

In [11]: ds.random_sample(0.05, seed=1234).take_batch()
Out[11]:
{'id': array([ 27,  54,  72, 111, 136, 144, 147, 168, 200, 224, 225, 245, 247,
        248, 307, 312, 313, 340, 347, 375])}
```

## Related issue number

This issue has been raised a few times:
Closes ray-project#40406 ray-project#48497

Other implementations did not solve the root cause:
ray-project#46088
ray-project#49443

<!-- For example: "Closes ray-project#1234" -->

## Checks

- [x] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [x] I've run `scripts/format.sh` to lint the changes in this PR.
- [ ] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [ ] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [ ] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [x] Unit tests
   - [ ] Release tests
   - [ ] This PR is not tested :(

Signed-off-by: Kit Lee <7000003+wingkitlee0@users.noreply.github.com>
Signed-off-by: Steve Han <stevehan2001@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

data Ray Data-related issues go add ONLY when ready to merge, run all tests triage Needs triage (eg: priority, bug/not-bug, and owning component)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants