Skip to content

Added bicubic support for interpolation with AA#3810

Merged
fmassa merged 5 commits intopytorch:masterfrom
Quansight:vfdev-5/interpolate-aa-cpu-more-modes
May 13, 2021
Merged

Added bicubic support for interpolation with AA#3810
fmassa merged 5 commits intopytorch:masterfrom
Quansight:vfdev-5/interpolate-aa-cpu-more-modes

Conversation

@vfdev-5
Copy link
Copy Markdown
Contributor

@vfdev-5 vfdev-5 commented May 11, 2021

Description:

Note:

@vfdev-5 vfdev-5 requested a review from fmassa May 11, 2021 12:48
Copy link
Copy Markdown
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks!

I think this is good to merge. Can you show me what type of differences in interpolation we have between different methods?

// taken from
// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
// src/libImaging/Resample.c#L46-L62
static inline scalar_t _filter(scalar_t x) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the future: check if this is equivalent / similar to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/UpSample.h#L324-L332

Comment on lines +561 to +563
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic_aa"),
TORCH_FN(interpolate_bicubic_aa_forward_kernel));
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up to you, but another option would be to handle this all inside a single function in C++, and don't expose multiple variants to be dispatched on the python side of things.

# High value is mostly required for test cases with
# downsampling and upsampling where we can not exactly
# match PIL implementation.
accepted_tol = 15.0
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you share some image examples with me of the image difference between PIL and your implementation?

Something like plt.imshow(pil_interp - tv_interp) so that I can see what types of differences we are seeing here?

@vfdev-5
Copy link
Copy Markdown
Contributor Author

vfdev-5 commented May 12, 2021

Images for the interpolation with AA for bilinear and bicubic modes:
(better visualization with GH light mode)

  • BILINEAR

image
image

  • BICUBIC

image
image

@fmassa
Copy link
Copy Markdown
Member

fmassa commented May 12, 2021

Thanks! I think the differences might be due to PIL computing the bicubic interpolation with integers while we use floats, so they might have more rounding errors due to the more expensive computations.

Let's get this merged, can you rebase the PR?

@fmassa fmassa merged commit 0fd0f50 into pytorch:master May 13, 2021
@vfdev-5 vfdev-5 deleted the vfdev-5/interpolate-aa-cpu-more-modes branch May 13, 2021 09:42
facebook-github-bot pushed a commit that referenced this pull request May 19, 2021
Summary:
* Added support for bicubic mode with AA

* Updated comment in the test

Reviewed By: cpuhrsch

Differential Revision: D28538771

fbshipit-source-id: 8c5bc434a8b3478c2088b46886a28c561d666b55
facebook-github-bot pushed a commit to pytorch/pytorch that referenced this pull request Dec 29, 2021
Summary:
Description:
- Added antialias flag to interpolate (CPU only)
  - forward and backward for bicubic mode
  - added tests

Previous PR for bilinear, #65142

### Benchmarks

<details>
<summary>
Forward pass, CPU. PTH interpolation vs PIL
</summary>

Cases:
- PTH RGB 3 Channels, float32 vs PIL RGB uint8 (apples vs pears)
- PTH 1 Channel, float32 vs PIL 1 Channel Float

Code: https://gist.github.com/vfdev-5/b173761a567f2283b3c649c3c0574112

```
Torch config: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201402
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - CPU capability usage: AVX2
  - CUDA Runtime 11.1
  - NVCC architecture flags: -gencode;arch=compute_61,code=sm_61
  - CuDNN 8.0.5
  - Build settings: BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_PYTORCH_QNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.11.0, USE_CUDA=1, USE_CUDNN=1, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=0, USE_OPENMP=ON, USE_ROCM=OFF,

Num threads: 1
[------------------- Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (320, 196) -------------------]
                                                  |  Reference, PIL 8.4.0, mode: RGB  |  1.11.0a0+gitb0bdf58
1 threads: -------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |                4.5                |          5.2
      channels_last non-contiguous torch.float32  |                4.5                |          5.3

Times are in milliseconds (ms).

[------------------- Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (460, 220) -------------------]
                                                  |  Reference, PIL 8.4.0, mode: RGB  |  1.11.0a0+gitb0bdf58
1 threads: -------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |                5.7                |          6.4
      channels_last non-contiguous torch.float32  |                5.7                |          6.4

Times are in milliseconds (ms).

[------------------- Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (120, 96) --------------------]
                                                  |  Reference, PIL 8.4.0, mode: RGB  |  1.11.0a0+gitb0bdf58
1 threads: -------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |                3.0                |          4.0
      channels_last non-contiguous torch.float32  |                2.9                |          4.1

Times are in milliseconds (ms).

[------------------ Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (1200, 196) -------------------]
                                                  |  Reference, PIL 8.4.0, mode: RGB  |  1.11.0a0+gitb0bdf58
1 threads: -------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |                14.7               |          17.1
      channels_last non-contiguous torch.float32  |                14.8               |          17.2

Times are in milliseconds (ms).

[------------------ Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (120, 1200) -------------------]
                                                  |  Reference, PIL 8.4.0, mode: RGB  |  1.11.0a0+gitb0bdf58
1 threads: -------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |                3.5                |          3.9
      channels_last non-contiguous torch.float32  |                3.5                |          3.9

Times are in milliseconds (ms).

[---------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (320, 196) ---------]
                                 |  Reference, PIL 8.4.0, mode: F  |  1.11.0a0+gitb0bdf58
1 threads: ------------------------------------------------------------------------------
       contiguous torch.float32  |               2.4               |          1.8

Times are in milliseconds (ms).

[---------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (460, 220) ---------]
                                 |  Reference, PIL 8.4.0, mode: F  |  1.11.0a0+gitb0bdf58
1 threads: ------------------------------------------------------------------------------
       contiguous torch.float32  |               3.1               |          2.2

Times are in milliseconds (ms).

[---------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (120, 96) ----------]
                                 |  Reference, PIL 8.4.0, mode: F  |  1.11.0a0+gitb0bdf58
1 threads: ------------------------------------------------------------------------------
       contiguous torch.float32  |               1.6               |          1.4

Times are in milliseconds (ms).

[--------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (1200, 196) ---------]
                                 |  Reference, PIL 8.4.0, mode: F  |  1.11.0a0+gitb0bdf58
1 threads: ------------------------------------------------------------------------------
       contiguous torch.float32  |               7.9               |          5.7

Times are in milliseconds (ms).

[--------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (120, 1200) ---------]
                                 |  Reference, PIL 8.4.0, mode: F  |  1.11.0a0+gitb0bdf58
1 threads: ------------------------------------------------------------------------------
       contiguous torch.float32  |               1.7               |          1.3

Times are in milliseconds (ms).

```

</details>

Code is moved from torchvision: pytorch/vision#3810 and pytorch/vision#4208

Pull Request resolved: #68819

Reviewed By: mikaylagawarecki

Differential Revision: D33339117

Pulled By: jbschlosser

fbshipit-source-id: 6a0443bbba5439f52c7dbc1be819b75634cf67c4
wconstab pushed a commit to pytorch/pytorch that referenced this pull request Jan 5, 2022
Summary:
Description:
- Added antialias flag to interpolate (CPU only)
  - forward and backward for bicubic mode
  - added tests

Previous PR for bilinear, #65142

### Benchmarks

<details>
<summary>
Forward pass, CPU. PTH interpolation vs PIL
</summary>

Cases:
- PTH RGB 3 Channels, float32 vs PIL RGB uint8 (apples vs pears)
- PTH 1 Channel, float32 vs PIL 1 Channel Float

Code: https://gist.github.com/vfdev-5/b173761a567f2283b3c649c3c0574112

```
Torch config: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201402
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - CPU capability usage: AVX2
  - CUDA Runtime 11.1
  - NVCC architecture flags: -gencode;arch=compute_61,code=sm_61
  - CuDNN 8.0.5
  - Build settings: BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_PYTORCH_QNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.11.0, USE_CUDA=1, USE_CUDNN=1, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=0, USE_OPENMP=ON, USE_ROCM=OFF,

Num threads: 1
[------------------- Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (320, 196) -------------------]
                                                  |  Reference, PIL 8.4.0, mode: RGB  |  1.11.0a0+gitb0bdf58
1 threads: -------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |                4.5                |          5.2
      channels_last non-contiguous torch.float32  |                4.5                |          5.3

Times are in milliseconds (ms).

[------------------- Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (460, 220) -------------------]
                                                  |  Reference, PIL 8.4.0, mode: RGB  |  1.11.0a0+gitb0bdf58
1 threads: -------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |                5.7                |          6.4
      channels_last non-contiguous torch.float32  |                5.7                |          6.4

Times are in milliseconds (ms).

[------------------- Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (120, 96) --------------------]
                                                  |  Reference, PIL 8.4.0, mode: RGB  |  1.11.0a0+gitb0bdf58
1 threads: -------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |                3.0                |          4.0
      channels_last non-contiguous torch.float32  |                2.9                |          4.1

Times are in milliseconds (ms).

[------------------ Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (1200, 196) -------------------]
                                                  |  Reference, PIL 8.4.0, mode: RGB  |  1.11.0a0+gitb0bdf58
1 threads: -------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |                14.7               |          17.1
      channels_last non-contiguous torch.float32  |                14.8               |          17.2

Times are in milliseconds (ms).

[------------------ Downsampling (bicubic): torch.Size([1, 3, 906, 438]) -> (120, 1200) -------------------]
                                                  |  Reference, PIL 8.4.0, mode: RGB  |  1.11.0a0+gitb0bdf58
1 threads: -------------------------------------------------------------------------------------------------
      channels_first contiguous torch.float32     |                3.5                |          3.9
      channels_last non-contiguous torch.float32  |                3.5                |          3.9

Times are in milliseconds (ms).

[---------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (320, 196) ---------]
                                 |  Reference, PIL 8.4.0, mode: F  |  1.11.0a0+gitb0bdf58
1 threads: ------------------------------------------------------------------------------
       contiguous torch.float32  |               2.4               |          1.8

Times are in milliseconds (ms).

[---------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (460, 220) ---------]
                                 |  Reference, PIL 8.4.0, mode: F  |  1.11.0a0+gitb0bdf58
1 threads: ------------------------------------------------------------------------------
       contiguous torch.float32  |               3.1               |          2.2

Times are in milliseconds (ms).

[---------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (120, 96) ----------]
                                 |  Reference, PIL 8.4.0, mode: F  |  1.11.0a0+gitb0bdf58
1 threads: ------------------------------------------------------------------------------
       contiguous torch.float32  |               1.6               |          1.4

Times are in milliseconds (ms).

[--------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (1200, 196) ---------]
                                 |  Reference, PIL 8.4.0, mode: F  |  1.11.0a0+gitb0bdf58
1 threads: ------------------------------------------------------------------------------
       contiguous torch.float32  |               7.9               |          5.7

Times are in milliseconds (ms).

[--------- Downsampling (bicubic): torch.Size([1, 1, 906, 438]) -> (120, 1200) ---------]
                                 |  Reference, PIL 8.4.0, mode: F  |  1.11.0a0+gitb0bdf58
1 threads: ------------------------------------------------------------------------------
       contiguous torch.float32  |               1.7               |          1.3

Times are in milliseconds (ms).

```

</details>

Code is moved from torchvision: pytorch/vision#3810 and pytorch/vision#4208

Pull Request resolved: #68819

Reviewed By: mikaylagawarecki

Differential Revision: D33339117

Pulled By: jbschlosser

fbshipit-source-id: 6a0443bbba5439f52c7dbc1be819b75634cf67c4
justinchuby added a commit to microsoft/onnxscript that referenced this pull request Mar 19, 2026
…#2849)

When exporting `F.interpolate(mode='bicubic', antialias=True)`, the ONNX
Resize node was emitted with `cubic_coeff_a=-0.75` (OpenCV-compatible),
but PyTorch uses `-0.5` (Keys/PIL-compatible) for the antialias path.
This caused ~32x higher numerical error vs. PyTorch when running the
exported model in ONNX Runtime.

## Changes

- **`_aten_upsample_output_size` / `_aten_upsample_scales`**: Added
`cubic_coeff_a: float = -0.75` parameter (default preserves existing
behavior for non-antialias cases) and thread it through to `op.Resize`.
- **`aten__upsample_bicubic2d_aa`**: Pass `cubic_coeff_a=-0.5` to match
PyTorch's runtime behavior when `antialias=True`.

```python
# antialias=True  → cubic_coeff_a=-0.5  (Keys/PIL-compatible)  ✓
# antialias=False → cubic_coeff_a=-0.75 (OpenCV-compatible)    ✓
```

<!-- START COPILOT ORIGINAL PROMPT -->



<details>

<summary>Original prompt</summary>


----

*This section details on the original issue you should resolve*

<issue_title>ONNX dynamo export writes cubic_coeff_a=-0.75 for bicubic
antialias=True (should be -0.5)</issue_title>
<issue_description>### 🐛 Describe the bug

# ONNX dynamo export writes cubic_coeff_a=-0.75 for bicubic
antialias=True (should be -0.5)

## Bug

When exporting `F.interpolate(mode='bicubic', antialias=True)` to ONNX
via the dynamo exporter, the Resize node is written with
`cubic_coeff_a=-0.75`. However, PyTorch internally uses
`cubic_coeff_a=-0.5` (Keys interpolation) when `antialias=True`, as
documented in the source:

```cpp
// aten/src/ATen/native/cpu/UpSampleKernel.cpp, line ~1347
// We are using -0.5 for bicubic, antialiasing=true (compatibility with PIL)
// and using -0.75 for bicubic, antialiasing=false (compatibility with Opencv)
constexpr scalar_t a = use_keys_cubic ? -0.5 : -0.75;
```

The exported ONNX model therefore produces different results than
PyTorch when run in ONNX Runtime (or any runtime that correctly respects
the `cubic_coeff_a` attribute).

The `-0.75` value was originally hardcoded in PR pytorch/pytorch#24805
for the non-antialias case and was carried forward without accounting
for the antialias path. The distinction between `-0.5` (Keys,
PIL-compatible) and `-0.75` (OpenCV-compatible) based on the antialias
flag was introduced in the ATen kernels via pytorch/vision#3810 and
pytorch#68819.

The legacy TorchScript exporter does not support `antialias=True` at all
(`UnsupportedOperatorError`), so this only affects the dynamo exporter.

## To reproduce

```python
import io
import numpy as np
import onnx
import onnxruntime as ort
import torch
import torch.nn as nn
import torch.nn.functional as F


class BicubicAA(nn.Module):
    def forward(self, x):
        return F.interpolate(x, size=[224, 224], mode="bicubic",
                             align_corners=False, antialias=True)


# Export
model = BicubicAA()
model.eval()
x = torch.rand(1, 3, 800, 600)
buf = io.BytesIO()
torch.onnx.export(model, (x,), buf, opset_version=18, dynamo=True)
buf.seek(0)
onnx_model = onnx.load(buf)

# Inspect: cubic_coeff_a is -0.75 (wrong for antialias=True)
for node in onnx_model.graph.node:
    if node.op_type == "Resize":
        for attr in node.attribute:
            if attr.name == "cubic_coeff_a":
                print(f"Exported cubic_coeff_a = {attr.f}")  # -0.75
            if attr.name == "antialias":
                print(f"Exported antialias = {attr.i}")       # 1

# Numerical impact
with torch.no_grad():
    pt_out = model(x).numpy()

buf.seek(0)
sess = ort.InferenceSession(buf.read())
ort_wrong = sess.run(None, {"x": x.numpy()})[0]

# Patch to correct value and re-run
for node in onnx_model.graph.node:
    if node.op_type == "Resize":
        for attr in node.attribute:
            if attr.name == "cubic_coeff_a":
                attr.f = -0.5
buf2 = io.BytesIO()
onnx.save(onnx_model, buf2)
buf2.seek(0)
sess2 = ort.InferenceSession(buf2.read())
ort_fixed = sess2.run(None, {"x": x.numpy()})[0]

print(f"PyTorch vs ONNX (exported a=-0.75): mean={np.abs(ort_wrong - pt_out).mean():.2e}")
print(f"PyTorch vs ONNX (patched  a=-0.50): mean={np.abs(ort_fixed - pt_out).mean():.2e}")
```

Output:

```
Exported cubic_coeff_a = -0.75
Exported antialias = 1
PyTorch vs ONNX (exported a=-0.75): mean=5.31e-03
PyTorch vs ONNX (patched  a=-0.50): mean=1.67e-04
```

Patching `cubic_coeff_a` to `-0.5` reduces mean error by 32x, confirming
that PyTorch uses `-0.5` at runtime but the exporter writes `-0.75`.

## Expected behavior

When `antialias=True`, the ONNX Resize node should be exported with
`cubic_coeff_a=-0.5` to match PyTorch's runtime behavior. When
`antialias=False`, `cubic_coeff_a=-0.75` is correct.

### Versions

Collecting environment information...
PyTorch version: 2.10.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 4.2.3
Libc version: glibc-2.31

Python version: 3.12.12 (main, Feb 3 2026, 22:51:04) [Clang 21.1.4 ]
(64-bit runtime)
Python platform: Linux-5.4.0-208-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.2.152
CUDA_MODULE_LOADING set to: 
GPU models and configuration: 
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
GPU 2: NVIDIA A100-SXM4-40GB
GPU 3: NVIDIA A100-SXM4-40GB
GPU 4: NVIDIA A100-SXM4-40GB
GPU 5: NVIDIA A100-SXM4-40GB
GPU 6: NVIDIA A100-SXM4-40GB
GPU 7: NVIDIA A100-SXM4-40GB

Nvidia driver version: 565.57.01
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      43 bits physical...

</details>



<!-- START COPILOT CODING AGENT SUFFIX -->

- Fixes pytorch/pytorch#177138

<!-- START COPILOT CODING AGENT TIPS -->
---

🔒 GitHub Advanced Security automatically protects Copilot coding agent
pull requests. You can protect all pull requests by enabling Advanced
Security for your repositories. [Learn more about Advanced
Security.](https://gh.io/cca-advanced-security)

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants