Skip to content

Commit

Permalink
Enhance package availability check to support multiple distribution n…
Browse files Browse the repository at this point in the history
…ames
  • Loading branch information
kazssym committed Jan 19, 2025
1 parent adcae38 commit 7c1ee2f
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions optimum/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,21 @@
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}


def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
def _is_package_available(pkg_name: str, return_version: bool = False, dist_names: Union[list[str], None] = None) -> Union[Tuple[bool, str], bool]:
# Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
package_exists = importlib.util.find_spec(pkg_name) is not None
package_version = "N/A"
if package_exists:
try:
package_version = importlib.metadata.version(pkg_name)
package_exists = True
except importlib.metadata.PackageNotFoundError:
if dist_names is None:
dist_names = [pkg_name]
for dist_name in dist_names:
try:
package_version = importlib.metadata.version(dist_name)
package_exists = True
break
except importlib.metadata.PackageNotFoundError:
pass
else:
package_exists = False
if return_version:
return package_exists, package_version
Expand All @@ -66,7 +72,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_torch_available, _torch_version = _is_package_available("torch", return_version=True)

# importlib.metadata.version seem to not be robust with the ONNX Runtime extensions (`onnxruntime-gpu`, etc.)
_onnxruntime_available = _is_package_available("onnxruntime", return_version=False)
_onnxruntime_available = _is_package_available("onnxruntime", return_version=False, dist_names=["onnxruntime, onnxruntime-gpu", "onnxruntime-directml"])

# TODO : Remove
torch_version = version.parse(importlib.metadata.version("torch")) if _torch_available else None
Expand Down

0 comments on commit 7c1ee2f

Please sign in to comment.