I would like to use a simple TF model with batch normalization along with PPO. I am considering the following TF feature - https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization. The problem I am facing is that in order for the above batch normalization feature to work, you have to run the moving mean/variance update operations, in addition to the NN graph computation. As far as I can tell, there is no way of passing these operations into extra_ops in self.par_opt.optimize(extra_ops=...) call inside PPOEvaluator.run_sgd_minibatch(..). Is anyone aware of a workaround that I could use? Is there perhaps a proper way of doing this that I have missed? Apologies if this is a dumb question, I am fairly new to both TF and rllib.
EDIT: Additionally, I don't think there is currently a way to add extra values to the TF feed_dict during training/prediction. In my use-case, this is necessary to switch batch normalization between training and prediction modes.