From 435e19e923ede19a0cb2075a287e8cd94c2e371a Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Wed, 22 Nov 2023 15:01:04 -0800 Subject: [PATCH] [pytorch] Fixes windows load nvfuser_codegen bug --- bom/build.gradle | 18 +++++++++--------- .../main/java/ai/djl/pytorch/jni/LibUtils.java | 5 +---- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/bom/build.gradle b/bom/build.gradle index 736e8e4a59c..aa17702f99d 100644 --- a/bom/build.gradle +++ b/bom/build.gradle @@ -109,15 +109,15 @@ publishing { addDependency(dependencies, "ai.djl.mxnet", "mxnet-native-mkl", "linux-x86_64", "${mxnet_version}") addDependency(dependencies, "ai.djl.mxnet", "mxnet-native-mkl", "win-x86_64", "${mxnet_version}") addDependency(dependencies, "ai.djl.mxnet", "mxnet-native-cu112mkl", "linux-x86_64", "${mxnet_version}") - addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cpu", "osx-x86_64", "${pytorch_version}-SNAPSHOT") - addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cpu", "osx-aarch64", "${pytorch_version}-SNAPSHOT") - addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cpu", "linux-x86_64", "${pytorch_version}-SNAPSHOT") - addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cpu", "win-x86_64", "${pytorch_version}-SNAPSHOT") - addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cpu-precxx11", "linux-x86_64", "${pytorch_version}-SNAPSHOT") - addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cpu-precxx11", "linux-aarch64", "${pytorch_version}-SNAPSHOT") - addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu121", "linux-x86_64", "${pytorch_version}-SNAPSHOT") - addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu121", "win-x86_64", "${pytorch_version}-SNAPSHOT") - addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu121-precxx11", "linux-x86_64", "${pytorch_version}-SNAPSHOT") + addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cpu", "osx-x86_64", "${pytorch_version}") + addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cpu", "osx-aarch64", "${pytorch_version}") + addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cpu", "linux-x86_64", "${pytorch_version}") + addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cpu", "win-x86_64", "${pytorch_version}") + addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cpu-precxx11", "linux-x86_64", "${pytorch_version}") + addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cpu-precxx11", "linux-aarch64", "${pytorch_version}") + addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu121", "linux-x86_64", "${pytorch_version}") + addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu121", "win-x86_64", "${pytorch_version}") + addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu121-precxx11", "linux-x86_64", "${pytorch_version}") addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu117", "linux-x86_64", "1.13.1") addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu117", "win-x86_64", "1.13.1") addDependency(dependencies, "ai.djl.pytorch", "pytorch-native-cu117-precxx11", "linux-x86_64", "1.13.1") diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java index 9d422463910..860c1bc48e8 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java @@ -106,10 +106,6 @@ public static String getLibtorchPath() { private static void loadLibTorch(LibTorch libTorch) { Path libDir = libTorch.dir.toAbsolutePath(); - if ("1.8.1".equals(getVersion()) && System.getProperty("os.name").startsWith("Mac")) { - // PyTorch 1.8.1 libtorch_cpu.dylib cannot be loaded individually - return; - } boolean isCuda = libTorch.flavor.contains("cu"); List deferred = Arrays.asList( @@ -120,6 +116,7 @@ private static void loadLibTorch(LibTorch libTorch) { System.mapLibraryName("torch_cuda_cpp"), System.mapLibraryName("torch_cuda_cu"), System.mapLibraryName("torch_cuda"), + System.mapLibraryName("nvfuser_codegen"), System.mapLibraryName("torch")); Set loadLater = new HashSet<>(deferred);