Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add __repr__ and model_list in ModelCreator #123

Merged
merged 5 commits into from
Jan 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

**Improvements**

- Refator `trunc_normal_` and `linspace` usage in Swin-T, Cross-Former, PVT and CSWin models. [#100](/~https://github.com/Oneflow-Inc/vision/pull/100)
- Refator `trunc_normal_` and `linspace` usage in Swin-T, Cross-Former, PVT and CSWin models [#100](/~https://github.com/Oneflow-Inc/vision/pull/100)
- Refator `Vision Transformer` model [#115](/~https://github.com/Oneflow-Inc/vision/pull/115)
- Refine `flowvision.models.ModelCreator` to support `ModelCreator.model_list` func [#123](/~https://github.com/Oneflow-Inc/vision/pull/123)


**Docs Update**
Expand Down
31 changes: 31 additions & 0 deletions flowvision/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
import oneflow as flow


def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]


class ModelCreator(object):
_model_list = defaultdict(
set
Expand Down Expand Up @@ -84,3 +88,30 @@ def model_table(filter="", pretrained=False, **kwargs):
table_items, headers=table_headers, tablefmt="fancy_grid", **kwargs
)
return table

@staticmethod
def model_list(filter="", pretrained=False, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

把文档加上吧,描述一下这个函数功能是什么

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

把文档加上吧,描述一下这个函数功能是什么

写到docs里了,在这个PR下 #124

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

all_models = ModelCreator._model_entrypoints.keys()
if filter:
models = []
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
for f in include_filters:
include_models = fnmatch.filter(all_models, f)
if len(include_models):
models = set(models).union(include_models)
else:
models = all_models

sorted_model = list(sorted(models))
if pretrained:
for model in sorted_model:
if not ModelCreator._model_list[model]:
sorted_model.remove(model)

return sorted_model

def __repr__(self) -> str:
all_model_table = ModelCreator.model_table("")
return "Registry of all models:\n" + all_model_table

__str__ = __repr__