Skip to content

[Script] Add list comprehension support in hidet script#235

Merged
yaoyaoding merged 9 commits intohidet-org:mainfrom
yaoyaoding:update-hidet-script
May 19, 2023
Merged

[Script] Add list comprehension support in hidet script#235
yaoyaoding merged 9 commits intohidet-org:mainfrom
yaoyaoding:update-hidet-script

Conversation

@yaoyaoding
Copy link
Copy Markdown
Member

Add comprehension expressions for list, set and dict

def test_list_comprehension():
    from hidet.lang import attrs, printf

    with hidet.script_module() as script_module:

        @hidet.script
        def func():
            attrs.func_kind = 'host_kernel'
            bs = 1.0 + 1
            shape = [bs, 3, 224, 224]
            a = [1, 2, 3]
            b = [i + 1 for i in range(3) if i != 2]
            c = [s / (i + 1) for i, s in enumerate(shape)]
            printf("%d %d %d\n", a[0], a[1], a[2])
            printf("%d\n", b[0])
            printf("%f\n", c[0])
            printf("%d\n", len(b))

    func = script_module.build()
    func()
def test_dict_comprehension():
    from hidet.lang import attrs, printf

    with hidet.script_module() as script_module:

        @hidet.script
        def func():
            attrs.func_kind = 'host_kernel'
            bs = 1.0 + 1
            shape = [bs, 3, 224, 224]
            a = {k: v for k, v in enumerate(shape)}
            b = {i: (j + 1) / 3 for i, j in enumerate(range(3)) if j != 2}
            printf("%d %d %d\n", a[0], a[1], a[2])
            printf("%f\n", b[1])
            printf("%d\n", len(b))

    func = script_module.build()
    func()

Allow pass a tuple or list to grid

    Usage 1: specify the grid dimensions with positional arguments.
      for i, j in grid(2, 3):
          printf("%d %d\n", i, j)

      for indices in grid(2, 3):
          printf("%d %d\n", indices[0], indices[1])

      for i in grid(2):
          printf("%d\n", i)

    Usage 2: specify the grid dimensions with a list or tuple.
      for indices in grid([2, 3]):
          printf("%d %d\n", indices[0], indices[1])

      for indices in grid([2]):  # indices is a tuple with one element
          printf("%d %d\n", indices[0])

    Usage 3: specify the loop attribute
      for i, j in grid(2, 3, attrs='up'):   # loop i is unrolled while loop j is parallelized
          printf("%d %d\n", i, j)
def demo_softmax(shape: List[int], axis: int):
    from hidet.lang import f32
    from hidet.lang import attrs
    from hidet.lang import tensor
    from hidet.lang import grid
    import math

    with hidet.script_module() as script_module:

        @hidet.script
        def kernel(x: f32[shape], y: f32[shape]):
            attrs.func_kind = 'host_kernel'
            spatial_shape = shape[:axis] + shape[axis + 1 :]
            reduce_extent = shape[axis]

            max_value = tensor('default', f32, shape=spatial_shape)  # max(x, axis)
            exp_value = tensor('default', f32, shape=shape)  # exp(x - max)
            sum_value = tensor('default', f32, shape=spatial_shape)  # sum(exp(x - max), axis)

            # max value
            for indices in grid(spatial_shape):
                max_value[indices] = -1e10
                for k in range(reduce_extent):
                    max_value[indices] = max(max_value[indices], x[indices[:axis] + (k,) + indices[axis:]])

            # exp(x - max)
            for indices in grid(shape):
                exp_value[indices] = math.exp(x[indices] - max_value[indices[:axis] + indices[axis + 1 :]])

            # sum(exp(x - max))
            for indices in grid(spatial_shape):
                sum_value[indices] = 0.0
                for k in range(reduce_extent):
                    sum_value[indices] += exp_value[indices[:axis] + (k,) + indices[axis:]]

            # exp(x - max) / sum(exp(x - max))
            for indices in grid(shape):
                y[indices] = exp_value[indices] / sum_value[indices[:axis] + indices[axis + 1 :]]

    func = script_module.build()
    x = hidet.randn(shape)
    y1 = hidet.ops.softmax(x, axis)
    y2 = hidet.empty(shape)
    func(x, y2)
    numpy.testing.assert_allclose(y1.numpy(), y2.numpy(), rtol=1e-5, atol=1e-5)

@yaoyaoding yaoyaoding merged commit ff2af86 into hidet-org:main May 19, 2023
@yaoyaoding yaoyaoding deleted the update-hidet-script branch May 19, 2023 04:27
vadiklyutiy pushed a commit that referenced this pull request Jul 22, 2024
Adding support for the operator `torch.as_tensor`, which was encountered
in #221

Also added more tests for `torch.argmax, torch.argmin` as discussed in
#234
vadiklyutiy pushed a commit that referenced this pull request Jul 23, 2024
Adding support for the operator `torch.as_tensor`, which was encountered
in #221

Also added more tests for `torch.argmax, torch.argmin` as discussed in
#234
vadiklyutiy pushed a commit that referenced this pull request Dec 26, 2024
Adding support for the operator `torch.as_tensor`, which was encountered
in #221

Also added more tests for `torch.argmax, torch.argmin` as discussed in
#234
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant