Skip to content

[rllib] Allow passing in additional TF operations for evaluation for PPO (e.g. for batch normalization) #2023

@kovalevvlad

Description

@kovalevvlad

I would like to use a simple TF model with batch normalization along with PPO. I am considering the following TF feature - https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization. The problem I am facing is that in order for the above batch normalization feature to work, you have to run the moving mean/variance update operations, in addition to the NN graph computation. As far as I can tell, there is no way of passing these operations into extra_ops in self.par_opt.optimize(extra_ops=...) call inside PPOEvaluator.run_sgd_minibatch(..). Is anyone aware of a workaround that I could use? Is there perhaps a proper way of doing this that I have missed? Apologies if this is a dumb question, I am fairly new to both TF and rllib.

EDIT: Additionally, I don't think there is currently a way to add extra values to the TF feed_dict during training/prediction. In my use-case, this is necessary to switch batch normalization between training and prediction modes.

Metadata

Metadata

Assignees

No one assigned

    Labels

    good-first-issueGreat starter issue for someone just starting to contribute to Ray

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions