diff --git a/docs/source/changelog.md b/docs/source/changelog.md index 54c991b7..60652e6a 100644 --- a/docs/source/changelog.md +++ b/docs/source/changelog.md @@ -14,8 +14,9 @@ **Improvements** -- Refator `trunc_normal_` and `linspace` usage in Swin-T, Cross-Former, PVT and CSWin models. [#100](/~https://github.com/Oneflow-Inc/vision/pull/100) +- Refator `trunc_normal_` and `linspace` usage in Swin-T, Cross-Former, PVT and CSWin models [#100](/~https://github.com/Oneflow-Inc/vision/pull/100) - Refator `Vision Transformer` model [#115](/~https://github.com/Oneflow-Inc/vision/pull/115) +- Refine `flowvision.models.ModelCreator` to support `ModelCreator.model_list` func [#123](/~https://github.com/Oneflow-Inc/vision/pull/123) **Docs Update** diff --git a/flowvision/models/registry.py b/flowvision/models/registry.py index c2c53b51..15775dfd 100644 --- a/flowvision/models/registry.py +++ b/flowvision/models/registry.py @@ -11,6 +11,10 @@ import oneflow as flow +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] + + class ModelCreator(object): _model_list = defaultdict( set @@ -84,3 +88,30 @@ def model_table(filter="", pretrained=False, **kwargs): table_items, headers=table_headers, tablefmt="fancy_grid", **kwargs ) return table + + @staticmethod + def model_list(filter="", pretrained=False, **kwargs): + all_models = ModelCreator._model_entrypoints.keys() + if filter: + models = [] + include_filters = filter if isinstance(filter, (tuple, list)) else [filter] + for f in include_filters: + include_models = fnmatch.filter(all_models, f) + if len(include_models): + models = set(models).union(include_models) + else: + models = all_models + + sorted_model = list(sorted(models)) + if pretrained: + for model in sorted_model: + if not ModelCreator._model_list[model]: + sorted_model.remove(model) + + return sorted_model + + def __repr__(self) -> str: + all_model_table = ModelCreator.model_table("") + return "Registry of all models:\n" + all_model_table + + __str__ = __repr__