Skip to content

Commit

Permalink
Add ignore and custom modules for aten backend
Browse files Browse the repository at this point in the history
  • Loading branch information
sovrasov committed Jul 16, 2024
1 parent cc05220 commit f7506d6
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
16 changes: 12 additions & 4 deletions ptflops/aten_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import sys
import traceback
from collections import defaultdict
from copy import deepcopy
from functools import partial
from typing import Optional, Tuple, Union

Expand All @@ -23,12 +24,15 @@

class FlopCounterMode(TorchDispatchMode):
def __init__(self, module=None, verbose=False, print_per_layer_stat=False,
output_params=None):
output_params=None, custom_hooks={}, ignored_ops=[]):
self.verbose = verbose
if output_params is None:
output_params = defaultdict(dict)
self.output_params = output_params
self.print_fn = partial(print, **self.output_params['print_params'])
self.all_ops = deepcopy(ATEN_OPS_MAPPING)
self.all_ops.update(custom_hooks)
self.ignored_ops = ignored_ops

self.print_per_layer_stat = print_per_layer_stat
self.flop_counts = defaultdict(lambda: defaultdict(int))
Expand Down Expand Up @@ -82,8 +86,11 @@ def normalize_tuple(x):

out = func(*args, **kwargs)
func_packet = func._overloadpacket
if func_packet in ATEN_OPS_MAPPING:
flop_count = ATEN_OPS_MAPPING[func_packet](args, normalize_tuple(out))

if func_packet in self.ignored_ops:
self.print_fn(f'Warning: {func_packet} operation is ignored')
elif func_packet in self.all_ops:
flop_count = self.all_ops[func_packet](args, normalize_tuple(out))
for par in self.parents:
self.flop_counts[par][func_packet] += flop_count
elif self.verbose:
Expand Down Expand Up @@ -119,7 +126,8 @@ def get_flops_aten(model, input_res,
batch = torch.ones(()).new_empty((1, *input_res))

try:
counter = FlopCounterMode(model, verbose, print_per_layer_stat, output_params)
counter = FlopCounterMode(model, verbose, print_per_layer_stat, output_params,
custom_modules_hooks, ignore_modules)
with counter:
if isinstance(batch, dict):
_ = model(**batch)
Expand Down
13 changes: 7 additions & 6 deletions ptflops/flops_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def get_model_complexity_info(model: nn.Module,
input_constructor: Optional[Callable[[Tuple], Dict]] = None,
ost: TextIO = sys.stdout,
verbose: bool = False,
ignore_modules: List[nn.Module] = [],
custom_modules_hooks: Dict[nn.Module, Any] = {},
ignore_modules: List[Union[nn.Module, Any]] = [],
custom_modules_hooks: Dict[Union[nn.Module, Any], Any] = {},
backend: Union[str, FLOPS_BACKEND] = FLOPS_BACKEND.ATEN,
flops_units: Optional[str] = None,
param_units: Optional[str] = None,
Expand Down Expand Up @@ -61,10 +61,11 @@ def get_model_complexity_info(model: nn.Module,
:type ost: TextIO
:param verbose: Parameter to control printing of extra information and warnings.
:type verbose: bool
:param ignore_modules: A list of torch.nn.Module modules to ignore.
:type ignore_modules: nn.Module
:param custom_modules_hooks: A dict that contains custom hooks on torch modules.
:type custom_modules_hooks: Dict[nn.Module, Any]
:param ignore_modules: A list of torch.nn.Module or torch.ops.aten modules to ignore.
:type ignore_modules: List[Union[nn.Module, Any]]
:param custom_modules_hooks: A dict that contains custom hooks for torch.nn.Module or
torch.ops.aten modules.
:type custom_modules_hooks: Dict[Union[nn.Module, Any], Any]
:param backend: Backend that used for evaluating model complexity.
:type backend: FLOPS_BACKEND
:param flops_units: Units for string representation of MACs (GMac, MMac or KMac).
Expand Down

0 comments on commit f7506d6

Please sign in to comment.