Skip to content

Use dataclass instead of TypedDict for config #1675

@xuantengh

Description

@xuantengh

Currently, nemo-rl uses TypedDict to host the config options:

class GRPOConfig(TypedDict):
num_prompts_per_step: int
num_generations_per_prompt: int
max_num_epochs: int
max_num_steps: int
max_rollout_turns: int
normalize_rewards: bool
use_leave_one_out_baseline: bool
val_period: int
val_batch_size: int
val_at_start: bool
max_val_samples: int
seed: int
async_grpo: NotRequired[AsyncGRPOConfig]
overlong_filtering: NotRequired[bool]
# whether to enable dynamic sampling, i.e.
# whether to discard prompts whose rewards have zero standard deviation
use_dynamic_sampling: bool
# When using dynamic sampling, the maximum number of batches to generate
# before throwing an error
dynamic_sampling_max_gen_batches: NotRequired[int]
# When using dynamic sampling, generation prompt batch size will equal
# num_prompts_per_step * batch_multiplier
batch_multiplier: NotRequired[float]
reward_shaping: RewardShapingConfig
reward_scaling: RewardScalingConfig
class GRPOSaveState(TypedDict):
consumed_samples: int
current_step: int
current_epoch: int
total_steps: int
total_valid_tokens: int # Track total number of non-padding tokens during training
val_reward: NotRequired[
float
] # Optional field - may not be present during training
def _default_grpo_save_state() -> GRPOSaveState:
return {
"consumed_samples": 0,
"current_step": 0,
"current_epoch": 0,
"total_steps": 0,
"total_valid_tokens": 0,
"val_reward": -99999999.0,
}
class GRPOLoggerConfig(LoggerConfig):
num_val_samples_to_print: int # number of val samples to print to stdout
class MasterConfig(TypedDict):
policy: PolicyConfig
loss_fn: ClippedPGLossConfig
env: dict[str, Any]
data: DataConfig
grpo: GRPOConfig
logger: GRPOLoggerConfig
cluster: ClusterConfig
checkpointing: CheckpointingConfig

We can consider to use dataclass to host these options like:

@dataclass
class GRPOConfig:
    num_prompts_per_step: int
    num_generations_per_prompt: int
    max_num_epochs: int
    max_num_steps: int
    # ...

It has several benifits, like, more friendly for type checking for pyrefly/mypy, and ease of development, e.g., you can directly access config.logger.log_dir with type hints in IDE/editor, instead of typing config["logger"]["log_dir"]. We may also add some runtime validation logic in the __post_init__ dunder for dataclass.

The dataclass could be converted from omegaconf using to_object method.

Metadata

Metadata

Assignees

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions