Skip to content

[Bug] Checkpoint fails with "with_cp=True" in mmdet swin-transformer distillation.  #224

@Mrxiaoyuer

Description

@Mrxiaoyuer

Describe the bug

A clear and concise description of what the bug is.

[Hi Developers, Thanks for your efforts. I was trying to distill swint-ransformers and use the (with_cp=True) to train with larger batch size. But it turns out some errors as attached below.]

To Reproduce

[python ./tools/mmdet/train_mmdet.py ./configs/distill/cwd/my_config.py]

Post related information

  1. The output of pip list | grep "mmcv\|mmrazor\|^torch"
    [here]
  2. Your config file if you modified it or created a new one.
[_base_ = [
    '../../_base_/datasets/mmdet/coco_detection.py',
    '../../_base_/schedules/mmdet/schedule_1x.py',
    '../../_base_/mmdet_runtime.py'
]

pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth'  # noqa

student = dict(
    type='mmdet.MaskRCNN',
        backbone=dict(
        type='SwinTransformer',
        embed_dims=96,
        depths=[2, 2, 18, 2],
        num_heads=[3, 6, 12, 24],
        window_size=7,
        mlp_ratio=4,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.,
        attn_drop_rate=0.,
        drop_path_rate=0.2,
        patch_norm=True,
        out_indices=(0, 1, 2, 3),
        with_cp=True,
        convert_weights=True,
        init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
    neck=dict(
        type='FPN',
        in_channels=[96, 192, 384, 768],
        out_channels=256,
        num_outs=5),
    rpn_head=dict(
        type='RPNHead',
        in_channels=256,
        feat_channels=256,
        anchor_generator=dict(
            type='AnchorGenerator',
            scales=[8],
            ratios=[0.5, 1.0, 2.0],
            strides=[4, 8, 16, 32, 64]),
        bbox_coder=dict(
            type='DeltaXYWHBBoxCoder',
            target_means=[.0, .0, .0, .0],
            target_stds=[1.0, 1.0, 1.0, 1.0]),
        loss_cls=dict(
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
    roi_head=dict(
        type='StandardRoIHead',
        bbox_roi_extractor=dict(
            type='SingleRoIExtractor',
            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
            out_channels=256,
            featmap_strides=[4, 8, 16, 32]),
        bbox_head=dict(
            type='Shared2FCBBoxHead',
            in_channels=256,
            fc_out_channels=1024,
            roi_feat_size=7,
            num_classes=80,
            bbox_coder=dict(
                type='DeltaXYWHBBoxCoder',
                target_means=[0., 0., 0., 0.],
                target_stds=[0.1, 0.1, 0.2, 0.2]),
            reg_class_agnostic=False,
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
            loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
        ),
    # model training and testing settings
    train_cfg=dict(
        rpn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.7,
                neg_iou_thr=0.3,
                min_pos_iou=0.3,
                match_low_quality=True,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=256,
                pos_fraction=0.5,
                neg_pos_ub=-1,
                add_gt_as_proposals=False),
            allowed_border=-1,
            pos_weight=-1,
            debug=False),
        rpn_proposal=dict(
            nms_pre=2000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.5,
                neg_iou_thr=0.5,
                min_pos_iou=0.5,
                match_low_quality=True,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=512,
                pos_fraction=0.25,
                neg_pos_ub=-1,
                add_gt_as_proposals=True),
            mask_size=28,
            pos_weight=-1,
            debug=False)),
    test_cfg=dict(
        rpn=dict(
            nms_pre=1000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            score_thr=0.05,
            nms=dict(type='nms', iou_threshold=0.5),
            max_per_img=100,
            mask_thr_binary=0.5)))

teacher_ckpt = 'https://github.com/Mrxiaoyuer/coco_checkpoints/releases/download/v0.0.2/swinl_epoch_24.pth'  # noqa

teacher = dict(
    type='mmdet.MaskRCNN',
    init_cfg=dict(type='Pretrained', checkpoint=teacher_ckpt),
    backbone=dict(
        type='SwinTransformer',
        embed_dims=192,
        depths=[2, 2, 18, 2],
        num_heads=[6, 12, 24, 48],
        window_size=7,
        mlp_ratio=4,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.,
        attn_drop_rate=0.,
        drop_path_rate=0.2,
        patch_norm=True,
        out_indices=(0, 1, 2, 3),
        with_cp=False,
        convert_weights=True,
        init_cfg=None),
    neck=dict(
        type='FPN',
        in_channels=[192, 384, 768, 1536],
        out_channels=256,
        start_level=0,
        add_extra_convs='on_output',
        num_outs=5),
    rpn_head=dict(
        type='RPNHead',
        in_channels=256,
        feat_channels=256,
        anchor_generator=dict(
            type='AnchorGenerator',
            scales=[8],
            ratios=[0.5, 1.0, 2.0],
            strides=[4, 8, 16, 32, 64]),
        bbox_coder=dict(
            type='DeltaXYWHBBoxCoder',
            target_means=[.0, .0, .0, .0],
            target_stds=[1.0, 1.0, 1.0, 1.0]),
        loss_cls=dict(
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
    roi_head=dict(
        type='StandardRoIHead',
        bbox_roi_extractor=dict(
            type='SingleRoIExtractor',
            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
            out_channels=256,
            featmap_strides=[4, 8, 16, 32]),
        bbox_head=dict(
            type='Shared2FCBBoxHead',
            in_channels=256,
            fc_out_channels=1024,
            roi_feat_size=7,
            num_classes=80,
            bbox_coder=dict(
                type='DeltaXYWHBBoxCoder',
                target_means=[0., 0., 0., 0.],
                target_stds=[0.1, 0.1, 0.2, 0.2]),
            reg_class_agnostic=False,
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
            loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
        ),
    # model training and testing settings
    train_cfg=dict(
        rpn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.7,
                neg_iou_thr=0.3,
                min_pos_iou=0.3,
                match_low_quality=True,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=256,
                pos_fraction=0.5,
                neg_pos_ub=-1,
                add_gt_as_proposals=False),
            allowed_border=-1,
            pos_weight=-1,
            debug=False),
        rpn_proposal=dict(
            nms_pre=2000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.5,
                neg_iou_thr=0.5,
                min_pos_iou=0.5,
                match_low_quality=True,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=512,
                pos_fraction=0.25,
                neg_pos_ub=-1,
                add_gt_as_proposals=True),
            mask_size=28,
            pos_weight=-1,
            debug=False)),
    test_cfg=dict(
        rpn=dict(
            nms_pre=1000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            score_thr=0.05,
            nms=dict(type='nms', iou_threshold=0.5),
            max_per_img=100,
            mask_thr_binary=0.5)))


# algorithm setting
algorithm = dict(
    type='GeneralDistill',
    architecture=dict(
        type='MMDetArchitecture',
        model=student,
    ),
    distiller=dict(
        type='SingleTeacherDistiller',
        teacher=teacher,
        teacher_trainable=False,
        components=[
            dict(
                student_module='neck.fpn_convs.3.conv',
                teacher_module='neck.fpn_convs.3.conv',
                losses=[
                    dict(
                        type='ChannelWiseDivergence',
                        name='loss_cwd_cls_head',
                        tau=1,
                        loss_weight=5,
                    )
                ])
        ]),
)

find_unused_parameters = True

# optimizer
optimizer = dict(
    _delete_=True,
    type='AdamW',
    lr=0.0001,
    betas=(0.9, 0.999),
    weight_decay=0.05,
    paramwise_cfg=dict(
        custom_keys={
            'absolute_pos_embed': dict(decay_mult=0.),
            'relative_position_bias_table': dict(decay_mult=0.),
            'norm': dict(decay_mult=0.)
        }))
lr_config = dict(warmup_iters=1000, step=[8, 11])
runner = dict(max_epochs=12)
  1. Your train log file if you meet the problem during training.
    [loading annotations into memory...
    Done (t=51.46s)
    creating index...
    Done (t=51.98s)
    creating index...
    Done (t=51.93s)
    creating index...
    Done (t=51.97s)
    creating index...
    index created!
    index created!
    index created!
    index created!
    loading annotations into memory...
    loading annotations into memory...
    loading annotations into memory...
    loading annotations into memory...
    Done (t=1.01s)
    creating index...
    Done (t=1.24s)
    creating index...
    Done (t=1.25s)
    creating index...
    Done (t=1.02s)
    creating index...
    index created!
    index created!
    index created!
    index created!
    2022-08-08 20:44:14,703 - mmdet - INFO - Start running, host: root@da8f74b40bf146bf925ac3a4349ccef80000Z0, work_dir: /mnt/azureml/cr/j/89a4752b1bff40f9951561de40f847a6/cap/data-capability/wd/INPUT_3324fea54867408dbec17da2dfe673f2_data/fuxun/output/fuxun-mmrazor-cwd-swin-l-s_1x_adamw_2xbs/cwd_cls_head_swin-l-s_1x_coco_adamw_2xbs
    2022-08-08 20:44:14,704 - mmdet - INFO - Hooks will be executed in the following order:
    before_run:
    (VERY_HIGH ) StepLrUpdaterHook
    (ABOVE_NORMAL) Fp16OptimizerHook
    (NORMAL ) CheckpointHook
    (LOW ) DistEvalHook
    (VERY_LOW ) TextLoggerHook

before_train_epoch:
(VERY_HIGH ) StepLrUpdaterHook
(NORMAL ) NumClassCheckHook
(NORMAL ) DistSamplerSeedHook
(LOW ) IterTimerHook
(LOW ) DistEvalHook
(VERY_LOW ) TextLoggerHook

before_train_iter:
(VERY_HIGH ) StepLrUpdaterHook
(LOW ) IterTimerHook
(LOW ) DistEvalHook

after_train_iter:
(ABOVE_NORMAL) Fp16OptimizerHook
(NORMAL ) CheckpointHook
(LOW ) IterTimerHook
(LOW ) DistEvalHook
(VERY_LOW ) TextLoggerHook

after_train_epoch:
(NORMAL ) CheckpointHook
(LOW ) DistEvalHook
(VERY_LOW ) TextLoggerHook

before_val_epoch:
(NORMAL ) NumClassCheckHook
(NORMAL ) DistSamplerSeedHook
(LOW ) IterTimerHook
(VERY_LOW ) TextLoggerHook

before_val_iter:
(LOW ) IterTimerHook

after_val_iter:
(LOW ) IterTimerHook

after_val_epoch:
(VERY_LOW ) TextLoggerHook

after_run:
(VERY_LOW ) TextLoggerHook

2022-08-08 20:44:14,704 - mmdet - INFO - workflow: [('train', 1)], max: 12 epochs
2022-08-08 20:44:14,710 - mmdet - INFO - Checkpoints will be saved to /mnt/azureml/cr/j/89a4752b1bff40f9951561de40f847a6/cap/data-capability/wd/INPUT_3324fea54867408dbec17da2dfe673f2_data/fuxun/output/fuxun-mmrazor-cwd-swin-l-s_1x_adamw_2xbs/cwd_cls_head_swin-l-s_1x_coco_adamw_2xbs by HardDiskBackend.
/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1623448265233/work/c10/core/TensorImpl.h:1156.)
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1623448265233/work/c10/core/TensorImpl.h:1156.)
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1623448265233/work/c10/core/TensorImpl.h:1156.)
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1623448265233/work/c10/core/TensorImpl.h:1156.)
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
Traceback (most recent call last):
File "tools/mmdet/train_mmdet.py", line 230, in
main()
File "tools/mmdet/train_mmdet.py", line 226, in main
meta=meta)
File "/tmp/code/mmrazor/apis/mmdet/train.py", line 206, in train_mmdet_model
runner.run(data_loader, cfg.workflow)
File "/opt/conda/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 127, in run
epoch_runner(data_loaders[i], **kwargs)
File "/opt/conda/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 51, in train
self.call_hook('after_train_iter')
File "/opt/conda/lib/python3.7/site-packages/mmcv/runner/base_runner.py", line 309, in call_hook
getattr(hook, fn_name)(self)
File "/opt/conda/lib/python3.7/site-packages/mmcv/runner/hooks/optimizer.py", line 272, in after_train_iter
self.loss_scaler.scale(runner.outputs['loss']).backward()
File "/opt/conda/lib/python3.7/site-packages/torch/_tensor.py", line 255, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/opt/conda/lib/python3.7/site-packages/torch/autograd/init.py", line 149, in backward
allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
File "/opt/conda/lib/python3.7/site-packages/torch/autograd/function.py", line 87, in apply
return self._forward_cls.backward(self, *args) # type: ignore[attr-defined]
File "/opt/conda/lib/python3.7/site-packages/torch/utils/checkpoint.py", line 138, in backward
torch.autograd.backward(outputs_with_grad, args_with_grad)
File "/opt/conda/lib/python3.7/site-packages/torch/autograd/init.py", line 149, in backward
allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the forward function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple checkpoint functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 323 has been marked as ready twice. This means that multiple autograd engine hooks have fired for this particular parameter during this iteration. You can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print parameter names for further debugging.]
4. Other code you modified in the mmrazor folder.
[here]

Additional context

Add any other context about the problem here.

[(1) I tried add (with_cp=True) in the r101_r50 cwd distillation and it works fine. So this is probably related to the swin transformer structure
(2) I also tried (with_cp=True) by training swin-transformer alone in the mmdetection repo and it works fine. ]

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingenhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions