Skip to content

sharding empty uses double peak mem and residual mem #13513

@chenyuxyz

Description

@chenyuxyz
from tinygrad import Tensor, Device
ds = tuple(f"{Device.DEFAULT}:{i}" for i in range(8))
t = Tensor.empty(10**9).shard(ds, 0).realize()
scheduled 16 kernels in 6.19 ms (221 uops in cache)
*** AMD        1 E_1953125_16_4                                 arg  2 mem   4.50 GB tm    478.56us/     0.48ms (      0 GFLOPS 2090|2090   GB/s) 
*** AMD        2 E_1953125_16_4n1                               arg  2 mem   5.00 GB tm    483.28us/     0.96ms (      0 GFLOPS 2069|2069   GB/s) 
*** AMD        3 E_1953125_16_4n2                               arg  2 mem   5.50 GB tm    485.12us/     1.45ms (      0 GFLOPS 2061|2061   GB/s) 
*** AMD        4 E_1953125_16_4n3                               arg  2 mem   6.00 GB tm    486.72us/     1.93ms (      0 GFLOPS 2055|2055   GB/s) 
*** AMD        5 E_1953125_16_4n4                               arg  2 mem   6.50 GB tm    486.88us/     2.42ms (      0 GFLOPS 2054|2054   GB/s) 
*** AMD        6 E_1953125_16_4n5                               arg  2 mem   7.00 GB tm    489.72us/     2.91ms (      0 GFLOPS 2042|2042   GB/s) 
*** AMD        7 E_1953125_16_4n6                               arg  2 mem   7.50 GB tm    487.40us/     3.40ms (      0 GFLOPS 2052|2052   GB/s) 
*** AMD        8 E_1953125_16_4n7                               arg  2 mem   8.00 GB tm    492.20us/     3.89ms (      0 GFLOPS 2032|2032   GB/s) 
*** AMD        9 E_1953125_16_4n8                               arg  2 mem   4.50 GB tm    491.52us/     4.38ms (      0 GFLOPS 2035|2035   GB/s) 
*** AMD:1     10 xfer  500.00M,   AMD:1 <- AMD                  arg  2 mem   4.50 GB tm   8730.90us/    13.11ms (      0 GFLOPS   57|57     GB/s) 
*** AMD:2     11 xfer  500.00M,   AMD:2 <- AMD                  arg  2 mem   4.50 GB tm   8731.92us/    21.84ms (      0 GFLOPS   57|57     GB/s) 
*** AMD:3     12 xfer  500.00M,   AMD:3 <- AMD                  arg  2 mem   4.50 GB tm   8710.56us/    30.55ms (      0 GFLOPS   57|57     GB/s) 
*** AMD:4     13 xfer  500.00M,   AMD:4 <- AMD                  arg  2 mem   4.50 GB tm   8702.11us/    39.26ms (      0 GFLOPS   57|57     GB/s) 
*** AMD:5     14 xfer  500.00M,   AMD:5 <- AMD                  arg  2 mem   4.50 GB tm   8687.91us/    47.94ms (      0 GFLOPS   58|58     GB/s) 
*** AMD:6     15 xfer  500.00M,   AMD:6 <- AMD                  arg  2 mem   4.50 GB tm   8707.93us/    56.65ms (      0 GFLOPS   57|57     GB/s) 
*** AMD:7     16 xfer  500.00M,   AMD:7 <- AMD                  arg  2 mem   4.50 GB tm   8703.59us/    65.36ms (      0 GFLOPS   57|57     GB/s) 

because it creates empty on device 0 first then copies each shard, the peak mem is 8GB and ends with 4.5GB. this should creates empty on each device and uses 4GB total

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions