[Trainer] accelerate contextparallel support in trainer#40205
Conversation
src/transformers/trainer.py
Outdated
| self.is_tp_enabled = getattr(self.accelerator.state, "torch_tp_plugin", None) is not None | ||
| self.is_cp_enabled = ( | ||
| getattr(self.accelerator.state, "parallelism_config", None) is not None | ||
| and getattr(self.accelerator.state.parallelism_config, "cp_size", 1) > 1 |
There was a problem hiding this comment.
Should we only rely onparallelism_config to configure CP?
src/transformers/trainer.py
Outdated
| self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None | ||
| self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None | ||
| self.is_tp_enabled = getattr(self.accelerator.state, "torch_tp_plugin", None) is not None | ||
| self.is_cp_enabled = ( |
There was a problem hiding this comment.
It would be great to just use self.parallelism_config = getattr(self.accelerator.parallelism_config, None), and also to have a ref for parallelism_config in TrainerState
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
src/transformers/training_args.py
Outdated
| if not self.fsdp: | ||
| from accelerate.utils import FullyShardedDataParallelPlugin | ||
|
|
||
| self.fsdp_plugin = FullyShardedDataParallelPlugin( | ||
| fsdp_version=2, | ||
| auto_wrap_policy="transformer_based_wrap", | ||
| state_dict_type="FULL_STATE_DICT", | ||
| ) | ||
| else: | ||
| # Ensure FSDP v2 is used when context parallelism is enabled | ||
| if self.fsdp_config.get("version", 1) != 2: | ||
| logger.warning("Context parallelism requires FSDP v2. Updating FSDP config to use version 2.") | ||
| self.fsdp_config["version"] = 2 |
There was a problem hiding this comment.
Shouldn't it warn the user when it's enabling FSDP without explicit configuration from the user?
src/transformers/trainer.py
Outdated
| and num_items_in_batch is not None | ||
| ): | ||
| loss *= self.accelerator.num_processes | ||
| # if ( |
There was a problem hiding this comment.
TODO: Need to understand why we need this realistically.
|
@SunMarc I have fixed the issues you raised |
src/transformers/trainer.py
Outdated
| logger.info(f"Saving model checkpoint to {output_dir}") | ||
|
|
||
| # Defer to accelerate's get_state_dict when using distributed setups that require special state dict handling | ||
| if state_dict is None and (self.is_fsdp2 or self.is_deepspeed_enabled): |
There was a problem hiding this comment.
We don't need this at all. save_pretrained works with torch parallelism just ok. I suppose we do want to keep this for non transformers models only?
|
Failing tests seem unrelated. |
SunMarc
left a comment
There was a problem hiding this comment.
Thanks ! A few nits but overall LGTM
src/transformers/trainer.py
Outdated
| if state_dict is None and (getattr(self.accelerator, "is_fsdp2", False) or self.is_deepspeed_enabled): | ||
| state_dict = self.accelerator.get_state_dict(self.model) |
There was a problem hiding this comment.
is there an issue with how things are currently handled ? just to better understand
There was a problem hiding this comment.
I think it would just silently fail at this point, but it's with custom models which is a rather rare use-case.
| if ( | ||
| getattr(self.accelerator, "parallelism_config") is not None | ||
| and self.accelerator.parallelism_config.cp_enabled | ||
| ): |
There was a problem hiding this comment.
still need to fix that potentially but we can do that in a follow up otherwise
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
|
on it @SunMarc |
|
Am I missing something or was this feature merged w/o adding any tests? I'm working on an HF Trainer integration PR for ALST/UlyssesSP via huggingface/accelerate#3817 and I was hoping to have some existing CP tests I could extend/copy but I can't find any. How will you know if this feature breaks if you have no tests? The Accelerate side doesn't test most of this feature either. I'm puzzled. |
|
CI test PR #41860 |

What does this PR do?
Add support for context parallel in the Trainer