Skip to content

[rllib] Allow for JAX framework #8732

@KristianHolsheimer

Description

@KristianHolsheimer

What is the problem?

I've been using JAX as my framework for a little while now. I just upgraded to the nightly build (due to some unrelated issues) and now RLlib is telling me I need to install TensorFlow or Torch.

I tried setting {'framework': 'jax', ...} in my trainer config, but this results in another error. Basically, not recognizing any framework other than one of: [tf|tfe|torch|auto]

Ray version: ray-0.9.0.dev0, Python 3.8, Ubuntu Linux 20.04 LTS

Script to reproduce:

import ray
from ray.rllib.policy.policy import Policy as BasePolicy
from ray.rllib.agents.trainer_template import build_trainer


class Policy(BasePolicy):
    def compute_actions(self, obs_batch, **kwargs):
        actions = [self.action_space.sample() for _ in obs_batch]
        return actions, [], {}

    def get_weights(self):
        pass
    
    def set_weights(self, weights):
        pass

    def learn_on_batch(self, sample_batch):
        pass


trainer = build_trainer(
    name='foo',
    default_policy=Policy,
)

ray.init()
ray.tune.run(
    trainer,
    config={
        'framework': 'jax',
        'env': 'FrozenLake-v0',
    },
    stop={'training_iteration': 1}
)

If we cannot run your script, we cannot fix your issue.

  • I have verified my script runs in a clean environment and reproduces the issue.
  • I have verified the issue also occurs with the latest wheels.

Metadata

Metadata

Assignees

Labels

P1Issue that should be fixed within a few weeksbugSomething that is supposed to be working; but isn't

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions