[RLlib] MetricsLogger: Fix get/set_state to handle tensors in self.values.#53514
Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR updates Stats.get_state and Stats.from_state to correctly handle tensor values in self.values, ensuring state serialization avoids returning raw tensors.
- Introduces
convert_to_numpyinget_stateto serialize any tensors to NumPy. - Adds a
_could_be_tensorflag to track potential tensor insertions and skips repopulatingvaluesinfrom_statewhen tensors were present. - Refactors
check_valueto detect zero-dimensional tensors and enforce scalar reduction behavior.
Comments suppressed due to low confidence (3)
rllib/utils/metrics/stats.py:187
- [nitpick] The private flag
_could_be_tensormay be clearer as_may_have_tensorsor_has_tensorsto reflect its boolean purpose more directly.
self._could_be_tensor = False
rllib/utils/metrics/stats.py:639
- There are no explicit unit tests covering tensor serialization in
get_stateand recovery infrom_state. Adding tests for pushing Torch/TF tensors and ensuring correct NumPy output will prevent regressions.
def get_state(self) -> Dict[str, Any]:
rllib/utils/metrics/stats.py:642
convert_to_numpymay not handledequeinputs uniformly. It could be safer to convertself.valuesto a list first, e.g.,convert_to_numpy(list(self.values)).
"values": convert_to_numpy(self.values),
rllib/utils/metrics/stats.py
Outdated
| len_before_reduce = len(self) | ||
| if self._has_new_values: | ||
| # Only calculate and update history if there were new values pushed since last reduce | ||
| # Only calculate and update history if there were new values pushed since' |
There was a problem hiding this comment.
There's a stray apostrophe after since. Removing it will clean up the comment.
| # Only calculate and update history if there were new values pushed since' | |
| # Only calculate and update history if there were new values pushed since |
|
|
||
| @staticmethod | ||
| def from_state(state: Dict[str, Any], throughputs=False) -> "Stats": | ||
| def from_state(state: Dict[str, Any]) -> "Stats": |
There was a problem hiding this comment.
The throughputs parameter was removed from from_state, which may break existing callers relying on that signature. Consider retaining a backward-compatible overload or documenting the change.
| # whether we are on a supported device). | ||
| values = state["values"] | ||
| if "_could_be_tensor" in state and state["_could_be_tensor"]: | ||
| values = [] |
There was a problem hiding this comment.
- If we do this, we should state it as a known limitation in metrics-logger.rst.
- Alternatively, can we noch check whether we are on a supported device and keep track of GPU tensors?
There was a problem hiding this comment.
- That's true. This is a limitation for now. Albeit a very small one since - normally - tensor logging is only done for loss metrics, like loss, entropy, etc.. and these are very very often ephemeral values where it's not a problem at all to just start fresh after a checkpoint loading (
window=1anyways?). - Yeah, but then we would have to store the original device as well, which quickly becomes super messy (when you transfer a checkpoint from one cluster type (GPUs) to another (no GPUs?)).
There was a problem hiding this comment.
- What if
window!=1? It's a user-facing class. - Rodger! Yeah there is some ugly complexity here.
rllib/utils/metrics/stats.py
Outdated
| self.values: Union[List, deque.Deque] = None | ||
| self._set_values(force_list(init_values)) | ||
|
|
||
| self._could_be_tensor = False |
There was a problem hiding this comment.
From reading the code self._could_be_tensor can be renamed to self._is_tensor?
ArturNiederfahrenhorst
left a comment
There was a problem hiding this comment.
Left some comments but - no blockers.
…ics_logger_get_set_state_handling_tensors
….values`. (#53514) Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
….values`. (#53514) Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
MetricsLogger: Fix
get/set_stateto handle tensors inself.values.self._may_have_tensorsflag to be True.self._may_have_tensorsflag and does NOT populate theself.valuesfield, if it's True.Why are these changes needed?
Related issue number
Closes #53467
Checks
git commit -s) in this PR.scripts/format.shto lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/under thecorresponding
.rstfile.