Skip to content

Commit

Permalink
Update ldiv! methods for AMDGPU.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Nov 27, 2023
1 parent be04834 commit c0b50bf
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 8 deletions.
18 changes: 14 additions & 4 deletions ext/AMDGPU/ic0.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,27 @@ end

for ArrayType in (:(ROCVector{T}), :(ROCMatrix{T}))
@eval begin
function ldiv!(ic::AMD_IC0{ROCSparseMatrixCSR{T,Cint}}, x::$ArrayType) where T <: BlasFloat
ldiv!(LowerTriangular(ic.P), x) # Forward substitution with L
ldiv!(LowerTriangular(ic.P)', x) # Backward substitution with Lᴴ
return x
end

function ldiv!(y::$ArrayType, ic::AMD_IC0{ROCSparseMatrixCSR{T,Cint}}, x::$ArrayType) where T <: BlasFloat
copyto!(y, x)
ldiv!(LowerTriangular(ic.P), y) # Forward substitution with L
ldiv!(LowerTriangular(ic.P)', y) # Backward substitution with Lᴴ
ldiv!(ic, y)
return y
end

function ldiv!(ic::AMD_IC0{ROCSparseMatrixCSC{T,Cint}}, x::$ArrayType) where T <: BlasReal
ldiv!(UpperTriangular(ic.P)', x) # Forward substitution with L
ldiv!(UpperTriangular(ic.P), x) # Backward substitution with Lᴴ
return x
end

function ldiv!(y::$ArrayType, ic::AMD_IC0{ROCSparseMatrixCSC{T,Cint}}, x::$ArrayType) where T <: BlasReal
copyto!(y, x)
ldiv!(UpperTriangular(ic.P)', y) # Forward substitution with L
ldiv!(UpperTriangular(ic.P), y) # Backward substitution with Lᴴ
ldiv!(ic, y)
return y
end
end
Expand Down
18 changes: 14 additions & 4 deletions ext/AMDGPU/ilu0.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,27 @@ end

for ArrayType in (:(ROCVector{T}), :(ROCMatrix{T}))
@eval begin
function ldiv!(ilu::AMD_ILU0{ROCSparseMatrixCSR{T,Cint}}, x::$ArrayType) where T <: BlasFloat
ldiv!(UnitLowerTriangular(ilu.P), x) # Forward substitution with L
ldiv!(UpperTriangular(ilu.P), x) # Backward substitution with U
return x
end

function ldiv!(y::$ArrayType, ilu::AMD_ILU0{ROCSparseMatrixCSR{T,Cint}}, x::$ArrayType) where T <: BlasFloat
copyto!(y, x)
ldiv!(UnitLowerTriangular(ilu.P), y) # Forward substitution with L
ldiv!(UpperTriangular(ilu.P), y) # Backward substitution with U
ldiv!(ilu, y)
return y
end

function ldiv!(y::$ArrayType, ilu::AMD_ILU0{ROCSparseMatrixCSC{T,Cint}}, x::$ArrayType) where T <: BlasReal
copyto!(y, x)
function ldiv!(ilu::AMD_ILU0{ROCSparseMatrixCSC{T,Cint}}, x::$ArrayType) where T <: BlasReal
ldiv!(LowerTriangular(ilu.P), y) # Forward substitution with L
ldiv!(UnitUpperTriangular(ilu.P), y) # Backward substitution with U
return x
end

function ldiv!(y::$ArrayType, ilu::AMD_ILU0{ROCSparseMatrixCSC{T,Cint}}, x::$ArrayType) where T <: BlasReal
copyto!(y, x)
ldiv!(ilu, y)
return y
end
end
Expand Down

0 comments on commit c0b50bf

Please sign in to comment.