[Refactor] Using mmcv transformer bricks to refactor vit.#571
[Refactor] Using mmcv transformer bricks to refactor vit.#571xvjiarui merged 14 commits intoopen-mmlab:masterfrom
Conversation
Codecov Report
@@ Coverage Diff @@
## master #571 +/- ##
==========================================
- Coverage 85.95% 85.45% -0.50%
==========================================
Files 101 101
Lines 5234 5220 -14
Branches 828 840 +12
==========================================
- Hits 4499 4461 -38
- Misses 561 586 +25
+ Partials 174 173 -1
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
| @@ -0,0 +1,53 @@ | |||
| import logging | |||
mmseg/models/utils/helpers.py
Outdated
| import collections.abc | ||
| from itertools import repeat | ||
|
|
||
|
|
||
| # From PyTorch internals | ||
| def _ntuple(n): | ||
|
|
||
| def parse(x): | ||
| if isinstance(x, collections.abc.Iterable): | ||
| return x | ||
| return tuple(repeat(x, n)) | ||
|
|
||
| return parse | ||
|
|
||
|
|
||
| to_1tuple = _ntuple(1) | ||
| to_2tuple = _ntuple(2) | ||
| to_3tuple = _ntuple(3) | ||
| to_4tuple = _ntuple(4) | ||
| to_ntuple = _ntuple |
There was a problem hiding this comment.
This file is not necessary. Use from torch.nn.modules.utils import _pair as to_2tuple instead.
mmseg/models/backbones/vit.py
Outdated
| # We only implement the 'jax_impl' initialization implemented at | ||
| # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 | ||
| trunc_normal_(self.pos_embed, std=.02) | ||
| trunc_normal_(self.cls_token, std=.02) | ||
| for n, m in self.named_modules(): | ||
| if isinstance(m, Linear): | ||
| trunc_normal_(m.weight, std=.02) | ||
| if m.bias is not None: | ||
| if 'mlp' in n: | ||
| normal_init(m.bias, std=1e-6) | ||
| else: | ||
| constant_init(m.bias, 0) | ||
| elif isinstance(m, Conv2d): | ||
| kaiming_init(m.weight, mode='fan_in') | ||
| if m.bias is not None: | ||
| constant_init(m.bias, 0) | ||
| elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): | ||
| constant_init(m.bias, 0) | ||
| constant_init(m.weight, 1.0) | ||
| else: | ||
| raise TypeError('pretrained must be a str or None') | ||
| # Modified from ClassyVision | ||
| nn.init.normal_(self.pos_embed, std=0.02) |
There was a problem hiding this comment.
Why remove the initialization?
1. Use timm style init_weights; 2. Remove to_xtuple and trunc_norm_;
mmseg/models/backbones/vit.py
Outdated
| else: | ||
| state_dict = checkpoint | ||
|
|
||
| if 'rwightman/pytorch-image-models' in pretrained: |
There was a problem hiding this comment.
If user downloaded the weight from timm and would like to init the model with path, the condition does not hold.
.github/workflows/build.yml
Outdated
| include: | ||
| - torch: 1.3.0+cpu | ||
| torchvision: 0.4.1+cpu | ||
| torch_version: 1.3.0 |
| with_cp (bool): Use checkpoint or not. Using checkpoint will save | ||
| some memory while slowing down the training speed. Default: False. | ||
| pretrain_style (str): Choose to use timm or mmcls pretrain weights. | ||
| Default: timm. |
There was a problem hiding this comment.
We should explain what options are supported, and add assert.
…#571) * [Refactor] Using mmcv bricks to refactor vit * Follow the vit code structure from mmclassification * Add MMCV install into CI system. * Add to 'Install MMCV' CI item * Add 'Install MMCV_CPU' and 'Install MMCV_GPU CI' items * Fix & Add 1. Fix low code coverage of vit.py; 2. Remove HybirdEmbed; 3. Fix doc string of VisionTransformer; * Add helpers unit test. * Add converter to convert vit pretrain weights from timm style to mmcls style. * Clean some rebundant code and refactor init 1. Use timm style init_weights; 2. Remove to_xtuple and trunc_norm_; * Add comments for VisionTransformer.init_weights() * Add arg: pretrain_style to choose timm or mmcls vit pretrain weights.
…en-mmlab#571) * polish README * fix typo
The foundation of this PR:
mmcv: open-mmlab/mmcv#978 (merged)
mmclassification: open-mmlab/mmpretrain#295