Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Row gather linop #901

Merged
merged 11 commits into from
Feb 11, 2022
40 changes: 34 additions & 6 deletions common/unified/matrix/dense_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "common/unified/base/kernel_launch.hpp"
#include "common/unified/base/kernel_launch_reduction.hpp"
#include "core/base/mixed_precision_types.hpp"
#include "core/components/prefix_sum_kernels.hpp"


Expand Down Expand Up @@ -413,25 +414,52 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_DECLARE_DENSE_INV_SYMM_PERMUTE_KERNEL);


template <typename ValueType, typename IndexType>
template <typename ValueType, typename OutputType, typename IndexType>
void row_gather(std::shared_ptr<const DefaultExecutor> exec,
const Array<IndexType>* row_indices,
const Array<IndexType>* row_idxs,
const matrix::Dense<ValueType>* orig,
matrix::Dense<ValueType>* row_gathered)
matrix::Dense<OutputType>* row_collection)
{
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto orig, auto rows, auto gathered) {
gathered(row, col) = orig(rows[row], col);
},
dim<2>{row_indices->get_num_elems(), orig->get_size()[1]}, orig,
*row_indices, row_gathered);
dim<2>{row_idxs->get_num_elems(), orig->get_size()[1]}, orig, *row_idxs,
row_collection);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2(
GKO_DECLARE_DENSE_ROW_GATHER_KERNEL);


template <typename ValueType, typename OutputType, typename IndexType>
void advanced_row_gather(std::shared_ptr<const DefaultExecutor> exec,
const matrix::Dense<ValueType>* alpha,
const Array<IndexType>* row_idxs,
const matrix::Dense<ValueType>* orig,
const matrix::Dense<ValueType>* beta,
matrix::Dense<OutputType>* row_collection)
{
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto alpha, auto orig, auto rows,
auto beta, auto gathered) {
using type = device_type<highest_precision<ValueType, OutputType>>;
gathered(row, col) =
static_cast<type>(alpha[0] * orig(rows[row], col)) +
static_cast<type>(beta[0]) *
static_cast<type>(gathered(row, col));
},
dim<2>{row_idxs->get_num_elems(), orig->get_size()[1]},
alpha->get_const_values(), orig, *row_idxs, beta->get_const_values(),
row_collection);
}

GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2(
GKO_DECLARE_DENSE_ADVANCED_ROW_GATHER_KERNEL);


template <typename ValueType, typename IndexType>
void column_permute(std::shared_ptr<const DefaultExecutor> exec,
const Array<IndexType>* permutation_indices,
Expand Down
3 changes: 2 additions & 1 deletion core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,18 @@ target_sources(ginkgo
matrix/permutation.cpp
matrix/sellp.cpp
matrix/sparsity_csr.cpp
matrix/row_gatherer.cpp
multigrid/amgx_pgm.cpp
preconditioner/isai.cpp
preconditioner/jacobi.cpp
reorder/rcm.cpp
solver/bicg.cpp
solver/bicgstab.cpp
solver/cb_gmres.cpp
solver/cg.cpp
solver/cgs.cpp
solver/fcg.cpp
solver/gmres.cpp
solver/cb_gmres.cpp
solver/idr.cpp
solver/ir.cpp
solver/lower_trs.cpp
Expand Down
136 changes: 136 additions & 0 deletions core/base/dispatch_helper.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*******************************<GINKGO LICENSE>******************************
Copyright (c) 2017-2022, the Ginkgo authors
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:

1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.

3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
******************************<GINKGO LICENSE>*******************************/

#ifndef GKO_CORE_BASE_DISPATCH_HELPER_HPP_
#define GKO_CORE_BASE_DISPATCH_HELPER_HPP_


#include <memory>


#include <ginkgo/core/base/exception_helpers.hpp>


namespace gko {


/**
* run uses template to go through the list and select the valid
* template and run it.
*
* @tparam T the type of input object
* @tparam Func the function will run if the object can be converted to K
* @tparam ...Args the additional arguments for the Func
*
* @note this is the end case
*/
template <typename T, typename Func, typename... Args>
void run(T, Func, Args...)
{
GKO_NOT_IMPLEMENTED;
}

/**
* run uses template to go through the list and select the valid
* template and run it.
*
* @tparam K the current type tried in the convertion
* @tparam ...Types the other types will be tried in the conversion if K fails
* @tparam T the type of input object
* @tparam Func the function will run if the object can be converted to K
* @tparam ...Args the additional arguments for the Func
*
* @param obj the input object waiting converted
* @param f the function will run if obj can be converted successfully
* @param args the additional arguments for the function
*/
template <typename K, typename... Types, typename T, typename Func,
typename... Args>
void run(T obj, Func f, Args... args)
{
if (auto dobj = dynamic_cast<K>(obj)) {
f(dobj, args...);
} else {
run<Types...>(obj, f, args...);
}
}

/**
* run uses template to go through the list and select the valid
* template and run it.
*
* @tparam Base the Base class with one template
* @tparam T the type of input object waiting converted
* @tparam Func the validation
* @tparam ...Args the variadic arguments.
*
* @note this is the end case
*/
template <template <typename> class Base, typename T, typename Func,
typename... Args>
void run(T, Func, Args...)
{
GKO_NOT_IMPLEMENTED;
}

/**
* run uses template to go through the list and select the valid
* template and run it.
*
* @tparam Base the Base class with one template
* @tparam K the current template type of B. pointer of const Base<K> is tried
* in the convertion.
* @tparam ...Types the other types will be tried in the conversion if K fails
* @tparam T the type of input object waiting converted
* @tparam Func the function will run if the object can be converted to pointer
* of const Base<K>
* @tparam ...Args the additional arguments for the Func
*
* @param obj the input object waiting converted
* @param f the function will run if obj can be converted successfully
* @param args the additional arguments for the function
*/
template <template <typename> class Base, typename K, typename... Types,
typename T, typename func, typename... Args>
void run(T obj, func f, Args... args)
{
if (auto dobj = std::dynamic_pointer_cast<const Base<K>>(obj)) {
f(dobj, args...);
} else {
run<Base, Types...>(obj, f, args...);
}
}


} // namespace gko

#endif // GKO_CORE_BASE_DISPATCH_HELPER_HPP_
24 changes: 24 additions & 0 deletions core/base/mixed_precision_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,28 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE(_macro, int64)


#ifdef GINKGO_MIXED_PRECISION
#define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_2(_macro, ...) \
template _macro(float, float, __VA_ARGS__); \
template _macro(float, double, __VA_ARGS__); \
template _macro(double, float, __VA_ARGS__); \
template _macro(double, double, __VA_ARGS__); \
template _macro(std::complex<float>, std::complex<float>, __VA_ARGS__); \
template _macro(std::complex<float>, std::complex<double>, __VA_ARGS__); \
template _macro(std::complex<double>, std::complex<float>, __VA_ARGS__); \
template _macro(std::complex<double>, std::complex<double>, __VA_ARGS__)
#else
#define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_2(_macro, ...) \
template _macro(float, float, __VA_ARGS__); \
template _macro(double, double, __VA_ARGS__); \
template _macro(std::complex<float>, std::complex<float>, __VA_ARGS__); \
template _macro(std::complex<double>, std::complex<double>, __VA_ARGS__)
#endif


#define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2(_macro) \
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_2(_macro, int32); \
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_2(_macro, int64)


#endif // GKO_CORE_BASE_MIXED_PRECISION_TYPES_HPP_
11 changes: 10 additions & 1 deletion core/device_hooks/common_kernels.inc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,13 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
GKO_NOT_COMPILED(GKO_HOOK_MODULE); \
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE(_macro)

#define GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE_2(_macro) \
template <typename InputValueType, typename OutputValueType, \
typename IndexType> \
_macro(InputValueType, OutputValueType, IndexType) \
GKO_NOT_COMPILED(GKO_HOOK_MODULE); \
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2(_macro)

#define GKO_STUB_TEMPLATE_TYPE(_macro) \
template <typename IndexType> \
_macro(IndexType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \
Expand Down Expand Up @@ -270,7 +277,9 @@ GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_TRANSPOSE_KERNEL);
GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_CONJ_TRANSPOSE_KERNEL);
GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_SYMM_PERMUTE_KERNEL);
GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_INV_SYMM_PERMUTE_KERNEL);
GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_ROW_GATHER_KERNEL);
GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE_2(GKO_DECLARE_DENSE_ROW_GATHER_KERNEL);
GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE_2(
GKO_DECLARE_DENSE_ADVANCED_ROW_GATHER_KERNEL);
GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_COLUMN_PERMUTE_KERNEL);
GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_INV_ROW_PERMUTE_KERNEL);
GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_INV_COLUMN_PERMUTE_KERNEL);
Expand Down
Loading