|
| 1 | +from typing import Tuple |
| 2 | +from hidet import nn |
| 3 | +from hidet.apps.diffusion.modeling.stable_diffusion.downsample import Downsample2D |
| 4 | +from hidet.apps.diffusion.modeling.stable_diffusion.resnet_blocks import ResnetBlock2D |
| 5 | +from hidet.apps.diffusion.modeling.stable_diffusion.transformer_blocks import Transformer2DModel |
| 6 | +from hidet.apps.diffusion.modeling.stable_diffusion.upsample import Upsample2D |
| 7 | +from hidet.graph.tensor import Tensor |
| 8 | +from hidet.graph.ops import concat |
| 9 | + |
| 10 | + |
| 11 | +class CrossAttnDownBlock2D(nn.Module[Tensor]): |
| 12 | + def __init__(self, **kwargs): |
| 13 | + super().__init__() |
| 14 | + self.has_cross_attention = True |
| 15 | + self.resnets = [] |
| 16 | + self.attentions = [] |
| 17 | + |
| 18 | + transformer_layers_per_block = kwargs["transformer_layers_per_block"] |
| 19 | + num_layers = kwargs["num_layers"] |
| 20 | + |
| 21 | + if isinstance(transformer_layers_per_block, int): |
| 22 | + transformer_layers_per_block = [transformer_layers_per_block] * num_layers |
| 23 | + |
| 24 | + for i in range(num_layers): |
| 25 | + input_channels = kwargs["input_channels"] if i == 0 else kwargs["output_channels"] |
| 26 | + self.resnets.append(ResnetBlock2D(**{**kwargs, "input_channels": input_channels})) |
| 27 | + self.attentions.append( |
| 28 | + Transformer2DModel( |
| 29 | + **{ |
| 30 | + **kwargs, |
| 31 | + "attention_head_dim": kwargs["output_channels"] // kwargs["num_attention_heads"], |
| 32 | + "input_channels": kwargs["output_channels"], |
| 33 | + "num_layers": transformer_layers_per_block[i], |
| 34 | + } |
| 35 | + ) |
| 36 | + ) |
| 37 | + |
| 38 | + self.resnets = nn.ModuleList(self.resnets) |
| 39 | + self.attentions = nn.ModuleList(self.attentions) |
| 40 | + |
| 41 | + if kwargs["add_downsample"]: |
| 42 | + self.downsamplers = nn.ModuleList([Downsample2D(kwargs["output_channels"], **kwargs)]) |
| 43 | + else: |
| 44 | + self.downsamplers = None |
| 45 | + |
| 46 | + def forward(self, hidden_states: Tensor, temb: Tensor, encoder_hidden_states: Tensor) -> Tensor: |
| 47 | + output_states = () |
| 48 | + blocks = list(zip(self.resnets, self.attentions)) |
| 49 | + |
| 50 | + for resnet, attn in blocks: |
| 51 | + hidden_states = resnet(hidden_states, temb) |
| 52 | + hidden_states = attn(hidden_states, encoder_hidden_states) |
| 53 | + |
| 54 | + output_states += (hidden_states,) |
| 55 | + |
| 56 | + if self.downsamplers is not None: |
| 57 | + for downsampler in self.downsamplers: |
| 58 | + hidden_states = downsampler(hidden_states) |
| 59 | + |
| 60 | + output_states += (hidden_states,) |
| 61 | + |
| 62 | + return hidden_states, output_states |
| 63 | + |
| 64 | + |
| 65 | +class DownBlock2D(nn.Module[Tensor]): |
| 66 | + def __init__(self, **kwargs): |
| 67 | + super().__init__() |
| 68 | + self.has_cross_attention = False |
| 69 | + self.resnets = [] |
| 70 | + |
| 71 | + for i in range(kwargs["num_layers"]): |
| 72 | + input_channels = kwargs["input_channels"] if i == 0 else kwargs["output_channels"] |
| 73 | + self.resnets.append(ResnetBlock2D(**{**kwargs, "input_channels": input_channels})) |
| 74 | + |
| 75 | + self.resnets = nn.ModuleList(self.resnets) |
| 76 | + if kwargs["add_downsample"]: |
| 77 | + self.downsamplers = nn.ModuleList([Downsample2D(kwargs["output_channels"], **kwargs)]) |
| 78 | + else: |
| 79 | + self.downsamplers = None |
| 80 | + |
| 81 | + def forward(self, hidden_states: Tensor, temb: Tensor) -> Tensor: |
| 82 | + output_states = () |
| 83 | + |
| 84 | + for resnet in self.resnets: |
| 85 | + hidden_states = resnet(hidden_states, temb) |
| 86 | + output_states = output_states + (hidden_states,) |
| 87 | + |
| 88 | + if self.downsamplers is not None: |
| 89 | + for downsampler in self.downsamplers: |
| 90 | + hidden_states = downsampler(hidden_states) |
| 91 | + |
| 92 | + output_states = output_states + (hidden_states,) |
| 93 | + |
| 94 | + return hidden_states, output_states |
| 95 | + |
| 96 | + |
| 97 | +class MidBlock2DCrossAttn(nn.Module[Tensor]): |
| 98 | + def __init__(self, **kwargs): |
| 99 | + super().__init__() |
| 100 | + |
| 101 | + self.has_cross_attention = True |
| 102 | + |
| 103 | + transformer_layers_per_block = kwargs["transformer_layers_per_block"] |
| 104 | + if isinstance(kwargs["transformer_layers_per_block"], int): |
| 105 | + transformer_layers_per_block = [transformer_layers_per_block] * kwargs["num_layers"] |
| 106 | + |
| 107 | + self.resnets = [ResnetBlock2D(**{**kwargs, "input_channels": kwargs["input_channels"]})] |
| 108 | + self.attentions = [] |
| 109 | + |
| 110 | + for i in range(kwargs["num_layers"]): |
| 111 | + self.attentions.append( |
| 112 | + Transformer2DModel( |
| 113 | + **{ |
| 114 | + **kwargs, |
| 115 | + "attention_head_dim": kwargs["input_channels"] // kwargs["num_attention_heads"], |
| 116 | + "input_channels": kwargs["input_channels"], |
| 117 | + "num_layers": transformer_layers_per_block[i], |
| 118 | + } |
| 119 | + ) |
| 120 | + ) |
| 121 | + |
| 122 | + self.resnets.append(ResnetBlock2D(**{**kwargs, "input_channels": kwargs["input_channels"]})) |
| 123 | + |
| 124 | + self.resnets = nn.ModuleList(self.resnets) |
| 125 | + self.attentions = nn.ModuleList(self.attentions) |
| 126 | + |
| 127 | + def forward(self, hidden_states: Tensor, temb: Tensor, encoder_hidden_states: Tensor) -> Tensor: |
| 128 | + hidden_states = self.resnets[0](hidden_states, temb) |
| 129 | + |
| 130 | + for attn, resnet in zip(self.attentions, self.resnets[1:]): |
| 131 | + hidden_states = attn(hidden_states, encoder_hidden_states) |
| 132 | + hidden_states = resnet(hidden_states, temb) |
| 133 | + |
| 134 | + return hidden_states |
| 135 | + |
| 136 | + |
| 137 | +class CrossAttnUpBlock2D(nn.Module[Tensor]): |
| 138 | + def __init__(self, **kwargs): |
| 139 | + super().__init__() |
| 140 | + self.has_cross_attention = True |
| 141 | + num_layers = kwargs["num_layers"] |
| 142 | + |
| 143 | + transformer_layers_per_block = kwargs["transformer_layers_per_block"] |
| 144 | + if isinstance(transformer_layers_per_block, int): |
| 145 | + transformer_layers_per_block = [transformer_layers_per_block] * num_layers |
| 146 | + |
| 147 | + self.resnets = [] |
| 148 | + self.attentions = [] |
| 149 | + for i in range(num_layers): |
| 150 | + res_skip_channels = kwargs["input_channels"] if (i == num_layers - 1) else kwargs["output_channels"] |
| 151 | + resnet_in_channels = kwargs["prev_output_channel"] if i == 0 else kwargs["output_channels"] |
| 152 | + input_channels = resnet_in_channels + res_skip_channels |
| 153 | + |
| 154 | + self.resnets.append(ResnetBlock2D(**{**kwargs, "input_channels": input_channels})) |
| 155 | + |
| 156 | + self.attentions.append( |
| 157 | + Transformer2DModel( |
| 158 | + **{ |
| 159 | + **kwargs, |
| 160 | + "attention_head_dim": kwargs["output_channels"] // kwargs["num_attention_heads"], |
| 161 | + "input_channels": kwargs["output_channels"], |
| 162 | + "num_layers": transformer_layers_per_block[i], |
| 163 | + } |
| 164 | + ) |
| 165 | + ) |
| 166 | + |
| 167 | + self.resnets = nn.ModuleList(self.resnets) |
| 168 | + self.attentions = nn.ModuleList(self.attentions) |
| 169 | + |
| 170 | + if kwargs["add_upsample"]: |
| 171 | + self.upsamplers = nn.ModuleList([Upsample2D(kwargs["output_channels"], **kwargs)]) |
| 172 | + else: |
| 173 | + self.upsamplers = None |
| 174 | + |
| 175 | + def forward( |
| 176 | + self, |
| 177 | + hidden_states: Tensor, |
| 178 | + res_hidden_states_tuple: Tuple[Tensor], |
| 179 | + temb: Tensor, |
| 180 | + upsample_size: int, |
| 181 | + encoder_hidden_states: Tensor, |
| 182 | + is_final_block=False, |
| 183 | + ) -> Tensor: |
| 184 | + for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): |
| 185 | + res_hidden_states = res_hidden_states_tuple[-1] |
| 186 | + res_hidden_states_tuple = res_hidden_states_tuple[:-1] |
| 187 | + |
| 188 | + hidden_states = concat([hidden_states, res_hidden_states], axis=1) |
| 189 | + |
| 190 | + hidden_states = resnet(hidden_states, temb) |
| 191 | + hidden_states = attn( |
| 192 | + hidden_states, |
| 193 | + encoder_hidden_states=encoder_hidden_states, |
| 194 | + temperature_scaling=2 if is_final_block and i == 1 else 1, |
| 195 | + ) |
| 196 | + |
| 197 | + if self.upsamplers is not None: |
| 198 | + for upsampler in self.upsamplers: |
| 199 | + hidden_states = upsampler(hidden_states, upsample_size) |
| 200 | + |
| 201 | + return hidden_states |
| 202 | + |
| 203 | + |
| 204 | +class UpBlock2D(nn.Module[Tensor]): |
| 205 | + def __init__(self, **kwargs): |
| 206 | + super().__init__() |
| 207 | + self.has_cross_attention = False |
| 208 | + self.resnets = [] |
| 209 | + |
| 210 | + for i in range(kwargs["num_layers"]): |
| 211 | + res_skip_channels = ( |
| 212 | + kwargs["input_channels"] if (i == kwargs["num_layers"] - 1) else kwargs["output_channels"] |
| 213 | + ) |
| 214 | + resnet_input_channels = kwargs["prev_output_channel"] if i == 0 else kwargs["output_channels"] |
| 215 | + input_channels = res_skip_channels + resnet_input_channels |
| 216 | + |
| 217 | + self.resnets.append(ResnetBlock2D(**{**kwargs, "input_channels": input_channels})) |
| 218 | + |
| 219 | + self.resnets = nn.ModuleList(self.resnets) |
| 220 | + if kwargs["add_upsample"]: |
| 221 | + self.upsamplers = nn.ModuleList([Upsample2D(kwargs["output_channels"], **kwargs)]) |
| 222 | + else: |
| 223 | + self.upsamplers = None |
| 224 | + |
| 225 | + def forward( |
| 226 | + self, hidden_states: Tensor, res_hidden_states_tuple: Tuple[Tensor], temb: Tensor, upsample_size: int |
| 227 | + ) -> Tensor: |
| 228 | + for resnet in self.resnets: |
| 229 | + res_hidden_states = res_hidden_states_tuple[-1] |
| 230 | + res_hidden_states_tuple = res_hidden_states_tuple[:-1] |
| 231 | + |
| 232 | + hidden_states = concat([hidden_states, res_hidden_states], axis=1) |
| 233 | + hidden_states = resnet(hidden_states, temb) |
| 234 | + |
| 235 | + if self.upsamplers is not None: |
| 236 | + for upsampler in self.upsamplers: |
| 237 | + hidden_states = upsampler(hidden_states, upsample_size) |
| 238 | + |
| 239 | + return hidden_states |
0 commit comments