diff --git a/julia/deps/build.jl b/julia/deps/build.jl index 9a719be850f5..badca65be577 100644 --- a/julia/deps/build.jl +++ b/julia/deps/build.jl @@ -51,9 +51,10 @@ elseif Sys.islinux() end if Sys.isunix() - try - push!(CUDAPATHS, replace(strip(read(`which nvcc`, String)), "bin/nvcc", "lib64")) - catch + nvcc_path = Sys.which("nvcc") + if nvcc_path ≢ nothing + @info "Found nvcc: $nvcc_path" + push!(CUDAPATHS, replace(nvcc_path, "bin/nvcc", "lib64")) end end diff --git a/julia/src/MXNet.jl b/julia/src/MXNet.jl index 70eda96da9d6..89ec88b52cdc 100644 --- a/julia/src/MXNet.jl +++ b/julia/src/MXNet.jl @@ -137,6 +137,7 @@ export to_graphviz include("base.jl") +include("runtime.jl") include("context.jl") include("util.jl") diff --git a/julia/src/base.jl b/julia/src/base.jl index 61779d194a94..683146402620 100644 --- a/julia/src/base.jl +++ b/julia/src/base.jl @@ -85,12 +85,11 @@ function mx_get_last_error() end "Utility macro to call MXNet API functions" -macro mxcall(fv, argtypes, args...) - f = eval(fv) +macro mxcall(f, argtypes, args...) args = map(esc, args) quote - _mxret = ccall(($(QuoteNode(f)), $MXNET_LIB), - Cint, $argtypes, $(args...)) + _mxret = ccall(($f, $MXNET_LIB), + Cint, $(esc(argtypes)), $(args...)) if _mxret != 0 err_msg = mx_get_last_error() throw(MXError(err_msg)) diff --git a/julia/src/runtime.jl b/julia/src/runtime.jl new file mode 100644 index 000000000000..cedcced9d29a --- /dev/null +++ b/julia/src/runtime.jl @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# License); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# runtime detection of compile time features in the native library + +module MXRuntime + +using ..mx + +export LibFeature +export feature_list, isenabled + +# defined in include/mxnet/c_api.h +struct LibFeature + _name::Ptr{Cchar} + enabled::Bool +end + +function Base.getproperty(x::LibFeature, p::Symbol) + (p == :name) && return unsafe_string(getfield(x, :_name)) + getfield(x, p) +end + +Base.show(io::IO, x::LibFeature) = + print(io, ifelse(x.enabled, "✔", "✖"), " ", x.name) + +""" + feature_list() + +Check the library for compile-time features. +The list of features are maintained in libinfo.h and libinfo.cc +""" +function feature_list() + ref = Ref{Ptr{LibFeature}}(C_NULL) + s = Ref{Csize_t}(C_NULL) + @mx.mxcall(:MXLibInfoFeatures, (Ref{Ptr{LibFeature}}, Ref{Csize_t}), ref, s) + unsafe_wrap(Array, ref[], s[]) +end + +""" + isenabled(x::Symbol)::Bool + +Returns the given runtime feature is enabled or not. + +```julia-repl +julia> mx.isenabled(:CUDA) +false + +julia> mx.isenabled(:CPU_SSE) +true +``` + +See also `mx.feature_list()`. +""" +isenabled(x::Symbol) = + any(feature_list()) do i + Symbol(i.name) == x && i.enabled + end + +end # module MXRuntime + +using .MXRuntime