Skip to content

Commit

Permalink
add __repr__ and model_list in ModelCreator (#123)
Browse files Browse the repository at this point in the history
* add __repr__ in Creator

* refine info

* add model_list into ModelCreator

* update ChangeLog

* auto format by CI

Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
  • Loading branch information
rentainhe and oneflow-ci-bot authored Jan 22, 2022
1 parent 356207d commit d870085
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
3 changes: 2 additions & 1 deletion docs/source/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@

**Improvements**

- Refator `trunc_normal_` and `linspace` usage in Swin-T, Cross-Former, PVT and CSWin models. [#100](/~https://github.com/Oneflow-Inc/vision/pull/100)
- Refator `trunc_normal_` and `linspace` usage in Swin-T, Cross-Former, PVT and CSWin models [#100](/~https://github.com/Oneflow-Inc/vision/pull/100)
- Refator `Vision Transformer` model [#115](/~https://github.com/Oneflow-Inc/vision/pull/115)
- Refine `flowvision.models.ModelCreator` to support `ModelCreator.model_list` func [#123](/~https://github.com/Oneflow-Inc/vision/pull/123)


**Docs Update**
Expand Down
31 changes: 31 additions & 0 deletions flowvision/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
import oneflow as flow


def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]


class ModelCreator(object):
_model_list = defaultdict(
set
Expand Down Expand Up @@ -84,3 +88,30 @@ def model_table(filter="", pretrained=False, **kwargs):
table_items, headers=table_headers, tablefmt="fancy_grid", **kwargs
)
return table

@staticmethod
def model_list(filter="", pretrained=False, **kwargs):
all_models = ModelCreator._model_entrypoints.keys()
if filter:
models = []
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
for f in include_filters:
include_models = fnmatch.filter(all_models, f)
if len(include_models):
models = set(models).union(include_models)
else:
models = all_models

sorted_model = list(sorted(models))
if pretrained:
for model in sorted_model:
if not ModelCreator._model_list[model]:
sorted_model.remove(model)

return sorted_model

def __repr__(self) -> str:
all_model_table = ModelCreator.model_table("")
return "Registry of all models:\n" + all_model_table

__str__ = __repr__

0 comments on commit d870085

Please sign in to comment.