[AIR] Add distributed torch_geometric example#23580
[AIR] Add distributed torch_geometric example#23580amogkam merged 13 commits intoray-project:masterfrom
torch_geometric example#23580Conversation
…h-geometric-examples
python/ray/ml/examples/pytorch_geometric/distributed_reddit_example.py
Outdated
Show resolved
Hide resolved
| ) | ||
|
|
||
| # Disable distributed sampler since the train_loader has already been split above. | ||
| train_loader = train.torch.prepare_data_loader(train_loader, add_dist_sampler=False) |
There was a problem hiding this comment.
dumb question, why do we do split separately? Instead of combined in prepare_data_loader?
There was a problem hiding this comment.
You need to use torch geometric's NeighborSampler for sampling subgraphs from the overall graph, instead of the standard DistributedSampler
| return x.log_softmax(dim=-1) | ||
|
|
||
| @torch.no_grad() | ||
| def inference(self, x_all, subgraph_loader): |
There was a problem hiding this comment.
Is this planned to be used for predictor impl?
There was a problem hiding this comment.
Eventually yes, but the challenge for prediction is how to add "fresh data" to the graph to do inference on.
| scaling_config={"num_workers": num_workers, "use_gpu": use_gpu}, | ||
| ) | ||
| result = trainer.fit() | ||
| print(result.metrics) |
There was a problem hiding this comment.
what does prediction look like?
There was a problem hiding this comment.
Prediction is not supported for now- we need to be able to add "fresh data" to the existing graph and then re-run the inference algorithm on the new data.
| torch==1.9.0+cu111 | ||
| torchvision==0.10.0+cu111 | ||
|
|
||
| -f https://data.pyg.org/whl/torch-1.9.0+cu111.html |
There was a problem hiding this comment.
curious, what is this for?
There was a problem hiding this comment.
These are required dependencies for pytorch geometric
python/ray/ml/examples/pytorch_geometric/distributed_reddit_example.py
Outdated
Show resolved
Hide resolved
| self.convs.append(SAGEConv(hidden_channels, out_channels)) | ||
|
|
||
| def forward(self, x, adjs): | ||
| for i, (edge_index, _, size) in enumerate(adjs): |
There was a problem hiding this comment.
can you comment a bit about the format of this adjs matrix?
especially 1. what does size mean in this context? and 2. how do we make sure there are always enough hidden layers to handle the adjacency links in adjs?
There was a problem hiding this comment.
Added a comment here- but more information are in the torch geometric docs.
For 2, we pass in a sizes list to the NeighborSampler, so the size of this list should match with the number of layers in the model.
| x = F.relu(x) | ||
| xs.append(x.cpu()) | ||
|
|
||
| x_all = torch.cat(xs, dim=0) |
There was a problem hiding this comment.
looks a bit weird to me. I think I am just clueless.
if we overwrite the entire x_all here, we will only have features for the nodes that we scored with the last layer.
it feels more appropriate to me to update the weights of xs in x_all, not simply "x_all = ..." ??
There was a problem hiding this comment.
This is just the inference code so no weights updating. I think this works the same way as a standard feed-forward neural network. We only want the output of the last layer, and we don't care about the hidden states during inference.
There was a problem hiding this comment.
ok, I understand this now. subgraph_loader actually samples a subgraph for every single node in the graph.
so if there are n nodes in the graph, the inner loop will run n times. each time, we essentially aggregating data from all neighboring nodes to this specific node.
so at the end, torch.cat(xs) will give us a new updated graph, since xs will contain data for every single node at the point.
interesting design.
|
|
||
| -f https://data.pyg.org/whl/torch-1.9.0+cu111.html |
There was a problem hiding this comment.
Do we need to make these changes to requirements_dl.txt (line 6 above)?
There was a problem hiding this comment.
Since there's only a GPU test, I think it should be fine for now
There was a problem hiding this comment.
Oh but doesn't that make the instruction in line 6 no longer true? Do we actually want these to be in CPU docker as well? Alternative solution: move these above that line.
There was a problem hiding this comment.
Updated the comment to reflect the changes!
python/ray/ml/examples/pytorch_geometric/distributed_reddit_example.py
Outdated
Show resolved
Hide resolved
.buildkite/pipeline.gpu.large.yml
Outdated
| commands: | ||
| - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT | ||
| - DATA_PROCESSING_TESTING=1 TRAIN_TESTING=1 TUNE_TESTING=1 ./ci/travis/install-dependencies.sh | ||
| - DATA_PROCESSING_TESTING=1 TRAIN_TESTING=1 TUNE_TESTING=1 PYTHON=3.7 ./ci/travis/install-dependencies.sh |
There was a problem hiding this comment.
For my learning, is this needed?
There was a problem hiding this comment.
Torch geometric does not support python 3.6.
We could just make a separate build just for 3.7, but I thought it would be better to just upgrade everything to 3.7 since this is what we do for Tune anyways currently.
There was a problem hiding this comment.
Oh wait isn't the default value 3.7?
There was a problem hiding this comment.
No it's 3.6 I believe.
There was a problem hiding this comment.
Ah I believe it was updated for GPU images here
But similar to my comment on that PR, having it explicit makes sense (in case we change default version in the future)
There was a problem hiding this comment.
Ah got it. Actually there are versions of torch geometric that support python 3.6, but the later versions don't. But in any case, it's fine to have this be explicit?
…ample.py Co-authored-by: matthewdeng <matthew.j.deng@gmail.com>
…h-geometric-examples
gjoliver
left a comment
There was a problem hiding this comment.
sorry about the delay, have a few minor questions/comments.
| # the creator stream. | ||
| for i in item: | ||
| i.record_stream(curr_stream) | ||
| if isinstance(i, torch.Tensor): |
There was a problem hiding this comment.
can you comment what may show up here as well, and why you need this if statement now?
There was a problem hiding this comment.
The pytorch dataloader can actually just return a batch of anything. In all of our examples and tests so far our data loaders return batches of tensors, but in this case, the torch geometric data loader also returns batch size, node id, etc., which are not all tensors.
| # Use 10% of nodes for validation and 10% for testing. | ||
| fake_dataset = FakeDataset(transform=RandomNodeSplit(num_val=0.1, num_test=0.1)) | ||
|
|
||
| def gen_dataset(): |
There was a problem hiding this comment.
feels a little unnecessary.
why don[t we simply return fake_dataset here, and below in the configuration, we say "dataset_fn": gen_fake_dataset?
There was a problem hiding this comment.
Good point 😅. Made a follow up PR here #24080!
| def inference(self, x_all, subgraph_loader): | ||
| for i in range(self.num_layers): | ||
| xs = [] | ||
| for batch_size, n_id, adj in subgraph_loader: |
There was a problem hiding this comment.
actually reading this again now, I am still a bit curious how should a user use this inference call.
this will only work if subgraph_loader iterates through all nodes in a graph. so:
- how does a user construct such a subgraph loader?
- is it really a common case that someone would want to score an entire graph?
There was a problem hiding this comment.
I think the intent is to use this just for validation and testing and not for actual live predictions.
We will need to figure out the inference/prediction story more later. This was copied over from the example on torch geometric, but let me rename this to "test" to make this more clear.
Add example for distributed pytorch geometric (graph learning) with Ray AIR
This only showcases distributed training, but with data small enough that it can be loaded in by each training worker individually. Distributed data ingest is out of scope for this PR.
Why are these changes needed?
Related issue number
Checks
scripts/format.shto lint the changes in this PR.