Skip to content

Commit 364ba9c

Browse files
KTong821vadiklyutiy
authored andcommitted
[Graph] Add major UNet building components (#97)
Add UNet Down, Up, and Mid block definitions and attention transformer utility layer. Modules are designed so that kwargs passed to constructors are all the same config from huggingface with minimal changes - lots of shared values and too many parameters to list individually. Same kwargs are passed to nested objects. Open to other suggestions, although this is a single use case problem. Towards #57.
1 parent cdff99a commit 364ba9c

2 files changed

Lines changed: 378 additions & 0 deletions

File tree

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
from typing import Optional
2+
from hidet.graph import nn, ops
3+
from hidet.graph.tensor import Tensor
4+
5+
6+
class FeedForward(nn.Module[Tensor]):
7+
def __init__(
8+
self,
9+
dim: int,
10+
dim_out: Optional[int] = None,
11+
mult: int = 4,
12+
activation_fn: str = "geglu",
13+
inner_dim: Optional[int] = None,
14+
bias: bool = True,
15+
):
16+
super().__init__()
17+
if inner_dim is None:
18+
inner_dim = int(dim * mult)
19+
dim_out = dim_out if dim_out is not None else dim
20+
21+
if activation_fn != "geglu":
22+
raise NotImplementedError("Expected geglu for feedforward activation.")
23+
24+
act_fn = nn.Geglu(dim, inner_dim, bias=bias)
25+
26+
self.net = []
27+
self.net.append(act_fn)
28+
self.net.append(nn.Identity()) # replaces dropout
29+
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
30+
self.net = nn.Sequential(self.net)
31+
32+
def forward(self, x) -> Tensor:
33+
return self.net(x)
34+
35+
36+
class BasicTransformerBlock(nn.Module[Tensor]):
37+
def __init__(self, dim: int, **kwargs):
38+
super().__init__()
39+
40+
self.norm1 = nn.LayerNorm(dim)
41+
self.attn1 = nn.CrossAttention(
42+
dim, heads=kwargs["num_attention_heads"], dim_head=kwargs["attention_head_dim"], upcast=True, out_bias=True
43+
)
44+
45+
self.norm2 = nn.LayerNorm(dim)
46+
self.attn2 = nn.CrossAttention(
47+
dim,
48+
cross_attention_dim=kwargs["cross_attention_dim"],
49+
heads=kwargs["num_attention_heads"],
50+
dim_head=kwargs["attention_head_dim"],
51+
upcast=True,
52+
out_bias=True,
53+
)
54+
55+
self.norm3 = nn.LayerNorm(dim)
56+
self.ff = FeedForward(dim, activation_fn="geglu", bias=True)
57+
58+
def forward(self, hidden_states: Tensor, encoder_hidden_states: Tensor, temperature_scaling: float = 1.0) -> Tensor:
59+
norm_hidden_states = self.norm1(hidden_states)
60+
61+
attn_output = self.attn1(norm_hidden_states, temperature_scaling=temperature_scaling)
62+
63+
hidden_states = attn_output + hidden_states
64+
if len(hidden_states.shape) == 4:
65+
hidden_states = hidden_states.squeeze(1)
66+
67+
norm_hidden_states = self.norm2(hidden_states)
68+
69+
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
70+
71+
hidden_states = attn_output + hidden_states
72+
73+
norm_hidden_states = self.norm3(hidden_states)
74+
75+
ff_output = self.ff(norm_hidden_states)
76+
hidden_states = ff_output + hidden_states
77+
if len(hidden_states.shape) == 4:
78+
hidden_states = hidden_states.squeeze(1)
79+
80+
return hidden_states
81+
82+
83+
class Transformer2DModel(nn.Module[Tensor]):
84+
def __init__(self, **kwargs):
85+
super().__init__()
86+
87+
inner_dim = kwargs["num_attention_heads"] * kwargs["attention_head_dim"]
88+
self.use_linear_projection = kwargs["use_linear_projection"]
89+
90+
self.norm = nn.GroupNorm(kwargs["resnet_groups"], kwargs["input_channels"], eps=1e-6, affine=True)
91+
if kwargs["use_linear_projection"]:
92+
self.proj_in = nn.Linear(kwargs["input_channels"], inner_dim)
93+
else:
94+
self.proj_in = nn.Conv2d(kwargs["input_channels"], inner_dim, kernel_size=1)
95+
96+
self.transformer_blocks = nn.ModuleList(
97+
[BasicTransformerBlock(inner_dim, **kwargs) for _ in range(kwargs["num_layers"])]
98+
)
99+
100+
self.output_channels = (
101+
kwargs["input_channels"] if kwargs.get("output_channels", None) is None else kwargs["output_channels"]
102+
)
103+
104+
if kwargs["use_linear_projection"]:
105+
self.proj_out = nn.Linear(inner_dim, kwargs["input_channels"])
106+
else:
107+
self.proj_out = nn.Conv2d(inner_dim, kwargs["input_channels"], kernel_size=1)
108+
109+
def forward(self, hidden_states: Tensor, encoder_hidden_states: Tensor, temperature_scaling: float = 1.0) -> Tensor:
110+
bs, _, h, w = hidden_states.shape
111+
residuals = hidden_states
112+
hidden_states = self.norm(hidden_states)
113+
114+
def compress_hidden_states(x):
115+
return ops.permute_dims(x, (0, 2, 3, 1)).reshape((bs, h * w, x.shape[1]))
116+
117+
def decompress_hidden_states(x):
118+
return ops.permute_dims(x.reshape((bs, h, w, inner_dim)), (0, 3, 1, 2)).contiguous()
119+
120+
if not self.use_linear_projection:
121+
hidden_states = self.proj_in(hidden_states)
122+
inner_dim = hidden_states.shape[1]
123+
hidden_states = compress_hidden_states(hidden_states)
124+
else:
125+
inner_dim = hidden_states.shape[1]
126+
hidden_states = compress_hidden_states(hidden_states)
127+
hidden_states = self.proj_in(hidden_states)
128+
129+
for block in self.transformer_blocks:
130+
hidden_states = block(hidden_states, encoder_hidden_states, temperature_scaling=temperature_scaling)
131+
132+
if not self.use_linear_projection:
133+
hidden_states = decompress_hidden_states(hidden_states)
134+
hidden_states = self.proj_out(hidden_states)
135+
else:
136+
hidden_states = self.proj_out(hidden_states)
137+
hidden_states = decompress_hidden_states(hidden_states)
138+
139+
return hidden_states + residuals
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
from typing import Tuple
2+
from hidet import nn
3+
from hidet.apps.diffusion.modeling.stable_diffusion.downsample import Downsample2D
4+
from hidet.apps.diffusion.modeling.stable_diffusion.resnet_blocks import ResnetBlock2D
5+
from hidet.apps.diffusion.modeling.stable_diffusion.transformer_blocks import Transformer2DModel
6+
from hidet.apps.diffusion.modeling.stable_diffusion.upsample import Upsample2D
7+
from hidet.graph.tensor import Tensor
8+
from hidet.graph.ops import concat
9+
10+
11+
class CrossAttnDownBlock2D(nn.Module[Tensor]):
12+
def __init__(self, **kwargs):
13+
super().__init__()
14+
self.has_cross_attention = True
15+
self.resnets = []
16+
self.attentions = []
17+
18+
transformer_layers_per_block = kwargs["transformer_layers_per_block"]
19+
num_layers = kwargs["num_layers"]
20+
21+
if isinstance(transformer_layers_per_block, int):
22+
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
23+
24+
for i in range(num_layers):
25+
input_channels = kwargs["input_channels"] if i == 0 else kwargs["output_channels"]
26+
self.resnets.append(ResnetBlock2D(**{**kwargs, "input_channels": input_channels}))
27+
self.attentions.append(
28+
Transformer2DModel(
29+
**{
30+
**kwargs,
31+
"attention_head_dim": kwargs["output_channels"] // kwargs["num_attention_heads"],
32+
"input_channels": kwargs["output_channels"],
33+
"num_layers": transformer_layers_per_block[i],
34+
}
35+
)
36+
)
37+
38+
self.resnets = nn.ModuleList(self.resnets)
39+
self.attentions = nn.ModuleList(self.attentions)
40+
41+
if kwargs["add_downsample"]:
42+
self.downsamplers = nn.ModuleList([Downsample2D(kwargs["output_channels"], **kwargs)])
43+
else:
44+
self.downsamplers = None
45+
46+
def forward(self, hidden_states: Tensor, temb: Tensor, encoder_hidden_states: Tensor) -> Tensor:
47+
output_states = ()
48+
blocks = list(zip(self.resnets, self.attentions))
49+
50+
for resnet, attn in blocks:
51+
hidden_states = resnet(hidden_states, temb)
52+
hidden_states = attn(hidden_states, encoder_hidden_states)
53+
54+
output_states += (hidden_states,)
55+
56+
if self.downsamplers is not None:
57+
for downsampler in self.downsamplers:
58+
hidden_states = downsampler(hidden_states)
59+
60+
output_states += (hidden_states,)
61+
62+
return hidden_states, output_states
63+
64+
65+
class DownBlock2D(nn.Module[Tensor]):
66+
def __init__(self, **kwargs):
67+
super().__init__()
68+
self.has_cross_attention = False
69+
self.resnets = []
70+
71+
for i in range(kwargs["num_layers"]):
72+
input_channels = kwargs["input_channels"] if i == 0 else kwargs["output_channels"]
73+
self.resnets.append(ResnetBlock2D(**{**kwargs, "input_channels": input_channels}))
74+
75+
self.resnets = nn.ModuleList(self.resnets)
76+
if kwargs["add_downsample"]:
77+
self.downsamplers = nn.ModuleList([Downsample2D(kwargs["output_channels"], **kwargs)])
78+
else:
79+
self.downsamplers = None
80+
81+
def forward(self, hidden_states: Tensor, temb: Tensor) -> Tensor:
82+
output_states = ()
83+
84+
for resnet in self.resnets:
85+
hidden_states = resnet(hidden_states, temb)
86+
output_states = output_states + (hidden_states,)
87+
88+
if self.downsamplers is not None:
89+
for downsampler in self.downsamplers:
90+
hidden_states = downsampler(hidden_states)
91+
92+
output_states = output_states + (hidden_states,)
93+
94+
return hidden_states, output_states
95+
96+
97+
class MidBlock2DCrossAttn(nn.Module[Tensor]):
98+
def __init__(self, **kwargs):
99+
super().__init__()
100+
101+
self.has_cross_attention = True
102+
103+
transformer_layers_per_block = kwargs["transformer_layers_per_block"]
104+
if isinstance(kwargs["transformer_layers_per_block"], int):
105+
transformer_layers_per_block = [transformer_layers_per_block] * kwargs["num_layers"]
106+
107+
self.resnets = [ResnetBlock2D(**{**kwargs, "input_channels": kwargs["input_channels"]})]
108+
self.attentions = []
109+
110+
for i in range(kwargs["num_layers"]):
111+
self.attentions.append(
112+
Transformer2DModel(
113+
**{
114+
**kwargs,
115+
"attention_head_dim": kwargs["input_channels"] // kwargs["num_attention_heads"],
116+
"input_channels": kwargs["input_channels"],
117+
"num_layers": transformer_layers_per_block[i],
118+
}
119+
)
120+
)
121+
122+
self.resnets.append(ResnetBlock2D(**{**kwargs, "input_channels": kwargs["input_channels"]}))
123+
124+
self.resnets = nn.ModuleList(self.resnets)
125+
self.attentions = nn.ModuleList(self.attentions)
126+
127+
def forward(self, hidden_states: Tensor, temb: Tensor, encoder_hidden_states: Tensor) -> Tensor:
128+
hidden_states = self.resnets[0](hidden_states, temb)
129+
130+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
131+
hidden_states = attn(hidden_states, encoder_hidden_states)
132+
hidden_states = resnet(hidden_states, temb)
133+
134+
return hidden_states
135+
136+
137+
class CrossAttnUpBlock2D(nn.Module[Tensor]):
138+
def __init__(self, **kwargs):
139+
super().__init__()
140+
self.has_cross_attention = True
141+
num_layers = kwargs["num_layers"]
142+
143+
transformer_layers_per_block = kwargs["transformer_layers_per_block"]
144+
if isinstance(transformer_layers_per_block, int):
145+
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
146+
147+
self.resnets = []
148+
self.attentions = []
149+
for i in range(num_layers):
150+
res_skip_channels = kwargs["input_channels"] if (i == num_layers - 1) else kwargs["output_channels"]
151+
resnet_in_channels = kwargs["prev_output_channel"] if i == 0 else kwargs["output_channels"]
152+
input_channels = resnet_in_channels + res_skip_channels
153+
154+
self.resnets.append(ResnetBlock2D(**{**kwargs, "input_channels": input_channels}))
155+
156+
self.attentions.append(
157+
Transformer2DModel(
158+
**{
159+
**kwargs,
160+
"attention_head_dim": kwargs["output_channels"] // kwargs["num_attention_heads"],
161+
"input_channels": kwargs["output_channels"],
162+
"num_layers": transformer_layers_per_block[i],
163+
}
164+
)
165+
)
166+
167+
self.resnets = nn.ModuleList(self.resnets)
168+
self.attentions = nn.ModuleList(self.attentions)
169+
170+
if kwargs["add_upsample"]:
171+
self.upsamplers = nn.ModuleList([Upsample2D(kwargs["output_channels"], **kwargs)])
172+
else:
173+
self.upsamplers = None
174+
175+
def forward(
176+
self,
177+
hidden_states: Tensor,
178+
res_hidden_states_tuple: Tuple[Tensor],
179+
temb: Tensor,
180+
upsample_size: int,
181+
encoder_hidden_states: Tensor,
182+
is_final_block=False,
183+
) -> Tensor:
184+
for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
185+
res_hidden_states = res_hidden_states_tuple[-1]
186+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
187+
188+
hidden_states = concat([hidden_states, res_hidden_states], axis=1)
189+
190+
hidden_states = resnet(hidden_states, temb)
191+
hidden_states = attn(
192+
hidden_states,
193+
encoder_hidden_states=encoder_hidden_states,
194+
temperature_scaling=2 if is_final_block and i == 1 else 1,
195+
)
196+
197+
if self.upsamplers is not None:
198+
for upsampler in self.upsamplers:
199+
hidden_states = upsampler(hidden_states, upsample_size)
200+
201+
return hidden_states
202+
203+
204+
class UpBlock2D(nn.Module[Tensor]):
205+
def __init__(self, **kwargs):
206+
super().__init__()
207+
self.has_cross_attention = False
208+
self.resnets = []
209+
210+
for i in range(kwargs["num_layers"]):
211+
res_skip_channels = (
212+
kwargs["input_channels"] if (i == kwargs["num_layers"] - 1) else kwargs["output_channels"]
213+
)
214+
resnet_input_channels = kwargs["prev_output_channel"] if i == 0 else kwargs["output_channels"]
215+
input_channels = res_skip_channels + resnet_input_channels
216+
217+
self.resnets.append(ResnetBlock2D(**{**kwargs, "input_channels": input_channels}))
218+
219+
self.resnets = nn.ModuleList(self.resnets)
220+
if kwargs["add_upsample"]:
221+
self.upsamplers = nn.ModuleList([Upsample2D(kwargs["output_channels"], **kwargs)])
222+
else:
223+
self.upsamplers = None
224+
225+
def forward(
226+
self, hidden_states: Tensor, res_hidden_states_tuple: Tuple[Tensor], temb: Tensor, upsample_size: int
227+
) -> Tensor:
228+
for resnet in self.resnets:
229+
res_hidden_states = res_hidden_states_tuple[-1]
230+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
231+
232+
hidden_states = concat([hidden_states, res_hidden_states], axis=1)
233+
hidden_states = resnet(hidden_states, temb)
234+
235+
if self.upsamplers is not None:
236+
for upsampler in self.upsamplers:
237+
hidden_states = upsampler(hidden_states, upsample_size)
238+
239+
return hidden_states

0 commit comments

Comments
 (0)