Skip to content

TypeError when not passing total_episodes in PPOv2Trainer #1740

@meng-wenlong

Description

@meng-wenlong

Hi! I've been attempting to run examples/scripts/ppo/ppo_tldr.py and encountered an "TypeError: 'float' object cannot be interpreted as an integer" on the line of for update in range(1, args.num_updates + 1). After reading ppov2_trainer.py, I think the underlying issue originates from:

        #########
        # calculate various batch sizes
        #########
        if args.total_episodes is None:  # allow the users to define episodes in terms of epochs.
            args.total_episodes = args.num_train_epochs * self.train_dataset_len

num_train_epochs is an float argument in TraininingArguments, which PPOv2Config inherits. I suppose we should use args.num_ppo_epochs here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions