Skip to content

Fix bug in ModelCatalog when using custom action distribution#12846

Merged
sven1977 merged 4 commits intoray-project:masterfrom
janblumenkamp:fix_action_dist_catalog
Jan 25, 2021
Merged

Fix bug in ModelCatalog when using custom action distribution#12846
sven1977 merged 4 commits intoray-project:masterfrom
janblumenkamp:fix_action_dist_catalog

Conversation

@janblumenkamp
Copy link
Copy Markdown
Contributor

Why are these changes needed?

Previously, when a custom action distribution is defined, ModelCatalog._get_multi_action_distribution was assigned to dist_cls, which later is returned as tuple together with dist_cls.required_model_output_shape. This tuple is already returned from _get_multi_action_distribution though, so it can directly be returned from get_action_dist.

Related issue number

None

Checks

  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

@janblumenkamp janblumenkamp changed the title return tuple returned from _get_multi_action_distribution when using … Fix bug in ModelCatalog when using custom action distribution Dec 14, 2020
@janblumenkamp
Copy link
Copy Markdown
Contributor Author

@sven1977 (I can't assign you as reviewer)

@sven1977
Copy link
Copy Markdown
Contributor

sven1977 commented Jan 8, 2021

Thanks @janblumenkamp for this PR.
Could you also change this line as indicated below in catalog.py? It's another bug along the same lines.

   @staticmethod
    def _get_multi_action_distribution(dist_class, action_space, config,
                                       framework):
        # In case the custom distribution is a child of MultiActionDistr.
        # If users want to completely ignore the suggested child
        # distributions, they should simply do so in their custom class'
        # constructor.
        if issubclass(dist_class,
                      (MultiActionDistribution, TorchMultiActionDistribution)):
            flat_action_space = flatten_space(action_space)
            child_dists_and_in_lens = tree.map_structure(
                lambda s: ModelCatalog.get_action_dist(
                    s, config, framework=framework), flat_action_space)
            child_dists = [e[0] for e in child_dists_and_in_lens]
            input_lens = [int(e[1]) for e in child_dists_and_in_lens]
            return partial(
                dist_class,
                action_space=action_space,
                child_distributions=child_dists,
                input_lens=input_lens), int(sum(input_lens))
        return dist_class, dist_class.required_model_output_shape(
            action_space, config) # <<------ HERE!!!

@janblumenkamp janblumenkamp force-pushed the fix_action_dist_catalog branch from d5224a0 to 67bd26e Compare January 9, 2021 20:37
@sven1977 sven1977 self-assigned this Jan 11, 2021
@sven1977
Copy link
Copy Markdown
Contributor

Hey @janblumenkamp , thanks for doing this PR! Could you fix the test_catalog test case? It's just a simple key error caused by an empty config inside the test case I think.

@janblumenkamp
Copy link
Copy Markdown
Contributor Author

I think this fixed the issue, but I am not sure if the current fails are caused by me or by master?

dist_cls = ModelCatalog._get_multi_action_distribution(
dist_cls, action_space, {}, framework)
return ModelCatalog._get_multi_action_distribution(
dist_cls, action_space, config, framework)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing the original config here seems to lead to recursion, as _get_multi_action_distribution also calls get_action_dist and config still contains the custom_action_dist key.

Perhaps:

if config.get("custom_action_dist"):
            custom_action_config = config.copy()
            action_dist_name = custom_action_config.pop("custom_action_dist")

            logger.debug(
                "Using custom action distribution {}".format(action_dist_name))
            dist_cls = _global_registry.get(RLLIB_ACTION_DIST,
                                            action_dist_name)
            return ModelCatalog._get_multi_action_distribution(
                dist_cls, action_space, custom_action_config, framework)

@sven1977
Copy link
Copy Markdown
Contributor

All good now! Thanks everyone.

Copy link
Copy Markdown
Contributor

@sven1977 sven1977 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @janblumenkamp for the fixes!

@sven1977 sven1977 merged commit 964689b into ray-project:master Jan 25, 2021
@janblumenkamp
Copy link
Copy Markdown
Contributor Author

Sorry for not responding to zseymours review earlier! Thanks for merging, Sven. Out of curiosity, how was the point raised in the review fixed eventually? It doesn't seem to be fixed in master.

fishbone pushed a commit to fishbone/ray that referenced this pull request Feb 16, 2021
…ray-project#12846)

* return tuple returned from _get_multi_action_distribution when using custom action dict

* Always return dst_class and required_model_output_shape in _get_multi_action_distribution

* pass model config to _get_multi_action_distribution
fishbone added a commit to fishbone/ray that referenced this pull request Feb 16, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants