Skip to content

Deprecate explicit shape promotion in PyTorch/XLA#6121

Merged
qihqi merged 1 commit intomasterfrom
sdasgup3/deprecate-implicit-broadcasting
Dec 12, 2023
Merged

Deprecate explicit shape promotion in PyTorch/XLA#6121
qihqi merged 1 commit intomasterfrom
sdasgup3/deprecate-implicit-broadcasting

Conversation

@sdasgup3
Copy link
Copy Markdown
Collaborator

Proposal

Currently, both PyTorch/XLA and XLA handle implicit broadcasting for static/bounded-dynamic shapes. The implicit broadcasting in PyTorch/XLA’s codebase is not uniform: sometimes rely on their own implementation and sometimes on XLA’s implementation. The goal of this document is to investigate removal of implicit-broadcasting support in PyTorch/XLA so as to have a single source of truth in XLA.

Current state of Broadcasting in PyTorch/XLA vs XLA

Implicit broadcasting, for static and bounded dynamic dimensions, are supported in both PyTorch/XLA and XLA codebase.

PyTorch/XLA

XlaHelpers::Promote is used to do the implicit broadcasting using the following utility functions.

However, the handling of implicit broadcasting in PyTorch/XLA codebase is not uniform: some are explicit (e.g. max and arith via XlaHelpers::Promote), while others rely on XlaBuilder’s implicit broadcasting (e.g. in Relu xla::Max is called without any promotion).

XLA

HLO currently handles implicit broadcasting for binary operations, ternary operation, and map for static shapes and bounded dynamic shapes in three steps:

Here are some important points to note

  1. Currently, PyTorch/XLA calls XlaHelpers::Promote before calling the XLaBuilder APIs, for broadcastable operations, ensuring the operands shapes are already promoted and thereby preventing the XLA’s support for implicit broadcasting unexercised.
  2. Current both PyTorch/XLA and XLA support "same rank broadcasting" and "scalar broadcasting" with identical semantics (refer to Appendix for more information). For "different rank broadcasting" XLA explicitly needs a broadcasting_dimension specification (refer to xla-broadcasting-principles).

Design Proposal

Given that the HLO/StableHLO emitted via implicit broadcasting in PyTorch/XLA and XLA are semantically identical (refer to Appendix), we propose to handle broadcasting only in XLA with the following changes

  • Prevent XlaHelpers::PromoteShapes to do implicit broadcasting.
  • Update each call site of XlaHelpers::Promote such that the following XLA client API calls (like xla::Min, xla::Atan2 etc.) accept broadcasting dimensions.
    • Given that PyTorch supports numpy based implicit-broadcasting rules, it is possible to derive the broadcast_dimensions just using the operand shapes.
std::vector<int64_t> XlaHelpers::getBroadcastDimensions(xla::XlaOp op1,
                                                        xla::XlaOp op2) {
  const xla::Shape& shape1 = ShapeHelper::ShapeOfXlaOp(op1);
  const xla::Shape& shape2 = ShapeHelper::ShapeOfXlaOp(op2);
  if (shape1.rank() == 0 || shape2.rank() == 0 ||
      shape1.rank() == shape2.rank())
    return {};

  std::vector<int64_t> broadcast_dimensions(
      shape1.rank() <= shape2.rank() ? shape1.rank() : shape2.rank());
  std::iota(broadcast_dimensions.begin(), broadcast_dimensions.end(),
            std::abs(shape1.rank() - shape2.rank()));
  return broadcast_dimensions;
}

Appendix

Examples demonstrating the current state

Next we demonstrate the current state of implicit broadcasting in PyTorch/XLA and XLA codebase. To extract the XLA behavior we need to bypass the implicit broadcasting in PyTorch/XLA code using the patch.

Same rank broadcasting

PyTorch code
import torch
import torch_xla
from torch_xla.core import xla_model as xm
from typing import Tuple, Type, Callable, Union, List

device = xm.xla_device()

## multiply (same axis dynamic)
a = torch.randn((10,1)).to(device=device)
b = torch.randn((1, 5)).to(device=device)


c = a * b
hlo_content = torch_xla._XLAC._get_xla_tensors_hlo([c])
print(hlo_content)
print(xm.get_stablehlo([c]))
Broadcasting via PyTorch/XLA
//HLO
HloModule IrToHlo.11, entry_computation_layout={(f32[1,5]{1,0}, f32[10,1]{1,0})->(f32[10,5]{1,0})}

ENTRY %IrToHlo.11 (p0.1: f32[1,5], p1.2: f32[10,1]) -> (f32[10,5]) {
  %p1.2 = f32[10,1]{1,0} parameter(1)
  %broadcast.6 = f32[10,1]{1,0} broadcast(f32[10,1]{1,0} %p1.2), dimensions={0,1}
  %reshape.7 = f32[10]{0} reshape(f32[10,1]{1,0} %broadcast.6)
  %broadcast.8 = f32[10,5]{1,0} broadcast(f32[10]{0} %reshape.7), dimensions={0}
  %p0.1 = f32[1,5]{1,0} parameter(0)
  %broadcast.3 = f32[1,5]{1,0} broadcast(f32[1,5]{1,0} %p0.1), dimensions={0,1}
  %reshape.4 = f32[5]{0} reshape(f32[1,5]{1,0} %broadcast.3)
  %broadcast.5 = f32[10,5]{1,0} broadcast(f32[5]{0} %reshape.4), dimensions={1}
  %multiply.9 = f32[10,5]{1,0} multiply(f32[10,5]{1,0} %broadcast.8, f32[10,5]{1,0} %broadcast.5)
  ROOT %tuple.10 = (f32[10,5]{1,0}) tuple(f32[10,5]{1,0} %multiply.9)
}

// Stablehlo
module @IrToHlo.11 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<1x5xf32>, %arg1: tensor<10x1xf32>) -> tensor<10x5xf32> {
    %0 = stablehlo.reshape %arg1 : (tensor<10x1xf32>) -> tensor<10xf32>
    %1 = stablehlo.broadcast_in_dim %0, dims = [0] : (tensor<10xf32>) -> tensor<10x5xf32>
    %2 = stablehlo.reshape %arg0 : (tensor<1x5xf32>) -> tensor<5xf32>
    %3 = stablehlo.broadcast_in_dim %2, dims = [1] : (tensor<5xf32>) -> tensor<10x5xf32>
    %4 = stablehlo.multiply %1, %3 : tensor<10x5xf32>
    return %4 : tensor<10x5xf32>
  }
}
Broadcasting via XLA
//HLO
HloModule IrToHlo.9, entry_computation_layout={(f32[1,5]{1,0}, f32[10,1]{1,0})->(f32[10,5]{1,0})}

ENTRY %IrToHlo.9 (p0.1: f32[1,5], p1.2: f32[10,1]) -> (f32[10,5]) {
  %p1.2 = f32[10,1]{1,0} parameter(1)
  %reshape.3 = f32[10]{0} reshape(f32[10,1]{1,0} %p1.2)
  %broadcast.4 = f32[10,5]{1,0} broadcast(f32[10]{0} %reshape.3), dimensions={0}
  %p0.1 = f32[1,5]{1,0} parameter(0)
  %reshape.5 = f32[5]{0} reshape(f32[1,5]{1,0} %p0.1)
  %broadcast.6 = f32[10,5]{1,0} broadcast(f32[5]{0} %reshape.5), dimensions={1}
  %multiply.7 = f32[10,5]{1,0} multiply(f32[10,5]{1,0} %broadcast.4, f32[10,5]{1,0} %broadcast.6)
  ROOT %tuple.8 = (f32[10,5]{1,0}) tuple(f32[10,5]{1,0} %multiply.7)
}

// Stablehlo
Same as above

Scalar Broadcasting

PyTorch code
import torch
import torch_xla
from torch_xla.core import xla_model as xm
from typing import Tuple, Type, Callable, Union, List

device = xm.xla_device()

## multiply (same axis dynamic)
a = torch.randn(()).to(device=device)
b = torch.randn((1, 5)).to(device=device)


c = a * b
hlo_content = torch_xla._XLAC._get_xla_tensors_hlo([c])
print(hlo_content)
print(xm.get_stablehlo([c]))
Broadcasting via PyTorch/XLA
//HLO
HloModule IrToHlo.6, entry_computation_layout={(f32[1,5]{1,0}, f32[])->(f32[1,5]{1,0})}

ENTRY %IrToHlo.6 (p0.1: f32[1,5], p1.2: f32[]) -> (f32[1,5]) {
  %p1.2 = f32[] parameter(1)
  %broadcast.3 = f32[1,5]{1,0} broadcast(f32[] %p1.2), dimensions={}
  %p0.1 = f32[1,5]{1,0} parameter(0)
  %multiply.4 = f32[1,5]{1,0} multiply(f32[1,5]{1,0} %broadcast.3, f32[1,5]{1,0} %p0.1)
  ROOT %tuple.5 = (f32[1,5]{1,0}) tuple(f32[1,5]{1,0} %multiply.4)
}

// Stablehlo
module @IrToHlo.6 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<1x5xf32>, %arg1: tensor<f32>) -> tensor<1x5xf32> {
    %0 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<f32>) -> tensor<1x5xf32>
    %1 = stablehlo.multiply %0, %arg0 : tensor<1x5xf32>
    return %1 : tensor<1x5xf32>
  }
}
Broadcasting via XLA
//HLO
// same as above

// Stablehlo
Same as above

Different rank broadcasting

Pytorch code
import torch
import torch_xla
from torch_xla.core import xla_model as xm
from typing import Tuple, Type, Callable, Union, List

device = xm.xla_device()

## multiply (same axis dynamic)
a = torch.randn((10,1)).to(device=device)
b = torch.randn((6, 8, 1, 5)).to(device=device)
c = a * b
hlo_content = torch_xla._XLAC._get_xla_tensors_hlo([c])
print(hlo_content)
print(xm.get_stablehlo([c]))
Broadcasting via PyTorch/XLA
// HLO
HloModule IrToHlo.12, entry_computation_layout={(f32[6,8,1,5]{3,2,1,0}, f32[10,1]{1,0})->(f32[6,8,10,5]{3,2,1,0})}

ENTRY %IrToHlo.12 (p0.1: f32[6,8,1,5], p1.2: f32[10,1]) -> (f32[6,8,10,5]) {
  %p1.2 = f32[10,1]{1,0} parameter(1)
  %broadcast.6 = f32[10,1]{1,0} broadcast(f32[10,1]{1,0} %p1.2), dimensions={0,1}
  %reshape.7 = f32[10]{0} reshape(f32[10,1]{1,0} %broadcast.6)
  %broadcast.8 = f32[10,5]{1,0} broadcast(f32[10]{0} %reshape.7), dimensions={0}
  %broadcast.9 = f32[6,8,10,5]{3,2,1,0} broadcast(f32[10,5]{1,0} %broadcast.8), dimensions={2,3}
  %p0.1 = f32[6,8,1,5]{3,2,1,0} parameter(0)
  %broadcast.3 = f32[6,8,1,5]{3,2,1,0} broadcast(f32[6,8,1,5]{3,2,1,0} %p0.1), dimensions={0,1,2,3}
  %reshape.4 = f32[6,8,5]{2,1,0} reshape(f32[6,8,1,5]{3,2,1,0} %broadcast.3)
  %broadcast.5 = f32[6,8,10,5]{3,2,1,0} broadcast(f32[6,8,5]{2,1,0} %reshape.4), dimensions={0,1,3}
  %multiply.10 = f32[6,8,10,5]{3,2,1,0} multiply(f32[6,8,10,5]{3,2,1,0} %broadcast.9, f32[6,8,10,5]{3,2,1,0} %broadcast.5)
  ROOT %tuple.11 = (f32[6,8,10,5]{3,2,1,0}) tuple(f32[6,8,10,5]{3,2,1,0} %multiply.10)
}

// Stablehlo
module @IrToHlo.12 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<6x8x1x5xf32>, %arg1: tensor<10x1xf32>) -> tensor<6x8x10x5xf32> {
    %0 = stablehlo.reshape %arg1 : (tensor<10x1xf32>) -> tensor<10xf32>
    %1 = stablehlo.broadcast_in_dim %0, dims = [2] : (tensor<10xf32>) -> tensor<6x8x10x5xf32>
    %2 = stablehlo.reshape %arg0 : (tensor<6x8x1x5xf32>) -> tensor<6x8x5xf32>
    %3 = stablehlo.broadcast_in_dim %2, dims = [0, 1, 3] : (tensor<6x8x5xf32>) -> tensor<6x8x10x5xf32>
    %4 = stablehlo.multiply %1, %3 : tensor<6x8x10x5xf32>
    return %4 : tensor<6x8x10x5xf32>
  }
}
Broadcasting via XLA

XLA fails as broadcast_dimensions has to be specified when ranks are different.

@sdasgup3 sdasgup3 requested review from lsy323 and qihqi December 12, 2023 19:30
@sdasgup3 sdasgup3 changed the title Removal of explicit promotion in PT/XLA Deprecate explicit shape promotion in PyTorch/XLA Dec 12, 2023
Copy link
Copy Markdown
Collaborator

@qihqi qihqi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, this is great!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants