Skip to content

Commit 60e75ca

Browse files
zhumakhanZhumakhan
authored andcommitted
Setitem with tensor values. And Boolean type promotion (#290)
- setitem with tensor values - boolean type promotion --------- Co-authored-by: Zhumakhan <nazirzhumakhan@gmail,.com>
1 parent d047440 commit 60e75ca

9 files changed

Lines changed: 132 additions & 19 deletions

File tree

python/hidet/graph/frontend/torch/register_functions.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,6 @@ def setitem(x: Tensor, item, setvalue):
269269
if not isinstance(item, tuple):
270270
item = tuple([item])
271271

272-
if not isinstance(setvalue, (int, float)):
273-
raise NotImplementedError('Currently Tensor __setitem__ only supports int or float values')
274-
275272
# now, the item could have
276273
# 1. integer index
277274
# 2. slice
@@ -321,7 +318,11 @@ def setitem(x: Tensor, item, setvalue):
321318
ends.append(v.stop)
322319
steps.append(v.step)
323320

324-
out = ops.set_strided_slice(x, starts, ends, steps, setvalue)
321+
if isinstance(setvalue, Tensor):
322+
squeeze_dims = [i for i, dimlen in enumerate(setvalue.shape) if dimlen == 1]
323+
setvalue = ops.squeeze(setvalue, squeeze_dims)
324+
325+
out = ops.set_strided_slice(x, setvalue, starts, ends, steps)
325326
return out
326327

327328

python/hidet/graph/ops/arithmetic.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,11 @@ class SetStridedSliceTask(Task):
192192
def __init__(
193193
self,
194194
data: TensorNode,
195+
setvalue: Union[int, float, TensorNode],
195196
starts: List[Optional[int]],
196197
ends: List[Optional[int]],
197198
axes: List[int],
198199
strides: List[int],
199-
setvalue: [Union[int, float]],
200200
):
201201
assert len(starts) == len(ends) == len(axes) == len(strides)
202202
if len(axes) != len(set(axes)):
@@ -216,17 +216,37 @@ def __init__(
216216
)
217217
axis2info[axis] = (start, end, stride)
218218

219+
if isinstance(setvalue, TensorNode):
220+
inputs = [data, setvalue]
221+
else:
222+
inputs = [data]
223+
219224
def fmap(indices):
220-
ret = data.type.dtype(setvalue)
225+
if isinstance(setvalue, TensorNode):
226+
ret = setvalue
227+
else:
228+
ret = data.type.dtype(setvalue)
229+
230+
new_val = True
231+
new_indices = []
221232
for axis, index in enumerate(indices):
222233
start, end, stride = axis2info[axis]
223-
ret = if_then_else(
224-
logical_or(index < start, index >= end, (index - start) % stride != 0), data[indices], ret
234+
new_val = if_then_else(
235+
logical_or(index < start, index >= end, (index - start) % stride != 0), False, new_val
225236
)
226-
return ret
237+
if start + 1 < end:
238+
new_indices.append((index - start) // stride)
239+
240+
if isinstance(setvalue, TensorNode):
241+
if len(new_indices) != 0:
242+
ret = ret[new_indices]
243+
else:
244+
ret = ret[(0,) * ret.ndim]
245+
246+
return if_then_else(new_val, ret, data[indices])
227247

228248
out = compute('out', shape=output_shape, fcompute=lambda *indices: fmap(indices))
229-
super().__init__(name='set_slice', inputs=[data], outputs=[out])
249+
super().__init__(name='set_slice', inputs=inputs, outputs=[out])
230250

231251

232252
class RollTask(Task):
@@ -702,18 +722,29 @@ class SetStridedSliceOp(Operator):
702722
def __init__(
703723
self,
704724
data: Tensor,
725+
setvalue: Optional[Union[int, float, Tensor]],
705726
starts: Sequence[Optional[int]],
706727
ends: Sequence[Optional[int]],
707728
strides: Optional[Sequence[Optional[int]]] = None,
708-
setvalue: Optional[Union[int, float]] = 0.0,
709729
):
730+
if setvalue is None:
731+
setvalue = 0.0
732+
710733
starts, ends, axes, strides = normalize_slice(data.shape, starts, ends, axes=None, strides=strides)
711-
task = SetStridedSliceTask(input_like(data, 'data'), starts, ends, axes, strides, setvalue)
712-
super().__init__(
713-
inputs=[data],
714-
attributes={'starts': starts, 'ends': ends, 'strides': strides, 'setvalue': setvalue},
715-
task=task,
716-
)
734+
735+
attributes = {'starts': starts, 'ends': ends, 'strides': strides}
736+
737+
if isinstance(setvalue, Tensor):
738+
task = SetStridedSliceTask(
739+
input_like(data, 'data'), input_like(setvalue, 'setvalue'), starts, ends, axes, strides
740+
)
741+
inputs = [data, setvalue]
742+
else:
743+
task = SetStridedSliceTask(input_like(data, 'data'), setvalue, starts, ends, axes, strides)
744+
inputs = [data]
745+
attributes['setvalue'] = setvalue
746+
747+
super().__init__(inputs=inputs, attributes=attributes, task=task)
717748

718749

719750
class RollOp(Operator):
@@ -1070,9 +1101,9 @@ def composite_elementwise(
10701101

10711102
def set_strided_slice(
10721103
data: Tensor,
1104+
setvalue: Optional[Union[int, float]],
10731105
starts: Sequence[Optional[int]],
10741106
ends: Sequence[Optional[int]],
10751107
strides: Optional[Sequence[Optional[int]]] = None,
1076-
setvalue: Optional[Union[int, float]] = 0.0,
10771108
) -> Tensor:
1078-
return SetStridedSliceOp(data, starts, ends, strides, setvalue).outputs[0]
1109+
return SetStridedSliceOp(data, setvalue, starts, ends, strides).outputs[0]

python/hidet/ir/dtypes/boolean.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ def is_complex(self) -> bool:
3535
def is_vector(self) -> bool:
3636
return False
3737

38+
def is_boolean(self) -> bool:
39+
return True
40+
3841
def constant(self, value: Any):
3942
from hidet.ir.expr import constant
4043

python/hidet/ir/dtypes/complex.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ def is_vector(self) -> bool:
3535
def is_complex(self) -> bool:
3636
return True
3737

38+
def is_boolean(self) -> bool:
39+
return False
40+
3841
def constant(self, value: Any):
3942
from hidet.ir.expr import Constant, constant
4043

python/hidet/ir/dtypes/floats.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ def is_complex(self) -> bool:
5151
def is_vector(self) -> bool:
5252
return False
5353

54+
def is_boolean(self) -> bool:
55+
return False
56+
5457
def constant(self, value: Any):
5558
from hidet.ir.expr import Constant, constant
5659

python/hidet/ir/dtypes/integer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ def is_complex(self) -> bool:
4545
def is_vector(self) -> bool:
4646
return False
4747

48+
def is_boolean(self) -> bool:
49+
return False
50+
4851
def constant(self, value: Any):
4952
from hidet.ir.expr import Constant, constant
5053

python/hidet/ir/dtypes/promotion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ def promote_type(t1: DataType, t2: DataType) -> DataType:
131131
return t2
132132
elif t1.is_float() and t2.is_integer():
133133
return t1
134+
elif t1.is_boolean() and t2.is_boolean():
135+
return t1
134136
else:
135137
pair = (t1, t2)
136138
if pair not in _promotion_table:

python/hidet/ir/type.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ def is_complex(self) -> bool:
174174
def is_vector(self) -> bool:
175175
raise NotImplementedError()
176176

177+
def is_boolean(self) -> bool:
178+
raise NotImplementedError()
179+
177180
def constant(self, value: Any):
178181
raise NotImplementedError()
179182

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
import pytest
13+
import torch
14+
from hidet.testing.torch_utils import check_module, FunctionalModule
15+
16+
17+
@pytest.mark.parametrize(
18+
"a_shape, b_shape, indices",
19+
[
20+
[[10, 11, 12, 13], [10, 9, 8], (slice(None), slice(9), slice(8), 0)],
21+
[[5, 4, 3, 2], [2, 2], (slice(2, 4), slice(2, 4), 0, 1)],
22+
[[1, 1, 1024], [10], (0, 0, slice(1000, 1010))],
23+
[[4, 4, 4, 4], [1, 1, 1], (slice(1), 0, slice(1), slice(1))],
24+
[[4, 4, 4, 4], [2, 3], (3, slice(2), 2, slice(3))],
25+
[[4, 4, 4, 4], [2, 3], (slice(2), 0, slice(3), 0)],
26+
[[4, 4, 4, 4], [4, 4], (0, Ellipsis, 0)],
27+
[[10, 10], [10, 10], (Ellipsis,)],
28+
[[1, 3, 28, 28, 85], [1, 3, 28, 28, 2], (Ellipsis, slice(2))],
29+
],
30+
)
31+
def test_setitem_with_tensor(a_shape, b_shape, indices):
32+
def check_setitem(x, y, indices):
33+
x[indices] = y
34+
return x
35+
36+
check_module(
37+
FunctionalModule(op=check_setitem), args=[torch.randn(a_shape), torch.randn(b_shape), indices], atol=0, rtol=0
38+
)
39+
40+
41+
@pytest.mark.parametrize(
42+
"a_shape, setvalue, indices",
43+
[
44+
[[10, 11, 12, 13], 1.0, (slice(None), slice(9), slice(8), 0)],
45+
[[5, 4, 3, 2], 1.0, (slice(2, 4), slice(2, 4), 0, 1)],
46+
[[1, 1, 1024], 1.0, (0, 0, slice(1000, 1010))],
47+
[[4, 4, 4, 4], 1.0, (slice(1), 0, slice(1), slice(1))],
48+
[[4, 4, 4, 4], 1.0, (3, slice(2), 2, slice(3))],
49+
[[4, 4, 4, 4], 1.0, (slice(2), 0, slice(3), 0)],
50+
[[4, 4, 4, 4], 1.0, (0, Ellipsis, 0)],
51+
[[10, 10], 1.0, (Ellipsis,)],
52+
[[1, 3, 28, 28, 85], 1.0, (Ellipsis, slice(2))],
53+
],
54+
)
55+
def test_setitem_with_scalar(a_shape, setvalue, indices):
56+
def check_setitem(x, setvalue, indices):
57+
x[indices] = setvalue
58+
return x
59+
60+
check_module(FunctionalModule(op=check_setitem), args=[torch.randn(a_shape), setvalue, indices], atol=0, rtol=0)
61+
62+
63+
if __name__ == '__main__':
64+
pytest.main([__file__])

0 commit comments

Comments
 (0)