Skip to content

[aten] 8 bytes aligned vector loads for bf16 and fp16 dtypes in torch.cat#150233

Closed
zhaozhul wants to merge 1 commit intopytorch:mainfrom
zhaozhul:main
Closed

[aten] 8 bytes aligned vector loads for bf16 and fp16 dtypes in torch.cat#150233
zhaozhul wants to merge 1 commit intopytorch:mainfrom
zhaozhul:main

Conversation

@zhaozhul
Copy link
Copy Markdown
Contributor

Enable aligned vector loading for 2 bytes datatypes in torch.cat. Specifically:

  1. reduce the vector length to 8 bytes for 2-byte types (fp16, bf16 etc)
  2. enable through a conditional template

The reason why 8-byte vector loading was chosen for fp16 and bf16:
16-byte load results in heavier register overheads (i.e. 4 register per load for fp32 -> 8 register per load for fp16). Therefore, to employ the benefits of vectorized loading, we reduced ALIGNED_VEC_LOAD_BYTES to 8 for fp16 and bf16

perf testing:

before:

torch-cat-D1-30108-D2-624-D3-772-dtype-torch.float32:
         B  pt_eager      copy
0    100.0  0.022621  0.036162
1   1000.0  0.133616  0.207051
2  10000.0  1.326848  1.848768
3  20000.0  2.744544  3.692128
torch-cat-D1-30108-D2-624-D3-772-dtype-torch.bfloat16:
         B  pt_eager      copy
0    100.0  0.022434  0.035477
1   1000.0  0.140608  0.144518
2  10000.0  1.303792  1.229584
3  20000.0  2.668288  2.436160

after:

torch-cat-D1-30108-D2-624-D3-772-dtype-torch.float32:
         B  pt_eager      copy
0    100.0  0.022608  0.036328
1   1000.0  0.133861  0.207399
2  10000.0  1.325120  1.847136
3  20000.0  2.726528  3.693184
torch-cat-D1-30108-D2-624-D3-772-dtype-torch.bfloat16:
         B  pt_eager      copy
0    100.0  0.019942  0.035482
1   1000.0  0.084858  0.144544
2  10000.0  0.924384  1.230672
3  20000.0  1.944448  2.436480

bw analysis:

bw on fp16/bf16 got increased by 40%-50% for large tensors

before:

Bandwidth (GB/s) for ((16384, 16384), 1) int8;fp16;fp32;int32;fp64;long|869.87|1382.74|1956.46|1952.73|1969.03|1963.66
Bandwidth (GB/s) for ((4194304,), 0) int8;fp16;fp32;int32;fp64;long|568.43|926.53|1589.20|1567.52|1771.54|1783.68
Bandwidth (GB/s) for ((16777216,), 0) int8;fp16;fp32;int32;fp64;long|752.07|1269.50|1894.86|1900.85|1954.10|1955.08
Bandwidth (GB/s) for ((33554432,), 0) int8;fp16;fp32;int32;fp64;long|807.08|1354.69|1960.48|1962.45|1972.73|1973.85
Bandwidth (GB/s) for ((134217728,), 0) int8;fp16;fp32;int32;fp64;long|864.02|1398.02|1963.43|1955.32|1963.37|1969.96

after:

Bandwidth (GB/s) for ((16384, 16384), 1) int8;fp16;fp32;int32;fp64;long|873.08|1892.16|1954.35|1962.51|1962.03|1965.98
Bandwidth (GB/s) for ((4194304,), 0) int8;fp16;fp32;int32;fp64;long|575.13|1242.45|1576.37|1571.30|1769.94|1790.22
Bandwidth (GB/s) for ((16777216,), 0) int8;fp16;fp32;int32;fp64;long|742.92|1734.57|1887.99|1897.62|1940.99|1959.25
Bandwidth (GB/s) for ((33554432,), 0) int8;fp16;fp32;int32;fp64;long|802.60|1865.45|1952.64|1947.53|1974.47|1973.48
Bandwidth (GB/s) for ((134217728,), 0) int8;fp16;fp32;int32;fp64;long|865.32|1939.07|1965.72|1963.25|1969.06|1968.72

Perf testing code:

# pyre-strict
from typing import List, Optional, Tuple

import click
import pandas as pd

import torch

# @manual=//triton:triton
import triton


# CUDA_VISIBLE_DEVICEs=7 buck2 run @mode/opt //scripts/zhaozhu:cat_bench


@click.command()
@click.option("--data-type", type=str, default="bf16")
@click.option("--return-result", type=bool, default=False)
def main(
    data_type: str,
    return_result: bool,
) -> Optional[Tuple[List[triton.testing.Benchmark], List[pd.DataFrame]]]:
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cuda.matmul.allow_tf32 = True
    if data_type == "fp32":
        dtype = torch.float32
    elif data_type == "fp16":
        dtype = torch.float16
    elif data_type == "bf16":
        dtype = torch.bfloat16
    else:
        raise ValueError(f"Unsupported data type: {data_type}.")

    D1 = int(torch.randint(low=10000, high=50000, size=(1,)).item())
    D2 = int(torch.randint(low=100, high=1000, size=(1,)).item())
    D3 = int(torch.randint(low=500, high=1000, size=(1,)).item())

    configs: List[triton.testing.Benchmark] = [
        triton.testing.Benchmark(
            x_names=["B"],
            x_vals=[100, 1000, 10000, 20000],
            line_arg="provider",
            line_vals=["pt_eager", "copy"],
            line_names=["pt_eager", "copy"],
            styles=[("blue", "-"), ("green", "-"), ("red", "-")],
            ylabel="ms",
            plot_name=f"torch-cat-D1-{D1}-D2-{D2}-D3-{D3}-dtype-{dtype}",
            args={
                "D1": D1,
                "D2": D2,
                "D3": D3,
                "dtype": dtype,
            },
        )
    ]

    @triton.testing.perf_report(configs)
    def bench_cat(
        B: int,
        D1: int,
        D2: int,
        D3: int,
        dtype: torch.dtype,
        provider: str,
    ) -> float:
        warmup = 10
        rep = 3

        tensors = []

        a = torch.empty(
            # (B, 30108),
            (B, D1),
            dtype=dtype,
            device=torch.device("cuda"),
        ).uniform_(-1.0, 1.0)
        b = torch.empty(
            # (B, 624),
            (B, D2),
            dtype=dtype,
            device=torch.device("cuda"),
        ).uniform_(-1.0, 1.0)
        c = torch.empty(
            # (B, 772),
            (B, D3),
            dtype=dtype,
            device=torch.device("cuda"),
        ).uniform_(-1.0, 1.0)

        tensors = [a, b, c]

        total_cols: int = int(a.shape[1] + b.shape[1] + c.shape[1])

        def torch_copy(
            tensors: List[torch.Tensor], is_inplace: bool = True
        ) -> torch.Tensor:
            f = torch.zeros([B, total_cols], dtype=dtype, device=torch.device("cuda"))
            col_idx = 0
            for t in tensors:
                temp = f[:, col_idx : col_idx + t.shape[1]]
                if is_inplace:
                    temp.copy_(t)
                else:
                    f[:, col_idx : col_idx + t.shape[1]] = t
                col_idx += t.shape[1]
            return f

        def torch_cat(tensors: List[torch.Tensor]) -> torch.Tensor:
            return torch.cat(tensors, dim=1)

        ref = torch_cat(tensors)
        real = torch_copy(tensors, is_inplace=False)

        torch.testing.assert_allclose(ref, real)

        if provider == "pt_eager":
            fn = lambda: torch_cat(tensors)  # noqa E731
            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
            return ms
        elif provider == "stack":

            def torch_stack(tensors: List[torch.Tensor]) -> torch.Tensor:
                return torch.stack(tensors, dim=1).view(-1, total_cols)

            fn = lambda: torch_stack(tensors)
            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
            return ms
        elif provider == "copy":
            fn = lambda: torch_copy(tensors)
            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
            return ms
        else:
            raise ValueError(f"unsupported provider: {provider}")

    df = bench_cat.run(print_data=True, return_df=return_result)

    if return_result:
        return configs, df


if __name__ == "__main__":
    main()

and bw analysis code is from: #102815

@zhaozhul zhaozhul added the topic: performance topic category label Mar 28, 2025
@zhaozhul zhaozhul requested review from eqy and syed-ahmed as code owners March 28, 2025 22:15
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 28, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/150233

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 8ac1794 with merge base 7ac0658 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link
Copy Markdown

linux-foundation-easycla Bot commented Mar 28, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: zhaozhul / name: Zhao Zhu (8ac1794)

@pytorch-bot pytorch-bot Bot added the release notes: cuda release notes category label Mar 28, 2025
@ngimel
Copy link
Copy Markdown
Collaborator

ngimel commented Apr 3, 2025

Did you try actually launching half kernels with the same 128-bit loads? Is your improvement just due to vectorized path that narrow types didn't take previously, and you don't need extra 64-bit instantiation?

@zhaozhul
Copy link
Copy Markdown
Contributor Author

zhaozhul commented Apr 3, 2025

Did you try actually launching half kernels with the same 128-bit loads? Is your improvement just due to vectorized path that narrow types didn't take previously, and you don't need extra 64-bit instantiation?

Thanks Natalia. We originally tried to enable such vectorized load with 128-bit as well, but there was major perf regression:

for fp16 types, torch.cat is even slower than copy. This might be related to increased register access (since the byte-per-element is half of fp32, hence double the register access during vectorized loads)

torch-cat-D1-14318-D2-479-D3-843-dtype-torch.float16: 8 reg
         B  pt_eager      copy
0    100.0  0.027242  0.034411
1   1000.0  0.200789  0.144489
2  10000.0  1.934080  1.230432
3  20000.0  3.862880  2.433824

torch-cat-D1-46070-D2-493-D3-526-dtype-torch.float32: 4 reg
         B  pt_eager      copy
0    100.0  0.022901  0.035566
1   1000.0  0.133395  0.206304
2  10000.0  1.317136  1.843776
3  20000.0  2.708800  3.693952

torch-cat-D1-14655-D2-116-D3-770-dtype-torch.float64: 2 reg
         B  pt_eager      copy
0    100.0  0.033610  0.048977
1   1000.0  0.253016  0.355499
2  10000.0  2.527584  3.365216
3  20000.0  5.111072  6.726208

@ngimel
Copy link
Copy Markdown
Collaborator

ngimel commented Apr 3, 2025

Registers are 4 byte, and the load instruction should use fewer of them. But maybe index computations exert register pressure.

@zhaozhul
Copy link
Copy Markdown
Contributor Author

zhaozhul commented Apr 3, 2025

Registers are 4 byte, and the load instruction should use fewer of them. But maybe index computations exert register pressure.

Yes the kILP doubled due to smaller byte size here, hence more register pressure

    IndexType v_elementOffset[kILP];
    T reg_data[kILP];

8ac1794#diff-c5d250b30de1c137c7be334acce924451e59a89f7135467204b30fa1b2462e0aL256

@zhaozhul
Copy link
Copy Markdown
Contributor Author

zhaozhul commented Apr 3, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 3, 2025
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

timocafe pushed a commit to timocafe/pytorch that referenced this pull request Apr 16, 2025
….cat (pytorch#150233)

Enable aligned vector loading for 2 bytes datatypes in torch.cat. Specifically:
1. reduce the vector length to 8 bytes for 2-byte types (fp16, bf16 etc)
2. enable through a conditional template

The reason why 8-byte vector loading was chosen for fp16 and bf16:
16-byte load results in heavier register overheads (i.e. 4 register per load for fp32 -> 8 register per load for fp16). Therefore, to employ the benefits of vectorized loading, we reduced ALIGNED_VEC_LOAD_BYTES to 8 for fp16 and bf16

### perf testing:

before:
```
torch-cat-D1-30108-D2-624-D3-772-dtype-torch.float32:
         B  pt_eager      copy
0    100.0  0.022621  0.036162
1   1000.0  0.133616  0.207051
2  10000.0  1.326848  1.848768
3  20000.0  2.744544  3.692128
torch-cat-D1-30108-D2-624-D3-772-dtype-torch.bfloat16:
         B  pt_eager      copy
0    100.0  0.022434  0.035477
1   1000.0  0.140608  0.144518
2  10000.0  1.303792  1.229584
3  20000.0  2.668288  2.436160
```

after:
```
torch-cat-D1-30108-D2-624-D3-772-dtype-torch.float32:
         B  pt_eager      copy
0    100.0  0.022608  0.036328
1   1000.0  0.133861  0.207399
2  10000.0  1.325120  1.847136
3  20000.0  2.726528  3.693184
torch-cat-D1-30108-D2-624-D3-772-dtype-torch.bfloat16:
         B  pt_eager      copy
0    100.0  0.019942  0.035482
1   1000.0  0.084858  0.144544
2  10000.0  0.924384  1.230672
3  20000.0  1.944448  2.436480

```

### bw analysis:
bw on fp16/bf16 got increased by 40%-50% for large tensors

before:
```
Bandwidth (GB/s) for ((16384, 16384), 1) int8;fp16;fp32;int32;fp64;long|869.87|1382.74|1956.46|1952.73|1969.03|1963.66
Bandwidth (GB/s) for ((4194304,), 0) int8;fp16;fp32;int32;fp64;long|568.43|926.53|1589.20|1567.52|1771.54|1783.68
Bandwidth (GB/s) for ((16777216,), 0) int8;fp16;fp32;int32;fp64;long|752.07|1269.50|1894.86|1900.85|1954.10|1955.08
Bandwidth (GB/s) for ((33554432,), 0) int8;fp16;fp32;int32;fp64;long|807.08|1354.69|1960.48|1962.45|1972.73|1973.85
Bandwidth (GB/s) for ((134217728,), 0) int8;fp16;fp32;int32;fp64;long|864.02|1398.02|1963.43|1955.32|1963.37|1969.96
```

after:
```
Bandwidth (GB/s) for ((16384, 16384), 1) int8;fp16;fp32;int32;fp64;long|873.08|1892.16|1954.35|1962.51|1962.03|1965.98
Bandwidth (GB/s) for ((4194304,), 0) int8;fp16;fp32;int32;fp64;long|575.13|1242.45|1576.37|1571.30|1769.94|1790.22
Bandwidth (GB/s) for ((16777216,), 0) int8;fp16;fp32;int32;fp64;long|742.92|1734.57|1887.99|1897.62|1940.99|1959.25
Bandwidth (GB/s) for ((33554432,), 0) int8;fp16;fp32;int32;fp64;long|802.60|1865.45|1952.64|1947.53|1974.47|1973.48
Bandwidth (GB/s) for ((134217728,), 0) int8;fp16;fp32;int32;fp64;long|865.32|1939.07|1965.72|1963.25|1969.06|1968.72
```

### Perf testing code:

```
# pyre-strict
from typing import List, Optional, Tuple

import click
import pandas as pd

import torch

# @Manual=//triton:triton
import triton

# CUDA_VISIBLE_DEVICEs=7 buck2 run @mode/opt //scripts/zhaozhu:cat_bench

@click.command()
@click.option("--data-type", type=str, default="bf16")
@click.option("--return-result", type=bool, default=False)
def main(
    data_type: str,
    return_result: bool,
) -> Optional[Tuple[List[triton.testing.Benchmark], List[pd.DataFrame]]]:
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cuda.matmul.allow_tf32 = True
    if data_type == "fp32":
        dtype = torch.float32
    elif data_type == "fp16":
        dtype = torch.float16
    elif data_type == "bf16":
        dtype = torch.bfloat16
    else:
        raise ValueError(f"Unsupported data type: {data_type}.")

    D1 = int(torch.randint(low=10000, high=50000, size=(1,)).item())
    D2 = int(torch.randint(low=100, high=1000, size=(1,)).item())
    D3 = int(torch.randint(low=500, high=1000, size=(1,)).item())

    configs: List[triton.testing.Benchmark] = [
        triton.testing.Benchmark(
            x_names=["B"],
            x_vals=[100, 1000, 10000, 20000],
            line_arg="provider",
            line_vals=["pt_eager", "copy"],
            line_names=["pt_eager", "copy"],
            styles=[("blue", "-"), ("green", "-"), ("red", "-")],
            ylabel="ms",
            plot_name=f"torch-cat-D1-{D1}-D2-{D2}-D3-{D3}-dtype-{dtype}",
            args={
                "D1": D1,
                "D2": D2,
                "D3": D3,
                "dtype": dtype,
            },
        )
    ]

    @triton.testing.perf_report(configs)
    def bench_cat(
        B: int,
        D1: int,
        D2: int,
        D3: int,
        dtype: torch.dtype,
        provider: str,
    ) -> float:
        warmup = 10
        rep = 3

        tensors = []

        a = torch.empty(
            # (B, 30108),
            (B, D1),
            dtype=dtype,
            device=torch.device("cuda"),
        ).uniform_(-1.0, 1.0)
        b = torch.empty(
            # (B, 624),
            (B, D2),
            dtype=dtype,
            device=torch.device("cuda"),
        ).uniform_(-1.0, 1.0)
        c = torch.empty(
            # (B, 772),
            (B, D3),
            dtype=dtype,
            device=torch.device("cuda"),
        ).uniform_(-1.0, 1.0)

        tensors = [a, b, c]

        total_cols: int = int(a.shape[1] + b.shape[1] + c.shape[1])

        def torch_copy(
            tensors: List[torch.Tensor], is_inplace: bool = True
        ) -> torch.Tensor:
            f = torch.zeros([B, total_cols], dtype=dtype, device=torch.device("cuda"))
            col_idx = 0
            for t in tensors:
                temp = f[:, col_idx : col_idx + t.shape[1]]
                if is_inplace:
                    temp.copy_(t)
                else:
                    f[:, col_idx : col_idx + t.shape[1]] = t
                col_idx += t.shape[1]
            return f

        def torch_cat(tensors: List[torch.Tensor]) -> torch.Tensor:
            return torch.cat(tensors, dim=1)

        ref = torch_cat(tensors)
        real = torch_copy(tensors, is_inplace=False)

        torch.testing.assert_allclose(ref, real)

        if provider == "pt_eager":
            fn = lambda: torch_cat(tensors)  # noqa E731
            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
            return ms
        elif provider == "stack":

            def torch_stack(tensors: List[torch.Tensor]) -> torch.Tensor:
                return torch.stack(tensors, dim=1).view(-1, total_cols)

            fn = lambda: torch_stack(tensors)
            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
            return ms
        elif provider == "copy":
            fn = lambda: torch_copy(tensors)
            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
            return ms
        else:
            raise ValueError(f"unsupported provider: {provider}")

    df = bench_cat.run(print_data=True, return_df=return_result)

    if return_result:
        return configs, df

if __name__ == "__main__":
    main()
```

and bw analysis code is from: pytorch#102815

Pull Request resolved: pytorch#150233
Approved by: https://github.com/ngimel
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
….cat (pytorch#150233)

Enable aligned vector loading for 2 bytes datatypes in torch.cat. Specifically:
1. reduce the vector length to 8 bytes for 2-byte types (fp16, bf16 etc)
2. enable through a conditional template

The reason why 8-byte vector loading was chosen for fp16 and bf16:
16-byte load results in heavier register overheads (i.e. 4 register per load for fp32 -> 8 register per load for fp16). Therefore, to employ the benefits of vectorized loading, we reduced ALIGNED_VEC_LOAD_BYTES to 8 for fp16 and bf16

### perf testing:

before:
```
torch-cat-D1-30108-D2-624-D3-772-dtype-torch.float32:
         B  pt_eager      copy
0    100.0  0.022621  0.036162
1   1000.0  0.133616  0.207051
2  10000.0  1.326848  1.848768
3  20000.0  2.744544  3.692128
torch-cat-D1-30108-D2-624-D3-772-dtype-torch.bfloat16:
         B  pt_eager      copy
0    100.0  0.022434  0.035477
1   1000.0  0.140608  0.144518
2  10000.0  1.303792  1.229584
3  20000.0  2.668288  2.436160
```

after:
```
torch-cat-D1-30108-D2-624-D3-772-dtype-torch.float32:
         B  pt_eager      copy
0    100.0  0.022608  0.036328
1   1000.0  0.133861  0.207399
2  10000.0  1.325120  1.847136
3  20000.0  2.726528  3.693184
torch-cat-D1-30108-D2-624-D3-772-dtype-torch.bfloat16:
         B  pt_eager      copy
0    100.0  0.019942  0.035482
1   1000.0  0.084858  0.144544
2  10000.0  0.924384  1.230672
3  20000.0  1.944448  2.436480

```

### bw analysis:
bw on fp16/bf16 got increased by 40%-50% for large tensors

before:
```
Bandwidth (GB/s) for ((16384, 16384), 1) int8;fp16;fp32;int32;fp64;long|869.87|1382.74|1956.46|1952.73|1969.03|1963.66
Bandwidth (GB/s) for ((4194304,), 0) int8;fp16;fp32;int32;fp64;long|568.43|926.53|1589.20|1567.52|1771.54|1783.68
Bandwidth (GB/s) for ((16777216,), 0) int8;fp16;fp32;int32;fp64;long|752.07|1269.50|1894.86|1900.85|1954.10|1955.08
Bandwidth (GB/s) for ((33554432,), 0) int8;fp16;fp32;int32;fp64;long|807.08|1354.69|1960.48|1962.45|1972.73|1973.85
Bandwidth (GB/s) for ((134217728,), 0) int8;fp16;fp32;int32;fp64;long|864.02|1398.02|1963.43|1955.32|1963.37|1969.96
```

after:
```
Bandwidth (GB/s) for ((16384, 16384), 1) int8;fp16;fp32;int32;fp64;long|873.08|1892.16|1954.35|1962.51|1962.03|1965.98
Bandwidth (GB/s) for ((4194304,), 0) int8;fp16;fp32;int32;fp64;long|575.13|1242.45|1576.37|1571.30|1769.94|1790.22
Bandwidth (GB/s) for ((16777216,), 0) int8;fp16;fp32;int32;fp64;long|742.92|1734.57|1887.99|1897.62|1940.99|1959.25
Bandwidth (GB/s) for ((33554432,), 0) int8;fp16;fp32;int32;fp64;long|802.60|1865.45|1952.64|1947.53|1974.47|1973.48
Bandwidth (GB/s) for ((134217728,), 0) int8;fp16;fp32;int32;fp64;long|865.32|1939.07|1965.72|1963.25|1969.06|1968.72
```

### Perf testing code:

```
# pyre-strict
from typing import List, Optional, Tuple

import click
import pandas as pd

import torch

# @Manual=//triton:triton
import triton

# CUDA_VISIBLE_DEVICEs=7 buck2 run @mode/opt //scripts/zhaozhu:cat_bench

@click.command()
@click.option("--data-type", type=str, default="bf16")
@click.option("--return-result", type=bool, default=False)
def main(
    data_type: str,
    return_result: bool,
) -> Optional[Tuple[List[triton.testing.Benchmark], List[pd.DataFrame]]]:
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cuda.matmul.allow_tf32 = True
    if data_type == "fp32":
        dtype = torch.float32
    elif data_type == "fp16":
        dtype = torch.float16
    elif data_type == "bf16":
        dtype = torch.bfloat16
    else:
        raise ValueError(f"Unsupported data type: {data_type}.")

    D1 = int(torch.randint(low=10000, high=50000, size=(1,)).item())
    D2 = int(torch.randint(low=100, high=1000, size=(1,)).item())
    D3 = int(torch.randint(low=500, high=1000, size=(1,)).item())

    configs: List[triton.testing.Benchmark] = [
        triton.testing.Benchmark(
            x_names=["B"],
            x_vals=[100, 1000, 10000, 20000],
            line_arg="provider",
            line_vals=["pt_eager", "copy"],
            line_names=["pt_eager", "copy"],
            styles=[("blue", "-"), ("green", "-"), ("red", "-")],
            ylabel="ms",
            plot_name=f"torch-cat-D1-{D1}-D2-{D2}-D3-{D3}-dtype-{dtype}",
            args={
                "D1": D1,
                "D2": D2,
                "D3": D3,
                "dtype": dtype,
            },
        )
    ]

    @triton.testing.perf_report(configs)
    def bench_cat(
        B: int,
        D1: int,
        D2: int,
        D3: int,
        dtype: torch.dtype,
        provider: str,
    ) -> float:
        warmup = 10
        rep = 3

        tensors = []

        a = torch.empty(
            # (B, 30108),
            (B, D1),
            dtype=dtype,
            device=torch.device("cuda"),
        ).uniform_(-1.0, 1.0)
        b = torch.empty(
            # (B, 624),
            (B, D2),
            dtype=dtype,
            device=torch.device("cuda"),
        ).uniform_(-1.0, 1.0)
        c = torch.empty(
            # (B, 772),
            (B, D3),
            dtype=dtype,
            device=torch.device("cuda"),
        ).uniform_(-1.0, 1.0)

        tensors = [a, b, c]

        total_cols: int = int(a.shape[1] + b.shape[1] + c.shape[1])

        def torch_copy(
            tensors: List[torch.Tensor], is_inplace: bool = True
        ) -> torch.Tensor:
            f = torch.zeros([B, total_cols], dtype=dtype, device=torch.device("cuda"))
            col_idx = 0
            for t in tensors:
                temp = f[:, col_idx : col_idx + t.shape[1]]
                if is_inplace:
                    temp.copy_(t)
                else:
                    f[:, col_idx : col_idx + t.shape[1]] = t
                col_idx += t.shape[1]
            return f

        def torch_cat(tensors: List[torch.Tensor]) -> torch.Tensor:
            return torch.cat(tensors, dim=1)

        ref = torch_cat(tensors)
        real = torch_copy(tensors, is_inplace=False)

        torch.testing.assert_allclose(ref, real)

        if provider == "pt_eager":
            fn = lambda: torch_cat(tensors)  # noqa E731
            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
            return ms
        elif provider == "stack":

            def torch_stack(tensors: List[torch.Tensor]) -> torch.Tensor:
                return torch.stack(tensors, dim=1).view(-1, total_cols)

            fn = lambda: torch_stack(tensors)
            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
            return ms
        elif provider == "copy":
            fn = lambda: torch_copy(tensors)
            ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
            return ms
        else:
            raise ValueError(f"unsupported provider: {provider}")

    df = bench_cat.run(print_data=True, return_df=return_result)

    if return_result:
        return configs, df

if __name__ == "__main__":
    main()
```

and bw analysis code is from: pytorch#102815

Pull Request resolved: pytorch#150233
Approved by: https://github.com/ngimel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: cuda release notes category topic: performance topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants