-
Notifications
You must be signed in to change notification settings - Fork 27.4k
ONNX dynamo export writes cubic_coeff_a=-0.75 for bicubic antialias=True (should be -0.5) #177138
Description
🐛 Describe the bug
ONNX dynamo export writes cubic_coeff_a=-0.75 for bicubic antialias=True (should be -0.5)
Bug
When exporting F.interpolate(mode='bicubic', antialias=True) to ONNX via the dynamo exporter, the Resize node is written with cubic_coeff_a=-0.75. However, PyTorch internally uses cubic_coeff_a=-0.5 (Keys interpolation) when antialias=True, as documented in the source:
// aten/src/ATen/native/cpu/UpSampleKernel.cpp, line ~1347
// We are using -0.5 for bicubic, antialiasing=true (compatibility with PIL)
// and using -0.75 for bicubic, antialiasing=false (compatibility with Opencv)
constexpr scalar_t a = use_keys_cubic ? -0.5 : -0.75;The exported ONNX model therefore produces different results than PyTorch when run in ONNX Runtime (or any runtime that correctly respects the cubic_coeff_a attribute).
The -0.75 value was originally hardcoded in PR #24805 for the non-antialias case and was carried forward without accounting for the antialias path. The distinction between -0.5 (Keys, PIL-compatible) and -0.75 (OpenCV-compatible) based on the antialias flag was introduced in the ATen kernels via pytorch/vision#3810 and #68819.
The legacy TorchScript exporter does not support antialias=True at all (UnsupportedOperatorError), so this only affects the dynamo exporter.
To reproduce
import io
import numpy as np
import onnx
import onnxruntime as ort
import torch
import torch.nn as nn
import torch.nn.functional as F
class BicubicAA(nn.Module):
def forward(self, x):
return F.interpolate(x, size=[224, 224], mode="bicubic",
align_corners=False, antialias=True)
# Export
model = BicubicAA()
model.eval()
x = torch.rand(1, 3, 800, 600)
buf = io.BytesIO()
torch.onnx.export(model, (x,), buf, opset_version=18, dynamo=True)
buf.seek(0)
onnx_model = onnx.load(buf)
# Inspect: cubic_coeff_a is -0.75 (wrong for antialias=True)
for node in onnx_model.graph.node:
if node.op_type == "Resize":
for attr in node.attribute:
if attr.name == "cubic_coeff_a":
print(f"Exported cubic_coeff_a = {attr.f}") # -0.75
if attr.name == "antialias":
print(f"Exported antialias = {attr.i}") # 1
# Numerical impact
with torch.no_grad():
pt_out = model(x).numpy()
buf.seek(0)
sess = ort.InferenceSession(buf.read())
ort_wrong = sess.run(None, {"x": x.numpy()})[0]
# Patch to correct value and re-run
for node in onnx_model.graph.node:
if node.op_type == "Resize":
for attr in node.attribute:
if attr.name == "cubic_coeff_a":
attr.f = -0.5
buf2 = io.BytesIO()
onnx.save(onnx_model, buf2)
buf2.seek(0)
sess2 = ort.InferenceSession(buf2.read())
ort_fixed = sess2.run(None, {"x": x.numpy()})[0]
print(f"PyTorch vs ONNX (exported a=-0.75): mean={np.abs(ort_wrong - pt_out).mean():.2e}")
print(f"PyTorch vs ONNX (patched a=-0.50): mean={np.abs(ort_fixed - pt_out).mean():.2e}")Output:
Exported cubic_coeff_a = -0.75
Exported antialias = 1
PyTorch vs ONNX (exported a=-0.75): mean=5.31e-03
PyTorch vs ONNX (patched a=-0.50): mean=1.67e-04
Patching cubic_coeff_a to -0.5 reduces mean error by 32x, confirming that PyTorch uses -0.5 at runtime but the exporter writes -0.75.
Expected behavior
When antialias=True, the ONNX Resize node should be exported with cubic_coeff_a=-0.5 to match PyTorch's runtime behavior. When antialias=False, cubic_coeff_a=-0.75 is correct.
Versions
Collecting environment information...
PyTorch version: 2.10.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 4.2.3
Libc version: glibc-2.31
Python version: 3.12.12 (main, Feb 3 2026, 22:51:04) [Clang 21.1.4 ] (64-bit runtime)
Python platform: Linux-5.4.0-208-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.2.152
CUDA_MODULE_LOADING set to:
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
GPU 2: NVIDIA A100-SXM4-40GB
GPU 3: NVIDIA A100-SXM4-40GB
GPU 4: NVIDIA A100-SXM4-40GB
GPU 5: NVIDIA A100-SXM4-40GB
GPU 6: NVIDIA A100-SXM4-40GB
GPU 7: NVIDIA A100-SXM4-40GB
Nvidia driver version: 565.57.01
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 43 bits physical, 48 bits virtual
CPU(s): 256
On-line CPU(s) list: 0-255
Thread(s) per core: 2
Core(s) per socket: 64
Socket(s): 2
NUMA node(s): 8
Vendor ID: AuthenticAMD
CPU family: 23
Model: 49
Model name: AMD EPYC 7742 64-Core Processor
Stepping: 0
Frequency boost: enabled
CPU MHz: 3161.415
CPU max MHz: 2250.0000
CPU min MHz: 1500.0000
BogoMIPS: 4491.50
Virtualization: AMD-V
L1d cache: 4 MiB
L1i cache: 4 MiB
L2 cache: 64 MiB
L3 cache: 512 MiB
NUMA node0 CPU(s): 0-15,128-143
NUMA node1 CPU(s): 16-31,144-159
NUMA node2 CPU(s): 32-47,160-175
NUMA node3 CPU(s): 48-63,176-191
NUMA node4 CPU(s): 64-79,192-207
NUMA node5 CPU(s): 80-95,208-223
NUMA node6 CPU(s): 96-111,224-239
NUMA node7 CPU(s): 112-127,240-255
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Vulnerable
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP conditional; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif umip rdpid overflow_recov succor smca sme sev sev_es
Versions of relevant libraries:
[pip3] Could not collect
[conda] Could not collect
onnx: 1.20.1
onnxruntime: 1.24.3
onnxscript: 0.6.2
cc @justinchuby @titaiwangms @chauhang @penguinwu @avikchaudhuri @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4
Metadata
Metadata
Assignees
Labels
Type
Projects
Status