[aten] 8 bytes aligned vector loads for bf16 and fp16 dtypes in torch.cat#150233
[aten] 8 bytes aligned vector loads for bf16 and fp16 dtypes in torch.cat#150233zhaozhul wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/150233
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 8ac1794 with merge base 7ac0658 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
|
Did you try actually launching half kernels with the same 128-bit loads? Is your improvement just due to vectorized path that narrow types didn't take previously, and you don't need extra 64-bit instantiation? |
Thanks Natalia. We originally tried to enable such vectorized load with 128-bit as well, but there was major perf regression: for fp16 types, torch.cat is even slower than copy. This might be related to increased register access (since the byte-per-element is half of fp32, hence double the register access during vectorized loads) |
|
Registers are 4 byte, and the load instruction should use fewer of them. But maybe index computations exert register pressure. |
Yes the kILP doubled due to smaller byte size here, hence more register pressure 8ac1794#diff-c5d250b30de1c137c7be334acce924451e59a89f7135467204b30fa1b2462e0aL256 |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
….cat (pytorch#150233) Enable aligned vector loading for 2 bytes datatypes in torch.cat. Specifically: 1. reduce the vector length to 8 bytes for 2-byte types (fp16, bf16 etc) 2. enable through a conditional template The reason why 8-byte vector loading was chosen for fp16 and bf16: 16-byte load results in heavier register overheads (i.e. 4 register per load for fp32 -> 8 register per load for fp16). Therefore, to employ the benefits of vectorized loading, we reduced ALIGNED_VEC_LOAD_BYTES to 8 for fp16 and bf16 ### perf testing: before: ``` torch-cat-D1-30108-D2-624-D3-772-dtype-torch.float32: B pt_eager copy 0 100.0 0.022621 0.036162 1 1000.0 0.133616 0.207051 2 10000.0 1.326848 1.848768 3 20000.0 2.744544 3.692128 torch-cat-D1-30108-D2-624-D3-772-dtype-torch.bfloat16: B pt_eager copy 0 100.0 0.022434 0.035477 1 1000.0 0.140608 0.144518 2 10000.0 1.303792 1.229584 3 20000.0 2.668288 2.436160 ``` after: ``` torch-cat-D1-30108-D2-624-D3-772-dtype-torch.float32: B pt_eager copy 0 100.0 0.022608 0.036328 1 1000.0 0.133861 0.207399 2 10000.0 1.325120 1.847136 3 20000.0 2.726528 3.693184 torch-cat-D1-30108-D2-624-D3-772-dtype-torch.bfloat16: B pt_eager copy 0 100.0 0.019942 0.035482 1 1000.0 0.084858 0.144544 2 10000.0 0.924384 1.230672 3 20000.0 1.944448 2.436480 ``` ### bw analysis: bw on fp16/bf16 got increased by 40%-50% for large tensors before: ``` Bandwidth (GB/s) for ((16384, 16384), 1) int8;fp16;fp32;int32;fp64;long|869.87|1382.74|1956.46|1952.73|1969.03|1963.66 Bandwidth (GB/s) for ((4194304,), 0) int8;fp16;fp32;int32;fp64;long|568.43|926.53|1589.20|1567.52|1771.54|1783.68 Bandwidth (GB/s) for ((16777216,), 0) int8;fp16;fp32;int32;fp64;long|752.07|1269.50|1894.86|1900.85|1954.10|1955.08 Bandwidth (GB/s) for ((33554432,), 0) int8;fp16;fp32;int32;fp64;long|807.08|1354.69|1960.48|1962.45|1972.73|1973.85 Bandwidth (GB/s) for ((134217728,), 0) int8;fp16;fp32;int32;fp64;long|864.02|1398.02|1963.43|1955.32|1963.37|1969.96 ``` after: ``` Bandwidth (GB/s) for ((16384, 16384), 1) int8;fp16;fp32;int32;fp64;long|873.08|1892.16|1954.35|1962.51|1962.03|1965.98 Bandwidth (GB/s) for ((4194304,), 0) int8;fp16;fp32;int32;fp64;long|575.13|1242.45|1576.37|1571.30|1769.94|1790.22 Bandwidth (GB/s) for ((16777216,), 0) int8;fp16;fp32;int32;fp64;long|742.92|1734.57|1887.99|1897.62|1940.99|1959.25 Bandwidth (GB/s) for ((33554432,), 0) int8;fp16;fp32;int32;fp64;long|802.60|1865.45|1952.64|1947.53|1974.47|1973.48 Bandwidth (GB/s) for ((134217728,), 0) int8;fp16;fp32;int32;fp64;long|865.32|1939.07|1965.72|1963.25|1969.06|1968.72 ``` ### Perf testing code: ``` # pyre-strict from typing import List, Optional, Tuple import click import pandas as pd import torch # @Manual=//triton:triton import triton # CUDA_VISIBLE_DEVICEs=7 buck2 run @mode/opt //scripts/zhaozhu:cat_bench @click.command() @click.option("--data-type", type=str, default="bf16") @click.option("--return-result", type=bool, default=False) def main( data_type: str, return_result: bool, ) -> Optional[Tuple[List[triton.testing.Benchmark], List[pd.DataFrame]]]: torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True if data_type == "fp32": dtype = torch.float32 elif data_type == "fp16": dtype = torch.float16 elif data_type == "bf16": dtype = torch.bfloat16 else: raise ValueError(f"Unsupported data type: {data_type}.") D1 = int(torch.randint(low=10000, high=50000, size=(1,)).item()) D2 = int(torch.randint(low=100, high=1000, size=(1,)).item()) D3 = int(torch.randint(low=500, high=1000, size=(1,)).item()) configs: List[triton.testing.Benchmark] = [ triton.testing.Benchmark( x_names=["B"], x_vals=[100, 1000, 10000, 20000], line_arg="provider", line_vals=["pt_eager", "copy"], line_names=["pt_eager", "copy"], styles=[("blue", "-"), ("green", "-"), ("red", "-")], ylabel="ms", plot_name=f"torch-cat-D1-{D1}-D2-{D2}-D3-{D3}-dtype-{dtype}", args={ "D1": D1, "D2": D2, "D3": D3, "dtype": dtype, }, ) ] @triton.testing.perf_report(configs) def bench_cat( B: int, D1: int, D2: int, D3: int, dtype: torch.dtype, provider: str, ) -> float: warmup = 10 rep = 3 tensors = [] a = torch.empty( # (B, 30108), (B, D1), dtype=dtype, device=torch.device("cuda"), ).uniform_(-1.0, 1.0) b = torch.empty( # (B, 624), (B, D2), dtype=dtype, device=torch.device("cuda"), ).uniform_(-1.0, 1.0) c = torch.empty( # (B, 772), (B, D3), dtype=dtype, device=torch.device("cuda"), ).uniform_(-1.0, 1.0) tensors = [a, b, c] total_cols: int = int(a.shape[1] + b.shape[1] + c.shape[1]) def torch_copy( tensors: List[torch.Tensor], is_inplace: bool = True ) -> torch.Tensor: f = torch.zeros([B, total_cols], dtype=dtype, device=torch.device("cuda")) col_idx = 0 for t in tensors: temp = f[:, col_idx : col_idx + t.shape[1]] if is_inplace: temp.copy_(t) else: f[:, col_idx : col_idx + t.shape[1]] = t col_idx += t.shape[1] return f def torch_cat(tensors: List[torch.Tensor]) -> torch.Tensor: return torch.cat(tensors, dim=1) ref = torch_cat(tensors) real = torch_copy(tensors, is_inplace=False) torch.testing.assert_allclose(ref, real) if provider == "pt_eager": fn = lambda: torch_cat(tensors) # noqa E731 ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) return ms elif provider == "stack": def torch_stack(tensors: List[torch.Tensor]) -> torch.Tensor: return torch.stack(tensors, dim=1).view(-1, total_cols) fn = lambda: torch_stack(tensors) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) return ms elif provider == "copy": fn = lambda: torch_copy(tensors) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) return ms else: raise ValueError(f"unsupported provider: {provider}") df = bench_cat.run(print_data=True, return_df=return_result) if return_result: return configs, df if __name__ == "__main__": main() ``` and bw analysis code is from: pytorch#102815 Pull Request resolved: pytorch#150233 Approved by: https://github.com/ngimel
….cat (pytorch#150233) Enable aligned vector loading for 2 bytes datatypes in torch.cat. Specifically: 1. reduce the vector length to 8 bytes for 2-byte types (fp16, bf16 etc) 2. enable through a conditional template The reason why 8-byte vector loading was chosen for fp16 and bf16: 16-byte load results in heavier register overheads (i.e. 4 register per load for fp32 -> 8 register per load for fp16). Therefore, to employ the benefits of vectorized loading, we reduced ALIGNED_VEC_LOAD_BYTES to 8 for fp16 and bf16 ### perf testing: before: ``` torch-cat-D1-30108-D2-624-D3-772-dtype-torch.float32: B pt_eager copy 0 100.0 0.022621 0.036162 1 1000.0 0.133616 0.207051 2 10000.0 1.326848 1.848768 3 20000.0 2.744544 3.692128 torch-cat-D1-30108-D2-624-D3-772-dtype-torch.bfloat16: B pt_eager copy 0 100.0 0.022434 0.035477 1 1000.0 0.140608 0.144518 2 10000.0 1.303792 1.229584 3 20000.0 2.668288 2.436160 ``` after: ``` torch-cat-D1-30108-D2-624-D3-772-dtype-torch.float32: B pt_eager copy 0 100.0 0.022608 0.036328 1 1000.0 0.133861 0.207399 2 10000.0 1.325120 1.847136 3 20000.0 2.726528 3.693184 torch-cat-D1-30108-D2-624-D3-772-dtype-torch.bfloat16: B pt_eager copy 0 100.0 0.019942 0.035482 1 1000.0 0.084858 0.144544 2 10000.0 0.924384 1.230672 3 20000.0 1.944448 2.436480 ``` ### bw analysis: bw on fp16/bf16 got increased by 40%-50% for large tensors before: ``` Bandwidth (GB/s) for ((16384, 16384), 1) int8;fp16;fp32;int32;fp64;long|869.87|1382.74|1956.46|1952.73|1969.03|1963.66 Bandwidth (GB/s) for ((4194304,), 0) int8;fp16;fp32;int32;fp64;long|568.43|926.53|1589.20|1567.52|1771.54|1783.68 Bandwidth (GB/s) for ((16777216,), 0) int8;fp16;fp32;int32;fp64;long|752.07|1269.50|1894.86|1900.85|1954.10|1955.08 Bandwidth (GB/s) for ((33554432,), 0) int8;fp16;fp32;int32;fp64;long|807.08|1354.69|1960.48|1962.45|1972.73|1973.85 Bandwidth (GB/s) for ((134217728,), 0) int8;fp16;fp32;int32;fp64;long|864.02|1398.02|1963.43|1955.32|1963.37|1969.96 ``` after: ``` Bandwidth (GB/s) for ((16384, 16384), 1) int8;fp16;fp32;int32;fp64;long|873.08|1892.16|1954.35|1962.51|1962.03|1965.98 Bandwidth (GB/s) for ((4194304,), 0) int8;fp16;fp32;int32;fp64;long|575.13|1242.45|1576.37|1571.30|1769.94|1790.22 Bandwidth (GB/s) for ((16777216,), 0) int8;fp16;fp32;int32;fp64;long|742.92|1734.57|1887.99|1897.62|1940.99|1959.25 Bandwidth (GB/s) for ((33554432,), 0) int8;fp16;fp32;int32;fp64;long|802.60|1865.45|1952.64|1947.53|1974.47|1973.48 Bandwidth (GB/s) for ((134217728,), 0) int8;fp16;fp32;int32;fp64;long|865.32|1939.07|1965.72|1963.25|1969.06|1968.72 ``` ### Perf testing code: ``` # pyre-strict from typing import List, Optional, Tuple import click import pandas as pd import torch # @Manual=//triton:triton import triton # CUDA_VISIBLE_DEVICEs=7 buck2 run @mode/opt //scripts/zhaozhu:cat_bench @click.command() @click.option("--data-type", type=str, default="bf16") @click.option("--return-result", type=bool, default=False) def main( data_type: str, return_result: bool, ) -> Optional[Tuple[List[triton.testing.Benchmark], List[pd.DataFrame]]]: torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True if data_type == "fp32": dtype = torch.float32 elif data_type == "fp16": dtype = torch.float16 elif data_type == "bf16": dtype = torch.bfloat16 else: raise ValueError(f"Unsupported data type: {data_type}.") D1 = int(torch.randint(low=10000, high=50000, size=(1,)).item()) D2 = int(torch.randint(low=100, high=1000, size=(1,)).item()) D3 = int(torch.randint(low=500, high=1000, size=(1,)).item()) configs: List[triton.testing.Benchmark] = [ triton.testing.Benchmark( x_names=["B"], x_vals=[100, 1000, 10000, 20000], line_arg="provider", line_vals=["pt_eager", "copy"], line_names=["pt_eager", "copy"], styles=[("blue", "-"), ("green", "-"), ("red", "-")], ylabel="ms", plot_name=f"torch-cat-D1-{D1}-D2-{D2}-D3-{D3}-dtype-{dtype}", args={ "D1": D1, "D2": D2, "D3": D3, "dtype": dtype, }, ) ] @triton.testing.perf_report(configs) def bench_cat( B: int, D1: int, D2: int, D3: int, dtype: torch.dtype, provider: str, ) -> float: warmup = 10 rep = 3 tensors = [] a = torch.empty( # (B, 30108), (B, D1), dtype=dtype, device=torch.device("cuda"), ).uniform_(-1.0, 1.0) b = torch.empty( # (B, 624), (B, D2), dtype=dtype, device=torch.device("cuda"), ).uniform_(-1.0, 1.0) c = torch.empty( # (B, 772), (B, D3), dtype=dtype, device=torch.device("cuda"), ).uniform_(-1.0, 1.0) tensors = [a, b, c] total_cols: int = int(a.shape[1] + b.shape[1] + c.shape[1]) def torch_copy( tensors: List[torch.Tensor], is_inplace: bool = True ) -> torch.Tensor: f = torch.zeros([B, total_cols], dtype=dtype, device=torch.device("cuda")) col_idx = 0 for t in tensors: temp = f[:, col_idx : col_idx + t.shape[1]] if is_inplace: temp.copy_(t) else: f[:, col_idx : col_idx + t.shape[1]] = t col_idx += t.shape[1] return f def torch_cat(tensors: List[torch.Tensor]) -> torch.Tensor: return torch.cat(tensors, dim=1) ref = torch_cat(tensors) real = torch_copy(tensors, is_inplace=False) torch.testing.assert_allclose(ref, real) if provider == "pt_eager": fn = lambda: torch_cat(tensors) # noqa E731 ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) return ms elif provider == "stack": def torch_stack(tensors: List[torch.Tensor]) -> torch.Tensor: return torch.stack(tensors, dim=1).view(-1, total_cols) fn = lambda: torch_stack(tensors) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) return ms elif provider == "copy": fn = lambda: torch_copy(tensors) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) return ms else: raise ValueError(f"unsupported provider: {provider}") df = bench_cat.run(print_data=True, return_df=return_result) if return_result: return configs, df if __name__ == "__main__": main() ``` and bw analysis code is from: pytorch#102815 Pull Request resolved: pytorch#150233 Approved by: https://github.com/ngimel
Enable aligned vector loading for 2 bytes datatypes in torch.cat. Specifically:
The reason why 8-byte vector loading was chosen for fp16 and bf16:
16-byte load results in heavier register overheads (i.e. 4 register per load for fp32 -> 8 register per load for fp16). Therefore, to employ the benefits of vectorized loading, we reduced ALIGNED_VEC_LOAD_BYTES to 8 for fp16 and bf16
perf testing:
before:
after:
bw analysis:
bw on fp16/bf16 got increased by 40%-50% for large tensors
before:
after:
Perf testing code:
and bw analysis code is from: #102815