Skip to content

Commit

Permalink
Use CUDAEx for CuArrays automatically (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf authored Jan 18, 2021
1 parent f888892 commit cf29e3e
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ CUDA = "2"
GPUArrays = "6"
InitialValues = "0.2"
SplittablesBase = "0.1"
Transducers = "0.4.54"
Transducers = "0.4.55"
julia = "1"
7 changes: 3 additions & 4 deletions examples/histogram_msd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ using FLoops
using FoldsCUDA
using Setfield

function histogram_msd(xs, ex = xs isa CuArray ? CUDAEx() : ThreadedEx())
function histogram_msd(xs, ex = nothing)
zs = ntuple(_ -> 0, 9) # a tuple of 9 zeros
@floop ex for x in xs
d = msd(x)
Expand Down Expand Up @@ -78,14 +78,13 @@ aspercentage(hist1)
# However, you need to explicitly specify to use CUDA by, e.g.,
# passing `CUDAEx` to `@floop`:

executor = has_cuda_gpu() ? CUDAEx() : ThreadedEx() # fallback to thread
hist2 = histogram_msd((x^2 for x in 1:10^8), executor)
hist2 = histogram_msd(x^2 for x in 1:10^8)

# Frequency in percentage:
aspercentage(hist2)
#-

hist3 = histogram_msd((exp(x) for x in range(1, 35, length=10^8)), executor)
hist3 = histogram_msd(exp(x) for x in range(1, 35, length=10^8))

# Frequency in percentage:
aspercentage(hist3)
2 changes: 2 additions & 0 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@ popsimd(; simd = nothing, kwargs...) = kwargs

Transducers.transduce(xf, rf::RF, init, xs, exc::CUDAEx) where {RF} =
transduce_cuda(xf, rf, init, xs; popsimd(; exc.kwargs...)...)

Transducers.executor_type(::CuArray) = CUDAEx
8 changes: 4 additions & 4 deletions test/environments/main/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,11 @@ version = "0.1.3"

[[FLoops]]
deps = ["Compat", "FLoopsBase", "JuliaVariables", "MLStyle", "Setfield", "Transducers"]
git-tree-sha1 = "4f5835d6d465f65544e4fbfaa4ee1e2dfb1bcbd3"
git-tree-sha1 = "e38ee2e0fdb95053d26593f92ffaa8353b9ff53f"
repo-rev = "release"
repo-url = "/~https://github.com/JuliaFolds/FLoops.jl.git"
uuid = "cc61a311-1640-44b5-9fba-1b764f453329"
version = "0.1.5"
version = "0.1.6"

[[FLoopsBase]]
deps = ["ContextVariablesX"]
Expand Down Expand Up @@ -370,11 +370,11 @@ version = "0.5.7"

[[Transducers]]
deps = ["ArgCheck", "BangBang", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"]
git-tree-sha1 = "266420fe31e9c86abcf0d4905c75b67dac087c33"
git-tree-sha1 = "6881d54b5c7235540804dcaf2d7b1320b2832b77"
repo-rev = "release"
repo-url = "/~https://github.com/JuliaFolds/Transducers.jl.git"
uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999"
version = "0.4.54"
version = "0.4.55"

[[UUIDs]]
deps = ["Random", "SHA"]
Expand Down

0 comments on commit cf29e3e

Please sign in to comment.