Skip to content

[Distributed][auto-parallel] Automatic Partitioning#344

Merged
soodoshll merged 9 commits intohidet-org:auto-parallelfrom
soodoshll:auto-parallel-2
Aug 22, 2023
Merged

[Distributed][auto-parallel] Automatic Partitioning#344
soodoshll merged 9 commits intohidet-org:auto-parallelfrom
soodoshll:auto-parallel-2

Conversation

@soodoshll
Copy link
Copy Markdown
Collaborator

This PR contains:

  1. An ILP solver to find the optimal partitioning plan for the given model, minimizing the communication cost while ensuring the parameters not exceed the memory budget;
  2. hidet.distributed.partition, which partitions the whole flow graph according to the plan given by 1), and save the partitions to disk;
  3. hidet.distributed.load_partition, which loads partition from the disk to the desired device;
  4. a launch script
  5. a small resnet example(example/distributed/resnet.py). Which can be run by python -m hidet.distributed.launch [n_gpus] resnet.py. The memory budget has been set as 24MiB, which is less than the weight size of a resnet18 model. It can help test tensor parallel

Several issues I plan to solve in the future

  1. ILP solving is slow for large models like llama. We need to coalesce nodes before sending the graph to ILP;
  2. Multi processes on one machine running compilation in parallel often trigger conflicts in local filesystem. We should implement a locking mechanism in the future.

@soodoshll
Copy link
Copy Markdown
Collaborator Author

@yaoyaoding @xinli-git this pr is ready for review

x = hidet.zeros([32, 3, 224, 224], device='cuda')
opt_graph = hidet.graph.optimize(flow_graph)
compiled = opt_graph.build()
print(compiled(x))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider adding all_close with y_truth ?

parser.add_argument('script_args', nargs=argparse.REMAINDER)
args = parser.parse_args()

procs = []
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be possible to follow this convention:
https://github.com/hidet-org/hidet/blob/main/setup.py#L40 ? make it a cli command

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current launch method is fine to me. Another option is to add a sub-command like

$ hidet dist launch resnet.py

from .rule import op_shard_rule_search
from .shard import OpShardSpec, TensorShardSpec, connect, node_comm_cost

# I copied it from compiled_graph.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: possibly remove this

for node in tqdm.tqdm(g.nodes):
node_str = str(node)
if node_str not in cache:
cache[node_str] = op_shard_rule_search(node, num_shards)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If node_str can uniquely identify the sharding rule, maybe cleaner to add lru_cache decorator to op_shard_rule_search function?

sharded = xsum(p_vars)
param_mem += (num_shards - ((num_shards - 1) * sharded)) * (p.nbytes // num_shards)
param_tot += p.nbytes
print(f"Total paramter size: {param_tot/1024**3} GiB")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print -> logging.info

logger = logging.Logger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler()) like here

https://github.com/hidet-org/hidet/blob/main/python/hidet/drivers/build_task.py

requirements.txt Outdated
requests

# for auto-parallelization
mip No newline at end of file
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend that we make the distributed feature into this in setup.py

    extras_require={
        'distributed': ['filelock', 'mip', ...],
    },

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree with Xin that it is better to put the extra dependency like 'mip' to extra and require the user to install via pip install hidet[distributed].

Copy link
Copy Markdown
Member

@yaoyaoding yaoyaoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I roughly go over all the code and it looks good to me. Thanks @soodoshll !

parser.add_argument('script_args', nargs=argparse.REMAINDER)
args = parser.parse_args()

procs = []
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current launch method is fine to me. Another option is to add a sub-command like

$ hidet dist launch resnet.py

requirements.txt Outdated
requests

# for auto-parallelization
mip No newline at end of file
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree with Xin that it is better to put the extra dependency like 'mip' to extra and require the user to install via pip install hidet[distributed].

Copy link
Copy Markdown
Contributor

@xinli-git xinli-git left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! @soodoshll , looking forward to try it out on some LLMs :)

@soodoshll soodoshll merged commit a12680d into hidet-org:auto-parallel Aug 22, 2023
vadiklyutiy pushed a commit that referenced this pull request Jul 22, 2024
vadiklyutiy pushed a commit that referenced this pull request Jul 23, 2024
vadiklyutiy pushed a commit that referenced this pull request Dec 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants