import torch
from functorch.experimental import functionalize
def check(fn, primals_size, primals_dtype, tangents_size, tangents_dtype):
primals = [torch.empty(size, dtype=dtype) for (size, dtype) in zip(primals_size, primals_dtype)]
if len(primals_dtype) > 2 and primals_dtype[1] == torch.int64:
primals[1] = torch.zeros(primals_size[1], dtype=primals_dtype[1])
tangents = [torch.empty(size, dtype=dtype) for (size, dtype) in zip(tangents_size, tangents_dtype)]
inputs = (*primals, *tangents)
ref = fn(*inputs)
res = functionalize(fn)(*inputs)
for t1, t2 in zip(ref, res):
assert torch.allclose(t1, t2)
########### Function - zero_ #####################
def fn1(tangents_1):
new_empty = torch.ops.aten.new_empty(tangents_1, [1, 3, 2, 10])
zero_ = torch.ops.aten.zero_(new_empty); new_empty = None
return (zero_, None)
primals_size = []
primals_dtype = []
tangents_size = [torch.Size([1, 3, 2, 1])]
tangents_dtype = [torch.float32]
check(fn1, [], [], tangents_size, tangents_dtype)
########### Function - zero_ #####################
def fn2(primals_1, primals_2, tangents_1):
exp = torch.ops.aten.exp(primals_1); primals_1 = None
gather = torch.ops.aten.gather(exp, 3, primals_2)
new_empty = torch.ops.aten.new_empty(tangents_1, [1, 3, 2, 10])
zero_ = torch.ops.aten.zero_(new_empty); new_empty = None
scatter_add_ = torch.ops.aten.scatter_add_(zero_, 3, primals_2, tangents_1); zero_ = primals_2 = tangents_1 = None
mul = torch.ops.aten.mul(scatter_add_, exp); scatter_add_ = exp = None
return (gather, mul, None)
# return pytree.tree_unflatten([gather, mul, None], self._out_spec)
primals_size = [torch.Size([1, 3, 2, 10]), torch.Size([1, 3, 2, 1])]
primals_dtype = [torch.float32, torch.int64]
tangents_size = [torch.Size([1, 3, 2, 1])]
tangents_dtype = [torch.float32]
# check(fn2, primals_size, primals_dtype, tangents_size, tangents_dtype)