Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions aten/src/ATen/native/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,25 @@ Tensor& remainder_(Tensor& self, const Tensor& other) {
}

Tensor& floor_divide_out(Tensor& result, const Tensor& self, const Tensor& other) {
TORCH_WARN_ONCE(
"floor_divide is deprecated, and will be removed in a future version of PyTorch. "
"It currently rounds toward 0 (like the 'trunc' function, NOT 'floor'). "
"This results in incorrect rounding for negative values.\n"
"To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), "
"or for actual floor division, use torch.div(a, b, rounding_mode='floor')."
);
// FIXME: Not actually doing floor division (#43874)
return div_trunc_out(self, other, result);
}

Tensor floor_divide(const Tensor& self, const Tensor& other) {
TORCH_WARN_ONCE(
"floor_divide is deprecated, and will be removed in a future version of PyTorch. "
"It currently rounds toward 0 (like the 'trunc' function, NOT 'floor'). "
"This results in incorrect rounding for negative values.\n"
"To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), "
"or for actual floor division, use torch.div(a, b, rounding_mode='floor')."
);
Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
div_trunc_stub(iter.device_type(), iter);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/Indexing.cu
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ void index_put_accum_kernel(Tensor & self, const c10::List<c10::optional<Tensor>
using device_ptr = thrust::device_ptr<int64_t>;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

linearIndex.floor_divide_(sliceSize);
linearIndex.divide_(sliceSize, "trunc");
{
sorted_indices.copy_(linearIndex);
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ SparseTensor coalesce_sparse_cuda(const SparseTensor& self) {
// broadcasting logic; instead, it will blast the elements from one
// to the other so long as the numel is the same
indicesSlice.copy_(indices1D);
indices1D.floor_divide_(self.size(d));
indices1D.divide_(self.size(d), "trunc");
indicesSlice.add_(indices1D, -self.size(d));
}
}
Expand Down
39 changes: 19 additions & 20 deletions test/jit/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,16 @@ def _try_fn(self, fn, *args, **kwargs):
return e

def _verify_no(self, kind, m):
node_count = sum(kind in str(n) for n in m.graph.nodes())
self.assertEqual(node_count, 0)
self._verify_count(kind, m, 0)

def _verify_count(self, kind, m, count):
node_count = sum(kind in str(n) for n in m.graph.nodes())
node_count = sum(str(n).count(kind) for n in m.graph.nodes())
self.assertEqual(node_count, count)

"""
Tests that verify Torchscript remaps aten::div(_) from versions 0-3
to call either aten::true_divide(_), if an input is a float type,
or aten::floor_divide(_) otherwise.
or truncated aten::divide(_) otherwise.

NOTE: currently compares against current div behavior, too, since
div behavior has not yet been updated.
Expand All @@ -110,7 +109,7 @@ def test_versioned_div_tensor(self):
def historic_div(self, other):
if self.is_floating_point() or other.is_floating_point():
return self.true_divide(other)
return self.floor_divide(other)
return self.divide(other, rounding_mode='trunc')

# Tensor x Tensor
class MyModule(torch.nn.Module):
Expand All @@ -130,8 +129,8 @@ def forward(self, a, b):
except Exception as e:
self.skipTest("Failed to load fixture!")

self._verify_count("aten::div", v3_module, 3) # true_divide aliases to div
self._verify_count("aten::floor_divide", v3_module, 3)
self._verify_count("aten::div", v3_module, 6) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument

current_module = self._save_load_module(MyModule)
self._verify_count("aten::div", current_module, 3)
Expand All @@ -158,7 +157,7 @@ def test_versioned_div_tensor_inplace(self):
def historic_div_(self, other):
if self.is_floating_point() or other.is_floating_point():
return self.true_divide_(other)
return self.floor_divide_(other)
return self.divide_(other, rounding_mode='trunc')

class MyModule(torch.nn.Module):
def __init__(self):
Expand All @@ -173,8 +172,8 @@ def forward(self, a, b):
except Exception as e:
self.skipTest("Failed to load fixture!")

self._verify_count("aten::div", v3_module, 1) # true_divide aliases to div
self._verify_count("aten::floor_divide", v3_module, 1)
self._verify_count("aten::div", v3_module, 2) # true_divide and divide both alias to div
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument

current_module = self._save_load_module(MyModule)
self._verify_count("aten::div", current_module, 1)
Expand Down Expand Up @@ -204,7 +203,7 @@ def test_versioned_div_tensor_out(self):
def historic_div_out(self, other, out):
if self.is_floating_point() or other.is_floating_point() or out.is_floating_point():
return torch.true_divide(self, other, out=out)
return torch.floor_divide(self, other, out=out)
return torch.divide(self, other, out=out, rounding_mode='trunc')

class MyModule(torch.nn.Module):
def __init__(self):
Expand All @@ -218,8 +217,8 @@ def forward(self, a, b, out):
except Exception as e:
self.skipTest("Failed to load fixture!")

self._verify_count("aten::div", v3_module, 1) # true_divide aliases to div
self._verify_count("aten::floor_divide", v3_module, 1)
self._verify_count("aten::div", v3_module, 2) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', v3_module, 1) # rounding_mode argument

current_module = self._save_load_module(MyModule)
self._verify_count("aten::div", current_module, 1)
Expand Down Expand Up @@ -254,7 +253,7 @@ def historic_div_scalar_float(self, other: float):
def historic_div_scalar_int(self, other: int):
if self.is_floating_point():
return torch.true_divide(self, other)
return torch.floor_divide(self, other)
return torch.divide(self, other, rounding_mode='trunc')

class MyModuleFloat(torch.nn.Module):
def __init__(self):
Expand All @@ -277,8 +276,8 @@ def forward(self, a, b: int):
self.skipTest("Failed to load fixture!")

for m in (v3_module_float, v3_module_int):
self._verify_count("aten::div", m, 1) # true_divide aliases to div
self._verify_count("aten::floor_divide", m, 1)
self._verify_count("aten::div", m, 2) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument

current_module_float = self._save_load_module(MyModuleFloat)
current_module_int = self._save_load_module(MyModuleInt)
Expand Down Expand Up @@ -314,7 +313,7 @@ def historic_div_scalar_float_reciprocal(self, other: float):
def historic_div_scalar_int_reciprocal(self, other: int):
if self.is_floating_point():
return other / self
return other // self
return torch.divide(other, self, rounding_mode='trunc')

class MyModuleFloat(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -385,7 +384,7 @@ def historic_div_scalar_int_inplace(self, other: int):
if self.is_floating_point():
return self.true_divide_(other)

return self.floor_divide_(other)
return self.divide_(other, rounding_mode='trunc')

class MyModuleFloat(torch.nn.Module):
def __init__(self):
Expand All @@ -410,8 +409,8 @@ def forward(self, a, b: int):
self.skipTest("Failed to load fixture!")

for m in (v3_module_float, v3_module_int):
self._verify_count("aten::div", m, 1) # true_divide aliases to div
self._verify_count("aten::floor_divide", m, 1)
self._verify_count("aten::div_", m, 2) # true_divide and divide alias to div
self._verify_count('prim::Constant[value="trunc"]', m, 1) # rounding_mode argument

current_module_float = self._save_load_module(MyModuleFloat)
current_module_int = self._save_load_module(MyModuleInt)
Expand Down
96 changes: 58 additions & 38 deletions test/test_binary_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,8 @@ def test_inplace_dunders(self, device):
t -= 1
t *= 1
t /= 1
t //= 1
with self.maybeWarnsRegex(UserWarning, 'floor_divide'):
t //= 1
t %= 1
self.assertEqual(expected, t.data_ptr())

Expand Down Expand Up @@ -801,6 +802,17 @@ def test_add_with_tail(self, device, dtype):
def test_cross_device_binary_ops(self, devices):
vals = (1., (2.,))
cpu_tensor = torch.randn(2, 2)

def do_test(op, a, b):
with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
op(a, b)
with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
op(b, a)
with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
op(a, cpu_tensor)
with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
op(cpu_tensor, a)

for op in (operator.add, torch.add,
operator.sub, torch.sub,
operator.mul, torch.mul,
Expand All @@ -810,14 +822,11 @@ def test_cross_device_binary_ops(self, devices):
a = torch.tensor(a, device=devices[0])
b = torch.tensor(b, device=devices[1])

with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
op(a, b)
with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
op(b, a)
with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
op(a, cpu_tensor)
with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"):
op(cpu_tensor, a)
if op in (operator.floordiv, torch.floor_divide):
with self.maybeWarnsRegex(UserWarning, "floor_divide"):
do_test(op, a, b)
else:
do_test(op, a, b)

# This test ensures that a scalar Tensor can be safely used
# in a binary operation in conjunction with a Tensor on all
Expand Down Expand Up @@ -871,8 +880,9 @@ def _scalar_helper(python_op, torch_op):

_scalar_helper(operator.truediv, operator.truediv)
_scalar_helper(operator.truediv, torch.true_divide)
_scalar_helper(lambda a, b: math.trunc(a / b), operator.floordiv)
_scalar_helper(lambda a, b: math.trunc(a / b), torch.floor_divide)
with self.maybeWarnsRegex(UserWarning, 'floor_divide'):
_scalar_helper(lambda a, b: math.trunc(a / b), operator.floordiv)
_scalar_helper(lambda a, b: math.trunc(a / b), torch.floor_divide)

# NOTE: torch.floor_divide currently truncates instead of flooring.
# See https://github.com/pytorch/pytorch/issues/43874.
Expand Down Expand Up @@ -902,7 +912,8 @@ def _wrapped_floordiv(a, b):
b_t = torch.tensor(b, device=device)

self.assertEqual(scripted_div(a_t, b_t), expected_div)
self.assertEqual(scripted_floordiv(a_t, b_t), expected_truncdiv)
with self.maybeWarnsRegex(UserWarning, 'floor_divide'):
self.assertEqual(scripted_floordiv(a_t, b_t), expected_truncdiv)

# Creates jitted functions of one tensor
def _wrapped_div_scalar(a):
Expand Down Expand Up @@ -932,20 +943,22 @@ def _wrapped_rfloordiv_scalar(a):
a_t = torch.tensor(a, device=device)

self.assertEqual(a / 5, scripted_div_scalar(a_t))
self.assertEqual(math.trunc(a / 5), scripted_floordiv_scalar(a_t))
with self.maybeWarnsRegex(UserWarning, 'floor_divide'):
self.assertEqual(math.trunc(a / 5), scripted_floordiv_scalar(a_t))

# Skips zero divisors
if a == 0:
continue

self.assertEqual(5 / a, scripted_rdiv_scalar(a_t))

# Handles Issue 45199 (see comment above)
if a_t.is_floating_point():
with self.assertRaises(RuntimeError):
scripted_rfloordiv_scalar(a_t)
else:
self.assertEqual(5 // a, scripted_rfloordiv_scalar(a_t))
with self.maybeWarnsRegex(UserWarning, 'floor_divide'):
# Handles Issue 45199 (see comment above)
if a_t.is_floating_point():
with self.assertRaises(RuntimeError):
scripted_rfloordiv_scalar(a_t)
else:
self.assertEqual(5 // a, scripted_rfloordiv_scalar(a_t))

# NOTE: torch.floor_divide currently truncates instead of flooring
# the quotient. See https://github.com/pytorch/pytorch/issues/43874.
Expand Down Expand Up @@ -1039,23 +1052,26 @@ def _wrapped_ifloordiv_scalar(a):
if not a_t.is_floating_point() and b_t.is_floating_point():
# Inplace modification fails because a float tensor is required
# if the divisor is a float tensor
with self.assertRaises(RuntimeError):
with self.assertRaises(RuntimeError), self.maybeWarnsRegex(UserWarning, "floor_divide"):
a_t.clone().floor_divide_(b_t)
with self.assertRaises(RuntimeError):
with self.assertRaises(RuntimeError), self.maybeWarnsRegex(UserWarning, "floor_divide"):
scripted_floor_divide_tensor(a_t.clone(), b_t)
tmp = a_t.clone()
with self.assertRaises(RuntimeError):
with self.assertRaises(RuntimeError), self.maybeWarnsRegex(UserWarning, "floor_divide"):
tmp //= b_t
else:
# Inplace modification is OK when both or neither tensor is
# a float tensor
self.assertEqual(a_t.clone().floor_divide_(b_t).item(), expected_itruncdiv)
self.assertEqual(scripted_floor_divide__tensor(a_t.clone(), b_t).item(), expected_itruncdiv)
with self.maybeWarnsRegex(UserWarning, "floor_divide"):
self.assertEqual(a_t.clone().floor_divide_(b_t).item(), expected_itruncdiv)
self.assertEqual(scripted_floor_divide__tensor(a_t.clone(), b_t).item(), expected_itruncdiv)
tmp = a_t.clone()
tmp //= b_t
with self.maybeWarnsRegex(UserWarning, "floor_divide"):
tmp //= b_t
self.assertEqual(tmp.item(), expected_itruncdiv)

self.assertEqual(scripted_floor_divide__scalar(a_t), math.trunc(a / 5))
with self.maybeWarnsRegex(UserWarning, "floor_divide"):
self.assertEqual(scripted_floor_divide__scalar(a_t), math.trunc(a / 5))

# Tests binary op equivalence with Python builtin ops
# Also tests that reverse operations are equivalent to forward ops
Expand Down Expand Up @@ -1431,7 +1447,8 @@ def test_floor_divide_tensor(self, device, dtype):
x = torch.randn(10, device=device).mul(30).to(dtype)
y = torch.arange(1, 11, dtype=dtype, device=device)

z = x // y
with self.maybeWarnsRegex(UserWarning, "floor_divide"):
z = x // y
z_alt = torch.trunc(x.double() / y.double()).to(dtype)

self.assertEqual(z.dtype, x.dtype)
Expand All @@ -1442,7 +1459,8 @@ def test_floor_divide_tensor(self, device, dtype):
def test_floor_divide_scalar(self, device, dtype):
x = torch.randn(100, device=device).mul(10).to(dtype)

z = x // 3
with self.maybeWarnsRegex(UserWarning, "floor_divide"):
z = x // 3
z_alt = torch.tensor([math.trunc(v.item() / 3.) for v in x], dtype=x.dtype, device=device)

self.assertEqual(z.dtype, x.dtype)
Expand All @@ -1456,17 +1474,18 @@ def test_floor_divide_out(self, device, dtype):
y = torch.arange(1, 11, dtype=dtype, device=device)
o = torch.empty(10, dtype=dtype, device=device)

torch.floor_divide(x, y, out=o)
self.assertEqual(o, x // y)
with self.maybeWarnsRegex(UserWarning, "floor_divide"):
torch.floor_divide(x, y, out=o)
self.assertEqual(o, x // y)

# Tests scalar with out
torch.floor_divide(x, 2, out=o)
self.assertEqual(o, x // 2)
# Tests scalar with out
torch.floor_divide(x, 2, out=o)
self.assertEqual(o, x // 2)

if dtype == torch.int:
o = torch.empty(10, dtype=torch.float, device=device)
torch.floor_divide(x, y, out=o)
self.assertEqual(o, torch.floor_divide(x.float(), y.float()))
if dtype == torch.int:
o = torch.empty(10, dtype=torch.float, device=device)
torch.floor_divide(x, y, out=o)
self.assertEqual(o, torch.floor_divide(x.float(), y.float()))

@onlyCPU
@dtypes(*torch.testing.get_all_math_dtypes('cpu'))
Expand Down Expand Up @@ -1720,7 +1739,8 @@ def test_floor_divide_zero(self, device, dtype):
a = torch.tensor([0, 1], dtype=dtype, device=device)
b = torch.tensor([0, 1], dtype=dtype, device=device)
with self.assertRaisesRegex(RuntimeError, 'ZeroDivisionError'):
a // b
with self.maybeWarnsRegex(UserWarning, "floor_divide"):
a // b

@unittest.skipIf(TEST_WITH_ASAN, "Integer overflows are not allowed under ASAN")
@dtypes(*torch.testing.get_all_dtypes())
Expand Down
5 changes: 3 additions & 2 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5510,7 +5510,7 @@ def fn(x):
def test_integral_shape_inference(self):
cu = torch.jit.CompilationUnit('''
def test_integral_shape_inference(a):
return a // a
return a * a
''')
inputs = [torch.ones(10, 10, dtype=torch.long)]
outputs = torch.ones(10, 10)
Expand Down Expand Up @@ -11861,7 +11861,8 @@ def fn():
cu = torch.jit.CompilationUnit(funcs_str)
f_script = cu.fn
f = scope['fn']
self.assertEqual(f_script(), f())
with self.maybeWarnsRegex(UserWarning, "floor_divide"):
self.assertEqual(f_script(), f())

def test_call_python_fn_from_script_fn(self):
@torch.jit.ignore
Expand Down
Loading