Currently, nemo-rl uses TypedDict to host the config options:
|
class GRPOConfig(TypedDict): |
|
num_prompts_per_step: int |
|
num_generations_per_prompt: int |
|
max_num_epochs: int |
|
max_num_steps: int |
|
max_rollout_turns: int |
|
normalize_rewards: bool |
|
use_leave_one_out_baseline: bool |
|
val_period: int |
|
val_batch_size: int |
|
val_at_start: bool |
|
max_val_samples: int |
|
seed: int |
|
async_grpo: NotRequired[AsyncGRPOConfig] |
|
overlong_filtering: NotRequired[bool] |
|
# whether to enable dynamic sampling, i.e. |
|
# whether to discard prompts whose rewards have zero standard deviation |
|
use_dynamic_sampling: bool |
|
# When using dynamic sampling, the maximum number of batches to generate |
|
# before throwing an error |
|
dynamic_sampling_max_gen_batches: NotRequired[int] |
|
# When using dynamic sampling, generation prompt batch size will equal |
|
# num_prompts_per_step * batch_multiplier |
|
batch_multiplier: NotRequired[float] |
|
reward_shaping: RewardShapingConfig |
|
reward_scaling: RewardScalingConfig |
|
|
|
|
|
class GRPOSaveState(TypedDict): |
|
consumed_samples: int |
|
current_step: int |
|
current_epoch: int |
|
total_steps: int |
|
total_valid_tokens: int # Track total number of non-padding tokens during training |
|
val_reward: NotRequired[ |
|
float |
|
] # Optional field - may not be present during training |
|
|
|
|
|
def _default_grpo_save_state() -> GRPOSaveState: |
|
return { |
|
"consumed_samples": 0, |
|
"current_step": 0, |
|
"current_epoch": 0, |
|
"total_steps": 0, |
|
"total_valid_tokens": 0, |
|
"val_reward": -99999999.0, |
|
} |
|
|
|
|
|
class GRPOLoggerConfig(LoggerConfig): |
|
num_val_samples_to_print: int # number of val samples to print to stdout |
|
|
|
|
|
class MasterConfig(TypedDict): |
|
policy: PolicyConfig |
|
loss_fn: ClippedPGLossConfig |
|
env: dict[str, Any] |
|
data: DataConfig |
|
grpo: GRPOConfig |
|
logger: GRPOLoggerConfig |
|
cluster: ClusterConfig |
|
checkpointing: CheckpointingConfig |
We can consider to use dataclass to host these options like:
@dataclass
class GRPOConfig:
num_prompts_per_step: int
num_generations_per_prompt: int
max_num_epochs: int
max_num_steps: int
# ...
It has several benifits, like, more friendly for type checking for pyrefly/mypy, and ease of development, e.g., you can directly access config.logger.log_dir with type hints in IDE/editor, instead of typing config["logger"]["log_dir"]. We may also add some runtime validation logic in the __post_init__ dunder for dataclass.
The dataclass could be converted from omegaconf using to_object method.
Currently, nemo-rl uses
TypedDictto host the config options:RL/nemo_rl/algorithms/grpo.py
Lines 115 to 177 in b238e41
We can consider to use
dataclassto host these options like:It has several benifits, like, more friendly for type checking for
pyrefly/mypy, and ease of development, e.g., you can directly accessconfig.logger.log_dirwith type hints in IDE/editor, instead of typingconfig["logger"]["log_dir"]. We may also add some runtime validation logic in the__post_init__dunder for dataclass.The dataclass could be converted from omegaconf using
to_objectmethod.