Skip to content

Commit 72cb53f

Browse files
committed
Update on "Slicing with backed should produce backed output when possible "
* `x[0:s1]` where x.size(0) = `s0-1` should produce `Min(s1, s0-1)` * Before this PR, it would produce `u0`. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo [ghstack-poisoned]
2 parents d0d5f4f + 631961b commit 72cb53f

1 file changed

Lines changed: 1 addition & 12 deletions

File tree

test/distributed/tensor/test_dtensor_export.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -543,18 +543,7 @@ def forward(self, x):
543543
%item : [num_users=2] = call_method[target=item](args = (%clamp,), kwargs = {})
544544
%ge_1 : [num_users=1] = call_function[target=operator.ge](args = (%item, 1), kwargs = {})
545545
%_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_1, Runtime assertion failed for expression u0 >= 1 on node 'ge_1'), kwargs = {})
546-
%getitem : [num_users=3] = call_function[target=operator.getitem](args = (%l_x_, slice(None, item, None)), kwargs = {})
547-
%getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%getitem, _local_tensor), kwargs = {})
548-
%sym_size_int : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%getattr_1, 0), kwargs = {})
549-
%sym_size_int_1 : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%getitem, 0), kwargs = {})
550-
%ge_2 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int, 0), kwargs = {})
551-
%_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_2, Runtime assertion failed for expression u2 >= 0 on node 'ge_2'), kwargs = {})
552-
%le : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int, 4), kwargs = {})
553-
%_assert_scalar_default_2 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le, Runtime assertion failed for expression u2 <= 4 on node 'le'), kwargs = {})
554-
%ge_3 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
555-
%_assert_scalar_default_3 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_3, Runtime assertion failed for expression u1 >= 0 on node 'ge_3'), kwargs = {})
556-
%le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 4), kwargs = {})
557-
%_assert_scalar_default_4 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u1 <= 4 on node 'le_1'), kwargs = {})
546+
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%l_x_, slice(None, item, None)), kwargs = {})
558547
return (getitem,)""", # noqa: B950
559548
)
560549

0 commit comments

Comments
 (0)