[RLlib] New ConnectorV2 API #02: SingleAgentEpisode enhancements.#41075
Conversation
| actions: List[ActType] = None, | ||
| rewards: List[SupportsFloat] = None, | ||
| infos: List[Dict] = None, | ||
| states=None, |
There was a problem hiding this comment.
State outputs are no longer needed as a separate field. They are treated just like any other extra model outputs (e.g. as a (possibly nested) dict under the STATE_OUT key).
rllib/env/single_agent_episode.py
Outdated
| ) | ||
| if self.t_started < len(self.observations) - 1: | ||
| self.t = len(self.observations) - 1 | ||
| self._len_pre_buffer = len(self.rewards) |
There was a problem hiding this comment.
Added concept of a "lookback buffer" inside an ongoing episode.
This allows for custom connectors to look back at previous data until a certain (user defined) amount of timesteps, e.g. to be able to add "prev. rewards", "prev. 5 actions", etc.. to a model's input (via custom connectors).
rllib/env/single_agent_episode.py
Outdated
| """ | ||
| assert episode_chunk.id_ == self.id_ | ||
| assert not self.is_done | ||
| assert not self.is_done and not self.is_numpy |
There was a problem hiding this comment.
For simplicity, we assume that Episode is still in the "list-format" (not numpyized yet).
We might have to change this concat_episode() API in the future, but right now, it's only used inside DreamerV3's replay buffer anyways (and in some test cases).
| self.validate() | ||
|
|
||
| def add_initial_observation( | ||
| def add_env_reset( |
There was a problem hiding this comment.
Changed name for clarity:
add_env_reset(): Add all the data returned by an env.reset calladd_env_step(): Add all the data returned by an env.step call
| self.validate() | ||
|
|
||
| def add_timestep( | ||
| def add_env_step( |
| self.extra_model_outputs[k] = [v] | ||
| else: | ||
| self.extra_model_outputs[k].append(v) | ||
| self.extra_model_outputs[k].append(v) |
There was a problem hiding this comment.
simplified via defaultdict
rllib/env/single_agent_episode.py
Outdated
|
|
||
| self.observations = np.array(self.observations) | ||
| self.actions = np.array(self.actions) | ||
| self.observations = batch(self.observations) |
There was a problem hiding this comment.
Allow for nested obs/action spaces.
rllib/env/single_agent_episode.py
Outdated
| self.render_images = np.array(self.render_images, dtype=np.uint8) | ||
| for k, v in self.extra_model_outputs.items(): | ||
| self.extra_model_outputs[k] = np.array(v) | ||
| self.extra_model_outputs[k] = batch(v) |
There was a problem hiding this comment.
Allow for complex (nested) model outs (especially now that states are part of these extra model outs).
There was a problem hiding this comment.
can we use batch and not np.array conversion everywhere? This allows us to unittest batch and make sure it's behavior is predictable and re-used that everywhere.
There was a problem hiding this comment.
The argument against it is that this would be overkill (we know that rewards are only a lits of floats, never complex structs). But yes, batch() should work on these as well, of course. There is a proper unit test for batch, which was added recently.
There was a problem hiding this comment.
Solved: I added an extra test for batch/unbatch on simple structs AND used batch() everywhere in this method (even on rewards).
rllib/env/single_agent_episode.py
Outdated
| for k in extra_model_output_keys | ||
| }, | ||
| ) | ||
| def get_observations(self, indices: Optional[Union[int, List[int], slice]] = None) -> Any: |
There was a problem hiding this comment.
Added these very practical new APIs to get data from the episode in a user friendly fashion.
| }, | ||
| ) | ||
|
|
||
| @staticmethod |
There was a problem hiding this comment.
We'll try to get rid of SampleBatch eventually (it's kind of an overloaded mess). There is no application currently that requires constructing an episode from an existing SampleBatch (only the other way around: Episode -> SampleBatch)
rllib/env/single_agent_env_runner.py
Outdated
| gym.register( | ||
| "custom-env-v0", | ||
| partial( | ||
| if ( |
There was a problem hiding this comment.
This is a bug fix. Otherwise, passing in a class to config.environment(env=[some class]) does not work (only strings work).
…e_fixes Signed-off-by: Sven Mika <svenmika1977@gmail.com>
| # TODO (simon): Check, if this works for the default | ||
| # stateful encoders. | ||
| initial_state={k: s[i] for k, s in states.items()}, | ||
| self._episodes[i].add_env_reset( |
There was a problem hiding this comment.
Cleaner naming of these Episode methods:
add_env_resetadd_env_step
Both add to an episode the return values of those gym.Env calls.
| self._ts_since_last_metrics: int = 0 | ||
| self._weights_seq_no: int = 0 | ||
|
|
||
| # TODO (sven): This is a temporary solution. STATE_OUTs |
There was a problem hiding this comment.
Temp. fix: We need the new connectors to make this work w/o having to keep self._states around here. The PRs for this are lined up and rely on this one here to be merged first.
| infos = [] | ||
| extra_model_outputs = [] | ||
| states = np.random.random(10) | ||
| extra_model_outputs = {"extra_1": [], "state_out": np.random.random()} |
There was a problem hiding this comment.
Fixed the tests to move state_out into being just another extra_model_out.
…_sa_episode_fixes' into env_runner_support_connectors_02_sa_episode_fixes
…runner_support_connectors_02_sa_episode_fixes
…runner_support_connectors_02_sa_episode_fixes
…runner_support_connectors_02_sa_episode_fixes
…runner_support_connectors_02_sa_episode_fixes
|
the rllib tests are still failing. |
…runner_support_connectors_02_sa_episode_fixes
…runner_support_connectors_02_sa_episode_fixes
Signed-off-by: sven1977 <svenmika1977@gmail.com>
This PR is the 2nd in the "enhanced/new ConnectorV2 API" series:
SingleAgentEpisodeclass; more consistent API names and additional convenience getter APIs for obs, actions, etc.; removesstateproperty from Episodes (now just anotherextra_model_outputsubkey).SingleAgenEnvRunner). Merged and activated test cases.Why are these changes needed?
Related issue number
Checks
git commit -s) in this PR.scripts/format.shto lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/under thecorresponding
.rstfile.