-
Notifications
You must be signed in to change notification settings - Fork 527
Expand file tree
/
Copy pathtensor_parallel.py
More file actions
262 lines (210 loc) · 7.59 KB
/
Copy pathtensor_parallel.py
File metadata and controls
262 lines (210 loc) · 7.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import os
from typing import Sequence
import torch
import torch.distributed as dist
from my_dtype_tensor_subclass import MyDTypeTensor
from torch.distributed import DeviceMesh
from torch.distributed.tensor import DTensor, Placement, Replicate, Shard
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.utils import fill_defaults
# a tensor subclass that supports tensor parallelism with DTensor
class MyDTypeTensorTP(MyDTypeTensor):
pass
implements = MyDTypeTensorTP.implements
aten = torch.ops.aten
@implements([aten._to_copy.default, aten.clone.default])
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)
@implements([aten.alias.default])
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(aten.alias)
)
@implements([aten.split.Tensor])
def _(func, types, args, kwargs):
int_data_list = func(args[0].int_data, *args[1:], **kwargs)
scale_list = func(args[0].scale, *args[1:], **kwargs)
out = [
MyDTypeTensorTP(
int_data,
scale,
transposed=args[0].transposed,
dtype=args[0].dtype,
)
for int_data, scale in zip(int_data_list, scale_list)
]
return out
@implements([aten.empty_like.default])
def _(func, types, args, kwargs):
int_data_empty_like = func(args[0].int_data, *args[1:], **kwargs)
return MyDTypeTensorTP(
int_data_empty_like,
args[0].scale,
transposed=args[0].transposed,
dtype=args[0].dtype,
)
@implements(aten.slice.Tensor)
def _(func, types, args, kwargs):
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
assert step == 1
if end >= self.shape[dim]:
end = self.shape[dim]
if dim == 0:
return return_and_correct_aliasing(
func,
args,
kwargs,
self._apply_fn_to_data(
lambda x: aten.slice.Tensor(x, dim, start, end, step)
),
)
elif dim == 1:
return MyDTypeTensorTP(
aten.slice.Tensor(self.int_data, dim, start, end, step),
self.scale.view(-1),
transposed=self.transposed,
dtype=self.dtype,
)
else:
raise NotImplementedError(
f"MyDTypeTensorTP slice with dim={dim} is not supported"
)
# this is needed for DTensor.from_local() and for flattening tensor
@implements(aten.view.default)
def _(func, types, args, kwargs):
x, shape = args
if tuple(x.shape) == tuple(shape) or (len(shape) == 1 and shape[0] == -1):
return x.__class__(x.int_data, x.scale, transposed=x.transposed, dtype=x.dtype)
raise ValueError(
f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]"
)
@implements(aten.t.default)
def _(func, types, args, kwargs):
tensor = args[0]
new = MyDTypeTensorTP(
tensor.int_data,
tensor.scale,
transposed=not tensor.transposed,
dtype=tensor.dtype,
)
return return_and_correct_aliasing(func, args, kwargs, new)
@implements(aten.addmm.default)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[1],
args[2],
args[0],
)
weight_tensor = weight_tensor.dequantize()
return aten.addmm(input_tensor, weight_tensor, bias)
@implements(aten.mm.default)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, _ = (args[0], args[1], None)
weight_tensor = weight_tensor.dequantize()
return aten.mm(input_tensor, weight_tensor)
class M(torch.nn.Module):
def __init__(self, in_features, out_features, **kwargs) -> None:
super().__init__(**kwargs)
self.linear = torch.nn.Linear(in_features, out_features, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
to_my_dtype_tp = MyDTypeTensorTP.from_float
def quantize(m: torch.nn.Module) -> torch.nn.Module:
"""
Quantize the model
"""
m.linear.weight = torch.nn.Parameter(
to_my_dtype_tp(m.linear.weight), requires_grad=False
)
return m
def shard(
full_tensor: torch.Tensor,
device_mesh: DeviceMesh,
placements: Sequence[Placement],
) -> DTensor:
"""
Add a shard function to simplify both colwise_shard and rowwise_shard. The
shard function accepts a full tensor, and returns a DTensor based on
indicated placements. Goal is to move the shard function as a static method
of DTensor, e.g.
dtensor = DTensor.shard(full_tensor, device_mesh, placement)
"""
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
shape, offset = compute_local_shape_and_global_offset(
full_tensor.shape, device_mesh, placements
)
slices = [
slice(cur_offset, cur_offset + cur_shape)
for cur_shape, cur_offset in zip(shape, offset)
]
local_tensor = full_tensor[slices]
return DTensor.from_local(local_tensor, device_mesh, placements)
def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
"""
Shard linear layer of the model in column-wise fashion
"""
# Column-wise is wrt to A^T, so for A it is row-wise.
orig_weight = m.linear.weight
# Construct DTensor from local shard
dtensor = shard(orig_weight, mesh, [Shard(0)])
# Replace parameter in module
m.linear.weight = torch.nn.Parameter(dtensor, requires_grad=False)
return m
def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
"""
Shard linear layer of the model in row-wise fashion
"""
# Row-wise is wrt to A^T, so for A it is column-wise.
orig_weight = m.linear.weight
# Construct DTensor from local shard
dtensor = shard(orig_weight, mesh, [Shard(1)])
# Replace parameter in module
m.linear.weight = torch.nn.Parameter(dtensor, requires_grad=False)
return m
########
# Test #
########
def main():
# To make sure different ranks create the same module
torch.manual_seed(5)
# Get rank and device
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
# Original model
proj_up = M(1024, 2048).to(device)
proj_dn = M(2048, 1024).to(device)
example_input = 100 * torch.randn(128, 1024, device=device)
proj_dn(proj_up(example_input))
# Quantize the model
up_quant = quantize(proj_up)
dn_quant = quantize(proj_dn)
dn_quant(up_quant(example_input))
print("Quantization works!")
# Create a device mesh
dist.init_process_group(backend="nccl")
mesh = dist.init_device_mesh("cuda", (world_size,))
# Shard the models
up_dist = colwise_shard(up_quant, mesh)
dn_dist = rowwise_shard(dn_quant, mesh)
# We need to turn inputs into DTensor form as well -- just a format change
input_dtensor = DTensor.from_local(example_input, mesh, [Replicate()])
y_d = dn_dist(up_dist(input_dtensor))
print("Distributed result:", y_d)
print("Distributed works!")
up_compiled = torch.compile(up_dist)
y_up = up_compiled(input_dtensor)
dn_compiled = torch.compile(dn_dist)
y_dn = dn_compiled(y_up)
print("compiled result:", y_dn)
print("torch.compile works!")
dist.destroy_process_group()
if __name__ == "__main__":
main()