Skip to content

[Functionalization] Enable FSDP#4691

Merged
alanwaketan merged 3 commits intofunctionalizationfrom
alanwaketn/fsdp_func
Feb 25, 2023
Merged

[Functionalization] Enable FSDP#4691
alanwaketan merged 3 commits intofunctionalizationfrom
alanwaketn/fsdp_func

Conversation

@alanwaketan
Copy link
Copy Markdown
Collaborator

@alanwaketan alanwaketan commented Feb 24, 2023

Summary:
This pull request enables FSDP by replacing .set_ with our own _replace_xla_tensor API. The reason for that is Functionalization pass will reapply the new value to all the tensor's aliases since it's an in-place ops. However, that reapplication assumes the source and the destination would share the same amount of elements (view_copy). And .set_ doesn't follow this rule.

P.S. It also removes two .data tests that are no longer applicable.

Test Plan:
CI.

Summary:
This pull request enables FSDP by replacing .set_ with our own _replace_xla_tensor
API. The reason for that is Functionalization pass will reapply the new value to all the
tensor's aliases since it's an in-place ops. However, that reapplication assumes the source
and the destination would share the same amount of elements (view_copy). And .set_ doesn't
follow this rule.

P.S. It also removes two .data tests that are no longer applicable.

Test Plan:
CI.
@alanwaketan
Copy link
Copy Markdown
Collaborator Author

@bdhirsh Hit a weird crash while running .set in our resnet FSDP:

root@t1v-n-307ffe96-w-0:/workspaces/work/pytorch/xla# PJRT_DEVICE=CPU python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
==> Preparing data..
Epoch 1 train begin 04:51:55
Traceback (most recent call last):
  File "test/test_train_mp_imagenet_fsdp.py", line 389, in <module>
    _mp_fn(0, FLAGS)
  File "test/test_train_mp_imagenet_fsdp.py", line 380, in _mp_fn
    accuracy = train_imagenet()
  File "test/test_train_mp_imagenet_fsdp.py", line 351, in train_imagenet
    train_loop_fn(train_device_loader, epoch)
  File "test/test_train_mp_imagenet_fsdp.py", line 322, in train_loop_fn
    loss.backward()
  File "/workspaces/work/pytorch/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/workspaces/work/pytorch/torch/autograd/__init__.py", line 204, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/workspaces/work/pytorch/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspaces/work/pytorch/xla/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py", line 1188, in _post_backward_hook
    self._free_full_params(
  File "/workspaces/work/pytorch/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspaces/work/pytorch/xla/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py", line 1398, in _free_full_params
    p.set_(self._dummy_data_placeholder)
RuntimeError: /workspaces/work/pytorch/xla/torch_xla/csrc/data_ops.cpp:69 : Check failed: total_element_count == xla::util::Multiply<int64_t>(output_sizes) (1 vs. 2048000)
*** Begin stack trace ***
        tsl::CurrentStackTrace[abi:cxx11]()
        torch_xla::GetCompleteShape(absl::lts_20220623::Span<long const>, absl::lts_20220623::Span<long const>)
        torch_xla::tensor_methods::view(c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&, absl::lts_20220623::Span<long const>)
        torch_xla::XLANativeFunctions::view_copy_symint(at::Tensor const&, c10::ArrayRef<c10::SymInt>)



        at::_ops::view_copy::call(at::Tensor const&, c10::ArrayRef<c10::SymInt>)


        std::function<at::Tensor (at::Tensor const&, long)>::operator()(at::Tensor const&, long) const
        at::FunctionalTensorWrapper::regenerate_from_base()
        at::FunctionalTensorWrapper::sync_()
        at::functionalization::impl::sync(at::Tensor const&)




        at::_ops::set__source_Tensor::redispatch(c10::DispatchKeySet, at::Tensor&, at::Tensor const&)




        at::_ops::set__source_Tensor::redispatch(c10::DispatchKeySet, at::Tensor&, at::Tensor const&)




        at::_ops::set__source_Tensor::call(at::Tensor&, at::Tensor const&)




        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_Vectorcall
        PyVectorcall_Call
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_Vectorcall

        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_Vectorcall
        PyVectorcall_Call
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_Vectorcall

        _PyObject_FastCallDict

        PyObject_Call

        torch::autograd::PyFunctionPostHook::operator()(std::vector<at::Tensor, std::allocator<at::Tensor> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> > const&)


        torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&)
        torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&)
        torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool)
        torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool)







        clone
*** End stack trace ***
(2048000) vs. (1)

@JackCaoG
Copy link
Copy Markdown
Collaborator

Is resnet failure blocking?

@alanwaketan
Copy link
Copy Markdown
Collaborator Author

Is resnet failure blocking?

No, I created a new API _replace_xla_tensor to workaround it.

Comment thread test/test_operations.py
Comment thread test/test_operations.py
Comment thread test/test_operations.py Outdated
self.assertEqual(met.counter_value('DestroyXlaTensor'), 5)

# shouldn't crash
t2.cpu()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, can we just do a value check here instead of just calling .cpu?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, Let me add that.

@alanwaketan
Copy link
Copy Markdown
Collaborator Author

Thanks @JackCaoG for approving the change. Will merge it after the CIs are green.

@bdhirsh The crash on the .set_() is still real. If we can fix that, we can then reuse .set_() again.

@alanwaketan alanwaketan merged commit cf44d6f into functionalization Feb 25, 2023
@bdhirsh
Copy link
Copy Markdown
Contributor

bdhirsh commented Feb 27, 2023

Sounds good. I'm having trouble repro'ing that crash without XLA, although I am able to repro at least one issue with .set_():

import torch
from functorch import functionalize

def f(x):
    y = torch.ones(2)
    x.view(-1)
    x.set_(y.storage())
    return x

x = torch.zeros(2)
out = functionalize(f)(x)

fails with:

  File "tmp5.py", line 7, in f
    x.set_(y.storage())
RuntimeError: t.storage().use_count() == 1 INTERNAL ASSERT FAILED at "/scratch/hirsheybar2/work/pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp":197, please report a bug to PyTorch.

Glad that you have a workaround for now!

@alanwaketan
Copy link
Copy Markdown
Collaborator Author

Sounds good. I'm having trouble repro'ing that crash without XLA, although I am able to repro at least one issue with .set_():

import torch
from functorch import functionalize

def f(x):
    y = torch.ones(2)
    x.view(-1)
    x.set_(y.storage())
    return x

x = torch.zeros(2)
out = functionalize(f)(x)

fails with:

  File "tmp5.py", line 7, in f
    x.set_(y.storage())
RuntimeError: t.storage().use_count() == 1 INTERNAL ASSERT FAILED at "/scratch/hirsheybar2/work/pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp":197, please report a bug to PyTorch.

Glad that you have a workaround for now!

Thanks, Brian. Let me see if I can write a repro for you without XLA.

alanwaketan added a commit that referenced this pull request Mar 1, 2023
Summary:
This pull request enables FSDP by replacing .set_ with our own _replace_xla_tensor API. The reason for that is Functionalization pass will reapply the new value to all the tensor's aliases since it's an in-place ops. However, that reapplication assumes the source and the destination would share the same amount of elements (view_copy). And .set_ doesn't follow this rule.

P.S. It also removes two .data tests that are no longer applicable.

Test Plan:
CI.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants