Skip to content

Add functions to emit custom call to place a buffer to host and device.#8350

Merged
tengyifei merged 2 commits intomasterfrom
hanq_host_offload
Nov 4, 2024
Merged

Add functions to emit custom call to place a buffer to host and device.#8350
tengyifei merged 2 commits intomasterfrom
hanq_host_offload

Conversation

@qihqi
Copy link
Copy Markdown
Collaborator

@qihqi qihqi commented Nov 1, 2024

This is used for host-offloading.

example code of what jax emits:

def policy(prim, *avals, **params) -> Offloadable:
  return Offloadable(src='device', dst='pinned_host')

@functools.partial(jax.remat, policy=policy)
def f(x):
  x = jnp.sin(x)
  x = jnp.sin(x)
  return jnp.sum(x)

becomes:

module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<16xf32> {mhlo.layout_mode = "default"}) -> (tensor<16xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.sine %arg0 : tensor<16xf32>
    %1 = stablehlo.cosine %arg0 : tensor<16xf32>
    %2 = stablehlo.custom_call @annotate_device_placement(%1) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<16xf32>) -> tensor<16xf32>
    %3 = stablehlo.cosine %0 : tensor<16xf32>
    %4 = stablehlo.custom_call @annotate_device_placement(%3) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<16xf32>) -> tensor<16xf32>
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %5:3 = stablehlo.optimization_barrier %2, %4, %cst : tensor<16xf32>, tensor<16xf32>, tensor<f32>
    %6 = stablehlo.custom_call @annotate_device_placement(%5#0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "device"}} : (tensor<16xf32>) -> tensor<16xf32>
    %7 = stablehlo.custom_call @annotate_device_placement(%5#1) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "device"}} : (tensor<16xf32>) -> tensor<16xf32>
    %8 = stablehlo.broadcast_in_dim %5#2, dims = [] : (tensor<f32>) -> tensor<16xf32>
    %9 = stablehlo.multiply %8, %7 : tensor<16xf32>
    %10 = stablehlo.multiply %9, %6 : tensor<16xf32>
    return %10 : tensor<16xf32>
  }
}

@qihqi qihqi requested a review from tengyifei November 1, 2024 23:00
This is used for host-offloading.

example code of what jax emits:
```python
def policy(prim, *avals, **params) -> Offloadable:
  return Offloadable(src='device', dst='pinned_host')

@functools.partial(jax.remat, policy=policy)
def f(x):
  x = jnp.sin(x)
  x = jnp.sin(x)
  return jnp.sum(x)
```

becomes:
```mlir
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<16xf32> {mhlo.layout_mode = "default"}) -> (tensor<16xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.sine %arg0 : tensor<16xf32>
    %1 = stablehlo.cosine %arg0 : tensor<16xf32>
    %2 = stablehlo.custom_call @annotate_device_placement(%1) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<16xf32>) -> tensor<16xf32>
    %3 = stablehlo.cosine %0 : tensor<16xf32>
    %4 = stablehlo.custom_call @annotate_device_placement(%3) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "pinned_host"}} : (tensor<16xf32>) -> tensor<16xf32>
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %5:3 = stablehlo.optimization_barrier %2, %4, %cst : tensor<16xf32>, tensor<16xf32>, tensor<f32>
    %6 = stablehlo.custom_call @annotate_device_placement(%5#0) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "device"}} : (tensor<16xf32>) -> tensor<16xf32>
    %7 = stablehlo.custom_call @annotate_device_placement(%5#1) {has_side_effect = true, mhlo.frontend_attributes = {_xla_buffer_placement = "device"}} : (tensor<16xf32>) -> tensor<16xf32>
    %8 = stablehlo.broadcast_in_dim %5#2, dims = [] : (tensor<f32>) -> tensor<16xf32>
    %9 = stablehlo.multiply %8, %7 : tensor<16xf32>
    %10 = stablehlo.multiply %9, %6 : tensor<16xf32>
    return %10 : tensor<16xf32>
  }
}
```
@qihqi qihqi force-pushed the hanq_host_offload branch from 96f0d78 to fc408f2 Compare November 1, 2024 23:02
Comment thread test/stablehlo/test_stablehlo_custom_call.py Outdated
@tengyifei
Copy link
Copy Markdown
Collaborator

Need to format python files to pass linter

@tengyifei tengyifei merged commit a19a996 into master Nov 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants