Skip to content

Commit

Permalink
move calibrate_bn_mixin to utils
Browse files Browse the repository at this point in the history
  • Loading branch information
gaoyang07 committed Dec 6, 2022
1 parent a3f78bc commit 6e5224a
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 11 deletions.
3 changes: 1 addition & 2 deletions mmrazor/engine/runner/evolution_search_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from mmrazor.registry import LOOPS, TASK_UTILS
from mmrazor.structures import Candidates, export_fix_subnet
from mmrazor.utils import SupportRandomSubnet
from .mixins import CalibrateBNMixin
from .utils import check_subnet_resources, crossover
from .utils import CalibrateBNMixin, check_subnet_resources, crossover


@LOOPS.register_module()
Expand Down
13 changes: 6 additions & 7 deletions mmrazor/engine/runner/subnet_val_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from mmrazor.models.utils import add_prefix
from mmrazor.registry import LOOPS
from .mixins import CalibrateBNMixin
from .utils import CalibrateBNMixin


@LOOPS.register_module()
Expand All @@ -33,12 +33,11 @@ def __init__(self,
self.evaluate_fixed_subnet = evaluate_fixed_subnet
self.calibrate_sample_num = calibrate_sample_num

# remove CheckpointHook to avoid extra problems when testing.
if self.evaluate_fixed_subnet:
for hook in self.runner._hooks:
if isinstance(hook, CheckpointHook):
self.runner._hooks.remove(hook)
break
# remove CheckpointHook to avoid extra problems.
for hook in self.runner._hooks:
if isinstance(hook, CheckpointHook):
self.runner._hooks.remove(hook)
break

def run(self):
"""Launch validation."""
Expand Down
3 changes: 2 additions & 1 deletion mmrazor/engine/runner/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .calibrate_bn_mixin import CalibrateBNMixin
from .check import check_subnet_resources
from .genetic import crossover

__all__ = ['crossover', 'check_subnet_resources']
__all__ = ['crossover', 'check_subnet_resources', 'CalibrateBNMixin']
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import Tensor, nn
from torch.utils.data import DataLoader, Dataset

from mmrazor.engine.runner.mixins import CalibrateBNMixin
from mmrazor.engine.runner.utils import CalibrateBNMixin


class ToyModel(nn.Module):
Expand Down

0 comments on commit 6e5224a

Please sign in to comment.