Skip to content

stablehlo.uniform_quantize cannot be serialized to bytecode #1812

@lsy323

Description

@lsy323

What happened?

The MLIR module containing stablehlo.uniform_quantize/dequantize ops failed during bytecode serializing with error

loc("custom-call.6"): error: failed to legalize operation 'stablehlo.uniform_quantize' that was explicitly marked illegal

However, the MLIR module can be serialized to readable format

module @IrToHlo.18 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<10x3x3x3xf32>, %arg2: tensor<1x3x10x10xf32>) -> tensor<1x10x8x8xf32> {
    %0 = stablehlo.constant dense<0.000000e+00> : tensor<1x10x8x8xf32>
    %1 = stablehlo.uniform_quantize %arg2 : (tensor<1x3x10x10xf32>) -> tensor<1x3x10x10x!quant.uniform<i8:f32, 1.000000e+00>>
    %2 = stablehlo.uniform_dequantize %1 : (tensor<1x3x10x10x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<1x3x10x10xf32>
    %3 = stablehlo.uniform_quantize %arg1 : (tensor<10x3x3x3xf32>) -> tensor<10x3x3x3x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>
    %4 = stablehlo.uniform_dequantize %3 : (tensor<10x3x3x3x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>) -> tensor<10x3x3x3xf32>
    %5 = stablehlo.convolution(%2, %4) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [0, 0]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<1x3x10x10xf32>, tensor<10x3x3x3xf32>) -> tensor<1x10x8x8xf32>
    %6 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<10xf32>) -> tensor<1x10x8x8xf32>
    %7 = stablehlo.add %5, %6 : tensor<1x10x8x8xf32>
    %8 = stablehlo.maximum %7, %0 : tensor<1x10x8x8xf32>
    %9 = stablehlo.uniform_quantize %8 : (tensor<1x10x8x8xf32>) -> tensor<1x10x8x8x!quant.uniform<i8:f32, 1.000000e+00>>
    %10 = stablehlo.uniform_dequantize %9 : (tensor<1x10x8x8x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<1x10x8x8xf32>
    return %10 : tensor<1x10x8x8xf32>
  }
}

Steps to reproduce your issue

The StableHLO with uniform_quant/dequant op is generated from PyTorch -> PyTorch/XLA -> StableHLO. To reproduce the bug e2e requires changes in PyTorch/XLA and HLO->StableHLO converter. The change hasn't been merged to PyTorch/XLA head.(But will be merged soon as experimental feature) Please let me if repro e2e is needed.

This function is used to serialize stablehlo bytecode in PyTorch/XLA

Version information

StableHLO commit 46a2506

from openxla/xla: 51b59cfb1999c6f1b3ec59851675044b2c502aae

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions