[RLlib] Fix performance and functionality flaws in attention nets (via Trajectory view API).#11729
[RLlib] Fix performance and functionality flaws in attention nets (via Trajectory view API).#11729sven1977 wants to merge 194 commits intoray-project:masterfrom
Conversation
…ectory_view_api_attention_nets
…nto trajectory_view_api_attention_nets # Conflicts: # rllib/models/tf/attention_net.py # rllib/policy/view_requirement.py
…ectory_view_api_enable_by_default_for_all_simple
…ectory_view_api_attention_nets
…ectory_view_api_enable_by_default_for_all_simple
…ectory_view_api_attention_nets # Conflicts: # rllib/agents/ppo/ppo_tf_policy.py # rllib/evaluation/tests/test_trajectory_view_api.py # rllib/policy/tf_policy_template.py # rllib/utils/tf_ops.py
…on_nets # Conflicts: # rllib/agents/ppo/ppo_tf_policy.py # rllib/evaluation/collectors/simple_list_collector.py # rllib/evaluation/tests/test_trajectory_view_api.py # rllib/policy/dynamic_tf_policy.py # rllib/policy/policy.py # rllib/policy/sample_batch.py # rllib/policy/view_requirement.py
| def __init__(self, shift_before: int = 0): | ||
| self.shift_before = max(shift_before, 1) | ||
| def __init__(self, view_reqs): | ||
| self.shift_before = -min( |
There was a problem hiding this comment.
Might want to add comment to describe what this code does!
| def add_init_obs(self, episode_id: EpisodeID, agent_index: int, | ||
| env_id: EnvID, t: int, init_obs: TensorType, | ||
| view_requirements: Dict[str, ViewRequirement]) -> None: | ||
| """Adds an initial observation (after reset) to the Agent's trajectory. |
There was a problem hiding this comment.
Change the description. It adds more than a single observation.
There was a problem hiding this comment.
No, it doesn't it's really just adds a single one. Same as it used to work w/ SampleBatchBuilder.
| / view_req.batch_repeat_value)) | ||
| repeat_count = (view_req.data_rel_pos_to - | ||
| view_req.data_rel_pos_from + 1) | ||
| data = np.asarray([ |
There was a problem hiding this comment.
Same as above, big confused. Add comments on what these lines of code do
There was a problem hiding this comment.
Provided an example.
| shift = view_req.data_rel_pos + obs_shift | ||
| # Shift is exactly 0: Use trajectory as is. | ||
| if shift == 0: | ||
| data = np_data[data_col][self.shift_before:] |
There was a problem hiding this comment.
Provided an example.
| [np.zeros(shape=shape, dtype=dtype) | ||
| for _ in range(shift)] | ||
|
|
||
| def _get_input_dict(self, view_reqs, abs_pos: int = -1) -> \ |
There was a problem hiding this comment.
Add description what this method does
| batch = SampleBatch(self.buffers) | ||
| assert SampleBatch.UNROLL_ID in batch.data | ||
| batch = SampleBatch( | ||
| self.buffers, _seq_lens=self.seq_lens, _dont_check_lens=True) |
There was a problem hiding this comment.
What is _dont_check_lens?
| return batch | ||
|
|
||
|
|
||
| class _PolicyCollectorGroup: |
There was a problem hiding this comment.
Probably add comments on what this is
rllib/policy/sample_batch.py
Outdated
| for i, seq_len in enumerate(self.seq_lens): | ||
| count += seq_len | ||
| if count >= end: | ||
| data["state_in_0"] = self.data["state_in_0"][state_start: |
There was a problem hiding this comment.
Add comment on what this does
| # Range of indices on time-axis, make sure to create | ||
| if view_req.data_rel_pos_from is not None: | ||
| ret[view_col] = np.zeros_like([[ | ||
| view_req.space.sample() |
There was a problem hiding this comment.
Same add comment here
| return x | ||
|
|
||
|
|
||
| class PositionalEmbedding(tf.keras.layers.Layer): |
There was a problem hiding this comment.
Add comments on what this does (how it initializes embedding per position based on cos/sin something)
…ectory_view_api_attention_nets # Conflicts: # rllib/agents/trainer.py # rllib/evaluation/collectors/simple_list_collector.py # rllib/evaluation/tests/test_trajectory_view_api.py # rllib/models/tf/attention_net.py # rllib/policy/policy.py # rllib/policy/torch_policy_template.py # rllib/policy/view_requirement.py # src/ray/raylet/node_manager.cc
…ectory_view_api_attention_nets # Conflicts: # rllib/agents/trainer.py # rllib/evaluation/collectors/simple_list_collector.py # rllib/evaluation/tests/test_trajectory_view_api.py # rllib/models/tf/attention_net.py # rllib/policy/policy.py # rllib/policy/torch_policy_template.py # rllib/policy/view_requirement.py # src/ray/raylet/node_manager.cc
…ectory_view_api_attention_nets � Conflicts: � rllib/agents/ppo/appo_tf_policy.py � rllib/agents/ppo/ppo_torch_policy.py � rllib/agents/qmix/model.py � rllib/evaluation/collectors/simple_list_collector.py � rllib/evaluation/rollout_worker.py � rllib/evaluation/tests/test_trajectory_view_api.py � rllib/models/modelv2.py � rllib/policy/dynamic_tf_policy.py � rllib/policy/policy.py � rllib/policy/sample_batch.py � rllib/policy/view_requirement.py
…ectory_view_api_attention_nets
…ectory_view_api_attention_nets
…ectory_view_api_attention_nets
…ectory_view_api_attention_nets
|
Moved here: |
Why are these changes needed?
Related issue number
Checks
scripts/format.shto lint the changes in this PR.