add Paella (Fast Text-Conditional Discrete Denoising on Vector-Quantized Latent Spaces) #2058
add Paella (Fast Text-Conditional Discrete Denoising on Vector-Quantized Latent Spaces) #2058aengusng8 wants to merge 2 commits intohuggingface:mainfrom aengusng8:add-paella
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
|
Hey @aengusng8, Super cool! This already looks great :-) Please ping me @pcuenca if you'd like to have a review |
| # Copied from Paella/modules.py | ||
| class ModulatedLayerNorm(nn.Module): | ||
| def __init__(self, num_features, eps=1e-6, channels_first=True): | ||
| super().__init__() | ||
| self.ln = nn.LayerNorm(num_features, eps=eps) | ||
| self.gamma = nn.Parameter(torch.randn(1, 1, 1)) | ||
| self.beta = nn.Parameter(torch.randn(1, 1, 1)) | ||
| self.channels_first = channels_first | ||
|
|
||
| def forward(self, x, w=None): | ||
| x = x.permute(0, 2, 3, 1) if self.channels_first else x | ||
| if w is None: | ||
| x = self.ln(x) | ||
| else: | ||
| x = self.gamma * w * self.ln(x) + self.beta * w | ||
| x = x.permute(0, 3, 1, 2) if self.channels_first else x | ||
| return x | ||
|
|
||
|
|
||
| class ResBlock(nn.Module): | ||
| def __init__(self, c, c_hidden, c_cond=0, c_skip=0, scaler=None, layer_scale_init_value=1e-6): | ||
| super().__init__() | ||
| self.depthwise = nn.Sequential(nn.ReflectionPad2d(1), nn.Conv2d(c, c, kernel_size=3, groups=c)) | ||
| self.ln = ModulatedLayerNorm(c, channels_first=False) | ||
| self.channelwise = nn.Sequential( | ||
| nn.Linear(c + c_skip, c_hidden), | ||
| nn.GELU(), | ||
| nn.Linear(c_hidden, c), | ||
| ) | ||
| self.gamma = ( | ||
| nn.Parameter(layer_scale_init_value * torch.ones(c), requires_grad=True) | ||
| if layer_scale_init_value > 0 | ||
| else None | ||
| ) | ||
| self.scaler = scaler | ||
| if c_cond > 0: | ||
| self.cond_mapper = nn.Linear(c_cond, c) | ||
|
|
||
| def forward(self, x, s=None, skip=None): | ||
| res = x | ||
| x = self.depthwise(x) | ||
| if s is not None: | ||
| if s.size(2) == s.size(3) == 1: | ||
| s = s.expand(-1, -1, x.size(2), x.size(3)) | ||
| elif s.size(2) != x.size(2) or s.size(3) != x.size(3): | ||
| s = nn.functional.interpolate(s, size=x.shape[-2:], mode="bilinear") | ||
| s = self.cond_mapper(s.permute(0, 2, 3, 1)) | ||
| # s = self.cond_mapper(s.permute(0, 2, 3, 1)) | ||
| # if s.size(1) == s.size(2) == 1: | ||
| # s = s.expand(-1, x.size(2), x.size(3), -1) | ||
| x = self.ln(x.permute(0, 2, 3, 1), s) | ||
| if skip is not None: | ||
| x = torch.cat([x, skip.permute(0, 2, 3, 1)], dim=-1) | ||
| x = self.channelwise(x) | ||
| x = self.gamma * x if self.gamma is not None else x | ||
| x = res + x.permute(0, 3, 1, 2) | ||
| if self.scaler is not None: | ||
| x = self.scaler(x) | ||
| return x | ||
|
|
||
|
|
||
| class DenoiseUNet(nn.Module): | ||
| def __init__( | ||
| self, | ||
| num_vec_classes, | ||
| c_hidden=1280, | ||
| c_clip=1024, | ||
| c_r=64, | ||
| down_levels=[4, 8, 16], | ||
| up_levels=[16, 8, 4], | ||
| ): | ||
| super().__init__() | ||
| self.num_vec_classes = num_vec_classes | ||
| self.c_r = c_r | ||
| self.down_levels = down_levels | ||
| self.up_levels = up_levels | ||
| c_levels = [c_hidden // (2**i) for i in reversed(range(len(down_levels)))] | ||
| self.embedding = nn.Embedding(num_vec_classes, c_levels[0]) | ||
|
|
||
| # DOWN BLOCKS | ||
| self.down_blocks = nn.ModuleList() | ||
| for i, num_blocks in enumerate(down_levels): | ||
| blocks = [] | ||
| if i > 0: | ||
| blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) | ||
| for _ in range(num_blocks): | ||
| block = ResBlock(c_levels[i], c_levels[i] * 4, c_clip + c_r) | ||
| block.channelwise[-1].weight.data *= np.sqrt(1 / sum(down_levels)) | ||
| blocks.append(block) | ||
| self.down_blocks.append(nn.ModuleList(blocks)) | ||
|
|
||
| # UP BLOCKS | ||
| self.up_blocks = nn.ModuleList() | ||
| for i, num_blocks in enumerate(up_levels): | ||
| blocks = [] | ||
| for j in range(num_blocks): | ||
| block = ResBlock( | ||
| c_levels[len(c_levels) - 1 - i], | ||
| c_levels[len(c_levels) - 1 - i] * 4, | ||
| c_clip + c_r, | ||
| c_levels[len(c_levels) - 1 - i] if (j == 0 and i > 0) else 0, | ||
| ) | ||
| block.channelwise[-1].weight.data *= np.sqrt(1 / sum(up_levels)) | ||
| blocks.append(block) | ||
| if i < len(up_levels) - 1: | ||
| blocks.append( | ||
| nn.ConvTranspose2d( | ||
| c_levels[len(c_levels) - 1 - i], | ||
| c_levels[len(c_levels) - 2 - i], | ||
| kernel_size=4, | ||
| stride=2, | ||
| padding=1, | ||
| ) | ||
| ) | ||
| self.up_blocks.append(nn.ModuleList(blocks)) | ||
|
|
||
| self.clf = nn.Conv2d(c_levels[0], num_vec_classes, kernel_size=1) | ||
|
|
||
| def gamma(self, r): | ||
| return (r * torch.pi / 2).cos() | ||
|
|
||
| def gen_r_embedding(self, r, max_positions=10000): | ||
| dtype = r.dtype | ||
| r = self.gamma(r) * max_positions | ||
| half_dim = self.c_r // 2 | ||
| emb = math.log(max_positions) / (half_dim - 1) | ||
| emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() | ||
| emb = r[:, None] * emb[None, :] | ||
| emb = torch.cat([emb.sin(), emb.cos()], dim=1) | ||
| if self.c_r % 2 == 1: # zero pad | ||
| emb = nn.functional.pad(emb, (0, 1), mode="constant") | ||
| return emb.to(dtype) | ||
|
|
||
| def _down_encode_(self, x, s): | ||
| level_outputs = [] | ||
| for i, blocks in enumerate(self.down_blocks): | ||
| for block in blocks: | ||
| if isinstance(block, ResBlock): | ||
| # s_level = s[:, 0] | ||
| # s = s[:, 1:] | ||
| x = block(x, s) | ||
| else: | ||
| x = block(x) | ||
| level_outputs.insert(0, x) | ||
| return level_outputs | ||
|
|
||
| def _up_decode(self, level_outputs, s): | ||
| x = level_outputs[0] | ||
| for i, blocks in enumerate(self.up_blocks): | ||
| for j, block in enumerate(blocks): | ||
| if isinstance(block, ResBlock): | ||
| # s_level = s[:, 0] | ||
| # s = s[:, 1:] | ||
| if i > 0 and j == 0: | ||
| x = block(x, s, level_outputs[i]) | ||
| else: | ||
| x = block(x, s) | ||
| else: | ||
| x = block(x) | ||
| return x | ||
|
|
||
| def forward(self, x, c, r): # r is a uniform value between 0 and 1 | ||
| r_embed = self.gen_r_embedding(r) | ||
| x = self.embedding(x).permute(0, 3, 1, 2) | ||
| if len(c.shape) == 2: | ||
| s = torch.cat([c, r_embed], dim=-1)[:, :, None, None] | ||
| else: | ||
| r_embed = r_embed[:, :, None, None].expand(-1, -1, c.size(2), c.size(3)) | ||
| s = torch.cat([c, r_embed], dim=1) | ||
| level_outputs = self._down_encode_(x, s) | ||
| x = self._up_decode(level_outputs, s) | ||
| x = self.clf(x) | ||
| return x |
There was a problem hiding this comment.
Hi @pcuenca (cc @patrickvonplaten), should I use layers, blocks, or models in the diffusers\src\diffusers\models folder to replace some parts of the original Paella model class, or should I keep the original Paella model class unchanged?
There was a problem hiding this comment.
Hey @aengusng8,
No worries! Thanks a lot for working on this :-)
It would be amazing if you could try to "mold" your code into the existing UNet2DConditionModel class:
Also we've just added a design philosophy that might help: https://huggingface.co/docs/diffusers/main/en/conceptual/philosophy
So it be super cool if you could gauge whether it's possible to "force" the whole modeling code into UNet2DConditionModel - feel free to design your own, new unet up and down class
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
Any updates? Is this based on https://github.com/dome272/Paella ? |
Hi, @patrickvonplaten. This is my draft PR that we recently mentioned in Discord. I am 90% complete to move Paella to our library, and I think I need your help to finalize this progress.
What is done?
"Pipeline" and "Scheduler" are ready to run, check my Kaggle notebook: https://www.kaggle.com/code/aengusng/notebookd7ca68b633/notebook
Note: run this by CPU only in Colab, or CPU/GPU in Kaggle.
Current bottleneck problems?
I have a few questions that I would appreciate your help with:
diffusers\src\diffusers\modelsfolder to replace some parts of the original Paella model class, or should I keep the original Paella model class unchanged?einops,rudalle, andopen_clip_torch, since they are part of the author's code?vqvaeis initialized fromrudalle.get_vae, theirtext_encoderandtokenizerare initialized fromopen_clip, and How to save and upload the model class/configurations ofvqvae,text_encoder, andtokenizerthat are outside of Diffusers (like this https://huggingface.co/CompVis/stable-diffusion-v1-4)?What is next?
Updated: Closed this PR because comparing internal and external models takes time and deliberation.