Skip to content

Commit

Permalink
metal lowbit kernels: pip install (#1785)
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelcandales authored Mar 1, 2025
1 parent 4a4925f commit 8f93751
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 10 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,4 @@ checkpoints/

# Experimental
torchao/experimental/cmake-out
torchao/experimental/deps
18 changes: 18 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,20 @@ def use_debug_mode():
CUDAExtension,
)

build_torchao_experimental_mps = (
os.getenv("TORCHAO_BUILD_EXPERIMENTAL_MPS") == "1"
and build_torchao_experimental
and torch.mps.is_available()
)

if os.getenv("TORCHAO_BUILD_EXPERIMENTAL_MPS") == "1":
if use_cpp != "1":
print("Building experimental MPS ops requires USE_CPP=1")
if not platform.machine().startswith("arm64") or platform.system() != "Darwin":
print("Experimental MPS ops require Apple Silicon.")
if not torch.mps.is_available():
print("MPS not available. Skipping compilation of experimental MPS ops.")

# Constant known variables used throughout this file
cwd = os.path.abspath(os.path.curdir)
third_party_path = os.path.join(cwd, "third_party")
Expand Down Expand Up @@ -174,15 +188,19 @@ def build_cmake(self, ext):
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)

build_mps_ops = "ON" if build_torchao_experimental_mps else "OFF"

subprocess.check_call(
[
"cmake",
ext.sourcedir,
"-DCMAKE_BUILD_TYPE=" + build_type,
# Disable now because 1) KleidiAI increases build time, and 2) KleidiAI has accuracy issues due to BF16
"-DTORCHAO_BUILD_KLEIDIAI=OFF",
"-DTORCHAO_BUILD_MPS_OPS=" + build_mps_ops,
"-DTorch_DIR=" + torch_dir,
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
"-DCMAKE_INSTALL_PREFIX=cmake-out",
],
cwd=self.build_temp,
)
Expand Down
7 changes: 7 additions & 0 deletions torchao/experimental/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ if (NOT CMAKE_BUILD_TYPE)
endif()

option(TORCHAO_BUILD_EXECUTORCH_OPS "Building torchao ops for ExecuTorch." OFF)
option(TORCHAO_BUILD_MPS_OPS "Building torchao MPS ops" OFF)


if(NOT TORCHAO_INCLUDE_DIRS)
Expand Down Expand Up @@ -51,6 +52,12 @@ if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
torchao_ops_linear_8bit_act_xbit_weight_aten
torchao_ops_embedding_xbit_aten
)
if (TORCHAO_BUILD_MPS_OPS)
message(STATUS "Building with MPS support")
add_subdirectory(ops/mps)
target_link_libraries(torchao_ops_aten PRIVATE torchao_ops_mps_aten)
endif()

install(
TARGETS torchao_ops_aten
EXPORT _targets
Expand Down
3 changes: 2 additions & 1 deletion torchao/experimental/ops/mps/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ find_package(Torch REQUIRED)
# Generate metal_shader_lib.h by running gen_metal_shader_lib.py
set(METAL_SHADERS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/metal)
file(GLOB METAL_FILES ${METAL_SHADERS_DIR}/*.metal)
set(METAL_SHADERS_YAML ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/metal.yaml)
set(GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py)
set(GENERATED_METAL_SHADER_LIB ${CMAKE_INSTALL_PREFIX}/include/torchao/experimental/kernels/mps/src/metal_shader_lib.h)
add_custom_command(
OUTPUT ${GENERATED_METAL_SHADER_LIB}
COMMAND python ${GEN_SCRIPT} ${GENERATED_METAL_SHADER_LIB}
DEPENDS ${METAL_FILES} ${GEN_SCRIPT}
DEPENDS ${METAL_FILES} ${METAL_SHADERS_YAML} ${GEN_SCRIPT}
COMMENT "Generating metal_shader_lib.h using gen_metal_shader_lib.py"
)
add_custom_target(generated_metal_shader_lib ALL DEPENDS ${GENERATED_METAL_SHADER_LIB})
Expand Down
9 changes: 5 additions & 4 deletions torchao/experimental/ops/mps/test/test_lowbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@
import torch
from parameterized import parameterized

libname = "libtorchao_ops_mps_aten.dylib"
libpath = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname)
)
import torchao # noqa: F401

try:
for nbit in range(1, 8):
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
except AttributeError:
try:
libname = "libtorchao_ops_mps_aten.dylib"
libpath = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname)
)
torch.ops.load_library(libpath)
except:
raise RuntimeError(f"Failed to load library {libpath}")
Expand Down
10 changes: 5 additions & 5 deletions torchao/experimental/ops/mps/test/test_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@
import torch
from parameterized import parameterized

import torchao # noqa: F401
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer, _quantize

libname = "libtorchao_ops_mps_aten.dylib"
libpath = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname)
)

try:
for nbit in range(1, 8):
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
except AttributeError:
try:
libname = "libtorchao_ops_mps_aten.dylib"
libpath = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname)
)
torch.ops.load_library(libpath)
except:
raise RuntimeError(f"Failed to load library {libpath}")
Expand Down

0 comments on commit 8f93751

Please sign in to comment.