From e6b4d55aa10b4bb4b37e0f7de02a4461d3bb82c7 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Wed, 31 Oct 2018 09:30:57 -0700 Subject: [PATCH 01/12] Add header files required by horovod --- include/dlpack/dlpack.h | 141 + include/dmlc/any.h | 371 ++ include/dmlc/array_view.h | 128 + include/dmlc/base.h | 291 ++ include/dmlc/blockingconcurrentqueue.h | 991 +++++ include/dmlc/common.h | 85 + include/dmlc/concurrency.h | 258 ++ include/dmlc/concurrentqueue.h | 3719 +++++++++++++++++ include/dmlc/config.h | 186 + include/dmlc/data.h | 397 ++ include/dmlc/endian.h | 44 + include/dmlc/input_split_shuffle.h | 168 + include/dmlc/io.h | 522 +++ include/dmlc/json.h | 981 +++++ include/dmlc/logging.h | 424 ++ include/dmlc/lua.h | 739 ++++ include/dmlc/memory.h | 261 ++ include/dmlc/memory_io.h | 105 + include/dmlc/omp.h | 47 + include/dmlc/optional.h | 261 ++ include/dmlc/parameter.h | 1065 +++++ include/dmlc/recordio.h | 196 + include/dmlc/registry.h | 306 ++ include/dmlc/serializer.h | 410 ++ include/dmlc/thread_group.h | 808 ++++ include/dmlc/thread_local.h | 83 + include/dmlc/threadediter.h | 475 +++ include/dmlc/timer.h | 49 + include/dmlc/type_traits.h | 191 + include/mshadow/README.md | 8 + include/mshadow/base.h | 1106 +++++ include/mshadow/cuda/reduce.cuh | 120 + include/mshadow/cuda/tensor_gpu-inl.cuh | 828 ++++ include/mshadow/dot_engine-inl.h | 906 ++++ include/mshadow/expr_engine-inl.h | 482 +++ include/mshadow/expr_scalar-inl.h | 165 + include/mshadow/expression.h | 416 ++ include/mshadow/extension.h | 41 + include/mshadow/extension/broadcast.h | 165 + .../mshadow/extension/broadcast_with_axis.h | 258 ++ include/mshadow/extension/channel_pool.h | 108 + include/mshadow/extension/channel_unpool.h | 137 + include/mshadow/extension/choose.h | 90 + include/mshadow/extension/complex.h | 525 +++ include/mshadow/extension/concat.h | 194 + include/mshadow/extension/crop.h | 119 + include/mshadow/extension/fill.h | 103 + include/mshadow/extension/flip.h | 132 + include/mshadow/extension/implicit_gemm.h | 128 + include/mshadow/extension/mask.h | 97 + include/mshadow/extension/mirror.h | 62 + include/mshadow/extension/one_hot.h | 87 + include/mshadow/extension/pack_col2patch.h | 154 + include/mshadow/extension/pad.h | 111 + include/mshadow/extension/range.h | 118 + include/mshadow/extension/reduce_with_axis.h | 136 + include/mshadow/extension/reduceto1d.h | 104 + include/mshadow/extension/reshape.h | 87 + include/mshadow/extension/slice.h | 156 + include/mshadow/extension/slice_ex.h | 135 + include/mshadow/extension/spatial_pool.h | 152 + include/mshadow/extension/spatial_unpool.h | 135 + .../extension/spatial_upsampling_nearest.h | 71 + include/mshadow/extension/swapaxis.h | 110 + include/mshadow/extension/take.h | 99 + include/mshadow/extension/take_grad.h | 111 + include/mshadow/extension/transpose.h | 200 + include/mshadow/extension/unpack_patch2col.h | 151 + include/mshadow/half.h | 288 ++ include/mshadow/half2.h | 143 + include/mshadow/io.h | 137 + include/mshadow/logging.h | 234 ++ include/mshadow/packet-inl.h | 413 ++ include/mshadow/packet/plain-inl.h | 76 + include/mshadow/packet/sse-inl.h | 147 + include/mshadow/random.h | 570 +++ include/mshadow/stream_gpu-inl.h | 212 + include/mshadow/tensor.h | 1078 +++++ include/mshadow/tensor_container.h | 208 + include/mshadow/tensor_cpu-inl.h | 627 +++ include/mshadow/tensor_gpu-inl.h | 245 ++ include/nnvm/base.h | 35 + include/nnvm/c_api.h | 388 ++ include/nnvm/compiler/op_attr_types.h | 101 + include/nnvm/compiler/packed_func_ext.h | 59 + include/nnvm/compiler/util.h | 33 + include/nnvm/graph.h | 315 ++ include/nnvm/graph_attr_types.h | 112 + include/nnvm/layout.h | 455 ++ include/nnvm/node.h | 201 + include/nnvm/op.h | 562 +++ include/nnvm/op_attr_types.h | 219 + include/nnvm/pass.h | 128 + include/nnvm/pass_functions.h | 190 + include/nnvm/symbolic.h | 217 + include/nnvm/top/README | 1 + include/nnvm/top/nn.h | 498 +++ include/nnvm/top/tensor.h | 301 ++ include/nnvm/tuple.h | 633 +++ 99 files changed, 30835 insertions(+) create mode 100644 include/dlpack/dlpack.h create mode 100644 include/dmlc/any.h create mode 100644 include/dmlc/array_view.h create mode 100644 include/dmlc/base.h create mode 100644 include/dmlc/blockingconcurrentqueue.h create mode 100644 include/dmlc/common.h create mode 100644 include/dmlc/concurrency.h create mode 100644 include/dmlc/concurrentqueue.h create mode 100644 include/dmlc/config.h create mode 100644 include/dmlc/data.h create mode 100644 include/dmlc/endian.h create mode 100644 include/dmlc/input_split_shuffle.h create mode 100644 include/dmlc/io.h create mode 100644 include/dmlc/json.h create mode 100644 include/dmlc/logging.h create mode 100644 include/dmlc/lua.h create mode 100644 include/dmlc/memory.h create mode 100644 include/dmlc/memory_io.h create mode 100644 include/dmlc/omp.h create mode 100644 include/dmlc/optional.h create mode 100644 include/dmlc/parameter.h create mode 100644 include/dmlc/recordio.h create mode 100644 include/dmlc/registry.h create mode 100644 include/dmlc/serializer.h create mode 100644 include/dmlc/thread_group.h create mode 100644 include/dmlc/thread_local.h create mode 100644 include/dmlc/threadediter.h create mode 100644 include/dmlc/timer.h create mode 100644 include/dmlc/type_traits.h create mode 100644 include/mshadow/README.md create mode 100755 include/mshadow/base.h create mode 100644 include/mshadow/cuda/reduce.cuh create mode 100755 include/mshadow/cuda/tensor_gpu-inl.cuh create mode 100644 include/mshadow/dot_engine-inl.h create mode 100644 include/mshadow/expr_engine-inl.h create mode 100644 include/mshadow/expr_scalar-inl.h create mode 100644 include/mshadow/expression.h create mode 100644 include/mshadow/extension.h create mode 100644 include/mshadow/extension/broadcast.h create mode 100644 include/mshadow/extension/broadcast_with_axis.h create mode 100644 include/mshadow/extension/channel_pool.h create mode 100644 include/mshadow/extension/channel_unpool.h create mode 100644 include/mshadow/extension/choose.h create mode 100644 include/mshadow/extension/complex.h create mode 100644 include/mshadow/extension/concat.h create mode 100644 include/mshadow/extension/crop.h create mode 100644 include/mshadow/extension/fill.h create mode 100644 include/mshadow/extension/flip.h create mode 100644 include/mshadow/extension/implicit_gemm.h create mode 100644 include/mshadow/extension/mask.h create mode 100644 include/mshadow/extension/mirror.h create mode 100644 include/mshadow/extension/one_hot.h create mode 100644 include/mshadow/extension/pack_col2patch.h create mode 100644 include/mshadow/extension/pad.h create mode 100644 include/mshadow/extension/range.h create mode 100644 include/mshadow/extension/reduce_with_axis.h create mode 100644 include/mshadow/extension/reduceto1d.h create mode 100644 include/mshadow/extension/reshape.h create mode 100644 include/mshadow/extension/slice.h create mode 100644 include/mshadow/extension/slice_ex.h create mode 100644 include/mshadow/extension/spatial_pool.h create mode 100644 include/mshadow/extension/spatial_unpool.h create mode 100644 include/mshadow/extension/spatial_upsampling_nearest.h create mode 100644 include/mshadow/extension/swapaxis.h create mode 100644 include/mshadow/extension/take.h create mode 100644 include/mshadow/extension/take_grad.h create mode 100644 include/mshadow/extension/transpose.h create mode 100644 include/mshadow/extension/unpack_patch2col.h create mode 100644 include/mshadow/half.h create mode 100755 include/mshadow/half2.h create mode 100644 include/mshadow/io.h create mode 100644 include/mshadow/logging.h create mode 100644 include/mshadow/packet-inl.h create mode 100644 include/mshadow/packet/plain-inl.h create mode 100644 include/mshadow/packet/sse-inl.h create mode 100644 include/mshadow/random.h create mode 100644 include/mshadow/stream_gpu-inl.h create mode 100755 include/mshadow/tensor.h create mode 100644 include/mshadow/tensor_container.h create mode 100755 include/mshadow/tensor_cpu-inl.h create mode 100755 include/mshadow/tensor_gpu-inl.h create mode 100644 include/nnvm/base.h create mode 100644 include/nnvm/c_api.h create mode 100644 include/nnvm/compiler/op_attr_types.h create mode 100644 include/nnvm/compiler/packed_func_ext.h create mode 100644 include/nnvm/compiler/util.h create mode 100644 include/nnvm/graph.h create mode 100644 include/nnvm/graph_attr_types.h create mode 100644 include/nnvm/layout.h create mode 100644 include/nnvm/node.h create mode 100644 include/nnvm/op.h create mode 100644 include/nnvm/op_attr_types.h create mode 100644 include/nnvm/pass.h create mode 100644 include/nnvm/pass_functions.h create mode 100644 include/nnvm/symbolic.h create mode 100644 include/nnvm/top/README create mode 100644 include/nnvm/top/nn.h create mode 100644 include/nnvm/top/tensor.h create mode 100644 include/nnvm/tuple.h diff --git a/include/dlpack/dlpack.h b/include/dlpack/dlpack.h new file mode 100644 index 000000000000..f8dc8fcd2cdf --- /dev/null +++ b/include/dlpack/dlpack.h @@ -0,0 +1,141 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file dlpack.h + * \brief The common header of DLPack. + */ +#ifndef DLPACK_DLPACK_H_ +#define DLPACK_DLPACK_H_ + +#ifdef __cplusplus +#define DLPACK_EXTERN_C extern "C" +#else +#define DLPACK_EXTERN_C +#endif + +/*! \brief The current version of dlpack */ +#define DLPACK_VERSION 010 + +/*! \brief DLPACK_DLL prefix for windows */ +#ifdef _WIN32 +#ifdef DLPACK_EXPORTS +#define DLPACK_DLL __declspec(dllexport) +#else +#define DLPACK_DLL __declspec(dllimport) +#endif +#else +#define DLPACK_DLL +#endif + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif +/*! + * \brief The device type in DLContext. + */ +typedef enum { + kDLCPU = 1, + kDLGPU = 2, + // kDLCPUPinned = kDLCPU | kDLGPU + kDLCPUPinned = 3, + kDLOpenCL = 4, + kDLMetal = 8, + kDLVPI = 9, + kDLROCM = 10, +} DLDeviceType; + +/*! + * \brief A Device context for Tensor and operator. + */ +typedef struct { + /*! \brief The device type used in the device. */ + DLDeviceType device_type; + /*! \brief The device index */ + int device_id; +} DLContext; + +/*! + * \brief The type code options DLDataType. + */ +typedef enum { + kDLInt = 0U, + kDLUInt = 1U, + kDLFloat = 2U, +} DLDataTypeCode; + +/*! + * \brief The data type the tensor can hold. + * + * Examples + * - float: type_code = 2, bits = 32, lanes=1 + * - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 + * - int8: type_code = 0, bits = 8, lanes=1 + */ +typedef struct { + /*! + * \brief Type code of base types. + * We keep it uint8_t instead of DLDataTypeCode for minimal memory + * footprint, but the value should be one of DLDataTypeCode enum values. + * */ + uint8_t code; + /*! + * \brief Number of bits, common choices are 8, 16, 32. + */ + uint8_t bits; + /*! \brief Number of lanes in the type, used for vector types. */ + uint16_t lanes; +} DLDataType; + +/*! + * \brief Plain C Tensor object, does not manage memory. + */ +typedef struct { + /*! + * \brief The opaque data pointer points to the allocated data. + * This will be CUDA device pointer or cl_mem handle in OpenCL. + * This pointer is always aligns to 256 bytes as in CUDA. + */ + void* data; + /*! \brief The device context of the tensor */ + DLContext ctx; + /*! \brief Number of dimensions */ + int ndim; + /*! \brief The data type of the pointer*/ + DLDataType dtype; + /*! \brief The shape of the tensor */ + int64_t* shape; + /*! + * \brief strides of the tensor, + * can be NULL, indicating tensor is compact. + */ + int64_t* strides; + /*! \brief The offset in bytes to the beginning pointer to data */ + uint64_t byte_offset; +} DLTensor; + +/*! + * \brief C Tensor object, manage memory of DLTensor. This data structure is + * intended to faciliate the borrowing of DLTensor by another framework. It is + * not meant to transfer the tensor. When the borrowing framework doesn't need + * the tensor, it should call the deleter to notify the host that the resource + * is no longer needed. + */ +typedef struct DLManagedTensor { + /*! \brief DLTensor which is being memory managed */ + DLTensor dl_tensor; + /*! \brief the context of the original host framework of DLManagedTensor in + * which DLManagedTensor is used in the framework. It can also be NULL. + */ + void * manager_ctx; + /*! \brief Destructor signature void (*)(void*) - this should be called + * to destruct manager_ctx which holds the DLManagedTensor. It can be NULL + * if there is no way for the caller to provide a reasonable destructor. + */ + void (*deleter)(struct DLManagedTensor * self); +} DLManagedTensor; +#ifdef __cplusplus +} // DLPACK_EXTERN_C +#endif +#endif // DLPACK_DLPACK_H_ diff --git a/include/dmlc/any.h b/include/dmlc/any.h new file mode 100644 index 000000000000..8041bf7ee53a --- /dev/null +++ b/include/dmlc/any.h @@ -0,0 +1,371 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file any.h + * \brief Container to hold any data type. + */ +#ifndef DMLC_ANY_H_ +#define DMLC_ANY_H_ + +// This code need c++11 to compile +#include +#include +#include +#include + +#include "./base.h" +#include "./logging.h" + +namespace dmlc { +// forward declare any; +class any; + +/*! + * Get a reference to content stored in the any as type T. + * This will cause an error if + * T does not match the type stored. + * This function is not part of std::any standard. + * + * \param src The source source any container. + * \return The reference of content + * \tparam T The type of the value to be fetched. + */ +template +inline T& get(any& src); // NOLINT(*) + +/*! + * Get the const reference content stored in the any as type T. + * This will cause an error if + * T does not match the type stored. + * This function is not part of std::any standard. + * + * \param src The source source any container. + * \return The reference of content + * \tparam T The type of the value to be fetched. + */ +template +inline const T& get(const any& src); + +/*! + * \brief An any class that is compatible to std::any in c++17. + * + * \code + * dmlc::any a = std::string("mydear"), b = 1; + * // get reference out and add it + * dmlc::get(b) += 1; + * // a is now string + * LOG(INFO) << dmlc::get(a); + * // a is now 2, the string stored will be properly destructed + * a = std::move(b); + * LOG(INFO) << dmlc::get(a); + * \endcode + * \sa get + */ +class any { + public: + /*! \brief default constructor */ + inline any() = default; + /*! + * \brief move constructor from another any + * \param other The other any to be moved + */ + inline any(any&& other); // NOLINT(*) + /*! + * \brief copy constructor + * \param other The other any to be copied + */ + inline any(const any& other); // NOLINT(*) + /*! + * \brief constructor from any types + * \param other The other types to be constructed into any. + * \tparam T The value type of other. + */ + template + inline any(T&& other); // NOLINT(*) + /*! \brief destructor */ + inline ~any(); + /*! + * \brief assign operator from other + * \param other The other any to be copy or moved. + * \return self + */ + inline any& operator=(any&& other); + /*! + * \brief assign operator from other + * \param other The other any to be copy or moved. + * \return self + */ + inline any& operator=(const any& other); + /*! + * \brief assign operator from any type. + * \param other The other any to be copy or moved. + * \tparam T The value type of other. + * \return self + */ + template + inline any& operator=(T&& other); + /*! + * \return whether the container is empty. + */ + inline bool empty() const; + /*! + * \brief clear the content of container + */ + inline void clear(); + /*! + * swap current content with other + * \param other The other data to be swapped. + */ + inline void swap(any& other); // NOLINT(*) + /*! + * \return The type_info about the stored type. + */ + inline const std::type_info& type() const; + /*! \brief Construct value of type T inplace */ + template + inline void construct(Args&&... args); + + private: + //! \cond Doxygen_Suppress + // declare of helper class + template + class TypeOnHeap; + template + class TypeOnStack; + template + class TypeInfo; + // size of stack space, it takes 32 bytes for one any type. + static const size_t kStack = sizeof(void*) * 3; + static const size_t kAlign = sizeof(void*); + // container use dynamic storage only when space runs lager + union Data { + // stack space + std::aligned_storage::type stack; + // pointer to heap space + void* pheap; + }; + // type specific information + struct Type { + // destructor function + void (*destroy)(Data* data); + // copy constructor + void (*create_from_data)(Data* dst, const Data& src); + // the type info function + const std::type_info* ptype_info; + }; + // constant to check if data can be stored on heap. + template + struct data_on_stack { + static const bool value = alignof(T) <= kAlign && sizeof(T) <= kStack; + }; + // declare friend with + template + friend T& get(any& src); // NOLINT(*) + template + friend const T& get(const any& src); + // internal construct function + inline void construct(any&& other); + // internal construct function + inline void construct(const any& other); + // internal function to check if type is correct. + template + inline void check_type() const; + // internal type specific information + const Type* type_{nullptr}; + // internal data + Data data_; +}; + +template +inline any::any(T&& other) { + typedef typename std::decay::type DT; + if (std::is_same::value) { + this->construct(std::forward(other)); + } else { + static_assert(std::is_copy_constructible
::value, + "Any can only hold value that is copy constructable"); + type_ = TypeInfo
::get_type(); + if (data_on_stack
::value) { +#pragma GCC diagnostic push +#if 6 <= __GNUC__ +#pragma GCC diagnostic ignored "-Wplacement-new" +#endif + new (&(data_.stack)) DT(std::forward(other)); +#pragma GCC diagnostic pop + } else { + data_.pheap = new DT(std::forward(other)); + } + } +} + +inline any::any(any&& other) { + this->construct(std::move(other)); +} + +inline any::any(const any& other) { + this->construct(other); +} + +inline void any::construct(any&& other) { + type_ = other.type_; + data_ = other.data_; + other.type_ = nullptr; +} + +inline void any::construct(const any& other) { + type_ = other.type_; + if (type_ != nullptr) { + type_->create_from_data(&data_, other.data_); + } +} + +template +inline void any::construct(Args&&... args) { + clear(); + typedef typename std::decay::type DT; + type_ = TypeInfo
::get_type(); + if (data_on_stack
::value) { +#pragma GCC diagnostic push +#if 6 <= __GNUC__ +#pragma GCC diagnostic ignored "-Wplacement-new" +#endif + new (&(data_.stack)) DT(std::forward(args)...); +#pragma GCC diagnostic pop + } else { + data_.pheap = new DT(std::forward(args)...); + } +} + +inline any::~any() { + this->clear(); +} + +inline any& any::operator=(any&& other) { + any(std::move(other)).swap(*this); + return *this; +} + +inline any& any::operator=(const any& other) { + any(other).swap(*this); + return *this; +} + +template +inline any& any::operator=(T&& other) { + any(std::forward(other)).swap(*this); + return *this; +} + +inline void any::swap(any& other) { // NOLINT(*) + std::swap(type_, other.type_); + std::swap(data_, other.data_); +} + +inline void any::clear() { + if (type_ != nullptr) { + if (type_->destroy != nullptr) { + type_->destroy(&data_); + } + type_ = nullptr; + } +} + +inline bool any::empty() const { + return type_ == nullptr; +} + +inline const std::type_info& any::type() const { + if (type_ != nullptr) { + return *(type_->ptype_info); + } else { + return typeid(void); + } +} + +template +inline void any::check_type() const { + CHECK(type_ != nullptr) + << "The any container is empty" + << " requested=" << typeid(T).name(); + CHECK(*(type_->ptype_info) == typeid(T)) + << "The stored type mismatch" + << " stored=" << type_->ptype_info->name() + << " requested=" << typeid(T).name(); +} + +template +inline const T& get(const any& src) { + src.check_type(); + return *any::TypeInfo::get_ptr(&(src.data_)); +} + +template +inline T& get(any& src) { // NOLINT(*) + src.check_type(); + return *any::TypeInfo::get_ptr(&(src.data_)); +} + +template +class any::TypeOnHeap { + public: + inline static T* get_ptr(any::Data* data) { + return static_cast(data->pheap); + } + inline static const T* get_ptr(const any::Data* data) { + return static_cast(data->pheap); + } + inline static void create_from_data(any::Data* dst, const any::Data& data) { + dst->pheap = new T(*get_ptr(&data)); + } + inline static void destroy(Data* data) { + delete static_cast(data->pheap); + } +}; + +template +class any::TypeOnStack { + public: + inline static T* get_ptr(any::Data* data) { + return reinterpret_cast(&(data->stack)); + } + inline static const T* get_ptr(const any::Data* data) { + return reinterpret_cast(&(data->stack)); + } + inline static void create_from_data(any::Data* dst, const any::Data& data) { + new (&(dst->stack)) T(*get_ptr(&data)); + } + inline static void destroy(Data* data) { + T* dptr = reinterpret_cast(&(data->stack)); + dptr->~T(); + } +}; + +template +class any::TypeInfo + : public std::conditional::value, + any::TypeOnStack, + any::TypeOnHeap >::type { + public: + inline static const Type* get_type() { + static TypeInfo tp; + return &(tp.type_); + } + + private: + // local type + Type type_; + // constructor + TypeInfo() { + if (std::is_pod::value && data_on_stack::value) { + type_.destroy = nullptr; + } else { + type_.destroy = TypeInfo::destroy; + } + type_.create_from_data = TypeInfo::create_from_data; + type_.ptype_info = &typeid(T); + } +}; +//! \endcond + +} // namespace dmlc + +#endif // DMLC_ANY_H_ diff --git a/include/dmlc/array_view.h b/include/dmlc/array_view.h new file mode 100644 index 000000000000..5e01a78cc53d --- /dev/null +++ b/include/dmlc/array_view.h @@ -0,0 +1,128 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file array_view.h + * \brief Read only data structure to reference array + */ +#ifndef DMLC_ARRAY_VIEW_H_ +#define DMLC_ARRAY_VIEW_H_ + +#include +#include + +namespace dmlc { + +/*! + * \brief Read only data structure to reference continuous memory region of array. + * Provide unified view for vector, array and C style array. + * This data structure do not guarantee aliveness of referenced array. + * + * Make sure do not use array_view to record data in async function closures. + * Also do not use array_view to create reference to temporary data structure. + * + * \tparam ValueType The value + * + * \code + * std::vector myvec{1,2,3}; + * dmlc::array_view view(myvec); + * // indexed visit to the view. + * LOG(INFO) << view[0]; + * + * for (int v : view) { + * // visit each element in the view + * } + * \endcode + */ +template +class array_view { + public: + /*! \brief default constructor */ + array_view() = default; + /*! + * \brief default copy constructor + * \param other another array view. + */ + array_view(const array_view &other) = default; // NOLINT(*) +#ifndef _MSC_VER + /*! + * \brief default move constructor + * \param other another array view. + */ + array_view(array_view&& other) = default; // NOLINT(*) +#else + /*! + * \brief default move constructor + * \param other another array view. + */ + array_view(array_view&& other) { // NOLINT(*) + begin_ = other.begin_; + size_ = other.size_; + other.begin_ = nullptr; + } +#endif + /*! + * \brief default assign constructor + * \param other another array view. + * \return self. + */ + array_view& operator=(const array_view& other) = default; // NOLINT(*) + /*! + * \brief construct array view std::vector + * \param other vector container + */ + array_view(const std::vector& other) { // NOLINT(*) + if (other.size() != 0) { + begin_ = &other[0]; size_ = other.size(); + } + } + /*! + * \brief construct array std::array + * \param other another array view. + */ + template + array_view(const std::array& other) { // NOLINT(*) + if (size != 0) { + begin_ = &other[0]; size_ = size; + } + } + /*! + * \brief construct array view from continuous segment + * \param begin beginning pointre + * \param end end pointer + */ + array_view(const ValueType* begin, const ValueType* end) { + if (begin < end) { + begin_ = begin; + size_ = end - begin; + } + } + /*! \return size of the array */ + inline size_t size() const { + return size_; + } + /*! \return begin of the array */ + inline const ValueType* begin() const { + return begin_; + } + /*! \return end point of the array */ + inline const ValueType* end() const { + return begin_ + size_; + } + /*! + * \brief get i-th element from the view + * \param i The index. + * \return const reference to i-th element. + */ + inline const ValueType& operator[](size_t i) const { + return begin_[i]; + } + + private: + /*! \brief the begin of the view */ + const ValueType* begin_{nullptr}; + /*! \brief The size of the view */ + size_t size_{0}; +}; + +} // namespace dmlc + +#endif // DMLC_ARRAY_VIEW_H_ diff --git a/include/dmlc/base.h b/include/dmlc/base.h new file mode 100644 index 000000000000..1caf487e9365 --- /dev/null +++ b/include/dmlc/base.h @@ -0,0 +1,291 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file base.h + * \brief defines configuration macros + */ +#ifndef DMLC_BASE_H_ +#define DMLC_BASE_H_ + +/*! \brief whether use glog for logging */ +#ifndef DMLC_USE_GLOG +#define DMLC_USE_GLOG 0 +#endif + +/*! + * \brief whether throw dmlc::Error instead of + * directly calling abort when FATAL error occured + * NOTE: this may still not be perfect. + * do not use FATAL and CHECK in destructors + */ +#ifndef DMLC_LOG_FATAL_THROW +#define DMLC_LOG_FATAL_THROW 1 +#endif + +/*! + * \brief whether always log a message before throw + * This can help identify the error that cannot be catched. + */ +#ifndef DMLC_LOG_BEFORE_THROW +#define DMLC_LOG_BEFORE_THROW 0 +#endif + +/*! + * \brief Whether to use customized logger, + * whose output can be decided by other libraries. + */ +#ifndef DMLC_LOG_CUSTOMIZE +#define DMLC_LOG_CUSTOMIZE 0 +#endif + +/*! + * \brief Whether to print stack trace for fatal error, + * enabled on linux when using gcc. + */ +#if (defined(__GNUC__) && !defined(__MINGW32__)\ + && !defined(__sun) && !defined(__SVR4)\ + && !(defined __MINGW64__) && !(defined __ANDROID__)) +#if (!defined(DMLC_LOG_STACK_TRACE)) +#define DMLC_LOG_STACK_TRACE 1 +#endif +#if (!defined(DMLC_LOG_STACK_TRACE_SIZE)) +#define DMLC_LOG_STACK_TRACE_SIZE 10 +#endif +#endif + +/*! \brief whether compile with hdfs support */ +#ifndef DMLC_USE_HDFS +#define DMLC_USE_HDFS 0 +#endif + +/*! \brief whether compile with s3 support */ +#ifndef DMLC_USE_S3 +#define DMLC_USE_S3 0 +#endif + +/*! \brief whether or not use parameter server */ +#ifndef DMLC_USE_PS +#define DMLC_USE_PS 0 +#endif + +/*! \brief whether or not use c++11 support */ +#ifndef DMLC_USE_CXX11 +#if defined(__GXX_EXPERIMENTAL_CXX0X__) || defined(_MSC_VER) +#define DMLC_USE_CXX11 1 +#else +#define DMLC_USE_CXX11 (__cplusplus >= 201103L) +#endif +#endif + +/*! \brief strict CXX11 support */ +#ifndef DMLC_STRICT_CXX11 +#if defined(_MSC_VER) +#define DMLC_STRICT_CXX11 1 +#else +#define DMLC_STRICT_CXX11 (__cplusplus >= 201103L) +#endif +#endif + +/*! \brief Whether cxx11 thread local is supported */ +#ifndef DMLC_CXX11_THREAD_LOCAL +#if defined(_MSC_VER) +#define DMLC_CXX11_THREAD_LOCAL (_MSC_VER >= 1900) +#elif defined(__clang__) +#define DMLC_CXX11_THREAD_LOCAL (__has_feature(cxx_thread_local)) +#else +#define DMLC_CXX11_THREAD_LOCAL (__cplusplus >= 201103L) +#endif +#endif + + +/*! \brief whether RTTI is enabled */ +#ifndef DMLC_ENABLE_RTTI +#define DMLC_ENABLE_RTTI 1 +#endif + +/*! \brief whether use fopen64 */ +#ifndef DMLC_USE_FOPEN64 +#define DMLC_USE_FOPEN64 1 +#endif + +/// check if g++ is before 4.6 +#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__) +#if __GNUC__ == 4 && __GNUC_MINOR__ < 6 +#pragma message("Will need g++-4.6 or higher to compile all" \ + "the features in dmlc-core, " \ + "compile without c++0x, some features may be disabled") +#undef DMLC_USE_CXX11 +#define DMLC_USE_CXX11 0 +#endif +#endif + +/*! + * \brief Use little endian for binary serialization + * if this is set to 0, use big endian. + */ +#ifndef DMLC_IO_USE_LITTLE_ENDIAN +#define DMLC_IO_USE_LITTLE_ENDIAN 1 +#endif + +/*! + * \brief Enable std::thread related modules, + * Used to disable some module in mingw compile. + */ +#ifndef DMLC_ENABLE_STD_THREAD +#define DMLC_ENABLE_STD_THREAD DMLC_USE_CXX11 +#endif + +/*! \brief whether enable regex support, actually need g++-4.9 or higher*/ +#ifndef DMLC_USE_REGEX +#define DMLC_USE_REGEX DMLC_STRICT_CXX11 +#endif + +/*! \brief helper macro to supress unused warning */ +#if defined(__GNUC__) +#define DMLC_ATTRIBUTE_UNUSED __attribute__((unused)) +#else +#define DMLC_ATTRIBUTE_UNUSED +#endif + +/*! \brief helper macro to generate string concat */ +#define DMLC_STR_CONCAT_(__x, __y) __x##__y +#define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y) + +/*! + * \brief Disable copy constructor and assignment operator. + * + * If C++11 is supported, both copy and move constructors and + * assignment operators are deleted explicitly. Otherwise, they are + * only declared but not implemented. Place this macro in private + * section if C++11 is not available. + */ +#ifndef DISALLOW_COPY_AND_ASSIGN +# if DMLC_USE_CXX11 +# define DISALLOW_COPY_AND_ASSIGN(T) \ + T(T const&) = delete; \ + T(T&&) = delete; \ + T& operator=(T const&) = delete; \ + T& operator=(T&&) = delete +# else +# define DISALLOW_COPY_AND_ASSIGN(T) \ + T(T const&); \ + T& operator=(T const&) +# endif +#endif + +#if DMLC_USE_FOPEN64 && \ + (!defined(__GNUC__) || (defined __ANDROID__) || ((defined __MINGW32__) && !(defined __MINGW64__))) +#define fopen64 std::fopen +#endif + +#ifdef __APPLE__ +# define off64_t off_t +# if DMLC_USE_FOPEN64 +# define fopen64 std::fopen +# endif +#endif + +#ifdef _MSC_VER +#if _MSC_VER < 1900 +// NOTE: sprintf_s is not equivalent to snprintf, +// they are equivalent when success, which is sufficient for our case +#define snprintf sprintf_s +#define vsnprintf vsprintf_s +#endif +#else +#ifdef _FILE_OFFSET_BITS +#if _FILE_OFFSET_BITS == 32 +#pragma message("Warning: FILE OFFSET BITS defined to be 32 bit") +#endif +#endif + + +extern "C" { +#include +} +#endif + +#ifdef _MSC_VER +//! \cond Doxygen_Suppress +typedef signed char int8_t; +typedef __int16 int16_t; +typedef __int32 int32_t; +typedef __int64 int64_t; +typedef unsigned char uint8_t; +typedef unsigned __int16 uint16_t; +typedef unsigned __int32 uint32_t; +typedef unsigned __int64 uint64_t; +//! \endcond +#else +#include +#endif +#include +#include + +#if defined(_MSC_VER) && _MSC_VER < 1900 +#define noexcept_true throw () +#define noexcept_false +#define noexcept(a) noexcept_##a +#endif + +#if DMLC_USE_CXX11 +#define DMLC_THROW_EXCEPTION noexcept(false) +#define DMLC_NO_EXCEPTION noexcept(true) +#else +#define DMLC_THROW_EXCEPTION +#define DMLC_NO_EXCEPTION +#endif + +/*! \brief namespace for dmlc */ +namespace dmlc { +/*! + * \brief safely get the beginning address of a vector + * \param vec input vector + * \return beginning address of a vector + */ +template +inline T *BeginPtr(std::vector &vec) { // NOLINT(*) + if (vec.size() == 0) { + return NULL; + } else { + return &vec[0]; + } +} +/*! + * \brief get the beginning address of a const vector + * \param vec input vector + * \return beginning address of a vector + */ +template +inline const T *BeginPtr(const std::vector &vec) { + if (vec.size() == 0) { + return NULL; + } else { + return &vec[0]; + } +} +/*! + * \brief get the beginning address of a string + * \param str input string + * \return beginning address of a string + */ +inline char* BeginPtr(std::string &str) { // NOLINT(*) + if (str.length() == 0) return NULL; + return &str[0]; +} +/*! + * \brief get the beginning address of a const string + * \param str input string + * \return beginning address of a string + */ +inline const char* BeginPtr(const std::string &str) { + if (str.length() == 0) return NULL; + return &str[0]; +} +} // namespace dmlc + +#if defined(_MSC_VER) && _MSC_VER < 1900 +#define constexpr const +#define alignof __alignof +#endif + +#endif // DMLC_BASE_H_ diff --git a/include/dmlc/blockingconcurrentqueue.h b/include/dmlc/blockingconcurrentqueue.h new file mode 100644 index 000000000000..9d249430289b --- /dev/null +++ b/include/dmlc/blockingconcurrentqueue.h @@ -0,0 +1,991 @@ +//! \cond Doxygen_Suppress +// Provides an efficient blocking version of moodycamel::ConcurrentQueue. +// ©2015-2016 Cameron Desrochers. Distributed under the terms of the simplified +// BSD license, available at the top of concurrentqueue.h. +// Uses Jeff Preshing's semaphore implementation (under the terms of its +// separate zlib license, embedded below). + +#ifndef DMLC_BLOCKINGCONCURRENTQUEUE_H_ +#define DMLC_BLOCKINGCONCURRENTQUEUE_H_ + +#pragma once + +#include "concurrentqueue.h" +#include +#include +#include +#include +#include + +#if defined(_WIN32) +// Avoid including windows.h in a header; we only need a handful of +// items, so we'll redeclare them here (this is relatively safe since +// the API generally has to remain stable between Windows versions). +// I know this is an ugly hack but it still beats polluting the global +// namespace with thousands of generic names or adding a .cpp for nothing. +extern "C" { + struct _SECURITY_ATTRIBUTES; + __declspec(dllimport) void* __stdcall CreateSemaphoreW(_SECURITY_ATTRIBUTES* lpSemaphoreAttributes, long lInitialCount, long lMaximumCount, const wchar_t* lpName); + __declspec(dllimport) int __stdcall CloseHandle(void* hObject); + __declspec(dllimport) unsigned long __stdcall WaitForSingleObject(void* hHandle, unsigned long dwMilliseconds); + __declspec(dllimport) int __stdcall ReleaseSemaphore(void* hSemaphore, long lReleaseCount, long* lpPreviousCount); +} +#elif defined(__MACH__) +#include +#elif defined(__unix__) +#include +#endif + +namespace dmlc { + +namespace moodycamel +{ +namespace details +{ + // Code in the mpmc_sema namespace below is an adaptation of Jeff Preshing's + // portable + lightweight semaphore implementations, originally from + // /~https://github.com/preshing/cpp11-on-multicore/blob/master/common/sema.h + // LICENSE: + // Copyright (c) 2015 Jeff Preshing + // + // This software is provided 'as-is', without any express or implied + // warranty. In no event will the authors be held liable for any damages + // arising from the use of this software. + // + // Permission is granted to anyone to use this software for any purpose, + // including commercial applications, and to alter it and redistribute it + // freely, subject to the following restrictions: + // + // 1. The origin of this software must not be misrepresented; you must not + // claim that you wrote the original software. If you use this software + // in a product, an acknowledgement in the product documentation would be + // appreciated but is not required. + // 2. Altered source versions must be plainly marked as such, and must not be + // misrepresented as being the original software. + // 3. This notice may not be removed or altered from any source distribution. + namespace mpmc_sema + { +#if defined(_WIN32) + class Semaphore + { + private: + void* m_hSema; + + Semaphore(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; + Semaphore& operator=(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; + + public: + Semaphore(int initialCount = 0) + { + assert(initialCount >= 0); + const long maxLong = 0x7fffffff; + m_hSema = CreateSemaphoreW(nullptr, initialCount, maxLong, nullptr); + } + + ~Semaphore() + { + CloseHandle(m_hSema); + } + + void wait() + { + const unsigned long infinite = 0xffffffff; + WaitForSingleObject(m_hSema, infinite); + } + + bool try_wait() + { + const unsigned long RC_WAIT_TIMEOUT = 0x00000102; + return WaitForSingleObject(m_hSema, 0) != RC_WAIT_TIMEOUT; + } + + bool timed_wait(std::uint64_t usecs) + { + const unsigned long RC_WAIT_TIMEOUT = 0x00000102; + return WaitForSingleObject(m_hSema, (unsigned long)(usecs / 1000)) != RC_WAIT_TIMEOUT; + } + + void signal(int count = 1) + { + ReleaseSemaphore(m_hSema, count, nullptr); + } + }; +#elif defined(__MACH__) + //--------------------------------------------------------- + // Semaphore (Apple iOS and OSX) + // Can't use POSIX semaphores due to http://lists.apple.com/archives/darwin-kernel/2009/Apr/msg00010.html + //--------------------------------------------------------- + class Semaphore + { + private: + semaphore_t m_sema; + + Semaphore(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; + Semaphore& operator=(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; + + public: + Semaphore(int initialCount = 0) + { + assert(initialCount >= 0); + semaphore_create(mach_task_self(), &m_sema, SYNC_POLICY_FIFO, initialCount); + } + + ~Semaphore() + { + semaphore_destroy(mach_task_self(), m_sema); + } + + void wait() + { + semaphore_wait(m_sema); + } + + bool try_wait() + { + return timed_wait(0); + } + + bool timed_wait(std::uint64_t timeout_usecs) + { + mach_timespec_t ts; + ts.tv_sec = static_cast(timeout_usecs / 1000000); + ts.tv_nsec = (timeout_usecs % 1000000) * 1000; + + // added in OSX 10.10: https://developer.apple.com/library/prerelease/mac/documentation/General/Reference/APIDiffsMacOSX10_10SeedDiff/modules/Darwin.html + kern_return_t rc = semaphore_timedwait(m_sema, ts); + + return rc != KERN_OPERATION_TIMED_OUT; + } + + void signal() + { + semaphore_signal(m_sema); + } + + void signal(int count) + { + while (count-- > 0) + { + semaphore_signal(m_sema); + } + } + }; +#elif defined(__unix__) + //--------------------------------------------------------- + // Semaphore (POSIX, Linux) + //--------------------------------------------------------- + class Semaphore + { + private: + sem_t m_sema; + + Semaphore(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; + Semaphore& operator=(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; + + public: + Semaphore(int initialCount = 0) + { + assert(initialCount >= 0); + sem_init(&m_sema, 0, initialCount); + } + + ~Semaphore() + { + sem_destroy(&m_sema); + } + + void wait() + { + // http://stackoverflow.com/questions/2013181/gdb-causes-sem-wait-to-fail-with-eintr-error + int rc; + do { + rc = sem_wait(&m_sema); + } while (rc == -1 && errno == EINTR); + } + + bool try_wait() + { + int rc; + do { + rc = sem_trywait(&m_sema); + } while (rc == -1 && errno == EINTR); + return !(rc == -1 && errno == EAGAIN); + } + + bool timed_wait(std::uint64_t usecs) + { + struct timespec ts; + const int usecs_in_1_sec = 1000000; + const int nsecs_in_1_sec = 1000000000; + clock_gettime(CLOCK_REALTIME, &ts); + ts.tv_sec += usecs / usecs_in_1_sec; + ts.tv_nsec += (usecs % usecs_in_1_sec) * 1000; + // sem_timedwait bombs if you have more than 1e9 in tv_nsec + // so we have to clean things up before passing it in + if (ts.tv_nsec >= nsecs_in_1_sec) { + ts.tv_nsec -= nsecs_in_1_sec; + ++ts.tv_sec; + } + + int rc; + do { + rc = sem_timedwait(&m_sema, &ts); + } while (rc == -1 && errno == EINTR); + return !(rc == -1 && errno == ETIMEDOUT); + } + + void signal() + { + sem_post(&m_sema); + } + + void signal(int count) + { + while (count-- > 0) + { + sem_post(&m_sema); + } + } + }; +#else +#error Unsupported platform! (No semaphore wrapper available) +#endif + + //--------------------------------------------------------- + // LightweightSemaphore + //--------------------------------------------------------- + class LightweightSemaphore + { + public: + typedef std::make_signed::type ssize_t; + + private: + std::atomic m_count; + Semaphore m_sema; + + bool waitWithPartialSpinning(std::int64_t timeout_usecs = -1) + { + ssize_t oldCount; + // Is there a better way to set the initial spin count? + // If we lower it to 1000, testBenaphore becomes 15x slower on my Core i7-5930K Windows PC, + // as threads start hitting the kernel semaphore. + int spin = 10000; + while (--spin >= 0) + { + oldCount = m_count.load(std::memory_order_relaxed); + if ((oldCount > 0) && m_count.compare_exchange_strong(oldCount, oldCount - 1, std::memory_order_acquire, std::memory_order_relaxed)) + return true; + std::atomic_signal_fence(std::memory_order_acquire); // Prevent the compiler from collapsing the loop. + } + oldCount = m_count.fetch_sub(1, std::memory_order_acquire); + if (oldCount > 0) + return true; + if (timeout_usecs < 0) + { + m_sema.wait(); + return true; + } + if (m_sema.timed_wait((std::uint64_t)timeout_usecs)) + return true; + // At this point, we've timed out waiting for the semaphore, but the + // count is still decremented indicating we may still be waiting on + // it. So we have to re-adjust the count, but only if the semaphore + // wasn't signaled enough times for us too since then. If it was, we + // need to release the semaphore too. + while (true) + { + oldCount = m_count.load(std::memory_order_acquire); + if (oldCount >= 0 && m_sema.try_wait()) + return true; + if (oldCount < 0 && m_count.compare_exchange_strong(oldCount, oldCount + 1, std::memory_order_relaxed, std::memory_order_relaxed)) + return false; + } + } + + ssize_t waitManyWithPartialSpinning(ssize_t max, std::int64_t timeout_usecs = -1) + { + assert(max > 0); + ssize_t oldCount; + int spin = 10000; + while (--spin >= 0) + { + oldCount = m_count.load(std::memory_order_relaxed); + if (oldCount > 0) + { + ssize_t newCount = oldCount > max ? oldCount - max : 0; + if (m_count.compare_exchange_strong(oldCount, newCount, std::memory_order_acquire, std::memory_order_relaxed)) + return oldCount - newCount; + } + std::atomic_signal_fence(std::memory_order_acquire); + } + oldCount = m_count.fetch_sub(1, std::memory_order_acquire); + if (oldCount <= 0) + { + if (timeout_usecs < 0) + m_sema.wait(); + else if (!m_sema.timed_wait((std::uint64_t)timeout_usecs)) + { + while (true) + { + oldCount = m_count.load(std::memory_order_acquire); + if (oldCount >= 0 && m_sema.try_wait()) + break; + if (oldCount < 0 && m_count.compare_exchange_strong(oldCount, oldCount + 1, std::memory_order_relaxed, std::memory_order_relaxed)) + return 0; + } + } + } + if (max > 1) + return 1 + tryWaitMany(max - 1); + return 1; + } + + public: + LightweightSemaphore(ssize_t initialCount = 0) : m_count(initialCount) + { + assert(initialCount >= 0); + } + + bool tryWait() + { + ssize_t oldCount = m_count.load(std::memory_order_relaxed); + while (oldCount > 0) + { + if (m_count.compare_exchange_weak(oldCount, oldCount - 1, std::memory_order_acquire, std::memory_order_relaxed)) + return true; + } + return false; + } + + void wait() + { + if (!tryWait()) + waitWithPartialSpinning(); + } + + bool wait(std::int64_t timeout_usecs) + { + return tryWait() || waitWithPartialSpinning(timeout_usecs); + } + + // Acquires between 0 and (greedily) max, inclusive + ssize_t tryWaitMany(ssize_t max) + { + assert(max >= 0); + ssize_t oldCount = m_count.load(std::memory_order_relaxed); + while (oldCount > 0) + { + ssize_t newCount = oldCount > max ? oldCount - max : 0; + if (m_count.compare_exchange_weak(oldCount, newCount, std::memory_order_acquire, std::memory_order_relaxed)) + return oldCount - newCount; + } + return 0; + } + + // Acquires at least one, and (greedily) at most max + ssize_t waitMany(ssize_t max, std::int64_t timeout_usecs) + { + assert(max >= 0); + ssize_t result = tryWaitMany(max); + if (result == 0 && max > 0) + result = waitManyWithPartialSpinning(max, timeout_usecs); + return result; + } + + ssize_t waitMany(ssize_t max) + { + ssize_t result = waitMany(max, -1); + assert(result > 0); + return result; + } + + void signal(ssize_t count = 1) + { + assert(count >= 0); + ssize_t oldCount = m_count.fetch_add(count, std::memory_order_release); + ssize_t toRelease = -oldCount < count ? -oldCount : count; + if (toRelease > 0) + { + m_sema.signal((int)toRelease); + } + } + + ssize_t availableApprox() const + { + ssize_t count = m_count.load(std::memory_order_relaxed); + return count > 0 ? count : 0; + } + }; + } // end namespace mpmc_sema +} // end namespace details + + +// This is a blocking version of the queue. It has an almost identical interface to +// the normal non-blocking version, with the addition of various wait_dequeue() methods +// and the removal of producer-specific dequeue methods. +template +class BlockingConcurrentQueue +{ +private: + typedef ::dmlc::moodycamel::ConcurrentQueue ConcurrentQueue; + typedef details::mpmc_sema::LightweightSemaphore LightweightSemaphore; + +public: + typedef typename ConcurrentQueue::producer_token_t producer_token_t; + typedef typename ConcurrentQueue::consumer_token_t consumer_token_t; + + typedef typename ConcurrentQueue::index_t index_t; + typedef typename ConcurrentQueue::size_t size_t; + typedef typename std::make_signed::type ssize_t; + + static const size_t BLOCK_SIZE = ConcurrentQueue::BLOCK_SIZE; + static const size_t EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD = ConcurrentQueue::EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD; + static const size_t EXPLICIT_INITIAL_INDEX_SIZE = ConcurrentQueue::EXPLICIT_INITIAL_INDEX_SIZE; + static const size_t IMPLICIT_INITIAL_INDEX_SIZE = ConcurrentQueue::IMPLICIT_INITIAL_INDEX_SIZE; + static const size_t INITIAL_IMPLICIT_PRODUCER_HASH_SIZE = ConcurrentQueue::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE; + static const std::uint32_t EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE = ConcurrentQueue::EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE; + static const size_t MAX_SUBQUEUE_SIZE = ConcurrentQueue::MAX_SUBQUEUE_SIZE; + +public: + // Creates a queue with at least `capacity` element slots; note that the + // actual number of elements that can be inserted without additional memory + // allocation depends on the number of producers and the block size (e.g. if + // the block size is equal to `capacity`, only a single block will be allocated + // up-front, which means only a single producer will be able to enqueue elements + // without an extra allocation -- blocks aren't shared between producers). + // This method is not thread safe -- it is up to the user to ensure that the + // queue is fully constructed before it starts being used by other threads (this + // includes making the memory effects of construction visible, possibly with a + // memory barrier). + explicit BlockingConcurrentQueue(size_t capacity = 6 * BLOCK_SIZE) + : inner(capacity), sema(create(), &BlockingConcurrentQueue::template destroy) + { + assert(reinterpret_cast((BlockingConcurrentQueue*)1) == &((BlockingConcurrentQueue*)1)->inner && "BlockingConcurrentQueue must have ConcurrentQueue as its first member"); + if (!sema) { + MOODYCAMEL_THROW(std::bad_alloc()); + } + } + + BlockingConcurrentQueue(size_t minCapacity, size_t maxExplicitProducers, size_t maxImplicitProducers) + : inner(minCapacity, maxExplicitProducers, maxImplicitProducers), sema(create(), &BlockingConcurrentQueue::template destroy) + { + assert(reinterpret_cast((BlockingConcurrentQueue*)1) == &((BlockingConcurrentQueue*)1)->inner && "BlockingConcurrentQueue must have ConcurrentQueue as its first member"); + if (!sema) { + MOODYCAMEL_THROW(std::bad_alloc()); + } + } + + // Disable copying and copy assignment + BlockingConcurrentQueue(BlockingConcurrentQueue const&) MOODYCAMEL_DELETE_FUNCTION; + BlockingConcurrentQueue& operator=(BlockingConcurrentQueue const&) MOODYCAMEL_DELETE_FUNCTION; + + // Moving is supported, but note that it is *not* a thread-safe operation. + // Nobody can use the queue while it's being moved, and the memory effects + // of that move must be propagated to other threads before they can use it. + // Note: When a queue is moved, its tokens are still valid but can only be + // used with the destination queue (i.e. semantically they are moved along + // with the queue itself). + BlockingConcurrentQueue(BlockingConcurrentQueue&& other) MOODYCAMEL_NOEXCEPT + : inner(std::move(other.inner)), sema(std::move(other.sema)) + { } + + inline BlockingConcurrentQueue& operator=(BlockingConcurrentQueue&& other) MOODYCAMEL_NOEXCEPT + { + return swap_internal(other); + } + + // Swaps this queue's state with the other's. Not thread-safe. + // Swapping two queues does not invalidate their tokens, however + // the tokens that were created for one queue must be used with + // only the swapped queue (i.e. the tokens are tied to the + // queue's movable state, not the object itself). + inline void swap(BlockingConcurrentQueue& other) MOODYCAMEL_NOEXCEPT + { + swap_internal(other); + } + +private: + BlockingConcurrentQueue& swap_internal(BlockingConcurrentQueue& other) + { + if (this == &other) { + return *this; + } + + inner.swap(other.inner); + sema.swap(other.sema); + return *this; + } + +public: + // Enqueues a single item (by copying it). + // Allocates memory if required. Only fails if memory allocation fails (or implicit + // production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0, + // or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Thread-safe. + inline bool enqueue(T const& item) + { + if (details::likely(inner.enqueue(item))) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues a single item (by moving it, if possible). + // Allocates memory if required. Only fails if memory allocation fails (or implicit + // production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0, + // or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Thread-safe. + inline bool enqueue(T&& item) + { + if (details::likely(inner.enqueue(std::move(item)))) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues a single item (by copying it) using an explicit producer token. + // Allocates memory if required. Only fails if memory allocation fails (or + // Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Thread-safe. + inline bool enqueue(producer_token_t const& token, T const& item) + { + if (details::likely(inner.enqueue(token, item))) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues a single item (by moving it, if possible) using an explicit producer token. + // Allocates memory if required. Only fails if memory allocation fails (or + // Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Thread-safe. + inline bool enqueue(producer_token_t const& token, T&& item) + { + if (details::likely(inner.enqueue(token, std::move(item)))) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues several items. + // Allocates memory if required. Only fails if memory allocation fails (or + // implicit production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE + // is 0, or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Note: Use std::make_move_iterator if the elements should be moved instead of copied. + // Thread-safe. + template + inline bool enqueue_bulk(It itemFirst, size_t count) + { + if (details::likely(inner.enqueue_bulk(std::forward(itemFirst), count))) { + sema->signal((LightweightSemaphore::ssize_t)(ssize_t)count); + return true; + } + return false; + } + + // Enqueues several items using an explicit producer token. + // Allocates memory if required. Only fails if memory allocation fails + // (or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Note: Use std::make_move_iterator if the elements should be moved + // instead of copied. + // Thread-safe. + template + inline bool enqueue_bulk(producer_token_t const& token, It itemFirst, size_t count) + { + if (details::likely(inner.enqueue_bulk(token, std::forward(itemFirst), count))) { + sema->signal((LightweightSemaphore::ssize_t)(ssize_t)count); + return true; + } + return false; + } + + // Enqueues a single item (by copying it). + // Does not allocate memory. Fails if not enough room to enqueue (or implicit + // production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE + // is 0). + // Thread-safe. + inline bool try_enqueue(T const& item) + { + if (inner.try_enqueue(item)) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues a single item (by moving it, if possible). + // Does not allocate memory (except for one-time implicit producer). + // Fails if not enough room to enqueue (or implicit production is + // disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0). + // Thread-safe. + inline bool try_enqueue(T&& item) + { + if (inner.try_enqueue(std::move(item))) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues a single item (by copying it) using an explicit producer token. + // Does not allocate memory. Fails if not enough room to enqueue. + // Thread-safe. + inline bool try_enqueue(producer_token_t const& token, T const& item) + { + if (inner.try_enqueue(token, item)) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues a single item (by moving it, if possible) using an explicit producer token. + // Does not allocate memory. Fails if not enough room to enqueue. + // Thread-safe. + inline bool try_enqueue(producer_token_t const& token, T&& item) + { + if (inner.try_enqueue(token, std::move(item))) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues several items. + // Does not allocate memory (except for one-time implicit producer). + // Fails if not enough room to enqueue (or implicit production is + // disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0). + // Note: Use std::make_move_iterator if the elements should be moved + // instead of copied. + // Thread-safe. + template + inline bool try_enqueue_bulk(It itemFirst, size_t count) + { + if (inner.try_enqueue_bulk(std::forward(itemFirst), count)) { + sema->signal((LightweightSemaphore::ssize_t)(ssize_t)count); + return true; + } + return false; + } + + // Enqueues several items using an explicit producer token. + // Does not allocate memory. Fails if not enough room to enqueue. + // Note: Use std::make_move_iterator if the elements should be moved + // instead of copied. + // Thread-safe. + template + inline bool try_enqueue_bulk(producer_token_t const& token, It itemFirst, size_t count) + { + if (inner.try_enqueue_bulk(token, std::forward(itemFirst), count)) { + sema->signal((LightweightSemaphore::ssize_t)(ssize_t)count); + return true; + } + return false; + } + + + // Attempts to dequeue from the queue. + // Returns false if all producer streams appeared empty at the time they + // were checked (so, the queue is likely but not guaranteed to be empty). + // Never allocates. Thread-safe. + template + inline bool try_dequeue(U& item) + { + if (sema->tryWait()) { + while (!inner.try_dequeue(item)) { + continue; + } + return true; + } + return false; + } + + // Attempts to dequeue from the queue using an explicit consumer token. + // Returns false if all producer streams appeared empty at the time they + // were checked (so, the queue is likely but not guaranteed to be empty). + // Never allocates. Thread-safe. + template + inline bool try_dequeue(consumer_token_t& token, U& item) + { + if (sema->tryWait()) { + while (!inner.try_dequeue(token, item)) { + continue; + } + return true; + } + return false; + } + + // Attempts to dequeue several elements from the queue. + // Returns the number of items actually dequeued. + // Returns 0 if all producer streams appeared empty at the time they + // were checked (so, the queue is likely but not guaranteed to be empty). + // Never allocates. Thread-safe. + template + inline size_t try_dequeue_bulk(It itemFirst, size_t max) + { + size_t count = 0; + max = (size_t)sema->tryWaitMany((LightweightSemaphore::ssize_t)(ssize_t)max); + while (count != max) { + count += inner.template try_dequeue_bulk(itemFirst, max - count); + } + return count; + } + + // Attempts to dequeue several elements from the queue using an explicit consumer token. + // Returns the number of items actually dequeued. + // Returns 0 if all producer streams appeared empty at the time they + // were checked (so, the queue is likely but not guaranteed to be empty). + // Never allocates. Thread-safe. + template + inline size_t try_dequeue_bulk(consumer_token_t& token, It itemFirst, size_t max) + { + size_t count = 0; + max = (size_t)sema->tryWaitMany((LightweightSemaphore::ssize_t)(ssize_t)max); + while (count != max) { + count += inner.template try_dequeue_bulk(token, itemFirst, max - count); + } + return count; + } + + + + // Blocks the current thread until there's something to dequeue, then + // dequeues it. + // Never allocates. Thread-safe. + template + inline void wait_dequeue(U& item) + { + sema->wait(); + while (!inner.try_dequeue(item)) { + continue; + } + } + + // Blocks the current thread until either there's something to dequeue + // or the timeout (specified in microseconds) expires. Returns false + // without setting `item` if the timeout expires, otherwise assigns + // to `item` and returns true. + // Using a negative timeout indicates an indefinite timeout, + // and is thus functionally equivalent to calling wait_dequeue. + // Never allocates. Thread-safe. + template + inline bool wait_dequeue_timed(U& item, std::int64_t timeout_usecs) + { + if (!sema->wait(timeout_usecs)) { + return false; + } + while (!inner.try_dequeue(item)) { + continue; + } + return true; + } + + // Blocks the current thread until either there's something to dequeue + // or the timeout expires. Returns false without setting `item` if the + // timeout expires, otherwise assigns to `item` and returns true. + // Never allocates. Thread-safe. + template + inline bool wait_dequeue_timed(U& item, std::chrono::duration const& timeout) + { + return wait_dequeue_timed(item, std::chrono::duration_cast(timeout).count()); + } + + // Blocks the current thread until there's something to dequeue, then + // dequeues it using an explicit consumer token. + // Never allocates. Thread-safe. + template + inline void wait_dequeue(consumer_token_t& token, U& item) + { + sema->wait(); + while (!inner.try_dequeue(token, item)) { + continue; + } + } + + // Blocks the current thread until either there's something to dequeue + // or the timeout (specified in microseconds) expires. Returns false + // without setting `item` if the timeout expires, otherwise assigns + // to `item` and returns true. + // Using a negative timeout indicates an indefinite timeout, + // and is thus functionally equivalent to calling wait_dequeue. + // Never allocates. Thread-safe. + template + inline bool wait_dequeue_timed(consumer_token_t& token, U& item, std::int64_t timeout_usecs) + { + if (!sema->wait(timeout_usecs)) { + return false; + } + while (!inner.try_dequeue(token, item)) { + continue; + } + return true; + } + + // Blocks the current thread until either there's something to dequeue + // or the timeout expires. Returns false without setting `item` if the + // timeout expires, otherwise assigns to `item` and returns true. + // Never allocates. Thread-safe. + template + inline bool wait_dequeue_timed(consumer_token_t& token, U& item, std::chrono::duration const& timeout) + { + return wait_dequeue_timed(token, item, std::chrono::duration_cast(timeout).count()); + } + + // Attempts to dequeue several elements from the queue. + // Returns the number of items actually dequeued, which will + // always be at least one (this method blocks until the queue + // is non-empty) and at most max. + // Never allocates. Thread-safe. + template + inline size_t wait_dequeue_bulk(It itemFirst, size_t max) + { + size_t count = 0; + max = (size_t)sema->waitMany((LightweightSemaphore::ssize_t)(ssize_t)max); + while (count != max) { + count += inner.template try_dequeue_bulk(itemFirst, max - count); + } + return count; + } + + // Attempts to dequeue several elements from the queue. + // Returns the number of items actually dequeued, which can + // be 0 if the timeout expires while waiting for elements, + // and at most max. + // Using a negative timeout indicates an indefinite timeout, + // and is thus functionally equivalent to calling wait_dequeue_bulk. + // Never allocates. Thread-safe. + template + inline size_t wait_dequeue_bulk_timed(It itemFirst, size_t max, std::int64_t timeout_usecs) + { + size_t count = 0; + max = (size_t)sema->waitMany((LightweightSemaphore::ssize_t)(ssize_t)max, timeout_usecs); + while (count != max) { + count += inner.template try_dequeue_bulk(itemFirst, max - count); + } + return count; + } + + // Attempts to dequeue several elements from the queue. + // Returns the number of items actually dequeued, which can + // be 0 if the timeout expires while waiting for elements, + // and at most max. + // Never allocates. Thread-safe. + template + inline size_t wait_dequeue_bulk_timed(It itemFirst, size_t max, std::chrono::duration const& timeout) + { + return wait_dequeue_bulk_timed(itemFirst, max, std::chrono::duration_cast(timeout).count()); + } + + // Attempts to dequeue several elements from the queue using an explicit consumer token. + // Returns the number of items actually dequeued, which will + // always be at least one (this method blocks until the queue + // is non-empty) and at most max. + // Never allocates. Thread-safe. + template + inline size_t wait_dequeue_bulk(consumer_token_t& token, It itemFirst, size_t max) + { + size_t count = 0; + max = (size_t)sema->waitMany((LightweightSemaphore::ssize_t)(ssize_t)max); + while (count != max) { + count += inner.template try_dequeue_bulk(token, itemFirst, max - count); + } + return count; + } + + // Attempts to dequeue several elements from the queue using an explicit consumer token. + // Returns the number of items actually dequeued, which can + // be 0 if the timeout expires while waiting for elements, + // and at most max. + // Using a negative timeout indicates an indefinite timeout, + // and is thus functionally equivalent to calling wait_dequeue_bulk. + // Never allocates. Thread-safe. + template + inline size_t wait_dequeue_bulk_timed(consumer_token_t& token, It itemFirst, size_t max, std::int64_t timeout_usecs) + { + size_t count = 0; + max = (size_t)sema->waitMany((LightweightSemaphore::ssize_t)(ssize_t)max, timeout_usecs); + while (count != max) { + count += inner.template try_dequeue_bulk(token, itemFirst, max - count); + } + return count; + } + + // Attempts to dequeue several elements from the queue using an explicit consumer token. + // Returns the number of items actually dequeued, which can + // be 0 if the timeout expires while waiting for elements, + // and at most max. + // Never allocates. Thread-safe. + template + inline size_t wait_dequeue_bulk_timed(consumer_token_t& token, It itemFirst, size_t max, std::chrono::duration const& timeout) + { + return wait_dequeue_bulk_timed(token, itemFirst, max, std::chrono::duration_cast(timeout).count()); + } + + + // Returns an estimate of the total number of elements currently in the queue. This + // estimate is only accurate if the queue has completely stabilized before it is called + // (i.e. all enqueue and dequeue operations have completed and their memory effects are + // visible on the calling thread, and no further operations start while this method is + // being called). + // Thread-safe. + inline size_t size_approx() const + { + return (size_t)sema->availableApprox(); + } + + + // Returns true if the underlying atomic variables used by + // the queue are lock-free (they should be on most platforms). + // Thread-safe. + static bool is_lock_free() + { + return ConcurrentQueue::is_lock_free(); + } + + +private: + template + static inline U* create() + { + auto p = (Traits::malloc)(sizeof(U)); + return p != nullptr ? new (p) U : nullptr; + } + + template + static inline U* create(A1&& a1) + { + auto p = (Traits::malloc)(sizeof(U)); + return p != nullptr ? new (p) U(std::forward(a1)) : nullptr; + } + + template + static inline void destroy(U* p) + { + if (p != nullptr) { + p->~U(); + } + (Traits::free)(p); + } + +private: + ConcurrentQueue inner; + std::unique_ptr sema; +}; + + +template +inline void swap(BlockingConcurrentQueue& a, BlockingConcurrentQueue& b) MOODYCAMEL_NOEXCEPT +{ + a.swap(b); +} + +} // end namespace moodycamel +} // namespace dmlc + +#endif // DMLC_BLOCKINGCONCURRENTQUEUE_H_ +//! \endcond Doxygen_Suppress diff --git a/include/dmlc/common.h b/include/dmlc/common.h new file mode 100644 index 000000000000..9aead8c5b142 --- /dev/null +++ b/include/dmlc/common.h @@ -0,0 +1,85 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file common.h + * \brief defines some common utility function. + */ +#ifndef DMLC_COMMON_H_ +#define DMLC_COMMON_H_ + +#include +#include +#include +#include +#include "./logging.h" + +namespace dmlc { +/*! + * \brief Split a string by delimiter + * \param s String to be splitted. + * \param delim The delimiter. + * \return a splitted vector of strings. + */ +inline std::vector Split(const std::string& s, char delim) { + std::string item; + std::istringstream is(s); + std::vector ret; + while (std::getline(is, item, delim)) { + ret.push_back(item); + } + return ret; +} + +/*! + * \brief hash an object and combines the key with previous keys + */ +template +inline size_t HashCombine(size_t key, const T& value) { + std::hash hash_func; + return key ^ (hash_func(value) + 0x9e3779b9 + (key << 6) + (key >> 2)); +} + +/*! + * \brief specialize for size_t + */ +template<> +inline size_t HashCombine(size_t key, const size_t& value) { + return key ^ (value + 0x9e3779b9 + (key << 6) + (key >> 2)); +} + +/*! + * \brief OMP Exception class catches, saves and rethrows exception from OMP blocks + */ +class OMPException { + private: + // exception_ptr member to store the exception + std::exception_ptr omp_exception_; + // mutex to be acquired during catch to set the exception_ptr + std::mutex mutex_; + + public: + /*! + * \brief Parallel OMP blocks should be placed within Run to save exception + */ + template + void Run(Function f, Parameters... params) { + try { + f(params...); + } catch (dmlc::Error &ex) { + std::lock_guard lock(mutex_); + if (!omp_exception_) { + omp_exception_ = std::current_exception(); + } + } + } + + /*! + * \brief should be called from the main thread to rethrow the exception + */ + void Rethrow() { + if (this->omp_exception_) std::rethrow_exception(this->omp_exception_); + } +}; + +} // namespace dmlc + +#endif // DMLC_COMMON_H_ diff --git a/include/dmlc/concurrency.h b/include/dmlc/concurrency.h new file mode 100644 index 000000000000..754cf5aa286e --- /dev/null +++ b/include/dmlc/concurrency.h @@ -0,0 +1,258 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file concurrency.h + * \brief thread-safe data structures. + * \author Yutian Li + */ +#ifndef DMLC_CONCURRENCY_H_ +#define DMLC_CONCURRENCY_H_ +// this code depends on c++11 +#if DMLC_USE_CXX11 +#include +#include +#include +#include +#include +#include +#include "dmlc/base.h" + +namespace dmlc { + +/*! + * \brief Simple userspace spinlock implementation. + */ +class Spinlock { + public: +#ifdef _MSC_VER + Spinlock() { + lock_.clear(); + } +#else +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wbraced-scalar-init" + Spinlock() : lock_(ATOMIC_FLAG_INIT) { + } +#pragma clang diagnostic pop +#endif + ~Spinlock() = default; + /*! + * \brief Acquire lock. + */ + inline void lock() noexcept(true); + /*! + * \brief Release lock. + */ + inline void unlock() noexcept(true); + + private: + std::atomic_flag lock_; + /*! + * \brief Disable copy and move. + */ + DISALLOW_COPY_AND_ASSIGN(Spinlock); +}; + +/*! \brief type of concurrent queue */ +enum class ConcurrentQueueType { + /*! \brief FIFO queue */ + kFIFO, + /*! \brief queue with priority */ + kPriority +}; + +/*! + * \brief Cocurrent blocking queue. + */ +template +class ConcurrentBlockingQueue { + public: + ConcurrentBlockingQueue(); + ~ConcurrentBlockingQueue() = default; + /*! + * \brief Push element to the end of the queue. + * \param e Element to push into. + * \param priority the priority of the element, only used for priority queue. + * The higher the priority is, the better. + * \tparam E the element type + * + * It will copy or move the element into the queue, depending on the type of + * the parameter. + */ + template + void Push(E&& e, int priority = 0); + + /*! + * \brief Push element to the front of the queue. Only works for FIFO queue. + * For priority queue it is the same as Push. + * \param e Element to push into. + * \param priority the priority of the element, only used for priority queue. + * The higher the priority is, the better. + * \tparam E the element type + * + * It will copy or move the element into the queue, depending on the type of + * the parameter. + */ + template + void PushFront(E&& e, int priority = 0); + /*! + * \brief Pop element from the queue. + * \param rv Element popped. + * \return On false, the queue is exiting. + * + * The element will be copied or moved into the object passed in. + */ + bool Pop(T* rv); + /*! + * \brief Signal the queue for destruction. + * + * After calling this method, all blocking pop call to the queue will return + * false. + */ + void SignalForKill(); + /*! + * \brief Get the size of the queue. + * \return The size of the queue. + */ + size_t Size(); + + private: + struct Entry { + T data; + int priority; + inline bool operator<(const Entry &b) const { + return priority < b.priority; + } + }; + + std::mutex mutex_; + std::condition_variable cv_; + std::atomic exit_now_; + int nwait_consumer_; + // a priority queue + std::vector priority_queue_; + // a FIFO queue + std::deque fifo_queue_; + /*! + * \brief Disable copy and move. + */ + DISALLOW_COPY_AND_ASSIGN(ConcurrentBlockingQueue); +}; + +inline void Spinlock::lock() noexcept(true) { + while (lock_.test_and_set(std::memory_order_acquire)) { + } +} + +inline void Spinlock::unlock() noexcept(true) { + lock_.clear(std::memory_order_release); +} + +template +ConcurrentBlockingQueue::ConcurrentBlockingQueue() + : exit_now_{false}, nwait_consumer_{0} {} + +template +template +void ConcurrentBlockingQueue::Push(E&& e, int priority) { + static_assert(std::is_same::type>::type, + T>::value, + "Types must match."); + bool notify; + { + std::lock_guard lock{mutex_}; + if (type == ConcurrentQueueType::kFIFO) { + fifo_queue_.emplace_back(std::forward(e)); + notify = nwait_consumer_ != 0; + } else { + Entry entry; + entry.data = std::move(e); + entry.priority = priority; + priority_queue_.push_back(std::move(entry)); + std::push_heap(priority_queue_.begin(), priority_queue_.end()); + notify = nwait_consumer_ != 0; + } + } + if (notify) cv_.notify_one(); +} + +template +template +void ConcurrentBlockingQueue::PushFront(E&& e, int priority) { + static_assert(std::is_same::type>::type, + T>::value, + "Types must match."); + bool notify; + { + std::lock_guard lock{mutex_}; + if (type == ConcurrentQueueType::kFIFO) { + fifo_queue_.emplace_front(std::forward(e)); + notify = nwait_consumer_ != 0; + } else { + Entry entry; + entry.data = std::move(e); + entry.priority = priority; + priority_queue_.push_back(std::move(entry)); + std::push_heap(priority_queue_.begin(), priority_queue_.end()); + notify = nwait_consumer_ != 0; + } + } + if (notify) cv_.notify_one(); +} + +template +bool ConcurrentBlockingQueue::Pop(T* rv) { + std::unique_lock lock{mutex_}; + if (type == ConcurrentQueueType::kFIFO) { + ++nwait_consumer_; + cv_.wait(lock, [this] { + return !fifo_queue_.empty() || exit_now_.load(); + }); + --nwait_consumer_; + if (!exit_now_.load()) { + *rv = std::move(fifo_queue_.front()); + fifo_queue_.pop_front(); + return true; + } else { + return false; + } + } else { + ++nwait_consumer_; + cv_.wait(lock, [this] { + return !priority_queue_.empty() || exit_now_.load(); + }); + --nwait_consumer_; + if (!exit_now_.load()) { + std::pop_heap(priority_queue_.begin(), priority_queue_.end()); + *rv = std::move(priority_queue_.back().data); + priority_queue_.pop_back(); + return true; + } else { + return false; + } + } +} + +template +void ConcurrentBlockingQueue::SignalForKill() { + { + std::lock_guard lock{mutex_}; + exit_now_.store(true); + } + cv_.notify_all(); +} + +template +size_t ConcurrentBlockingQueue::Size() { + std::lock_guard lock{mutex_}; + if (type == ConcurrentQueueType::kFIFO) { + return fifo_queue_.size(); + } else { + return priority_queue_.size(); + } +} +} // namespace dmlc +#endif // DMLC_USE_CXX11 +#endif // DMLC_CONCURRENCY_H_ diff --git a/include/dmlc/concurrentqueue.h b/include/dmlc/concurrentqueue.h new file mode 100644 index 000000000000..f9b7d1147dc5 --- /dev/null +++ b/include/dmlc/concurrentqueue.h @@ -0,0 +1,3719 @@ +//! \cond Doxygen_Suppress +// Provides a C++11 implementation of a multi-producer, multi-consumer lock-free queue. +// An overview, including benchmark results, is provided here: +// http://moodycamel.com/blog/2014/a-fast-general-purpose-lock-free-queue-for-c++ +// The full design is also described in excruciating detail at: +// http://moodycamel.com/blog/2014/detailed-design-of-a-lock-free-queue + +// Simplified BSD license: +// Copyright (c) 2013-2016, Cameron Desrochers. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without modification, +// are permitted provided that the following conditions are met: +// +// - Redistributions of source code must retain the above copyright notice, this list of +// conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, this list of +// conditions and the following disclaimer in the documentation and/or other materials +// provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL +// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT +// OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR +// TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +// EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#ifndef DMLC_CONCURRENTQUEUE_H_ +#define DMLC_CONCURRENTQUEUE_H_ +#pragma once + +#if defined(__GNUC__) +// Disable -Wconversion warnings (spuriously triggered when Traits::size_t and +// Traits::index_t are set to < 32 bits, causing integer promotion, causing warnings +// upon assigning any computed values) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" + +#ifdef MCDBGQ_USE_RELACY +#pragma GCC diagnostic ignored "-Wint-to-pointer-cast" +#endif +#endif + +#if defined(_WIN32) || defined(__WINDOWS__) || defined(__WIN32__) || defined(_WIN64) +#include // for GetCurrentThreadId() +#endif + +#if defined(__APPLE__) +#include "TargetConditionals.h" +#endif + +#ifdef MCDBGQ_USE_RELACY +#include "relacy/relacy_std.hpp" +#include "relacy_shims.h" +// We only use malloc/free anyway, and the delete macro messes up `= delete` method declarations. +// We'll override the default trait malloc ourselves without a macro. +#undef new +#undef delete +#undef malloc +#undef free +#else +#include // Requires C++11. Sorry VS2010. +#include +#endif +#include // for max_align_t +#include +#include +#include +#include +#include +#include +#include // for CHAR_BIT +#include +#include // partly for __WINPTHREADS_VERSION if on MinGW-w64 w/ POSIX threading + +namespace dmlc { + +// Platform-specific definitions of a numeric thread ID type and an invalid value +namespace moodycamel { namespace details { +template struct thread_id_converter { + typedef thread_id_t thread_id_numeric_size_t; + typedef thread_id_t thread_id_hash_t; + static thread_id_hash_t prehash(thread_id_t const& x) { return x; } +}; +} } +#if defined(MCDBGQ_USE_RELACY) +namespace moodycamel { namespace details { + typedef std::uint32_t thread_id_t; + static const thread_id_t invalid_thread_id = 0xFFFFFFFFU; + static const thread_id_t invalid_thread_id2 = 0xFFFFFFFEU; + static inline thread_id_t thread_id() { return rl::thread_index(); } +} } +#elif defined(_WIN32) || defined(__WINDOWS__) || defined(__WIN32__) +// No sense pulling in windows.h in a header, we'll manually declare the function +// we use and rely on backwards-compatibility for this not to break +extern "C" __declspec(dllimport) unsigned long __stdcall GetCurrentThreadId(void); +namespace moodycamel { namespace details { + static_assert(sizeof(unsigned long) == sizeof(std::uint32_t), "Expected size of unsigned long to be 32 bits on Windows"); + typedef std::uint32_t thread_id_t; + static const thread_id_t invalid_thread_id = 0; // See http://blogs.msdn.com/b/oldnewthing/archive/2004/02/23/78395.aspx + static const thread_id_t invalid_thread_id2 = 0xFFFFFFFFU; // Not technically guaranteed to be invalid, but is never used in practice. Note that all Win32 thread IDs are presently multiples of 4. + static inline thread_id_t thread_id() { return static_cast(::GetCurrentThreadId()); } +} } +#elif defined(__arm__) || defined(_M_ARM) || defined(__aarch64__) || (defined(__APPLE__) && TARGET_OS_IPHONE) +namespace moodycamel { namespace details { + static_assert(sizeof(std::thread::id) == 4 || sizeof(std::thread::id) == 8, "std::thread::id is expected to be either 4 or 8 bytes"); + + typedef std::thread::id thread_id_t; + static const thread_id_t invalid_thread_id; // Default ctor creates invalid ID + + // Note we don't define a invalid_thread_id2 since std::thread::id doesn't have one; it's + // only used if MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED is defined anyway, which it won't + // be. + static inline thread_id_t thread_id() { return std::this_thread::get_id(); } + + template struct thread_id_size { }; + template<> struct thread_id_size<4> { typedef std::uint32_t numeric_t; }; + template<> struct thread_id_size<8> { typedef std::uint64_t numeric_t; }; + + template<> struct thread_id_converter { + typedef thread_id_size::numeric_t thread_id_numeric_size_t; +#ifndef __APPLE__ + typedef std::size_t thread_id_hash_t; +#else + typedef thread_id_numeric_size_t thread_id_hash_t; +#endif + + static thread_id_hash_t prehash(thread_id_t const& x) + { +#ifndef __APPLE__ + return std::hash()(x); +#else + return *reinterpret_cast(&x); +#endif + } + }; +} } +#else +// Use a nice trick from this answer: http://stackoverflow.com/a/8438730/21475 +// In order to get a numeric thread ID in a platform-independent way, we use a thread-local +// static variable's address as a thread identifier :-) +#if defined(__GNUC__) || defined(__INTEL_COMPILER) +#define MOODYCAMEL_THREADLOCAL __thread +#elif defined(_MSC_VER) +#define MOODYCAMEL_THREADLOCAL __declspec(thread) +#else +// Assume C++11 compliant compiler +#define MOODYCAMEL_THREADLOCAL thread_local +#endif +namespace moodycamel { namespace details { +typedef std::uintptr_t thread_id_t; +static const thread_id_t invalid_thread_id = 0; // Address can't be nullptr +static const thread_id_t invalid_thread_id2 = 1; // Member accesses off a null pointer are also generally invalid. Plus it's not aligned. +static inline thread_id_t thread_id() { static MOODYCAMEL_THREADLOCAL int x; return reinterpret_cast(&x); } +} } +#endif + +// Exceptions +#ifndef MOODYCAMEL_EXCEPTIONS_ENABLED +#if (defined(_MSC_VER) && defined(_CPPUNWIND)) || (defined(__GNUC__) && defined(__EXCEPTIONS)) || (!defined(_MSC_VER) && !defined(__GNUC__)) +#define MOODYCAMEL_EXCEPTIONS_ENABLED +#endif +#endif +#ifdef MOODYCAMEL_EXCEPTIONS_ENABLED +#define MOODYCAMEL_TRY try +#define MOODYCAMEL_CATCH(...) catch(__VA_ARGS__) +#define MOODYCAMEL_RETHROW throw +#define MOODYCAMEL_THROW(expr) throw (expr) +#else +#define MOODYCAMEL_TRY if (true) +#define MOODYCAMEL_CATCH(...) else if (false) +#define MOODYCAMEL_RETHROW +#define MOODYCAMEL_THROW(expr) +#endif + +#ifndef MOODYCAMEL_NOEXCEPT +#if !defined(MOODYCAMEL_EXCEPTIONS_ENABLED) +#define MOODYCAMEL_NOEXCEPT +#define MOODYCAMEL_NOEXCEPT_CTOR(type, valueType, expr) true +#define MOODYCAMEL_NOEXCEPT_ASSIGN(type, valueType, expr) true +#elif defined(_MSC_VER) && defined(_NOEXCEPT) && _MSC_VER < 1800 +// VS2012's std::is_nothrow_[move_]constructible is broken and returns true when it shouldn't :-( +// We have to assume *all* non-trivial constructors may throw on VS2012! +#define MOODYCAMEL_NOEXCEPT _NOEXCEPT +#define MOODYCAMEL_NOEXCEPT_CTOR(type, valueType, expr) (std::is_rvalue_reference::value && std::is_move_constructible::value ? std::is_trivially_move_constructible::value : std::is_trivially_copy_constructible::value) +#define MOODYCAMEL_NOEXCEPT_ASSIGN(type, valueType, expr) ((std::is_rvalue_reference::value && std::is_move_assignable::value ? std::is_trivially_move_assignable::value || std::is_nothrow_move_assignable::value : std::is_trivially_copy_assignable::value || std::is_nothrow_copy_assignable::value) && MOODYCAMEL_NOEXCEPT_CTOR(type, valueType, expr)) +#elif defined(_MSC_VER) && defined(_NOEXCEPT) && _MSC_VER < 1900 +#define MOODYCAMEL_NOEXCEPT _NOEXCEPT +#define MOODYCAMEL_NOEXCEPT_CTOR(type, valueType, expr) (std::is_rvalue_reference::value && std::is_move_constructible::value ? std::is_trivially_move_constructible::value || std::is_nothrow_move_constructible::value : std::is_trivially_copy_constructible::value || std::is_nothrow_copy_constructible::value) +#define MOODYCAMEL_NOEXCEPT_ASSIGN(type, valueType, expr) ((std::is_rvalue_reference::value && std::is_move_assignable::value ? std::is_trivially_move_assignable::value || std::is_nothrow_move_assignable::value : std::is_trivially_copy_assignable::value || std::is_nothrow_copy_assignable::value) && MOODYCAMEL_NOEXCEPT_CTOR(type, valueType, expr)) +#else +#define MOODYCAMEL_NOEXCEPT noexcept +#define MOODYCAMEL_NOEXCEPT_CTOR(type, valueType, expr) noexcept(expr) +#define MOODYCAMEL_NOEXCEPT_ASSIGN(type, valueType, expr) noexcept(expr) +#endif +#endif + +#ifndef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED +#ifdef MCDBGQ_USE_RELACY +#define MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED +#else +// VS2013 doesn't support `thread_local`, and MinGW-w64 w/ POSIX threading has a crippling bug: http://sourceforge.net/p/mingw-w64/bugs/445 +// g++ <=4.7 doesn't support thread_local either. +// Finally, iOS/ARM doesn't have support for it either, and g++/ARM allows it to compile but it's unconfirmed to actually work +#if (!defined(_MSC_VER) || _MSC_VER >= 1900) && (!defined(__MINGW32__) && !defined(__MINGW64__) || !defined(__WINPTHREADS_VERSION)) && (!defined(__GNUC__) || __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8)) && (!defined(__APPLE__) || !TARGET_OS_IPHONE) && !defined(__arm__) && !defined(_M_ARM) && !defined(__aarch64__) +// Assume `thread_local` is fully supported in all other C++11 compilers/platforms +//#define MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED // always disabled for now since several users report having problems with it on +#endif +#endif +#endif + +// VS2012 doesn't support deleted functions. +// In this case, we declare the function normally but don't define it. A link error will be generated if the function is called. +#ifndef MOODYCAMEL_DELETE_FUNCTION +#if defined(_MSC_VER) && _MSC_VER < 1800 +#define MOODYCAMEL_DELETE_FUNCTION +#else +#define MOODYCAMEL_DELETE_FUNCTION = delete +#endif +#endif + +// Compiler-specific likely/unlikely hints +namespace moodycamel { namespace details { +#if defined(__GNUC__) +inline bool likely(bool x) { return __builtin_expect((x), true); } +inline bool unlikely(bool x) { return __builtin_expect((x), false); } +#else +inline bool likely(bool x) { return x; } + inline bool unlikely(bool x) { return x; } +#endif +} } + +#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG +#include "internal/concurrentqueue_internal_debug.h" +#endif + +namespace moodycamel { +namespace details { +template +struct const_numeric_max { + static_assert(std::is_integral::value, "const_numeric_max can only be used with integers"); + static const T value = std::numeric_limits::is_signed + ? (static_cast(1) << (sizeof(T) * CHAR_BIT - 1)) - static_cast(1) + : static_cast(-1); +}; + +#if defined(__GLIBCXX__) +typedef ::max_align_t std_max_align_t; // libstdc++ forgot to add it to std:: for a while +#else +typedef std::max_align_t std_max_align_t; // Others (e.g. MSVC) insist it can *only* be accessed via std:: +#endif + +// Some platforms have incorrectly set max_align_t to a type with <8 bytes alignment even while supporting +// 8-byte aligned scalar values (*cough* 32-bit iOS). Work around this with our own union. See issue #64. +typedef union { + std_max_align_t x; + long long y; + void* z; +} max_align_t; +} + +// Default traits for the ConcurrentQueue. To change some of the +// traits without re-implementing all of them, inherit from this +// struct and shadow the declarations you wish to be different; +// since the traits are used as a template type parameter, the +// shadowed declarations will be used where defined, and the defaults +// otherwise. +struct ConcurrentQueueDefaultTraits +{ + // General-purpose size type. std::size_t is strongly recommended. + typedef std::size_t size_t; + + // The type used for the enqueue and dequeue indices. Must be at least as + // large as size_t. Should be significantly larger than the number of elements + // you expect to hold at once, especially if you have a high turnover rate; + // for example, on 32-bit x86, if you expect to have over a hundred million + // elements or pump several million elements through your queue in a very + // short space of time, using a 32-bit type *may* trigger a race condition. + // A 64-bit int type is recommended in that case, and in practice will + // prevent a race condition no matter the usage of the queue. Note that + // whether the queue is lock-free with a 64-int type depends on the whether + // std::atomic is lock-free, which is platform-specific. + typedef std::size_t index_t; + + // Internally, all elements are enqueued and dequeued from multi-element + // blocks; this is the smallest controllable unit. If you expect few elements + // but many producers, a smaller block size should be favoured. For few producers + // and/or many elements, a larger block size is preferred. A sane default + // is provided. Must be a power of 2. + static const size_t BLOCK_SIZE = 32; + + // For explicit producers (i.e. when using a producer token), the block is + // checked for being empty by iterating through a list of flags, one per element. + // For large block sizes, this is too inefficient, and switching to an atomic + // counter-based approach is faster. The switch is made for block sizes strictly + // larger than this threshold. + static const size_t EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD = 32; + + // How many full blocks can be expected for a single explicit producer? This should + // reflect that number's maximum for optimal performance. Must be a power of 2. + static const size_t EXPLICIT_INITIAL_INDEX_SIZE = 32; + + // How many full blocks can be expected for a single implicit producer? This should + // reflect that number's maximum for optimal performance. Must be a power of 2. + static const size_t IMPLICIT_INITIAL_INDEX_SIZE = 32; + + // The initial size of the hash table mapping thread IDs to implicit producers. + // Note that the hash is resized every time it becomes half full. + // Must be a power of two, and either 0 or at least 1. If 0, implicit production + // (using the enqueue methods without an explicit producer token) is disabled. + static const size_t INITIAL_IMPLICIT_PRODUCER_HASH_SIZE = 32; + + // Controls the number of items that an explicit consumer (i.e. one with a token) + // must consume before it causes all consumers to rotate and move on to the next + // internal queue. + static const std::uint32_t EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE = 256; + + // The maximum number of elements (inclusive) that can be enqueued to a sub-queue. + // Enqueue operations that would cause this limit to be surpassed will fail. Note + // that this limit is enforced at the block level (for performance reasons), i.e. + // it's rounded up to the nearest block size. + static const size_t MAX_SUBQUEUE_SIZE = details::const_numeric_max::value; + + +#ifndef MCDBGQ_USE_RELACY + // Memory allocation can be customized if needed. + // malloc should return nullptr on failure, and handle alignment like std::malloc. +#if defined(malloc) || defined(free) + // Gah, this is 2015, stop defining macros that break standard code already! + // Work around malloc/free being special macros: + static inline void* WORKAROUND_malloc(size_t size) { return malloc(size); } + static inline void WORKAROUND_free(void* ptr) { return free(ptr); } + static inline void* (malloc)(size_t size) { return WORKAROUND_malloc(size); } + static inline void (free)(void* ptr) { return WORKAROUND_free(ptr); } +#else + static inline void* malloc(size_t size) { return std::malloc(size); } + static inline void free(void* ptr) { return std::free(ptr); } +#endif +#else + // Debug versions when running under the Relacy race detector (ignore + // these in user code) + static inline void* malloc(size_t size) { return rl::rl_malloc(size, $); } + static inline void free(void* ptr) { return rl::rl_free(ptr, $); } +#endif +}; + + +// When producing or consuming many elements, the most efficient way is to: +// 1) Use one of the bulk-operation methods of the queue with a token +// 2) Failing that, use the bulk-operation methods without a token +// 3) Failing that, create a token and use that with the single-item methods +// 4) Failing that, use the single-parameter methods of the queue +// Having said that, don't create tokens willy-nilly -- ideally there should be +// a maximum of one token per thread (of each kind). +struct ProducerToken; +struct ConsumerToken; + +template class ConcurrentQueue; +template class BlockingConcurrentQueue; +class ConcurrentQueueTests; + + +namespace details +{ +struct ConcurrentQueueProducerTypelessBase +{ + ConcurrentQueueProducerTypelessBase* next; + std::atomic inactive; + ProducerToken* token; + + ConcurrentQueueProducerTypelessBase() + : next(nullptr), inactive(false), token(nullptr) + { + } +}; + +template struct _hash_32_or_64 { + static inline std::uint32_t hash(std::uint32_t h) + { + // MurmurHash3 finalizer -- see https://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.cpp + // Since the thread ID is already unique, all we really want to do is propagate that + // uniqueness evenly across all the bits, so that we can use a subset of the bits while + // reducing collisions significantly + h ^= h >> 16; + h *= 0x85ebca6b; + h ^= h >> 13; + h *= 0xc2b2ae35; + return h ^ (h >> 16); + } +}; +template<> struct _hash_32_or_64<1> { + static inline std::uint64_t hash(std::uint64_t h) + { + h ^= h >> 33; + h *= 0xff51afd7ed558ccd; + h ^= h >> 33; + h *= 0xc4ceb9fe1a85ec53; + return h ^ (h >> 33); + } +}; +template struct hash_32_or_64 : public _hash_32_or_64<(size > 4)> { }; + +static inline size_t hash_thread_id(thread_id_t id) +{ + static_assert(sizeof(thread_id_t) <= 8, "Expected a platform where thread IDs are at most 64-bit values"); + return static_cast(hash_32_or_64::thread_id_hash_t)>::hash( + thread_id_converter::prehash(id))); +} + +template +static inline bool circular_less_than(T a, T b) +{ +#ifdef _MSC_VER + #pragma warning(push) +#pragma warning(disable: 4554) +#endif + static_assert(std::is_integral::value && !std::numeric_limits::is_signed, "circular_less_than is intended to be used only with unsigned integer types"); + return static_cast(a - b) > static_cast(static_cast(1) << static_cast(sizeof(T) * CHAR_BIT - 1)); +#ifdef _MSC_VER +#pragma warning(pop) +#endif +} + +template +static inline char* align_for(char* ptr) +{ + const std::size_t alignment = std::alignment_of::value; + return ptr + (alignment - (reinterpret_cast(ptr) % alignment)) % alignment; +} + +template +static inline T ceil_to_pow_2(T x) +{ + static_assert(std::is_integral::value && !std::numeric_limits::is_signed, "ceil_to_pow_2 is intended to be used only with unsigned integer types"); + + // Adapted from http://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 + --x; + x |= x >> 1; + x |= x >> 2; + x |= x >> 4; + for (std::size_t i = 1; i < sizeof(T); i <<= 1) { + x |= x >> (i << 3); + } + ++x; + return x; +} + +template +static inline void swap_relaxed(std::atomic& left, std::atomic& right) +{ + T temp = std::move(left.load(std::memory_order_relaxed)); + left.store(std::move(right.load(std::memory_order_relaxed)), std::memory_order_relaxed); + right.store(std::move(temp), std::memory_order_relaxed); +} + +template +static inline T const& nomove(T const& x) +{ + return x; +} + +template +struct nomove_if +{ + template + static inline T const& eval(T const& x) + { + return x; + } +}; + +template<> +struct nomove_if +{ + template + static inline auto eval(U&& x) + -> decltype(std::forward(x)) + { + return std::forward(x); + } +}; + +template +static inline auto deref_noexcept(It& it) MOODYCAMEL_NOEXCEPT -> decltype(*it) +{ + return *it; +} + +#if defined(__clang__) || !defined(__GNUC__) || __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8) +template struct is_trivially_destructible : std::is_trivially_destructible { }; +#else +template struct is_trivially_destructible : std::has_trivial_destructor { }; +#endif + +#ifdef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED +#ifdef MCDBGQ_USE_RELACY + typedef RelacyThreadExitListener ThreadExitListener; + typedef RelacyThreadExitNotifier ThreadExitNotifier; +#else + struct ThreadExitListener + { + typedef void (*callback_t)(void*); + callback_t callback; + void* userData; + + ThreadExitListener* next; // reserved for use by the ThreadExitNotifier + }; + + + class ThreadExitNotifier + { + public: + static void subscribe(ThreadExitListener* listener) + { + auto& tlsInst = instance(); + listener->next = tlsInst.tail; + tlsInst.tail = listener; + } + + static void unsubscribe(ThreadExitListener* listener) + { + auto& tlsInst = instance(); + ThreadExitListener** prev = &tlsInst.tail; + for (auto ptr = tlsInst.tail; ptr != nullptr; ptr = ptr->next) { + if (ptr == listener) { + *prev = ptr->next; + break; + } + prev = &ptr->next; + } + } + + private: + ThreadExitNotifier() : tail(nullptr) { } + ThreadExitNotifier(ThreadExitNotifier const&) MOODYCAMEL_DELETE_FUNCTION; + ThreadExitNotifier& operator=(ThreadExitNotifier const&) MOODYCAMEL_DELETE_FUNCTION; + + ~ThreadExitNotifier() + { + // This thread is about to exit, let everyone know! + assert(this == &instance() && "If this assert fails, you likely have a buggy compiler! Change the preprocessor conditions such that MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED is no longer defined."); + for (auto ptr = tail; ptr != nullptr; ptr = ptr->next) { + ptr->callback(ptr->userData); + } + } + + // Thread-local + static inline ThreadExitNotifier& instance() + { + static thread_local ThreadExitNotifier notifier; + return notifier; + } + + private: + ThreadExitListener* tail; + }; +#endif +#endif + +template struct static_is_lock_free_num { enum { value = 0 }; }; +template<> struct static_is_lock_free_num { enum { value = ATOMIC_CHAR_LOCK_FREE }; }; +template<> struct static_is_lock_free_num { enum { value = ATOMIC_SHORT_LOCK_FREE }; }; +template<> struct static_is_lock_free_num { enum { value = ATOMIC_INT_LOCK_FREE }; }; +template<> struct static_is_lock_free_num { enum { value = ATOMIC_LONG_LOCK_FREE }; }; +template<> struct static_is_lock_free_num { enum { value = ATOMIC_LLONG_LOCK_FREE }; }; +template struct static_is_lock_free : static_is_lock_free_num::type> { }; +template<> struct static_is_lock_free { enum { value = ATOMIC_BOOL_LOCK_FREE }; }; +template struct static_is_lock_free { enum { value = ATOMIC_POINTER_LOCK_FREE }; }; +} + + +struct ProducerToken +{ + template + explicit ProducerToken(ConcurrentQueue& queue); + + template + explicit ProducerToken(BlockingConcurrentQueue& queue); + + ProducerToken(ProducerToken&& other) MOODYCAMEL_NOEXCEPT + : producer(other.producer) + { + other.producer = nullptr; + if (producer != nullptr) { + producer->token = this; + } + } + + inline ProducerToken& operator=(ProducerToken&& other) MOODYCAMEL_NOEXCEPT + { + swap(other); + return *this; + } + + void swap(ProducerToken& other) MOODYCAMEL_NOEXCEPT + { + std::swap(producer, other.producer); + if (producer != nullptr) { + producer->token = this; + } + if (other.producer != nullptr) { + other.producer->token = &other; + } + } + + // A token is always valid unless: + // 1) Memory allocation failed during construction + // 2) It was moved via the move constructor + // (Note: assignment does a swap, leaving both potentially valid) + // 3) The associated queue was destroyed + // Note that if valid() returns true, that only indicates + // that the token is valid for use with a specific queue, + // but not which one; that's up to the user to track. + inline bool valid() const { return producer != nullptr; } + + ~ProducerToken() + { + if (producer != nullptr) { + producer->token = nullptr; + producer->inactive.store(true, std::memory_order_release); + } + } + + // Disable copying and assignment + ProducerToken(ProducerToken const&) MOODYCAMEL_DELETE_FUNCTION; + ProducerToken& operator=(ProducerToken const&) MOODYCAMEL_DELETE_FUNCTION; + + private: + template friend class ConcurrentQueue; + friend class ConcurrentQueueTests; + + protected: + details::ConcurrentQueueProducerTypelessBase* producer; +}; + + +struct ConsumerToken +{ + template + explicit ConsumerToken(ConcurrentQueue& q); + + template + explicit ConsumerToken(BlockingConcurrentQueue& q); + + ConsumerToken(ConsumerToken&& other) MOODYCAMEL_NOEXCEPT + : initialOffset(other.initialOffset), lastKnownGlobalOffset(other.lastKnownGlobalOffset), itemsConsumedFromCurrent(other.itemsConsumedFromCurrent), currentProducer(other.currentProducer), desiredProducer(other.desiredProducer) + { + } + + inline ConsumerToken& operator=(ConsumerToken&& other) MOODYCAMEL_NOEXCEPT + { + swap(other); + return *this; + } + + void swap(ConsumerToken& other) MOODYCAMEL_NOEXCEPT + { + std::swap(initialOffset, other.initialOffset); + std::swap(lastKnownGlobalOffset, other.lastKnownGlobalOffset); + std::swap(itemsConsumedFromCurrent, other.itemsConsumedFromCurrent); + std::swap(currentProducer, other.currentProducer); + std::swap(desiredProducer, other.desiredProducer); + } + + // Disable copying and assignment + ConsumerToken(ConsumerToken const&) MOODYCAMEL_DELETE_FUNCTION; + ConsumerToken& operator=(ConsumerToken const&) MOODYCAMEL_DELETE_FUNCTION; + + private: + template friend class ConcurrentQueue; + friend class ConcurrentQueueTests; + + private: // but shared with ConcurrentQueue + std::uint32_t initialOffset; + std::uint32_t lastKnownGlobalOffset; + std::uint32_t itemsConsumedFromCurrent; + details::ConcurrentQueueProducerTypelessBase* currentProducer; + details::ConcurrentQueueProducerTypelessBase* desiredProducer; +}; + +// Need to forward-declare this swap because it's in a namespace. +// See http://stackoverflow.com/questions/4492062/why-does-a-c-friend-class-need-a-forward-declaration-only-in-other-namespaces +template +inline void swap(typename ConcurrentQueue::ImplicitProducerKVP& a, typename ConcurrentQueue::ImplicitProducerKVP& b) MOODYCAMEL_NOEXCEPT; + + +template +class ConcurrentQueue { + public: + typedef ::dmlc::moodycamel::ProducerToken producer_token_t; + typedef ::dmlc::moodycamel::ConsumerToken consumer_token_t; + + typedef typename Traits::index_t index_t; + typedef typename Traits::size_t size_t; + + static const size_t BLOCK_SIZE = static_cast(Traits::BLOCK_SIZE); + static const size_t EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD = static_cast(Traits::EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD); + static const size_t EXPLICIT_INITIAL_INDEX_SIZE = static_cast(Traits::EXPLICIT_INITIAL_INDEX_SIZE); + static const size_t IMPLICIT_INITIAL_INDEX_SIZE = static_cast(Traits::IMPLICIT_INITIAL_INDEX_SIZE); + static const size_t INITIAL_IMPLICIT_PRODUCER_HASH_SIZE = static_cast(Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE); + static const std::uint32_t EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE = static_cast(Traits::EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE); +#ifdef _MSC_VER + #pragma warning(push) +#pragma warning(disable: 4307) // + integral constant overflow (that's what the ternary expression is for!) +#pragma warning(disable: 4309) // static_cast: Truncation of constant value +#endif + static const size_t MAX_SUBQUEUE_SIZE = (details::const_numeric_max::value - + static_cast(Traits::MAX_SUBQUEUE_SIZE) < + BLOCK_SIZE) ? details::const_numeric_max::value + : ( + (static_cast(Traits::MAX_SUBQUEUE_SIZE) + + (BLOCK_SIZE - 1)) / BLOCK_SIZE * BLOCK_SIZE); +#ifdef _MSC_VER +#pragma warning(pop) +#endif + + static_assert(!std::numeric_limits::is_signed && std::is_integral::value, + "Traits::size_t must be an unsigned integral type"); + static_assert(!std::numeric_limits::is_signed && std::is_integral::value, + "Traits::index_t must be an unsigned integral type"); + static_assert(sizeof(index_t) >= sizeof(size_t), + "Traits::index_t must be at least as wide as Traits::size_t"); + static_assert((BLOCK_SIZE > 1) && !(BLOCK_SIZE & (BLOCK_SIZE - 1)), + "Traits::BLOCK_SIZE must be a power of 2 (and at least 2)"); + static_assert((EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD > 1) && + !(EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD & + (EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD - 1)), + "Traits::EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD must be a power of 2 (and greater than 1)"); + static_assert((EXPLICIT_INITIAL_INDEX_SIZE > 1) && + !(EXPLICIT_INITIAL_INDEX_SIZE & (EXPLICIT_INITIAL_INDEX_SIZE - 1)), + "Traits::EXPLICIT_INITIAL_INDEX_SIZE must be a power of 2 (and greater than 1)"); + static_assert((IMPLICIT_INITIAL_INDEX_SIZE > 1) && + !(IMPLICIT_INITIAL_INDEX_SIZE & (IMPLICIT_INITIAL_INDEX_SIZE - 1)), + "Traits::IMPLICIT_INITIAL_INDEX_SIZE must be a power of 2 (and greater than 1)"); + static_assert((INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) || + !(INITIAL_IMPLICIT_PRODUCER_HASH_SIZE & (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE - 1)), + "Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE must be a power of 2"); + static_assert( + INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0 || INITIAL_IMPLICIT_PRODUCER_HASH_SIZE >= 1, + "Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE must be at least 1 (or 0 to disable implicit enqueueing)"); + + public: + // Creates a queue with at least `capacity` element slots; note that the + // actual number of elements that can be inserted without additional memory + // allocation depends on the number of producers and the block size (e.g. if + // the block size is equal to `capacity`, only a single block will be allocated + // up-front, which means only a single producer will be able to enqueue elements + // without an extra allocation -- blocks aren't shared between producers). + // This method is not thread safe -- it is up to the user to ensure that the + // queue is fully constructed before it starts being used by other threads (this + // includes making the memory effects of construction visible, possibly with a + // memory barrier). + explicit ConcurrentQueue(size_t capacity = 6 * BLOCK_SIZE) + : producerListTail(nullptr), producerCount(0), initialBlockPoolIndex(0), nextExplicitConsumerId( + 0), globalExplicitConsumerOffset(0) { + implicitProducerHashResizeInProgress.clear(std::memory_order_relaxed); + populate_initial_implicit_producer_hash(); + populate_initial_block_list( + capacity / BLOCK_SIZE + ((capacity & (BLOCK_SIZE - 1)) == 0 ? 0 : 1)); + +#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG + // Track all the producers using a fully-resolved typed list for + // each kind; this makes it possible to debug them starting from + // the root queue object (otherwise wacky casts are needed that + // don't compile in the debugger's expression evaluator). + explicitProducers.store(nullptr, std::memory_order_relaxed); + implicitProducers.store(nullptr, std::memory_order_relaxed); +#endif + } + + // Computes the correct amount of pre-allocated blocks for you based + // on the minimum number of elements you want available at any given + // time, and the maximum concurrent number of each type of producer. + ConcurrentQueue(size_t minCapacity, size_t maxExplicitProducers, size_t maxImplicitProducers) + : producerListTail(nullptr), producerCount(0), initialBlockPoolIndex(0), nextExplicitConsumerId( + 0), globalExplicitConsumerOffset(0) { + implicitProducerHashResizeInProgress.clear(std::memory_order_relaxed); + populate_initial_implicit_producer_hash(); + size_t blocks = + (((minCapacity + BLOCK_SIZE - 1) / BLOCK_SIZE) - 1) * (maxExplicitProducers + 1) + + 2 * (maxExplicitProducers + maxImplicitProducers); + populate_initial_block_list(blocks); + +#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG + explicitProducers.store(nullptr, std::memory_order_relaxed); + implicitProducers.store(nullptr, std::memory_order_relaxed); +#endif + } + + // Note: The queue should not be accessed concurrently while it's + // being deleted. It's up to the user to synchronize this. + // This method is not thread safe. + ~ConcurrentQueue() { + // Destroy producers + auto ptr = producerListTail.load(std::memory_order_relaxed); + while (ptr != nullptr) { + auto next = ptr->next_prod(); + if (ptr->token != nullptr) { + ptr->token->producer = nullptr; + } + destroy(ptr); + ptr = next; + } + + // Destroy implicit producer hash tables + if (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE != 0) { + auto hash = implicitProducerHash.load(std::memory_order_relaxed); + while (hash != nullptr) { + auto prev = hash->prev; + if (prev != + nullptr) { // The last hash is part of this object and was not allocated dynamically + for (size_t i = 0; i != hash->capacity; ++i) { + hash->entries[i].~ImplicitProducerKVP(); + } + hash->~ImplicitProducerHash(); + (Traits::free)(hash); + } + hash = prev; + } + } + + // Destroy global free list + auto block = freeList.head_unsafe(); + while (block != nullptr) { + auto next = block->freeListNext.load(std::memory_order_relaxed); + if (block->dynamicallyAllocated) { + destroy(block); + } + block = next; + } + + // Destroy initial free list + destroy_array(initialBlockPool, initialBlockPoolSize); + } + + // Disable copying and copy assignment + ConcurrentQueue(ConcurrentQueue const &) MOODYCAMEL_DELETE_FUNCTION; + + ConcurrentQueue &operator=(ConcurrentQueue const &) MOODYCAMEL_DELETE_FUNCTION; + + // Moving is supported, but note that it is *not* a thread-safe operation. + // Nobody can use the queue while it's being moved, and the memory effects + // of that move must be propagated to other threads before they can use it. + // Note: When a queue is moved, its tokens are still valid but can only be + // used with the destination queue (i.e. semantically they are moved along + // with the queue itself). + ConcurrentQueue(ConcurrentQueue &&other) MOODYCAMEL_NOEXCEPT + : producerListTail(other.producerListTail.load(std::memory_order_relaxed)), producerCount( + other.producerCount.load(std::memory_order_relaxed)), initialBlockPoolIndex( + other.initialBlockPoolIndex.load(std::memory_order_relaxed)), initialBlockPool( + other.initialBlockPool), initialBlockPoolSize(other.initialBlockPoolSize), freeList( + std::move(other.freeList)), nextExplicitConsumerId( + other.nextExplicitConsumerId.load(std::memory_order_relaxed)), globalExplicitConsumerOffset( + other.globalExplicitConsumerOffset.load(std::memory_order_relaxed)) { + // Move the other one into this, and leave the other one as an empty queue + implicitProducerHashResizeInProgress.clear(std::memory_order_relaxed); + populate_initial_implicit_producer_hash(); + swap_implicit_producer_hashes(other); + + other.producerListTail.store(nullptr, std::memory_order_relaxed); + other.producerCount.store(0, std::memory_order_relaxed); + other.nextExplicitConsumerId.store(0, std::memory_order_relaxed); + other.globalExplicitConsumerOffset.store(0, std::memory_order_relaxed); + +#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG + explicitProducers.store(other.explicitProducers.load(std::memory_order_relaxed), std::memory_order_relaxed); + other.explicitProducers.store(nullptr, std::memory_order_relaxed); + implicitProducers.store(other.implicitProducers.load(std::memory_order_relaxed), std::memory_order_relaxed); + other.implicitProducers.store(nullptr, std::memory_order_relaxed); +#endif + + other.initialBlockPoolIndex.store(0, std::memory_order_relaxed); + other.initialBlockPoolSize = 0; + other.initialBlockPool = nullptr; + + reown_producers(); + } + + inline ConcurrentQueue &operator=(ConcurrentQueue &&other) MOODYCAMEL_NOEXCEPT { + return swap_internal(other); + } + + // Swaps this queue's state with the other's. Not thread-safe. + // Swapping two queues does not invalidate their tokens, however + // the tokens that were created for one queue must be used with + // only the swapped queue (i.e. the tokens are tied to the + // queue's movable state, not the object itself). + inline void swap(ConcurrentQueue &other) MOODYCAMEL_NOEXCEPT { + swap_internal(other); + } + + private: + ConcurrentQueue &swap_internal(ConcurrentQueue &other) { + if (this == &other) { + return *this; + } + + details::swap_relaxed(producerListTail, other.producerListTail); + details::swap_relaxed(producerCount, other.producerCount); + details::swap_relaxed(initialBlockPoolIndex, other.initialBlockPoolIndex); + std::swap(initialBlockPool, other.initialBlockPool); + std::swap(initialBlockPoolSize, other.initialBlockPoolSize); + freeList.swap(other.freeList); + details::swap_relaxed(nextExplicitConsumerId, other.nextExplicitConsumerId); + details::swap_relaxed(globalExplicitConsumerOffset, other.globalExplicitConsumerOffset); + + swap_implicit_producer_hashes(other); + + reown_producers(); + other.reown_producers(); + +#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG + details::swap_relaxed(explicitProducers, other.explicitProducers); + details::swap_relaxed(implicitProducers, other.implicitProducers); +#endif + + return *this; + } + + public: + // Enqueues a single item (by copying it). + // Allocates memory if required. Only fails if memory allocation fails (or implicit + // production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0, + // or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Thread-safe. + inline bool enqueue(T const &item) { + if (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) return false; + return inner_enqueue(item); + } + + // Enqueues a single item (by moving it, if possible). + // Allocates memory if required. Only fails if memory allocation fails (or implicit + // production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0, + // or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Thread-safe. + inline bool enqueue(T &&item) { + if (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) return false; + return inner_enqueue(std::move(item)); + } + + // Enqueues a single item (by copying it) using an explicit producer token. + // Allocates memory if required. Only fails if memory allocation fails (or + // Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Thread-safe. + inline bool enqueue(producer_token_t const &token, T const &item) { + return inner_enqueue(token, item); + } + + // Enqueues a single item (by moving it, if possible) using an explicit producer token. + // Allocates memory if required. Only fails if memory allocation fails (or + // Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Thread-safe. + inline bool enqueue(producer_token_t const &token, T &&item) { + return inner_enqueue(token, std::move(item)); + } + + // Enqueues several items. + // Allocates memory if required. Only fails if memory allocation fails (or + // implicit production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE + // is 0, or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Note: Use std::make_move_iterator if the elements should be moved instead of copied. + // Thread-safe. + template + bool enqueue_bulk(It itemFirst, size_t count) { + if (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) return false; + return inner_enqueue_bulk(itemFirst, count); + } + + // Enqueues several items using an explicit producer token. + // Allocates memory if required. Only fails if memory allocation fails + // (or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Note: Use std::make_move_iterator if the elements should be moved + // instead of copied. + // Thread-safe. + template + bool enqueue_bulk(producer_token_t const &token, It itemFirst, size_t count) { + return inner_enqueue_bulk(token, itemFirst, count); + } + + // Enqueues a single item (by copying it). + // Does not allocate memory. Fails if not enough room to enqueue (or implicit + // production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE + // is 0). + // Thread-safe. + inline bool try_enqueue(T const &item) { + if (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) return false; + return inner_enqueue(item); + } + + // Enqueues a single item (by moving it, if possible). + // Does not allocate memory (except for one-time implicit producer). + // Fails if not enough room to enqueue (or implicit production is + // disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0). + // Thread-safe. + inline bool try_enqueue(T &&item) { + if (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) return false; + return inner_enqueue(std::move(item)); + } + + // Enqueues a single item (by copying it) using an explicit producer token. + // Does not allocate memory. Fails if not enough room to enqueue. + // Thread-safe. + inline bool try_enqueue(producer_token_t const &token, T const &item) { + return inner_enqueue(token, item); + } + + // Enqueues a single item (by moving it, if possible) using an explicit producer token. + // Does not allocate memory. Fails if not enough room to enqueue. + // Thread-safe. + inline bool try_enqueue(producer_token_t const &token, T &&item) { + return inner_enqueue(token, std::move(item)); + } + + // Enqueues several items. + // Does not allocate memory (except for one-time implicit producer). + // Fails if not enough room to enqueue (or implicit production is + // disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0). + // Note: Use std::make_move_iterator if the elements should be moved + // instead of copied. + // Thread-safe. + template + bool try_enqueue_bulk(It itemFirst, size_t count) { + if (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) return false; + return inner_enqueue_bulk(itemFirst, count); + } + + // Enqueues several items using an explicit producer token. + // Does not allocate memory. Fails if not enough room to enqueue. + // Note: Use std::make_move_iterator if the elements should be moved + // instead of copied. + // Thread-safe. + template + bool try_enqueue_bulk(producer_token_t const &token, It itemFirst, size_t count) { + return inner_enqueue_bulk(token, itemFirst, count); + } + + + // Attempts to dequeue from the queue. + // Returns false if all producer streams appeared empty at the time they + // were checked (so, the queue is likely but not guaranteed to be empty). + // Never allocates. Thread-safe. + template + bool try_dequeue(U &item) { + // Instead of simply trying each producer in turn (which could cause needless contention on the first + // producer), we score them heuristically. + size_t nonEmptyCount = 0; + ProducerBase *best = nullptr; + size_t bestSize = 0; + for (auto ptr = producerListTail.load(std::memory_order_acquire); + nonEmptyCount < 3 && ptr != nullptr; ptr = ptr->next_prod()) { + auto size = ptr->size_approx(); + if (size > 0) { + if (size > bestSize) { + bestSize = size; + best = ptr; + } + ++nonEmptyCount; + } + } + + // If there was at least one non-empty queue but it appears empty at the time + // we try to dequeue from it, we need to make sure every queue's been tried + if (nonEmptyCount > 0) { + if (details::likely(best->dequeue(item))) { + return true; + } + for (auto ptr = producerListTail.load(std::memory_order_acquire); + ptr != nullptr; ptr = ptr->next_prod()) { + if (ptr != best && ptr->dequeue(item)) { + return true; + } + } + } + return false; + } + + // Attempts to dequeue from the queue. + // Returns false if all producer streams appeared empty at the time they + // were checked (so, the queue is likely but not guaranteed to be empty). + // This differs from the try_dequeue(item) method in that this one does + // not attempt to reduce contention by interleaving the order that producer + // streams are dequeued from. So, using this method can reduce overall throughput + // under contention, but will give more predictable results in single-threaded + // consumer scenarios. This is mostly only useful for internal unit tests. + // Never allocates. Thread-safe. + template + bool try_dequeue_non_interleaved(U &item) { + for (auto ptr = producerListTail.load(std::memory_order_acquire); + ptr != nullptr; ptr = ptr->next_prod()) { + if (ptr->dequeue(item)) { + return true; + } + } + return false; + } + + // Attempts to dequeue from the queue using an explicit consumer token. + // Returns false if all producer streams appeared empty at the time they + // were checked (so, the queue is likely but not guaranteed to be empty). + // Never allocates. Thread-safe. + template + bool try_dequeue(consumer_token_t &token, U &item) { + // The idea is roughly as follows: + // Every 256 items from one producer, make everyone rotate (increase the global offset) -> this means the highest efficiency consumer dictates the rotation speed of everyone else, more or less + // If you see that the global offset has changed, you must reset your consumption counter and move to your designated place + // If there's no items where you're supposed to be, keep moving until you find a producer with some items + // If the global offset has not changed but you've run out of items to consume, move over from your current position until you find an producer with something in it + + if (token.desiredProducer == nullptr || token.lastKnownGlobalOffset != + globalExplicitConsumerOffset.load( + std::memory_order_relaxed)) { + if (!update_current_producer_after_rotation(token)) { + return false; + } + } + + // If there was at least one non-empty queue but it appears empty at the time + // we try to dequeue from it, we need to make sure every queue's been tried + if (static_cast(token.currentProducer)->dequeue(item)) { + if (++token.itemsConsumedFromCurrent == EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE) { + globalExplicitConsumerOffset.fetch_add(1, std::memory_order_relaxed); + } + return true; + } + + auto tail = producerListTail.load(std::memory_order_acquire); + auto ptr = static_cast(token.currentProducer)->next_prod(); + if (ptr == nullptr) { + ptr = tail; + } + while (ptr != static_cast(token.currentProducer)) { + if (ptr->dequeue(item)) { + token.currentProducer = ptr; + token.itemsConsumedFromCurrent = 1; + return true; + } + ptr = ptr->next_prod(); + if (ptr == nullptr) { + ptr = tail; + } + } + return false; + } + + // Attempts to dequeue several elements from the queue. + // Returns the number of items actually dequeued. + // Returns 0 if all producer streams appeared empty at the time they + // were checked (so, the queue is likely but not guaranteed to be empty). + // Never allocates. Thread-safe. + template + size_t try_dequeue_bulk(It itemFirst, size_t max) { + size_t count = 0; + for (auto ptr = producerListTail.load(std::memory_order_acquire); + ptr != nullptr; ptr = ptr->next_prod()) { + count += ptr->dequeue_bulk(itemFirst, max - count); + if (count == max) { + break; + } + } + return count; + } + + // Attempts to dequeue several elements from the queue using an explicit consumer token. + // Returns the number of items actually dequeued. + // Returns 0 if all producer streams appeared empty at the time they + // were checked (so, the queue is likely but not guaranteed to be empty). + // Never allocates. Thread-safe. + template + size_t try_dequeue_bulk(consumer_token_t &token, It itemFirst, size_t max) { + if (token.desiredProducer == nullptr || token.lastKnownGlobalOffset != + globalExplicitConsumerOffset.load( + std::memory_order_relaxed)) { + if (!update_current_producer_after_rotation(token)) { + return 0; + } + } + + size_t count = static_cast(token.currentProducer)->dequeue_bulk(itemFirst, max); + if (count == max) { + if ((token.itemsConsumedFromCurrent += static_cast(max)) >= + EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE) { + globalExplicitConsumerOffset.fetch_add(1, std::memory_order_relaxed); + } + return max; + } + token.itemsConsumedFromCurrent += static_cast(count); + max -= count; + + auto tail = producerListTail.load(std::memory_order_acquire); + auto ptr = static_cast(token.currentProducer)->next_prod(); + if (ptr == nullptr) { + ptr = tail; + } + while (ptr != static_cast(token.currentProducer)) { + auto dequeued = ptr->dequeue_bulk(itemFirst, max); + count += dequeued; + if (dequeued != 0) { + token.currentProducer = ptr; + token.itemsConsumedFromCurrent = static_cast(dequeued); + } + if (dequeued == max) { + break; + } + max -= dequeued; + ptr = ptr->next_prod(); + if (ptr == nullptr) { + ptr = tail; + } + } + return count; + } + + + // Attempts to dequeue from a specific producer's inner queue. + // If you happen to know which producer you want to dequeue from, this + // is significantly faster than using the general-case try_dequeue methods. + // Returns false if the producer's queue appeared empty at the time it + // was checked (so, the queue is likely but not guaranteed to be empty). + // Never allocates. Thread-safe. + template + inline bool try_dequeue_from_producer(producer_token_t const &producer, U &item) { + return static_cast(producer.producer)->dequeue(item); + } + + // Attempts to dequeue several elements from a specific producer's inner queue. + // Returns the number of items actually dequeued. + // If you happen to know which producer you want to dequeue from, this + // is significantly faster than using the general-case try_dequeue methods. + // Returns 0 if the producer's queue appeared empty at the time it + // was checked (so, the queue is likely but not guaranteed to be empty). + // Never allocates. Thread-safe. + template + inline size_t + try_dequeue_bulk_from_producer(producer_token_t const &producer, It itemFirst, size_t max) { + return static_cast(producer.producer)->dequeue_bulk(itemFirst, max); + } + + + // Returns an estimate of the total number of elements currently in the queue. This + // estimate is only accurate if the queue has completely stabilized before it is called + // (i.e. all enqueue and dequeue operations have completed and their memory effects are + // visible on the calling thread, and no further operations start while this method is + // being called). + // Thread-safe. + size_t size_approx() const { + size_t size = 0; + for (auto ptr = producerListTail.load(std::memory_order_acquire); + ptr != nullptr; ptr = ptr->next_prod()) { + size += ptr->size_approx(); + } + return size; + } + + + // Returns true if the underlying atomic variables used by + // the queue are lock-free (they should be on most platforms). + // Thread-safe. + static bool is_lock_free() { + return + details::static_is_lock_free::value == 2 && + details::static_is_lock_free::value == 2 && + details::static_is_lock_free::value == 2 && + details::static_is_lock_free::value == 2 && + details::static_is_lock_free::value == 2 && + details::static_is_lock_free::thread_id_numeric_size_t>::value == + 2; + } + + + private: + friend struct ProducerToken; + friend struct ConsumerToken; + friend struct ExplicitProducer; + + friend class ConcurrentQueueTests; + + enum AllocationMode { + CanAlloc, CannotAlloc + }; + + + /////////////////////////////// + // Queue methods + /////////////////////////////// + + template + inline bool inner_enqueue(producer_token_t const &token, U &&element) { + return static_cast(token.producer)->ConcurrentQueue::ExplicitProducer::template enqueue( + std::forward(element)); + } + + template + inline bool inner_enqueue(U &&element) { + auto producer = get_or_add_implicit_producer(); + return producer == nullptr ? false + : producer->ConcurrentQueue::ImplicitProducer::template enqueue( + std::forward(element)); + } + + template + inline bool inner_enqueue_bulk(producer_token_t const &token, It itemFirst, size_t count) { + return static_cast(token.producer)->ConcurrentQueue::ExplicitProducer::template enqueue_bulk( + itemFirst, count); + } + + template + inline bool inner_enqueue_bulk(It itemFirst, size_t count) { + auto producer = get_or_add_implicit_producer(); + return producer == nullptr ? false + : producer->ConcurrentQueue::ImplicitProducer::template enqueue_bulk( + itemFirst, count); + } + + inline bool update_current_producer_after_rotation(consumer_token_t &token) { + // Ah, there's been a rotation, figure out where we should be! + auto tail = producerListTail.load(std::memory_order_acquire); + if (token.desiredProducer == nullptr && tail == nullptr) { + return false; + } + auto prodCount = producerCount.load(std::memory_order_relaxed); + auto globalOffset = globalExplicitConsumerOffset.load(std::memory_order_relaxed); + if (details::unlikely(token.desiredProducer == nullptr)) { + // Aha, first time we're dequeueing anything. + // Figure out our local position + // Note: offset is from start, not end, but we're traversing from end -- subtract from count first + std::uint32_t offset = prodCount - 1 - (token.initialOffset % prodCount); + token.desiredProducer = tail; + for (std::uint32_t i = 0; i != offset; ++i) { + token.desiredProducer = static_cast(token.desiredProducer)->next_prod(); + if (token.desiredProducer == nullptr) { + token.desiredProducer = tail; + } + } + } + + std::uint32_t delta = globalOffset - token.lastKnownGlobalOffset; + if (delta >= prodCount) { + delta = delta % prodCount; + } + for (std::uint32_t i = 0; i != delta; ++i) { + token.desiredProducer = static_cast(token.desiredProducer)->next_prod(); + if (token.desiredProducer == nullptr) { + token.desiredProducer = tail; + } + } + + token.lastKnownGlobalOffset = globalOffset; + token.currentProducer = token.desiredProducer; + token.itemsConsumedFromCurrent = 0; + return true; + } + + + /////////////////////////// + // Free list + /////////////////////////// + + template + struct FreeListNode { + FreeListNode() + : freeListRefs(0), freeListNext(nullptr) {} + + std::atomic freeListRefs; + std::atomic freeListNext; + }; + + // A simple CAS-based lock-free free list. Not the fastest thing in the world under heavy contention, but + // simple and correct (assuming nodes are never freed until after the free list is destroyed), and fairly + // speedy under low contention. + template // N must inherit FreeListNode or have the same fields (and initialization of them) + struct FreeList { + FreeList() + : freeListHead(nullptr) {} + + FreeList(FreeList &&other) + : freeListHead(other.freeListHead.load(std::memory_order_relaxed)) { + other.freeListHead.store(nullptr, std::memory_order_relaxed); + } + + void swap(FreeList &other) { details::swap_relaxed(freeListHead, other.freeListHead); } + + FreeList(FreeList const &) MOODYCAMEL_DELETE_FUNCTION; + + FreeList &operator=(FreeList const &) MOODYCAMEL_DELETE_FUNCTION; + + inline void add(N *node) { +#if MCDBGQ_NOLOCKFREE_FREELIST + debug::DebugLock lock(mutex); +#endif + // We know that the should-be-on-freelist bit is 0 at this point, so it's safe to + // set it using a fetch_add + if (node->freeListRefs.fetch_add(SHOULD_BE_ON_FREELIST, std::memory_order_acq_rel) == 0) { + // Oh look! We were the last ones referencing this node, and we know + // we want to add it to the free list, so let's do it! + add_knowing_refcount_is_zero(node); + } + } + + inline N *try_get() { +#if MCDBGQ_NOLOCKFREE_FREELIST + debug::DebugLock lock(mutex); +#endif + auto head = freeListHead.load(std::memory_order_acquire); + while (head != nullptr) { + auto prevHead = head; + auto refs = head->freeListRefs.load(std::memory_order_relaxed); + if ((refs & REFS_MASK) == 0 || + !head->freeListRefs.compare_exchange_strong(refs, refs + 1, std::memory_order_acquire, + std::memory_order_relaxed)) { + head = freeListHead.load(std::memory_order_acquire); + continue; + } + + // Good, reference count has been incremented (it wasn't at zero), which means we can read the + // next and not worry about it changing between now and the time we do the CAS + auto next = head->freeListNext.load(std::memory_order_relaxed); + if (freeListHead.compare_exchange_strong(head, next, std::memory_order_acquire, + std::memory_order_relaxed)) { + // Yay, got the node. This means it was on the list, which means shouldBeOnFreeList must be false no + // matter the refcount (because nobody else knows it's been taken off yet, it can't have been put back on). + assert((head->freeListRefs.load(std::memory_order_relaxed) & SHOULD_BE_ON_FREELIST) == 0); + + // Decrease refcount twice, once for our ref, and once for the list's ref + head->freeListRefs.fetch_sub(2, std::memory_order_release); + return head; + } + + // OK, the head must have changed on us, but we still need to decrease the refcount we increased. + // Note that we don't need to release any memory effects, but we do need to ensure that the reference + // count decrement happens-after the CAS on the head. + refs = prevHead->freeListRefs.fetch_sub(1, std::memory_order_acq_rel); + if (refs == SHOULD_BE_ON_FREELIST + 1) { + add_knowing_refcount_is_zero(prevHead); + } + } + + return nullptr; + } + + // Useful for traversing the list when there's no contention (e.g. to destroy remaining nodes) + N *head_unsafe() const { return freeListHead.load(std::memory_order_relaxed); } + + private: + inline void add_knowing_refcount_is_zero(N *node) { + // Since the refcount is zero, and nobody can increase it once it's zero (except us, and we run + // only one copy of this method per node at a time, i.e. the single thread case), then we know + // we can safely change the next pointer of the node; however, once the refcount is back above + // zero, then other threads could increase it (happens under heavy contention, when the refcount + // goes to zero in between a load and a refcount increment of a node in try_get, then back up to + // something non-zero, then the refcount increment is done by the other thread) -- so, if the CAS + // to add the node to the actual list fails, decrease the refcount and leave the add operation to + // the next thread who puts the refcount back at zero (which could be us, hence the loop). + auto head = freeListHead.load(std::memory_order_relaxed); + while (true) { + node->freeListNext.store(head, std::memory_order_relaxed); + node->freeListRefs.store(1, std::memory_order_release); + if (!freeListHead.compare_exchange_strong(head, node, std::memory_order_release, + std::memory_order_relaxed)) { + // Hmm, the add failed, but we can only try again when the refcount goes back to zero + if (node->freeListRefs.fetch_add(SHOULD_BE_ON_FREELIST - 1, std::memory_order_release) == + 1) { + continue; + } + } + return; + } + } + + private: + // Implemented like a stack, but where node order doesn't matter (nodes are inserted out of order under contention) + std::atomic freeListHead; + + static const std::uint32_t REFS_MASK = 0x7FFFFFFF; + static const std::uint32_t SHOULD_BE_ON_FREELIST = 0x80000000; + +#if MCDBGQ_NOLOCKFREE_FREELIST + debug::DebugMutex mutex; +#endif + }; + + + /////////////////////////// + // Block + /////////////////////////// + + enum InnerQueueContext { + implicit_context = 0, explicit_context = 1 + }; + + struct Block { + Block() + : next(nullptr), elementsCompletelyDequeued(0), freeListRefs(0), freeListNext(nullptr) + , shouldBeOnFreeList(false), dynamicallyAllocated(true) { +#if MCDBGQ_TRACKMEM + owner = nullptr; +#endif + } + + template + inline bool is_empty() const { + if (context == explicit_context && BLOCK_SIZE <= EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD) { + // Check flags + for (size_t i = 0; i < BLOCK_SIZE; ++i) { + if (!emptyFlags[i].load(std::memory_order_relaxed)) { + return false; + } + } + + // Aha, empty; make sure we have all other memory effects that happened before the empty flags were set + std::atomic_thread_fence(std::memory_order_acquire); + return true; + } else { + // Check counter + if (elementsCompletelyDequeued.load(std::memory_order_relaxed) == BLOCK_SIZE) { + std::atomic_thread_fence(std::memory_order_acquire); + return true; + } + assert(elementsCompletelyDequeued.load(std::memory_order_relaxed) <= BLOCK_SIZE); + return false; + } + } + + // Returns true if the block is now empty (does not apply in explicit context) + template + inline bool set_empty(index_t i) { + if (context == explicit_context && BLOCK_SIZE <= EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD) { + // Set flag + assert(!emptyFlags[BLOCK_SIZE - 1 - + static_cast(i & static_cast(BLOCK_SIZE - 1))].load( + std::memory_order_relaxed)); + emptyFlags[BLOCK_SIZE - 1 - + static_cast(i & static_cast(BLOCK_SIZE - 1))].store(true, + std::memory_order_release); + return false; + } else { + // Increment counter + auto prevVal = elementsCompletelyDequeued.fetch_add(1, std::memory_order_release); + assert(prevVal < BLOCK_SIZE); + return prevVal == BLOCK_SIZE - 1; + } + } + + // Sets multiple contiguous item statuses to 'empty' (assumes no wrapping and count > 0). + // Returns true if the block is now empty (does not apply in explicit context). + template + inline bool set_many_empty(index_t i, size_t count) { + if (context == explicit_context && BLOCK_SIZE <= EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD) { + // Set flags + std::atomic_thread_fence(std::memory_order_release); + i = BLOCK_SIZE - 1 - static_cast(i & static_cast(BLOCK_SIZE - 1)) - count + + 1; + for (size_t j = 0; j != count; ++j) { + assert(!emptyFlags[i + j].load(std::memory_order_relaxed)); + emptyFlags[i + j].store(true, std::memory_order_relaxed); + } + return false; + } else { + // Increment counter + auto prevVal = elementsCompletelyDequeued.fetch_add(count, std::memory_order_release); + assert(prevVal + count <= BLOCK_SIZE); + return prevVal + count == BLOCK_SIZE; + } + } + + template + inline void set_all_empty() { + if (context == explicit_context && BLOCK_SIZE <= EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD) { + // Set all flags + for (size_t i = 0; i != BLOCK_SIZE; ++i) { + emptyFlags[i].store(true, std::memory_order_relaxed); + } + } else { + // Reset counter + elementsCompletelyDequeued.store(BLOCK_SIZE, std::memory_order_relaxed); + } + } + + template + inline void reset_empty() { + if (context == explicit_context && BLOCK_SIZE <= EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD) { + // Reset flags + for (size_t i = 0; i != BLOCK_SIZE; ++i) { + emptyFlags[i].store(false, std::memory_order_relaxed); + } + } else { + // Reset counter + elementsCompletelyDequeued.store(0, std::memory_order_relaxed); + } + } + + inline T *operator[](index_t idx) MOODYCAMEL_NOEXCEPT { + return static_cast(static_cast(elements)) + + static_cast(idx & static_cast(BLOCK_SIZE - 1)); + } + + inline T const *operator[](index_t idx) const MOODYCAMEL_NOEXCEPT { + return static_cast(static_cast(elements)) + + static_cast(idx & static_cast(BLOCK_SIZE - 1)); + } + + private: + // IMPORTANT: This must be the first member in Block, so that if T depends on the alignment of + // addresses returned by malloc, that alignment will be preserved. Apparently clang actually + // generates code that uses this assumption for AVX instructions in some cases. Ideally, we + // should also align Block to the alignment of T in case it's higher than malloc's 16-byte + // alignment, but this is hard to do in a cross-platform way. Assert for this case: + static_assert(std::alignment_of::value <= std::alignment_of::value, + "The queue does not support super-aligned types at this time"); + // Additionally, we need the alignment of Block itself to be a multiple of max_align_t since + // otherwise the appropriate padding will not be added at the end of Block in order to make + // arrays of Blocks all be properly aligned (not just the first one). We use a union to force + // this. + union { + char elements[sizeof(T) * BLOCK_SIZE]; + details::max_align_t dummy; + }; + public: + Block *next; + std::atomic elementsCompletelyDequeued; + std::atomic emptyFlags[ + BLOCK_SIZE <= EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD ? BLOCK_SIZE : 1]; + public: + std::atomic freeListRefs; + std::atomic freeListNext; + std::atomic shouldBeOnFreeList; + bool dynamicallyAllocated; // Perhaps a better name for this would be 'isNotPartOfInitialBlockPool' + +#if MCDBGQ_TRACKMEM + void* owner; +#endif + }; + + static_assert(std::alignment_of::value >= std::alignment_of::value, + "Internal error: Blocks must be at least as aligned as the type they are wrapping"); + + +#if MCDBGQ_TRACKMEM + public: + struct MemStats; + private: +#endif + + /////////////////////////// + // Producer base + /////////////////////////// + + struct ProducerBase : public details::ConcurrentQueueProducerTypelessBase { + ProducerBase(ConcurrentQueue *parent_, bool isExplicit_) + : + tailIndex(0), headIndex(0), dequeueOptimisticCount(0), dequeueOvercommit(0), tailBlock( + nullptr), isExplicit(isExplicit_), parent(parent_) { + } + + virtual ~ProducerBase() {}; + + template + inline bool dequeue(U &element) { + if (isExplicit) { + return static_cast(this)->dequeue(element); + } else { + return static_cast(this)->dequeue(element); + } + } + + template + inline size_t dequeue_bulk(It &itemFirst, size_t max) { + if (isExplicit) { + return static_cast(this)->dequeue_bulk(itemFirst, max); + } else { + return static_cast(this)->dequeue_bulk(itemFirst, max); + } + } + + inline ProducerBase *next_prod() const { return static_cast(next); } + + inline size_t size_approx() const { + auto tail = tailIndex.load(std::memory_order_relaxed); + auto head = headIndex.load(std::memory_order_relaxed); + return details::circular_less_than(head, tail) ? static_cast(tail - head) : 0; + } + + inline index_t getTail() const { return tailIndex.load(std::memory_order_relaxed); } + + protected: + std::atomic tailIndex; // Where to enqueue to next + std::atomic headIndex; // Where to dequeue from next + + std::atomic dequeueOptimisticCount; + std::atomic dequeueOvercommit; + + Block *tailBlock; + + public: + bool isExplicit; + ConcurrentQueue *parent; + + protected: +#if MCDBGQ_TRACKMEM + friend struct MemStats; +#endif + }; + + + /////////////////////////// + // Explicit queue + /////////////////////////// + + struct ExplicitProducer : public ProducerBase { + explicit ExplicitProducer(ConcurrentQueue *parent) + : + ProducerBase(parent, true), blockIndex(nullptr), pr_blockIndexSlotsUsed(0), pr_blockIndexSize( + EXPLICIT_INITIAL_INDEX_SIZE >> 1), pr_blockIndexFront(0), pr_blockIndexEntries(nullptr) + , pr_blockIndexRaw(nullptr) { + size_t poolBasedIndexSize = details::ceil_to_pow_2(parent->initialBlockPoolSize) >> 1; + if (poolBasedIndexSize > pr_blockIndexSize) { + pr_blockIndexSize = poolBasedIndexSize; + } + + new_block_index( + 0); // This creates an index with double the number of current entries, i.e. EXPLICIT_INITIAL_INDEX_SIZE + } + + ~ExplicitProducer() { + // Destruct any elements not yet dequeued. + // Since we're in the destructor, we can assume all elements + // are either completely dequeued or completely not (no halfways). + if (this->tailBlock != nullptr) { // Note this means there must be a block index too + // First find the block that's partially dequeued, if any + Block *halfDequeuedBlock = nullptr; + if ((this->headIndex.load(std::memory_order_relaxed) & + static_cast(BLOCK_SIZE - 1)) != 0) { + // The head's not on a block boundary, meaning a block somewhere is partially dequeued + // (or the head block is the tail block and was fully dequeued, but the head/tail are still not on a boundary) + size_t i = (pr_blockIndexFront - pr_blockIndexSlotsUsed) & (pr_blockIndexSize - 1); + while (details::circular_less_than(pr_blockIndexEntries[i].base + BLOCK_SIZE, + this->headIndex.load( + std::memory_order_relaxed))) { + i = (i + 1) & (pr_blockIndexSize - 1); + } + assert(details::circular_less_than(pr_blockIndexEntries[i].base, + this->headIndex.load( + std::memory_order_relaxed))); + halfDequeuedBlock = pr_blockIndexEntries[i].block; + } + + // Start at the head block (note the first line in the loop gives us the head from the tail on the first iteration) + auto block = this->tailBlock; + do { + block = block->next; + if (block->ConcurrentQueue::Block::template is_empty()) { + continue; + } + + size_t i = 0; // Offset into block + if (block == halfDequeuedBlock) { + i = static_cast(this->headIndex.load(std::memory_order_relaxed) & + static_cast(BLOCK_SIZE - 1)); + } + + // Walk through all the items in the block; if this is the tail block, we need to stop when we reach the tail index + auto lastValidIndex = (this->tailIndex.load(std::memory_order_relaxed) & + static_cast(BLOCK_SIZE - 1)) == 0 ? BLOCK_SIZE + : static_cast( + this->tailIndex.load(std::memory_order_relaxed) & + static_cast(BLOCK_SIZE - 1)); + while (i != BLOCK_SIZE && (block != this->tailBlock || i != lastValidIndex)) { + (*block)[i++]->~T(); + } + } while (block != this->tailBlock); + } + + // Destroy all blocks that we own + if (this->tailBlock != nullptr) { + auto block = this->tailBlock; + do { + auto nextBlock = block->next; + if (block->dynamicallyAllocated) { + destroy(block); + } else { + this->parent->add_block_to_free_list(block); + } + block = nextBlock; + } while (block != this->tailBlock); + } + + // Destroy the block indices + auto header = static_cast(pr_blockIndexRaw); + while (header != nullptr) { + auto prev = static_cast(header->prev); + header->~BlockIndexHeader(); + (Traits::free)(header); + header = prev; + } + } + + template + inline bool enqueue(U &&element) { + index_t currentTailIndex = this->tailIndex.load(std::memory_order_relaxed); + index_t newTailIndex = 1 + currentTailIndex; + if ((currentTailIndex & static_cast(BLOCK_SIZE - 1)) == 0) { + // We reached the end of a block, start a new one + auto startBlock = this->tailBlock; + auto originalBlockIndexSlotsUsed = pr_blockIndexSlotsUsed; + if (this->tailBlock != nullptr && + this->tailBlock->next->ConcurrentQueue::Block::template is_empty()) { + // We can re-use the block ahead of us, it's empty! + this->tailBlock = this->tailBlock->next; + this->tailBlock->ConcurrentQueue::Block::template reset_empty(); + + // We'll put the block on the block index (guaranteed to be room since we're conceptually removing the + // last block from it first -- except instead of removing then adding, we can just overwrite). + // Note that there must be a valid block index here, since even if allocation failed in the ctor, + // it would have been re-attempted when adding the first block to the queue; since there is such + // a block, a block index must have been successfully allocated. + } else { + // Whatever head value we see here is >= the last value we saw here (relatively), + // and <= its current value. Since we have the most recent tail, the head must be + // <= to it. + auto head = this->headIndex.load(std::memory_order_relaxed); + assert(!details::circular_less_than(currentTailIndex, head)); + if (!details::circular_less_than(head, currentTailIndex + BLOCK_SIZE) + || (MAX_SUBQUEUE_SIZE != details::const_numeric_max::value && + (MAX_SUBQUEUE_SIZE == 0 || + MAX_SUBQUEUE_SIZE - BLOCK_SIZE < currentTailIndex - head))) { + // We can't enqueue in another block because there's not enough leeway -- the + // tail could surpass the head by the time the block fills up! (Or we'll exceed + // the size limit, if the second part of the condition was true.) + return false; + } + // We're going to need a new block; check that the block index has room + if (pr_blockIndexRaw == nullptr || pr_blockIndexSlotsUsed == pr_blockIndexSize) { + // Hmm, the circular block index is already full -- we'll need + // to allocate a new index. Note pr_blockIndexRaw can only be nullptr if + // the initial allocation failed in the constructor. + + if (allocMode == CannotAlloc || !new_block_index(pr_blockIndexSlotsUsed)) { + return false; + } + } + + // Insert a new block in the circular linked list + auto newBlock = this->parent->ConcurrentQueue::template requisition_block(); + if (newBlock == nullptr) { + return false; + } +#if MCDBGQ_TRACKMEM + newBlock->owner = this; +#endif + newBlock->ConcurrentQueue::Block::template reset_empty(); + if (this->tailBlock == nullptr) { + newBlock->next = newBlock; + } else { + newBlock->next = this->tailBlock->next; + this->tailBlock->next = newBlock; + } + this->tailBlock = newBlock; + ++pr_blockIndexSlotsUsed; + } + + if (!MOODYCAMEL_NOEXCEPT_CTOR(T, U, new(nullptr) T(std::forward(element)))) { + // The constructor may throw. We want the element not to appear in the queue in + // that case (without corrupting the queue): + MOODYCAMEL_TRY { + new((*this->tailBlock)[currentTailIndex]) T(std::forward(element)); + } + MOODYCAMEL_CATCH (...) { + // Revert change to the current block, but leave the new block available + // for next time + pr_blockIndexSlotsUsed = originalBlockIndexSlotsUsed; + this->tailBlock = startBlock == nullptr ? this->tailBlock : startBlock; + MOODYCAMEL_RETHROW; + } + } else { + (void) startBlock; + (void) originalBlockIndexSlotsUsed; + } + + // Add block to block index + auto &entry = blockIndex.load(std::memory_order_relaxed)->entries[pr_blockIndexFront]; + entry.base = currentTailIndex; + entry.block = this->tailBlock; + blockIndex.load(std::memory_order_relaxed)->front.store(pr_blockIndexFront, + std::memory_order_release); + pr_blockIndexFront = (pr_blockIndexFront + 1) & (pr_blockIndexSize - 1); + + if (!MOODYCAMEL_NOEXCEPT_CTOR(T, U, new(nullptr) T(std::forward(element)))) { + this->tailIndex.store(newTailIndex, std::memory_order_release); + return true; + } + } + + // Enqueue + new((*this->tailBlock)[currentTailIndex]) T(std::forward(element)); + + this->tailIndex.store(newTailIndex, std::memory_order_release); + return true; + } + + template + bool dequeue(U &element) { + auto tail = this->tailIndex.load(std::memory_order_relaxed); + auto overcommit = this->dequeueOvercommit.load(std::memory_order_relaxed); + if (details::circular_less_than( + this->dequeueOptimisticCount.load(std::memory_order_relaxed) - overcommit, tail)) { + // Might be something to dequeue, let's give it a try + + // Note that this if is purely for performance purposes in the common case when the queue is + // empty and the values are eventually consistent -- we may enter here spuriously. + + // Note that whatever the values of overcommit and tail are, they are not going to change (unless we + // change them) and must be the same value at this point (inside the if) as when the if condition was + // evaluated. + + // We insert an acquire fence here to synchronize-with the release upon incrementing dequeueOvercommit below. + // This ensures that whatever the value we got loaded into overcommit, the load of dequeueOptisticCount in + // the fetch_add below will result in a value at least as recent as that (and therefore at least as large). + // Note that I believe a compiler (signal) fence here would be sufficient due to the nature of fetch_add (all + // read-modify-write operations are guaranteed to work on the latest value in the modification order), but + // unfortunately that can't be shown to be correct using only the C++11 standard. + // See http://stackoverflow.com/questions/18223161/what-are-the-c11-memory-ordering-guarantees-in-this-corner-case + std::atomic_thread_fence(std::memory_order_acquire); + + // Increment optimistic counter, then check if it went over the boundary + auto myDequeueCount = this->dequeueOptimisticCount.fetch_add(1, std::memory_order_relaxed); + + // Note that since dequeueOvercommit must be <= dequeueOptimisticCount (because dequeueOvercommit is only ever + // incremented after dequeueOptimisticCount -- this is enforced in the `else` block below), and since we now + // have a version of dequeueOptimisticCount that is at least as recent as overcommit (due to the release upon + // incrementing dequeueOvercommit and the acquire above that synchronizes with it), overcommit <= myDequeueCount. + assert(overcommit <= myDequeueCount); + + // Note that we reload tail here in case it changed; it will be the same value as before or greater, since + // this load is sequenced after (happens after) the earlier load above. This is supported by read-read + // coherency (as defined in the standard), explained here: http://en.cppreference.com/w/cpp/atomic/memory_order + tail = this->tailIndex.load(std::memory_order_acquire); + if (details::likely( + details::circular_less_than(myDequeueCount - overcommit, tail))) { + // Guaranteed to be at least one element to dequeue! + + // Get the index. Note that since there's guaranteed to be at least one element, this + // will never exceed tail. We need to do an acquire-release fence here since it's possible + // that whatever condition got us to this point was for an earlier enqueued element (that + // we already see the memory effects for), but that by the time we increment somebody else + // has incremented it, and we need to see the memory effects for *that* element, which is + // in such a case is necessarily visible on the thread that incremented it in the first + // place with the more current condition (they must have acquired a tail that is at least + // as recent). + auto index = this->headIndex.fetch_add(1, std::memory_order_acq_rel); + + + // Determine which block the element is in + + auto localBlockIndex = blockIndex.load(std::memory_order_acquire); + auto localBlockIndexHead = localBlockIndex->front.load(std::memory_order_acquire); + + // We need to be careful here about subtracting and dividing because of index wrap-around. + // When an index wraps, we need to preserve the sign of the offset when dividing it by the + // block size (in order to get a correct signed block count offset in all cases): + auto headBase = localBlockIndex->entries[localBlockIndexHead].base; + auto blockBaseIndex = index & ~static_cast(BLOCK_SIZE - 1); + auto offset = static_cast( + static_cast::type>(blockBaseIndex - headBase) / + BLOCK_SIZE); + auto block = localBlockIndex->entries[(localBlockIndexHead + offset) & + (localBlockIndex->size - 1)].block; + + // Dequeue + auto &el = *((*block)[index]); + if (!MOODYCAMEL_NOEXCEPT_ASSIGN(T, T &&, element = std::move(el))) { + // Make sure the element is still fully dequeued and destroyed even if the assignment + // throws + struct Guard { + Block *block; + index_t index; + + ~Guard() { + (*block)[index]->~T(); + block->ConcurrentQueue::Block::template set_empty(index); + } + } guard = {block, index}; + + element = std::move(el); + } else { + element = std::move(el); + el.~T(); + block->ConcurrentQueue::Block::template set_empty(index); + } + + return true; + } else { + // Wasn't anything to dequeue after all; make the effective dequeue count eventually consistent + this->dequeueOvercommit.fetch_add(1, + std::memory_order_release); // Release so that the fetch_add on dequeueOptimisticCount is guaranteed to happen before this write + } + } + + return false; + } + + template + bool enqueue_bulk(It itemFirst, size_t count) { + // First, we need to make sure we have enough room to enqueue all of the elements; + // this means pre-allocating blocks and putting them in the block index (but only if + // all the allocations succeeded). + index_t startTailIndex = this->tailIndex.load(std::memory_order_relaxed); + auto startBlock = this->tailBlock; + auto originalBlockIndexFront = pr_blockIndexFront; + auto originalBlockIndexSlotsUsed = pr_blockIndexSlotsUsed; + + Block *firstAllocatedBlock = nullptr; + + // Figure out how many blocks we'll need to allocate, and do so + size_t blockBaseDiff = + ((startTailIndex + count - 1) & ~static_cast(BLOCK_SIZE - 1)) - + ((startTailIndex - 1) & ~static_cast(BLOCK_SIZE - 1)); + index_t currentTailIndex = (startTailIndex - 1) & ~static_cast(BLOCK_SIZE - 1); + if (blockBaseDiff > 0) { + // Allocate as many blocks as possible from ahead + while (blockBaseDiff > 0 && this->tailBlock != nullptr && + this->tailBlock->next != firstAllocatedBlock && + this->tailBlock->next->ConcurrentQueue::Block::template is_empty()) { + blockBaseDiff -= static_cast(BLOCK_SIZE); + currentTailIndex += static_cast(BLOCK_SIZE); + + this->tailBlock = this->tailBlock->next; + firstAllocatedBlock = + firstAllocatedBlock == nullptr ? this->tailBlock : firstAllocatedBlock; + + auto &entry = blockIndex.load(std::memory_order_relaxed)->entries[pr_blockIndexFront]; + entry.base = currentTailIndex; + entry.block = this->tailBlock; + pr_blockIndexFront = (pr_blockIndexFront + 1) & (pr_blockIndexSize - 1); + } + + // Now allocate as many blocks as necessary from the block pool + while (blockBaseDiff > 0) { + blockBaseDiff -= static_cast(BLOCK_SIZE); + currentTailIndex += static_cast(BLOCK_SIZE); + + auto head = this->headIndex.load(std::memory_order_relaxed); + assert(!details::circular_less_than(currentTailIndex, head)); + bool full = !details::circular_less_than(head, currentTailIndex + BLOCK_SIZE) || + (MAX_SUBQUEUE_SIZE != details::const_numeric_max::value && + (MAX_SUBQUEUE_SIZE == 0 || + MAX_SUBQUEUE_SIZE - BLOCK_SIZE < currentTailIndex - head)); + if (pr_blockIndexRaw == nullptr || pr_blockIndexSlotsUsed == pr_blockIndexSize || full) { + if (allocMode == CannotAlloc || full || !new_block_index(originalBlockIndexSlotsUsed)) { + // Failed to allocate, undo changes (but keep injected blocks) + pr_blockIndexFront = originalBlockIndexFront; + pr_blockIndexSlotsUsed = originalBlockIndexSlotsUsed; + this->tailBlock = startBlock == nullptr ? firstAllocatedBlock : startBlock; + return false; + } + + // pr_blockIndexFront is updated inside new_block_index, so we need to + // update our fallback value too (since we keep the new index even if we + // later fail) + originalBlockIndexFront = originalBlockIndexSlotsUsed; + } + + // Insert a new block in the circular linked list + auto newBlock = this->parent->ConcurrentQueue::template requisition_block(); + if (newBlock == nullptr) { + pr_blockIndexFront = originalBlockIndexFront; + pr_blockIndexSlotsUsed = originalBlockIndexSlotsUsed; + this->tailBlock = startBlock == nullptr ? firstAllocatedBlock : startBlock; + return false; + } + +#if MCDBGQ_TRACKMEM + newBlock->owner = this; +#endif + newBlock->ConcurrentQueue::Block::template set_all_empty(); + if (this->tailBlock == nullptr) { + newBlock->next = newBlock; + } else { + newBlock->next = this->tailBlock->next; + this->tailBlock->next = newBlock; + } + this->tailBlock = newBlock; + firstAllocatedBlock = + firstAllocatedBlock == nullptr ? this->tailBlock : firstAllocatedBlock; + + ++pr_blockIndexSlotsUsed; + + auto &entry = blockIndex.load(std::memory_order_relaxed)->entries[pr_blockIndexFront]; + entry.base = currentTailIndex; + entry.block = this->tailBlock; + pr_blockIndexFront = (pr_blockIndexFront + 1) & (pr_blockIndexSize - 1); + } + + // Excellent, all allocations succeeded. Reset each block's emptiness before we fill them up, and + // publish the new block index front + auto block = firstAllocatedBlock; + while (true) { + block->ConcurrentQueue::Block::template reset_empty(); + if (block == this->tailBlock) { + break; + } + block = block->next; + } + + if (MOODYCAMEL_NOEXCEPT_CTOR(T, decltype(*itemFirst), + new(nullptr) T(details::deref_noexcept(itemFirst)))) { + blockIndex.load(std::memory_order_relaxed)->front.store( + (pr_blockIndexFront - 1) & (pr_blockIndexSize - 1), std::memory_order_release); + } + } + + // Enqueue, one block at a time + index_t newTailIndex = startTailIndex + static_cast(count); + currentTailIndex = startTailIndex; + auto endBlock = this->tailBlock; + this->tailBlock = startBlock; + assert((startTailIndex & static_cast(BLOCK_SIZE - 1)) != 0 || + firstAllocatedBlock != nullptr || count == 0); + if ((startTailIndex & static_cast(BLOCK_SIZE - 1)) == 0 && + firstAllocatedBlock != nullptr) { + this->tailBlock = firstAllocatedBlock; + } + while (true) { + auto stopIndex = (currentTailIndex & ~static_cast(BLOCK_SIZE - 1)) + + static_cast(BLOCK_SIZE); + if (details::circular_less_than(newTailIndex, stopIndex)) { + stopIndex = newTailIndex; + } + if (MOODYCAMEL_NOEXCEPT_CTOR(T, decltype(*itemFirst), + new(nullptr) T(details::deref_noexcept(itemFirst)))) { + while (currentTailIndex != stopIndex) { + new((*this->tailBlock)[currentTailIndex++]) T(*itemFirst++); + } + } else { + MOODYCAMEL_TRY { + while (currentTailIndex != stopIndex) { + // Must use copy constructor even if move constructor is available + // because we may have to revert if there's an exception. + // Sorry about the horrible templated next line, but it was the only way + // to disable moving *at compile time*, which is important because a type + // may only define a (noexcept) move constructor, and so calls to the + // cctor will not compile, even if they are in an if branch that will never + // be executed + new((*this->tailBlock)[currentTailIndex]) T( + details::nomove_if<(bool) !MOODYCAMEL_NOEXCEPT_CTOR(T, decltype(*itemFirst), + new(nullptr) T( + details::deref_noexcept( + itemFirst)))>::eval( + *itemFirst)); + ++currentTailIndex; + ++itemFirst; + } + } + MOODYCAMEL_CATCH (...) { + // Oh dear, an exception's been thrown -- destroy the elements that + // were enqueued so far and revert the entire bulk operation (we'll keep + // any allocated blocks in our linked list for later, though). + auto constructedStopIndex = currentTailIndex; + auto lastBlockEnqueued = this->tailBlock; + + pr_blockIndexFront = originalBlockIndexFront; + pr_blockIndexSlotsUsed = originalBlockIndexSlotsUsed; + this->tailBlock = startBlock == nullptr ? firstAllocatedBlock : startBlock; + + if (!details::is_trivially_destructible::value) { + auto block = startBlock; + if ((startTailIndex & static_cast(BLOCK_SIZE - 1)) == 0) { + block = firstAllocatedBlock; + } + currentTailIndex = startTailIndex; + while (true) { + stopIndex = (currentTailIndex & ~static_cast(BLOCK_SIZE - 1)) + + static_cast(BLOCK_SIZE); + if (details::circular_less_than(constructedStopIndex, stopIndex)) { + stopIndex = constructedStopIndex; + } + while (currentTailIndex != stopIndex) { + (*block)[currentTailIndex++]->~T(); + } + if (block == lastBlockEnqueued) { + break; + } + block = block->next; + } + } + MOODYCAMEL_RETHROW; + } + } + + if (this->tailBlock == endBlock) { + assert(currentTailIndex == newTailIndex); + break; + } + this->tailBlock = this->tailBlock->next; + } + + if (!MOODYCAMEL_NOEXCEPT_CTOR(T, decltype(*itemFirst), + new(nullptr) T(details::deref_noexcept(itemFirst))) && + firstAllocatedBlock != nullptr) { + blockIndex.load(std::memory_order_relaxed)->front.store( + (pr_blockIndexFront - 1) & (pr_blockIndexSize - 1), std::memory_order_release); + } + + this->tailIndex.store(newTailIndex, std::memory_order_release); + return true; + } + + template + size_t dequeue_bulk(It &itemFirst, size_t max) { + auto tail = this->tailIndex.load(std::memory_order_relaxed); + auto overcommit = this->dequeueOvercommit.load(std::memory_order_relaxed); + auto desiredCount = static_cast(tail - (this->dequeueOptimisticCount.load( + std::memory_order_relaxed) - overcommit)); + if (details::circular_less_than(0, desiredCount)) { + desiredCount = desiredCount < max ? desiredCount : max; + std::atomic_thread_fence(std::memory_order_acquire); + + auto myDequeueCount = this->dequeueOptimisticCount.fetch_add(desiredCount, + std::memory_order_relaxed); + assert(overcommit <= myDequeueCount); + + tail = this->tailIndex.load(std::memory_order_acquire); + auto actualCount = static_cast(tail - (myDequeueCount - overcommit)); + if (details::circular_less_than(0, actualCount)) { + actualCount = desiredCount < actualCount ? desiredCount : actualCount; + if (actualCount < desiredCount) { + this->dequeueOvercommit.fetch_add(desiredCount - actualCount, + std::memory_order_release); + } + + // Get the first index. Note that since there's guaranteed to be at least actualCount elements, this + // will never exceed tail. + auto firstIndex = this->headIndex.fetch_add(actualCount, std::memory_order_acq_rel); + + // Determine which block the first element is in + auto localBlockIndex = blockIndex.load(std::memory_order_acquire); + auto localBlockIndexHead = localBlockIndex->front.load(std::memory_order_acquire); + + auto headBase = localBlockIndex->entries[localBlockIndexHead].base; + auto firstBlockBaseIndex = firstIndex & ~static_cast(BLOCK_SIZE - 1); + auto offset = static_cast( + static_cast::type>(firstBlockBaseIndex - headBase) / + BLOCK_SIZE); + auto indexIndex = (localBlockIndexHead + offset) & (localBlockIndex->size - 1); + + // Iterate the blocks and dequeue + auto index = firstIndex; + do { + auto firstIndexInBlock = index; + auto endIndex = + (index & ~static_cast(BLOCK_SIZE - 1)) + static_cast(BLOCK_SIZE); + endIndex = details::circular_less_than( + firstIndex + static_cast(actualCount), endIndex) ? firstIndex + + static_cast(actualCount) + : endIndex; + auto block = localBlockIndex->entries[indexIndex].block; + if (MOODYCAMEL_NOEXCEPT_ASSIGN(T, T &&, details::deref_noexcept(itemFirst) = std::move( + (*(*block)[index])))) { + while (index != endIndex) { + auto &el = *((*block)[index]); + *itemFirst++ = std::move(el); + el.~T(); + ++index; + } + } else { + MOODYCAMEL_TRY { + while (index != endIndex) { + auto &el = *((*block)[index]); + *itemFirst = std::move(el); + ++itemFirst; + el.~T(); + ++index; + } + } + MOODYCAMEL_CATCH (...) { + // It's too late to revert the dequeue, but we can make sure that all + // the dequeued objects are properly destroyed and the block index + // (and empty count) are properly updated before we propagate the exception + do { + block = localBlockIndex->entries[indexIndex].block; + while (index != endIndex) { + (*block)[index++]->~T(); + } + block->ConcurrentQueue::Block::template set_many_empty( + firstIndexInBlock, static_cast(endIndex - firstIndexInBlock)); + indexIndex = (indexIndex + 1) & (localBlockIndex->size - 1); + + firstIndexInBlock = index; + endIndex = (index & ~static_cast(BLOCK_SIZE - 1)) + + static_cast(BLOCK_SIZE); + endIndex = details::circular_less_than( + firstIndex + static_cast(actualCount), endIndex) ? firstIndex + + static_cast(actualCount) + : endIndex; + } while (index != firstIndex + actualCount); + + MOODYCAMEL_RETHROW; + } + } + block->ConcurrentQueue::Block::template set_many_empty( + firstIndexInBlock, static_cast(endIndex - firstIndexInBlock)); + indexIndex = (indexIndex + 1) & (localBlockIndex->size - 1); + } while (index != firstIndex + actualCount); + + return actualCount; + } else { + // Wasn't anything to dequeue after all; make the effective dequeue count eventually consistent + this->dequeueOvercommit.fetch_add(desiredCount, std::memory_order_release); + } + } + + return 0; + } + + private: + struct BlockIndexEntry { + index_t base; + Block *block; + }; + + struct BlockIndexHeader { + size_t size; + std::atomic front; // Current slot (not next, like pr_blockIndexFront) + BlockIndexEntry *entries; + void *prev; + }; + + + bool new_block_index(size_t numberOfFilledSlotsToExpose) { + auto prevBlockSizeMask = pr_blockIndexSize - 1; + + // Create the new block + pr_blockIndexSize <<= 1; + auto newRawPtr = static_cast((Traits::malloc)( + sizeof(BlockIndexHeader) + std::alignment_of::value - 1 + + sizeof(BlockIndexEntry) * pr_blockIndexSize)); + if (newRawPtr == nullptr) { + pr_blockIndexSize >>= 1; // Reset to allow graceful retry + return false; + } + + auto newBlockIndexEntries = reinterpret_cast(details::align_for( + newRawPtr + sizeof(BlockIndexHeader))); + + // Copy in all the old indices, if any + size_t j = 0; + if (pr_blockIndexSlotsUsed != 0) { + auto i = (pr_blockIndexFront - pr_blockIndexSlotsUsed) & prevBlockSizeMask; + do { + newBlockIndexEntries[j++] = pr_blockIndexEntries[i]; + i = (i + 1) & prevBlockSizeMask; + } while (i != pr_blockIndexFront); + } + + // Update everything + auto header = new(newRawPtr) BlockIndexHeader; + header->size = pr_blockIndexSize; + header->front.store(numberOfFilledSlotsToExpose - 1, std::memory_order_relaxed); + header->entries = newBlockIndexEntries; + header->prev = pr_blockIndexRaw; // we link the new block to the old one so we can free it later + + pr_blockIndexFront = j; + pr_blockIndexEntries = newBlockIndexEntries; + pr_blockIndexRaw = newRawPtr; + blockIndex.store(header, std::memory_order_release); + + return true; + } + + private: + std::atomic blockIndex; + + // To be used by producer only -- consumer must use the ones in referenced by blockIndex + size_t pr_blockIndexSlotsUsed; + size_t pr_blockIndexSize; + size_t pr_blockIndexFront; // Next slot (not current) + BlockIndexEntry *pr_blockIndexEntries; + void *pr_blockIndexRaw; + +#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG + public: + ExplicitProducer* nextExplicitProducer; + private: +#endif + +#if MCDBGQ_TRACKMEM + friend struct MemStats; +#endif + }; + + + ////////////////////////////////// + // Implicit queue + ////////////////////////////////// + + struct ImplicitProducer : public ProducerBase { + ImplicitProducer(ConcurrentQueue *parent) + : + ProducerBase(parent, false), nextBlockIndexCapacity(IMPLICIT_INITIAL_INDEX_SIZE), blockIndex( + nullptr) { + new_block_index(); + } + + ~ImplicitProducer() { + // Note that since we're in the destructor we can assume that all enqueue/dequeue operations + // completed already; this means that all undequeued elements are placed contiguously across + // contiguous blocks, and that only the first and last remaining blocks can be only partially + // empty (all other remaining blocks must be completely full). + +#ifdef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED + // Unregister ourselves for thread termination notification + if (!this->inactive.load(std::memory_order_relaxed)) { + details::ThreadExitNotifier::unsubscribe(&threadExitListener); + } +#endif + + // Destroy all remaining elements! + auto tail = this->tailIndex.load(std::memory_order_relaxed); + auto index = this->headIndex.load(std::memory_order_relaxed); + Block *block = nullptr; + assert(index == tail || details::circular_less_than(index, tail)); + bool forceFreeLastBlock = + index != tail; // If we enter the loop, then the last (tail) block will not be freed + while (index != tail) { + if ((index & static_cast(BLOCK_SIZE - 1)) == 0 || block == nullptr) { + if (block != nullptr) { + // Free the old block + this->parent->add_block_to_free_list(block); + } + + block = get_block_index_entry_for_index(index)->value.load(std::memory_order_relaxed); + } + + ((*block)[index])->~T(); + ++index; + } + // Even if the queue is empty, there's still one block that's not on the free list + // (unless the head index reached the end of it, in which case the tail will be poised + // to create a new block). + if (this->tailBlock != nullptr && + (forceFreeLastBlock || (tail & static_cast(BLOCK_SIZE - 1)) != 0)) { + this->parent->add_block_to_free_list(this->tailBlock); + } + + // Destroy block index + auto localBlockIndex = blockIndex.load(std::memory_order_relaxed); + if (localBlockIndex != nullptr) { + for (size_t i = 0; i != localBlockIndex->capacity; ++i) { + localBlockIndex->index[i]->~BlockIndexEntry(); + } + do { + auto prev = localBlockIndex->prev; + localBlockIndex->~BlockIndexHeader(); + (Traits::free)(localBlockIndex); + localBlockIndex = prev; + } while (localBlockIndex != nullptr); + } + } + + template + inline bool enqueue(U &&element) { + index_t currentTailIndex = this->tailIndex.load(std::memory_order_relaxed); + index_t newTailIndex = 1 + currentTailIndex; + if ((currentTailIndex & static_cast(BLOCK_SIZE - 1)) == 0) { + // We reached the end of a block, start a new one + auto head = this->headIndex.load(std::memory_order_relaxed); + assert(!details::circular_less_than(currentTailIndex, head)); + if (!details::circular_less_than(head, currentTailIndex + BLOCK_SIZE) || + (MAX_SUBQUEUE_SIZE != details::const_numeric_max::value && + (MAX_SUBQUEUE_SIZE == 0 || + MAX_SUBQUEUE_SIZE - BLOCK_SIZE < currentTailIndex - head))) { + return false; + } +#if MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX + debug::DebugLock lock(mutex); +#endif + // Find out where we'll be inserting this block in the block index + BlockIndexEntry *idxEntry; + if (!insert_block_index_entry(idxEntry, currentTailIndex)) { + return false; + } + + // Get ahold of a new block + auto newBlock = this->parent->ConcurrentQueue::template requisition_block(); + if (newBlock == nullptr) { + rewind_block_index_tail(); + idxEntry->value.store(nullptr, std::memory_order_relaxed); + return false; + } +#if MCDBGQ_TRACKMEM + newBlock->owner = this; +#endif + newBlock->ConcurrentQueue::Block::template reset_empty(); + + if (!MOODYCAMEL_NOEXCEPT_CTOR(T, U, new(nullptr) T(std::forward(element)))) { + // May throw, try to insert now before we publish the fact that we have this new block + MOODYCAMEL_TRY { + new((*newBlock)[currentTailIndex]) T(std::forward(element)); + } + MOODYCAMEL_CATCH (...) { + rewind_block_index_tail(); + idxEntry->value.store(nullptr, std::memory_order_relaxed); + this->parent->add_block_to_free_list(newBlock); + MOODYCAMEL_RETHROW; + } + } + + // Insert the new block into the index + idxEntry->value.store(newBlock, std::memory_order_relaxed); + + this->tailBlock = newBlock; + + if (!MOODYCAMEL_NOEXCEPT_CTOR(T, U, new(nullptr) T(std::forward(element)))) { + this->tailIndex.store(newTailIndex, std::memory_order_release); + return true; + } + } + + // Enqueue + new((*this->tailBlock)[currentTailIndex]) T(std::forward(element)); + + this->tailIndex.store(newTailIndex, std::memory_order_release); + return true; + } + + template + bool dequeue(U &element) { + // See ExplicitProducer::dequeue for rationale and explanation + index_t tail = this->tailIndex.load(std::memory_order_relaxed); + index_t overcommit = this->dequeueOvercommit.load(std::memory_order_relaxed); + if (details::circular_less_than( + this->dequeueOptimisticCount.load(std::memory_order_relaxed) - overcommit, tail)) { + std::atomic_thread_fence(std::memory_order_acquire); + + index_t myDequeueCount = this->dequeueOptimisticCount.fetch_add(1, + std::memory_order_relaxed); + assert(overcommit <= myDequeueCount); + tail = this->tailIndex.load(std::memory_order_acquire); + if (details::likely( + details::circular_less_than(myDequeueCount - overcommit, tail))) { + index_t index = this->headIndex.fetch_add(1, std::memory_order_acq_rel); + + // Determine which block the element is in + auto entry = get_block_index_entry_for_index(index); + + // Dequeue + auto block = entry->value.load(std::memory_order_relaxed); + auto &el = *((*block)[index]); + + if (!MOODYCAMEL_NOEXCEPT_ASSIGN(T, T &&, element = std::move(el))) { +#if MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX + // Note: Acquiring the mutex with every dequeue instead of only when a block + // is released is very sub-optimal, but it is, after all, purely debug code. + debug::DebugLock lock(producer->mutex); +#endif + struct Guard { + Block *block; + index_t index; + BlockIndexEntry *entry; + ConcurrentQueue *parent; + + ~Guard() { + (*block)[index]->~T(); + if (block->ConcurrentQueue::Block::template set_empty(index)) { + entry->value.store(nullptr, std::memory_order_relaxed); + parent->add_block_to_free_list(block); + } + } + } guard = {block, index, entry, this->parent}; + + element = std::move(el); + } else { + element = std::move(el); + el.~T(); + + if (block->ConcurrentQueue::Block::template set_empty(index)) { + { +#if MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX + debug::DebugLock lock(mutex); +#endif + // Add the block back into the global free pool (and remove from block index) + entry->value.store(nullptr, std::memory_order_relaxed); + } + this->parent->add_block_to_free_list(block); // releases the above store + } + } + + return true; + } else { + this->dequeueOvercommit.fetch_add(1, std::memory_order_release); + } + } + + return false; + } + + template + bool enqueue_bulk(It itemFirst, size_t count) { + // First, we need to make sure we have enough room to enqueue all of the elements; + // this means pre-allocating blocks and putting them in the block index (but only if + // all the allocations succeeded). + + // Note that the tailBlock we start off with may not be owned by us any more; + // this happens if it was filled up exactly to the top (setting tailIndex to + // the first index of the next block which is not yet allocated), then dequeued + // completely (putting it on the free list) before we enqueue again. + + index_t startTailIndex = this->tailIndex.load(std::memory_order_relaxed); + auto startBlock = this->tailBlock; + Block *firstAllocatedBlock = nullptr; + auto endBlock = this->tailBlock; + + // Figure out how many blocks we'll need to allocate, and do so + size_t blockBaseDiff = + ((startTailIndex + count - 1) & ~static_cast(BLOCK_SIZE - 1)) - + ((startTailIndex - 1) & ~static_cast(BLOCK_SIZE - 1)); + index_t currentTailIndex = (startTailIndex - 1) & ~static_cast(BLOCK_SIZE - 1); + if (blockBaseDiff > 0) { +#if MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX + debug::DebugLock lock(mutex); +#endif + do { + blockBaseDiff -= static_cast(BLOCK_SIZE); + currentTailIndex += static_cast(BLOCK_SIZE); + + // Find out where we'll be inserting this block in the block index + BlockIndexEntry *idxEntry = nullptr; // initialization here unnecessary but compiler can't always tell + Block *newBlock; + bool indexInserted = false; + auto head = this->headIndex.load(std::memory_order_relaxed); + assert(!details::circular_less_than(currentTailIndex, head)); + bool full = !details::circular_less_than(head, currentTailIndex + BLOCK_SIZE) || + (MAX_SUBQUEUE_SIZE != details::const_numeric_max::value && + (MAX_SUBQUEUE_SIZE == 0 || + MAX_SUBQUEUE_SIZE - BLOCK_SIZE < currentTailIndex - head)); + if (full || + !(indexInserted = insert_block_index_entry(idxEntry, currentTailIndex)) || + (newBlock = this->parent->ConcurrentQueue::template requisition_block()) == + nullptr) { + // Index allocation or block allocation failed; revert any other allocations + // and index insertions done so far for this operation + if (indexInserted) { + rewind_block_index_tail(); + idxEntry->value.store(nullptr, std::memory_order_relaxed); + } + currentTailIndex = (startTailIndex - 1) & ~static_cast(BLOCK_SIZE - 1); + for (auto block = firstAllocatedBlock; block != nullptr; block = block->next) { + currentTailIndex += static_cast(BLOCK_SIZE); + idxEntry = get_block_index_entry_for_index(currentTailIndex); + idxEntry->value.store(nullptr, std::memory_order_relaxed); + rewind_block_index_tail(); + } + this->parent->add_blocks_to_free_list(firstAllocatedBlock); + this->tailBlock = startBlock; + + return false; + } + +#if MCDBGQ_TRACKMEM + newBlock->owner = this; +#endif + newBlock->ConcurrentQueue::Block::template reset_empty(); + newBlock->next = nullptr; + + // Insert the new block into the index + idxEntry->value.store(newBlock, std::memory_order_relaxed); + + // Store the chain of blocks so that we can undo if later allocations fail, + // and so that we can find the blocks when we do the actual enqueueing + if ((startTailIndex & static_cast(BLOCK_SIZE - 1)) != 0 || + firstAllocatedBlock != nullptr) { + assert(this->tailBlock != nullptr); + this->tailBlock->next = newBlock; + } + this->tailBlock = newBlock; + endBlock = newBlock; + firstAllocatedBlock = firstAllocatedBlock == nullptr ? newBlock : firstAllocatedBlock; + } while (blockBaseDiff > 0); + } + + // Enqueue, one block at a time + index_t newTailIndex = startTailIndex + static_cast(count); + currentTailIndex = startTailIndex; + this->tailBlock = startBlock; + assert((startTailIndex & static_cast(BLOCK_SIZE - 1)) != 0 || + firstAllocatedBlock != nullptr || count == 0); + if ((startTailIndex & static_cast(BLOCK_SIZE - 1)) == 0 && + firstAllocatedBlock != nullptr) { + this->tailBlock = firstAllocatedBlock; + } + while (true) { + auto stopIndex = (currentTailIndex & ~static_cast(BLOCK_SIZE - 1)) + + static_cast(BLOCK_SIZE); + if (details::circular_less_than(newTailIndex, stopIndex)) { + stopIndex = newTailIndex; + } + if (MOODYCAMEL_NOEXCEPT_CTOR(T, decltype(*itemFirst), + new(nullptr) T(details::deref_noexcept(itemFirst)))) { + while (currentTailIndex != stopIndex) { + new((*this->tailBlock)[currentTailIndex++]) T(*itemFirst++); + } + } else { + MOODYCAMEL_TRY { + while (currentTailIndex != stopIndex) { + new((*this->tailBlock)[currentTailIndex]) T( + details::nomove_if<(bool) !MOODYCAMEL_NOEXCEPT_CTOR(T, decltype(*itemFirst), + new(nullptr) T( + details::deref_noexcept( + itemFirst)))>::eval( + *itemFirst)); + ++currentTailIndex; + ++itemFirst; + } + } + MOODYCAMEL_CATCH (...) { + auto constructedStopIndex = currentTailIndex; + auto lastBlockEnqueued = this->tailBlock; + + if (!details::is_trivially_destructible::value) { + auto block = startBlock; + if ((startTailIndex & static_cast(BLOCK_SIZE - 1)) == 0) { + block = firstAllocatedBlock; + } + currentTailIndex = startTailIndex; + while (true) { + stopIndex = (currentTailIndex & ~static_cast(BLOCK_SIZE - 1)) + + static_cast(BLOCK_SIZE); + if (details::circular_less_than(constructedStopIndex, stopIndex)) { + stopIndex = constructedStopIndex; + } + while (currentTailIndex != stopIndex) { + (*block)[currentTailIndex++]->~T(); + } + if (block == lastBlockEnqueued) { + break; + } + block = block->next; + } + } + + currentTailIndex = (startTailIndex - 1) & ~static_cast(BLOCK_SIZE - 1); + for (auto block = firstAllocatedBlock; block != nullptr; block = block->next) { + currentTailIndex += static_cast(BLOCK_SIZE); + auto idxEntry = get_block_index_entry_for_index(currentTailIndex); + idxEntry->value.store(nullptr, std::memory_order_relaxed); + rewind_block_index_tail(); + } + this->parent->add_blocks_to_free_list(firstAllocatedBlock); + this->tailBlock = startBlock; + MOODYCAMEL_RETHROW; + } + } + + if (this->tailBlock == endBlock) { + assert(currentTailIndex == newTailIndex); + break; + } + this->tailBlock = this->tailBlock->next; + } + this->tailIndex.store(newTailIndex, std::memory_order_release); + return true; + } + + template + size_t dequeue_bulk(It &itemFirst, size_t max) { + auto tail = this->tailIndex.load(std::memory_order_relaxed); + auto overcommit = this->dequeueOvercommit.load(std::memory_order_relaxed); + auto desiredCount = static_cast(tail - (this->dequeueOptimisticCount.load( + std::memory_order_relaxed) - overcommit)); + if (details::circular_less_than(0, desiredCount)) { + desiredCount = desiredCount < max ? desiredCount : max; + std::atomic_thread_fence(std::memory_order_acquire); + + auto myDequeueCount = this->dequeueOptimisticCount.fetch_add(desiredCount, + std::memory_order_relaxed); + assert(overcommit <= myDequeueCount); + + tail = this->tailIndex.load(std::memory_order_acquire); + auto actualCount = static_cast(tail - (myDequeueCount - overcommit)); + if (details::circular_less_than(0, actualCount)) { + actualCount = desiredCount < actualCount ? desiredCount : actualCount; + if (actualCount < desiredCount) { + this->dequeueOvercommit.fetch_add(desiredCount - actualCount, + std::memory_order_release); + } + + // Get the first index. Note that since there's guaranteed to be at least actualCount elements, this + // will never exceed tail. + auto firstIndex = this->headIndex.fetch_add(actualCount, std::memory_order_acq_rel); + + // Iterate the blocks and dequeue + auto index = firstIndex; + BlockIndexHeader *localBlockIndex; + auto indexIndex = get_block_index_index_for_index(index, localBlockIndex); + do { + auto blockStartIndex = index; + auto endIndex = + (index & ~static_cast(BLOCK_SIZE - 1)) + static_cast(BLOCK_SIZE); + endIndex = details::circular_less_than( + firstIndex + static_cast(actualCount), endIndex) ? firstIndex + + static_cast(actualCount) + : endIndex; + + auto entry = localBlockIndex->index[indexIndex]; + auto block = entry->value.load(std::memory_order_relaxed); + if (MOODYCAMEL_NOEXCEPT_ASSIGN(T, T &&, details::deref_noexcept(itemFirst) = std::move( + (*(*block)[index])))) { + while (index != endIndex) { + auto &el = *((*block)[index]); + *itemFirst++ = std::move(el); + el.~T(); + ++index; + } + } else { + MOODYCAMEL_TRY { + while (index != endIndex) { + auto &el = *((*block)[index]); + *itemFirst = std::move(el); + ++itemFirst; + el.~T(); + ++index; + } + } + MOODYCAMEL_CATCH (...) { + do { + entry = localBlockIndex->index[indexIndex]; + block = entry->value.load(std::memory_order_relaxed); + while (index != endIndex) { + (*block)[index++]->~T(); + } + + if (block->ConcurrentQueue::Block::template set_many_empty( + blockStartIndex, static_cast(endIndex - blockStartIndex))) { +#if MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX + debug::DebugLock lock(mutex); +#endif + entry->value.store(nullptr, std::memory_order_relaxed); + this->parent->add_block_to_free_list(block); + } + indexIndex = (indexIndex + 1) & (localBlockIndex->capacity - 1); + + blockStartIndex = index; + endIndex = (index & ~static_cast(BLOCK_SIZE - 1)) + + static_cast(BLOCK_SIZE); + endIndex = details::circular_less_than( + firstIndex + static_cast(actualCount), endIndex) ? firstIndex + + static_cast(actualCount) + : endIndex; + } while (index != firstIndex + actualCount); + + MOODYCAMEL_RETHROW; + } + } + if (block->ConcurrentQueue::Block::template set_many_empty( + blockStartIndex, static_cast(endIndex - blockStartIndex))) { + { +#if MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX + debug::DebugLock lock(mutex); +#endif + // Note that the set_many_empty above did a release, meaning that anybody who acquires the block + // we're about to free can use it safely since our writes (and reads!) will have happened-before then. + entry->value.store(nullptr, std::memory_order_relaxed); + } + this->parent->add_block_to_free_list(block); // releases the above store + } + indexIndex = (indexIndex + 1) & (localBlockIndex->capacity - 1); + } while (index != firstIndex + actualCount); + + return actualCount; + } else { + this->dequeueOvercommit.fetch_add(desiredCount, std::memory_order_release); + } + } + + return 0; + } + + private: + // The block size must be > 1, so any number with the low bit set is an invalid block base index + static const index_t INVALID_BLOCK_BASE = 1; + + struct BlockIndexEntry { + std::atomic key; + std::atomic value; + }; + + struct BlockIndexHeader { + size_t capacity; + std::atomic tail; + BlockIndexEntry *entries; + BlockIndexEntry **index; + BlockIndexHeader *prev; + }; + + template + inline bool insert_block_index_entry(BlockIndexEntry *&idxEntry, index_t blockStartIndex) { + auto localBlockIndex = blockIndex.load( + std::memory_order_relaxed); // We're the only writer thread, relaxed is OK + if (localBlockIndex == nullptr) { + return false; // this can happen if new_block_index failed in the constructor + } + auto newTail = (localBlockIndex->tail.load(std::memory_order_relaxed) + 1) & + (localBlockIndex->capacity - 1); + idxEntry = localBlockIndex->index[newTail]; + if (idxEntry->key.load(std::memory_order_relaxed) == INVALID_BLOCK_BASE || + idxEntry->value.load(std::memory_order_relaxed) == nullptr) { + + idxEntry->key.store(blockStartIndex, std::memory_order_relaxed); + localBlockIndex->tail.store(newTail, std::memory_order_release); + return true; + } + + // No room in the old block index, try to allocate another one! + if (allocMode == CannotAlloc || !new_block_index()) { + return false; + } + localBlockIndex = blockIndex.load(std::memory_order_relaxed); + newTail = (localBlockIndex->tail.load(std::memory_order_relaxed) + 1) & + (localBlockIndex->capacity - 1); + idxEntry = localBlockIndex->index[newTail]; + assert(idxEntry->key.load(std::memory_order_relaxed) == INVALID_BLOCK_BASE); + idxEntry->key.store(blockStartIndex, std::memory_order_relaxed); + localBlockIndex->tail.store(newTail, std::memory_order_release); + return true; + } + + inline void rewind_block_index_tail() { + auto localBlockIndex = blockIndex.load(std::memory_order_relaxed); + localBlockIndex->tail.store((localBlockIndex->tail.load(std::memory_order_relaxed) - 1) & + (localBlockIndex->capacity - 1), std::memory_order_relaxed); + } + + inline BlockIndexEntry *get_block_index_entry_for_index(index_t index) const { + BlockIndexHeader *localBlockIndex; + auto idx = get_block_index_index_for_index(index, localBlockIndex); + return localBlockIndex->index[idx]; + } + + inline size_t + get_block_index_index_for_index(index_t index, BlockIndexHeader *&localBlockIndex) const { +#if MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX + debug::DebugLock lock(mutex); +#endif + index &= ~static_cast(BLOCK_SIZE - 1); + localBlockIndex = blockIndex.load(std::memory_order_acquire); + auto tail = localBlockIndex->tail.load(std::memory_order_acquire); + auto tailBase = localBlockIndex->index[tail]->key.load(std::memory_order_relaxed); + assert(tailBase != INVALID_BLOCK_BASE); + // Note: Must use division instead of shift because the index may wrap around, causing a negative + // offset, whose negativity we want to preserve + auto offset = static_cast( + static_cast::type>(index - tailBase) / BLOCK_SIZE); + size_t idx = (tail + offset) & (localBlockIndex->capacity - 1); + assert(localBlockIndex->index[idx]->key.load(std::memory_order_relaxed) == index && + localBlockIndex->index[idx]->value.load(std::memory_order_relaxed) != nullptr); + return idx; + } + + bool new_block_index() { + auto prev = blockIndex.load(std::memory_order_relaxed); + size_t prevCapacity = prev == nullptr ? 0 : prev->capacity; + auto entryCount = prev == nullptr ? nextBlockIndexCapacity : prevCapacity; + auto raw = static_cast((Traits::malloc)( + sizeof(BlockIndexHeader) + + std::alignment_of::value - 1 + sizeof(BlockIndexEntry) * entryCount + + std::alignment_of::value - 1 + + sizeof(BlockIndexEntry * ) * nextBlockIndexCapacity)); + if (raw == nullptr) { + return false; + } + + auto header = new(raw) BlockIndexHeader; + auto entries = reinterpret_cast(details::align_for( + raw + sizeof(BlockIndexHeader))); + auto index = reinterpret_cast(details::align_for( + reinterpret_cast(entries) + sizeof(BlockIndexEntry) * entryCount)); + if (prev != nullptr) { + auto prevTail = prev->tail.load(std::memory_order_relaxed); + auto prevPos = prevTail; + size_t i = 0; + do { + prevPos = (prevPos + 1) & (prev->capacity - 1); + index[i++] = prev->index[prevPos]; + } while (prevPos != prevTail); + assert(i == prevCapacity); + } + for (size_t i = 0; i != entryCount; ++i) { + new(entries + i) BlockIndexEntry; + entries[i].key.store(INVALID_BLOCK_BASE, std::memory_order_relaxed); + index[prevCapacity + i] = entries + i; + } + header->prev = prev; + header->entries = entries; + header->index = index; + header->capacity = nextBlockIndexCapacity; + header->tail.store((prevCapacity - 1) & (nextBlockIndexCapacity - 1), + std::memory_order_relaxed); + + blockIndex.store(header, std::memory_order_release); + + nextBlockIndexCapacity <<= 1; + + return true; + } + + private: + size_t nextBlockIndexCapacity; + std::atomic blockIndex; + +#ifdef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED + public: + details::ThreadExitListener threadExitListener; + private: +#endif + +#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG + public: + ImplicitProducer* nextImplicitProducer; + private: +#endif + +#if MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX + mutable debug::DebugMutex mutex; +#endif +#if MCDBGQ_TRACKMEM + friend struct MemStats; +#endif + }; + + + ////////////////////////////////// + // Block pool manipulation + ////////////////////////////////// + + void populate_initial_block_list(size_t blockCount) { + initialBlockPoolSize = blockCount; + if (initialBlockPoolSize == 0) { + initialBlockPool = nullptr; + return; + } + + initialBlockPool = create_array(blockCount); + if (initialBlockPool == nullptr) { + initialBlockPoolSize = 0; + } + for (size_t i = 0; i < initialBlockPoolSize; ++i) { + initialBlockPool[i].dynamicallyAllocated = false; + } + } + + inline Block *try_get_block_from_initial_pool() { + if (initialBlockPoolIndex.load(std::memory_order_relaxed) >= initialBlockPoolSize) { + return nullptr; + } + + auto index = initialBlockPoolIndex.fetch_add(1, std::memory_order_relaxed); + + return index < initialBlockPoolSize ? (initialBlockPool + index) : nullptr; + } + + inline void add_block_to_free_list(Block *block) { +#if MCDBGQ_TRACKMEM + block->owner = nullptr; +#endif + freeList.add(block); + } + + inline void add_blocks_to_free_list(Block *block) { + while (block != nullptr) { + auto next = block->next; + add_block_to_free_list(block); + block = next; + } + } + + inline Block *try_get_block_from_free_list() { + return freeList.try_get(); + } + + // Gets a free block from one of the memory pools, or allocates a new one (if applicable) + template + Block *requisition_block() { + auto block = try_get_block_from_initial_pool(); + if (block != nullptr) { + return block; + } + + block = try_get_block_from_free_list(); + if (block != nullptr) { + return block; + } + + if (canAlloc == CanAlloc) { + return create(); + } + + return nullptr; + } + + +#if MCDBGQ_TRACKMEM + public: + struct MemStats { + size_t allocatedBlocks; + size_t usedBlocks; + size_t freeBlocks; + size_t ownedBlocksExplicit; + size_t ownedBlocksImplicit; + size_t implicitProducers; + size_t explicitProducers; + size_t elementsEnqueued; + size_t blockClassBytes; + size_t queueClassBytes; + size_t implicitBlockIndexBytes; + size_t explicitBlockIndexBytes; + + friend class ConcurrentQueue; + + private: + static MemStats getFor(ConcurrentQueue* q) + { + MemStats stats = { 0 }; + + stats.elementsEnqueued = q->size_approx(); + + auto block = q->freeList.head_unsafe(); + while (block != nullptr) { + ++stats.allocatedBlocks; + ++stats.freeBlocks; + block = block->freeListNext.load(std::memory_order_relaxed); + } + + for (auto ptr = q->producerListTail.load(std::memory_order_acquire); ptr != nullptr; ptr = ptr->next_prod()) { + bool implicit = dynamic_cast(ptr) != nullptr; + stats.implicitProducers += implicit ? 1 : 0; + stats.explicitProducers += implicit ? 0 : 1; + + if (implicit) { + auto prod = static_cast(ptr); + stats.queueClassBytes += sizeof(ImplicitProducer); + auto head = prod->headIndex.load(std::memory_order_relaxed); + auto tail = prod->tailIndex.load(std::memory_order_relaxed); + auto hash = prod->blockIndex.load(std::memory_order_relaxed); + if (hash != nullptr) { + for (size_t i = 0; i != hash->capacity; ++i) { + if (hash->index[i]->key.load(std::memory_order_relaxed) != ImplicitProducer::INVALID_BLOCK_BASE && hash->index[i]->value.load(std::memory_order_relaxed) != nullptr) { + ++stats.allocatedBlocks; + ++stats.ownedBlocksImplicit; + } + } + stats.implicitBlockIndexBytes += hash->capacity * sizeof(typename ImplicitProducer::BlockIndexEntry); + for (; hash != nullptr; hash = hash->prev) { + stats.implicitBlockIndexBytes += sizeof(typename ImplicitProducer::BlockIndexHeader) + hash->capacity * sizeof(typename ImplicitProducer::BlockIndexEntry*); + } + } + for (; details::circular_less_than(head, tail); head += BLOCK_SIZE) { + //auto block = prod->get_block_index_entry_for_index(head); + ++stats.usedBlocks; + } + } + else { + auto prod = static_cast(ptr); + stats.queueClassBytes += sizeof(ExplicitProducer); + auto tailBlock = prod->tailBlock; + bool wasNonEmpty = false; + if (tailBlock != nullptr) { + auto block = tailBlock; + do { + ++stats.allocatedBlocks; + if (!block->ConcurrentQueue::Block::template is_empty() || wasNonEmpty) { + ++stats.usedBlocks; + wasNonEmpty = wasNonEmpty || block != tailBlock; + } + ++stats.ownedBlocksExplicit; + block = block->next; + } while (block != tailBlock); + } + auto index = prod->blockIndex.load(std::memory_order_relaxed); + while (index != nullptr) { + stats.explicitBlockIndexBytes += sizeof(typename ExplicitProducer::BlockIndexHeader) + index->size * sizeof(typename ExplicitProducer::BlockIndexEntry); + index = static_cast(index->prev); + } + } + } + + auto freeOnInitialPool = q->initialBlockPoolIndex.load(std::memory_order_relaxed) >= q->initialBlockPoolSize ? 0 : q->initialBlockPoolSize - q->initialBlockPoolIndex.load(std::memory_order_relaxed); + stats.allocatedBlocks += freeOnInitialPool; + stats.freeBlocks += freeOnInitialPool; + + stats.blockClassBytes = sizeof(Block) * stats.allocatedBlocks; + stats.queueClassBytes += sizeof(ConcurrentQueue); + + return stats; + } + }; + + // For debugging only. Not thread-safe. + MemStats getMemStats() + { + return MemStats::getFor(this); + } + private: + friend struct MemStats; +#endif + + + ////////////////////////////////// + // Producer list manipulation + ////////////////////////////////// + + ProducerBase *recycle_or_create_producer(bool isExplicit) { + bool recycled; + return recycle_or_create_producer(isExplicit, recycled); + } + + ProducerBase *recycle_or_create_producer(bool isExplicit, bool &recycled) { +#if MCDBGQ_NOLOCKFREE_IMPLICITPRODHASH + debug::DebugLock lock(implicitProdMutex); +#endif + // Try to re-use one first + for (auto ptr = producerListTail.load(std::memory_order_acquire); + ptr != nullptr; ptr = ptr->next_prod()) { + if (ptr->inactive.load(std::memory_order_relaxed) && ptr->isExplicit == isExplicit) { + bool expected = true; + if (ptr->inactive.compare_exchange_strong(expected, /* desired */ false, + std::memory_order_acquire, + std::memory_order_relaxed)) { + // We caught one! It's been marked as activated, the caller can have it + recycled = true; + return ptr; + } + } + } + + recycled = false; + return add_producer(isExplicit ? static_cast(create(this)) + : create(this)); + } + + ProducerBase *add_producer(ProducerBase *producer) { + // Handle failed memory allocation + if (producer == nullptr) { + return nullptr; + } + + producerCount.fetch_add(1, std::memory_order_relaxed); + + // Add it to the lock-free list + auto prevTail = producerListTail.load(std::memory_order_relaxed); + do { + producer->next = prevTail; + } while (!producerListTail.compare_exchange_weak(prevTail, producer, std::memory_order_release, + std::memory_order_relaxed)); + +#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG + if (producer->isExplicit) { + auto prevTailExplicit = explicitProducers.load(std::memory_order_relaxed); + do { + static_cast(producer)->nextExplicitProducer = prevTailExplicit; + } while (!explicitProducers.compare_exchange_weak(prevTailExplicit, static_cast(producer), std::memory_order_release, std::memory_order_relaxed)); + } + else { + auto prevTailImplicit = implicitProducers.load(std::memory_order_relaxed); + do { + static_cast(producer)->nextImplicitProducer = prevTailImplicit; + } while (!implicitProducers.compare_exchange_weak(prevTailImplicit, static_cast(producer), std::memory_order_release, std::memory_order_relaxed)); + } +#endif + + return producer; + } + + void reown_producers() { + // After another instance is moved-into/swapped-with this one, all the + // producers we stole still think their parents are the other queue. + // So fix them up! + for (auto ptr = producerListTail.load(std::memory_order_relaxed); + ptr != nullptr; ptr = ptr->next_prod()) { + ptr->parent = this; + } + } + + + ////////////////////////////////// + // Implicit producer hash + ////////////////////////////////// + + struct ImplicitProducerKVP { + std::atomic key; + ImplicitProducer *value; // No need for atomicity since it's only read by the thread that sets it in the first place + + ImplicitProducerKVP() + : value(nullptr) {} + + ImplicitProducerKVP(ImplicitProducerKVP &&other) MOODYCAMEL_NOEXCEPT { + key.store(other.key.load(std::memory_order_relaxed), std::memory_order_relaxed); + value = other.value; + } + + inline ImplicitProducerKVP &operator=(ImplicitProducerKVP &&other) MOODYCAMEL_NOEXCEPT { + swap(other); + return *this; + } + + inline void swap(ImplicitProducerKVP &other) MOODYCAMEL_NOEXCEPT { + if (this != &other) { + details::swap_relaxed(key, other.key); + std::swap(value, other.value); + } + } + }; + + template + friend void moodycamel::swap(typename ConcurrentQueue::ImplicitProducerKVP &, + typename ConcurrentQueue::ImplicitProducerKVP &) MOODYCAMEL_NOEXCEPT; + + struct ImplicitProducerHash { + size_t capacity; + ImplicitProducerKVP *entries; + ImplicitProducerHash *prev; + }; + + inline void populate_initial_implicit_producer_hash() { + if (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) return; + + implicitProducerHashCount.store(0, std::memory_order_relaxed); + auto hash = &initialImplicitProducerHash; + hash->capacity = INITIAL_IMPLICIT_PRODUCER_HASH_SIZE; + hash->entries = &initialImplicitProducerHashEntries[0]; + for (size_t i = 0; i != INITIAL_IMPLICIT_PRODUCER_HASH_SIZE; ++i) { + initialImplicitProducerHashEntries[i].key.store(details::invalid_thread_id, + std::memory_order_relaxed); + } + hash->prev = nullptr; + implicitProducerHash.store(hash, std::memory_order_relaxed); + } + + void swap_implicit_producer_hashes(ConcurrentQueue &other) { + if (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) return; + + // Swap (assumes our implicit producer hash is initialized) + initialImplicitProducerHashEntries.swap(other.initialImplicitProducerHashEntries); + initialImplicitProducerHash.entries = &initialImplicitProducerHashEntries[0]; + other.initialImplicitProducerHash.entries = &other.initialImplicitProducerHashEntries[0]; + + details::swap_relaxed(implicitProducerHashCount, other.implicitProducerHashCount); + + details::swap_relaxed(implicitProducerHash, other.implicitProducerHash); + if (implicitProducerHash.load(std::memory_order_relaxed) == + &other.initialImplicitProducerHash) { + implicitProducerHash.store(&initialImplicitProducerHash, std::memory_order_relaxed); + } else { + ImplicitProducerHash *hash; + for (hash = implicitProducerHash.load(std::memory_order_relaxed); + hash->prev != &other.initialImplicitProducerHash; hash = hash->prev) { + continue; + } + hash->prev = &initialImplicitProducerHash; + } + if (other.implicitProducerHash.load(std::memory_order_relaxed) == + &initialImplicitProducerHash) { + other.implicitProducerHash.store(&other.initialImplicitProducerHash, + std::memory_order_relaxed); + } else { + ImplicitProducerHash *hash; + for (hash = other.implicitProducerHash.load(std::memory_order_relaxed); + hash->prev != &initialImplicitProducerHash; hash = hash->prev) { + continue; + } + hash->prev = &other.initialImplicitProducerHash; + } + } + + // Only fails (returns nullptr) if memory allocation fails + ImplicitProducer *get_or_add_implicit_producer() { + // Note that since the data is essentially thread-local (key is thread ID), + // there's a reduced need for fences (memory ordering is already consistent + // for any individual thread), except for the current table itself. + + // Start by looking for the thread ID in the current and all previous hash tables. + // If it's not found, it must not be in there yet, since this same thread would + // have added it previously to one of the tables that we traversed. + + // Code and algorithm adapted from http://preshing.com/20130605/the-worlds-simplest-lock-free-hash-table + +#if MCDBGQ_NOLOCKFREE_IMPLICITPRODHASH + debug::DebugLock lock(implicitProdMutex); +#endif + + auto id = details::thread_id(); + auto hashedId = details::hash_thread_id(id); + + auto mainHash = implicitProducerHash.load(std::memory_order_acquire); + for (auto hash = mainHash; hash != nullptr; hash = hash->prev) { + // Look for the id in this hash + auto index = hashedId; + while (true) { // Not an infinite loop because at least one slot is free in the hash table + index &= hash->capacity - 1; + + auto probedKey = hash->entries[index].key.load(std::memory_order_relaxed); + if (probedKey == id) { + // Found it! If we had to search several hashes deep, though, we should lazily add it + // to the current main hash table to avoid the extended search next time. + // Note there's guaranteed to be room in the current hash table since every subsequent + // table implicitly reserves space for all previous tables (there's only one + // implicitProducerHashCount). + auto value = hash->entries[index].value; + if (hash != mainHash) { + index = hashedId; + while (true) { + index &= mainHash->capacity - 1; + probedKey = mainHash->entries[index].key.load(std::memory_order_relaxed); + auto empty = details::invalid_thread_id; +#ifdef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED + auto reusable = details::invalid_thread_id2; + if ((probedKey == empty && mainHash->entries[index].key.compare_exchange_strong(empty, id, std::memory_order_relaxed, std::memory_order_relaxed)) || + (probedKey == reusable && mainHash->entries[index].key.compare_exchange_strong(reusable, id, std::memory_order_acquire, std::memory_order_acquire))) { +#else + if ((probedKey == empty && + mainHash->entries[index].key.compare_exchange_strong(empty, id, + std::memory_order_relaxed, + std::memory_order_relaxed))) { +#endif + mainHash->entries[index].value = value; + break; + } + ++index; + } + } + + return value; + } + if (probedKey == details::invalid_thread_id) { + break; // Not in this hash table + } + ++index; + } + } + + // Insert! + auto newCount = 1 + implicitProducerHashCount.fetch_add(1, std::memory_order_relaxed); + while (true) { + if (newCount >= (mainHash->capacity >> 1) && + !implicitProducerHashResizeInProgress.test_and_set(std::memory_order_acquire)) { + // We've acquired the resize lock, try to allocate a bigger hash table. + // Note the acquire fence synchronizes with the release fence at the end of this block, and hence when + // we reload implicitProducerHash it must be the most recent version (it only gets changed within this + // locked block). + mainHash = implicitProducerHash.load(std::memory_order_acquire); + if (newCount >= (mainHash->capacity >> 1)) { + auto newCapacity = mainHash->capacity << 1; + while (newCount >= (newCapacity >> 1)) { + newCapacity <<= 1; + } + auto raw = static_cast((Traits::malloc)( + sizeof(ImplicitProducerHash) + std::alignment_of::value - 1 + + sizeof(ImplicitProducerKVP) * newCapacity)); + if (raw == nullptr) { + // Allocation failed + implicitProducerHashCount.fetch_sub(1, std::memory_order_relaxed); + implicitProducerHashResizeInProgress.clear(std::memory_order_relaxed); + return nullptr; + } + + auto newHash = new(raw) ImplicitProducerHash; + newHash->capacity = newCapacity; + newHash->entries = reinterpret_cast(details::align_for( + raw + sizeof(ImplicitProducerHash))); + for (size_t i = 0; i != newCapacity; ++i) { + new(newHash->entries + i) ImplicitProducerKVP; + newHash->entries[i].key.store(details::invalid_thread_id, std::memory_order_relaxed); + } + newHash->prev = mainHash; + implicitProducerHash.store(newHash, std::memory_order_release); + implicitProducerHashResizeInProgress.clear(std::memory_order_release); + mainHash = newHash; + } else { + implicitProducerHashResizeInProgress.clear(std::memory_order_release); + } + } + + // If it's < three-quarters full, add to the old one anyway so that we don't have to wait for the next table + // to finish being allocated by another thread (and if we just finished allocating above, the condition will + // always be true) + if (newCount < (mainHash->capacity >> 1) + (mainHash->capacity >> 2)) { + bool recycled; + auto producer = static_cast(recycle_or_create_producer(false, + recycled)); + if (producer == nullptr) { + implicitProducerHashCount.fetch_sub(1, std::memory_order_relaxed); + return nullptr; + } + if (recycled) { + implicitProducerHashCount.fetch_sub(1, std::memory_order_relaxed); + } + +#ifdef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED + producer->threadExitListener.callback = &ConcurrentQueue::implicit_producer_thread_exited_callback; + producer->threadExitListener.userData = producer; + details::ThreadExitNotifier::subscribe(&producer->threadExitListener); +#endif + + auto index = hashedId; + while (true) { + index &= mainHash->capacity - 1; + auto probedKey = mainHash->entries[index].key.load(std::memory_order_relaxed); + + auto empty = details::invalid_thread_id; +#ifdef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED + auto reusable = details::invalid_thread_id2; + if ((probedKey == empty && mainHash->entries[index].key.compare_exchange_strong(empty, id, std::memory_order_relaxed, std::memory_order_relaxed)) || + (probedKey == reusable && mainHash->entries[index].key.compare_exchange_strong(reusable, id, std::memory_order_acquire, std::memory_order_acquire))) { +#else + if ((probedKey == empty && mainHash->entries[index].key.compare_exchange_strong(empty, id, + std::memory_order_relaxed, + std::memory_order_relaxed))) { +#endif + mainHash->entries[index].value = producer; + break; + } + ++index; + } + return producer; + } + + // Hmm, the old hash is quite full and somebody else is busy allocating a new one. + // We need to wait for the allocating thread to finish (if it succeeds, we add, if not, + // we try to allocate ourselves). + mainHash = implicitProducerHash.load(std::memory_order_acquire); + } + } + +#ifdef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED + void implicit_producer_thread_exited(ImplicitProducer* producer) + { + // Remove from thread exit listeners + details::ThreadExitNotifier::unsubscribe(&producer->threadExitListener); + + // Remove from hash +#if MCDBGQ_NOLOCKFREE_IMPLICITPRODHASH + debug::DebugLock lock(implicitProdMutex); +#endif + auto hash = implicitProducerHash.load(std::memory_order_acquire); + assert(hash != nullptr); // The thread exit listener is only registered if we were added to a hash in the first place + auto id = details::thread_id(); + auto hashedId = details::hash_thread_id(id); + details::thread_id_t probedKey; + + // We need to traverse all the hashes just in case other threads aren't on the current one yet and are + // trying to add an entry thinking there's a free slot (because they reused a producer) + for (; hash != nullptr; hash = hash->prev) { + auto index = hashedId; + do { + index &= hash->capacity - 1; + probedKey = hash->entries[index].key.load(std::memory_order_relaxed); + if (probedKey == id) { + hash->entries[index].key.store(details::invalid_thread_id2, std::memory_order_release); + break; + } + ++index; + } while (probedKey != details::invalid_thread_id); // Can happen if the hash has changed but we weren't put back in it yet, or if we weren't added to this hash in the first place + } + + // Mark the queue as being recyclable + producer->inactive.store(true, std::memory_order_release); + } + + static void implicit_producer_thread_exited_callback(void* userData) + { + auto producer = static_cast(userData); + auto queue = producer->parent; + queue->implicit_producer_thread_exited(producer); + } +#endif + + ////////////////////////////////// + // Utility functions + ////////////////////////////////// + + template + static inline U *create_array(size_t count) { + assert(count > 0); + auto p = static_cast((Traits::malloc)(sizeof(U) * count)); + if (p == nullptr) { + return nullptr; + } + + for (size_t i = 0; i != count; ++i) { + new(p + i) U(); + } + return p; + } + + template + static inline void destroy_array(U *p, size_t count) { + if (p != nullptr) { + assert(count > 0); + for (size_t i = count; i != 0;) { + (p + --i)->~U(); + } + (Traits::free)(p); + } + } + + template + static inline U *create() { + auto p = (Traits::malloc)(sizeof(U)); + return p != nullptr ? new(p) U : nullptr; + } + + template + static inline U *create(A1 &&a1) { + auto p = (Traits::malloc)(sizeof(U)); + return p != nullptr ? new(p) U(std::forward(a1)) : nullptr; + } + + template + static inline void destroy(U *p) { + if (p != nullptr) { + p->~U(); + } + (Traits::free)(p); + } + + private: + std::atomic producerListTail; + std::atomic producerCount; + + std::atomic initialBlockPoolIndex; + Block *initialBlockPool; + size_t initialBlockPoolSize; + +#if !MCDBGQ_USEDEBUGFREELIST + FreeList freeList; +#else + debug::DebugFreeList freeList; +#endif + + std::atomic implicitProducerHash; + std::atomic implicitProducerHashCount; // Number of slots logically used + ImplicitProducerHash initialImplicitProducerHash; + std::array initialImplicitProducerHashEntries; + std::atomic_flag implicitProducerHashResizeInProgress; + + std::atomic nextExplicitConsumerId; + std::atomic globalExplicitConsumerOffset; + +#if MCDBGQ_NOLOCKFREE_IMPLICITPRODHASH + debug::DebugMutex implicitProdMutex; +#endif + +#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG + std::atomic explicitProducers; + std::atomic implicitProducers; +#endif +}; + + +template +ProducerToken::ProducerToken(ConcurrentQueue &queue) + : producer(queue.recycle_or_create_producer(true)) { + if (producer != nullptr) { + producer->token = this; + } +} + +template +ProducerToken::ProducerToken(BlockingConcurrentQueue &queue) + : producer( + reinterpret_cast *>(&queue)->recycle_or_create_producer(true)) { + if (producer != nullptr) { + producer->token = this; + } +} + +template +ConsumerToken::ConsumerToken(ConcurrentQueue &queue) + : itemsConsumedFromCurrent(0), currentProducer(nullptr), desiredProducer(nullptr) { + initialOffset = queue.nextExplicitConsumerId.fetch_add(1, std::memory_order_release); + lastKnownGlobalOffset = -1; +} + +template +ConsumerToken::ConsumerToken(BlockingConcurrentQueue &queue) + : itemsConsumedFromCurrent(0), currentProducer(nullptr), desiredProducer(nullptr) { + initialOffset = reinterpret_cast *>(&queue)->nextExplicitConsumerId.fetch_add( + 1, std::memory_order_release); + lastKnownGlobalOffset = -1; +} + +template +inline void swap(ConcurrentQueue &a, ConcurrentQueue &b) MOODYCAMEL_NOEXCEPT { + a.swap(b); +} + +inline void swap(ProducerToken &a, ProducerToken &b) MOODYCAMEL_NOEXCEPT { + a.swap(b); +} + +inline void swap(ConsumerToken &a, ConsumerToken &b) MOODYCAMEL_NOEXCEPT { + a.swap(b); +} + +template +inline void swap(typename ConcurrentQueue::ImplicitProducerKVP &a, + typename ConcurrentQueue::ImplicitProducerKVP &b) MOODYCAMEL_NOEXCEPT { + a.swap(b); +} + +} + +} // namespace dmlc + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + +#endif // DMLC_CONCURRENTQUEUE_H_ +//! \endcond Doxygen_Suppress diff --git a/include/dmlc/config.h b/include/dmlc/config.h new file mode 100644 index 000000000000..a4c5b53d827d --- /dev/null +++ b/include/dmlc/config.h @@ -0,0 +1,186 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file config.h + * \brief defines config parser class + */ +#ifndef DMLC_CONFIG_H_ +#define DMLC_CONFIG_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +/*! \brief namespace for dmlc */ +namespace dmlc { + +/*! + * \brief class for config parser + * + * Two modes are supported: + * 1. non-multi value mode: if two same keys in the configure file, the later one will replace the + * ealier one; when using iterator, the order will be the "last effective insersion" order + * 2. multi value mode: multiple values with the same key could co-exist; when using iterator, the + * order will be the insersion order. + * + * [Basic usage] + * + * Config cfg(file_input_stream); + * for(Config::ConfigIterator iter = cfg.begin(); iter != cfg.end(); ++iter) { + * ConfigEntry ent = *iter; + * std::string key = ent.first; + * std::string value = ent.second; + * do_something_with(key, value); + * } + */ +class Config { + public: + /*! + * \brief type when extracting from iterator + */ + typedef std::pair ConfigEntry; + + /*! + * \brief iterator class + */ + class ConfigIterator; + + /*! + * \brief create empty config + * \param multi_value whether the config supports multi value + */ + explicit Config(bool multi_value = false); + /*! + * \brief create config and load content from the given stream + * \param is input stream + * \param multi_value whether the config supports multi value + */ + explicit Config(std::istream& is, bool multi_value = false); // NOLINT(*) + /*! + * \brief clear all the values + */ + void Clear(void); + /*! + * \brief load the contents from the stream + * \param is the stream as input + */ + void LoadFromStream(std::istream& is); // NOLINT(*) + /*! + * \brief set a key-value pair into the config; if the key already exists in the configure file, + * it will either replace the old value with the given one (in non-multi value mode) or + * store it directly (in multi-value mode); + * \param key key + * \param value value + * \param is_string whether the value should be wrapped by quotes in proto string + */ + template + void SetParam(const std::string& key, const T& value, bool is_string = false); + + /*! + * \brief get the config under the key; if multiple values exist for the same key, + * return the last inserted one. + * \param key key + * \return config value + */ + const std::string& GetParam(const std::string& key) const; + + /*! + * \brief check whether the configure value given by the key should be wrapped by quotes + * \param key key + * \return whether the configure value is represented by string + */ + bool IsGenuineString(const std::string& key) const; + + /*! + * \brief transform all the configuration into string recognizable to protobuf + * \return string that could be parsed directly by protobuf + */ + std::string ToProtoString(void) const; + + /*! + * \brief get begin iterator + * \return begin iterator + */ + ConfigIterator begin() const; + + /*! + * \brief get end iterator + * \return end iterator + */ + ConfigIterator end() const; + + public: + /*! + * \brief iterator class + */ + class ConfigIterator : public std::iterator< std::input_iterator_tag, ConfigEntry > { + friend class Config; + public: + /*! + * \brief copy constructor + */ + ConfigIterator(const ConfigIterator& other); + /*! + * \brief uni-increment operators + * \return the reference of current config + */ + ConfigIterator& operator++(); + /*! + * \brief uni-increment operators + * \return the reference of current config + */ + ConfigIterator operator++(int); // NOLINT(*) + /*! + * \brief compare operators + * \param rhs the other config to compare against + * \return the compared result + */ + bool operator == (const ConfigIterator& rhs) const; + /*! + * \brief compare operators not equal + * \param rhs the other config to compare against + * \return the compared result + */ + bool operator != (const ConfigIterator& rhs) const; + /*! + * \brief retrieve value from operator + */ + ConfigEntry operator * () const; + + private: + ConfigIterator(size_t index, const Config* config); + void FindNextIndex(); + + private: + size_t index_; + const Config* config_; + }; + + private: + struct ConfigValue { + std::vector val; + std::vector insert_index; + bool is_string; + }; + void Insert(const std::string& key, const std::string& value, bool is_string); + + private: + std::map config_map_; + std::vector > order_; + const bool multi_value_; +}; + +template +void Config::SetParam(const std::string& key, const T& value, bool is_string) { + std::ostringstream oss; + oss << value; + Insert(key, oss.str(), is_string); +} + +} // namespace dmlc + +#endif // DMLC_CONFIG_H_ diff --git a/include/dmlc/data.h b/include/dmlc/data.h new file mode 100644 index 000000000000..16e0667322fb --- /dev/null +++ b/include/dmlc/data.h @@ -0,0 +1,397 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file data.h + * \brief defines common input data structure, + * and interface for handling the input data + */ +#ifndef DMLC_DATA_H_ +#define DMLC_DATA_H_ + +#include +#include +#include +#include "./base.h" +#include "./io.h" +#include "./logging.h" +#include "./registry.h" + +// To help C Preprocessor with processing c++ templated types +#define __DMLC_COMMA , + +namespace dmlc { +/*! + * \brief this defines the float point + * that will be used to store feature values + */ +typedef float real_t; + +/*! + * \brief this defines the unsigned integer type + * that can normally be used to store feature index + */ +typedef unsigned index_t; + +// This file describes common data structure that can be used +// for large-scale machine learning, this may not be a complete list +// But we will keep the most common and useful ones, and keep adding new ones +/*! + * \brief data iterator interface + * this is not a C++ style iterator, but nice for data pulling:) + * This interface is used to pull in the data + * The system can do some useful tricks for you like pre-fetching + * from disk and pre-computation. + * + * Usage example: + * \code + * + * itr->BeforeFirst(); + * while (itr->Next()) { + * const DType &batch = itr->Value(); + * // some computations + * } + * \endcode + * \tparam DType the data type + */ +template +class DataIter { + public: + /*! \brief destructor */ + virtual ~DataIter(void) {} + /*! \brief set before first of the item */ + virtual void BeforeFirst(void) = 0; + /*! \brief move to next item */ + virtual bool Next(void) = 0; + /*! \brief get current data */ + virtual const DType &Value(void) const = 0; +}; + +/*! + * \brief one row of training instance + * \tparam IndexType type of index + * \tparam DType type of data (both label and value will be of DType + */ +template +class Row { + public: + /*! \brief label of the instance */ + const DType *label; + /*! \brief weight of the instance */ + const real_t *weight; + /*! \brief session-id of the instance */ + const uint64_t *qid; + /*! \brief length of the sparse vector */ + size_t length; + /*! + * \brief field of each instance + */ + const IndexType *field; + /*! + * \brief index of each instance + */ + const IndexType *index; + /*! + * \brief array value of each instance, this can be NULL + * indicating every value is set to be 1 + */ + const DType *value; + /*! + * \param i the input index + * \return field for i-th feature + */ + inline IndexType get_field(size_t i) const { + return field[i]; + } + /*! + * \param i the input index + * \return i-th feature + */ + inline IndexType get_index(size_t i) const { + return index[i]; + } + /*! + * \param i the input index + * \return i-th feature value, this function is always + * safe even when value == NULL + */ + inline DType get_value(size_t i) const { + return value == NULL ? DType(1.0f) : value[i]; + } + /*! + * \return the label of the instance + */ + inline DType get_label() const { + return *label; + } + /*! + * \return the weight of the instance, this function is always + * safe even when weight == NULL + */ + inline real_t get_weight() const { + return weight == NULL ? 1.0f : *weight; + } + /*! + * \return the qid of the instance, this function is always + * safe even when qid == NULL + */ + inline uint64_t get_qid() const { + return qid == NULL ? 0 : *qid; + } + /*! + * \brief helper function to compute dot product of current + * \param weight the dense array of weight we want to product + * \param size the size of the weight vector + * \tparam V type of the weight vector + * \return the result of dot product + */ + template + inline V SDot(const V *weight, size_t size) const { + V sum = static_cast(0); + if (value == NULL) { + for (size_t i = 0; i < length; ++i) { + CHECK(index[i] < size) << "feature index exceed bound"; + sum += weight[index[i]]; + } + } else { + for (size_t i = 0; i < length; ++i) { + CHECK(index[i] < size) << "feature index exceed bound"; + sum += weight[index[i]] * value[i]; + } + } + return sum; + } +}; + +/*! + * \brief a block of data, containing several rows in sparse matrix + * This is useful for (streaming-sxtyle) algorithms that scans through rows of data + * examples include: SGD, GD, L-BFGS, kmeans + * + * The size of batch is usually large enough so that parallelizing over the rows + * can give significant speedup + * \tparam IndexType type to store the index used in row batch + * \tparam DType type to store the label and value used in row batch + */ +template +struct RowBlock { + /*! \brief batch size */ + size_t size; + /*! \brief array[size+1], row pointer to beginning of each rows */ + const size_t *offset; + /*! \brief array[size] label of each instance */ + const DType *label; + /*! \brief With weight: array[size] label of each instance, otherwise nullptr */ + const real_t *weight; + /*! \brief With qid: array[size] session id of each instance, otherwise nullptr */ + const uint64_t *qid; + /*! \brief field id*/ + const IndexType *field; + /*! \brief feature index */ + const IndexType *index; + /*! \brief feature value, can be NULL, indicating all values are 1 */ + const DType *value; + /*! + * \brief get specific rows in the batch + * \param rowid the rowid in that row + * \return the instance corresponding to the row + */ + inline Row operator[](size_t rowid) const; + /*! \return memory cost of the block in bytes */ + inline size_t MemCostBytes(void) const { + size_t cost = size * (sizeof(size_t) + sizeof(DType)); + if (weight != NULL) cost += size * sizeof(real_t); + if (qid != NULL) cost += size * sizeof(size_t); + size_t ndata = offset[size] - offset[0]; + if (field != NULL) cost += ndata * sizeof(IndexType); + if (index != NULL) cost += ndata * sizeof(IndexType); + if (value != NULL) cost += ndata * sizeof(DType); + return cost; + } + /*! + * \brief slice a RowBlock to get rows in [begin, end) + * \param begin the begin row index + * \param end the end row index + * \return the sliced RowBlock + */ + inline RowBlock Slice(size_t begin, size_t end) const { + CHECK(begin <= end && end <= size); + RowBlock ret; + ret.size = end - begin; + ret.label = label + begin; + if (weight != NULL) { + ret.weight = weight + begin; + } else { + ret.weight = NULL; + } + if (qid != NULL) { + ret.qid = qid + begin; + } else { + ret.qid = NULL; + } + ret.offset = offset + begin; + ret.field = field; + ret.index = index; + ret.value = value; + return ret; + } +}; + +/*! + * \brief Data structure that holds the data + * Row block iterator interface that gets RowBlocks + * Difference between RowBlockIter and Parser: + * RowBlockIter caches the data internally that can be used + * to iterate the dataset multiple times, + * Parser holds very limited internal state and was usually + * used to read data only once + * + * \sa Parser + * \tparam IndexType type of index in RowBlock + * \tparam DType type of label and value in RowBlock + * Create function was only implemented for IndexType uint64_t and uint32_t + * and DType real_t and int + */ +template +class RowBlockIter : public DataIter > { + public: + /*! + * \brief create a new instance of iterator that returns rowbatch + * by default, a in-memory based iterator will be returned + * + * \param uri the uri of the input, can contain hdfs prefix + * \param part_index the part id of current input + * \param num_parts total number of splits + * \param type type of dataset can be: "libsvm", ... + * + * \return the created data iterator + */ + static RowBlockIter * + Create(const char *uri, + unsigned part_index, + unsigned num_parts, + const char *type); + /*! \return maximum feature dimension in the dataset */ + virtual size_t NumCol() const = 0; +}; + +/*! + * \brief parser interface that parses input data + * used to load dmlc data format into your own data format + * Difference between RowBlockIter and Parser: + * RowBlockIter caches the data internally that can be used + * to iterate the dataset multiple times, + * Parser holds very limited internal state and was usually + * used to read data only once + * + * + * \sa RowBlockIter + * \tparam IndexType type of index in RowBlock + * \tparam DType type of label and value in RowBlock + * Create function was only implemented for IndexType uint64_t and uint32_t + * and DType real_t and int + */ +template +class Parser : public DataIter > { + public: + /*! + * \brief create a new instance of parser based on the "type" + * + * \param uri_ the uri of the input, can contain hdfs prefix + * \param part_index the part id of current input + * \param num_parts total number of splits + * \param type type of dataset can be: "libsvm", "auto", ... + * + * When "auto" is passed, the type is decided by format argument string in URI. + * + * \return the created parser + */ + static Parser * + Create(const char *uri_, + unsigned part_index, + unsigned num_parts, + const char *type); + /*! \return size of bytes read so far */ + virtual size_t BytesRead(void) const = 0; + /*! \brief Factory type of the parser*/ + typedef Parser* (*Factory) + (const std::string& path, + const std::map& args, + unsigned part_index, + unsigned num_parts); +}; + +/*! + * \brief registry entry of parser factory + * \tparam IndexType The type of index + * \tparam DType The type of label and value + */ +template +struct ParserFactoryReg + : public FunctionRegEntryBase, + typename Parser::Factory> {}; + +/*! + * \brief Register a new distributed parser to dmlc-core. + * + * \param IndexType The type of Batch index, can be uint32_t or uint64_t + * \param DataType The type of Batch label and value, can be real_t or int + * \param TypeName The typename of of the data. + * \param FactoryFunction The factory function that creates the parser. + * + * \begincode + * + * // define the factory function + * template + * Parser* + * CreateLibSVMParser(const char* uri, unsigned part_index, unsigned num_parts) { + * return new LibSVMParser(uri, part_index, num_parts); + * } + * + * // Register it to DMLC + * // Then we can use Parser::Create(uri, part_index, num_parts, "libsvm"); + * // to create the parser + * + * DMLC_REGISTER_DATA_PARSER(uint32_t, real_t, libsvm, CreateLibSVMParser); + * DMLC_REGISTER_DATA_PARSER(uint64_t, real_t, libsvm, CreateLibSVMParser); + * + * \endcode + */ +#define DMLC_REGISTER_DATA_PARSER(IndexType, DataType, TypeName, FactoryFunction) \ + DMLC_REGISTRY_REGISTER(ParserFactoryReg, \ + ParserFactoryReg ## _ ## IndexType ## _ ## DataType, TypeName) \ + .set_body(FactoryFunction) + + +// implementation of operator[] +template +inline Row +RowBlock::operator[](size_t rowid) const { + CHECK(rowid < size); + Row inst; + inst.label = label + rowid; + if (weight != NULL) { + inst.weight = weight + rowid; + } else { + inst.weight = NULL; + } + if (qid != NULL) { + inst.qid = qid + rowid; + } else { + inst.qid = NULL; + } + inst.length = offset[rowid + 1] - offset[rowid]; + if (field != NULL) { + inst.field = field + offset[rowid]; + } else { + inst.field = NULL; + } + inst.index = index + offset[rowid]; + if (value == NULL) { + inst.value = NULL; + } else { + inst.value = value + offset[rowid]; + } + return inst; +} + +} // namespace dmlc +#endif // DMLC_DATA_H_ diff --git a/include/dmlc/endian.h b/include/dmlc/endian.h new file mode 100644 index 000000000000..e7deeaa49034 --- /dev/null +++ b/include/dmlc/endian.h @@ -0,0 +1,44 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file endian.h + * \brief Endian testing, need c++11 + */ +#ifndef DMLC_ENDIAN_H_ +#define DMLC_ENDIAN_H_ + +#include "./base.h" + +#if defined(__APPLE__) || defined(_WIN32) +#define DMLC_LITTLE_ENDIAN 1 +#else +#include +#define DMLC_LITTLE_ENDIAN (__BYTE_ORDER == __LITTLE_ENDIAN) +#endif + +/*! \brief whether serialize using little endian */ +#define DMLC_IO_NO_ENDIAN_SWAP (DMLC_LITTLE_ENDIAN == DMLC_IO_USE_LITTLE_ENDIAN) + +namespace dmlc { + +/*! + * \brief A generic inplace byte swapping function. + * \param data The data pointer. + * \param elem_bytes The number of bytes of the data elements + * \param num_elems Number of elements in the data. + * \note Always try pass in constant elem_bytes to enable + * compiler optimization + */ +inline void ByteSwap(void* data, size_t elem_bytes, size_t num_elems) { + for (size_t i = 0; i < num_elems; ++i) { + uint8_t* bptr = reinterpret_cast(data) + elem_bytes * i; + for (size_t j = 0; j < elem_bytes / 2; ++j) { + uint8_t v = bptr[elem_bytes - 1 - j]; + bptr[elem_bytes - 1 - j] = bptr[j]; + bptr[j] = v; + } + } +} + +} // namespace dmlc +#endif // DMLC_ENDIAN_H_ + diff --git a/include/dmlc/input_split_shuffle.h b/include/dmlc/input_split_shuffle.h new file mode 100644 index 000000000000..fc2c65e0a91e --- /dev/null +++ b/include/dmlc/input_split_shuffle.h @@ -0,0 +1,168 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file input_split_shuffle.h + * \brief base class to construct input split with global shuffling + * \author Yifeng Geng + */ +#ifndef DMLC_INPUT_SPLIT_SHUFFLE_H_ +#define DMLC_INPUT_SPLIT_SHUFFLE_H_ + +#include +#include +#include +#include +#include +#include + +namespace dmlc { +/*! \brief class to construct input split with global shuffling */ +class InputSplitShuffle : public InputSplit { + public: + // destructor + virtual ~InputSplitShuffle(void) { source_.reset(); } + // implement BeforeFirst + virtual void BeforeFirst(void) { + if (num_shuffle_parts_ > 1) { + std::shuffle(shuffle_indexes_.begin(), shuffle_indexes_.end(), trnd_); + int idx = shuffle_indexes_[0] + part_index_ * num_shuffle_parts_; + source_->ResetPartition(idx, num_parts_ * num_shuffle_parts_); + cur_shuffle_idx_ = 0; + } else { + source_->BeforeFirst(); + } + } + virtual void HintChunkSize(size_t chunk_size) { + source_->HintChunkSize(chunk_size); + } + virtual size_t GetTotalSize(void) { + return source_->GetTotalSize(); + } + // implement next record + virtual bool NextRecord(Blob *out_rec) { + if (num_shuffle_parts_ > 1) { + if (!source_->NextRecord(out_rec)) { + if (cur_shuffle_idx_ == num_shuffle_parts_ - 1) { + return false; + } + ++cur_shuffle_idx_; + int idx = + shuffle_indexes_[cur_shuffle_idx_] + part_index_ * num_shuffle_parts_; + source_->ResetPartition(idx, num_parts_ * num_shuffle_parts_); + return NextRecord(out_rec); + } else { + return true; + } + } else { + return source_->NextRecord(out_rec); + } + } + // implement next chunk + virtual bool NextChunk(Blob* out_chunk) { + if (num_shuffle_parts_ > 1) { + if (!source_->NextChunk(out_chunk)) { + if (cur_shuffle_idx_ == num_shuffle_parts_ - 1) { + return false; + } + ++cur_shuffle_idx_; + int idx = + shuffle_indexes_[cur_shuffle_idx_] + part_index_ * num_shuffle_parts_; + source_->ResetPartition(idx, num_parts_ * num_shuffle_parts_); + return NextChunk(out_chunk); + } else { + return true; + } + } else { + return source_->NextChunk(out_chunk); + } + } + // implement ResetPartition. + virtual void ResetPartition(unsigned rank, unsigned nsplit) { + CHECK(nsplit == num_parts_) << "num_parts is not consistent!"; + int idx = shuffle_indexes_[0] + rank * num_shuffle_parts_; + source_->ResetPartition(idx, nsplit * num_shuffle_parts_); + cur_shuffle_idx_ = 0; + } + /*! + * \brief constructor + * \param uri the uri of the input, can contain hdfs prefix + * \param part_index the part id of current input + * \param num_parts total number of splits + * \param type type of record + * List of possible types: "text", "recordio" + * - "text": + * text file, each line is treated as a record + * input split will split on '\\n' or '\\r' + * - "recordio": + * binary recordio file, see recordio.h + * \param num_shuffle_parts number of shuffle chunks for each split + * \param shuffle_seed shuffle seed for chunk shuffling + */ + InputSplitShuffle(const char* uri, + unsigned part_index, + unsigned num_parts, + const char* type, + unsigned num_shuffle_parts, + int shuffle_seed) + : part_index_(part_index), + num_parts_(num_parts), + num_shuffle_parts_(num_shuffle_parts), + cur_shuffle_idx_(0) { + for (unsigned i = 0; i < num_shuffle_parts_; i++) { + shuffle_indexes_.push_back(i); + } + trnd_.seed(kRandMagic_ + part_index_ + num_parts_ + num_shuffle_parts_ + + shuffle_seed); + std::shuffle(shuffle_indexes_.begin(), shuffle_indexes_.end(), trnd_); + int idx = shuffle_indexes_[cur_shuffle_idx_] + part_index_ * num_shuffle_parts_; + source_.reset( + InputSplit::Create(uri, idx , num_parts_ * num_shuffle_parts_, type)); + } + /*! + * \brief factory function: + * create input split with chunk shuffling given a uri + * \param uri the uri of the input, can contain hdfs prefix + * \param part_index the part id of current input + * \param num_parts total number of splits + * \param type type of record + * List of possible types: "text", "recordio" + * - "text": + * text file, each line is treated as a record + * input split will split on '\\n' or '\\r' + * - "recordio": + * binary recordio file, see recordio.h + * \param num_shuffle_parts number of shuffle chunks for each split + * \param shuffle_seed shuffle seed for chunk shuffling + * \return a new input split + * \sa InputSplit::Type + */ + static InputSplit* Create(const char* uri, + unsigned part_index, + unsigned num_parts, + const char* type, + unsigned num_shuffle_parts, + int shuffle_seed) { + CHECK(num_shuffle_parts > 0) << "number of shuffle parts should be greater than zero!"; + return new InputSplitShuffle( + uri, part_index, num_parts, type, num_shuffle_parts, shuffle_seed); + } + + private: + // magic nyumber for seed + static const int kRandMagic_ = 666; + /*! \brief random engine */ + std::mt19937 trnd_; + /*! \brief inner inputsplit */ + std::unique_ptr source_; + /*! \brief part index */ + unsigned part_index_; + /*! \brief number of parts */ + unsigned num_parts_; + /*! \brief the number of block for shuffling*/ + unsigned num_shuffle_parts_; + /*! \brief current shuffle block index */ + unsigned cur_shuffle_idx_; + /*! \brief shuffled indexes */ + std::vector shuffle_indexes_; +}; +} // namespace dmlc +#endif // DMLC_INPUT_SPLIT_SHUFFLE_H_ diff --git a/include/dmlc/io.h b/include/dmlc/io.h new file mode 100644 index 000000000000..5e76e4c6e24c --- /dev/null +++ b/include/dmlc/io.h @@ -0,0 +1,522 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file io.h + * \brief defines serializable interface of dmlc + */ +#ifndef DMLC_IO_H_ +#define DMLC_IO_H_ +#include +#include +#include +#include +#include +#include +#include "./logging.h" + +// include uint64_t only to make io standalone +#ifdef _MSC_VER +/*! \brief uint64 */ +typedef unsigned __int64 uint64_t; +#else +#include +#endif + +/*! \brief namespace for dmlc */ +namespace dmlc { +/*! + * \brief interface of stream I/O for serialization + */ +class Stream { // NOLINT(*) + public: + /*! + * \brief reads data from a stream + * \param ptr pointer to a memory buffer + * \param size block size + * \return the size of data read + */ + virtual size_t Read(void *ptr, size_t size) = 0; + /*! + * \brief writes data to a stream + * \param ptr pointer to a memory buffer + * \param size block size + */ + virtual void Write(const void *ptr, size_t size) = 0; + /*! \brief virtual destructor */ + virtual ~Stream(void) {} + /*! + * \brief generic factory function + * create an stream, the stream will close the underlying files upon deletion + * + * \param uri the uri of the input currently we support + * hdfs://, s3://, and file:// by default file:// will be used + * \param flag can be "w", "r", "a" + * \param allow_null whether NULL can be returned, or directly report error + * \return the created stream, can be NULL when allow_null == true and file do not exist + */ + static Stream *Create(const char *uri, + const char* const flag, + bool allow_null = false); + // helper functions to write/read different data structures + /*! + * \brief writes a data to stream. + * + * dmlc::Stream support Write/Read of most STL composites and base types. + * If the data type is not supported, a compile time error will be issued. + * + * This function is endian-aware, + * the output endian defined by DMLC_IO_USE_LITTLE_ENDIAN + * + * \param data data to be written + * \tparam T the data type to be written + */ + template + inline void Write(const T &data); + /*! + * \brief loads a data from stream. + * + * dmlc::Stream support Write/Read of most STL composites and base types. + * If the data type is not supported, a compile time error will be issued. + * + * This function is endian-aware, + * the input endian defined by DMLC_IO_USE_LITTLE_ENDIAN + * + * \param out_data place holder of data to be deserialized + * \return whether the load was successful + */ + template + inline bool Read(T *out_data); + /*! + * \brief Endian aware write array of data. + * \param data The data pointer + * \param num_elems Number of elements + * \tparam T the data type. + */ + template + inline void WriteArray(const T* data, size_t num_elems); + /*! + * \brief Endian aware read array of data. + * \param data The data pointer + * \param num_elems Number of elements + * \tparam T the data type. + * \return whether the load was successful + */ + template + inline bool ReadArray(T* data, size_t num_elems); +}; + +/*! \brief interface of i/o stream that support seek */ +class SeekStream: public Stream { + public: + // virtual destructor + virtual ~SeekStream(void) {} + /*! \brief seek to certain position of the file */ + virtual void Seek(size_t pos) = 0; + /*! \brief tell the position of the stream */ + virtual size_t Tell(void) = 0; + /*! + * \brief generic factory function + * create an SeekStream for read only, + * the stream will close the underlying files upon deletion + * error will be reported and the system will exit when create failed + * \param uri the uri of the input currently we support + * hdfs://, s3://, and file:// by default file:// will be used + * \param allow_null whether NULL can be returned, or directly report error + * \return the created stream, can be NULL when allow_null == true and file do not exist + */ + static SeekStream *CreateForRead(const char *uri, + bool allow_null = false); +}; + +/*! \brief interface for serializable objects */ +class Serializable { + public: + /*! \brief virtual destructor */ + virtual ~Serializable() {} + /*! + * \brief load the model from a stream + * \param fi stream where to load the model from + */ + virtual void Load(Stream *fi) = 0; + /*! + * \brief saves the model to a stream + * \param fo stream where to save the model to + */ + virtual void Save(Stream *fo) const = 0; +}; + +/*! + * \brief input split creates that allows reading + * of records from split of data, + * independent part that covers all the dataset + * + * see InputSplit::Create for definition of record + */ +class InputSplit { + public: + /*! \brief a blob of memory region */ + struct Blob { + /*! \brief points to start of the memory region */ + void *dptr; + /*! \brief size of the memory region */ + size_t size; + }; + /*! + * \brief hint the inputsplit how large the chunk size + * it should return when implementing NextChunk + * this is a hint so may not be enforced, + * but InputSplit will try adjust its internal buffer + * size to the hinted value + * \param chunk_size the chunk size + */ + virtual void HintChunkSize(size_t chunk_size) {} + /*! \brief get the total size of the InputSplit */ + virtual size_t GetTotalSize(void) = 0; + /*! \brief reset the position of InputSplit to beginning */ + virtual void BeforeFirst(void) = 0; + /*! + * \brief get the next record, the returning value + * is valid until next call to NextRecord, NextChunk or NextBatch + * caller can modify the memory content of out_rec + * + * For text, out_rec contains a single line + * For recordio, out_rec contains one record content(with header striped) + * + * \param out_rec used to store the result + * \return true if we can successfully get next record + * false if we reached end of split + * \sa InputSplit::Create for definition of record + */ + virtual bool NextRecord(Blob *out_rec) = 0; + /*! + * \brief get a chunk of memory that can contain multiple records, + * the caller needs to parse the content of the resulting chunk, + * for text file, out_chunk can contain data of multiple lines + * for recordio, out_chunk can contain multiple records(including headers) + * + * This function ensures there won't be partial record in the chunk + * caller can modify the memory content of out_chunk, + * the memory is valid until next call to NextRecord, NextChunk or NextBatch + * + * Usually NextRecord is sufficient, NextChunk can be used by some + * multi-threaded parsers to parse the input content + * + * \param out_chunk used to store the result + * \return true if we can successfully get next record + * false if we reached end of split + * \sa InputSplit::Create for definition of record + * \sa RecordIOChunkReader to parse recordio content from out_chunk + */ + virtual bool NextChunk(Blob *out_chunk) = 0; + /*! + * \brief get a chunk of memory that can contain multiple records, + * with hint for how many records is needed, + * the caller needs to parse the content of the resulting chunk, + * for text file, out_chunk can contain data of multiple lines + * for recordio, out_chunk can contain multiple records(including headers) + * + * This function ensures there won't be partial record in the chunk + * caller can modify the memory content of out_chunk, + * the memory is valid until next call to NextRecord, NextChunk or NextBatch + * + * + * \param out_chunk used to store the result + * \param n_records used as a hint for how many records should be returned, may be ignored + * \return true if we can successfully get next record + * false if we reached end of split + * \sa InputSplit::Create for definition of record + * \sa RecordIOChunkReader to parse recordio content from out_chunk + */ + virtual bool NextBatch(Blob *out_chunk, size_t n_records) { + return NextChunk(out_chunk); + } + /*! \brief destructor*/ + virtual ~InputSplit(void) {} + /*! + * \brief reset the Input split to a certain part id, + * The InputSplit will be pointed to the head of the new specified segment. + * This feature may not be supported by every implementation of InputSplit. + * \param part_index The part id of the new input. + * \param num_parts The total number of parts. + */ + virtual void ResetPartition(unsigned part_index, unsigned num_parts) = 0; + /*! + * \brief factory function: + * create input split given a uri + * \param uri the uri of the input, can contain hdfs prefix + * \param part_index the part id of current input + * \param num_parts total number of splits + * \param type type of record + * List of possible types: "text", "recordio", "indexed_recordio" + * - "text": + * text file, each line is treated as a record + * input split will split on '\\n' or '\\r' + * - "recordio": + * binary recordio file, see recordio.h + * - "indexed_recordio": + * binary recordio file with index, see recordio.h + * \return a new input split + * \sa InputSplit::Type + */ + static InputSplit* Create(const char *uri, + unsigned part_index, + unsigned num_parts, + const char *type); + /*! + * \brief factory function: + * create input split given a uri for input and index + * \param uri the uri of the input, can contain hdfs prefix + * \param index_uri the uri of the index, can contain hdfs prefix + * \param part_index the part id of current input + * \param num_parts total number of splits + * \param type type of record + * List of possible types: "text", "recordio", "indexed_recordio" + * - "text": + * text file, each line is treated as a record + * input split will split on '\\n' or '\\r' + * - "recordio": + * binary recordio file, see recordio.h + * - "indexed_recordio": + * binary recordio file with index, see recordio.h + * \param shuffle whether to shuffle the output from the InputSplit, + * supported only by "indexed_recordio" type. + * Defaults to "false" + * \param seed random seed to use in conjunction with the "shuffle" + * option. Defaults to 0 + * \param batch_size a hint to InputSplit what is the intended number + * of examples return per batch. Used only by + * "indexed_recordio" type + * \param recurse_directories whether to recursively traverse directories + * \return a new input split + * \sa InputSplit::Type + */ + static InputSplit* Create(const char *uri, + const char *index_uri, + unsigned part_index, + unsigned num_parts, + const char *type, + const bool shuffle = false, + const int seed = 0, + const size_t batch_size = 256, + const bool recurse_directories = false); +}; + +#ifndef _LIBCPP_SGX_NO_IOSTREAMS +/*! + * \brief a std::ostream class that can can wrap Stream objects, + * can use ostream with that output to underlying Stream + * + * Usage example: + * \code + * + * Stream *fs = Stream::Create("hdfs:///test.txt", "w"); + * dmlc::ostream os(fs); + * os << "hello world" << std::endl; + * delete fs; + * \endcode + */ +class ostream : public std::basic_ostream { + public: + /*! + * \brief construct std::ostream type + * \param stream the Stream output to be used + * \param buffer_size internal streambuf size + */ + explicit ostream(Stream *stream, + size_t buffer_size = (1 << 10)) + : std::basic_ostream(NULL), buf_(buffer_size) { + this->set_stream(stream); + } + // explictly synchronize the buffer + virtual ~ostream() DMLC_NO_EXCEPTION { + buf_.pubsync(); + } + /*! + * \brief set internal stream to be stream, reset states + * \param stream new stream as output + */ + inline void set_stream(Stream *stream) { + buf_.set_stream(stream); + this->rdbuf(&buf_); + } + + /*! \return how many bytes we written so far */ + inline size_t bytes_written(void) const { + return buf_.bytes_out(); + } + + private: + // internal streambuf + class OutBuf : public std::streambuf { + public: + explicit OutBuf(size_t buffer_size) + : stream_(NULL), buffer_(buffer_size), bytes_out_(0) { + if (buffer_size == 0) buffer_.resize(2); + } + // set stream to the buffer + inline void set_stream(Stream *stream); + + inline size_t bytes_out() const { return bytes_out_; } + private: + /*! \brief internal stream by StreamBuf */ + Stream *stream_; + /*! \brief internal buffer */ + std::vector buffer_; + /*! \brief number of bytes written so far */ + size_t bytes_out_; + // override sync + inline int_type sync(void); + // override overflow + inline int_type overflow(int c); + }; + /*! \brief buffer of the stream */ + OutBuf buf_; +}; + +/*! + * \brief a std::istream class that can can wrap Stream objects, + * can use istream with that output to underlying Stream + * + * Usage example: + * \code + * + * Stream *fs = Stream::Create("hdfs:///test.txt", "r"); + * dmlc::istream is(fs); + * is >> mydata; + * delete fs; + * \endcode + */ +class istream : public std::basic_istream { + public: + /*! + * \brief construct std::ostream type + * \param stream the Stream output to be used + * \param buffer_size internal buffer size + */ + explicit istream(Stream *stream, + size_t buffer_size = (1 << 10)) + : std::basic_istream(NULL), buf_(buffer_size) { + this->set_stream(stream); + } + virtual ~istream() DMLC_NO_EXCEPTION {} + /*! + * \brief set internal stream to be stream, reset states + * \param stream new stream as output + */ + inline void set_stream(Stream *stream) { + buf_.set_stream(stream); + this->rdbuf(&buf_); + } + /*! \return how many bytes we read so far */ + inline size_t bytes_read(void) const { + return buf_.bytes_read(); + } + + private: + // internal streambuf + class InBuf : public std::streambuf { + public: + explicit InBuf(size_t buffer_size) + : stream_(NULL), bytes_read_(0), + buffer_(buffer_size) { + if (buffer_size == 0) buffer_.resize(2); + } + // set stream to the buffer + inline void set_stream(Stream *stream); + // return how many bytes read so far + inline size_t bytes_read(void) const { + return bytes_read_; + } + private: + /*! \brief internal stream by StreamBuf */ + Stream *stream_; + /*! \brief how many bytes we read so far */ + size_t bytes_read_; + /*! \brief internal buffer */ + std::vector buffer_; + // override underflow + inline int_type underflow(); + }; + /*! \brief input buffer */ + InBuf buf_; +}; +#endif +} // namespace dmlc + +#include "./serializer.h" + +namespace dmlc { +// implementations of inline functions +template +inline void Stream::Write(const T &data) { + serializer::Handler::Write(this, data); +} +template +inline bool Stream::Read(T *out_data) { + return serializer::Handler::Read(this, out_data); +} + +template +inline void Stream::WriteArray(const T* data, size_t num_elems) { + for (size_t i = 0; i < num_elems; ++i) { + this->Write(data[i]); + } +} + +template +inline bool Stream::ReadArray(T* data, size_t num_elems) { + for (size_t i = 0; i < num_elems; ++i) { + if (!this->Read(data + i)) return false; + } + return true; +} + +#ifndef _LIBCPP_SGX_NO_IOSTREAMS +// implementations for ostream +inline void ostream::OutBuf::set_stream(Stream *stream) { + if (stream_ != NULL) this->pubsync(); + this->stream_ = stream; + this->setp(&buffer_[0], &buffer_[0] + buffer_.size() - 1); +} +inline int ostream::OutBuf::sync(void) { + if (stream_ == NULL) return -1; + std::ptrdiff_t n = pptr() - pbase(); + stream_->Write(pbase(), n); + this->pbump(-static_cast(n)); + bytes_out_ += n; + return 0; +} +inline int ostream::OutBuf::overflow(int c) { + *(this->pptr()) = c; + std::ptrdiff_t n = pptr() - pbase(); + this->pbump(-static_cast(n)); + if (c == EOF) { + stream_->Write(pbase(), n); + bytes_out_ += n; + } else { + stream_->Write(pbase(), n + 1); + bytes_out_ += n + 1; + } + return c; +} + +// implementations for istream +inline void istream::InBuf::set_stream(Stream *stream) { + stream_ = stream; + this->setg(&buffer_[0], &buffer_[0], &buffer_[0]); +} +inline int istream::InBuf::underflow() { + char *bhead = &buffer_[0]; + if (this->gptr() == this->egptr()) { + size_t sz = stream_->Read(bhead, buffer_.size()); + this->setg(bhead, bhead, bhead + sz); + bytes_read_ += sz; + } + if (this->gptr() == this->egptr()) { + return traits_type::eof(); + } else { + return traits_type::to_int_type(*gptr()); + } +} +#endif +} // namespace dmlc +#endif // DMLC_IO_H_ diff --git a/include/dmlc/json.h b/include/dmlc/json.h new file mode 100644 index 000000000000..ef82dfb57aa7 --- /dev/null +++ b/include/dmlc/json.h @@ -0,0 +1,981 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file json.h + * \brief Lightweight JSON Reader/Writer that read save into C++ data structs. + * This includes STL composites and structures. + */ +#ifndef DMLC_JSON_H_ +#define DMLC_JSON_H_ + +// This code requires C++11 to compile +#include +#ifndef _LIBCPP_SGX_NO_IOSTREAMS +#include +#endif +#include +#include +#include +#include +#include +#include + +#include "./base.h" +#include "./logging.h" +#include "./type_traits.h" + +#if DMLC_USE_CXX11 +#include +#include +#include +#if DMLC_STRICT_CXX11 +#if DMLC_ENABLE_RTTI +#include "./any.h" +#endif // DMLC_ENABLE_RTTI +#endif // DMLC_STRICT_CXX11 +#endif // DMLC_USE_CXX11 + +namespace dmlc { +/*! + * \brief Lightweight JSON Reader to read any STL compositions and structs. + * The user need to know the schema of the + * + */ +class JSONReader { + public: + /*! + * \brief Constructor. + * \param is the input source. + */ +#ifndef _LIBCPP_SGX_NO_IOSTREAMS + explicit JSONReader(std::istream *is) +#else + explicit JSONReader(std::string *is) +#endif + : is_(is), + line_count_r_(0), + line_count_n_(0) {} + /*! + * \brief Parse next JSON string. + * \param out_str the output string. + * \throw dmlc::Error when next token is not string + */ + inline void ReadString(std::string *out_str); + /*! + * \brief Read Number. + * \param out_value output value; + * \throw dmlc::Error when next token is not number of ValueType. + * \tparam ValueType type of the number + */ + template + inline void ReadNumber(ValueType *out_value); + /*! + * \brief Begin parsing an object. + * \code + * std::string key; + * // value can be any type that is json serializable. + * std::string value; + * reader->BeginObject(); + * while (reader->NextObjectItem(&key)) { + * // do somthing to key value + * reader->Read(&value); + * } + * \endcode + */ + inline void BeginObject(); + /*! + * \brief Begin parsing an array. + * \code + * // value can be any type that is json serializable. + * std::string value; + * reader->BeginArray(); + * while (reader->NextObjectArrayItem(&value)) { + * // do somthing to value + * } + * \endcode + */ + inline void BeginArray(); + /*! + * \brief Try to move to next object item. + * If this call is successful, user can proceed to call + * reader->Read to read in the value. + * \param out_key the key to the next object. + * \return true if the read is successful, false if we are at end of the object. + */ + inline bool NextObjectItem(std::string *out_key); + /*! + * \brief Try to read the next element in the array. + * If this call is successful, user can proceed to call + * reader->Read to read in the value. + * \return true if the read is successful, false if we are at end of the array. + */ + inline bool NextArrayItem(); + /*! + * \brief Read next ValueType. + * \param out_value any STL or json readable type to be read + * \throw dmlc::Error when the read of ValueType is not successful. + * \tparam ValueType the data type to be read. + */ + template + inline void Read(ValueType *out_value); + + /*! \return current line count */ + inline std::string line_info() const { +#ifndef _LIBCPP_SGX_NO_IOSTREAMS + char temp[64]; + std::ostringstream os; + os << " Line " << std::max(line_count_r_, line_count_n_); + is_->getline(temp, 64); + os << ", around ^`" << temp << "`"; + return os.str(); +#else + std::string info = " Line "; + info += std::to_string(std::max(line_count_r_, line_count_n_)); + + // string getline + size_t end_pos = is_->find('\n'); + end_pos = std::min((size_t)64, + end_pos == std::string::npos ? is_->size() : end_pos); + std::string line = is_->substr(0, end_pos); + is_->erase(0, line.size() + 1); // +1 for \n + + info += ", around ^`" + line + "`"; + return info; +#endif + } + + private: +#ifndef _LIBCPP_SGX_NO_IOSTREAMS + /*! \brief internal reader stream */ + std::istream *is_; +#else + /*! \brief internal reader string */ + std::string *is_; +#endif + /*! \brief "\\r" counter */ + size_t line_count_r_; + /*! \brief "\\n" counter */ + size_t line_count_n_; + /*! + * \brief record how many element processed in + * current array/object scope. + */ + std::vector scope_counter_; + /*! + * \brief Read next nonspace character. + * \return the next nonspace character. + */ + inline int NextNonSpace(); + /*! + * \brief Read just before next nonspace but not read that. + * \return the next nonspace character. + */ + inline int PeekNextNonSpace(); + /*! + * \brief Takes the next char from the input source. + * \return the next character. + */ + inline int NextChar(); + /*! + * \brief Returns the next char from the input source. + * \return the next character. + */ + inline int PeekNextChar(); +}; + +/*! + * \brief Lightweight json to write any STL compositions. + */ +class JSONWriter { + public: + /*! + * \brief Constructor. + * \param os the output reciever. + */ +#ifndef _LIBCPP_SGX_NO_IOSTREAMS + explicit JSONWriter(std::ostream *os) +#else + explicit JSONWriter(std::string *os) +#endif + : os_(os) {} + /*! + * \brief Write a string that do not contain escape characters. + * \param s the string to be written. + */ + inline void WriteNoEscape(const std::string &s); + /*! + * \brief Write a string that can contain escape characters. + * \param s the string to be written. + */ + inline void WriteString(const std::string &s); + /*! + * \brief Write a string that can contain escape characters. + * \param v the value to be written. + * \tparam ValueType The value type to be written. + */ + template + inline void WriteNumber(const ValueType &v); + /*! + * \brief Start beginning of array. + * \param multi_line whether to start an multi_line array. + * \code + * writer->BeginArray(); + * for (auto& v : vdata) { + * writer->WriteArrayItem(v); + * } + * writer->EndArray(); + * \endcode + */ + inline void BeginArray(bool multi_line = true); + /*! \brief Finish writing an array. */ + inline void EndArray(); + /*! + * \brief Start beginning of array. + * \param multi_line whether to start an multi_line array. + * \code + * writer->BeginObject(); + * for (auto& kv : vmap) { + * writer->WriteObjectKeyValue(kv.first, kv.second); + * } + * writer->EndObject(); + * \endcode + */ + inline void BeginObject(bool multi_line = true); + /*! \brief Finish writing object. */ + inline void EndObject(); + /*! + * \brief Write key value pair in the object. + * \param key the key of the object. + * \param value the value of to be written. + * \tparam ValueType The value type to be written. + */ + template + inline void WriteObjectKeyValue(const std::string &key, + const ValueType &value); + /*! + * \brief Write seperator of array, before writing next element. + * User can proceed to call writer->Write to write next item + */ + inline void WriteArraySeperator(); + /*! + * \brief Write value into array. + * \param value The value of to be written. + * \tparam ValueType The value type to be written. + */ + template + inline void WriteArrayItem(const ValueType &value); + /*! + * \brief Write value to json. + * \param value any STL or json readable that can be written. + * \tparam ValueType the data type to be write. + */ + template + inline void Write(const ValueType &value); + + private: +#ifndef _LIBCPP_SGX_NO_IOSTREAMS + /*! \brief Output stream */ + std::ostream *os_; +#else + std::string *os_; +#endif + /*! + * \brief record how many element processed in + * current array/object scope. + */ + std::vector scope_counter_; + /*! \brief Record whether current is a multiline scope */ + std::vector scope_multi_line_; + /*! + * \brief Write seperating space and newlines + */ + inline void WriteSeperator(); +}; + +/*! + * \brief Helper class to read JSON into a class or struct object. + * \code + * struct Param { + * std::string name; + * int value; + * // define load function from JSON + * inline void Load(dmlc::JSONReader *reader) { + * dmlc::JSONStructReadHelper helper; + * helper.DeclareField("name", &name); + * helper.DeclareField("value", &value); + * helper.ReadAllFields(reader); + * } + * }; + * \endcode + */ +class JSONObjectReadHelper { + public: + /*! + * \brief Declare field of type T + * \param key the key of the of field. + * \param addr address of the data type. + * \tparam T the data type to be read, must be STL composition of JSON serializable. + */ + template + inline void DeclareField(const std::string &key, T *addr) { + DeclareFieldInternal(key, addr, false); + } + /*! + * \brief Declare optional field of type T + * \param key the key of the of field. + * \param addr address of the data type. + * \tparam T the data type to be read, must be STL composition of JSON serializable. + */ + template + inline void DeclareOptionalField(const std::string &key, T *addr) { + DeclareFieldInternal(key, addr, true); + } + /*! + * \brief Read in all the declared fields. + * \param reader the JSONReader to read the json. + */ + inline void ReadAllFields(JSONReader *reader); + + private: + /*! + * \brief Internal function to declare field. + * \param key the key of the of field. + * \param addr address of the data type. + * \param optional if set to true, no error will be reported if the key is not presented. + * \tparam T the data type to be read, must be STL composition of JSON serializable. + */ + template + inline void DeclareFieldInternal(const std::string &key, T *addr, bool optional); + /*! + * \brief The internal reader function. + * \param reader The reader to read. + * \param addr The memory address to read. + */ + template + inline static void ReaderFunction(JSONReader *reader, void *addr); + /*! \brief callback type to reader function */ + typedef void (*ReadFunction)(JSONReader *reader, void *addr); + /*! \brief internal data entry */ + struct Entry { + /*! \brief the reader function */ + ReadFunction func; + /*! \brief the address to read */ + void *addr; + /*! \brief whether it is optional */ + bool optional; + }; + /*! \brief the internal map of reader callbacks */ + std::map map_; +}; + +#define DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName) \ + static DMLC_ATTRIBUTE_UNUSED ::dmlc::json::AnyJSONManager& \ + __make_AnyJSONType ## _ ## KeyName ## __ + +/*! + * \def DMLC_JSON_ENABLE_ANY + * \brief Macro to enable save/load JSON of dmlc:: whose actual type is Type. + * Any type will be saved as json array [KeyName, content] + * + * \param Type The type to be registered. + * \param KeyName The Type key assigned to the type, must be same during load. + */ +#define DMLC_JSON_ENABLE_ANY(Type, KeyName) \ + DMLC_STR_CONCAT(DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName), __COUNTER__) = \ + ::dmlc::json::AnyJSONManager::Global()->EnableType(#KeyName) \ + +//! \cond Doxygen_Suppress +namespace json { + +/*! + * \brief generic serialization handler + * \tparam T the type to be serialized + */ +template +struct Handler; + +template +struct NumericHandler { + inline static void Write(JSONWriter *writer, const ValueType &value) { + writer->WriteNumber(value); + } + inline static void Read(JSONReader *reader, ValueType *value) { + reader->ReadNumber(value); + } +}; + +template +struct ArrayHandler { + inline static void Write(JSONWriter *writer, const ContainerType &array) { + typedef typename ContainerType::value_type ElemType; + writer->BeginArray(array.size() > 10 || !dmlc::is_pod::value); + for (typename ContainerType::const_iterator it = array.begin(); + it != array.end(); ++it) { + writer->WriteArrayItem(*it); + } + writer->EndArray(); + } + inline static void Read(JSONReader *reader, ContainerType *array) { + typedef typename ContainerType::value_type ElemType; + array->clear(); + reader->BeginArray(); + while (reader->NextArrayItem()) { + ElemType value; + Handler::Read(reader, &value); + array->insert(array->end(), value); + } + } +}; + +template +struct MapHandler{ + inline static void Write(JSONWriter *writer, const ContainerType &map) { + writer->BeginObject(map.size() > 1); + for (typename ContainerType::const_iterator it = map.begin(); it != map.end(); ++it) { + writer->WriteObjectKeyValue(it->first, it->second); + } + writer->EndObject(); + } + inline static void Read(JSONReader *reader, ContainerType *map) { + typedef typename ContainerType::mapped_type ElemType; + map->clear(); + reader->BeginObject(); + std::string key; + while (reader->NextObjectItem(&key)) { + ElemType value; + reader->Read(&value); + (*map)[key] = value; + } + } +}; + +template +struct CommonJSONSerializer { + inline static void Write(JSONWriter *writer, const T &value) { + value.Save(writer); + } + inline static void Read(JSONReader *reader, T *value) { + value->Load(reader); + } +}; + +template<> +struct Handler { + inline static void Write(JSONWriter *writer, const std::string &value) { + writer->WriteString(value); + } + inline static void Read(JSONReader *reader, std::string *str) { + reader->ReadString(str); + } +}; + +template +struct Handler > : public ArrayHandler > { +}; + +template +struct Handler > { + inline static void Write(JSONWriter *writer, const std::pair &kv) { + writer->BeginArray(); + writer->WriteArrayItem(kv.first); + writer->WriteArrayItem(kv.second); + writer->EndArray(); + } + inline static void Read(JSONReader *reader, std::pair *kv) { + reader->BeginArray(); + CHECK(reader->NextArrayItem()) + << "Expect array of length 2"; + Handler::Read(reader, &(kv->first)); + CHECK(reader->NextArrayItem()) + << "Expect array of length 2"; + Handler::Read(reader, &(kv->second)); + CHECK(!reader->NextArrayItem()) + << "Expect array of length 2"; + } +}; + +template +struct Handler > : public ArrayHandler > { +}; + +template +struct Handler > : public MapHandler > { +}; + +#if DMLC_USE_CXX11 +template +struct Handler > + : public MapHandler > { +}; +#endif // DMLC_USE_CXX11 + +template +struct Handler { + inline static void Write(JSONWriter *writer, const T &data) { + typedef typename dmlc::IfThenElseType::value, + NumericHandler, + CommonJSONSerializer >::Type THandler; + THandler::Write(writer, data); + } + inline static void Read(JSONReader *reader, T *data) { + typedef typename dmlc::IfThenElseType::value, + NumericHandler, + CommonJSONSerializer >::Type THandler; + THandler::Read(reader, data); + } +}; + +#if DMLC_STRICT_CXX11 +#if DMLC_ENABLE_RTTI +// Manager to store json serialization strategy. +class AnyJSONManager { + public: + template + inline AnyJSONManager& EnableType(const std::string& type_name) { // NOLINT(*) + std::type_index tp = std::type_index(typeid(T)); + if (type_name_.count(tp) != 0) { + CHECK(type_name_.at(tp) == type_name) + << "Type has already been registered as another typename " << type_name_.at(tp); + return *this; + } + CHECK(type_map_.count(type_name) == 0) + << "Type name " << type_name << " already registered in registry"; + Entry e; + e.read = ReadAny; + e.write = WriteAny; + type_name_[tp] = type_name; + type_map_[type_name] = e; + return *this; + } + // return global singleton + inline static AnyJSONManager* Global() { + static AnyJSONManager inst; + return &inst; + } + + private: + AnyJSONManager() {} + + template + inline static void WriteAny(JSONWriter *writer, const any &data) { + writer->Write(dmlc::get(data)); + } + template + inline static void ReadAny(JSONReader *reader, any* data) { + T temp; + reader->Read(&temp); + *data = std::move(temp); + } + // data entry to store vtable for any type + struct Entry { + void (*read)(JSONReader* reader, any *data); + void (*write)(JSONWriter* reader, const any& data); + }; + + template + friend struct Handler; + + std::unordered_map type_name_; + std::unordered_map type_map_; +}; + +template<> +struct Handler { + inline static void Write(JSONWriter *writer, const any &data) { + std::unordered_map& + nmap = AnyJSONManager::Global()->type_name_; + std::type_index id = std::type_index(data.type()); + auto it = nmap.find(id); + CHECK(it != nmap.end() && it->first == id) + << "Type " << id.name() << " has not been registered via DMLC_JSON_ENABLE_ANY"; + std::string type_name = it->second; + AnyJSONManager::Entry e = AnyJSONManager::Global()->type_map_.at(type_name); + writer->BeginArray(false); + writer->WriteArrayItem(type_name); + writer->WriteArraySeperator(); + e.write(writer, data); + writer->EndArray(); + } + inline static void Read(JSONReader *reader, any *data) { + std::string type_name; + reader->BeginArray(); + CHECK(reader->NextArrayItem()) << "invalid any json format"; + Handler::Read(reader, &type_name); + std::unordered_map& + tmap = AnyJSONManager::Global()->type_map_; + auto it = tmap.find(type_name); + CHECK(it != tmap.end() && it->first == type_name) + << "Typename " << type_name << " has not been registered via DMLC_JSON_ENABLE_ANY"; + AnyJSONManager::Entry e = it->second; + CHECK(reader->NextArrayItem()) << "invalid any json format"; + e.read(reader, data); + CHECK(!reader->NextArrayItem()) << "invalid any json format"; + } +}; +#endif // DMLC_ENABLE_RTTI +#endif // DMLC_STRICT_CXX11 + +} // namespace json + +// implementations of JSONReader/Writer +inline int JSONReader::NextChar() { +#ifndef _LIBCPP_SGX_NO_IOSTREAMS + return is_->get(); +#else + int ch = is_->at(0); + is_->erase(0, 1); + return ch; +#endif +} + +inline int JSONReader::PeekNextChar() { +#ifndef _LIBCPP_SGX_NO_IOSTREAMS + return is_->peek(); +#else + return is_->at(0); +#endif +} + +inline int JSONReader::NextNonSpace() { + int ch; + do { + ch = NextChar(); + if (ch == '\n') ++line_count_n_; + if (ch == '\r') ++line_count_r_; + } while (isspace(ch)); + return ch; +} + +inline int JSONReader::PeekNextNonSpace() { + int ch; + while (true) { + ch = PeekNextChar(); + if (ch == '\n') ++line_count_n_; + if (ch == '\r') ++line_count_r_; + if (!isspace(ch)) break; + NextChar(); + } + return ch; +} + +namespace { + template +#ifndef _LIBCPP_SGX_NO_IOSTREAMS + void Extend(std::ostream *os, T item) { + *os << item; + } +#else + void Extend(std::string *ostr, T item) { + *ostr += item; + } +#endif +} // namespace + +inline void JSONReader::ReadString(std::string *out_str) { + int ch = NextNonSpace(); + CHECK_EQ(ch, '\"') + << "Error at" << line_info() + << ", Expect \'\"\' but get \'" << static_cast(ch) << '\''; +#ifndef _LIBCPP_SGX_NO_IOSTREAMS + std::ostringstream output; +#else + std::string output = ""; +#endif + while (true) { + ch = NextChar(); + if (ch == '\\') { + char sch = static_cast(NextChar()); + switch (sch) { + case 'r': Extend(&output, "\r"); break; + case 'n': Extend(&output, "\n"); break; + case '\\': Extend(&output, "\\"); break; + case 't': Extend(&output, "\t"); break; + case '\"': Extend(&output, "\""); break; + default: LOG(FATAL) << "unknown string escape \\" << sch; + } + } else { + if (ch == '\"') break; + Extend(&output, static_cast(ch)); + } + if (ch == EOF || ch == '\r' || ch == '\n') { + LOG(FATAL) + << "Error at" << line_info() + << ", Expect \'\"\' but reach end of line "; + } + } +#ifndef _LIBCPP_SGX_NO_IOSTREAMS + *out_str = output.str(); +#else + *out_str = output; +#endif +} + +template +inline void JSONReader::ReadNumber(ValueType *out_value) { +#ifndef _LIBCPP_SGX_NO_IOSTREAMS + *is_ >> *out_value; + CHECK(!is_->fail()) + << "Error at" << line_info() + << ", Expect number"; +#else + char* endptr; + const char* icstr = is_->c_str(); + unsigned number = strtol(icstr, &endptr, 10); + is_->erase(0, endptr - icstr); + *out_value = static_cast(number); +#endif +} + +inline void JSONReader::BeginObject() { + int ch = NextNonSpace(); + CHECK_EQ(ch, '{') + << "Error at" << line_info() + << ", Expect \'{\' but get \'" << static_cast(ch) << '\''; + scope_counter_.push_back(0); +} + +inline void JSONReader::BeginArray() { + int ch = NextNonSpace(); + CHECK_EQ(ch, '[') + << "Error at" << line_info() + << ", Expect \'{\' but get \'" << static_cast(ch) << '\''; + scope_counter_.push_back(0); +} + +inline bool JSONReader::NextObjectItem(std::string *out_key) { + bool next = true; + if (scope_counter_.back() != 0) { + int ch = NextNonSpace(); + if (ch == EOF) { + next = false; + } else if (ch == '}') { + next = false; + } else { + CHECK_EQ(ch, ',') + << "Error at" << line_info() + << ", JSON object expect \'}\' or \',\' \'" << static_cast(ch) << '\''; + } + } else { + int ch = PeekNextNonSpace(); + if (ch == '}') { + NextChar(); + next = false; + } + } + if (!next) { + scope_counter_.pop_back(); + return false; + } else { + scope_counter_.back() += 1; + ReadString(out_key); + int ch = NextNonSpace(); + CHECK_EQ(ch, ':') + << "Error at" << line_info() + << ", Expect \':\' but get \'" << static_cast(ch) << '\''; + return true; + } +} + +inline bool JSONReader::NextArrayItem() { + bool next = true; + if (scope_counter_.back() != 0) { + int ch = NextNonSpace(); + if (ch == EOF) { + next = false; + } else if (ch == ']') { + next = false; + } else { + CHECK_EQ(ch, ',') + << "Error at" << line_info() + << ", JSON array expect \']\' or \',\'. Get \'" << static_cast(ch) << "\' instead"; + } + } else { + int ch = PeekNextNonSpace(); + if (ch == ']') { + NextChar(); + next = false; + } + } + if (!next) { + scope_counter_.pop_back(); + return false; + } else { + scope_counter_.back() += 1; + return true; + } +} + +template +inline void JSONReader::Read(ValueType *out_value) { + json::Handler::Read(this, out_value); +} + +inline void JSONWriter::WriteNoEscape(const std::string &s) { + Extend(os_, '\"'); + Extend(os_, s); + Extend(os_, '\"'); +} + +inline void JSONWriter::WriteString(const std::string &s) { + Extend(os_, '\"'); + for (size_t i = 0; i < s.length(); ++i) { + char ch = s[i]; + switch (ch) { + case '\r': Extend(os_, "\\r"); break; + case '\n': Extend(os_, "\\n"); break; + case '\\': Extend(os_, "\\\\"); break; + case '\t': Extend(os_, "\\t"); break; + case '\"': Extend(os_, "\\\""); break; + default: Extend(os_, ch); + } + } + Extend(os_, '\"'); +} + +template +inline void JSONWriter::WriteNumber(const ValueType &v) { +#ifndef _LIBCPP_SGX_NO_IOSTREAMS + Extend(os_, v); +#else + Extend(os_, std::to_string(v)); +#endif +} + +inline void JSONWriter::BeginArray(bool multi_line) { + Extend(os_, '['); + scope_multi_line_.push_back(multi_line); + scope_counter_.push_back(0); +} + +inline void JSONWriter::EndArray() { + CHECK_NE(scope_multi_line_.size(), 0U); + CHECK_NE(scope_counter_.size(), 0U); + bool newline = scope_multi_line_.back(); + size_t nelem = scope_counter_.back(); + scope_multi_line_.pop_back(); + scope_counter_.pop_back(); + if (newline && nelem != 0) WriteSeperator(); + Extend(os_, ']'); +} + +inline void JSONWriter::BeginObject(bool multi_line) { + Extend(os_, '{'); + scope_multi_line_.push_back(multi_line); + scope_counter_.push_back(0); +} + +inline void JSONWriter::EndObject() { + CHECK_NE(scope_multi_line_.size(), 0U); + CHECK_NE(scope_counter_.size(), 0U); + bool newline = scope_multi_line_.back(); + size_t nelem = scope_counter_.back(); + scope_multi_line_.pop_back(); + scope_counter_.pop_back(); + if (newline && nelem != 0) WriteSeperator(); + Extend(os_, '}'); +} + +template +inline void JSONWriter::WriteObjectKeyValue(const std::string &key, + const ValueType &value) { + if (scope_counter_.back() > 0) { + Extend(os_, ", "); + } + WriteSeperator(); + Extend(os_, '\"'); + Extend(os_, key); + Extend(os_, "\": "); + scope_counter_.back() += 1; + json::Handler::Write(this, value); +} + +inline void JSONWriter::WriteArraySeperator() { + if (scope_counter_.back() != 0) { + Extend(os_, ", "); + } + scope_counter_.back() += 1; + WriteSeperator(); +} + +template +inline void JSONWriter::WriteArrayItem(const ValueType &value) { + this->WriteArraySeperator(); + json::Handler::Write(this, value); +} + +template +inline void JSONWriter::Write(const ValueType &value) { + size_t nscope = scope_multi_line_.size(); + json::Handler::Write(this, value); + CHECK_EQ(nscope, scope_multi_line_.size()) + << "Uneven scope, did you call EndArray/EndObject after each BeginObject/Array?"; +} + +inline void JSONWriter::WriteSeperator() { + if (scope_multi_line_.size() == 0 || scope_multi_line_.back()) { + Extend(os_, '\n'); + Extend(os_, std::string(scope_multi_line_.size() * 2, ' ')); + } +} + +inline void JSONObjectReadHelper::ReadAllFields(JSONReader *reader) { + reader->BeginObject(); + std::map visited; + std::string key; + while (reader->NextObjectItem(&key)) { + if (map_.count(key) != 0) { + Entry e = map_[key]; + (*e.func)(reader, e.addr); + visited[key] = 0; + } else { +#ifndef _LIBCPP_SGX_NO_IOSTREAMS + std::ostringstream err; +#else + std::string err(""); +#endif + Extend(&err, "JSONReader: Unknown field "); + Extend(&err, key); + Extend(&err, ", candidates are: \n"); + for (std::map::iterator + it = map_.begin(); it != map_.end(); ++it) { + Extend(&err, '\"'); + Extend(&err, it->first); + Extend(&err, "\"\n"); + } +#ifndef _LIBCPP_SGX_NO_IOSTREAMS + LOG(FATAL) << err.str(); +#else + LOG(FATAL) << err; +#endif + } + } + if (visited.size() != map_.size()) { + for (std::map::iterator + it = map_.begin(); it != map_.end(); ++it) { + if (it->second.optional) continue; + CHECK_NE(visited.count(it->first), 0U) + << "JSONReader: Missing field \"" << it->first << "\"\n At " + << reader->line_info(); + } + } +} + +template +inline void JSONObjectReadHelper::ReaderFunction(JSONReader *reader, void *addr) { + json::Handler::Read(reader, static_cast(addr)); +} + +template +inline void JSONObjectReadHelper:: +DeclareFieldInternal(const std::string &key, T *addr, bool optional) { + CHECK_EQ(map_.count(key), 0U) + << "Adding duplicate field " << key; + Entry e; + e.func = ReaderFunction; + e.addr = static_cast(addr); + e.optional = optional; + map_[key] = e; +} + +//! \endcond +} // namespace dmlc +#endif // DMLC_JSON_H_ diff --git a/include/dmlc/logging.h b/include/dmlc/logging.h new file mode 100644 index 000000000000..8e7878bd41d3 --- /dev/null +++ b/include/dmlc/logging.h @@ -0,0 +1,424 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file logging.h + * \brief defines logging macros of dmlc + * allows use of GLOG, fall back to internal + * implementation when disabled + */ +#ifndef DMLC_LOGGING_H_ +#define DMLC_LOGGING_H_ +#include +#include +#include +#include +#include +#include +#include "./base.h" + +#if DMLC_LOG_STACK_TRACE +#include +#endif + +#if DMLC_LOG_STACK_TRACE +#include +#endif + +namespace dmlc { +/*! + * \brief exception class that will be thrown by + * default logger if DMLC_LOG_FATAL_THROW == 1 + */ +struct Error : public std::runtime_error { + /*! + * \brief constructor + * \param s the error message + */ + explicit Error(const std::string &s) : std::runtime_error(s) {} +}; +} // namespace dmlc + +#if DMLC_USE_GLOG +#include + +namespace dmlc { +/*! + * \brief optionally redirect to google's init log + * \param argv0 The arguments. + */ +inline void InitLogging(const char* argv0) { + google::InitGoogleLogging(argv0); +} +} // namespace dmlc + +#else +// use a light version of glog +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable : 4722) +#pragma warning(disable : 4068) +#endif + +namespace dmlc { +inline void InitLogging(const char*) { + // DO NOTHING +} + +class LogCheckError { + public: + LogCheckError() : str(nullptr) {} + explicit LogCheckError(const std::string& str_) : str(new std::string(str_)) {} + ~LogCheckError() { if (str != nullptr) delete str; } + operator bool() {return str != nullptr; } + std::string* str; +}; + +#ifndef DMLC_GLOG_DEFINED + +#ifndef _LIBCPP_SGX_NO_IOSTREAMS +#define DEFINE_CHECK_FUNC(name, op) \ + template \ + inline LogCheckError LogCheck##name(const X& x, const Y& y) { \ + if (x op y) return LogCheckError(); \ + std::ostringstream os; \ + os << " (" << x << " vs. " << y << ") "; /* CHECK_XX(x, y) requires x and y can be serialized to string. Use CHECK(x OP y) otherwise. NOLINT(*) */ \ + return LogCheckError(os.str()); \ + } \ + inline LogCheckError LogCheck##name(int x, int y) { \ + return LogCheck##name(x, y); \ + } +#else +#define DEFINE_CHECK_FUNC(name, op) \ + template \ + inline LogCheckError LogCheck##name(const X& x, const Y& y) { \ + if (x op y) return LogCheckError(); \ + return LogCheckError("Error."); \ + } \ + inline LogCheckError LogCheck##name(int x, int y) { \ + return LogCheck##name(x, y); \ + } +#endif + +#define CHECK_BINARY_OP(name, op, x, y) \ + if (dmlc::LogCheckError _check_err = dmlc::LogCheck##name(x, y)) \ + dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \ + << "Check failed: " << #x " " #op " " #y << *(_check_err.str) + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wsign-compare" +DEFINE_CHECK_FUNC(_LT, <) +DEFINE_CHECK_FUNC(_GT, >) +DEFINE_CHECK_FUNC(_LE, <=) +DEFINE_CHECK_FUNC(_GE, >=) +DEFINE_CHECK_FUNC(_EQ, ==) +DEFINE_CHECK_FUNC(_NE, !=) +#pragma GCC diagnostic pop + +// Always-on checking +#define CHECK(x) \ + if (!(x)) \ + dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \ + << "Check failed: " #x << ' ' +#define CHECK_LT(x, y) CHECK_BINARY_OP(_LT, <, x, y) +#define CHECK_GT(x, y) CHECK_BINARY_OP(_GT, >, x, y) +#define CHECK_LE(x, y) CHECK_BINARY_OP(_LE, <=, x, y) +#define CHECK_GE(x, y) CHECK_BINARY_OP(_GE, >=, x, y) +#define CHECK_EQ(x, y) CHECK_BINARY_OP(_EQ, ==, x, y) +#define CHECK_NE(x, y) CHECK_BINARY_OP(_NE, !=, x, y) +#define CHECK_NOTNULL(x) \ + ((x) == NULL ? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check notnull: " #x << ' ', (x) : (x)) // NOLINT(*) +// Debug-only checking. +#ifdef NDEBUG +#define DCHECK(x) \ + while (false) CHECK(x) +#define DCHECK_LT(x, y) \ + while (false) CHECK((x) < (y)) +#define DCHECK_GT(x, y) \ + while (false) CHECK((x) > (y)) +#define DCHECK_LE(x, y) \ + while (false) CHECK((x) <= (y)) +#define DCHECK_GE(x, y) \ + while (false) CHECK((x) >= (y)) +#define DCHECK_EQ(x, y) \ + while (false) CHECK((x) == (y)) +#define DCHECK_NE(x, y) \ + while (false) CHECK((x) != (y)) +#else +#define DCHECK(x) CHECK(x) +#define DCHECK_LT(x, y) CHECK((x) < (y)) +#define DCHECK_GT(x, y) CHECK((x) > (y)) +#define DCHECK_LE(x, y) CHECK((x) <= (y)) +#define DCHECK_GE(x, y) CHECK((x) >= (y)) +#define DCHECK_EQ(x, y) CHECK((x) == (y)) +#define DCHECK_NE(x, y) CHECK((x) != (y)) +#endif // NDEBUG + +#if DMLC_LOG_CUSTOMIZE +#define LOG_INFO dmlc::CustomLogMessage(__FILE__, __LINE__) +#else +#define LOG_INFO dmlc::LogMessage(__FILE__, __LINE__) +#endif +#define LOG_ERROR LOG_INFO +#define LOG_WARNING LOG_INFO +#define LOG_FATAL dmlc::LogMessageFatal(__FILE__, __LINE__) +#define LOG_QFATAL LOG_FATAL + +// Poor man version of VLOG +#define VLOG(x) LOG_INFO.stream() + +#define LOG(severity) LOG_##severity.stream() +#define LG LOG_INFO.stream() +#define LOG_IF(severity, condition) \ + !(condition) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) + +#ifdef NDEBUG +#define LOG_DFATAL LOG_ERROR +#define DFATAL ERROR +#define DLOG(severity) true ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) +#define DLOG_IF(severity, condition) \ + (true || !(condition)) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) +#else +#define LOG_DFATAL LOG_FATAL +#define DFATAL FATAL +#define DLOG(severity) LOG(severity) +#define DLOG_IF(severity, condition) LOG_IF(severity, condition) +#endif + +// Poor man version of LOG_EVERY_N +#define LOG_EVERY_N(severity, n) LOG(severity) + +#endif // DMLC_GLOG_DEFINED + +class DateLogger { + public: + DateLogger() { +#if defined(_MSC_VER) + _tzset(); +#endif + } + const char* HumanDate() { +#ifndef _LIBCPP_SGX_CONFIG +#if defined(_MSC_VER) + _strtime_s(buffer_, sizeof(buffer_)); +#else + time_t time_value = time(NULL); + struct tm *pnow; +#if !defined(_WIN32) + struct tm now; + pnow = localtime_r(&time_value, &now); +#else + pnow = localtime(&time_value); // NOLINT(*) +#endif + snprintf(buffer_, sizeof(buffer_), "%02d:%02d:%02d", + pnow->tm_hour, pnow->tm_min, pnow->tm_sec); +#endif +#endif // _LIBCPP_SGX_CONFIG + return buffer_; + } + + private: + char buffer_[9]; +}; + +#ifndef _LIBCPP_SGX_NO_IOSTREAMS +class LogMessage { + public: + LogMessage(const char* file, int line) + : +#ifdef __ANDROID__ + log_stream_(std::cout) +#else + log_stream_(std::cerr) +#endif + { + log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":" + << line << ": "; + } + ~LogMessage() { log_stream_ << '\n'; } + std::ostream& stream() { return log_stream_; } + + protected: + std::ostream& log_stream_; + + private: + DateLogger pretty_date_; + LogMessage(const LogMessage&); + void operator=(const LogMessage&); +}; + +// customized logger that can allow user to define where to log the message. +class CustomLogMessage { + public: + CustomLogMessage(const char* file, int line) { + log_stream_ << "[" << DateLogger().HumanDate() << "] " << file << ":" + << line << ": "; + } + ~CustomLogMessage() { + Log(log_stream_.str()); + } + std::ostream& stream() { return log_stream_; } + /*! + * \brief customized logging of the message. + * This function won't be implemented by libdmlc + * \param msg The message to be logged. + */ + static void Log(const std::string& msg); + + private: + std::ostringstream log_stream_; +}; +#else +class DummyOStream { + public: + template + DummyOStream& operator<<(T _) { return *this; } + inline std::string str() { return ""; } +}; +class LogMessage { + public: + LogMessage(const char* file, int line) : log_stream_() {} + DummyOStream& stream() { return log_stream_; } + + protected: + DummyOStream log_stream_; + + private: + LogMessage(const LogMessage&); + void operator=(const LogMessage&); +}; +#endif + + + +#if DMLC_LOG_STACK_TRACE +inline std::string Demangle(char const *msg_str) { + using std::string; + string msg(msg_str); + size_t symbol_start = string::npos; + size_t symbol_end = string::npos; + if ( ((symbol_start = msg.find("_Z")) != string::npos) + && (symbol_end = msg.find_first_of(" +", symbol_start)) ) { + string left_of_symbol(msg, 0, symbol_start); + string symbol(msg, symbol_start, symbol_end - symbol_start); + string right_of_symbol(msg, symbol_end); + + int status = 0; + size_t length = string::npos; + std::unique_ptr demangled_symbol = + {abi::__cxa_demangle(symbol.c_str(), 0, &length, &status), &std::free}; + if (demangled_symbol && status == 0 && length > 0) { + string symbol_str(demangled_symbol.get()); + std::ostringstream os; + os << left_of_symbol << symbol_str << right_of_symbol; + return os.str(); + } + } + return string(msg_str); +} + +inline std::string StackTrace() { + using std::string; + std::ostringstream stacktrace_os; + const int MAX_STACK_SIZE = DMLC_LOG_STACK_TRACE_SIZE; + void *stack[MAX_STACK_SIZE]; + int nframes = backtrace(stack, MAX_STACK_SIZE); + stacktrace_os << "Stack trace returned " << nframes << " entries:" << std::endl; + char **msgs = backtrace_symbols(stack, nframes); + if (msgs != nullptr) { + for (int frameno = 0; frameno < nframes; ++frameno) { + string msg = dmlc::Demangle(msgs[frameno]); + stacktrace_os << "[bt] (" << frameno << ") " << msg << "\n"; + } + } + free(msgs); + string stack_trace = stacktrace_os.str(); + return stack_trace; +} + +#else // DMLC_LOG_STACK_TRACE is off + +inline std::string demangle(char const* msg_str) { + return std::string(); +} + +inline std::string StackTrace() { + return std::string("stack traces not available when " + "DMLC_LOG_STACK_TRACE is disabled at compile time."); +} + +#endif // DMLC_LOG_STACK_TRACE + +#if defined(_LIBCPP_SGX_NO_IOSTREAMS) +class LogMessageFatal : public LogMessage { + public: + LogMessageFatal(const char* file, int line) : LogMessage(file, line) {} + ~LogMessageFatal() { + abort(); + } + private: + LogMessageFatal(const LogMessageFatal&); + void operator=(const LogMessageFatal&); +}; +#elif DMLC_LOG_FATAL_THROW == 0 +class LogMessageFatal : public LogMessage { + public: + LogMessageFatal(const char* file, int line) : LogMessage(file, line) {} + ~LogMessageFatal() { + log_stream_ << "\n\n" << StackTrace() << "\n"; + abort(); + } + + private: + LogMessageFatal(const LogMessageFatal&); + void operator=(const LogMessageFatal&); +}; +#else +class LogMessageFatal { + public: + LogMessageFatal(const char* file, int line) { + log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":" + << line << ": "; + } + std::ostringstream &stream() { return log_stream_; } + ~LogMessageFatal() DMLC_THROW_EXCEPTION { +#if DMLC_LOG_STACK_TRACE + log_stream_ << "\n\n" << StackTrace() << "\n"; +#endif + + // throwing out of destructor is evil + // hopefully we can do it here + // also log the message before throw +#if DMLC_LOG_BEFORE_THROW + LOG(ERROR) << log_stream_.str(); +#endif + throw Error(log_stream_.str()); + } + + private: + std::ostringstream log_stream_; + DateLogger pretty_date_; + LogMessageFatal(const LogMessageFatal&); + void operator=(const LogMessageFatal&); +}; +#endif + +// This class is used to explicitly ignore values in the conditional +// logging macros. This avoids compiler warnings like "value computed +// is not used" and "statement has no effect". +class LogMessageVoidify { + public: + LogMessageVoidify() {} + // This has to be an operator with a precedence lower than << but + // higher than "?:". See its usage. +#if !defined(_LIBCPP_SGX_NO_IOSTREAMS) + void operator&(std::ostream&) {} +#endif +}; + +} // namespace dmlc + +#endif +#endif // DMLC_LOGGING_H_ diff --git a/include/dmlc/lua.h b/include/dmlc/lua.h new file mode 100644 index 000000000000..13aa7b73d269 --- /dev/null +++ b/include/dmlc/lua.h @@ -0,0 +1,739 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file lua.h + * \brief C++11 header only interface to easily interact with Lua and Torch. + * This code is evolved from torch plugin code for MXNet. + * + * This header will require Torch and Lua to be presented, do not include. + * + * \author Junyuan Xie, Min Lin, Tianqi Chen + * + * \code + * + * // Example code to use the lua module. + * dmlc::LuaState* lua = dmlc::LuaState::ThreadLocalState(); + * // vectors converts automatically to lua table. + * auto tbl = lua->Convert(std::vector{1,2,3}); + * // use eval to get lua reference, this is a function + * auto print = lua->Eval("return function(x) print(x) end"); + * // lua function can be directly called from c++, arguments are converted. + * print(100); + * + * // set field in the table. + * tbl.SetField("square", lua->Eval("return function(x) x*x end")); + * // call the function, covert back to C++ values. + * int x = tbl["square"](100).Get(); + * + * \endcode + */ +#ifndef DMLC_LUA_H_ +#define DMLC_LUA_H_ + +extern "C" { +#include +#include +#include +} + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "./base.h" +#include "./logging.h" +#include "./thread_local.h" + +namespace dmlc { + +// forward declare torch state +class LuaState; + +namespace lua_stack { +template +struct Handler; +}; + +/*! \brief an reference to lua object */ +class LuaRef { + public: + /*! \brief construct an nil ref */ + LuaRef() = default; + /*! + * \brief move constructor from another LuaRef + * \param other The other LuaRef to be moved + */ + inline LuaRef(LuaRef&& other); // NOLINT(*) + /*! + * \brief copy constructor + * \param other The other LuaRef to be copied + */ + inline LuaRef(const LuaRef& other); // NOLINT(*) + /*! + * \brief assign operator from other + * \param other The other LuaRef to be copy or moved. + * \return self + */ + inline LuaRef& operator=(LuaRef&& other); + /*! + * \brief assign operator from other + * \param other The other LuaRef to be copy or moved. + * \return self + */ + inline LuaRef& operator=(const LuaRef& other); + /*! \brief destructor */ + inline ~LuaRef(); + /*! + * \brief swap content with another ref + * \param other another LuaRef to be swaped. + */ + inline void swap(LuaRef& other); // NOLINT(*) + /*! + * \brief Get content out as type T. + * + * \tparam T the type to be fetched. + * \return the corresponding c type. + */ + template + inline T Get() const; + /*! + * \brief Get user data pointer from LuaRef + * + * CAREFUL when getting userdata(e.g. pointer to Tensor's storage) from LuaRef. + * Remember they are managed by Lua, and can get deleted when all the + * LuaRef to the userdata destructs. A good practice is always use a LuaRef to keep + * the userdata alive when you need them from C++ side. + * + * \tparam T the type of pointer to be fetched. + * \return the corresponding c type. + */ + template + inline T* GetUDataPtr() const; + /*! \return whether the value is nil */ + inline bool is_nil() const; + /*! + * \brief invoke the LuaRef as function + * \param args Arguments to be passed. + * \tparam Args arguments to be passed. + * \return The first return value. + */ + template + inline LuaRef operator()(Args&& ...args) const; + /*! + * \brief Get field from the lua table. + * The reference must be a table + * \param key The key to the table + * \return a new ref to the corresponding field. + */ + inline LuaRef operator[](const std::string& key) const; + /*! + * \brief Get field from the lua array + * The reference must be a array + * \param index The index to the array, + * Note: the index convention follows lua table, starts from 1 + * \return a new ref to the corresponding field. + */ + inline LuaRef operator[](size_t index) const; + /*! + * \brief Set field of lua table. + * The reference must be a table + * \param key The key to the table + * \param value Lua convertable value to be setted. + * \return self. + */ + template + inline LuaRef& SetField(const std::string& key, const T& value); // NOLINT(*) + /*! + * \brief Set LuaRef to the value on top of the stack. + * This state must be nil. + * This is API used by developer. + * + * \param s the corresponding lua state. + */ + inline void SetByPopStack_(LuaState* s); + + private: + // friend with luastate + friend struct lua_stack::Handler; + friend class LuaState; + friend std::ostream &operator<<(std::ostream &os, const LuaRef &r); + /*! \brief pointer to the state */ + LuaState* state_{nullptr}; + /*! \brief reference index */ + int ref_; +}; + +/*! \brief A Lua state */ +class LuaState { + public: + /*! \brief options to be provided in lua state */ + enum Option { + kNoThreadProtect, + kThreadLocal, + kLocking, + }; + /*! \brief destructor */ + inline ~LuaState(); + /*! + * \brief evaluate a piece of lua code, return the first result. + * \param lua_code Lua code + * \return A LuaRef object of the first returned result, + * Can be nil if the code did not return LuaRefthing. + */ + inline LuaRef Eval(const char* lua_code); + /*! + * \brief evaluate a piece of lua code, return the first result. + * \param lua_code Lua code + * \return A LuaRef object of the first returned result, + * Can be nil if the code did not return anything. + */ + inline LuaRef Eval(const std::string& lua_code) { + return this->Eval(lua_code.c_str()); + } + /*! + * \brief convert a C++ type to lua type + * \param value The data to be converted. + * vector, map will be converted to table. + * \return a converted value. + * \tparam T the type to be converted. + */ + template + inline LuaRef Convert(const T& value); + /*! + * \brief get global field from the state + * \param key The key to the global field. + * \return The global field value. + */ + inline LuaRef operator[](const std::string& key); + /*! + * \brief Set the value to the global table. + * \param key The key of the global field. + * \param value The value to the set. + */ + inline void SetGlobalField(const std::string& key, const LuaRef& value); + /*! + * Get a thread local version of lua state. + * The LuaState runs in thread local mode, + * all the LuaRef can only be run on the current thread. + * This is the recommended behavior when invoking Lua. + * + * \return a threadlocal version of lua state. + */ + static inline LuaState* ThreadLocalState(); + /*! + * Create a new lua state. + * \note It is highly recommended to use ThreadLocalState instead. + * + * Most Lua program assumes it only runs from the same thread. + * Some Lua code that wraps C library(e.g. Torch) could rely + * on thread_local storage to store global state such as random number generator. + * This means if the code is invoked by another thread, the thread_local + * might become inavailable, depending on the implementation. + * + * If the global state is stored only in Lua's global table, then + * it is safe to use kLocking mode and call the code from multiple thread. + * Never-the-less, using ThreadLocalState removes the need to lock, + * and is the desirable usecase in most times. + * + * \sa ThreadLocalState + * \param option The option to use the state. + * \return a newly created lua state + */ + static inline LuaState* Create_(Option option); + + /*! + * \brief protected run f, this is used by API developers. + * always call this to access lua state + * f must not destruct LuaRef, or access the mutex + * + * \param f the function to be called. + * \tparam F the function to be called, signiture (lua_State *L) + */ + template + inline void PRun_(F f); + /*! + * \param L the other lua state. + * \return if the internal lua state is same as L + */ + inline bool SameLuaState(lua_State *L) const { + return L_ == L; + } + + protected: + struct StackReset; + friend class LuaRef; + friend struct ThreadLocalStore; + /*! + * \brief constructor + */ + inline LuaState(); + + /*! \brief internal option, default to thread local */ + Option option_{kThreadLocal}; + /*! \brief internal lua state */ + lua_State* L_; + /*! \brief internal lock about the state */ + std::mutex mutex_; +}; + +// implementations after this line +//! \cond Doxygen_Suppress +/*! \brief macro to check error during lua call */ +#define LUA_CALL(x) \ + if ((x)) { \ + LOG(FATAL) << "Lua Call Error:" << lua_tostring(L, -1); \ + } + +/*! + * \brief namespace to handle conversions between lua and c++ + * User can provide an specialization of dmlc::lua_stack::Handler + * to allow customized c++ data types to interact with Lua. + * + * By default basic data types, composition of vector, and unordered_map is supported. + * The conversion rules + * - basic types(string, int, float) to corresponding lua types. + * - unordered_map to Lua table. + * - vector to lua indexed table. + */ +namespace lua_stack { +inline int lua_abs_index(lua_State* L, int index) { + if (index > 0 || index <= LUA_REGISTRYINDEX) return index; + return lua_gettop(L) + index + 1; +} + +template +struct Handler; + +template +struct NumberHandler { + static inline T Get(lua_State* L, int index, LuaState* s) { + CHECK_EQ(lua_type(L, index), LUA_TNUMBER) + << "Attempt to get number but type is \'" + << lua_typename(L, lua_type(L, index)) << '\''; + if (std::is_integral::value) { + return static_cast(lua_tointeger(L, index)); + } else { + return static_cast(lua_tonumber(L, index)); + } + } + static inline void Push(lua_State* L, const T& v) { + if (std::is_integral::value) { + lua_pushinteger(L, static_cast(v)); + } else { + lua_pushnumber(L, static_cast(v)); + } + } +}; + +template +struct MapHandler { + using K = typename ContainerType::key_type; + using V = typename ContainerType::mapped_type; + static inline ContainerType Get(lua_State* L, int index, LuaState* s) { + ContainerType ret; + CHECK(lua_istable(L, index)) + << "Expected a table but get " + << lua_typename(L, lua_type(L, index)) << '\''; + int tid = lua_abs_index(L, index); + lua_pushnil(L); + while (lua_next(L, -2)) { + ret[Handler::Get(L, -2, s)] = Handler::Pop(L, -1, s); + lua_pop(L, 1); + } + lua_settop(L, tid); + return ret; + } + static inline void Push(lua_State* L, const ContainerType& v) { + lua_createtable(L, v.size(), 0); + for (const auto& kv : v) { + Handler::Push(L, kv.first); + Handler::Push(L, kv.second); + lua_settable(L, -3); + } + } +}; + +struct UndefinedHandler { +}; + +template +struct Handler + : public std::conditional::value, + NumberHandler, + UndefinedHandler>::type { +}; + +template<> +struct Handler { + static inline std::string Get(lua_State* L, int index, LuaState* s) { + CHECK_EQ(lua_type(L, index), LUA_TSTRING); + return std::string(lua_tostring(L, index)); + } + static inline void Push(lua_State* L, const std::string& v) { + lua_pushstring(L, v.c_str()); + } +}; + +template +struct Handler > { + static inline std::vector Get(lua_State* L, int index, LuaState* s) { + std::vector ret; + CHECK(lua_istable(L, index)) + << "Expected a table but get " + << lua_typename(L, lua_type(L, index)) << '\''; + int tid = lua_abs_index(L, index); + lua_pushnil(L); + while (lua_next(L, tid)) { + CHECK_EQ(Handler::Get(L, -2, s), ret.size() + 1) + << "Target table is not an array"; + ret.push_back(Handler::Get(L, -1, s)); + lua_pop(L, 1); + } + lua_settop(L, tid); + return ret; + } + static inline void Push(lua_State* L, const std::vector& v) { + lua_createtable(L, v.size(), 0); + for (size_t i = 0; i < v.size(); ++i) { + Handler::Push(L, v[i]); + lua_rawseti(L, -2, i + 1); + } + } +}; + +template +struct Handler > + : public MapHandler > { +}; + +template<> +struct Handler { + static inline LuaRef Get(lua_State* L, int index, LuaState* s) { + LuaRef ret; + lua_pushvalue(L, index); + ret.SetByPopStack_(s); + return ret; + } + + static inline void Push(lua_State* L, const LuaRef& v) { + if (v.is_nil()) { + lua_pushnil(L); + } else { + CHECK(v.state_->SameLuaState(L)) + << "Cannot pass LuaRef on a different LuaState's function"; + lua_rawgeti(L, LUA_REGISTRYINDEX, v.ref_); + } + } +}; + +template<> +struct Handler { + static inline LuaRef Get(lua_State* L, int index, LuaState* s) { + LOG(FATAL) << "not supported"; + return LuaRef(); + } + static inline void Push(lua_State* L, const std::nullptr_t& v) { + lua_pushnil(L); + } +}; + +// generic functor to call push the arguments. +struct PushArg { + lua_State* L; + template + inline void operator()(const T& v) const { + Handler::Push(L, v); + } +}; + +} // namespace lua_stack + +inline LuaState::LuaState() { + L_ = luaL_newstate(); + CHECK(L_ != nullptr) + << "Failed to create new lua state"; + luaL_openlibs(L_); +} + +inline LuaState::~LuaState() { + if (option_ != kThreadLocal && L_ != nullptr) { + // never close threadlocal, for save destruction. + lua_close(L_); + } +} + +inline LuaState* LuaState::Create_(Option opt) { + LuaState* s = new LuaState(); + s->option_ = opt; + CHECK_NE(opt, kThreadLocal) + << "use LuaState::ThreadLocalState() to get the thread local state"; + return s; +} + +inline void LuaRef::SetByPopStack_(LuaState* s) { + CHECK(state_ == nullptr); + lua_State* L = s->L_; + if (!lua_isnil(L, -1)) { + ref_ = lua_ref(L, LUA_REGISTRYINDEX); + state_ = s; + } else { + lua_pop(L, 1); + } +} + +// RAII guard to reset stack +struct LuaState::StackReset { + lua_State* L; + int top; + ~StackReset() { + lua_settop(L, top); + } +}; + +template +inline void LuaState::PRun_(F f) { + if (option_ != kLocking) { + StackReset reset{L_, lua_gettop(L_)}; + if (option_ == kThreadLocal) { + CHECK_EQ(ThreadLocalState(), this) + << "Invoke lua from a different thread in ThreadLocal mode."; + } + f(L_); + CHECK_EQ(reset.top, lua_gettop(L_)); + } else { + std::lock_guard lock(mutex_); + StackReset reset{L_, lua_gettop(L_)}; + f(L_); + CHECK_EQ(reset.top, lua_gettop(L_)); + } +} + +inline LuaState* LuaState::ThreadLocalState() { + return ThreadLocalStore::Get(); +} + +inline LuaRef LuaState::Eval(const char* lua_code) { + LuaRef ret; + this->PRun_([this, lua_code, &ret](lua_State* L) { + luaL_loadstring(L, lua_code); + CHECK_EQ(lua_pcall(L, 0, 1, 0), 0) + << "Lua call error: " << lua_tostring(L, -1) << '\n' + << "---------\n" + << lua_code + << "\n----------"; + ret.SetByPopStack_(this); + }); + return ret; +} + +template +inline LuaRef LuaState::Convert(const T& value) { + LuaRef ret; + this->PRun_([this, &value, &ret](lua_State* L) { + lua_stack::Handler::Push(L, value); + ret.SetByPopStack_(this); + }); + return ret; +} + +inline LuaRef LuaState::operator[](const std::string& key) { + LuaRef ret; + this->PRun_([this, &key, &ret](lua_State* L) { + lua_getglobal(L, key.c_str()); + ret.SetByPopStack_(this); + }); + return ret; +} + +inline void LuaState::SetGlobalField( + const std::string& key, const LuaRef& value) { + this->PRun_([this, &key, &value](lua_State* L) { + lua_rawgeti(L, LUA_REGISTRYINDEX, value.ref_); + lua_setglobal(L, key.c_str()); + }); +} + +inline LuaRef::LuaRef(const LuaRef& other) { + if (other.state_ != nullptr) { + state_ = other.state_; + state_->PRun_([this, &other](lua_State* L) { + lua_rawgeti(L, LUA_REGISTRYINDEX, other.ref_); + ref_ = luaL_ref(L, LUA_REGISTRYINDEX); + }); + } +} + +inline LuaRef::LuaRef(LuaRef&& other) { + ref_ = other.ref_; + state_ = other.state_; + other.state_ = nullptr; +} + +inline LuaRef& LuaRef::operator=(LuaRef&& other) { + LuaRef(std::move(other)).swap(*this); + return *this; +} + +inline LuaRef& LuaRef::operator=(const LuaRef& other) { + LuaRef(other).swap(*this); + return *this; +} + +inline void LuaRef::swap(LuaRef& other) { // NOLINT(*) + std::swap(state_, other.state_); + std::swap(ref_, other.ref_); +} + +inline LuaRef::~LuaRef() { + if (state_ != nullptr) { + state_->PRun_([this](lua_State* L) { + luaL_unref(L, LUA_REGISTRYINDEX, ref_); + }); + } +} + +inline bool LuaRef::is_nil() const { + return state_ == nullptr; +} + +std::ostream &operator<<(std::ostream &os, const LuaRef &r) { + if (!r.is_nil()) { + r.state_->PRun_([&os, &r](lua_State* L) { + lua_rawgeti(L, LUA_REGISTRYINDEX, r.ref_); + int type = lua_type(L, -1); + switch (type) { + case LUA_TSTRING: + os << "lua_string:'" << lua_tostring(L, -1) << "'"; break; + case LUA_TBOOLEAN: + os << "lua_bool:" << (lua_toboolean(L, -1) ? "true" : "false"); break; + case LUA_TNUMBER: + os << "lua_number:" << lua_tonumber(L, -1); break; + default: + os << "lua[ref=" << r.ref_ << ']' << lua_typename(L, type); break; + } + lua_pop(L, 1); + }); + } else { + os << "lua_nil"; + } + return os; +} + +template +inline T LuaRef::Get() const { + CHECK(state_ != nullptr) << "Get:: LuaRef is nil"; + T ret; + state_->PRun_([&ret, this](lua_State* L) { + lua_rawgeti(L, LUA_REGISTRYINDEX, ref_); + ret = lua_stack::Handler::Get(L, -1, state_); + lua_pop(L, 1); + }); + return ret; +} + +template +inline T* LuaRef::GetUDataPtr() const { + CHECK(state_ != nullptr) << "Get:: LuaRef is nil"; + T* ret; + state_->PRun_([&ret, this](lua_State* L) { + lua_rawgeti(L, LUA_REGISTRYINDEX, ref_); + ret = reinterpret_cast(lua_touserdata(L, -1)); + lua_pop(L, 1); + }); + return ret; +} + +// helper function to dispatch varg foreach +template +struct for_each_dispatcher_ { + static inline void run(const std::tuple& args, F f) { + f(std::get(args)); + for_each_dispatcher_<(I + 1) == sizeof...(Args), (I+1), F, Args...>::run(args, f); + } +}; +// helper function to run foreach +template +struct for_each_dispatcher_ { + static inline void run(const std::tuple& args, F f) { + } +}; + +// template function to iterate over tuples +template +inline void for_each(const std::tuple& args, F f) { + for_each_dispatcher_::run(args, f); +} + +template +inline LuaRef LuaRef::operator()(Args&& ...args) const { + CHECK(state_ != nullptr) << "LuaRef is nil"; + auto targ = std::make_tuple(std::forward(args)...); + size_t nargs = sizeof...(Args); + LuaRef ret; + state_->PRun_([this, nargs, &targ, &ret](lua_State* L) { + lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_); + CHECK(lua_isfunction(L, -1)) + << "Expect to invoke a function but type='" + << lua_typename(L, lua_type(L, -1)) << '\''; + for_each(targ, lua_stack::PushArg{L}); + LUA_CALL(lua_pcall(L, nargs, 1, 0)); + ret.SetByPopStack_(state_); + }); + return ret; +} + +template +inline LuaRef& LuaRef::SetField(const std::string& key, const T& value) { // NOLINT(*) + CHECK(state_ != nullptr) << "LuaRef is nil"; + state_->PRun_([this, &key, &value](lua_State* L) { + lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_); + CHECK(lua_istable(L, -1)) + << "Expect a table but type='" + << lua_typename(L, lua_type(L, -1)) << '\''; + lua_stack::Handler::Push(L, value); + lua_setfield(L, -2, key.c_str()); + lua_pop(L, 1); + }); + return *this; +} + +inline LuaRef LuaRef::operator[](const std::string& key) const { + CHECK(state_ != nullptr) << "LuaRef is nil"; + LuaRef ret; + state_->PRun_([this, &key, &ret](lua_State* L) { + lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_); + CHECK(lua_istable(L, -1)) + << "Expect a table but type='" + << lua_typename(L, lua_type(L, -1)) << '\''; + lua_getfield(L, -1, key.c_str()); + ret.SetByPopStack_(state_); + lua_pop(L, 1); + }); + return ret; +} + +inline LuaRef LuaRef::operator[](size_t index) const { + CHECK(state_ != nullptr) << "LuaRef is nil"; + LuaRef ret; + state_->PRun_([this, index, &ret](lua_State* L) { + lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_); + CHECK(lua_istable(L, -1)) + << "Expect a table but type='" + << lua_typename(L, lua_type(L, -1)) << '\''; + lua_rawgeti(L, -1, index); + ret.SetByPopStack_(state_); + lua_pop(L, 1); + }); + return ret; +} + +//! \endcond +} // namespace dmlc + +#endif // DMLC_LUA_H_ diff --git a/include/dmlc/memory.h b/include/dmlc/memory.h new file mode 100644 index 000000000000..3a2b9b07988f --- /dev/null +++ b/include/dmlc/memory.h @@ -0,0 +1,261 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file memory.h + * \brief Additional memory hanlding utilities. + */ +#ifndef DMLC_MEMORY_H_ +#define DMLC_MEMORY_H_ + +#include +#include "./base.h" +#include "./logging.h" +#include "./thread_local.h" + +namespace dmlc { + +/*! + * \brief A memory pool that allocate memory of fixed size and alignment. + * \tparam size The size of each piece. + * \tparam align The alignment requirement of the memory. + */ +template +class MemoryPool { + public: + /*! \brief constructor */ + MemoryPool() { + static_assert(align % alignof(LinkedList) == 0, + "alignment requirement failed."); + curr_page_.reset(new Page()); + } + /*! \brief allocate a new memory of size */ + inline void* allocate() { + if (head_ != nullptr) { + LinkedList* ret = head_; + head_ = head_->next; + return ret; + } else { + if (page_ptr_ < kPageSize) { + return &(curr_page_->data[page_ptr_++]); + } else { + allocated_.push_back(std::move(curr_page_)); + curr_page_.reset(new Page()); + page_ptr_ = 1; + return &(curr_page_->data[0]); + } + } + } + /*! + * \brief deallocate a piece of memory + * \param p The pointer to the memory to be de-allocated. + */ + inline void deallocate(void* p) { + LinkedList* ptr = static_cast(p); + ptr->next = head_; + head_ = ptr; + } + + private: + // page size of each member + static const int kPageSize = ((1 << 22) / size); + // page to be requested. + struct Page { + typename std::aligned_storage::type data[kPageSize]; + }; + // internal linked list structure. + struct LinkedList { + LinkedList* next{nullptr}; + }; + // head of free list + LinkedList* head_{nullptr}; + // current free page + std::unique_ptr curr_page_; + // pointer to the current free page position. + size_t page_ptr_{0}; + // allocated pages. + std::vector > allocated_; +}; + + +/*! + * \brief A thread local allocator that get memory from a threadlocal memory pool. + * This is suitable to allocate objects that do not cross thread. + * \tparam T the type of the data to be allocated. + */ +template +class ThreadlocalAllocator { + public: + /*! \brief pointer type */ + typedef T* pointer; + /*! \brief const pointer type */ + typedef const T* const_ptr; + /*! \brief value type */ + typedef T value_type; + /*! \brief default constructor */ + ThreadlocalAllocator() {} + /*! + * \brief constructor from another allocator + * \param other another allocator + * \tparam U another type + */ + template + ThreadlocalAllocator(const ThreadlocalAllocator& other) {} + /*! + * \brief allocate memory + * \param n number of blocks + * \return an uninitialized memory of type T. + */ + inline T* allocate(size_t n) { + CHECK_EQ(n, 1); + typedef ThreadLocalStore > Store; + return static_cast(Store::Get()->allocate()); + } + /*! + * \brief deallocate memory + * \param p a memory to be returned. + * \param n number of blocks + */ + inline void deallocate(T* p, size_t n) { + CHECK_EQ(n, 1); + typedef ThreadLocalStore > Store; + Store::Get()->deallocate(p); + } +}; + + +/*! + * \brief a shared pointer like type that allocate object + * from a threadlocal object pool. This object is not thread-safe + * but can be faster than shared_ptr in certain usecases. + * \tparam T the data type. + */ +template +struct ThreadlocalSharedPtr { + public: + /*! \brief default constructor */ + ThreadlocalSharedPtr() : block_(nullptr) {} + /*! + * \brief constructor from nullptr + * \param other the nullptr type + */ + ThreadlocalSharedPtr(std::nullptr_t other) : block_(nullptr) {} // NOLINT(*) + /*! + * \brief copy constructor + * \param other another pointer. + */ + ThreadlocalSharedPtr(const ThreadlocalSharedPtr& other) + : block_(other.block_) { + IncRef(block_); + } + /*! + * \brief move constructor + * \param other another pointer. + */ + ThreadlocalSharedPtr(ThreadlocalSharedPtr&& other) + : block_(other.block_) { + other.block_ = nullptr; + } + /*! + * \brief destructor + */ + ~ThreadlocalSharedPtr() { + DecRef(block_); + } + /*! + * \brief move assignment + * \param other another object to be assigned. + * \return self. + */ + inline ThreadlocalSharedPtr& operator=(ThreadlocalSharedPtr&& other) { + DecRef(block_); + block_ = other.block_; + other.block_ = nullptr; + return *this; + } + /*! + * \brief copy assignment + * \param other another object to be assigned. + * \return self. + */ + inline ThreadlocalSharedPtr &operator=(const ThreadlocalSharedPtr& other) { + DecRef(block_); + block_ = other.block_; + IncRef(block_); + return *this; + } + /*! \brief check if nullptr */ + inline bool operator==(std::nullptr_t other) const { + return block_ == nullptr; + } + /*! + * \return get the pointer content. + */ + inline T* get() const { + if (block_ == nullptr) return nullptr; + return reinterpret_cast(&(block_->data)); + } + /*! + * \brief reset the pointer to nullptr. + */ + inline void reset() { + DecRef(block_); + block_ = nullptr; + } + /*! \return if use_count == 1*/ + inline bool unique() const { + if (block_ == nullptr) return false; + return block_->use_count_ == 1; + } + /*! \return dereference pointer */ + inline T* operator*() const { + return reinterpret_cast(&(block_->data)); + } + /*! \return dereference pointer */ + inline T* operator->() const { + return reinterpret_cast(&(block_->data)); + } + /*! + * \brief create a new space from threadlocal storage and return it. + * \tparam Args the arguments. + * \param args The input argument + * \return the allocated pointer. + */ + template + inline static ThreadlocalSharedPtr Create(Args&&... args) { + ThreadlocalAllocator arena; + ThreadlocalSharedPtr p; + p.block_ = arena.allocate(1); + p.block_->use_count_ = 1; + new (&(p.block_->data)) T(std::forward(args)...); + return p; + } + + private: + // internal reference block + struct RefBlock { + typename std::aligned_storage::type data; + unsigned use_count_; + }; + // decrease ref counter + inline static void DecRef(RefBlock* block) { + if (block != nullptr) { + if (--block->use_count_ == 0) { + ThreadlocalAllocator arena; + T* dptr = reinterpret_cast(&(block->data)); + dptr->~T(); + arena.deallocate(block, 1); + } + } + } + // increase ref counter + inline static void IncRef(RefBlock* block) { + if (block != nullptr) { + ++block->use_count_; + } + } + // internal block + RefBlock *block_; +}; + +} // namespace dmlc + +#endif // DMLC_MEMORY_H_ diff --git a/include/dmlc/memory_io.h b/include/dmlc/memory_io.h new file mode 100644 index 000000000000..4e807585cc31 --- /dev/null +++ b/include/dmlc/memory_io.h @@ -0,0 +1,105 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file memory_io.h + * \brief defines binary serialization class to serialize things into/from memory region. + */ +#ifndef DMLC_MEMORY_IO_H_ +#define DMLC_MEMORY_IO_H_ + +#include +#include +#include +#include "./base.h" +#include "./io.h" +#include "./logging.h" + +namespace dmlc { +/*! + * \brief A Stream that operates on fixed region of memory + * This class allows us to read/write from/to a fixed memory region. + */ +struct MemoryFixedSizeStream : public SeekStream { + public: + /*! + * \brief constructor + * \param p_buffer the head pointer of the memory region. + * \param buffer_size the size of the memorybuffer + */ + MemoryFixedSizeStream(void *p_buffer, size_t buffer_size) + : p_buffer_(reinterpret_cast(p_buffer)), + buffer_size_(buffer_size) { + curr_ptr_ = 0; + } + virtual size_t Read(void *ptr, size_t size) { + CHECK(curr_ptr_ + size <= buffer_size_); + size_t nread = std::min(buffer_size_ - curr_ptr_, size); + if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread); + curr_ptr_ += nread; + return nread; + } + virtual void Write(const void *ptr, size_t size) { + if (size == 0) return; + CHECK(curr_ptr_ + size <= buffer_size_); + std::memcpy(p_buffer_ + curr_ptr_, ptr, size); + curr_ptr_ += size; + } + virtual void Seek(size_t pos) { + curr_ptr_ = static_cast(pos); + } + virtual size_t Tell(void) { + return curr_ptr_; + } + + private: + /*! \brief in memory buffer */ + char *p_buffer_; + /*! \brief current pointer */ + size_t buffer_size_; + /*! \brief current pointer */ + size_t curr_ptr_; +}; // class MemoryFixedSizeStream + +/*! + * \brief A in memory stream that is backed by std::string. + * This class allows us to read/write from/to a std::string. + */ +struct MemoryStringStream : public dmlc::SeekStream { + public: + /*! + * \brief constructor + * \param p_buffer the pointer to the string. + */ + explicit MemoryStringStream(std::string *p_buffer) + : p_buffer_(p_buffer) { + curr_ptr_ = 0; + } + virtual size_t Read(void *ptr, size_t size) { + CHECK(curr_ptr_ <= p_buffer_->length()); + size_t nread = std::min(p_buffer_->length() - curr_ptr_, size); + if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread); + curr_ptr_ += nread; + return nread; + } + virtual void Write(const void *ptr, size_t size) { + if (size == 0) return; + if (curr_ptr_ + size > p_buffer_->length()) { + p_buffer_->resize(curr_ptr_+size); + } + std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size); + curr_ptr_ += size; + } + virtual void Seek(size_t pos) { + curr_ptr_ = static_cast(pos); + } + virtual size_t Tell(void) { + return curr_ptr_; + } + + private: + /*! \brief in memory buffer */ + std::string *p_buffer_; + /*! \brief current pointer */ + size_t curr_ptr_; +}; // class MemoryStringStream +} // namespace dmlc +#endif // DMLC_MEMORY_IO_H_ diff --git a/include/dmlc/omp.h b/include/dmlc/omp.h new file mode 100644 index 000000000000..8b8e506b5430 --- /dev/null +++ b/include/dmlc/omp.h @@ -0,0 +1,47 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file omp.h + * \brief header to handle OpenMP compatibility issues + */ +#ifndef DMLC_OMP_H_ +#define DMLC_OMP_H_ + + +#if defined(_OPENMP) +#include +#else + +#if defined(__ANDROID__) +#define __GOMP_NOTHROW +#elif defined(__cplusplus) +#define __GOMP_NOTHROW throw() +#else +#define __GOMP_NOTHROW __attribute__((__nothrow__)) +#endif + +//! \cond Doxygen_Suppress +#ifdef __cplusplus +extern "C" { +#endif +inline int omp_get_thread_num() __GOMP_NOTHROW { return 0; } +inline int omp_get_num_threads() __GOMP_NOTHROW { return 1; } +inline int omp_get_max_threads() __GOMP_NOTHROW { return 1; } +inline int omp_get_num_procs() __GOMP_NOTHROW { return 1; } +inline void omp_set_num_threads(int nthread) __GOMP_NOTHROW {} +#ifdef __cplusplus +} +#endif // __cplusplus +#endif // _OPENMP + +// loop variable used in openmp +namespace dmlc { +#ifdef _MSC_VER +typedef int omp_uint; +typedef long omp_ulong; // NOLINT(*) +#else +typedef unsigned omp_uint; +typedef unsigned long omp_ulong; // NOLINT(*) +#endif +//! \endcond +} // namespace dmlc +#endif // DMLC_OMP_H_ diff --git a/include/dmlc/optional.h b/include/dmlc/optional.h new file mode 100644 index 000000000000..dedbc7478102 --- /dev/null +++ b/include/dmlc/optional.h @@ -0,0 +1,261 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file optional.h + * \brief Container to hold optional data. + */ +#ifndef DMLC_OPTIONAL_H_ +#define DMLC_OPTIONAL_H_ + +#include +#include +#include +#include + +#include "./base.h" +#include "./common.h" +#include "./logging.h" +#include "./type_traits.h" + +namespace dmlc { + +/*! \brief dummy type for assign null to optional */ +struct nullopt_t { +#if defined(_MSC_VER) && _MSC_VER < 1900 + /*! \brief dummy constructor */ + explicit nullopt_t(int a) {} +#else + /*! \brief dummy constructor */ + constexpr nullopt_t(int a) {} +#endif +}; + +/*! Assign null to optional: optional x = nullopt; */ +constexpr const nullopt_t nullopt = nullopt_t(0); + +/*! + * \brief c++17 compatible optional class. + * + * At any time an optional instance either + * hold no value (string representation "None") + * or hold a value of type T. + */ +template +class optional { + public: + /*! \brief construct an optional object that contains no value */ + optional() : is_none(true) {} + /*! \brief construct an optional object with value */ + explicit optional(const T& value) { + is_none = false; + new (&val) T(value); + } + /*! \brief construct an optional object with another optional object */ + optional(const optional& other) { + is_none = other.is_none; + if (!is_none) { + new (&val) T(other.value()); + } + } + /*! \brief deconstructor */ + ~optional() { + if (!is_none) { + reinterpret_cast(&val)->~T(); + } + } + /*! \brief swap two optional */ + void swap(optional& other) { + std::swap(val, other.val); + std::swap(is_none, other.is_none); + } + /*! \brief set this object to hold value + * \param value the value to hold + * \return return self to support chain assignment + */ + optional& operator=(const T& value) { + (optional(value)).swap(*this); + return *this; + } + /*! \brief set this object to hold the same value with other + * \param other the other object + * \return return self to support chain assignment + */ + optional& operator=(const optional &other) { + (optional(other)).swap(*this); + return *this; + } + /*! \brief clear the value this object is holding. + * optional x = nullopt; + */ + optional& operator=(nullopt_t) { + (optional()).swap(*this); + return *this; + } + /*! \brief non-const dereference operator */ + T& operator*() { // NOLINT(*) + return *reinterpret_cast(&val); + } + /*! \brief const dereference operator */ + const T& operator*() const { + return *reinterpret_cast(&val); + } + /*! \brief equal comparison */ + bool operator==(const optional& other) const { + return this->is_none == other.is_none && + (this->is_none == true || this->value() == other.value()); + } + /*! \brief return the holded value. + * throws std::logic_error if holding no value + */ + const T& value() const { + if (is_none) { + throw std::logic_error("bad optional access"); + } + return *reinterpret_cast(&val); + } + /*! \brief whether this object is holding a value */ + explicit operator bool() const { return !is_none; } + /*! \brief whether this object is holding a value (alternate form). */ + bool has_value() const { return operator bool(); } + + private: + // whether this is none + bool is_none; + // on stack storage of value + typename std::aligned_storage::type val; +}; + +/*! \brief serialize an optional object to string. + * + * \code + * dmlc::optional x; + * std::cout << x; // None + * x = 0; + * std::cout << x; // 0 + * \endcode + * + * \param os output stream + * \param t source optional object + * \return output stream + */ +template +std::ostream &operator<<(std::ostream &os, const optional &t) { + if (t) { + os << *t; + } else { + os << "None"; + } + return os; +} + +/*! \brief parse a string object into optional + * + * \code + * dmlc::optional x; + * std::string s1 = "1"; + * std::istringstream is1(s1); + * s1 >> x; // x == optional(1) + * + * std::string s2 = "None"; + * std::istringstream is2(s2); + * s2 >> x; // x == optional() + * \endcode + * + * \param is input stream + * \param t target optional object + * \return input stream + */ +template +std::istream &operator>>(std::istream &is, optional &t) { + char buf[4]; + std::streampos origin = is.tellg(); + is.read(buf, 4); + if (is.fail() || buf[0] != 'N' || buf[1] != 'o' || + buf[2] != 'n' || buf[3] != 'e') { + is.clear(); + is.seekg(origin); + T x; + is >> x; + t = x; + if (std::is_integral::value && !is.eof() && is.peek() == 'L') is.get(); + } else { + t = nullopt; + } + return is; +} +/*! \brief specialization of '>>' istream parsing for optional + * + * Permits use of generic parameter FieldEntry class to create + * FieldEntry> without explicit specialization. + * + * \code + * dmlc::optional x; + * std::string s1 = "true"; + * std::istringstream is1(s1); + * s1 >> x; // x == optional(true) + * + * std::string s2 = "None"; + * std::istringstream is2(s2); + * s2 >> x; // x == optional() + * \endcode + * + * \param is input stream + * \param t target optional object + * \return input stream + */ +inline std::istream &operator>>(std::istream &is, optional &t) { + // Discard initial whitespace + while (isspace(is.peek())) + is.get(); + // Extract chars that might be valid into a separate string, stopping + // on whitespace or other non-alphanumerics such as ",)]". + std::string s; + while (isalnum(is.peek())) + s.push_back(is.get()); + + if (!is.fail()) { + std::transform(s.begin(), s.end(), s.begin(), ::tolower); + if (s == "1" || s == "true") + t = true; + else if (s == "0" || s == "false") + t = false; + else if (s == "none") + t = nullopt; + else + is.setstate(std::ios::failbit); + } + + return is; +} + +/*! \brief description for optional int */ +DMLC_DECLARE_TYPE_NAME(optional, "int or None"); +/*! \brief description for optional bool */ +DMLC_DECLARE_TYPE_NAME(optional, "boolean or None"); +/*! \brief description for optional float */ +DMLC_DECLARE_TYPE_NAME(optional, "float or None"); +/*! \brief description for optional double */ +DMLC_DECLARE_TYPE_NAME(optional, "double or None"); + +} // namespace dmlc + +namespace std { +/*! \brief std hash function for optional */ +template +struct hash > { + /*! + * \brief returns hash of the optional value. + * \param val value. + * \return hash code. + */ + size_t operator()(const dmlc::optional& val) const { + std::hash hash_bool; + size_t res = hash_bool(val.has_value()); + if (val.has_value()) { + res = dmlc::HashCombine(res, val.value()); + } + return res; + } +}; +} // namespace std + +#endif // DMLC_OPTIONAL_H_ diff --git a/include/dmlc/parameter.h b/include/dmlc/parameter.h new file mode 100644 index 000000000000..0830cb99cd19 --- /dev/null +++ b/include/dmlc/parameter.h @@ -0,0 +1,1065 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file parameter.h + * \brief Provide lightweight util to do parameter setup and checking. + */ +#ifndef DMLC_PARAMETER_H_ +#define DMLC_PARAMETER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "./base.h" +#include "./json.h" +#include "./logging.h" +#include "./type_traits.h" +#include "./optional.h" + +namespace dmlc { +// this file is backward compatible with non-c++11 +/*! \brief Error throwed by parameter checking */ +struct ParamError : public dmlc::Error { + /*! + * \brief constructor + * \param msg error message + */ + explicit ParamError(const std::string &msg) + : dmlc::Error(msg) {} +}; + +/*! + * \brief Get environment variable with default. + * \param key the name of environment variable. + * \param default_value the default value of environment vriable. + * \return The value received + */ +template +inline ValueType GetEnv(const char *key, + ValueType default_value); +/*! + * \brief Set environment variable. + * \param key the name of environment variable. + * \param value the new value for key. + * \return The value received + */ +template +inline void SetEnv(const char *key, + ValueType value); + +/*! \brief internal namespace for parameter manangement */ +namespace parameter { +// forward declare ParamManager +class ParamManager; +// forward declare FieldAccessEntry +class FieldAccessEntry; +// forward declare FieldEntry +template +class FieldEntry; +// forward declare ParamManagerSingleton +template +struct ParamManagerSingleton; + +/*! \brief option in parameter initialization */ +enum ParamInitOption { + /*! \brief allow unknown parameters */ + kAllowUnknown, + /*! \brief need to match exact parameters */ + kAllMatch, + /*! \brief allow unmatched hidden field with format __*__ */ + kAllowHidden +}; +} // namespace parameter +/*! + * \brief Information about a parameter field in string representations. + */ +struct ParamFieldInfo { + /*! \brief name of the field */ + std::string name; + /*! \brief type of the field in string format */ + std::string type; + /*! + * \brief detailed type information string + * This include the default value, enum constran and typename. + */ + std::string type_info_str; + /*! \brief detailed description of the type */ + std::string description; +}; + +/*! + * \brief Parameter is the base type every parameter struct should inheritate from + * The following code is a complete example to setup parameters. + * \code + * struct Param : public dmlc::Parameter { + * float learning_rate; + * int num_hidden; + * std::string name; + * // declare parameters in header file + * DMLC_DECLARE_PARAMETER(Param) { + * DMLC_DECLARE_FIELD(num_hidden).set_range(0, 1000); + * DMLC_DECLARE_FIELD(learning_rate).set_default(0.01f); + * DMLC_DECLARE_FIELD(name).set_default("hello"); + * } + * }; + * // register it in cc file + * DMLC_REGISTER_PARAMETER(Param); + * \endcode + * + * After that, the Param struct will get all the functions defined in Parameter. + * \tparam PType the type of parameter struct + * + * \sa DMLC_DECLARE_FIELD, DMLC_REGISTER_PARAMETER, DMLC_DECLARE_PARAMETER + */ +template +struct Parameter { + public: + /*! + * \brief initialize the parameter by keyword arguments. + * This function will initialize the parameter struct, check consistency + * and throw error if something wrong happens. + * + * \param kwargs map of keyword arguments, or vector of pairs + * \parma option The option on initialization. + * \tparam Container container type + * \throw ParamError when something go wrong. + */ + template + inline void Init(const Container &kwargs, + parameter::ParamInitOption option = parameter::kAllowHidden) { + PType::__MANAGER__()->RunInit(static_cast(this), + kwargs.begin(), kwargs.end(), + NULL, + option); + } + /*! + * \brief initialize the parameter by keyword arguments. + * This is same as Init, but allow unknown arguments. + * + * \param kwargs map of keyword arguments, or vector of pairs + * \tparam Container container type + * \throw ParamError when something go wrong. + * \return vector of pairs of unknown arguments. + */ + template + inline std::vector > + InitAllowUnknown(const Container &kwargs) { + std::vector > unknown; + PType::__MANAGER__()->RunInit(static_cast(this), + kwargs.begin(), kwargs.end(), + &unknown, parameter::kAllowUnknown); + return unknown; + } + + /*! + * \brief Update the dict with values stored in parameter. + * + * \param dict The dictionary to be updated. + * \tparam Container container type + */ + template + inline void UpdateDict(Container *dict) const { + PType::__MANAGER__()->UpdateDict(this->head(), dict); + } + /*! + * \brief Return a dictionary representation of the parameters + * \return A dictionary that maps key -> value + */ + inline std::map __DICT__() const { + std::vector > vec + = PType::__MANAGER__()->GetDict(this->head()); + return std::map(vec.begin(), vec.end()); + } + /*! + * \brief Write the parameters in JSON format. + * \param writer JSONWriter used for writing. + */ + inline void Save(dmlc::JSONWriter *writer) const { + writer->Write(this->__DICT__()); + } + /*! + * \brief Load the parameters from JSON. + * \param reader JSONReader used for loading. + * \throw ParamError when something go wrong. + */ + inline void Load(dmlc::JSONReader *reader) { + std::map kwargs; + reader->Read(&kwargs); + this->Init(kwargs); + } + /*! + * \brief Get the fields of the parameters. + * \return List of ParamFieldInfo of each field. + */ + inline static std::vector __FIELDS__() { + return PType::__MANAGER__()->GetFieldInfo(); + } + /*! + * \brief Print docstring of the parameter + * \return the printed docstring + */ + inline static std::string __DOC__() { + std::ostringstream os; + PType::__MANAGER__()->PrintDocString(os); + return os.str(); + } + + protected: + /*! + * \brief internal function to allow declare of a parameter memember + * \param manager the parameter manager + * \param key the key name of the parameter + * \param ref the reference to the parameter in the struct. + */ + template + inline parameter::FieldEntry& DECLARE( + parameter::ParamManagerSingleton *manager, + const std::string &key, DType &ref) { // NOLINT(*) + parameter::FieldEntry *e = + new parameter::FieldEntry(); + e->Init(key, this->head(), ref); + manager->manager.AddEntry(key, e); + return *e; + } + + private: + /*! \return Get head pointer of child structure */ + inline PType *head() const { + return static_cast(const_cast*>(this)); + } +}; + +//! \cond Doxygen_Suppress +/*! + * \brief macro used to declare parameter + * + * Example: + * \code + * struct Param : public dmlc::Parameter { + * // declare parameters in header file + * DMLC_DECLARE_PARAMETER(Param) { + * // details of declarations + * } + * }; + * \endcode + * + * This macro need to be put in a source file so that registeration only happens once. + * Refer to example code in Parameter for details + * + * \param PType the name of parameter struct. + * \sa Parameter + */ +#define DMLC_DECLARE_PARAMETER(PType) \ + static ::dmlc::parameter::ParamManager *__MANAGER__(); \ + inline void __DECLARE__(::dmlc::parameter::ParamManagerSingleton *manager) \ + +/*! + * \brief macro to declare fields + * \param FieldName the name of the field. + */ +#define DMLC_DECLARE_FIELD(FieldName) this->DECLARE(manager, #FieldName, FieldName) + +/*! + * \brief macro to declare alias of a fields + * \param FieldName the name of the field. + * \param AliasName the name of the alias, must be declared after the field is declared. + */ +#define DMLC_DECLARE_ALIAS(FieldName, AliasName) manager->manager.AddAlias(#FieldName, #AliasName) + +/*! + * \brief Macro used to register parameter. + * + * This macro need to be put in a source file so that registeration only happens once. + * Refer to example code in Parameter for details + * \param PType the type of parameter struct. + * \sa Parameter + */ +#define DMLC_REGISTER_PARAMETER(PType) \ + ::dmlc::parameter::ParamManager *PType::__MANAGER__() { \ + static ::dmlc::parameter::ParamManagerSingleton inst(#PType); \ + return &inst.manager; \ + } \ + static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \ + __make__ ## PType ## ParamManager__ = \ + (*PType::__MANAGER__()) \ + +//! \endcond +/*! + * \brief internal namespace for parameter manangement + * There is no need to use it directly in normal case + */ +namespace parameter { +/*! + * \brief FieldAccessEntry interface to help manage the parameters + * Each entry can be used to access one parameter in the Parameter struct. + * + * This is an internal interface used that is used to manage parameters + */ +class FieldAccessEntry { + public: + FieldAccessEntry() + : has_default_(false) {} + /*! \brief destructor */ + virtual ~FieldAccessEntry() {} + /*! + * \brief set the default value. + * \param head the pointer to the head of the struct + * \throw error if no default is presented + */ + virtual void SetDefault(void *head) const = 0; + /*! + * \brief set the parameter by string value + * \param head the pointer to the head of the struct + * \param value the value to be set + */ + virtual void Set(void *head, const std::string &value) const = 0; + // check if value is OK + virtual void Check(void *head) const {} + /*! + * \brief get the string representation of value. + * \param head the pointer to the head of the struct + */ + virtual std::string GetStringValue(void *head) const = 0; + /*! + * \brief Get field information + * \return the corresponding field information + */ + virtual ParamFieldInfo GetFieldInfo() const = 0; + + protected: + /*! \brief whether this parameter have default value */ + bool has_default_; + /*! \brief positional index of parameter in struct */ + size_t index_; + /*! \brief parameter key name */ + std::string key_; + /*! \brief parameter type */ + std::string type_; + /*! \brief description of the parameter */ + std::string description_; + /*! + * \brief print string representation of default value + * \parma os the stream to print the docstring to. + */ + virtual void PrintDefaultValueString(std::ostream &os) const = 0; // NOLINT(*) + // allow ParamManager to modify self + friend class ParamManager; +}; + +/*! + * \brief manager class to handle parameter structure for each type + * An manager will be created for each parameter structure. + */ +class ParamManager { + public: + /*! \brief destructor */ + ~ParamManager() { + for (size_t i = 0; i < entry_.size(); ++i) { + delete entry_[i]; + } + } + /*! + * \brief find the access entry by parameter key + * \param key the key of the parameter. + * \return pointer to FieldAccessEntry, NULL if nothing is found. + */ + inline FieldAccessEntry *Find(const std::string &key) const { + std::map::const_iterator it = + entry_map_.find(key); + if (it == entry_map_.end()) return NULL; + return it->second; + } + /*! + * \brief set parameter by keyword arguments. + * \param head head to the parameter field. + * \param begin begin iterator of original kwargs + * \param end end iterator of original kwargs + * \param unknown_args optional, used to hold unknown arguments + * When it is specified, unknown arguments will be stored into here, instead of raise an error + * \tparam RandomAccessIterator iterator type + * \throw ParamError when there is unknown argument and unknown_args == NULL, or required argument is missing. + */ + template + inline void RunInit(void *head, + RandomAccessIterator begin, + RandomAccessIterator end, + std::vector > *unknown_args, + parameter::ParamInitOption option) const { + std::set selected_args; + for (RandomAccessIterator it = begin; it != end; ++it) { + FieldAccessEntry *e = Find(it->first); + if (e != NULL) { + e->Set(head, it->second); + e->Check(head); + selected_args.insert(e); + } else { + if (unknown_args != NULL) { + unknown_args->push_back(*it); + } else { + if (option != parameter::kAllowUnknown) { + if (option == parameter::kAllowHidden && + it->first.length() > 4 && + it->first.find("__") == 0 && + it->first.rfind("__") == it->first.length()-2) { + continue; + } + std::ostringstream os; + os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n"; + os << "----------------\n"; + PrintDocString(os); + throw dmlc::ParamError(os.str()); + } + } + } + } + + for (std::map::const_iterator it = entry_map_.begin(); + it != entry_map_.end(); ++it) { + if (selected_args.count(it->second) == 0) { + it->second->SetDefault(head); + } + } + } + /*! + * \brief internal function to add entry to manager, + * The manager will take ownership of the entry. + * \param key the key to the parameters + * \param e the pointer to the new entry. + */ + inline void AddEntry(const std::string &key, FieldAccessEntry *e) { + e->index_ = entry_.size(); + // TODO(bing) better error message + if (entry_map_.count(key) != 0) { + LOG(FATAL) << "key " << key << " has already been registered in " << name_; + } + entry_.push_back(e); + entry_map_[key] = e; + } + /*! + * \brief internal function to add entry to manager, + * The manager will take ownership of the entry. + * \param key the key to the parameters + * \param e the pointer to the new entry. + */ + inline void AddAlias(const std::string& field, const std::string& alias) { + if (entry_map_.count(field) == 0) { + LOG(FATAL) << "key " << field << " has not been registered in " << name_; + } + if (entry_map_.count(alias) != 0) { + LOG(FATAL) << "Alias " << alias << " has already been registered in " << name_; + } + entry_map_[alias] = entry_map_[field]; + } + /*! + * \brief set the name of parameter manager + * \param name the name to set + */ + inline void set_name(const std::string &name) { + name_ = name; + } + /*! + * \brief get field information of each field. + * \return field information + */ + inline std::vector GetFieldInfo() const { + std::vector ret(entry_.size()); + for (size_t i = 0; i < entry_.size(); ++i) { + ret[i] = entry_[i]->GetFieldInfo(); + } + return ret; + } + /*! + * \brief Print readible docstring to ostream, add newline. + * \parma os the stream to print the docstring to. + */ + inline void PrintDocString(std::ostream &os) const { // NOLINT(*) + for (size_t i = 0; i < entry_.size(); ++i) { + ParamFieldInfo info = entry_[i]->GetFieldInfo(); + os << info.name << " : " << info.type_info_str << '\n'; + if (info.description.length() != 0) { + os << " " << info.description << '\n'; + } + } + } + /*! + * \brief Get internal parameters in vector of pairs. + * \param head the head of the struct. + * \param skip_default skip the values that equals default value. + * \return the parameter dictionary. + */ + inline std::vector > GetDict(void * head) const { + std::vector > ret; + for (std::map::const_iterator + it = entry_map_.begin(); it != entry_map_.end(); ++it) { + ret.push_back(std::make_pair(it->first, it->second->GetStringValue(head))); + } + return ret; + } + /*! + * \brief Update the dictionary with values in parameter. + * \param head the head of the struct. + * \tparam Container The container type + * \return the parameter dictionary. + */ + template + inline void UpdateDict(void * head, Container* dict) const { + for (std::map::const_iterator + it = entry_map_.begin(); it != entry_map_.end(); ++it) { + (*dict)[it->first] = it->second->GetStringValue(head); + } + } + + private: + /*! \brief parameter struct name */ + std::string name_; + /*! \brief positional list of entries */ + std::vector entry_; + /*! \brief map from key to entry */ + std::map entry_map_; +}; + +//! \cond Doxygen_Suppress + +// The following piece of code will be template heavy and less documented +// singleton parameter manager for certain type, used for initialization +template +struct ParamManagerSingleton { + ParamManager manager; + explicit ParamManagerSingleton(const std::string ¶m_name) { + PType param; + manager.set_name(param_name); + param.__DECLARE__(this); + } +}; + +// Base class of FieldEntry +// implement set_default +template +class FieldEntryBase : public FieldAccessEntry { + public: + // entry type + typedef TEntry EntryType; + // implement set value + virtual void Set(void *head, const std::string &value) const { + std::istringstream is(value); + is >> this->Get(head); + if (!is.fail()) { + while (!is.eof()) { + int ch = is.get(); + if (ch == EOF) { + is.clear(); break; + } + if (!isspace(ch)) { + is.setstate(std::ios::failbit); break; + } + } + } + + if (is.fail()) { + std::ostringstream os; + os << "Invalid Parameter format for " << key_ + << " expect " << type_ << " but value=\'" << value<< '\''; + throw dmlc::ParamError(os.str()); + } + } + virtual std::string GetStringValue(void *head) const { + std::ostringstream os; + PrintValue(os, this->Get(head)); + return os.str(); + } + virtual ParamFieldInfo GetFieldInfo() const { + ParamFieldInfo info; + std::ostringstream os; + info.name = key_; + info.type = type_; + os << type_; + if (has_default_) { + os << ',' << " optional, default="; + PrintDefaultValueString(os); + } else { + os << ", required"; + } + info.type_info_str = os.str(); + info.description = description_; + return info; + } + // implement set head to default value + virtual void SetDefault(void *head) const { + if (!has_default_) { + std::ostringstream os; + os << "Required parameter " << key_ + << " of " << type_ << " is not presented"; + throw dmlc::ParamError(os.str()); + } else { + this->Get(head) = default_value_; + } + } + // return reference of self as derived type + inline TEntry &self() { + return *(static_cast(this)); + } + // implement set_default + inline TEntry &set_default(const DType &default_value) { + default_value_ = default_value; + has_default_ = true; + // return self to allow chaining + return this->self(); + } + // implement describe + inline TEntry &describe(const std::string &description) { + description_ = description; + // return self to allow chaining + return this->self(); + } + // initialization function + inline void Init(const std::string &key, + void *head, DType &ref) { // NOLINT(*) + this->key_ = key; + if (this->type_.length() == 0) { + this->type_ = dmlc::type_name(); + } + this->offset_ = ((char*)&ref) - ((char*)head); // NOLINT(*) + } + + protected: + // print the value + virtual void PrintValue(std::ostream &os, DType value) const { // NOLINT(*) + os << value; + } + virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*) + PrintValue(os, default_value_); + } + // get the internal representation of parameter + // for example if this entry corresponds field param.learning_rate + // then Get(¶m) will return reference to param.learning_rate + inline DType &Get(void *head) const { + return *(DType*)((char*)(head) + offset_); // NOLINT(*) + } + // internal offset of the field + ptrdiff_t offset_; + // default value of field + DType default_value_; +}; + +// parameter base for numeric types that have range +template +class FieldEntryNumeric + : public FieldEntryBase { + public: + FieldEntryNumeric() + : has_begin_(false), has_end_(false) {} + // implement set_range + virtual TEntry &set_range(DType begin, DType end) { + begin_ = begin; end_ = end; + has_begin_ = true; has_end_ = true; + return this->self(); + } + // implement set_range + virtual TEntry &set_lower_bound(DType begin) { + begin_ = begin; has_begin_ = true; + return this->self(); + } + // consistency check for numeric ranges + virtual void Check(void *head) const { + FieldEntryBase::Check(head); + DType v = this->Get(head); + if (has_begin_ && has_end_) { + if (v < begin_ || v > end_) { + std::ostringstream os; + os << "value " << v << " for Parameter " << this->key_ + << " exceed bound [" << begin_ << ',' << end_ <<']'; + throw dmlc::ParamError(os.str()); + } + } else if (has_begin_ && v < begin_) { + std::ostringstream os; + os << "value " << v << " for Parameter " << this->key_ + << " should be greater equal to " << begin_; + throw dmlc::ParamError(os.str()); + } else if (has_end_ && v > end_) { + std::ostringstream os; + os << "value " << v << " for Parameter " << this->key_ + << " should be smaller equal to " << end_; + throw dmlc::ParamError(os.str()); + } + } + + protected: + // whether it have begin and end range + bool has_begin_, has_end_; + // data bound + DType begin_, end_; +}; + +/*! + * \brief FieldEntry defines parsing and checking behavior of DType. + * This class can be specialized to implement specific behavior of more settings. + * \tparam DType the data type of the entry. + */ +template +class FieldEntry : + public IfThenElseType::value, + FieldEntryNumeric, DType>, + FieldEntryBase, DType> >::Type { +}; + +// specialize define for int(enum) +template<> +class FieldEntry + : public FieldEntryNumeric, int> { + public: + // construct + FieldEntry() : is_enum_(false) {} + // parent + typedef FieldEntryNumeric, int> Parent; + // override set + virtual void Set(void *head, const std::string &value) const { + if (is_enum_) { + std::map::const_iterator it = enum_map_.find(value); + std::ostringstream os; + if (it == enum_map_.end()) { + os << "Invalid Input: \'" << value; + os << "\', valid values are: "; + PrintEnums(os); + throw dmlc::ParamError(os.str()); + } else { + os << it->second; + Parent::Set(head, os.str()); + } + } else { + Parent::Set(head, value); + } + } + virtual ParamFieldInfo GetFieldInfo() const { + if (is_enum_) { + ParamFieldInfo info; + std::ostringstream os; + info.name = key_; + info.type = type_; + PrintEnums(os); + if (has_default_) { + os << ',' << "optional, default="; + PrintDefaultValueString(os); + } else { + os << ", required"; + } + info.type_info_str = os.str(); + info.description = description_; + return info; + } else { + return Parent::GetFieldInfo(); + } + } + // add enum + inline FieldEntry &add_enum(const std::string &key, int value) { + if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \ + enum_back_map_.count(value) != 0) { + std::ostringstream os; + os << "Enum " << "(" << key << ": " << value << " exisit!" << ")\n"; + os << "Enums: "; + for (std::map::const_iterator it = enum_map_.begin(); + it != enum_map_.end(); ++it) { + os << "(" << it->first << ": " << it->second << "), "; + } + throw dmlc::ParamError(os.str()); + } + enum_map_[key] = value; + enum_back_map_[value] = key; + is_enum_ = true; + return this->self(); + } + + protected: + // enum flag + bool is_enum_; + // enum map + std::map enum_map_; + // enum map + std::map enum_back_map_; + // override print behavior + virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*) + os << '\''; + PrintValue(os, default_value_); + os << '\''; + } + // override print default + virtual void PrintValue(std::ostream &os, int value) const { // NOLINT(*) + if (is_enum_) { + CHECK_NE(enum_back_map_.count(value), 0U) + << "Value not found in enum declared"; + os << enum_back_map_.at(value); + } else { + os << value; + } + } + + + private: + inline void PrintEnums(std::ostream &os) const { // NOLINT(*) + os << '{'; + for (std::map::const_iterator + it = enum_map_.begin(); it != enum_map_.end(); ++it) { + if (it != enum_map_.begin()) { + os << ", "; + } + os << "\'" << it->first << '\''; + } + os << '}'; + } +}; + + +// specialize define for optional(enum) +template<> +class FieldEntry > + : public FieldEntryBase >, optional > { + public: + // construct + FieldEntry >() : is_enum_(false) {} + // parent + typedef FieldEntryBase >, optional > Parent; + // override set + virtual void Set(void *head, const std::string &value) const { + if (is_enum_ && value != "None") { + std::map::const_iterator it = enum_map_.find(value); + std::ostringstream os; + if (it == enum_map_.end()) { + os << "Invalid Input: \'" << value; + os << "\', valid values are: "; + PrintEnums(os); + throw dmlc::ParamError(os.str()); + } else { + os << it->second; + Parent::Set(head, os.str()); + } + } else { + Parent::Set(head, value); + } + } + virtual ParamFieldInfo GetFieldInfo() const { + if (is_enum_) { + ParamFieldInfo info; + std::ostringstream os; + info.name = key_; + info.type = type_; + PrintEnums(os); + if (has_default_) { + os << ',' << "optional, default="; + PrintDefaultValueString(os); + } else { + os << ", required"; + } + info.type_info_str = os.str(); + info.description = description_; + return info; + } else { + return Parent::GetFieldInfo(); + } + } + // add enum + inline FieldEntry > &add_enum(const std::string &key, int value) { + CHECK_NE(key, "None") << "None is reserved for empty optional"; + if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \ + enum_back_map_.count(value) != 0) { + std::ostringstream os; + os << "Enum " << "(" << key << ": " << value << " exisit!" << ")\n"; + os << "Enums: "; + for (std::map::const_iterator it = enum_map_.begin(); + it != enum_map_.end(); ++it) { + os << "(" << it->first << ": " << it->second << "), "; + } + throw dmlc::ParamError(os.str()); + } + enum_map_[key] = value; + enum_back_map_[value] = key; + is_enum_ = true; + return this->self(); + } + + protected: + // enum flag + bool is_enum_; + // enum map + std::map enum_map_; + // enum map + std::map enum_back_map_; + // override print behavior + virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*) + os << '\''; + PrintValue(os, default_value_); + os << '\''; + } + // override print default + virtual void PrintValue(std::ostream &os, optional value) const { // NOLINT(*) + if (is_enum_) { + if (!value) { + os << "None"; + } else { + CHECK_NE(enum_back_map_.count(value.value()), 0U) + << "Value not found in enum declared"; + os << enum_back_map_.at(value.value()); + } + } else { + os << value; + } + } + + + private: + inline void PrintEnums(std::ostream &os) const { // NOLINT(*) + os << "{None"; + for (std::map::const_iterator + it = enum_map_.begin(); it != enum_map_.end(); ++it) { + os << ", "; + os << "\'" << it->first << '\''; + } + os << '}'; + } +}; + +// specialize define for string +template<> +class FieldEntry + : public FieldEntryBase, std::string> { + public: + // parent class + typedef FieldEntryBase, std::string> Parent; + // override set + virtual void Set(void *head, const std::string &value) const { + this->Get(head) = value; + } + // override print default + virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*) + os << '\'' << default_value_ << '\''; + } +}; + +// specialize define for bool +template<> +class FieldEntry + : public FieldEntryBase, bool> { + public: + // parent class + typedef FieldEntryBase, bool> Parent; + // override set + virtual void Set(void *head, const std::string &value) const { + std::string lower_case; lower_case.resize(value.length()); + std::transform(value.begin(), value.end(), lower_case.begin(), ::tolower); + bool &ref = this->Get(head); + if (lower_case == "true") { + ref = true; + } else if (lower_case == "false") { + ref = false; + } else if (lower_case == "1") { + ref = true; + } else if (lower_case == "0") { + ref = false; + } else { + std::ostringstream os; + os << "Invalid Parameter format for " << key_ + << " expect " << type_ << " but value=\'" << value<< '\''; + throw dmlc::ParamError(os.str()); + } + } + + protected: + // print default string + virtual void PrintValue(std::ostream &os, bool value) const { // NOLINT(*) + os << static_cast(value); + } +}; + + +// specialize define for float. Uses stof for platform independent handling of +// INF, -INF, NAN, etc. +#if DMLC_USE_CXX11 +template <> +class FieldEntry : public FieldEntryNumeric, float> { + public: + // parent + typedef FieldEntryNumeric, float> Parent; + // override set + virtual void Set(void *head, const std::string &value) const { + try { + this->Get(head) = std::stof(value); + } catch (const std::invalid_argument &) { + std::ostringstream os; + os << "Invalid Parameter format for " << key_ << " expect " << type_ + << " but value=\'" << value << '\''; + throw dmlc::ParamError(os.str()); + } catch (const std::out_of_range&) { + std::ostringstream os; + os << "Out of range value for " << key_ << ", value=\'" << value << '\''; + throw dmlc::ParamError(os.str()); + } + } +}; + +// specialize define for double. Uses stod for platform independent handling of +// INF, -INF, NAN, etc. +template <> +class FieldEntry + : public FieldEntryNumeric, double> { + public: + // parent + typedef FieldEntryNumeric, double> Parent; + // override set + virtual void Set(void *head, const std::string &value) const { + try { + this->Get(head) = std::stod(value); + } catch (const std::invalid_argument &) { + std::ostringstream os; + os << "Invalid Parameter format for " << key_ << " expect " << type_ + << " but value=\'" << value << '\''; + throw dmlc::ParamError(os.str()); + } catch (const std::out_of_range&) { + std::ostringstream os; + os << "Out of range value for " << key_ << ", value=\'" << value << '\''; + throw dmlc::ParamError(os.str()); + } + } +}; +#endif // DMLC_USE_CXX11 + +} // namespace parameter +//! \endcond + +// implement GetEnv +template +inline ValueType GetEnv(const char *key, + ValueType default_value) { + const char *val = getenv(key); + // On some implementations, if the var is set to a blank string (i.e. "FOO="), then + // a blank string will be returned instead of NULL. In order to be consistent, if + // the environment var is a blank string, then also behave as if a null was returned. + if (val == nullptr || !*val) { + return default_value; + } + ValueType ret; + parameter::FieldEntry e; + e.Init(key, &ret, ret); + e.Set(&ret, val); + return ret; +} + +// implement SetEnv +template +inline void SetEnv(const char *key, + ValueType value) { + parameter::FieldEntry e; + e.Init(key, &value, value); +#ifdef _WIN32 + _putenv(key, e.GetStringValue(&value).c_str()); +#else + setenv(key, e.GetStringValue(&value).c_str(), 1); +#endif // _WIN32 +} +} // namespace dmlc +#endif // DMLC_PARAMETER_H_ diff --git a/include/dmlc/recordio.h b/include/dmlc/recordio.h new file mode 100644 index 000000000000..6220780acadc --- /dev/null +++ b/include/dmlc/recordio.h @@ -0,0 +1,196 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file recordio.h + * \brief recordio that is able to pack binary data into a splittable + * format, useful to exchange data in binary serialization, + * such as binary raw data or protobuf + */ +#ifndef DMLC_RECORDIO_H_ +#define DMLC_RECORDIO_H_ +#include +#include +#include "./io.h" +#include "./logging.h" + +namespace dmlc { +/*! + * \brief writer of binary recordio + * binary format for recordio + * recordio format: magic lrecord data pad + * + * - magic is magic number + * - pad is simply a padding space to make record align to 4 bytes + * - lrecord encodes length and continue bit + * - data.length() = (lrecord & (1U<<29U - 1)); + * - cflag == (lrecord >> 29U) & 7; + * + * cflag was used to handle (rare) special case when magic number + * occured in the data sequence. + * + * In such case, the data is splitted into multiple records by + * the cells of magic number + * + * (1) cflag == 0: this is a complete record; + * (2) cflag == 1: start of a multiple-rec; + * cflag == 2: middle of multiple-rec; + * cflag == 3: end of multiple-rec + */ +class RecordIOWriter { + public: + /*! + * \brief magic number of recordio + * note: (kMagic >> 29U) & 7 > 3 + * this ensures lrec will not be kMagic + */ + static const uint32_t kMagic = 0xced7230a; + /*! + * \brief encode the lrecord + * \param cflag cflag part of the lrecord + * \param length length part of lrecord + * \return the encoded data + */ + inline static uint32_t EncodeLRec(uint32_t cflag, uint32_t length) { + return (cflag << 29U) | length; + } + /*! + * \brief decode the flag part of lrecord + * \param rec the lrecord + * \return the flag + */ + inline static uint32_t DecodeFlag(uint32_t rec) { + return (rec >> 29U) & 7U; + } + /*! + * \brief decode the length part of lrecord + * \param rec the lrecord + * \return the length + */ + inline static uint32_t DecodeLength(uint32_t rec) { + return rec & ((1U << 29U) - 1U); + } + /*! + * \brief constructor + * \param stream the stream to be constructed + */ + explicit RecordIOWriter(Stream *stream) + : stream_(stream), seek_stream_(dynamic_cast(stream)), + except_counter_(0) { + CHECK(sizeof(uint32_t) == 4) << "uint32_t needs to be 4 bytes"; + } + /*! + * \brief write record to the stream + * \param buf the buffer of memory region + * \param size the size of record to write out + */ + void WriteRecord(const void *buf, size_t size); + /*! + * \brief write record to the stream + * \param data the data to write out + */ + inline void WriteRecord(const std::string &data) { + this->WriteRecord(data.c_str(), data.length()); + } + /*! + * \return number of exceptions(occurance of magic number) + * during the writing process + */ + inline size_t except_counter(void) const { + return except_counter_; + } + + /*! \brief tell the current position of the input stream */ + inline size_t Tell(void) { + CHECK(seek_stream_ != NULL) << "The input stream is not seekable"; + return seek_stream_->Tell(); + } + + private: + /*! \brief output stream */ + Stream *stream_; + /*! \brief seekable stream */ + SeekStream *seek_stream_; + /*! \brief counts the number of exceptions */ + size_t except_counter_; +}; +/*! + * \brief reader of binary recordio to reads in record from stream + * \sa RecordIOWriter + */ +class RecordIOReader { + public: + /*! + * \brief constructor + * \param stream the stream to be constructed + */ + explicit RecordIOReader(Stream *stream) + : stream_(stream), seek_stream_(dynamic_cast(stream)), + end_of_stream_(false) { + CHECK(sizeof(uint32_t) == 4) << "uint32_t needs to be 4 bytes"; + } + /*! + * \brief read next complete record from stream + * \param out_rec used to store output record in string + * \return true of read was successful, false if end of stream was reached + */ + bool NextRecord(std::string *out_rec); + + /*! \brief seek to certain position of the input stream */ + inline void Seek(size_t pos) { + CHECK(seek_stream_ != NULL) << "The input stream is not seekable"; + seek_stream_->Seek(pos); + } + + /*! \brief tell the current position of the input stream */ + inline size_t Tell(void) { + CHECK(seek_stream_ != NULL) << "The input stream is not seekable"; + return seek_stream_->Tell(); + } + + private: + /*! \brief output stream */ + Stream *stream_; + SeekStream *seek_stream_; + /*! \brief whether we are at end of stream */ + bool end_of_stream_; +}; + +/*! + * \brief reader of binary recordio from Blob returned by InputSplit + * This class divides the blob into several independent parts specified by caller, + * and read from one segment. + * The part reading can be used together with InputSplit::NextChunk for + * multi-threaded parsing(each thread take a RecordIOChunkReader) + * + * \sa RecordIOWriter, InputSplit + */ +class RecordIOChunkReader { + public: + /*! + * \brief constructor + * \param chunk source data returned by InputSplit + * \param part_index which part we want to reado + * \param num_parts number of total segments + */ + explicit RecordIOChunkReader(InputSplit::Blob chunk, + unsigned part_index = 0, + unsigned num_parts = 1); + /*! + * \brief read next complete record from stream + * the blob contains the memory content + * NOTE: this function is not threadsafe, use one + * RecordIOChunkReader per thread + * \param out_rec used to store output blob, the header is already + * removed and out_rec only contains the memory content + * \return true of read was successful, false if end was reached + */ + bool NextRecord(InputSplit::Blob *out_rec); + + private: + /*! \brief internal temporal data */ + std::string temp_; + /*! \brief internal data pointer */ + char *pbegin_, *pend_; +}; + +} // namespace dmlc +#endif // DMLC_RECORDIO_H_ diff --git a/include/dmlc/registry.h b/include/dmlc/registry.h new file mode 100644 index 000000000000..d68b57597250 --- /dev/null +++ b/include/dmlc/registry.h @@ -0,0 +1,306 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file registry.h + * \brief Registry utility that helps to build registry singletons. + */ +#ifndef DMLC_REGISTRY_H_ +#define DMLC_REGISTRY_H_ + +#include +#include +#include +#include "./base.h" +#include "./logging.h" +#include "./parameter.h" +#include "./type_traits.h" + +namespace dmlc { +/*! + * \brief Registry class. + * Registry can be used to register global singletons. + * The most commonly use case are factory functions. + * + * \tparam EntryType Type of Registry entries, + * EntryType need to name a name field. + */ +template +class Registry { + public: + /*! \return list of entries in the registry(excluding alias) */ + inline static const std::vector& List() { + return Get()->const_list_; + } + /*! \return list all names registered in the registry, including alias */ + inline static std::vector ListAllNames() { + const std::map &fmap = Get()->fmap_; + typename std::map::const_iterator p; + std::vector names; + for (p = fmap.begin(); p !=fmap.end(); ++p) { + names.push_back(p->first); + } + return names; + } + /*! + * \brief Find the entry with corresponding name. + * \param name name of the function + * \return the corresponding function, can be NULL + */ + inline static const EntryType *Find(const std::string &name) { + const std::map &fmap = Get()->fmap_; + typename std::map::const_iterator p = fmap.find(name); + if (p != fmap.end()) { + return p->second; + } else { + return NULL; + } + } + /*! + * \brief Add alias to the key_name + * \param key_name The original entry key + * \param alias The alias key. + */ + inline void AddAlias(const std::string& key_name, + const std::string& alias) { + EntryType* e = fmap_.at(key_name); + if (fmap_.count(alias)) { + CHECK_EQ(e, fmap_.at(alias)) + << "Trying to register alias " << alias << " for key " << key_name + << " but " << alias << " is already taken"; + } else { + fmap_[alias] = e; + } + } + /*! + * \brief Internal function to register a name function under name. + * \param name name of the function + * \return ref to the registered entry, used to set properties + */ + inline EntryType &__REGISTER__(const std::string& name) { + CHECK_EQ(fmap_.count(name), 0U) + << name << " already registered"; + EntryType *e = new EntryType(); + e->name = name; + fmap_[name] = e; + const_list_.push_back(e); + entry_list_.push_back(e); + return *e; + } + /*! + * \brief Internal function to either register or get registered entry + * \param name name of the function + * \return ref to the registered entry, used to set properties + */ + inline EntryType &__REGISTER_OR_GET__(const std::string& name) { + if (fmap_.count(name) == 0) { + return __REGISTER__(name); + } else { + return *fmap_.at(name); + } + } + /*! + * \brief get a singleton of the Registry. + * This function can be defined by DMLC_REGISTRY_ENABLE. + * \return get a singleton + */ + static Registry *Get(); + + private: + /*! \brief list of entry types */ + std::vector entry_list_; + /*! \brief list of entry types */ + std::vector const_list_; + /*! \brief map of name->function */ + std::map fmap_; + /*! \brief constructor */ + Registry() {} + /*! \brief destructor */ + ~Registry() { + for (size_t i = 0; i < entry_list_.size(); ++i) { + delete entry_list_[i]; + } + } +}; + +/*! + * \brief Common base class for function registry. + * + * \code + * // This example demonstrates how to use Registry to create a factory of trees. + * struct TreeFactory : + * public FunctionRegEntryBase > { + * }; + * + * // in a independent cc file + * namespace dmlc { + * DMLC_REGISTRY_ENABLE(TreeFactory); + * } + * // register binary tree constructor into the registry. + * DMLC_REGISTRY_REGISTER(TreeFactory, TreeFactory, BinaryTree) + * .describe("Constructor of BinaryTree") + * .set_body([]() { return new BinaryTree(); }); + * \endcode + * + * \tparam EntryType The type of subclass that inheritate the base. + * \tparam FunctionType The function type this registry is registerd. + */ +template +class FunctionRegEntryBase { + public: + /*! \brief name of the entry */ + std::string name; + /*! \brief description of the entry */ + std::string description; + /*! \brief additional arguments to the factory function */ + std::vector arguments; + /*! \brief Function body to create ProductType */ + FunctionType body; + /*! \brief Return type of the function */ + std::string return_type; + + /*! + * \brief Set the function body. + * \param body Function body to set. + * \return reference to self. + */ + inline EntryType &set_body(FunctionType body) { + this->body = body; + return this->self(); + } + /*! + * \brief Describe the function. + * \param description The description of the factory function. + * \return reference to self. + */ + inline EntryType &describe(const std::string &description) { + this->description = description; + return this->self(); + } + /*! + * \brief Add argument information to the function. + * \param name Name of the argument. + * \param type Type of the argument. + * \param description Description of the argument. + * \return reference to self. + */ + inline EntryType &add_argument(const std::string &name, + const std::string &type, + const std::string &description) { + ParamFieldInfo info; + info.name = name; + info.type = type; + info.type_info_str = info.type; + info.description = description; + arguments.push_back(info); + return this->self(); + } + /*! + * \brief Append list if arguments to the end. + * \param args Additional list of arguments. + * \return reference to self. + */ + inline EntryType &add_arguments(const std::vector &args) { + arguments.insert(arguments.end(), args.begin(), args.end()); + return this->self(); + } + /*! + * \brief Set the return type. + * \param type Return type of the function, could be Symbol or Symbol[] + * \return reference to self. + */ + inline EntryType &set_return_type(const std::string &type) { + return_type = type; + return this->self(); + } + + protected: + /*! + * \return reference of self as derived type + */ + inline EntryType &self() { + return *(static_cast(this)); + } +}; + +/*! + * \def DMLC_REGISTRY_ENABLE + * \brief Macro to enable the registry of EntryType. + * This macro must be used under namespace dmlc, and only used once in cc file. + * \param EntryType Type of registry entry + */ +#define DMLC_REGISTRY_ENABLE(EntryType) \ + template<> \ + Registry *Registry::Get() { \ + static Registry inst; \ + return &inst; \ + } \ + +/*! + * \brief Generic macro to register an EntryType + * There is a complete example in FactoryRegistryEntryBase. + * + * \param EntryType The type of registry entry. + * \param EntryTypeName The typename of EntryType, must do not contain namespace :: . + * \param Name The name to be registered. + * \sa FactoryRegistryEntryBase + */ +#define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \ + static DMLC_ATTRIBUTE_UNUSED EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \ + ::dmlc::Registry::Get()->__REGISTER__(#Name) \ + +/*! + * \brief (Optional) Declare a file tag to current file that contains object registrations. + * + * This will declare a dummy function that will be called by register file to + * incur a link dependency. + * + * \param UniqueTag The unique tag used to represent. + * \sa DMLC_REGISTRY_LINK_TAG + */ +#define DMLC_REGISTRY_FILE_TAG(UniqueTag) \ + int __dmlc_registry_file_tag_ ## UniqueTag ## __() { return 0; } + +/*! + * \brief (Optional) Force link to all the objects registered in file tag. + * + * This macro must be used in the same file as DMLC_REGISTRY_ENABLE and + * in the same namespace as DMLC_REGISTRY_FILE_TAG + * + * DMLC_REGISTRY_FILE_TAG and DMLC_REGISTRY_LINK_TAG are optional macros for registration. + * They are used to encforce link of certain file into during static linking. + * + * This is mainly used to solve problem during statically link a library which contains backward registration. + * Specifically, this avoids the objects in these file tags to be ignored by compiler. + * + * For dynamic linking, this problem won't occur as everything is loaded by default. + * + * Use of this is optional as it will create an error when a file tag do not exist. + * An alternative solution is always ask user to enable --whole-archieve during static link. + * + * \begincode + * // in file objective_registry.cc + * DMLC_REGISTRY_ENABLE(MyObjective); + * DMLC_REGISTRY_LINK_TAG(regression_op); + * DMLC_REGISTRY_LINK_TAG(rank_op); + * + * // in file regression_op.cc + * // declare tag of this file. + * DMLC_REGISTRY_FILE_TAG(regression_op); + * DMLC_REGISTRY_REGISTER(MyObjective, logistic_reg, logistic_reg); + * // ... + * + * // in file rank_op.cc + * // declare tag of this file. + * DMLC_REGISTRY_FILE_TAG(rank_op); + * DMLC_REGISTRY_REGISTER(MyObjective, pairwiserank, pairwiserank); + * + * \endcode + * + * \param UniqueTag The unique tag used to represent. + * \sa DMLC_REGISTRY_ENABLE, DMLC_REGISTRY_FILE_TAG + */ +#define DMLC_REGISTRY_LINK_TAG(UniqueTag) \ + int __dmlc_registry_file_tag_ ## UniqueTag ## __(); \ + static int DMLC_ATTRIBUTE_UNUSED __reg_file_tag_ ## UniqueTag ## __ = \ + __dmlc_registry_file_tag_ ## UniqueTag ## __(); +} // namespace dmlc +#endif // DMLC_REGISTRY_H_ diff --git a/include/dmlc/serializer.h b/include/dmlc/serializer.h new file mode 100644 index 000000000000..4bede4a3b416 --- /dev/null +++ b/include/dmlc/serializer.h @@ -0,0 +1,410 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file serializer.h + * \brief serializer template class that helps serialization. + * This file do not need to be directly used by most user. + */ +#ifndef DMLC_SERIALIZER_H_ +#define DMLC_SERIALIZER_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "./base.h" +#include "./io.h" +#include "./logging.h" +#include "./type_traits.h" +#include "./endian.h" + +#if DMLC_USE_CXX11 +#include +#include +#endif + +namespace dmlc { +/*! \brief internal namespace for serializers */ +namespace serializer { +/*! + * \brief generic serialization handler + * \tparam T the type to be serialized + * \tparam need_endian_swap Whether use little endian + */ +template +struct Handler; + +//! \cond Doxygen_Suppress +/*! + * \brief Serializer that redirect calls by condition + * \tparam cond the condition + * \tparam Then the serializer used for then condition + * \tparam Else the serializer used for else condition + * \tparam Return the type of data the serializer handles + */ +template +struct IfThenElse; + +template +struct IfThenElse { + inline static void Write(Stream *strm, const T &data) { + Then::Write(strm, data); + } + inline static bool Read(Stream *strm, T *data) { + return Then::Read(strm, data); + } +}; +template +struct IfThenElse { + inline static void Write(Stream *strm, const T &data) { + Else::Write(strm, data); + } + inline static bool Read(Stream *strm, T *data) { + return Else::Read(strm, data); + } +}; + +/*! \brief Serializer for POD(plain-old-data) data */ +template +struct NativePODHandler { + inline static void Write(Stream *strm, const T &data) { + strm->Write(&data, sizeof(T)); + } + inline static bool Read(Stream *strm, T *dptr) { + return strm->Read((void*)dptr, sizeof(T)) == sizeof(T); // NOLINT(*) + } +}; + +/*! \brief Serializer for arithmetic data, handle endianness */ +template +struct ArithmeticHandler { + inline static void Write(Stream *strm, const T &data) { + if (DMLC_IO_NO_ENDIAN_SWAP) { + strm->Write(&data, sizeof(T)); + } else { + T copy = data; + ByteSwap(©, sizeof(T), 1); + strm->Write(©, sizeof(T)); + } + } + inline static bool Read(Stream *strm, T *dptr) { + bool ret = strm->Read((void*)dptr, sizeof(T)) == sizeof(T); // NOLINT(*) + if (!DMLC_IO_NO_ENDIAN_SWAP) { + ByteSwap(dptr, sizeof(T), 1); + } + return ret; + } +}; + +// serializer for class that have save/load function +template +struct SaveLoadClassHandler { + inline static void Write(Stream *strm, const T &data) { + data.Save(strm); + } + inline static bool Read(Stream *strm, T *data) { + return data->Load(strm); + } +}; + +/*! + * \brief dummy class for undefined serialization. + * This is used to generate error message when user tries to + * serialize something that is not supported. + * \tparam T the type to be serialized + */ +template +struct UndefinedSerializerFor { +}; + +/*! + * \brief Serializer handler for std::vector where T is POD type. + * \tparam T element type + */ +template +struct NativePODVectorHandler { + inline static void Write(Stream *strm, const std::vector &vec) { + uint64_t sz = static_cast(vec.size()); + strm->Write(sz); + if (sz != 0) { + strm->Write(&vec[0], sizeof(T) * vec.size()); + } + } + inline static bool Read(Stream *strm, std::vector *out_vec) { + uint64_t sz; + if (!strm->Read(&sz)) return false; + size_t size = static_cast(sz); + out_vec->resize(size); + if (sz != 0) { + size_t nbytes = sizeof(T) * size; + return strm->Read(&(*out_vec)[0], nbytes) == nbytes; + } + return true; + } +}; + +/*! + * \brief Serializer handler for std::vector where T can be composed type + * \tparam T element type + */ +template +struct ComposeVectorHandler { + inline static void Write(Stream *strm, const std::vector &vec) { + uint64_t sz = static_cast(vec.size()); + strm->Write(sz); + strm->WriteArray(dmlc::BeginPtr(vec), vec.size()); + } + inline static bool Read(Stream *strm, std::vector *out_vec) { + uint64_t sz; + if (!strm->Read(&sz)) return false; + size_t size = static_cast(sz); + out_vec->resize(size); + return strm->ReadArray(dmlc::BeginPtr(*out_vec), size); + } +}; + +/*! + * \brief Serializer handler for std::basic_string where T is POD type. + * \tparam T element type + */ +template +struct NativePODStringHandler { + inline static void Write(Stream *strm, const std::basic_string &vec) { + uint64_t sz = static_cast(vec.length()); + strm->Write(sz); + if (sz != 0) { + strm->Write(&vec[0], sizeof(T) * vec.length()); + } + } + inline static bool Read(Stream *strm, std::basic_string *out_vec) { + uint64_t sz; + if (!strm->Read(&sz)) return false; + size_t size = static_cast(sz); + out_vec->resize(size); + if (sz != 0) { + size_t nbytes = sizeof(T) * size; + return strm->Read(&(*out_vec)[0], nbytes) == nbytes; + } + return true; + } +}; + +/*! \brief Serializer for std::pair */ +template +struct PairHandler { + inline static void Write(Stream *strm, const std::pair &data) { + Handler::Write(strm, data.first); + Handler::Write(strm, data.second); + } + inline static bool Read(Stream *strm, std::pair *data) { + return Handler::Read(strm, &(data->first)) && + Handler::Read(strm, &(data->second)); + } +}; + +// set type handler that can handle most collection type case +template +struct CollectionHandler { + inline static void Write(Stream *strm, const ContainerType &data) { + // dump data to vector + std::vector vdata(data.begin(), data.end()); + // serialize the vector + Handler >::Write(strm, vdata); + } + inline static bool Read(Stream *strm, ContainerType *data) { + std::vector vdata; + if (!Handler >::Read(strm, &vdata)) return false; + data->clear(); + data->insert(vdata.begin(), vdata.end()); + return true; + } +}; + + +// handler that can handle most list type case +// this type insert function takes additional iterator +template +struct ListHandler { + inline static void Write(Stream *strm, const ListType &data) { + typedef typename ListType::value_type ElemType; + // dump data to vector + std::vector vdata(data.begin(), data.end()); + // serialize the vector + Handler >::Write(strm, vdata); + } + inline static bool Read(Stream *strm, ListType *data) { + typedef typename ListType::value_type ElemType; + std::vector vdata; + if (!Handler >::Read(strm, &vdata)) return false; + data->clear(); + data->insert(data->begin(), vdata.begin(), vdata.end()); + return true; + } +}; + +//! \endcond + +/*! + * \brief generic serialization handler for type T + * + * User can define specialization of this class to support + * composite serialization of their own class. + * + * \tparam T the type to be serialized + */ +template +struct Handler { + /*! + * \brief write data to stream + * \param strm the stream we write the data. + * \param data the data obeject to be serialized + */ + inline static void Write(Stream *strm, const T &data) { + IfThenElse::value, + ArithmeticHandler, + IfThenElse::value && DMLC_IO_NO_ENDIAN_SWAP, + NativePODHandler, + IfThenElse::value, + SaveLoadClassHandler, + UndefinedSerializerFor, T>, + T>, + T> + ::Write(strm, data); + } + /*! + * \brief read data to stream + * \param strm the stream to read the data. + * \param data the pointer to the data obeject to read + * \return whether the read is successful + */ + inline static bool Read(Stream *strm, T *data) { + return + IfThenElse::value, + ArithmeticHandler, + IfThenElse::value && DMLC_IO_NO_ENDIAN_SWAP, + NativePODHandler, + IfThenElse::value, + SaveLoadClassHandler, + UndefinedSerializerFor, T>, + T>, + T> + ::Read(strm, data); + } +}; + +//! \cond Doxygen_Suppress +template +struct Handler > { + inline static void Write(Stream *strm, const std::vector &data) { + IfThenElse::value && DMLC_IO_NO_ENDIAN_SWAP, + NativePODVectorHandler, + ComposeVectorHandler, std::vector > + ::Write(strm, data); + } + inline static bool Read(Stream *strm, std::vector *data) { + return IfThenElse::value && DMLC_IO_NO_ENDIAN_SWAP, + NativePODVectorHandler, + ComposeVectorHandler, + std::vector > + ::Read(strm, data); + } +}; + +template +struct Handler > { + inline static void Write(Stream *strm, const std::basic_string &data) { + IfThenElse::value && (DMLC_IO_NO_ENDIAN_SWAP || sizeof(T) == 1), + NativePODStringHandler, + UndefinedSerializerFor, + std::basic_string > + ::Write(strm, data); + } + inline static bool Read(Stream *strm, std::basic_string *data) { + return IfThenElse::value && (DMLC_IO_NO_ENDIAN_SWAP || sizeof(T) == 1), + NativePODStringHandler, + UndefinedSerializerFor, + std::basic_string > + ::Read(strm, data); + } +}; + +template +struct Handler > { + inline static void Write(Stream *strm, const std::pair &data) { + IfThenElse::value && + dmlc::is_pod::value && + DMLC_IO_NO_ENDIAN_SWAP, + NativePODHandler >, + PairHandler, + std::pair > + ::Write(strm, data); + } + inline static bool Read(Stream *strm, std::pair *data) { + return IfThenElse::value && + dmlc::is_pod::value && + DMLC_IO_NO_ENDIAN_SWAP, + NativePODHandler >, + PairHandler, + std::pair > + ::Read(strm, data); + } +}; + +template +struct Handler > + : public CollectionHandler, std::pair > { +}; + +template +struct Handler > + : public CollectionHandler, std::pair > { +}; + +template +struct Handler > + : public CollectionHandler, T> { +}; + +template +struct Handler > + : public CollectionHandler, T> { +}; + +template +struct Handler > + : public ListHandler > { +}; + +template +struct Handler > + : public ListHandler > { +}; + +#if DMLC_USE_CXX11 +template +struct Handler > + : public CollectionHandler, std::pair > { +}; + +template +struct Handler > + : public CollectionHandler, std::pair > { +}; + +template +struct Handler > + : public CollectionHandler, T> { +}; + +template +struct Handler > + : public CollectionHandler, T> { +}; +#endif +//! \endcond +} // namespace serializer +} // namespace dmlc +#endif // DMLC_SERIALIZER_H_ diff --git a/include/dmlc/thread_group.h b/include/dmlc/thread_group.h new file mode 100644 index 000000000000..626142f30284 --- /dev/null +++ b/include/dmlc/thread_group.h @@ -0,0 +1,808 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file thread_group.h + * \brief Thread and synchronization primitives and lifecycle management + */ +#ifndef DMLC_THREAD_GROUP_H_ +#define DMLC_THREAD_GROUP_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if defined(DMLC_USE_CXX14) || __cplusplus > 201103L /* C++14 */ +#include +#endif +#include +#ifdef __linux__ +#include +#include +#endif + +namespace dmlc { + +/*! + * \brief Simple manual-reset event gate which remains open after signalled + */ +class ManualEvent { + public: + ManualEvent() : signaled_(false) {} + + /*! + * \brief Wait for the object to become signaled. If the object + * is already in the signaled state and reset() has not been called, then no wait will occur + */ + void wait() { + std::unique_lock lock(mutex_); + if (!signaled_) { + condition_variable_.wait(lock); + } + } + + /*! + * \brief Set this object's state to signaled (wait() will release or pass through) + */ + void signal() { + signaled_ = true; + std::unique_lock lk(mutex_); + condition_variable_.notify_all(); + } + + /*! + * \brief Manually reset this object's state to unsignaled (wait() will block) + */ + void reset() { + std::unique_lock lk(mutex_); + signaled_ = false; + } + + private: + /*! \brief Internal mutex to protect condition variable and signaled_ variable */ + std::mutex mutex_; + /*! \brief Internal condition variable */ + std::condition_variable condition_variable_; + /*! \brief lockfree signal state check */ + std::atomic signaled_; +}; + +#if defined(DMLC_USE_CXX14) || __cplusplus > 201103L /* C++14 */ +/*! \brief Mutex which can be read-locked and write-locked */ +using SharedMutex = std::shared_timed_mutex; +/*! \brief Write lock, disallows both reads and writes */ +using WriteLock = std::unique_lock; +/*! \brief Read lock, allows concurrent data reads */ +using ReadLock = std::shared_lock; +#else +/*! \brief Standard mutex for C++ < 14 */ +using SharedMutex = std::recursive_mutex; +/*! \brief Standard unique lock for C++ < 14 */ +using WriteLock = std::unique_lock; +/*! \brief Standard unique lock for C++ < 14 */ +using ReadLock = std::unique_lock; +#endif + +/*! + * \brief Thread lifecycle management group + * \note See gtest unit tests Syc.* for a usage examples + */ +class ThreadGroup { + public: + /*! + * \brief Lifecycle-managed thread (used by ThreadGroup) + * \note See gtest unit tests Syc.* for a usage examples + */ + class Thread { + public: + /*! \brief Shared pointer type for readability */ + using SharedPtr = std::shared_ptr; + + /*! + * \brief Constructor + * \param threadName User-defined name of the thread. must be unique per ThreadGroup + * \param owner The ThreadGroup object managing the lifecycle of this thread + * \param thrd Optionally-assigned std::thread object associated with this Thread class + */ + Thread(std::string threadName, ThreadGroup *owner, std::thread *thrd = nullptr) + : name_(std::move(threadName)) + , thread_(thrd) + , ready_event_(std::make_shared()) + , start_event_(std::make_shared()) + , owner_(owner) + , shutdown_requested_(false) + , auto_remove_(false) { + CHECK_NOTNULL(owner); + } + + /*! + * \brief Destructor with cleanup + */ + virtual ~Thread() { + const bool self_delete = is_current_thread(); + if (!self_delete) { + request_shutdown(); + internal_join(true); + } + WriteLock guard(thread_mutex_); + if (thread_.load()) { + std::thread *thrd = thread_.load(); + thread_ = nullptr; + if (self_delete) { + thrd->detach(); + } + delete thrd; + } + } + + /*! + * \brief Name of the thread + * \return Pointer to the thread name's string + * \note This shoul ndly be used as immediate for the sacope of the + * shared pointer pointing to this object + */ + const char *name() const { + return name_.c_str(); + } + + /*! + * \brief Launch the given Thread object + * \tparam StartFunction Function type for the thread 'main' function + * \tparam Args Arguments to pass to the thread 'main' function + * \param pThis Shared pointer for the managed thread to launch + * \param autoRemove if true, automatically remove this Thread object from the + * ThreadGroup owner upon exit + * \param start_function The Thread's 'main' function + * \param args Arguments to pass to the Thread's 'main' function + * \return true if the thread was successfully created and added to the ThreadGroup + * If false is returned, the thread may have already been started, but if something + * went wrong (ie duplicte thread name for the ThreadGroup), then request_shutdown() + * will have been been called on the running thread + */ + template + static bool launch(std::shared_ptr pThis, + bool autoRemove, + StartFunction start_function, + Args ...args); + + /*! + * \brief Check if this class represents the currently running thread (self) + * \return true if the current running thread belongs to this class + */ + bool is_current_thread() const { + ReadLock guard(thread_mutex_); + return thread_.load() ? (thread_.load()->get_id() == std::this_thread::get_id()) : false; + } + + /*! + * \brief Signal to this thread that a thread shutdown/exit is requested. + * \note This is a candidate for overrise in a derived class which may trigger shutdown + * by means other than a boolean (ie condition variable, SimpleManualkEvent, etc). + */ + virtual void request_shutdown() { + shutdown_requested_ = true; + } + + /*! + * \brief Check whether shutdown has been requested (request_shutdown() was called) + * \return true if shutdown was requested. + * \note This may be overriden to match an overriden to match an overriden 'request_shutdown()', + * for instance. + */ + virtual bool is_shutdown_requested() const { + return shutdown_requested_.load(); + } + + /*! + * \brief Check whether the thread is set to auto-remove itself from the ThreadGroup owner + * when exiting + * \return true if the thread will auto-remove itself from the ThreadGroup owner + * when exiting + */ + bool is_auto_remove() const { + return auto_remove_; + } + + /*! + * \brief Make the thread joinable (by removing the auto_remove flag) + * \warning Care should be taken not to cause a race condition between this call + * and parallel execution of this thread auto-removing itself + */ + void make_joinable() { + auto_remove_ = false; + } + + /*! + * \brief Check whether the thread is joinable + * \return true if the thread is joinable + */ + bool joinable() const { + ReadLock guard(thread_mutex_); + if (thread_.load()) { + CHECK_EQ(auto_remove_, false); + // be checked by searching the group or exit event. + return thread_.load()->joinable(); + } + return false; + } + + /*! + * \brief Thread join + * \note join() may not be called on auto-remove threads + */ + void join() { + internal_join(false); + } + + /*! + * \brief Get this thread's id + * \return this thread's id + */ + std::thread::id get_id() const { + ReadLock guard(thread_mutex_); + return thread_.load()->get_id(); + } + + private: + /*! + * \brief Internal join function + * \param auto_remove_ok Whether to allow join on an auto-remove thread + */ + void internal_join(bool auto_remove_ok) { + ReadLock guard(thread_mutex_); + // should be careful calling (or any function externally) this when in + // auto-remove mode + if (thread_.load() && thread_.load()->get_id() != std::thread::id()) { + std::thread::id someId; + if (!auto_remove_ok) { + CHECK_EQ(auto_remove_, false); + } + CHECK_NOTNULL(thread_.load()); + if (thread_.load()->joinable()) { + thread_.load()->join(); + } else { + LOG(WARNING) << "Thread " << name_ << " ( " + << thread_.load()->get_id() << " ) not joinable"; + } + } + } + + /*! + * \brief Thread bootstrapping and teardown wrapper + * \tparam StartFunction Thread's "main" function + * \tparam Args Argument types to be passed to the start_function + * \param pThis Shared pointer to the Thread object to operate upon + * \param start_function Thread's "main" function (i.e. passed to launch()) + * \param args Arguments to be passed to the start_function + * \return The thread's return code + */ + template + static int entry_and_exit_f(std::shared_ptr pThis, + StartFunction start_function, + Args... args); + /*! \brief Thread name */ + std::string name_; + /*! \brief Shared mutex for some thread operations */ + mutable SharedMutex thread_mutex_; + /*! \brief Pointer to the stl thread object */ + std::atomic thread_; + /*! \brief Signaled when the thread is started and ready to execute user code */ + std::shared_ptr ready_event_; + /*! \brief Thread will block after setting ready_event_ until start_event_ is signaled */ + std::shared_ptr start_event_; + /*! \brief The ThreadGroup ownber managing this thread's lifecycle */ + ThreadGroup *owner_; + /*! \brief Flag to determine if shutdown was requested. */ + std::atomic shutdown_requested_; + /*! + * \brief Whether to automatically remove this thread's object from the ThreadGroup when the + * thread exists (perform its own cleanup) + */ + volatile bool auto_remove_; + }; + + /*! + * \brief Constructor + */ + inline ThreadGroup() + : evEmpty_(std::make_shared()) { + evEmpty_->signal(); // Starts out empty + } + + /*! + * \brief Destructor, perform cleanup. All child threads will be exited when this + * destructor completes + */ + virtual ~ThreadGroup() { + request_shutdown_all(); + join_all(); + } + + /*! + * \brief Check if the current thread a member if this ThreadGroup + * \return true if the current thread is a member of this thread group + * \note This lookup involved a linear search, so for a large number of threads, + * is it not advised to call this function in a performance-sensitive area + */ + inline bool is_this_thread_in() const { + std::thread::id id = std::this_thread::get_id(); + ReadLock guard(m_); + for (auto it = threads_.begin(), end = threads_.end(); it != end; ++it) { + std::shared_ptr thrd = *it; + if (thrd->get_id() == id) + return true; + } + return false; + } + + /*! + * \brief Check if the current thread is a member of this ThreadGroup + * \param thrd The thread to search for + * \return true if the given thread is a member of this ThreadGroup + */ + inline bool is_thread_in(std::shared_ptr thrd) const { + if (thrd) { + std::thread::id id = thrd->get_id(); + ReadLock guard(m_); + for (auto it = threads_.begin(), end = threads_.end(); it != end; ++it) { + std::shared_ptr thrd = *it; + if (thrd->get_id() == id) + return true; + } + return false; + } else { + return false; + } + } + + /*! + * \brief Add a Thread object to this thread group + * \param thrd The thread to add to this ThreadGroup object + * \return true if the given thread was added to this ThreadGroup + */ + inline bool add_thread(std::shared_ptr thrd) { + if (thrd) { + WriteLock guard(m_); + auto iter = name_to_thread_.find(thrd->name()); + if (iter == name_to_thread_.end()) { + name_to_thread_.emplace(std::make_pair(thrd->name(), thrd)); + CHECK_EQ(threads_.insert(thrd).second, true); + evEmpty_->reset(); + return true; + } + } + return false; + } + + /*! + * \brief Remove a Thread object from this thread group + * \param thrd The thread to remove from this ThreadGroup object + * \return true if the given thread was removed from this ThreadGroup + */ + inline bool remove_thread(std::shared_ptr thrd) { + if (thrd) { + WriteLock guard(m_); + auto iter = threads_.find(thrd); + if (iter != threads_.end()) { + name_to_thread_.erase(thrd->name()); + threads_.erase(iter); + if (threads_.empty()) { + evEmpty_->signal(); + } + return true; + } + } + return false; + } + + /*! + * \brief Join all threads in this ThreadGroup + * \note While it is not valid to call 'join' on an auto-remove thread, this function will + * wait for auto-remove threads to exit (waits for the ThreadGroup to become empty) + */ + inline void join_all() { + CHECK_EQ(!is_this_thread_in(), true); + do { + std::unique_lock lk(join_all_mtx_); + std::unordered_set> working_set; + { + ReadLock guard(m_); + for (auto iter = threads_.begin(), e_iter = threads_.end(); iter != e_iter; ++iter) { + if (!(*iter)->is_auto_remove()) { + working_set.emplace(*iter); + } + } + } + // Where possible, prefer to do a proper join rather than simply waiting for empty + // (easier to troubleshoot) + while (!working_set.empty()) { + std::shared_ptr thrd; + thrd = *working_set.begin(); + if (thrd->joinable()) { + thrd->join(); + } + remove_thread(thrd); + working_set.erase(working_set.begin()); + thrd.reset(); + } + // Wait for auto-remove threads (if any) to complete + } while (0); + evEmpty_->wait(); + CHECK_EQ(threads_.size(), 0); + } + + /*! + * \brief Call request_shutdown() on all threads in this ThreadGroup + * \param make_all_joinable If true, remove all auto_remove flags from child threads + */ + inline void request_shutdown_all(const bool make_all_joinable = true) { + std::unique_lock lk(join_all_mtx_); + ReadLock guard(m_); + for (auto &thread : threads_) { + if (make_all_joinable) { + thread->make_joinable(); + } + thread->request_shutdown(); + } + } + + /*! + * \brief Return the number of threads in this thread group + * \return Number of threads in this thread group + */ + inline size_t size() const { + ReadLock guard(m_); + return threads_.size(); + } + + /*! + * \brief Check if the ThreadGroup is empty + * \return true if the ThreadGroup is empty + */ + inline bool empty() const { + ReadLock guard(m_); + return threads_.size() == 0; + } + + /*! + * \brief Create and launch a new Thread object which will be owned by this ThreadGroup + * \tparam StartFunction Function type for the thread 'main' function + * \tparam ThreadType managedThreadclass type (in case it's derived, for instance) + * \tparam Args Arguments to pass to the thread 'main' function + * \param threadName Name if the thread. Must be unique for a ThreadGroup object + * \param auto_remove If true, automatically remove this Thread object from the + * ThreadGroup owner upon exit + * \param start_function The Thread's 'main' function + * \param args Arguments to pass to the Thread's 'main' function + * \return true if the thread was successfully created and added to the ThreadGroup + * If false is returned, the thread may have already been started, but if something + * went wrong (ie duplicte thread name for the ThreadGroup), then request_shutdown() + * will have been been called on the running thread + */ + template + inline bool create(const std::string &threadName, + bool auto_remove, + StartFunction start_function, + Args... args) { + typename ThreadType::SharedPtr newThread(new ThreadType(threadName, this)); + return Thread::launch(newThread, auto_remove, start_function, args...); + } + + /*! + * \brief Lookup Thread object by name + * \param name Name of the thread to look up + * \return A shared pointer to the Thread object + */ + inline std::shared_ptr thread_by_name(const std::string& name) { + ReadLock guard(m_); + auto iter = name_to_thread_.find(name); + if (iter != name_to_thread_.end()) { + return iter->second; + } + return nullptr; + } + + private: + /*! \brief ThreadGroup synchronization mutex */ + mutable SharedMutex m_; + /*! \brief join_all/auto_remove synchronization mutex */ + mutable std::mutex join_all_mtx_; + /*! \brief Set of threads owned and managed by this ThreadGroup object */ + std::unordered_set> threads_; + /*! \brief Manual event which is signaled when the thread group is empty */ + std::shared_ptr evEmpty_; + /*! \brief name->thread mapping */ + std::unordered_map> name_to_thread_; +}; + +/*! + * \brief Blocking queue thread class + * \tparam ObjectType Object type to queue + * \tparam quit_item Object value to signify queue shutdown (ie nullptr for pointer type is common) + * \note See gtest unit test Syc.ManagedThreadLaunchQueueThread for a usage example + */ +template +class BlockingQueueThread : public ThreadGroup::Thread { + using BQT = BlockingQueueThread; + + public: + /*! + * \brief Constructor + * \param name Name for the blockin g queue thread. Must be unique for a specific ThreadGroup + * \param owner ThreadGroup lifecycle manafger/owner + * \param thrd Optionally attach an existing stl thread object + */ + BlockingQueueThread(const std::string& name, + dmlc::ThreadGroup *owner, + std::thread *thrd = nullptr) + : ThreadGroup::Thread(std::move(name), owner, thrd) + , shutdown_in_progress_(false) { + } + + + /*! + * \brief Destructor + */ + ~BlockingQueueThread() override { + // Call to parent first because we don't want to wait for the queue to empty + ThreadGroup::Thread::request_shutdown(); + request_shutdown(); + } + + /*! + * \brief Signal the thread that a shutdown is desired + * \note Since consumer doesn't necessarily get items in order, we must wait for + * the queue to empty. + * This is generally a shutdown procedure and should not be called from + * a performance-sensitive area + */ + void request_shutdown() override { + shutdown_in_progress_ = true; + while (queue_->size_approx() > 0 && !ThreadGroup::Thread::is_shutdown_requested()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + ThreadGroup::Thread::request_shutdown(); + queue_->enqueue(quit_item); + } + + /*! + * \brief Enqueue and item + * \param item The item to enqueue + */ + void enqueue(const ObjectType& item) { + if (!shutdown_in_progress_) { + queue_->enqueue(item); + } + } + + /*! + * \brief Get the approximate size of the queue + * \return The approximate size of the queue + */ + size_t size_approx() const { return queue_->size_approx(); } + + /*! + * \brief Launch to the 'run' function which will, in turn, call the class' + * 'run' function, passing it the given 'secondary_function' + * for it to call as needed + * \tparam SecondaryFunction Type of the secondary function for 'run' override + * to call as needed + * \param pThis Pointer to the managed thread to launch + * \param secondary_function secondary function for 'run' override to call as needed + * \return true if thread is launched successfully and added to the ThreadGroup + */ + template + static bool launch_run(std::shared_ptr pThis, + SecondaryFunction secondary_function) { + return ThreadGroup::Thread::launch(pThis, true, [](std::shared_ptr pThis, + SecondaryFunction secondary_function) { + return pThis->run(secondary_function); + }, + pThis, secondary_function); + } + + /*! + * \brief Thread's main queue processing function + * \tparam OnItemFunction Function type to call when an item is dequeued + * \param on_item_function Function to call when an item is dequeued + * \return 0 if completed through a `quit_item`, nonzero if on_item_function requested an exit + */ + template + inline int run(OnItemFunction on_item_function) { + int rc = 0; + do { + ObjectType item; + queue_->wait_dequeue(item); + if (item == quit_item) { + break; + } + rc = on_item_function(item); + if (rc) { + break; + } + } while (true); + return rc; + } + + private: + /*! \brief The blocking queue associated with this thread */ + std::shared_ptr> queue_ = + std::make_shared>(); + /*! \brief Whether shutdown request is in progress */ + std::atomic shutdown_in_progress_; +}; + +/*! + * \brief Managed timer thread + * \tparam Duration Duration type (ie seconds, microseconds, etc) + */ +template +class TimerThread : public ThreadGroup::Thread { + using ThreadGroup::Thread::is_shutdown_requested; + + public: + /*! + * \brief Constructor + * \param name Name of the timer thread + * \param owner ThreadGroup owner if the timer thread + */ + TimerThread(const std::string& name, ThreadGroup *owner) + : Thread(name, owner) { + } + + /*! + * \brief Destructor + */ + ~TimerThread() override { + request_shutdown(); + } + + /*! + * \brief Launch to the 'run' function which will, in turn, call the class' + * 'run' function, passing it the given 'secondary_function' + * for it to call as needed + * \tparam SecondaryFunction Type of the secondary function for 'run' override + * to call as needed + * \param pThis Pointer to the managed thread to launch + * \param secondary_function secondary function for 'run' override to call as needed + * \return true if thread is launched successfully and added to the ThreadGroup + */ + template + static bool launch_run(std::shared_ptr> pThis, + SecondaryFunction secondary_function) { + return ThreadGroup::Thread::launch(pThis, true, [](std::shared_ptr> pThis, + SecondaryFunction secondary_function) { + return pThis->run(secondary_function); + }, + pThis, secondary_function); + } + + /*! + * \brief Start a given timer thread + * \tparam Function Type of the timer function + * \param timer_thread Thread object to perform the timer events + * \param duration Duration between the end end of the timer function and the next timer event + * \param function Function to call when the timer expires + * \note Calling shutdown_requested() will cause the thread to exit the next time that the timer + * expires. + */ + template + static void start(std::shared_ptr timer_thread, + Duration duration, + Function function) { + timer_thread->duration_ = duration; + launch_run(timer_thread, function); + } + + /*! + * \brief Internal timer execution function + * \tparam OnTimerFunction Type of function to call each time the timer expires + * \param on_timer_function Function to call each time the timer expires + * \return Exit code of the thread + */ + template + inline int run(OnTimerFunction on_timer_function) { + int rc = 0; + while (!is_shutdown_requested()) { + std::this_thread::sleep_for(duration_); + if (!is_shutdown_requested()) { + rc = on_timer_function(); + } + } + return rc; + } + + private: + Duration duration_; +}; + +/* + * Inline functions - see declarations for usage + */ +template +inline int ThreadGroup::Thread::entry_and_exit_f(std::shared_ptr pThis, + StartFunction start_function, + Args... args) { + int rc; + if (pThis) { + // Signal launcher that we're up and running + pThis->ready_event_->signal(); + // Wait for launcher to be ready for us to start + pThis->start_event_->wait(); + // Reset start_event_ for possible reuse + pThis->start_event_->reset(); // Reset in case it needs to be reused + // If we haven't been requested to shut down prematurely, then run the desired function + if (!pThis->is_shutdown_requested()) { + rc = start_function(args...); + } else { + rc = -1; + } + // If we're set up as auto-remove, then remove this thread from the thread group + if (pThis->is_auto_remove()) { + pThis->owner_->remove_thread(pThis); + } + // Release this thread shared pinter. May or may not be the last reference. + pThis.reset(); + } else { + LOG(ERROR) << "Null pThis thread pointer"; + rc = EINVAL; + } + return rc; +} + +template +inline bool ThreadGroup::Thread::launch(std::shared_ptr pThis, + bool autoRemove, + StartFunction start_function, + Args ...args) { + WriteLock guard(pThis->thread_mutex_); + CHECK_EQ(!pThis->thread_.load(), true); + CHECK_NOTNULL(pThis->owner_); + // Set auto remove + pThis->auto_remove_ = autoRemove; + // Create the actual stl thread object + pThis->thread_ = new std::thread(Thread::template entry_and_exit_f< + StartFunction, Args...>, + pThis, + start_function, + args...); + // Attempt to add the thread to the thread group (after started, since in case + // something goes wrong, there's not a zombie thread in the thread group) + if (!pThis->owner_->add_thread(pThis)) { + pThis->request_shutdown(); + LOG(ERROR) << "Duplicate thread name within the same thread group is not allowed"; + } + // Wait for the thread to spin up + pThis->ready_event_->wait(); + // Signal the thgread to continue (it will check its shutdown status) + pThis->start_event_->signal(); + // Return if successful + return pThis->thread_.load() != nullptr; +} + +/*! + * \brief Utility function to easily create a timer + * \tparam Duration Duration type (i.e. std::chrono::milliseconds) + * \tparam TimerFunction Function to call each time the timer expires + * \param timer_name Name of the timer. Must be unique per ThreadGroup object + * \param duration Duration of the timer between calls to timer_function + * \param owner ThreadGroup owner of the timer + * \param timer_function Function to call each time the timer expires + * \return true if the timer was successfully created + */ +template +inline bool CreateTimer(const std::string& timer_name, + const Duration& duration, + ThreadGroup *owner, + TimerFunction timer_function) { + std::shared_ptr> timer_thread = + std::make_shared>(timer_name, owner); + dmlc::TimerThread::start(timer_thread, duration, timer_function); + return timer_thread != nullptr; +} +} // namespace dmlc + +#endif // DMLC_THREAD_GROUP_H_ diff --git a/include/dmlc/thread_local.h b/include/dmlc/thread_local.h new file mode 100644 index 000000000000..fecaef8686de --- /dev/null +++ b/include/dmlc/thread_local.h @@ -0,0 +1,83 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file thread_local.h + * \brief Portable thread local storage. + */ +#ifndef DMLC_THREAD_LOCAL_H_ +#define DMLC_THREAD_LOCAL_H_ + +#include +#include +#include +#include "./base.h" + +namespace dmlc { + +// macro hanlding for threadlocal variables +#ifdef __GNUC__ + #define MX_THREAD_LOCAL __thread +#elif __STDC_VERSION__ >= 201112L + #define MX_THREAD_LOCAL _Thread_local +#elif defined(_MSC_VER) + #define MX_THREAD_LOCAL __declspec(thread) +#endif + +#if DMLC_CXX11_THREAD_LOCAL == 0 +#pragma message("Warning: CXX11 thread_local is not formally supported") +#endif + +/*! + * \brief A threadlocal store to store threadlocal variables. + * Will return a thread local singleton of type T + * \tparam T the type we like to store + */ +template +class ThreadLocalStore { + public: + /*! \return get a thread local singleton */ + static T* Get() { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local T inst; + return &inst; +#else + static MX_THREAD_LOCAL T* ptr = nullptr; + if (ptr == nullptr) { + ptr = new T(); + Singleton()->RegisterDelete(ptr); + } + return ptr; +#endif + } + + private: + /*! \brief constructor */ + ThreadLocalStore() {} + /*! \brief destructor */ + ~ThreadLocalStore() { + for (size_t i = 0; i < data_.size(); ++i) { + delete data_[i]; + } + } + /*! \return singleton of the store */ + static ThreadLocalStore *Singleton() { + static ThreadLocalStore inst; + return &inst; + } + /*! + * \brief register str for internal deletion + * \param str the string pointer + */ + void RegisterDelete(T *str) { + std::unique_lock lock(mutex_); + data_.push_back(str); + lock.unlock(); + } + /*! \brief internal mutex */ + std::mutex mutex_; + /*!\brief internal data */ + std::vector data_; +}; + +} // namespace dmlc + +#endif // DMLC_THREAD_LOCAL_H_ diff --git a/include/dmlc/threadediter.h b/include/dmlc/threadediter.h new file mode 100644 index 000000000000..c920156b2331 --- /dev/null +++ b/include/dmlc/threadediter.h @@ -0,0 +1,475 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file threadediter.h + * \brief thread backed iterator that can be used to implement + * general thread-based pipeline such as prefetch and pre-computation + * To use the functions in this header, C++11 is required + * \author Tianqi Chen + */ +#ifndef DMLC_THREADEDITER_H_ +#define DMLC_THREADEDITER_H_ +// defines DMLC_USE_CXX11 +#include "./base.h" +// this code depends on c++11 +#if DMLC_ENABLE_STD_THREAD +#include +#include +#include +#include +#include +#include "./data.h" +#include "./logging.h" + +namespace dmlc { +/*! + * \brief a iterator that was backed by a thread + * to pull data eagerly from a single producer into a bounded buffer + * the consumer can pull the data at its own rate + * + * NOTE: thread concurrency cost time, make sure to store big blob of data in DType + * + * Usage example: + * \code + * ThreadedIter iter; + * iter.Init(&producer); + * // the following code can be in parallel + * DType *dptr; + * while (iter.Next(&dptr)) { + * // do something on dptr + * // recycle the space + * iter.Recycle(&dptr); + * } + * \endcode + * \tparam DType the type of data blob we support + */ +template +class ThreadedIter : public DataIter { + public: + /*! + * \brief producer class interface + * that threaditer used as source to + * preduce the content + */ + class Producer { + public: + // virtual destructor + virtual ~Producer() {} + /*! \brief reset the producer to beginning */ + virtual void BeforeFirst(void) { + NotImplemented(); + } + /*! + * \brief load the data content into DType, + * the caller can pass in NULL or an existing address + * when inout_dptr is NULL: + * producer need to allocate a DType and fill the content + * when inout_dptr is specified + * producer takes need to fill the content into address + * specified inout_dptr, or delete the one and create a new one + * + * \param inout_dptr used to pass in the data holder cell + * and return the address of the cell filled + * \return true if there is next record, false if we reach the end + */ + virtual bool Next(DType **inout_dptr) = 0; + }; + /*! + * \brief constructor + * \param max_capacity maximum capacity of the queue + */ + explicit ThreadedIter(size_t max_capacity = 8) + : producer_owned_(NULL), + producer_thread_(NULL), + max_capacity_(max_capacity), + nwait_consumer_(0), + nwait_producer_(0), + out_data_(NULL) {} + /*! \brief destructor */ + virtual ~ThreadedIter(void) { + this->Destroy(); + } + /*! + * \brief destroy all the related resources + * this is equivalent to destructor, can be used + * to destroy the threaditer when user think it is + * appropriate, it is safe to call this multiple times + */ + inline void Destroy(void); + /*! + * \brief set maximum capacity of the queue + * \param max_capacity maximum capacity of the queue + */ + inline void set_max_capacity(size_t max_capacity) { + max_capacity_ = max_capacity; + } + /*! + * \brief initialize the producer and start the thread + * can only be called once + * \param producer pointer to the producer + * \param pass_ownership whether pass the ownership to the iter + * if this is true, the threaditer will delete the producer + * when destructed + */ + inline void Init(Producer *producer, bool pass_ownership = false); + /*! + * \brief initialize the producer and start the thread + * pass in two function(closure) of producer to represent the producer + * the beforefirst function is optional, and defaults to not implemented + * NOTE: the closure must remain valid until the ThreadedIter destructs + * \param next the function called to get next element, see Producer.Next + * \param beforefirst the function to call to reset the producer, see Producer.BeforeFirst + */ + inline void Init(std::function next, + std::function beforefirst = NotImplemented); + /*! + * \brief get the next data, this function is threadsafe + * \param out_dptr used to hold the pointer to the record + * after the function call, the caller takes ownership of the pointer + * the caller can call recycle to return ownership back to the threaditer + * so that the pointer can be re-used + * \return true if there is next record, false if we reach the end + * \sa Recycle + */ + inline bool Next(DType **out_dptr); + /*! + * \brief recycle the data cell, this function is threadsafe + * the threaditer can reuse the data cell for future data loading + * \param inout_dptr pointer to the dptr to recycle, after the function call + * the content of inout_dptr will be set to NULL + */ + inline void Recycle(DType **inout_dptr); + + /*! + * \brief Rethrows exception which is set by the producer + */ + inline void ThrowExceptionIfSet(void); + + /*! + * \brief clears exception_ptr, called from Init + */ + inline void ClearException(void); + + /*! + * \brief adapt the iterator interface's Next + * NOTE: the call to this function is not threadsafe + * use the other Next instead + * \return true if there is next record, false if we reach the end + */ + virtual bool Next(void) { + if (out_data_ != NULL) { + this->Recycle(&out_data_); + } + if (Next(&out_data_)) { + return true; + } else { + return false; + } + } + /*! + * \brief adapt the iterator interface's Value + * NOTE: the call to this function is not threadsafe + * use the other Next instead + */ + virtual const DType &Value(void) const { + CHECK(out_data_ != NULL) << "Calling Value at beginning or end?"; + return *out_data_; + } + /*! \brief set the iterator before first location */ + virtual void BeforeFirst(void) { + ThrowExceptionIfSet(); + std::unique_lock lock(mutex_); + if (out_data_ != NULL) { + free_cells_.push(out_data_); + out_data_ = NULL; + } + if (producer_sig_ == kDestroy) return; + + producer_sig_ = kBeforeFirst; + CHECK(!producer_sig_processed_); + if (nwait_producer_ != 0) { + producer_cond_.notify_one(); + } + CHECK(!producer_sig_processed_); + // wait until the request has been processed + consumer_cond_.wait(lock, [this]() { + return producer_sig_processed_; + }); + producer_sig_processed_ = false; + bool notify = nwait_producer_ != 0 && !produce_end_; + lock.unlock(); + // notify producer, in case they are waiting for the condition. + if (notify) producer_cond_.notify_one(); + ThrowExceptionIfSet(); + } + + private: + /*! \brief not support BeforeFirst */ + inline static void NotImplemented(void) { + LOG(FATAL) << "BeforeFirst is not supported"; + } + /*! \brief signals send to producer */ + enum Signal { + kProduce, + kBeforeFirst, + kDestroy + }; + /*! \brief producer class */ + Producer *producer_owned_; + /*! \brief signal to producer */ + Signal producer_sig_; + /*! \brief whether the special signal other than kProduce is procssed */ + bool producer_sig_processed_; + /*! \brief thread that runs the producer */ + std::thread *producer_thread_; + /*! \brief whether produce ends */ + bool produce_end_; + /*! \brief maximum queue size */ + size_t max_capacity_; + /*! \brief internal mutex */ + std::mutex mutex_; + /*! brief internal mutex for exceptions */ + std::mutex mutex_exception_; + /*! \brief number of consumer waiting */ + unsigned nwait_consumer_; + /*! \brief number of consumer waiting */ + unsigned nwait_producer_; + /*! \brief conditional variable for producer thread */ + std::condition_variable producer_cond_; + /*! \brief conditional variable for consumer threads */ + std::condition_variable consumer_cond_; + /*! \brief the current output cell */ + DType *out_data_; + /*! \brief internal queue of producer */ + std::queue queue_; + /*! \brief free cells that can be used */ + std::queue free_cells_; + /*! \brief holds a reference to iterator exception thrown in spawned threads */ + std::exception_ptr iter_exception_{nullptr}; +}; + +// implementation of functions +template inline void ThreadedIter::Destroy(void) { + if (producer_thread_ != NULL) { + { + // lock the mutex + std::lock_guard lock(mutex_); + // send destroy signal + producer_sig_ = kDestroy; + if (nwait_producer_ != 0) { + producer_cond_.notify_one(); + } + } + producer_thread_->join(); + delete producer_thread_; + producer_thread_ = NULL; + } + // end of critical region + // now the slave thread should exit + while (free_cells_.size() != 0) { + delete free_cells_.front(); + free_cells_.pop(); + } + while (queue_.size() != 0) { + delete queue_.front(); + queue_.pop(); + } + if (producer_owned_ != NULL) { + delete producer_owned_; + } + if (out_data_ != NULL) { + delete out_data_; + out_data_ = NULL; + } +} + +template +inline void ThreadedIter:: +Init(Producer *producer, bool pass_ownership) { + CHECK(producer_owned_ == NULL) << "can only call Init once"; + if (pass_ownership) producer_owned_ = producer; + auto next = [producer](DType **dptr) { + return producer->Next(dptr); + }; + auto beforefirst = [producer]() { + producer->BeforeFirst(); + }; + this->Init(next, beforefirst); +} + +template +inline void ThreadedIter::Init(std::function next, + std::function beforefirst) { + producer_sig_ = kProduce; + producer_sig_processed_ = false; + produce_end_ = false; + ClearException(); + // procedure running in prodcuer + // run producer thread + auto producer_fun = [this, next, beforefirst]() { + while (true) { + try { + DType *cell = NULL; + { + // lockscope + std::unique_lock lock(mutex_); + ++this->nwait_producer_; + producer_cond_.wait(lock, [this]() { + if (producer_sig_ == kProduce) { + bool ret = !produce_end_ && (queue_.size() < max_capacity_ || + free_cells_.size() != 0); + return ret; + } else { + return true; + } + }); + --this->nwait_producer_; + if (producer_sig_ == kProduce) { + if (free_cells_.size() != 0) { + cell = free_cells_.front(); + free_cells_.pop(); + } + } else if (producer_sig_ == kBeforeFirst) { + // reset the producer + beforefirst(); + // cleanup the queue + while (queue_.size() != 0) { + free_cells_.push(queue_.front()); + queue_.pop(); + } + // reset the state + produce_end_ = false; + producer_sig_processed_ = true; + producer_sig_ = kProduce; + // notify consumer that all the process as been done. + lock.unlock(); + consumer_cond_.notify_all(); + continue; + } else { + // destroy the thread + DCHECK(producer_sig_ == kDestroy); + producer_sig_processed_ = true; + produce_end_ = true; + consumer_cond_.notify_all(); + return; + } + } // end of lock scope + // now without lock + produce_end_ = !next(&cell); + DCHECK(cell != NULL || produce_end_); + bool notify; + { + // lockscope + std::lock_guard lock(mutex_); + if (!produce_end_) { + queue_.push(cell); + } else { + if (cell != NULL) + free_cells_.push(cell); + } + // put things into queue + notify = nwait_consumer_ != 0; + } + if (notify) + consumer_cond_.notify_all(); + } catch (dmlc::Error &e) { + // Shouldn't throw exception in destructor + DCHECK(producer_sig_ != kDestroy); + { + std::lock_guard lock(mutex_exception_); + if (!iter_exception_) { + iter_exception_ = std::current_exception(); + } + } + bool next_notify = false; + { + std::unique_lock lock(mutex_); + if (producer_sig_ == kBeforeFirst) { + while (queue_.size() != 0) { + free_cells_.push(queue_.front()); + queue_.pop(); + } + produce_end_ = true; + producer_sig_processed_ = true; + lock.unlock(); + consumer_cond_.notify_all(); + } else if (producer_sig_ == kProduce) { + produce_end_ = true; + next_notify = nwait_consumer_ != 0; + lock.unlock(); + if (next_notify) + consumer_cond_.notify_all(); + } + } + return; + } + } + }; + producer_thread_ = new std::thread(producer_fun); +} + +template +inline bool ThreadedIter::Next(DType **out_dptr) { + if (producer_sig_ == kDestroy) + return false; + ThrowExceptionIfSet(); + std::unique_lock lock(mutex_); + CHECK(producer_sig_ == kProduce) + << "Make sure you call BeforeFirst not inconcurrent with Next!"; + ++nwait_consumer_; + consumer_cond_.wait(lock, + [this]() { return queue_.size() != 0 || produce_end_; }); + --nwait_consumer_; + if (queue_.size() != 0) { + *out_dptr = queue_.front(); + queue_.pop(); + bool notify = nwait_producer_ != 0 && !produce_end_; + lock.unlock(); + if (notify) + producer_cond_.notify_one(); + + ThrowExceptionIfSet(); + return true; + } else { + CHECK(produce_end_); + lock.unlock(); + + ThrowExceptionIfSet(); + return false; + } +} + +template +inline void ThreadedIter::Recycle(DType **inout_dptr) { + bool notify; + ThrowExceptionIfSet(); + { + std::lock_guard lock(mutex_); + free_cells_.push(*inout_dptr); + *inout_dptr = NULL; + notify = nwait_producer_ != 0 && !produce_end_; + } + if (notify) + producer_cond_.notify_one(); + ThrowExceptionIfSet(); +} + +template inline void ThreadedIter::ThrowExceptionIfSet(void) { + std::exception_ptr tmp_exception{nullptr}; + { + std::lock_guard lock(mutex_exception_); + if (iter_exception_) { + tmp_exception = iter_exception_; + } + } + if (tmp_exception) + std::rethrow_exception(tmp_exception); +} + +template inline void ThreadedIter::ClearException(void) { + std::lock_guard lock(mutex_exception_); + iter_exception_ = nullptr; +} + +} // namespace dmlc +#endif // DMLC_USE_CXX11 +#endif // DMLC_THREADEDITER_H_ diff --git a/include/dmlc/timer.h b/include/dmlc/timer.h new file mode 100644 index 000000000000..c97059f97812 --- /dev/null +++ b/include/dmlc/timer.h @@ -0,0 +1,49 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file timer.h + * \brief cross platform timer for timing + * \author Tianqi Chen + */ +#ifndef DMLC_TIMER_H_ +#define DMLC_TIMER_H_ + +#include "base.h" + +#if DMLC_USE_CXX11 +#include +#endif + +#include +#ifdef __MACH__ +#include +#include +#endif +#include "./logging.h" + +namespace dmlc { +/*! + * \brief return time in seconds + */ +inline double GetTime(void) { + #if DMLC_USE_CXX11 + return std::chrono::duration( + std::chrono::high_resolution_clock::now().time_since_epoch()).count(); + #elif defined __MACH__ + clock_serv_t cclock; + mach_timespec_t mts; + host_get_clock_service(mach_host_self(), CALENDAR_CLOCK, &cclock); + CHECK(clock_get_time(cclock, &mts) == 0) << "failed to get time"; + mach_port_deallocate(mach_task_self(), cclock); + return static_cast(mts.tv_sec) + static_cast(mts.tv_nsec) * 1e-9; + #else + #if defined(__unix__) || defined(__linux__) + timespec ts; + CHECK(clock_gettime(CLOCK_REALTIME, &ts) == 0) << "failed to get time"; + return static_cast(ts.tv_sec) + static_cast(ts.tv_nsec) * 1e-9; + #else + return static_cast(time(NULL)); + #endif + #endif +} +} // namespace dmlc +#endif // DMLC_TIMER_H_ diff --git a/include/dmlc/type_traits.h b/include/dmlc/type_traits.h new file mode 100644 index 000000000000..c528903499e3 --- /dev/null +++ b/include/dmlc/type_traits.h @@ -0,0 +1,191 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file type_traits.h + * \brief type traits information header + */ +#ifndef DMLC_TYPE_TRAITS_H_ +#define DMLC_TYPE_TRAITS_H_ + +#include "./base.h" +#if DMLC_USE_CXX11 +#include +#endif +#include + +namespace dmlc { +/*! + * \brief whether a type is pod type + * \tparam T the type to query + */ +template +struct is_pod { +#if DMLC_USE_CXX11 + /*! \brief the value of the traits */ + static const bool value = std::is_pod::value; +#else + /*! \brief the value of the traits */ + static const bool value = false; +#endif +}; + + +/*! + * \brief whether a type is integer type + * \tparam T the type to query + */ +template +struct is_integral { +#if DMLC_USE_CXX11 + /*! \brief the value of the traits */ + static const bool value = std::is_integral::value; +#else + /*! \brief the value of the traits */ + static const bool value = false; +#endif +}; + +/*! + * \brief whether a type is floating point type + * \tparam T the type to query + */ +template +struct is_floating_point { +#if DMLC_USE_CXX11 + /*! \brief the value of the traits */ + static const bool value = std::is_floating_point::value; +#else + /*! \brief the value of the traits */ + static const bool value = false; +#endif +}; + +/*! + * \brief whether a type is arithemetic type + * \tparam T the type to query + */ +template +struct is_arithmetic { +#if DMLC_USE_CXX11 + /*! \brief the value of the traits */ + static const bool value = std::is_arithmetic::value; +#else + /*! \brief the value of the traits */ + static const bool value = (dmlc::is_integral::value || + dmlc::is_floating_point::value); +#endif +}; + +/*! + * \brief helper class to construct a string that represents type name + * + * Specialized this class to defined type name of custom types + * + * \tparam T the type to query + */ +template +struct type_name_helper { + /*! + * \return a string of typename. + */ + static inline std::string value() { + return ""; + } +}; + +/*! + * \brief the string representation of type name + * \tparam T the type to query + * \return a const string of typename. + */ +template +inline std::string type_name() { + return type_name_helper::value(); +} + +/*! + * \brief whether a type have save/load function + * \tparam T the type to query + */ +template +struct has_saveload { + /*! \brief the value of the traits */ + static const bool value = false; +}; + +/*! + * \brief template to select type based on condition + * For example, IfThenElseType::Type will give int + * \tparam cond the condition + * \tparam Then the typename to be returned if cond is true + * \tparam Else typename to be returned if cond is false +*/ +template +struct IfThenElseType; + +/*! \brief macro to quickly declare traits information */ +#define DMLC_DECLARE_TRAITS(Trait, Type, Value) \ + template<> \ + struct Trait { \ + static const bool value = Value; \ + } + +/*! \brief macro to quickly declare traits information */ +#define DMLC_DECLARE_TYPE_NAME(Type, Name) \ + template<> \ + struct type_name_helper { \ + static inline std::string value() { \ + return Name; \ + } \ + } + +//! \cond Doxygen_Suppress +// declare special traits when C++11 is not available +#if DMLC_USE_CXX11 == 0 +DMLC_DECLARE_TRAITS(is_pod, char, true); +DMLC_DECLARE_TRAITS(is_pod, int8_t, true); +DMLC_DECLARE_TRAITS(is_pod, int16_t, true); +DMLC_DECLARE_TRAITS(is_pod, int32_t, true); +DMLC_DECLARE_TRAITS(is_pod, int64_t, true); +DMLC_DECLARE_TRAITS(is_pod, uint8_t, true); +DMLC_DECLARE_TRAITS(is_pod, uint16_t, true); +DMLC_DECLARE_TRAITS(is_pod, uint32_t, true); +DMLC_DECLARE_TRAITS(is_pod, uint64_t, true); +DMLC_DECLARE_TRAITS(is_pod, float, true); +DMLC_DECLARE_TRAITS(is_pod, double, true); + +DMLC_DECLARE_TRAITS(is_integral, char, true); +DMLC_DECLARE_TRAITS(is_integral, int8_t, true); +DMLC_DECLARE_TRAITS(is_integral, int16_t, true); +DMLC_DECLARE_TRAITS(is_integral, int32_t, true); +DMLC_DECLARE_TRAITS(is_integral, int64_t, true); +DMLC_DECLARE_TRAITS(is_integral, uint8_t, true); +DMLC_DECLARE_TRAITS(is_integral, uint16_t, true); +DMLC_DECLARE_TRAITS(is_integral, uint32_t, true); +DMLC_DECLARE_TRAITS(is_integral, uint64_t, true); + +DMLC_DECLARE_TRAITS(is_floating_point, float, true); +DMLC_DECLARE_TRAITS(is_floating_point, double, true); + +#endif + +DMLC_DECLARE_TYPE_NAME(float, "float"); +DMLC_DECLARE_TYPE_NAME(double, "double"); +DMLC_DECLARE_TYPE_NAME(int, "int"); +DMLC_DECLARE_TYPE_NAME(uint32_t, "int (non-negative)"); +DMLC_DECLARE_TYPE_NAME(uint64_t, "long (non-negative)"); +DMLC_DECLARE_TYPE_NAME(std::string, "string"); +DMLC_DECLARE_TYPE_NAME(bool, "boolean"); +DMLC_DECLARE_TYPE_NAME(void*, "ptr"); + +template +struct IfThenElseType { + typedef Then Type; +}; + +template +struct IfThenElseType { + typedef Else Type; +}; +//! \endcond +} // namespace dmlc +#endif // DMLC_TYPE_TRAITS_H_ diff --git a/include/mshadow/README.md b/include/mshadow/README.md new file mode 100644 index 000000000000..86276af013e2 --- /dev/null +++ b/include/mshadow/README.md @@ -0,0 +1,8 @@ +Code Guide +==== +This readme contains notes about code in mshadow. MShadow generally follows Google's C++ Style. + +Convention +==== +* Basically, all the files ends in ```-inl.h, -inl.cuh``` are implementations, and can be ignored if only using mshadow +* The files ends in ```.h``` are heavily commented with [doxyen format](http://www.doxygen.org/), and can be used to generate the corresponding document. diff --git a/include/mshadow/base.h b/include/mshadow/base.h new file mode 100755 index 000000000000..4cdab74d6a74 --- /dev/null +++ b/include/mshadow/base.h @@ -0,0 +1,1106 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file base.h + * \brief definitions of base types, operators, macros functions + * + * \author Bing Xu, Tianqi Chen + */ +#ifndef MSHADOW_BASE_H_ +#define MSHADOW_BASE_H_ +#ifdef _MSC_VER +#ifndef _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_WARNINGS +#endif +#ifndef _CRT_SECURE_NO_DEPRECATE +#define _CRT_SECURE_NO_DEPRECATE +#endif +#ifndef NOMINMAX +#define NOMINMAX +#endif +#endif +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _MSC_VER +//! \cond Doxygen_Suppress +typedef signed char int8_t; +typedef __int16 int16_t; +typedef __int32 int32_t; +typedef __int64 int64_t; +typedef unsigned char uint8_t; +typedef unsigned __int16 uint16_t; +typedef unsigned __int32 uint32_t; +typedef unsigned __int64 uint64_t; +//! \endcond +#else +#include +#endif +// macro defintiions +/*! + * \brief if this macro is define to be 1, + * mshadow should compile without any of other libs + */ +#ifndef MSHADOW_STAND_ALONE +#define MSHADOW_STAND_ALONE 0 +#endif +/*! \brief whether do padding during allocation */ +#ifndef MSHADOW_ALLOC_PAD +#define MSHADOW_ALLOC_PAD true +#endif +/*! + * \brief + * x dimension of data must be bigger pad_size * ratio to be alloced padded memory, + * otherwise use tide allocation + * for example, if pad_ratio=2, GPU memory alignement size is 32, + * then we will only allocate padded memory if x dimension > 64 + * set it to 0 then we will always allocate padded memory + */ +#ifndef MSHADOW_MIN_PAD_RATIO + #define MSHADOW_MIN_PAD_RATIO 2 +#endif + +#if MSHADOW_STAND_ALONE + #define MSHADOW_USE_CBLAS 0 + #define MSHADOW_USE_MKL 0 + #define MSHADOW_USE_CUDA 0 +#endif + +/*! + * \brief force user to use GPU stream during computation + * error will be shot when default stream NULL is used + */ +#ifndef MSHADOW_FORCE_STREAM +#define MSHADOW_FORCE_STREAM 1 +#endif + +/*! \brief use CBLAS for CBLAS */ +#ifndef MSHADOW_USE_CBLAS + #define MSHADOW_USE_CBLAS 0 +#endif +/*! \brief use MKL for BLAS */ +#ifndef MSHADOW_USE_MKL + #define MSHADOW_USE_MKL 1 +#endif + +/*! + * \brief use CUDA support, must ensure that the cuda include path is correct, + * or directly compile using nvcc + */ +#ifndef MSHADOW_USE_CUDA + #define MSHADOW_USE_CUDA 1 +#endif + +/*! + * \brief use CUDNN support, must ensure that the cudnn include path is correct + */ +#ifndef MSHADOW_USE_CUDNN + #define MSHADOW_USE_CUDNN 0 +#endif + +/*! + * \brief use CUSOLVER support + */ +#ifndef MSHADOW_USE_CUSOLVER + #define MSHADOW_USE_CUSOLVER MSHADOW_USE_CUDA +#endif + +/*! + * \brief seems CUDAARCH is deprecated in future NVCC + * set this to 1 if you want to use CUDA version smaller than 2.0 + */ +#ifndef MSHADOW_OLD_CUDA +#define MSHADOW_OLD_CUDA 0 +#endif + +/*! + * \brief macro to decide existence of c++11 compiler + */ +#ifndef MSHADOW_IN_CXX11 + #if (defined(__GXX_EXPERIMENTAL_CXX0X__) ||\ + __cplusplus >= 201103L || defined(_MSC_VER)) + #define MSHADOW_IN_CXX11 1 + #else + #define MSHADOW_IN_CXX11 0 + #endif +#endif + +/*! \brief whether use SSE */ +#ifndef MSHADOW_USE_SSE + #define MSHADOW_USE_SSE 1 +#endif + +/*! \brief whether use F16C instruction set architecture extension */ +#ifndef MSHADOW_USE_F16C + #if defined(_MSC_VER) || defined(__CUDACC__) + #define MSHADOW_USE_F16C 0 + #elif defined(__clang__) && \ + ((__clang_major__ < 8) || ((__clang_major__ == 8) && (__clang_minor__ < 1))) + #define MSHADOW_USE_F16C 0 + #else + #define MSHADOW_USE_F16C 1 + #endif +#endif + +/*! \brief whether use NVML to get dynamic info */ +#ifndef MSHADOW_USE_NVML + #define MSHADOW_USE_NVML 0 +#endif +// SSE is conflict with cudacc +#ifdef __CUDACC__ + #undef MSHADOW_USE_SSE + #define MSHADOW_USE_SSE 0 +#endif + +#if MSHADOW_USE_CBLAS +extern "C" { + #include +} +#elif MSHADOW_USE_MKL + #include + #include + #include + #include + #include +#endif + +#if MSHADOW_USE_CUDA + #include + #include + #include +#endif + +#if MSHADOW_USE_CUDNN == 1 + #include +#endif + +#if MSHADOW_USE_CUSOLVER == 1 + #include +#endif + +#if MSHADOW_USE_NVML + #include +#endif + +// -------------------------------- +// MSHADOW_XINLINE is used for inlining template code for both CUDA and CPU code +#ifdef MSHADOW_XINLINE + #error "MSHADOW_XINLINE must not be defined" +#endif +#ifdef _MSC_VER +#define MSHADOW_FORCE_INLINE __forceinline +#pragma warning(disable : 4068) +#else +#define MSHADOW_FORCE_INLINE inline __attribute__((always_inline)) +#endif +#ifdef __CUDACC__ + #define MSHADOW_XINLINE MSHADOW_FORCE_INLINE __device__ __host__ +#else + #define MSHADOW_XINLINE MSHADOW_FORCE_INLINE +#endif +/*! \brief cpu force inline */ +#define MSHADOW_CINLINE MSHADOW_FORCE_INLINE + +#if defined(__GXX_EXPERIMENTAL_CXX0X) ||\ + defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L + #define MSHADOW_CONSTEXPR constexpr +#else + #define MSHADOW_CONSTEXPR const +#endif + +/*! + * \brief default data type for tensor string + * in code release, change it to default_real_t + * during development, change it to empty string so that missing + * template arguments can be detected + */ +#ifndef MSHADOW_DEFAULT_DTYPE +#define MSHADOW_DEFAULT_DTYPE = ::mshadow::default_real_t +#endif + +/*! + * \brief DMLC marco for logging + */ +#ifndef MSHADOW_USE_GLOG +#define MSHADOW_USE_GLOG DMLC_USE_GLOG +#endif // MSHADOW_USE_GLOG + +#if DMLC_USE_CXX11 +#define MSHADOW_THROW_EXCEPTION noexcept(false) +#define MSHADOW_NO_EXCEPTION noexcept(true) +#else +#define MSHADOW_THROW_EXCEPTION +#define MSHADOW_NO_EXCEPTION +#endif + +#if defined(_MSC_VER) +#define MSHADOW_ALIGNED(x) __declspec(align(x)) +#else +#define MSHADOW_ALIGNED(x) __attribute__ ((aligned(x))) +#endif + +/*! + * \brief Protected cuda call in mshadow + * \param func Expression to call. + * It checks for CUDA errors after invocation of the expression. + */ +#define MSHADOW_CUDA_CALL(func) \ + { \ + cudaError_t e = (func); \ + if (e == cudaErrorCudartUnloading) { \ + throw dmlc::Error(cudaGetErrorString(e)); \ + } \ + CHECK(e == cudaSuccess) \ + << "CUDA: " << cudaGetErrorString(e); \ + } + +/*! + * \brief Run function and catch error, log unknown error. + * \param func Expression to call. + */ +#define MSHADOW_CATCH_ERROR(func) \ + { \ + try { \ + (func); \ + } catch (const dmlc::Error &e) { \ + std::string what = e.what(); \ + if (what.find("driver shutting down") == std::string::npos) { \ + LOG(ERROR) << "Ignore CUDA Error " << what; \ + } \ + } \ + } + +#include "./half.h" +#include "./half2.h" +#include "./logging.h" +/*! \brief namespace for mshadow */ +namespace mshadow { +/*! \brief buffer size for each random number generator */ +const unsigned kRandBufferSize = 1000000; +/*! \brief pi */ +const float kPi = 3.1415926f; +/*! \brief type that will be used for index */ +typedef int64_t index_t; + +#ifdef _WIN32 + /*! \brief openmp index for windows */ + typedef int64_t openmp_index_t; +#else + /*! \brief openmp index for linux */ + typedef index_t openmp_index_t; +#endif + +/*! \brief float point type that will be used in default by mshadow */ +typedef float default_real_t; + +/*! \brief data type flag */ +enum TypeFlag { + kFloat32 = 0, + kFloat64 = 1, + kFloat16 = 2, + kUint8 = 3, + kInt32 = 4, + kInt8 = 5, + kInt64 = 6, +}; + +template +struct DataType; + +template<> +struct DataType { + static const int kFlag = kFloat32; + static const int kLanes = 1; +#if MSHADOW_USE_CUDA +#if (CUDA_VERSION >= 8000) + static const cudaDataType_t kCudaFlag = CUDA_R_32F; +#endif +#if MSHADOW_USE_CUDNN + static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_FLOAT; + typedef float ScaleType; +#endif +#endif +}; +template<> +struct DataType { + static const int kFlag = kFloat64; + static const int kLanes = 1; +#if MSHADOW_USE_CUDA +#if (CUDA_VERSION >= 8000) + static const cudaDataType_t kCudaFlag = CUDA_R_64F; +#endif +#if MSHADOW_USE_CUDNN + static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_DOUBLE; + typedef double ScaleType; +#endif +#endif +}; +template<> +struct DataType { + static const int kFlag = kFloat16; + static const int kLanes = 1; +#if MSHADOW_USE_CUDA +#if (CUDA_VERSION >= 8000) + static const cudaDataType_t kCudaFlag = CUDA_R_16F; +#endif +#if MSHADOW_USE_CUDNN + static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_HALF; + typedef float ScaleType; +#endif +#endif +}; +template<> +struct DataType { + static const int kFlag = kFloat16; + static const int kLanes = 2; +}; +template<> +struct DataType { + static const int kFlag = kUint8; + static const int kLanes = 1; +#if MSHADOW_USE_CUDA +#if (CUDA_VERSION >= 8000) + static const cudaDataType_t kCudaFlag = CUDA_R_8U; +#endif +#if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6) + // no uint8 in cudnn for now + static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT8; + typedef uint8_t ScaleType; +#endif +#endif +}; +template<> +struct DataType { + static const int kFlag = kInt8; + static const int kLanes = 1; +#if MSHADOW_USE_CUDA +#if (CUDA_VERSION >= 8000) + static const cudaDataType_t kCudaFlag = CUDA_R_8I; +#endif +#if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6) + static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT8; + typedef int8_t ScaleType; +#endif +#endif +}; +template<> +struct DataType { + static const int kFlag = kInt32; + static const int kLanes = 1; +#if MSHADOW_USE_CUDA +#if (CUDA_VERSION >= 8000) + static const cudaDataType_t kCudaFlag = CUDA_R_32I; +#endif +#if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6) + static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT32; + typedef int32_t ScaleType; +#endif +#endif +}; +template<> +struct DataType { + static const int kFlag = kInt64; + static const int kLanes = 1; +}; + +/*! \brief type enum value for default real type */ +const int default_type_flag = DataType::kFlag; + +/*! layout flag */ +enum LayoutFlag { + kNCHW = 0, + kNHWC, + kCHWN, + + kNCW = 1 << 3, + kNWC, + kCWN, + + kNCDHW = 1 << 5, + kNDHWC, + kCDHWN +}; + +template +struct LayoutType; + +template<> +struct LayoutType { + static const index_t kNdim = 4; +#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4) + static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW; +#else + static const int kCudnnFlag = -1; +#endif +}; + +template<> +struct LayoutType { + static const index_t kNdim = 4; +#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4) + static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC; +#else + static const int kCudnnFlag = -1; +#endif +}; + +/*! \brief default layout for 4d tensor */ +const int default_layout = kNCHW; + +template<> +struct LayoutType { + static const index_t kNdim = 5; +#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4) + static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW; +#else + static const int kCudnnFlag = -1; +#endif +}; + +template<> +struct LayoutType { + static const index_t kNdim = 5; +#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4) + static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC; +#else + static const int kCudnnFlag = -1; +#endif +}; + +/*! \brief default layout for 5d tensor */ +const int default_layout_5d = kNCDHW; + +/*! \brief namespace for operators */ +namespace op { +// binary operator +/*! \brief mul operator */ +struct mul{ + /*! \brief map a, b to result using defined operation */ + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + return a * b; + } +}; +/*! \brief plus operator */ +struct plus { + /*! \brief map a, b to result using defined operation */ + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + return a + b; + } +}; +/*! \brief minus operator */ +struct minus { + /*! \brief map a, b to result using defined operation */ + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + return a - b; + } +}; +/*! \brief divide operator */ +struct div { + /*! \brief map a, b to result using defined operation */ + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + return a / b; + } +}; +/*! \brief get rhs */ +struct right { + /*! \brief map a, b to result using defined operation */ + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + return b; + } +}; +// unary operator/ function: example +// these operators can be defined by user, +// in the same style as binary and unary operator +// to use, simply write F( src ) +/*! \brief identity function that maps a real number to it self */ +struct identity{ + /*! \brief map a to result using defined operation */ + template + MSHADOW_XINLINE static DType Map(DType a) { + return a; + } +}; +} // namespace op +/*! \brief namespace for savers */ +namespace sv { +/*! \brief save to saver: = */ +struct saveto { + /*! \brief save b to a using save method */ + template + MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*) + a = b; + } + /*! \brief helper constant to use BLAS, alpha */ + inline static default_real_t AlphaBLAS(void) { return 1.0f; } + /*! \brief helper constant to use BLAS, beta */ + inline static default_real_t BetaBLAS(void) { return 0.0f; } + /*! \brief corresponding binary operator type */ + typedef op::right OPType; +}; +/*! \brief save to saver: += */ +struct plusto { + /*! \brief save b to a using save method */ + template + MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*) + a += b; + } + /*! \brief helper constant to use BLAS, alpha */ + inline static default_real_t AlphaBLAS(void) { return 1.0f; } + /*! \brief helper constant to use BLAS, beta */ + inline static default_real_t BetaBLAS(void) { return 1.0f; } + /*! \brief corresponding binary operator type */ + typedef op::plus OPType; +}; +/*! \brief minus to saver: -= */ +struct minusto { + /*! \brief save b to a using save method */ + template + MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*) + a -= b; + } + /*! \brief helper constant to use BLAS, alpha */ + inline static default_real_t AlphaBLAS(void) { return -1.0f; } + /*! \brief helper constant to use BLAS, beta */ + inline static default_real_t BetaBLAS(void) { return 1.0f; } + /*! \brief corresponding binary operator type */ + typedef op::minus OPType; +}; +/*! \brief multiply to saver: *= */ +struct multo { + /*! \brief save b to a using save method */ + template + MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*) + a *= b; + } + /*! \brief corresponding binary operator type */ + typedef op::mul OPType; +}; +/*! \brief divide to saver: /= */ +struct divto { + /*! \brief save b to a using save method */ + template + MSHADOW_XINLINE static void Save(DType& a, DType b) { // NOLINT(*) + a /= b; + } + /*! \brief corresponding binary operator type */ + typedef op::div OPType; +}; +} // namespace sv +/*! \brief namespace for potential reducer operations */ +namespace red { +namespace limits { +/*! + * \brief minimum value of certain types + * \tparam DType data type + */ +template +MSHADOW_XINLINE DType MinValue(void); +/*! \brief minimum value of float */ +template<> +MSHADOW_XINLINE float MinValue(void) { + return -FLT_MAX; +} +/*! \brief minimum value of double */ +template<> +MSHADOW_XINLINE double MinValue(void) { + return -DBL_MAX; +} +/*! \brief minimum value of half */ +template<> +MSHADOW_XINLINE half::half_t MinValue(void) { + return MSHADOW_HALF_MIN; +} +/*! \brief minimum value of uint8_t */ +template<> +MSHADOW_XINLINE uint8_t MinValue(void) { + return 0; +} +/*! \brief minimum value of int8_t */ +template<> +MSHADOW_XINLINE int8_t MinValue(void) { + return SCHAR_MIN; +} +/*! \brief minimum value of int32_t */ +template<> +MSHADOW_XINLINE int MinValue(void) { + return INT_MIN; +} +/*! \brief minimum value of int64_t */ +template<> +MSHADOW_XINLINE int64_t MinValue(void) { + return LLONG_MIN; +} + +/*! + * \brief maximum value of certain types + * \tparam DType data type + */ +template +MSHADOW_XINLINE DType MaxValue(void); +/*! \brief maximum value of float */ +template<> +MSHADOW_XINLINE float MaxValue(void) { + return FLT_MAX; +} +/*! \brief maximum value of double */ +template<> +MSHADOW_XINLINE double MaxValue(void) { + return DBL_MAX; +} +/*! \brief maximum value of half */ +template<> +MSHADOW_XINLINE half::half_t MaxValue(void) { + return MSHADOW_HALF_MAX; +} +/*! \brief maximum value of uint8_t */ +template<> +MSHADOW_XINLINE uint8_t MaxValue(void) { + return UCHAR_MAX; +} +/*! \brief maximum value of int8_t */ +template<> +MSHADOW_XINLINE int8_t MaxValue(void) { + return SCHAR_MAX; +} +/*! \brief maximum value of int32_t */ +template<> +MSHADOW_XINLINE int MaxValue(void) { + return INT_MAX; +} +/*! \brief maximum value of int64_t */ +template<> +MSHADOW_XINLINE int64_t MaxValue(void) { + return LLONG_MAX; +} +} // namespace limits + +/*! \brief sum reducer */ +struct sum { + /*! \brief do reduction into dst */ + template + MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*) + dst += src; + } + /*! \brief do stable reduction into dst */ + template + MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& residual) { // NOLINT(*) + DType y = src - residual; + DType t = dst + y; + residual = (t - dst) - y; + dst = t; + } + /*! \brief combine the results of two reducers */ + template + MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*) + Reduce(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*) + DType t1 = dst_val + src_val; + DType e = t1 - dst_val; + DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual; + dst_val = t1 + t2; + dst_residual = t2 - (dst_val - t1); + } + /*! \brief finalize reduction */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*) + /*! \brief finalize reduction */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*) + /*! + *\brief calculate gradient of redres with respect to redsrc, + * redres: reduced result, redsrc: one of reduction element + */ + template + MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) { + return 1; + } + /*! + *\brief set the initial value during reduction + */ + template + MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*) + initv = 0; + } + /*! + *\brief set the initial value during reduction + */ + template + MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &residual) { // NOLINT(*) + SetInitValue(initv); + residual = 0; + } +}; +/*! \brief maximum reducer */ +struct maximum { + /*! \brief do reduction into dst */ + template + MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*) + using namespace std; +#ifdef __CUDACC__ + dst = ::max(dst, src); +#else + dst = max(dst, src); +#endif // __CUDACC__ + } + /*! \brief do reduction into dst */ + template + MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType &none) { // NOLINT(*) + Reduce(dst, src); + } + /*! \brief combine the results of two reducers */ + template + MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*) + Reduce(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*) + Reduce(dst_val, src_val); + } + /*! \brief finalize reduction */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*) + /*! \brief finalize reduction */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*) + /*! + * \brief calculate gradient of redres with respect to redsrc, + * redres: reduced result, redsrc: one of reduction element + */ + template + MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) { + return redres == redsrc ? 1: 0; + } + /*! + *\brief set the initial value during reduction + */ + template + MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*) + initv = limits::MinValue(); + } + /*! + *\brief set the initial value during reduction + */ + template + MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &none) { // NOLINT(*) + SetInitValue(initv); + } +}; +/*! \brief minimum reducer */ +struct minimum { + /*! \brief do reduction into dst */ + template + MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*) + using namespace std; +#ifdef __CUDACC__ + dst = ::min(dst, src); +#else + dst = min(dst, src); +#endif // __CUDACC__ + } + /*! \brief do reduction into dst */ + template + MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType &none) { // NOLINT(*) + Reduce(dst, src); + } + /*! \brief combine the results of two reducers */ + template + MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*) + Reduce(dst_val, src_val); + } + /*! \brief combine the results of two reducers */ + template + MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*) + Reduce(dst_val, src_val); + } + /*! \brief finalize reduction */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*) + /*! \brief finalize reduction */ + template + MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*) + /*! + * \brief calculate gradient of redres with respect to redsrc, + * redres: reduced result, redsrc: one of reduction element + */ + template + MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) { + return redres == redsrc ? 1: 0; + } + /*! + *\brief set the initial value during reduction + */ + template + MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*) + initv = limits::MaxValue(); + } + /*! + *\brief set the initial value during reduction + */ + template + MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &none) { // NOLINT(*) + SetInitValue(initv); + } +}; +} // namespace red + +#define MSHADOW_TYPE_SWITCH(type, DType, ...) \ + switch (type) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat16: \ + { \ + typedef mshadow::half::half_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kUint8: \ + { \ + typedef uint8_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt8: \ + { \ + typedef int8_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt32: \ + { \ + typedef int32_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt64: \ + { \ + typedef int64_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } + +#define MSHADOW_TYPE_SWITCH_WITH_HALF2(type, DType, ...) \ + switch (type) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat16: \ + { \ + typedef mshadow::half::half2_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kUint8: \ + { \ + typedef uint8_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt32: \ + { \ + typedef int32_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt64: \ + { \ + typedef int64_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } + +#define MSHADOW_SGL_DBL_TYPE_SWITCH(type, DType, ...) \ + switch (type) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + LOG(FATAL) << "This operation only supports " \ + "32-bit and 64-bit floating point"; \ + } + +#define MSHADOW_REAL_TYPE_SWITCH(type, DType, ...) \ + switch (type) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat16: \ + { \ + typedef mshadow::half::half_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kUint8: \ + LOG(FATAL) << "This operation only support " \ + "floating point types not uint8"; \ + break; \ + case mshadow::kInt8: \ + LOG(FATAL) << "This operation only support " \ + "floating point types not int8"; \ + break; \ + case mshadow::kInt32: \ + LOG(FATAL) << "This operation only support " \ + "floating point types, not int32";\ + break; \ + case mshadow::kInt64: \ + LOG(FATAL) << "This operation only support " \ + "floating point types, not int64";\ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } + +#define MSHADOW_REAL_TYPE_SWITCH_EX(type$, DType$, DLargeType$, ...) \ + switch (type$) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType$; \ + typedef float DLargeType$; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType$; \ + typedef double DLargeType$; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat16: \ + { \ + typedef mshadow::half::half_t DType$; \ + typedef float DLargeType$; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kUint8: \ + LOG(FATAL) << "This operation only support " \ + "floating point types not uint8"; \ + break; \ + case mshadow::kInt8: \ + LOG(FATAL) << "This operation only support " \ + "floating point types not int8"; \ + break; \ + case mshadow::kInt32: \ + LOG(FATAL) << "This operation only support " \ + "floating point types, not int32";\ + break; \ + case mshadow::kInt64: \ + LOG(FATAL) << "This operation only support " \ + "floating point types, not int64";\ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type$; \ + } + +#define MSHADOW_LAYOUT_SWITCH(layout, Layout, ...) \ + switch (layout) { \ + case mshadow::kNCHW: \ + { \ + const int Layout = kNCHW; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kNHWC: \ + { \ + const int Layout = kNHWC; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kNCDHW: \ + { \ + const int Layout = kNCDHW; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kNDHWC: \ + { \ + const int Layout = kNDHWC; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + LOG(FATAL) << "Unknown layout enum " << layout; \ + } + +/*! + * \brief Only supports int64 index type for aux_data + * in NDArray class fow now. + */ +#define MSHADOW_IDX_TYPE_SWITCH(type, DType, ...) \ + switch (type) { \ + case mshadow::kInt64: \ + { \ + typedef int64_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } + +/*! \brief get data type size from type enum */ +inline size_t mshadow_sizeof(int type) { + int size = 0; + MSHADOW_TYPE_SWITCH(type, DType, size = sizeof(DType);); + return size; +} + +} // namespace mshadow +#endif // MSHADOW_BASE_H_ diff --git a/include/mshadow/cuda/reduce.cuh b/include/mshadow/cuda/reduce.cuh new file mode 100644 index 000000000000..921d5ad5e0c0 --- /dev/null +++ b/include/mshadow/cuda/reduce.cuh @@ -0,0 +1,120 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file reduce.cuh + * \brief helper functions to do reduction + * \author Tianqi Chen + */ +#ifndef MSHADOW_CUDA_REDUCE_CUH_ +#define MSHADOW_CUDA_REDUCE_CUH_ + +namespace mshadow { +namespace cuda { +/* + * \brief reduce over the dimension x + * \tparam Reducer reducer + * \tparam x_bits dimension = 1< +inline __device__ void Reduce1D(volatile DType buf[1 << x_bits]); +/* + * \brief reduce over the dimension x + * \tparam Reducer reducer + * \tparam xmax_bits maximum size of buffer + * \tparam DType content data type + * \param xsize size of x dimension, not sure if aligned + */ +template +inline __device__ void +Reduce1DNotAlign(volatile DType buf[1 << xmax_bits], int xsize); +// ===============================================x=== +// implementations afterwards, +// no need to read if only use the functions +// -------------------------------------------------- +#ifdef __DEVICE_EMULATION__ +#define __syncwarp() __syncthreads() +#else +#if CUDA_VERSION < 9000 +#define __syncwarp() +#endif +#endif + +template +inline __device__ void ReduceX(volatile DType buf[], int tid) { + if (x_bits >= 10) { + if (tid < 512) Reducer::Reduce(buf[tid] , buf[tid + 512]); + __syncthreads(); + } + if (x_bits >= 9) { + if (tid < 256) Reducer::Reduce(buf[tid] , buf[tid + 256]); + __syncthreads(); + } + if (x_bits >= 8) { + if (tid < 128) Reducer::Reduce(buf[tid] , buf[tid + 128]); + __syncthreads(); + } + if (x_bits >= 7) { + if (tid < 64) Reducer::Reduce(buf[tid] , buf[tid + 64]); + __syncthreads(); + } + if (x_bits >= 6) { + if (tid < 32) Reducer::Reduce(buf[tid] , buf[tid + 32]); + __syncthreads(); + } + // in warp optimization + if (x_bits >= 5) { + if (tid < 16) Reducer::Reduce(buf[tid] , buf[tid + 16]); +#if MSHADOW_OLD_CUDA + __syncthreads(); +#else + __syncwarp(); +#endif + } + if (x_bits >= 4) { + if (tid < 8) Reducer::Reduce(buf[tid] , buf[tid + 8]); + __syncwarp(); + } + if (x_bits >= 3) { + if (tid < 4) Reducer::Reduce(buf[tid] , buf[tid + 4]); + __syncwarp(); + } + if (x_bits >= 2) { + if (tid < 2) Reducer::Reduce(buf[tid] , buf[tid + 2]); + __syncwarp(); + } + if (x_bits >= 1) { + if (tid < 1) Reducer::Reduce(buf[tid] , buf[tid + 1]); + __syncwarp(); + } +} +template +inline __device__ void Reduce1D(volatile DType buf[1 << x_bits]) { + ReduceX(buf, threadIdx.x); +} +// reduce with a upper bound +#define __RD_NON_ALIGN(els, x_bits) \ + els \ + if (xmax_bits >= x_bits && x_size >= (1 << x_bits)) { \ + if (tid < (1 << x_bits) && tid + (1 << x_bits) < x_size) { \ + Reducer::Reduce(buf[tid] , buf[tid + (1 << x_bits)]); \ + } \ + __syncthreads(); \ + ReduceX(buf, tid); \ + } \ + +template +inline __device__ void Reduce1DNotAlign(volatile DType buf[], int x_size) { + int tid = threadIdx.x; + __RD_NON_ALIGN(, 8) + __RD_NON_ALIGN(else, 7) + __RD_NON_ALIGN(else, 6) + __RD_NON_ALIGN(else, 5) + __RD_NON_ALIGN(else, 4) + __RD_NON_ALIGN(else, 3) + __RD_NON_ALIGN(else, 2) + __RD_NON_ALIGN(else, 1) +} +} // namespace cuda +} // namespace mshadow +#endif // MSHADOW_CUDA_REDUCE_CUH_ + diff --git a/include/mshadow/cuda/tensor_gpu-inl.cuh b/include/mshadow/cuda/tensor_gpu-inl.cuh new file mode 100755 index 000000000000..72e4b7eb9ee9 --- /dev/null +++ b/include/mshadow/cuda/tensor_gpu-inl.cuh @@ -0,0 +1,828 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file tensor_gpu-inl.cuh + * \brief implementation of GPU code using CUDA + * \author Bing Xu, Tianqi Chen + */ +#ifndef MSHADOW_CUDA_TENSOR_GPU_INL_CUH_ +#define MSHADOW_CUDA_TENSOR_GPU_INL_CUH_ +#include +#include +#if CUDA_VERSION >= 7000 +#include +#endif +#include "../tensor.h" +#include "./reduce.cuh" +#define MSHADOW_CUDA_POST_KERNEL_CHECK(x) \ + /* Code block avoids redefinition of cudaError_t err */ \ + do { \ + cudaError err = cudaPeekAtLastError(); \ + CHECK_EQ(err, cudaSuccess) << "Name: " << #x << " ErrStr:" << cudaGetErrorString(err); \ + } while (0) +namespace mshadow { +namespace cuda { +/* load unit for memory access, if CUDAARCH not defined, this is advanced nvcc */ +#if MSHADOW_OLD_CUDA +const int kMemUnitBits = 4; +const int kMaxThreadsPerBlock = 512; +#else +const int kMemUnitBits = 5; +const int kMaxThreadsPerBlock = 1024; +#endif +/*! \brief number of units that can do synchronized update, half warp size */ +const int kMemUnit = 1 << kMemUnitBits; +/*! \brief mask that could be helpful sometime */ +const int kMemUnitMask = kMemUnit - 1; +/*! \brief suggested thread number(logscale) for mapping kernel */ +const int kBaseThreadBits = 8; +/*! \brief suggested thread number for mapping kernel */ +const int kBaseThreadNum = 1 << kBaseThreadBits; +/*! \brief maximum value of grid */ +const int kMaxGridNum = 65535; +/*! \brief maximum value of grid within each dimension */ +const int kMaxGridDim = 65535; +/*! \brief suggested grid number for mapping kernel */ +const int kBaseGridNum = 1024; +/*! \brief get align stride for given size in x dimension */ +inline index_t GetAlignStride(index_t xsize) { + if (xsize >= MSHADOW_MIN_PAD_RATIO * 32) { + return ((xsize + kMemUnit - 1) >> kMemUnitBits) << kMemUnitBits; + } else { + // if originally space is not aligned, no necessary to to alligned thread allocation + return xsize; + } +} +inline void CheckLaunchParam(dim3 dimGrid, dim3 dimBlock, const char *estr = "") { + if (dimBlock.x * dimBlock.y * dimBlock.z > static_cast(kMaxThreadsPerBlock) || + dimGrid.x > kMaxGridDim || dimGrid.y > kMaxGridDim) { + LOG(FATAL) << "too large launch parameter: " + << estr << "[" + << dimGrid.x << "," + << dimGrid.y << "], [" + << dimBlock.x << "," + << dimBlock.y << "," + << dimBlock.z << "]"; + } +} +template +__device__ void MapPlanProc(DstPlan dst, index_t xstride, + Shape<2> dshape, const Plan plan, int block_idx) { + const index_t tid = (block_idx << block_dim_bits) + threadIdx.x; + const int y = tid / xstride; + const int x = tid % xstride; + if (y < dshape[0] && x < dshape[1]) { + Saver::Save(dst.REval(y, x), plan.Eval(y, x)); + } +} +template +__global__ void MapPlanKernel(DstPlan dst, index_t xstride, + Shape<2> dshape, const Plan plan) { + MapPlanProc + (dst, xstride, dshape, plan, blockIdx.x); +} +template +__global__ void MapPlanLargeKernel(DstPlan dst, index_t xstride, + Shape<2> dshape, const Plan plan, int repeat) { + for (int i = 0; i < repeat; ++i) { + MapPlanProc + (dst, xstride, dshape, plan, blockIdx.x + i * grid_size); + } +} + +template +inline void MapPlan(expr::Plan dst, + const expr::Plan &plan, + Shape<2> dshape, + cudaStream_t stream) { + const index_t xstride = GetAlignStride(dshape[1]); + const int num_block = (dshape[0] * xstride + kBaseThreadNum-1) / kBaseThreadNum; + dim3 dimBlock(kBaseThreadNum, 1, 1); + + if (num_block < kMaxGridNum) { + dim3 dimGrid(num_block, 1, 1); + MapPlanKernel, + expr::Plan > + <<>>(dst, xstride, dshape, plan); + MSHADOW_CUDA_POST_KERNEL_CHECK(MapPlanKernel); + } else { + int repeat = (num_block + kBaseGridNum-1) / kBaseGridNum; + dim3 dimGrid(kBaseGridNum, 1 , 1); + MapPlanLargeKernel, + expr::Plan > + <<>>(dst, xstride, dshape, plan, repeat); + MSHADOW_CUDA_POST_KERNEL_CHECK(MapPlanLargeKernel); + } +} + +template +__global__ void +__launch_bounds__(kMemUnit*kMemUnit, 1) +MapRedKeepLowestKernel(DstPlan dst, Plan plan, + DType scale, Shape<2> eshape) { + const unsigned warp_size = 1 << warp_bits; + const unsigned x = (blockIdx.x << warp_bits) + threadIdx.x; + // to avoid bank conflict + __shared__ DType s_res[warp_size][warp_size + 1]; + // note: reverse store [y][x], so that we can reduce over threadIdx.x, use warp optimization + if (threadIdx.y < eshape[0] && x < eshape[1]) { + s_res[threadIdx.x][threadIdx.y] = plan.Eval(threadIdx.y, x); + } + for (unsigned y = warp_size; y < eshape[0]; y += warp_size) { + if (threadIdx.y + y < eshape[0] && x < eshape[1]) { + Reducer::Reduce(s_res[threadIdx.x][threadIdx.y], plan.Eval(threadIdx.y + y, x)); + } + } + __syncthreads(); + if (eshape[0] >= warp_size) { + Reduce1D(s_res[threadIdx.y]); + } else { + Reduce1DNotAlign(s_res[threadIdx.y], eshape[0]); + } + __syncthreads(); + + if (threadIdx.y == 0 && x < eshape[1]) { + Saver::Save(dst.REval(0, x), DType(s_res[threadIdx.x][0] * scale)); + } +} + +template +inline void MapReduceKeepLowest(expr::Plan dst, + const expr::Plan &plan, + DType scale, Shape<2> eshape, + cudaStream_t stream) { + dim3 dimBlock(kMemUnit, kMemUnit); + dim3 dimGrid((eshape[1] + kMemUnit - 1) >> kMemUnitBits); + CheckLaunchParam(dimGrid, dimBlock, "MapRedKeepLowestKernel"); + MapRedKeepLowestKernel, + expr::Plan > + <<>>(dst, plan, scale, eshape); + MSHADOW_CUDA_POST_KERNEL_CHECK(MapRedKeepLowestKernel); +} + +template +__global__ void MapReduceKeepDim1Kernel(DstPlan dst, Plan plan, DType scale, Shape<4> pshape) { + const int block_size = 1 << block_dim_bits; + __shared__ DType s_rec[block_size]; + const int c = blockIdx.x + blockIdx.y * gridDim.x; + const index_t tot = pshape[3] * pshape[2] * pshape[0]; + + if (c < pshape[1]) { + DType res; Reducer::SetInitValue(res); + for (index_t i_offset = 0; i_offset < tot; i_offset += block_size) { + index_t i = i_offset + threadIdx.x; + if (i< tot) { + const index_t x = i % pshape[3]; + i /= pshape[3]; + const index_t y = i % pshape[2]; + const index_t n = i / pshape[2]; + Reducer::Reduce(res, plan.Eval((n * pshape[1] + c) * pshape[2] + y, x)); + } + } + s_rec[threadIdx.x] = res; + __syncthreads(); + Reduce1D(s_rec); + if (threadIdx.x == 0) { + Saver::Save(dst.REval(0, c), DType(s_rec[0] * scale)); + } + } +} + +template +inline void MapReduceKeepDim1(expr::Plan dst, + const expr::Plan &plan, + DType scale, Shape<4> pshape, + cudaStream_t stream) { + dim3 dimBlock(kBaseThreadNum); + const int grid_dim_x = (pshape[1] > kMaxGridNum) ? kMaxGridNum : pshape[1]; + const int grid_dim_y = (pshape[1] > kMaxGridNum) ? (pshape[1] + kMaxGridNum - 1) / kMaxGridNum + : 1; + dim3 dimGrid(grid_dim_x, grid_dim_y); + CheckLaunchParam(dimGrid, dimBlock, "MapReduceKeepDim1"); + MapReduceKeepDim1Kernel, + expr::Plan > + <<>>(dst, plan, scale, pshape); + MSHADOW_CUDA_POST_KERNEL_CHECK(MapReduceKeepDim1Kernel); +} + +template +__global__ void GetBatchedViewKernel(DType **dst, DType *src, int num, int stride) { + const int x_size = 1 << x_bits; + const int start = threadIdx.x; + // Copy the addresses of src to dst every stride steps + for (int i = start; i < num; i += x_size) { + dst[i] = src + i * stride; + } +} + +template +inline void GetBatchedView(DType **dst, DType *src, int num, int stride, + Stream *stream) { + cudaStream_t stream_ = Stream::GetStream(stream); + dim3 dimBlock(kBaseThreadNum); + dim3 dimGrid(1); + CheckLaunchParam(dimGrid, dimBlock, "GetBatchedView"); + GetBatchedViewKernel + <<>> (dst, src, num, stride); + MSHADOW_CUDA_POST_KERNEL_CHECK(GetBatchedViewKernel); +} + +template +__global__ void SoftmaxGradKernel(DstPlan dst, SrcPlan1 src, SrcPlan2 label, index_t xmax) { + const unsigned x_size = 1 << x_bits; + const int y = blockIdx.x; + const int k = static_cast(label.Eval(0, y)); + + // calculate normalizer, with writeback + for (unsigned x = 0; x < xmax; x += x_size) { + const unsigned xindex = x + threadIdx.x; + if (xindex < xmax) { + if (xindex == k) { + dst.REval(y, xindex) = src.Eval(y, xindex) - 1.0f; + } else { + dst.REval(y, xindex) = src.Eval(y, xindex); + } + } + } +} + +template +__global__ void SmoothSoftmaxGradKernel(DstPlan dst, SrcPlan1 src, SrcPlan2 label, index_t xmax, + float alpha) { + const unsigned x_size = 1 << x_bits; + const int y = blockIdx.x; + const int k = static_cast(label.Eval(0, y)); + // xmax is the number of classes in our distribution + const float smooth_grad = (alpha / (xmax - 1)); + + // calculate normalizer, with writeback + for (unsigned x = 0; x < xmax; x += x_size) { + const unsigned xindex = x + threadIdx.x; + if (xindex < xmax) { + if (xindex == k) { + dst.REval(y, xindex) = src.Eval(y, xindex) - 1.0f + alpha; + } else { + dst.REval(y, xindex) = src.Eval(y, xindex) - smooth_grad; + } + } + } +} + +template +__global__ void SoftmaxGradKernel(DstPlan dst, SrcPlan1 src, SrcPlan2 label, index_t xmax, + DType ignore_label) { + const unsigned x_size = 1 << x_bits; + const int y = blockIdx.x; + const int k = static_cast(label.Eval(0, y)); + + // calculate normalizer, with writeback + for (unsigned x = 0; x < xmax; x += x_size) { + const unsigned xindex = x + threadIdx.x; + if (xindex < xmax) { + if (static_cast(ignore_label) == k) { + dst.REval(y, xindex) = 0.0f; + } else { + if (xindex == k) { + dst.REval(y, xindex) = src.Eval(y, xindex) - 1.0f; + } else { + dst.REval(y, xindex) = src.Eval(y, xindex); + } + } + } + } +} + +template +__global__ void SmoothSoftmaxGradKernel(DstPlan dst, SrcPlan1 src, SrcPlan2 label, index_t xmax, + DType ignore_label, float alpha) { + const unsigned x_size = 1 << x_bits; + const int y = blockIdx.x; + const int k = static_cast(label.Eval(0, y)); + // xmax is the number of classes in our distribution + const float smooth_grad = (alpha / (xmax - 1)); + + // calculate normalizer, with writeback + for (unsigned x = 0; x < xmax; x += x_size) { + const unsigned xindex = x + threadIdx.x; + if (xindex < xmax) { + if (static_cast(ignore_label) == k) { + dst.REval(y, xindex) = 0.0f; + } else { + if (xindex == k) { + dst.REval(y, xindex) = src.Eval(y, xindex) - 1.0f + alpha; + } else { + dst.REval(y, xindex) = src.Eval(y, xindex) - smooth_grad; + } + } + } + } +} + +template +__global__ void SoftmaxKernel(DstPlan dst, SrcPlan src, index_t xmax) { + const unsigned x_size = 1 << x_bits; + const int y = blockIdx.x; + __shared__ DType s_rec[x_size]; + // step 1: get max + if (threadIdx.x < xmax) { + s_rec[threadIdx.x] = src.Eval(y, threadIdx.x); + } + for (unsigned x = x_size; x < xmax; x += x_size) { + if (x + threadIdx.x < xmax) { + DType a = src.Eval(y, x + threadIdx.x); + s_rec[threadIdx.x] = max(a, s_rec[threadIdx.x]); + } + } + __syncthreads(); + if (threadIdx.x >= xmax) { + s_rec[threadIdx.x] = s_rec[0]; + } + __syncthreads(); + Reduce1D(s_rec); + __syncthreads(); + DType smax = s_rec[0]; + __syncthreads(); + s_rec[threadIdx.x] = 0.0f; + __syncthreads(); + + // calculate normalizer, with writeback + for (unsigned x = 0; x < xmax; x += x_size) { + if (x + threadIdx.x < xmax) { + DType p = expf(src.Eval(y, x + threadIdx.x) - smax); + s_rec[threadIdx.x] += p; + // write back first, will fetch later + dst.REval(y, x + threadIdx.x) = p; + } + } + // calculate normalizer + __syncthreads(); + Reduce1D(s_rec); + __syncthreads(); + DType ssum = s_rec[0]; + + for (unsigned x = 0; x < xmax; x += x_size) { + if (x + threadIdx.x < xmax) { + dst.REval(y, x + threadIdx.x) /= ssum; + } + } +} + +template +inline void Softmax(const Tensor &dst, + const Tensor &src) { + dim3 dimBlock(kBaseThreadNum); + dim3 dimGrid(dst.size(0)); + CHECK_EQ(dst.shape_, src.shape_) << "Softmax: shape mismatch"; + CheckLaunchParam(dimGrid, dimBlock, "Softmax"); + cudaStream_t stream = Stream::GetStream(dst.stream_); + SoftmaxKernel + <<>> + (expr::MakePlan(dst), + expr::MakePlan(src), + dst.size(1)); + MSHADOW_CUDA_POST_KERNEL_CHECK(SoftmaxKernel); +} + +template +inline void SoftmaxGrad(const Tensor &dst, + const Tensor &src, + const Tensor &label) { + dim3 dimBlock(kBaseThreadNum); + dim3 dimGrid(dst.size(0)); + CHECK_EQ(dst.shape_, src.shape_) << "SoftmaxGrad: shape mismatch"; + CHECK_EQ(dst.size(0), label.size(0)) << "SoftmaxGrad: label shape mismatch"; + CheckLaunchParam(dimGrid, dimBlock, "SoftmaxGrad"); + cudaStream_t stream = Stream::GetStream(dst.stream_); + SoftmaxGradKernel + <<>> + (expr::MakePlan(dst), + expr::MakePlan(src), + expr::MakePlan(label), + dst.size(1)); + MSHADOW_CUDA_POST_KERNEL_CHECK(SoftmaxGradKernel); +} + +template +inline void SmoothSoftmaxGrad(const Tensor &dst, + const Tensor &src, + const Tensor &label, + const float alpha) { + dim3 dimBlock(kBaseThreadNum); + dim3 dimGrid(dst.size(0)); + CHECK_EQ(dst.shape_, src.shape_) << "SoftmaxGrad: shape mismatch"; + CHECK_EQ(dst.size(0), label.size(0)) << "SoftmaxGrad: label shape mismatch"; + CheckLaunchParam(dimGrid, dimBlock, "SoftmaxGrad"); + cudaStream_t stream = Stream::GetStream(dst.stream_); + SmoothSoftmaxGradKernel + <<>> + (expr::MakePlan(dst), + expr::MakePlan(src), + expr::MakePlan(label), + dst.size(1), + alpha); + MSHADOW_CUDA_POST_KERNEL_CHECK(SoftmaxGradKernel); +} + +template +inline void SoftmaxGrad(const Tensor &dst, + const Tensor &src, + const Tensor &label, + const DType &ignore_label) { + dim3 dimBlock(kBaseThreadNum); + dim3 dimGrid(dst.size(0)); + CHECK_EQ(dst.shape_, src.shape_) << "SoftmaxGrad: shape mismatch"; + CHECK_EQ(dst.size(0), label.size(0)) << "SoftmaxGrad: label shape mismatch"; + CheckLaunchParam(dimGrid, dimBlock, "SoftmaxGrad"); + cudaStream_t stream = Stream::GetStream(dst.stream_); + SoftmaxGradKernel + <<>> + (expr::MakePlan(dst), + expr::MakePlan(src), + expr::MakePlan(label), + dst.size(1), + ignore_label); + MSHADOW_CUDA_POST_KERNEL_CHECK(SoftmaxGradKernel); +} + +template +inline void SmoothSoftmaxGrad(const Tensor &dst, + const Tensor &src, + const Tensor &label, + const DType &ignore_label, + const float alpha) { + dim3 dimBlock(kBaseThreadNum); + dim3 dimGrid(dst.size(0)); + CHECK_EQ(dst.shape_, src.shape_) << "SoftmaxGrad: shape mismatch"; + CHECK_EQ(dst.size(0), label.size(0)) << "SoftmaxGrad: label shape mismatch"; + CheckLaunchParam(dimGrid, dimBlock, "SoftmaxGrad"); + cudaStream_t stream = Stream::GetStream(dst.stream_); + SmoothSoftmaxGradKernel + <<>> + (expr::MakePlan(dst), + expr::MakePlan(src), + expr::MakePlan(label), + dst.size(1), + ignore_label, + alpha); + MSHADOW_CUDA_POST_KERNEL_CHECK(SoftmaxGradKernel); +} + +template +__global__ void Softmax3DGradKernel(Tensor dst, + const Tensor src, + const Tensor label) { + const index_t xmax = dst.size(1); + const index_t nmax = dst.size(2); + const unsigned n_size = 1 << n_bits; + const int y = blockIdx.x; + const int n = threadIdx.x; + + for (index_t n_index = n; n_index < nmax; n_index += n_size) { + const int k = static_cast(label[y][n_index]); + for (index_t i = 0; i < xmax; ++i) { + if (i == k) { + dst[y][i][n_index] = src[y][i][n_index] - 1.0f; + } else { + dst[y][i][n_index] = src[y][i][n_index]; + } + } + } +} + +template +__global__ void Softmax3DGradKernel(Tensor dst, + const Tensor src, + const Tensor label, + DType ignore_label) { + const index_t xmax = dst.size(1); + const index_t nmax = dst.size(2); + const unsigned n_size = 1 << n_bits; + const int y = blockIdx.x; + const int n = threadIdx.x; + for (index_t n_index = n; n_index < nmax; n_index += n_size) { + int k = static_cast(label[y][n_index]); + if (k == static_cast(ignore_label)) { + for (index_t i = 0; i < xmax; ++i) { + dst[y][i][n_index] = 0.0f; + } + } else { + for (index_t i = 0; i < xmax; ++i) { + if (i == k) { + dst[y][i][n_index] = src[y][i][n_index] - 1.0f; + } else { + dst[y][i][n_index] = src[y][i][n_index]; + } + } + } + } +} + +template +__global__ void Softmax3DKernel(Tensor dst, + const Tensor src) { + const index_t xmax = dst.size(1); + const index_t nmax = dst.size(2); + const unsigned n_size = 1 << n_bits; + const int y = blockIdx.x; + const int n = threadIdx.x; + + for (index_t n_index = n; n_index < nmax; n_index += n_size) { + DType smax = src[y][0][n_index]; + for (index_t i = 1; i < xmax; ++i) { + smax = max(smax, src[y][i][n_index]); // NOLINT(*) + } + DType ssum = 0.0f; + for (index_t i = 0; i < xmax; ++i) { + DType p = expf(src[y][i][n_index] - smax); + ssum += p; + dst[y][i][n_index] = p; + } + for (index_t i = 0; i < xmax; ++i) { + dst[y][i][n_index] /= ssum; + } + } +} + +template +inline void Softmax(const Tensor &dst, + const Tensor &src) { + dim3 dimBlock(kBaseThreadNum); + dim3 dimGrid(dst.size(0)); + CHECK_EQ(dst.shape_, src.shape_) << "Softmax: shape mismatch"; + CheckLaunchParam(dimGrid, dimBlock, "Softmax"); + cudaStream_t stream = Stream::GetStream(dst.stream_); + Softmax3DKernel<<>>(dst, src); + MSHADOW_CUDA_POST_KERNEL_CHECK(Softmax3DKernel); +} + +template +inline void SoftmaxGrad(const Tensor &dst, + const Tensor &src, + const Tensor &label) { + dim3 dimBlock(kBaseThreadNum); + dim3 dimGrid(dst.size(0)); + CHECK_EQ(dst.shape_, src.shape_) << "SoftmaxGrad: shape mismatch"; + CHECK_EQ(dst.size(0), label.size(0)) << "SoftmaxGrad: label shape mismatch"; + CHECK_EQ(dst.size(2), label.size(1)) << "SoftmaxGrad: label shape mismatch"; + CheckLaunchParam(dimGrid, dimBlock, "SoftmaxGrad"); + cudaStream_t stream = Stream::GetStream(dst.stream_); + Softmax3DGradKernel<<>>(dst, src, label); + MSHADOW_CUDA_POST_KERNEL_CHECK(Softmax3DGradKernel); +} + +template +inline void SoftmaxGrad(const Tensor &dst, + const Tensor &src, + const Tensor &label, + const DType &ignore_label) { + dim3 dimBlock(kBaseThreadNum); + dim3 dimGrid(dst.size(0)); + CHECK_EQ(dst.shape_, src.shape_) << "SoftmaxGrad: shape mismatch"; + CHECK_EQ(dst.size(0), label.size(0)) << "SoftmaxGrad: label shape mismatch"; + CHECK_EQ(dst.size(2), label.size(1)) << "SoftmaxGrad: label shape mismatch"; + CheckLaunchParam(dimGrid, dimBlock, "SoftmaxGrad"); + cudaStream_t stream = Stream::GetStream(dst.stream_); + Softmax3DGradKernel<<>>( + dst, src, label, ignore_label); + MSHADOW_CUDA_POST_KERNEL_CHECK(Softmax3DGradKernel); +} + +template +__global__ void AddTakeGradKernel(DstPlan dst, + SrcPlan1 index, SrcPlan2 src, + index_t ymax, index_t xmax, const int K) { + const unsigned x_size = 1 << x_bits; + const int xindex = blockIdx.x * x_size + threadIdx.x; + __shared__ int ptr; + for (unsigned y = 0; y < ymax; ++y) { + if (threadIdx.x == 0) { + ptr = index.Eval(0, y); + if (ptr <= 0) ptr = 0; + else if (ptr >= K) ptr = K - 1; + } + __syncthreads(); + if (xindex < xmax) { + dst.REval(ptr, xindex) += src.Eval(y, xindex); + } + } +} + +template +__global__ void AddTakeGradLargeBatchKernel(DType* dst, + const IdxType *sorted, const IdxType *index, + const DType *src, + int ymax, int xmax) { + // Based on Torch's Version /~https://github.com/torch/cunn/blob/master/lib/THCUNN/LookupTable.cu + // Each warp is responsible for an input into the LookupTable. + // If the preceeding input has the same as this input, then the warp + // exits immediately. The warp also processes subsequent inputs with the + // same value. + // + // Input Warp + // 1 + // 1 ( exits without doing any work) + // 5 + // 8 + // Also, all warp will loop for SZ times to increase the throughput. + + const int warp_size = 1 << warp_bits; + int idx = blockIdx.x * blockDim.y + threadIdx.y; + + if (idx < ymax + && (idx == 0 || sorted[idx] != sorted[idx - 1])) { + do { + const int start_feature = threadIdx.x + blockIdx.y * blockDim.x * SZ; + const int dst_row = static_cast(sorted[idx]) * xmax; + const int src_row = static_cast(index[idx]) * xmax; + float grad_out[SZ]; + float grad_weight[SZ]; + #pragma unroll + for (int ii = 0; ii < SZ; ii++) { + int feature_dim = start_feature + ii * warp_size; + if (feature_dim < xmax) { + grad_out[ii] = src[src_row + feature_dim]; + grad_weight[ii] = dst[dst_row + feature_dim]; + } + } + + #pragma unroll + for (int ii = 0; ii < SZ; ii++) { + grad_weight[ii] += grad_out[ii]; + } + + #pragma unroll + for (int ii = 0; ii < SZ; ii++) { + int feature_dim = start_feature + ii * warp_size; + if (feature_dim < xmax) { + dst[dst_row + feature_dim] = grad_weight[ii]; + } + } + idx++; + } while (idx < ymax && (sorted[idx] == sorted[idx - 1])); + } +} + +template +inline void AddTakeGrad(Tensor dst, + const Tensor& index, + const Tensor &src) { + CHECK_EQ(dst.CheckContiguous(), true); + CHECK_EQ(index.CheckContiguous(), true); + CHECK_EQ(src.CheckContiguous(), true); + const int kUnitBits = kMemUnitBits + 1; + dim3 dimBlock(1 << kUnitBits); + dim3 dimGrid((dst.size(1) + (1 << kUnitBits) - 1) >> kUnitBits); + + CHECK_EQ(dst.size(1), src.size(1)) << "AddTakeGrad: shape mismatch"; + CHECK_EQ(index.size(0), src.size(0)) << "AddTakeGrad: shape mismatch"; + CheckLaunchParam(dimGrid, dimBlock, "AddTakeGrad"); + cudaStream_t stream = Stream::GetStream(dst.stream_); + const int K = dst.shape_[0]; + + AddTakeGradKernel + <<>> + (expr::MakePlan(dst), + expr::MakePlan(index), + expr::MakePlan(src), + src.size(0), + src.size(1), K); + MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradKernel); +} + +template +inline void AddTakeGradLargeBatch(Tensor dst, + const Tensor& sorted, + const Tensor& index, + const Tensor &src) { + CHECK_EQ(dst.CheckContiguous(), true); + CHECK_EQ(sorted.CheckContiguous(), true); + CHECK_EQ(index.CheckContiguous(), true); + CHECK_EQ(src.CheckContiguous(), true); + const int kWarpBits = kMemUnitBits; + const int SZ = 4; + const int block_dim_x = 1 << kWarpBits; + const int block_dim_y = 4; + const int grid_dim_x = (src.size(0) + block_dim_y - 1) / block_dim_y; + const int grid_dim_y = (src.size(1) + block_dim_x * SZ - 1) / (block_dim_x * SZ); + dim3 dimBlock(block_dim_x, block_dim_y); + dim3 dimGrid(grid_dim_x, grid_dim_y); + + CHECK_EQ(dst.size(1), src.size(1)) << "AddTakeGradLargeBatch: shape mismatch"; + CHECK_EQ(index.size(0), src.size(0)) << "AddTakeGradLargeBatch: shape mismatch"; + CheckLaunchParam(dimGrid, dimBlock, "AddTakeGradLargeBatch"); + cudaStream_t stream = Stream::GetStream(dst.stream_); + + AddTakeGradLargeBatchKernel + <<>> + (dst.dptr_, + sorted.dptr_, + index.dptr_, + src.dptr_, + static_cast(src.size(0)), + static_cast(src.size(1))); + MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradLargeBatchKernel); +} + +template +__global__ void IndexFillKernel(DstPlan dst, + const IndexPlan index, + const SrcPlan src, + const int ymax, + const int xmax) { + int bid = blockIdx.y * blockDim.x + blockIdx.x; + int tid = bid * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x; + if (tid < ymax * xmax) { + int i = tid / xmax; + int j = tid % xmax; + int k = static_cast(index.Eval(0, i)); + dst.REval(k, j) = src.Eval(i, j); + } +} + +template +inline void IndexFill(Tensor dst, + const Tensor& index, + const Tensor &src) { + CHECK_EQ(dst.CheckContiguous(), true); + CHECK_EQ(index.CheckContiguous(), true); + CHECK_EQ(src.CheckContiguous(), true); + CHECK_EQ(dst.size(1), src.size(1)) << "IndexFill: shape mismatch"; + CHECK_EQ(index.size(0), src.size(0)) << "IndexFill: shape mismatch"; + const int block_dim_x = 1 << kMemUnitBits; + const int block_dim_y = 1 << kMemUnitBits; + const int block_size = block_dim_x * block_dim_y; + int grid_dim_x = (src.size(0) * src.size(1) + block_size - 1) / block_size; + int grid_dim_y = 1; + while (grid_dim_x > kMaxGridDim) { + grid_dim_x = (grid_dim_x + 1) / 2; + grid_dim_y *= 2; + } + dim3 dimBlock(block_dim_x, block_dim_y); + dim3 dimGrid(grid_dim_x, grid_dim_y); + CheckLaunchParam(dimGrid, dimBlock, "IndexFill"); + cudaStream_t stream = Stream::GetStream(dst.stream_); + + IndexFillKernel + <<>> + (expr::MakePlan(dst), + expr::MakePlan(index), + expr::MakePlan(src), + src.size(0), + src.size(1)); + MSHADOW_CUDA_POST_KERNEL_CHECK(IndexFillKernel); +} + +template +inline void SortByKey(Tensor keys, Tensor values, + bool is_ascend) { + CHECK_EQ(keys.CheckContiguous(), true); + CHECK_EQ(values.CheckContiguous(), true); +#if CUDA_VERSION >= 7000 + cudaStream_t stream = Stream::GetStream(keys.stream_); + thrust::device_ptr key_iter = thrust::device_pointer_cast(keys.dptr_); + thrust::device_ptr value_iter = thrust::device_pointer_cast(values.dptr_); + if (is_ascend) { + thrust::stable_sort_by_key( + thrust::cuda::par.on(stream), + key_iter, key_iter + keys.size(0), value_iter, thrust::less()); // NOLINT(*) + } else { + thrust::stable_sort_by_key( + thrust::cuda::par.on(stream), + key_iter, key_iter + keys.size(0), value_iter, thrust::greater()); // NOLINT(*) + } + MSHADOW_CUDA_POST_KERNEL_CHECK(SortByKey); +#else + LOG(FATAL) << "SortByKey is only supported for CUDA version >=7.0!"; +#endif +} + +template +inline void SortByKey(Tensor keys, Tensor values, + bool is_ascend) { + LOG(FATAL) << "SortByKey for half_t is not implemented!"; +} + +template +inline void SortByKey(Tensor keys, Tensor values, + bool is_ascend) { + LOG(FATAL) << "SortByKey for half_t is not implemented!"; +} + +// break ambiguous template deduction for +inline void SortByKey(Tensor keys, + Tensor values, + bool is_ascend) { + LOG(FATAL) << "SortByKey for half_t is not implemented!"; +} +} // namespace cuda +} // namespace mshadow +#endif // MSHADOW_CUDA_TENSOR_GPU_INL_CUH_ diff --git a/include/mshadow/dot_engine-inl.h b/include/mshadow/dot_engine-inl.h new file mode 100644 index 000000000000..5363974fc941 --- /dev/null +++ b/include/mshadow/dot_engine-inl.h @@ -0,0 +1,906 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file dot_engine-inl.h + * \brief definitions of how Matrix Multiplications can be evaluated + * \author Tianqi Chen + */ +#ifndef MSHADOW_DOT_ENGINE_INL_H_ +#define MSHADOW_DOT_ENGINE_INL_H_ + +#include +#include "./base.h" +#include "./extension/implicit_gemm.h" + +#ifdef __CUDACC__ +#include "./cuda/tensor_gpu-inl.cuh" +#endif // #ifdef __CUDACC__ + +namespace mshadow { + /*! +* \brief CPU/GPU: Get a batched view of the src array. dst[i] = src + i * stride +* \param dst 2D pointer +* \param src 1D pointer +* \param num number of batches +* \param stride size of each batch +* \param stream +*/ +template +inline void GetBatchedView(DType **dst, DType *src, int num, int stride, + Stream *stream); +template +inline void GetBatchedView(DType **dst, DType *src, int num, int stride, + Stream *stream) { + for (int i = 0; i < num; i++) { + dst[i] = src + i * stride; + } +} +#ifdef __CUDACC__ +namespace cuda {}; +template +inline void GetBatchedView(DType **dst, DType *src, int num, int stride, + Stream *stream) { + cuda::GetBatchedView(dst, src, num, stride, stream); +} +#endif // #ifdef __CUDACC__ + +namespace expr { +//--------------------------------------------------------------------- +// Matrix Multiplications, depends on BLAS Engine +//--------------------------------------------------------------------- +template +struct DotEngine { + inline static void Eval(Tensor *p_dst, + const Tensor &lhs, + const Tensor &rhs, + DType scale); +}; +// handles the dot, use CblasColMajor +template +struct BLASEngine { + inline static bool GetT(bool t) { + return t ? true : false; + } + inline static void SetStream(Stream *stream) { + } + inline static void gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, DType alpha, + const DType *A, int lda, const DType *B, int ldb, + DType beta, DType *C, int ldc) { + LOG(FATAL) << "Not implmented!"; + } + inline static void batched_gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, DType alpha, + const DType *A, int lda, const DType *B, int ldb, + DType beta, DType *C, int ldc, int batch_count, + DType **workspace) { + LOG(FATAL) << "Not implmented!"; + } + inline static void gemv(Stream *stream, + bool trans, int m, int n, + DType alpha, const DType *A, int lda, + const DType *X, int incX, + DType beta, DType *Y, int incY) { + LOG(FATAL) << "Not implmented!"; + } + inline static void batched_gemv(Stream *stream, + bool trans, int m, int n, + DType alpha, const DType *A, int lda, + const DType *X, int incX, + DType beta, DType *Y, int incY, int batch_count) { + LOG(FATAL) << "Not implmented!"; + } + inline static void ger(Stream *stream, + int m, int n, DType alpha, + const DType *X, int incX, + const DType *Y, int incY, DType *A, int lda) { + LOG(FATAL) << "Not implmented!"; + } + inline static void batched_ger(Stream *stream, + int m, int n, DType alpha, + const DType *X, int incX, + const DType *Y, int incY, DType *A, int lda, int batch_count) { + LOG(FATAL) << "Not implmented!"; + } + inline static void dot(Stream *stream, + int n, + const DType* X, int incX, + const DType* Y, int incY, + DType* ret) { + LOG(FATAL) << "Not implmented!"; + } +}; + +#if MSHADOW_STAND_ALONE +template<> +struct BLASEngine { + inline static bool GetT(bool t) { + return t ? true : false; + } + inline static void SetStream(Stream *stream) { + } + inline static void gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, float alpha, + const float *A, int lda, const float *B, int ldb, + float beta, float *C, int ldc) { + if (alpha == 1.0f && beta == 0.0f) { + bool transpose_left = transb; + bool transpose_right = transa; + Tensor lhs((float*)B, Shape2(transpose_left ? k : n, transpose_left ? n : k)); // NOLINT(*) + Tensor rhs((float*)A, Shape2(transpose_right ? m : k, transpose_right ? k : m)); // NOLINT(*) + Tensor dst(C, Shape2(m, n)); + if (!transpose_left && !transpose_right) { + dst = expr::implicit_dot(lhs, rhs); return; + } else if (!transpose_left && transpose_right) { + dst = expr::implicit_dot(lhs, rhs.T()); return; + } else if (transpose_left && !transpose_right) { + dst = expr::implicit_dot(lhs.T(), rhs); return; + } else { + LOG(FATAL) << "Not implmented!"; + } + } else { + LOG(FATAL) << "Not implmented!"; + } + } + inline static void batched_gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, float alpha, + const float *A, int lda, const float *B, int ldb, + float beta, float *C, int ldc, int batch_count, + float **workspace) { + for (int i = 0; i < batch_count; ++i) { + gemm(stream, transa, transb, m, n, k, alpha, + A + i * m * k, lda, B + i * k * n, ldb, + beta, C + i * m * n, ldc); + } + } + inline static void gemv(Stream *stream, + bool trans, int m, int n, + float alpha, const float *A, int lda, + const float *X, int incX, + float beta, float *Y, int incY) { + LOG(FATAL) << "Not implmented!"; + } + inline static void batched_gemv(Stream *stream, + bool trans, int m, int n, + float alpha, const float *A, int lda, + const float *X, int incX, + float beta, float *Y, int incY, int batch_count) { + LOG(FATAL) << "Not implmented!"; + } + inline static void ger(Stream *stream, + int m, int n, float alpha, + const float *X, int incX, + const float *Y, int incY, float *A, int lda) { + LOG(FATAL) << "Not implmented!"; + } + inline static void batched_ger(Stream *stream, + int m, int n, float alpha, + const float *X, int incX, + const float *Y, int incY, float *A, int lda, int batch_count) { + LOG(FATAL) << "Not implmented!"; + } + inline static void dot(Stream *stream, + int n, + const float* X, int incX, + const float* Y, int incY, + float* ret) { + LOG(FATAL) << "Not implmented!"; + } +}; + +template<> +struct BLASEngine { + inline static bool GetT(bool t) { + return t ? true : false; + } + inline static void SetStream(Stream *stream) { + } + inline static void gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, double alpha, + const double *A, int lda, const double *B, int ldb, + double beta, double *C, int ldc) { + if (alpha == 1.0f && beta == 0.0f) { + bool transpose_left = transb; + bool transpose_right = transa; + Tensor lhs((double*)B, Shape2(transpose_left ? k : n, transpose_left ? n : k)); // NOLINT(*) + Tensor rhs((double*)A, Shape2(transpose_right ? m : k, transpose_right ? k : m)); // NOLINT(*) + Tensor dst(C, Shape2(m, n)); + if (!transpose_left && !transpose_right) { + dst = expr::implicit_dot(lhs, rhs); return; + } else if (!transpose_left && transpose_right) { + dst = expr::implicit_dot(lhs, rhs.T()); return; + } else if (transpose_left && !transpose_right) { + dst = expr::implicit_dot(lhs.T(), rhs); return; + } else { + LOG(FATAL) << "Not implmented!"; + } + } else { + LOG(FATAL) << "Not implmented!"; + } + } + inline static void batched_gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, double alpha, + const double *A, int lda, const double *B, int ldb, + double beta, double *C, int ldc, int batch_count, + double **workspace) { + for (int i = 0; i < batch_count; ++i) { + gemm(stream, transa, transb, m, n, k, alpha, + A + i * m * k, lda, B + i * k * n, ldb, + beta, C + i * m * n, ldc); + } + } + inline static void gemv(Stream *stream, + bool trans, int m, int n, + double alpha, const double *A, int lda, + const double *X, int incX, + double beta, double *Y, int incY) { + LOG(FATAL) << "Not implmented!"; + } + inline static void batched_gemv(Stream *stream, + bool trans, int m, int n, + double alpha, const double *A, int lda, + const double *X, int incX, + double beta, double *Y, int incY, int batch_count) { + LOG(FATAL) << "Not implmented!"; + } + inline static void ger(Stream *stream, + int m, int n, double alpha, + const double *X, int incX, + const double *Y, int incY, double *A, int lda) { + LOG(FATAL) << "Not implmented!"; + } + inline static void batched_ger(Stream *stream, + int m, int n, double alpha, + const double *X, int incX, + const double *Y, int incY, double *A, int lda, int batch_count) { + LOG(FATAL) << "Not implmented!"; + } + inline static void dot(Stream *stream, + int n, + const double* X, int incX, + const double* Y, int incY, + double* ret) { + LOG(FATAL) << "Not implmented!"; + } +}; + +#elif (MSHADOW_USE_MKL || MSHADOW_USE_CBLAS) // NOLINT(*) +template<> +struct BLASEngine { + inline static CBLAS_TRANSPOSE GetT(bool t) { + return t ? CblasTrans : CblasNoTrans; + } + inline static void SetStream(Stream *stream) { + } + inline static void gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, float alpha, + const float *A, int lda, const float *B, int ldb, + float beta, float *C, int ldc) { + cblas_sgemm(CblasColMajor, GetT(transa), GetT(transb), + m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + } + inline static void batched_gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, float alpha, + const float *A, int lda, const float *B, int ldb, + float beta, float *C, int ldc, int batch_count, + float **workspace) { +#if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000) + std::vector p_m(batch_count, m); + std::vector p_n(batch_count, n); + std::vector p_k(batch_count, k); + std::vector p_lda(batch_count, lda); + std::vector p_ldb(batch_count, ldb); + std::vector p_ldc(batch_count, ldc); + std::vector p_alpha(batch_count, alpha); + std::vector p_beta(batch_count, beta); + std::vector pp_A; + std::vector pp_B; + std::vector pp_C; + + CBLAS_TRANSPOSE cblas_a_trans = GetT(transa); + CBLAS_TRANSPOSE cblas_b_trans = GetT(transb); + + std::vector p_group_sizeb(batch_count, batch_count); + std::vector p_transa(batch_count, cblas_a_trans); + std::vector p_transb(batch_count, cblas_b_trans); + + auto m_k = m * k; + auto k_n = k * n; + auto m_n = m * n; + + for (int i = 0; i < batch_count; i++) { + pp_A.push_back(A + i * m_k); + pp_B.push_back(B + i * k_n); + pp_C.push_back(C + i * m_n); + } + + cblas_sgemm_batch(CblasColMajor, p_transa.data(), p_transb.data(), + p_m.data(), p_n.data(), p_k.data(), + p_alpha.data(), pp_A.data(), p_lda.data(), pp_B.data(), + p_ldb.data(), p_beta.data(), pp_C.data(), p_ldc.data(), + 1, p_group_sizeb.data()); +#else + for (int i = 0; i < batch_count; ++i) { + gemm(stream, transa, transb, m, n, k, alpha, + A + i * m * k, lda, B + i * k * n, ldb, + beta, C + i * m * n, ldc); + } +#endif + } + inline static void gemv(Stream *stream, + bool trans, int m, int n, + float alpha, const float *A, int lda, + const float *X, int incX, + float beta, float *Y, int incY) { + cblas_sgemv(CblasColMajor, GetT(trans), m, n, alpha, + A, lda, X, incX, beta, Y, incY); + } + inline static void batched_gemv(Stream *stream, + bool trans, int m, int n, + float alpha, const float *A, int lda, + const float *X, int incX, + float beta, float *Y, int incY, int batch_count) { + for (int i = 0; i < batch_count; ++i) { + gemv(stream, trans, m, n, alpha, A + i * m * n, lda, + X + i * (trans ? m : n) * incX, incX, + beta, Y + i * (trans ? n : m) * incY, incY); + } + } + inline static void ger(Stream *stream, + int m, int n, float alpha, + const float *X, int incX, + const float *Y, int incY, float *A, int lda) { + cblas_sger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda); + } + inline static void batched_ger(Stream *stream, + int m, int n, float alpha, + const float *X, int incX, + const float *Y, int incY, float *A, int lda, int batch_count) { + for (int i = 0; i < batch_count; ++i) { + ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, + A + i * lda * n, lda); + } + } + inline static void dot(Stream *stream, + int n, + const float* X, int incX, + const float* Y, int incY, + float* ret) { + *ret = cblas_sdot(n, X, incX, Y, incY); + } +}; + +template<> +struct BLASEngine { + inline static CBLAS_TRANSPOSE GetT(bool t) { + return t ? CblasTrans : CblasNoTrans; + } + inline static void SetStream(Stream *stream) { + } + inline static void gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, double alpha, + const double *A, int lda, const double *B, int ldb, + double beta, double *C, int ldc) { + cblas_dgemm(CblasColMajor, GetT(transa), GetT(transb), + m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + } + inline static void batched_gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, double alpha, + const double *A, int lda, const double *B, int ldb, + double beta, double *C, int ldc, int batch_count, + double **workspace) { +#if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000) + std::vector p_m(batch_count, m); + std::vector p_n(batch_count, n); + std::vector p_k(batch_count, k); + std::vector p_lda(batch_count, lda); + std::vector p_ldb(batch_count, ldb); + std::vector p_ldc(batch_count, ldc); + std::vector p_alpha(batch_count, alpha); + std::vector p_beta(batch_count, beta); + std::vector pp_A; + std::vector pp_B; + std::vector pp_C; + + CBLAS_TRANSPOSE cblas_a_trans = GetT(transa); + CBLAS_TRANSPOSE cblas_b_trans = GetT(transb); + + std::vector p_group_sizeb(batch_count, batch_count); + std::vector p_transa(batch_count, cblas_a_trans); + std::vector p_transb(batch_count, cblas_b_trans); + + auto m_k = m * k; + auto k_n = k * n; + auto m_n = m * n; + + for (int i = 0; i < batch_count; i++) { + pp_A.push_back(A + i * m_k); + pp_B.push_back(B + i * k_n); + pp_C.push_back(C + i * m_n); + } + + cblas_dgemm_batch(CblasColMajor, p_transa.data(), p_transb.data(), + p_m.data(), p_n.data(), p_k.data(), + p_alpha.data(), pp_A.data(), p_lda.data(), pp_B.data(), + p_ldb.data(), p_beta.data(), pp_C.data(), p_ldc.data(), + 1, p_group_sizeb.data()); +#else + for (int i = 0; i < batch_count; ++i) { + gemm(stream, transa, transb, m, n, k, alpha, + A + i * m * k, lda, B + i * k * n, ldb, + beta, C + i * m * n, ldc); + } +#endif + } + inline static void gemv(Stream *stream, + bool trans, int m, int n, double alpha, + const double *A, int lda, + const double *X, int incX, + double beta, double *Y, int incY) { + cblas_dgemv(CblasColMajor, GetT(trans), m, n, alpha, + A, lda, X, incX, beta, Y, incY); + } + inline static void batched_gemv(Stream *stream, + bool trans, int m, int n, + double alpha, const double *A, int lda, + const double *X, int incX, + double beta, double *Y, int incY, int batch_count) { + for (int i = 0; i < batch_count; ++i) { + gemv(stream, trans, m, n, alpha, A + i * m * n, lda, + X + i * (trans ? m : n) * incX, incX, + beta, Y + i * (trans ? n : m) * incY, incY); + } + } + inline static void ger(Stream *stream, + int m, int n, double alpha, + const double *X, int incX, + const double *Y, int incY, double *A, int lda) { + cblas_dger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda); + } + inline static void batched_ger(Stream *stream, + int m, int n, double alpha, + const double *X, int incX, + const double *Y, int incY, double *A, int lda, int batch_count) { + for (int i = 0; i < batch_count; ++i) { + ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, + A + i * lda * n, lda); + } + } + inline static void dot(Stream *stream, + int n, + const double* X, int incX, + const double* Y, int incY, + double* ret) { + *ret = cblas_ddot(n, X, incX, Y, incY); + } +}; +#endif // MSHADOW_USE_CBLAS || MSHADOW_USE_MKL || MSHADOW_STAND_ALONE +// CuBLAS redirect code +#if MSHADOW_USE_CUDA +// All CuBLAS goes to here, use legacy API: not threadsafe +template<> +struct BLASEngine { + inline static cublasOperation_t GetT(bool t) { + return t ? CUBLAS_OP_T : CUBLAS_OP_N; + } + inline static void SetStream(Stream *stream) { + cublasStatus_t err = cublasSetStream(Stream::GetBlasHandle(stream), + Stream::GetStream(stream)); + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas set stream fail"; + } + inline static void gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, half::half_t alpha, + const half::half_t *A, int lda, + const half::half_t *B, int ldb, half::half_t beta, + half::half_t *C, int ldc) { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 7050 + // Always use pseudo-fp16: fp32 compute with fp16 I/O. + float alpha_f = float(alpha); // NOLINT(*) + float beta_f = float(beta); // NOLINT(*) + #if CUDA_VERSION >= 8000 + cublasStatus_t err = cublasSgemmEx(Stream::GetBlasHandle(stream), + GetT(transa), GetT(transb), m, n, k, &alpha_f, + A, CUDA_R_16F, lda, B, CUDA_R_16F, + ldb, &beta_f, C, CUDA_R_16F, ldc); + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas SgemmEx fail"; + #else + cublasStatus_t err = cublasSgemmEx(Stream::GetBlasHandle(stream), + GetT(transa), GetT(transb), m, n, k, &alpha_f, + A, CUBLAS_DATA_HALF, lda, B, CUBLAS_DATA_HALF, + ldb, &beta_f, C, CUBLAS_DATA_HALF, ldc); + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas SgemmEx fail"; + #endif // CUDA_VERSION >= 8000 +#else + LOG(FATAL) << "Require CUDA version >= 7.5!"; +#endif // defined(CUDA_VERSION) && CUDA_VERSION >= 7050 + } + inline static void batched_gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, half::half_t alpha, + const half::half_t *A, int lda, const half::half_t *B, int ldb, + half::half_t beta, half::half_t *C, int ldc, int batch_count, + half::half_t **workspace) { + for (int i = 0; i < batch_count; ++i) { + gemm(stream, transa, transb, m, n, k, alpha, + A + i * m * k, lda, B + i * k * n, ldb, + beta, C + i * m * n, ldc); + } + } + inline static void gemv(Stream *stream, + bool trans, int m, int n, half::half_t alpha, + const half::half_t *A, int lda, + const half::half_t *X, int incX, half::half_t beta, + half::half_t *Y, int incY) { + LOG(FATAL) << "Not implmented!"; + } + inline static void batched_gemv(Stream *stream, + bool trans, int m, int n, + half::half_t alpha, const half::half_t *A, int lda, + const half::half_t *X, int incX, + half::half_t beta, half::half_t *Y, int incY, int batch_count) { + LOG(FATAL) << "Not implmented!"; + } + inline static void ger(Stream *stream, + int m, int n, half::half_t alpha, + const half::half_t *X, int incX, + const half::half_t *Y, int incY, half::half_t *A, int lda) { + LOG(FATAL) << "Not implmented!"; + } + inline static void batched_ger(Stream *stream, + int m, int n, half::half_t alpha, + const half::half_t *X, int incX, const half::half_t *Y, int incY, + half::half_t *A, int lda, int batch_count) { + LOG(FATAL) << "Not implmented!"; + } + inline static void dot(Stream *stream, + int n, + const half::half_t* X, int incX, + const half::half_t* Y, int incY, + half::half_t *ret) { + LOG(FATAL) << "Not implmented!"; + } +}; + +template<> +struct BLASEngine { + inline static cublasOperation_t GetT(bool t) { + return t ? CUBLAS_OP_T : CUBLAS_OP_N; + } + inline static void SetStream(Stream *stream) { + cublasStatus_t err = cublasSetStream(Stream::GetBlasHandle(stream), + Stream::GetStream(stream)); + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: set stream fail"; + } + inline static void gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, float alpha, + const float *A, int lda, + const float *B, int ldb, float beta, + float *C, int ldc) { + cublasStatus_t err = cublasSgemm(Stream::GetBlasHandle(stream), + GetT(transa), GetT(transb), m, n, k, &alpha, + A, lda, B, ldb, &beta, C, ldc); + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sgemm fail"; + } + inline static void batched_gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, float alpha, + const float *A, int lda, const float *B, int ldb, + float beta, float *C, int ldc, int batch_count, + float **workspace) { +#if defined(__CUDACC__) && CUDA_VERSION >= 4010 && CUDA_VERSION < 8000 + // Cast DType* to DType** using workspace as a buffer + bool alloc_workspace = false; + if (workspace == NULL) { + // Allocate the workspace if it's NULL. + // TODO(sxjscience) Try to move the allocation inside Tensor, which is thread-safe. + cudaMalloc(reinterpret_cast(&workspace), 3 * batch_count * sizeof(float*)); + alloc_workspace = true; + } + GetBatchedView(workspace, const_cast(A), batch_count, m * k, stream); + GetBatchedView(workspace + batch_count, + const_cast(B), batch_count, k * n, stream); + GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream); + cublasStatus_t err = cublasSgemmBatched(Stream::GetBlasHandle(stream), + GetT(transa), GetT(transb), m, n, k, &alpha, + (const float**)workspace, lda, + (const float**)(workspace + batch_count), ldb, + &beta, workspace + 2 * batch_count, ldc, batch_count); + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: SgemmBatched fail"; + if (alloc_workspace) { + cudaFree(workspace); + } +#elif defined(__CUDACC__) && CUDA_VERSION >= 8000 + cublasStatus_t err = cublasSgemmStridedBatched(Stream::GetBlasHandle(stream), + GetT(transa), GetT(transb), m, n, k, &alpha, + A, lda, m * k, + B, ldb, k * n, + &beta, C, ldc, m * n, + batch_count); + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: SgemmStridedBatched fail"; +#else + for (int i = 0; i < batch_count; ++i) { + gemm(stream, transa, transb, m, n, k, alpha, + A + i * m * k, lda, B + i * k * n, ldb, + beta, C + i * m * n, ldc); + } +#endif // defined(__CUDACC__) && CUDA_VERSION >= 4010 + } + inline static void gemv(Stream *stream, + bool trans, int m, int n, float alpha, + const float *A, int lda, + const float *X, int incX, float beta, + float *Y, int incY) { + cublasStatus_t err = cublasSgemv(Stream::GetBlasHandle(stream), + GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY); + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sgemv fail"; + } + inline static void batched_gemv(Stream *stream, + bool trans, int m, int n, + float alpha, const float *A, int lda, + const float *X, int incX, + float beta, float *Y, int incY, int batch_count) { + for (int i = 0; i < batch_count; ++i) { + gemv(stream, trans, m, n, alpha, A + i * m * n, lda, + X + i * (trans ? m : n) * incX, incX, + beta, Y + i * (trans ? n : m) * incY, incY); + } + } + inline static void ger(Stream *stream, + int m, int n, float alpha, + const float *X, int incX, + const float *Y, int incY, float *A, int lda) { + cublasStatus_t err = cublasSger(Stream::GetBlasHandle(stream), + m, n, &alpha, X, incX, Y, incY, A, lda); + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sger fail"; + } + inline static void batched_ger(Stream *stream, + int m, int n, float alpha, + const float *X, int incX, + const float *Y, int incY, float *A, int lda, int batch_count) { + for (int i = 0; i < batch_count; ++i) { + ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, + A + i * lda * n, lda); + } + } + inline static void dot(Stream *stream, + int n, + const float* X, int incX, + const float* Y, int incY, + float *ret) { + cublasSetPointerMode(Stream::GetBlasHandle(stream), + CUBLAS_POINTER_MODE_DEVICE); + cublasStatus_t err = cublasSdot(Stream::GetBlasHandle(stream), + n, X, incX, Y, incY, ret); + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dot fail"; + cublasSetPointerMode(Stream::GetBlasHandle(stream), + CUBLAS_POINTER_MODE_HOST); + } +}; + +template<> +struct BLASEngine { + inline static cublasOperation_t GetT(bool t) { + return t ? CUBLAS_OP_T : CUBLAS_OP_N; + } + inline static void SetStream(Stream *stream) { + cublasStatus_t err = cublasSetStream(Stream::GetBlasHandle(stream), + Stream::GetStream(stream)); + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: set stream fail"; + } + inline static void gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, double alpha, + const double *A, int lda, + const double *B, int ldb, + double beta, double *C, int ldc) { + cublasStatus_t err = cublasDgemm(Stream::GetBlasHandle(stream), + GetT(transa), GetT(transb), m, n, k, &alpha, + A, lda, B, ldb, &beta, C, ldc); + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dgemm fail"; + } + inline static void batched_gemm(Stream *stream, + bool transa, bool transb, + int m, int n, int k, double alpha, + const double *A, int lda, const double *B, int ldb, + double beta, double *C, int ldc, int batch_count, + double **workspace) { +#if defined(__CUDACC__) && CUDA_VERSION >= 4010 && CUDA_VERSION < 8000 + // Cast DType* to DType** using workspace as a buffer + bool alloc_workspace = false; + if (workspace == NULL) { + // Allocate the workspace if it's NULL. + // TODO(sxjscience) Try to move the allocation inside Tensor, which is thread-safe. + cudaMalloc(reinterpret_cast(&workspace), 3 * batch_count * sizeof(double*)); + alloc_workspace = true; + } + GetBatchedView(workspace, const_cast(A), batch_count, m * k, stream); + GetBatchedView(workspace + batch_count, + const_cast(B), batch_count, k * n, stream); + GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream); + cublasStatus_t err = cublasDgemmBatched(Stream::GetBlasHandle(stream), + GetT(transa), GetT(transb), m, n, k, &alpha, + (const double**)workspace, lda, + (const double**)(workspace + batch_count), ldb, + &beta, workspace + 2 * batch_count, ldc, batch_count); + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: DgemmBatched fail"; + if (alloc_workspace) { + cudaFree(workspace); + } +#elif defined(__CUDACC__) && CUDA_VERSION >= 8000 + cublasStatus_t err = cublasDgemmStridedBatched(Stream::GetBlasHandle(stream), + GetT(transa), GetT(transb), m, n, k, &alpha, + A, lda, m * k, + B, ldb, k * n, + &beta, C, ldc, m * n, + batch_count); + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: DgemmStridedBatched fail"; +#else + for (int i = 0; i < batch_count; ++i) { + gemm(stream, transa, transb, m, n, k, alpha, + A + i * m * k, lda, B + i * k * n, ldb, + beta, C + i * m * n, ldc); + } +#endif // defined(__CUDACC__) && CUDA_VERSION >= 4010 + } + inline static void gemv(Stream *stream, + bool trans, int m, int n, double alpha, + const double *A, int lda, + const double *X, int incX, + double beta, double *Y, int incY) { + cublasStatus_t err = cublasDgemv(Stream::GetBlasHandle(stream), + GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY); + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dgemv fail"; + } + inline static void batched_gemv(Stream *stream, + bool trans, int m, int n, + double alpha, const double *A, int lda, + const double *X, int incX, + double beta, double *Y, int incY, int batch_count) { + for (int i = 0; i < batch_count; ++i) { + gemv(stream, trans, m, n, alpha, A + i * m * n, lda, + X + i * (trans ? m : n) * incX, incX, + beta, Y + i * (trans ? n : m) * incY, incY); + } + } + inline static void ger(Stream *stream, + int m, int n, double alpha, + const double *X, int incX, + const double *Y, int incY, double *A, int lda) { + cublasStatus_t err = cublasDger(Stream::GetBlasHandle(stream), + m, n, &alpha, X, incX, Y, incY, A, lda); + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dger fail"; + } + inline static void batched_ger(Stream *stream, + int m, int n, double alpha, + const double *X, int incX, + const double *Y, int incY, double *A, int lda, int batch_count) { + for (int i = 0; i < batch_count; ++i) { + ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, + A + i * lda * n, lda); + } + } + inline static void dot(Stream *stream, + int n, + const double* X, int incX, + const double* Y, int incY, + double *ret) { + cublasSetPointerMode(Stream::GetBlasHandle(stream), + CUBLAS_POINTER_MODE_DEVICE); + cublasStatus_t err = cublasDdot(Stream::GetBlasHandle(stream), + n, X, incX, Y, incY, ret); + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dot fail"; + cublasSetPointerMode(Stream::GetBlasHandle(stream), + CUBLAS_POINTER_MODE_HOST); + } +}; +#endif // MSHADOW_USE_CUDA +// helper function to decide which shape we are in +inline Shape<2> GetShape(const Shape<2> &shape, bool transpose) { + return transpose ? Shape2(shape[1], shape[0]) : shape; +} +// dst = dot(lhs[.T], rhs[.T]) +template +struct DotEngine { + inline static void Eval(Tensor *p_dst, + const Tensor &lhs, + const Tensor &rhs, + DType scale) { + Tensor &dst = *p_dst; +#if MSHADOW_STAND_ALONE + if (xpu::kDevMask == cpu::kDevMask && scale == 1.0f) { + if (!transpose_left && !transpose_right) { + dst = expr::implicit_dot(lhs, rhs); return; + } else if (!transpose_left && transpose_right) { + dst = expr::implicit_dot(lhs, rhs.T()); return; + } else if (transpose_left && !transpose_right) { + dst = expr::implicit_dot(lhs.T(), rhs); return; + } + } +#endif + // set kernel stream + // if there is no stream, crush + BLASEngine::SetStream(dst.stream_); + Shape<2> sleft = GetShape(lhs.shape_, transpose_left); + Shape<2> sright = GetShape(rhs.shape_, transpose_right); + CHECK(dst.size(0) == sleft[0] && dst.size(1) == sright[1] && sleft[1] == sright[0]) + << "dot-gemm: matrix shape mismatch"; + // use column major argument to compatible with most BLAS + BLASEngine::gemm + (dst.stream_, + transpose_right , transpose_left, + transpose_right ? rhs.size(0) : rhs.size(1), + transpose_left ? lhs.size(1) : lhs.size(0), + transpose_right ? rhs.size(1) : rhs.size(0), + DType(scale * SV::AlphaBLAS()), + rhs.dptr_, rhs.stride_, + lhs.dptr_, lhs.stride_, + DType(SV::BetaBLAS()), + dst.dptr_, dst.stride_); + } +}; +template +struct DotEngine { + inline static void Eval(Tensor *p_dst, + const Tensor &lhs, + const Tensor &rhs, + DType scale) { + Tensor &dst = *p_dst; + // set kernel stream + // if there is no stream, crush + BLASEngine::SetStream(dst.stream_); + Shape<2> sright = GetShape(rhs.shape_, transpose_right); + CHECK(dst.size(0) == sright[1] && lhs.size(0) == sright[0]) + << "dot-gemv: matrix shape mismatch" + << "dst: " << dst.shape_ << "\n" + << "lhs: " << lhs.shape_ << "\n" + << "rhs: " << sright << "\n"; + BLASEngine::gemv + (dst.stream_, + transpose_right, + rhs.size(1), rhs.size(0), scale * SV::AlphaBLAS(), + rhs.dptr_, rhs.stride_, + lhs.dptr_, 1, SV::BetaBLAS(), + dst.dptr_, 1); + } +}; +template +struct DotEngine { + inline static void Eval(Tensor *p_dst, + const Tensor &lhs, + const Tensor &rhs, + DType scale) { + Tensor &dst = *p_dst; + // set kernel stream + // if there is no stream, crush + BLASEngine::SetStream(dst.stream_); + CHECK(dst.size(0) == lhs.size(0) && dst.size(1) == rhs.size(0)) + << "dot-ger: matrix shape mismatch" + << "dst: " << dst.shape_ << "\n" + << "lhs: " << lhs.shape_ << "\n" + << "rhs: " << rhs.shape_; + if (SV::BetaBLAS() == 0.0f) { + BLASEngine::ger + (dst.stream_, rhs.size(0), lhs.size(0), scale * SV::AlphaBLAS(), + rhs.dptr_, 1, lhs.dptr_, 1, dst.dptr_, dst.stride_); + } else { + DotEngine::Eval(p_dst, lhs.FlatTo2D(), rhs.FlatTo2D(), scale); + } + } +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_DOT_ENGINE_INL_H_ diff --git a/include/mshadow/expr_engine-inl.h b/include/mshadow/expr_engine-inl.h new file mode 100644 index 000000000000..6421ebcff812 --- /dev/null +++ b/include/mshadow/expr_engine-inl.h @@ -0,0 +1,482 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file expr_engine-inl.h + * \brief definitions of how expressions should be evaluated + * \author Tianqi Chen, Bing Xu + */ +#ifndef MSHADOW_EXPR_ENGINE_INL_H_ +#define MSHADOW_EXPR_ENGINE_INL_H_ +#include +#include +#include "./logging.h" +#include "./expression.h" +#include "./tensor.h" + +namespace mshadow { +namespace expr { +/*! + * \brief a general class that allows extension that makes tensors of some shape + * \tparam SubType type of subclass + * \tparam SrcExp source expression of the MakeTensorExp, the source of operation + * \tparam dim dimension of the expression + * \tparam DType the type of elements + */ +template +struct MakeTensorExp + : public Exp, + DType, type::kChainer> { + /*! \brief the shape of this expression */ + Shape shape_; + /*! \brief true self of subtype */ + inline const SubType& real_self(void) const{ + return *static_cast(this); + } +}; +//---------------------------------------------------------------------- +// This part of code gives plan that can be used to carry out execution +//--------------------------------------------------------------------- +// Declarations of plans +template +class Plan { + public: + /*! + * \brief evaluate the expression at index [y][x] + * to be implemented by SubType, for RValue, the return type will be DType & + */ + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const; +}; +// tensor plan +template +class Plan, DType> { + public: + explicit Plan(const Tensor &t) + : dptr_(t.dptr_), stride_(t.stride_) {} + // for RValue, the return type should be reference + MSHADOW_XINLINE DType &REval(index_t y, index_t x) { + return dptr_[y * stride_ + x]; + } + // const evaluation + MSHADOW_XINLINE const DType &Eval(index_t y, index_t x) const { + return dptr_[y * stride_ + x]; + } + + private: + DType *dptr_; + index_t stride_; +}; +// special evaluation case for 1d tensor, no stride +template +class Plan, DType> { + public: + explicit Plan(const Tensor &t) : dptr_(t.dptr_) {} + MSHADOW_XINLINE DType &REval(index_t y, index_t x) { + return dptr_[x]; + } + MSHADOW_XINLINE const DType &Eval(index_t y, index_t x) const { + return dptr_[x]; + } + + private: + DType *dptr_; +}; +// scalar +template +class Plan, DType> { + public: + explicit Plan(DType scalar) : scalar_(scalar) {} + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + return scalar_; + } + + private: + DType scalar_; +}; +// unary expression +template +class Plan, DstDType> { + public: + explicit Plan(const Plan &src) : src_(src) {} + MSHADOW_XINLINE DstDType Eval(index_t y, index_t x) const { + return DstDType(src_.Eval(y, x)); // NOLINT(*) + } + + private: + Plan src_; +}; + +// ternary expression +template +class Plan, DType> { + public: + explicit Plan(const Plan &item1, const Plan &item2, + const Plan &item3) + : item1_(item1), item2_(item2), item3_(item3) {} + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + return OP::Map(item1_.Eval(y, x), item2_.Eval(y, x), item3_.Eval(y, x)); + } + + private: + Plan item1_; + Plan item2_; + Plan item3_; +}; +// binary expression +template +class Plan, DType> { + public: + explicit Plan(const Plan &lhs, const Plan &rhs) + : lhs_(lhs), rhs_(rhs) {} + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + return OP::Map(lhs_.Eval(y, x), rhs_.Eval(y, x)); + } + + private: + Plan lhs_; + Plan rhs_; +}; +// unary expression +template +class Plan, DType> { + public: + explicit Plan(const Plan &src) : src_(src) {} + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + return OP::Map(src_.Eval(y, x)); + } + + private: + Plan src_; +}; +// remaps map tensor expression to subtype's plan +template +struct Plan, DType> { + public: + Plan(const Plan &src) : src_(src) {} + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + return src_.Eval(y, x); + } + + private: + Plan src_; +}; +// tranpsoe +template +class Plan, DType> { + public: + explicit Plan(const Plan &src) : src_(src) {} + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + return src_.Eval(x, y); + } + + private: + Plan src_; +}; +//---------------------------------------------------------------------- +// Mappings from expression to plans +//--------------------------------------------------------------------- +template +inline Plan, DType> +MakePlan(const BinaryMapExp &e); + +template +inline Plan, DType> +MakePlan(const TernaryMapExp &e); + +template +inline Plan, DType> MakePlan(const ScalarExp &e) { + return Plan, DType>(e.scalar_); +} + +template +inline Plan, DstDType> +MakePlan(const TypecastExp &e) { + return Plan, DstDType>(MakePlan(e.exp)); +} + +template +inline Plan MakePlan(const RValueExp &e) { + return Plan(e.self()); +} + +template +inline Plan, DType> +MakePlan(const TransposeExp &e) { + return Plan, DType>(MakePlan(e.exp)); +} + +template +inline Plan +MakePlan(const MakeTensorExp &e) { + return Plan(e.real_self()); +} + +template +inline Plan, DType> +MakePlan(const UnaryMapExp &e) { + return Plan, DType>(MakePlan(e.src_)); +} + +template +inline Plan, DType> +MakePlan(const BinaryMapExp &e) { + return Plan, + DType>(MakePlan(e.lhs_), MakePlan(e.rhs_)); +} + +// Ternary +template +inline Plan, DType> +MakePlan(const TernaryMapExp &e) { + return Plan, + DType>(MakePlan(e.item1_), MakePlan(e.item2_), MakePlan(e.item3_)); +} +//---------------------------------------------------------------- +// Static Type inference and Type Checking +//---------------------------------------------------------------- +/*! + * \brief static type inference template, + * used to get the dimension of each expression, + * if ExpInfo::kDim == -1, this means here are mismatch in expression + * if (ExpInfo::kDevMask & cpu::kDevMask) != 0, this means this expression can be assigned to cpu + * \tparam E expression + */ +template +struct ExpInfo { + static const int kDim = -1; + static const int kDevMask = 0; +}; +template +struct ExpInfo< ScalarExp > { + static const int kDim = 0; + static const int kDevMask = 0xffff; +}; +template +struct ExpInfo > { + static const int kDim = ExpInfo::kDim; + static const int kDevMask = ExpInfo::kDevMask; +}; +template +struct ExpInfo > { + static const int kDim = ExpInfo::kDim; + static const int kDevMask = ExpInfo::kDevMask; +}; +template +struct ExpInfo > { + static const int kDim = dim; + static const int kDevMask = Device::kDevMask; +}; +template +struct ExpInfo > { + static const int kDimSrc = ExpInfo::kDim; + static const int kDim = kDimSrc >= 0 ? dim : -1; + static const int kDevMask = ExpInfo::kDevMask; +}; +template +struct ExpInfo > { + static const int kDim = ExpInfo::kDim; + static const int kDevMask = ExpInfo::kDevMask; +}; +template +struct ExpInfo > { + static const int kDimLhs = ExpInfo::kDim; + static const int kDimRhs = ExpInfo::kDim; + static const int kDim = (kDimLhs >= 0 && kDimRhs >= 0) ?\ + (kDimLhs == 0 ?\ + kDimRhs :\ + ((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1; + static const int kDevMask = ExpInfo::kDevMask & ExpInfo::kDevMask; +}; +template +struct ExpInfo > { + static const int kDimItem1 = ExpInfo::kDim; + static const int kDimItem2 = ExpInfo::kDim; + static const int kDimItem3 = ExpInfo::kDim; + static const int kDim = kDimItem1; + static const int kDevMask = ExpInfo::kDevMask & ExpInfo::kDevMask & ExpInfo::kDevMask; +}; + +/*! \brief template to do type check */ +template +struct TypeCheck { + /*! \brief dimension of expression*/ + static const int kExpDim = ExpInfo::kDim; + /*! \brief whether the expression device type matches */ + static const bool kDevPass = (ExpInfo::kDevMask & Device::kDevMask) != 0; + /*! \brief whether the expression can be mapped to expression of dim */ + static const bool kMapPass = (kExpDim == 0 || kExpDim == dim) && kDevPass; + /*! \brief whether the expression can be reduced to expression of dim */ + static const bool kRedPass = (kExpDim > dim) && kDevPass; +}; +/*! \brief used to help static type check*/ +template +struct TypeCheckPass; +// Todo : add static assert using C++11 +template<> +struct TypeCheckPass {}; +template<> +struct TypeCheckPass { + inline static void Error_All_Tensor_in_Exp_Must_Have_Same_Type(void) {} + inline static void Error_TypeCheck_Not_Pass_For_Reduce_Exp(void) {} + inline static void Error_Expression_Does_Not_Meet_Dimension_Req(void) {} +}; + +//---------------------------------------------------------------- +// Runtime Stream Getting +//---------------------------------------------------------------- +template +struct StreamInfo { + inline static Stream *Get(const E &t); +}; +template +struct StreamInfo > { + inline static Stream *Get(const Tensor &t) { + return t.stream_; + } +}; +//---------------------------------------------------------------- +// Runtime Shape Checking +//---------------------------------------------------------------- +/*! + * \brief runtime shape checking template + * get the shape of an expression, report error if shape mismatch + * \tparam dim the dimension of the shape + * \tparam E expression + */ +template +struct ShapeCheck { + inline static Shape Check(const E &t); +}; +template +struct ShapeCheck > { + inline static Shape Check(const ScalarExp &exp) { + // use lowest dimension to mark scalar exp + Shape shape; + for (int i = 0; i < dim; ++i) { + shape[i] = 0; + } + return shape; + } +}; +template +struct ShapeCheck > { + inline static Shape + Check(const TypecastExp &exp) { + return ShapeCheck::Check(exp.exp); + } +}; +template +struct ShapeCheck > { + inline static Shape Check(const TransposeExp &e) { + // swap the lowest two dimensions + Shape s = ShapeCheck::Check(e.exp); + std::swap(s[0], s[1]); + return s; + } +}; +template +struct ShapeCheck > { + inline static Shape Check(const Tensor &t) { + return t.shape_; + } +}; +template +struct ShapeCheck > { + inline static Shape + Check(const MakeTensorExp &t) { + return t.shape_; + } +}; +template +struct ShapeCheck > { + inline static Shape Check(const UnaryMapExp &t) { + Shape s = ShapeCheck::Check(t.src_); + return s; + } +}; + +template +struct ShapeCheck > { + inline static Shape + Check(const BinaryMapExp &t) { + Shape shape1 = ShapeCheck::Check(t.lhs_); + Shape shape2 = ShapeCheck::Check(t.rhs_); + if (shape1[0] == 0) return shape2; + if (shape2[0] == 0) return shape1; + CHECK_EQ(shape1, shape2) << "BinaryMapExp: Shapes of operands are not the same, " << + "Shape1=" << shape1 << ", Shape2=" << shape2; + return shape1; + } +}; + +template +struct ShapeCheck > { + inline static Shape + Check(const TernaryMapExp &t) { + Shape shape1 = ShapeCheck::Check(t.item1_); + Shape shape2 = ShapeCheck::Check(t.item2_); + Shape shape3 = ShapeCheck::Check(t.item3_); + bool same = (shape1 == shape2) && (shape2 == shape3); + CHECK(same) << "TernaryMapExp: Shapes of operands are not the same, " << + "Shape1=" << shape1 << ", Shape2=" << shape2 << ", Shape3=" << shape3; + + return shape1; + } +}; +} // namespace expr + +} // namespace mshadow +// include definition of dot engine +#include "./dot_engine-inl.h" + +namespace mshadow { +namespace expr { +/*! \brief some engine that evaluate complex expression */ +template +struct ExpComplexEngine { + inline static void Eval(RV *dst, const E &exp); +}; +/*! \brief the engine that dispatches simple operations*/ +template +struct ExpEngine { + template + inline static void Eval(RV *dst, + const Exp &exp) { + MapExp(dst, exp); + } + template + inline static void Eval(RV *dst, + const Exp &exp) { + MapExp(dst, exp); + } + template + inline static void Eval(RV *dst, + const Exp &exp) { + MapExp(dst, exp); + } + template + inline static void Eval(RV *dst, + const Exp &exp) { + ExpComplexEngine::Eval(dst->ptrself(), exp.self()); + } +}; +template +struct ExpComplexEngine, + DotExp, + Tensor, + ltrans, rtrans, DType>, + DType> { + inline static void Eval(Tensor *dst, + const DotExp, + Tensor, + ltrans, rtrans, DType> &exp) { + DotEngine::Eval(dst, exp.lhs_, exp.rhs_, exp.scale_); + } +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXPR_ENGINE_INL_H_ diff --git a/include/mshadow/expr_scalar-inl.h b/include/mshadow/expr_scalar-inl.h new file mode 100644 index 000000000000..1ddaba412543 --- /dev/null +++ b/include/mshadow/expr_scalar-inl.h @@ -0,0 +1,165 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file expr_scalar-inl.h + * \brief definitions of operators in expression with respect to scalar + * this file will be included several times, each time with MACRO MSHADOW_SCALAR_ to be different types + * + * DO NOT add pragma once or macro guard + * \author Tianqi Chen, Bing Xu + */ +// macro guard is harmful, used to pass the cpplint +#ifndef MSHADOW_EXPR_SCALAR_INL_H_ +#define MSHADOW_EXPR_SCALAR_INL_H_ +// undef the guard so it can be included multiple times +#undef MSHADOW_EXPR_SCALAR_INL_H_ + +namespace mshadow { +namespace expr { +// DotExp +/*! \brief dot operator def */ +template +inline DotExp +operator*(const DotExp &lhs, + MSHADOW_SCALAR_ rhs) { + return DotExp(lhs.lhs_, lhs.rhs_, lhs.scale_ * rhs); +} +/*! \brief scale of dot operation */ +template +inline DotExp +operator*(MSHADOW_SCALAR_ lhs, + const DotExp &rhs) { + return DotExp(rhs.lhs_, rhs.rhs_, rhs.scale_ * lhs); +} + +/*! \brief operator overload */ +template +inline ReduceTo1DExp +operator*(const ReduceTo1DExp &e, MSHADOW_SCALAR_ scale) { + return ReduceTo1DExp(e.src_, e.scale_ * scale); +} +/*! \brief operator overload */ +template +inline ReduceTo1DExp +operator*(MSHADOW_SCALAR_ scale, const ReduceTo1DExp &e) { + return ReduceTo1DExp(e.src_, e.scale_ * scale); +} + +/*! \brief operator overload for const */ +template +inline BinaryMapExp, + MSHADOW_SCALAR_, (ta|type::kMapper)> +F(const Exp &lhs, const ScalarExp &rhs) { + return MakeExp(lhs, rhs); +} +/*! \brief operator overload for const */ +template +inline BinaryMapExp, TB, + MSHADOW_SCALAR_, (tb|type::kMapper)> +F(const ScalarExp &lhs, const Exp &rhs) { + return MakeExp(lhs, rhs); +} +/*! \brief operator overload for const */ +template +inline BinaryMapExp, ScalarExp, + MSHADOW_SCALAR_, (1|type::kMapper)> +F(const ScalarExp &lhs, const ScalarExp &rhs) { + return MakeExp(lhs, rhs); +} +// constant operators +/*! \brief operator overload */ +template +inline BinaryMapExp, + MSHADOW_SCALAR_, (ta|type::kMapper)> +operator+(const Exp &lhs, + const ScalarExp &rhs) { + return MakeExp(lhs, rhs); +} +/*! \brief operator overload */ +template +inline BinaryMapExp, + MSHADOW_SCALAR_, (ta|type::kMapper)> +operator-(const Exp &lhs, + const ScalarExp &rhs) { + return MakeExp(lhs, rhs); +} +/*! \brief operator overload */ +template +inline BinaryMapExp, + MSHADOW_SCALAR_, (ta|type::kMapper)> +operator*(const Exp &lhs, + const ScalarExp &rhs) { + return MakeExp(lhs, rhs); +} +/*! \brief operator overload */ +template +inline BinaryMapExp, + MSHADOW_SCALAR_, (ta|type::kMapper)> +operator/(const Exp &lhs, + const ScalarExp &rhs) { + return MakeExp(lhs, rhs); +} +// constant operators 2 +/*! \brief operator overload */ +template +inline BinaryMapExp, TB, + MSHADOW_SCALAR_, (tb|type::kMapper)> +operator+(const ScalarExp &lhs, + const Exp &rhs) { + return MakeExp(lhs, rhs); +} +/*! \brief operator overload */ +template +inline BinaryMapExp, TB, + MSHADOW_SCALAR_, (tb|type::kMapper)> +operator-(const ScalarExp &lhs, + const Exp &rhs) { + return MakeExp(lhs, rhs); +} +/*! \brief operator overload */ +template +inline BinaryMapExp, TB, + MSHADOW_SCALAR_, (tb|type::kMapper)> +operator*(const ScalarExp &lhs, + const Exp &rhs) { + return MakeExp(lhs, rhs); +} +/*! \brief operator overload */ +template +inline BinaryMapExp, TB, + MSHADOW_SCALAR_, (tb|type::kMapper)> +operator/(const ScalarExp &lhs, const Exp &rhs) { + return MakeExp(lhs, rhs); +} +// constant operators 3 +/*! \brief operator overload */ +inline BinaryMapExp, ScalarExp, + MSHADOW_SCALAR_, (1|type::kMapper)> +operator+(const ScalarExp &lhs, + const ScalarExp &rhs) { + return MakeExp(lhs, rhs); +} +/*! \brief operator overload */ +inline BinaryMapExp, ScalarExp, + MSHADOW_SCALAR_, (1|type::kMapper)> +operator-(const ScalarExp &lhs, + const ScalarExp &rhs) { + return MakeExp(lhs, rhs); +} +/*! \brief operator overload */ +inline BinaryMapExp, ScalarExp, + MSHADOW_SCALAR_, (1|type::kMapper)> +operator*(const ScalarExp &lhs, + const ScalarExp &rhs) { + return MakeExp(lhs, rhs); +} +/*! \brief operator overload */ +inline BinaryMapExp, ScalarExp, + MSHADOW_SCALAR_, (1|type::kMapper)> +operator/(const ScalarExp &lhs, const ScalarExp &rhs) { + return MakeExp(lhs, rhs); +} +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXPR_SCALAR_INL_H_ diff --git a/include/mshadow/expression.h b/include/mshadow/expression.h new file mode 100644 index 000000000000..77f943165088 --- /dev/null +++ b/include/mshadow/expression.h @@ -0,0 +1,416 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file expression.h + * \brief definitions of abstract expressions and expressions template + * \author Tianqi Chen, Bing Xu + */ +#ifndef MSHADOW_EXPRESSION_H_ +#define MSHADOW_EXPRESSION_H_ +#include "./base.h" + +namespace mshadow { +/*! + * \brief namespace for abstract expressions and expressions template, + * have no dependency on tensor.h, + * These data structure takes no charge in computations, + * they are only used to define operations and represent expression in a symbolic way + */ +namespace expr { +/*! \brief type of expressions */ +namespace type { +// type expression type are defined as bitmask +// subtype relationshop kRValue < kMapper < kPull < kComplex +/*! + * \brief this expression directly correspnds to a data class, + * can be used to assign data + */ +const int kRValue = 0; +/*! + * \brief expression contains element-wise tensor operations, + * map a expression to same shape + */ +const int kMapper = 1; +/*! + * \brief expression that can be chained with other expressiones + * Usually it have function Eval(i,j) defined, which pulls the result (i, j) from input + * expression and output the result at certain position. + */ +const int kChainer = 3; +/*! \brief othercase: e.g dot product */ +const int kComplex = 7; +} // namespace type +/*! + * \brief expression engine that actually interprets these expressions + * this is a function template that needed to be implemented for specific expressions + * \tparam Saver the save method + * \tparam RValue the type of RValue to be saved + * \sa namespace sv + */ +template +struct ExpEngine; +/*! \brief defines how expression exp can be evaluated and stored into dst */ +// template +// inline static void Eval(RValue *dst, const EType &exp); +/*! + * \brief base class for expression + * \tparam SubType inheritated class must put their type into this parameter + * \tparam DType the data type of each element in the expression + * \tparam exp_type expression type, see namespace type + */ +template +struct Exp { + public: + /*! \return subtype instance of current class */ + inline const SubType& self(void) const { + return *static_cast(this); + } + /*! \return reference of subtype instance of current class */ + inline SubType* ptrself(void) { + return static_cast(this); + } +}; +/*! + * \brief scalar expression + * \tparam DType the data type of the scalar + */ +template +struct ScalarExp: public Exp, DType, type::kMapper> { + /*! \brief scalar value */ + DType scalar_; + /*! \brief implicit constructor, MUST NOT BE explicit */ + ScalarExp(DType scalar) : scalar_(scalar) {} // NOLINT(*) +}; +/*! \brief create an scalar expression */ +template +inline ScalarExp scalar(DType s) { + return ScalarExp(s); +} +/*! + * \brief typecast expression, cast the type of elements + * \tparam DstDType the target type we want to cast into + * \tparam SrcDType the target type we want to cast from + * \tparam EType the type of the source expression + * \tparam etype the type of expression after cast + */ +template +struct TypecastExp: + public Exp, + DstDType, etype> { + /*! \brief expression to be typecasted */ + const EType &exp; + /*! \brief constructor */ + explicit TypecastExp(const EType &e) : exp(e) {} +}; +/*! \brief create an scalar expression */ +template +inline TypecastExp +tcast(const Exp &exp) { + return TypecastExp(exp.self()); +} +/*! \brief represent a transpose expression of a container */ +template +struct TransposeExp: public Exp, + DType, type::kChainer> { + /*! \brief expression to be transposed */ + const EType &exp; + /*! \brief constructor */ + explicit TransposeExp(const EType &e) : exp(e) {} + /*! \brief transpose expression */ + inline const EType &T(void) const { + return exp; + } +}; +/*! + * \brief base class of all rvalues + * \tparam Container the actually class of data container, e.g. Tensor1D + * \tparam DataType the element data type of each element in the container + */ +template +class RValueExp: public Exp { + public: + /*! + *\brief transpose of a matrix + *\return transpose of current expression + */ + inline const TransposeExp T(void) const { + return TransposeExp(this->self()); + } + /*! \brief operator overload */ + inline Container &operator+=(DType s) { + ExpEngine::Eval(this->ptrself(), scalar(s)); + return *(this->ptrself()); + } + /*! \brief operator overload */ + inline Container &operator-=(DType s) { + ExpEngine::Eval(this->ptrself(), scalar(s)); + return *(this->ptrself()); + } + /*! \brief operator overload */ + inline Container &operator*=(DType s) { + ExpEngine::Eval(this->ptrself(), scalar(s)); + return *(this->ptrself()); + } + /*! \brief operator overload */ + inline Container &operator/=(DType s) { + ExpEngine::Eval(this->ptrself(), scalar(s)); + return *(this->ptrself()); + } + /*! \brief operator overload */ + inline Container &__assign(DType s) { + ExpEngine::Eval(this->ptrself(), scalar(s)); + return *(this->ptrself()); + } + /*! \brief we can not define container = container */ + template + inline Container &__assign(const Exp &exp) { + ExpEngine::Eval(this->ptrself(), exp.self()); + return *(this->ptrself()); + } + /*! \brief operator overload, assign */ + inline Container &__assign(const Exp &exp); + /*! \brief implementation of operator+= */ + template + inline Container &operator+=(const Exp &exp) { + ExpEngine::Eval(this->ptrself(), exp.self()); + return *(this->ptrself()); + } + /*! \brief implementation of operator-= */ + template + inline Container &operator-=(const Exp &exp) { + ExpEngine::Eval(this->ptrself(), exp.self()); + return *(this->ptrself()); + } + /*! \brief implementation of operator*= */ + template + inline Container &operator*=(const Exp &exp) { + ExpEngine::Eval(this->ptrself(), exp.self()); + return *(this->ptrself()); + } + /*! \brief implementation of operator/= */ + template + inline Container &operator/=(const Exp &exp) { + ExpEngine::Eval(this->ptrself(), exp.self()); + return *(this->ptrself()); + } +}; +/*! + * \brief matrix multiplication expression dot(lhs[.T], rhs[.T]) + * \tparam TA type of lhs + * \tparam TB type of rhs + * \tparam ltrans whether lhs is transposed + * \tparam rtrans whether rhs is transposed + * \tparam DType the data type of the scalar + */ +template +struct DotExp: public Exp, + DType, type::kComplex> { + /*! \brief left operand */ + const TA &lhs_; + /*! \brief right operand */ + const TB &rhs_; + /*! \brief scale over result */ + DType scale_; + /*! \brief constructor */ + explicit DotExp(const TA &lhs, const TB &rhs, DType scale) + : lhs_(lhs), rhs_(rhs), scale_(scale) {} +}; +// definition of dot expression +/*! \brief dot operator def */ +template +inline DotExp +dot(const RValueExp &lhs, const RValueExp &rhs) { + return DotExp(lhs.self(), rhs.self(), DType(1.0f)); +} +/*! \brief dot operator def */ +template +inline DotExp +dot(const TransposeExp &lhs, const RValueExp &rhs) { + return DotExp(lhs.exp, rhs.self(), DType(1.0f)); +} +/*! \brief dot operator def */ +template +inline DotExp +dot(const RValueExp &lhs, const TransposeExp &rhs) { + return DotExp(lhs.self(), rhs.exp, DType(1.0f)); +} +/*! \brief dot operator def */ +template +inline DotExp +dot(const TransposeExp &lhs, const TransposeExp &rhs) { + return DotExp(lhs.exp, rhs.exp, DType(1.0f)); +} +/*! \brief batch_dot operator def */ +template +inline DotExp +batch_dot(const RValueExp &lhs, const RValueExp &rhs) { + return DotExp( + lhs.self(), rhs.self(), DType(1.0f)); +} +//--------------- +// TernaryMapExp +// -------------- +/*! + * \brief ternary map expression + * \tparam OP operator + * \tparam TA type of item1 + * \tparam TB type of item2 + * \tparam etype expression type, sa namespace::type + */ +template +struct TernaryMapExp: public Exp, + DType, etype> { + /*! \brief first operand */ + const TA &item1_; + /*! \brief second operand */ + const TB &item2_; + /*! \brief third operand */ + const TC &item3_; + /*! \brief constructor */ + explicit TernaryMapExp(const TA &item1, const TB &item2, const TC &item3) + :item1_(item1), item2_(item2), item3_(item3) {} +}; + +/*! \brief make expression */ +template +inline TernaryMapExp +MakeExp(const Exp &item1, const Exp &item2, + const Exp &item3) { + return TernaryMapExp(item1.self(), item2.self(), item3.self()); +} +/*! + * \brief short hand for MakeExp, usage F(item1,item2,item3). create a ternary operation expression + * \param item1 first operand + * \param item2 second operand + * \param item3 third operand + * \return the result expression + * \tparam ternary operator + * \tparam TA item1 expression + * \tparam ta item1 expression type + * \tparam TB item2 expression + * \tparam tb item2 expression type + * \tparam TC item3 expression + * \tparam tc item3 expression type + * \sa mshadow::op + */ + +// Ternary +template +inline TernaryMapExp +F(const Exp &item1, const Exp &item2, + const Exp &item3) { + return MakeExp(item1, item2, item3); +} +//--------------- +// BinaryMapExp +// -------------- +/*! + * \brief binary map expression lhs [op] rhs + * \tparam OP operator + * \tparam TA type of lhs + * \tparam TB type of rhs + * \tparam etype expression type, sa namespace::type + */ +template +struct BinaryMapExp: public Exp, + DType, etype> { + /*! \brief left operand */ + const TA &lhs_; + /*! \brief right operand */ + const TB &rhs_; + /*! \brief constructor */ + explicit BinaryMapExp(const TA &lhs, const TB &rhs) + :lhs_(lhs), rhs_(rhs) {} +}; + +/*! \brief make expression */ +template +inline BinaryMapExp +MakeExp(const Exp &lhs, const Exp &rhs) { + return BinaryMapExp(lhs.self(), rhs.self()); +} +/*! + * \brief short hand for MakeExp, usage F(lhs, rhs). create a binary operation expression + * \param lhs left operand + * \param rhs right operand + * \return the result expression + * \tparam binary operator + * \tparam TA lhs expression + * \tparam ta lhs expression type + * \tparam TB rhs expression + * \tparam tb rhs expression type + * \sa mshadow::op + */ +template +inline BinaryMapExp +F(const Exp &lhs, const Exp &rhs) { + return MakeExp(lhs, rhs); +} +// operator rules +/*! \brief operator overload */ +template +inline BinaryMapExp +operator+(const Exp &lhs, const Exp &rhs) { + return MakeExp(lhs, rhs); +} +/*! \brief operator overload */ +template +inline BinaryMapExp +operator-(const Exp &lhs, const Exp &rhs) { + return MakeExp(lhs, rhs); +} +/*! \brief operator overload */ +template +inline BinaryMapExp +operator*(const Exp &lhs, const Exp &rhs) { + return MakeExp(lhs, rhs); +} +/*! \brief operator overload */ +template +inline BinaryMapExp +operator/(const Exp &lhs, const Exp &rhs) { + return MakeExp(lhs, rhs); +} +//--------------- +// UnaryMapExp +// -------------- +/*! + * \brief unary map expression op(src) + * \tparam OP operator + * \tparam TA type of src + * \tparam etype expression type, sa namespace::type + */ +template +struct UnaryMapExp: public Exp, + DType, etype> { + /*! \brief source expression */ + const TA &src_; + /*! \brief constructor */ + explicit UnaryMapExp(const TA &src) : src_(src) {} +}; + +/*! \brief make expression */ +template +inline UnaryMapExp +MakeExp(const Exp &src) { + return UnaryMapExp(src.self()); +} +/*! + * \brief short hand for MakeExp, usage F(src), create a unary operation expression + * \param src source expression + * \return the result expression + * \tparam operator + * \tparam TA source expression + * \tparam ta source expression type + * \sa mshadow::op + */ +template +inline UnaryMapExp +F(const Exp &src) { + return MakeExp(src); +} +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXPRESSION_H_ diff --git a/include/mshadow/extension.h b/include/mshadow/extension.h new file mode 100644 index 000000000000..7af0f56f7699 --- /dev/null +++ b/include/mshadow/extension.h @@ -0,0 +1,41 @@ +/*! + * Copyright by Contributors + * \file extension.h + * \brief some extension of expressions, + * used to support something beyond elementwise op + * \author Tianqi Chen, Bing Xu + */ +#ifndef MSHADOW_EXTENSION_H_ +#define MSHADOW_EXTENSION_H_ +#include "./expr_engine-inl.h" +#include "./extension/broadcast.h" +#include "./extension/unpack_patch2col.h" +#include "./extension/pack_col2patch.h" +#include "./extension/reshape.h" +#include "./extension/swapaxis.h" +#include "./extension/reduceto1d.h" +#include "./extension/spatial_pool.h" +#include "./extension/spatial_unpool.h" +#include "./extension/channel_pool.h" +#include "./extension/channel_unpool.h" +#include "./extension/pad.h" +#include "./extension/crop.h" +#include "./extension/mirror.h" +#include "./extension/concat.h" +#include "./extension/implicit_gemm.h" +#include "./extension/choose.h" +#include "./extension/fill.h" +#include "./extension/one_hot.h" +#include "./extension/slice.h" +#include "./extension/slice_ex.h" +#include "./extension/take.h" +#include "./extension/take_grad.h" +#include "./extension/reduce_with_axis.h" +#include "./extension/broadcast_with_axis.h" +#include "./extension/spatial_upsampling_nearest.h" +#include "./extension/transpose.h" +#include "./extension/flip.h" +#include "./extension/complex.h" +#include "./extension/range.h" +#include "./extension/mask.h" +#endif // MSHADOW_EXTENSION_H_ diff --git a/include/mshadow/extension/broadcast.h b/include/mshadow/extension/broadcast.h new file mode 100644 index 000000000000..ea138ccd9e4d --- /dev/null +++ b/include/mshadow/extension/broadcast.h @@ -0,0 +1,165 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file broadcast.h + * \brief support for broadcast and repmat + * \author Tianqi Chen + */ +#ifndef MSHADOW_EXTENSION_BROADCAST_H_ +#define MSHADOW_EXTENSION_BROADCAST_H_ +#include "../extension.h" +namespace mshadow { +namespace expr { +/*! + * \brief broadcast Tensor1D into a higher dimension Tensor + * input: Tensor: ishape[0] + * output: Tensor : oshape[dimcast] = ishape[0] + * \tparam SrcExp type of input expression + * \tparam DType the type of elements + * \tparam dimdst target tensor dimension + * \tparam dimcast_m_dst dimdst - dimcast + */ +template +struct Broadcast1DExp: + public MakeTensorExp, + SrcExp, dimdst, DType> { + /*! \brief source operand */ + const SrcExp &src_; + /*! \brief constructor */ + Broadcast1DExp(const SrcExp &src, Shape shape) + : src_(src) { + this->shape_ = shape; + } +}; + +/*! + * \brief broadcast scalar into a higher dimension Tensor + * input: Tensor: ishape = {1} + * output: Tensor : oshape[dimcast] = ishape[0] + * \tparam SrcExp type of input expression + * \tparam DType the type of elements + * \tparam dimdst target tensor dimension + */ +template +struct BroadcastScalarExp: + public MakeTensorExp, + SrcExp, dimdst, DType> { + /*! \brief source operand */ + const SrcExp &src_; + /*! \brief constructor */ + BroadcastScalarExp(const SrcExp &src, Shape shape) + : src_(src) { + this->shape_ = shape; + } +}; + +/*! + * \brief a expression that replicate a 1 dimension tensor in dimension dimcast + * \param src Tensor: shape[0] + * \param shape shape of output + * \return a expresion with type Tensor + * \tparam dimcast target dimension where the 1D tensor will be broadcasted + * \tparam SrcExp type of input expression + * \tparam DType the type of elements + * \tparam dimdst dimension of destination tensor + * \tparam dimcast_lowest the dimension we want to cast the data into + */ +template +inline Broadcast1DExp +broadcast(const expr::Exp &src, Shape shape) { + TypeCheckPass::kDim == 1> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + typedef ShapeCheck<1, SrcExp> ShapeCheckDim1SrcExp; + CHECK_EQ(ShapeCheckDim1SrcExp::Check(src.self())[0], shape[dimcast]) + << "broadcast, shape mismatch"; + return Broadcast1DExp(src.self(), shape); +} + +/*! + * \brief a expression that replicate a scalar tensor to target dimension. + * \param src Tensor: shape[0] == 1 + * \param shape shape of output + * \return a expresion with type Tensor + * \tparam dimcast target dimension where the 1D tensor will be broadcasted + * \tparam SrcExp type of input expression + * \tparam DType the type of elements + * \tparam dimdst dimension of destination tensor + */ +template +inline BroadcastScalarExp +broadcast_scalar(const expr::Exp &src, Shape shape) { + TypeCheckPass::kDim == 1> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + typedef ShapeCheck<1, SrcExp> ShapeCheckDim1SrcExp; + CHECK_EQ(ShapeCheckDim1SrcExp::Check(src.self())[0], 1U) + << "broadcast_scalar, source need to be scalar expression"; + return BroadcastScalarExp(src.self(), shape); +} +// short cut functions +/*! + * \brief a expression that replicate a 1 dimension tensor for nrow times + * \param src Tensor: shape[0] + * \param nrow number of rows to replicate + * \return a expresion with type Tensor size(1), size(0) = nrow + * \tparam Device which device it lies + */ +template +inline Broadcast1DExp +repmat(const expr::Exp &src, index_t nrow) { + return broadcast<1> + (src, Shape2(nrow, ShapeCheck<1, SrcExp>::Check(src.self())[0])); +} +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + static const int dimcast = dimdst - dimdst_m_cast; + explicit Plan(const Broadcast1DExp &e) + : src_(MakePlan(e.src_)), + ystride_(e.shape_.ProdShape(dimcast + 1, dimdst - 1)), + length_(e.shape_[dimcast]) { + TypeCheckPass + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + } + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + return src_.Eval(0, (y / ystride_) % length_); + } + + private: + expr::Plan src_; + const index_t ystride_, length_; +}; + +/*! \brief execution plan of Broadcast1DExp */ +template +struct Plan, DType>{ + public: + explicit Plan(const Broadcast1DExp &e) + : src_(MakePlan(e.src_)) {} + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + return src_.Eval(0, x); + } + + private: + expr::Plan src_; +}; + +/*! \brief execution plan of Broadcast1DExp */ +template +struct Plan, DType>{ + public: + explicit Plan(const BroadcastScalarExp &e) + : src_(MakePlan(e.src_)) {} + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + return src_.Eval(0, 0); + } + + private: + expr::Plan src_; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_BROADCAST_H_ diff --git a/include/mshadow/extension/broadcast_with_axis.h b/include/mshadow/extension/broadcast_with_axis.h new file mode 100644 index 000000000000..49605af67d32 --- /dev/null +++ b/include/mshadow/extension/broadcast_with_axis.h @@ -0,0 +1,258 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file broadcast_with_axis.h + * \brief + * \author Junyuan Xie, Xingjian Shi +*/ +#ifndef MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_ +#define MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_ + +#include +#include "../extension.h" + +namespace mshadow { +namespace expr { + + /*! + * \brief Broadcasting the tensor in the given axis. If keepdim is off, insert the broadcasting dim after axis. Otherwise broadcasting axis. + * \tparam SrcExp source expression + * \tparam DType data type + * \tparam dimsrc source dimension + * \tparam dimdst destination dimension + */ +template +struct BroadcastWithAxisExp: + public MakeTensorExp, + SrcExp, dimdst, DType> { + /*! \brief data oprand */ + const SrcExp &src_; + /*! \brief size of the last dimension of dst */ + index_t dst_last_; + /*! \brief product of the dimensions after the broadcasting axis */ + index_t trailing_; + /*! \brief new dimension of the broadcasting axis*/ + index_t size_; + /*! \brief size of the last dimension of src*/ + index_t last_; + /*! constructor */ + BroadcastWithAxisExp(const SrcExp &src, const int axis, const index_t size) + : src_(src), size_(size) { + bool keepdim = (dimsrc == dimdst); + Shape src_shape = ShapeCheck::Check(src_); + this->trailing_ = 1; + + if (!keepdim) { + CHECK(dimsrc > axis && axis >= -1) << "broadcast axis (no keepdim) out of bound, " << + "axis must be between -1 and" << dimsrc - 1 << ", given=" << axis << "."; + for (int i = 0; i <= axis; ++i) { + this->shape_[i] = src_shape[i]; + } + this->shape_[axis + 1] = size_; + for (int i = axis + 1; i < dimsrc; ++i) { + this->trailing_ *= src_shape[i]; + this->shape_[i + 1] = src_shape[i]; + } + } else { + CHECK(dimdst > axis && axis >= 0) << "broadcast axis (keepdim) out of bound, " << + "axis must be between 0 and" << dimdst - 1 << ", given=" << axis << "."; + CHECK_EQ(src_shape[axis], 1U) << "Size of the dimension of the broadcasting axis must be 1" << + " when keepdim is on, src_shape[" << axis << "]=" << src_shape[axis] << "."; + for (int i = 0; i <= axis - 1; ++i) { + this->shape_[i] = src_shape[i]; + } + this->shape_[axis] = size_; + for (int i = axis + 1; i < dimdst; ++i) { + this->trailing_ *= src_shape[i]; + this->shape_[i] = src_shape[i]; + } + } + + this->last_ = src_shape[dimsrc - 1]; + this->dst_last_ = this->shape_[dimdst - 1]; + } +}; // struct BroadcastWithAxisExp + +/*! + * \brief Broadcasting the tensor after given axis. + * \tparam SrcExp source expression + * \tparam DType data type + * \tparam etype type of the expression + */ +template +inline BroadcastWithAxisExp::kDim, + ExpInfo::kDim + 1> +broadcast_with_axis(const Exp &src, const int axis, const index_t size) { + return BroadcastWithAxisExp::kDim, + ExpInfo::kDim + 1>(src.self(), axis, size); +} + +/*! +* \brief Broadcasting the tensor in the given axis (keepdim turned on) +* \tparam SrcExp source expression +* \tparam DType data type +* \tparam etype type of the expression +*/ +template +inline BroadcastWithAxisExp::kDim, + ExpInfo::kDim> + broadcast_keepdim(const Exp &src, const int axis, const index_t size) { + return BroadcastWithAxisExp::kDim, + ExpInfo::kDim>(src.self(), axis, size); +} + +/*! +* \brief Broadcasting the tensor in multiple axes. The dimension of the source tensor + in the given axes must be 1. +* \tparam SrcExp source expression +* \tparam DType data type +* \tparam dimsrc source dimension +* \tparam axesnum number of broadcasting dimensions +*/ +template +struct BroadcastWithMultiAxesExp : + public MakeTensorExp, + SrcExp, dimsrc, DType> { + /*! \brief data oprand */ + const SrcExp &src_; + /*! \brief size of the last dimension of dst */ + index_t dst_last_; + /*! \brief number of broadcasting axes*/ + index_t axesnum_; + /*! \brief product of the dimensions after the broadcasting axses */ + Shape trailings_; + /*! \brief new dimension of the broadcasting axes*/ + Shape sizes_; + /*! \brief size of the last dimension of src*/ + index_t last_; + /*! constructor */ + template + BroadcastWithMultiAxesExp(const SrcExp &src, const TShape& axes, const TShape& sizes) + : src_(src) { + Shape src_shape = ShapeCheck::Check(src_); + CHECK(axes.ndim() == sizes.ndim()) << "ndim of axes and sizes must be equal."; + this->axesnum_ = axes.ndim(); + CHECK(this->axesnum_ <= dimsrc) << "Number of broadcasting axes must be smaller than" + "the source ndim, number of axes=" << this->axesnum_ << " dimsrc=" << dimsrc; + for (index_t i = 0; i < this->axesnum_; i++) { + CHECK(dimsrc > axes[i]) << "broadcast axis (keepdim) out of bound, " << + "all axes must be between 0 and" << dimsrc - 1 << ", given axes[" << i << "] = " << axes[i] + << "."; + CHECK_EQ(src_shape[axes[i]], 1U) << "Size of the dimension of the broadcasting axis must be 1" + << ", src_shape[" << axes[i] << "]=" << src_shape[axes[i]] << "."; + if (i < this->axesnum_ - 1) { + CHECK(axes[i] < axes[i + 1]) << "The given axes must be in increasing order."; + } + } + for (index_t i = 0; i < dimsrc; i++) { + this->shape_[i] = src_shape[i]; + this->sizes_[i] = 1; + this->trailings_[i] = 1; + } + for (index_t i = 0; i < this->axesnum_; i++) { + this->shape_[axes[i]] = sizes[i]; + this->sizes_[i] = sizes[i]; + } + for (index_t i = 0; i < this->axesnum_; i++) { + this->trailings_[i] = 1; + for (index_t j = axes[i] + 1; j < dimsrc; ++j) { + this->trailings_[i] *= this->shape_[j]; + } + } + this->last_ = src_shape[dimsrc - 1]; + this->dst_last_ = this->shape_[dimsrc - 1]; + } +}; // struct BroadcastWithMultiAxesExp + +/*! +* \brief Broadcasting the tensor in the given axis (keepdim turned on) +* \param src source +* \param axes broadcasting axes +* \param sizes sizes of the broadcasting axes +* \tparam SrcExp source expression +* \tparam DType data type +* \tparam etype type of the expression +* \tparam TShape the flexible shape type +*/ +template +inline BroadcastWithMultiAxesExp::kDim> +broadcast_multi_axes(const Exp &src, +const TShape &axes, const TShape &sizes) { + return BroadcastWithMultiAxesExp::kDim>(src.self(), axes, sizes); +} + +/*! +* \brief Broadcasting the tensor to the target shape, + dimension of different sizes must be 1 in the original tensor. +* \param src source +* \param target_shape shape of the target broadcasting tensor +* \tparam SrcExp source expression +* \tparam DType data type +* \tparam etype type of the expression +* \tparam TShape the flexible shape type +*/ +template +inline BroadcastWithMultiAxesExp::kDim> +broadcast_to(const Exp &src, const TShape &target_shape) { + static const size_t dimsrc = ExpInfo::kDim; + CHECK_EQ(target_shape.ndim(), dimsrc); + std::vector axes_vec, sizes_vec; + Shape src_shape = ShapeCheck::Check(src.self()); + for (size_t i = 0; i < dimsrc; ++i) { + if (src_shape[i] != target_shape[i]) { + CHECK_EQ(src_shape[i], 1U) << "broadcasting axis must have size 1, received shape=" + << src_shape << " target_shape=" << target_shape; + axes_vec.push_back(i); + sizes_vec.push_back(target_shape[i]); + } + } + TShape axes = TShape(axes_vec.begin(), axes_vec.end()); + TShape sizes = TShape(sizes_vec.begin(), sizes_vec.end()); + return BroadcastWithMultiAxesExp::kDim>(src.self(), axes, sizes); +} + +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + explicit Plan(const BroadcastWithAxisExp &e) + : src_(MakePlan(e.src_)), dst_last_(e.dst_last_), + trailing_(e.trailing_), size_(e.size_), last_(e.last_) {} + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + index_t x = (i * dst_last_ + j) / trailing_ / size_; + index_t y = (i * dst_last_ + j) % trailing_; + index_t z = x * trailing_ + y; + return src_.Eval(z / last_, z % last_); + } + + private: + Plan src_; + const index_t dst_last_, trailing_, size_, last_; +}; + +template +struct Plan, DType> { + public: + explicit Plan(const BroadcastWithMultiAxesExp &e) + : src_(MakePlan(e.src_)), dst_last_(e.dst_last_), last_(e.last_), axesnum_(e.axesnum_), + trailings_(e.trailings_), sizes_(e.sizes_) {} + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + index_t indx = i * dst_last_ + j; + for (index_t p = 0; p < dimsrc; ++p) { + if (p >= axesnum_) { + break; + } + indx = (indx / trailings_[p] / sizes_[p]) * trailings_[p] + (indx % trailings_[p]); + } + return src_.Eval(indx / last_, indx % last_); + } + + private: + Plan src_; + const index_t dst_last_, last_, axesnum_; + const Shape trailings_, sizes_; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_ diff --git a/include/mshadow/extension/channel_pool.h b/include/mshadow/extension/channel_pool.h new file mode 100644 index 000000000000..60d1112f4a61 --- /dev/null +++ b/include/mshadow/extension/channel_pool.h @@ -0,0 +1,108 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file channel_pool.h + * \brief support for chpool + * \author Tianqi Chen + */ +#ifndef MSHADOW_EXTENSION_CHANNEL_POOL_H_ +#define MSHADOW_EXTENSION_CHANNEL_POOL_H_ +#include +#include "../extension.h" +namespace mshadow { +namespace expr { +/*! + * \brief channel pooling expression, do reduction over (local nearby) channels, + * used to implement local response normalization + * \tparam Reducer reduction method during pooling + * \tparam SrcExp source expression to be pooled from + * \tparam DType the type of elements + * \tparam srcdim dimension of src + */ +template +struct ChannelPoolingExp: + public MakeTensorExp, + SrcExp, srcdim, DType> { + /*! \brief source operand */ + const SrcExp &src_; + /*! \brief neighbor size */ + index_t nsize_; + /*! \brief stride of pooling */ + index_t stride_; + /*! \brief pad of pooling of each side */ + index_t pad_; + index_t src_channel_; + /*! \brief constructor */ + ChannelPoolingExp(const SrcExp &src, index_t nsize, index_t stride, index_t pad) + : src_(src), nsize_(nsize), stride_(stride), pad_(pad) { + this->shape_ = ShapeCheck::Check(src_); + this->src_channel_ = this->shape_[srcdim - 3]; + CHECK_GE(this->shape_[srcdim - 3], nsize_) + << "chpool: local size must be smaller than nchannels"; + this->shape_[srcdim - 3] = (this->src_channel_ - nsize + pad * 2 + 1) / stride; + } +}; +/*! + * \brief channel pooling, do reduction over (local nearby) channels, + * used to implement local response normalization + * \param src source data + * \param nsize neighbor size + * \return expression of pooled result + * \tparam Reducer reducer type + * \tparam SrcExp source expression + * \tparam DType the type of elements + * \tparam etype type of expression + */ +template +inline ChannelPoolingExp::kDim> +chpool(const Exp &src, index_t nsize) { + TypeCheckPass::kDim >= 3> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + CHECK_EQ(nsize % 2, 1U) << "chpool: if no pad is specified, local size must be odd"; + return ChannelPoolingExp::kDim>(src.self(), nsize, 1, nsize / 2); +} + +template +inline ChannelPoolingExp::kDim> +chpool(const Exp &src, index_t nsize, index_t stride, index_t pad) { + TypeCheckPass::kDim >= 3> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + return ChannelPoolingExp::kDim>(src.self(), nsize, stride, pad); +} + +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + explicit Plan(const ChannelPoolingExp &e) + : src_(MakePlan(e.src_)), channel_(e.shape_[srcdim - 3]), + height_(e.shape_[srcdim - 2]), width_(e.shape_[srcdim - 1]), + hnsize_(e.nsize_), stride_(e.stride_), pad_(e.pad_), + src_channel_(e.src_channel_) {} + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + using namespace std; + const index_t y = i % height_; + i /= height_; + const index_t c = i % channel_; + const index_t n = i / channel_; + const index_t x = j; + const index_t cstart = c * stride_ < pad_ ? 0 : c * stride_ - pad_; + const index_t cend = min(c * stride_ - pad_ + hnsize_, channel_); + DType res; Reducer::SetInitValue(res); + for (index_t cc = cstart; cc < cend; ++cc) { + Reducer::Reduce(res, src_.Eval((n * src_channel_ + cc) * height_ + y, x)); + } + return res; + } + + private: + Plan src_; + const index_t channel_, height_, width_, hnsize_, stride_, pad_, src_channel_; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_CHANNEL_POOL_H_ + diff --git a/include/mshadow/extension/channel_unpool.h b/include/mshadow/extension/channel_unpool.h new file mode 100644 index 000000000000..00ba279c1760 --- /dev/null +++ b/include/mshadow/extension/channel_unpool.h @@ -0,0 +1,137 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file channel_pool.h + * \brief support for chpool + * \author Tianqi Chen + */ +#ifndef MSHADOW_EXTENSION_CHANNEL_UNPOOL_H_ +#define MSHADOW_EXTENSION_CHANNEL_UNPOOL_H_ +#include +#include "../extension.h" +namespace mshadow { +namespace expr { +/*! + * \brief channel pooling expression, do reduction over (local nearby) channels, + * used to implement local response normalization + * \tparam Reducer reduction method during pooling + * \tparam SrcExp source expression to be pooled from + * \tparam DType the type of elements + * \tparam srcdim dimension of src + */ +template +struct ChannelUnpoolingExp: + public MakeTensorExp, + SrcExp, srcdim, DType> { + /*! \brief source input, corresponds to src in pooling */ + const SrcExp &data_src_; + /*! \brief result of pooled data, corresponds to result of pooling */ + const SrcExp &data_pooled_; + /*! \brief gradient data of pooled part, to be propgate down */ + const SrcExp &grad_pooled_; + /*! \brief channel of pooled expression */ + index_t pchannel_; + /*! \brief kernel size in height */ + index_t nsize_; + /*! \brief kernel size in width */ + index_t kstride_; + /*! \brief pad */ + index_t pad_; + /*! \brief constructor */ + ChannelUnpoolingExp(const SrcExp &data_src, + const SrcExp &data_pooled, + const SrcExp &grad_pooled, + index_t nsize, index_t kstride, index_t pad) + : data_src_(data_src), data_pooled_(data_pooled), + grad_pooled_(grad_pooled), + nsize_(nsize), kstride_(kstride), pad_(pad) { + Shape pshape = ShapeCheck::Check(grad_pooled); + typedef ShapeCheck ShapeCheckSrcDimSrcExp; + CHECK_EQ(pshape, ShapeCheckSrcDimSrcExp::Check(data_pooled)) + << "ChannelUnPoolingExp: data and grad shape mismatch"; + Shape sshape = ShapeCheck::Check(data_src); + for (int k = 0; k < srcdim; ++k) { + if (k == 1) { + continue; + } + CHECK_EQ(pshape[k], sshape[k]) + << "ChannelUnPoolingExp: pooled tensor and src tensor shape mismatch" + << pshape[k] + << " vs " + << sshape[k]; + } + pchannel_ = pshape[1]; + this->shape_ = sshape; + } +}; +/*! + * \brief channel unpooling, do unroll over (local nearby) channels + * \param src source data + * \param nsize neighbor size + * \param stride stride of the pooling + * \param pad number of padding at each side + * \return expression of pooled result + * \tparam Reducer reducer type + * \tparam SrcExp source expression + * \tparam DType the type of elements + * \tparam etype type of expression + */ +template +inline ChannelUnpoolingExp::kDim> +ch_unpool(const Exp &data_src, + const Exp &data_pooled, + const Exp &grad_pooled, + index_t nsize, index_t stride, index_t pad) { + TypeCheckPass::kDim >= 3> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + return ChannelUnpoolingExp::kDim> + (data_src.self(), data_pooled.self(), grad_pooled.self(), nsize, stride, pad); +} + +template +inline ChannelUnpoolingExp::kDim> +ch_unpool(const Exp &data_src, + const Exp &data_pooled, + const Exp &grad_pooled, index_t nsize) { + return ch_unpool(data_src, data_pooled, grad_pooled, nsize, 1, nsize / 2); +} + + +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + explicit Plan(const ChannelUnpoolingExp &e) + : data_src_(e.data_src_), data_pooled_(e.data_pooled_), + grad_pooled_(e.grad_pooled_), channel_(e.shape_[srcdim - 3]), + height_(e.shape_[srcdim - 2]), pchannel_(e.pchannel_), + hnsize_(e.nsize_), stride_(e.kstride_), pad_(e.pad_) {} + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + using namespace std; + const DType vsrc = data_src_.Eval(i, j); + const index_t y = i % height_; + i /= height_; + const index_t c = i % channel_; + const index_t n = i / channel_; + const index_t x = j; + const index_t cstart = c < hnsize_ - pad_ ? 0 + : (c - (hnsize_ - pad_) + stride_) / stride_; + const index_t cend = min((c + pad_ + stride_) / stride_, channel_); + DType val = static_cast(0); + for (index_t cc = cstart; cc < cend; ++cc) { + val += Reducer::PartialGrad(vsrc, + data_pooled_.Eval((n * pchannel_ + cc) * height_ + y, x)) * + grad_pooled_.Eval((n * pchannel_ + cc) * height_ + y, x); + } + return val; + } + + private: + Plan data_src_, data_pooled_, grad_pooled_; + const index_t channel_, height_, pchannel_, hnsize_, stride_, pad_; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_CHANNEL_UNPOOL_H_ + diff --git a/include/mshadow/extension/choose.h b/include/mshadow/extension/choose.h new file mode 100644 index 000000000000..b1391724d400 --- /dev/null +++ b/include/mshadow/extension/choose.h @@ -0,0 +1,90 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file choose.h + * \brief support for implicit array selection operation + * \author Tianqi Chen + */ +#ifndef MSHADOW_EXTENSION_CHOOSE_H_ +#define MSHADOW_EXTENSION_CHOOSE_H_ + +#include "../extension.h" + +namespace mshadow { +namespace expr { +/*! + * \brief Make a choice of index in the lowest changing dimension. + * \tparam SrcExp type of lhs expression + * \tparam IndexExp type of index expression + * \tparam DType the type of elements + */ +template +struct MatChooseRowElementExp: + public Exp, + DType, type::kChainer> { + /*! \brief source operand */ + const SrcExp &src_; + /*! \brief index operand */ + const IndexExp &index_; + /*! \brief constructor */ + MatChooseRowElementExp(const SrcExp &src, const IndexExp &index) + : src_(src), index_(index) {} +}; + +template +inline MatChooseRowElementExp +mat_choose_row_element(const Exp &src, + const Exp &index) { + TypeCheckPass::kDim == 2 && ExpInfo::kDim == 1> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + return MatChooseRowElementExp(src.self(), index.self()); +} + +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + explicit Plan(const MatChooseRowElementExp &e) + : src_(MakePlan(e.src_)), + index_(MakePlan(e.index_)) { + } + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + index_t idx = static_cast(index_.Eval(0, x)); + return src_.Eval(x, idx); + } + + private: + expr::Plan src_; + expr::Plan index_; +}; + +template +inline Plan, DType> +MakePlan(const MatChooseRowElementExp &exp) { + return Plan, DType>(exp); +} + +template +struct ShapeCheck > { + inline static Shape + Check(const MatChooseRowElementExp &t) { + CHECK(dim == 1) + << "MatChooseRowElementExp only support 1 dimension output"; + Shape<2> shape1 = ShapeCheck<2, SrcExp>::Check(t.src_); + Shape shape2 = ShapeCheck::Check(t.index_); + CHECK_EQ(shape1[0], shape2[0]) + << "mat_choose_row_element index length and number of rows in matrix"; + return shape2; + } +}; + +template +struct ExpInfo > { + static const int kDim = 1; + static const int kDevMask = ExpInfo::kDevMask & ExpInfo::kDevMask; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_CHOOSE_H_ diff --git a/include/mshadow/extension/complex.h b/include/mshadow/extension/complex.h new file mode 100644 index 000000000000..8e79b7eb819c --- /dev/null +++ b/include/mshadow/extension/complex.h @@ -0,0 +1,525 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file complex.h + * \brief support for complex operations + * \author Xingjian Shi + */ +#ifndef MSHADOW_EXTENSION_COMPLEX_H_ +#define MSHADOW_EXTENSION_COMPLEX_H_ +#include +#include "../extension.h" + +namespace mshadow { +namespace op { +namespace complex { +enum BinaryCalculationType { kBinaryCC, kBinaryCR, kBinaryRC}; +enum UnitaryCalculationType { kUnitaryC2R, kUnitaryC2C, kUnitaryR2C }; +struct mul { + /*! \brief map a_real, a_imag, b_real, b_imag to result using defined operation */ + template + MSHADOW_XINLINE static DType RealMap(DType a_real, DType a_imag, + DType b_real, DType b_imag) { + return a_real * b_real - a_imag * b_imag; + } + template + MSHADOW_XINLINE static DType ImagMap(DType a_real, DType a_imag, + DType b_real, DType b_imag) { + return a_real * b_imag + b_real * a_imag; + } +}; + +struct div { + /*! \brief map a_real, a_imag, b_real, b_imag to result using defined operation */ + template + MSHADOW_XINLINE static DType RealMap(DType a_real, DType a_imag, + DType b_real, DType b_imag) { + return (a_real * b_real + a_imag * b_imag) / (b_real * b_real + b_imag * b_imag); + } + template + MSHADOW_XINLINE static DType ImagMap(DType a_real, DType a_imag, + DType b_real, DType b_imag) { + return (b_real * a_imag - a_real * b_imag) / (b_real * b_real + b_imag * b_imag); + } +}; + +struct conjugate { + template + MSHADOW_XINLINE static DType RealMap(const expr::Plan &src_, + index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) { + return src_.Eval(real_i, real_j); + } + template + MSHADOW_XINLINE static DType ImagMap(const expr::Plan &src_, + index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) { + return -src_.Eval(imag_i, imag_j); + } +}; + +struct exchange { + template + MSHADOW_XINLINE static DType RealMap(const expr::Plan &src_, + index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) { + return src_.Eval(imag_i, imag_j); + } + template + MSHADOW_XINLINE static DType ImagMap(const expr::Plan &src_, + index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) { + return src_.Eval(real_i, real_j); + } +}; + +// r2c operator +struct pad_imag { + template + MSHADOW_XINLINE static DType RealMap(const expr::Plan &src_, + index_t real_i, index_t real_j) { + return src_.Eval(real_i, real_j); + } + template + MSHADOW_XINLINE static DType ImagMap(const expr::Plan &src_, + index_t real_i, index_t real_j) { + return 0; + } +}; + +// c2r operator +struct toreal { + template + MSHADOW_XINLINE static DType RealMap(const expr::Plan &src_, + index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) { + DType real_val = src_.Eval(real_i, real_j); + return real_val; + } +}; + +struct abs_square { + template + MSHADOW_XINLINE static DType RealMap(const expr::Plan &src_, + index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) { + DType real_val = src_.Eval(real_i, real_j); + DType image_val = src_.Eval(imag_i, imag_j); + return real_val * real_val + image_val * image_val; + } +}; + +struct sum_real_imag { + template + MSHADOW_XINLINE static DType RealMap(const expr::Plan &src_, + index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) { + DType real_val = src_.Eval(real_i, real_j); + DType image_val = src_.Eval(imag_i, imag_j); + return real_val + image_val; + } +}; +} // namespace complex +} // namespace op + +namespace expr { +//-------------------- +// ComplexBinaryMapExp +//-------------------- + /*! +* \brief binary map expression lhs [op] rhs where lhs and rhs are complex tensors +* \tparam OP operator +* \tparam calctype type of the calculation +* \tparam TA type of lhs +* \tparam TB type of rhs +* \tparam etype expression type, sa namespace::type +*/ +template +struct ComplexBinaryMapExp : public Exp, + DType, etype> { + /*! \brief left operand */ + const TA &lhs_; + /*! \brief right operand */ + const TB &rhs_; + /*! \brief constructor */ + explicit ComplexBinaryMapExp(const TA &lhs, const TB &rhs) + :lhs_(lhs), rhs_(rhs) {} +}; + +//------------------- +// ComplexConjExp +//------------------- +/*! +* \brief compute conj(src) where src is a complex tensor +* \tparam TA type of src +* \tparam etype expression type, sa namespace::type +*/ +template +struct ComplexUnitaryExp : public Exp, + DType, etype> { + /*! \brief source expression */ + const TA &src_; + /*! \brief constructor */ + explicit ComplexUnitaryExp(const TA &src) : src_(src) {} +}; + + + +template +inline ComplexBinaryMapExp +ComplexF(const Exp &lhs, const Exp &rhs) { + return ComplexBinaryMapExp(lhs.self(), rhs.self()); +} + +/*! +* \brief conj Negation the imaginary part of A where A is a complex tensor +* \param src source tensor +* \tparam e1 type of source expression +*/ +template +inline ComplexUnitaryExp +ComplexF(const Exp &src) { + return ComplexUnitaryExp(src.self()); +} + +/*! +* \brief complex_mul_cc Complex multipilication two complex tensors, A * B +*/ +template +inline ComplexBinaryMapExp +complex_mul_cc(const Exp &lhs, const Exp &rhs) { + return ComplexF(lhs, rhs); +} + +/*! +* \brief complex_mul_cr Complex multipilication a complex tensor A and a real tensor B +*/ +template +inline ComplexBinaryMapExp +complex_mul_cr(const Exp &lhs, const Exp &rhs) { + return ComplexF(lhs, rhs); +} + +/*! +* \brief complex_mul_rc Complex multipilication of a real tensor B and a complex tensor A +*/ +template +inline ComplexBinaryMapExp +complex_mul_rc(const Exp &lhs, const Exp &rhs) { + return ComplexF(lhs, rhs); +} + +/*! +* \brief complex_mul_cc Complex multipilication two complex tensors, A * B +*/ +template +inline ComplexBinaryMapExp +complex_div_cc(const Exp &lhs, const Exp &rhs) { + return ComplexF(lhs, rhs); +} + +/*! +* \brief complex_mul_cr Complex multipilication a complex tensor A and a real tensor B +*/ +template +inline ComplexBinaryMapExp +complex_div_cr(const Exp &lhs, const Exp &rhs) { + return ComplexF(lhs, rhs); +} + +/*! +* \brief complex_mul_rc Complex multipilication of a real tensor A and a complex tensor B +*/ +template +inline ComplexBinaryMapExp +complex_div_rc(const Exp &lhs, const Exp &rhs) { + return ComplexF(lhs, rhs); +} + +/*! +* \brief conj Negation the imaginary part of A where A is a complex tensor +* \param src source tensor +* \tparam e1 type of source expression +*/ +template +inline ComplexUnitaryExp +conj(const Exp &src) { + return ComplexF(src); +} + +/*! +* \brief complex_exchange Exchange the real and imaginary part of A where A is a complex tensor +* \param src source tensor +* \tparam e1 type of source expression +*/ +template +inline ComplexUnitaryExp +complex_exchange(const Exp &src) { + return ComplexF(src); +} + +/*! +* \brief complex_pad_imag Transform real matrix into complex matrix +* \param src source tensor +* \tparam e1 type of source expression +*/ +template +inline ComplexUnitaryExp +complex_pad_imag(const Exp &src) { + return ComplexF(src); +} + +/*! +* \brief complex_toreal convert complex matrix to real matrix, keep only real part +* \param src source tensor +* \tparam e1 type of source expression +*/ +template +inline ComplexUnitaryExp +complex_toreal(const Exp &src) { + return ComplexF(src); +} + +/*! +* \brief complex_abs_square calculate the square of the modulus of A where A is a complex tensor +* \param src source tensor +* \tparam e1 type of source expression +*/ +template +inline ComplexUnitaryExp +complex_abs_square(const Exp &src) { + return ComplexF(src); +} + +template +inline ComplexUnitaryExp +complex_sum_real_imag(const Exp &src) { + return ComplexF(src); +} + +template +struct ShapeCheck > { + inline static Shape + Check(const ComplexBinaryMapExp &t) { + Shape shape1 = ShapeCheck::Check(t.lhs_); + Shape shape2 = ShapeCheck::Check(t.rhs_); + if (shape1[0] == 0) return shape2; + if (shape2[0] == 0) return shape1; + if (calctype == op::complex::kBinaryCC) { + CHECK_EQ(shape1, shape2) << "ComplexBinaryMapExp (CC): Shapes of operands are not the same."; + CHECK_EQ(shape1[dim - 1] % 2, 0) << + "ComplexBinaryMapExp (CC): Shape of the last dimension is not even. " + "We must have real part + imaginary part."; + return shape1; + } else if (calctype == op::complex::kBinaryCR) { + for (int i = 0; i < dim - 1; ++i) { + CHECK_EQ(shape1.shape_[i], shape2.shape_[i]) << + "ComplexBinaryMapExp (CR): Shapes of operands are not the same."; + } + CHECK_EQ(shape1[dim - 1], shape2[dim - 1] * 2) << + "ComplexBinaryMapExp (CR): Shapes of operands do not match."; + return shape1; + } else if (calctype == op::complex::kBinaryRC) { + for (int i = 0; i < dim - 1; ++i) { + CHECK_EQ(shape1.shape_[i], shape2.shape_[i]) << + "ComplexBinaryMapExp (RC): Shapes of operands are not the same."; + } + CHECK_EQ(shape2[dim - 1], shape1[dim - 1] * 2) << + "ComplexBinaryMapExp (RC): Shapes of operands do not match."; + return shape2; + } else { + LOG(FATAL) << "ComplexBinaryMapExp: Unexpected Calculation Type!"; + return shape1; + } + } +}; + +template +struct ShapeCheck > { + inline static Shape Check(const ComplexUnitaryExp &t) { + Shape s = ShapeCheck::Check(t.src_); + CHECK_EQ(s[dim - 1] % 2, 0) << "ComplexUnitaryExp: Shape of the last dimension is not even. " + "We must have real + imaginary."; + if (calctype == op::complex::kUnitaryC2C) { + return s; + } else if (calctype == op::complex::kUnitaryC2R) { + Shape s_ret = s; + s_ret[dim - 1] /= 2; + return s_ret; + } else if (calctype == op::complex::kUnitaryR2C) { + Shape s_ret = s; + s_ret[dim-1] *= 2; + return s_ret; + } else { + LOG(FATAL) << "ComplexUnitaryExp: Unexpected Calculation Type!"; + return s; + } + } +}; + + + +// complex binary expression (cc) +template +class Plan, DType> { + public: + explicit Plan(const Plan &lhs, const Plan &rhs) + : lhs_(lhs), rhs_(rhs) {} + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + const index_t base_x = static_cast(x / 2) * 2; + if (x % 2 == 0) { + return OP::RealMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1), + rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1)); + } else { + return OP::ImagMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1), + rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1)); + } + } + + private: + Plan lhs_; + Plan rhs_; +}; + +// complex binary expression (cr) +template +class Plan, DType> { + public: + explicit Plan(const Plan &lhs, const Plan &rhs) + : lhs_(lhs), rhs_(rhs) {} + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + const index_t base_x = static_cast(x / 2) * 2; + if (x % 2 == 0) { + return OP::RealMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1), + rhs_.Eval(y, base_x / 2), static_cast(0)); + } else { + return OP::ImagMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1), + rhs_.Eval(y, base_x / 2), static_cast(0)); + } + } + + private: + Plan lhs_; + Plan rhs_; +}; + + +// complex binary expression (rc) +template +class Plan, DType> { + public: + explicit Plan(const Plan &lhs, const Plan &rhs) + : lhs_(lhs), rhs_(rhs) {} + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + const index_t base_x = static_cast(x / 2) * 2; + if (x % 2 == 0) { + return OP::RealMap(lhs_.Eval(y, base_x / 2), static_cast(0), + rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1)); + } else { + return OP::ImagMap(lhs_.Eval(y, base_x / 2), static_cast(0), + rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1)); + } + } + + private: + Plan lhs_; + Plan rhs_; +}; + + +// complex unitary expression (c2c) +template +class Plan, DType> { + public: + explicit Plan(const Plan &src) : src_(src) {} + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + const index_t base_x = static_cast(x / 2) * 2; + if (0 == x % 2) { + return OP::RealMap(src_, y, base_x, y, base_x + 1); + } else { + return OP::ImagMap(src_, y, base_x, y, base_x + 1); + } + } + + private: + Plan src_; +}; + +// complex unitary expression (r2c) +template +class Plan, DType> { + public: + explicit Plan(const Plan &src) : src_(src) {} + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + const index_t real_x = static_cast(x / 2); + if (0 == x%2) { + // x,y should be coordinates in the complex matrix + // this defines how we will give value to the real part from the real matrix src_, + // thus the index has only 2 dimensions + return OP::RealMap(src_, y, real_x); + } else { + return OP::ImagMap(src_, y, real_x); + } + } + + private: + Plan src_; +}; + +// complex unitary expression (c2r) +template +class Plan, DType> { + public: + explicit Plan(const Plan &src) : src_(src) {} + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + return OP::RealMap(src_, y, x * 2, y, x * 2 + 1); + } + + private: + Plan src_; +}; + + + +template +inline Plan, DType> +MakePlan(const ComplexBinaryMapExp &e) { + return Plan, + DType>(MakePlan(e.lhs_), MakePlan(e.rhs_)); +} + +template +inline Plan, DType> +MakePlan(const ComplexUnitaryExp &e) { + return Plan, + DType>(MakePlan(e.src_)); +} + + + +template +struct ExpInfo > { + static const int kDimLhs = ExpInfo::kDim; + static const int kDimRhs = ExpInfo::kDim; + static const int kDim = (kDimLhs >= 0 && kDimRhs >= 0) ? \ + (kDimLhs == 0 ? \ + kDimRhs : \ + ((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1; + static const int kDevMask = ExpInfo::kDevMask & ExpInfo::kDevMask; +}; + +template +struct ExpInfo > { + static const int kDim = ExpInfo::kDim; + static const int kDevMask = ExpInfo::kDevMask; +}; + +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_COMPLEX_H_ diff --git a/include/mshadow/extension/concat.h b/include/mshadow/extension/concat.h new file mode 100644 index 000000000000..c51b1dcb0a26 --- /dev/null +++ b/include/mshadow/extension/concat.h @@ -0,0 +1,194 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file concat.h + * \brief support for concatenation + */ +#ifndef MSHADOW_EXTENSION_CONCAT_H_ +#define MSHADOW_EXTENSION_CONCAT_H_ + +#include "../extension.h" + +namespace mshadow { +namespace expr { +/*! + * \brief concat expression, concat two tensor's channel + * \tparam LhsExp left expression + * \tparam RhsExp right expression + * \tparam DType the type of elements + * \tparam srcdim dimension of src + * \tparam dimsrc_m_cat dimsrc - dimcat + */ +template +struct ConcatExp : public TRValue, + Device, srcdim, DType> { + static const int dimcat = srcdim - dimsrc_m_cat; + const LhsExp &src1_; + const RhsExp &src2_; + index_t dcat_src1_; + index_t dcat_src2_; + Shape<4> shape_; + ConcatExp(const LhsExp &src1, const RhsExp &src2) : src1_(src1), src2_(src2) { + Shape sshape1 = ShapeCheck::Check(src1_); + Shape sshape2 = ShapeCheck::Check(src2_); + #pragma unroll + for (int i = 0; i < srcdim; ++i) { + if (i != dimcat) { + CHECK_EQ(sshape1[i], sshape2[i]) << "ConcatExp: shape mismatch"; + } + } + this->shape_ = sshape1; + this->shape_[dimcat] = sshape1[dimcat] + sshape2[dimcat]; + this->dcat_src1_ = sshape1[dimcat]; + this->dcat_src2_ = sshape2[dimcat]; + } + template + inline void + operator=(const expr::Exp &exp) { + this->__assign(exp); + } + inline void + operator=(const DType &exp) { + this->__assign(exp); + } +}; // struct ConcatExp +/*! + * \brief concat two 4D tensor + * \param src1 source tensor1 + * \param src2 source tensor2 + * \return concated 4D tensor + * \tparam cdim the dimension to concatnate on + * \tparam SrcExp source expression + * \tparam DType the type of elements + * \tparam etype type of expression + */ +template +inline ConcatExp +concat(const TRValue &src1, + const TRValue &src2) { + TypeCheckPass::kDim == ExpInfo::kDim> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + TypeCheckPass::kDim == srcdim> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + return ConcatExp + (src1.self(), src2.self()); +} +//------------------------ +// engine plugin +//------------------------ +// runtime shapecheck +template +struct ShapeCheck >{ + inline static Shape Check(const ConcatExp &t) { + return t.shape_; + } +}; +template +struct StreamInfo >{ + inline static Stream * + Get(const ConcatExp &t) { + Stream *lhs = StreamInfo::Get(t.src1_); + Stream *rhs = StreamInfo::Get(t.src2_); + if (lhs != rhs) return NULL; + return lhs; + } +}; +// static typecheck +template +struct ExpInfo >{ + static const int kDimLhs = ExpInfo::kDim; + static const int kDimRhs = ExpInfo::kDim; + // copy from binarymap + static const int kDim = (kDimLhs >= 0 && kDimRhs >= 0) ?\ + (kDimLhs == 0 ?\ + kDimRhs :\ + ((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1; + static const int kDevMask = ExpInfo::kDevMask & ExpInfo::kDevMask; +}; +//---------------------- +// Execution plan +//--------------------- +template +struct Plan, DType> { + public: + static const int dimcat = srcdim - dimsrc_m_cat; + explicit Plan(const ConcatExp &e) + : src1_(MakePlan(e.src1_)), src2_(MakePlan(e.src2_)), + height_(e.shape_.ProdShape(dimcat + 1, srcdim - 1)), + ch_src1_(e.dcat_src1_), ch_src2_(e.dcat_src2_), ch_(e.shape_[dimcat]) {} + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + const index_t y = i % height_; + i /= height_; + const index_t c = i % ch_; + const index_t b = i / ch_; + const index_t x = j; + if (c < ch_src1_) { + return src1_.Eval((b * ch_src1_ + c) * height_ + y, x); + } else { + return src2_.Eval((b * ch_src2_ + c - ch_src1_) * height_ + y, x); + } + } + MSHADOW_XINLINE DType &REval(index_t i, index_t j) { + const index_t y = i % height_; + i /= height_; + const index_t c = i % ch_; + const index_t b = i / ch_; + const index_t x = j; + if (c < ch_src1_) { + return src1_.REval((b * ch_src1_ + c) * height_ + y, x); + } else { + return src2_.REval((b * ch_src2_ + c - ch_src1_) * height_ + y, x); + } + } + + private: + Plan src1_; + Plan src2_; + const index_t height_, ch_src1_, ch_src2_, ch_; +}; // struct Plan + +// specialize for concat in x +template +struct Plan, DType> { + public: + explicit Plan(const ConcatExp &e) + : src1_(MakePlan(e.src1_)), src2_(MakePlan(e.src2_)), + width_src1_(e.dcat_src1_) {} + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + if (x < width_src1_) { + return src1_.Eval(y, x); + } else { + return src2_.Eval(y, x - width_src1_); + } + } + MSHADOW_XINLINE DType &REval(index_t y, index_t x) { + if (x < width_src1_) { + return src1_.REval(y, x); + } else { + return src2_.REval(y, x - width_src1_); + } + } + + private: + Plan src1_; + Plan src2_; + const index_t width_src1_; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_CONCAT_H_ diff --git a/include/mshadow/extension/crop.h b/include/mshadow/extension/crop.h new file mode 100644 index 000000000000..80096a2d22d3 --- /dev/null +++ b/include/mshadow/extension/crop.h @@ -0,0 +1,119 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file crop.h + * \brief support for crop + * \author Tianqi Chen + */ +#ifndef MSHADOW_EXTENSION_CROP_H_ +#define MSHADOW_EXTENSION_CROP_H_ +#include "../extension.h" +namespace mshadow { +namespace expr { +/*! + * \brief crop expression, cut off the boundary region, reverse operation of padding + * \tparam SrcExp source expression to be pooled from + * \tparam DType the type of elements + * \tparam srcdim dimension of src + */ +template +struct CroppingExp: + public MakeTensorExp, + SrcExp, srcdim, DType> { + /*! \brief source operand */ + const SrcExp &src_; + /*! \brief pad height */ + index_t pad_height_; + /*! \brief pad height */ + index_t pad_width_; + /*! \brief src height */ + index_t src_height_; + /*! \brief constructor */ + explicit CroppingExp(const SrcExp &src, Shape<2> cshape) + : src_(src) { + this->shape_ = ShapeCheck::Check(src_); + CHECK_GE(this->shape_[srcdim - 2], cshape[0]) << "CroppingExp: height requirement not met"; + CHECK_GE(this->shape_[srcdim - 1], cshape[1]) << "CroppingExp: width requirement not met"; + pad_height_ = (this->shape_[srcdim - 2] - cshape[0]) / 2; + pad_width_ = (this->shape_[srcdim - 1] - cshape[1]) / 2; + src_height_ = this->shape_[srcdim - 2]; + this->shape_[srcdim - 2] = cshape[0]; // height + this->shape_[srcdim - 1] = cshape[1]; // width + } + /*! \brief constructor */ + explicit CroppingExp(const SrcExp &src, Shape<2> cshape, + index_t start_height, index_t start_width) + : src_(src), pad_height_(start_height), pad_width_(start_width) { + this->shape_ = ShapeCheck::Check(src_); + CHECK_GE(this->shape_[srcdim - 2], cshape[0] + start_height) + << "CroppingExp: height requirement not met"; + CHECK_GE(this->shape_[srcdim - 1], cshape[1] + start_width) + << "CroppingExp: width requirement not met"; + src_height_ = this->shape_[srcdim - 2]; + this->shape_[srcdim - 2] = cshape[0]; // height + this->shape_[srcdim - 1] = cshape[1]; // width + } +}; // struct CroppingExp +/*! + * \brief revserse operationg of padding, cut off boundaries, + * crop output from center of input + * \param src original image batches + * \param oshape output shape to be cropped + * \return expression corresponding to padded result + * \tparam SrcExp source expression + * \tparam DType the type of elements + * \tparam etype type of expression + */ +template +inline CroppingExp::kDim> +crop(const Exp &src, Shape<2> oshape) { + TypeCheckPass::kDim >= 2> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + return CroppingExp::kDim>(src.self(), oshape); +} +/*! + * \brief same as crop, but can specify starting position to do cropping + * \param src original image batches + * \param oshape output shape to be cropped + * \param start_height start height position to do cropping + * \param start_width start width position to do cropping + * \return expression corresponding to padded result + * \tparam SrcExp source expression + * \tparam DType the type of elements + * \tparam etype type of expression + */ +template +inline CroppingExp::kDim> +crop(const Exp &src, Shape<2> oshape, + index_t start_height, index_t start_width) { + TypeCheckPass::kDim >= 2> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + return CroppingExp::kDim> + (src.self(), oshape, start_height, start_width); +} +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + explicit Plan(const CroppingExp &e) + : src_(MakePlan(e.src_)), + pad_height_(e.pad_height_), pad_width_(e.pad_width_), + new_height_(e.shape_[srcdim - 2]), src_height_(e.src_height_) {} + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + const index_t x = j; + const index_t y = i % new_height_; + const index_t c = i / new_height_; + const index_t h = y + pad_height_; + const index_t w = x + pad_width_; + return src_.Eval(c * src_height_ + h, w); + } + private: + Plan src_; + const index_t pad_height_, pad_width_; + const index_t new_height_; + const index_t src_height_; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_CROP_H_ diff --git a/include/mshadow/extension/fill.h b/include/mshadow/extension/fill.h new file mode 100644 index 000000000000..4ac62c1673e5 --- /dev/null +++ b/include/mshadow/extension/fill.h @@ -0,0 +1,103 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file fill.h + * \brief support for implicit array filling operation + * \author Xingjian Shi + */ +#ifndef MSHADOW_EXTENSION_FILL_H_ +#define MSHADOW_EXTENSION_FILL_H_ + +#include "../extension.h" + + +namespace mshadow { +namespace expr { +/*! + * \brief Set value of a specific element in each line of the data matrix. + * \tparam SrcExp type of src expression + * \tparam ValExp type of val expression + * \tparam IndexExp type of index expression + * \tparam DType the type of ret expression + */ +template +struct MatFillRowElementExp: + public Exp, + DType, type::kChainer> { + /*! \brief src operand */ + const SrcExp &src_; + const ValExp &val_; + /*! \brief index operand */ + const IndexExp &index_; + /*! \brief constructor */ + MatFillRowElementExp(const SrcExp &src, const ValExp &val, const IndexExp &index) + : src_(src), val_(val), index_(index) {} +}; + +template +inline MatFillRowElementExp +mat_fill_row_element(const Exp &src, + const Exp &val, + const Exp &index) { + TypeCheckPass::kDim == 2 && ExpInfo::kDim == 1 + && ExpInfo::kDim == 1>::Error_Expression_Does_Not_Meet_Dimension_Req(); + return MatFillRowElementExp(src.self(), + val.self(), index.self()); +} + +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + explicit Plan(const MatFillRowElementExp &e) + : src_(MakePlan(e.src_)), + val_(MakePlan(e.val_)), + index_(MakePlan(e.index_)) { + } + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + index_t idx = static_cast(index_.Eval(0, y)); + if (idx == x) { + return static_cast(val_.Eval(0, y)); + } else { + return static_cast(src_.Eval(y, x)); + } + } + + private: + expr::Plan src_; + expr::Plan val_; + expr::Plan index_; +}; + +template +inline Plan, DType> +MakePlan(const MatFillRowElementExp &exp) { + return Plan, DType>(exp); +} + +template +struct ShapeCheck > { + inline static Shape + Check(const MatFillRowElementExp &t) { + CHECK(dim == 2) + << "MatFillRowElementExp only support 2 dimension output"; + Shape<2> shape_src = ShapeCheck<2, SrcExp>::Check(t.src_); + Shape<1> shape_val = ShapeCheck<1, ValExp>::Check(t.val_); + Shape<1> shape_index = ShapeCheck<1, IndexExp>::Check(t.index_); + CHECK((shape_src[0] == shape_index[0]) && (shape_index[0] == shape_val[0])) + << "mat_fill_row_element index length, val length and number of rows in matrix"; + return shape_src; + } +}; + +template +struct ExpInfo > { + static const int kDim = 2; + static const int kDevMask = + ExpInfo::kDevMask & ExpInfo::kDevMask & ExpInfo::kDevMask; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_FILL_H_ diff --git a/include/mshadow/extension/flip.h b/include/mshadow/extension/flip.h new file mode 100644 index 000000000000..17d1894530fc --- /dev/null +++ b/include/mshadow/extension/flip.h @@ -0,0 +1,132 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file flip.h + * \brief support for flip a certain dimension. + * \author Junyuan Xie + */ +#ifndef MSHADOW_EXTENSION_FLIP_H_ +#define MSHADOW_EXTENSION_FLIP_H_ + +#include "../extension.h" + +namespace mshadow { +namespace expr { +/*! + * \brief slice expression, slice a tensor's channel + * \tparam SrcExp left expression + * \tparam DType the type of elements + * \tparam srcdim dimension of src + * \tparam dimsrc_m_cat dimsrc - dimcat + */ +template +struct FlipExp : public TRValue, + Device, srcdim, DType> { + const SrcExp &src_; + index_t trailing_; + index_t stride_; + index_t stride_j_; + Shape shape_; + FlipExp(const SrcExp &src, int dim) + : src_(src) { + shape_ = ShapeCheck::Check(src_); + stride_ = shape_[dim]; + stride_j_ = shape_[srcdim-1]; + trailing_ = 1; + for (int i = dim + 1; i < srcdim; ++i) { + trailing_ *= shape_[i]; + } + } + template + inline void + operator=(const expr::Exp &exp) { + this->__assign(exp); + } + inline void + operator=(const DType &exp) { + this->__assign(exp); + } +}; // struct Flip + +/*! + * \brief Flip a Tensor + * \param src source tensor + * \param begin The beginning slice. + * \param end The end slice. + * \return sliced tensor + * \tparam sdim the dimension to slice on + * \tparam SrcExp source expression + * \tparam DType the type of elements + * \tparam etype type of expression + */ +template +inline FlipExp +flip(const TRValue &src, int dim) { + return FlipExp(src.self(), dim); +} +//------------------------ +// engine plugin +//------------------------ +// runtime shapecheck +template +struct ShapeCheck >{ + inline static Shape Check(const FlipExp &t) { + return t.shape_; + } +}; +template +struct StreamInfo >{ + inline static Stream * + Get(const FlipExp &t) { + return StreamInfo::Get(t.src_); + } +}; +// static typecheck +template +struct ExpInfo >{ + static const int kDim = ExpInfo::kDim; + static const int kDevMask = ExpInfo::kDevMask; +}; +//---------------------- +// Execution plan +//--------------------- +template +struct Plan, DType> { + public: + explicit Plan(const FlipExp &e) + : src_(MakePlan(e.src_)), stride_j_(e.stride_j_), + trailing_(e.trailing_), stride_(e.stride_) {} + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + index_t idx = i*stride_j_+j; + const index_t low = idx%trailing_; + index_t high = idx/trailing_; + const index_t x = high%stride_; + high /= stride_; + idx = (high*stride_+stride_-1-x)*trailing_+low; + return src_.Eval(idx/stride_j_, idx%stride_j_); + } + MSHADOW_XINLINE DType &REval(index_t i, index_t j) const { + index_t idx = i*stride_j_+j; + const index_t low = idx%trailing_; + index_t high = idx/trailing_; + const index_t x = high%stride_; + high /= stride_; + idx = (high*stride_+stride_-1-x)*trailing_+low; + return src_.REval(idx/stride_j_, idx%stride_j_); + } + + private: + Plan src_; + const index_t stride_j_, trailing_, stride_; +}; // struct Plan +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_FLIP_H_ diff --git a/include/mshadow/extension/implicit_gemm.h b/include/mshadow/extension/implicit_gemm.h new file mode 100644 index 000000000000..b4b88ea326c8 --- /dev/null +++ b/include/mshadow/extension/implicit_gemm.h @@ -0,0 +1,128 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file implicit_gemm.h + * \brief support for implicit GEMM operation + * \author Tianqi Chen + */ +#ifndef MSHADOW_EXTENSION_IMPLICIT_GEMM_H_ +#define MSHADOW_EXTENSION_IMPLICIT_GEMM_H_ + +#include "../extension.h" +#include "../packet-inl.h" + +namespace mshadow { +namespace expr { +/*! + * \brief Matrix multiplication. + * \tparam LhsExp type of lhs expression + * \tparam LhsExp type of rhs expression + * \tparam DType the type of elements + */ +template +struct ImplicitGEMMExp: + public Exp, + DType, type::kChainer> { + /*! \brief lhs operand */ + const LhsExp &lhs_; + /*! \brief rhs operand */ + const RhsExp &rhs_; + /*! \brief internal production size*/ + index_t prod_size_; + /*! \brief the shape of this expression */ + Shape<2> shape_; + /*! \brief constructor */ + ImplicitGEMMExp(const LhsExp &lhs, const RhsExp &rhs) + : lhs_(lhs), rhs_(rhs) { + Shape<2> slhs = ShapeCheck<2, LhsExp>::Check(lhs_); + Shape<2> srhs = ShapeCheck<2, RhsExp>::Check(rhs_); + this->shape_ = mshadow::Shape2(slhs[0], srhs[1]); + prod_size_ = slhs[1]; + } +}; + + +template +inline ImplicitGEMMExp +implicit_dot(const Exp &lhs, + const Exp &rhs) { + TypeCheckPass::kDim == 2 && ExpInfo::kDim == 2> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + return ImplicitGEMMExp(lhs.self(), rhs.self()); +} + +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + explicit Plan(const ImplicitGEMMExp &e) + : lhs_(MakePlan(e.lhs_)), + rhs_(MakePlan(e.rhs_)), + prod_size_(e.prod_size_), + prod_size_lower_align_(packet::LowerAlign(e.prod_size_)) { + } + + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + typedef packet::Packet Packet; + Packet sum = Packet::Fill(0); + + const size_t packetSize = Packet::size; + DType lhs_temp[packetSize], rhs_temp[packetSize]; + + for (index_t i = 0; i < prod_size_lower_align_; i += packetSize) { + // unroll + for (index_t j = 0; j < packetSize; ++j) { + lhs_temp[j] = lhs_.Eval(y, i + j); + } + for (index_t j = 0; j < packetSize; ++j) { + rhs_temp[j] = rhs_.Eval(i + j, x); + } + sum = sum + Packet::LoadUnAligned(lhs_temp) * Packet::LoadUnAligned(rhs_temp); + } + DType ret_result = sum.Sum(); + + for (index_t i = prod_size_lower_align_; i < prod_size_; ++i) { + ret_result += lhs_.Eval(y, i) * rhs_.Eval(i, x); + } + return ret_result; + } + + private: + expr::Plan lhs_; + expr::Plan rhs_; + const index_t prod_size_; + const index_t prod_size_lower_align_; +}; + +template +inline Plan, DType> +MakePlan(const ImplicitGEMMExp &exp) { + return Plan, DType>(exp); +} + + +template +struct ShapeCheck > { + inline static Shape + Check(const ImplicitGEMMExp &t) { + CHECK(dim == 2) + << "ImplicitGEMMExp only support 2 dimension"; + Shape shape1 = ShapeCheck::Check(t.lhs_); + Shape shape2 = ShapeCheck::Check(t.rhs_); + CHECK_EQ(shape1[1], shape2[0]) + << "implicit_dot The matrix shape do not match"; + return t.shape_; + } +}; + +template +struct ExpInfo > { + static const int kDim = 2; + static const int kDevMask = ExpInfo::kDevMask & ExpInfo::kDevMask; +}; + +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_IMPLICIT_GEMM_H_ + diff --git a/include/mshadow/extension/mask.h b/include/mshadow/extension/mask.h new file mode 100644 index 000000000000..0fd4cc6db72e --- /dev/null +++ b/include/mshadow/extension/mask.h @@ -0,0 +1,97 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file mask.h + * \brief + * \author Bing Xu +*/ +#ifndef MSHADOW_EXTENSION_MASK_H_ +#define MSHADOW_EXTENSION_MASK_H_ + +#include "../extension.h" + +namespace mshadow { +namespace expr { + +/*! \brief Broadcast a mask and do element-wise multiplication + * \tparam IndexExp type of index expression + * \tparam SrcExp type of src expression + * \tparam DType data type + */ +template +struct MaskExp: public Exp, + DType, type::kChainer> { + /*! \brief index oprand */ + const IndexExp &index_; + /*! \brief matrix oprand */ + const SrcExp &src_; + /*! constructor */ + MaskExp(const IndexExp &index, const SrcExp &src) + : index_(index), src_(src) {} +}; // struct MaskExp + + + +template +inline MaskExp +mask(const Exp &index, + const Exp &src) { + return MaskExp(index.self(), src.self()); +} + + +//---------------------- +// Execution plan +//---------------------- + +template +struct Plan, DType> { + public: + explicit Plan(const MaskExp &e) + : index_(MakePlan(e.index_)), src_(MakePlan(e.src_)) { + } + + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + return static_cast(src_.Eval(y, x) * index_.Eval(0, y)); + } + + private: + expr::Plan index_; + expr::Plan src_; +}; // struct Plan + +template +inline Plan, DType> +MakePlan(const MaskExp &exp) { + return Plan, DType>(exp); +} + +template +struct ShapeCheck > { + inline static Shape + Check(const MaskExp &t) { + CHECK(dim == 2) + << "MaskExp only support 2D output"; + Shape<1> dshape = ShapeCheck<1, IndexExp>::Check(t.index_); + Shape<2> wshape = ShapeCheck<2, SrcExp>::Check(t.src_); + CHECK_EQ(dshape[0], wshape[0]) << "MaskExp require inputs in same first dimention"; + Shape ret; + ret[0] = wshape[0]; + ret[1] = wshape[1]; + return ret; + } +}; + + +template +struct ExpInfo > { + static const int kDim = 2; + static const int kDevMask = ExpInfo::kDevMask; +}; + +} // namespace expr +} // namespace mshadow + +#endif // MSHADOW_EXTENSION_MASK_H_ diff --git a/include/mshadow/extension/mirror.h b/include/mshadow/extension/mirror.h new file mode 100644 index 000000000000..9e9edc9b6f70 --- /dev/null +++ b/include/mshadow/extension/mirror.h @@ -0,0 +1,62 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file mirror.h + * \brief support for mirror + * \author Tianqi Chen + */ +#ifndef MSHADOW_EXTENSION_MIRROR_H_ +#define MSHADOW_EXTENSION_MIRROR_H_ +#include "../extension.h" +namespace mshadow { +namespace expr { +/*! + * \brief mirror expression, mirror a image in width + * \tparam SrcExp source expression to be mirrored + * \tparam DType the type of elements + * \tparam srcdim dimension of src + */ +template +struct MirroringExp: + public MakeTensorExp, + SrcExp, srcdim, DType> { + /*! \brief source operand */ + const SrcExp &src_; + /*! \brief constructor */ + explicit MirroringExp(const SrcExp &src) : src_(src) { + this->shape_ = ShapeCheck::Check(src_); + } +}; +/*! + * \brief mirroring expression, mirror images in width + * \param src original image batches + * \return expression corresponding to mirrored result + * \tparam SrcExp source expression + * \tparam DType the type of elements + * \tparam etype type of expression + */ +template +inline MirroringExp::kDim> +mirror(const Exp &src) { + TypeCheckPass::kDim >= 2> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + return MirroringExp::kDim>(src.self()); +} +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + explicit Plan(const MirroringExp &e) + : src_(MakePlan(e.src_)), width_(e.shape_[srcdim - 1]) {} + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + return src_.Eval(i, width_ - j - 1); + } + + private: + Plan src_; + const index_t width_; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_MIRROR_H_ diff --git a/include/mshadow/extension/one_hot.h b/include/mshadow/extension/one_hot.h new file mode 100644 index 000000000000..326d4c3560eb --- /dev/null +++ b/include/mshadow/extension/one_hot.h @@ -0,0 +1,87 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file one_hot.h + * \brief Create one-hot indicator array based on the index. + * \author Tianqi Chen + */ +#ifndef MSHADOW_EXTENSION_ONE_HOT_H_ +#define MSHADOW_EXTENSION_ONE_HOT_H_ + +#include "../extension.h" + + +namespace mshadow { +namespace expr { +/*! + * \brief Create a one-hot indicator array. + * \tparam IndexExp type of index expression + * \tparam DType the type of elements + */ +template +struct OneHotEncodeExp: + public Exp, + DType, type::kChainer> { + /*! \brief index operand */ + const IndexExp &index_; + /*! \brief number of choices we can have. */ + index_t num_choices_; + /*! \brief constructor */ + OneHotEncodeExp(const IndexExp &index, index_t num_choices) + : index_(index), num_choices_(num_choices) {} +}; + +template +inline OneHotEncodeExp +one_hot_encode(const Exp &index, index_t num_choices) { + TypeCheckPass::kDim == 1> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + return OneHotEncodeExp(index.self(), num_choices); +} + +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + explicit Plan(const OneHotEncodeExp &e) + : index_(MakePlan(e.index_)) { + } + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + index_t idx = static_cast(index_.Eval(0, y)); + return static_cast(x == idx); + } + + private: + expr::Plan index_; +}; + +template +inline Plan, DType> +MakePlan(const OneHotEncodeExp &exp) { + return Plan, DType>(exp); +} + +template +struct ShapeCheck > { + inline static Shape + Check(const OneHotEncodeExp &t) { + CHECK(dim == 2) + << "OneHotEncodeExp only support 2 dimension output"; + Shape<1> shape = ShapeCheck<1, IndexExp>::Check(t.index_); + Shape ret; + ret[0] = shape[0]; + ret[1] = t.num_choices_; + return ret; + } +}; + +template +struct ExpInfo > { + static const int kDim = 2; + static const int kDevMask = ExpInfo::kDevMask; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_ONE_HOT_H_ diff --git a/include/mshadow/extension/pack_col2patch.h b/include/mshadow/extension/pack_col2patch.h new file mode 100644 index 000000000000..37f1a699ead5 --- /dev/null +++ b/include/mshadow/extension/pack_col2patch.h @@ -0,0 +1,154 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file pack_col2patch.h + * \brief support for pack + * \author Tianqi Chen + */ +#ifndef MSHADOW_EXTENSION_PACK_COL2PATCH_H_ +#define MSHADOW_EXTENSION_PACK_COL2PATCH_H_ +#include +#include "../extension.h" +namespace mshadow { +namespace expr { +/*! + * \brief reverse operation of UnpackPatchToCol, + * used to backprop gradient back + * this is a version supporting multiple images + * \tparam SrcExp source expression + * \tparam DType the type of elements + * \tparam dstdim destination dimension + */ +template +struct PackColToPatchXExp: + public MakeTensorExp, + SrcExp, dstdim, DType> { + /*! \brief source operand */ + const SrcExp &src_; + /*! \brief patch height */ + index_t psize_y_; + /*! \brief patch height */ + index_t psize_x_; + /*! \brief patch stride */ + index_t pstride_y_; + index_t pstride_x_; + /*! \brief patch dilate */ + index_t pdilate_y_; + index_t pdilate_x_; + /*! \brief constructor */ + PackColToPatchXExp(const SrcExp &src, Shape imshape, + index_t psize_y, index_t psize_x, + index_t pstride_y, index_t pstride_x, + index_t pdilate_y, index_t pdilate_x) + :src_(src), psize_y_(psize_y), psize_x_(psize_x), + pstride_y_(pstride_y), pstride_x_(pstride_x), + pdilate_y_(pdilate_y), pdilate_x_(pdilate_x){ + this->shape_ = imshape; + const index_t o_height = (imshape[dstdim - 2] - + (pdilate_y * (psize_y - 1)+ 1))/pstride_y + 1; + const index_t o_width = (imshape[dstdim - 1] - + (pdilate_x * (psize_x - 1) + 1)) / pstride_x + 1; + Shape<2> sshape = ShapeCheck<2, SrcExp>::Check(src_); + CHECK_EQ(sshape[1], o_height * o_width * imshape.ProdShape(0, dstdim - 3)) + << "PackColToPatchExp: src.size(1) mismatch"; + CHECK_EQ(sshape[0], psize_y * psize_x * imshape[dstdim - 3]) + << "PackColToPatchExp: src.size(0) mismatch"; + } +}; +/*! + * \brief reverse operation of pack_col2patch, can be used to implement deconvolution + * \return packed img expression + * \param mat source matrix + * \param imshape shape of target img + * \param psize_y height of each patch + * \param psize_x height of each patch + * \param pstride stride of each patch + * \tparam SrcExp source expression + * \tparam DType the type of elements + * \tparam dstdim destination dimension + * \tparam etype type of expression + */ +template +inline PackColToPatchXExp +pack_col2patch(const expr::Exp &src, + Shape imshape, index_t psize_y, + index_t psize_x, index_t pstride, index_t pdilate) { + TypeCheckPass::kDim == 2> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y) + << "PackColToPatch:image shape smaller than patch size"; + return PackColToPatchXExp(src.self(), imshape, + psize_y, psize_x, pstride, pstride, + pdilate, pdilate); +} +/*! + *if you want to specify kstride_y and kstride_x + */ +template +inline PackColToPatchXExp +pack_col2patch(const expr::Exp &src, + Shape imshape, index_t psize_y, + index_t psize_x, index_t pstride_y, index_t pstride_x, + index_t pdilate_y, index_t pdilate_x) { + TypeCheckPass::kDim == 2> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y) + << "PackColToPatch:image shape smaller than patch size"; + return PackColToPatchXExp(src.self(), imshape, + psize_y, psize_x, pstride_y, pstride_x, + pdilate_y, pdilate_x); +} + +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + explicit Plan(const PackColToPatchXExp &e) + :src_(MakePlan(e.src_)), psize_y_(e.psize_y_), + psize_x_(e.psize_x_), pstride_y_(e.pstride_y_), pstride_x_(e.pstride_x_), + i_channel_(e.shape_[dstdim - 3]), pdilate_y_(e.pdilate_y_), pdilate_x_(e.pdilate_x_), + i_height_(e.shape_[dstdim - 2]), + o_height_((e.shape_[dstdim - 2] - (pdilate_y_ * (psize_y_ - 1) + 1)) / + pstride_y_ + 1), + o_width_((e.shape_[dstdim - 1] - (pdilate_x_ * (psize_x_ - 1) + 1)) / + pstride_x_ + 1) { + // note: i/o convention are same as unpack + } + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + using namespace std; + const index_t y = i % i_height_; + const index_t idivh = i / i_height_; + const index_t c = idivh % i_channel_; + const index_t n = idivh / i_channel_; + const index_t x = j; + + const index_t psize_y_dilate = (pdilate_y_ * (psize_y_ - 1) + 1); + const index_t psize_x_dilate = (pdilate_x_ * (psize_x_ - 1) + 1); + + const index_t py_min = + y < psize_y_dilate ? y % pdilate_y_ : (y-psize_y_dilate + pstride_y_) / pstride_y_; + const index_t px_min = + x < psize_x_dilate ? x % pdilate_x_ : (x-psize_x_dilate + pstride_x_) / pstride_x_; + const index_t py_max = min((y + pstride_y_) / pstride_y_, o_height_); + const index_t px_max = min((x + pstride_x_) / pstride_x_, o_width_); + DType res = static_cast(0); + for (index_t py = py_min; py < py_max; py += pdilate_y_) { + for (index_t px = px_min; px < px_max; px += pdilate_x_) { + res += src_.Eval(((c * psize_y_ + (y - py*pstride_y_) / pdilate_y_) * psize_x_ + + (x - px * pstride_x_) / pdilate_x_), + (n * o_height_ + py) * o_width_ + px); + } + } + return res; + } + + private: + Plan src_; + const index_t psize_y_, psize_x_, pstride_y_, pstride_x_, i_channel_; + const index_t pdilate_y_, pdilate_x_; + const index_t i_height_, o_height_, o_width_; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_PACK_COL2PATCH_H_ diff --git a/include/mshadow/extension/pad.h b/include/mshadow/extension/pad.h new file mode 100644 index 000000000000..6622a022acc8 --- /dev/null +++ b/include/mshadow/extension/pad.h @@ -0,0 +1,111 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file pad.h + * \brief support for pad + * \author Tianqi Chen + */ +#ifndef MSHADOW_EXTENSION_PAD_H_ +#define MSHADOW_EXTENSION_PAD_H_ +#include "../extension.h" +namespace mshadow { +namespace expr { +/*! + * \brief padding expression, pad a image with zeros + * \tparam SrcExp source expression + * \tparam DType the type of elements + * \tparam srcdim dimension of src + */ +template +struct PaddingExp: + public MakeTensorExp, + SrcExp, srcdim, DType> { + /*! \brief source operand */ + const SrcExp &src_; + /*! \brief pad size in y */ + index_t pad_y_; + /*! \brief pad size in x */ + index_t pad_x_; + /*! \brief source tensor height */ + index_t src_height_; + /*! \brief source tensor width */ + index_t src_width_; + /*! \brief constructor */ + PaddingExp(const SrcExp &src, index_t pad_y, index_t pad_x) + : src_(src), pad_y_(pad_y), pad_x_(pad_x) { + this->shape_ = ShapeCheck::Check(src_); + src_height_ = this->shape_[srcdim - 2]; + src_width_ = this->shape_[srcdim - 1]; + this->shape_[srcdim - 2] += pad_y * 2; // height + this->shape_[srcdim - 1] += pad_x * 2; // width + } +}; +/*! + * \brief padding expression, pad a image with zeros on boundaries, padding affects shape[0], and shape[1] + * \param src original image batches + * \param pad padding size + * \return expression corresponding to padded result + * \tparam SrcExp source expression + * \tparam DType the content data type + * \tparam etype type of expression + */ +template +inline PaddingExp::kDim> +pad(const Exp &src, index_t pad) { + TypeCheckPass::kDim >= 2> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + return PaddingExp::kDim>(src.self(), pad, pad); +} +/*! + * \brief padding expression, pad a image with zeros on boundaries, padding affects shape[0], and shape[1] + * \param src original image batches + * \param pad_y padding size in y + * \param pad_x padding size in x + * \return expression corresponding to padded result + * \tparam SrcExp source expression + * \tparam DType the content data type + * \tparam etype type of expression + */ +template +inline PaddingExp::kDim> +pad(const Exp &src, index_t pad_y, index_t pad_x) { + TypeCheckPass::kDim >= 2> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + return PaddingExp::kDim> + (src.self(), pad_y, pad_x); +} +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + explicit Plan(const PaddingExp &e) + : src_(MakePlan(e.src_)), + pad_y_(e.pad_y_), pad_x_(e.pad_x_), + new_height_(e.shape_[srcdim - 2]), + src_height_(e.src_height_), src_width_(e.src_width_) {} + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + const index_t x = j; + const index_t y = i % new_height_; + const index_t c = i / new_height_; + if (y < pad_y_ || x < pad_x_) return static_cast(0); + const index_t h = y - pad_y_; + const index_t w = x - pad_x_; + if (h < src_height_ && w < src_width_) { + return src_.Eval(c * src_height_ + h, w); + } else { + return static_cast(0); + } + } + + private: + Plan src_; + const index_t pad_y_; + const index_t pad_x_; + const index_t new_height_; + const index_t src_height_; + const index_t src_width_; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_PAD_H_ diff --git a/include/mshadow/extension/range.h b/include/mshadow/extension/range.h new file mode 100644 index 000000000000..ab49b6e3cf18 --- /dev/null +++ b/include/mshadow/extension/range.h @@ -0,0 +1,118 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file range.h + * \brief support generating a range vector + * \author Xingjian Shi + */ +#ifndef MSHADOW_EXTENSION_RANGE_H_ +#define MSHADOW_EXTENSION_RANGE_H_ + +#include "../extension.h" + +namespace mshadow { +namespace expr { +/*! + * \brief Generate a range vector similar to python: range(start, stop[, step][, repeat]). + If step is positive, the last element is the largest start + i * step less than stop + If step is negative, the last element is the smallest start + i * step greater than stop. + All elements are repeated for `repeat` times, e.g range(0, 4, 2, 3) --> 0, 0, 0, 2, 2, 2 + * \tparam SrcExp type of lhs expression + * \tparam IndexExp type of index expression + * \tparam DType the type of elements + */ +template +struct RangeExp: + public Exp, DType, type::kMapper> { + const DType start_; + const DType stop_; + const DType step_; + const int repeat_; + /*! \brief constructor */ + RangeExp(DType start, DType stop, DType step, int repeat) + : start_(start), stop_(stop), step_(step), repeat_(repeat) {} +}; + +template +inline RangeExp +range(DType start, DType stop, DType step = 1, int repeat = 1) { + return RangeExp(start, stop, step, repeat); +} + +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + explicit Plan(const RangeExp &e) + : start_(e.start_), + stop_(e.stop_), + step_(e.step_), + repeat_(e.repeat_) { + } + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + return start_ + static_cast((static_cast(x) / repeat_)) * step_; + } + + private: + const DType start_; + const DType stop_; + const DType step_; + const int repeat_; +}; + +template +inline Plan, DType> +MakePlan(const RangeExp &exp) { + return Plan, DType>(exp); +} + + +template +inline int RangeOutSize(DType start, DType stop, DType step, int repeat) { + return repeat * ((stop - start - 1) / step + 1); +} + +template<> +inline int RangeOutSize(float start, float stop, float step, int repeat) { + double d_start = static_cast(start); + double d_stop = static_cast(stop); + double d_step = static_cast(step); + return repeat * static_cast(ceil((d_stop - d_start) / d_step)); +} + +template<> +inline int RangeOutSize(double start, double stop, double step, int repeat) { + return repeat * static_cast(ceil((stop - start) / step)); +} + + +template +struct ShapeCheck > { + inline static Shape + Check(const RangeExp &t) { + CHECK(dim == 1) + << "RangeExp only support 1 dimension output, received " << dim; + CHECK(t.step_ != 0) + << "RangeExp does not support step=0, received " << t.step_; + CHECK(t.repeat_ > 0) + << "RangeExp only supports repeat > 0, received " << t.repeat_; + if (t.step_ > 0) { + CHECK(t.start_ < t.stop_) << "RangeExp does not support (start, stop, step) = " + << "(" << t.start_ << "," << t.stop_ << "," << t.step_ << ")"; + } else { + CHECK(t.start_ > t.stop_) << "RangeExp does not support (start, stop, step)= " + << "(" << t.start_ << "," << t.stop_ << "," << t.step_ << ")"; + } + return Shape1(RangeOutSize(t.start_, t.stop_, t.step_, t.repeat_)); + } +}; + +template +struct ExpInfo > { + static const int kDim = 1; + static const int kDevMask = 0xffff; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_RANGE_H_ diff --git a/include/mshadow/extension/reduce_with_axis.h b/include/mshadow/extension/reduce_with_axis.h new file mode 100644 index 000000000000..54bcc750cfc5 --- /dev/null +++ b/include/mshadow/extension/reduce_with_axis.h @@ -0,0 +1,136 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file reduce_with_axis.h + * \brief + * \author Junyuan Xie +*/ +#ifndef MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_ +#define MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_ + +#include "../extension.h" + +namespace mshadow { +namespace expr { + +/*! \brief reduce out the dimension of src labeled by axis. + * \tparam Reducer type of reducer + * \tparam SrcExp type of source expression + * \tparam DType data type + */ +template +struct ReduceWithAxisExp: + public MakeTensorExp, + SrcExp, dimdst, DType> { + /*! \brief source oprand */ + const SrcExp &src_; + /*! \brief size of last destination dimension */ + index_t last_dst_dim_; + /*! \brief size of trailing dimensions */ + index_t trailing_; + /*! \brief size of axis dimension */ + index_t size_; + /*! \brief size of last src dimension */ + index_t last_; + /*! constructor */ + explicit ReduceWithAxisExp(const SrcExp &src, int axis) + : src_(src) { + bool keepdim = (dimsrc == dimdst); + CHECK(dimsrc > axis) << "reduce axis out of bound"; + Shape src_shape = ShapeCheck::Check(src_); + for (int i = 0; i < axis; ++i) { + this->shape_[i] = src_shape[i]; + } + this->size_ = src_shape[axis]; + this->trailing_ = 1; + if (!keepdim) { + for (int i = axis + 1; i < dimsrc; ++i) { + this->trailing_ *= src_shape[i]; + this->shape_[i - 1] = src_shape[i]; + } + } else { + this->shape_[axis] = 1; + for (index_t i = axis + 1; i < dimsrc; ++i) { + this->trailing_ *= src_shape[i]; + this->shape_[i] = src_shape[i]; + } + } + + this->last_ = src_shape[dimsrc - 1]; + this->last_dst_dim_ = this->shape_[dimdst - 1]; + } +}; // struct ReduceWithAxisExp + +/*! + * \brief reduce out the dimension of src labeled by axis. + * \param Reducer type of the reducing operation + * \param mask whether to output the unmask indices + * \tparam SrcExp source expression + * \tparam DType data type + * \tparam etype type of the expression + */ +template +inline ReduceWithAxisExp::kDim, mask, + ExpInfo::kDim - 1> +reduce_with_axis(const Exp &src, int axis) { + return ReduceWithAxisExp::kDim, mask, + ExpInfo::kDim- 1>(src.self(), axis); +} + +/*! +* \brief reduce out the dimension of src labeled by axis, keepdim turned on. +* \param Reducer type of the reducing operation +* \param mask whether to output the unmask indices +* \tparam SrcExp source expression +* \tparam DType data type +* \tparam etype type of the expression +*/ +template +inline ReduceWithAxisExp::kDim, mask, + ExpInfo::kDim> + reduce_keepdim(const Exp &src, int axis) { + return ReduceWithAxisExp::kDim, mask, + ExpInfo::kDim>(src.self(), axis); +} + +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + explicit Plan(const ReduceWithAxisExp &e) + : src_(MakePlan(e.src_)), last_dst_dim_(e.last_dst_dim_), trailing_(e.trailing_), + size_(e.size_), last_(e.last_) {} + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + index_t x = (i*last_dst_dim_ + j)/trailing_; + index_t y = (i*last_dst_dim_ + j)%trailing_; + + if (mask) { + index_t idx = 0; + DType res; Reducer::SetInitValue(res); + for (index_t k = 0; k < size_; ++k) { + index_t z = (x*size_+k)*trailing_+y; + DType tmp = res; + Reducer::Reduce(res, src_.Eval(z/last_, z%last_)); + if (tmp != res) { + idx = k; + } + } + return static_cast(static_cast(idx)); + } else { + DType res; Reducer::SetInitValue(res); + for (index_t k = 0; k < size_; ++k) { + index_t z = (x*size_+k)*trailing_+y; + Reducer::Reduce(res, src_.Eval(z/last_, z%last_)); + } + return res; + } + } + + private: + Plan src_; + const index_t last_dst_dim_, trailing_, size_, last_; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_ diff --git a/include/mshadow/extension/reduceto1d.h b/include/mshadow/extension/reduceto1d.h new file mode 100644 index 000000000000..09a478ab311e --- /dev/null +++ b/include/mshadow/extension/reduceto1d.h @@ -0,0 +1,104 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file reduceto1d.h + * \brief support for sum_rows and sumall_except_dim + * \author Tianqi Chen + */ +#ifndef MSHADOW_EXTENSION_REDUCETO1D_H_ +#define MSHADOW_EXTENSION_REDUCETO1D_H_ +#include "../extension.h" +namespace mshadow { +namespace expr { +/*! + * \brief reduction to 1 dimension tensor + * input: Tensor: ishape + * output: Tensor shape[0] = ishape[dimkeep]; + * + * \tparam SrcExp type of expression to be reduced + * \tparam DType the data type of the scalar + * \tparam Reducer which reducer to use + * \tparam m_dimkeep which dimension to be kept, encoded with dimsrc - dimkeep + */ +template +struct ReduceTo1DExp: + public Exp, + DType, type::kComplex> { + /*! \brief source operand */ + const SrcExp &src_; + /*! \brief source operand, scale of the */ + DType scale_; + /*! \brief construct a repmat expression from src and nrow */ + ReduceTo1DExp(const SrcExp& src, DType scale) : src_(src), scale_(scale) {} +}; +/*! + * \brief a sum over all dimensions, except dimkeep + * \param exp input expression that must be a matrix Tensor + * \return a expresion with type Tensor + * \tparam dimkeep the dimension that will be kept + * \tparam SrcExp expression + * \tparam etype type of expression + */ +template +inline ReduceTo1DExp::kDim - dimkeep> +sumall_except_dim(const Exp &exp) { + return ReduceTo1DExp::kDim - dimkeep>(exp.self(), DType(1)); +} +/*! + * \brief reduce over all dimensions, except dimkeep + * \param exp input expression that must be a matrix Tensor + * \return a expresion with type Tensor + * \tparam dimkeep the dimension that will be kept + * \tparam SrcExp expression + * \tparam etype type of expression + */ +template +inline ReduceTo1DExp::kDim - dimkeep> +reduce_except_dim(const Exp &exp) { + return ReduceTo1DExp::kDim - dimkeep>(exp.self(), DType(1)); +} +/*! + * \brief a expression that sum over rows of a matrix + * \param exp input expression that must be a matrix Tensor + * \return a expresion with type Tensor + * \tparam SrcExp expression + * \tparam etype type of expression + */ +template +inline ReduceTo1DExp +sum_rows(const Exp &exp) { + TypeCheckPass::kDim ==2> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + return sumall_except_dim<1>(exp); +} +template +struct ExpComplexEngine, + ReduceTo1DExp, + DType> { + static const int dimkeep = ExpInfo::kDim - m_dimkeep; + inline static void Eval(Tensor *dst, + const ReduceTo1DExp &exp) { + TypeCheckPass + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + MapReduceKeepHighDim(dst, exp.src_, exp.scale_); + } +}; +template +struct ExpComplexEngine, + ReduceTo1DExp, DType> { + inline static void Eval(Tensor *dst, + const ReduceTo1DExp &exp) { + MapReduceKeepLowest(dst, exp.src_, exp.scale_); + } +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_REDUCETO1D_H_ diff --git a/include/mshadow/extension/reshape.h b/include/mshadow/extension/reshape.h new file mode 100644 index 000000000000..b310fe69291a --- /dev/null +++ b/include/mshadow/extension/reshape.h @@ -0,0 +1,87 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file reshape.h + * \brief support for reshape + * \author Tianqi Chen + */ +#ifndef MSHADOW_EXTENSION_RESHAPE_H_ +#define MSHADOW_EXTENSION_RESHAPE_H_ +#include "../extension.h" +namespace mshadow { +namespace expr { +/*! + * \brief reshape the content to another shape + * input: Tensor: ishape + * output: Tensor ishape.Size() == oshape.Size() + * \tparam SrcExp source expression + * \tparam dimdst target dimension + * \tparam dimsrc source dimension + */ +template +struct ReshapeExp: + public MakeTensorExp, + SrcExp, dimdst, DType> { + /*! \brief source expression */ + const SrcExp &src_; + /*! \brief smallest dimension of input */ + index_t ishapex_; + /*! \brief constructor */ + ReshapeExp(const SrcExp &src, Shape shape) + : src_(src) { + Shape ishape = ShapeCheck::Check(src_); + CHECK_EQ(ishape.Size(), shape.Size()) << "reshape size must match"; + ishapex_ = ishape[dimsrc - 1]; + this->shape_ = shape; + } +}; +/*! + * \brief a expression that reshapes a tensor to another shape + * \param src Tensor: + * \param oshape target shape + * \return a expresion with type Tensor + * \tparam SrcExp source expression + * \tparam etype source expression type + * \tparam dimdst target dimension + */ +template +inline ReshapeExp::kDim> +reshape(const Exp &src, Shape oshape) { + return ReshapeExp::kDim> + (src.self(), oshape); +} +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + explicit Plan(const ReshapeExp &e) + : src_(MakePlan(e.src_)), + oshapex_(e.shape_[dimdst - 1]), ishapex_(e.ishapex_) {} + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + const index_t idx = y * oshapex_ + x; + return src_.Eval(idx / ishapex_, idx % ishapex_); + } + + private: + Plan src_; + const index_t oshapex_, ishapex_; +}; +// special work plan for 1 dimensional data +template +struct Plan, DType> { + public: + explicit Plan(const ReshapeExp &e) + : src_(MakePlan(e.src_)), oshapex_(e.shape_[dimdst - 1]) { + } + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + return src_.Eval(0, y * oshapex_ + x); + } + + private: + Plan src_; + const index_t oshapex_; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_RESHAPE_H_ diff --git a/include/mshadow/extension/slice.h b/include/mshadow/extension/slice.h new file mode 100644 index 000000000000..cb2eff4548aa --- /dev/null +++ b/include/mshadow/extension/slice.h @@ -0,0 +1,156 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file slice.h + * \brief support for slice a certain dimension. + */ +#ifndef MSHADOW_EXTENSION_SLICE_H_ +#define MSHADOW_EXTENSION_SLICE_H_ + +#include "../extension.h" + +namespace mshadow { +namespace expr { +/*! + * \brief slice expression, slice a tensor's channel + * \tparam SrcExp left expression + * \tparam DType the type of elements + * \tparam srcdim dimension of src + * \tparam dimsrc_m_cat dimsrc - dimcat + */ +template +struct SliceExp : public TRValue, + Device, srcdim, DType> { + static const int dimslice = srcdim - dimsrc_m_slice; + const SrcExp &src_; + index_t ch_begin_; + index_t ch_old_; + Shape shape_; + SliceExp(const SrcExp &src, index_t begin, index_t end) + : src_(src), ch_begin_(begin) { + shape_ = ShapeCheck::Check(src_); + ch_old_ = shape_[dimslice]; + CHECK(begin < shape_[dimslice] && end <= shape_[dimslice]) + << "The slice went out of range"; + shape_[dimslice] = end - begin; + } + template + inline void + operator=(const expr::Exp &exp) { + this->__assign(exp); + } + inline void + operator=(const DType &exp) { + this->__assign(exp); + } +}; // struct Slice + +/*! + * \brief Slice a Tensor + * \param src source tensor + * \param begin The beginning slice. + * \param end The end slice. + * \return sliced tensor + * \tparam sdim the dimension to slice on + * \tparam SrcExp source expression + * \tparam DType the type of elements + * \tparam etype type of expression + */ +template +inline SliceExp +slice(const TRValue &src, index_t begin, index_t end) { + TypeCheckPass::kDim == srcdim> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + return SliceExp(src.self(), begin, end); +} +//------------------------ +// engine plugin +//------------------------ +// runtime shapecheck +template +struct ShapeCheck >{ + inline static Shape Check(const SliceExp &t) { + return t.shape_; + } +}; +template +struct StreamInfo >{ + inline static Stream * + Get(const SliceExp &t) { + return StreamInfo::Get(t.src_); + } +}; +// static typecheck +template +struct ExpInfo >{ + static const int kDim = ExpInfo::kDim; + static const int kDevMask = ExpInfo::kDevMask; +}; +//---------------------- +// Execution plan +//--------------------- +template +struct Plan, DType> { + public: + static const int dimslice = srcdim - dimsrc_m_slice; + explicit Plan(const SliceExp &e) + : src_(MakePlan(e.src_)), + height_(e.shape_.ProdShape(dimslice + 1, srcdim - 1)), + ch_begin_(e.ch_begin_), ch_old_(e.ch_old_), ch_(e.shape_[dimslice]) {} + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + const index_t y = i % height_; + i /= height_; + const index_t c = i % ch_ + ch_begin_; + const index_t b = i / ch_; + const index_t x = j; + return src_.Eval((b * ch_old_ + c) * height_ + y, x); + } + MSHADOW_XINLINE DType &REval(index_t i, index_t j) { + const index_t y = i % height_; + i /= height_; + const index_t c = i % ch_ + ch_begin_; + const index_t b = i / ch_; + const index_t x = j; + return src_.REval((b * ch_old_ + c) * height_ + y, x); + } + + private: + Plan src_; + const index_t height_, ch_begin_, ch_old_, ch_; +}; // struct Plan + +template +struct Plan, DType> { + public: + explicit Plan(const SliceExp &e) + : src_(MakePlan(e.src_)), + ch_begin_(e.ch_begin_) {} + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + return src_.Eval(y, x + ch_begin_); + } + MSHADOW_XINLINE DType &REval(index_t y, index_t x) { + return src_.REval(y, x + ch_begin_); + } + + private: + Plan src_; + const index_t ch_begin_; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_SLICE_H_ diff --git a/include/mshadow/extension/slice_ex.h b/include/mshadow/extension/slice_ex.h new file mode 100644 index 000000000000..7f464097fb3b --- /dev/null +++ b/include/mshadow/extension/slice_ex.h @@ -0,0 +1,135 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file slice.h + * \brief support for slice a certain dimension. + */ +#ifndef MSHADOW_EXTENSION_SLICE_EX_H_ +#define MSHADOW_EXTENSION_SLICE_EX_H_ + +#include "../extension.h" + +namespace mshadow { +namespace expr { +/*! + * \brief slice expression, slice a tensor's channel + * \tparam SrcExp left expression + * \tparam DType the type of elements + * \tparam srcdim dimension of src + * \tparam dimsrc_m_cat dimsrc - dimcat + */ +template +struct SliceExExp : public TRValue, + Device, srcdim, DType> { + const SrcExp &src_; + Shape src_shape_; + Shape shape_; + const Shape begin_; + const Shape end_; + SliceExExp(const SrcExp &src, Shape begin, Shape end) + : src_(src), begin_(begin), end_(end) { + src_shape_ = ShapeCheck::Check(src_); + for (int i = 0; i < srcdim; ++i) { + shape_[i] = end_[i] - begin_[i]; + } + } + template + inline void + operator=(const expr::Exp &exp) { + this->__assign(exp); + } + inline void + operator=(const DType &exp) { + this->__assign(exp); + } +}; // struct SliceEx + +/*! + * \brief SliceEx a Tensor + * \param src source tensor + * \param begin The beginning slice. + * \param end The end slice. + * \return sliced tensor + * \tparam sdim the dimension to slice on + * \tparam SrcExp source expression + * \tparam DType the type of elements + * \tparam etype type of expression + */ +template +inline SliceExExp +slice(const TRValue &src, Shape begin, Shape end) { + TypeCheckPass::kDim == srcdim> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + return SliceExExp(src.self(), begin, end); +} +//------------------------ +// engine plugin +//------------------------ +// runtime shapecheck +template +struct ShapeCheck >{ + inline static Shape Check(const SliceExExp &t) { + return t.shape_; + } +}; + +template +struct StreamInfo >{ + inline static Stream * + Get(const SliceExExp &t) { + return StreamInfo::Get(t.src_); + } +}; +// static typecheck +template +struct ExpInfo >{ + static const int kDim = ExpInfo::kDim; + static const int kDevMask = ExpInfo::kDevMask; +}; +//---------------------- +// Execution plan +//--------------------- +template +struct Plan, DType> { + public: + explicit Plan(const SliceExExp &e) + : src_(MakePlan(e.src_)), begin_(e.begin_), + src_shape_(e.src_shape_), shape_(e.shape_) {} + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + index_t idx = 0; + index_t stride = 1; + #pragma unroll + for (int k = srcdim-2; k >= 0; --k) { + idx += stride * (i%shape_[k] + begin_[k]); + i /= shape_[k]; + stride *= src_shape_[k]; + } + return src_.Eval(idx, j + begin_[srcdim-1]); + } + MSHADOW_XINLINE DType &REval(index_t i, index_t j) { + index_t idx = 0; + index_t stride = 1; + #pragma unroll + for (int k = srcdim-2; k >= 0; --k) { + idx += stride * (i%shape_[k] + begin_[k]); + i /= shape_[k]; + stride *= src_shape_[k]; + } + return src_.REval(idx, j + begin_[srcdim-1]); + } + + private: + Plan src_; + const Shape begin_, src_shape_, shape_; +}; // struct Plan +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_SLICE_EX_H_ diff --git a/include/mshadow/extension/spatial_pool.h b/include/mshadow/extension/spatial_pool.h new file mode 100644 index 000000000000..c833fb40ad58 --- /dev/null +++ b/include/mshadow/extension/spatial_pool.h @@ -0,0 +1,152 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file spatial_pool.h + * \brief support for spatial pooling + * \author Tianqi Chen + */ +#ifndef MSHADOW_EXTENSION_SPATIAL_POOL_H_ +#define MSHADOW_EXTENSION_SPATIAL_POOL_H_ +#include +#include "../extension.h" +namespace mshadow { +namespace expr { +/*! + * \brief pooling expression, do reduction over local patches of a image + * \tparam Reducer reduction method during pooling + * \tparam SrcExp source expression to be pooled from + * \tparam DType the content data type + * \tparam srcdim dimension of src + */ +template +struct PoolingExp: + public MakeTensorExp, + SrcExp, srcdim, DType> { + /*! \brief source operand */ + const SrcExp &src_; + /*! \brief kernel size in height */ + index_t ksize_y_; + /*! \brief kernel size in width */ + index_t ksize_x_; + /*! \brief kernel stride in y directory */ + index_t kstride_y_; + /*! \brief kernel stride in x directory */ + index_t kstride_x_; + /*! \brief source height shape[1] */ + index_t src_height_; + /*! \brief source width shape[0] */ + index_t src_width_; + /*! \brief constructor */ + PoolingExp(const SrcExp &src, + index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x) + : src_(src), ksize_y_(ksize_y), ksize_x_(ksize_x), + kstride_y_(kstride_y), kstride_x_(kstride_x) { + Shape sshape = ShapeCheck::Check(src_); + CHECK(sshape[srcdim - 1] >= ksize_x && sshape[srcdim - 2] >= ksize_y) + << "PoolingExp: kernel must be smaller than image"; + this->src_height_ = sshape[srcdim - 2]; + this->src_width_ = sshape[srcdim - 1]; + this->shape_ = sshape; + this->shape_[srcdim - 2] = (src_height_ - ksize_y) / kstride_y + 1; + this->shape_[srcdim - 1] = (src_width_ - ksize_x) / kstride_x + 1; + } + /*! \brief constructor, specify shape */ + PoolingExp(const SrcExp &src, Shape<2> pshape, + index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x) + : src_(src), ksize_y_(ksize_y), ksize_x_(ksize_x), + kstride_y_(kstride_y), kstride_x_(kstride_x) { + Shape sshape = ShapeCheck::Check(src_); + CHECK(sshape[srcdim - 1] >= ksize_x && sshape[srcdim - 2] >= ksize_y) + << "PoolingExp: kernel must be smaller than image"; + this->src_height_ = sshape[srcdim - 2]; + this->src_width_ = sshape[srcdim - 1]; + this->shape_ = sshape; + this->shape_[srcdim - 2] = pshape[0]; + this->shape_[srcdim - 1] = pshape[1]; + } +}; +/*! + * \brief pooling subregion results together + * \param src source image, shape: (batch, channel, height, width) + * \param ksize_y kernel size in height + * \param ksize_x kernel size in width + * \param kstride_y stride in y directory + * \param kstride_x stride in x directory + * \return expression of pooled result + * \tparam Reducer reducer type + * \tparam SrcExp source expression + * \tparam DType the content data type + * \tparam etype type of expression + */ +template +inline PoolingExp::kDim> +pool(const Exp &src, + index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x) { + TypeCheckPass::kDim >= 2> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + return PoolingExp::kDim> + (src.self(), ksize_y, ksize_x, kstride_y, kstride_x); +} +/*! + * \brief same as pool, except the output shape is specified by pshape + * \param src source image + * \param pshape ouput shape + * \param ksize_y kernel size in y + * \param ksize_x kernel size in x + * \param kstride_y stride in y directory + * \param kstride_x stride in x directory + * \return expression of pooled result + * \tparam Reducer reducer type + * \tparam SrcExp source expression + * \tparam DType the content data type + * \tparam etype type of expression + */ +template +inline PoolingExp::kDim> +pool(const Exp &src, Shape<2> pshape, + index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x) { + TypeCheckPass::kDim >= 2> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + return PoolingExp::kDim> + (src.self(), pshape, ksize_y, ksize_x, kstride_y, kstride_x); +} +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + explicit Plan(const PoolingExp &e) + : src_(MakePlan(e.src_)), + ksize_y_(e.ksize_y_), ksize_x_(e.ksize_x_), + kstride_y_(e.kstride_y_), kstride_x_(e.kstride_x_), + src_height_(e.src_height_), src_width_(e.src_width_), + new_height_(e.shape_[srcdim - 2]) {} + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + using namespace std; + const index_t py = i % new_height_; + const index_t y_start = py * kstride_y_; + const index_t y_end = min(y_start + ksize_y_, src_height_); + const index_t px = j; + const index_t x_start = px * kstride_x_; + const index_t x_end = min(x_start + ksize_x_, src_width_); + const index_t c = i / new_height_; + + DType res; Reducer::SetInitValue(res); + for (index_t y = y_start; y < y_end; ++y) { + for (index_t x = x_start; x < x_end; ++x) { + Reducer::Reduce(res, src_.Eval(c * src_height_ + y, x)); + } + } + return res; + } + + private: + Plan src_; + const index_t ksize_y_, ksize_x_, kstride_y_, kstride_x_; + const index_t src_height_, src_width_; + const index_t new_height_; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_SPATIAL_POOL_H_ diff --git a/include/mshadow/extension/spatial_unpool.h b/include/mshadow/extension/spatial_unpool.h new file mode 100644 index 000000000000..e9ca2dfd035b --- /dev/null +++ b/include/mshadow/extension/spatial_unpool.h @@ -0,0 +1,135 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file spatial_unpool.h + * \brief support for unpool + * \author Tianqi Chen + */ +#ifndef MSHADOW_EXTENSION_SPATIAL_UNPOOL_H_ +#define MSHADOW_EXTENSION_SPATIAL_UNPOOL_H_ +#include +#include "../extension.h" +namespace mshadow { +namespace expr { +/*! + * \brief unpooling expr reverse operation of pooling, used to pass gradient back + * \tparam Reducer reduction method during pooling + * \tparam SrcExp source expression to be pooled from + * \tparam DType the content data type + * \tparam srcdim dimension of src + */ +template +struct UnPoolingExp: + public MakeTensorExp, + SrcExp, srcdim, DType> { + /*! \brief source input, corresponds to src in pooling */ + const SrcExp &data_src_; + /*! \brief result of pooled data, corresponds to result of pooling */ + const SrcExp &data_pooled_; + /*! \brief gradient data of pooled part, to be propgate down */ + const SrcExp &grad_pooled_; + /*! \brief shape of pooled expression */ + index_t pshape_y_; + /*! \brief shape of pooled expression */ + index_t pshape_x_; + /*! \brief kernel size in height */ + index_t ksize_y_; + /*! \brief kernel size in width */ + index_t ksize_x_; + /*! \brief kernel stride in y directory */ + index_t kstride_y_; + /*! \brief kernel stride in x directory */ + index_t kstride_x_; + /*! \brief constructor */ + UnPoolingExp(const SrcExp &data_src, + const SrcExp &data_pooled, + const SrcExp &grad_pooled, + index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x) + : data_src_(data_src), data_pooled_(data_pooled), + grad_pooled_(grad_pooled), + ksize_y_(ksize_y), ksize_x_(ksize_x), + kstride_y_(kstride_y), kstride_x_(kstride_x) { + Shape pshape = ShapeCheck::Check(grad_pooled); + typedef ShapeCheck ShapeCheckSrcDimSrcExp; + CHECK_EQ(pshape, ShapeCheckSrcDimSrcExp::Check(data_pooled)) + << "UnPoolingExp: pooled shape mismatch"; + Shape sshape = ShapeCheck::Check(data_src); + for (int k = 0; k < srcdim - 2; ++k) { + CHECK_EQ(pshape[k], sshape[k]) << "UnPoolingExp: pool and src shape mismatch"; + } + pshape_x_ = pshape[srcdim - 1]; + pshape_y_ = pshape[srcdim - 2]; + this->shape_ = sshape; + } +}; +/*! + * \brief unpooling gradient for 4D, backprop gradient value back, revserse operation of pooling, + * same as unpooling, but allows unequal size of kernel + * \param data_src source input, corresponds to src in pooling + * \param data_pooled result of pooled data, corresponds to result of pooling + * \param grad_pooled gradient data of pooled part, to be propgate down + * \param ksize_y kernel height + * \param ksize_x kernel width + * \param kstride_y stride in y directory + * \param kstride_x stride in x directory + * \return expression corresponding to unpooled 4D Tensor, storing backproped gradient + * \tparam Reducer reducer type + * \tparam SrcExp source expression + * \tparam DType the content data type + * \tparam etype type of expression + */ +template +inline UnPoolingExp::kDim> +unpool(const Exp &data_src, + const Exp &data_pooled, + const Exp &grad_pooled, + index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x) { + return UnPoolingExp::kDim> + (data_src.self(), data_pooled.self(), grad_pooled.self(), + ksize_y, ksize_x, kstride_y, kstride_x); +} +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + explicit Plan(const UnPoolingExp &e) + : data_src_(MakePlan(e.data_src_)), data_pooled_(MakePlan(e.data_pooled_)), + grad_pooled_(MakePlan(e.grad_pooled_)), sshape_y_(e.shape_[srcdim - 2]), + pshape_y_(e.pshape_y_), pshape_x_(e.pshape_x_), + ksize_y_(e.ksize_y_), ksize_x_(e.ksize_x_), + kstride_y_(e.kstride_y_), kstride_x_(e.kstride_x_) {} + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + using namespace std; + const index_t x = j; + const index_t y = i % sshape_y_; + const index_t c = i / sshape_y_; + const DType vsrc = data_src_.Eval(i, j); + const index_t py_min = + y < ksize_y_ ? 0 : (y - ksize_y_ + kstride_y_) / kstride_y_; + const index_t px_min = + x < ksize_x_ ? 0 : (x - ksize_x_ + kstride_x_) / kstride_x_; + const index_t py_max = min((y + kstride_y_) / kstride_y_, pshape_y_); + const index_t px_max = min((x + kstride_x_) / kstride_x_, pshape_x_); + + DType val = static_cast(0); + for (index_t py = py_min; py < py_max; ++py) { + for (index_t px = px_min; px < px_max; ++px) { + val += Reducer::PartialGrad(vsrc, + data_pooled_.Eval(c * pshape_y_ + py, px)) * + grad_pooled_.Eval(c * pshape_y_ + py, px); + } + } + + return val; + } + + private: + Plan data_src_, data_pooled_, grad_pooled_; + const index_t sshape_y_, pshape_y_, pshape_x_; + const index_t ksize_y_, ksize_x_; + const index_t kstride_y_, kstride_x_; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_SPATIAL_UNPOOL_H_ diff --git a/include/mshadow/extension/spatial_upsampling_nearest.h b/include/mshadow/extension/spatial_upsampling_nearest.h new file mode 100644 index 000000000000..534fbdd9ebe0 --- /dev/null +++ b/include/mshadow/extension/spatial_upsampling_nearest.h @@ -0,0 +1,71 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file spatial_upsampling.h + * \brief + * \author Bing Xu +*/ +#ifndef MSHADOW_EXTENSION_SPATIAL_UPSAMPLING_NEAREST_H_ +#define MSHADOW_EXTENSION_SPATIAL_UPSAMPLING_NEAREST_H_ +#include "../extension.h" + +namespace mshadow { +namespace expr { + +/*! \brief nearest neighboor upsampling + * out(x, y) = in(int(x / scale_x), int(y / scale_y)) + * \tparam SrcExp source expression + * \tparam DType data type + * \tparam srcdim source dimension + */ +template +struct UpSamplingNearestExp : + public MakeTensorExp, + SrcExp, srcdim, DType> { + /*! \brief source oprand */ + const SrcExp &src_; + /*! \brief up sampling scale */ + index_t scale_; + /*! \brief constructor */ + UpSamplingNearestExp(const SrcExp &src, index_t scale) + : src_(src), scale_(scale) { + this->shape_ = ShapeCheck::Check(src_); + this->shape_[srcdim - 2] *= scale_; + this->shape_[srcdim - 1] *= scale_; + } +}; + + +template +inline UpSamplingNearestExp::kDim> +upsampling_nearest(const Exp &src, index_t scale) { + TypeCheckPass::kDim >= 2> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + return UpSamplingNearestExp::kDim>(src.self(), scale); +} + +template +struct Plan, DType> { + public: + explicit Plan(const UpSamplingNearestExp &e) + : src_(MakePlan(e.src_)), + scale_(e.scale_), + new_height_(e.shape_[srcdim - 2]), + src_height_(static_cast(e.shape_[srcdim - 2] / e.scale_)) {} + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + const index_t x = j; + const index_t y = i % new_height_; + const index_t c = i / new_height_; + const index_t h = static_cast(y / scale_); + const index_t w = static_cast(x / scale_); + return src_.Eval(c * src_height_ + h, w); + } + + private: + Plan src_; + const index_t scale_; + const index_t new_height_; + const index_t src_height_; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_SPATIAL_UPSAMPLING_NEAREST_H_ diff --git a/include/mshadow/extension/swapaxis.h b/include/mshadow/extension/swapaxis.h new file mode 100644 index 000000000000..b79aba441175 --- /dev/null +++ b/include/mshadow/extension/swapaxis.h @@ -0,0 +1,110 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file swapaxis.h + * \brief support for swapaxis + * \author Tianqi Chen + */ +#ifndef MSHADOW_EXTENSION_SWAPAXIS_H_ +#define MSHADOW_EXTENSION_SWAPAXIS_H_ +#include +#include +#include "../extension.h" +namespace mshadow { +namespace expr { +/*! + * \brief swap two axis of a tensor + * input: Tensor: ishape + * output: Tensor oshape[a1],oshape[a2] = ishape[a2],oshape[a1] + * + * \tparam SrcExp type of source expression + * \tparam DType the type of elements + * \tparam dimsrc source dimension, assert a1 > a2 + * \tparam m_a1 one dimension to be swapped, encoded by dimsrc - a1 + * \tparam a2 second dimension to be swapped, encoded by a2 + */ +template +struct SwapAxisExp: + public MakeTensorExp, + SrcExp, dimsrc, DType> { + // decode the a1, a2 + static const int a1 = dimsrc - m_a1; + /*! \brief source expression */ + const SrcExp &src_; + /*! \brief constructor */ + explicit SwapAxisExp(const SrcExp &src) : src_(src) { + this->shape_ = ShapeCheck::Check(src); + std::swap(this->shape_[a1], this->shape_[a2]); + } +}; +/*! + * \brief a expression that reshapes a tensor to another shape + * \param src Tensor: + * \return a expresion with type Tensor + * \tparam a1 higher dimension to be swapped, assert a1 > a2 + * \tparam a2 lower dimension to be swapped + * \tparam SrcExp source expression + * \tparam DType the type of elements + * \tparam etype source expression type + */ +template +inline SwapAxisExp::kDim, + ExpInfo::kDim - a1, a2> +swapaxis(const Exp &src) { + typedef ExpInfo Info; + TypeCheckPass= a1 + 1 && Info::kDim >= a2 + 1 && + a2 < a1>::Error_Expression_Does_Not_Meet_Dimension_Req(); + return SwapAxisExp::kDim, + ExpInfo::kDim - a1, a2>(src.self()); +} +template +struct Plan, DType> { + public: + // decode the a1 + static const int a1 = dimsrc - m_a1; + explicit Plan(const SwapAxisExp &e) + : src_(MakePlan(e.src_)), + shapey_(e.shape_.ProdShape(a1 + 1, dimsrc - 1)), + shapez_(e.shape_[a1]), + shapec_(e.shape_.ProdShape(a2 + 1, a1)), + shapen_(e.shape_[a2]) {} + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + const index_t y = i % shapey_; + i /= shapey_; + const index_t z = i % shapez_; + i /= shapez_; + const index_t c = i % shapec_; + i /= shapec_; + const index_t n = i % shapen_; + // swap z and n + return src_.Eval(((((i / shapen_) * shapez_ + z) * shapec_ + + c) * shapen_ + n) * shapey_ + y, j); + } + + private: + Plan src_; + const index_t shapey_, shapez_, shapec_, shapen_; +}; +template +struct Plan, DType> { + public: + explicit Plan(const SwapAxisExp &e) + : src_(MakePlan(e.src_)), + shapex_(e.shape_[dimsrc - 1]), + shapey_(e.shape_.ProdShape(a2 + 1, dimsrc - 1)), + shapez_(e.shape_[a2]) {} + MSHADOW_XINLINE DType Eval(index_t i, index_t x) const { + // swap x and z + const index_t y = i % shapey_; + i /= shapey_; + const index_t z = i % shapez_; + const index_t n = i / shapez_; + return src_.Eval((n * shapex_ + x) * shapey_ + y , z); + } + + private: + Plan src_; + const index_t shapex_, shapey_, shapez_; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_SWAPAXIS_H_ diff --git a/include/mshadow/extension/take.h b/include/mshadow/extension/take.h new file mode 100644 index 000000000000..76c4f4729491 --- /dev/null +++ b/include/mshadow/extension/take.h @@ -0,0 +1,99 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file take.h + * \brief + * \author Bing Xu +*/ +#ifndef MSHADOW_EXTENSION_TAKE_H_ +#define MSHADOW_EXTENSION_TAKE_H_ + +#include "../extension.h" + +namespace mshadow { +namespace expr { + +/*! \brief Take a column from a matrix + * \tparam IndexExp type of index expression + * \tparam SrcExp type of src expression + * \tparam DType data type + */ +template +struct TakeExp: public Exp, + DType, type::kChainer> { + /*! \brief index oprand */ + const IndexExp &index_; + /*! \brief embediing oprand */ + const SrcExp &src_; + /*! constructor */ + TakeExp(const IndexExp &index, const SrcExp &src) + : index_(index), src_(src) {} +}; // struct TakeExp + + + +template +inline TakeExp +take(const Exp &index, + const Exp &src) { + return TakeExp(index.self(), src.self()); +} + + +//---------------------- +// Execution plan +//---------------------- + +template +struct Plan, DType> { + public: + explicit Plan(const TakeExp &e) + : index_(MakePlan(e.index_)), src_(MakePlan(e.src_)) { + } + + // TODO(xx): discuss W shape: in * out or out * in + // Now I use in * out + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + index_t idx = static_cast(index_.Eval(0, y)); + return static_cast(src_.Eval(idx, x)); + } + + private: + expr::Plan index_; + expr::Plan src_; +}; // struct Plan + +template +inline Plan, DType> +MakePlan(const TakeExp &exp) { + return Plan, DType>(exp); +} + +template +struct ShapeCheck > { + inline static Shape + Check(const TakeExp &t) { + CHECK(dim == 2) + << "TakeExp only support 2D output"; + Shape<1> dshape = ShapeCheck<1, IndexExp>::Check(t.index_); + Shape<2> wshape = ShapeCheck<2, SrcExp>::Check(t.src_); + Shape ret; + ret[0] = dshape[0]; + ret[1] = wshape[1]; + return ret; + } +}; + + +template +struct ExpInfo > { + static const int kDim = 2; + static const int kDevMask = ExpInfo::kDevMask; +}; + +} // namespace expr +} // namespace mshadow + +#endif // MSHADOW_EXTENSION_TAKE_H_ diff --git a/include/mshadow/extension/take_grad.h b/include/mshadow/extension/take_grad.h new file mode 100644 index 000000000000..4479b3e0cd9d --- /dev/null +++ b/include/mshadow/extension/take_grad.h @@ -0,0 +1,111 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file take_grad.h + * \brief + * \author Bing Xu +*/ +#ifndef MSHADOW_EXTENSION_TAKE_GRAD_H_ +#define MSHADOW_EXTENSION_TAKE_GRAD_H_ + +#include "../extension.h" + +namespace mshadow { +namespace expr { + +/*! \brief Calculate embedding gradient + * \tparam IndexExp type of index expression + * \tparam SrcExp type of src expression + * \tparam DType data type + */ + +template +struct TakeGradExp : public Exp, + DType, type::kChainer> { + /*! \brief index oprand */ + const IndexExp &index_; + /*! \brief out gradient oprand */ + const SrcExp &src_; + /*! \brief batch size */ + const index_t input_dim_; + /*! \brief constructor */ + TakeGradExp(const IndexExp &index, const SrcExp &src, const index_t input_dim) + : index_(index), src_(src), input_dim_(input_dim) {} +}; // struct TakeGradExp + + +template +inline TakeGradExp +take_grad(const Exp &index, + const Exp &src, + const index_t input_dim) { + return TakeGradExp(index.self(), + src.self(), + input_dim); +} + +//---------------------- +// Execution plan +//---------------------- + +template +struct Plan, DType> { + public: + explicit Plan(const TakeGradExp &e) + : index_(MakePlan(e.index_)), + src_(MakePlan(e.src_)), + batch_size_(ShapeCheck<1, IndexExp>::Check(e.index_)[0]) { + } + + // now return shape: in * out + MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { + DType ret = 0.f; + for (index_t i = 0; i < batch_size_; ++i) { + index_t idx = static_cast(index_.Eval(0, i)); + if (idx == y) { + ret += static_cast(src_.Eval(i, x)); + } + } + return ret; + } + + private: + expr::Plan index_; + expr::Plan src_; + const index_t batch_size_; +}; // struct Plan + + +template +inline Plan, DType> +MakePlan(const TakeGradExp &exp) { + return Plan, DType>(exp); +} + +template +struct ShapeCheck > { + inline static Shape + Check(const TakeGradExp &t) { + CHECK(dim == 2) + << "TakeGradExp only support 2D output"; + // Shape<1> dshape = ShapeCheck<1, IndexExp>::Check(t.index_); + Shape<2> gshape = ShapeCheck<2, SrcExp>::Check(t.src_); + Shape ret; + ret[0] = t.input_dim_; + ret[1] = gshape[1]; + return ret; + } +}; // struct ShapeCheck + +template +struct ExpInfo > { + static const int kDim = 2; + static const int kDevMask = ExpInfo::kDevMask; +}; + +} // namespace expr +} // namespace mshadow + +#endif // MSHADOW_EXTENSION_TAKE_GRAD_H_ diff --git a/include/mshadow/extension/transpose.h b/include/mshadow/extension/transpose.h new file mode 100644 index 000000000000..6640153f2100 --- /dev/null +++ b/include/mshadow/extension/transpose.h @@ -0,0 +1,200 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file transpose.h + * \brief support for transpose + * \author Junyuan Xie + */ +#ifndef MSHADOW_EXTENSION_TRANSPOSE_H_ +#define MSHADOW_EXTENSION_TRANSPOSE_H_ +#include +#include "../extension.h" +namespace mshadow { +namespace expr { +/*! + * \brief transpose axes of a tensor + * input: Tensor: ishape + * output: Tensor oshape[a1],oshape[a2] = ishape[a2],oshape[a1] + * + * \tparam SrcExp type of source expression + * \tparam DType the type of elements + * \tparam dimsrc source dimension, assert a1 > a2 + * \tparam m_a1 one dimension to be swapped, encoded by dimsrc - a1 + * \tparam a2 second dimension to be swapped, encoded by a2 + */ +template +struct TransposeExExp: + public MakeTensorExp, + SrcExp, dimsrc, DType> { + /*! \brief source expression */ + const SrcExp &src_; + const Shape axes_; + Shape dst_in_src_stride_; // Holds the corresponding stride of the dst axes in src + index_t src_stride_; + /*! \brief constructor */ + explicit TransposeExExp(const SrcExp &src, Shape axes) : src_(src), axes_(axes) { + Shape src_shape = ShapeCheck::Check(src); + src_stride_ = src_shape[dimsrc - 1]; + Shape src_stride; + src_stride[dimsrc-1] = 1; + for (int i = dimsrc-2; i >= 0; --i) src_stride[i] = src_shape[i+1]*src_stride[i+1]; + for (int i = 0; i < dimsrc; ++i) { + dst_in_src_stride_[i] = src_stride[axes[i]]; + this->shape_[i] = src_shape[axes[i]]; + } + } +}; +/*! + * \brief a expression that reshapes a tensor to another shape + * \param src Tensor: + * \return a expresion with type Tensor + * \tparam a1 higher dimension to be swapped, assert a1 > a2 + * \tparam a2 lower dimension to be swapped + * \tparam SrcExp source expression + * \tparam DType the type of elements + * \tparam etype source expression type + */ +template +inline TransposeExExp::kDim> +transpose(const Exp &src, Shape::kDim> axes) { + return TransposeExExp::kDim>(src.self(), axes); +} + +template +struct Plan, DType> { + public: + explicit Plan(const TransposeExExp &e) + : src_(MakePlan(e.src_)), + src_stride_(e.src_stride_), + dst_in_src_stride_(e.dst_in_src_stride_), + dst_shape_(e.shape_) {} + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + index_t idx = j * dst_in_src_stride_[dimsrc - 1]; + #pragma unroll + for (int k = dimsrc-2; k >= 0; --k) { + idx += (i % dst_shape_[k]) * dst_in_src_stride_[k]; + i /= dst_shape_[k]; + } + return src_.Eval(idx/src_stride_, idx%src_stride_); + } + + private: + Plan src_; + const index_t src_stride_; + const Shape dst_in_src_stride_, dst_shape_; +}; + +/*! + * \brief transform contiguous indices of the source tensor to indices of the transposed tensor. + * input: Tensor: ishape + * output: Tensor: oshape = ishape + * + * \tparam SrcExp type of source expression + * \tparam DType the type of elements + * \tparam dimsrc source dimension + * \tparam etype source type + */ +template +struct TransposeIndicesExp: + public Exp, DType, etype> { + /*! \brief source expression */ + const SrcExp &src_indices_; // Expression of the source indices + Shape src_shape_; // Holds the corresponding stride of the source axes in dst + const Shape axes_; // The transpose axes + Shape src_in_dst_stride_; // Holds the corresponding stride of the source axes in dst + /*! \brief constructor */ + explicit TransposeIndicesExp(const SrcExp &src_indices, + Shape src_shape, + Shape axes) : src_indices_(src_indices), + src_shape_(src_shape), axes_(axes) { + Shape dst_shape_; + Shape dst_stride_; + bool axes_checking_flag[dimsrc] = { 0 }; + for (int i = 0; i < dimsrc; ++i) { + CHECK_LT(static_cast(axes[i]), dimsrc) + << "Invalid axes input! All elements of axes must be between 0 and " << dimsrc + << ", find axes=" << axes; + dst_shape_[i] = src_shape[axes[i]]; + axes_checking_flag[axes[i]] = true; + } + // check if the input axes is valid + for (int i = 0; i < dimsrc; ++i) { + CHECK_EQ(axes_checking_flag[i], true) + << "Invalid axes input! All elements of axes must be between 0 and " << dimsrc + << ", find axes=" << axes; + } + dst_stride_[dimsrc - 1] = 1; + for (int i = dimsrc - 2; i >= 0; --i) dst_stride_[i] = dst_shape_[i+1] * dst_stride_[i+1]; + for (int i = 0; i < dimsrc; ++i) { + src_in_dst_stride_[axes[i]] = dst_stride_[i]; + } + } +}; + +/*! + * \brief a expression that reshapes a tensor to another shape + * \param src Tensor: + * \return a expresion with type Tensor + * \tparam a1 higher dimension to be swapped, assert a1 > a2 + * \tparam a2 lower dimension to be swapped + * \tparam SrcExp source expression + * \tparam DType the type of elements + * \tparam etype source expression type + */ +template +inline TransposeIndicesExp +transpose_indices(const Exp &src_indices, + Shape src_shape, + Shape axes) { + return TransposeIndicesExp(src_indices.self(), src_shape, axes); +} + +template +struct Plan, DType> { + public: + explicit Plan(const TransposeIndicesExp &e) + : src_indices_(MakePlan(e.src_indices_)), + src_in_dst_stride_(e.src_in_dst_stride_), + src_shape_(e.src_shape_) {} + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + index_t src_idx = static_cast(src_indices_.Eval(i, j)); + index_t dst_idx = 0; + #pragma unroll + for (int k = dimsrc - 1; k >= 0; --k) { + dst_idx += (src_idx % src_shape_[k]) * src_in_dst_stride_[k]; + src_idx /= src_shape_[k]; + } + return static_cast(dst_idx); + } + + private: + Plan src_indices_; + const Shape src_in_dst_stride_, src_shape_; +}; + +//---------------------- +// Execution plan +//---------------------- +/*! \brief make expression */ +template +inline Plan, DType> +MakePlan(const TransposeIndicesExp &e) { + return Plan, DType>(e); +} + +template +struct ShapeCheck > { + inline static Shape + Check(const TransposeIndicesExp &t) { + Shape s = ShapeCheck::Check(t.src_indices_); + return s; + } +}; + +template +struct ExpInfo > { + static const int kDim = ExpInfo::kDim; + static const int kDevMask = ExpInfo::kDevMask; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_TRANSPOSE_H_ diff --git a/include/mshadow/extension/unpack_patch2col.h b/include/mshadow/extension/unpack_patch2col.h new file mode 100644 index 000000000000..ed473f81d496 --- /dev/null +++ b/include/mshadow/extension/unpack_patch2col.h @@ -0,0 +1,151 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file unpack_patch2col.h + * \brief support for unpack + * \author Tianqi Chen + */ +#ifndef MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_ +#define MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_ +#include "../extension.h" +namespace mshadow { +namespace expr { +/*! + * \brief unpack local (overlap) patches of image to column of mat, + * can be used to implement convolution, this expression allow unpack of a batch + * this is a version support unpacking multiple images + * after getting unpacked mat, we can use: output = dot(weight, mat) to get covolved results, the relations: + * \tparam SrcExp source expression + * \tparam dstdim destination dimension + */ +template +struct UnpackPatchToColXExp: + public MakeTensorExp, + SrcExp, 2, DType>{ + /*! \brief source operand */ + const SrcExp &img_; + /*! \brief patch height */ + index_t psize_y_; + /*! \brief patch width */ + index_t psize_x_; + /*! \brief patch stride */ + index_t pstride_y_; + index_t pstride_x_; + /*! \brief patch dilate */ + index_t pdilate_y_; + index_t pdilate_x_; + /*! \brief number of input channel */ + index_t i_channel_; + /*! \brief height of img */ + index_t i_height_; + /*! \brief width of img */ + index_t i_width_; + /*! \brief constructor */ + UnpackPatchToColXExp(const SrcExp &img, + index_t psize_y, + index_t psize_x, + index_t pstride_y, + index_t pstride_x, + index_t pdilate_y, + index_t pdilate_x) + : img_(img), psize_y_(psize_y), psize_x_(psize_x), + pstride_y_(pstride_y), pstride_x_(pstride_x), + pdilate_y_(pdilate_y), pdilate_x_(pdilate_x){ + Shape imshape = ShapeCheck::Check(img_); + CHECK(imshape[srcdim - 1] >= psize_x && imshape[srcdim - 2] >= psize_y) + << "UnpackPatchToCol:image shape smaller than patch size"; + this->i_channel_ = imshape[srcdim - 3]; + this->i_height_ = imshape[srcdim - 2]; + this->i_width_ = imshape[srcdim - 1]; + // calculate number of batches + const index_t num = imshape.ProdShape(0, srcdim - 3); + const index_t o_height = (i_height_ - + (pdilate_y * (psize_y - 1) + 1)) / pstride_y + 1; + const index_t o_width = (i_width_ - + (pdilate_x * (psize_x - 1) + 1)) / pstride_x + 1; + this->shape_[1] = o_height * o_width * num; + this->shape_[0] = psize_y * psize_x * i_channel_; + } +}; + +/*! + * \brief unpack local (overlap) patches of image to column of mat, can be used to implement convolution + * after getting unpacked mat, we can use: output = dot(weight, mat) to get covolved results, the relations: + * + * weight; shape[0]: out_channel, shape[1]: ichannel * psize_y * psize_x + * output; shape[0]: out_channel, shape[1]: out_height * out_width * num_of_images + * out_height = (in_height - psize_y) / pstride + 1, this means we pad inperfect patch with 0 + * out_width = (in_width - psize_x) / pstride + 1 + * + * \return mat target matrix; shape[0]: in_channel*psize_y*psize_x shape[1]: out_height*out_width * num_of_images + * \param img source image; shape[-3]: in_channels, shape[-2]: in_height, shape[-1]: in_width, can be 3D or 4D tensor(multiple images) + * \param psize_y height of each patch + * \param psize_x width of each patch + * \param pstride stride of each patch + * \param pdilate dilate of each patch + * \tparam SrcExp source expression + * \tparam DType the type of elements + * \tparam etype type of expression + */ +template +inline UnpackPatchToColXExp::kDim> +unpack_patch2col(const Exp &img, + index_t psize_y, index_t psize_x, index_t pstride, index_t pdilate) { + TypeCheckPass::kDim >= 3> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + return UnpackPatchToColXExp::kDim> + (img.self(), psize_y, psize_x, pstride, pstride, pdilate, pdilate); +} + +/*! + *if you want to specify stride_x and stride_y + */ +template +inline UnpackPatchToColXExp::kDim> +unpack_patch2col(const Exp &img, + index_t psize_y, index_t psize_x, index_t pstride_y_, index_t pstride_x_, + index_t pdilate_y_, index_t pdilate_x_) { + TypeCheckPass::kDim >= 3> + ::Error_Expression_Does_Not_Meet_Dimension_Req(); + return UnpackPatchToColXExp::kDim> + (img.self(), psize_y, psize_x, pstride_y_, pstride_x_, pdilate_y_, pdilate_x_); +} +//---------------------- +// Execution plan +//---------------------- +template +struct Plan, DType> { + public: + explicit Plan(const UnpackPatchToColXExp &e) + :src_(MakePlan(e.img_)), + psize_y_(e.psize_y_), psize_x_(e.psize_x_), + pstride_y_(e.pstride_y_), pstride_x_(e.pstride_x_), + i_channel_(e.i_channel_), pdilate_y_(e.pdilate_y_), pdilate_x_(e.pdilate_x_), + i_height_(e.i_height_), i_width_(e.i_width_), + o_height_((i_height_ - (pdilate_y_ * (psize_y_ - 1) + 1)) / pstride_y_ + 1), + o_width_((i_width_ - (pdilate_x_ * (psize_x_ - 1) + 1)) / pstride_x_ + 1) {} + MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { + const index_t x_offset = i % psize_x_ * pdilate_x_; + const index_t idivp = i / psize_x_; + const index_t y_offset = idivp % psize_y_ * pdilate_y_; + const index_t c = idivp / psize_y_; + const index_t x = (j % o_width_) * pstride_x_ + x_offset; + const index_t jdivw = j / o_width_; + const index_t y = (jdivw % o_height_) * pstride_y_ + y_offset; + const index_t n = jdivw / o_height_; + + if (x < i_width_ && y < i_height_) { + return src_.Eval((n * i_channel_ + c) * i_height_ + y, x); + } else { + return DType(0.0f); + } + } + + private: + Plan src_; + const index_t psize_y_, psize_x_, pstride_y_, pstride_x_, i_channel_; + const index_t pdilate_y_, pdilate_x_; + const index_t i_height_, i_width_, o_height_, o_width_; +}; +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_ diff --git a/include/mshadow/half.h b/include/mshadow/half.h new file mode 100644 index 000000000000..75d8e5d09d2f --- /dev/null +++ b/include/mshadow/half.h @@ -0,0 +1,288 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file half.h + * \brief definition of half (float16) type. + * + * \author Junyuan Xie + */ +#ifndef MSHADOW_HALF_H_ +#define MSHADOW_HALF_H_ +#include "./base.h" + +#if MSHADOW_USE_F16C + #include +#endif // MSHADOW_USE_F16C + +#if (MSHADOW_USE_CUDA && CUDA_VERSION >= 7050) + #define MSHADOW_CUDA_HALF 1 + #include + #if defined(__CUDA_ARCH__) + /*! \brief __half2float_warp */ + __host__ __device__ float __half2float_warp(const volatile __half& h) { /* NOLINT(*) */ + __half val; +#if CUDA_VERSION >= 9000 + val = const_cast<__half&>(h); +#else + val.x = h.x; +#endif + return __half2float(val); + } + #endif +#else + #define MSHADOW_CUDA_HALF 0 +#endif + +/*! \brief namespace for mshadow */ +namespace mshadow { +/* \brief name space for host/device portable half-precision floats */ +namespace half { +#define MSHADOW_HALF_OPERATOR(RTYPE, OP) \ + MSHADOW_XINLINE RTYPE operator OP (half_t a, half_t b) { \ + return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \ + } \ + template \ + MSHADOW_XINLINE RTYPE operator OP (half_t a, T b) { \ + return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \ + } \ + template \ + MSHADOW_XINLINE RTYPE operator OP (T a, half_t b) { \ + return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \ + } + +#define MSHADOW_HALF_ASSIGNOP(AOP, OP) \ + template \ + MSHADOW_XINLINE half_t operator AOP (const T& a) { \ + return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \ + } \ + template \ + MSHADOW_XINLINE half_t operator AOP (const volatile T& a) volatile { \ + return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \ + } + +#if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__)) +#define MSHADOW_HALF_CONVERSIONOP(T) \ + MSHADOW_XINLINE operator T() const { \ + return T(__half2float(cuhalf_)); /* NOLINT(*)*/ \ + } \ + MSHADOW_XINLINE operator T() const volatile { \ + return T(__half2float_warp(cuhalf_)); /* NOLINT(*)*/ \ + } +#elif(MSHADOW_USE_F16C) +#define MSHADOW_HALF_CONVERSIONOP(T) \ + MSHADOW_XINLINE operator T() const { \ + return T(_cvtsh_ss(half_)); /* NOLINT(*)*/ \ + } \ + MSHADOW_XINLINE operator T() const volatile { \ + return T(_cvtsh_ss(half_)); /* NOLINT(*)*/ \ + } +#else +#define MSHADOW_HALF_CONVERSIONOP(T) \ + MSHADOW_XINLINE operator T() const { \ + return T(half2float(half_)); /* NOLINT(*)*/ \ + } \ + MSHADOW_XINLINE operator T() const volatile { \ + return T(half2float(half_)); /* NOLINT(*)*/ \ + } +#endif // (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__)) + +class MSHADOW_ALIGNED(2) half_t { + public: + union { + uint16_t half_; +#if MSHADOW_CUDA_HALF + __half cuhalf_; +#endif // MSHADOW_CUDA_HALF + }; + + static MSHADOW_XINLINE half_t Binary(uint16_t value) { + half_t res; + res.half_ = value; + return res; + } + + MSHADOW_XINLINE half_t() {} + +#if MSHADOW_CUDA_HALF + MSHADOW_XINLINE explicit half_t(const __half& value) { + cuhalf_ = value; + } +#endif // MSHADOW_CUDA_HALF + + MSHADOW_XINLINE half_t(const float& value) { constructor(value); } + MSHADOW_XINLINE explicit half_t(const double& value) { constructor(value); } + MSHADOW_XINLINE explicit half_t(const int8_t& value) { constructor(value); } + MSHADOW_XINLINE explicit half_t(const uint8_t& value) { constructor(value); } + MSHADOW_XINLINE explicit half_t(const int32_t& value) { constructor(value); } + MSHADOW_XINLINE explicit half_t(const uint32_t& value) { constructor(value); } + MSHADOW_XINLINE explicit half_t(const int64_t& value) { constructor(value); } + MSHADOW_XINLINE explicit half_t(const uint64_t& value) { constructor(value); } + + MSHADOW_HALF_CONVERSIONOP(float) + + MSHADOW_HALF_ASSIGNOP(+=, +) + MSHADOW_HALF_ASSIGNOP(-=, -) + MSHADOW_HALF_ASSIGNOP(*=, *) + MSHADOW_HALF_ASSIGNOP(/=, /) + + MSHADOW_XINLINE half_t operator+() { + return *this; + } + + MSHADOW_XINLINE half_t operator-() { + return half_t(-float(*this)); // NOLINT(*) + } + + MSHADOW_XINLINE half_t operator=(const half_t& a) { + half_ = a.half_; + return a; + } + + template + MSHADOW_XINLINE half_t operator=(const T& a) { + return *this = half_t(a); /* NOLINT(*)*/ + } + + MSHADOW_XINLINE half_t operator=(const half_t& a) volatile { + half_ = a.half_; + return a; + } + + template + MSHADOW_XINLINE half_t operator=(const T& a) volatile { + return *this = half_t(a); /* NOLINT(*)*/ + } + + private: + union Bits { + float f; + int32_t si; + uint32_t ui; + }; + + static int const shift = 13; + static int const shiftSign = 16; + + static int32_t const infN = 0x7F800000; // flt32 infinity + static int32_t const maxN = 0x477FE000; // max flt16 normal as a flt32 + static int32_t const minN = 0x38800000; // min flt16 normal as a flt32 + static int32_t const signN = 0x80000000; // flt32 sign bit + + static int32_t const infC = infN >> shift; + static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32 + static int32_t const maxC = maxN >> shift; + static int32_t const minC = minN >> shift; + static int32_t const signC = signN >> shiftSign; // flt16 sign bit + + static int32_t const mulN = 0x52000000; // (1 << 23) / minN + static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift)) + + static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted + static int32_t const norC = 0x00400; // min flt32 normal down shifted + + static int32_t const maxD = infC - maxC - 1; + static int32_t const minD = minC - subC - 1; + + MSHADOW_XINLINE uint16_t float2half(const float& value) const { + Bits v, s; + v.f = value; + uint32_t sign = v.si & signN; + v.si ^= sign; + sign >>= shiftSign; // logical shift + s.si = mulN; + s.si = s.f * v.f; // correct subnormals + v.si ^= (s.si ^ v.si) & -(minN > v.si); + v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN)); + v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN)); + v.ui >>= shift; // logical shift + v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC); + v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC); + return v.ui | sign; + } + + MSHADOW_XINLINE uint16_t float2half(const volatile float& value) const volatile { // NOLINT (*) + Bits v, s; + v.f = value; + uint32_t sign = v.si & signN; + v.si ^= sign; + sign >>= shiftSign; // logical shift + s.si = mulN; + s.si = s.f * v.f; // correct subnormals + v.si ^= (s.si ^ v.si) & -(minN > v.si); + v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN)); + v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN)); + v.ui >>= shift; // logical shift + v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC); + v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC); + return v.ui | sign; + } + + MSHADOW_XINLINE float half2float(const uint16_t& value) const { + Bits v; + v.ui = value; + int32_t sign = v.si & signC; + v.si ^= sign; + sign <<= shiftSign; + v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); + v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); + Bits s; + s.si = mulC; + s.f *= v.si; + int32_t mask = -(norC > v.si); + v.si <<= shift; + v.si ^= (s.si ^ v.si) & mask; + v.si |= sign; + return v.f; + } + + MSHADOW_XINLINE float half2float(const volatile uint16_t& value) const volatile { // NOLINT(*) + Bits v; + v.ui = value; + int32_t sign = v.si & signC; + v.si ^= sign; + sign <<= shiftSign; + v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); + v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); + Bits s; + s.si = mulC; + s.f *= v.si; + int32_t mask = -(norC > v.si); + v.si <<= shift; + v.si ^= (s.si ^ v.si) & mask; + v.si |= sign; + return v.f; + } + + template + MSHADOW_XINLINE void constructor(const T& value) { +#if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__)) + cuhalf_ = __float2half(float(value)); // NOLINT(*) +#elif(MSHADOW_USE_F16C) + half_ = _cvtss_sh(static_cast(value), 0); +#else /* !MSHADOW_CUDA_HALF && !MSHADOW_USE_F16C */ + half_ = float2half(float(value)); // NOLINT(*) +#endif /* !MSHADOW_CUDA_HALF && !MSHADOW_USE_F16C */ + } +}; + +/*! \brief overloaded + operator for half_t */ +MSHADOW_HALF_OPERATOR(half_t, +) +/*! \brief overloaded - operator for half_t */ +MSHADOW_HALF_OPERATOR(half_t, -) +/*! \brief overloaded * operator for half_t */ +MSHADOW_HALF_OPERATOR(half_t, *) +/*! \brief overloaded / operator for half_t */ +MSHADOW_HALF_OPERATOR(half_t, /) +/*! \brief overloaded > operator for half_t */ +MSHADOW_HALF_OPERATOR(bool, >) +/*! \brief overloaded < operator for half_t */ +MSHADOW_HALF_OPERATOR(bool, <) +/*! \brief overloaded >= operator for half_t */ +MSHADOW_HALF_OPERATOR(bool, >=) +/*! \brief overloaded <= operator for half_t */ +MSHADOW_HALF_OPERATOR(bool, <=) + +#define MSHADOW_HALF_MIN mshadow::half::half_t::Binary(0xFBFF); +#define MSHADOW_HALF_MAX mshadow::half::half_t::Binary(0x7BFF); +} // namespace half +} // namespace mshadow +#endif // MSHADOW_HALF_H_ diff --git a/include/mshadow/half2.h b/include/mshadow/half2.h new file mode 100755 index 000000000000..3e130c85ba63 --- /dev/null +++ b/include/mshadow/half2.h @@ -0,0 +1,143 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file half2.h + * \brief definition of vector float16, half2 type. + * + * \author Antti-Pekka Hynninen + */ +#ifndef MSHADOW_HALF2_H_ +#define MSHADOW_HALF2_H_ + +#if (defined(__CUDACC__) && __CUDA_ARCH__ >= 530 && MSHADOW_USE_CUDA && CUDA_VERSION >= 7050) + #define MSHADOW_CUDA_HALF2 1 + #include +#else + #define MSHADOW_CUDA_HALF2 0 +#endif + +#include + +/*! \brief namespace for mshadow */ +namespace mshadow { +/* \brief name space for host/device portable half-precision floats */ +namespace half { + +#define MSHADOW_HALF2_ASSIGNOP(AOP, OP) \ + template \ + MSHADOW_XINLINE half2_t operator AOP (const T& a) { \ + return *this = half2_t(*this OP a); /* NOLINT(*)*/ \ + } \ + +class MSHADOW_ALIGNED(4) half2_t { + public: +#if MSHADOW_CUDA_HALF2 + half2 half2_; +#else + half_t half_t2[2]; +#endif + + MSHADOW_XINLINE half2_t() {} + +#if MSHADOW_CUDA_HALF2 + MSHADOW_XINLINE explicit half2_t(half2 a) : half2_(a) {} +#else + MSHADOW_XINLINE explicit half2_t(half_t a, half_t b) { + half_t2[0] = a; + half_t2[1] = b; + } +#endif + + MSHADOW_XINLINE explicit half2_t(int a) { +#if MSHADOW_CUDA_HALF2 + half2_ = __half2half2(__int2half_rz(a)); +#else + half_t2[0] = (half_t)a; + half_t2[1] = (half_t)a; +#endif + } + + MSHADOW_XINLINE half2_t operator+() { + return *this; + } + + MSHADOW_XINLINE half2_t operator-() { +#if MSHADOW_CUDA_HALF2 + return half2_t(__hneg2(half2_)); +#else + return half2_t(-half_t2[0], -half_t2[1]); +#endif + } + + MSHADOW_XINLINE half2_t operator=(const half2_t& a) { +#if MSHADOW_CUDA_HALF2 + half2_ = a.half2_; +#else + half_t2[0] = a.half_t2[0]; + half_t2[1] = a.half_t2[1]; +#endif + return a; + } + + MSHADOW_HALF2_ASSIGNOP(+=, +) + MSHADOW_HALF2_ASSIGNOP(-=, -) + MSHADOW_HALF2_ASSIGNOP(*=, *) + MSHADOW_HALF2_ASSIGNOP(/=, /) +}; + +/*! \brief overloaded + operator for half2_t */ +MSHADOW_XINLINE half2_t operator+(half2_t a, half2_t b) { +#if MSHADOW_CUDA_HALF2 + return half2_t(__floats2half2_rn(__low2float(a.half2_) + __low2float(b.half2_), + __high2float(a.half2_) + __high2float(b.half2_))); +#else + return half2_t(a.half_t2[0] + b.half_t2[0], a.half_t2[1] + b.half_t2[1]); +#endif +} +/*! \brief overloaded - operator for half2_t */ +MSHADOW_XINLINE half2_t operator-(half2_t a, half2_t b) { +#if MSHADOW_CUDA_HALF2 + return half2_t(__floats2half2_rn(__low2float(a.half2_) - __low2float(b.half2_), + __high2float(a.half2_) - __high2float(b.half2_))); +#else + return half2_t(a.half_t2[0] - b.half_t2[0], a.half_t2[1] - b.half_t2[1]); +#endif +} +/*! \brief overloaded * operator for half2_t */ +MSHADOW_XINLINE half2_t operator*(half2_t a, half2_t b) { +#if MSHADOW_CUDA_HALF2 + return half2_t(__floats2half2_rn(__low2float(a.half2_) * __low2float(b.half2_), + __high2float(a.half2_) * __high2float(b.half2_))); +#else + return half2_t(a.half_t2[0] * b.half_t2[0], a.half_t2[1] * b.half_t2[1]); +#endif +} +/*! \brief overloaded / operator for half2_t */ +MSHADOW_XINLINE half2_t operator/(half2_t a, half2_t b) { +#if MSHADOW_CUDA_HALF2 + return half2_t(__floats2half2_rn(__low2float(a.half2_) / __low2float(b.half2_), + __high2float(a.half2_) / __high2float(b.half2_))); +#else + return half2_t(a.half_t2[0] / b.half_t2[0], a.half_t2[1] / b.half_t2[1]); +#endif +} +/*! \brief overloaded % operator for half2_t */ +MSHADOW_XINLINE half2_t operator%(half2_t a, half2_t b) { +#if MSHADOW_CUDA_HALF2 + return half2_t(__floats2half2_rn(::fmod(__low2float(a.half2_), __low2float(b.half2_)), + ::fmod(__high2float(a.half2_), __high2float(b.half2_)))); +#else + return half2_t(::fmod(a.half_t2[0], b.half_t2[0]), ::fmod(a.half_t2[1], b.half_t2[1])); +#endif +} +/*! \brief overloaded == operator for half2_t */ +MSHADOW_XINLINE bool operator==(half2_t a, half2_t b) { +#if MSHADOW_CUDA_HALF2 + return __hbeq2(a.half2_, b.half2_); +#else + return (a.half_t2[0] == b.half_t2[0] && a.half_t2[1] == b.half_t2[1]); +#endif +} + +} // namespace half +} // namespace mshadow +#endif // MSHADOW_HALF2_H_ diff --git a/include/mshadow/io.h b/include/mshadow/io.h new file mode 100644 index 000000000000..2d0efc3aa56b --- /dev/null +++ b/include/mshadow/io.h @@ -0,0 +1,137 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file io.h + * \brief definitions of I/O functions for mshadow tensor + * \author Tianqi Chen + */ +#ifndef MSHADOW_IO_H_ +#define MSHADOW_IO_H_ +#include "./tensor.h" + +namespace mshadow { +namespace utils { +/*! + * \brief interface of stream I/O, used to serialize data, + * mshadow does not restricted to only this interface in SaveBinary/LoadBinary + * mshadow accept all class that implements Read and Write + */ +class IStream { + public: + /*! + * \brief read data from stream + * \param ptr pointer to memory buffer + * \param size size of block + * \return usually is the size of data readed + */ + virtual size_t Read(void *ptr, size_t size) = 0; + /*! + * \brief write data to stream + * \param ptr pointer to memory buffer + * \param size size of block + */ + virtual void Write(const void *ptr, size_t size) = 0; + /*! \brief virtual destructor */ + virtual ~IStream(void) {} +}; +} // namespace utils +/*! + * \brief CPU/GPU: save a tensor by binary format, for GPU version, a temp Tensor storage will be allocated + * \param fo output binary stream + * \param src source data file + * \tparam dim dimension of tensor + * \tparam DType type of element in tensor + * \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. + */ +template +inline void SaveBinary(TStream &fo, const Tensor &src); // NOLINT(*) +/*! + * \brief CPU/GPU: save a tensor by binary format, for GPU version, a temp Tensor storage will be allocated + * \param fo output binary stream + * \param src source data file + * \tparam dim dimension of tensor + * \tparam DType type of element in tensor + * \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. + */ +template +inline void SaveBinary(TStream &fo, const Tensor &src); // NOLINT(*) +/*! + * \brief CPU/GPU: load a tensor by binary format, for GPU version, a temp Tensor storage will be allocated + * if pre_alloc is true , then space in dst is preallocated, and must have same shape of the tensor loaded + * if pre_alloc is false, then dst originally does not have space allocated, LoadBinary will allocate space for dst + * \param fi output binary stream + * \param dst destination file + * \param pre_alloc whether space is pre-allocated, if false, space allocation will happen + * \tparam dim dimension of tensor + * \tparam DType type of element in tensor + * \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. + */ +template +inline void LoadBinary(TStream &fi, // NOLINT(*) + Tensor *dst, bool pre_alloc); +/*! + * \brief CPU/GPU: load a tensor by binary format, for GPU version, a temp Tensor storage will be allocated + * if pre_alloc is true , then space in dst is preallocated, and must have same shape of the tensor loaded + * if pre_alloc is false, then dst originally does not have space allocated, LoadBinary will allocate space for dst + * \param fi output binary stream + * \param dst destination file + * \param pre_alloc whether space is pre-allocated, if false, space allocation will happen + * \tparam dim dimension of tensor + * \tparam DType type of element in tensor + * \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. + */ + +template +inline void LoadBinary(TStream &fi, // NOLINT(*) + Tensor *dst, bool pre_alloc); + +// implementations +template +inline void SaveBinary(TStream &fo, const Tensor &src_) { // NOLINT(*) + fo.Write(&src_.shape_, sizeof(src_.shape_)); + Tensor src = src_.FlatTo2D(); + for (index_t i = 0; i < src.size(0); ++i) { + fo.Write(src[i].dptr_, sizeof(DType) * src.size(1)); + } +} +template +inline void SaveBinary(TStream &fo, const Tensor &src) { // NOLINT(*) + // copy to CPU, then save + Tensor tmp(src.shape_); + AllocSpace(&tmp); + Stream stream; + Copy(tmp, src, &stream); + SaveBinary(fo, tmp); + FreeSpace(&tmp); +} +template +inline void LoadBinary(TStream &fi, // NOLINT(*) + Tensor *dst_, bool pre_alloc) { + Shape shape; + CHECK_NE(fi.Read(&shape, sizeof(shape)), 0) << "mshadow::LoadBinary"; + if (pre_alloc) { + CHECK_EQ(shape, dst_->shape_) << "LoadBinary, shape do not match pre-allocated shape"; + } else { + dst_->shape_ = shape; AllocSpace(dst_); + } + Tensor dst = dst_->FlatTo2D(); + if (dst.size(0) == 0) return; + for (index_t i = 0; i < dst.size(0); ++i) { + CHECK_NE(fi.Read(dst[i].dptr_, sizeof(DType) * dst.size(1)), 0) << "mshadow::LoadBinary"; + } +} +template +inline void LoadBinary(TStream &fi, // NOLINT(*) + Tensor *dst, bool pre_alloc) { + Tensor tmp; + LoadBinary(fi, &tmp, false); + if (pre_alloc) { + CHECK_EQ(tmp.shape, dst->shape_) << "LoadBinary, shape do not match pre-allocated shape"; + } else { + dst->shape = tmp.shape; AllocSpace(dst); + } + Stream stream; + Copy(*dst, tmp, &stream); + FreeSpace(&tmp); +} +} // namespace mshadow +#endif // MSHADOW_IO_H_ diff --git a/include/mshadow/logging.h b/include/mshadow/logging.h new file mode 100644 index 000000000000..002b90097595 --- /dev/null +++ b/include/mshadow/logging.h @@ -0,0 +1,234 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file logging.h + * \brief defines logging macros of dmlc + * allows use of GLOG, fall back to internal + * implementation when disabled + */ +#ifndef MSHADOW_LOGGING_H_ +#define MSHADOW_LOGGING_H_ +#ifndef DMLC_LOGGING_H_ +#define DMLC_LOGGING_H_ + +#include +#include +#include +#include +#include +#include "./base.h" + +namespace dmlc { +/*! \brief taken from DMLC directly */ + +/*! + * \brief exception class that will be thrown by + * default logger if DMLC_LOG_FATAL_THROW == 1 + */ +struct Error : public std::runtime_error { + /*! + * \brief constructor + * \param s the error message + */ + explicit Error(const std::string &s) : std::runtime_error(s) {} +}; +} // namespace dmlc + +#if defined(_MSC_VER) && _MSC_VER < 1900 +#define noexcept(a) +#endif + +#if DMLC_USE_GLOG +#include + +namespace dmlc { +/*! \brief taken from DMLC directly */ +inline void InitLogging(const char* argv0) { + google::InitGoogleLogging(argv0); +} +} // namespace dmlc + +#else +// use a light version of glog +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable : 4722) +#endif + +namespace dmlc { +inline void InitLogging(const char* argv0) { + // DO NOTHING +} + +// Always-on checking +#define CHECK(x) \ + if (!(x)) \ + dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check " \ + "failed: " #x << ' ' +#define CHECK_LT(x, y) CHECK((x) < (y)) +#define CHECK_GT(x, y) CHECK((x) > (y)) +#define CHECK_LE(x, y) CHECK((x) <= (y)) +#define CHECK_GE(x, y) CHECK((x) >= (y)) +#define CHECK_EQ(x, y) CHECK((x) == (y)) +#define CHECK_NE(x, y) CHECK((x) != (y)) +#define CHECK_NOTNULL(x) \ + ((x) == NULL ? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check notnull: " #x << ' ', (x) : (x)) // NOLINT(*) +// Debug-only checking. +#ifdef NDEBUG +#define DCHECK(x) \ + while (false) CHECK(x) +#define DCHECK_LT(x, y) \ + while (false) CHECK((x) < (y)) +#define DCHECK_GT(x, y) \ + while (false) CHECK((x) > (y)) +#define DCHECK_LE(x, y) \ + while (false) CHECK((x) <= (y)) +#define DCHECK_GE(x, y) \ + while (false) CHECK((x) >= (y)) +#define DCHECK_EQ(x, y) \ + while (false) CHECK((x) == (y)) +#define DCHECK_NE(x, y) \ + while (false) CHECK((x) != (y)) +#else +#define DCHECK(x) CHECK(x) +#define DCHECK_LT(x, y) CHECK((x) < (y)) +#define DCHECK_GT(x, y) CHECK((x) > (y)) +#define DCHECK_LE(x, y) CHECK((x) <= (y)) +#define DCHECK_GE(x, y) CHECK((x) >= (y)) +#define DCHECK_EQ(x, y) CHECK((x) == (y)) +#define DCHECK_NE(x, y) CHECK((x) != (y)) +#endif // NDEBUG + +#define LOG_INFO dmlc::LogMessage(__FILE__, __LINE__) +#define LOG_ERROR LOG_INFO +#define LOG_WARNING LOG_INFO +#define LOG_FATAL dmlc::LogMessageFatal(__FILE__, __LINE__) +#define LOG_QFATAL LOG_FATAL + +// Poor man version of VLOG +#define VLOG(x) LOG_INFO.stream() + +#define LOG(severity) LOG_##severity.stream() +#define LG LOG_INFO.stream() +#define LOG_IF(severity, condition) \ + !(condition) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) + +#ifdef NDEBUG +#define LOG_DFATAL LOG_ERROR +#define DFATAL ERROR +#define DLOG(severity) true ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) +#define DLOG_IF(severity, condition) \ + (true || !(condition)) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) +#else +#define LOG_DFATAL LOG_FATAL +#define DFATAL FATAL +#define DLOG(severity) LOG(severity) +#define DLOG_IF(severity, condition) LOG_IF(severity, condition) +#endif + +// Poor man version of LOG_EVERY_N +#define LOG_EVERY_N(severity, n) LOG(severity) + +class DateLogger { + public: + DateLogger() { +#if defined(_MSC_VER) + _tzset(); +#endif + } + const char* HumanDate() { +#if defined(_MSC_VER) + _strtime_s(buffer_, sizeof(buffer_)); +#else + time_t time_value = time(NULL); + struct tm now; + localtime_r(&time_value, &now); + snprintf(buffer_, sizeof(buffer_), "%02d:%02d:%02d", now.tm_hour, + now.tm_min, now.tm_sec); +#endif + return buffer_; + } + private: + char buffer_[9]; +}; + +class LogMessage { + public: + LogMessage(const char* file, int line) + : +#ifdef __ANDROID__ + log_stream_(std::cout) +#else + log_stream_(std::cerr) +#endif + { + log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":" + << line << ": "; + } + ~LogMessage() { log_stream_ << "\n"; } + std::ostream& stream() { return log_stream_; } + + protected: + std::ostream& log_stream_; + + private: + DateLogger pretty_date_; + LogMessage(const LogMessage&); + void operator=(const LogMessage&); +}; + +#if DMLC_LOG_FATAL_THROW == 0 +class LogMessageFatal : public LogMessage { + public: + LogMessageFatal(const char* file, int line) : LogMessage(file, line) {} + ~LogMessageFatal() { + log_stream_ << "\n"; + abort(); + } + + private: + LogMessageFatal(const LogMessageFatal&); + void operator=(const LogMessageFatal&); +}; +#else +class LogMessageFatal { + public: + LogMessageFatal(const char* file, int line) { + log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":" + << line << ": "; + } + std::ostringstream &stream() { return log_stream_; } + ~LogMessageFatal() DMLC_THROW_EXCEPTION { + // throwing out of destructor is evil + // hopefully we can do it here + throw Error(log_stream_.str()); + } + + private: + std::ostringstream log_stream_; + DateLogger pretty_date_; + LogMessageFatal(const LogMessageFatal&); + void operator=(const LogMessageFatal&); +}; +#endif + +// This class is used to explicitly ignore values in the conditional +// logging macros. This avoids compiler warnings like "value computed +// is not used" and "statement has no effect". +class LogMessageVoidify { + public: + LogMessageVoidify() {} + // This has to be an operator with a precedence lower than << but + // higher than "?:". See its usage. + void operator&(std::ostream&) {} +}; + +} // namespace dmlc + +#endif +#endif // DMLC_LOGGING_H_ +#endif // MSHADOW_LOGGING_H_ + diff --git a/include/mshadow/packet-inl.h b/include/mshadow/packet-inl.h new file mode 100644 index 000000000000..f5a89bfa8421 --- /dev/null +++ b/include/mshadow/packet-inl.h @@ -0,0 +1,413 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file packet-inl.h + * \brief Generic packet vectorization code + */ +#ifndef MSHADOW_PACKET_INL_H_ +#define MSHADOW_PACKET_INL_H_ + +#ifdef __APPLE__ +#include +#else +#include +#endif +#include "./base.h" +#include "./tensor.h" +#include "./expression.h" + + +namespace mshadow { +/*! \brief namespace of packet math*/ +namespace packet { + +enum PacketArch { + kPlain, + kSSE2, +}; + +#if MSHADOW_USE_SSE +#define MSHADOW_DEFAULT_PACKET ::mshadow::packet::kSSE2 +#else +#define MSHADOW_DEFAULT_PACKET ::mshadow::packet::kPlain +#endif + +// whether packet operator is enabled. +/*! + * \brief Generic packet type + * \tparam DType The data type of the packet. + * \tparam Arch the Arch of the packet. + */ +template +struct Packet; + +template +struct AlignBytes { + static const index_t value = 4; +}; + +} // namespace packet +} // namespace mshadow + +namespace mshadow { +namespace packet { +/*! + * \brief analog to cudaMallocPitch, allocate a aligned space with num_line * lspace cells + * \param out_pitch output parameter, the actuall space allocated for each line + * \param lspace number of cells required for each line + * \param num_line number of lines to be allocated + */ +inline void* AlignedMallocPitch(size_t *out_pitch, + size_t lspace, + size_t num_line) { + const index_t bits = AlignBytes::value; + const index_t mask = (1 << bits) - 1; + + size_t pitch = ((lspace + mask) >> bits) << bits; + *out_pitch = pitch; +#ifdef _MSC_VER + void *res = _aligned_malloc(pitch * num_line, 1 << bits); +#else + void *res; + int ret = posix_memalign(&res, 1 << bits, pitch * num_line); + CHECK_EQ(ret, 0) << "AlignedMallocPitch failed"; +#endif + if (res == NULL) { + LOG(FATAL) << "AlignedMallocPitch failed"; + } + return res; +} + +/*! + * \brief free aligned space + * \param ptr pointer to space to be freed + */ +inline void AlignedFree(void *ptr) { +#ifdef _MSC_VER + _aligned_free(ptr); +#else + free(ptr); +#endif +} + +/*! \brief check if a pointer is aligned */ +template +inline bool CheckAlign(size_t pitch) { + const index_t bits = AlignBytes::value; + return !(pitch & ((1 << bits) - 1)); +} + +/*! \brief check if a pointer is aligned */ +template +inline bool CheckAlign(void *ptr) { + return CheckAlign(reinterpret_cast(ptr)); +} + +/*! + * \brief get upper bound of aligned index of size + * \param size size of the array + * \param fsize size of float + */ +template +inline index_t UpperAlign(index_t size) { + const index_t bits = AlignBytes::value; + const index_t mask = (1 << bits) - 1; + const index_t fsize = sizeof(DType); + return (((size * fsize + mask) >> bits) << bits) / fsize; +} + +/*! + * \brief get lower bound of aligned index of size + * \param size size of the array + * \param fsize size of float + */ +template +inline index_t LowerAlign(index_t size) { + const index_t bits = AlignBytes::value; + const index_t fsize = sizeof(DType); + return (((size * fsize) >> bits) << bits) / fsize; +} + +/*! + * \brief generic Packet operator + * \tparam OP The operator + * \tparam DType The data type + * \tparam Arch The architecture. + */ +template +struct PacketOp { + static const bool kEnabled = false; +}; +// specialization of operators +template +struct PacketOp { + static const bool kEnabled = true; + MSHADOW_CINLINE static Packet Map(const Packet& lhs, + const Packet& rhs) { + return lhs + rhs; + } +}; +template +struct PacketOp { + static const bool kEnabled = true; + MSHADOW_CINLINE static Packet Map(const Packet& lhs, + const Packet& rhs) { + return lhs - rhs; + } +}; +template +struct PacketOp { + static const bool kEnabled = true; + MSHADOW_CINLINE static Packet Map(const Packet& lhs, + const Packet& rhs) { + return lhs * rhs; + } +}; +template +struct PacketOp { + static const bool kEnabled = true; + MSHADOW_CINLINE static Packet Map(const Packet& lhs, + const Packet& rhs) { + return lhs / rhs; + } +}; + +template +struct PacketOp { + static const bool kEnabled = true; + MSHADOW_CINLINE static Packet Map(const Packet& src) { + return src; + } +}; + + +// savers to do storage +template +struct Saver{ + MSHADOW_CINLINE static void Save(TFloat *dst, const Packet& src) { + Packet lhs = Packet::Load(dst); + Packet ans = PacketOp::Map(lhs, src); + ans.Store(dst); + } +}; +template +struct Saver { + MSHADOW_CINLINE static void Save(TFloat *dst, const Packet& src) { + src.Store(dst); + } +}; +} // namespace packet +} // namespace mshadow + +#include "packet/plain-inl.h" +#if MSHADOW_USE_SSE && !defined(__CUDACC__) +#include "packet/sse-inl.h" +#endif + +namespace mshadow { +namespace expr { + +typedef packet::PacketArch PacketArch; + +// same as plan, but use packet +template +class PacketPlan { + public: + /*! + * \brief evaluate the expression at index [y][x], + * x will be aligned to Packet::Size() + */ + MSHADOW_CINLINE packet::Packet EvalPacket(index_t y, index_t x) const; + MSHADOW_CINLINE DType Eval(index_t y, index_t x) const; +}; + +template +class PacketPlan, DType, Arch> { + public: + explicit PacketPlan(const Tensor &t) + :dptr_(t.dptr_), stride_(t.stride_) {} + MSHADOW_CINLINE packet::Packet EvalPacket(index_t y, index_t x) const { + return packet::Packet::Load(&dptr_[y * stride_ + x]); + } + MSHADOW_CINLINE DType Eval(index_t y, index_t x) const { + return dptr_[y * stride_ + x]; + } + + private: + const DType *dptr_; + index_t stride_; +}; + +template +class PacketPlan, DType, Arch> { + public: + explicit PacketPlan(DType scalar) : scalar_(scalar) {} + MSHADOW_CINLINE packet::Packet EvalPacket(index_t y, index_t x) const { + return packet::Packet::Fill(scalar_); + } + MSHADOW_CINLINE DType Eval(index_t y, index_t x) const { + return scalar_; + } + + private: + DType scalar_; +}; + +template +class PacketPlan, DType, Arch> { + public: + PacketPlan(const PacketPlan &lhs, const PacketPlan &rhs) + : lhs_(lhs), rhs_(rhs) {} + MSHADOW_CINLINE packet::Packet EvalPacket(index_t y, index_t x) const { + return packet::PacketOp::Map(lhs_.EvalPacket(y, x), rhs_.EvalPacket(y, x)); + } + MSHADOW_CINLINE DType Eval(index_t y, index_t x) const { + return OP::Map(lhs_.Eval(y, x), rhs_.Eval(y, x)); + } + + private: + PacketPlan lhs_; + PacketPlan rhs_; +}; + +template +class PacketPlan, DType, Arch> { + public: + PacketPlan(const PacketPlan &src) : src_(src) {} + MSHADOW_CINLINE packet::Packet EvalPacket(index_t y, index_t x) const { + return packet::PacketOp::Map(src_.EvalPacket(y, x)); + } + MSHADOW_CINLINE DType Eval(index_t y, index_t x) const { + return OP::Map(src_.Eval(y, x)); + } + + private: + PacketPlan src_; +}; + +template +inline PacketPlan, DType, Arch> +MakePacketPlan(const BinaryMapExp &e); + +template +inline PacketPlan, DType, Arch> MakePacketPlan(const ScalarExp &e) { + return PacketPlan, DType, Arch>(e.scalar_); +} +template +inline PacketPlan MakePacketPlan(const RValueExp &e) { + return PacketPlan(e.self()); +} +template +inline PacketPlan +MakePacketPlan(const MakeTensorExp &e) { + return PacketPlan(e.real_self()); +} +template +inline PacketPlan, DType, Arch> +MakePacketPlan(const UnaryMapExp &e) { + return PacketPlan, DType, Arch>(MakePacketPlan(e.src_)); +} +template +inline PacketPlan, DType, Arch> +MakePacketPlan(const BinaryMapExp &e) { + return PacketPlan, + DType, Arch>(MakePacketPlan(e.lhs_), MakePacketPlan(e.rhs_)); +} + +/*! + * \brief static check packet enable + * + * \tparam Device the type of Device + * \tparam dim dimension of the tensor + * \tparam E expression + */ +template +struct PacketCheck{ + static const bool kPass = false; +}; +template +struct PacketCheck { + static const bool kPass = true; +}; +template +struct PacketCheck { + static const bool kPass = true; +}; +template +struct PacketCheck, Arch> { + static const bool kPass = PacketCheck::kPass; +}; +template +struct PacketCheck, Arch> { + static const bool kPass = PacketCheck::kPass; +}; +template +struct PacketCheck, Arch> { + static const bool kPass = PacketCheck::kPass && + packet::PacketOp::kEnabled; +}; +template +struct PacketCheck< BinaryMapExp, Arch> { + static const bool kPass = packet::PacketOp::kEnabled && + PacketCheck::kPass && PacketCheck::kPass; +}; +//---------------------------------------------------- +// Check if data is aligned and allow packet operation +//---------------------------------------------------- +template +struct PacketAlignCheck { + inline static bool Check(const E &exp) { + return false; + } +}; +template +struct PacketAlignCheck, Arch> { + inline static bool Check(const ScalarExp &exp) { + return true; + } +}; +template +struct PacketAlignCheck, Arch> { + inline static bool Check(const Tensor &t) { + return packet::CheckAlign(t.dptr_) && + packet::CheckAlign(t.stride_ * sizeof(DType)); + } +}; +template +struct PacketAlignCheck, Arch> { + inline static bool Check(const UnaryMapExp &t) { + return PacketAlignCheck::Check(t.src_); + } +}; +template +struct PacketAlignCheck, Arch> { + inline static bool Check(const BinaryMapExp &t) { + return PacketAlignCheck::Check(t.lhs_) && + PacketAlignCheck::Check(t.rhs_); + } +}; + +/*! + * \brief use PacketPlan to compute result + */ +template +inline void MapPacketPlan(Tensor _dst, + const expr::PacketPlan& plan) { + Tensor dst = _dst.FlatTo2D(); + const index_t xlen = packet::LowerAlign(dst.size(1)); + const size_t packetSize = packet::Packet::size; +#ifndef __CUDACC__ + #pragma omp parallel for +#endif + for (openmp_index_t y = 0; y < dst.size(0); ++y) { + for (index_t x = 0; x < xlen; x += packetSize) { + packet::Saver::Save(&dst[y][x], plan.EvalPacket(y, x)); + } + for (index_t x = xlen; x < dst.size(1); ++x) { + SV::Save(dst[y][x], plan.Eval(y, x)); + } + } +} +} // namespace expr +} // namespace mshadow +#endif // MSHADOW_PACKET_INL_H_ diff --git a/include/mshadow/packet/plain-inl.h b/include/mshadow/packet/plain-inl.h new file mode 100644 index 000000000000..de28ad7b4894 --- /dev/null +++ b/include/mshadow/packet/plain-inl.h @@ -0,0 +1,76 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file plain-inl.h + * \brief support of plain packet that use the plain datatype. + */ +#ifndef MSHADOW_PACKET_PLAIN_INL_H_ +#define MSHADOW_PACKET_PLAIN_INL_H_ + +#include "../base.h" +#include "../packet-inl.h" + +namespace mshadow { +namespace packet { +template +struct Packet { + public: + /*! \brief number of float in vector */ + static constexpr index_t size = 1; + /*! \brief The internal data */ + DType data_; + // enable default copy constructor + Packet(void) {} + // constructor from the intrinsic type + explicit Packet(DType data) : data_(data) {} + // create a fill with the target value s + MSHADOW_CINLINE static Packet Fill(DType s) { + return Packet(s); + } + // load from address + MSHADOW_CINLINE static Packet Load(const DType* src) { + return Packet(*src); + } + // load from address + MSHADOW_CINLINE static Packet LoadUnAligned(const DType* src) { + return Packet(*src); + } + // fill it with value s + MSHADOW_CINLINE Packet& operator=(DType s) { + data_ = s; + return *this; + } + // store data into dst + MSHADOW_CINLINE void Store(DType* dst) const { + *dst = data_; + } + // get the sum of all contents + MSHADOW_CINLINE DType Sum() const { + return data_; + } +}; + +template +MSHADOW_CINLINE Packet operator+(const Packet& lhs, + const Packet& rhs) { + return Packet(lhs.data_ + rhs.data_); +} + +template +MSHADOW_CINLINE Packet operator-(const Packet& lhs, + const Packet& rhs) { + return Packet(lhs.data_ - rhs.data_); +} +template +MSHADOW_CINLINE Packet operator*(const Packet& lhs, + const Packet& rhs) { + return Packet(lhs.data_ * rhs.data_); +} + +template +MSHADOW_CINLINE Packet operator/(const Packet& lhs, + const Packet& rhs) { + return Packet(lhs.data_ / rhs.data_); +} +} // namespace packet +} // namespace mshadow +#endif // MSHADOW_PACKET_PLAIN_INL_H_ diff --git a/include/mshadow/packet/sse-inl.h b/include/mshadow/packet/sse-inl.h new file mode 100644 index 000000000000..923a5f60de38 --- /dev/null +++ b/include/mshadow/packet/sse-inl.h @@ -0,0 +1,147 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file sse-inl.h + * \brief support of sse2 packet optimization of some operations + * \author Tianqi Chen + */ +#ifndef MSHADOW_PACKET_SSE_INL_H_ +#define MSHADOW_PACKET_SSE_INL_H_ + +#include +#include "../base.h" +#include "../packet-inl.h" + +namespace mshadow { +namespace packet { +template<> +struct Packet { + public: + /*! \brief number of float in vector */ + static constexpr index_t size = 4; + /*! \brief The internal data */ + __m128 data_; + // enable default copy constructor + Packet(void) {} + // constructor from the intrinsic type + explicit Packet(__m128 data) : data_(data) {} + // create a fill with the target value s + MSHADOW_CINLINE static Packet Fill(float s) { + return Packet(_mm_set1_ps(s)); + } + // load from address + MSHADOW_CINLINE static Packet Load(const float* src) { + return Packet(_mm_load_ps(src)); + } + // load from address + MSHADOW_CINLINE static Packet LoadUnAligned(const float* src) { + return Packet(_mm_loadu_ps(src)); + } + // fill it with value s + MSHADOW_CINLINE Packet& operator=(float s) { + data_ = _mm_set1_ps(s); + return *this; + } + // store data into dst + MSHADOW_CINLINE void Store(float* dst) const { + _mm_store_ps(dst, data_); + } + // get the sum of all contents + MSHADOW_CINLINE float Sum() const { + __m128 ans = _mm_add_ps(data_, _mm_movehl_ps(data_, data_)); + __m128 rst = _mm_add_ss(ans, _mm_shuffle_ps(ans, ans, 1)); +#if defined(_MSC_VER) && (_MSC_VER <= 1500) && defined(_WIN64) + return rst.m128_f32[0]; +#else + float rr = _mm_cvtss_f32(rst); + return rr; +#endif + } +}; + + +/*! \brief vector real type for float */ +template<> +struct Packet { + /*! \brief number of float in vector */ + static constexpr index_t size = 2; + // internal data + __m128d data_; + // constructor + Packet(void) {} + explicit Packet(__m128d data) : data_(data) {} + // create a fill with the target value s + MSHADOW_CINLINE static Packet Fill(double s) { + return Packet(_mm_set1_pd(s)); + } + // load from address + MSHADOW_CINLINE static Packet Load(const double* src) { + return Packet(_mm_load_pd(src)); + } + MSHADOW_CINLINE static Packet LoadUnAligned(const double* src) { + return Packet(_mm_loadu_pd(src)); + } + // fill it with value s + MSHADOW_CINLINE Packet& operator=(double s) { + data_ = _mm_set1_pd(s); + return *this; + } + // store data into dst + MSHADOW_CINLINE void Store(double* dst) const { + _mm_store_pd(dst, data_); + } + // get sum of all content + inline double Sum(void) const { + __m128d tmp = _mm_add_sd(data_, _mm_unpackhi_pd(data_, data_)); +#if defined(_MSC_VER) && (_MSC_VER <= 1500) && defined(_WIN64) + return tmp.m128d_f64[0]; +#else + double ans = _mm_cvtsd_f64(tmp); + return ans; +#endif + } +}; + +MSHADOW_CINLINE Packet operator+(const Packet& lhs, + const Packet& rhs) { + return Packet(_mm_add_ps(lhs.data_, rhs.data_)); +} + +MSHADOW_CINLINE Packet operator+(const Packet& lhs, + const Packet& rhs) { + return Packet(_mm_add_pd(lhs.data_, rhs.data_)); +} + +MSHADOW_CINLINE Packet operator-(const Packet& lhs, + const Packet& rhs) { + return Packet(_mm_sub_ps(lhs.data_, rhs.data_)); +} + +MSHADOW_CINLINE Packet operator-(const Packet& lhs, + const Packet& rhs) { + return Packet(_mm_sub_pd(lhs.data_, rhs.data_)); +} + +MSHADOW_CINLINE Packet operator*(const Packet& lhs, + const Packet& rhs) { + return Packet(_mm_mul_ps(lhs.data_, rhs.data_)); +} + +MSHADOW_CINLINE Packet operator*(const Packet& lhs, + const Packet& rhs) { + return Packet(_mm_mul_pd(lhs.data_, rhs.data_)); +} + + +MSHADOW_CINLINE Packet operator/(const Packet& lhs, + const Packet& rhs) { + return Packet(_mm_div_ps(lhs.data_, rhs.data_)); +} + +MSHADOW_CINLINE Packet operator/(const Packet& lhs, + const Packet& rhs) { + return Packet(_mm_div_pd(lhs.data_, rhs.data_)); +} + +} // namespace packet +} // namespace mshadow +#endif // MSHADOW_PACKET_SSE_INL_H_ diff --git a/include/mshadow/random.h b/include/mshadow/random.h new file mode 100644 index 000000000000..c136f4f67809 --- /dev/null +++ b/include/mshadow/random.h @@ -0,0 +1,570 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file random.h + * \brief Random inline functions for tensor. + * \author Bing Xu, Tianqi Chen + * Based on curand|MKL|stdlib + */ +#ifndef MSHADOW_RANDOM_H_ +#define MSHADOW_RANDOM_H_ + +#include +#include +#include +#include "./base.h" +#include "./tensor.h" +#include "./tensor_container.h" + +#if MSHADOW_IN_CXX11 +#include // use cxx11 random by default +#endif + +#if _MSC_VER +#define rand_r(x) rand() +#endif + + +namespace mshadow { +/*! + * \brief random number generator + * \tparam Device the device of random number generator + * \tparam DType the target data type of random number can be float for double + */ +template +class Random {}; + +/*! \brief CPU random number generator */ +template +class Random { + public: + /*! + * \brief constructor of random engine + * \param seed random number seed + */ + explicit Random(int seed) { + this->Seed(seed); + buffer_.Resize(Shape1(kRandBufferSize)); + } + ~Random(void) { + } + /*! + * \brief seed random number generator using this seed + * \param seed seed of prng + */ + inline void Seed(int seed) { +#if MSHADOW_IN_CXX11 + rnd_engine_.seed(seed); +#endif + this->rseed_ = static_cast(seed); + } + /*! + * \brief get random seed used in random generator + * \return seed in unsigned + */ + inline unsigned GetSeed() const { + return rseed_; + } + /*! + * \brief set the stream of computation + * \param stream computation stream + */ + inline void set_stream(Stream *stream) { + } + +// These samplers are only avail in C++11. +#if MSHADOW_IN_CXX11 + + /*! + * \brief get some random integer + * \return integer as unsigned + */ + inline unsigned GetRandInt() { + return rnd_engine_(); + } + + /*! + * \brief get a set of random integers + */ + inline void GetRandInt(const Tensor& dst) { + std::generate_n(dst.dptr_, dst.size(0), [&](){ return rnd_engine_(); }); + } + + /*! + * \brief generate data from a distribution + * \param dst destination + * \tparam dim dimension of tensor + * \param sampler sampler of the distribution + */ + template + inline void SampleDistribution(Tensor *dst, Sampler sampler) { + if (dst->CheckContiguous()) { + std::generate_n(dst->dptr_, dst->shape_.Size(), sampler); + } else { + Tensor mat = dst->FlatTo2D(); + for (index_t i = 0; i < mat.size(0); ++i) { + std::generate_n(mat[i].dptr_, mat.size(1), sampler); + } + } + } + + /*! + * \brief generate data from uniform [a,b) + * \param dst destination + * \param a lower bound of uniform + * \param b upper bound of uniform + * \tparam dim dimension of tensor + */ + template + inline void SampleUniform(Tensor *dst, + PType a = 0.0f , PType b = 1.0f ) { + // Ensure that half_t is handled correctly. + typedef typename std::conditional::value, + DType, double>::type FType; + typedef typename std::conditional::value, + std::uniform_int_distribution, + std::uniform_real_distribution>::type GType; + GType dist_uniform(a, b); + SampleDistribution(dst, [&](){ return dist_uniform(rnd_engine_);}); + } + + /*! + * \brief generate data from standard gaussian + * \param dst destination + * \param mu mean variable + * \param sigma standard deviation + * \tparam dim dimension of tensor + */ + template + inline void SampleGaussian(Tensor *dst, + PType mu = 0.0f, PType sigma = 1.0f ) { + if (sigma <= 0) { + *dst = mu; return; + } + typedef typename std::conditional::value, + DType, double>::type GType; + std::normal_distribution dist_normal(mu, sigma); + SampleDistribution(dst, [&](){ return dist_normal(rnd_engine_);}); + } + + /*! + * \brief generate data from a gamma distribution + * \param dst destination + * \param alpha (shape) parameter + * \param beta (scale) parameter + * \tparam dim dimension of tensor + */ + template + inline void SampleGamma(Tensor *dst, + PType alpha, PType beta) { + typedef typename std::conditional::value, + DType, double>::type GType; + std::gamma_distribution dist_gamma(alpha, beta); + SampleDistribution(dst, [&](){ return dist_gamma(rnd_engine_);}); + } + + /*! + * \brief generate data from an exponential distribution + * \param dst destination + * \param lambda parameter (rate) of the distribution + * \tparam dim dimension of tensor + */ + template + inline void SampleExponential(Tensor *dst, PType lambda ) { + typedef typename std::conditional::value, + DType, double>::type GType; + std::exponential_distribution dist_exp(lambda); + SampleDistribution(dst, [&](){ return dist_exp(rnd_engine_);}); + } + + /*! + * \brief generate data from a poisson distribution + * \param dst destination + * \param lambda parameter (rate) of the distribution + * \tparam dim dimension of tensor + */ + template + inline void SamplePoisson(Tensor *dst, PType lambda) { + typedef typename std::conditional::value, DType, int>::type GType; + std::poisson_distribution dist_poisson(lambda); + SampleDistribution(dst, [&](){ return static_cast(dist_poisson(rnd_engine_));}); + } + + /*! + * \brief generate data from a negative binomial distribution + * \param dst destination + * \param k limit on number of failures + * \param p success probability + * \tparam dim dimension of tensor + */ + template + inline void SampleNegativeBinomial(Tensor *dst, PType1 k, PType2 p) { + typedef typename std::conditional::value, DType, int>::type GType; + std::negative_binomial_distribution dist_negbinomial(k, p); + SampleDistribution(dst, [&](){ return static_cast(dist_negbinomial(rnd_engine_));}); + } + + /*! + * \brief generate data from a generalized negative binomial distribution + * \param dst destination + * \param mu parameter (mean) of the distribution + * \param alpha parameter (over dispersion) of the distribution + * (for alpha=0 this gives a Poisson) + * \tparam dim dimension of tensor + */ + template + inline void SampleGeneralizedNegativeBinomial(Tensor *dst, + PType mu, PType alpha) { + if (alpha == PType(0)) { + SamplePoisson(dst, mu); // limit of Poisson + } else { + PType r(PType(1) / alpha); + PType beta = mu * alpha; + std::gamma_distribution<> dist_gamma(r, beta); + typedef typename std::conditional::value, DType, int>::type GType; + SampleDistribution(dst, + [&](){ std::poisson_distribution dist_poisson(dist_gamma(rnd_engine_)); + return static_cast(dist_poisson(rnd_engine_));}); + } + } +#endif + + /*! + * \brief return a temporal expression storing standard gaussian random variables + * the temporal tensor is only valid before next call of gaussian or uniform + * can be used as part of expression + * Caution: this means expression such as A = gaussian(s1) * gaussian(s2) will give invalid result, + * since second call of gaussian(s2) makes gaussian(s1) invalid + * A = gaussian(s1)*B+C; is correct; use one gaussian/uniform in each expression + * \param shape shape of the tensor + * \return a temporal expression storing standard gaussian random variables + * \tparam dim dimension of tensor + */ + template + inline expr::ReshapeExp, DType, dim, 1> + gaussian(Shape shape) { + buffer_.Resize(Shape1(shape.Size())); + this->SampleGaussian(&buffer_, 0.0f, 1.0f); + return expr::reshape(buffer_, shape); + } + /*! + * \brief return a temporal expression storing standard uniform [0,1) + * the temporal tensor is only valid before next call of gaussian or uniform + * can be used as part of expression + * Caution: this means expression such as A = uniform(s1) * uniform(s2) will give invalid result, + * since second call of gaussian(s2) makes gaussian(s1) invalid + * A = gaussian(s1)*B+C; is correct; use one gaussian/uniform in each expression + * \param shape shape of the tensor + * \return a temporal expression storing standard uniform [0,1) + * \tparam dim dimension of tensor + */ + template + inline expr::ReshapeExp, DType, dim, 1> + uniform(Shape shape) { + buffer_.Resize(Shape1(shape.Size())); + this->SampleUniform(&buffer_, 0.0f, 1.0f); + return expr::reshape(buffer_, shape); + } + + std::mt19937 &GetRndEngine() { + return rnd_engine_; + } + + private: +#if MSHADOW_IN_CXX11 + /*! \brief use c++11 random engine. */ + std::mt19937 rnd_engine_; + /*! \brief random number seed used in random engine */ + unsigned rseed_; + +#else + + /*! \brief random number seed used by PRNG */ + unsigned rseed_; + // functions + template + inline void SampleUniform(Tensor *dst, + DType a = 0.0f, DType b = 1.0f) { + if (dst->CheckContiguous()) { + this->GenUniform(dst->dptr_, dst->shape_.Size(), a, b); + } else { + Tensor mat = dst->FlatTo2D(); + for (index_t i = 0; i < mat.size(0); ++i) { + this->GenUniform(mat[i].dptr_, mat.size(1), a, b); + } + } + } + template + inline void SampleGaussian(Tensor *dst, + DType mu = 0.0f, DType sigma = 1.0f) { + if (sigma <= 0.0f) { + *dst = mu; return; + } + if (dst->CheckContiguous()) { + this->GenGaussian(dst->dptr_, dst->shape_.Size(), mu, sigma); + } else { + Tensor mat = dst->FlatTo2D(); + for (index_t i = 0; i < mat.size(0); ++i) { + this->GenGaussian(mat[i].dptr_, mat.size(1), mu, sigma); + } + } + } + inline void GenUniform(float *dptr, index_t size, float a, float b) { + for (index_t j = 0; j < size; ++j) { + dptr[j] = static_cast(RandNext()) * (b - a) + a; + } + } + inline void GenUniform(double *dptr, index_t size, double a, double b) { + for (index_t j = 0; j < size; ++j) { + dptr[j] = static_cast(RandNext()) * (b - a) + a; + } + } + inline void GenGaussian(float *dptr, index_t size, float mu, float sigma) { + this->GenGaussianX(dptr, size, mu, sigma); + } + inline void GenGaussian(double *dptr, index_t size, double mu, double sigma) { + this->GenGaussianX(dptr, size, mu, sigma); + } + inline void GenGaussianX(DType *dptr, index_t size, DType mu, DType sigma) { + DType g1 = 0.0f, g2 = 0.0f; + for (index_t j = 0; j < size; ++j) { + if ((j & 1) == 0) { + this->SampleNormal2D(&g1, &g2); + dptr[j] = mu + g1 * sigma; + } else { + dptr[j] = mu + g2 * sigma; + } + } + } + /*! \brief get next random number from rand */ + inline DType RandNext(void) { + return static_cast(rand_r(&rseed_)) / + (static_cast(RAND_MAX) + 1.0f); + } + /*! \brief return a real numer uniform in (0,1) */ + inline DType RandNext2(void) { + return (static_cast(rand_r(&rseed_)) + 1.0f) / + (static_cast(RAND_MAX) + 2.0f); + } + /*! + * \brief sample iid xx,yy ~N(0,1) + * \param xx first gaussian output + * \param yy second gaussian output + */ + inline void SampleNormal2D(DType *xx_, DType *yy_) { + DType &xx = *xx_, &yy = *yy_; + DType x, y, s; + do { + x = 2.0f * RandNext2() - 1.0f; + y = 2.0f * RandNext2() - 1.0f; + s = x * x + y * y; + } while (s >= 1.0f || s == 0.0f); + DType t = std::sqrt(-2.0f * std::log(s) / s); + xx = x * t; yy = y * t; + } +#endif + /*! \brief temporal space used to store random numbers */ + TensorContainer buffer_; +}; // class Random + +// only allow GPU PRNG when cuda is enabled +#if MSHADOW_USE_CUDA +/*! \brief GPU random number generator */ +template +class Random { + public: + /*! + * \brief constructor of random engine + * \param seed random number seed + */ + explicit Random(int seed) : gen_(NULL) { + this->Seed(seed); + buffer_.Resize(Shape1(kRandBufferSize)); + } + ~Random(void) MSHADOW_THROW_EXCEPTION { + DeleteGenerator(); + } + /*! + * \brief set the stream of computation + * \param stream computation stream + */ + inline void set_stream(Stream *stream) { + curandStatus_t status; + status = curandSetStream(gen_, Stream::GetStream(stream)); + + CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "set_stream CURAND failed"; + } + /*! + * \brief seed random number generator using this seed + * \param seed seed of prng + */ + inline void Seed(int seed) { + // Create a new rng, either initially or if the RNG type can't reset its offset. + if (gen_ == NULL || (curandSetGeneratorOffset(gen_, 0ULL) != CURAND_STATUS_SUCCESS)) + CreateGenerator(); + // Now set the seed. + curandStatus_t status; + status = curandSetPseudoRandomGeneratorSeed(gen_, static_cast(seed)); + CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Set CURAND seed failed."; + } + /*! + * \brief get a set of random integers + */ + inline void GetRandInt(const Tensor& dst) { + curandStatus_t status = curandGenerate(gen_, dst.dptr_, dst.size(0)); + CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen rand ints failed."; + } + /*! + * \brief generate data from uniform [a,b) + * \param dst destination + * \param a lower bound of uniform + * \param b upper bound of uniform + * \tparam dim dimension of tensor + */ + template + inline void SampleUniform(Tensor *dst, + DType a = 0.0f, DType b = 1.0f); + + /*! + * \brief generate data from standard gaussian + * \param dst destination + * \param mu mean variable + * \param sigma standard deviation + * \tparam dim dimension of tensor + */ + template + inline void SampleGaussian(Tensor *dst, + DType mu = 0.0f, DType sigma = 1.0f); + /*! + * \brief return a temporal expression storing standard gaussian random variables + * the temporal tensor is only valid before next call of gaussian or uniform + * can be used as part of expression + * Caution: this means expression such as A = gaussian(s1) * gaussian(s2) will give invalid result, + * since second call of gaussian(s2) makes gaussian(s1) invalid + * A = gaussian(s1)*B+C; is correct; use one gaussian/uniform in each expression + * \param shape shape of the tensor + * \param mu mean + * \param sigma variance + * \return a temporal expression storing standard gaussian random variables + * \tparam dim dimension of tensor + */ + template + inline expr::ReshapeExp, DType, dim, 1> + gaussian(Shape shape, DType mu = 0.0f, DType sigma = 1.0f); + /*! + * \brief return a temporal expression storing standard uniform [0,1) + * the temporal tensor is only valid before next call of gaussian or uniform + * can be used as part of expression + * Caution: this means expression such as A = gaussian(s1) * gaussian(s2) will give invalid result, + * since second call of gaussian(s2) makes gaussian(s1) invalid + * A = gaussian(s1)*B+C; is correct; use one gaussian/uniform in each expression + * \param shape shape of the tensor + * \return a temporal expression storing standard uniform [0,1) + * \tparam dim dimension of tensor + */ + template + inline expr::ReshapeExp, DType, dim, 1> + uniform(Shape shape); + + private: + inline void GenGaussian(float *dptr, size_t size, float mu, float sigma) { + curandStatus_t status; + status = curandGenerateNormal(gen_, dptr, size, mu, sigma); + CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Normal float failed." + << " size = " << size + << ",mu = " << mu + << ",sigma = " << sigma; + } + inline void GenGaussian(double *dptr, size_t size, double mu, double sigma) { + curandStatus_t status; + status = curandGenerateNormalDouble(gen_, dptr, size, mu, sigma); + CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Normal double failed." + << " size = " << size + << ",mu = " << mu + << ",sigma = " << sigma; + } + inline void GenUniform(float *dptr, size_t size) { + curandStatus_t status; + status = curandGenerateUniform(gen_, dptr, size); + CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Uniform float failed." + << " size = " << size; + } + inline void GenUniform(double *dptr, size_t size) { + curandStatus_t status; + status = curandGenerateUniformDouble(gen_, dptr, size); + CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Uniform double failed." + << " size = " << size; + } + inline void CreateGenerator() { + if (gen_ != NULL) + DeleteGenerator(); + curandStatus_t status; + status = curandCreateGenerator(&gen_, CURAND_RNG_PSEUDO_DEFAULT); + CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Cannot create CURAND Generator"; + } + inline void DeleteGenerator() { + if (gen_ != NULL) { + curandStatus_t status; + status = curandDestroyGenerator(gen_); + CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Destory CURAND Gen failed"; + gen_ = NULL; + } + } + /*! \brief random number generator */ + curandGenerator_t gen_; + /*! \brief templ buffer */ + TensorContainer buffer_; +}; // class Random +#endif // MSHADOW_USE_CUDA + +#ifdef __CUDACC__ +// implementations that depends on cuda kernels +template +template +inline void Random::SampleUniform( + Tensor *dst, DType a, DType b) { + if (a == 0.0f && b == 1.0f) { + if (dst->CheckContiguous()) { + this->GenUniform(dst->dptr_, dst->shape_.Size()); + } else { + *dst = this->uniform(dst->shape_); + } + } else { + *dst = this->uniform(dst->shape_) * (b - a) + a; + } +} +template +template +inline void Random::SampleGaussian( + Tensor *dst, DType mu, DType sigma) { + // We need to check whether the shape size is even since CuRand supports only normal distribution + // generation of even number of elements. + if (dst->CheckContiguous() && (dst->shape_.Size() % 2 == 0)) { + this->GenGaussian(dst->dptr_, dst->shape_.Size(), mu, sigma); + } else { + *dst = this->gaussian(dst->shape_, mu, sigma); + } +} + +template +template +inline expr::ReshapeExp, DType, dim, 1> +Random::gaussian(Shape shape, DType mu, DType sigma) { + size_t aligned_sz = ((shape.Size() + 1UL) >> 1) << 1; + // allocate alligned size + buffer_.Resize(Shape1(aligned_sz)); + buffer_.Resize(Shape1(shape.Size())); + this->GenGaussian(buffer_.dptr_, aligned_sz, mu, sigma); + return expr::reshape(buffer_, shape); +} + +template +template +inline expr::ReshapeExp, DType, dim, 1> +Random::uniform(Shape shape) { + buffer_.Resize(Shape1(shape.Size())); + this->GenUniform(buffer_.dptr_, buffer_.size(0)); + return expr::reshape(buffer_, shape); +} +#endif // __CUDACC__ +} // namespace mshadow +#endif // MSHADOW_RANDOM_H_ diff --git a/include/mshadow/stream_gpu-inl.h b/include/mshadow/stream_gpu-inl.h new file mode 100644 index 000000000000..d20d2d788526 --- /dev/null +++ b/include/mshadow/stream_gpu-inl.h @@ -0,0 +1,212 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file stream_gpu-inl.h + * \brief implementation of GPU code + * \author Bing Xu, Tianqi Chen + */ +#ifndef MSHADOW_STREAM_GPU_INL_H_ +#define MSHADOW_STREAM_GPU_INL_H_ +#include +#include "./base.h" +#include "./tensor.h" +#include "./logging.h" + +namespace mshadow { +#if MSHADOW_USE_CUDA == 1 +// Stream alocation +// actual implementation of GPU stream in CUDA +template<> +struct Stream { + /*! \brief handle state */ + enum HandleState { + NoHandle = 0, + OwnHandle = 1, + }; + /*! \brief cudaStream */ + cudaStream_t stream_; + /*! \brief cublas handle */ + cublasHandle_t blas_handle_; + /*! \brief cusolver handle */ + #if MSHADOW_USE_CUSOLVER == 1 + cusolverDnHandle_t solver_handle_; + #endif + /*! \brief cudnn handle */ + #if MSHADOW_USE_CUDNN == 1 + cudnnHandle_t dnn_handle_; + #endif + /*! \brief cublas handle ownership */ + HandleState blas_handle_ownership_; + /*! \brief cusolver handle ownership */ + HandleState solver_handle_ownership_; + /*! \brief cudnn handle ownership */ + HandleState dnn_handle_ownership_; + /*! \brief cudaDeviceProp */ + cudaDeviceProp prop; + /*! \brief dev id */ + int dev_id; + + Stream(void) + : stream_(0) + , blas_handle_(0) +#if MSHADOW_USE_CUDNN == 1 + , dnn_handle_(0) +#endif + , blas_handle_ownership_(NoHandle) + , solver_handle_ownership_(NoHandle) + , dnn_handle_ownership_(NoHandle) {} + /*! + * \brief wait for all the computation associated + * with this stream to complete + */ + inline void Wait(void) { + MSHADOW_CUDA_CALL(cudaStreamSynchronize(stream_)); + } + /*! + * \brief query whether the the stream is idle + * \return true if the stream is idle and all the job have been completed + */ + inline bool CheckIdle(void) { + cudaError_t err = cudaStreamQuery(stream_); + if (err == cudaSuccess) return true; + if (err == cudaErrorNotReady) return false; + LOG(FATAL) << cudaGetErrorString(err); + return false; + } + /*! + * \brief returns actual cudaStream_t given an input GPU stream pointer + * \param stream pointer to GPU stream + */ + inline static cudaStream_t GetStream(Stream *stream) { + if (stream == NULL) { +#if MSHADOW_FORCE_STREAM + LOG(FATAL) << "Default GPU stream was used when MSHADOW_FORCE_STREAM was on"; +#endif + return 0; + } else { + return stream->stream_; + } + } + /*! + * \brief return actual cublasHandle + * \param pointer to GPU stream + */ + inline static cublasHandle_t GetBlasHandle(Stream *stream) { + if (stream == NULL) { + return 0; + } else { + CHECK_NE(stream->blas_handle_ownership_, NoHandle) + << "No handle exist in source stream"; + return stream->blas_handle_; + } + } + /*! \brief Destory cublas handle if own it */ + inline void DestroyBlasHandle() { + if (blas_handle_ownership_ == OwnHandle) { + cublasStatus_t err = cublasDestroy(blas_handle_); + blas_handle_ownership_ = NoHandle; + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Destory cublas handle failed"; + } + } + /*! \brief Destory original blas handle and create a new one */ + inline void CreateBlasHandle() { + this->DestroyBlasHandle(); + cublasStatus_t err = cublasCreate(&blas_handle_); + blas_handle_ownership_ = OwnHandle; + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Create cublas handle failed"; + } +#if MSHADOW_USE_CUSOLVER == 1 + inline static cusolverDnHandle_t GetSolverHandle(Stream *stream) { + if (stream == NULL) { + return 0; + } else { + CHECK_NE(stream->solver_handle_ownership_, NoHandle) << "No handle exist in source stream"; + return stream->solver_handle_; + } + } +#endif + inline void DestroySolverHandle() { +#if MSHADOW_USE_CUSOLVER == 1 + if (solver_handle_ownership_ == OwnHandle) { + cusolverStatus_t err = cusolverDnDestroy(solver_handle_); + CHECK_EQ(err, CUSOLVER_STATUS_SUCCESS) << "Destory cusolver handle failed"; + } +#endif + } + inline void CreateSolverHandle() { +#if MSHADOW_USE_CUSOLVER == 1 + this->DestroySolverHandle(); + cusolverStatus_t err = cusolverDnCreate(&solver_handle_); + CHECK_EQ(err, CUSOLVER_STATUS_SUCCESS) << "Create cusolver handle failed"; + err = cusolverDnSetStream(solver_handle_, stream_); + CHECK_EQ(err, CUSOLVER_STATUS_SUCCESS) << "Setting cusolver stream failed"; + this->solver_handle_ownership_ = OwnHandle; +#endif + } +// #if MSHADOW_USE_CUDNN && defined(__CUDACC__) +#if MSHADOW_USE_CUDNN == 1 + inline static cudnnHandle_t GetDnnHandle(Stream *stream) { + if (stream == NULL) { + return 0; + } else { + CHECK_NE(stream->dnn_handle_ownership_, NoHandle) << "No handle exist in source stream"; + return stream->dnn_handle_; + } + } +#endif + inline void DestroyDnnHandle() { +// #if MSHADOW_USE_CUDNN && defined(__CUDACC__) +#if MSHADOW_USE_CUDNN == 1 + if (dnn_handle_ownership_ == OwnHandle) { + cudnnStatus_t err = cudnnDestroy(dnn_handle_); + this->dnn_handle_ownership_ = NoHandle; + CHECK_EQ(err, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(err); + } +#endif + } + inline void CreateDnnHandle() { +// #if MSHADOW_USE_CUDNN == 1 && defined(__CUDACC__) +#if MSHADOW_USE_CUDNN == 1 + this->DestroyDnnHandle(); + cudnnStatus_t err = cudnnCreate(&dnn_handle_); + CHECK_EQ(err, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(err); + // At this point, we have the resource which may need to be freed + this->dnn_handle_ownership_ = OwnHandle; + err = cudnnSetStream(dnn_handle_, stream_); + CHECK_EQ(err, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(err); +#endif + } +}; +template<> +inline void DeleteStream(Stream *stream) { + if (stream) { + MSHADOW_CUDA_CALL(cudaStreamDestroy(stream->stream_)); + stream->DestroyBlasHandle(); + stream->DestroySolverHandle(); + stream->DestroyDnnHandle(); + delete stream; + } +} +template<> +inline Stream *NewStream(bool create_blas_handle, + bool create_dnn_handle, + int dev_id) { + // RAII on Cuda exception + struct StreamDeleter { void operator()(Stream *ptr) const { DeleteStream(ptr); } }; + std::unique_ptr, StreamDeleter> st(new Stream()); + MSHADOW_CUDA_CALL(cudaStreamCreate(&st->stream_)); + if (create_blas_handle) { + st->CreateBlasHandle(); + st->CreateSolverHandle(); + } + if (create_dnn_handle) { + st->CreateDnnHandle(); + } + st->dev_id = dev_id; + if (dev_id != -1) { + MSHADOW_CUDA_CALL(cudaGetDeviceProperties(&st->prop, dev_id)); + } + return st.release(); +} +#endif +} // namespace mshadow +#endif // MSHADOW_STREAM_GPU_INL_H_ diff --git a/include/mshadow/tensor.h b/include/mshadow/tensor.h new file mode 100755 index 000000000000..f74281d36693 --- /dev/null +++ b/include/mshadow/tensor.h @@ -0,0 +1,1078 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file tensor.h + * \brief header file of tensor data structure and functions + * This lib requires explicit memory allocation and de-allocation + * all the data structure Tensor, Tensor are like handles(pointers), + * no memory allocation is happening during calculation + * + * For STL style tensor, see tensor_container.h + * \author Bing Xu, Tianqi Chen + */ +#ifndef MSHADOW_TENSOR_H_ +#define MSHADOW_TENSOR_H_ +#include +#include +#include "./base.h" +#include "./expression.h" + +namespace mshadow { +/*! \brief device name CPU */ +struct cpu { + /*! \brief whether this device is CPU or not */ + static const bool kDevCPU = true; + /*! \brief device flag number, identifies this device */ + static const int kDevMask = 1 << 0; +}; +/*! \brief device name GPU */ +struct gpu { + /*! \brief whether this device is CPU or not */ + static const bool kDevCPU = false; + /*! \brief device flag number, identifies this device */ + static const int kDevMask = 1 << 1; +}; +template +struct Shape; + +/*! + * \brief allow string printing of the shape + * \param os the output stream + * \param shape the shape + * \return the ostream + */ +template +inline std::ostream &operator<<(std::ostream &os, const Shape &shape); // NOLINT(*) + +/*! + * \brief shape of a tensor + * \tparam dimension dimension of tensor + */ +template +struct Shape { + /*! \brief dimension of current shape */ + static const int kDimension = dimension; + /*! \brief dimension of current shape minus one */ + static const int kSubdim = dimension - 1; + /*! \brief storing the dimension information */ + index_t shape_[kDimension]; + /*! \brief default constructor, do nothing */ + MSHADOW_XINLINE Shape(void) {} + /*! \brief constuctor */ + MSHADOW_XINLINE Shape(const Shape &s) { + #pragma unroll + for (int i = 0; i < kDimension; ++i) { + this->shape_[i] = s[i]; + } + } + /*! + * \brief get corresponding index + * \param idx dimension index + * \return the corresponding dimension size + */ + MSHADOW_XINLINE index_t &operator[](index_t idx) { + return shape_[idx]; + } + /*! + * \brief get corresponding index + * \param idx dimension index + * \return the corresponding dimension size + */ + MSHADOW_XINLINE const index_t &operator[](index_t idx) const { + return shape_[idx]; + } + /*! + * \return whether two shape equals + * \param s the shape to compare against + */ + MSHADOW_XINLINE bool operator==(const Shape &s) const { + #pragma unroll + for (int i = 0; i < kDimension; ++i) { + if (s.shape_[i] != this->shape_[i]) return false; + } + return true; + } + /*! + * \return whether two shape not equal + * \param s the shape to compare against + */ + MSHADOW_XINLINE bool operator!=(const Shape &s) const { + return !(*this == s); + } + /*! + * flatten the tensor, return a 1D shape + * \return the flat 1d shape + */ + MSHADOW_XINLINE Shape<1> FlatTo1D(void) const { + Shape<1> s; + s[0] = this->Size(); + return s; + } + /*! + * flatten the higher dimension to second dimension, return a 2D shape + * \return the flat 2d shape + */ + MSHADOW_XINLINE Shape<2> FlatTo2D(void) const { + Shape<2> s; + s.shape_[1] = this->shape_[kDimension - 1]; + index_t ymax = 1; + #pragma unroll + for (int i = 0; i < kDimension - 1; ++i) { + ymax *= this->shape_[i]; + } + s.shape_[0] = ymax; + return s; + } + /*! \return number of valid elements */ + MSHADOW_XINLINE index_t Size(void) const { + index_t size = this->shape_[0]; + #pragma unroll + for (int i = 1; i < kDimension; ++i) { + size *= this->shape_[i]; + } + return size; + } + /*! + * \return product shape in [dimstart,dimend) + * \param dimstart start dimension + * \param dimend end dimension + */ + MSHADOW_XINLINE index_t ProdShape(int dimstart, int dimend) const { + index_t num = 1; + #pragma unroll + for (int i = dimstart; i < dimend; ++i) { + num *= this->shape_[i]; + } + return num; + } + /*! + * \brief get subshape that takes off largest dimension +v * \return subshape + */ + MSHADOW_XINLINE Shape SubShape(void) const { + Shape s; + // for cuda + #pragma unroll + for (int i = 0; i < kSubdim; ++i) { + s.shape_[i] = this->shape_[i + 1]; + } + return s; + } + /*! + * \brief slice the shape from start to end + * \tparam dimstart start dimension + * \tparam dimend end dimension + * \return the sliced shape + */ + template + MSHADOW_XINLINE Shape Slice(void) const { + Shape s; + #pragma unroll + for (int i = dimstart; i < dimend; ++i) { + s[i - dimstart] = this->shape_[i]; + } + return s; + } + //! \cond Doxygen_Suppress + template + friend std::ostream &operator<<(std::ostream &os, const Shape &shape); // NOLINT(*) + //! \endcond +}; // Shape +//------------------------------------------------ +// useful construction functions to generate shape +//------------------------------------------------- +/*! + * \brief construct a one dimension shape, stride will equal s0 + * \param s0 size of dimension 0 + * \return the shape construction + */ +MSHADOW_XINLINE Shape<1> Shape1(index_t s0) { + Shape<1> s; s[0] = s0; + return s; +} +/*! + * \brief construct a two dimension shape, stride will equal s0 + * \param s0 size of dimension 0 + * \param s1 size of dimension 1 + * \return the shape construction + */ +MSHADOW_XINLINE Shape<2> Shape2(index_t s0, index_t s1) { + Shape<2> s; s[0] = s0; s[1] = s1; + return s; +} +/*! + * \brief construct a three dimension shape, stride will equal s0 + * \param s0 size of dimension 0 + * \param s1 size of dimension 1 + * \param s2 size of dimension 2 + * \return the shape construction + */ +MSHADOW_XINLINE Shape<3> Shape3(index_t s0, index_t s1, index_t s2) { + Shape<3> s; + s[0] = s0; s[1] = s1; s[2] = s2; + return s; +} +/*! + * \brief construct a four dimension shape, stride will equal s0 + * \param s0 size of dimension 0 + * \param s1 size of dimension 1 + * \param s2 size of dimension 2 + * \param s3 size of dimension 3 + * \return the shape construction + */ +MSHADOW_XINLINE Shape<4> Shape4(index_t s0, index_t s1, + index_t s2, index_t s3) { + Shape<4> s; + s[0] = s0; s[1] = s1; s[2] = s2; s[3] = s3; + return s; +} +/*! +* \brief construct a five dimension shape, stride will equal s0 +* \param s0 size of dimension 0 +* \param s1 size of dimension 1 +* \param s2 size of dimension 2 +* \param s3 size of dimension 3 +* \param s4 size of dimension 4 +* \return the shape construction +*/ +MSHADOW_XINLINE Shape<5> Shape5(index_t s0, index_t s1, index_t s2, + index_t s3, index_t s4) { + Shape<5> s; + s[0] = s0; s[1] = s1; s[2] = s2; s[3] = s3; s[4] = s4; + return s; +} + +/*! +* \brief Convert shape in src_layout to shape in dst_layout +* \param src original shape +* \param src_layout layout of original shape +* \param dst_layout target layout +* \return shape in target layout +*/ +inline Shape<3> ConvertLayout(const Shape<3>& src, int src_layout, int dst_layout) { + Shape<3> dst; + switch (src_layout) { + case kNCW: + dst = src; + break; + case kNWC: + dst[0] = src[0]; + dst[1] = src[2]; + dst[2] = src[1]; + break; + default: + LOG(FATAL) << "Invalid layout for 3d shape " << src_layout; + } + switch (dst_layout) { + case kNCW: + return dst; + case kNWC: + { + index_t tmp = dst[1]; + dst[1] = dst[2]; + dst[2] = tmp; + } + break; + default: + LOG(FATAL) << "Invalid layout for 3d shape " << src_layout; + } + return dst; +} + +/*! +* \brief Convert shape in src_layout to shape in dst_layout +* \param src original shape +* \param src_layout layout of original shape +* \param dst_layout target layout +* \return shape in target layout +*/ +inline Shape<4> ConvertLayout(const Shape<4>& src, int src_layout, int dst_layout) { + Shape<4> dst; + switch (src_layout) { + case kNCHW: + dst = src; + break; + case kNHWC: + dst[0] = src[0]; + dst[2] = src[1]; + dst[3] = src[2]; + dst[1] = src[3]; + break; + default: + LOG(FATAL) << "Invalid layout for 4d shape " << src_layout; + dst = src; // fixes compiler warning + } + Shape<4> dst2; + switch (dst_layout) { + case kNCHW: + return dst; + case kNHWC: + dst2[0] = dst[0]; + dst2[1] = dst[2]; + dst2[2] = dst[3]; + dst2[3] = dst[1]; + break; + default: + LOG(FATAL) << "Invalid layout for 4d shape " << src_layout; + dst2 = src; // fixes compiler warning + } + return dst2; +} + +/*! +* \brief Convert shape in src_layout to shape in dst_layout +* \param src original shape +* \param src_layout layout of original shape +* \param dst_layout target layout +* \return shape in target layout +*/ +inline Shape<5> ConvertLayout(const Shape<5>& src, int src_layout, int dst_layout) { + Shape<5> dst; + switch (src_layout) { + case kNCDHW: + dst = src; + break; + case kNDHWC: + dst[0] = src[0]; + dst[2] = src[1]; + dst[3] = src[2]; + dst[4] = src[3]; + dst[1] = src[4]; + break; + default: + LOG(FATAL) << "Invalid layout for 5d shape " << src_layout; + } + Shape<5> dst2; + switch (dst_layout) { + case kNCDHW: + return dst; + case kNDHWC: + dst2[0] = dst[0]; + dst2[1] = dst[2]; + dst2[2] = dst[3]; + dst2[3] = dst[4]; + dst2[4] = dst[1]; + break; + default: + LOG(FATAL) << "Invalid layout for 5d shape " << src_layout; + } + return dst2; +} + +/*! + * \brief computaion stream structure, used for asynchronous computations + */ +template +struct Stream { + // this is only a dummy implementation for CPU + // for GPU, the actual implementation will be specialized in tensor_gpu-inl.h + /*! + * \brief wait for all the computations associated + * with this stream to complete + */ + inline void Wait(void) {} + /*! + * \brief query whether the the stream is idle + * \return true if the stream is idle and all the jobs have been completed + */ + inline bool CheckIdle(void) { + return true; + } + /*! \brief create a blas handle */ + inline void CreateBlasHandle() {} +}; +/*! + * \brief Tensor RValue, this is the super type of all kinds of possible tensors + * \tparam Container the tensor type + * \tparam Device which device the tensor is on + * \tparam dimension dimension of the tensor + * \tparam DType the type of elements in the tensor + */ +template +struct TRValue: public expr::RValueExp { +}; +// more compact template +/*! + * \brief general tensor + * \tparam Device which device the tensor is on + * \tparam dimension dimension of the tensor + * \tparam DType the type of elements in the tensor + */ +template +struct Tensor: public TRValue, + Device, dimension, DType> { + public: + //-------------------------------- + // struct memembers + //-------------------------------- + /*! \brief whether current type lies in cpu */ + static const bool kDevCPU = Device::kDevCPU; + /*! \brief dimension of subtype */ + static const int kSubdim = dimension - 1; + //-------------------------------- + // struct memembers + //-------------------------------- + /*! \brief pointer to the data */ + DType *dptr_; + /*! \brief shape of the tensor */ + Shape shape_; + /*! + * \brief storing the stride information in x dimension + * this is used to deal with pitch allocation in gpu or sse(align x dimension to 64bit) for efficiency + */ + index_t stride_; + /*! + * \brief stream where the computation lies + * stream is a device dependency concept where each computation + */ + Stream *stream_; + //-------------------------------- + // functions + //-------------------------------- + /*! \brief default constructor */ + MSHADOW_XINLINE Tensor(void) : stream_(NULL) {} + /*! \brief constructor from shape */ + MSHADOW_XINLINE Tensor(const Shape &shape) + : shape_(shape), stream_(NULL) {} + /*! \brief constructor from data pointer and shape, without stride */ + MSHADOW_XINLINE Tensor(DType *dptr, const Shape &shape) + : dptr_(dptr), shape_(shape), stride_(shape[kSubdim]), stream_(NULL) {} + /*! \brief constructor from data pointer and shape, without stride */ + MSHADOW_XINLINE Tensor(DType *dptr, const Shape &shape, + Stream *stream) + : dptr_(dptr), shape_(shape), stride_(shape[kSubdim]), stream_(stream) {} + /*! \brief constructor from data pointer and shape */ + MSHADOW_XINLINE Tensor(DType *dptr, + const Shape &shape, + index_t stride, Stream *stream) + : dptr_(dptr), shape_(shape), stride_(stride), stream_(stream) {} + /*! + * \brief set the stream to do computation of current tensor + * \param stream the computation stream + */ + inline void set_stream(Stream *stream) { + this->stream_ = stream; + } + /*! + * \return memory cost of the tensor, including the aligned x dimension + * \tparam startdim the starting dimension + */ + template + MSHADOW_XINLINE index_t MemSize(void) const { + index_t memsz = this->stride_; + #pragma unroll + for (int i = startdim; i < kSubdim; ++i) { + memsz *= this->shape_[i]; + } + return memsz; + } + /*! + * \return whether the tensor's memory is continuous + * x dimension same as stride + */ + MSHADOW_XINLINE bool CheckContiguous(void) const { + return this->shape_[dimension - 1] == stride_; + } + /*! + * \return memory cost of the tensor, including the aligned x dimension + */ + MSHADOW_XINLINE index_t MSize(void) const { + return this->MemSize<0>(); + } + /*! + * \brief return size of i-th dimension, start counting from highest dimension + * \param idx the dimension count from the highest dimensin + * \return the size + */ + MSHADOW_XINLINE index_t size(index_t idx) const { + return shape_[idx]; + } + /*! + * \brief flatten the tensor to 1 dimension + * \return tensor after flatten + */ + MSHADOW_XINLINE Tensor FlatTo1D(void) const { + return Tensor(dptr_, shape_.FlatTo1D(), stride_, stream_); + } + /*! + * \brief flatten the tensor to 2 dimension, collapse the higher dimensions together + * \return tensor after flatten + */ + MSHADOW_XINLINE Tensor FlatTo2D(void) const { + return Tensor(dptr_, shape_.FlatTo2D(), stride_, stream_); + } + /*! + * \brief get a element of dimension - 1 + * \param idx index + * \return the result tensor + */ + MSHADOW_XINLINE Tensor operator[](index_t idx) const { + return Tensor(dptr_ + this->MemSize<1>() * idx, + shape_.SubShape(), stride_, stream_); + } + /*! + * \brief slice the tensor in highest dimension [begin,end) + * \param begin begin position of slice + * \param end end position of slice + * \return tensor after slice + */ + MSHADOW_XINLINE Tensor + Slice(index_t begin, index_t end) const { + Shape s = this->shape_; + s[0] = end - begin; + return Tensor(dptr_ + this->MemSize<1>() * begin, + s, stride_, stream_); + } + /*!\brief implement the assignment of same type */ + inline Tensor & + operator=(const Tensor &exp) { + dptr_ = exp.dptr_; + shape_ = exp.shape_; + stride_ = exp.stride_; + stream_ = exp.stream_; + return *this; + } + /*!\brief functions to fit expression template */ + template + inline Tensor & + operator=(const expr::Exp &exp) { + return this->__assign(exp); + } + /*!\brief functions to fit expression template */ + inline Tensor &operator=(const DType &exp) { + return this->__assign(exp); + } +}; +/* + * respecialized class Tensor1D, thei is due to different implementation in operator[] + */ +template +struct Tensor: + public TRValue, Device, 1, DType> { + public: + DType *dptr_; + Shape<1> shape_; + index_t stride_; + Stream *stream_; + // constructor + MSHADOW_XINLINE Tensor(void) : stream_(NULL) {} + MSHADOW_XINLINE Tensor(const Shape<1> &shape) + : shape_(shape), stream_(NULL) {} + MSHADOW_XINLINE Tensor(DType *dptr, Shape<1> shape) + : dptr_(dptr), shape_(shape), stride_(shape[0]), stream_(NULL) {} + MSHADOW_XINLINE Tensor(DType *dptr, Shape<1> shape, Stream *stream) + : dptr_(dptr), shape_(shape), stride_(shape[0]), stream_(stream) {} + MSHADOW_XINLINE Tensor(DType *dptr, Shape<1> shape, + index_t stride, Stream *stream) + : dptr_(dptr), shape_(shape), stride_(stride), stream_(stream) {} + inline void set_stream(Stream *stream) { + this->stream_ = stream; + } + MSHADOW_XINLINE Tensor FlatTo1D(void) const { + return *this; + } + MSHADOW_XINLINE Tensor FlatTo2D(void) const { + return Tensor(dptr_, shape_.FlatTo2D(), stride_, stream_); + } + MSHADOW_XINLINE Tensor Slice(index_t begin, index_t end) const { + Shape<1> s; + s[0] = end - begin; + return Tensor(dptr_ + begin, s, s[0], stream_); + } + MSHADOW_XINLINE bool CheckContiguous(void) const { + return true; + } + MSHADOW_XINLINE index_t MSize(void) const { + return shape_[0]; + } + MSHADOW_XINLINE index_t size(index_t i) const { + return shape_[0]; + } + MSHADOW_XINLINE DType &operator[](index_t idx) { + return dptr_[idx]; + } + MSHADOW_XINLINE const DType &operator[](index_t idx) const { + return dptr_[idx]; + } + /*!\brief implement the assignment of same type */ + inline Tensor & + operator=(const Tensor &exp) { + dptr_ = exp.dptr_; + shape_ = exp.shape_; + stride_ = exp.stride_; + stream_ = exp.stream_; + return *this; + } + template + inline Tensor & + operator=(const expr::Exp &exp) { + return this->__assign(exp); + } + inline Tensor &operator=(const DType &exp) { + return this->__assign(exp); + } +}; +//------------------------ +// Function Declarations +//----------------------- +/*! + * \brief initialize tensor engine, used to call intialization functions of dependent libs + * this function should be called before all GPU tensor operations, + * for using tensors in CPU, this call is actually not needed + * \param device_id GPU device id to be choosed + * \tparam Device the device type + */ +template +inline void InitTensorEngine(int device_id = 0); +/*! + * \brief Shutdown tensor engine on current device + * this function should be called after all GPU tensor operations, + * for using tensors in CPU, this call is actually not needed + * \tparam Device the device type + */ +template +inline void ShutdownTensorEngine(void); +/*! + * \brief set the device of current thread to work on + * \param devid the device id + * \tparam Device the device type + */ +template +inline void SetDevice(int devid); +/*! + * \brief create a new stream from system + * \param create_blas_handle whether create blas & cusolver handle in stream + * \param create_dnn_handle whether create cudnn handle in stream + * \param dev_id device id + * \return a pointer to the created stream + * \tparam Device the device type + */ +template +inline Stream *NewStream(bool create_blas_handle, + bool create_dnn_handle, + int dev_id = -1); +/*! \brief default behavior: create cublas handle + * \param dev_id device id + * \return a pointer to the created stream + */ +template +inline Stream *NewStream(int dev_id) { + return NewStream(true, false, dev_id); +} +/*! + * \brief delete the computing stream + * \param stream the stream parameter to be deleted + */ +template +inline void DeleteStream(Stream *stream); +/*! + * \brief CPU/CPU: allocate space for CTensor, according to the shape in the obj + * this function is responsible to set the stride_ in each obj.shape + * \param obj the tensor object, with shape specified + * \param pad whether padding dimension 0, to make last dimension aligned, + * padding may help improve efficiency of matrix multiplications + * if true, will allocate space with stride_ that may not equals shape[0] + * if false, will allocate continuous space + * \tparam dim specify the dim of tensor + * \tparam DType type of element in tensor + */ +template +inline void AllocSpace(Tensor *obj, + bool pad = MSHADOW_ALLOC_PAD); +/*! + * \brief CPU/CPU: allocate space for CTensor, according to the shape in the obj + * this function is responsible to set the stride_ in each obj.shape + * \param obj the tensor object, with shape specified + * \param pad whether padding dimension 0, to make last dimension aligned, + * padding may help improve efficiency of matrix multiplications + * if true, will allocate space with stride_ that may not equals shape[0] + * if false, will allocate continuous space + * \tparam dim specify the dim of tensor + * \tparam DType type of element in tensor + */ +template +inline void AllocSpace(Tensor *obj, + bool pad = MSHADOW_ALLOC_PAD); +/*! + * \brief CPU/GPU: free the space of tensor, will set obj.dptr to NULL + * \param obj the tensor object + * \tparam dim specify the dim of tensor + * \tparam DType type of element in tensor + */ +template +inline void FreeSpace(Tensor *obj); +/*! + * \brief CPU/GPU: free the space of tensor, will set obj.dptr to NULL + * \param obj the tensor object + * \tparam dim specify the dim of tensor + * \tparam DType type of element in tensor + */ +template +inline void FreeSpace(Tensor *obj); +/*! + * \brief CPU/GPU: short cut to allocate and initialize a Tensor + * \param shape: shape of tensor + * \param initv: initialization value + * \param pad : padding option + * \param stream : stream of tensor + * \tparam Device device of tensor + * \tparam DType type of element in tensor + * \tparam dim dimention of tensor + * \return a new allocated tensor + * \sa AllocSpace + */ +template +inline Tensor NewTensor(const Shape &shape, + DType initv, + bool pad = MSHADOW_ALLOC_PAD, + Stream *stream = NULL); +/*! + * \brief copy data from one tensor to another, with same shape + * \param dst target tensor + * \param src source tensor + * \param stream the stream, when specified, the copy can exhibit asynchronize behavior + * \tparam dim specify the dim of tensor + * \tparam DType type of element in tensor + */ +template +inline void Copy(Tensor dst, + const Tensor &src, + Stream *stream = NULL); +/*! + * \brief copy data from one tensor to another, with same shape + * \param dst target tensor + * \param src source tensor + * \param stream the stream, when specified, the copy can exhibit asynchronize behavior + * \tparam dim specify the dim of tensor + * \tparam DType type of element in tensor + */ +template +inline void Copy(Tensor dst, + const Tensor &src, + Stream *stream = NULL); +/*! + * \brief copy data from one tensor to another, with same shape + * \param dst target tensor + * \param src source tensor + * \param stream the stream, when specified, the copy can exhibit asynchronize behavior + * \tparam dim specify the dim of tensor + * \tparam DType type of element in tensor + */ +template +inline void Copy(Tensor dst, + const Tensor &src, + Stream *stream = NULL); +/*! + * \brief copy data from one tensor to another, with same shape + * \param dst target tensor + * \param src source tensor + * \param stream the stream, when specified, the copy can exhibit asynchronize behavior + * \tparam dim specify the dim of tensor + * \tparam DType type of element in tensor + */ +template +inline void Copy(Tensor dst, + const Tensor &src, + Stream *stream = NULL); +/*! + * \brief CPU/GPU: normalize softmax: dst[i][j] = exp(energy[i][j]) /(sum_j exp(energy[i][j])) + * \param dst destination + * \param energy input energy + */ +template +inline void Softmax(Tensor dst, const Tensor &energy); +/*! + * \brief CPU/GPU: normalize softmax: dst[i][j] = exp(energy[i][j]) /(sum_j exp(energy[i][j])) + * \param dst destination + * \param energy input energy + */ +template +inline void Softmax(Tensor dst, const Tensor &energy); + +/*! + * \brief CPU/GPU: softmax gradient + * \param dst destination + * \param src source output + * \param label label info + */ +template +inline void SoftmaxGrad(Tensor dst, + const Tensor &src, + const Tensor &label); +/*! + * \brief CPU/GPU: softmax gradient + * \param dst destination + * \param src source output + * \param label label info + */ +template +inline void SoftmaxGrad(const Tensor &dst, + const Tensor &src, + const Tensor &label); +/*! + * \brief CPU/GPU: Gradient accumulate of embedding matrix. + dst[index[i]] += src[i] + Called when the featuredim of src is much larger than the batchsize + * \param dst destination + * \param index index to take + * \param src source output + */ +template +inline void AddTakeGrad(Tensor dst, + const Tensor& index, + const Tensor &src); +/*! + * \brief CPU/GPU: Gradient accumulate of embedding matrix. + dst[index[i]] += src[i] + Called when the featuredim of src is much larger than the batchsize + * \param dst destination + * \param index index to take + * \param src source output + */ +template +inline void AddTakeGrad(Tensor dst, + const Tensor& index, + const Tensor &src); +/*! + * \brief CPU/GPU: Gradient accumulate of embedding matrix. + dst[sorted[i]] += src[index[i]] + Called when the batchsize of src is larger than the featuredim + * \param dst destination + * \param sorted the sorted indices + * \param index original index of the sorted indices + * \param src source output + */ +template +inline void AddTakeGradLargeBatch(Tensor dst, + const Tensor& sorted, + const Tensor& index, + const Tensor &src); +/*! + * \brief CPU/GPU: Gradient accumulate of embedding matrix. + dst[sorted[i]] += src[index[i]] + Called when the batchsize of src is larger than the featuredim + * \param dst destination + * \param sorted the sorted indices + * \param index original index of the sorted indices + * \param src source output + */ +template +inline void AddTakeGradLargeBatch(Tensor dst, + const Tensor& sorted, + const Tensor& index, + const Tensor &src); +/*! + * \brief CPU/GPU: Fill the values of the destination matrix to specific rows in the source matrix. + dst[index[i]] = src[i] + Will use atomicAdd in the inner implementation and the result may not be deterministic. + * \param dst destination + * \param index the index to accumulate value + * \param src source output + */ +template +inline void IndexFill(Tensor dst, + const Tensor& index, + const Tensor &src); +/*! + * \brief CPU/GPU: Fill the values of the destination matrix to specific rows in the source matrix. + dst[index[i]] = src[i] + Will use atomicAdd in the inner implementation and the result may not be deterministic. + * \param dst destination + * \param index the index to accumulate value + * \param src source output + */ +template +inline void IndexFill(Tensor dst, + const Tensor& index, + const Tensor &src); +/*! + * \brief CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!) + * \param keys the keys to sort + * \param values the values that sorts w.r.t the key + * \param is_ascend whether to sort key in ascending order + */ +template +inline void SortByKey(Tensor keys, Tensor values, + bool is_ascend = true); +/*! + * \brief CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!) + * \param keys the keys to sort + * \param values the values that sorts w.r.t the key + * \param is_ascend whether to sort key in ascending order + */ +template +inline void SortByKey(Tensor keys, Tensor values, + bool is_ascend = true); +/*! + * \brief CPU/GPU: Sort the keys within each segment. (Stable sort is performed!) + Segments is defined as an ascending ordered vector like [0, 0, 0, 1, 1, 2, 3, 3, 3,...] + We sort separately the keys labeled by 0 and 1, 2, 3, etc. + Currently only supports sorting in ascending order !! + * \param values the data to sort + * \param segments segment indicator + */ +template +inline void VectorizedSort(Tensor values, Tensor segments); + +// function declarations to support expression, no need to understand them +// these functions do not need to be directly used +/*! + * \brief CPU/GPU: map a expression to a tensor, this function calls MapPlan + * \tparam Saver specify storage method + * \tparam R specifies the storage type of the tensor + * \tparam dim dim of the tensor, during usage, there is no need to specify this parameter + * \tparam DType the type of elements in the tensor + * \tparam E specifies the expression type, not need to specify this parameter during usage + * \tparam etype expression type + * \param dst destination + * \param exp expression + * \sa namespace mshadow:sv, mshadow::op, mshadow::expr + */ +template +inline void MapExp(TRValue *dst, + const expr::Exp &exp); +/*! + * \brief CPU/GPU: map a expression to a tensor, this function calls MapPlan + * \tparam Saver specify storage method + * \tparam R specifies the storage type of the tensor + * \tparam dim dim of the tensor, during usage, there is no need to specify this parameter + * \tparam DType the type of elements in the tensor + * \tparam E specifies the expression type, not need to specify this parameter during usage + * \tparam etype expression type + * \param dst destination + * \param exp expression + * \sa namespace mshadow:sv, mshadow::op, mshadow::expr + */ +template +inline void MapExp(TRValue *dst, + const expr::Exp &exp); +/*! + * \brief CPU/GPU: map a expression, do reduction to 1D Tensor in lowest dimension (dimension 0) + * \tparam Saver specify storage method + * \tparam Reducer specify a reducer method + * \tparam R specifies the storage type of the tensor + * \tparam DType the type of elements in the tensor + * \tparam E specifies the expression type, not need to specify this parameter during usage + * \tparam etype expression type + * \param dst destination + * \param exp expression + * \param scale scale the result before save + * \sa namespace mshadow:sv, mshadow::op, mshadow::red, mshadow::expr + */ +template +inline void MapReduceKeepLowest(TRValue *dst, + const expr::Exp &exp, + DType scale = 1); +/*! + * \brief CPU/GPU: map a expression, do reduction to 1D Tensor in lowest dimension (dimension 0) + * \tparam Saver specify storage method + * \tparam Reducer specify a reducer method + * \tparam R specifies the storage type of the tensor + * \tparam DType the type of elements in the tensor + * \tparam E specifies the expression type, not need to specify this parameter during usage + * \tparam etype expression type + * \param dst destination + * \param exp expression + * \param scale scale the result before save + * \sa namespace mshadow:sv, mshadow::op, mshadow::red, mshadow::expr + */ +template +inline void MapReduceKeepLowest(TRValue *dst, + const expr::Exp &exp, + DType scale = 1); +/*! + * \brief CPU/GPU: map a expression, do reduction to 1D Tensor in third dimension (dimension 2) + * \tparam Saver specify storage method + * \tparam Reducer specify a reducer method + * \tparam R specifies the storage type of the tensor + * \tparam DType the type of elements in the tensor + * \tparam dimkeep the target dimension to be kept, should be larger than 0, for 0, use MapReduceKeepLowest + * \tparam E specifies the expression type, not need to specify this parameter during usage + * \tparam etype expression type + * \param dst destination + * \param exp expression + * \param scale scale the result before save + * \sa namespace mshadow:sv, mshadow::op, mshadow::red, mshadow::expr + */ +template +inline void MapReduceKeepHighDim(TRValue *dst, + const expr::Exp &exp, + DType scale = 1); +/*! + * \brief CPU/GPU: map a expression, do reduction to 1D Tensor in third dimension (dimension 2) + * \tparam Saver specify storage method + * \tparam Reducer specify a reducer method + * \tparam R specifies the storage type of the tensor + * \tparam DType the type of elements in the tensor + * \tparam dimkeep the target dimension to be kept, should be larger than 0, for 0, use MapReduceKeepLowest + * \tparam E specifies the expression type, not need to specify this parameter during usage + * \tparam etype expression type + * \param dst destination + * \param exp expression + * \param scale scale the result before save + * \sa namespace mshadow:sv, mshadow::op, mshadow::red, mshadow::expr + */ +template +inline void MapReduceKeepHighDim(TRValue *dst, + const expr::Exp &exp, + DType scale = 1); +/*! + * \brief CPU/GPU: 1 dimension vector dot + * \param dst Length 1 vector, used to hold the result. + * \param lhs Left operand vector + * \param rhs Right operand vector + */ +template +inline void VectorDot(Tensor dst, + const Tensor &lhs, + const Tensor &rhs); +/*! + * \brief CPU/GPU: dst = alpha * op(lhs) op(rhs) + beta * dst + * \param dst Length 3 tensor, used to hold the result + * \param lhs Left operand vector + * \param rhs Right operand vector + * \param alpha multiplier of op(lhs)op(rhs) + * \param beta multiplier of dst + * \param workspace Workspace for casting DType* to DType** (batched-view), must have size >= 3 * batch_size + */ +template +inline void BatchGEMM(Tensor dst, + const Tensor &lhs, + const Tensor &rhs, + DType alpha, + DType beta, + Tensor workspace); +} // namespace mshadow +// include headers +#include "./stream_gpu-inl.h" +#include "./extension.h" +#include "./expr_engine-inl.h" +#include "./tensor_cpu-inl.h" +#include "./tensor_gpu-inl.h" +#include "./io.h" +#include "./tensor_container.h" +#include "./random.h" +// add definition of scalar related operators +#ifdef MSHADOW_SCALAR_ + #error "MSHADOW_SCALAR_ must not be defined" +#endif +// enumerate all the scalar data type we aim to be good at +#define MSHADOW_SCALAR_ float +#include "./expr_scalar-inl.h" +#undef MSHADOW_SCALAR_ +#define MSHADOW_SCALAR_ double +#include "./expr_scalar-inl.h" +#undef MSHADOW_SCALAR_ +#define MSHADOW_SCALAR_ int +#include "./expr_scalar-inl.h" +#undef MSHADOW_SCALAR_ +#define MSHADOW_SCALAR_ mshadow::half::half_t +#include "./expr_scalar-inl.h" +#undef MSHADOW_SCALAR_ +#endif // MSHADOW_TENSOR_H_ diff --git a/include/mshadow/tensor_container.h b/include/mshadow/tensor_container.h new file mode 100644 index 000000000000..b4df68e8e3a5 --- /dev/null +++ b/include/mshadow/tensor_container.h @@ -0,0 +1,208 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file tensor_container.h + * \brief tensor container that does memory allocation and resize like STL + * \author Tianqi Chen + */ +#ifndef MSHADOW_TENSOR_CONTAINER_H_ +#define MSHADOW_TENSOR_CONTAINER_H_ +#include "./tensor.h" +#include "./io.h" + +namespace mshadow { +/*! + * \brief tensor container that does memory allocation and resize like STL, + * use it to save the lines of FreeSpace in class. + * Do not abuse it, efficiency can come from pre-allocation and no re-allocation + * + * \tparam Device which device the tensor is on + * \tparam dimension dimension of the tensor + */ +template +class TensorContainer: public Tensor { + public: + /*! + * \brief constructor + * \param pad whether use padding alignment in space allocation + */ + explicit TensorContainer(bool pad = MSHADOW_ALLOC_PAD) { + this->pad_ = pad; + this->dptr_ = data_.dptr_ = NULL; + this->shape_[0] = 0; + this->stride_ = 0; + this->data_.stride_ = 0; + this->data_.shape_[0] = 0; + } + /*! + * \brief constructor + * \param shape intial shape + */ + explicit TensorContainer(const Shape &shape) { + this->pad_ = MSHADOW_ALLOC_PAD; + data_.dptr_ = NULL; + this->AllocByShape(shape); + } + /*! + * \brief constructor + * \param shape intial shape + * \param initv intial value + */ + explicit TensorContainer(const Shape &shape, DType initv) { + this->pad_ = MSHADOW_ALLOC_PAD; + data_.dptr_ = NULL; + this->AllocByShape(shape); + (*this) = initv; + } + /*! + * \brief copy constructor + * \param src source value + */ + TensorContainer + (const TensorContainer &src) + : pad_(src.pad_) { + this->dptr_ = data_.dptr_ = NULL; + this->shape_[0] = 0; + this->stride_ = 0; + this->data_.stride_ = 0; + this->data_.shape_[0] = 0; + this->stream_ = src.stream_; + if (src.dptr_ != NULL) { + this->AllocByShape(src.shape_); + mshadow::Copy(*this, src, this->stream_); + } + } + ~TensorContainer(void) { + this->Release(); + } + /*! + * \brief resize the container to given shape, content is NOT preserved + * \param shape target shape + */ + inline void Resize(const Shape &shape) { + Shape<2> s2 = shape.FlatTo2D(); + if (s2.shape_[1] > data_.stride_ || s2.shape_[0] > data_.size(0)) { + this->AllocByShape(shape); + } else { + this->shape_ = shape; + if (this->pad_) { + this->stride_ = data_.stride_; + } else { + this->stride_ = s2.shape_[1]; + } + } + } + /*! + * \brief resize the container to given shape, and initialize, content is NOT preserved + * \param shape target shape + * \param initv initialization value + */ + inline void Resize(const Shape &shape, DType initv) { + this->Resize(shape); + (*this) = initv; + } + /*! \brief set whether padding is allowed in tensor */ + inline void set_pad(bool pad) { + this->pad_ = pad; + } + /*! + * \brief save by binary format + * \param fo output binary stream + * \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. + */ + template + inline void SaveBinary(TStream &fo) const { // NOLINT(*) + mshadow::SaveBinary(fo, *this); + } + /*! + * \brief load by binary format, a temp Tensor storage will be allocated + * \param fi input binary stream + * \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. + */ + template + inline void LoadBinary(TStream &fi) { // NOLINT(*) + Tensor tmp; + mshadow::LoadBinary(fi, &tmp, false); + this->Resize(tmp.shape_); + Stream stream; + Copy(*this, tmp, &stream); + mshadow::FreeSpace(&tmp); + } + /*! + * \brief assign operator from TensorContainer + * \param src source value + * \return reference of self + */ + inline TensorContainer &operator= + (const TensorContainer &src) { + this->pad_ = src.pad_; + this->stream_ = src.stream_; + if (src.dptr_ != NULL) { + this->Resize(src.shape_); + mshadow::Copy(*this, src, this->stream_); + } + return *this; + } + /*!\brief functions to fit expression template */ + inline Tensor &operator=(DType s) { + return this->__assign(s); + } + /*!\brief functions to fit expression template */ + template + inline Tensor & + operator=(const expr::Exp &exp) { + return this->__assign(exp); + } + /*!\brief functions to fit expression template */ + template + inline Tensor & + operator=(const expr::Exp &exp) { + return this->__assign(exp); + } + /*!\brief functions to fit expression template */ + template + inline Tensor & + operator=(const expr::Exp &exp) { + return this->__assign(exp); + } + /*! + * \brief Release the llocated space, + * The TensorContainer is still functionable, + * but will restart allocating space when Resize is called. + */ + inline void Release(void) { + if (data_.dptr_ != NULL) { + this->shape_[0] = 0; + this->stride_ = 0; + this->data_.stride_ = 0; + this->data_.shape_[0] = 0; + try { + mshadow::FreeSpace(&data_); + } catch (const dmlc::Error &e) { + this->dptr_ = data_.dptr_ = NULL; + throw e; + } + this->dptr_ = data_.dptr_ = NULL; + } + } + + private: + /*! \brief whether we do padding in the space */ + bool pad_; + /*! \brief the shape of data_ is actually current data space */ + Tensor data_; + + inline void AllocByShape(const Shape& shape) { + if (data_.dptr_ != NULL) this->Release(); + data_.shape_ = shape.FlatTo2D(); + mshadow::AllocSpace(&data_, pad_); + this->dptr_ = data_.dptr_; + this->shape_ = shape; + if (this->pad_) { + this->stride_ = data_.stride_; + } else { + this->stride_ = data_.size(1); + } + } +}; +} // namespace mshadow +#endif // MSHADOW_TENSOR_CONTAINER_H_ diff --git a/include/mshadow/tensor_cpu-inl.h b/include/mshadow/tensor_cpu-inl.h new file mode 100755 index 000000000000..ab5f9a68df14 --- /dev/null +++ b/include/mshadow/tensor_cpu-inl.h @@ -0,0 +1,627 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file tensor_cpu-inl.h + * \brief implementation of CPU host code + * \author Bing Xu, Tianqi Chen + */ +#ifndef MSHADOW_TENSOR_CPU_INL_H_ +#define MSHADOW_TENSOR_CPU_INL_H_ +#include +#include +#include +#include +#include "./base.h" +#include "./tensor.h" +#include "./packet-inl.h" +#include "./dot_engine-inl.h" + +namespace mshadow { +template<> +inline void InitTensorEngine(int dev_id) { +} +template<> +inline void ShutdownTensorEngine(void) { +} + +template<> +inline void SetDevice(int devid) { +} +template<> +inline Stream *NewStream(bool create_blas_handle, + bool create_dnn_handle, + int dev_id) { + return new Stream(); +} +template<> +inline void DeleteStream(Stream *stream) { + delete stream; +} + +template +inline std::ostream &operator<<(std::ostream &os, const Shape &shape) { // NOLINT(*) + os << '('; + for (int i = 0; i < ndim; ++i) { + if (i != 0) os << ','; + os << shape[i]; + } + // python style tuple + if (ndim == 1) os << ','; + os << ')'; + return os; +} + +template +inline void *AllocHost_(size_t size); +template +inline void FreeHost_(void * dptr); + +#ifdef __CUDACC__ +template<> +inline void *AllocHost_(size_t size) { + void *dptr; + MSHADOW_CUDA_CALL(cudaMallocHost(&dptr, size, cudaHostAllocPortable)); + return dptr; +} +template<> +inline void FreeHost_(void *dptr) { + MSHADOW_CUDA_CALL(cudaFreeHost(dptr)); +} +#endif + +template<> +inline void *AllocHost_(size_t size) { + size_t pitch; + return packet::AlignedMallocPitch(&pitch, size, 1); +} +template<> +inline void FreeHost_(void *dptr) { + packet::AlignedFree(dptr); +} + +template +inline void AllocHost(Tensor *obj) { + obj->stride_ = obj->size(dim - 1); + CHECK_EQ(obj->CheckContiguous(), true) << "AllocHost"; + void *dptr = AllocHost_(obj->MSize() * sizeof(DType)); + obj->dptr_ = reinterpret_cast(dptr); +} +template +inline void FreeHost(Tensor *obj) { + if (obj->dptr_ == NULL) { + LOG(FATAL) << "FreeHost:: double free"; + } + FreeHost_(obj->dptr_); + obj->dptr_ = NULL; +} + +template +inline void AllocSpace(Tensor *obj, bool pad) { + size_t pitch; + void *dptr; + if (pad) { + dptr = packet::AlignedMallocPitch + (&pitch, obj->size(dim - 1) * sizeof(DType), obj->shape_.FlatTo2D()[0]); + obj->stride_ = static_cast(pitch / sizeof(DType)); + } else { + obj->stride_ = obj->size(dim - 1); + dptr = packet::AlignedMallocPitch + (&pitch, obj->shape_.Size() * sizeof(DType), 1); + } + obj->dptr_ = reinterpret_cast(dptr); +} +template +inline Tensor +NewTensor(const Shape &shape, DType initv, bool pad, Stream *stream_) { + Tensor obj(shape); + obj.stream_ = stream_; + AllocSpace(&obj, pad); + MapExp(&obj, expr::ScalarExp(initv)); + return obj; +} +template +inline void FreeSpace(Tensor *obj) { + packet::AlignedFree(obj->dptr_); + obj->dptr_ = NULL; +} +template +inline void Copy(Tensor _dst, + const Tensor &_src, + Stream *stream) { + CHECK_EQ(_dst.shape_, _src.shape_) + << "Copy:shape mismatch:" << _dst.shape_ << " vs " << _src.shape_; + if (_dst.CheckContiguous() && _src.CheckContiguous()) { + memcpy(_dst.dptr_, _src.dptr_, sizeof(DType) * _dst.shape_.Size()); + } else { + Tensor dst = _dst.FlatTo2D(); + Tensor src = _src.FlatTo2D(); + for (index_t y = 0; y < dst.size(0); ++y) { + memcpy(dst[y].dptr_, src[y].dptr_, sizeof(DType) * dst.size(1)); + } + } +} + +template +inline void MapPlan(TRValue *dst, + const expr::Plan &plan) { + Shape<2> shape = expr::ShapeCheck::Check(dst->self()).FlatTo2D(); + expr::Plan dplan = expr::MakePlan(dst->self()); +#ifndef __CUDACC__ + #pragma omp parallel for +#endif + // temp remove openmp, as default setting throttles CPU + for (openmp_index_t y = 0; y < shape[0]; ++y) { + for (index_t x = 0; x < shape[1]; ++x) { + // trust your compiler! -_- they will optimize it + Saver::template Save(dplan.REval(y, x), plan.Eval(y, x)); + } + } +} +// code to handle SSE optimization +template +struct MapExpCPUEngine { + inline static void Map(TRValue *dst, + const expr::Exp &exp) { + MapPlan(dst, MakePlan(exp.self())); + } +}; + +template +struct MapExpCPUEngine, + dim, DType, E, etype> { + inline static void Map(Tensor *dst, + const expr::Exp &exp) { + if (expr::PacketAlignCheck::Check(exp.self()) && + expr::PacketAlignCheck, MSHADOW_DEFAULT_PACKET>::Check(*dst)) { + expr::MapPacketPlan(dst->self(), + expr::MakePacketPlan(exp.self())); + } else { + MapPlan(dst, MakePlan(exp.self())); + } + } +}; + + +template +inline void MapExp(TRValue *dst, + const expr::Exp &exp) { + expr::TypeCheckPass::kMapPass> + ::Error_All_Tensor_in_Exp_Must_Have_Same_Type(); + Shape eshape = expr::ShapeCheck::Check(exp.self()); + Shape dshape = expr::ShapeCheck::Check(dst->self()); + CHECK(eshape[0] == 0 || eshape == dshape) + << "Assignment: Shape of Tensors are not consistent with target, " + << "eshape: " << eshape << " dshape:" << dshape; + MapExpCPUEngine::kPass, + Saver, R, dim, DType, E, etype> + ::Map(dst->ptrself(), exp); +} + +template +inline void MapReduceKeepLowest(TRValue *dst, + const expr::Exp &exp, + DType scale) { + expr::TypeCheckPass::kRedPass> + ::Error_TypeCheck_Not_Pass_For_Reduce_Exp(); + Shape<2> eshape = expr::ShapeCheck::kDim, E> + ::Check(exp.self()).FlatTo2D(); + Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self()); + CHECK_EQ(eshape[1], dshape[0]) << "MapReduceKeepLowest::reduction dimension do not match"; + CHECK_NE(eshape[0], 0U) << "can not reduce over empty tensor"; + // execution + expr::Plan dplan = MakePlan(dst->self()); + expr::Plan splan = MakePlan(exp.self()); +#ifndef __CUDACC__ + #pragma omp parallel for +#endif + for (openmp_index_t x = 0; x < eshape[1]; ++x) { + DType res = splan.Eval(0, x); + for (index_t y = 1; y < eshape[0]; ++y) { + Reducer::Reduce(res, splan.Eval(y, x)); + } + Saver::template Save(dplan.REval(0, x), res * scale); + } +} + +template +inline void MapReduceKeepHighDim(TRValue *dst, + const expr::Exp &exp, + DType scale) { + expr::TypeCheckPass::kRedPass> + ::Error_TypeCheck_Not_Pass_For_Reduce_Exp(); + typedef Shape::kDim> EShape; + EShape eshape = expr::ShapeCheck::kDim, E> + ::Check(exp.self()); + Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self()); + CHECK_EQ(eshape[dimkeep], dshape[0]) + << "MapReduceKeepHighDim::reduction dimension do not match"; + // use equvalent form + Shape<4> pshape = Shape4(eshape.ProdShape(0, dimkeep), + eshape[dimkeep], + eshape.ProdShape(dimkeep + 1, EShape::kSubdim), + eshape[EShape::kSubdim]); + // execution + expr::Plan dplan = MakePlan(dst->self()); + expr::Plan splan = MakePlan(exp.self()); +#ifndef __CUDACC__ + #pragma omp parallel for +#endif + for (openmp_index_t c = 0; c < pshape[1]; ++c) { + DType res; Reducer::SetInitValue(res); + for (index_t n = 0; n < pshape[0]; ++n) { + DType tres; Reducer::SetInitValue(tres); + for (index_t y = 0; y < pshape[2]; ++y) { + for (index_t x = 0; x < pshape[3]; ++x) { + Reducer::Reduce(tres, + splan.Eval((n * pshape[1] + c) * pshape[2] + y, x)); + } + } + Reducer::Reduce(res, tres); + } + Saver::template Save(dplan.REval(0, c), DType(res * scale)); + } +} + +template +inline void Softmax(Tensor dst, + const Tensor &energy) { + DType mmax = energy[0]; + for (index_t x = 1; x < dst.size(0); ++x) { + if (mmax < energy[x]) mmax = energy[x]; + } + DType sum = DType(0.0f); + for (index_t x = 0; x < dst.size(0); ++x) { + dst[x] = std::exp(energy[x] - mmax); + sum += dst[x]; + } + for (index_t x = 0; x < dst.size(0); ++x) { + dst[x] /= sum; + } +} + +template +inline void SoftmaxGrad(Tensor dst, + const Tensor &src, + const Tensor &label) { +#pragma omp parallel for + for (openmp_index_t y = 0; y < dst.size(0); ++y) { + const index_t k = static_cast(label[y]); + for (index_t x = 0; x < dst.size(1); ++x) { + if (x == k) { + dst[y][k] = src[y][k] - 1.0f; + } else { + dst[y][x] = src[y][x]; + } + } + } +} + +template +inline void SmoothSoftmaxGrad(Tensor dst, + const Tensor &src, + const Tensor &label, + const float alpha) { + const float smooth_grad = (alpha / (dst.size(1) - 1)); +#pragma omp parallel for + for (openmp_index_t y = 0; y < dst.size(0); ++y) { + const index_t k = static_cast(label[y]); + for (index_t x = 0; x < dst.size(1); ++x) { + if (x == k) { + dst[y][k] = src[y][k] - 1.0f + alpha; + } else { + dst[y][x] = src[y][x] - smooth_grad; + } + } + } +} + + +template +inline void SoftmaxGrad(Tensor dst, + const Tensor &src, + const Tensor &label, + const DType &ignore_label) { +#pragma omp parallel for + for (openmp_index_t y = 0; y < dst.size(0); ++y) { + const int k = static_cast(label[y]); + for (int x = 0; x < static_cast(dst.size(1)); ++x) { + if (static_cast(ignore_label) == k) { + dst[y][x] = 0.0f; + } else { + if (x == k) { + dst[y][k] = src[y][k] - 1.0f; + } else { + dst[y][x] = src[y][x]; + } + } + } + } +} + +template +inline void SmoothSoftmaxGrad(Tensor dst, + const Tensor &src, + const Tensor &label, + const DType &ignore_label, + const float alpha) { + const float smooth_grad = (alpha / (dst.size(1) - 1)); +#pragma omp parallel for + for (openmp_index_t y = 0; y < dst.size(0); ++y) { + const int k = static_cast(label[y]); + for (int x = 0; x < static_cast(dst.size(1)); ++x) { + if (static_cast(ignore_label) == k) { + dst[y][x] = 0.0f; + } else { + if (x == k) { + dst[y][k] = src[y][k] - 1.0f + alpha; + } else { + dst[y][x] = src[y][x] - smooth_grad; + } + } + } + } +} + +template +inline void SoftmaxGrad(Tensor dst, + const Tensor &src, + const Tensor &label) { +#pragma omp parallel for + for (openmp_index_t n = 0; n < dst.size(2); ++n) { + for (index_t y = 0; y < dst.size(0); ++y) { + const int k = static_cast(label[y][n]); + for (int x = 0; x < static_cast(dst.size(1)); ++x) { + if (x == k) { + dst[y][k][n] = src[y][k][n] - 1.0f; + } else { + dst[y][x][n] = src[y][x][n]; + } + } + } + } +} + +template +inline void SmoothSoftmaxGrad(Tensor dst, + const Tensor &src, + const Tensor &label, + const float alpha) { + const float smooth_grad = (alpha / (dst.size(1) - 1)); +#pragma omp parallel for + for (openmp_index_t n = 0; n < dst.size(2); ++n) { + for (index_t y = 0; y < dst.size(0); ++y) { + const int k = static_cast(label[y][n]); + for (int x = 0; x < static_cast(dst.size(1)); ++x) { + if (x == k) { + dst[y][k][n] = src[y][k][n] - 1.0f + alpha; + } else { + dst[y][x][n] = src[y][x][n] - smooth_grad; + } + } + } + } +} + +template +inline void SoftmaxGrad(Tensor dst, + const Tensor &src, + const Tensor &label, + const DType &ignore_label) { +#pragma omp parallel for + for (openmp_index_t n = 0; n < dst.size(2); ++n) { + for (index_t y = 0; y < dst.size(0); ++y) { + const int k = static_cast(label[y][n]); + if (k == static_cast(ignore_label)) { + for (int x = 0; x < static_cast(dst.size(1)); ++x) { + dst[y][x][n] = DType(0.0f); + } + } else { + for (int x = 0; x < static_cast(dst.size(1)); ++x) { + if (x == k) { + dst[y][k][n] = src[y][k][n] - 1.0f; + } else { + dst[y][x][n] = src[y][x][n]; + } + } + } + } + } +} + +template +inline void SmoothSoftmaxGrad(Tensor dst, + const Tensor &src, + const Tensor &label, + const DType &ignore_label, + const float alpha) { + const float smooth_grad = (alpha / (dst.size(1) - 1)); +#pragma omp parallel for + for (openmp_index_t n = 0; n < dst.size(2); ++n) { + for (index_t y = 0; y < dst.size(0); ++y) { + const int k = static_cast(label[y][n]); + if (k == static_cast(ignore_label)) { + for (int x = 0; x < static_cast(dst.size(1)); ++x) { + dst[y][x][n] = DType(0.0f); + } + } else { + for (int x = 0; x < static_cast(dst.size(1)); ++x) { + if (x == k) { + dst[y][k][n] = src[y][k][n] - 1.0f + alpha; + } else { + dst[y][x][n] = src[y][x][n] - smooth_grad; + } + } + } + } + } +} + +template +inline void Softmax(Tensor dst, + const Tensor &energy) { + CHECK_EQ(dst.shape_, energy.shape_) << "Softmax: shape mismatch"; +#pragma omp parallel for + for (openmp_index_t y = 0; y < dst.size(0); ++y) { + Softmax(dst[y], energy[y]); + } +} + +template +inline void Softmax(Tensor dst, + const Tensor &energy) { + CHECK_EQ(dst.shape_, energy.shape_) << "Softmax: shape mismatch"; +#pragma omp parallel for + for (openmp_index_t y = 0; y < dst.size(0); ++y) { + for (index_t n = 0; n < dst.size(2); ++n) { + DType mmax = energy[y][0][n]; + for (index_t x = 1; x < dst.size(1); ++x) { + if (mmax < energy[y][x][n]) mmax = energy[y][x][n]; + } + DType sum = DType(0.0f); + for (index_t x = 0; x < dst.size(1); ++x) { + dst[y][x][n] = std::exp(energy[y][x][n] - mmax); + sum += dst[y][x][n]; + } + for (index_t x = 0; x < dst.size(1); ++x) { + dst[y][x][n] /= sum; + } + } + } +} + +template +inline void AddTakeGrad(Tensor dst, + const Tensor& index, + const Tensor &src) { + const int K = dst.shape_[0]; + for (index_t y = 0; y < index.size(0); ++y) { + int j = index[y]; + if (j <= 0) j = 0; + else if (j >= K) j = K - 1; + dst[j] += src[y]; + } +} + +template +inline void AddTakeGradLargeBatch(Tensor dst, + const Tensor& sorted, + const Tensor& index, + const Tensor &src) { + for (index_t y = 0; y < sorted.size(0); ++y) { + dst[sorted[y]] += src[index[y]]; + } +} + +template +inline void IndexFill(Tensor dst, + const Tensor& index, + const Tensor &src) { + for (index_t y = 0; y < index.size(0); ++y) { + for (index_t j = 0; j < src.size(1); j++) { + dst[index[y]][j] = src[y][j]; + } + } +} + +template +inline void SortByKey(Tensor keys, Tensor values, + bool is_ascend) { + CHECK_EQ(keys.CheckContiguous(), true); + CHECK_EQ(values.CheckContiguous(), true); + CHECK_EQ(keys.size(0), values.size(0)) + << "The sizes of key/value are not equal! keys_size: " << keys.size(0) + << "values_size: " << values.size(0); + std::vector idx(keys.size(0)); + std::vector keys_vec(keys.size(0)); + std::vector values_vec(values.size(0)); + for (int i = 0; i < keys.size(0); i++) { + idx[i] = i; + keys_vec[i] = keys[i]; + values_vec[i] = values[i]; + } + if (is_ascend) { + std::stable_sort(idx.begin(), idx.end(), + [&keys_vec](size_t i1, size_t i2) + {return keys_vec[i1] < keys_vec[i2]; }); + } else { + std::stable_sort(idx.begin(), idx.end(), + [&keys_vec](size_t i1, size_t i2) + {return keys_vec[i1] > keys_vec[i2]; }); + } + for (index_t i = 0; i < values.size(0); i++) { + keys[i] = keys_vec[idx[i]]; + values[i] = values_vec[idx[i]]; + } +} + +template +inline void VectorizedSort(Tensor values, Tensor segments) { + // We can sort each segments using two stable sorts + SortByKey(values, segments, true); + SortByKey(segments, values, true); +} + +// blas related +template +inline void VectorDot(Tensor dst, + const Tensor &lhs, + const Tensor &rhs) { + CHECK_EQ(lhs.size(0), rhs.size(0)) + << "VectorDot: Shape mismatch"; + CHECK_EQ(dst.size(0), 1U) + << "VectorDot: expect dst to be scalar"; + expr::BLASEngine::SetStream(lhs.stream_); + mshadow::expr::BLASEngine::dot( + lhs.stream_, lhs.size(0), lhs.dptr_, 1, rhs.dptr_, 1, dst.dptr_); +} + +template +inline void BatchGEMM(Tensor dst, + const Tensor &lhs, + const Tensor &rhs, + DType alpha, + DType beta, + Tensor workspace) { + index_t batch_size = dst.shape_[0]; + expr::BLASEngine::SetStream(dst.stream_); + Shape<3> sleft = transpose_left ? Shape3(lhs.shape_[0], lhs.shape_[2], lhs.shape_[1]) + : lhs.shape_; + Shape<3> sright = transpose_right ? Shape3(rhs.shape_[0], rhs.shape_[2], rhs.shape_[1]) + : rhs.shape_; + CHECK_EQ(dst.CheckContiguous(), true); + CHECK_EQ(lhs.CheckContiguous(), true); + CHECK_EQ(rhs.CheckContiguous(), true); + CHECK(sleft[0] == batch_size && sright[0] == batch_size) + << "BatchGEMM: batchsize must be equal." + << "dst: " << dst.shape_ << "\n" + << "lhs: " << sleft << "\n" + << "rhs: " << sright << "\n"; + CHECK(dst.size(1) == sleft[1] && dst.size(2) == sright[2] && sleft[2] == sright[1]) + << "BatchGEMM: matrix shape mismatch" + << "dst: " << dst.shape_ << "\n" + << "lhs: " << sleft << "\n" + << "rhs: " << sright << "\n"; + CHECK(workspace.size(0) >= 3 * batch_size) + << "Workspace Size must be bigger than " << 3 * batch_size; + CHECK_EQ(workspace.CheckContiguous(), true); + // use column major argument to compatible with most BLAS + expr::BLASEngine::batched_gemm + (dst.stream_, + transpose_right, transpose_left, + transpose_right ? rhs.size(1) : rhs.size(2), + transpose_left ? lhs.size(2) : lhs.size(1), + transpose_right ? rhs.size(2) : rhs.size(1), + alpha, + rhs.dptr_, rhs.stride_, + lhs.dptr_, lhs.stride_, + beta, + dst.dptr_, dst.stride_, batch_size, + workspace.dptr_); +} +} // namespace mshadow +#endif // MSHADOW_TENSOR_CPU_INL_H_ diff --git a/include/mshadow/tensor_gpu-inl.h b/include/mshadow/tensor_gpu-inl.h new file mode 100755 index 000000000000..94fdb0527e72 --- /dev/null +++ b/include/mshadow/tensor_gpu-inl.h @@ -0,0 +1,245 @@ +/*! + * Copyright (c) 2014 by Contributors + * \file tensor_gpu-inl.h + * \brief implementation of GPU host code + * \author Bing Xu, Tianqi Chen + */ +#ifndef MSHADOW_TENSOR_GPU_INL_H_ +#define MSHADOW_TENSOR_GPU_INL_H_ +#include "./base.h" +#include "./tensor.h" + +namespace mshadow { +#if MSHADOW_USE_CUDA +template<> +inline void InitTensorEngine(int dev_id) { + cudaDeviceProp prop; + int device_id = 0; + int device_count = 0; + cudaGetDeviceCount(&device_count); + CHECK_GT(device_count, 0) << "Cannot find CUDA device. Please check CUDA-Configuration"; + if (dev_id < 0) { + device_id = 0; + } else { + device_id = dev_id; + } + CHECK_LT(device_id, device_count) << "Incorrect Device ID"; + MSHADOW_CUDA_CALL(cudaSetDevice(device_id)); + MSHADOW_CUDA_CALL(cudaGetDeviceProperties(&prop, device_id)); +} +template<> +inline void ShutdownTensorEngine(void) { +} +template<> +inline void SetDevice(int devid) { + MSHADOW_CUDA_CALL(cudaSetDevice(devid)); +} +template +inline void AllocSpace(Tensor *obj, bool pad) { + size_t pitch; + // common choice for cuda mem align unit is 32 + if (pad && obj->size(dim - 1) >= MSHADOW_MIN_PAD_RATIO * 32) { + MSHADOW_CUDA_CALL(cudaMallocPitch(reinterpret_cast(&(obj->dptr_)), &pitch, + obj->size(dim - 1) * sizeof(DType), + obj->shape_.FlatTo2D()[0])); + obj->stride_ = static_cast(pitch / sizeof(DType)); + } else { + obj->stride_ = obj->size(dim - 1); + MSHADOW_CUDA_CALL(cudaMallocPitch(reinterpret_cast(&(obj->dptr_)), &pitch, + obj->shape_.Size() * sizeof(DType), 1)); + } +} +template +inline void FreeSpace(Tensor *obj) { + MSHADOW_CUDA_CALL(cudaFree(obj->dptr_)); + obj->dptr_ = NULL; +} +template +inline void Copy(Tensor _dst, + Tensor _src, + cudaMemcpyKind kind, + Stream *stream) { + CHECK_EQ(_dst.shape_, _src.shape_) << "Copy:shape mismatch"; + Tensor dst = _dst.FlatTo2D(); + Tensor src = _src.FlatTo2D(); + MSHADOW_CUDA_CALL(cudaMemcpy2DAsync(dst.dptr_, dst.stride_ * sizeof(DType), + src.dptr_, src.stride_ * sizeof(DType), + dst.size(1) * sizeof(DType), + dst.size(0), kind, + Stream::GetStream(stream))); + // use synchronize call behavior for zero stream + if (stream == NULL) { + MSHADOW_CUDA_CALL(cudaStreamSynchronize(0)); + } +} +template +inline void Copy(Tensor dst, + const Tensor &src, + Stream *stream) { + Copy(dst, src, cudaMemcpyDeviceToHost, stream); +} +template +inline void Copy(Tensor dst, + const Tensor &src, + Stream *stream) { + Copy(dst, src, cudaMemcpyDeviceToDevice, stream); +} +template +inline void Copy(Tensor dst, + const Tensor &src, + Stream *stream) { + Copy(dst, src, cudaMemcpyHostToDevice, stream); +} +#endif // MSHADOW_USE_CUDA +} // namespace mshadow + +// the following part is included only if compiler is nvcc +#ifdef __CUDACC__ +#include "./cuda/tensor_gpu-inl.cuh" + +namespace mshadow { +template +inline void MapExp(TRValue *dst, + const expr::Exp &exp) { + expr::TypeCheckPass::kMapPass> + ::Error_All_Tensor_in_Exp_Must_Have_Same_Type(); + Shape eshape = expr::ShapeCheck::Check(exp.self()); + Shape dshape = expr::ShapeCheck::Check(dst->self()); + CHECK(eshape[0] == 0 || eshape == dshape) + << "Assignment: Shape of Tensors are not consistent with target, " + << "eshape: " << eshape << " dshape:" << dshape; + cuda::MapPlan(MakePlan(dst->self()), + MakePlan(exp.self()), + dshape.FlatTo2D(), + Stream::GetStream(expr::StreamInfo::Get(dst->self()))); +} + +template +inline void MapReduceKeepLowest(TRValue *dst, + const expr::Exp &exp, + DType scale) { + expr::TypeCheckPass::kRedPass> + ::Error_TypeCheck_Not_Pass_For_Reduce_Exp(); + Shape<2> eshape = expr::ShapeCheck::kDim, E> + ::Check(exp.self()).FlatTo2D(); + Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self()); + CHECK_EQ(eshape[1], dshape[0]) << "MapReduceKeepLowest::reduction dimension do not match"; + CHECK_NE(eshape[0], 0U) << "can not reduce over empty tensor"; + cuda::MapReduceKeepLowest + (MakePlan(dst->self()), MakePlan(exp.self()), scale, eshape, + Stream::GetStream(expr::StreamInfo::Get(dst->self()))); +} + +template +inline void MapReduceKeepHighDim(TRValue *dst, + const expr::Exp &exp, + DType scale) { + expr::TypeCheckPass::kRedPass> + ::Error_TypeCheck_Not_Pass_For_Reduce_Exp(); + typedef Shape::kDim> EShape; + EShape eshape = expr::ShapeCheck::kDim, E> + ::Check(exp.self()); + Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self()); + CHECK_EQ(eshape[dimkeep], dshape[0]) << "MapReduceKeepHighDim::reduction dimension do not match"; + // use equvalent form + Shape<4> pshape = Shape4(eshape.ProdShape(0, dimkeep), + eshape[dimkeep], + eshape.ProdShape(dimkeep + 1, EShape::kSubdim), + eshape[EShape::kSubdim]); + // call equavalent map red dim 2 + cuda::MapReduceKeepDim1 + (MakePlan(dst->self()), MakePlan(exp.self()), scale, pshape, + Stream::GetStream(expr::StreamInfo::Get(dst->self()))); +} +template +inline void Softmax(Tensor dst, + const Tensor& src) { + cuda::Softmax(dst, src); +} + +template +inline void Softmax(Tensor dst, + const Tensor& src) { + cuda::Softmax(dst, src); +} + +template +inline void SoftmaxGrad(const Tensor &dst, + const Tensor &src, + const Tensor &label) { + cuda::SoftmaxGrad(dst, src, label); +} + +template +inline void SmoothSoftmaxGrad(const Tensor &dst, + const Tensor &src, + const Tensor &label, + const float alpha) { + cuda::SmoothSoftmaxGrad(dst, src, label, alpha); +} + +template +inline void SoftmaxGrad(const Tensor &dst, + const Tensor &src, + const Tensor &label, + const DType &ignore_label) { + cuda::SoftmaxGrad(dst, src, label, ignore_label); +} + +template +inline void SmoothSoftmaxGrad(const Tensor &dst, + const Tensor &src, + const Tensor &label, + const DType &ignore_label, + const float alpha) { + cuda::SmoothSoftmaxGrad(dst, src, label, ignore_label, alpha); +} + +template +inline void SoftmaxGrad(const Tensor &dst, + const Tensor &src, + const Tensor &label) { + cuda::SoftmaxGrad(dst, src, label); +} + +template +inline void SoftmaxGrad(const Tensor &dst, + const Tensor &src, + const Tensor &label, + const DType &ignore_label) { + cuda::SoftmaxGrad(dst, src, label, ignore_label); +} + +template +inline void AddTakeGrad(Tensor dst, + const Tensor& index, + const Tensor &src) { + cuda::AddTakeGrad(dst, index, src); +} + +template +inline void AddTakeGradLargeBatch(Tensor dst, + const Tensor& sorted, + const Tensor& index, + const Tensor &src) { + cuda::AddTakeGradLargeBatch(dst, sorted, index, src); +} + +template +inline void SortByKey(Tensor keys, Tensor values, + bool is_ascend) { + cuda::SortByKey(keys, values, is_ascend); +} + +template +inline void IndexFill(Tensor dst, + const Tensor& index, + const Tensor &src) { + cuda::IndexFill(dst, index, src); +} +} // namespace mshadow +#endif // __CUDACC__ +#endif // MSHADOW_TENSOR_GPU_INL_H_ diff --git a/include/nnvm/base.h b/include/nnvm/base.h new file mode 100644 index 000000000000..449bd2f4626e --- /dev/null +++ b/include/nnvm/base.h @@ -0,0 +1,35 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file nnvm/base.h + * \brief Configuration of nnvm as well as basic data structure. + */ +#ifndef NNVM_BASE_H_ +#define NNVM_BASE_H_ + +#include +#include +#include +#include +#include +#include +#include + +namespace nnvm { + +/*! \brief any type */ +using dmlc::any; + +/*! \brief array_veiw type */ +using dmlc::array_view; + +/*!\brief getter function of any type */ +using dmlc::get; + +} // namespace nnvm + +// describe op registration point +#define NNVM_STRINGIZE_DETAIL(x) #x +#define NNVM_STRINGIZE(x) NNVM_STRINGIZE_DETAIL(x) +#define NNVM_DESCRIBE(...) describe(__VA_ARGS__ "\n\nFrom:" __FILE__ ":" NNVM_STRINGIZE(__LINE__)) +#define NNVM_ADD_FILELINE "\n\nDefined in " __FILE__ ":L" NNVM_STRINGIZE(__LINE__) +#endif // NNVM_BASE_H_ diff --git a/include/nnvm/c_api.h b/include/nnvm/c_api.h new file mode 100644 index 000000000000..daf9b564f3fa --- /dev/null +++ b/include/nnvm/c_api.h @@ -0,0 +1,388 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file nnvm/c_api.h + * \brief C API of NNVM symbolic construction and pass. + * Enables construction and transformation of Graph + * in any other host languages. + */ +#ifndef NNVM_C_API_H_ +#define NNVM_C_API_H_ + +/*! \brief NNVM_DLL prefix for windows */ +#ifdef _WIN32 +#ifdef NNVM_EXPORTS +#define NNVM_DLL __declspec(dllexport) +#else +#define NNVM_DLL __declspec(dllimport) +#endif +#else +#define NNVM_DLL +#endif + +/*! \brief manually define unsigned int */ +typedef unsigned int nn_uint; + +/*! \brief handle to a function that takes param and creates symbol */ +typedef void *OpHandle; +/*! \brief handle to a symbol that can be bind as operator */ +typedef void *SymbolHandle; +/*! \brief handle to Graph */ +typedef void *GraphHandle; + +#ifdef __cplusplus +extern "C" { +#endif +/*! + * \brief Set the last error message needed by C API + * \param msg The error message to set. + */ +NNVM_DLL void NNAPISetLastError(const char* msg); + +/*! + * \brief return str message of the last error + * all function in this file will return 0 when success + * and -1 when an error occured, + * NNGetLastError can be called to retrieve the error + * + * this function is threadsafe and can be called by different thread + * \return error info + */ +NNVM_DLL const char *NNGetLastError(void); + +/*! + * \brief list all the available operator names, include entries. + * \param out_size the size of returned array + * \param out_array the output operator name array. + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNListAllOpNames(nn_uint *out_size, + const char*** out_array); + +/*! + * \brief Get operator handle given name. + * \param op_name The name of the operator. + * \param op_out The returnning op handle. + */ +NNVM_DLL int NNGetOpHandle(const char* op_name, + OpHandle* op_out); + +/*! + * \brief list all the available operators. + * This won't include the alias, use ListAllNames + * instead to get all alias names. + * + * \param out_size the size of returned array + * \param out_array the output AtomicSymbolCreator array + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNListUniqueOps(nn_uint *out_size, + OpHandle **out_array); + +/*! + * \brief Get the detailed information about atomic symbol. + * \param op The operator handle. + * \param real_name The returned name of the creator. + * This name is not the alias name of the atomic symbol. + * \param description The returned description of the symbol. + * \param num_doc_args Number of arguments that contain documents. + * \param arg_names Name of the arguments of doc args + * \param arg_type_infos Type informations about the arguments. + * \param arg_descriptions Description information about the arguments. + * \param return_type Return type of the function, if any. + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNGetOpInfo(OpHandle op, + const char **real_name, + const char **description, + nn_uint *num_doc_args, + const char ***arg_names, + const char ***arg_type_infos, + const char ***arg_descriptions, + const char **return_type); +/*! + * \brief Create an AtomicSymbol functor. + * \param op The operator handle + * \param num_param the number of parameters + * \param keys the keys to the params + * \param vals the vals of the params + * \param out pointer to the created symbol handle + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolCreateAtomicSymbol(OpHandle op, + nn_uint num_param, + const char **keys, + const char **vals, + SymbolHandle *out); +/*! + * \brief Create a Variable Symbol. + * \param name name of the variable + * \param out pointer to the created symbol handle + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolCreateVariable(const char *name, SymbolHandle *out); +/*! + * \brief Create a Symbol by grouping list of symbols together + * \param num_symbols number of symbols to be grouped + * \param symbols array of symbol handles + * \param out pointer to the created symbol handle + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolCreateGroup(nn_uint num_symbols, + SymbolHandle *symbols, + SymbolHandle *out); +/*! + * \brief Add src_dep to the handle as control dep. + * \param handle The symbol to add dependency edges on. + * \param src_dep the source handles. + */ +NNVM_DLL int NNAddControlDeps(SymbolHandle handle, + SymbolHandle src_dep); +/*! + * \brief Free the symbol handle. + * \param symbol the symbol + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolFree(SymbolHandle symbol); +/*! + * \brief Copy the symbol to another handle + * \param symbol the source symbol + * \param out used to hold the result of copy + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out); +/*! + * \brief Print the content of symbol, used for debug. + * \param symbol the symbol + * \param out_str pointer to hold the output string of the printing. + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolPrint(SymbolHandle symbol, const char **out_str); +/*! + * \brief Get string attribute from symbol + * \param symbol the source symbol + * \param key The key of the symbol. + * \param out The result attribute, can be NULL if the attribute do not exist. + * \param success Whether the result is contained in out. + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolGetAttr(SymbolHandle symbol, + const char* key, + const char** out, + int *success); +/*! + * \brief Set string attribute from symbol. + * NOTE: Setting attribute to a symbol can affect the semantics(mutable/immutable) of symbolic graph. + * + * Safe recommendaton: use immutable graph + * - Only allow set attributes during creation of new symbol as optional parameter + * + * Mutable graph (be careful about the semantics): + * - Allow set attr at any point. + * - Mutating an attribute of some common node of two graphs can cause confusion from user. + * + * \param symbol the source symbol + * \param num_param Number of parameters to set. + * \param keys The keys of the attribute + * \param values The value to be set + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolSetAttrs(SymbolHandle symbol, + nn_uint num_param, + const char** keys, + const char** values); +/*! + * \brief Get all attributes from symbol, including all descendents. + * \param symbol the source symbol + * \param recursive_option 0 for recursive, 1 for shallow. + * \param out_size The number of output attributes + * \param out 2*out_size strings representing key value pairs. + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol, + int recursive_option, + nn_uint *out_size, + const char*** out); + +/*! + * \brief List inputs variables in the symbol. + * \param symbol the symbol + * \param option The option to list the inputs + * option=0 means list all arguments. + * option=1 means list arguments that are readed only by the graph. + * option=2 means list arguments that are mutated by the graph. + * \param out_size output size + * \param out_sym_array the output array. + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolListInputVariables(SymbolHandle symbol, + int option, + nn_uint *out_size, + SymbolHandle** out_sym_array); + +/*! + * \brief List input names in the symbol. + * \param symbol the symbol + * \param option The option to list the inputs + * option=0 means list all arguments. + * option=1 means list arguments that are readed only by the graph. + * option=2 means list arguments that are mutated by the graph. + * \param out_size output size + * \param out_str_array pointer to hold the output string array + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolListInputNames(SymbolHandle symbol, + int option, + nn_uint *out_size, + const char ***out_str_array); +/*! + * \brief List returns names in the symbol. + * \param symbol the symbol + * \param out_size output size + * \param out_str_array pointer to hold the output string array + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolListOutputNames(SymbolHandle symbol, + nn_uint *out_size, + const char ***out_str_array); + + +/*! + * \brief Supply number of outputs of the symbol. + * \param symbol the symbol + * \param output_count number of outputs + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolGetNumOutputs(SymbolHandle symbol, + nn_uint *output_count); + +/*! + * \brief Get a symbol that contains all the internals. + * \param symbol The symbol + * \param out The output symbol whose outputs are all the internals. + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolGetInternals(SymbolHandle symbol, + SymbolHandle *out); +/*! + * \brief Get a symbol that contains only direct children. + * \param symbol The symbol + * \param out The output symbol whose outputs are the direct children. + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolGetChildren(SymbolHandle symbol, + SymbolHandle *out); +/*! + * \brief Get index-th outputs of the symbol. + * \param symbol The symbol + * \param index the Index of the output. + * \param out The output symbol whose outputs are the index-th symbol. + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolGetOutput(SymbolHandle symbol, + nn_uint index, + SymbolHandle *out); + +/*! + * \brief Compose the symbol on other symbols. + * + * This function will change the sym hanlde. + * To achieve function apply behavior, copy the symbol first + * before apply. + * + * \param sym the symbol to apply + * \param name the name of symbol + * \param num_args number of arguments + * \param keys the key of keyword args (optional) + * \param args arguments to sym + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNSymbolCompose(SymbolHandle sym, + const char* name, + nn_uint num_args, + const char** keys, + SymbolHandle* args); + +// Graph IR API +/*! + * \brief create a graph handle from symbol + * \param symbol The symbol representing the graph. + * \param graph The graph handle created. + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNGraphCreate(SymbolHandle symbol, GraphHandle *graph); +/*! + * \brief free the graph handle + * \param handle The handle to be freed. + */ +NNVM_DLL int NNGraphFree(GraphHandle handle); +/*! + * \brief Get a new symbol from the graph. + * \param graph The graph handle. + * \param symbol The corresponding symbol + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol); + +/*! + * \brief Get Set a attribute in json format. + * This feature allows pass graph attributes back and forth in reasonable speed. + * + * \param handle The graph handle. + * \param key The key to the attribute. + * \param json_value The value need to be in format [type_name, value], + * Where type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY. + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle, + const char* key, + const char* json_value); + +/*! + * \brief Get a serialized attrirbute from graph. + * This feature allows pass graph attributes back and forth in reasonable speed. + * + * \param handle The graph handle. + * \param key The key to the attribute. + * \param json_out The result attribute, can be NULL if the attribute do not exist. + * The json_out is an array of [type_name, value]. + * Where the type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY. + * \param success Whether the result is contained in out. + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNGraphGetJSONAttr(GraphHandle handle, + const char* key, + const char** json_out, + int *success); + +/*! + * \brief Set a attribute whose type is std::vector in c++ + * This feature allows pass List of symbolic variables for gradient request. + * + * \note This is beta feature only used for test purpos + * + * \param handle The graph handle. + * \param key The key to the attribute. + * \param list The symbol whose outputs represents the list of NodeEntry to be passed. + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle, + const char* key, + SymbolHandle list); +/*! + * \brief Apply passes on the src graph. + * \param src The source graph handle. + * \param num_pass The number of pass to be applied. + * \param pass_names The names of the pass. + * \param dst The result graph. + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNGraphApplyPasses(GraphHandle src, + nn_uint num_pass, + const char** pass_names, + GraphHandle *dst); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // NNVM_C_API_H_ diff --git a/include/nnvm/compiler/op_attr_types.h b/include/nnvm/compiler/op_attr_types.h new file mode 100644 index 000000000000..497a520db78e --- /dev/null +++ b/include/nnvm/compiler/op_attr_types.h @@ -0,0 +1,101 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file nnvm/compiler/op_attr_types.h + * \brief The Expr and related elements in DataFlow construction. + */ +#ifndef NNVM_COMPILER_OP_ATTR_TYPES_H_ +#define NNVM_COMPILER_OP_ATTR_TYPES_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "packed_func_ext.h" + +namespace nnvm { +namespace compiler { + +using ::tvm::Array; +using ::tvm::Tensor; +using ::tvm::Schedule; + +/*! \brief operator pattern used in graph fusion */ +enum OpPatternKind { + // Elementwise operation + kElemWise = 0, + // Broadcasting operator, can always map output axis to the input in order. + // for example :code:`out[i, ax1, j, ax2] = input[i, j]`. + // Note that the axis need to be in order so transpose is not a bcast operator. + kBroadcast = 1, + // Injective operator, can always injectively map output axis to a single input axis. + // All injective operator can still be safely fused to injective and reduction. + kInjective = 2, + // Communicative reduction operator. + kCommReduce = 3, + // Complex operation, can still fuse elemwise operations into its output. + // but cannot chain another complex op + kOutEWiseFusable = 4, + // Opaque operation, cannot fuse anything. + kOpaque = 8 +}; + +/*! \brief the operator pattern */ +using TOpPattern = int; + +/*! + * \brief Computation description interface + * \param attrs The attribute of the node. + * \param inputs The input tensors(placeholders) + * \param out_info Tensors holding shape/type information about output, + & these are always placeholders. + * \return The output description of the tensor. + */ +using FTVMCompute = std::function< + Array(const NodeAttrs& attrs, + const Array& inputs, + const Array& out_info)>; + +/*! + * \brief Build the computation schedule for + * op whose root is at current op. + * \param attrs The attribute of the node. + * \param outs The output tensors. + * \param target The build target. + * \return schedule The computation schedule. + */ +using FTVMSchedule = std::function< + Schedule(const NodeAttrs& attrs, + const Array& outs, + const std::string& target)>; + +/*! + * \brief Modify the op node to alter its input layout. + * it is invoked in AlterOpLayout pass. + * \param attrs The attribute of the original node. + * \param inputs The input symbols of the original node. + * \param tinfos The inferred shape and dtype of the inputs. + * \param ret The replaced operator. + * \return Whether to replace current operator. + */ +using FTVMAlterOpLayout = std::function< + bool(const NodeAttrs& attrs, + const Symbol& inputs, + const Array& tinfos, + Symbol* ret)>; + +/*! + * \brief Transform from normal operator to vectorized operator + * \param node The source node. + * \return Transformed vectorized op. + */ +using FTVMVectorizedOp = std::function; + +} // namespace compiler +} // namespace nnvm +#endif // NNVM_COMPILER_OP_ATTR_TYPES_H_ diff --git a/include/nnvm/compiler/packed_func_ext.h b/include/nnvm/compiler/packed_func_ext.h new file mode 100644 index 000000000000..e289fd4efa59 --- /dev/null +++ b/include/nnvm/compiler/packed_func_ext.h @@ -0,0 +1,59 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file nnvm/compiler/packed_func_ext.h + * \brief Extension to enable packed functionn for nnvm types + */ +#ifndef NNVM_COMPILER_PACKED_FUNC_EXT_H_ +#define NNVM_COMPILER_PACKED_FUNC_EXT_H_ + +#include +#include +#include +#include +#include +#include +#include + +namespace nnvm { +namespace compiler { + +using tvm::runtime::PackedFunc; + +using AttrDict = std::unordered_map; + +/*! + * \brief Get PackedFunction from global registry and + * report error if it does not exist + * \param name The name of the function. + * \return The created PackedFunc. + */ +inline const PackedFunc& GetPackedFunc(const std::string& name) { + const PackedFunc* pf = tvm::runtime::Registry::Get(name); + CHECK(pf != nullptr) << "Cannot find function " << name << " in registry"; + return *pf; +} +} // namespace compiler +} // namespace nnvm + +// Enable the graph and symbol object exchange. +namespace tvm { +namespace runtime { + +template<> +struct extension_class_info { + static const int code = 16; +}; + +template<> +struct extension_class_info { + static const int code = 17; +}; + +template<> +struct extension_class_info { + static const int code = 18; +}; + +} // namespace runtime +} // namespace tvm +#endif // NNVM_COMPILER_PACKED_FUNC_EXT_H_ diff --git a/include/nnvm/compiler/util.h b/include/nnvm/compiler/util.h new file mode 100644 index 000000000000..5d5bc4478530 --- /dev/null +++ b/include/nnvm/compiler/util.h @@ -0,0 +1,33 @@ +/*! +* Copyright (c) 2016 by Contributors +* \file nnvm/compiler/util.h +* \brief Utility functions for nnvm compiler +*/ +#ifndef NNVM_COMPILER_UTIL_H_ +#define NNVM_COMPILER_UTIL_H_ + +#include +#include + +namespace nnvm { +namespace compiler { + +/* + * \brief Helper function to convert TShape to TVM array. Useful for + * passing data from NNVM param structures to TOPI ops. + * + * \param shape The shape to convert + * + * \return An Array of Expr, where each element is a constant int32 + */ +inline tvm::Array ShapeToArray(TShape shape) { + tvm::Array result; + for (auto i : shape) { + result.push_back(tvm::make_const(tvm::Int(32), i)); + } + return result; +} + +} // namespace compiler +} // namespace nnvm +#endif // NNVM_COMPILER_UTIL_H_ diff --git a/include/nnvm/graph.h b/include/nnvm/graph.h new file mode 100644 index 000000000000..3f8a2a3642b1 --- /dev/null +++ b/include/nnvm/graph.h @@ -0,0 +1,315 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file nnvm/graph.h + * \brief Configuation of nnvm as well as basic data structure. + */ +#ifndef NNVM_GRAPH_H_ +#define NNVM_GRAPH_H_ + +#include +#include +#include +#include +#include +#include +#include "base.h" +#include "node.h" +#include "symbolic.h" + +namespace nnvm { + +class IndexedGraph; + +/*! + * \brief Symbolic computation graph. + * This is the intermediate representation for optimization pass. + */ +class Graph { + public: + /*! \brief outputs of the computation graph. */ + std::vector outputs; + /*! + * \brief attributes of a graph + * Note that attribute is shared pointer and can be shared across graphs. + * + * It is highly recommended to keep each attribute immutable. + * It is also safe to implement an copy-on-write semnatics. + * + * Copy when shared_ptr.unique is not true, while reuse original space + * when shared_ptr.unique is true. + */ + std::unordered_map > attrs; + /*! + * \brief Get the immutable attribute from attrs. + * \param attr_name the name of the attribute + * \return the reference to corresponding attribute + * \tparam T the type of the attribute. + */ + template + inline const T& GetAttr(const std::string& attr_name) const; + /*! + * \brief Check whether has a specific attribute. + * \param attr_name the name of the attribute + * \return a boolean result + */ + inline bool HasAttr(const std::string& attr_name) const; + /*! + * \brief Get a move copy of the attribute, implement copy on write semantics. + * The content is moved if the reference counter of shared_ptr is 1. + * The attribute is erased from attrs after the call. + * + * \param attr_name the name of the attribute + * \return a new copy of the corresponding attribute. + * \tparam T the type of the attribute. + */ + template + inline T MoveCopyAttr(const std::string& attr_name); + /*! + * \brief get a indexed graph of current graph, if not exist, create it on demand + * \return The indexed graph. + * \sa IndexedGraph + */ + const IndexedGraph& indexed_graph() const; + + private: + // internal structure of indexed graph + mutable std::shared_ptr indexed_graph_; +}; + +/*! + * \brief Auxiliary data structure to index a graph. + * It maps Nodes in the graph to consecutive integers node_id. + * It also maps IndexedGraph::NodeEntry to consecutive integer entry_id. + * This allows storing properties of Node and NodeEntry into + * compact vector and quickly access them without resorting to hashmap. + * + * The node_id and entry_rptr are the same as the JSON graph produced by SaveJSON Pass. + */ +class IndexedGraph { + public: + /*! \brief represents a data in the graph */ + struct NodeEntry { + /*! \brief the source node id in the computation graph */ + uint32_t node_id; + /*! \brief index of output from the source. */ + uint32_t index; + /*! \brief version of the node */ + uint32_t version; + }; + /*! \brief Node data structure in IndexedGraph */ + struct Node { + /*! \brief pointer to the source node */ + const nnvm::Node* source; + /*! \brief inputs to the node */ + array_view inputs; + /*! \brief control flow dependencies to the node */ + array_view control_deps; + /*! \brief weak reference to node */ + std::weak_ptr weak_ref; + }; + /*! \return number of nodes in the graph */ + inline size_t num_nodes() const { + return nodes_.size(); + } + /*! \return total number of NodeEntry in the graph */ + inline size_t num_node_entries() const { + return entry_rptr_.back(); + } + /*! + * \brief Get a unique entry id between 0 to num_node_entries() + * for a given IndexedGraph::NodeEntry + * \param node_id The node index + * \param index the output index + * \return the unique index. + */ + inline uint32_t entry_id(uint32_t node_id, uint32_t index) const { + return entry_rptr_[node_id] + index; + } + /*! + * \brief Get a unique entry id between 0 to num_node_entries() + * for a given IndexedGraph::NodeEntry + * \param e The entry to query for index. + * \return the unique index. + */ + inline uint32_t entry_id(const NodeEntry& e) const { + return entry_rptr_[e.node_id] + e.index; + } + /*! + * \brief Get a unique entry id between 0 to num_node_entries() + * for a given NodeEntry. + * \param e The entry to query for index. + * \return the unique index. + */ + inline uint32_t entry_id(const nnvm::NodeEntry& e) const { + return entry_rptr_[node_id(e.node.get())] + e.index; + } + /*! + * \brief Get the corresponding node id for a given Node in the IndexedGraph. + * \param node The Node to query for index. + * \return the node index. + */ + inline uint32_t node_id(const nnvm::Node* node) const { + return node2index_.at(node); + } + /*! + * \brief Get the corresponding Node structure for a given node_id. + * \param node_id The node id + * \return const reference to the corresponding IndexedGraph::Node + */ + inline const Node& operator[](uint32_t node_id) const { + return nodes_[node_id]; + } + /*! + * \brief Get the corresponding Node structure + * \param node The pointer to the Node structure + * \return const reference to the corresponding IndexedGraph::Node + */ + inline const Node& operator[](const nnvm::Node* node) const { + return nodes_[node_id(node)]; + } + /*! \return list of argument nodes */ + inline const std::vector& input_nodes() const { + return input_nodes_; + } + /*! \return list of mutable nodes */ + inline const std::unordered_set& mutable_input_nodes() const { + return mutable_input_nodes_; + } + /*! \return list of output entries */ + inline const std::vector& outputs() const { + return outputs_; + } + + /*! \return whether a node is existed in the indexed graph */ + inline bool exist(const nnvm::Node* node) const { + return node2index_.count(node); + } + + // disalllow copy assign + IndexedGraph(const IndexedGraph&) = delete; + + private: + friend class Graph; + /*! + * \brief Constructor an IndexedGraph from normal Graph + * \param other The source graph. + */ + explicit IndexedGraph(const Graph& other); + // Node pointers in CSR structure. + std::vector nodes_; + // Index to all input nodes. + std::vector input_nodes_; + // Index to all mutable input nodes. + std::unordered_set mutable_input_nodes_; + // space to store the outputs entries + std::vector outputs_; + // mapping from node to index. + std::unordered_map node2index_; + // CSR pointer of node entries + std::vector entry_rptr_; + // space to store input entries of each + std::vector input_entries_; + // control flow dependencies + std::vector control_deps_; +}; + +/*! + * \brief perform a Post Order DFS visit to each node in the graph. + * This order is deterministic and is also topoligical sorted. + * \param heads The heads in the graph. + * \param fvisit a function of type std::function&)> + * \tparam FVisit The function type to perform the visit. + */ +template +inline void DFSVisit(const std::vector& heads, FVisit fvisit); + +// inline function implementations +template +inline const T& Graph::GetAttr(const std::string& attr_name) const { + auto it = attrs.find(attr_name); + CHECK(it != attrs.end()) + << "Cannot find attribute " << attr_name << " in the graph"; + return nnvm::get(*it->second); +} + +inline bool Graph::HasAttr(const std::string& attr_name) const { + auto it = attrs.find(attr_name); + return it != attrs.end(); +} + +template +inline T Graph::MoveCopyAttr(const std::string& attr_name) { + auto it = attrs.find(attr_name); + CHECK(it != attrs.end()) + << "Cannot find attribute " << attr_name << " in the graph"; + std::shared_ptr sptr = it->second; + attrs.erase(it); + if (sptr.unique()) { + return std::move(nnvm::get(*sptr)); + } else { + return nnvm::get(*sptr); + } +} + +template +void PostOrderDFSVisit(const std::vector& heads, + FVisit fvisit, + HashFunc hash, + InDegree indegree, + GetInput getinput) { + std::vector > stack; + std::unordered_set visited; + for (auto& head : heads) { + HashType head_hash = hash(head); + if (visited.count(head_hash) == 0) { + stack.push_back(std::make_pair(head, 0)); + visited.insert(head_hash); + } + while (!stack.empty()) { + std::pair& back = stack.back(); + if (back.second == indegree(back.first)) { + fvisit(back.first); + stack.pop_back(); + } else { + const GNode& input = getinput(back.first, back.second++); + HashType input_hash = hash(input); + if (visited.count(input_hash) == 0) { + stack.push_back(std::make_pair(input, 0)); + visited.insert(input_hash); + } + } + } + } +} + +template +inline void DFSVisit(const std::vector& heads, + FVisit fvisit) { + typedef const NodePtr* GNode; + std::vector head_nodes(heads.size()); + std::transform(heads.begin(), heads.end(), head_nodes.begin(), + [](const NodeEntry& e)->GNode { + return &e.node; + }); + PostOrderDFSVisit( + head_nodes, + [fvisit](GNode n) { fvisit(*n); }, // FVisit + [](GNode n)->Node* { return n->get(); }, // HashFunc + [](GNode n)->uint32_t { // InDegree + if (!(*n)) return 0; + return (*n)->inputs.size() + (*n)->control_deps.size(); + }, + [](GNode n, uint32_t index)->GNode { // GetInput + if (index < (*n)->inputs.size()) { + return &(*n)->inputs.at(index).node; + } else { + return &(*n)->control_deps.at(index - (*n)->inputs.size()); + } + }); +} + +} // namespace nnvm + +#endif // NNVM_GRAPH_H_ diff --git a/include/nnvm/graph_attr_types.h b/include/nnvm/graph_attr_types.h new file mode 100644 index 000000000000..2fe82c9a7de0 --- /dev/null +++ b/include/nnvm/graph_attr_types.h @@ -0,0 +1,112 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file nnvm/graph_attr_types.h + * \brief Data structures that can appear in graph attributes. + */ +#ifndef NNVM_GRAPH_ATTR_TYPES_H_ +#define NNVM_GRAPH_ATTR_TYPES_H_ + +#include +#include +#include "tuple.h" +#include "layout.h" + +namespace nnvm { + +/*! + * \brief The result holder of JSON serializer + * + * \note Stored under ret.attrs["json"], provided by Pass "SaveJSON" + + * \code + * Graph ret = ApplyPass(src_graph, "SaveJSON"); + * const JSONString& json = ret.GetAttr("shape"); + * \endcode + */ +using JSONString = std::string; + +/*! + * \brief The result holder of shape of each NodeEntry in the graph. + * \note Stored under graph.attrs["shape"], provided by Pass "InferShape" + * + * \code + * Graph g = ApplyPass(src_graph, "InferShape"); + * const ShapeVector& shapes = g.GetAttr("shape"); + * // get shape by entry id + * TShape entry_shape = shapes[g.indexed_graph().entry_id(my_entry)]; + * \endcode + * + * \sa FInferShape + */ +using ShapeVector = std::vector; + +/*! + * \brief The result holder of type of each NodeEntry in the graph. + * \note Stored under graph.attrs["dtype"], provided by Pass "InferType" + * + * \code + * Graph g = ApplyPass(src_graph, "InferType"); + * const DTypeVector& types = g.GetAttr("dtype"); + * // get type by entry id + * int entry_type = dtypes[g.indexed_graph().entry_id(my_entry)]; + * \endcode + * + * \sa FInferType + */ +using DTypeVector = std::vector; + +/*! + * \brief The result holder of layout of each NodeEntry in the graph. + * \note Stored under graph.attrs["layout"], provided by Pass "InferType" + * + * \code + * Graph g = ApplyPass(src_graph, "LayoutTransform"); + * const LayoutVector& layouts = g.GetAttr("layout"); + * // get layout by entry id + * int entry_layout = layouts[g.indexed_graph().entry_id(my_entry)]; + * \endcode + * + * \sa FCorrectLayout + */ +using LayoutVector = std::vector; + +/*! + * \brief The result holder of device of each operator in the graph. + * \note Stored under graph.attrs["device"], provided by Pass "PlaceDevice" + * + * \code + * Graph g = ApplyPass(src_graph, "PlaceDevice"); + * const &device = g.GetAttr("device"); + * // get device by node_id + * int device_type = device[g.indexed_graph().node_id(my_node)]; + * \endcode + */ +using DeviceVector = std::vector; + +/*! + * \brief The result holder of device of each operator in the graph. + * + * \note Stored under graph.attrs["device_assign_map"], needed by Pass "PlaceDevice" + * -1 means unknown device + */ +using DeviceAssignMap = std::unordered_map; + +/*! + * \brief The result holder of storage id of each NodeEntry in the graph. + * + * \note Stored under graph.attrs["storage"], provided by Pass "PlanMemory" + * Storage id is a continuous integer. + * If the storage id is -1 then the storage is not assigned. + * + * \code + * Graph g = ApplyPass(src_graph, "PlanMemory"); + * const &storage = g.GetAttr("storage"); + * // get storage id by entry + * int storage_id = storage[g.indexed_graph().entry_id(my_entry)]; + * \endcode + */ +using StorageVector = std::vector; + +} // namespace nnvm + +#endif // NNVM_GRAPH_ATTR_TYPES_H_ diff --git a/include/nnvm/layout.h b/include/nnvm/layout.h new file mode 100644 index 000000000000..94813f5323f8 --- /dev/null +++ b/include/nnvm/layout.h @@ -0,0 +1,455 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file nnvm/layout.h + * \brief Layout expression. + * The layout is composed of upper cases, lower cases and numbers, + * where upper case indicates a (super-)dimension and + * the corresponding lower case with factor size indicates the split (sub-)dimension. + * For example, NCHW16c can describe a 5-D tensor of + * [batch_size, channel, height, width, channel_block]. + * Here sub-dimension channel_block=16 is the split of super-dimension C (channel). + */ +#ifndef NNVM_LAYOUT_H_ +#define NNVM_LAYOUT_H_ + +#include +#include +#include +#include +#include +#include + +namespace nnvm { + +class Layout { + public: + using LayoutDim = char; + + /*! \brief default constructor */ + Layout() : name_("__undef__") {} // NOLINT(*) + + /*! + * \brief construct from a string. + * \param layout input in layout convention: + * upper case indicates a dimension and + * the corresponding lower case with factor size + * indicates the split dimension. + * return undefined layout if "__undef__" is passed. + */ + inline Layout(const std::string& layout) { // NOLINT(*) + parse(layout); + } + /*! + * \brief copy constructor from another layout + * \param s the source layout + */ + inline Layout(const Layout& s) { // NOLINT(*) + this->parse(s.name_); + } + /*! + * \brief move constructor from Layout + * \param src the source layout + */ + inline Layout(Layout&& src) { // NOLINT(*) + this->swap(src); + } + /*! + * \brief assignment from another layout. + * \param src source layout + * \return reference of self + */ + inline Layout& operator=(const Layout& src) { + this->parse(src.name_); + return *this; + } + /*! + * \brief assignment from rvalue of another layout. + * \param src source layout + * \return reference of self + */ + inline Layout& operator=(Layout&& src) { + Layout(std::move(src)).swap(*this); // NOLINT(*) + return *this; + } + /*! + * \brief assignment from string. + * \param src source layout + * \return reference of self + */ + inline Layout& operator=(const std::string& src) { + this->parse(src); + return *this; + } + /*! + * \return whether two layout equals + * \param s the layout to compare against + */ + inline bool operator==(const Layout& s) const { + return name_ == s.name_; + } + /*! + * \return whether two layout not equal + * \param s the layout to compare against + */ + inline bool operator!=(const Layout& s) const { + return !(*this == s); + } + + /*! + * \brief Append the current layout by another. + * @param other the layout to be appended + * @return a new layout + */ + inline Layout operator+(const Layout& other) const { + if (!this->defined() && !other.defined()) { + return Layout::Undef(); + } else if (!this->defined()) { + return other; + } else if (!other.defined()) { + return *this; + } + return Layout(this->name_ + other.name_); + } + + /*! + * \brief Check whether a given dimension is a super-dimension. + * \param dim input dimension + * \return Whether a given dimension is a super-dimension. + */ + static inline bool is_superdim(LayoutDim dim) { + return dim >= 'A' && dim <= 'Z'; + } + + /*! + * \brief Check whether a given dimension is a sub-dimension. + * \param dim input dimension + * \return Whether a given dimension is a sub-dimension. + */ + static inline bool is_subdim(LayoutDim dim) { + return dim >= 'a' && dim <= 'z'; + } + + /*! + * \brief Convert a given dimension to super-dimension. + * \param dim input dimension + * \return The converted description. + */ + static inline LayoutDim to_superdim(LayoutDim dim) { + if (is_subdim(dim)) { + return dim - 'a' + 'A'; + } + return dim; + } + + /*! + * \brief Convert a given dimension to sub-dimension. + * \param dim input dimension + * \return The converted description. + */ + static inline LayoutDim to_subdim(LayoutDim dim) { + if (is_superdim(dim)) { + return dim - 'A' + 'a'; + } + return dim; + } + + /*! + * \brief Return an undefined layout. + * \return a (global) undefined layout. + */ + static inline const Layout& Undef() { + static Layout undef; + return undef; + } + + /*! + * \brief Swap current object with other + * \param other another object to be swapped. + */ + inline void swap(Layout& other) { // NOLINT(*) + std::swap(name_, other.name_); + std::swap(superdim_pos_, other.superdim_pos_); + std::swap(subdim_pos_, other.subdim_pos_); + std::swap(subdim_size_, other.subdim_size_); + std::swap(layout_simplified_, other.layout_simplified_); + } + + /*! + * \brief Two layouts are convertible only if + * they have same set of super-dimensions. + * e.g., NCHW, NCHW16c, NHWC are convertible between each other, + * but NCHW, CHW, OIHW are not. + * \param dst the target layout + * \return Whether can be converted to dst layout. + */ + inline bool convertible(const Layout &dst) const { + if (!this->defined() || !dst.defined()) return false; + for (size_t i = 0; i < kUniqueDim; ++i) { + if ((superdim_pos_[i] >= 0 && dst.superdim_pos_[i] < 0) || + (superdim_pos_[i] < 0 && dst.superdim_pos_[i] >= 0)) { + return false; + } + } + return true; + } + + /*! + * \brief Returns a sublayout which is the portion of the object + * that starts at dimension \p pos and spans \p len dimensions + * (or until the end of the layout, whichever comes first). + * \param pos The start position. + * \param len The length of the sub-layout. + * \return A newly constructed Layout object. + */ + inline Layout sublayout(size_t pos, size_t len) const { + if (pos > ndim()) return Layout::Undef(); + if (pos + len > ndim()) len = ndim() - pos; + if (len == 0) return Layout::Undef(); + std::ostringstream new_layout; + for (size_t i = pos; i < pos + len; ++i) { + if (is_subdim(layout_simplified_[i])) { + auto block_size = this->subsizeof(layout_simplified_[i]); + CHECK_GT(block_size, 0); + new_layout << block_size; + } + new_layout << layout_simplified_[i]; + } + return Layout(new_layout.str()); + } + + /*! \return A newly constructed reversed Layout object. */ + inline Layout reverse() const { + if (!this->defined()) return Layout::Undef(); + std::ostringstream new_layout; + for (int64_t i = this->ndim() - 1; i >= 0; --i) { + if (is_subdim(layout_simplified_[i])) { + auto block_size = this->subsizeof(layout_simplified_[i]); + CHECK_GT(block_size, 0); + new_layout << block_size; + } + new_layout << layout_simplified_[i]; + } + return Layout(new_layout.str()); + } + + /*! + * \brief Split \p dim by \p size and put the sub-dimension to position \p target_pos. + * \param dim The source dimension to be split. It must be a super-dimension. + * \param target_pos The target position of the newly split sub-dimension. + * \param size size of the sub-dimension. + * \return A newly constructed Layout object. + */ + inline Layout split(LayoutDim dim, size_t target_pos, uint32_t size) const { + CHECK(target_pos <= this->ndim()) << "Invalid split position " + << target_pos << " for layout " << name_; + CHECK(is_superdim(dim)) << "Cannot split a sub-dimension " << dim; + CHECK(this->contains(dim)) << "Axis " << dim << " does not exist in " << name_; + CHECK(!this->contains(to_subdim(dim))) << "Dimension " << dim + << " has already been split in " + << name_; + CHECK(size > 0) << "Invalid split size " << size; + std::ostringstream new_layout; + for (size_t i = 0; i <= this->ndim(); ++i) { + if (i == target_pos) { + new_layout << size << Layout::to_subdim(dim); + } + if (i == this->ndim()) break; + new_layout << this->at(i); + } + Layout x(new_layout.str()); + return x; + } + + using iterator = std::vector::const_iterator; + using reverse_iterator = std::vector::const_reverse_iterator; + + /*! \return begin iterator */ + inline iterator begin() const { + return layout_simplified_.begin(); + } + /*! \return end iterator */ + inline iterator end() const { + return layout_simplified_.end(); + } + /*! \return rbegin iterator */ + inline reverse_iterator rbegin() const { + return layout_simplified_.rbegin(); + } + /*! \return rend iterator */ + inline reverse_iterator rend() const { + return layout_simplified_.rend(); + } + + /*! \return number of dimensions */ + inline size_t ndim() const { + return layout_simplified_.size(); + } + + /*! + * \brief The description of the \p i-th dimension. + * If it is a sub-dimension, the size will be returned as well, + * e.g., 16c. Otherwise a single character is returned, e.g., C. + * \param i The position + * \return the description of the dimension. + */ + inline std::string at(size_t i) const { + CHECK_LT(i, this->ndim()) << "position " << i + << " exceeds ndim=" << this->ndim(); + std::ostringstream repr; + if (is_subdim(layout_simplified_[i])) { + auto factor = subsizeof(layout_simplified_[i]); + CHECK_GT(factor, 0); + repr << factor; + } + repr << layout_simplified_[i]; + return repr.str(); + } + + /*! + * \brief return the index of the input dimension. + * If it is not found in the layout or the layout is undefined, + * return -1. + * \param dim the input dimension. + * \return the index or -1 if not found. + */ + inline int32_t indexof(LayoutDim dim) const { + if (!this->defined()) return -1; + else if (is_superdim(dim)) return superdim_pos_[dim - 'A']; + else if (is_subdim(dim)) return subdim_pos_[dim - 'a']; + return -1; + } + + /*! + * \param dim the input super-dimension or sub-dimension. + * \return the size of the sub-dimension of \p dim (if \p dim is a super-dimension), + * or the size of \p dim itself (if \p dim is a sub-dimension). + * Return -1 if \p dim is not in the layout or the layout is undefined. + */ + inline int64_t subsizeof(LayoutDim dim) const { + CHECK(is_superdim(dim) || is_subdim(dim)) << "Invalid dim " << dim; + if (!this->defined() || !this->contains(to_subdim(dim))) { + return -1; + } + int idx = to_subdim(dim) - 'a'; + return subdim_size_[idx]; + } + + /*! + * \brief Whether the layout contains a dimension. + * \param dim dimension to be checked. + * \return Whether the layout contains the dimension. + */ + inline bool contains(LayoutDim dim) const { + if (is_superdim(dim)) { + return superdim_pos_[dim-'A'] >= 0; + } else if (is_subdim(dim)) { + return subdim_pos_[dim-'a'] >= 0; + } + return false; + } + + inline LayoutDim operator[](size_t i) const { + return layout_simplified_[i]; + } + + /*! \return whether the layout is defined */ + inline bool defined() const { + return name_ != "__undef__"; + } + + /*! \return the string description of the layout */ + inline const std::string& name() const { + return name_; + } + + /*! + * \brief Write layout in JSON format. + * \param writer JSONWriter + */ + inline void Save(dmlc::JSONWriter* writer) const { + writer->Write(name_); + } + + /*! + * \brief Load layout from JSON. + * \param reader JSONReader + */ + inline void Load(dmlc::JSONReader* reader) { + std::string tmp; + reader->Read(&tmp); + this->parse(tmp); + } + + /*! + * \brief allow output string of layout to ostream + * \param os the output stream + * \param l the layout + * \return the ostream + */ + friend std::ostream& operator<<(std::ostream& os, const Layout& l) { + os << l.name_; + return os; + } + + private: + static const uint32_t kUniqueDim = 26; + + std::string name_; + int32_t superdim_pos_[kUniqueDim]; + int32_t subdim_pos_[kUniqueDim]; + int64_t subdim_size_[kUniqueDim]; + std::vector layout_simplified_; + + void parse(const std::string& layout) { + name_ = layout; + std::fill_n(superdim_pos_, kUniqueDim, -1); + std::fill_n(subdim_pos_, kUniqueDim, -1); + std::fill_n(subdim_size_, kUniqueDim, -1); + layout_simplified_.clear(); + + if (layout == "__undef__") return; + + int32_t factor = 0; + uint32_t curr = 0; + for (size_t i = 0; i < layout.size(); ++i) { + const LayoutDim c = layout.at(i); + if (is_superdim(c)) { + int pos = c - 'A'; + CHECK_EQ(factor, 0) << "Invalid layout " << layout + << ": invalid factor size " << factor + << " before dimension " << c; + CHECK_EQ(superdim_pos_[pos], -1) << "Invalid layout " << layout + << ": duplicate dimension " << c; + superdim_pos_[pos] = curr++; + layout_simplified_.push_back(c); + } else if (is_subdim(c)) { + int pos = c - 'a'; + CHECK_GT(factor, 0) << "Invalid layout " << layout << ": invalid factor size " + << factor << " for dimension " << c; + CHECK_EQ(subdim_pos_[pos], -1) << "Invalid layout " << layout + << ": duplicate dimension " << c; + CHECK_EQ(subdim_size_[pos], -1) << "Invalid layout " << layout + << ": duplicate dimension " << c; + subdim_pos_[pos] = curr++; + subdim_size_[pos] = factor; + layout_simplified_.push_back(c); + factor = 0; + } else if (c >= '0' && c <= '9') { + CHECK(factor >= 0) << "Invalid layout " << layout << ": _ is adjacent to a number."; + factor = factor * 10 + c - '0'; + } else { + LOG(FATAL) << "Invalid layout " << layout; + } + } + CHECK(!layout_simplified_.empty()) << "Invalid layout " << layout; + for (LayoutDim dim : layout_simplified_) { + CHECK(is_superdim(dim) || superdim_pos_[dim-'a'] >= 0) + << "Invalid layout " << layout << ": missing axis " + << static_cast(dim - 'a' + 'A'); + } + } +}; + +} // namespace nnvm + +#endif // NNVM_LAYOUT_H_ diff --git a/include/nnvm/node.h b/include/nnvm/node.h new file mode 100644 index 000000000000..ae782f04965e --- /dev/null +++ b/include/nnvm/node.h @@ -0,0 +1,201 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file nnvm/node.h + * \brief Graph node data structure. + */ +#ifndef NNVM_NODE_H_ +#define NNVM_NODE_H_ + +#include +#include +#include +#include +#include "base.h" +#include "op.h" +#include "c_api.h" + +namespace nnvm { + +// Forward declare node. +class Node; +class Symbol; + +/*! + * \brief we always used NodePtr for a reference pointer + * to the node, so this alias can be changed in case. + * + * By default, NodePtr is a std::shared_ptr of node + */ +using NodePtr = std::shared_ptr; + +/*! \brief an entry that represents output data from a node */ +struct NodeEntry { + /*! \brief the source node of this data */ + NodePtr node; + /*! \brief index of output from the source. */ + uint32_t index; + /*! + * \brief version of input Variable. + * This field can only be nonzero when this->node is a Variable node. + * version is increased by one each time a Variable get composed to a mutation Op. + * This information can be helpful to decide order of operations when sequence of mutation happens. + */ + uint32_t version; +}; + +/*! + * \brief This lets you use a NodeEntry as a key in a unordered_map of the form + * unordered_map + */ +struct NodeEntryHash { + size_t operator()(const NodeEntry& e) const { + return std::hash()(e.node.get()) ^ + (std::hash()(e.index) << 1 >> 1) ^ + (std::hash()(e.version) << 1); + } +}; + +/*! + * \brief This lets you use a NodeEntry as a key in a unordered_map of the form + * unordered_map + */ +struct NodeEntryEqual { + size_t operator()(const NodeEntry& a, const NodeEntry& b) const { + return (a.node.get() == b.node.get()) && + (a.index == b.index) && + (a.version == b.version); + } +}; + +/*! use NodeEntry as key in unordered_map */ +template +using NodeEntryMap = std::unordered_map; + +/*! + * \brief The attributes of the current operation node. + * Usually are additional parameters like axis, + */ +struct NodeAttrs { + /*! + * \brief The operator this node uses. + * For place holder variable, op == nullptr. + */ + const Op *op{nullptr}; + /*! \brief name of the node */ + std::string name; + /*! \brief The dictionary representation of attributes */ + std::unordered_map dict; + /*! + * \brief A parsed version of attributes, + * This is generated if OpProperty.attr_parser is registered. + * The object can be used to quickly access attributes. + */ + any parsed; + /*! + * \brief Some operators take graphs as input. These operators include + * control flow operators and high-order functions. + * These graphs don't change when the operators are invoked for different + * mini-batches. In this sense, the subgraphs are kind of similar to + * the parameters and show be kept as node attributes. + * + * Users need to make sure the subgraphs are disjoint with the main graph. + * If a graph shares nodes with subgraphs, loading the graph from LoadJSON + * may generate a graph that has a different structure from the original graph + * (some of the nodes are duplicated). If nodes are shared between two graphs, + * shared nodes might be executed multiple times, which can be a problem for + * stateful operators. + */ + std::vector > subgraphs; +}; + +/*! + * \brief Node represents an operation in a computation graph. + */ +class NNVM_DLL Node { + public: + /*! \brief The attributes in the node. */ + NodeAttrs attrs; + /*! \brief inputs to this node */ + std::vector inputs; + /*! + * \brief Optional control flow dependencies + * Gives operation must be performed before this operation. + */ + std::vector control_deps; + /*! \brief additional fields for this node */ + any info; + /*! \brief destructor of node */ + ~Node(); + /*! \return operator in this node */ + inline const Op* op() const; + /*! + * \brief return whether node is placeholder variable. + * This is equivalent to op == nullptr + * \return whether node is placeholder input variable + */ + inline bool is_variable() const; + /*! \return number of outputs from this node */ + inline uint32_t num_outputs() const; + /*! \return number of inputs from this node */ + inline uint32_t num_inputs() const; + /*! + * \brief create a new empty shared_ptr of Node. + * \return a created empty node. + */ + static NodePtr Create(); +}; + +/*! + * \brief Quick utilities make node. + * \param op_name The name of operator + * \param node_name The name of the node + * \param inputs The input entries + * \param attrs The attributes + * \return The created node entry. + */ +inline NodeEntry MakeNode( + const char* op_name, + std::string node_name, + std::vector inputs, + std::unordered_map attrs = + std::unordered_map()) { + NodePtr p = Node::Create(); + p->attrs.op = nnvm::Op::Get(op_name); + p->attrs.name = std::move(node_name); + p->attrs.dict = attrs; + if (p->attrs.op->attr_parser) { + p->attrs.op->attr_parser(&(p->attrs)); + } + p->inputs = std::move(inputs); + return NodeEntry{p, 0, 0}; +} + +// implementation of functions. +inline const Op* Node::op() const { + return this->attrs.op; +} +inline bool Node::is_variable() const { + return this->op() == nullptr; +} + +inline uint32_t Node::num_outputs() const { + if (is_variable()) return 1; + if (this->op()->get_num_outputs == nullptr) { + return this->op()->num_outputs; + } else { + return this->op()->get_num_outputs(this->attrs); + } +} + +inline uint32_t Node::num_inputs() const { + if (is_variable()) return 1; + if (this->op()->get_num_inputs == nullptr) { + return this->op()->num_inputs; + } else { + return this->op()->get_num_inputs(this->attrs); + } +} + +} // namespace nnvm + +#endif // NNVM_NODE_H_ diff --git a/include/nnvm/op.h b/include/nnvm/op.h new file mode 100644 index 000000000000..9d171bbdb2bc --- /dev/null +++ b/include/nnvm/op.h @@ -0,0 +1,562 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file nnvm/op.h + * \brief Operator information structor. + */ +#ifndef NNVM_OP_H_ +#define NNVM_OP_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "base.h" +#include "c_api.h" + +namespace nnvm { + +// forward declarations +class Node; +struct NodeAttrs; +template +class OpMap; +class OpGroup; +class OpRegistryEntry; +using dmlc::ParamFieldInfo; + +/*! \brief constant to indicate it take any length of positional inputs */ +static const uint32_t kVarg = std::numeric_limits::max(); + +/*! + * \brief Operator structure. + * + * Besides the fields in the structure, + * arbitary additional information can be associated with each op. + * See function GetAttr for details. + * + * \code + * // Example usage of Op + * + * // registeration of oeprators + * // NOTE that the attr function can register any + * // additional attributes to the operator + * NNVM_REGISTER_OP(add) + * .describe("add two inputs together") + * .set_num_inputs(2) + * .set_attr("OpKernel", AddKernel) + * .include("ElementwiseOpAttr"); + * + * // can register attribute by group + * // all the ops that include the group get the attribute. + * NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr) + * .set_attr("FInferShape", ElementwiseInferShape); + * + * NNVM_REGISTER_OP(sub) + * .describe("substract one tensor from another") + * .set_num_inputs(2); + * + * // Can call regster multiple times in different files + * // to register different part of information + * NNVM_REGISTER_OP(sub) + * .set_attr("OpKernel", SubKernel); + * .include("ElementwiseOpAttr"); + * + * // get operators from registry. + * void my_function() { + * const Op* add = Op::Get("add"); + * const Op* sub = Op::Get("sub"); + * // query basic information about each operator. + * assert(op->name == "plus"); + * assert(op->num_inputs == 2); + * + * // get additional registered information, + * // Assume user registered a OpKernel type attribute as gpu_kernel on each operator. + * const OpMap& kernel = Op::GetAttr("OpKernel"); + * // we can get the kernel functions by using operator as key. + * auto add_kernel = kernel[add]; + * auto sub_kernel = kernel[sub]; + * // subsequent code can make use of the queried kernel functions. + * } + * \endcode + */ +class NNVM_DLL Op { + public: + /*! \brief name of the operator */ + std::string name; + /*! + * \brief detailed description of the operator + * This can be used to generate docstring automatically for the operator. + */ + std::string description; + /* \brief description of inputs and keyword arguments*/ + std::vector arguments; + /*! + * \brief number of inputs to the operator, + * -1 means it is variable length + * When get_num_inputs is presented, + * the number will be decided by get_num_inputs instead. + * \sa get_num_inputs + */ + uint32_t num_inputs = 1; + /*! + * \brief number of outputs of the operator + * When get_num_outputs is presented. + * The number of outputs will be decided by + * get_num_outputs function + * \sa get_num_outputs + */ + uint32_t num_outputs = 1; + /*! + * \brief support level of the operator, + * The lower the more priority it contains. + * This is in analogies to BLAS levels. + */ + uint32_t support_level = 10; + /*! + * \brief get number of outputs given information about the node. + * \param attrs The attribute of the node + * \return number of outputs. + */ + std::function get_num_outputs = nullptr; + /*! + * \brief get number of inputs given information about the node. + * \param attrs The attribute of the node + * \return number of inputs + */ + std::function get_num_inputs = nullptr; + /*! + * \brief Attribute parser to parse the NodeAttrs information. + * + * This can help to get quick access to a parsed attribute + * object + * + * \code + * // Example usage of attr_parser. + * + * // Suppose we want to register operator sum. + * // The parameters about sum operator + * struct SumParam { + * int axis; + * }; + * // The parser function + * void SumAttrParser(NodeAttrs* attrs) { + * // This will be invoked during node construction. + * SumParam param; + * // parse axis string to integer + * param.axis = atoi(attrs->dict["axis"].c_str()); + * // set the parsed parameter + * attrs->parsed = std::move(param); + * } + * // The other function that can utilize the parsed result. + * TShape SumInferShape(const NodeAttrs& attrs, + * const std::vector& ishapes) { + * // we can use the parsed version of param + * // without repeatively parsing the parameter + * const SumParam& param = nnvm::get(attrs.parsed); + * } + * \endcode + */ + std::function attr_parser = nullptr; + // function fields. + /*! + * \brief setter function during registration + * Set the description of operator + * \param descr the description string. + * \return reference to self. + */ + inline Op& describe(const std::string& descr); // NOLINT(*) + /*! + * \brief Add argument information to the function. + * \param name Name of the argument. + * \param type Type of the argument. + * \param description Description of the argument. + * \return reference to self. + */ + inline Op& add_argument(const std::string &name, + const std::string &type, + const std::string &description); + /*! + * \brief Append list if arguments to the end. + * \param args Additional list of arguments. + * \return reference to self. + */ + inline Op& add_arguments(const std::vector &args); + /*! + * \brief Set the num_inputs + * \param n The number of inputs to be set. + * \return reference to self. + */ + inline Op& set_num_inputs(uint32_t n); // NOLINT(*) + /*! + * \brief Set the support level of op. + * \param level The support level. + * \return reference to self. + */ + inline Op& set_support_level(uint32_t level); // NOLINT(*) + /*! + * \brief Set the get_num_outputs function. + * \param fn The function to be set. + * \return reference to self. + */ + inline Op& set_num_inputs(std::function fn); // NOLINT(*) + /*! + * \brief Set the num_outputs + * \param n The number of outputs to be set. + * \return reference to self. + */ + inline Op& set_num_outputs(uint32_t n); // NOLINT(*) + /*! + * \brief Set the get_num_outputs function. + * \param fn The function to be set. + * \return reference to self. + */ + inline Op& set_num_outputs(std::function fn); // NOLINT(*) + /*! + * \brief Set the attr_parser function. + * \param fn The number of outputs to be set. + * \return reference to self. + */ + inline Op& set_attr_parser(std::function fn); // NOLINT(*) + /*! + * \brief Register additional attributes to operator. + * \param attr_name The name of the attribute. + * \param value The value to be set. + * \param plevel The priority level of this set, + * an higher priority level attribute + * will replace lower priority level attribute. + * Must be bigger than 0. + * + * Cannot set with same plevel twice in the code. + * + * \tparam ValueType The type of the value to be set. + */ + template + inline Op& set_attr(const std::string& attr_name, // NOLINT(*) + const ValueType& value, + int plevel = 10); + /*! + * \brief Add another alias to this operator. + * The same Op can be queried with Op::Get(alias) + * \param alias The alias of the operator. + * \return reference to self. + */ + Op& add_alias(const std::string& alias); // NOLINT(*) + /*! + * \brief Include all the attributes from an registered op group. + * \param group_name The name of the group. + * \return reference to self. + * + * \sa NNVM_REGISTER_OP_GROUP + */ + Op& include(const std::string& group_name); + /*! + * \brief Get an Op for a given operator name. + * Will raise an error if the op has not been registered. + * \param op_name Name of the operator. + * \return Pointer to a Op, valid throughout program lifetime. + */ + static const Op* Get(const std::string& op_name); + /*! + * \brief Get additional registered attribute about operators. + * If nothing has been registered, an empty OpMap will be returned. + * \param attr_name The name of the attribute. + * \return An OpMap of specified attr_name. + * \tparam ValueType The type of the attribute. + */ + template + static const OpMap& GetAttr(const std::string& attr_name); + + private: + template + friend class OpMap; + friend class OpGroup; + friend class dmlc::Registry; + // Program internal unique index of operator. + // Used to help index the program. + uint32_t index_{0}; + // internal constructor + Op(); + // get const reference to certain attribute + static const any* GetAttrMap(const std::string& key); + // update the attribute OpMap + static void UpdateAttrMap(const std::string& key, + std::function updater); + // add a trigger based on tag matching on certain tag attribute + // This will apply trigger on all the op such that + // include the corresponding group. + // The trigger will also be applied to all future registrations + // that calls include + static void AddGroupTrigger(const std::string& group_name, + std::function trigger); +}; + +/*! + * \brief A map data structure that takes Op* as key + * and returns ValueType + * \tparam ValueType The type of the value stored in map. + */ +template +class OpMap { + public: + /*! + * \brief get the corresponding value element at op + * \param op The key to the map + * \return the const reference to the content value. + */ + inline const ValueType& operator[](const Op* op) const; + /*! + * \brief get the corresponding value element at op with default value. + * \param op The key to the map + * \param def_value The default value when the key does not exist. + * \return the const reference to the content value. + */ + inline const ValueType& get(const Op* op, const ValueType& def_value) const; + /*! + * \brief Check if the map has op as key. + * \param op The key to the map + * \return 1 if op is contained in map, 0 otherwise. + */ + inline int count(const Op* op) const; + + private: + friend class Op; + // internal attribute name + std::string attr_name_; + // internal data + std::vector > data_; + OpMap() = default; +}; + +/*! + * \brief auxiliary data structure used to + * set attributes to a group of operators + */ +class OpGroup { + public: + /*! \brief the tag key to be matched */ + std::string group_name; + /*! + * \brief Register additional attributes to operator group. + * \param attr_name The name of the attribute. + * \param value The value to be set. + * \param plevel The priority level of this set, + * an higher priority level attribute + * will replace lower priority level attribute. + * Must be bigger than 0. + * + * Cannot set with same plevel twice in the code. + * + * \tparam ValueType The type of the value to be set. + */ + template + inline OpGroup& set_attr(const std::string& attr_name, // NOLINT(*) + const ValueType& value, + int plevel = 1); +}; + +// internal macros to make +#define NNVM_REGISTER_VAR_DEF(OpName) \ + static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName + +#define NNVM_REGISTER_GVAR_DEF(TagName) \ + static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_ ## NnvmOpGroup ## _ ## TagName + +/*! + * \def NNVM_REGISTER_OP + * \brief Register a new operator, or set attribute of the corresponding op. + * + * \param OpName The name of registry + * + * \code + * + * NNVM_REGISTER_OP(add) + * .describe("add two inputs together") + * .set_num_inputs(2) + * .set_attr("gpu_kernel", AddKernel); + * + * \endcode + */ +#define NNVM_REGISTER_OP(OpName) \ + DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \ + ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName) + +/*! + * \def NNVM_REGISTER_OP_GROUP + * \brief Register attribute to a group of operators. + * These attributes will be registered to Op that include the group. + * + * \param GroupName The name of the group. + * + * \code + * + * NNVM_REGISTER_OP(add) + * .include("ElementwiseOpAttr"); + * + * // register same attributes to all the ops that include the group + * NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr) + * .set_attr("FInferShape", ElementwiseInferShape); + * + * NNVM_REGISTER_OP(mul) + * .include("ElementwiseOpAttr"); + * + * \endcode + */ +#define NNVM_REGISTER_OP_GROUP(GroupName) \ + DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = \ + ::nnvm::OpGroup {#GroupName} + +// implementations of template functions after this. +// member function of Op +template +inline const OpMap& Op::GetAttr(const std::string& key) { + const any* ref = GetAttrMap(key); + if (ref == nullptr) { + // update the attribute map of the key by creating new empty OpMap + UpdateAttrMap(key, [key](any* pmap) { + // use callback so it is in lockscope + if (pmap->empty()) { + OpMap pm; + pm.attr_name_ = key; + *pmap = std::move(pm); + } + }); + ref = GetAttrMap(key); + } + return nnvm::get >(*ref); +} + +template +inline Op& Op::set_attr( // NOLINT(*) + const std::string& attr_name, + const ValueType& value, + int plevel) { + CHECK_GT(plevel, 0) + << "plevel in set_attr must be greater than 0"; + // update the attribute map of the key by creating new empty if needed. + UpdateAttrMap(attr_name, + [this, attr_name, value, plevel](any* pmap) { + // the callback is in lockscope so is threadsafe. + if (pmap->empty()) { + OpMap pm; + pm.attr_name_ = attr_name; + *pmap = std::move(pm); + } + CHECK(pmap->type() == typeid(OpMap)) + << "Attribute " << attr_name + << " of operator " << this->name + << " is registered as inconsistent types" + << " previously " << pmap->type().name() + << " current " << typeid(OpMap).name(); + std::vector >& vec = + nnvm::get >(*pmap).data_; + // resize the value type. + if (vec.size() <= index_) { + vec.resize(index_ + 1, + std::make_pair(ValueType(), 0)); + } + std::pair& p = vec[index_]; + CHECK(p.second != plevel) + << "Attribute " << attr_name + << " of operator " << this->name + << " is already registered with same plevel=" << plevel; + if (p.second < plevel) { + vec[index_] = std::make_pair(value, plevel); + } + }); + return *this; +} + + +inline Op& Op::describe(const std::string& descr) { // NOLINT(*) + this->description = descr; + return *this; +} + +inline Op& Op::add_argument(const std::string &name, + const std::string &type, + const std::string &description) { + arguments.push_back({name, type, type, description}); + return *this; +} + +inline Op& Op::add_arguments(const std::vector &args) { + this->arguments.insert(arguments.end(), args.begin(), args.end()); + return *this; +} + +inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*) + this->num_inputs = n; + return *this; +} + +inline Op& Op::set_support_level(uint32_t n) { // NOLINT(*) + this->support_level = n; + return *this; +} + +inline Op& Op::set_num_inputs(std::function fn) { // NOLINT(*) + this->get_num_inputs = fn; + return *this; +} + +inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*) + this->num_outputs = n; + return *this; +} + +inline Op& Op::set_num_outputs(std::function fn) { // NOLINT(*) + this->get_num_outputs = fn; + return *this; +} + +inline Op& Op::set_attr_parser(std::function fn) { // NOLINT(*) + this->attr_parser = fn; + return *this; +} + +// member functions of OpMap +template +inline int OpMap::count(const Op* op) const { + if (op == nullptr) return 0; + const uint32_t idx = op->index_; + return idx < data_.size() ? (data_[idx].second != 0) : 0; +} + +template +inline const ValueType& OpMap::operator[](const Op* op) const { + CHECK(op != nullptr); + const uint32_t idx = op->index_; + CHECK(idx < data_.size() && data_[idx].second) + << "Attribute " << attr_name_ + << " has not been registered for Operator " << op->name; + return data_[idx].first; +} + +template +inline const ValueType& OpMap::get(const Op* op, const ValueType& def_value) const { + if (op == nullptr) return def_value; + const uint32_t idx = op->index_; + if (idx < data_.size() && data_[idx].second) { + return data_[idx].first; + } else { + return def_value; + } +} + +template +inline OpGroup& OpGroup::set_attr(const std::string& attr_name, + const ValueType& value, + int plevel) { + auto trigger = [attr_name, value, plevel](Op* op) { + op->set_attr(attr_name, value, plevel); + }; + Op::AddGroupTrigger(group_name, trigger); + return *this; +} + +} // namespace nnvm + +#endif // NNVM_OP_H_ diff --git a/include/nnvm/op_attr_types.h b/include/nnvm/op_attr_types.h new file mode 100644 index 000000000000..abed19f9bc7d --- /dev/null +++ b/include/nnvm/op_attr_types.h @@ -0,0 +1,219 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file nnvm/op_attr_types.h + * \brief Data structures that can appear in operator attributes. + */ +#ifndef NNVM_OP_ATTR_TYPES_H_ +#define NNVM_OP_ATTR_TYPES_H_ + +#include +#include +#include +#include +#include "base.h" +#include "node.h" +#include "tuple.h" +#include "layout.h" + +namespace nnvm { + +// These types are optional attributes in each operator. +// Each attribute can be required by some passes. + +/*! + * \brief Return list of input arguments names of each operator. + * + * \param attrs The attributes of the node. + * \return list of inputs + * \note Register under "FListInputNames", default return {"data"}. + * + * FListInputNames enables automatic variable creation for missing arguments. + */ +using FListInputNames = std::function (const NodeAttrs& attrs)>; + +/*! + * \brief Return number of visible outputs by the user. + * + * \param attrs The attributes of the node. + * + * \note Register under "FNumVisibleOutputs", default not registered. + * This can be used to hide certain output from the user, + * but the additional outputs can be used to pass information from + * forward to gradient pass. + */ +using FNumVisibleOutputs = std::function; + +/*! + * \brief Return list of output arguments names of each operator. + * + * \param attrs The attributes of the node. + * \return list of inputs + * \note Register under "FListOutputNames", default return {"outputs"}. + * + * FListOutputNames customized naming for operator outputs. + */ +using FListOutputNames = std::function (const NodeAttrs& attrs)>; + +/*! + * \brief Check whether operator will mutate k-th input. + * \param attrs The attributes of the node. + * \return list of input indices it mutates. + * + * \note Register under "FMutateInputs", default return false + * FMutateInputs enables mutation order handling correctly. + */ +using FMutateInputs = std::function (const NodeAttrs& attrs)>; + +/*! + * \brief Inference function of certain type. + * \tparam AttrType The type of the attribute to be infered. + * \return whether all attributes are inferred. + */ +template +using FInferNodeEntryAttr = std::function *in_attrs, + std::vector *out_attrs)>; + +/*! + * \brief Get attribute dictionary from node. + * + * \param attrs The attributes of the node. + * \return The attribute dict. + * \note Register under "FUpdateAttrDict" + */ +using FGetAttrDict = std::function< + std::unordered_map + (const NodeAttrs& attrs)>; + +/*! + * \brief Shape inference function. + * Update the shapes given the input shape information. + * TShape.ndim() == 0 means the shape is still unknown. + * + * \note Register under "FInferShape", + * by default do not update any shapes. + * + * FInferShape is needed by shape inference + */ +using FInferShape = FInferNodeEntryAttr; + +/*! + * \brief Type inference function. + * Update the type given the known type information. + * + * \note Register under "FInferType", + * by default set all the output types to 0. + */ +using FInferType = FInferNodeEntryAttr; + +/*! + * \brief Whether this op is an explicit backward operator, + * If TIsBackward is true: + * - The first control_deps of the node points to the corresponding forward operator. + * + * \note Register under "TIsBackward" + * This enables easier shape/type inference for backward operators. + */ +using TIsBackward = bool; + +/*! + * \brief Get possible inplace options. + * This function enables optimization to reuse memory of inputs in output. + * \param attrs The attributes of the node + * \return list of pair of that maps input->output, + * indicating possible in place operations. + * + * \note Register under "FInplaceOption", by default no inplace can happen. + */ +using FInplaceOption = std::function< + std::vector > (const NodeAttrs& attrs)>; + +/*! + * \brief Get if the inplace option is an identity + * This function enables inplace optimization even when input reference count + * is greater than one. + * \param attrs The attributes of the node + * \return list of bool indicating whether corresponding pair from FInplaceOption + * is an identity + * + * \note Register under "FInplaceIdentity", by default no identities. + */ +using FInplaceIdentity = std::function (const NodeAttrs& attrs)>; + +/*! + * \brief Get list of inputs in the op whose content are actually not used by the operator + * These are dummy input that can be used for example in zeros_like, ones_like. + * + * \param attrs The attributes of the node + * \return list input index that are not used by the operator. + * + * \note Register under "FIgnoreInputs". + */ +using FIgnoreInputs = std::function< + std::vector (const NodeAttrs& attrs)>; + +/*! + * \brief Get the gradient node of the op node + * This function generates the backward graph of the node + * \param nodeptr The node to take gradient + * \param out_grads Gradient of current node's outputs + * \return gradients of the inputs + * + * \note Register under "FGradient" + */ +using FGradient = std::function( + const NodePtr& nodeptr, + const std::vector& out_grads)>; + +/*! + * \brief Set the attributes of input variable. + * Usually used for setting initialization or weight decay. + * \param attrs The attributes of this node. + * \param var the input variable + * \param index index of var in all inputs + */ +using FSetInputVarAttrOnCompose = std::function; + +/*! + * \brief Infer & correct function of node layout. See \p Layout for layout convention + * \param attrs The attribute of the node. + * \param ilayouts Given the input layouts produced by ancestor nodes, + * it should be filled by layouts that the node requests. + * If the requested layout is different from what ancestor produces, + * a __layout_transform__ operator will be inserted automatically. + * \param last_ilayouts The input layouts requested by the node + * at the last infer pass (if any). + * This can be useful when an operator wants to keep + * the input layout the same as the original one. + * For example, after the pass of AlterOpLayout, + * transpose(input, axis=[1, 2, 3, 0]) may receive an input of NCHW16c layout, + * with which it cannot calculate with axis=[1, 2, 3, 0]. + * Last input layouts allow it to know what the layout it originally inferred, + * i.e., the layout in the imported model. + * \param olayouts Inferred output layouts. + * \return success flag. + */ +using FCorrectLayout = std::function *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts)>; + +/*! + * \brief Get a list of inputs that represent graphs instead of data. + * Normally, input symbols are considered as data to the operator. However, + * control flow operators and high-order functions need to interpret symbols + * as graphs. + * \param attrs The attributes of this node. + * \return a list of input index that are interpreted as symbols by the operator. + * + * \note Register under "FInputGraph". + */ +using FInputGraph = std::function(const NodeAttrs& attrs)>; + +} // namespace nnvm + +#endif // NNVM_OP_ATTR_TYPES_H_ diff --git a/include/nnvm/pass.h b/include/nnvm/pass.h new file mode 100644 index 000000000000..2e8db6111887 --- /dev/null +++ b/include/nnvm/pass.h @@ -0,0 +1,128 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file nnvm/pass.h + * \brief Pass that can be applied to a graph. + */ +#ifndef NNVM_PASS_H_ +#define NNVM_PASS_H_ + +#include +#include +#include "base.h" +#include "graph.h" + +namespace nnvm { + +/*! + * \brief A PassFunction is an "Operator on Graph". + * It takes a source graph and return a graph that may or may + * not be the same as the input one. + * + * A pass function can either change the graph structure (thus, + * generating a new Graph), or add new attributes to the graph. + * + * \param src The graph to be transformed. + * \return The generated graph. + */ +typedef std::function PassFunction; + +/*! + * \brief Apply a series of pass transformations on the input graph. + * \param src The graph to be transformed. + * \param passes A list of pass names to be applied. + * \return The transformed graph + */ +Graph ApplyPasses(Graph src, + const std::vector& passes); + +/*! + * \brief Apply one pass to the graph. + * \param src The graph to be transformed. + * \param pass The name of pass to be applied. + * \return The transformed graph. + */ +inline Graph ApplyPass(Graph src, const std::string& pass) { + return ApplyPasses(src, {pass}); +} + + +/*! + * \brief Registry entry for pass functions. + */ +struct PassFunctionReg + : public dmlc::FunctionRegEntryBase { + /*! + * \brief Whether the pass will change graph structure + * If this is false, the pass will only change attributes. + */ + bool change_graph{false}; + /*! \brief dependencies on operator attributes */ + std::vector op_attr_dependency; + /*! \brief dependencies on attributes in the graph */ + std::vector graph_attr_dependency; + /*! \brief generated targets of graph attributes */ + std::vector graph_attr_targets; + /*! + * \brief Set whether this pass will change graph structure. + * \param v If true, the pass will change graph structure. + * \return Reference to self. + */ + PassFunctionReg& set_change_graph(bool v) { // NOLINT(*) + change_graph = v; + return *this; + } + /*! + * \brief Declare that this pass will generate the given graph attribute name + * once it is applied on the graph. + * \param attr_name Name of the graph attribute. + * \return Reference to self. + */ + PassFunctionReg& provide_graph_attr(const std::string& attr_name) { // NOLINT(*) + graph_attr_targets.push_back(attr_name); + return *this; + } + /*! + * \brief Declare this pass requires the given operator attribute to be + * available before being applied on the graph. + * \param attr_name Name of the attribute. + * \return Reference to self. + */ + PassFunctionReg& depend_op_attr(const std::string& attr_name) { // NOLINT(*) + op_attr_dependency.push_back(attr_name); + return *this; + } + /*! + * \brief Declare this pass requires the given graph attribute to be + * available before being applied on the graph. + * \param attr_name Name of the attribute. + * \return Reference to self. + */ + PassFunctionReg& depend_graph_attr(const std::string& attr_name) { // NOLINT(*) + graph_attr_dependency.push_back(attr_name); + return *this; + } +}; + +/*! + * \def NNVM_REGISTER_PASS + * \brief Macro to register pass fuctions. + * + * \code + * // example of registering a shape inference pass + * NNVM_REGISTER_PASS(InferShape) + * .describe("Shape Inference function, generate graph attributes") + * .provide_graph_attr("data_shape") + * .depend_graph_attr("indexed_graph") + * .depend_op_attr("infer_shape") + * .set_body([](const Graph& g) { + * // shape inference logic + * }); + * \endcode + */ +#define NNVM_REGISTER_PASS(name) \ + DMLC_REGISTRY_REGISTER(::nnvm::PassFunctionReg, PassFunctionReg, name) + +} // namespace nnvm + +#endif // NNVM_PASS_H_ diff --git a/include/nnvm/pass_functions.h b/include/nnvm/pass_functions.h new file mode 100644 index 000000000000..5a98dd456fb2 --- /dev/null +++ b/include/nnvm/pass_functions.h @@ -0,0 +1,190 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file nnvm/pass_functions.h + * \brief Pass functions that simply redirect the calls to ApplyPass + * + * This file serves as documentation on how to use functions implemented in "src/pass". + * It is totally optional to add these functions when you add a new pass, since + * ApplyPass can be directly called. + */ +#ifndef NNVM_PASS_FUNCTIONS_H_ +#define NNVM_PASS_FUNCTIONS_H_ + +#include +#include +#include +#include "base.h" +#include "pass.h" +#include "graph_attr_types.h" + +namespace nnvm { +namespace pass { + +/*! + * \brief Load a graph from JSON string, redirects to "LoadJSON" pass. + * \param json_str The json string. + * \return Loaded graph. + */ +inline Graph LoadJSON(const std::string& json_str) { + Graph ret; + ret.attrs["json"] = std::make_shared(json_str); + return ApplyPass(ret, "LoadJSON"); +} + +/*! + * \brief Save a graph to json, redirects to "SaveJSON" pass. + * \param graph The graph to be saved as json format. + * \return The json string. + */ +inline std::string SaveJSON(Graph graph) { + Graph ret = ApplyPass(std::move(graph), "SaveJSON"); + return ret.GetAttr("json"); +} + + +/*! + * \brief Print graph ir + * \param graph The graph to be printed + * \return The graph ir string. + */ +inline std::string PrintGraphIR(Graph graph) { + Graph ret = ApplyPass(std::move(graph), "PrintGraphIR"); + return ret.GetAttr("graphir"); +} + +/*! + * \brief Add control flow dependencies between nodes. + * + * This function will enforce the correct order between + * write (mutable operators) and read (immutable operators) + * to sovle write-after-read and read-after-write problems. + * + * \param src The input graph. + * \return A graph with proper control flow dependencies added. + */ +inline Graph OrderMutation(Graph src) { + return ApplyPass(std::move(src), "OrderMutation"); +} + +/*! + * \brief Infer shapes in the graph given the information. + * \param graph The input graph. + * \param shape_inputs The shapes of input symbols to the graph. + * \param shape_attr_key The key to the node attribute that can indicate shape. This is + * the place where manual hint for shapes could be injected. + * \return A graph with new attribute "shape" containing inferred shape of each NodeEntry. + * The index of ShapeVector is given by graph.indexed_graph().entry_id. + */ +inline Graph InferShape(Graph graph, + ShapeVector shape_inputs, + std::string shape_attr_key = "") { + if (shape_inputs.size() != 0) { + graph.attrs["shape_inputs"] = std::make_shared(std::move(shape_inputs)); + } + if (shape_attr_key.length() != 0) { + graph.attrs["shape_attr_key"] = std::make_shared(std::move(shape_attr_key)); + } + return ApplyPass(std::move(graph), "InferShape"); +} + +/*! + * \brief Infer types in the graph given the information. + * \param graph The input graph. + * \param dtype_inputs The types of input symbols to the graph. + * \param dtype_attr_key The key to the node attribute that can indicate types. This is + * the place where manual hint for types could be injected. + * \return A graph with new attribute "dtype" containing inferred type of each NodeEntry. + * The index of ShapeVector is given by graph.indexed_graph().entry_id. + */ +inline Graph InferType(Graph graph, + DTypeVector dtype_inputs, + std::string dtype_attr_key = "") { + if (dtype_inputs.size() != 0) { + graph.attrs["dtype_inputs"] = std::make_shared(std::move(dtype_inputs)); + } + if (dtype_attr_key.length() != 0) { + graph.attrs["dtype_attr_key"] = std::make_shared(std::move(dtype_attr_key)); + } + return ApplyPass(std::move(graph), "InferType"); +} + +/*! + * \brief Place the devices for each operator in the graph. + * + * Current device placement is quite simple. Each operator is assigned to a "group" (stored + * in `device_group_attr_key` attribute). Each group is assigned to a device (stored in + * `device_assign_map` attribute). Operators will be placed to the device assigned to its + * group. Copy operators will be injected if cross device reference happens. + * + * \param graph The input graph. + * \param device_group_attr_key The attribute name for hints of device group. + * \param device_assign_map The assignment map of device. + * \param device_copy_op The name of copy op to be inserted when cross device copy happened. + * \return A graph with new attribute "device", cotaining device information of each node. + */ +inline Graph PlaceDevice(Graph graph, + std::string device_group_attr_key, + DeviceAssignMap device_assign_map, + std::string device_copy_op) { + graph.attrs["device_group_attr_key"] = std::make_shared(std::move(device_group_attr_key)); + graph.attrs["device_assign_map"] = std::make_shared(std::move(device_assign_map)); + graph.attrs["device_copy_op"] = std::make_shared(std::move(device_copy_op)); + return ApplyPass(std::move(graph), "PlaceDevice"); +} + +/*! + * \brief Get the gradient graph whose outputs are gradients of xs wrt to ys. + * \param graph The input graph. + * \param ys The entries we want to take gradient from. + * \param xs The input to take gradient with respect to. + * \param ys_out_grad The symbol for additional gradient to be propagate back to y. + * \param aggregate_fun Aggregation function applied to aggregate the inputs. + * \param mirror_fun Optional mirror function to do mirror optimization and save memory. + * \param attr_hint_fun Optional, hint function to output a node that like src, but its attr is same as like. + * \param zero_ops Optional, list of operators that outputs a single zero array. The first one + * must be zeros_like. + * \param copy_op_str Optional, name of the copy operation required to handle duplicates + * on the edge of the graph + * \return A new graph, whose outputs correspond to inputs of xs. + */ +inline Graph Gradient( + Graph graph, + std::vector ys, + std::vector xs, + std::vector ys_out_grad, + std::function&& inputs)> aggregate_fun = nullptr, + std::function mirror_fun = nullptr, + std::function + attr_hint_fun = nullptr, + std::vector zero_ops = std::vector(), + std::string copy_op_str = std::string()) { + graph.attrs["grad_ys"] = std::make_shared(std::move(ys)); + + graph.attrs["grad_xs"] = std::make_shared(std::move(xs)); + graph.attrs["grad_ys_out_grad"] = std::make_shared(std::move(ys_out_grad)); + if (aggregate_fun != nullptr) { + graph.attrs["grad_aggregate_fun"] = std::make_shared(aggregate_fun); + } + + if (mirror_fun != nullptr) { + graph.attrs["grad_mirror_fun"] = std::make_shared(mirror_fun); + } + + if (attr_hint_fun != nullptr) { + graph.attrs["attr_hint_fun"] = std::make_shared(attr_hint_fun); + } + + if (zero_ops.size()) { + graph.attrs["zero_ops"] = std::make_shared(std::move(zero_ops)); + } + + if (copy_op_str != std::string()) { + graph.attrs["copy_op"] = std::make_shared(std::move(copy_op_str)); + } + + return ApplyPass(std::move(graph), "Gradient"); +} + +} // namespace pass +} // namespace nnvm +#endif // NNVM_PASS_FUNCTIONS_H_ diff --git a/include/nnvm/symbolic.h b/include/nnvm/symbolic.h new file mode 100644 index 000000000000..42cf5dd775c2 --- /dev/null +++ b/include/nnvm/symbolic.h @@ -0,0 +1,217 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file nnvm/symbolic.h + * \brief Symbolic graph construction API + * + * This API is optional, but useful to allow user + * to construct NNVM Graph easily, and quickly create + * front-end host languages. + */ +#ifndef NNVM_SYMBOLIC_H_ +#define NNVM_SYMBOLIC_H_ + +#include +#include +#include +#include + +#include "base.h" +#include "node.h" + +namespace nnvm { +/*! + * \brief Symbol is help class used to represent the operator node in Graph. + * + * Symbol acts as an interface for building graphs from different components + * like Variable, Functor and Group. Symbol is also exported to python front-end + * (while Graph is not) to enable quick test and deployment. Conceptually, + * symbol is the final operation of a graph and thus including all the information + * required (the graph) to evaluate its output value. + */ +class NNVM_DLL Symbol { + public: + /*! \brief option passed to ListAttr */ + enum ListAttrOption { + /*! \brief recursively list all attributes */ + kRecursive = 0, + /*! \brief only list attributes in current node */ + kShallow = 1 + }; + /*! \brief option passed to ListInputNames */ + enum ListInputOption { + /*! \brief list all the arguments */ + kAll = 0, + /*! \brief list only read only arguments */ + kReadOnlyArgs = 1, + /*! + * \brief List auxiliary states that can be mutated by the graph. + * This excludes the ReadOnly arguments + */ + kAuxiliaryStates = 2 + }; + + /*! \brief output entries contained in the symbol */ + std::vector outputs; + + /*! + * \brief Copy the symbol. + * \return A deep copy of this symbol. + */ + Symbol Copy() const; + /*! + * \brief Print the symbol info to output stream. + * \param os The output stream to print to. + */ + void Print(std::ostream &os) const; // NOLINT(*) + /*! + * \brief Get the index-th element from the returned tuple. + * \param index Index of multi output. + * \return The symbol corresponds to the indexed element. + */ + Symbol operator[] (size_t index) const; + /*! + * \brief List the input variable nodes. + * + * The order of the returned list is the same as the order of the input list to `operator()`. + * + * \param option The options to list the arguments. + * \return The arguments list of this symbol, they can be either named or unnamed (empty string). + * \sa ListInputOption + */ + std::vector ListInputs(ListInputOption option) const; + /*! + * \brief List the input names. + * + * The order of the returned list is the same as the order of the input list to `operator()`. + * + * \param option The options to list the arguments. + * \return The arguments list of this symbol, they can be either named or unnamed (empty string). + * \sa ListInputOption + */ + std::vector ListInputNames(ListInputOption option) const; + /*! + * \brief List the names of outputs for this symbol. + * + * For normal operators, it is usually symbol node name + "_output". + * + * \return get the descriptions of outputs for this symbol. + */ + std::vector ListOutputNames() const; + /*! + * \brief Compose the symbol with arguments, this changes the current symbol. + * The kwargs passed in can be in-complete, + * + * The rest of the symbols will remain the same name. + * + * \param args Positional arguments. + * \param kwargs Keyword arguments for the symbol. + * \param name Name of returned symbol. + */ + void Compose(const array_view& args, + const std::unordered_map& kwargs, + const std::string& name); + /*! + * \brief Apply the symbol as a function, compose with arguments + * + * This is equivalent to Copy then Compose. + * + * \param args Positional arguments for the symbol. + * \param kwargs Keyword arguments for the symbol. + * \param name Name of returned symbol. + * \return A new Symbol which is the composition of current symbol with its arguments. + */ + Symbol operator () (const array_view& args, + const std::unordered_map& kwargs, + const std::string& name) const; + /*! + * \brief Add control flow dependencies to the operators in symbols. + * + * For grouped symbol, an error will be raised. This mutates current symbolic Node. + * + * \param src The symbols to depend on. + */ + void AddControlDeps(const Symbol& src); + /* + * \brief Get all the internal nodes of the symbol. + * \return symbol A new symbol whose output contains all the outputs of the symbols + * including input variables and intermediate outputs. + */ + Symbol GetInternals() const; + /* + * \brief Get the direct inputs of the head node(s) of this symbol. + * \return symbol A new symbol whose output contains all the inputs of the head + * node(s). + */ + Symbol GetChildren() const; + /*! + * \brief Set additional attributes to current node. + * + * This only works for symbol with outputs from single operators. + * For grouped symbol, an error will be raised. + * + * This function mutates the node's symbol and is not recommended. + * + * \param attrs The attributes to set. + */ + void SetAttrs(const std::vector >& attrs); + /*! + * \brief Get attributes from the symbol. + * + * This only works for symbol with outputs from single operators. + * For grouped symbol, an error will be raised. + * + * \param key Key of the attribute. When key == "name", it returns the name attirbute. + * \param out The output value of the attribute. + * \return true If the attribute exists, false if the attribute does not exist. + */ + bool GetAttr(const std::string& key, std::string* out) const; + /*! + * \brief Get attribute dictionary from the symbol. + * + * For grouped symbol, an error will be raised. + * + * \param option If recursive flag is set, the attributes of all children are retrieved. + * The name of symbol will be pre-pended to each key. + * \return The created attribute. + */ + std::unordered_map ListAttrs(ListAttrOption option) const; + /*! + * \brief Get attribute dictionary from the symbol and all children. + * + * For grouped symbol, an error will be raised. + * + * \return The created attribute in format . + */ + std::vector > + ListAttrsRecursive() const; + /*! + * \brief Create symbolic functor(AtomicSymbol) by given operator and attributes. + * \param op The operator. + * \param attrs The additional attributes. + * \return Symbol that can be used to call compose further. + */ + static Symbol CreateFunctor(const Op* op, + std::unordered_map attrs); + /*! + * \brief Create symbolic functor(AtomicSymbol) by given node attributes. + * \param attrs pre-initialized Node attributes. + * \return Symbol that can be used to call compose further. + */ + static Symbol CreateFunctor(const NodeAttrs& attrs); + /*! + * \brief Create symbol node representing variable. + * \param name Name of the variable. + * \return The symbol. + */ + static Symbol CreateVariable(const std::string& name); + /*! + * \brief Create equivalence of symbol by grouping the symbols together. + * \param symbols A list of symbols to be grouped. + * \return The grouped symbol. + */ + static Symbol CreateGroup(const std::vector& symbols); +}; + +} // namespace nnvm + +#endif // NNVM_SYMBOLIC_H_ diff --git a/include/nnvm/top/README b/include/nnvm/top/README new file mode 100644 index 000000000000..09a4d6fc387f --- /dev/null +++ b/include/nnvm/top/README @@ -0,0 +1 @@ +NNVM Core Operator and Compiler diff --git a/include/nnvm/top/nn.h b/include/nnvm/top/nn.h new file mode 100644 index 000000000000..143a9548f18a --- /dev/null +++ b/include/nnvm/top/nn.h @@ -0,0 +1,498 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file nnvm/top/nn.h + * \brief Auxiliary param for tensor primitive. + */ +#ifndef NNVM_TOP_NN_H_ +#define NNVM_TOP_NN_H_ + +#include +#include +#include +#include +#include +#include "tensor.h" + +namespace nnvm { +namespace top { + +struct DenseParam : public dmlc::Parameter { + int units; + bool use_bias; + + DMLC_DECLARE_PARAMETER(DenseParam) { + DMLC_DECLARE_FIELD(units).set_lower_bound(1) + .describe("Number of hidden units of the dense transformation."); + DMLC_DECLARE_FIELD(use_bias).set_default(true) + .describe("Whether to use bias parameter"); + } + // constants + static const constexpr int kData = 0; + static const constexpr int kWeight = 1; + static const constexpr int kBias = 2; +}; + +struct DropoutParam : public dmlc::Parameter { + float rate; + + DMLC_DECLARE_PARAMETER(DropoutParam) { + DMLC_DECLARE_FIELD(rate).set_default(0.5) + .set_range(0, 1) + .describe("Fraction of the input that gets dropped out during training time."); + } +}; + +struct BatchNormParam : public dmlc::Parameter { + int axis; + double epsilon; + double momentum; + bool center; + bool scale; + + DMLC_DECLARE_PARAMETER(BatchNormParam) { + DMLC_DECLARE_FIELD(axis).set_default(1) + .describe("Specify which shape axis the channel is specified."); + DMLC_DECLARE_FIELD(epsilon).set_default(1e-5) + .describe("Small float added to variance to avoid dividing by zero."); + DMLC_DECLARE_FIELD(center).set_default(true) + .describe("If True, add offset of `beta` to normalized tensor." + "If False, `beta` is ignored."); + DMLC_DECLARE_FIELD(scale).set_default(true) + .describe("If True, multiply by `gamma`. If False, `gamma` is not used." + "When the next layer is piecewise linear (also e.g. `nn.relu`)," + "this can be disabled since the scaling" + "will be done by the next layer."); + } + // constants + static const constexpr int kData = 0; + static const constexpr int kGamma = 1; + static const constexpr int kBeta = 2; + static const constexpr int kMovingMean = 3; + static const constexpr int kMovingVariance = 4; +}; + + +// Shared by softmax and log_softmax +struct SoftmaxParam : public dmlc::Parameter { + int axis; + + DMLC_DECLARE_PARAMETER(SoftmaxParam) { + DMLC_DECLARE_FIELD(axis).set_default(-1) + .describe("The axis to sum over when computing softmax."); + } +}; + +struct LeakyReLUParam : public dmlc::Parameter { + double alpha; + + DMLC_DECLARE_PARAMETER(LeakyReLUParam) { + DMLC_DECLARE_FIELD(alpha).set_lower_bound(0.0).set_default(0.25) + .describe("slope coefficient for the negative half axis."); + } +}; + +struct PReLUParam : public dmlc::Parameter { + int axis; + DMLC_DECLARE_PARAMETER(PReLUParam) { + DMLC_DECLARE_FIELD(axis).set_default(1) + .describe("Specify which shape axis the channel is specified."); + } +}; + +struct PadParam : public dmlc::Parameter { + float pad_value; + Tuple > pad_width; + + DMLC_DECLARE_PARAMETER(PadParam) { + DMLC_DECLARE_FIELD(pad_value).set_default(0.0) + .describe("The value to be padded."); + DMLC_DECLARE_FIELD(pad_width) + .describe("Number of values padded to the edges of each axis, " + "in the format of ((before_1, after_1), ... (before_N, after_N))"); + } +}; + + +struct Conv2DParam : public dmlc::Parameter { + int channels; + TShape kernel_size; + TShape strides; + TShape padding; + TShape dilation; + int groups; + std::string layout; + std::string kernel_layout; + std::string out_layout; + int out_dtype; + bool use_bias; + + DMLC_DECLARE_PARAMETER(Conv2DParam) { + DMLC_DECLARE_FIELD(channels) + .describe("The dimensionality of the output space" + "i.e. the number of output channels in the convolution."); + DMLC_DECLARE_FIELD(kernel_size) + .describe("Specifies the dimensions of the convolution window."); + DMLC_DECLARE_FIELD(strides).set_default(TShape({1, 1})) + .describe("Specifies the strides of the convolution."); + DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0})) + .describe("If padding is non-zero, then the input is implicitly zero-padded" + "on both sides for padding number of points"); + DMLC_DECLARE_FIELD(dilation).set_default(TShape({1, 1})) + .describe("Specifies the dilation rate to use for dilated convolution."); + DMLC_DECLARE_FIELD(groups).set_default(1) + .describe("Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); + DMLC_DECLARE_FIELD(layout).set_default("NCHW") + .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + DMLC_DECLARE_FIELD(out_layout).set_default("__undef__") + .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); + DMLC_DECLARE_FIELD(kernel_layout).set_default("OIHW") + .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); + DMLC_DECLARE_DTYPE_FIELD(out_dtype) + .add_enum("same", -1) + .set_default(-1) + .describe("Output data type, set to explicit type under mixed precision setting"); + + DMLC_DECLARE_FIELD(use_bias).set_default(true) + .describe("Whether the layer uses a bias vector."); + } + // constants + static const constexpr int kData = 0; + static const constexpr int kWeight = 1; + static const constexpr int kBias = 2; +}; + +struct WinogradWeightTransformParam : public dmlc::Parameter { + int tile_size; + + DMLC_DECLARE_PARAMETER(WinogradWeightTransformParam) { + DMLC_DECLARE_FIELD(tile_size) + .describe("Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)"); + } + + static const constexpr int kWeight = 0; +}; + +struct WinogradConv2DParam : public dmlc::Parameter { + int channels; + TShape kernel_size; + TShape strides; + TShape padding; + TShape dilation; + int groups; + std::string layout; + std::string kernel_layout; + std::string out_layout; + int out_dtype; + bool use_bias; + int tile_size; + + DMLC_DECLARE_PARAMETER(WinogradConv2DParam) { + DMLC_DECLARE_FIELD(channels) + .describe("The dimensionality of the output space" + "i.e. the number of output channels in the convolution."); + DMLC_DECLARE_FIELD(kernel_size) + .describe("Specifies the dimensions of the convolution window."); + DMLC_DECLARE_FIELD(strides).set_default(TShape({1, 1})) + .describe("Specifies the strides of the convolution."); + DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0})) + .describe("If padding is non-zero, then the input is implicitly zero-padded" + "on both sides for padding number of points"); + DMLC_DECLARE_FIELD(dilation).set_default(TShape({1, 1})) + .describe("Specifies the dilation rate to use for dilated convolution."); + DMLC_DECLARE_FIELD(groups).set_default(1) + .describe("Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); + DMLC_DECLARE_FIELD(layout).set_default("NCHW") + .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + DMLC_DECLARE_FIELD(out_layout).set_default("__undef__") + .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); + DMLC_DECLARE_FIELD(kernel_layout).set_default("OIHW") + .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); + DMLC_DECLARE_DTYPE_FIELD(out_dtype) + .add_enum("same", -1) + .set_default(-1) + .describe("Output data type, set to explicit type under mixed precision setting"); + DMLC_DECLARE_FIELD(use_bias).set_default(true) + .describe("Whether the layer uses a bias vector."); + DMLC_DECLARE_FIELD(tile_size) + .describe("Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)"); + } + // constants + static const constexpr int kData = 0; + static const constexpr int kWeight = 1; + static const constexpr int kBias = 2; +}; + +struct Conv2DTransposeParam : public dmlc::Parameter { + int channels; + TShape kernel_size; + TShape strides; + TShape padding; + TShape output_padding; + TShape dilation; + int groups; + std::string layout; + std::string kernel_layout; + int out_dtype; + bool use_bias; + + DMLC_DECLARE_PARAMETER(Conv2DTransposeParam) { + DMLC_DECLARE_FIELD(channels) + .describe("The dimensionality of the output space" + "i.e. the number of output channels in the convolution."); + DMLC_DECLARE_FIELD(kernel_size) + .describe("Specifies the dimensions of the convolution window."); + DMLC_DECLARE_FIELD(strides).set_default(TShape({1, 1})) + .describe("Specifies the strides of the convolution."); + DMLC_DECLARE_FIELD(output_padding).set_default(TShape({0, 0})) + .describe("Zero-padding added to one side of the output."); + DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0})) + .describe("If padding is non-zero, then the input is implicitly zero-padded" + "on both sides for padding number of points"); + DMLC_DECLARE_FIELD(dilation).set_default(TShape({1, 1})) + .describe("Specifies the dilation rate to use for dilated convolution."); + DMLC_DECLARE_FIELD(groups).set_default(1) + .describe("Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); + DMLC_DECLARE_FIELD(layout).set_default("NCHW") + .describe("Dimension ordering of data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + DMLC_DECLARE_FIELD(kernel_layout).set_default("OIHW") + .describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); + DMLC_DECLARE_DTYPE_FIELD(out_dtype) + .add_enum("same", -1) + .set_default(-1) + .describe("Output data type, set to explicit type under mixed precision setting"); + DMLC_DECLARE_FIELD(use_bias).set_default(true) + .describe("Whether the layer uses a bias vector."); + } + // constants + static const constexpr int kData = 0; + static const constexpr int kWeight = 1; + static const constexpr int kBias = 2; +}; + + +struct MaxPool2DParam : public dmlc::Parameter { + TShape pool_size; + TShape strides; + TShape padding; + std::string layout; + bool ceil_mode; + + DMLC_DECLARE_PARAMETER(MaxPool2DParam) { + DMLC_DECLARE_FIELD(pool_size) + .describe("Size of the pooling windows.."); + DMLC_DECLARE_FIELD(strides).set_default(TShape({1, 1})) + .describe("Specifies the strides of the convolution."); + DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0})) + .describe("If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + DMLC_DECLARE_FIELD(layout).set_default("NCHW") + .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + DMLC_DECLARE_FIELD(ceil_mode).set_default(false) + .describe("When true, will use ceil instead of floor to compute the output shape."); + } +}; + + +struct AvgPool2DParam : public dmlc::Parameter { + TShape pool_size; + TShape strides; + TShape padding; + std::string layout; + bool ceil_mode; + bool count_include_pad; + + DMLC_DECLARE_PARAMETER(AvgPool2DParam) { + DMLC_DECLARE_FIELD(pool_size) + .describe("Size of the pooling windows.."); + DMLC_DECLARE_FIELD(strides).set_default(TShape({1, 1})) + .describe("Specifies the strides of the convolution."); + DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0})) + .describe("If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + DMLC_DECLARE_FIELD(layout).set_default("NCHW") + .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + DMLC_DECLARE_FIELD(ceil_mode).set_default(false) + .describe("When true, will use ceil instead of floor to compute the output shape."); + DMLC_DECLARE_FIELD(count_include_pad).set_default(false) + .describe("When true, will include padding to compute the average"); + } +}; + + +struct GlobalPool2DParam : public dmlc::Parameter { + std::string layout; + + DMLC_DECLARE_PARAMETER(GlobalPool2DParam) { + DMLC_DECLARE_FIELD(layout).set_default("NCHW") + .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + } +}; + +struct UpSamplingParam : public dmlc::Parameter { + int scale; + std::string layout; + std::string method; + + DMLC_DECLARE_PARAMETER(UpSamplingParam) { + DMLC_DECLARE_FIELD(scale) + .describe("upsampling scaling factor"); + DMLC_DECLARE_FIELD(layout) + .set_default("NCHW") + .describe("Dimension ordering of data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Upsampling is applied on the 'H' and" + "'W' dimensions."); + DMLC_DECLARE_FIELD(method) + .set_default("NEAREST_NEIGHBOR") + .describe("Specify the mode to use for scaling." + "NEAREST_NEIGHBOR - Nearest Neighbor" + "BILINEAR - Bilinear Interpolation"); + } +}; + +struct LayoutTransformParam : public dmlc::Parameter { + std::string src_layout; + std::string dst_layout; + + DMLC_DECLARE_PARAMETER(LayoutTransformParam) { + DMLC_DECLARE_FIELD(src_layout).set_default("__undef__") + .describe("Dimension ordering of data"); + DMLC_DECLARE_FIELD(dst_layout).set_default("__undef__") + .describe("Dimension ordering of data."); + } +}; + +struct MultiBoxPriorParam : public dmlc::Parameter { + Tuple sizes; + Tuple ratios; + Tuple steps; + Tuple offsets; + bool clip; + + DMLC_DECLARE_PARAMETER(MultiBoxPriorParam) { + DMLC_DECLARE_FIELD(sizes).set_default(Tuple({1.0})) + .describe("List of sizes of generated MultiBoxPriores."); + DMLC_DECLARE_FIELD(ratios).set_default(Tuple({1.0})) + .describe("List of aspect ratios of generated MultiBoxPriores."); + DMLC_DECLARE_FIELD(steps).set_default(Tuple({-1.0, -1.0})) + .describe("Priorbox step across y and x, -1 for auto calculation."); + DMLC_DECLARE_FIELD(offsets).set_default(Tuple({0.5, 0.5})) + .describe("Priorbox center offsets, y and x respectively."); + DMLC_DECLARE_FIELD(clip).set_default(false) + .describe("Whether to clip out-of-boundary boxes."); + } +}; + +struct MultiBoxTransformLocParam : public dmlc::Parameter { + bool clip; + float threshold; + Tuple variances; + DMLC_DECLARE_PARAMETER(MultiBoxTransformLocParam) { + DMLC_DECLARE_FIELD(clip).set_default(true) + .describe("Clip out-of-boundary boxes."); + DMLC_DECLARE_FIELD(threshold).set_default(0.01) + .describe("Threshold to be a positive prediction."); + DMLC_DECLARE_FIELD(variances).set_default(Tuple({0.1f, 0.1f, 0.2f, 0.2f})) + .describe("Variances to be decoded from box regression output."); + } +}; + +struct NMSParam : public dmlc::Parameter { + float nms_threshold; + bool force_suppress; + int nms_topk; + DMLC_DECLARE_PARAMETER(NMSParam) { + DMLC_DECLARE_FIELD(nms_threshold).set_default(0.5) + .describe("Non-maximum suppression threshold."); + DMLC_DECLARE_FIELD(force_suppress).set_default(false) + .describe("Suppress all detections regardless of class_id."); + DMLC_DECLARE_FIELD(nms_topk).set_default(-1) + .describe("Keep maximum top k detections before nms, -1 for no limit."); + } +}; + +struct LRNParam : public dmlc::Parameter { + int size; + int axis; + float alpha; + float beta; + float bias; + + DMLC_DECLARE_PARAMETER(LRNParam) { + DMLC_DECLARE_FIELD(size) + .describe("The size of the local region to be considered for normalization."); + DMLC_DECLARE_FIELD(axis) + .describe("input data layout channel axis"); + DMLC_DECLARE_FIELD(alpha) + .describe("The scaling parameter."); + DMLC_DECLARE_FIELD(beta) + .describe("The exponent parameter."); + DMLC_DECLARE_FIELD(bias) + .describe("The offset parameter."); + } + // constants + static const constexpr int kData = 0; +}; + +struct L2NormalizeParam : public dmlc::Parameter { + float eps; + Tuple axis; + + DMLC_DECLARE_PARAMETER(L2NormalizeParam) { + DMLC_DECLARE_FIELD(eps) + .describe("float type epsilon value."); + DMLC_DECLARE_FIELD(axis) + .describe("axis over the normalization applied"); + } +}; + +} // namespace top +} // namespace nnvm + +#endif // NNVM_TOP_NN_H_ diff --git a/include/nnvm/top/tensor.h b/include/nnvm/top/tensor.h new file mode 100644 index 000000000000..53ed5b3b0a22 --- /dev/null +++ b/include/nnvm/top/tensor.h @@ -0,0 +1,301 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file nnvm/top/tensor.h + * \brief Auxiliary param for tensor primitive. + */ +#ifndef NNVM_TOP_TENSOR_H_ +#define NNVM_TOP_TENSOR_H_ + +#include +#include +#include + +namespace nnvm { +namespace top { + +struct ConcatenateParam : public dmlc::Parameter { + int axis; + DMLC_DECLARE_PARAMETER(ConcatenateParam) { + DMLC_DECLARE_FIELD(axis).set_default(1) + .describe("the axis to be concated."); + } +}; + +struct ExpandDimsParam : public dmlc::Parameter { + int axis; + int num_newaxis; + DMLC_DECLARE_PARAMETER(ExpandDimsParam) { + DMLC_DECLARE_FIELD(axis) + .describe("the axis to be expanded."); + DMLC_DECLARE_FIELD(num_newaxis).set_lower_bound(1).set_default(1) + .describe("Number of new axis to be inserted."); + } +}; + +struct SplitParam : public dmlc::Parameter { + // numpy convention, only support indices, not support list. + Tuple indices_or_sections; + int axis; + // additional hint whether it is equal_split mode + // deduced from indices_or_sections + bool equal_split; + + DMLC_DECLARE_PARAMETER(SplitParam) { + DMLC_DECLARE_FIELD(indices_or_sections) + .describe("Number of outputs to be splitted"); + DMLC_DECLARE_FIELD(axis).set_lower_bound(0).set_default(1) + .describe("the axis to be splitted."); + } +}; + + +struct TakeParam : public dmlc::Parameter { + dmlc::optional axis; + + DMLC_DECLARE_PARAMETER(TakeParam) { + DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional()) + .describe("the axis over which to select values."); + } +}; + +struct StridedSliceParam : public dmlc::Parameter { + // numpy convention, only support indices, not support list. + Tuple begin; + Tuple end; + Tuple stride; + + DMLC_DECLARE_PARAMETER(StridedSliceParam) { + DMLC_DECLARE_FIELD(begin) + .describe("Indices for begin of slice"); + DMLC_DECLARE_FIELD(end) + .describe("Indices for end of the slice"); + DMLC_DECLARE_FIELD(stride).set_default(Tuple()) + .describe("Stride values of the slice"); + } +}; + +enum TypeFlag { + kFloat32 = 0, + kFloat64 = 1, + kFloat16 = 2, + kUint8 = 3, + kInt32 = 4, + kInt8 = 5, + kInt64 = 6, + kInt16 = 7, + kUint16 = 8, + kUint32 = 9, + kUint64 = 10, +}; + +enum IndicatorRuleFlag { + kGT0 = 0, + kLT0 = 1, + kMax = 2, + kMin = 3, +}; + +#define DMLC_DECLARE_DTYPE_FIELD(name) \ + DMLC_DECLARE_FIELD(name) \ + .add_enum("float16", kFloat16) \ + .add_enum("float32", kFloat32) \ + .add_enum("float64", kFloat64) \ + .add_enum("uint8", kUint8) \ + .add_enum("uint16", kUint16) \ + .add_enum("uint32", kUint32) \ + .add_enum("uint64", kUint64) \ + .add_enum("int8", kInt8) \ + .add_enum("int16", kInt16) \ + .add_enum("int32", kInt32) \ + .add_enum("int64", kInt64) + +struct CastParam : public dmlc::Parameter { + int dtype; + DMLC_DECLARE_PARAMETER(CastParam) { + DMLC_DECLARE_DTYPE_FIELD(dtype) + .describe("Output data type."); + } +}; + +struct IndicatorParam : public dmlc::Parameter { + TShape axis; + bool exclude; + DMLC_DECLARE_PARAMETER(IndicatorParam) { + DMLC_DECLARE_FIELD(axis).set_default(TShape()) + .describe(R"code(The axis or axes along which to perform the indicator rule. + + The default, `axis=()`, will compute over all elements into a + scalar array with shape `(1,)`. + + If `axis` is int, rule is applied on a particular axis. + + If `axis` is a tuple of ints, rule is applied on all the axes + specified in the tuple. + + If `exclude` is true, rule will be applied on the axes that are + NOT in axis instead.)code"); + DMLC_DECLARE_FIELD(exclude).set_default(false) + .describe("Whether to apply rule on axis that are NOT in axis instead."); + } +}; + +struct ReshapeParam : public dmlc::Parameter { + Tuple shape; + + DMLC_DECLARE_PARAMETER(ReshapeParam) { + DMLC_DECLARE_FIELD(shape); + } +}; + +struct SqueezeParam : public dmlc::Parameter { + TShape axis; + + DMLC_DECLARE_PARAMETER(SqueezeParam) { + DMLC_DECLARE_FIELD(axis).set_default(TShape()) + .describe("The axis to squeeze in the input tensor."); + } +}; + +struct ScalarParam : public dmlc::Parameter { + double scalar; + + DMLC_DECLARE_PARAMETER(ScalarParam) { + DMLC_DECLARE_FIELD(scalar); + } +}; + +struct FillValueParam : public dmlc::Parameter { + double fill_value; + + DMLC_DECLARE_PARAMETER(FillValueParam) { + DMLC_DECLARE_FIELD(fill_value) + .describe("Scalar value to be filled"); + } +}; + +struct TransposeParam : public dmlc::Parameter { + TShape axes; + + DMLC_DECLARE_PARAMETER(TransposeParam) { + DMLC_DECLARE_FIELD(axes).set_default(TShape()) + .describe("Target axis order. By default the axes will be inverted."); + } +}; + +struct FlipParam : public dmlc::Parameter { + int axis; + DMLC_DECLARE_PARAMETER(FlipParam) { + DMLC_DECLARE_FIELD(axis).set_default(0) + .describe("the axis to be reveresed."); + } +}; + +struct BroadcastToParam : public dmlc::Parameter { + TShape shape; + + DMLC_DECLARE_PARAMETER(BroadcastToParam) { + DMLC_DECLARE_FIELD(shape).set_default(TShape()) + .describe("The shape of the desired array." + " We can set the dim to zero if it's same as the original." + " E.g `A = broadcast_to(B, shape=(10, 0, 0))` "); + } +}; + +struct ReduceParam : public dmlc::Parameter { + TShape axis; + bool keepdims; + bool exclude; + + DMLC_DECLARE_PARAMETER(ReduceParam) { + DMLC_DECLARE_FIELD(axis).set_default(TShape()) + .describe(R"code(The axis or axes along which to perform the reduction. + + The default, `axis=()`, will compute over all elements into a + scalar array with shape `(1,)`. + + If `axis` is int, a reduction is performed on a particular axis. + + If `axis` is a tuple of ints, a reduction is performed on all the axes + specified in the tuple. + + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead.)code"); + + DMLC_DECLARE_FIELD(keepdims).set_default(false) + .describe("If this is set to `True`, the reduced axes are left " + "in the result as dimension with size one."); + DMLC_DECLARE_FIELD(exclude).set_default(false) + .describe("Whether to perform reduction on axis that are NOT in axis instead."); + } +}; + +struct InitOpWithScalarParam : public dmlc::Parameter { + TShape shape; + int dtype; + double fill_value; + + DMLC_DECLARE_PARAMETER(InitOpWithScalarParam) { + DMLC_DECLARE_FIELD(shape).set_default(TShape()); + DMLC_DECLARE_DTYPE_FIELD(dtype).set_default(kFloat32) + .describe("Target data type."); + DMLC_DECLARE_FIELD(fill_value).describe("Scalar value to fill"); + } +}; + +struct InitOpParam : public dmlc::Parameter { + TShape shape; + int dtype; + + DMLC_DECLARE_PARAMETER(InitOpParam) { + DMLC_DECLARE_FIELD(shape).set_default(TShape()); + DMLC_DECLARE_DTYPE_FIELD(dtype).set_default(kFloat32) + .describe("Target data type."); + } +}; + +struct ElementWiseReduceParam : public dmlc::Parameter { + int num_args; + DMLC_DECLARE_PARAMETER(ElementWiseReduceParam) { + DMLC_DECLARE_FIELD(num_args).set_lower_bound(1) + .describe("Number of inputs to be reduced."); + } +}; + +struct MatMulParam : public dmlc::Parameter { + bool transpose_a; + bool transpose_b; + + DMLC_DECLARE_PARAMETER(MatMulParam) { + DMLC_DECLARE_FIELD(transpose_a) + .describe("If true then transpose the first input before dot.") + .set_default(false); + DMLC_DECLARE_FIELD(transpose_b) + .describe("If true then transpose the second input before dot.") + .set_default(false); + } +}; + +struct ClipParam : public dmlc::Parameter { + double a_min, a_max; + DMLC_DECLARE_PARAMETER(ClipParam) { + DMLC_DECLARE_FIELD(a_min) + .describe("Minimum value such that value smaller then this will be clipped."); + DMLC_DECLARE_FIELD(a_max) + .describe("Maximum value such that value larger then this will be clipped."); + } +}; + +struct SliceLikeParam : public dmlc::Parameter { + Tuple axis; + DMLC_DECLARE_PARAMETER(SliceLikeParam) { + DMLC_DECLARE_FIELD(axis).set_default(Tuple()) + .describe("List of axes on which input data will be sliced according to the " + "corresponding size of the second input. By default will slice " + "on all axes. Negative axes are supported."); + } +}; + +} // namespace top +} // namespace nnvm + +#endif // NNVM_TOP_TENSOR_H_ diff --git a/include/nnvm/tuple.h b/include/nnvm/tuple.h new file mode 100644 index 000000000000..36b8ef13c74a --- /dev/null +++ b/include/nnvm/tuple.h @@ -0,0 +1,633 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file nnvm/tuple.h + * \brief Data structure Tuple and TShape to store dynamic sized shapes. + */ +#ifndef NNVM_TUPLE_H_ +#define NNVM_TUPLE_H_ + +#include +#include +#include +#include +#include +#include +#include "base.h" + +namespace nnvm { + +/*! \brief data type to store dim size */ +typedef int64_t dim_t; + +/*! + * \brief A dynamic sized array data structure that is optimized for storing + * small number of elements with same type. + * + * Data will be stored in stack when number of elements is small. + * It is suitable to hold shape of Tensor. + * + * \tparam ValueType The type of data stored inside tuple. + * \sa TShape + */ +template +class Tuple { + public: + /*! \brief default constructor */ + Tuple() = default; + /*! \brief destructor */ + inline ~Tuple() { + delete [] data_heap_; + } + /*! + * \brief copy constructor from another tuple + * \param s the source tuple + */ + inline Tuple(const Tuple& s) { + this->assign(s.begin(), s.end()); + } + /*! + * \brief constructor from initializer list + * \param init the initializer_list + */ + inline Tuple(std::initializer_list init) { + this->assign(init.begin(), init.end()); + } + /*! + * \brief constructor from vector + * \param init the vector + */ + inline Tuple(std::vector init) { // NOLINT(runtime/explicit) + this->assign(init.begin(), init.end()); + } + /*! + * \brief move constructor from Tuple + * \param src the source shape + */ + + inline Tuple(Tuple&& src) { // NOLINT(runtime/explicit) + this->swap(src); + } + /*! + * \brief construct the Tuple from content of iterator + * \param begin the beginning of iterator + * \param end end the end of the iterator + * \tparam RandomAccessIterator iterator type + */ + template + inline Tuple(RandomAccessIterator begin, + RandomAccessIterator end) { + this->assign(begin, end); + } + /*! + * \brief Assign content to tuple from iterator. + * \param begin the beginning of iterator + * \param end end the end of the iterator + * \tparam RandomAccessIterator iterator type + */ + template + inline void assign(RandomAccessIterator begin, + RandomAccessIterator end) { + this->SetDim(end - begin); + std::copy(begin, end, this->begin()); + } + /*! + * \brief Swap current object with other + * \param other another object to be swapped. + */ + inline void swap(Tuple& other) { // NOLINT(*) + std::swap(ndim_, other.ndim_); + std::swap(num_heap_allocated_, other.num_heap_allocated_); + std::swap(data_stack_, other.data_stack_); + std::swap(data_heap_, other.data_heap_); + } + /*! + * \brief assignment from another tuple. + * \param src source tuple + * \return reference of self + */ + inline Tuple& operator=(const Tuple& src) { + this->assign(src.begin(), src.end()); + return *this; + } + /*! + * \brief assignment from rvalue of another tuple. + * \param src source tuple + * \return reference of self + */ + inline Tuple& operator=(Tuple&& src) { + Tuple(std::move(src)).swap(*this); + return *this; + } + /*! + * \brief assignment from initializer list + * \param init the source initializer list + * \return reference of self + */ + inline Tuple &operator=(std::initializer_list init) { + this->assign(init.begin(), init.end()); + return *this; + } + /*! + * \return whether two tuple equals + * \param s the tuple to compare against + */ + inline bool operator==(const Tuple &s) const { + if (ndim_ != s.ndim_) return false; + return std::equal(begin(), end(), s.begin()); + } + /*! + * \return whether two tuple not equal + * \param s the tuple to compare against + */ + inline bool operator!=(const Tuple &s) const { + return !(*this == s); + } + /*! \return the begin data pointer to content of the tuple */ + inline const ValueType *begin() const { + return ndim_ <= kStackCache ? data_stack_ : data_heap_; + } + /*! \return the begin data pointer to content of the tuple */ + inline ValueType *begin() { + return ndim_ <= kStackCache ? data_stack_ : data_heap_; + } + /*! \return the data pointer to end of the tuple */ + inline const ValueType* end() const { + return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_); + } + /*! \return the data pointer to end the tuple */ + inline ValueType* end() { + return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_); + } + /*! \return number of dimension of the tuple */ + inline uint32_t ndim() const { + return ndim_; + } + /*! + * \brief get corresponding index + * \param i dimension index + * \return the corresponding dimension size + */ + inline ValueType& operator[](size_t i) { + return begin()[i]; + } + /*! + * \brief get corresponding index + * \param i dimension index + * \return the corresponding dimension size + */ + inline const ValueType& operator[](size_t i) const { + return begin()[i]; + } + /*! + * \brief Save Tuple to JSON. + * \param writer JSONWriter + */ + inline void Save(dmlc::JSONWriter* writer) const { + std::vector tmp(begin(), end()); + writer->Write(tmp); + } + /*! + * \brief Load Tuple from JSON. + * \param reader JSONReader + */ + inline void Load(dmlc::JSONReader* reader) { + std::vector tmp; + reader->Read(&tmp); + this->assign(tmp.begin(), tmp.end()); + } + /*! + * \brief allow output string of tuple to ostream + * \param os the output stream + * \param t the tuple + * \return the ostream + */ + friend std::ostream &operator<<(std::ostream &os, const Tuple &t) { + os << '['; + const ValueType* begin = t.begin(); + const ValueType* end = t.end(); + for (const ValueType* it = begin; it != end; ++it) { + if (it != begin) os << ','; + os << *it; + } + os << ']'; + return os; + } + /*! + * \brief read tuple from the istream + * \param is the input stream + * \param t The tuple + * \return the istream + */ + friend std::istream &operator>>(std::istream &is, Tuple &t) { + // get ( + while (true) { + char ch = is.peek(); + if (isdigit(ch) || ch == '-') { + ValueType idx; + if (is >> idx) { + t.assign(&idx, &idx + 1); + } + return is; + } + is.get(); + if (ch == '(' || ch == '[') break; + if (!isspace(ch)) { + is.setstate(std::ios::failbit); + return is; + } + } + // Handle empty tuple + while (isspace(is.peek())) { + is.get(); + } + if (is.peek() == ')' || is.peek() == ']') { + is.get(); + return is; + } + // Handle non-empty tuple + ValueType idx; + std::vector tmp; + while (is >> idx) { + tmp.push_back(idx); + char ch; + do { + ch = is.get(); + } while (isspace(ch)); + if (std::is_integral::value && ch == 'L') { + ch = is.get(); + } + if (ch == ',') { + while (true) { + ch = is.peek(); + if (isspace(ch)) { + is.get(); continue; + } + if (ch == ')' || ch == ']') { + is.get(); break; + } + break; + } + if (ch == ')' || ch == ']') break; + } else if (ch == ')' || ch == ']') { + break; + } else { + is.setstate(std::ios::failbit); + return is; + } + } + t.assign(tmp.begin(), tmp.end()); + return is; + } + /*! + * \brief save the content into binary stream + * \param strm the output stream + * \tparam DType data type that save to + * \tparam TStream any stream type that have write + */ + template + inline void Save(TStream *strm) const; + /*! + * \brief load the content from binary stream + * \param strm the output stream + * \tparam DType data type that load from + * \tparam TStream any stream type that have write + * \return whether the load is successful + */ + template + inline bool Load(TStream *strm); + + protected: + // stack cache size + static const uint32_t kStackCache = 4; + /*! \brief number of dimension of the tuple */ + uint32_t ndim_{0}; + /*! \brief number of cells allocated in data_heap_ */ + uint32_t num_heap_allocated_{0}; + /*! \brief in stack space used to store shape when it is small */ + ValueType data_stack_[kStackCache]; + /*! \brief space to store shape when dimension is big*/ + ValueType* data_heap_{nullptr}; + // internal function to change the dimension + inline void SetDim(uint32_t ndim) { + if (ndim > kStackCache && + ndim > num_heap_allocated_) { + delete [] data_heap_; + data_heap_ = new ValueType[ndim]; + num_heap_allocated_ = ndim; + } + ndim_ = ndim; + } +}; + +/*! + * \brief A Shape class that is used to represent shape of each tensor. + */ +class TShape : public Tuple { + public: + /*! \brief default constructor */ + TShape() = default; + /*! + * constructor to construct a shape with all 1. + * \param ndim the number of dimension + */ + inline TShape(uint32_t ndim) { // NOLINT(*) + this->SetDim(ndim); + std::fill_n(begin(), ndim, 1); + } + /*! + * \brief copy constructor of TShape + * \param s source shape. + */ + inline TShape(const Tuple& s) { // NOLINT(*) + this->assign(s.begin(), s.end()); + } + /*! + * \brief constructor from initializer list + * \param init the initializer_list + */ + inline TShape(std::initializer_list init) { + this->assign(init.begin(), init.end()); + } + /*! + * \brief move constructor. + * \param s source shape. + */ + inline TShape(Tuple&& s) { // NOLINT(*) + this->swap(s); + } + /*! + * \brief construct the Tuple from content of iterator + * \param begin the beginning of iterator + * \param end end the end of the iterator + * \tparam RandomAccessIterator iterator type + */ + template + inline TShape(RandomAccessIterator begin, + RandomAccessIterator end) { + this->assign(begin, end); + } + /*! + * \brief assignment function from tshape + * \param src source shape. + * \return self. + */ + inline TShape& operator=(const Tuple& src) { + this->assign(src.begin(), src.end()); + return *this; + } + /*! + * \brief move assignment function from tshape + * \param src source shape. + * \return self. + */ + inline TShape& operator=(Tuple&& src) { // NOLINT(*) + TShape(std::move(src)).swap(*this); // NOLINT(*) + return *this; + } + /*! \return total number of elements in the shape */ + inline size_t Size() const { + dim_t size = 1; + const dim_t* start = begin(), *fin = end(); + for (const dim_t* it = start; it != fin; ++it) { + size *= *it; + } + return size; + } + /*! + * \return product shape in [dimstart,dimend) + * \param dimstart start dimension + * \param dimend end dimension + */ + inline size_t ProdShape(int dimstart, int dimend) const { + dim_t num = 1; + const dim_t *d = this->data(); + for (int i = dimstart; i < dimend; ++i) { + num *= d[i]; + } + return num; + } + /*! \return the begin data pointer to content of the tuple */ + inline const dim_t *data() const { + return begin(); + } + /*! \return the begin data pointer to content of the tuple */ + inline dim_t *data() { + return begin(); + } +#ifdef MSHADOW_XINLINE + template + inline TShape(const mshadow::Shape &s) {// NOLINT(*) + this->assign(s.shape_, s.shape_ + dim); + } + + template + inline TShape(mshadow::Shape &&s) {// NOLINT(*) + this->assign(s.shape_, s.shape_ + dim); + } + /*! + * \brief assignment from shape + * \param shape source shape + * \tparam dim shape dimension + * \return reference of self + */ + template + inline TShape &operator=(const mshadow::Shape &shape) { + this->assign(shape.shape_, shape.shape_ + dim); + return *this; + } + /*! + * \brief get the shape of tensor specifying dim + * \return the shape requested + * \tparam dim dimension of the tensor + */ + template + inline mshadow::Shape get() const { + CHECK_EQ(dim, static_cast(ndim())) + << "dimension do not match target dimension " << dim << " vs " << ndim(); + const dim_t *d = this->data(); + mshadow::Shape s; + for (int i = 0; i < dim; ++i) { + s[i] = d[i]; + } + return s; + } + /*! + * flatten the higher dimension to second dimension, return a 2D shape + * \return the flat 2d shape + */ + inline mshadow::Shape<2> FlatTo2D(void) const { + mshadow::Shape<2> s; + if (ndim() == 0) return mshadow::Shape2(0, 0); + const dim_t *d = this->data(); + s.shape_[1] = d[ndim() - 1]; + dim_t ymax = 1; + for (size_t i = 1; i < ndim(); ++i) { + ymax *= d[i - 1]; + } + s.shape_[0] = ymax; + return s; + } + /*! + * flatten the shape into three parts: [0, axis_begin), [axis_begin, axis_end], (axis_end, ndim) + * \param axis_begin The beginning axis specified. + * \param axis_end The ending axis specified. + * \return the flat 3d shape + */ + inline mshadow::Shape<3> FlatTo3D(size_t axis_begin, size_t axis_end) const { + CHECK(axis_end >= axis_begin); + mshadow::Shape<3> s; + if (ndim() == 0) return mshadow::Shape3(0, 0, 0); + const dim_t *d = this->data(); + s.shape_[0] = 1; + s.shape_[1] = 1; + s.shape_[2] = 1; + + for (size_t i = 0; i < axis_begin; ++i) { + s.shape_[0] *= d[i]; + } + for (size_t i = axis_begin; i <= axis_end; ++i) { + s.shape_[1] *= d[i]; + } + for (size_t i = axis_end + 1; i < ndim(); ++i) { + s.shape_[2] *= d[i]; + } + return s; + } + /*! + * flatten the axis before and after the specified axis, so it becomes 3D tensor + * \param axis The axis specified. + * \return the flat 3d shape + */ + inline mshadow::Shape<3> FlatTo3D(size_t axis) const { + return FlatTo3D(axis, axis); + } + inline bool operator==(const TShape &s) const { + if (ndim() != s.ndim()) return false; + return std::equal(begin(), end(), s.begin()); + } + inline bool operator!=(const TShape &s) const { + return !(*this == s); + } + /*! + * \return whether two shape equals + * \param s the shape to compare against + * \tparam dim dimension of the shape + */ + template + inline bool operator==(const mshadow::Shape &s) const { + if (ndim_ != dim) return false; + const dim_t *d = dim <= kStackCache ? data_stack_ : data_heap_; + for (size_t i = 0; i < dim; ++i) { + if (d[i] != s.shape_[i]) return false; + } + return true; + } + /*! + * \return whether two shape not equals + * \param s the shape to compare against + * \tparam dim dimension of the shape + */ + template + inline bool operator!=(const mshadow::Shape &s) const { + return !(*this == s); + } +#endif +}; + +/*! \brief helper function to cast type of container elements */ +template +inline DstIter ShapeTypeCast(const SrcIter begin, + const SrcIter end, + DstIter dst_begin) { + typedef typename std::iterator_traits::value_type SrcDType; + typedef typename std::iterator_traits::value_type DstDType; + auto cast = [](const SrcDType& dim) { return static_cast(dim); }; + return std::transform(begin, end, dst_begin, cast); +} + +/*! \brief helper function to transform a container to TShape with type cast */ +template +inline TShape ShapeTypeCast(const SrcIter begin, const SrcIter end) { + size_t ndim = std::distance(begin, end); + TShape res(ndim); + ShapeTypeCast(begin, end, res.begin()); + return res; +} + +/*! \tparam ValueType The type of data stored inside tuple. */ +template +template +inline void Tuple::Save(TStream *strm) const { + strm->Write(&ndim_, sizeof(ndim_)); + if (typeid(DType) == typeid(ValueType)) { + strm->Write(begin(), sizeof(ValueType) * ndim_); + } else { + std::vector buffer(ndim_); + ShapeTypeCast(begin(), end(), buffer.data()); + strm->Write(buffer.data(), sizeof(DType) * ndim_); + } +} + +/*! \tparam ValueType The type of data stored inside tuple. */ +template +template +inline bool Tuple::Load(TStream *strm) { + if (strm->Read(&ndim_, sizeof(ndim_)) != sizeof(ndim_)) return false; + this->SetDim(ndim_); + size_t nread = sizeof(DType) * ndim_; + if (typeid(DType) == typeid(ValueType)) { + if (strm->Read(begin(), nread) != nread) return false; + } else { + std::vector buffer(ndim_); + if (strm->Read(buffer.data(), nread) != nread) return false; + ShapeTypeCast(buffer.begin(), buffer.end(), begin()); + } + return true; +} + +} // namespace nnvm + +namespace std { +/*! \brief hash function for Tuple. */ +template +struct hash > { + /*! \brief hash a Tuple into unsigned int */ + size_t operator()(const nnvm::Tuple& val) const { + std::hash hash_uint; + size_t res = hash_uint(val.ndim()); + for (uint32_t i = 0; i < val.ndim(); ++i) { + res = dmlc::HashCombine(res, val[i]); + } + return res; + } +}; + +/*! \brief hash function for TShape. */ +template<> +struct hash { + /*! \brief hash a TShape into unsigned int */ + size_t operator()(const nnvm::TShape& val) const { + std::hash hash_uint; + size_t res = hash_uint(val.ndim()); + for (uint32_t i = 0; i < val.ndim(); ++i) { + res = dmlc::HashCombine(res, val[i]); + } + return res; + } +}; +} // namespace std + +namespace dmlc { +/*! \brief description for optional TShape */ +DMLC_DECLARE_TYPE_NAME(optional, "Shape or None"); +// avoid low version of MSVC +#if !defined(_MSC_VER) +template +struct type_name_helper > { + static inline std::string value() { + return "tuple of <" + type_name() + ">"; + } +}; +#endif +} // namespace dmlc +#endif // NNVM_TUPLE_H_ From df6d33f3fc34973c3ea0c14a00e976280a4665e5 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Wed, 31 Oct 2018 19:46:11 +0000 Subject: [PATCH 02/12] Add symbolic link and cherry picked required header --- include/dlpack/dlpack.h | 142 +- include/dmlc | 1 + include/dmlc/any.h | 371 -- include/dmlc/array_view.h | 128 - include/dmlc/base.h | 291 -- include/dmlc/blockingconcurrentqueue.h | 991 ----- include/dmlc/common.h | 85 - include/dmlc/concurrency.h | 258 -- include/dmlc/concurrentqueue.h | 3719 ----------------- include/dmlc/config.h | 186 - include/dmlc/data.h | 397 -- include/dmlc/endian.h | 44 - include/dmlc/input_split_shuffle.h | 168 - include/dmlc/io.h | 522 --- include/dmlc/json.h | 981 ----- include/dmlc/logging.h | 424 -- include/dmlc/lua.h | 739 ---- include/dmlc/memory.h | 261 -- include/dmlc/memory_io.h | 105 - include/dmlc/omp.h | 47 - include/dmlc/optional.h | 261 -- include/dmlc/parameter.h | 1065 ----- include/dmlc/recordio.h | 196 - include/dmlc/registry.h | 306 -- include/dmlc/serializer.h | 410 -- include/dmlc/thread_group.h | 808 ---- include/dmlc/thread_local.h | 83 - include/dmlc/threadediter.h | 475 --- include/dmlc/timer.h | 49 - include/dmlc/type_traits.h | 191 - include/mshadow | 1 + include/mshadow/README.md | 8 - include/mshadow/base.h | 1106 ----- include/mshadow/cuda/reduce.cuh | 120 - include/mshadow/cuda/tensor_gpu-inl.cuh | 828 ---- include/mshadow/dot_engine-inl.h | 906 ---- include/mshadow/expr_engine-inl.h | 482 --- include/mshadow/expr_scalar-inl.h | 165 - include/mshadow/expression.h | 416 -- include/mshadow/extension.h | 41 - include/mshadow/extension/broadcast.h | 165 - .../mshadow/extension/broadcast_with_axis.h | 258 -- include/mshadow/extension/channel_pool.h | 108 - include/mshadow/extension/channel_unpool.h | 137 - include/mshadow/extension/choose.h | 90 - include/mshadow/extension/complex.h | 525 --- include/mshadow/extension/concat.h | 194 - include/mshadow/extension/crop.h | 119 - include/mshadow/extension/fill.h | 103 - include/mshadow/extension/flip.h | 132 - include/mshadow/extension/implicit_gemm.h | 128 - include/mshadow/extension/mask.h | 97 - include/mshadow/extension/mirror.h | 62 - include/mshadow/extension/one_hot.h | 87 - include/mshadow/extension/pack_col2patch.h | 154 - include/mshadow/extension/pad.h | 111 - include/mshadow/extension/range.h | 118 - include/mshadow/extension/reduce_with_axis.h | 136 - include/mshadow/extension/reduceto1d.h | 104 - include/mshadow/extension/reshape.h | 87 - include/mshadow/extension/slice.h | 156 - include/mshadow/extension/slice_ex.h | 135 - include/mshadow/extension/spatial_pool.h | 152 - include/mshadow/extension/spatial_unpool.h | 135 - .../extension/spatial_upsampling_nearest.h | 71 - include/mshadow/extension/swapaxis.h | 110 - include/mshadow/extension/take.h | 99 - include/mshadow/extension/take_grad.h | 111 - include/mshadow/extension/transpose.h | 200 - include/mshadow/extension/unpack_patch2col.h | 151 - include/mshadow/half.h | 288 -- include/mshadow/half2.h | 143 - include/mshadow/io.h | 137 - include/mshadow/logging.h | 234 -- include/mshadow/packet-inl.h | 413 -- include/mshadow/packet/plain-inl.h | 76 - include/mshadow/packet/sse-inl.h | 147 - include/mshadow/random.h | 570 --- include/mshadow/stream_gpu-inl.h | 212 - include/mshadow/tensor.h | 1078 ----- include/mshadow/tensor_container.h | 208 - include/mshadow/tensor_cpu-inl.h | 627 --- include/mshadow/tensor_gpu-inl.h | 245 -- include/nnvm | 1 + include/nnvm/base.h | 35 - include/nnvm/c_api.h | 388 -- include/nnvm/compiler/op_attr_types.h | 101 - include/nnvm/compiler/packed_func_ext.h | 59 - include/nnvm/compiler/util.h | 33 - include/nnvm/graph.h | 315 -- include/nnvm/graph_attr_types.h | 112 - include/nnvm/layout.h | 455 -- include/nnvm/node.h | 201 - include/nnvm/op.h | 562 --- include/nnvm/op_attr_types.h | 219 - include/nnvm/pass.h | 128 - include/nnvm/pass_functions.h | 190 - include/nnvm/symbolic.h | 217 - include/nnvm/top/README | 1 - include/nnvm/top/nn.h | 498 --- include/nnvm/top/tensor.h | 301 -- include/nnvm/tuple.h | 633 --- 102 files changed, 4 insertions(+), 30835 deletions(-) mode change 100644 => 120000 include/dlpack/dlpack.h create mode 120000 include/dmlc delete mode 100644 include/dmlc/any.h delete mode 100644 include/dmlc/array_view.h delete mode 100644 include/dmlc/base.h delete mode 100644 include/dmlc/blockingconcurrentqueue.h delete mode 100644 include/dmlc/common.h delete mode 100644 include/dmlc/concurrency.h delete mode 100644 include/dmlc/concurrentqueue.h delete mode 100644 include/dmlc/config.h delete mode 100644 include/dmlc/data.h delete mode 100644 include/dmlc/endian.h delete mode 100644 include/dmlc/input_split_shuffle.h delete mode 100644 include/dmlc/io.h delete mode 100644 include/dmlc/json.h delete mode 100644 include/dmlc/logging.h delete mode 100644 include/dmlc/lua.h delete mode 100644 include/dmlc/memory.h delete mode 100644 include/dmlc/memory_io.h delete mode 100644 include/dmlc/omp.h delete mode 100644 include/dmlc/optional.h delete mode 100644 include/dmlc/parameter.h delete mode 100644 include/dmlc/recordio.h delete mode 100644 include/dmlc/registry.h delete mode 100644 include/dmlc/serializer.h delete mode 100644 include/dmlc/thread_group.h delete mode 100644 include/dmlc/thread_local.h delete mode 100644 include/dmlc/threadediter.h delete mode 100644 include/dmlc/timer.h delete mode 100644 include/dmlc/type_traits.h create mode 120000 include/mshadow delete mode 100644 include/mshadow/README.md delete mode 100755 include/mshadow/base.h delete mode 100644 include/mshadow/cuda/reduce.cuh delete mode 100755 include/mshadow/cuda/tensor_gpu-inl.cuh delete mode 100644 include/mshadow/dot_engine-inl.h delete mode 100644 include/mshadow/expr_engine-inl.h delete mode 100644 include/mshadow/expr_scalar-inl.h delete mode 100644 include/mshadow/expression.h delete mode 100644 include/mshadow/extension.h delete mode 100644 include/mshadow/extension/broadcast.h delete mode 100644 include/mshadow/extension/broadcast_with_axis.h delete mode 100644 include/mshadow/extension/channel_pool.h delete mode 100644 include/mshadow/extension/channel_unpool.h delete mode 100644 include/mshadow/extension/choose.h delete mode 100644 include/mshadow/extension/complex.h delete mode 100644 include/mshadow/extension/concat.h delete mode 100644 include/mshadow/extension/crop.h delete mode 100644 include/mshadow/extension/fill.h delete mode 100644 include/mshadow/extension/flip.h delete mode 100644 include/mshadow/extension/implicit_gemm.h delete mode 100644 include/mshadow/extension/mask.h delete mode 100644 include/mshadow/extension/mirror.h delete mode 100644 include/mshadow/extension/one_hot.h delete mode 100644 include/mshadow/extension/pack_col2patch.h delete mode 100644 include/mshadow/extension/pad.h delete mode 100644 include/mshadow/extension/range.h delete mode 100644 include/mshadow/extension/reduce_with_axis.h delete mode 100644 include/mshadow/extension/reduceto1d.h delete mode 100644 include/mshadow/extension/reshape.h delete mode 100644 include/mshadow/extension/slice.h delete mode 100644 include/mshadow/extension/slice_ex.h delete mode 100644 include/mshadow/extension/spatial_pool.h delete mode 100644 include/mshadow/extension/spatial_unpool.h delete mode 100644 include/mshadow/extension/spatial_upsampling_nearest.h delete mode 100644 include/mshadow/extension/swapaxis.h delete mode 100644 include/mshadow/extension/take.h delete mode 100644 include/mshadow/extension/take_grad.h delete mode 100644 include/mshadow/extension/transpose.h delete mode 100644 include/mshadow/extension/unpack_patch2col.h delete mode 100644 include/mshadow/half.h delete mode 100755 include/mshadow/half2.h delete mode 100644 include/mshadow/io.h delete mode 100644 include/mshadow/logging.h delete mode 100644 include/mshadow/packet-inl.h delete mode 100644 include/mshadow/packet/plain-inl.h delete mode 100644 include/mshadow/packet/sse-inl.h delete mode 100644 include/mshadow/random.h delete mode 100644 include/mshadow/stream_gpu-inl.h delete mode 100755 include/mshadow/tensor.h delete mode 100644 include/mshadow/tensor_container.h delete mode 100755 include/mshadow/tensor_cpu-inl.h delete mode 100755 include/mshadow/tensor_gpu-inl.h create mode 120000 include/nnvm delete mode 100644 include/nnvm/base.h delete mode 100644 include/nnvm/c_api.h delete mode 100644 include/nnvm/compiler/op_attr_types.h delete mode 100644 include/nnvm/compiler/packed_func_ext.h delete mode 100644 include/nnvm/compiler/util.h delete mode 100644 include/nnvm/graph.h delete mode 100644 include/nnvm/graph_attr_types.h delete mode 100644 include/nnvm/layout.h delete mode 100644 include/nnvm/node.h delete mode 100644 include/nnvm/op.h delete mode 100644 include/nnvm/op_attr_types.h delete mode 100644 include/nnvm/pass.h delete mode 100644 include/nnvm/pass_functions.h delete mode 100644 include/nnvm/symbolic.h delete mode 100644 include/nnvm/top/README delete mode 100644 include/nnvm/top/nn.h delete mode 100644 include/nnvm/top/tensor.h delete mode 100644 include/nnvm/tuple.h diff --git a/include/dlpack/dlpack.h b/include/dlpack/dlpack.h deleted file mode 100644 index f8dc8fcd2cdf..000000000000 --- a/include/dlpack/dlpack.h +++ /dev/null @@ -1,141 +0,0 @@ -/*! - * Copyright (c) 2017 by Contributors - * \file dlpack.h - * \brief The common header of DLPack. - */ -#ifndef DLPACK_DLPACK_H_ -#define DLPACK_DLPACK_H_ - -#ifdef __cplusplus -#define DLPACK_EXTERN_C extern "C" -#else -#define DLPACK_EXTERN_C -#endif - -/*! \brief The current version of dlpack */ -#define DLPACK_VERSION 010 - -/*! \brief DLPACK_DLL prefix for windows */ -#ifdef _WIN32 -#ifdef DLPACK_EXPORTS -#define DLPACK_DLL __declspec(dllexport) -#else -#define DLPACK_DLL __declspec(dllimport) -#endif -#else -#define DLPACK_DLL -#endif - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif -/*! - * \brief The device type in DLContext. - */ -typedef enum { - kDLCPU = 1, - kDLGPU = 2, - // kDLCPUPinned = kDLCPU | kDLGPU - kDLCPUPinned = 3, - kDLOpenCL = 4, - kDLMetal = 8, - kDLVPI = 9, - kDLROCM = 10, -} DLDeviceType; - -/*! - * \brief A Device context for Tensor and operator. - */ -typedef struct { - /*! \brief The device type used in the device. */ - DLDeviceType device_type; - /*! \brief The device index */ - int device_id; -} DLContext; - -/*! - * \brief The type code options DLDataType. - */ -typedef enum { - kDLInt = 0U, - kDLUInt = 1U, - kDLFloat = 2U, -} DLDataTypeCode; - -/*! - * \brief The data type the tensor can hold. - * - * Examples - * - float: type_code = 2, bits = 32, lanes=1 - * - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 - * - int8: type_code = 0, bits = 8, lanes=1 - */ -typedef struct { - /*! - * \brief Type code of base types. - * We keep it uint8_t instead of DLDataTypeCode for minimal memory - * footprint, but the value should be one of DLDataTypeCode enum values. - * */ - uint8_t code; - /*! - * \brief Number of bits, common choices are 8, 16, 32. - */ - uint8_t bits; - /*! \brief Number of lanes in the type, used for vector types. */ - uint16_t lanes; -} DLDataType; - -/*! - * \brief Plain C Tensor object, does not manage memory. - */ -typedef struct { - /*! - * \brief The opaque data pointer points to the allocated data. - * This will be CUDA device pointer or cl_mem handle in OpenCL. - * This pointer is always aligns to 256 bytes as in CUDA. - */ - void* data; - /*! \brief The device context of the tensor */ - DLContext ctx; - /*! \brief Number of dimensions */ - int ndim; - /*! \brief The data type of the pointer*/ - DLDataType dtype; - /*! \brief The shape of the tensor */ - int64_t* shape; - /*! - * \brief strides of the tensor, - * can be NULL, indicating tensor is compact. - */ - int64_t* strides; - /*! \brief The offset in bytes to the beginning pointer to data */ - uint64_t byte_offset; -} DLTensor; - -/*! - * \brief C Tensor object, manage memory of DLTensor. This data structure is - * intended to faciliate the borrowing of DLTensor by another framework. It is - * not meant to transfer the tensor. When the borrowing framework doesn't need - * the tensor, it should call the deleter to notify the host that the resource - * is no longer needed. - */ -typedef struct DLManagedTensor { - /*! \brief DLTensor which is being memory managed */ - DLTensor dl_tensor; - /*! \brief the context of the original host framework of DLManagedTensor in - * which DLManagedTensor is used in the framework. It can also be NULL. - */ - void * manager_ctx; - /*! \brief Destructor signature void (*)(void*) - this should be called - * to destruct manager_ctx which holds the DLManagedTensor. It can be NULL - * if there is no way for the caller to provide a reasonable destructor. - */ - void (*deleter)(struct DLManagedTensor * self); -} DLManagedTensor; -#ifdef __cplusplus -} // DLPACK_EXTERN_C -#endif -#endif // DLPACK_DLPACK_H_ diff --git a/include/dlpack/dlpack.h b/include/dlpack/dlpack.h new file mode 120000 index 000000000000..119855e7cd94 --- /dev/null +++ b/include/dlpack/dlpack.h @@ -0,0 +1 @@ +../../3rdparty/dlpack/include/dlpack/dlpack.h \ No newline at end of file diff --git a/include/dmlc b/include/dmlc new file mode 120000 index 000000000000..869c40b0e502 --- /dev/null +++ b/include/dmlc @@ -0,0 +1 @@ +../3rdparty/dmlc-core/include/dmlc \ No newline at end of file diff --git a/include/dmlc/any.h b/include/dmlc/any.h deleted file mode 100644 index 8041bf7ee53a..000000000000 --- a/include/dmlc/any.h +++ /dev/null @@ -1,371 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file any.h - * \brief Container to hold any data type. - */ -#ifndef DMLC_ANY_H_ -#define DMLC_ANY_H_ - -// This code need c++11 to compile -#include -#include -#include -#include - -#include "./base.h" -#include "./logging.h" - -namespace dmlc { -// forward declare any; -class any; - -/*! - * Get a reference to content stored in the any as type T. - * This will cause an error if - * T does not match the type stored. - * This function is not part of std::any standard. - * - * \param src The source source any container. - * \return The reference of content - * \tparam T The type of the value to be fetched. - */ -template -inline T& get(any& src); // NOLINT(*) - -/*! - * Get the const reference content stored in the any as type T. - * This will cause an error if - * T does not match the type stored. - * This function is not part of std::any standard. - * - * \param src The source source any container. - * \return The reference of content - * \tparam T The type of the value to be fetched. - */ -template -inline const T& get(const any& src); - -/*! - * \brief An any class that is compatible to std::any in c++17. - * - * \code - * dmlc::any a = std::string("mydear"), b = 1; - * // get reference out and add it - * dmlc::get(b) += 1; - * // a is now string - * LOG(INFO) << dmlc::get(a); - * // a is now 2, the string stored will be properly destructed - * a = std::move(b); - * LOG(INFO) << dmlc::get(a); - * \endcode - * \sa get - */ -class any { - public: - /*! \brief default constructor */ - inline any() = default; - /*! - * \brief move constructor from another any - * \param other The other any to be moved - */ - inline any(any&& other); // NOLINT(*) - /*! - * \brief copy constructor - * \param other The other any to be copied - */ - inline any(const any& other); // NOLINT(*) - /*! - * \brief constructor from any types - * \param other The other types to be constructed into any. - * \tparam T The value type of other. - */ - template - inline any(T&& other); // NOLINT(*) - /*! \brief destructor */ - inline ~any(); - /*! - * \brief assign operator from other - * \param other The other any to be copy or moved. - * \return self - */ - inline any& operator=(any&& other); - /*! - * \brief assign operator from other - * \param other The other any to be copy or moved. - * \return self - */ - inline any& operator=(const any& other); - /*! - * \brief assign operator from any type. - * \param other The other any to be copy or moved. - * \tparam T The value type of other. - * \return self - */ - template - inline any& operator=(T&& other); - /*! - * \return whether the container is empty. - */ - inline bool empty() const; - /*! - * \brief clear the content of container - */ - inline void clear(); - /*! - * swap current content with other - * \param other The other data to be swapped. - */ - inline void swap(any& other); // NOLINT(*) - /*! - * \return The type_info about the stored type. - */ - inline const std::type_info& type() const; - /*! \brief Construct value of type T inplace */ - template - inline void construct(Args&&... args); - - private: - //! \cond Doxygen_Suppress - // declare of helper class - template - class TypeOnHeap; - template - class TypeOnStack; - template - class TypeInfo; - // size of stack space, it takes 32 bytes for one any type. - static const size_t kStack = sizeof(void*) * 3; - static const size_t kAlign = sizeof(void*); - // container use dynamic storage only when space runs lager - union Data { - // stack space - std::aligned_storage::type stack; - // pointer to heap space - void* pheap; - }; - // type specific information - struct Type { - // destructor function - void (*destroy)(Data* data); - // copy constructor - void (*create_from_data)(Data* dst, const Data& src); - // the type info function - const std::type_info* ptype_info; - }; - // constant to check if data can be stored on heap. - template - struct data_on_stack { - static const bool value = alignof(T) <= kAlign && sizeof(T) <= kStack; - }; - // declare friend with - template - friend T& get(any& src); // NOLINT(*) - template - friend const T& get(const any& src); - // internal construct function - inline void construct(any&& other); - // internal construct function - inline void construct(const any& other); - // internal function to check if type is correct. - template - inline void check_type() const; - // internal type specific information - const Type* type_{nullptr}; - // internal data - Data data_; -}; - -template -inline any::any(T&& other) { - typedef typename std::decay::type DT; - if (std::is_same::value) { - this->construct(std::forward(other)); - } else { - static_assert(std::is_copy_constructible
::value, - "Any can only hold value that is copy constructable"); - type_ = TypeInfo
::get_type(); - if (data_on_stack
::value) { -#pragma GCC diagnostic push -#if 6 <= __GNUC__ -#pragma GCC diagnostic ignored "-Wplacement-new" -#endif - new (&(data_.stack)) DT(std::forward(other)); -#pragma GCC diagnostic pop - } else { - data_.pheap = new DT(std::forward(other)); - } - } -} - -inline any::any(any&& other) { - this->construct(std::move(other)); -} - -inline any::any(const any& other) { - this->construct(other); -} - -inline void any::construct(any&& other) { - type_ = other.type_; - data_ = other.data_; - other.type_ = nullptr; -} - -inline void any::construct(const any& other) { - type_ = other.type_; - if (type_ != nullptr) { - type_->create_from_data(&data_, other.data_); - } -} - -template -inline void any::construct(Args&&... args) { - clear(); - typedef typename std::decay::type DT; - type_ = TypeInfo
::get_type(); - if (data_on_stack
::value) { -#pragma GCC diagnostic push -#if 6 <= __GNUC__ -#pragma GCC diagnostic ignored "-Wplacement-new" -#endif - new (&(data_.stack)) DT(std::forward(args)...); -#pragma GCC diagnostic pop - } else { - data_.pheap = new DT(std::forward(args)...); - } -} - -inline any::~any() { - this->clear(); -} - -inline any& any::operator=(any&& other) { - any(std::move(other)).swap(*this); - return *this; -} - -inline any& any::operator=(const any& other) { - any(other).swap(*this); - return *this; -} - -template -inline any& any::operator=(T&& other) { - any(std::forward(other)).swap(*this); - return *this; -} - -inline void any::swap(any& other) { // NOLINT(*) - std::swap(type_, other.type_); - std::swap(data_, other.data_); -} - -inline void any::clear() { - if (type_ != nullptr) { - if (type_->destroy != nullptr) { - type_->destroy(&data_); - } - type_ = nullptr; - } -} - -inline bool any::empty() const { - return type_ == nullptr; -} - -inline const std::type_info& any::type() const { - if (type_ != nullptr) { - return *(type_->ptype_info); - } else { - return typeid(void); - } -} - -template -inline void any::check_type() const { - CHECK(type_ != nullptr) - << "The any container is empty" - << " requested=" << typeid(T).name(); - CHECK(*(type_->ptype_info) == typeid(T)) - << "The stored type mismatch" - << " stored=" << type_->ptype_info->name() - << " requested=" << typeid(T).name(); -} - -template -inline const T& get(const any& src) { - src.check_type(); - return *any::TypeInfo::get_ptr(&(src.data_)); -} - -template -inline T& get(any& src) { // NOLINT(*) - src.check_type(); - return *any::TypeInfo::get_ptr(&(src.data_)); -} - -template -class any::TypeOnHeap { - public: - inline static T* get_ptr(any::Data* data) { - return static_cast(data->pheap); - } - inline static const T* get_ptr(const any::Data* data) { - return static_cast(data->pheap); - } - inline static void create_from_data(any::Data* dst, const any::Data& data) { - dst->pheap = new T(*get_ptr(&data)); - } - inline static void destroy(Data* data) { - delete static_cast(data->pheap); - } -}; - -template -class any::TypeOnStack { - public: - inline static T* get_ptr(any::Data* data) { - return reinterpret_cast(&(data->stack)); - } - inline static const T* get_ptr(const any::Data* data) { - return reinterpret_cast(&(data->stack)); - } - inline static void create_from_data(any::Data* dst, const any::Data& data) { - new (&(dst->stack)) T(*get_ptr(&data)); - } - inline static void destroy(Data* data) { - T* dptr = reinterpret_cast(&(data->stack)); - dptr->~T(); - } -}; - -template -class any::TypeInfo - : public std::conditional::value, - any::TypeOnStack, - any::TypeOnHeap >::type { - public: - inline static const Type* get_type() { - static TypeInfo tp; - return &(tp.type_); - } - - private: - // local type - Type type_; - // constructor - TypeInfo() { - if (std::is_pod::value && data_on_stack::value) { - type_.destroy = nullptr; - } else { - type_.destroy = TypeInfo::destroy; - } - type_.create_from_data = TypeInfo::create_from_data; - type_.ptype_info = &typeid(T); - } -}; -//! \endcond - -} // namespace dmlc - -#endif // DMLC_ANY_H_ diff --git a/include/dmlc/array_view.h b/include/dmlc/array_view.h deleted file mode 100644 index 5e01a78cc53d..000000000000 --- a/include/dmlc/array_view.h +++ /dev/null @@ -1,128 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file array_view.h - * \brief Read only data structure to reference array - */ -#ifndef DMLC_ARRAY_VIEW_H_ -#define DMLC_ARRAY_VIEW_H_ - -#include -#include - -namespace dmlc { - -/*! - * \brief Read only data structure to reference continuous memory region of array. - * Provide unified view for vector, array and C style array. - * This data structure do not guarantee aliveness of referenced array. - * - * Make sure do not use array_view to record data in async function closures. - * Also do not use array_view to create reference to temporary data structure. - * - * \tparam ValueType The value - * - * \code - * std::vector myvec{1,2,3}; - * dmlc::array_view view(myvec); - * // indexed visit to the view. - * LOG(INFO) << view[0]; - * - * for (int v : view) { - * // visit each element in the view - * } - * \endcode - */ -template -class array_view { - public: - /*! \brief default constructor */ - array_view() = default; - /*! - * \brief default copy constructor - * \param other another array view. - */ - array_view(const array_view &other) = default; // NOLINT(*) -#ifndef _MSC_VER - /*! - * \brief default move constructor - * \param other another array view. - */ - array_view(array_view&& other) = default; // NOLINT(*) -#else - /*! - * \brief default move constructor - * \param other another array view. - */ - array_view(array_view&& other) { // NOLINT(*) - begin_ = other.begin_; - size_ = other.size_; - other.begin_ = nullptr; - } -#endif - /*! - * \brief default assign constructor - * \param other another array view. - * \return self. - */ - array_view& operator=(const array_view& other) = default; // NOLINT(*) - /*! - * \brief construct array view std::vector - * \param other vector container - */ - array_view(const std::vector& other) { // NOLINT(*) - if (other.size() != 0) { - begin_ = &other[0]; size_ = other.size(); - } - } - /*! - * \brief construct array std::array - * \param other another array view. - */ - template - array_view(const std::array& other) { // NOLINT(*) - if (size != 0) { - begin_ = &other[0]; size_ = size; - } - } - /*! - * \brief construct array view from continuous segment - * \param begin beginning pointre - * \param end end pointer - */ - array_view(const ValueType* begin, const ValueType* end) { - if (begin < end) { - begin_ = begin; - size_ = end - begin; - } - } - /*! \return size of the array */ - inline size_t size() const { - return size_; - } - /*! \return begin of the array */ - inline const ValueType* begin() const { - return begin_; - } - /*! \return end point of the array */ - inline const ValueType* end() const { - return begin_ + size_; - } - /*! - * \brief get i-th element from the view - * \param i The index. - * \return const reference to i-th element. - */ - inline const ValueType& operator[](size_t i) const { - return begin_[i]; - } - - private: - /*! \brief the begin of the view */ - const ValueType* begin_{nullptr}; - /*! \brief The size of the view */ - size_t size_{0}; -}; - -} // namespace dmlc - -#endif // DMLC_ARRAY_VIEW_H_ diff --git a/include/dmlc/base.h b/include/dmlc/base.h deleted file mode 100644 index 1caf487e9365..000000000000 --- a/include/dmlc/base.h +++ /dev/null @@ -1,291 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file base.h - * \brief defines configuration macros - */ -#ifndef DMLC_BASE_H_ -#define DMLC_BASE_H_ - -/*! \brief whether use glog for logging */ -#ifndef DMLC_USE_GLOG -#define DMLC_USE_GLOG 0 -#endif - -/*! - * \brief whether throw dmlc::Error instead of - * directly calling abort when FATAL error occured - * NOTE: this may still not be perfect. - * do not use FATAL and CHECK in destructors - */ -#ifndef DMLC_LOG_FATAL_THROW -#define DMLC_LOG_FATAL_THROW 1 -#endif - -/*! - * \brief whether always log a message before throw - * This can help identify the error that cannot be catched. - */ -#ifndef DMLC_LOG_BEFORE_THROW -#define DMLC_LOG_BEFORE_THROW 0 -#endif - -/*! - * \brief Whether to use customized logger, - * whose output can be decided by other libraries. - */ -#ifndef DMLC_LOG_CUSTOMIZE -#define DMLC_LOG_CUSTOMIZE 0 -#endif - -/*! - * \brief Whether to print stack trace for fatal error, - * enabled on linux when using gcc. - */ -#if (defined(__GNUC__) && !defined(__MINGW32__)\ - && !defined(__sun) && !defined(__SVR4)\ - && !(defined __MINGW64__) && !(defined __ANDROID__)) -#if (!defined(DMLC_LOG_STACK_TRACE)) -#define DMLC_LOG_STACK_TRACE 1 -#endif -#if (!defined(DMLC_LOG_STACK_TRACE_SIZE)) -#define DMLC_LOG_STACK_TRACE_SIZE 10 -#endif -#endif - -/*! \brief whether compile with hdfs support */ -#ifndef DMLC_USE_HDFS -#define DMLC_USE_HDFS 0 -#endif - -/*! \brief whether compile with s3 support */ -#ifndef DMLC_USE_S3 -#define DMLC_USE_S3 0 -#endif - -/*! \brief whether or not use parameter server */ -#ifndef DMLC_USE_PS -#define DMLC_USE_PS 0 -#endif - -/*! \brief whether or not use c++11 support */ -#ifndef DMLC_USE_CXX11 -#if defined(__GXX_EXPERIMENTAL_CXX0X__) || defined(_MSC_VER) -#define DMLC_USE_CXX11 1 -#else -#define DMLC_USE_CXX11 (__cplusplus >= 201103L) -#endif -#endif - -/*! \brief strict CXX11 support */ -#ifndef DMLC_STRICT_CXX11 -#if defined(_MSC_VER) -#define DMLC_STRICT_CXX11 1 -#else -#define DMLC_STRICT_CXX11 (__cplusplus >= 201103L) -#endif -#endif - -/*! \brief Whether cxx11 thread local is supported */ -#ifndef DMLC_CXX11_THREAD_LOCAL -#if defined(_MSC_VER) -#define DMLC_CXX11_THREAD_LOCAL (_MSC_VER >= 1900) -#elif defined(__clang__) -#define DMLC_CXX11_THREAD_LOCAL (__has_feature(cxx_thread_local)) -#else -#define DMLC_CXX11_THREAD_LOCAL (__cplusplus >= 201103L) -#endif -#endif - - -/*! \brief whether RTTI is enabled */ -#ifndef DMLC_ENABLE_RTTI -#define DMLC_ENABLE_RTTI 1 -#endif - -/*! \brief whether use fopen64 */ -#ifndef DMLC_USE_FOPEN64 -#define DMLC_USE_FOPEN64 1 -#endif - -/// check if g++ is before 4.6 -#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__) -#if __GNUC__ == 4 && __GNUC_MINOR__ < 6 -#pragma message("Will need g++-4.6 or higher to compile all" \ - "the features in dmlc-core, " \ - "compile without c++0x, some features may be disabled") -#undef DMLC_USE_CXX11 -#define DMLC_USE_CXX11 0 -#endif -#endif - -/*! - * \brief Use little endian for binary serialization - * if this is set to 0, use big endian. - */ -#ifndef DMLC_IO_USE_LITTLE_ENDIAN -#define DMLC_IO_USE_LITTLE_ENDIAN 1 -#endif - -/*! - * \brief Enable std::thread related modules, - * Used to disable some module in mingw compile. - */ -#ifndef DMLC_ENABLE_STD_THREAD -#define DMLC_ENABLE_STD_THREAD DMLC_USE_CXX11 -#endif - -/*! \brief whether enable regex support, actually need g++-4.9 or higher*/ -#ifndef DMLC_USE_REGEX -#define DMLC_USE_REGEX DMLC_STRICT_CXX11 -#endif - -/*! \brief helper macro to supress unused warning */ -#if defined(__GNUC__) -#define DMLC_ATTRIBUTE_UNUSED __attribute__((unused)) -#else -#define DMLC_ATTRIBUTE_UNUSED -#endif - -/*! \brief helper macro to generate string concat */ -#define DMLC_STR_CONCAT_(__x, __y) __x##__y -#define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y) - -/*! - * \brief Disable copy constructor and assignment operator. - * - * If C++11 is supported, both copy and move constructors and - * assignment operators are deleted explicitly. Otherwise, they are - * only declared but not implemented. Place this macro in private - * section if C++11 is not available. - */ -#ifndef DISALLOW_COPY_AND_ASSIGN -# if DMLC_USE_CXX11 -# define DISALLOW_COPY_AND_ASSIGN(T) \ - T(T const&) = delete; \ - T(T&&) = delete; \ - T& operator=(T const&) = delete; \ - T& operator=(T&&) = delete -# else -# define DISALLOW_COPY_AND_ASSIGN(T) \ - T(T const&); \ - T& operator=(T const&) -# endif -#endif - -#if DMLC_USE_FOPEN64 && \ - (!defined(__GNUC__) || (defined __ANDROID__) || ((defined __MINGW32__) && !(defined __MINGW64__))) -#define fopen64 std::fopen -#endif - -#ifdef __APPLE__ -# define off64_t off_t -# if DMLC_USE_FOPEN64 -# define fopen64 std::fopen -# endif -#endif - -#ifdef _MSC_VER -#if _MSC_VER < 1900 -// NOTE: sprintf_s is not equivalent to snprintf, -// they are equivalent when success, which is sufficient for our case -#define snprintf sprintf_s -#define vsnprintf vsprintf_s -#endif -#else -#ifdef _FILE_OFFSET_BITS -#if _FILE_OFFSET_BITS == 32 -#pragma message("Warning: FILE OFFSET BITS defined to be 32 bit") -#endif -#endif - - -extern "C" { -#include -} -#endif - -#ifdef _MSC_VER -//! \cond Doxygen_Suppress -typedef signed char int8_t; -typedef __int16 int16_t; -typedef __int32 int32_t; -typedef __int64 int64_t; -typedef unsigned char uint8_t; -typedef unsigned __int16 uint16_t; -typedef unsigned __int32 uint32_t; -typedef unsigned __int64 uint64_t; -//! \endcond -#else -#include -#endif -#include -#include - -#if defined(_MSC_VER) && _MSC_VER < 1900 -#define noexcept_true throw () -#define noexcept_false -#define noexcept(a) noexcept_##a -#endif - -#if DMLC_USE_CXX11 -#define DMLC_THROW_EXCEPTION noexcept(false) -#define DMLC_NO_EXCEPTION noexcept(true) -#else -#define DMLC_THROW_EXCEPTION -#define DMLC_NO_EXCEPTION -#endif - -/*! \brief namespace for dmlc */ -namespace dmlc { -/*! - * \brief safely get the beginning address of a vector - * \param vec input vector - * \return beginning address of a vector - */ -template -inline T *BeginPtr(std::vector &vec) { // NOLINT(*) - if (vec.size() == 0) { - return NULL; - } else { - return &vec[0]; - } -} -/*! - * \brief get the beginning address of a const vector - * \param vec input vector - * \return beginning address of a vector - */ -template -inline const T *BeginPtr(const std::vector &vec) { - if (vec.size() == 0) { - return NULL; - } else { - return &vec[0]; - } -} -/*! - * \brief get the beginning address of a string - * \param str input string - * \return beginning address of a string - */ -inline char* BeginPtr(std::string &str) { // NOLINT(*) - if (str.length() == 0) return NULL; - return &str[0]; -} -/*! - * \brief get the beginning address of a const string - * \param str input string - * \return beginning address of a string - */ -inline const char* BeginPtr(const std::string &str) { - if (str.length() == 0) return NULL; - return &str[0]; -} -} // namespace dmlc - -#if defined(_MSC_VER) && _MSC_VER < 1900 -#define constexpr const -#define alignof __alignof -#endif - -#endif // DMLC_BASE_H_ diff --git a/include/dmlc/blockingconcurrentqueue.h b/include/dmlc/blockingconcurrentqueue.h deleted file mode 100644 index 9d249430289b..000000000000 --- a/include/dmlc/blockingconcurrentqueue.h +++ /dev/null @@ -1,991 +0,0 @@ -//! \cond Doxygen_Suppress -// Provides an efficient blocking version of moodycamel::ConcurrentQueue. -// ©2015-2016 Cameron Desrochers. Distributed under the terms of the simplified -// BSD license, available at the top of concurrentqueue.h. -// Uses Jeff Preshing's semaphore implementation (under the terms of its -// separate zlib license, embedded below). - -#ifndef DMLC_BLOCKINGCONCURRENTQUEUE_H_ -#define DMLC_BLOCKINGCONCURRENTQUEUE_H_ - -#pragma once - -#include "concurrentqueue.h" -#include -#include -#include -#include -#include - -#if defined(_WIN32) -// Avoid including windows.h in a header; we only need a handful of -// items, so we'll redeclare them here (this is relatively safe since -// the API generally has to remain stable between Windows versions). -// I know this is an ugly hack but it still beats polluting the global -// namespace with thousands of generic names or adding a .cpp for nothing. -extern "C" { - struct _SECURITY_ATTRIBUTES; - __declspec(dllimport) void* __stdcall CreateSemaphoreW(_SECURITY_ATTRIBUTES* lpSemaphoreAttributes, long lInitialCount, long lMaximumCount, const wchar_t* lpName); - __declspec(dllimport) int __stdcall CloseHandle(void* hObject); - __declspec(dllimport) unsigned long __stdcall WaitForSingleObject(void* hHandle, unsigned long dwMilliseconds); - __declspec(dllimport) int __stdcall ReleaseSemaphore(void* hSemaphore, long lReleaseCount, long* lpPreviousCount); -} -#elif defined(__MACH__) -#include -#elif defined(__unix__) -#include -#endif - -namespace dmlc { - -namespace moodycamel -{ -namespace details -{ - // Code in the mpmc_sema namespace below is an adaptation of Jeff Preshing's - // portable + lightweight semaphore implementations, originally from - // /~https://github.com/preshing/cpp11-on-multicore/blob/master/common/sema.h - // LICENSE: - // Copyright (c) 2015 Jeff Preshing - // - // This software is provided 'as-is', without any express or implied - // warranty. In no event will the authors be held liable for any damages - // arising from the use of this software. - // - // Permission is granted to anyone to use this software for any purpose, - // including commercial applications, and to alter it and redistribute it - // freely, subject to the following restrictions: - // - // 1. The origin of this software must not be misrepresented; you must not - // claim that you wrote the original software. If you use this software - // in a product, an acknowledgement in the product documentation would be - // appreciated but is not required. - // 2. Altered source versions must be plainly marked as such, and must not be - // misrepresented as being the original software. - // 3. This notice may not be removed or altered from any source distribution. - namespace mpmc_sema - { -#if defined(_WIN32) - class Semaphore - { - private: - void* m_hSema; - - Semaphore(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; - Semaphore& operator=(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; - - public: - Semaphore(int initialCount = 0) - { - assert(initialCount >= 0); - const long maxLong = 0x7fffffff; - m_hSema = CreateSemaphoreW(nullptr, initialCount, maxLong, nullptr); - } - - ~Semaphore() - { - CloseHandle(m_hSema); - } - - void wait() - { - const unsigned long infinite = 0xffffffff; - WaitForSingleObject(m_hSema, infinite); - } - - bool try_wait() - { - const unsigned long RC_WAIT_TIMEOUT = 0x00000102; - return WaitForSingleObject(m_hSema, 0) != RC_WAIT_TIMEOUT; - } - - bool timed_wait(std::uint64_t usecs) - { - const unsigned long RC_WAIT_TIMEOUT = 0x00000102; - return WaitForSingleObject(m_hSema, (unsigned long)(usecs / 1000)) != RC_WAIT_TIMEOUT; - } - - void signal(int count = 1) - { - ReleaseSemaphore(m_hSema, count, nullptr); - } - }; -#elif defined(__MACH__) - //--------------------------------------------------------- - // Semaphore (Apple iOS and OSX) - // Can't use POSIX semaphores due to http://lists.apple.com/archives/darwin-kernel/2009/Apr/msg00010.html - //--------------------------------------------------------- - class Semaphore - { - private: - semaphore_t m_sema; - - Semaphore(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; - Semaphore& operator=(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; - - public: - Semaphore(int initialCount = 0) - { - assert(initialCount >= 0); - semaphore_create(mach_task_self(), &m_sema, SYNC_POLICY_FIFO, initialCount); - } - - ~Semaphore() - { - semaphore_destroy(mach_task_self(), m_sema); - } - - void wait() - { - semaphore_wait(m_sema); - } - - bool try_wait() - { - return timed_wait(0); - } - - bool timed_wait(std::uint64_t timeout_usecs) - { - mach_timespec_t ts; - ts.tv_sec = static_cast(timeout_usecs / 1000000); - ts.tv_nsec = (timeout_usecs % 1000000) * 1000; - - // added in OSX 10.10: https://developer.apple.com/library/prerelease/mac/documentation/General/Reference/APIDiffsMacOSX10_10SeedDiff/modules/Darwin.html - kern_return_t rc = semaphore_timedwait(m_sema, ts); - - return rc != KERN_OPERATION_TIMED_OUT; - } - - void signal() - { - semaphore_signal(m_sema); - } - - void signal(int count) - { - while (count-- > 0) - { - semaphore_signal(m_sema); - } - } - }; -#elif defined(__unix__) - //--------------------------------------------------------- - // Semaphore (POSIX, Linux) - //--------------------------------------------------------- - class Semaphore - { - private: - sem_t m_sema; - - Semaphore(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; - Semaphore& operator=(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; - - public: - Semaphore(int initialCount = 0) - { - assert(initialCount >= 0); - sem_init(&m_sema, 0, initialCount); - } - - ~Semaphore() - { - sem_destroy(&m_sema); - } - - void wait() - { - // http://stackoverflow.com/questions/2013181/gdb-causes-sem-wait-to-fail-with-eintr-error - int rc; - do { - rc = sem_wait(&m_sema); - } while (rc == -1 && errno == EINTR); - } - - bool try_wait() - { - int rc; - do { - rc = sem_trywait(&m_sema); - } while (rc == -1 && errno == EINTR); - return !(rc == -1 && errno == EAGAIN); - } - - bool timed_wait(std::uint64_t usecs) - { - struct timespec ts; - const int usecs_in_1_sec = 1000000; - const int nsecs_in_1_sec = 1000000000; - clock_gettime(CLOCK_REALTIME, &ts); - ts.tv_sec += usecs / usecs_in_1_sec; - ts.tv_nsec += (usecs % usecs_in_1_sec) * 1000; - // sem_timedwait bombs if you have more than 1e9 in tv_nsec - // so we have to clean things up before passing it in - if (ts.tv_nsec >= nsecs_in_1_sec) { - ts.tv_nsec -= nsecs_in_1_sec; - ++ts.tv_sec; - } - - int rc; - do { - rc = sem_timedwait(&m_sema, &ts); - } while (rc == -1 && errno == EINTR); - return !(rc == -1 && errno == ETIMEDOUT); - } - - void signal() - { - sem_post(&m_sema); - } - - void signal(int count) - { - while (count-- > 0) - { - sem_post(&m_sema); - } - } - }; -#else -#error Unsupported platform! (No semaphore wrapper available) -#endif - - //--------------------------------------------------------- - // LightweightSemaphore - //--------------------------------------------------------- - class LightweightSemaphore - { - public: - typedef std::make_signed::type ssize_t; - - private: - std::atomic m_count; - Semaphore m_sema; - - bool waitWithPartialSpinning(std::int64_t timeout_usecs = -1) - { - ssize_t oldCount; - // Is there a better way to set the initial spin count? - // If we lower it to 1000, testBenaphore becomes 15x slower on my Core i7-5930K Windows PC, - // as threads start hitting the kernel semaphore. - int spin = 10000; - while (--spin >= 0) - { - oldCount = m_count.load(std::memory_order_relaxed); - if ((oldCount > 0) && m_count.compare_exchange_strong(oldCount, oldCount - 1, std::memory_order_acquire, std::memory_order_relaxed)) - return true; - std::atomic_signal_fence(std::memory_order_acquire); // Prevent the compiler from collapsing the loop. - } - oldCount = m_count.fetch_sub(1, std::memory_order_acquire); - if (oldCount > 0) - return true; - if (timeout_usecs < 0) - { - m_sema.wait(); - return true; - } - if (m_sema.timed_wait((std::uint64_t)timeout_usecs)) - return true; - // At this point, we've timed out waiting for the semaphore, but the - // count is still decremented indicating we may still be waiting on - // it. So we have to re-adjust the count, but only if the semaphore - // wasn't signaled enough times for us too since then. If it was, we - // need to release the semaphore too. - while (true) - { - oldCount = m_count.load(std::memory_order_acquire); - if (oldCount >= 0 && m_sema.try_wait()) - return true; - if (oldCount < 0 && m_count.compare_exchange_strong(oldCount, oldCount + 1, std::memory_order_relaxed, std::memory_order_relaxed)) - return false; - } - } - - ssize_t waitManyWithPartialSpinning(ssize_t max, std::int64_t timeout_usecs = -1) - { - assert(max > 0); - ssize_t oldCount; - int spin = 10000; - while (--spin >= 0) - { - oldCount = m_count.load(std::memory_order_relaxed); - if (oldCount > 0) - { - ssize_t newCount = oldCount > max ? oldCount - max : 0; - if (m_count.compare_exchange_strong(oldCount, newCount, std::memory_order_acquire, std::memory_order_relaxed)) - return oldCount - newCount; - } - std::atomic_signal_fence(std::memory_order_acquire); - } - oldCount = m_count.fetch_sub(1, std::memory_order_acquire); - if (oldCount <= 0) - { - if (timeout_usecs < 0) - m_sema.wait(); - else if (!m_sema.timed_wait((std::uint64_t)timeout_usecs)) - { - while (true) - { - oldCount = m_count.load(std::memory_order_acquire); - if (oldCount >= 0 && m_sema.try_wait()) - break; - if (oldCount < 0 && m_count.compare_exchange_strong(oldCount, oldCount + 1, std::memory_order_relaxed, std::memory_order_relaxed)) - return 0; - } - } - } - if (max > 1) - return 1 + tryWaitMany(max - 1); - return 1; - } - - public: - LightweightSemaphore(ssize_t initialCount = 0) : m_count(initialCount) - { - assert(initialCount >= 0); - } - - bool tryWait() - { - ssize_t oldCount = m_count.load(std::memory_order_relaxed); - while (oldCount > 0) - { - if (m_count.compare_exchange_weak(oldCount, oldCount - 1, std::memory_order_acquire, std::memory_order_relaxed)) - return true; - } - return false; - } - - void wait() - { - if (!tryWait()) - waitWithPartialSpinning(); - } - - bool wait(std::int64_t timeout_usecs) - { - return tryWait() || waitWithPartialSpinning(timeout_usecs); - } - - // Acquires between 0 and (greedily) max, inclusive - ssize_t tryWaitMany(ssize_t max) - { - assert(max >= 0); - ssize_t oldCount = m_count.load(std::memory_order_relaxed); - while (oldCount > 0) - { - ssize_t newCount = oldCount > max ? oldCount - max : 0; - if (m_count.compare_exchange_weak(oldCount, newCount, std::memory_order_acquire, std::memory_order_relaxed)) - return oldCount - newCount; - } - return 0; - } - - // Acquires at least one, and (greedily) at most max - ssize_t waitMany(ssize_t max, std::int64_t timeout_usecs) - { - assert(max >= 0); - ssize_t result = tryWaitMany(max); - if (result == 0 && max > 0) - result = waitManyWithPartialSpinning(max, timeout_usecs); - return result; - } - - ssize_t waitMany(ssize_t max) - { - ssize_t result = waitMany(max, -1); - assert(result > 0); - return result; - } - - void signal(ssize_t count = 1) - { - assert(count >= 0); - ssize_t oldCount = m_count.fetch_add(count, std::memory_order_release); - ssize_t toRelease = -oldCount < count ? -oldCount : count; - if (toRelease > 0) - { - m_sema.signal((int)toRelease); - } - } - - ssize_t availableApprox() const - { - ssize_t count = m_count.load(std::memory_order_relaxed); - return count > 0 ? count : 0; - } - }; - } // end namespace mpmc_sema -} // end namespace details - - -// This is a blocking version of the queue. It has an almost identical interface to -// the normal non-blocking version, with the addition of various wait_dequeue() methods -// and the removal of producer-specific dequeue methods. -template -class BlockingConcurrentQueue -{ -private: - typedef ::dmlc::moodycamel::ConcurrentQueue ConcurrentQueue; - typedef details::mpmc_sema::LightweightSemaphore LightweightSemaphore; - -public: - typedef typename ConcurrentQueue::producer_token_t producer_token_t; - typedef typename ConcurrentQueue::consumer_token_t consumer_token_t; - - typedef typename ConcurrentQueue::index_t index_t; - typedef typename ConcurrentQueue::size_t size_t; - typedef typename std::make_signed::type ssize_t; - - static const size_t BLOCK_SIZE = ConcurrentQueue::BLOCK_SIZE; - static const size_t EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD = ConcurrentQueue::EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD; - static const size_t EXPLICIT_INITIAL_INDEX_SIZE = ConcurrentQueue::EXPLICIT_INITIAL_INDEX_SIZE; - static const size_t IMPLICIT_INITIAL_INDEX_SIZE = ConcurrentQueue::IMPLICIT_INITIAL_INDEX_SIZE; - static const size_t INITIAL_IMPLICIT_PRODUCER_HASH_SIZE = ConcurrentQueue::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE; - static const std::uint32_t EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE = ConcurrentQueue::EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE; - static const size_t MAX_SUBQUEUE_SIZE = ConcurrentQueue::MAX_SUBQUEUE_SIZE; - -public: - // Creates a queue with at least `capacity` element slots; note that the - // actual number of elements that can be inserted without additional memory - // allocation depends on the number of producers and the block size (e.g. if - // the block size is equal to `capacity`, only a single block will be allocated - // up-front, which means only a single producer will be able to enqueue elements - // without an extra allocation -- blocks aren't shared between producers). - // This method is not thread safe -- it is up to the user to ensure that the - // queue is fully constructed before it starts being used by other threads (this - // includes making the memory effects of construction visible, possibly with a - // memory barrier). - explicit BlockingConcurrentQueue(size_t capacity = 6 * BLOCK_SIZE) - : inner(capacity), sema(create(), &BlockingConcurrentQueue::template destroy) - { - assert(reinterpret_cast((BlockingConcurrentQueue*)1) == &((BlockingConcurrentQueue*)1)->inner && "BlockingConcurrentQueue must have ConcurrentQueue as its first member"); - if (!sema) { - MOODYCAMEL_THROW(std::bad_alloc()); - } - } - - BlockingConcurrentQueue(size_t minCapacity, size_t maxExplicitProducers, size_t maxImplicitProducers) - : inner(minCapacity, maxExplicitProducers, maxImplicitProducers), sema(create(), &BlockingConcurrentQueue::template destroy) - { - assert(reinterpret_cast((BlockingConcurrentQueue*)1) == &((BlockingConcurrentQueue*)1)->inner && "BlockingConcurrentQueue must have ConcurrentQueue as its first member"); - if (!sema) { - MOODYCAMEL_THROW(std::bad_alloc()); - } - } - - // Disable copying and copy assignment - BlockingConcurrentQueue(BlockingConcurrentQueue const&) MOODYCAMEL_DELETE_FUNCTION; - BlockingConcurrentQueue& operator=(BlockingConcurrentQueue const&) MOODYCAMEL_DELETE_FUNCTION; - - // Moving is supported, but note that it is *not* a thread-safe operation. - // Nobody can use the queue while it's being moved, and the memory effects - // of that move must be propagated to other threads before they can use it. - // Note: When a queue is moved, its tokens are still valid but can only be - // used with the destination queue (i.e. semantically they are moved along - // with the queue itself). - BlockingConcurrentQueue(BlockingConcurrentQueue&& other) MOODYCAMEL_NOEXCEPT - : inner(std::move(other.inner)), sema(std::move(other.sema)) - { } - - inline BlockingConcurrentQueue& operator=(BlockingConcurrentQueue&& other) MOODYCAMEL_NOEXCEPT - { - return swap_internal(other); - } - - // Swaps this queue's state with the other's. Not thread-safe. - // Swapping two queues does not invalidate their tokens, however - // the tokens that were created for one queue must be used with - // only the swapped queue (i.e. the tokens are tied to the - // queue's movable state, not the object itself). - inline void swap(BlockingConcurrentQueue& other) MOODYCAMEL_NOEXCEPT - { - swap_internal(other); - } - -private: - BlockingConcurrentQueue& swap_internal(BlockingConcurrentQueue& other) - { - if (this == &other) { - return *this; - } - - inner.swap(other.inner); - sema.swap(other.sema); - return *this; - } - -public: - // Enqueues a single item (by copying it). - // Allocates memory if required. Only fails if memory allocation fails (or implicit - // production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0, - // or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). - // Thread-safe. - inline bool enqueue(T const& item) - { - if (details::likely(inner.enqueue(item))) { - sema->signal(); - return true; - } - return false; - } - - // Enqueues a single item (by moving it, if possible). - // Allocates memory if required. Only fails if memory allocation fails (or implicit - // production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0, - // or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). - // Thread-safe. - inline bool enqueue(T&& item) - { - if (details::likely(inner.enqueue(std::move(item)))) { - sema->signal(); - return true; - } - return false; - } - - // Enqueues a single item (by copying it) using an explicit producer token. - // Allocates memory if required. Only fails if memory allocation fails (or - // Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). - // Thread-safe. - inline bool enqueue(producer_token_t const& token, T const& item) - { - if (details::likely(inner.enqueue(token, item))) { - sema->signal(); - return true; - } - return false; - } - - // Enqueues a single item (by moving it, if possible) using an explicit producer token. - // Allocates memory if required. Only fails if memory allocation fails (or - // Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). - // Thread-safe. - inline bool enqueue(producer_token_t const& token, T&& item) - { - if (details::likely(inner.enqueue(token, std::move(item)))) { - sema->signal(); - return true; - } - return false; - } - - // Enqueues several items. - // Allocates memory if required. Only fails if memory allocation fails (or - // implicit production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE - // is 0, or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). - // Note: Use std::make_move_iterator if the elements should be moved instead of copied. - // Thread-safe. - template - inline bool enqueue_bulk(It itemFirst, size_t count) - { - if (details::likely(inner.enqueue_bulk(std::forward(itemFirst), count))) { - sema->signal((LightweightSemaphore::ssize_t)(ssize_t)count); - return true; - } - return false; - } - - // Enqueues several items using an explicit producer token. - // Allocates memory if required. Only fails if memory allocation fails - // (or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). - // Note: Use std::make_move_iterator if the elements should be moved - // instead of copied. - // Thread-safe. - template - inline bool enqueue_bulk(producer_token_t const& token, It itemFirst, size_t count) - { - if (details::likely(inner.enqueue_bulk(token, std::forward(itemFirst), count))) { - sema->signal((LightweightSemaphore::ssize_t)(ssize_t)count); - return true; - } - return false; - } - - // Enqueues a single item (by copying it). - // Does not allocate memory. Fails if not enough room to enqueue (or implicit - // production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE - // is 0). - // Thread-safe. - inline bool try_enqueue(T const& item) - { - if (inner.try_enqueue(item)) { - sema->signal(); - return true; - } - return false; - } - - // Enqueues a single item (by moving it, if possible). - // Does not allocate memory (except for one-time implicit producer). - // Fails if not enough room to enqueue (or implicit production is - // disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0). - // Thread-safe. - inline bool try_enqueue(T&& item) - { - if (inner.try_enqueue(std::move(item))) { - sema->signal(); - return true; - } - return false; - } - - // Enqueues a single item (by copying it) using an explicit producer token. - // Does not allocate memory. Fails if not enough room to enqueue. - // Thread-safe. - inline bool try_enqueue(producer_token_t const& token, T const& item) - { - if (inner.try_enqueue(token, item)) { - sema->signal(); - return true; - } - return false; - } - - // Enqueues a single item (by moving it, if possible) using an explicit producer token. - // Does not allocate memory. Fails if not enough room to enqueue. - // Thread-safe. - inline bool try_enqueue(producer_token_t const& token, T&& item) - { - if (inner.try_enqueue(token, std::move(item))) { - sema->signal(); - return true; - } - return false; - } - - // Enqueues several items. - // Does not allocate memory (except for one-time implicit producer). - // Fails if not enough room to enqueue (or implicit production is - // disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0). - // Note: Use std::make_move_iterator if the elements should be moved - // instead of copied. - // Thread-safe. - template - inline bool try_enqueue_bulk(It itemFirst, size_t count) - { - if (inner.try_enqueue_bulk(std::forward(itemFirst), count)) { - sema->signal((LightweightSemaphore::ssize_t)(ssize_t)count); - return true; - } - return false; - } - - // Enqueues several items using an explicit producer token. - // Does not allocate memory. Fails if not enough room to enqueue. - // Note: Use std::make_move_iterator if the elements should be moved - // instead of copied. - // Thread-safe. - template - inline bool try_enqueue_bulk(producer_token_t const& token, It itemFirst, size_t count) - { - if (inner.try_enqueue_bulk(token, std::forward(itemFirst), count)) { - sema->signal((LightweightSemaphore::ssize_t)(ssize_t)count); - return true; - } - return false; - } - - - // Attempts to dequeue from the queue. - // Returns false if all producer streams appeared empty at the time they - // were checked (so, the queue is likely but not guaranteed to be empty). - // Never allocates. Thread-safe. - template - inline bool try_dequeue(U& item) - { - if (sema->tryWait()) { - while (!inner.try_dequeue(item)) { - continue; - } - return true; - } - return false; - } - - // Attempts to dequeue from the queue using an explicit consumer token. - // Returns false if all producer streams appeared empty at the time they - // were checked (so, the queue is likely but not guaranteed to be empty). - // Never allocates. Thread-safe. - template - inline bool try_dequeue(consumer_token_t& token, U& item) - { - if (sema->tryWait()) { - while (!inner.try_dequeue(token, item)) { - continue; - } - return true; - } - return false; - } - - // Attempts to dequeue several elements from the queue. - // Returns the number of items actually dequeued. - // Returns 0 if all producer streams appeared empty at the time they - // were checked (so, the queue is likely but not guaranteed to be empty). - // Never allocates. Thread-safe. - template - inline size_t try_dequeue_bulk(It itemFirst, size_t max) - { - size_t count = 0; - max = (size_t)sema->tryWaitMany((LightweightSemaphore::ssize_t)(ssize_t)max); - while (count != max) { - count += inner.template try_dequeue_bulk(itemFirst, max - count); - } - return count; - } - - // Attempts to dequeue several elements from the queue using an explicit consumer token. - // Returns the number of items actually dequeued. - // Returns 0 if all producer streams appeared empty at the time they - // were checked (so, the queue is likely but not guaranteed to be empty). - // Never allocates. Thread-safe. - template - inline size_t try_dequeue_bulk(consumer_token_t& token, It itemFirst, size_t max) - { - size_t count = 0; - max = (size_t)sema->tryWaitMany((LightweightSemaphore::ssize_t)(ssize_t)max); - while (count != max) { - count += inner.template try_dequeue_bulk(token, itemFirst, max - count); - } - return count; - } - - - - // Blocks the current thread until there's something to dequeue, then - // dequeues it. - // Never allocates. Thread-safe. - template - inline void wait_dequeue(U& item) - { - sema->wait(); - while (!inner.try_dequeue(item)) { - continue; - } - } - - // Blocks the current thread until either there's something to dequeue - // or the timeout (specified in microseconds) expires. Returns false - // without setting `item` if the timeout expires, otherwise assigns - // to `item` and returns true. - // Using a negative timeout indicates an indefinite timeout, - // and is thus functionally equivalent to calling wait_dequeue. - // Never allocates. Thread-safe. - template - inline bool wait_dequeue_timed(U& item, std::int64_t timeout_usecs) - { - if (!sema->wait(timeout_usecs)) { - return false; - } - while (!inner.try_dequeue(item)) { - continue; - } - return true; - } - - // Blocks the current thread until either there's something to dequeue - // or the timeout expires. Returns false without setting `item` if the - // timeout expires, otherwise assigns to `item` and returns true. - // Never allocates. Thread-safe. - template - inline bool wait_dequeue_timed(U& item, std::chrono::duration const& timeout) - { - return wait_dequeue_timed(item, std::chrono::duration_cast(timeout).count()); - } - - // Blocks the current thread until there's something to dequeue, then - // dequeues it using an explicit consumer token. - // Never allocates. Thread-safe. - template - inline void wait_dequeue(consumer_token_t& token, U& item) - { - sema->wait(); - while (!inner.try_dequeue(token, item)) { - continue; - } - } - - // Blocks the current thread until either there's something to dequeue - // or the timeout (specified in microseconds) expires. Returns false - // without setting `item` if the timeout expires, otherwise assigns - // to `item` and returns true. - // Using a negative timeout indicates an indefinite timeout, - // and is thus functionally equivalent to calling wait_dequeue. - // Never allocates. Thread-safe. - template - inline bool wait_dequeue_timed(consumer_token_t& token, U& item, std::int64_t timeout_usecs) - { - if (!sema->wait(timeout_usecs)) { - return false; - } - while (!inner.try_dequeue(token, item)) { - continue; - } - return true; - } - - // Blocks the current thread until either there's something to dequeue - // or the timeout expires. Returns false without setting `item` if the - // timeout expires, otherwise assigns to `item` and returns true. - // Never allocates. Thread-safe. - template - inline bool wait_dequeue_timed(consumer_token_t& token, U& item, std::chrono::duration const& timeout) - { - return wait_dequeue_timed(token, item, std::chrono::duration_cast(timeout).count()); - } - - // Attempts to dequeue several elements from the queue. - // Returns the number of items actually dequeued, which will - // always be at least one (this method blocks until the queue - // is non-empty) and at most max. - // Never allocates. Thread-safe. - template - inline size_t wait_dequeue_bulk(It itemFirst, size_t max) - { - size_t count = 0; - max = (size_t)sema->waitMany((LightweightSemaphore::ssize_t)(ssize_t)max); - while (count != max) { - count += inner.template try_dequeue_bulk(itemFirst, max - count); - } - return count; - } - - // Attempts to dequeue several elements from the queue. - // Returns the number of items actually dequeued, which can - // be 0 if the timeout expires while waiting for elements, - // and at most max. - // Using a negative timeout indicates an indefinite timeout, - // and is thus functionally equivalent to calling wait_dequeue_bulk. - // Never allocates. Thread-safe. - template - inline size_t wait_dequeue_bulk_timed(It itemFirst, size_t max, std::int64_t timeout_usecs) - { - size_t count = 0; - max = (size_t)sema->waitMany((LightweightSemaphore::ssize_t)(ssize_t)max, timeout_usecs); - while (count != max) { - count += inner.template try_dequeue_bulk(itemFirst, max - count); - } - return count; - } - - // Attempts to dequeue several elements from the queue. - // Returns the number of items actually dequeued, which can - // be 0 if the timeout expires while waiting for elements, - // and at most max. - // Never allocates. Thread-safe. - template - inline size_t wait_dequeue_bulk_timed(It itemFirst, size_t max, std::chrono::duration const& timeout) - { - return wait_dequeue_bulk_timed(itemFirst, max, std::chrono::duration_cast(timeout).count()); - } - - // Attempts to dequeue several elements from the queue using an explicit consumer token. - // Returns the number of items actually dequeued, which will - // always be at least one (this method blocks until the queue - // is non-empty) and at most max. - // Never allocates. Thread-safe. - template - inline size_t wait_dequeue_bulk(consumer_token_t& token, It itemFirst, size_t max) - { - size_t count = 0; - max = (size_t)sema->waitMany((LightweightSemaphore::ssize_t)(ssize_t)max); - while (count != max) { - count += inner.template try_dequeue_bulk(token, itemFirst, max - count); - } - return count; - } - - // Attempts to dequeue several elements from the queue using an explicit consumer token. - // Returns the number of items actually dequeued, which can - // be 0 if the timeout expires while waiting for elements, - // and at most max. - // Using a negative timeout indicates an indefinite timeout, - // and is thus functionally equivalent to calling wait_dequeue_bulk. - // Never allocates. Thread-safe. - template - inline size_t wait_dequeue_bulk_timed(consumer_token_t& token, It itemFirst, size_t max, std::int64_t timeout_usecs) - { - size_t count = 0; - max = (size_t)sema->waitMany((LightweightSemaphore::ssize_t)(ssize_t)max, timeout_usecs); - while (count != max) { - count += inner.template try_dequeue_bulk(token, itemFirst, max - count); - } - return count; - } - - // Attempts to dequeue several elements from the queue using an explicit consumer token. - // Returns the number of items actually dequeued, which can - // be 0 if the timeout expires while waiting for elements, - // and at most max. - // Never allocates. Thread-safe. - template - inline size_t wait_dequeue_bulk_timed(consumer_token_t& token, It itemFirst, size_t max, std::chrono::duration const& timeout) - { - return wait_dequeue_bulk_timed(token, itemFirst, max, std::chrono::duration_cast(timeout).count()); - } - - - // Returns an estimate of the total number of elements currently in the queue. This - // estimate is only accurate if the queue has completely stabilized before it is called - // (i.e. all enqueue and dequeue operations have completed and their memory effects are - // visible on the calling thread, and no further operations start while this method is - // being called). - // Thread-safe. - inline size_t size_approx() const - { - return (size_t)sema->availableApprox(); - } - - - // Returns true if the underlying atomic variables used by - // the queue are lock-free (they should be on most platforms). - // Thread-safe. - static bool is_lock_free() - { - return ConcurrentQueue::is_lock_free(); - } - - -private: - template - static inline U* create() - { - auto p = (Traits::malloc)(sizeof(U)); - return p != nullptr ? new (p) U : nullptr; - } - - template - static inline U* create(A1&& a1) - { - auto p = (Traits::malloc)(sizeof(U)); - return p != nullptr ? new (p) U(std::forward(a1)) : nullptr; - } - - template - static inline void destroy(U* p) - { - if (p != nullptr) { - p->~U(); - } - (Traits::free)(p); - } - -private: - ConcurrentQueue inner; - std::unique_ptr sema; -}; - - -template -inline void swap(BlockingConcurrentQueue& a, BlockingConcurrentQueue& b) MOODYCAMEL_NOEXCEPT -{ - a.swap(b); -} - -} // end namespace moodycamel -} // namespace dmlc - -#endif // DMLC_BLOCKINGCONCURRENTQUEUE_H_ -//! \endcond Doxygen_Suppress diff --git a/include/dmlc/common.h b/include/dmlc/common.h deleted file mode 100644 index 9aead8c5b142..000000000000 --- a/include/dmlc/common.h +++ /dev/null @@ -1,85 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file common.h - * \brief defines some common utility function. - */ -#ifndef DMLC_COMMON_H_ -#define DMLC_COMMON_H_ - -#include -#include -#include -#include -#include "./logging.h" - -namespace dmlc { -/*! - * \brief Split a string by delimiter - * \param s String to be splitted. - * \param delim The delimiter. - * \return a splitted vector of strings. - */ -inline std::vector Split(const std::string& s, char delim) { - std::string item; - std::istringstream is(s); - std::vector ret; - while (std::getline(is, item, delim)) { - ret.push_back(item); - } - return ret; -} - -/*! - * \brief hash an object and combines the key with previous keys - */ -template -inline size_t HashCombine(size_t key, const T& value) { - std::hash hash_func; - return key ^ (hash_func(value) + 0x9e3779b9 + (key << 6) + (key >> 2)); -} - -/*! - * \brief specialize for size_t - */ -template<> -inline size_t HashCombine(size_t key, const size_t& value) { - return key ^ (value + 0x9e3779b9 + (key << 6) + (key >> 2)); -} - -/*! - * \brief OMP Exception class catches, saves and rethrows exception from OMP blocks - */ -class OMPException { - private: - // exception_ptr member to store the exception - std::exception_ptr omp_exception_; - // mutex to be acquired during catch to set the exception_ptr - std::mutex mutex_; - - public: - /*! - * \brief Parallel OMP blocks should be placed within Run to save exception - */ - template - void Run(Function f, Parameters... params) { - try { - f(params...); - } catch (dmlc::Error &ex) { - std::lock_guard lock(mutex_); - if (!omp_exception_) { - omp_exception_ = std::current_exception(); - } - } - } - - /*! - * \brief should be called from the main thread to rethrow the exception - */ - void Rethrow() { - if (this->omp_exception_) std::rethrow_exception(this->omp_exception_); - } -}; - -} // namespace dmlc - -#endif // DMLC_COMMON_H_ diff --git a/include/dmlc/concurrency.h b/include/dmlc/concurrency.h deleted file mode 100644 index 754cf5aa286e..000000000000 --- a/include/dmlc/concurrency.h +++ /dev/null @@ -1,258 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file concurrency.h - * \brief thread-safe data structures. - * \author Yutian Li - */ -#ifndef DMLC_CONCURRENCY_H_ -#define DMLC_CONCURRENCY_H_ -// this code depends on c++11 -#if DMLC_USE_CXX11 -#include -#include -#include -#include -#include -#include -#include "dmlc/base.h" - -namespace dmlc { - -/*! - * \brief Simple userspace spinlock implementation. - */ -class Spinlock { - public: -#ifdef _MSC_VER - Spinlock() { - lock_.clear(); - } -#else -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wbraced-scalar-init" - Spinlock() : lock_(ATOMIC_FLAG_INIT) { - } -#pragma clang diagnostic pop -#endif - ~Spinlock() = default; - /*! - * \brief Acquire lock. - */ - inline void lock() noexcept(true); - /*! - * \brief Release lock. - */ - inline void unlock() noexcept(true); - - private: - std::atomic_flag lock_; - /*! - * \brief Disable copy and move. - */ - DISALLOW_COPY_AND_ASSIGN(Spinlock); -}; - -/*! \brief type of concurrent queue */ -enum class ConcurrentQueueType { - /*! \brief FIFO queue */ - kFIFO, - /*! \brief queue with priority */ - kPriority -}; - -/*! - * \brief Cocurrent blocking queue. - */ -template -class ConcurrentBlockingQueue { - public: - ConcurrentBlockingQueue(); - ~ConcurrentBlockingQueue() = default; - /*! - * \brief Push element to the end of the queue. - * \param e Element to push into. - * \param priority the priority of the element, only used for priority queue. - * The higher the priority is, the better. - * \tparam E the element type - * - * It will copy or move the element into the queue, depending on the type of - * the parameter. - */ - template - void Push(E&& e, int priority = 0); - - /*! - * \brief Push element to the front of the queue. Only works for FIFO queue. - * For priority queue it is the same as Push. - * \param e Element to push into. - * \param priority the priority of the element, only used for priority queue. - * The higher the priority is, the better. - * \tparam E the element type - * - * It will copy or move the element into the queue, depending on the type of - * the parameter. - */ - template - void PushFront(E&& e, int priority = 0); - /*! - * \brief Pop element from the queue. - * \param rv Element popped. - * \return On false, the queue is exiting. - * - * The element will be copied or moved into the object passed in. - */ - bool Pop(T* rv); - /*! - * \brief Signal the queue for destruction. - * - * After calling this method, all blocking pop call to the queue will return - * false. - */ - void SignalForKill(); - /*! - * \brief Get the size of the queue. - * \return The size of the queue. - */ - size_t Size(); - - private: - struct Entry { - T data; - int priority; - inline bool operator<(const Entry &b) const { - return priority < b.priority; - } - }; - - std::mutex mutex_; - std::condition_variable cv_; - std::atomic exit_now_; - int nwait_consumer_; - // a priority queue - std::vector priority_queue_; - // a FIFO queue - std::deque fifo_queue_; - /*! - * \brief Disable copy and move. - */ - DISALLOW_COPY_AND_ASSIGN(ConcurrentBlockingQueue); -}; - -inline void Spinlock::lock() noexcept(true) { - while (lock_.test_and_set(std::memory_order_acquire)) { - } -} - -inline void Spinlock::unlock() noexcept(true) { - lock_.clear(std::memory_order_release); -} - -template -ConcurrentBlockingQueue::ConcurrentBlockingQueue() - : exit_now_{false}, nwait_consumer_{0} {} - -template -template -void ConcurrentBlockingQueue::Push(E&& e, int priority) { - static_assert(std::is_same::type>::type, - T>::value, - "Types must match."); - bool notify; - { - std::lock_guard lock{mutex_}; - if (type == ConcurrentQueueType::kFIFO) { - fifo_queue_.emplace_back(std::forward(e)); - notify = nwait_consumer_ != 0; - } else { - Entry entry; - entry.data = std::move(e); - entry.priority = priority; - priority_queue_.push_back(std::move(entry)); - std::push_heap(priority_queue_.begin(), priority_queue_.end()); - notify = nwait_consumer_ != 0; - } - } - if (notify) cv_.notify_one(); -} - -template -template -void ConcurrentBlockingQueue::PushFront(E&& e, int priority) { - static_assert(std::is_same::type>::type, - T>::value, - "Types must match."); - bool notify; - { - std::lock_guard lock{mutex_}; - if (type == ConcurrentQueueType::kFIFO) { - fifo_queue_.emplace_front(std::forward(e)); - notify = nwait_consumer_ != 0; - } else { - Entry entry; - entry.data = std::move(e); - entry.priority = priority; - priority_queue_.push_back(std::move(entry)); - std::push_heap(priority_queue_.begin(), priority_queue_.end()); - notify = nwait_consumer_ != 0; - } - } - if (notify) cv_.notify_one(); -} - -template -bool ConcurrentBlockingQueue::Pop(T* rv) { - std::unique_lock lock{mutex_}; - if (type == ConcurrentQueueType::kFIFO) { - ++nwait_consumer_; - cv_.wait(lock, [this] { - return !fifo_queue_.empty() || exit_now_.load(); - }); - --nwait_consumer_; - if (!exit_now_.load()) { - *rv = std::move(fifo_queue_.front()); - fifo_queue_.pop_front(); - return true; - } else { - return false; - } - } else { - ++nwait_consumer_; - cv_.wait(lock, [this] { - return !priority_queue_.empty() || exit_now_.load(); - }); - --nwait_consumer_; - if (!exit_now_.load()) { - std::pop_heap(priority_queue_.begin(), priority_queue_.end()); - *rv = std::move(priority_queue_.back().data); - priority_queue_.pop_back(); - return true; - } else { - return false; - } - } -} - -template -void ConcurrentBlockingQueue::SignalForKill() { - { - std::lock_guard lock{mutex_}; - exit_now_.store(true); - } - cv_.notify_all(); -} - -template -size_t ConcurrentBlockingQueue::Size() { - std::lock_guard lock{mutex_}; - if (type == ConcurrentQueueType::kFIFO) { - return fifo_queue_.size(); - } else { - return priority_queue_.size(); - } -} -} // namespace dmlc -#endif // DMLC_USE_CXX11 -#endif // DMLC_CONCURRENCY_H_ diff --git a/include/dmlc/concurrentqueue.h b/include/dmlc/concurrentqueue.h deleted file mode 100644 index f9b7d1147dc5..000000000000 --- a/include/dmlc/concurrentqueue.h +++ /dev/null @@ -1,3719 +0,0 @@ -//! \cond Doxygen_Suppress -// Provides a C++11 implementation of a multi-producer, multi-consumer lock-free queue. -// An overview, including benchmark results, is provided here: -// http://moodycamel.com/blog/2014/a-fast-general-purpose-lock-free-queue-for-c++ -// The full design is also described in excruciating detail at: -// http://moodycamel.com/blog/2014/detailed-design-of-a-lock-free-queue - -// Simplified BSD license: -// Copyright (c) 2013-2016, Cameron Desrochers. -// All rights reserved. -// -// Redistribution and use in source and binary forms, with or without modification, -// are permitted provided that the following conditions are met: -// -// - Redistributions of source code must retain the above copyright notice, this list of -// conditions and the following disclaimer. -// - Redistributions in binary form must reproduce the above copyright notice, this list of -// conditions and the following disclaimer in the documentation and/or other materials -// provided with the distribution. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY -// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL -// THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT -// OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR -// TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, -// EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -#ifndef DMLC_CONCURRENTQUEUE_H_ -#define DMLC_CONCURRENTQUEUE_H_ -#pragma once - -#if defined(__GNUC__) -// Disable -Wconversion warnings (spuriously triggered when Traits::size_t and -// Traits::index_t are set to < 32 bits, causing integer promotion, causing warnings -// upon assigning any computed values) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wconversion" - -#ifdef MCDBGQ_USE_RELACY -#pragma GCC diagnostic ignored "-Wint-to-pointer-cast" -#endif -#endif - -#if defined(_WIN32) || defined(__WINDOWS__) || defined(__WIN32__) || defined(_WIN64) -#include // for GetCurrentThreadId() -#endif - -#if defined(__APPLE__) -#include "TargetConditionals.h" -#endif - -#ifdef MCDBGQ_USE_RELACY -#include "relacy/relacy_std.hpp" -#include "relacy_shims.h" -// We only use malloc/free anyway, and the delete macro messes up `= delete` method declarations. -// We'll override the default trait malloc ourselves without a macro. -#undef new -#undef delete -#undef malloc -#undef free -#else -#include // Requires C++11. Sorry VS2010. -#include -#endif -#include // for max_align_t -#include -#include -#include -#include -#include -#include -#include // for CHAR_BIT -#include -#include // partly for __WINPTHREADS_VERSION if on MinGW-w64 w/ POSIX threading - -namespace dmlc { - -// Platform-specific definitions of a numeric thread ID type and an invalid value -namespace moodycamel { namespace details { -template struct thread_id_converter { - typedef thread_id_t thread_id_numeric_size_t; - typedef thread_id_t thread_id_hash_t; - static thread_id_hash_t prehash(thread_id_t const& x) { return x; } -}; -} } -#if defined(MCDBGQ_USE_RELACY) -namespace moodycamel { namespace details { - typedef std::uint32_t thread_id_t; - static const thread_id_t invalid_thread_id = 0xFFFFFFFFU; - static const thread_id_t invalid_thread_id2 = 0xFFFFFFFEU; - static inline thread_id_t thread_id() { return rl::thread_index(); } -} } -#elif defined(_WIN32) || defined(__WINDOWS__) || defined(__WIN32__) -// No sense pulling in windows.h in a header, we'll manually declare the function -// we use and rely on backwards-compatibility for this not to break -extern "C" __declspec(dllimport) unsigned long __stdcall GetCurrentThreadId(void); -namespace moodycamel { namespace details { - static_assert(sizeof(unsigned long) == sizeof(std::uint32_t), "Expected size of unsigned long to be 32 bits on Windows"); - typedef std::uint32_t thread_id_t; - static const thread_id_t invalid_thread_id = 0; // See http://blogs.msdn.com/b/oldnewthing/archive/2004/02/23/78395.aspx - static const thread_id_t invalid_thread_id2 = 0xFFFFFFFFU; // Not technically guaranteed to be invalid, but is never used in practice. Note that all Win32 thread IDs are presently multiples of 4. - static inline thread_id_t thread_id() { return static_cast(::GetCurrentThreadId()); } -} } -#elif defined(__arm__) || defined(_M_ARM) || defined(__aarch64__) || (defined(__APPLE__) && TARGET_OS_IPHONE) -namespace moodycamel { namespace details { - static_assert(sizeof(std::thread::id) == 4 || sizeof(std::thread::id) == 8, "std::thread::id is expected to be either 4 or 8 bytes"); - - typedef std::thread::id thread_id_t; - static const thread_id_t invalid_thread_id; // Default ctor creates invalid ID - - // Note we don't define a invalid_thread_id2 since std::thread::id doesn't have one; it's - // only used if MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED is defined anyway, which it won't - // be. - static inline thread_id_t thread_id() { return std::this_thread::get_id(); } - - template struct thread_id_size { }; - template<> struct thread_id_size<4> { typedef std::uint32_t numeric_t; }; - template<> struct thread_id_size<8> { typedef std::uint64_t numeric_t; }; - - template<> struct thread_id_converter { - typedef thread_id_size::numeric_t thread_id_numeric_size_t; -#ifndef __APPLE__ - typedef std::size_t thread_id_hash_t; -#else - typedef thread_id_numeric_size_t thread_id_hash_t; -#endif - - static thread_id_hash_t prehash(thread_id_t const& x) - { -#ifndef __APPLE__ - return std::hash()(x); -#else - return *reinterpret_cast(&x); -#endif - } - }; -} } -#else -// Use a nice trick from this answer: http://stackoverflow.com/a/8438730/21475 -// In order to get a numeric thread ID in a platform-independent way, we use a thread-local -// static variable's address as a thread identifier :-) -#if defined(__GNUC__) || defined(__INTEL_COMPILER) -#define MOODYCAMEL_THREADLOCAL __thread -#elif defined(_MSC_VER) -#define MOODYCAMEL_THREADLOCAL __declspec(thread) -#else -// Assume C++11 compliant compiler -#define MOODYCAMEL_THREADLOCAL thread_local -#endif -namespace moodycamel { namespace details { -typedef std::uintptr_t thread_id_t; -static const thread_id_t invalid_thread_id = 0; // Address can't be nullptr -static const thread_id_t invalid_thread_id2 = 1; // Member accesses off a null pointer are also generally invalid. Plus it's not aligned. -static inline thread_id_t thread_id() { static MOODYCAMEL_THREADLOCAL int x; return reinterpret_cast(&x); } -} } -#endif - -// Exceptions -#ifndef MOODYCAMEL_EXCEPTIONS_ENABLED -#if (defined(_MSC_VER) && defined(_CPPUNWIND)) || (defined(__GNUC__) && defined(__EXCEPTIONS)) || (!defined(_MSC_VER) && !defined(__GNUC__)) -#define MOODYCAMEL_EXCEPTIONS_ENABLED -#endif -#endif -#ifdef MOODYCAMEL_EXCEPTIONS_ENABLED -#define MOODYCAMEL_TRY try -#define MOODYCAMEL_CATCH(...) catch(__VA_ARGS__) -#define MOODYCAMEL_RETHROW throw -#define MOODYCAMEL_THROW(expr) throw (expr) -#else -#define MOODYCAMEL_TRY if (true) -#define MOODYCAMEL_CATCH(...) else if (false) -#define MOODYCAMEL_RETHROW -#define MOODYCAMEL_THROW(expr) -#endif - -#ifndef MOODYCAMEL_NOEXCEPT -#if !defined(MOODYCAMEL_EXCEPTIONS_ENABLED) -#define MOODYCAMEL_NOEXCEPT -#define MOODYCAMEL_NOEXCEPT_CTOR(type, valueType, expr) true -#define MOODYCAMEL_NOEXCEPT_ASSIGN(type, valueType, expr) true -#elif defined(_MSC_VER) && defined(_NOEXCEPT) && _MSC_VER < 1800 -// VS2012's std::is_nothrow_[move_]constructible is broken and returns true when it shouldn't :-( -// We have to assume *all* non-trivial constructors may throw on VS2012! -#define MOODYCAMEL_NOEXCEPT _NOEXCEPT -#define MOODYCAMEL_NOEXCEPT_CTOR(type, valueType, expr) (std::is_rvalue_reference::value && std::is_move_constructible::value ? std::is_trivially_move_constructible::value : std::is_trivially_copy_constructible::value) -#define MOODYCAMEL_NOEXCEPT_ASSIGN(type, valueType, expr) ((std::is_rvalue_reference::value && std::is_move_assignable::value ? std::is_trivially_move_assignable::value || std::is_nothrow_move_assignable::value : std::is_trivially_copy_assignable::value || std::is_nothrow_copy_assignable::value) && MOODYCAMEL_NOEXCEPT_CTOR(type, valueType, expr)) -#elif defined(_MSC_VER) && defined(_NOEXCEPT) && _MSC_VER < 1900 -#define MOODYCAMEL_NOEXCEPT _NOEXCEPT -#define MOODYCAMEL_NOEXCEPT_CTOR(type, valueType, expr) (std::is_rvalue_reference::value && std::is_move_constructible::value ? std::is_trivially_move_constructible::value || std::is_nothrow_move_constructible::value : std::is_trivially_copy_constructible::value || std::is_nothrow_copy_constructible::value) -#define MOODYCAMEL_NOEXCEPT_ASSIGN(type, valueType, expr) ((std::is_rvalue_reference::value && std::is_move_assignable::value ? std::is_trivially_move_assignable::value || std::is_nothrow_move_assignable::value : std::is_trivially_copy_assignable::value || std::is_nothrow_copy_assignable::value) && MOODYCAMEL_NOEXCEPT_CTOR(type, valueType, expr)) -#else -#define MOODYCAMEL_NOEXCEPT noexcept -#define MOODYCAMEL_NOEXCEPT_CTOR(type, valueType, expr) noexcept(expr) -#define MOODYCAMEL_NOEXCEPT_ASSIGN(type, valueType, expr) noexcept(expr) -#endif -#endif - -#ifndef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED -#ifdef MCDBGQ_USE_RELACY -#define MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED -#else -// VS2013 doesn't support `thread_local`, and MinGW-w64 w/ POSIX threading has a crippling bug: http://sourceforge.net/p/mingw-w64/bugs/445 -// g++ <=4.7 doesn't support thread_local either. -// Finally, iOS/ARM doesn't have support for it either, and g++/ARM allows it to compile but it's unconfirmed to actually work -#if (!defined(_MSC_VER) || _MSC_VER >= 1900) && (!defined(__MINGW32__) && !defined(__MINGW64__) || !defined(__WINPTHREADS_VERSION)) && (!defined(__GNUC__) || __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8)) && (!defined(__APPLE__) || !TARGET_OS_IPHONE) && !defined(__arm__) && !defined(_M_ARM) && !defined(__aarch64__) -// Assume `thread_local` is fully supported in all other C++11 compilers/platforms -//#define MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED // always disabled for now since several users report having problems with it on -#endif -#endif -#endif - -// VS2012 doesn't support deleted functions. -// In this case, we declare the function normally but don't define it. A link error will be generated if the function is called. -#ifndef MOODYCAMEL_DELETE_FUNCTION -#if defined(_MSC_VER) && _MSC_VER < 1800 -#define MOODYCAMEL_DELETE_FUNCTION -#else -#define MOODYCAMEL_DELETE_FUNCTION = delete -#endif -#endif - -// Compiler-specific likely/unlikely hints -namespace moodycamel { namespace details { -#if defined(__GNUC__) -inline bool likely(bool x) { return __builtin_expect((x), true); } -inline bool unlikely(bool x) { return __builtin_expect((x), false); } -#else -inline bool likely(bool x) { return x; } - inline bool unlikely(bool x) { return x; } -#endif -} } - -#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG -#include "internal/concurrentqueue_internal_debug.h" -#endif - -namespace moodycamel { -namespace details { -template -struct const_numeric_max { - static_assert(std::is_integral::value, "const_numeric_max can only be used with integers"); - static const T value = std::numeric_limits::is_signed - ? (static_cast(1) << (sizeof(T) * CHAR_BIT - 1)) - static_cast(1) - : static_cast(-1); -}; - -#if defined(__GLIBCXX__) -typedef ::max_align_t std_max_align_t; // libstdc++ forgot to add it to std:: for a while -#else -typedef std::max_align_t std_max_align_t; // Others (e.g. MSVC) insist it can *only* be accessed via std:: -#endif - -// Some platforms have incorrectly set max_align_t to a type with <8 bytes alignment even while supporting -// 8-byte aligned scalar values (*cough* 32-bit iOS). Work around this with our own union. See issue #64. -typedef union { - std_max_align_t x; - long long y; - void* z; -} max_align_t; -} - -// Default traits for the ConcurrentQueue. To change some of the -// traits without re-implementing all of them, inherit from this -// struct and shadow the declarations you wish to be different; -// since the traits are used as a template type parameter, the -// shadowed declarations will be used where defined, and the defaults -// otherwise. -struct ConcurrentQueueDefaultTraits -{ - // General-purpose size type. std::size_t is strongly recommended. - typedef std::size_t size_t; - - // The type used for the enqueue and dequeue indices. Must be at least as - // large as size_t. Should be significantly larger than the number of elements - // you expect to hold at once, especially if you have a high turnover rate; - // for example, on 32-bit x86, if you expect to have over a hundred million - // elements or pump several million elements through your queue in a very - // short space of time, using a 32-bit type *may* trigger a race condition. - // A 64-bit int type is recommended in that case, and in practice will - // prevent a race condition no matter the usage of the queue. Note that - // whether the queue is lock-free with a 64-int type depends on the whether - // std::atomic is lock-free, which is platform-specific. - typedef std::size_t index_t; - - // Internally, all elements are enqueued and dequeued from multi-element - // blocks; this is the smallest controllable unit. If you expect few elements - // but many producers, a smaller block size should be favoured. For few producers - // and/or many elements, a larger block size is preferred. A sane default - // is provided. Must be a power of 2. - static const size_t BLOCK_SIZE = 32; - - // For explicit producers (i.e. when using a producer token), the block is - // checked for being empty by iterating through a list of flags, one per element. - // For large block sizes, this is too inefficient, and switching to an atomic - // counter-based approach is faster. The switch is made for block sizes strictly - // larger than this threshold. - static const size_t EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD = 32; - - // How many full blocks can be expected for a single explicit producer? This should - // reflect that number's maximum for optimal performance. Must be a power of 2. - static const size_t EXPLICIT_INITIAL_INDEX_SIZE = 32; - - // How many full blocks can be expected for a single implicit producer? This should - // reflect that number's maximum for optimal performance. Must be a power of 2. - static const size_t IMPLICIT_INITIAL_INDEX_SIZE = 32; - - // The initial size of the hash table mapping thread IDs to implicit producers. - // Note that the hash is resized every time it becomes half full. - // Must be a power of two, and either 0 or at least 1. If 0, implicit production - // (using the enqueue methods without an explicit producer token) is disabled. - static const size_t INITIAL_IMPLICIT_PRODUCER_HASH_SIZE = 32; - - // Controls the number of items that an explicit consumer (i.e. one with a token) - // must consume before it causes all consumers to rotate and move on to the next - // internal queue. - static const std::uint32_t EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE = 256; - - // The maximum number of elements (inclusive) that can be enqueued to a sub-queue. - // Enqueue operations that would cause this limit to be surpassed will fail. Note - // that this limit is enforced at the block level (for performance reasons), i.e. - // it's rounded up to the nearest block size. - static const size_t MAX_SUBQUEUE_SIZE = details::const_numeric_max::value; - - -#ifndef MCDBGQ_USE_RELACY - // Memory allocation can be customized if needed. - // malloc should return nullptr on failure, and handle alignment like std::malloc. -#if defined(malloc) || defined(free) - // Gah, this is 2015, stop defining macros that break standard code already! - // Work around malloc/free being special macros: - static inline void* WORKAROUND_malloc(size_t size) { return malloc(size); } - static inline void WORKAROUND_free(void* ptr) { return free(ptr); } - static inline void* (malloc)(size_t size) { return WORKAROUND_malloc(size); } - static inline void (free)(void* ptr) { return WORKAROUND_free(ptr); } -#else - static inline void* malloc(size_t size) { return std::malloc(size); } - static inline void free(void* ptr) { return std::free(ptr); } -#endif -#else - // Debug versions when running under the Relacy race detector (ignore - // these in user code) - static inline void* malloc(size_t size) { return rl::rl_malloc(size, $); } - static inline void free(void* ptr) { return rl::rl_free(ptr, $); } -#endif -}; - - -// When producing or consuming many elements, the most efficient way is to: -// 1) Use one of the bulk-operation methods of the queue with a token -// 2) Failing that, use the bulk-operation methods without a token -// 3) Failing that, create a token and use that with the single-item methods -// 4) Failing that, use the single-parameter methods of the queue -// Having said that, don't create tokens willy-nilly -- ideally there should be -// a maximum of one token per thread (of each kind). -struct ProducerToken; -struct ConsumerToken; - -template class ConcurrentQueue; -template class BlockingConcurrentQueue; -class ConcurrentQueueTests; - - -namespace details -{ -struct ConcurrentQueueProducerTypelessBase -{ - ConcurrentQueueProducerTypelessBase* next; - std::atomic inactive; - ProducerToken* token; - - ConcurrentQueueProducerTypelessBase() - : next(nullptr), inactive(false), token(nullptr) - { - } -}; - -template struct _hash_32_or_64 { - static inline std::uint32_t hash(std::uint32_t h) - { - // MurmurHash3 finalizer -- see https://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.cpp - // Since the thread ID is already unique, all we really want to do is propagate that - // uniqueness evenly across all the bits, so that we can use a subset of the bits while - // reducing collisions significantly - h ^= h >> 16; - h *= 0x85ebca6b; - h ^= h >> 13; - h *= 0xc2b2ae35; - return h ^ (h >> 16); - } -}; -template<> struct _hash_32_or_64<1> { - static inline std::uint64_t hash(std::uint64_t h) - { - h ^= h >> 33; - h *= 0xff51afd7ed558ccd; - h ^= h >> 33; - h *= 0xc4ceb9fe1a85ec53; - return h ^ (h >> 33); - } -}; -template struct hash_32_or_64 : public _hash_32_or_64<(size > 4)> { }; - -static inline size_t hash_thread_id(thread_id_t id) -{ - static_assert(sizeof(thread_id_t) <= 8, "Expected a platform where thread IDs are at most 64-bit values"); - return static_cast(hash_32_or_64::thread_id_hash_t)>::hash( - thread_id_converter::prehash(id))); -} - -template -static inline bool circular_less_than(T a, T b) -{ -#ifdef _MSC_VER - #pragma warning(push) -#pragma warning(disable: 4554) -#endif - static_assert(std::is_integral::value && !std::numeric_limits::is_signed, "circular_less_than is intended to be used only with unsigned integer types"); - return static_cast(a - b) > static_cast(static_cast(1) << static_cast(sizeof(T) * CHAR_BIT - 1)); -#ifdef _MSC_VER -#pragma warning(pop) -#endif -} - -template -static inline char* align_for(char* ptr) -{ - const std::size_t alignment = std::alignment_of::value; - return ptr + (alignment - (reinterpret_cast(ptr) % alignment)) % alignment; -} - -template -static inline T ceil_to_pow_2(T x) -{ - static_assert(std::is_integral::value && !std::numeric_limits::is_signed, "ceil_to_pow_2 is intended to be used only with unsigned integer types"); - - // Adapted from http://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 - --x; - x |= x >> 1; - x |= x >> 2; - x |= x >> 4; - for (std::size_t i = 1; i < sizeof(T); i <<= 1) { - x |= x >> (i << 3); - } - ++x; - return x; -} - -template -static inline void swap_relaxed(std::atomic& left, std::atomic& right) -{ - T temp = std::move(left.load(std::memory_order_relaxed)); - left.store(std::move(right.load(std::memory_order_relaxed)), std::memory_order_relaxed); - right.store(std::move(temp), std::memory_order_relaxed); -} - -template -static inline T const& nomove(T const& x) -{ - return x; -} - -template -struct nomove_if -{ - template - static inline T const& eval(T const& x) - { - return x; - } -}; - -template<> -struct nomove_if -{ - template - static inline auto eval(U&& x) - -> decltype(std::forward(x)) - { - return std::forward(x); - } -}; - -template -static inline auto deref_noexcept(It& it) MOODYCAMEL_NOEXCEPT -> decltype(*it) -{ - return *it; -} - -#if defined(__clang__) || !defined(__GNUC__) || __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8) -template struct is_trivially_destructible : std::is_trivially_destructible { }; -#else -template struct is_trivially_destructible : std::has_trivial_destructor { }; -#endif - -#ifdef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED -#ifdef MCDBGQ_USE_RELACY - typedef RelacyThreadExitListener ThreadExitListener; - typedef RelacyThreadExitNotifier ThreadExitNotifier; -#else - struct ThreadExitListener - { - typedef void (*callback_t)(void*); - callback_t callback; - void* userData; - - ThreadExitListener* next; // reserved for use by the ThreadExitNotifier - }; - - - class ThreadExitNotifier - { - public: - static void subscribe(ThreadExitListener* listener) - { - auto& tlsInst = instance(); - listener->next = tlsInst.tail; - tlsInst.tail = listener; - } - - static void unsubscribe(ThreadExitListener* listener) - { - auto& tlsInst = instance(); - ThreadExitListener** prev = &tlsInst.tail; - for (auto ptr = tlsInst.tail; ptr != nullptr; ptr = ptr->next) { - if (ptr == listener) { - *prev = ptr->next; - break; - } - prev = &ptr->next; - } - } - - private: - ThreadExitNotifier() : tail(nullptr) { } - ThreadExitNotifier(ThreadExitNotifier const&) MOODYCAMEL_DELETE_FUNCTION; - ThreadExitNotifier& operator=(ThreadExitNotifier const&) MOODYCAMEL_DELETE_FUNCTION; - - ~ThreadExitNotifier() - { - // This thread is about to exit, let everyone know! - assert(this == &instance() && "If this assert fails, you likely have a buggy compiler! Change the preprocessor conditions such that MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED is no longer defined."); - for (auto ptr = tail; ptr != nullptr; ptr = ptr->next) { - ptr->callback(ptr->userData); - } - } - - // Thread-local - static inline ThreadExitNotifier& instance() - { - static thread_local ThreadExitNotifier notifier; - return notifier; - } - - private: - ThreadExitListener* tail; - }; -#endif -#endif - -template struct static_is_lock_free_num { enum { value = 0 }; }; -template<> struct static_is_lock_free_num { enum { value = ATOMIC_CHAR_LOCK_FREE }; }; -template<> struct static_is_lock_free_num { enum { value = ATOMIC_SHORT_LOCK_FREE }; }; -template<> struct static_is_lock_free_num { enum { value = ATOMIC_INT_LOCK_FREE }; }; -template<> struct static_is_lock_free_num { enum { value = ATOMIC_LONG_LOCK_FREE }; }; -template<> struct static_is_lock_free_num { enum { value = ATOMIC_LLONG_LOCK_FREE }; }; -template struct static_is_lock_free : static_is_lock_free_num::type> { }; -template<> struct static_is_lock_free { enum { value = ATOMIC_BOOL_LOCK_FREE }; }; -template struct static_is_lock_free { enum { value = ATOMIC_POINTER_LOCK_FREE }; }; -} - - -struct ProducerToken -{ - template - explicit ProducerToken(ConcurrentQueue& queue); - - template - explicit ProducerToken(BlockingConcurrentQueue& queue); - - ProducerToken(ProducerToken&& other) MOODYCAMEL_NOEXCEPT - : producer(other.producer) - { - other.producer = nullptr; - if (producer != nullptr) { - producer->token = this; - } - } - - inline ProducerToken& operator=(ProducerToken&& other) MOODYCAMEL_NOEXCEPT - { - swap(other); - return *this; - } - - void swap(ProducerToken& other) MOODYCAMEL_NOEXCEPT - { - std::swap(producer, other.producer); - if (producer != nullptr) { - producer->token = this; - } - if (other.producer != nullptr) { - other.producer->token = &other; - } - } - - // A token is always valid unless: - // 1) Memory allocation failed during construction - // 2) It was moved via the move constructor - // (Note: assignment does a swap, leaving both potentially valid) - // 3) The associated queue was destroyed - // Note that if valid() returns true, that only indicates - // that the token is valid for use with a specific queue, - // but not which one; that's up to the user to track. - inline bool valid() const { return producer != nullptr; } - - ~ProducerToken() - { - if (producer != nullptr) { - producer->token = nullptr; - producer->inactive.store(true, std::memory_order_release); - } - } - - // Disable copying and assignment - ProducerToken(ProducerToken const&) MOODYCAMEL_DELETE_FUNCTION; - ProducerToken& operator=(ProducerToken const&) MOODYCAMEL_DELETE_FUNCTION; - - private: - template friend class ConcurrentQueue; - friend class ConcurrentQueueTests; - - protected: - details::ConcurrentQueueProducerTypelessBase* producer; -}; - - -struct ConsumerToken -{ - template - explicit ConsumerToken(ConcurrentQueue& q); - - template - explicit ConsumerToken(BlockingConcurrentQueue& q); - - ConsumerToken(ConsumerToken&& other) MOODYCAMEL_NOEXCEPT - : initialOffset(other.initialOffset), lastKnownGlobalOffset(other.lastKnownGlobalOffset), itemsConsumedFromCurrent(other.itemsConsumedFromCurrent), currentProducer(other.currentProducer), desiredProducer(other.desiredProducer) - { - } - - inline ConsumerToken& operator=(ConsumerToken&& other) MOODYCAMEL_NOEXCEPT - { - swap(other); - return *this; - } - - void swap(ConsumerToken& other) MOODYCAMEL_NOEXCEPT - { - std::swap(initialOffset, other.initialOffset); - std::swap(lastKnownGlobalOffset, other.lastKnownGlobalOffset); - std::swap(itemsConsumedFromCurrent, other.itemsConsumedFromCurrent); - std::swap(currentProducer, other.currentProducer); - std::swap(desiredProducer, other.desiredProducer); - } - - // Disable copying and assignment - ConsumerToken(ConsumerToken const&) MOODYCAMEL_DELETE_FUNCTION; - ConsumerToken& operator=(ConsumerToken const&) MOODYCAMEL_DELETE_FUNCTION; - - private: - template friend class ConcurrentQueue; - friend class ConcurrentQueueTests; - - private: // but shared with ConcurrentQueue - std::uint32_t initialOffset; - std::uint32_t lastKnownGlobalOffset; - std::uint32_t itemsConsumedFromCurrent; - details::ConcurrentQueueProducerTypelessBase* currentProducer; - details::ConcurrentQueueProducerTypelessBase* desiredProducer; -}; - -// Need to forward-declare this swap because it's in a namespace. -// See http://stackoverflow.com/questions/4492062/why-does-a-c-friend-class-need-a-forward-declaration-only-in-other-namespaces -template -inline void swap(typename ConcurrentQueue::ImplicitProducerKVP& a, typename ConcurrentQueue::ImplicitProducerKVP& b) MOODYCAMEL_NOEXCEPT; - - -template -class ConcurrentQueue { - public: - typedef ::dmlc::moodycamel::ProducerToken producer_token_t; - typedef ::dmlc::moodycamel::ConsumerToken consumer_token_t; - - typedef typename Traits::index_t index_t; - typedef typename Traits::size_t size_t; - - static const size_t BLOCK_SIZE = static_cast(Traits::BLOCK_SIZE); - static const size_t EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD = static_cast(Traits::EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD); - static const size_t EXPLICIT_INITIAL_INDEX_SIZE = static_cast(Traits::EXPLICIT_INITIAL_INDEX_SIZE); - static const size_t IMPLICIT_INITIAL_INDEX_SIZE = static_cast(Traits::IMPLICIT_INITIAL_INDEX_SIZE); - static const size_t INITIAL_IMPLICIT_PRODUCER_HASH_SIZE = static_cast(Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE); - static const std::uint32_t EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE = static_cast(Traits::EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE); -#ifdef _MSC_VER - #pragma warning(push) -#pragma warning(disable: 4307) // + integral constant overflow (that's what the ternary expression is for!) -#pragma warning(disable: 4309) // static_cast: Truncation of constant value -#endif - static const size_t MAX_SUBQUEUE_SIZE = (details::const_numeric_max::value - - static_cast(Traits::MAX_SUBQUEUE_SIZE) < - BLOCK_SIZE) ? details::const_numeric_max::value - : ( - (static_cast(Traits::MAX_SUBQUEUE_SIZE) + - (BLOCK_SIZE - 1)) / BLOCK_SIZE * BLOCK_SIZE); -#ifdef _MSC_VER -#pragma warning(pop) -#endif - - static_assert(!std::numeric_limits::is_signed && std::is_integral::value, - "Traits::size_t must be an unsigned integral type"); - static_assert(!std::numeric_limits::is_signed && std::is_integral::value, - "Traits::index_t must be an unsigned integral type"); - static_assert(sizeof(index_t) >= sizeof(size_t), - "Traits::index_t must be at least as wide as Traits::size_t"); - static_assert((BLOCK_SIZE > 1) && !(BLOCK_SIZE & (BLOCK_SIZE - 1)), - "Traits::BLOCK_SIZE must be a power of 2 (and at least 2)"); - static_assert((EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD > 1) && - !(EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD & - (EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD - 1)), - "Traits::EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD must be a power of 2 (and greater than 1)"); - static_assert((EXPLICIT_INITIAL_INDEX_SIZE > 1) && - !(EXPLICIT_INITIAL_INDEX_SIZE & (EXPLICIT_INITIAL_INDEX_SIZE - 1)), - "Traits::EXPLICIT_INITIAL_INDEX_SIZE must be a power of 2 (and greater than 1)"); - static_assert((IMPLICIT_INITIAL_INDEX_SIZE > 1) && - !(IMPLICIT_INITIAL_INDEX_SIZE & (IMPLICIT_INITIAL_INDEX_SIZE - 1)), - "Traits::IMPLICIT_INITIAL_INDEX_SIZE must be a power of 2 (and greater than 1)"); - static_assert((INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) || - !(INITIAL_IMPLICIT_PRODUCER_HASH_SIZE & (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE - 1)), - "Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE must be a power of 2"); - static_assert( - INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0 || INITIAL_IMPLICIT_PRODUCER_HASH_SIZE >= 1, - "Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE must be at least 1 (or 0 to disable implicit enqueueing)"); - - public: - // Creates a queue with at least `capacity` element slots; note that the - // actual number of elements that can be inserted without additional memory - // allocation depends on the number of producers and the block size (e.g. if - // the block size is equal to `capacity`, only a single block will be allocated - // up-front, which means only a single producer will be able to enqueue elements - // without an extra allocation -- blocks aren't shared between producers). - // This method is not thread safe -- it is up to the user to ensure that the - // queue is fully constructed before it starts being used by other threads (this - // includes making the memory effects of construction visible, possibly with a - // memory barrier). - explicit ConcurrentQueue(size_t capacity = 6 * BLOCK_SIZE) - : producerListTail(nullptr), producerCount(0), initialBlockPoolIndex(0), nextExplicitConsumerId( - 0), globalExplicitConsumerOffset(0) { - implicitProducerHashResizeInProgress.clear(std::memory_order_relaxed); - populate_initial_implicit_producer_hash(); - populate_initial_block_list( - capacity / BLOCK_SIZE + ((capacity & (BLOCK_SIZE - 1)) == 0 ? 0 : 1)); - -#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG - // Track all the producers using a fully-resolved typed list for - // each kind; this makes it possible to debug them starting from - // the root queue object (otherwise wacky casts are needed that - // don't compile in the debugger's expression evaluator). - explicitProducers.store(nullptr, std::memory_order_relaxed); - implicitProducers.store(nullptr, std::memory_order_relaxed); -#endif - } - - // Computes the correct amount of pre-allocated blocks for you based - // on the minimum number of elements you want available at any given - // time, and the maximum concurrent number of each type of producer. - ConcurrentQueue(size_t minCapacity, size_t maxExplicitProducers, size_t maxImplicitProducers) - : producerListTail(nullptr), producerCount(0), initialBlockPoolIndex(0), nextExplicitConsumerId( - 0), globalExplicitConsumerOffset(0) { - implicitProducerHashResizeInProgress.clear(std::memory_order_relaxed); - populate_initial_implicit_producer_hash(); - size_t blocks = - (((minCapacity + BLOCK_SIZE - 1) / BLOCK_SIZE) - 1) * (maxExplicitProducers + 1) + - 2 * (maxExplicitProducers + maxImplicitProducers); - populate_initial_block_list(blocks); - -#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG - explicitProducers.store(nullptr, std::memory_order_relaxed); - implicitProducers.store(nullptr, std::memory_order_relaxed); -#endif - } - - // Note: The queue should not be accessed concurrently while it's - // being deleted. It's up to the user to synchronize this. - // This method is not thread safe. - ~ConcurrentQueue() { - // Destroy producers - auto ptr = producerListTail.load(std::memory_order_relaxed); - while (ptr != nullptr) { - auto next = ptr->next_prod(); - if (ptr->token != nullptr) { - ptr->token->producer = nullptr; - } - destroy(ptr); - ptr = next; - } - - // Destroy implicit producer hash tables - if (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE != 0) { - auto hash = implicitProducerHash.load(std::memory_order_relaxed); - while (hash != nullptr) { - auto prev = hash->prev; - if (prev != - nullptr) { // The last hash is part of this object and was not allocated dynamically - for (size_t i = 0; i != hash->capacity; ++i) { - hash->entries[i].~ImplicitProducerKVP(); - } - hash->~ImplicitProducerHash(); - (Traits::free)(hash); - } - hash = prev; - } - } - - // Destroy global free list - auto block = freeList.head_unsafe(); - while (block != nullptr) { - auto next = block->freeListNext.load(std::memory_order_relaxed); - if (block->dynamicallyAllocated) { - destroy(block); - } - block = next; - } - - // Destroy initial free list - destroy_array(initialBlockPool, initialBlockPoolSize); - } - - // Disable copying and copy assignment - ConcurrentQueue(ConcurrentQueue const &) MOODYCAMEL_DELETE_FUNCTION; - - ConcurrentQueue &operator=(ConcurrentQueue const &) MOODYCAMEL_DELETE_FUNCTION; - - // Moving is supported, but note that it is *not* a thread-safe operation. - // Nobody can use the queue while it's being moved, and the memory effects - // of that move must be propagated to other threads before they can use it. - // Note: When a queue is moved, its tokens are still valid but can only be - // used with the destination queue (i.e. semantically they are moved along - // with the queue itself). - ConcurrentQueue(ConcurrentQueue &&other) MOODYCAMEL_NOEXCEPT - : producerListTail(other.producerListTail.load(std::memory_order_relaxed)), producerCount( - other.producerCount.load(std::memory_order_relaxed)), initialBlockPoolIndex( - other.initialBlockPoolIndex.load(std::memory_order_relaxed)), initialBlockPool( - other.initialBlockPool), initialBlockPoolSize(other.initialBlockPoolSize), freeList( - std::move(other.freeList)), nextExplicitConsumerId( - other.nextExplicitConsumerId.load(std::memory_order_relaxed)), globalExplicitConsumerOffset( - other.globalExplicitConsumerOffset.load(std::memory_order_relaxed)) { - // Move the other one into this, and leave the other one as an empty queue - implicitProducerHashResizeInProgress.clear(std::memory_order_relaxed); - populate_initial_implicit_producer_hash(); - swap_implicit_producer_hashes(other); - - other.producerListTail.store(nullptr, std::memory_order_relaxed); - other.producerCount.store(0, std::memory_order_relaxed); - other.nextExplicitConsumerId.store(0, std::memory_order_relaxed); - other.globalExplicitConsumerOffset.store(0, std::memory_order_relaxed); - -#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG - explicitProducers.store(other.explicitProducers.load(std::memory_order_relaxed), std::memory_order_relaxed); - other.explicitProducers.store(nullptr, std::memory_order_relaxed); - implicitProducers.store(other.implicitProducers.load(std::memory_order_relaxed), std::memory_order_relaxed); - other.implicitProducers.store(nullptr, std::memory_order_relaxed); -#endif - - other.initialBlockPoolIndex.store(0, std::memory_order_relaxed); - other.initialBlockPoolSize = 0; - other.initialBlockPool = nullptr; - - reown_producers(); - } - - inline ConcurrentQueue &operator=(ConcurrentQueue &&other) MOODYCAMEL_NOEXCEPT { - return swap_internal(other); - } - - // Swaps this queue's state with the other's. Not thread-safe. - // Swapping two queues does not invalidate their tokens, however - // the tokens that were created for one queue must be used with - // only the swapped queue (i.e. the tokens are tied to the - // queue's movable state, not the object itself). - inline void swap(ConcurrentQueue &other) MOODYCAMEL_NOEXCEPT { - swap_internal(other); - } - - private: - ConcurrentQueue &swap_internal(ConcurrentQueue &other) { - if (this == &other) { - return *this; - } - - details::swap_relaxed(producerListTail, other.producerListTail); - details::swap_relaxed(producerCount, other.producerCount); - details::swap_relaxed(initialBlockPoolIndex, other.initialBlockPoolIndex); - std::swap(initialBlockPool, other.initialBlockPool); - std::swap(initialBlockPoolSize, other.initialBlockPoolSize); - freeList.swap(other.freeList); - details::swap_relaxed(nextExplicitConsumerId, other.nextExplicitConsumerId); - details::swap_relaxed(globalExplicitConsumerOffset, other.globalExplicitConsumerOffset); - - swap_implicit_producer_hashes(other); - - reown_producers(); - other.reown_producers(); - -#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG - details::swap_relaxed(explicitProducers, other.explicitProducers); - details::swap_relaxed(implicitProducers, other.implicitProducers); -#endif - - return *this; - } - - public: - // Enqueues a single item (by copying it). - // Allocates memory if required. Only fails if memory allocation fails (or implicit - // production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0, - // or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). - // Thread-safe. - inline bool enqueue(T const &item) { - if (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) return false; - return inner_enqueue(item); - } - - // Enqueues a single item (by moving it, if possible). - // Allocates memory if required. Only fails if memory allocation fails (or implicit - // production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0, - // or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). - // Thread-safe. - inline bool enqueue(T &&item) { - if (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) return false; - return inner_enqueue(std::move(item)); - } - - // Enqueues a single item (by copying it) using an explicit producer token. - // Allocates memory if required. Only fails if memory allocation fails (or - // Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). - // Thread-safe. - inline bool enqueue(producer_token_t const &token, T const &item) { - return inner_enqueue(token, item); - } - - // Enqueues a single item (by moving it, if possible) using an explicit producer token. - // Allocates memory if required. Only fails if memory allocation fails (or - // Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). - // Thread-safe. - inline bool enqueue(producer_token_t const &token, T &&item) { - return inner_enqueue(token, std::move(item)); - } - - // Enqueues several items. - // Allocates memory if required. Only fails if memory allocation fails (or - // implicit production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE - // is 0, or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). - // Note: Use std::make_move_iterator if the elements should be moved instead of copied. - // Thread-safe. - template - bool enqueue_bulk(It itemFirst, size_t count) { - if (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) return false; - return inner_enqueue_bulk(itemFirst, count); - } - - // Enqueues several items using an explicit producer token. - // Allocates memory if required. Only fails if memory allocation fails - // (or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). - // Note: Use std::make_move_iterator if the elements should be moved - // instead of copied. - // Thread-safe. - template - bool enqueue_bulk(producer_token_t const &token, It itemFirst, size_t count) { - return inner_enqueue_bulk(token, itemFirst, count); - } - - // Enqueues a single item (by copying it). - // Does not allocate memory. Fails if not enough room to enqueue (or implicit - // production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE - // is 0). - // Thread-safe. - inline bool try_enqueue(T const &item) { - if (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) return false; - return inner_enqueue(item); - } - - // Enqueues a single item (by moving it, if possible). - // Does not allocate memory (except for one-time implicit producer). - // Fails if not enough room to enqueue (or implicit production is - // disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0). - // Thread-safe. - inline bool try_enqueue(T &&item) { - if (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) return false; - return inner_enqueue(std::move(item)); - } - - // Enqueues a single item (by copying it) using an explicit producer token. - // Does not allocate memory. Fails if not enough room to enqueue. - // Thread-safe. - inline bool try_enqueue(producer_token_t const &token, T const &item) { - return inner_enqueue(token, item); - } - - // Enqueues a single item (by moving it, if possible) using an explicit producer token. - // Does not allocate memory. Fails if not enough room to enqueue. - // Thread-safe. - inline bool try_enqueue(producer_token_t const &token, T &&item) { - return inner_enqueue(token, std::move(item)); - } - - // Enqueues several items. - // Does not allocate memory (except for one-time implicit producer). - // Fails if not enough room to enqueue (or implicit production is - // disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0). - // Note: Use std::make_move_iterator if the elements should be moved - // instead of copied. - // Thread-safe. - template - bool try_enqueue_bulk(It itemFirst, size_t count) { - if (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) return false; - return inner_enqueue_bulk(itemFirst, count); - } - - // Enqueues several items using an explicit producer token. - // Does not allocate memory. Fails if not enough room to enqueue. - // Note: Use std::make_move_iterator if the elements should be moved - // instead of copied. - // Thread-safe. - template - bool try_enqueue_bulk(producer_token_t const &token, It itemFirst, size_t count) { - return inner_enqueue_bulk(token, itemFirst, count); - } - - - // Attempts to dequeue from the queue. - // Returns false if all producer streams appeared empty at the time they - // were checked (so, the queue is likely but not guaranteed to be empty). - // Never allocates. Thread-safe. - template - bool try_dequeue(U &item) { - // Instead of simply trying each producer in turn (which could cause needless contention on the first - // producer), we score them heuristically. - size_t nonEmptyCount = 0; - ProducerBase *best = nullptr; - size_t bestSize = 0; - for (auto ptr = producerListTail.load(std::memory_order_acquire); - nonEmptyCount < 3 && ptr != nullptr; ptr = ptr->next_prod()) { - auto size = ptr->size_approx(); - if (size > 0) { - if (size > bestSize) { - bestSize = size; - best = ptr; - } - ++nonEmptyCount; - } - } - - // If there was at least one non-empty queue but it appears empty at the time - // we try to dequeue from it, we need to make sure every queue's been tried - if (nonEmptyCount > 0) { - if (details::likely(best->dequeue(item))) { - return true; - } - for (auto ptr = producerListTail.load(std::memory_order_acquire); - ptr != nullptr; ptr = ptr->next_prod()) { - if (ptr != best && ptr->dequeue(item)) { - return true; - } - } - } - return false; - } - - // Attempts to dequeue from the queue. - // Returns false if all producer streams appeared empty at the time they - // were checked (so, the queue is likely but not guaranteed to be empty). - // This differs from the try_dequeue(item) method in that this one does - // not attempt to reduce contention by interleaving the order that producer - // streams are dequeued from. So, using this method can reduce overall throughput - // under contention, but will give more predictable results in single-threaded - // consumer scenarios. This is mostly only useful for internal unit tests. - // Never allocates. Thread-safe. - template - bool try_dequeue_non_interleaved(U &item) { - for (auto ptr = producerListTail.load(std::memory_order_acquire); - ptr != nullptr; ptr = ptr->next_prod()) { - if (ptr->dequeue(item)) { - return true; - } - } - return false; - } - - // Attempts to dequeue from the queue using an explicit consumer token. - // Returns false if all producer streams appeared empty at the time they - // were checked (so, the queue is likely but not guaranteed to be empty). - // Never allocates. Thread-safe. - template - bool try_dequeue(consumer_token_t &token, U &item) { - // The idea is roughly as follows: - // Every 256 items from one producer, make everyone rotate (increase the global offset) -> this means the highest efficiency consumer dictates the rotation speed of everyone else, more or less - // If you see that the global offset has changed, you must reset your consumption counter and move to your designated place - // If there's no items where you're supposed to be, keep moving until you find a producer with some items - // If the global offset has not changed but you've run out of items to consume, move over from your current position until you find an producer with something in it - - if (token.desiredProducer == nullptr || token.lastKnownGlobalOffset != - globalExplicitConsumerOffset.load( - std::memory_order_relaxed)) { - if (!update_current_producer_after_rotation(token)) { - return false; - } - } - - // If there was at least one non-empty queue but it appears empty at the time - // we try to dequeue from it, we need to make sure every queue's been tried - if (static_cast(token.currentProducer)->dequeue(item)) { - if (++token.itemsConsumedFromCurrent == EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE) { - globalExplicitConsumerOffset.fetch_add(1, std::memory_order_relaxed); - } - return true; - } - - auto tail = producerListTail.load(std::memory_order_acquire); - auto ptr = static_cast(token.currentProducer)->next_prod(); - if (ptr == nullptr) { - ptr = tail; - } - while (ptr != static_cast(token.currentProducer)) { - if (ptr->dequeue(item)) { - token.currentProducer = ptr; - token.itemsConsumedFromCurrent = 1; - return true; - } - ptr = ptr->next_prod(); - if (ptr == nullptr) { - ptr = tail; - } - } - return false; - } - - // Attempts to dequeue several elements from the queue. - // Returns the number of items actually dequeued. - // Returns 0 if all producer streams appeared empty at the time they - // were checked (so, the queue is likely but not guaranteed to be empty). - // Never allocates. Thread-safe. - template - size_t try_dequeue_bulk(It itemFirst, size_t max) { - size_t count = 0; - for (auto ptr = producerListTail.load(std::memory_order_acquire); - ptr != nullptr; ptr = ptr->next_prod()) { - count += ptr->dequeue_bulk(itemFirst, max - count); - if (count == max) { - break; - } - } - return count; - } - - // Attempts to dequeue several elements from the queue using an explicit consumer token. - // Returns the number of items actually dequeued. - // Returns 0 if all producer streams appeared empty at the time they - // were checked (so, the queue is likely but not guaranteed to be empty). - // Never allocates. Thread-safe. - template - size_t try_dequeue_bulk(consumer_token_t &token, It itemFirst, size_t max) { - if (token.desiredProducer == nullptr || token.lastKnownGlobalOffset != - globalExplicitConsumerOffset.load( - std::memory_order_relaxed)) { - if (!update_current_producer_after_rotation(token)) { - return 0; - } - } - - size_t count = static_cast(token.currentProducer)->dequeue_bulk(itemFirst, max); - if (count == max) { - if ((token.itemsConsumedFromCurrent += static_cast(max)) >= - EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE) { - globalExplicitConsumerOffset.fetch_add(1, std::memory_order_relaxed); - } - return max; - } - token.itemsConsumedFromCurrent += static_cast(count); - max -= count; - - auto tail = producerListTail.load(std::memory_order_acquire); - auto ptr = static_cast(token.currentProducer)->next_prod(); - if (ptr == nullptr) { - ptr = tail; - } - while (ptr != static_cast(token.currentProducer)) { - auto dequeued = ptr->dequeue_bulk(itemFirst, max); - count += dequeued; - if (dequeued != 0) { - token.currentProducer = ptr; - token.itemsConsumedFromCurrent = static_cast(dequeued); - } - if (dequeued == max) { - break; - } - max -= dequeued; - ptr = ptr->next_prod(); - if (ptr == nullptr) { - ptr = tail; - } - } - return count; - } - - - // Attempts to dequeue from a specific producer's inner queue. - // If you happen to know which producer you want to dequeue from, this - // is significantly faster than using the general-case try_dequeue methods. - // Returns false if the producer's queue appeared empty at the time it - // was checked (so, the queue is likely but not guaranteed to be empty). - // Never allocates. Thread-safe. - template - inline bool try_dequeue_from_producer(producer_token_t const &producer, U &item) { - return static_cast(producer.producer)->dequeue(item); - } - - // Attempts to dequeue several elements from a specific producer's inner queue. - // Returns the number of items actually dequeued. - // If you happen to know which producer you want to dequeue from, this - // is significantly faster than using the general-case try_dequeue methods. - // Returns 0 if the producer's queue appeared empty at the time it - // was checked (so, the queue is likely but not guaranteed to be empty). - // Never allocates. Thread-safe. - template - inline size_t - try_dequeue_bulk_from_producer(producer_token_t const &producer, It itemFirst, size_t max) { - return static_cast(producer.producer)->dequeue_bulk(itemFirst, max); - } - - - // Returns an estimate of the total number of elements currently in the queue. This - // estimate is only accurate if the queue has completely stabilized before it is called - // (i.e. all enqueue and dequeue operations have completed and their memory effects are - // visible on the calling thread, and no further operations start while this method is - // being called). - // Thread-safe. - size_t size_approx() const { - size_t size = 0; - for (auto ptr = producerListTail.load(std::memory_order_acquire); - ptr != nullptr; ptr = ptr->next_prod()) { - size += ptr->size_approx(); - } - return size; - } - - - // Returns true if the underlying atomic variables used by - // the queue are lock-free (they should be on most platforms). - // Thread-safe. - static bool is_lock_free() { - return - details::static_is_lock_free::value == 2 && - details::static_is_lock_free::value == 2 && - details::static_is_lock_free::value == 2 && - details::static_is_lock_free::value == 2 && - details::static_is_lock_free::value == 2 && - details::static_is_lock_free::thread_id_numeric_size_t>::value == - 2; - } - - - private: - friend struct ProducerToken; - friend struct ConsumerToken; - friend struct ExplicitProducer; - - friend class ConcurrentQueueTests; - - enum AllocationMode { - CanAlloc, CannotAlloc - }; - - - /////////////////////////////// - // Queue methods - /////////////////////////////// - - template - inline bool inner_enqueue(producer_token_t const &token, U &&element) { - return static_cast(token.producer)->ConcurrentQueue::ExplicitProducer::template enqueue( - std::forward(element)); - } - - template - inline bool inner_enqueue(U &&element) { - auto producer = get_or_add_implicit_producer(); - return producer == nullptr ? false - : producer->ConcurrentQueue::ImplicitProducer::template enqueue( - std::forward(element)); - } - - template - inline bool inner_enqueue_bulk(producer_token_t const &token, It itemFirst, size_t count) { - return static_cast(token.producer)->ConcurrentQueue::ExplicitProducer::template enqueue_bulk( - itemFirst, count); - } - - template - inline bool inner_enqueue_bulk(It itemFirst, size_t count) { - auto producer = get_or_add_implicit_producer(); - return producer == nullptr ? false - : producer->ConcurrentQueue::ImplicitProducer::template enqueue_bulk( - itemFirst, count); - } - - inline bool update_current_producer_after_rotation(consumer_token_t &token) { - // Ah, there's been a rotation, figure out where we should be! - auto tail = producerListTail.load(std::memory_order_acquire); - if (token.desiredProducer == nullptr && tail == nullptr) { - return false; - } - auto prodCount = producerCount.load(std::memory_order_relaxed); - auto globalOffset = globalExplicitConsumerOffset.load(std::memory_order_relaxed); - if (details::unlikely(token.desiredProducer == nullptr)) { - // Aha, first time we're dequeueing anything. - // Figure out our local position - // Note: offset is from start, not end, but we're traversing from end -- subtract from count first - std::uint32_t offset = prodCount - 1 - (token.initialOffset % prodCount); - token.desiredProducer = tail; - for (std::uint32_t i = 0; i != offset; ++i) { - token.desiredProducer = static_cast(token.desiredProducer)->next_prod(); - if (token.desiredProducer == nullptr) { - token.desiredProducer = tail; - } - } - } - - std::uint32_t delta = globalOffset - token.lastKnownGlobalOffset; - if (delta >= prodCount) { - delta = delta % prodCount; - } - for (std::uint32_t i = 0; i != delta; ++i) { - token.desiredProducer = static_cast(token.desiredProducer)->next_prod(); - if (token.desiredProducer == nullptr) { - token.desiredProducer = tail; - } - } - - token.lastKnownGlobalOffset = globalOffset; - token.currentProducer = token.desiredProducer; - token.itemsConsumedFromCurrent = 0; - return true; - } - - - /////////////////////////// - // Free list - /////////////////////////// - - template - struct FreeListNode { - FreeListNode() - : freeListRefs(0), freeListNext(nullptr) {} - - std::atomic freeListRefs; - std::atomic freeListNext; - }; - - // A simple CAS-based lock-free free list. Not the fastest thing in the world under heavy contention, but - // simple and correct (assuming nodes are never freed until after the free list is destroyed), and fairly - // speedy under low contention. - template // N must inherit FreeListNode or have the same fields (and initialization of them) - struct FreeList { - FreeList() - : freeListHead(nullptr) {} - - FreeList(FreeList &&other) - : freeListHead(other.freeListHead.load(std::memory_order_relaxed)) { - other.freeListHead.store(nullptr, std::memory_order_relaxed); - } - - void swap(FreeList &other) { details::swap_relaxed(freeListHead, other.freeListHead); } - - FreeList(FreeList const &) MOODYCAMEL_DELETE_FUNCTION; - - FreeList &operator=(FreeList const &) MOODYCAMEL_DELETE_FUNCTION; - - inline void add(N *node) { -#if MCDBGQ_NOLOCKFREE_FREELIST - debug::DebugLock lock(mutex); -#endif - // We know that the should-be-on-freelist bit is 0 at this point, so it's safe to - // set it using a fetch_add - if (node->freeListRefs.fetch_add(SHOULD_BE_ON_FREELIST, std::memory_order_acq_rel) == 0) { - // Oh look! We were the last ones referencing this node, and we know - // we want to add it to the free list, so let's do it! - add_knowing_refcount_is_zero(node); - } - } - - inline N *try_get() { -#if MCDBGQ_NOLOCKFREE_FREELIST - debug::DebugLock lock(mutex); -#endif - auto head = freeListHead.load(std::memory_order_acquire); - while (head != nullptr) { - auto prevHead = head; - auto refs = head->freeListRefs.load(std::memory_order_relaxed); - if ((refs & REFS_MASK) == 0 || - !head->freeListRefs.compare_exchange_strong(refs, refs + 1, std::memory_order_acquire, - std::memory_order_relaxed)) { - head = freeListHead.load(std::memory_order_acquire); - continue; - } - - // Good, reference count has been incremented (it wasn't at zero), which means we can read the - // next and not worry about it changing between now and the time we do the CAS - auto next = head->freeListNext.load(std::memory_order_relaxed); - if (freeListHead.compare_exchange_strong(head, next, std::memory_order_acquire, - std::memory_order_relaxed)) { - // Yay, got the node. This means it was on the list, which means shouldBeOnFreeList must be false no - // matter the refcount (because nobody else knows it's been taken off yet, it can't have been put back on). - assert((head->freeListRefs.load(std::memory_order_relaxed) & SHOULD_BE_ON_FREELIST) == 0); - - // Decrease refcount twice, once for our ref, and once for the list's ref - head->freeListRefs.fetch_sub(2, std::memory_order_release); - return head; - } - - // OK, the head must have changed on us, but we still need to decrease the refcount we increased. - // Note that we don't need to release any memory effects, but we do need to ensure that the reference - // count decrement happens-after the CAS on the head. - refs = prevHead->freeListRefs.fetch_sub(1, std::memory_order_acq_rel); - if (refs == SHOULD_BE_ON_FREELIST + 1) { - add_knowing_refcount_is_zero(prevHead); - } - } - - return nullptr; - } - - // Useful for traversing the list when there's no contention (e.g. to destroy remaining nodes) - N *head_unsafe() const { return freeListHead.load(std::memory_order_relaxed); } - - private: - inline void add_knowing_refcount_is_zero(N *node) { - // Since the refcount is zero, and nobody can increase it once it's zero (except us, and we run - // only one copy of this method per node at a time, i.e. the single thread case), then we know - // we can safely change the next pointer of the node; however, once the refcount is back above - // zero, then other threads could increase it (happens under heavy contention, when the refcount - // goes to zero in between a load and a refcount increment of a node in try_get, then back up to - // something non-zero, then the refcount increment is done by the other thread) -- so, if the CAS - // to add the node to the actual list fails, decrease the refcount and leave the add operation to - // the next thread who puts the refcount back at zero (which could be us, hence the loop). - auto head = freeListHead.load(std::memory_order_relaxed); - while (true) { - node->freeListNext.store(head, std::memory_order_relaxed); - node->freeListRefs.store(1, std::memory_order_release); - if (!freeListHead.compare_exchange_strong(head, node, std::memory_order_release, - std::memory_order_relaxed)) { - // Hmm, the add failed, but we can only try again when the refcount goes back to zero - if (node->freeListRefs.fetch_add(SHOULD_BE_ON_FREELIST - 1, std::memory_order_release) == - 1) { - continue; - } - } - return; - } - } - - private: - // Implemented like a stack, but where node order doesn't matter (nodes are inserted out of order under contention) - std::atomic freeListHead; - - static const std::uint32_t REFS_MASK = 0x7FFFFFFF; - static const std::uint32_t SHOULD_BE_ON_FREELIST = 0x80000000; - -#if MCDBGQ_NOLOCKFREE_FREELIST - debug::DebugMutex mutex; -#endif - }; - - - /////////////////////////// - // Block - /////////////////////////// - - enum InnerQueueContext { - implicit_context = 0, explicit_context = 1 - }; - - struct Block { - Block() - : next(nullptr), elementsCompletelyDequeued(0), freeListRefs(0), freeListNext(nullptr) - , shouldBeOnFreeList(false), dynamicallyAllocated(true) { -#if MCDBGQ_TRACKMEM - owner = nullptr; -#endif - } - - template - inline bool is_empty() const { - if (context == explicit_context && BLOCK_SIZE <= EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD) { - // Check flags - for (size_t i = 0; i < BLOCK_SIZE; ++i) { - if (!emptyFlags[i].load(std::memory_order_relaxed)) { - return false; - } - } - - // Aha, empty; make sure we have all other memory effects that happened before the empty flags were set - std::atomic_thread_fence(std::memory_order_acquire); - return true; - } else { - // Check counter - if (elementsCompletelyDequeued.load(std::memory_order_relaxed) == BLOCK_SIZE) { - std::atomic_thread_fence(std::memory_order_acquire); - return true; - } - assert(elementsCompletelyDequeued.load(std::memory_order_relaxed) <= BLOCK_SIZE); - return false; - } - } - - // Returns true if the block is now empty (does not apply in explicit context) - template - inline bool set_empty(index_t i) { - if (context == explicit_context && BLOCK_SIZE <= EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD) { - // Set flag - assert(!emptyFlags[BLOCK_SIZE - 1 - - static_cast(i & static_cast(BLOCK_SIZE - 1))].load( - std::memory_order_relaxed)); - emptyFlags[BLOCK_SIZE - 1 - - static_cast(i & static_cast(BLOCK_SIZE - 1))].store(true, - std::memory_order_release); - return false; - } else { - // Increment counter - auto prevVal = elementsCompletelyDequeued.fetch_add(1, std::memory_order_release); - assert(prevVal < BLOCK_SIZE); - return prevVal == BLOCK_SIZE - 1; - } - } - - // Sets multiple contiguous item statuses to 'empty' (assumes no wrapping and count > 0). - // Returns true if the block is now empty (does not apply in explicit context). - template - inline bool set_many_empty(index_t i, size_t count) { - if (context == explicit_context && BLOCK_SIZE <= EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD) { - // Set flags - std::atomic_thread_fence(std::memory_order_release); - i = BLOCK_SIZE - 1 - static_cast(i & static_cast(BLOCK_SIZE - 1)) - count + - 1; - for (size_t j = 0; j != count; ++j) { - assert(!emptyFlags[i + j].load(std::memory_order_relaxed)); - emptyFlags[i + j].store(true, std::memory_order_relaxed); - } - return false; - } else { - // Increment counter - auto prevVal = elementsCompletelyDequeued.fetch_add(count, std::memory_order_release); - assert(prevVal + count <= BLOCK_SIZE); - return prevVal + count == BLOCK_SIZE; - } - } - - template - inline void set_all_empty() { - if (context == explicit_context && BLOCK_SIZE <= EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD) { - // Set all flags - for (size_t i = 0; i != BLOCK_SIZE; ++i) { - emptyFlags[i].store(true, std::memory_order_relaxed); - } - } else { - // Reset counter - elementsCompletelyDequeued.store(BLOCK_SIZE, std::memory_order_relaxed); - } - } - - template - inline void reset_empty() { - if (context == explicit_context && BLOCK_SIZE <= EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD) { - // Reset flags - for (size_t i = 0; i != BLOCK_SIZE; ++i) { - emptyFlags[i].store(false, std::memory_order_relaxed); - } - } else { - // Reset counter - elementsCompletelyDequeued.store(0, std::memory_order_relaxed); - } - } - - inline T *operator[](index_t idx) MOODYCAMEL_NOEXCEPT { - return static_cast(static_cast(elements)) + - static_cast(idx & static_cast(BLOCK_SIZE - 1)); - } - - inline T const *operator[](index_t idx) const MOODYCAMEL_NOEXCEPT { - return static_cast(static_cast(elements)) + - static_cast(idx & static_cast(BLOCK_SIZE - 1)); - } - - private: - // IMPORTANT: This must be the first member in Block, so that if T depends on the alignment of - // addresses returned by malloc, that alignment will be preserved. Apparently clang actually - // generates code that uses this assumption for AVX instructions in some cases. Ideally, we - // should also align Block to the alignment of T in case it's higher than malloc's 16-byte - // alignment, but this is hard to do in a cross-platform way. Assert for this case: - static_assert(std::alignment_of::value <= std::alignment_of::value, - "The queue does not support super-aligned types at this time"); - // Additionally, we need the alignment of Block itself to be a multiple of max_align_t since - // otherwise the appropriate padding will not be added at the end of Block in order to make - // arrays of Blocks all be properly aligned (not just the first one). We use a union to force - // this. - union { - char elements[sizeof(T) * BLOCK_SIZE]; - details::max_align_t dummy; - }; - public: - Block *next; - std::atomic elementsCompletelyDequeued; - std::atomic emptyFlags[ - BLOCK_SIZE <= EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD ? BLOCK_SIZE : 1]; - public: - std::atomic freeListRefs; - std::atomic freeListNext; - std::atomic shouldBeOnFreeList; - bool dynamicallyAllocated; // Perhaps a better name for this would be 'isNotPartOfInitialBlockPool' - -#if MCDBGQ_TRACKMEM - void* owner; -#endif - }; - - static_assert(std::alignment_of::value >= std::alignment_of::value, - "Internal error: Blocks must be at least as aligned as the type they are wrapping"); - - -#if MCDBGQ_TRACKMEM - public: - struct MemStats; - private: -#endif - - /////////////////////////// - // Producer base - /////////////////////////// - - struct ProducerBase : public details::ConcurrentQueueProducerTypelessBase { - ProducerBase(ConcurrentQueue *parent_, bool isExplicit_) - : - tailIndex(0), headIndex(0), dequeueOptimisticCount(0), dequeueOvercommit(0), tailBlock( - nullptr), isExplicit(isExplicit_), parent(parent_) { - } - - virtual ~ProducerBase() {}; - - template - inline bool dequeue(U &element) { - if (isExplicit) { - return static_cast(this)->dequeue(element); - } else { - return static_cast(this)->dequeue(element); - } - } - - template - inline size_t dequeue_bulk(It &itemFirst, size_t max) { - if (isExplicit) { - return static_cast(this)->dequeue_bulk(itemFirst, max); - } else { - return static_cast(this)->dequeue_bulk(itemFirst, max); - } - } - - inline ProducerBase *next_prod() const { return static_cast(next); } - - inline size_t size_approx() const { - auto tail = tailIndex.load(std::memory_order_relaxed); - auto head = headIndex.load(std::memory_order_relaxed); - return details::circular_less_than(head, tail) ? static_cast(tail - head) : 0; - } - - inline index_t getTail() const { return tailIndex.load(std::memory_order_relaxed); } - - protected: - std::atomic tailIndex; // Where to enqueue to next - std::atomic headIndex; // Where to dequeue from next - - std::atomic dequeueOptimisticCount; - std::atomic dequeueOvercommit; - - Block *tailBlock; - - public: - bool isExplicit; - ConcurrentQueue *parent; - - protected: -#if MCDBGQ_TRACKMEM - friend struct MemStats; -#endif - }; - - - /////////////////////////// - // Explicit queue - /////////////////////////// - - struct ExplicitProducer : public ProducerBase { - explicit ExplicitProducer(ConcurrentQueue *parent) - : - ProducerBase(parent, true), blockIndex(nullptr), pr_blockIndexSlotsUsed(0), pr_blockIndexSize( - EXPLICIT_INITIAL_INDEX_SIZE >> 1), pr_blockIndexFront(0), pr_blockIndexEntries(nullptr) - , pr_blockIndexRaw(nullptr) { - size_t poolBasedIndexSize = details::ceil_to_pow_2(parent->initialBlockPoolSize) >> 1; - if (poolBasedIndexSize > pr_blockIndexSize) { - pr_blockIndexSize = poolBasedIndexSize; - } - - new_block_index( - 0); // This creates an index with double the number of current entries, i.e. EXPLICIT_INITIAL_INDEX_SIZE - } - - ~ExplicitProducer() { - // Destruct any elements not yet dequeued. - // Since we're in the destructor, we can assume all elements - // are either completely dequeued or completely not (no halfways). - if (this->tailBlock != nullptr) { // Note this means there must be a block index too - // First find the block that's partially dequeued, if any - Block *halfDequeuedBlock = nullptr; - if ((this->headIndex.load(std::memory_order_relaxed) & - static_cast(BLOCK_SIZE - 1)) != 0) { - // The head's not on a block boundary, meaning a block somewhere is partially dequeued - // (or the head block is the tail block and was fully dequeued, but the head/tail are still not on a boundary) - size_t i = (pr_blockIndexFront - pr_blockIndexSlotsUsed) & (pr_blockIndexSize - 1); - while (details::circular_less_than(pr_blockIndexEntries[i].base + BLOCK_SIZE, - this->headIndex.load( - std::memory_order_relaxed))) { - i = (i + 1) & (pr_blockIndexSize - 1); - } - assert(details::circular_less_than(pr_blockIndexEntries[i].base, - this->headIndex.load( - std::memory_order_relaxed))); - halfDequeuedBlock = pr_blockIndexEntries[i].block; - } - - // Start at the head block (note the first line in the loop gives us the head from the tail on the first iteration) - auto block = this->tailBlock; - do { - block = block->next; - if (block->ConcurrentQueue::Block::template is_empty()) { - continue; - } - - size_t i = 0; // Offset into block - if (block == halfDequeuedBlock) { - i = static_cast(this->headIndex.load(std::memory_order_relaxed) & - static_cast(BLOCK_SIZE - 1)); - } - - // Walk through all the items in the block; if this is the tail block, we need to stop when we reach the tail index - auto lastValidIndex = (this->tailIndex.load(std::memory_order_relaxed) & - static_cast(BLOCK_SIZE - 1)) == 0 ? BLOCK_SIZE - : static_cast( - this->tailIndex.load(std::memory_order_relaxed) & - static_cast(BLOCK_SIZE - 1)); - while (i != BLOCK_SIZE && (block != this->tailBlock || i != lastValidIndex)) { - (*block)[i++]->~T(); - } - } while (block != this->tailBlock); - } - - // Destroy all blocks that we own - if (this->tailBlock != nullptr) { - auto block = this->tailBlock; - do { - auto nextBlock = block->next; - if (block->dynamicallyAllocated) { - destroy(block); - } else { - this->parent->add_block_to_free_list(block); - } - block = nextBlock; - } while (block != this->tailBlock); - } - - // Destroy the block indices - auto header = static_cast(pr_blockIndexRaw); - while (header != nullptr) { - auto prev = static_cast(header->prev); - header->~BlockIndexHeader(); - (Traits::free)(header); - header = prev; - } - } - - template - inline bool enqueue(U &&element) { - index_t currentTailIndex = this->tailIndex.load(std::memory_order_relaxed); - index_t newTailIndex = 1 + currentTailIndex; - if ((currentTailIndex & static_cast(BLOCK_SIZE - 1)) == 0) { - // We reached the end of a block, start a new one - auto startBlock = this->tailBlock; - auto originalBlockIndexSlotsUsed = pr_blockIndexSlotsUsed; - if (this->tailBlock != nullptr && - this->tailBlock->next->ConcurrentQueue::Block::template is_empty()) { - // We can re-use the block ahead of us, it's empty! - this->tailBlock = this->tailBlock->next; - this->tailBlock->ConcurrentQueue::Block::template reset_empty(); - - // We'll put the block on the block index (guaranteed to be room since we're conceptually removing the - // last block from it first -- except instead of removing then adding, we can just overwrite). - // Note that there must be a valid block index here, since even if allocation failed in the ctor, - // it would have been re-attempted when adding the first block to the queue; since there is such - // a block, a block index must have been successfully allocated. - } else { - // Whatever head value we see here is >= the last value we saw here (relatively), - // and <= its current value. Since we have the most recent tail, the head must be - // <= to it. - auto head = this->headIndex.load(std::memory_order_relaxed); - assert(!details::circular_less_than(currentTailIndex, head)); - if (!details::circular_less_than(head, currentTailIndex + BLOCK_SIZE) - || (MAX_SUBQUEUE_SIZE != details::const_numeric_max::value && - (MAX_SUBQUEUE_SIZE == 0 || - MAX_SUBQUEUE_SIZE - BLOCK_SIZE < currentTailIndex - head))) { - // We can't enqueue in another block because there's not enough leeway -- the - // tail could surpass the head by the time the block fills up! (Or we'll exceed - // the size limit, if the second part of the condition was true.) - return false; - } - // We're going to need a new block; check that the block index has room - if (pr_blockIndexRaw == nullptr || pr_blockIndexSlotsUsed == pr_blockIndexSize) { - // Hmm, the circular block index is already full -- we'll need - // to allocate a new index. Note pr_blockIndexRaw can only be nullptr if - // the initial allocation failed in the constructor. - - if (allocMode == CannotAlloc || !new_block_index(pr_blockIndexSlotsUsed)) { - return false; - } - } - - // Insert a new block in the circular linked list - auto newBlock = this->parent->ConcurrentQueue::template requisition_block(); - if (newBlock == nullptr) { - return false; - } -#if MCDBGQ_TRACKMEM - newBlock->owner = this; -#endif - newBlock->ConcurrentQueue::Block::template reset_empty(); - if (this->tailBlock == nullptr) { - newBlock->next = newBlock; - } else { - newBlock->next = this->tailBlock->next; - this->tailBlock->next = newBlock; - } - this->tailBlock = newBlock; - ++pr_blockIndexSlotsUsed; - } - - if (!MOODYCAMEL_NOEXCEPT_CTOR(T, U, new(nullptr) T(std::forward(element)))) { - // The constructor may throw. We want the element not to appear in the queue in - // that case (without corrupting the queue): - MOODYCAMEL_TRY { - new((*this->tailBlock)[currentTailIndex]) T(std::forward(element)); - } - MOODYCAMEL_CATCH (...) { - // Revert change to the current block, but leave the new block available - // for next time - pr_blockIndexSlotsUsed = originalBlockIndexSlotsUsed; - this->tailBlock = startBlock == nullptr ? this->tailBlock : startBlock; - MOODYCAMEL_RETHROW; - } - } else { - (void) startBlock; - (void) originalBlockIndexSlotsUsed; - } - - // Add block to block index - auto &entry = blockIndex.load(std::memory_order_relaxed)->entries[pr_blockIndexFront]; - entry.base = currentTailIndex; - entry.block = this->tailBlock; - blockIndex.load(std::memory_order_relaxed)->front.store(pr_blockIndexFront, - std::memory_order_release); - pr_blockIndexFront = (pr_blockIndexFront + 1) & (pr_blockIndexSize - 1); - - if (!MOODYCAMEL_NOEXCEPT_CTOR(T, U, new(nullptr) T(std::forward(element)))) { - this->tailIndex.store(newTailIndex, std::memory_order_release); - return true; - } - } - - // Enqueue - new((*this->tailBlock)[currentTailIndex]) T(std::forward(element)); - - this->tailIndex.store(newTailIndex, std::memory_order_release); - return true; - } - - template - bool dequeue(U &element) { - auto tail = this->tailIndex.load(std::memory_order_relaxed); - auto overcommit = this->dequeueOvercommit.load(std::memory_order_relaxed); - if (details::circular_less_than( - this->dequeueOptimisticCount.load(std::memory_order_relaxed) - overcommit, tail)) { - // Might be something to dequeue, let's give it a try - - // Note that this if is purely for performance purposes in the common case when the queue is - // empty and the values are eventually consistent -- we may enter here spuriously. - - // Note that whatever the values of overcommit and tail are, they are not going to change (unless we - // change them) and must be the same value at this point (inside the if) as when the if condition was - // evaluated. - - // We insert an acquire fence here to synchronize-with the release upon incrementing dequeueOvercommit below. - // This ensures that whatever the value we got loaded into overcommit, the load of dequeueOptisticCount in - // the fetch_add below will result in a value at least as recent as that (and therefore at least as large). - // Note that I believe a compiler (signal) fence here would be sufficient due to the nature of fetch_add (all - // read-modify-write operations are guaranteed to work on the latest value in the modification order), but - // unfortunately that can't be shown to be correct using only the C++11 standard. - // See http://stackoverflow.com/questions/18223161/what-are-the-c11-memory-ordering-guarantees-in-this-corner-case - std::atomic_thread_fence(std::memory_order_acquire); - - // Increment optimistic counter, then check if it went over the boundary - auto myDequeueCount = this->dequeueOptimisticCount.fetch_add(1, std::memory_order_relaxed); - - // Note that since dequeueOvercommit must be <= dequeueOptimisticCount (because dequeueOvercommit is only ever - // incremented after dequeueOptimisticCount -- this is enforced in the `else` block below), and since we now - // have a version of dequeueOptimisticCount that is at least as recent as overcommit (due to the release upon - // incrementing dequeueOvercommit and the acquire above that synchronizes with it), overcommit <= myDequeueCount. - assert(overcommit <= myDequeueCount); - - // Note that we reload tail here in case it changed; it will be the same value as before or greater, since - // this load is sequenced after (happens after) the earlier load above. This is supported by read-read - // coherency (as defined in the standard), explained here: http://en.cppreference.com/w/cpp/atomic/memory_order - tail = this->tailIndex.load(std::memory_order_acquire); - if (details::likely( - details::circular_less_than(myDequeueCount - overcommit, tail))) { - // Guaranteed to be at least one element to dequeue! - - // Get the index. Note that since there's guaranteed to be at least one element, this - // will never exceed tail. We need to do an acquire-release fence here since it's possible - // that whatever condition got us to this point was for an earlier enqueued element (that - // we already see the memory effects for), but that by the time we increment somebody else - // has incremented it, and we need to see the memory effects for *that* element, which is - // in such a case is necessarily visible on the thread that incremented it in the first - // place with the more current condition (they must have acquired a tail that is at least - // as recent). - auto index = this->headIndex.fetch_add(1, std::memory_order_acq_rel); - - - // Determine which block the element is in - - auto localBlockIndex = blockIndex.load(std::memory_order_acquire); - auto localBlockIndexHead = localBlockIndex->front.load(std::memory_order_acquire); - - // We need to be careful here about subtracting and dividing because of index wrap-around. - // When an index wraps, we need to preserve the sign of the offset when dividing it by the - // block size (in order to get a correct signed block count offset in all cases): - auto headBase = localBlockIndex->entries[localBlockIndexHead].base; - auto blockBaseIndex = index & ~static_cast(BLOCK_SIZE - 1); - auto offset = static_cast( - static_cast::type>(blockBaseIndex - headBase) / - BLOCK_SIZE); - auto block = localBlockIndex->entries[(localBlockIndexHead + offset) & - (localBlockIndex->size - 1)].block; - - // Dequeue - auto &el = *((*block)[index]); - if (!MOODYCAMEL_NOEXCEPT_ASSIGN(T, T &&, element = std::move(el))) { - // Make sure the element is still fully dequeued and destroyed even if the assignment - // throws - struct Guard { - Block *block; - index_t index; - - ~Guard() { - (*block)[index]->~T(); - block->ConcurrentQueue::Block::template set_empty(index); - } - } guard = {block, index}; - - element = std::move(el); - } else { - element = std::move(el); - el.~T(); - block->ConcurrentQueue::Block::template set_empty(index); - } - - return true; - } else { - // Wasn't anything to dequeue after all; make the effective dequeue count eventually consistent - this->dequeueOvercommit.fetch_add(1, - std::memory_order_release); // Release so that the fetch_add on dequeueOptimisticCount is guaranteed to happen before this write - } - } - - return false; - } - - template - bool enqueue_bulk(It itemFirst, size_t count) { - // First, we need to make sure we have enough room to enqueue all of the elements; - // this means pre-allocating blocks and putting them in the block index (but only if - // all the allocations succeeded). - index_t startTailIndex = this->tailIndex.load(std::memory_order_relaxed); - auto startBlock = this->tailBlock; - auto originalBlockIndexFront = pr_blockIndexFront; - auto originalBlockIndexSlotsUsed = pr_blockIndexSlotsUsed; - - Block *firstAllocatedBlock = nullptr; - - // Figure out how many blocks we'll need to allocate, and do so - size_t blockBaseDiff = - ((startTailIndex + count - 1) & ~static_cast(BLOCK_SIZE - 1)) - - ((startTailIndex - 1) & ~static_cast(BLOCK_SIZE - 1)); - index_t currentTailIndex = (startTailIndex - 1) & ~static_cast(BLOCK_SIZE - 1); - if (blockBaseDiff > 0) { - // Allocate as many blocks as possible from ahead - while (blockBaseDiff > 0 && this->tailBlock != nullptr && - this->tailBlock->next != firstAllocatedBlock && - this->tailBlock->next->ConcurrentQueue::Block::template is_empty()) { - blockBaseDiff -= static_cast(BLOCK_SIZE); - currentTailIndex += static_cast(BLOCK_SIZE); - - this->tailBlock = this->tailBlock->next; - firstAllocatedBlock = - firstAllocatedBlock == nullptr ? this->tailBlock : firstAllocatedBlock; - - auto &entry = blockIndex.load(std::memory_order_relaxed)->entries[pr_blockIndexFront]; - entry.base = currentTailIndex; - entry.block = this->tailBlock; - pr_blockIndexFront = (pr_blockIndexFront + 1) & (pr_blockIndexSize - 1); - } - - // Now allocate as many blocks as necessary from the block pool - while (blockBaseDiff > 0) { - blockBaseDiff -= static_cast(BLOCK_SIZE); - currentTailIndex += static_cast(BLOCK_SIZE); - - auto head = this->headIndex.load(std::memory_order_relaxed); - assert(!details::circular_less_than(currentTailIndex, head)); - bool full = !details::circular_less_than(head, currentTailIndex + BLOCK_SIZE) || - (MAX_SUBQUEUE_SIZE != details::const_numeric_max::value && - (MAX_SUBQUEUE_SIZE == 0 || - MAX_SUBQUEUE_SIZE - BLOCK_SIZE < currentTailIndex - head)); - if (pr_blockIndexRaw == nullptr || pr_blockIndexSlotsUsed == pr_blockIndexSize || full) { - if (allocMode == CannotAlloc || full || !new_block_index(originalBlockIndexSlotsUsed)) { - // Failed to allocate, undo changes (but keep injected blocks) - pr_blockIndexFront = originalBlockIndexFront; - pr_blockIndexSlotsUsed = originalBlockIndexSlotsUsed; - this->tailBlock = startBlock == nullptr ? firstAllocatedBlock : startBlock; - return false; - } - - // pr_blockIndexFront is updated inside new_block_index, so we need to - // update our fallback value too (since we keep the new index even if we - // later fail) - originalBlockIndexFront = originalBlockIndexSlotsUsed; - } - - // Insert a new block in the circular linked list - auto newBlock = this->parent->ConcurrentQueue::template requisition_block(); - if (newBlock == nullptr) { - pr_blockIndexFront = originalBlockIndexFront; - pr_blockIndexSlotsUsed = originalBlockIndexSlotsUsed; - this->tailBlock = startBlock == nullptr ? firstAllocatedBlock : startBlock; - return false; - } - -#if MCDBGQ_TRACKMEM - newBlock->owner = this; -#endif - newBlock->ConcurrentQueue::Block::template set_all_empty(); - if (this->tailBlock == nullptr) { - newBlock->next = newBlock; - } else { - newBlock->next = this->tailBlock->next; - this->tailBlock->next = newBlock; - } - this->tailBlock = newBlock; - firstAllocatedBlock = - firstAllocatedBlock == nullptr ? this->tailBlock : firstAllocatedBlock; - - ++pr_blockIndexSlotsUsed; - - auto &entry = blockIndex.load(std::memory_order_relaxed)->entries[pr_blockIndexFront]; - entry.base = currentTailIndex; - entry.block = this->tailBlock; - pr_blockIndexFront = (pr_blockIndexFront + 1) & (pr_blockIndexSize - 1); - } - - // Excellent, all allocations succeeded. Reset each block's emptiness before we fill them up, and - // publish the new block index front - auto block = firstAllocatedBlock; - while (true) { - block->ConcurrentQueue::Block::template reset_empty(); - if (block == this->tailBlock) { - break; - } - block = block->next; - } - - if (MOODYCAMEL_NOEXCEPT_CTOR(T, decltype(*itemFirst), - new(nullptr) T(details::deref_noexcept(itemFirst)))) { - blockIndex.load(std::memory_order_relaxed)->front.store( - (pr_blockIndexFront - 1) & (pr_blockIndexSize - 1), std::memory_order_release); - } - } - - // Enqueue, one block at a time - index_t newTailIndex = startTailIndex + static_cast(count); - currentTailIndex = startTailIndex; - auto endBlock = this->tailBlock; - this->tailBlock = startBlock; - assert((startTailIndex & static_cast(BLOCK_SIZE - 1)) != 0 || - firstAllocatedBlock != nullptr || count == 0); - if ((startTailIndex & static_cast(BLOCK_SIZE - 1)) == 0 && - firstAllocatedBlock != nullptr) { - this->tailBlock = firstAllocatedBlock; - } - while (true) { - auto stopIndex = (currentTailIndex & ~static_cast(BLOCK_SIZE - 1)) + - static_cast(BLOCK_SIZE); - if (details::circular_less_than(newTailIndex, stopIndex)) { - stopIndex = newTailIndex; - } - if (MOODYCAMEL_NOEXCEPT_CTOR(T, decltype(*itemFirst), - new(nullptr) T(details::deref_noexcept(itemFirst)))) { - while (currentTailIndex != stopIndex) { - new((*this->tailBlock)[currentTailIndex++]) T(*itemFirst++); - } - } else { - MOODYCAMEL_TRY { - while (currentTailIndex != stopIndex) { - // Must use copy constructor even if move constructor is available - // because we may have to revert if there's an exception. - // Sorry about the horrible templated next line, but it was the only way - // to disable moving *at compile time*, which is important because a type - // may only define a (noexcept) move constructor, and so calls to the - // cctor will not compile, even if they are in an if branch that will never - // be executed - new((*this->tailBlock)[currentTailIndex]) T( - details::nomove_if<(bool) !MOODYCAMEL_NOEXCEPT_CTOR(T, decltype(*itemFirst), - new(nullptr) T( - details::deref_noexcept( - itemFirst)))>::eval( - *itemFirst)); - ++currentTailIndex; - ++itemFirst; - } - } - MOODYCAMEL_CATCH (...) { - // Oh dear, an exception's been thrown -- destroy the elements that - // were enqueued so far and revert the entire bulk operation (we'll keep - // any allocated blocks in our linked list for later, though). - auto constructedStopIndex = currentTailIndex; - auto lastBlockEnqueued = this->tailBlock; - - pr_blockIndexFront = originalBlockIndexFront; - pr_blockIndexSlotsUsed = originalBlockIndexSlotsUsed; - this->tailBlock = startBlock == nullptr ? firstAllocatedBlock : startBlock; - - if (!details::is_trivially_destructible::value) { - auto block = startBlock; - if ((startTailIndex & static_cast(BLOCK_SIZE - 1)) == 0) { - block = firstAllocatedBlock; - } - currentTailIndex = startTailIndex; - while (true) { - stopIndex = (currentTailIndex & ~static_cast(BLOCK_SIZE - 1)) + - static_cast(BLOCK_SIZE); - if (details::circular_less_than(constructedStopIndex, stopIndex)) { - stopIndex = constructedStopIndex; - } - while (currentTailIndex != stopIndex) { - (*block)[currentTailIndex++]->~T(); - } - if (block == lastBlockEnqueued) { - break; - } - block = block->next; - } - } - MOODYCAMEL_RETHROW; - } - } - - if (this->tailBlock == endBlock) { - assert(currentTailIndex == newTailIndex); - break; - } - this->tailBlock = this->tailBlock->next; - } - - if (!MOODYCAMEL_NOEXCEPT_CTOR(T, decltype(*itemFirst), - new(nullptr) T(details::deref_noexcept(itemFirst))) && - firstAllocatedBlock != nullptr) { - blockIndex.load(std::memory_order_relaxed)->front.store( - (pr_blockIndexFront - 1) & (pr_blockIndexSize - 1), std::memory_order_release); - } - - this->tailIndex.store(newTailIndex, std::memory_order_release); - return true; - } - - template - size_t dequeue_bulk(It &itemFirst, size_t max) { - auto tail = this->tailIndex.load(std::memory_order_relaxed); - auto overcommit = this->dequeueOvercommit.load(std::memory_order_relaxed); - auto desiredCount = static_cast(tail - (this->dequeueOptimisticCount.load( - std::memory_order_relaxed) - overcommit)); - if (details::circular_less_than(0, desiredCount)) { - desiredCount = desiredCount < max ? desiredCount : max; - std::atomic_thread_fence(std::memory_order_acquire); - - auto myDequeueCount = this->dequeueOptimisticCount.fetch_add(desiredCount, - std::memory_order_relaxed); - assert(overcommit <= myDequeueCount); - - tail = this->tailIndex.load(std::memory_order_acquire); - auto actualCount = static_cast(tail - (myDequeueCount - overcommit)); - if (details::circular_less_than(0, actualCount)) { - actualCount = desiredCount < actualCount ? desiredCount : actualCount; - if (actualCount < desiredCount) { - this->dequeueOvercommit.fetch_add(desiredCount - actualCount, - std::memory_order_release); - } - - // Get the first index. Note that since there's guaranteed to be at least actualCount elements, this - // will never exceed tail. - auto firstIndex = this->headIndex.fetch_add(actualCount, std::memory_order_acq_rel); - - // Determine which block the first element is in - auto localBlockIndex = blockIndex.load(std::memory_order_acquire); - auto localBlockIndexHead = localBlockIndex->front.load(std::memory_order_acquire); - - auto headBase = localBlockIndex->entries[localBlockIndexHead].base; - auto firstBlockBaseIndex = firstIndex & ~static_cast(BLOCK_SIZE - 1); - auto offset = static_cast( - static_cast::type>(firstBlockBaseIndex - headBase) / - BLOCK_SIZE); - auto indexIndex = (localBlockIndexHead + offset) & (localBlockIndex->size - 1); - - // Iterate the blocks and dequeue - auto index = firstIndex; - do { - auto firstIndexInBlock = index; - auto endIndex = - (index & ~static_cast(BLOCK_SIZE - 1)) + static_cast(BLOCK_SIZE); - endIndex = details::circular_less_than( - firstIndex + static_cast(actualCount), endIndex) ? firstIndex + - static_cast(actualCount) - : endIndex; - auto block = localBlockIndex->entries[indexIndex].block; - if (MOODYCAMEL_NOEXCEPT_ASSIGN(T, T &&, details::deref_noexcept(itemFirst) = std::move( - (*(*block)[index])))) { - while (index != endIndex) { - auto &el = *((*block)[index]); - *itemFirst++ = std::move(el); - el.~T(); - ++index; - } - } else { - MOODYCAMEL_TRY { - while (index != endIndex) { - auto &el = *((*block)[index]); - *itemFirst = std::move(el); - ++itemFirst; - el.~T(); - ++index; - } - } - MOODYCAMEL_CATCH (...) { - // It's too late to revert the dequeue, but we can make sure that all - // the dequeued objects are properly destroyed and the block index - // (and empty count) are properly updated before we propagate the exception - do { - block = localBlockIndex->entries[indexIndex].block; - while (index != endIndex) { - (*block)[index++]->~T(); - } - block->ConcurrentQueue::Block::template set_many_empty( - firstIndexInBlock, static_cast(endIndex - firstIndexInBlock)); - indexIndex = (indexIndex + 1) & (localBlockIndex->size - 1); - - firstIndexInBlock = index; - endIndex = (index & ~static_cast(BLOCK_SIZE - 1)) + - static_cast(BLOCK_SIZE); - endIndex = details::circular_less_than( - firstIndex + static_cast(actualCount), endIndex) ? firstIndex + - static_cast(actualCount) - : endIndex; - } while (index != firstIndex + actualCount); - - MOODYCAMEL_RETHROW; - } - } - block->ConcurrentQueue::Block::template set_many_empty( - firstIndexInBlock, static_cast(endIndex - firstIndexInBlock)); - indexIndex = (indexIndex + 1) & (localBlockIndex->size - 1); - } while (index != firstIndex + actualCount); - - return actualCount; - } else { - // Wasn't anything to dequeue after all; make the effective dequeue count eventually consistent - this->dequeueOvercommit.fetch_add(desiredCount, std::memory_order_release); - } - } - - return 0; - } - - private: - struct BlockIndexEntry { - index_t base; - Block *block; - }; - - struct BlockIndexHeader { - size_t size; - std::atomic front; // Current slot (not next, like pr_blockIndexFront) - BlockIndexEntry *entries; - void *prev; - }; - - - bool new_block_index(size_t numberOfFilledSlotsToExpose) { - auto prevBlockSizeMask = pr_blockIndexSize - 1; - - // Create the new block - pr_blockIndexSize <<= 1; - auto newRawPtr = static_cast((Traits::malloc)( - sizeof(BlockIndexHeader) + std::alignment_of::value - 1 + - sizeof(BlockIndexEntry) * pr_blockIndexSize)); - if (newRawPtr == nullptr) { - pr_blockIndexSize >>= 1; // Reset to allow graceful retry - return false; - } - - auto newBlockIndexEntries = reinterpret_cast(details::align_for( - newRawPtr + sizeof(BlockIndexHeader))); - - // Copy in all the old indices, if any - size_t j = 0; - if (pr_blockIndexSlotsUsed != 0) { - auto i = (pr_blockIndexFront - pr_blockIndexSlotsUsed) & prevBlockSizeMask; - do { - newBlockIndexEntries[j++] = pr_blockIndexEntries[i]; - i = (i + 1) & prevBlockSizeMask; - } while (i != pr_blockIndexFront); - } - - // Update everything - auto header = new(newRawPtr) BlockIndexHeader; - header->size = pr_blockIndexSize; - header->front.store(numberOfFilledSlotsToExpose - 1, std::memory_order_relaxed); - header->entries = newBlockIndexEntries; - header->prev = pr_blockIndexRaw; // we link the new block to the old one so we can free it later - - pr_blockIndexFront = j; - pr_blockIndexEntries = newBlockIndexEntries; - pr_blockIndexRaw = newRawPtr; - blockIndex.store(header, std::memory_order_release); - - return true; - } - - private: - std::atomic blockIndex; - - // To be used by producer only -- consumer must use the ones in referenced by blockIndex - size_t pr_blockIndexSlotsUsed; - size_t pr_blockIndexSize; - size_t pr_blockIndexFront; // Next slot (not current) - BlockIndexEntry *pr_blockIndexEntries; - void *pr_blockIndexRaw; - -#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG - public: - ExplicitProducer* nextExplicitProducer; - private: -#endif - -#if MCDBGQ_TRACKMEM - friend struct MemStats; -#endif - }; - - - ////////////////////////////////// - // Implicit queue - ////////////////////////////////// - - struct ImplicitProducer : public ProducerBase { - ImplicitProducer(ConcurrentQueue *parent) - : - ProducerBase(parent, false), nextBlockIndexCapacity(IMPLICIT_INITIAL_INDEX_SIZE), blockIndex( - nullptr) { - new_block_index(); - } - - ~ImplicitProducer() { - // Note that since we're in the destructor we can assume that all enqueue/dequeue operations - // completed already; this means that all undequeued elements are placed contiguously across - // contiguous blocks, and that only the first and last remaining blocks can be only partially - // empty (all other remaining blocks must be completely full). - -#ifdef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED - // Unregister ourselves for thread termination notification - if (!this->inactive.load(std::memory_order_relaxed)) { - details::ThreadExitNotifier::unsubscribe(&threadExitListener); - } -#endif - - // Destroy all remaining elements! - auto tail = this->tailIndex.load(std::memory_order_relaxed); - auto index = this->headIndex.load(std::memory_order_relaxed); - Block *block = nullptr; - assert(index == tail || details::circular_less_than(index, tail)); - bool forceFreeLastBlock = - index != tail; // If we enter the loop, then the last (tail) block will not be freed - while (index != tail) { - if ((index & static_cast(BLOCK_SIZE - 1)) == 0 || block == nullptr) { - if (block != nullptr) { - // Free the old block - this->parent->add_block_to_free_list(block); - } - - block = get_block_index_entry_for_index(index)->value.load(std::memory_order_relaxed); - } - - ((*block)[index])->~T(); - ++index; - } - // Even if the queue is empty, there's still one block that's not on the free list - // (unless the head index reached the end of it, in which case the tail will be poised - // to create a new block). - if (this->tailBlock != nullptr && - (forceFreeLastBlock || (tail & static_cast(BLOCK_SIZE - 1)) != 0)) { - this->parent->add_block_to_free_list(this->tailBlock); - } - - // Destroy block index - auto localBlockIndex = blockIndex.load(std::memory_order_relaxed); - if (localBlockIndex != nullptr) { - for (size_t i = 0; i != localBlockIndex->capacity; ++i) { - localBlockIndex->index[i]->~BlockIndexEntry(); - } - do { - auto prev = localBlockIndex->prev; - localBlockIndex->~BlockIndexHeader(); - (Traits::free)(localBlockIndex); - localBlockIndex = prev; - } while (localBlockIndex != nullptr); - } - } - - template - inline bool enqueue(U &&element) { - index_t currentTailIndex = this->tailIndex.load(std::memory_order_relaxed); - index_t newTailIndex = 1 + currentTailIndex; - if ((currentTailIndex & static_cast(BLOCK_SIZE - 1)) == 0) { - // We reached the end of a block, start a new one - auto head = this->headIndex.load(std::memory_order_relaxed); - assert(!details::circular_less_than(currentTailIndex, head)); - if (!details::circular_less_than(head, currentTailIndex + BLOCK_SIZE) || - (MAX_SUBQUEUE_SIZE != details::const_numeric_max::value && - (MAX_SUBQUEUE_SIZE == 0 || - MAX_SUBQUEUE_SIZE - BLOCK_SIZE < currentTailIndex - head))) { - return false; - } -#if MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX - debug::DebugLock lock(mutex); -#endif - // Find out where we'll be inserting this block in the block index - BlockIndexEntry *idxEntry; - if (!insert_block_index_entry(idxEntry, currentTailIndex)) { - return false; - } - - // Get ahold of a new block - auto newBlock = this->parent->ConcurrentQueue::template requisition_block(); - if (newBlock == nullptr) { - rewind_block_index_tail(); - idxEntry->value.store(nullptr, std::memory_order_relaxed); - return false; - } -#if MCDBGQ_TRACKMEM - newBlock->owner = this; -#endif - newBlock->ConcurrentQueue::Block::template reset_empty(); - - if (!MOODYCAMEL_NOEXCEPT_CTOR(T, U, new(nullptr) T(std::forward(element)))) { - // May throw, try to insert now before we publish the fact that we have this new block - MOODYCAMEL_TRY { - new((*newBlock)[currentTailIndex]) T(std::forward(element)); - } - MOODYCAMEL_CATCH (...) { - rewind_block_index_tail(); - idxEntry->value.store(nullptr, std::memory_order_relaxed); - this->parent->add_block_to_free_list(newBlock); - MOODYCAMEL_RETHROW; - } - } - - // Insert the new block into the index - idxEntry->value.store(newBlock, std::memory_order_relaxed); - - this->tailBlock = newBlock; - - if (!MOODYCAMEL_NOEXCEPT_CTOR(T, U, new(nullptr) T(std::forward(element)))) { - this->tailIndex.store(newTailIndex, std::memory_order_release); - return true; - } - } - - // Enqueue - new((*this->tailBlock)[currentTailIndex]) T(std::forward(element)); - - this->tailIndex.store(newTailIndex, std::memory_order_release); - return true; - } - - template - bool dequeue(U &element) { - // See ExplicitProducer::dequeue for rationale and explanation - index_t tail = this->tailIndex.load(std::memory_order_relaxed); - index_t overcommit = this->dequeueOvercommit.load(std::memory_order_relaxed); - if (details::circular_less_than( - this->dequeueOptimisticCount.load(std::memory_order_relaxed) - overcommit, tail)) { - std::atomic_thread_fence(std::memory_order_acquire); - - index_t myDequeueCount = this->dequeueOptimisticCount.fetch_add(1, - std::memory_order_relaxed); - assert(overcommit <= myDequeueCount); - tail = this->tailIndex.load(std::memory_order_acquire); - if (details::likely( - details::circular_less_than(myDequeueCount - overcommit, tail))) { - index_t index = this->headIndex.fetch_add(1, std::memory_order_acq_rel); - - // Determine which block the element is in - auto entry = get_block_index_entry_for_index(index); - - // Dequeue - auto block = entry->value.load(std::memory_order_relaxed); - auto &el = *((*block)[index]); - - if (!MOODYCAMEL_NOEXCEPT_ASSIGN(T, T &&, element = std::move(el))) { -#if MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX - // Note: Acquiring the mutex with every dequeue instead of only when a block - // is released is very sub-optimal, but it is, after all, purely debug code. - debug::DebugLock lock(producer->mutex); -#endif - struct Guard { - Block *block; - index_t index; - BlockIndexEntry *entry; - ConcurrentQueue *parent; - - ~Guard() { - (*block)[index]->~T(); - if (block->ConcurrentQueue::Block::template set_empty(index)) { - entry->value.store(nullptr, std::memory_order_relaxed); - parent->add_block_to_free_list(block); - } - } - } guard = {block, index, entry, this->parent}; - - element = std::move(el); - } else { - element = std::move(el); - el.~T(); - - if (block->ConcurrentQueue::Block::template set_empty(index)) { - { -#if MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX - debug::DebugLock lock(mutex); -#endif - // Add the block back into the global free pool (and remove from block index) - entry->value.store(nullptr, std::memory_order_relaxed); - } - this->parent->add_block_to_free_list(block); // releases the above store - } - } - - return true; - } else { - this->dequeueOvercommit.fetch_add(1, std::memory_order_release); - } - } - - return false; - } - - template - bool enqueue_bulk(It itemFirst, size_t count) { - // First, we need to make sure we have enough room to enqueue all of the elements; - // this means pre-allocating blocks and putting them in the block index (but only if - // all the allocations succeeded). - - // Note that the tailBlock we start off with may not be owned by us any more; - // this happens if it was filled up exactly to the top (setting tailIndex to - // the first index of the next block which is not yet allocated), then dequeued - // completely (putting it on the free list) before we enqueue again. - - index_t startTailIndex = this->tailIndex.load(std::memory_order_relaxed); - auto startBlock = this->tailBlock; - Block *firstAllocatedBlock = nullptr; - auto endBlock = this->tailBlock; - - // Figure out how many blocks we'll need to allocate, and do so - size_t blockBaseDiff = - ((startTailIndex + count - 1) & ~static_cast(BLOCK_SIZE - 1)) - - ((startTailIndex - 1) & ~static_cast(BLOCK_SIZE - 1)); - index_t currentTailIndex = (startTailIndex - 1) & ~static_cast(BLOCK_SIZE - 1); - if (blockBaseDiff > 0) { -#if MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX - debug::DebugLock lock(mutex); -#endif - do { - blockBaseDiff -= static_cast(BLOCK_SIZE); - currentTailIndex += static_cast(BLOCK_SIZE); - - // Find out where we'll be inserting this block in the block index - BlockIndexEntry *idxEntry = nullptr; // initialization here unnecessary but compiler can't always tell - Block *newBlock; - bool indexInserted = false; - auto head = this->headIndex.load(std::memory_order_relaxed); - assert(!details::circular_less_than(currentTailIndex, head)); - bool full = !details::circular_less_than(head, currentTailIndex + BLOCK_SIZE) || - (MAX_SUBQUEUE_SIZE != details::const_numeric_max::value && - (MAX_SUBQUEUE_SIZE == 0 || - MAX_SUBQUEUE_SIZE - BLOCK_SIZE < currentTailIndex - head)); - if (full || - !(indexInserted = insert_block_index_entry(idxEntry, currentTailIndex)) || - (newBlock = this->parent->ConcurrentQueue::template requisition_block()) == - nullptr) { - // Index allocation or block allocation failed; revert any other allocations - // and index insertions done so far for this operation - if (indexInserted) { - rewind_block_index_tail(); - idxEntry->value.store(nullptr, std::memory_order_relaxed); - } - currentTailIndex = (startTailIndex - 1) & ~static_cast(BLOCK_SIZE - 1); - for (auto block = firstAllocatedBlock; block != nullptr; block = block->next) { - currentTailIndex += static_cast(BLOCK_SIZE); - idxEntry = get_block_index_entry_for_index(currentTailIndex); - idxEntry->value.store(nullptr, std::memory_order_relaxed); - rewind_block_index_tail(); - } - this->parent->add_blocks_to_free_list(firstAllocatedBlock); - this->tailBlock = startBlock; - - return false; - } - -#if MCDBGQ_TRACKMEM - newBlock->owner = this; -#endif - newBlock->ConcurrentQueue::Block::template reset_empty(); - newBlock->next = nullptr; - - // Insert the new block into the index - idxEntry->value.store(newBlock, std::memory_order_relaxed); - - // Store the chain of blocks so that we can undo if later allocations fail, - // and so that we can find the blocks when we do the actual enqueueing - if ((startTailIndex & static_cast(BLOCK_SIZE - 1)) != 0 || - firstAllocatedBlock != nullptr) { - assert(this->tailBlock != nullptr); - this->tailBlock->next = newBlock; - } - this->tailBlock = newBlock; - endBlock = newBlock; - firstAllocatedBlock = firstAllocatedBlock == nullptr ? newBlock : firstAllocatedBlock; - } while (blockBaseDiff > 0); - } - - // Enqueue, one block at a time - index_t newTailIndex = startTailIndex + static_cast(count); - currentTailIndex = startTailIndex; - this->tailBlock = startBlock; - assert((startTailIndex & static_cast(BLOCK_SIZE - 1)) != 0 || - firstAllocatedBlock != nullptr || count == 0); - if ((startTailIndex & static_cast(BLOCK_SIZE - 1)) == 0 && - firstAllocatedBlock != nullptr) { - this->tailBlock = firstAllocatedBlock; - } - while (true) { - auto stopIndex = (currentTailIndex & ~static_cast(BLOCK_SIZE - 1)) + - static_cast(BLOCK_SIZE); - if (details::circular_less_than(newTailIndex, stopIndex)) { - stopIndex = newTailIndex; - } - if (MOODYCAMEL_NOEXCEPT_CTOR(T, decltype(*itemFirst), - new(nullptr) T(details::deref_noexcept(itemFirst)))) { - while (currentTailIndex != stopIndex) { - new((*this->tailBlock)[currentTailIndex++]) T(*itemFirst++); - } - } else { - MOODYCAMEL_TRY { - while (currentTailIndex != stopIndex) { - new((*this->tailBlock)[currentTailIndex]) T( - details::nomove_if<(bool) !MOODYCAMEL_NOEXCEPT_CTOR(T, decltype(*itemFirst), - new(nullptr) T( - details::deref_noexcept( - itemFirst)))>::eval( - *itemFirst)); - ++currentTailIndex; - ++itemFirst; - } - } - MOODYCAMEL_CATCH (...) { - auto constructedStopIndex = currentTailIndex; - auto lastBlockEnqueued = this->tailBlock; - - if (!details::is_trivially_destructible::value) { - auto block = startBlock; - if ((startTailIndex & static_cast(BLOCK_SIZE - 1)) == 0) { - block = firstAllocatedBlock; - } - currentTailIndex = startTailIndex; - while (true) { - stopIndex = (currentTailIndex & ~static_cast(BLOCK_SIZE - 1)) + - static_cast(BLOCK_SIZE); - if (details::circular_less_than(constructedStopIndex, stopIndex)) { - stopIndex = constructedStopIndex; - } - while (currentTailIndex != stopIndex) { - (*block)[currentTailIndex++]->~T(); - } - if (block == lastBlockEnqueued) { - break; - } - block = block->next; - } - } - - currentTailIndex = (startTailIndex - 1) & ~static_cast(BLOCK_SIZE - 1); - for (auto block = firstAllocatedBlock; block != nullptr; block = block->next) { - currentTailIndex += static_cast(BLOCK_SIZE); - auto idxEntry = get_block_index_entry_for_index(currentTailIndex); - idxEntry->value.store(nullptr, std::memory_order_relaxed); - rewind_block_index_tail(); - } - this->parent->add_blocks_to_free_list(firstAllocatedBlock); - this->tailBlock = startBlock; - MOODYCAMEL_RETHROW; - } - } - - if (this->tailBlock == endBlock) { - assert(currentTailIndex == newTailIndex); - break; - } - this->tailBlock = this->tailBlock->next; - } - this->tailIndex.store(newTailIndex, std::memory_order_release); - return true; - } - - template - size_t dequeue_bulk(It &itemFirst, size_t max) { - auto tail = this->tailIndex.load(std::memory_order_relaxed); - auto overcommit = this->dequeueOvercommit.load(std::memory_order_relaxed); - auto desiredCount = static_cast(tail - (this->dequeueOptimisticCount.load( - std::memory_order_relaxed) - overcommit)); - if (details::circular_less_than(0, desiredCount)) { - desiredCount = desiredCount < max ? desiredCount : max; - std::atomic_thread_fence(std::memory_order_acquire); - - auto myDequeueCount = this->dequeueOptimisticCount.fetch_add(desiredCount, - std::memory_order_relaxed); - assert(overcommit <= myDequeueCount); - - tail = this->tailIndex.load(std::memory_order_acquire); - auto actualCount = static_cast(tail - (myDequeueCount - overcommit)); - if (details::circular_less_than(0, actualCount)) { - actualCount = desiredCount < actualCount ? desiredCount : actualCount; - if (actualCount < desiredCount) { - this->dequeueOvercommit.fetch_add(desiredCount - actualCount, - std::memory_order_release); - } - - // Get the first index. Note that since there's guaranteed to be at least actualCount elements, this - // will never exceed tail. - auto firstIndex = this->headIndex.fetch_add(actualCount, std::memory_order_acq_rel); - - // Iterate the blocks and dequeue - auto index = firstIndex; - BlockIndexHeader *localBlockIndex; - auto indexIndex = get_block_index_index_for_index(index, localBlockIndex); - do { - auto blockStartIndex = index; - auto endIndex = - (index & ~static_cast(BLOCK_SIZE - 1)) + static_cast(BLOCK_SIZE); - endIndex = details::circular_less_than( - firstIndex + static_cast(actualCount), endIndex) ? firstIndex + - static_cast(actualCount) - : endIndex; - - auto entry = localBlockIndex->index[indexIndex]; - auto block = entry->value.load(std::memory_order_relaxed); - if (MOODYCAMEL_NOEXCEPT_ASSIGN(T, T &&, details::deref_noexcept(itemFirst) = std::move( - (*(*block)[index])))) { - while (index != endIndex) { - auto &el = *((*block)[index]); - *itemFirst++ = std::move(el); - el.~T(); - ++index; - } - } else { - MOODYCAMEL_TRY { - while (index != endIndex) { - auto &el = *((*block)[index]); - *itemFirst = std::move(el); - ++itemFirst; - el.~T(); - ++index; - } - } - MOODYCAMEL_CATCH (...) { - do { - entry = localBlockIndex->index[indexIndex]; - block = entry->value.load(std::memory_order_relaxed); - while (index != endIndex) { - (*block)[index++]->~T(); - } - - if (block->ConcurrentQueue::Block::template set_many_empty( - blockStartIndex, static_cast(endIndex - blockStartIndex))) { -#if MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX - debug::DebugLock lock(mutex); -#endif - entry->value.store(nullptr, std::memory_order_relaxed); - this->parent->add_block_to_free_list(block); - } - indexIndex = (indexIndex + 1) & (localBlockIndex->capacity - 1); - - blockStartIndex = index; - endIndex = (index & ~static_cast(BLOCK_SIZE - 1)) + - static_cast(BLOCK_SIZE); - endIndex = details::circular_less_than( - firstIndex + static_cast(actualCount), endIndex) ? firstIndex + - static_cast(actualCount) - : endIndex; - } while (index != firstIndex + actualCount); - - MOODYCAMEL_RETHROW; - } - } - if (block->ConcurrentQueue::Block::template set_many_empty( - blockStartIndex, static_cast(endIndex - blockStartIndex))) { - { -#if MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX - debug::DebugLock lock(mutex); -#endif - // Note that the set_many_empty above did a release, meaning that anybody who acquires the block - // we're about to free can use it safely since our writes (and reads!) will have happened-before then. - entry->value.store(nullptr, std::memory_order_relaxed); - } - this->parent->add_block_to_free_list(block); // releases the above store - } - indexIndex = (indexIndex + 1) & (localBlockIndex->capacity - 1); - } while (index != firstIndex + actualCount); - - return actualCount; - } else { - this->dequeueOvercommit.fetch_add(desiredCount, std::memory_order_release); - } - } - - return 0; - } - - private: - // The block size must be > 1, so any number with the low bit set is an invalid block base index - static const index_t INVALID_BLOCK_BASE = 1; - - struct BlockIndexEntry { - std::atomic key; - std::atomic value; - }; - - struct BlockIndexHeader { - size_t capacity; - std::atomic tail; - BlockIndexEntry *entries; - BlockIndexEntry **index; - BlockIndexHeader *prev; - }; - - template - inline bool insert_block_index_entry(BlockIndexEntry *&idxEntry, index_t blockStartIndex) { - auto localBlockIndex = blockIndex.load( - std::memory_order_relaxed); // We're the only writer thread, relaxed is OK - if (localBlockIndex == nullptr) { - return false; // this can happen if new_block_index failed in the constructor - } - auto newTail = (localBlockIndex->tail.load(std::memory_order_relaxed) + 1) & - (localBlockIndex->capacity - 1); - idxEntry = localBlockIndex->index[newTail]; - if (idxEntry->key.load(std::memory_order_relaxed) == INVALID_BLOCK_BASE || - idxEntry->value.load(std::memory_order_relaxed) == nullptr) { - - idxEntry->key.store(blockStartIndex, std::memory_order_relaxed); - localBlockIndex->tail.store(newTail, std::memory_order_release); - return true; - } - - // No room in the old block index, try to allocate another one! - if (allocMode == CannotAlloc || !new_block_index()) { - return false; - } - localBlockIndex = blockIndex.load(std::memory_order_relaxed); - newTail = (localBlockIndex->tail.load(std::memory_order_relaxed) + 1) & - (localBlockIndex->capacity - 1); - idxEntry = localBlockIndex->index[newTail]; - assert(idxEntry->key.load(std::memory_order_relaxed) == INVALID_BLOCK_BASE); - idxEntry->key.store(blockStartIndex, std::memory_order_relaxed); - localBlockIndex->tail.store(newTail, std::memory_order_release); - return true; - } - - inline void rewind_block_index_tail() { - auto localBlockIndex = blockIndex.load(std::memory_order_relaxed); - localBlockIndex->tail.store((localBlockIndex->tail.load(std::memory_order_relaxed) - 1) & - (localBlockIndex->capacity - 1), std::memory_order_relaxed); - } - - inline BlockIndexEntry *get_block_index_entry_for_index(index_t index) const { - BlockIndexHeader *localBlockIndex; - auto idx = get_block_index_index_for_index(index, localBlockIndex); - return localBlockIndex->index[idx]; - } - - inline size_t - get_block_index_index_for_index(index_t index, BlockIndexHeader *&localBlockIndex) const { -#if MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX - debug::DebugLock lock(mutex); -#endif - index &= ~static_cast(BLOCK_SIZE - 1); - localBlockIndex = blockIndex.load(std::memory_order_acquire); - auto tail = localBlockIndex->tail.load(std::memory_order_acquire); - auto tailBase = localBlockIndex->index[tail]->key.load(std::memory_order_relaxed); - assert(tailBase != INVALID_BLOCK_BASE); - // Note: Must use division instead of shift because the index may wrap around, causing a negative - // offset, whose negativity we want to preserve - auto offset = static_cast( - static_cast::type>(index - tailBase) / BLOCK_SIZE); - size_t idx = (tail + offset) & (localBlockIndex->capacity - 1); - assert(localBlockIndex->index[idx]->key.load(std::memory_order_relaxed) == index && - localBlockIndex->index[idx]->value.load(std::memory_order_relaxed) != nullptr); - return idx; - } - - bool new_block_index() { - auto prev = blockIndex.load(std::memory_order_relaxed); - size_t prevCapacity = prev == nullptr ? 0 : prev->capacity; - auto entryCount = prev == nullptr ? nextBlockIndexCapacity : prevCapacity; - auto raw = static_cast((Traits::malloc)( - sizeof(BlockIndexHeader) + - std::alignment_of::value - 1 + sizeof(BlockIndexEntry) * entryCount + - std::alignment_of::value - 1 + - sizeof(BlockIndexEntry * ) * nextBlockIndexCapacity)); - if (raw == nullptr) { - return false; - } - - auto header = new(raw) BlockIndexHeader; - auto entries = reinterpret_cast(details::align_for( - raw + sizeof(BlockIndexHeader))); - auto index = reinterpret_cast(details::align_for( - reinterpret_cast(entries) + sizeof(BlockIndexEntry) * entryCount)); - if (prev != nullptr) { - auto prevTail = prev->tail.load(std::memory_order_relaxed); - auto prevPos = prevTail; - size_t i = 0; - do { - prevPos = (prevPos + 1) & (prev->capacity - 1); - index[i++] = prev->index[prevPos]; - } while (prevPos != prevTail); - assert(i == prevCapacity); - } - for (size_t i = 0; i != entryCount; ++i) { - new(entries + i) BlockIndexEntry; - entries[i].key.store(INVALID_BLOCK_BASE, std::memory_order_relaxed); - index[prevCapacity + i] = entries + i; - } - header->prev = prev; - header->entries = entries; - header->index = index; - header->capacity = nextBlockIndexCapacity; - header->tail.store((prevCapacity - 1) & (nextBlockIndexCapacity - 1), - std::memory_order_relaxed); - - blockIndex.store(header, std::memory_order_release); - - nextBlockIndexCapacity <<= 1; - - return true; - } - - private: - size_t nextBlockIndexCapacity; - std::atomic blockIndex; - -#ifdef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED - public: - details::ThreadExitListener threadExitListener; - private: -#endif - -#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG - public: - ImplicitProducer* nextImplicitProducer; - private: -#endif - -#if MCDBGQ_NOLOCKFREE_IMPLICITPRODBLOCKINDEX - mutable debug::DebugMutex mutex; -#endif -#if MCDBGQ_TRACKMEM - friend struct MemStats; -#endif - }; - - - ////////////////////////////////// - // Block pool manipulation - ////////////////////////////////// - - void populate_initial_block_list(size_t blockCount) { - initialBlockPoolSize = blockCount; - if (initialBlockPoolSize == 0) { - initialBlockPool = nullptr; - return; - } - - initialBlockPool = create_array(blockCount); - if (initialBlockPool == nullptr) { - initialBlockPoolSize = 0; - } - for (size_t i = 0; i < initialBlockPoolSize; ++i) { - initialBlockPool[i].dynamicallyAllocated = false; - } - } - - inline Block *try_get_block_from_initial_pool() { - if (initialBlockPoolIndex.load(std::memory_order_relaxed) >= initialBlockPoolSize) { - return nullptr; - } - - auto index = initialBlockPoolIndex.fetch_add(1, std::memory_order_relaxed); - - return index < initialBlockPoolSize ? (initialBlockPool + index) : nullptr; - } - - inline void add_block_to_free_list(Block *block) { -#if MCDBGQ_TRACKMEM - block->owner = nullptr; -#endif - freeList.add(block); - } - - inline void add_blocks_to_free_list(Block *block) { - while (block != nullptr) { - auto next = block->next; - add_block_to_free_list(block); - block = next; - } - } - - inline Block *try_get_block_from_free_list() { - return freeList.try_get(); - } - - // Gets a free block from one of the memory pools, or allocates a new one (if applicable) - template - Block *requisition_block() { - auto block = try_get_block_from_initial_pool(); - if (block != nullptr) { - return block; - } - - block = try_get_block_from_free_list(); - if (block != nullptr) { - return block; - } - - if (canAlloc == CanAlloc) { - return create(); - } - - return nullptr; - } - - -#if MCDBGQ_TRACKMEM - public: - struct MemStats { - size_t allocatedBlocks; - size_t usedBlocks; - size_t freeBlocks; - size_t ownedBlocksExplicit; - size_t ownedBlocksImplicit; - size_t implicitProducers; - size_t explicitProducers; - size_t elementsEnqueued; - size_t blockClassBytes; - size_t queueClassBytes; - size_t implicitBlockIndexBytes; - size_t explicitBlockIndexBytes; - - friend class ConcurrentQueue; - - private: - static MemStats getFor(ConcurrentQueue* q) - { - MemStats stats = { 0 }; - - stats.elementsEnqueued = q->size_approx(); - - auto block = q->freeList.head_unsafe(); - while (block != nullptr) { - ++stats.allocatedBlocks; - ++stats.freeBlocks; - block = block->freeListNext.load(std::memory_order_relaxed); - } - - for (auto ptr = q->producerListTail.load(std::memory_order_acquire); ptr != nullptr; ptr = ptr->next_prod()) { - bool implicit = dynamic_cast(ptr) != nullptr; - stats.implicitProducers += implicit ? 1 : 0; - stats.explicitProducers += implicit ? 0 : 1; - - if (implicit) { - auto prod = static_cast(ptr); - stats.queueClassBytes += sizeof(ImplicitProducer); - auto head = prod->headIndex.load(std::memory_order_relaxed); - auto tail = prod->tailIndex.load(std::memory_order_relaxed); - auto hash = prod->blockIndex.load(std::memory_order_relaxed); - if (hash != nullptr) { - for (size_t i = 0; i != hash->capacity; ++i) { - if (hash->index[i]->key.load(std::memory_order_relaxed) != ImplicitProducer::INVALID_BLOCK_BASE && hash->index[i]->value.load(std::memory_order_relaxed) != nullptr) { - ++stats.allocatedBlocks; - ++stats.ownedBlocksImplicit; - } - } - stats.implicitBlockIndexBytes += hash->capacity * sizeof(typename ImplicitProducer::BlockIndexEntry); - for (; hash != nullptr; hash = hash->prev) { - stats.implicitBlockIndexBytes += sizeof(typename ImplicitProducer::BlockIndexHeader) + hash->capacity * sizeof(typename ImplicitProducer::BlockIndexEntry*); - } - } - for (; details::circular_less_than(head, tail); head += BLOCK_SIZE) { - //auto block = prod->get_block_index_entry_for_index(head); - ++stats.usedBlocks; - } - } - else { - auto prod = static_cast(ptr); - stats.queueClassBytes += sizeof(ExplicitProducer); - auto tailBlock = prod->tailBlock; - bool wasNonEmpty = false; - if (tailBlock != nullptr) { - auto block = tailBlock; - do { - ++stats.allocatedBlocks; - if (!block->ConcurrentQueue::Block::template is_empty() || wasNonEmpty) { - ++stats.usedBlocks; - wasNonEmpty = wasNonEmpty || block != tailBlock; - } - ++stats.ownedBlocksExplicit; - block = block->next; - } while (block != tailBlock); - } - auto index = prod->blockIndex.load(std::memory_order_relaxed); - while (index != nullptr) { - stats.explicitBlockIndexBytes += sizeof(typename ExplicitProducer::BlockIndexHeader) + index->size * sizeof(typename ExplicitProducer::BlockIndexEntry); - index = static_cast(index->prev); - } - } - } - - auto freeOnInitialPool = q->initialBlockPoolIndex.load(std::memory_order_relaxed) >= q->initialBlockPoolSize ? 0 : q->initialBlockPoolSize - q->initialBlockPoolIndex.load(std::memory_order_relaxed); - stats.allocatedBlocks += freeOnInitialPool; - stats.freeBlocks += freeOnInitialPool; - - stats.blockClassBytes = sizeof(Block) * stats.allocatedBlocks; - stats.queueClassBytes += sizeof(ConcurrentQueue); - - return stats; - } - }; - - // For debugging only. Not thread-safe. - MemStats getMemStats() - { - return MemStats::getFor(this); - } - private: - friend struct MemStats; -#endif - - - ////////////////////////////////// - // Producer list manipulation - ////////////////////////////////// - - ProducerBase *recycle_or_create_producer(bool isExplicit) { - bool recycled; - return recycle_or_create_producer(isExplicit, recycled); - } - - ProducerBase *recycle_or_create_producer(bool isExplicit, bool &recycled) { -#if MCDBGQ_NOLOCKFREE_IMPLICITPRODHASH - debug::DebugLock lock(implicitProdMutex); -#endif - // Try to re-use one first - for (auto ptr = producerListTail.load(std::memory_order_acquire); - ptr != nullptr; ptr = ptr->next_prod()) { - if (ptr->inactive.load(std::memory_order_relaxed) && ptr->isExplicit == isExplicit) { - bool expected = true; - if (ptr->inactive.compare_exchange_strong(expected, /* desired */ false, - std::memory_order_acquire, - std::memory_order_relaxed)) { - // We caught one! It's been marked as activated, the caller can have it - recycled = true; - return ptr; - } - } - } - - recycled = false; - return add_producer(isExplicit ? static_cast(create(this)) - : create(this)); - } - - ProducerBase *add_producer(ProducerBase *producer) { - // Handle failed memory allocation - if (producer == nullptr) { - return nullptr; - } - - producerCount.fetch_add(1, std::memory_order_relaxed); - - // Add it to the lock-free list - auto prevTail = producerListTail.load(std::memory_order_relaxed); - do { - producer->next = prevTail; - } while (!producerListTail.compare_exchange_weak(prevTail, producer, std::memory_order_release, - std::memory_order_relaxed)); - -#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG - if (producer->isExplicit) { - auto prevTailExplicit = explicitProducers.load(std::memory_order_relaxed); - do { - static_cast(producer)->nextExplicitProducer = prevTailExplicit; - } while (!explicitProducers.compare_exchange_weak(prevTailExplicit, static_cast(producer), std::memory_order_release, std::memory_order_relaxed)); - } - else { - auto prevTailImplicit = implicitProducers.load(std::memory_order_relaxed); - do { - static_cast(producer)->nextImplicitProducer = prevTailImplicit; - } while (!implicitProducers.compare_exchange_weak(prevTailImplicit, static_cast(producer), std::memory_order_release, std::memory_order_relaxed)); - } -#endif - - return producer; - } - - void reown_producers() { - // After another instance is moved-into/swapped-with this one, all the - // producers we stole still think their parents are the other queue. - // So fix them up! - for (auto ptr = producerListTail.load(std::memory_order_relaxed); - ptr != nullptr; ptr = ptr->next_prod()) { - ptr->parent = this; - } - } - - - ////////////////////////////////// - // Implicit producer hash - ////////////////////////////////// - - struct ImplicitProducerKVP { - std::atomic key; - ImplicitProducer *value; // No need for atomicity since it's only read by the thread that sets it in the first place - - ImplicitProducerKVP() - : value(nullptr) {} - - ImplicitProducerKVP(ImplicitProducerKVP &&other) MOODYCAMEL_NOEXCEPT { - key.store(other.key.load(std::memory_order_relaxed), std::memory_order_relaxed); - value = other.value; - } - - inline ImplicitProducerKVP &operator=(ImplicitProducerKVP &&other) MOODYCAMEL_NOEXCEPT { - swap(other); - return *this; - } - - inline void swap(ImplicitProducerKVP &other) MOODYCAMEL_NOEXCEPT { - if (this != &other) { - details::swap_relaxed(key, other.key); - std::swap(value, other.value); - } - } - }; - - template - friend void moodycamel::swap(typename ConcurrentQueue::ImplicitProducerKVP &, - typename ConcurrentQueue::ImplicitProducerKVP &) MOODYCAMEL_NOEXCEPT; - - struct ImplicitProducerHash { - size_t capacity; - ImplicitProducerKVP *entries; - ImplicitProducerHash *prev; - }; - - inline void populate_initial_implicit_producer_hash() { - if (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) return; - - implicitProducerHashCount.store(0, std::memory_order_relaxed); - auto hash = &initialImplicitProducerHash; - hash->capacity = INITIAL_IMPLICIT_PRODUCER_HASH_SIZE; - hash->entries = &initialImplicitProducerHashEntries[0]; - for (size_t i = 0; i != INITIAL_IMPLICIT_PRODUCER_HASH_SIZE; ++i) { - initialImplicitProducerHashEntries[i].key.store(details::invalid_thread_id, - std::memory_order_relaxed); - } - hash->prev = nullptr; - implicitProducerHash.store(hash, std::memory_order_relaxed); - } - - void swap_implicit_producer_hashes(ConcurrentQueue &other) { - if (INITIAL_IMPLICIT_PRODUCER_HASH_SIZE == 0) return; - - // Swap (assumes our implicit producer hash is initialized) - initialImplicitProducerHashEntries.swap(other.initialImplicitProducerHashEntries); - initialImplicitProducerHash.entries = &initialImplicitProducerHashEntries[0]; - other.initialImplicitProducerHash.entries = &other.initialImplicitProducerHashEntries[0]; - - details::swap_relaxed(implicitProducerHashCount, other.implicitProducerHashCount); - - details::swap_relaxed(implicitProducerHash, other.implicitProducerHash); - if (implicitProducerHash.load(std::memory_order_relaxed) == - &other.initialImplicitProducerHash) { - implicitProducerHash.store(&initialImplicitProducerHash, std::memory_order_relaxed); - } else { - ImplicitProducerHash *hash; - for (hash = implicitProducerHash.load(std::memory_order_relaxed); - hash->prev != &other.initialImplicitProducerHash; hash = hash->prev) { - continue; - } - hash->prev = &initialImplicitProducerHash; - } - if (other.implicitProducerHash.load(std::memory_order_relaxed) == - &initialImplicitProducerHash) { - other.implicitProducerHash.store(&other.initialImplicitProducerHash, - std::memory_order_relaxed); - } else { - ImplicitProducerHash *hash; - for (hash = other.implicitProducerHash.load(std::memory_order_relaxed); - hash->prev != &initialImplicitProducerHash; hash = hash->prev) { - continue; - } - hash->prev = &other.initialImplicitProducerHash; - } - } - - // Only fails (returns nullptr) if memory allocation fails - ImplicitProducer *get_or_add_implicit_producer() { - // Note that since the data is essentially thread-local (key is thread ID), - // there's a reduced need for fences (memory ordering is already consistent - // for any individual thread), except for the current table itself. - - // Start by looking for the thread ID in the current and all previous hash tables. - // If it's not found, it must not be in there yet, since this same thread would - // have added it previously to one of the tables that we traversed. - - // Code and algorithm adapted from http://preshing.com/20130605/the-worlds-simplest-lock-free-hash-table - -#if MCDBGQ_NOLOCKFREE_IMPLICITPRODHASH - debug::DebugLock lock(implicitProdMutex); -#endif - - auto id = details::thread_id(); - auto hashedId = details::hash_thread_id(id); - - auto mainHash = implicitProducerHash.load(std::memory_order_acquire); - for (auto hash = mainHash; hash != nullptr; hash = hash->prev) { - // Look for the id in this hash - auto index = hashedId; - while (true) { // Not an infinite loop because at least one slot is free in the hash table - index &= hash->capacity - 1; - - auto probedKey = hash->entries[index].key.load(std::memory_order_relaxed); - if (probedKey == id) { - // Found it! If we had to search several hashes deep, though, we should lazily add it - // to the current main hash table to avoid the extended search next time. - // Note there's guaranteed to be room in the current hash table since every subsequent - // table implicitly reserves space for all previous tables (there's only one - // implicitProducerHashCount). - auto value = hash->entries[index].value; - if (hash != mainHash) { - index = hashedId; - while (true) { - index &= mainHash->capacity - 1; - probedKey = mainHash->entries[index].key.load(std::memory_order_relaxed); - auto empty = details::invalid_thread_id; -#ifdef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED - auto reusable = details::invalid_thread_id2; - if ((probedKey == empty && mainHash->entries[index].key.compare_exchange_strong(empty, id, std::memory_order_relaxed, std::memory_order_relaxed)) || - (probedKey == reusable && mainHash->entries[index].key.compare_exchange_strong(reusable, id, std::memory_order_acquire, std::memory_order_acquire))) { -#else - if ((probedKey == empty && - mainHash->entries[index].key.compare_exchange_strong(empty, id, - std::memory_order_relaxed, - std::memory_order_relaxed))) { -#endif - mainHash->entries[index].value = value; - break; - } - ++index; - } - } - - return value; - } - if (probedKey == details::invalid_thread_id) { - break; // Not in this hash table - } - ++index; - } - } - - // Insert! - auto newCount = 1 + implicitProducerHashCount.fetch_add(1, std::memory_order_relaxed); - while (true) { - if (newCount >= (mainHash->capacity >> 1) && - !implicitProducerHashResizeInProgress.test_and_set(std::memory_order_acquire)) { - // We've acquired the resize lock, try to allocate a bigger hash table. - // Note the acquire fence synchronizes with the release fence at the end of this block, and hence when - // we reload implicitProducerHash it must be the most recent version (it only gets changed within this - // locked block). - mainHash = implicitProducerHash.load(std::memory_order_acquire); - if (newCount >= (mainHash->capacity >> 1)) { - auto newCapacity = mainHash->capacity << 1; - while (newCount >= (newCapacity >> 1)) { - newCapacity <<= 1; - } - auto raw = static_cast((Traits::malloc)( - sizeof(ImplicitProducerHash) + std::alignment_of::value - 1 + - sizeof(ImplicitProducerKVP) * newCapacity)); - if (raw == nullptr) { - // Allocation failed - implicitProducerHashCount.fetch_sub(1, std::memory_order_relaxed); - implicitProducerHashResizeInProgress.clear(std::memory_order_relaxed); - return nullptr; - } - - auto newHash = new(raw) ImplicitProducerHash; - newHash->capacity = newCapacity; - newHash->entries = reinterpret_cast(details::align_for( - raw + sizeof(ImplicitProducerHash))); - for (size_t i = 0; i != newCapacity; ++i) { - new(newHash->entries + i) ImplicitProducerKVP; - newHash->entries[i].key.store(details::invalid_thread_id, std::memory_order_relaxed); - } - newHash->prev = mainHash; - implicitProducerHash.store(newHash, std::memory_order_release); - implicitProducerHashResizeInProgress.clear(std::memory_order_release); - mainHash = newHash; - } else { - implicitProducerHashResizeInProgress.clear(std::memory_order_release); - } - } - - // If it's < three-quarters full, add to the old one anyway so that we don't have to wait for the next table - // to finish being allocated by another thread (and if we just finished allocating above, the condition will - // always be true) - if (newCount < (mainHash->capacity >> 1) + (mainHash->capacity >> 2)) { - bool recycled; - auto producer = static_cast(recycle_or_create_producer(false, - recycled)); - if (producer == nullptr) { - implicitProducerHashCount.fetch_sub(1, std::memory_order_relaxed); - return nullptr; - } - if (recycled) { - implicitProducerHashCount.fetch_sub(1, std::memory_order_relaxed); - } - -#ifdef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED - producer->threadExitListener.callback = &ConcurrentQueue::implicit_producer_thread_exited_callback; - producer->threadExitListener.userData = producer; - details::ThreadExitNotifier::subscribe(&producer->threadExitListener); -#endif - - auto index = hashedId; - while (true) { - index &= mainHash->capacity - 1; - auto probedKey = mainHash->entries[index].key.load(std::memory_order_relaxed); - - auto empty = details::invalid_thread_id; -#ifdef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED - auto reusable = details::invalid_thread_id2; - if ((probedKey == empty && mainHash->entries[index].key.compare_exchange_strong(empty, id, std::memory_order_relaxed, std::memory_order_relaxed)) || - (probedKey == reusable && mainHash->entries[index].key.compare_exchange_strong(reusable, id, std::memory_order_acquire, std::memory_order_acquire))) { -#else - if ((probedKey == empty && mainHash->entries[index].key.compare_exchange_strong(empty, id, - std::memory_order_relaxed, - std::memory_order_relaxed))) { -#endif - mainHash->entries[index].value = producer; - break; - } - ++index; - } - return producer; - } - - // Hmm, the old hash is quite full and somebody else is busy allocating a new one. - // We need to wait for the allocating thread to finish (if it succeeds, we add, if not, - // we try to allocate ourselves). - mainHash = implicitProducerHash.load(std::memory_order_acquire); - } - } - -#ifdef MOODYCAMEL_CPP11_THREAD_LOCAL_SUPPORTED - void implicit_producer_thread_exited(ImplicitProducer* producer) - { - // Remove from thread exit listeners - details::ThreadExitNotifier::unsubscribe(&producer->threadExitListener); - - // Remove from hash -#if MCDBGQ_NOLOCKFREE_IMPLICITPRODHASH - debug::DebugLock lock(implicitProdMutex); -#endif - auto hash = implicitProducerHash.load(std::memory_order_acquire); - assert(hash != nullptr); // The thread exit listener is only registered if we were added to a hash in the first place - auto id = details::thread_id(); - auto hashedId = details::hash_thread_id(id); - details::thread_id_t probedKey; - - // We need to traverse all the hashes just in case other threads aren't on the current one yet and are - // trying to add an entry thinking there's a free slot (because they reused a producer) - for (; hash != nullptr; hash = hash->prev) { - auto index = hashedId; - do { - index &= hash->capacity - 1; - probedKey = hash->entries[index].key.load(std::memory_order_relaxed); - if (probedKey == id) { - hash->entries[index].key.store(details::invalid_thread_id2, std::memory_order_release); - break; - } - ++index; - } while (probedKey != details::invalid_thread_id); // Can happen if the hash has changed but we weren't put back in it yet, or if we weren't added to this hash in the first place - } - - // Mark the queue as being recyclable - producer->inactive.store(true, std::memory_order_release); - } - - static void implicit_producer_thread_exited_callback(void* userData) - { - auto producer = static_cast(userData); - auto queue = producer->parent; - queue->implicit_producer_thread_exited(producer); - } -#endif - - ////////////////////////////////// - // Utility functions - ////////////////////////////////// - - template - static inline U *create_array(size_t count) { - assert(count > 0); - auto p = static_cast((Traits::malloc)(sizeof(U) * count)); - if (p == nullptr) { - return nullptr; - } - - for (size_t i = 0; i != count; ++i) { - new(p + i) U(); - } - return p; - } - - template - static inline void destroy_array(U *p, size_t count) { - if (p != nullptr) { - assert(count > 0); - for (size_t i = count; i != 0;) { - (p + --i)->~U(); - } - (Traits::free)(p); - } - } - - template - static inline U *create() { - auto p = (Traits::malloc)(sizeof(U)); - return p != nullptr ? new(p) U : nullptr; - } - - template - static inline U *create(A1 &&a1) { - auto p = (Traits::malloc)(sizeof(U)); - return p != nullptr ? new(p) U(std::forward(a1)) : nullptr; - } - - template - static inline void destroy(U *p) { - if (p != nullptr) { - p->~U(); - } - (Traits::free)(p); - } - - private: - std::atomic producerListTail; - std::atomic producerCount; - - std::atomic initialBlockPoolIndex; - Block *initialBlockPool; - size_t initialBlockPoolSize; - -#if !MCDBGQ_USEDEBUGFREELIST - FreeList freeList; -#else - debug::DebugFreeList freeList; -#endif - - std::atomic implicitProducerHash; - std::atomic implicitProducerHashCount; // Number of slots logically used - ImplicitProducerHash initialImplicitProducerHash; - std::array initialImplicitProducerHashEntries; - std::atomic_flag implicitProducerHashResizeInProgress; - - std::atomic nextExplicitConsumerId; - std::atomic globalExplicitConsumerOffset; - -#if MCDBGQ_NOLOCKFREE_IMPLICITPRODHASH - debug::DebugMutex implicitProdMutex; -#endif - -#ifdef MOODYCAMEL_QUEUE_INTERNAL_DEBUG - std::atomic explicitProducers; - std::atomic implicitProducers; -#endif -}; - - -template -ProducerToken::ProducerToken(ConcurrentQueue &queue) - : producer(queue.recycle_or_create_producer(true)) { - if (producer != nullptr) { - producer->token = this; - } -} - -template -ProducerToken::ProducerToken(BlockingConcurrentQueue &queue) - : producer( - reinterpret_cast *>(&queue)->recycle_or_create_producer(true)) { - if (producer != nullptr) { - producer->token = this; - } -} - -template -ConsumerToken::ConsumerToken(ConcurrentQueue &queue) - : itemsConsumedFromCurrent(0), currentProducer(nullptr), desiredProducer(nullptr) { - initialOffset = queue.nextExplicitConsumerId.fetch_add(1, std::memory_order_release); - lastKnownGlobalOffset = -1; -} - -template -ConsumerToken::ConsumerToken(BlockingConcurrentQueue &queue) - : itemsConsumedFromCurrent(0), currentProducer(nullptr), desiredProducer(nullptr) { - initialOffset = reinterpret_cast *>(&queue)->nextExplicitConsumerId.fetch_add( - 1, std::memory_order_release); - lastKnownGlobalOffset = -1; -} - -template -inline void swap(ConcurrentQueue &a, ConcurrentQueue &b) MOODYCAMEL_NOEXCEPT { - a.swap(b); -} - -inline void swap(ProducerToken &a, ProducerToken &b) MOODYCAMEL_NOEXCEPT { - a.swap(b); -} - -inline void swap(ConsumerToken &a, ConsumerToken &b) MOODYCAMEL_NOEXCEPT { - a.swap(b); -} - -template -inline void swap(typename ConcurrentQueue::ImplicitProducerKVP &a, - typename ConcurrentQueue::ImplicitProducerKVP &b) MOODYCAMEL_NOEXCEPT { - a.swap(b); -} - -} - -} // namespace dmlc - -#if defined(__GNUC__) -#pragma GCC diagnostic pop -#endif - -#endif // DMLC_CONCURRENTQUEUE_H_ -//! \endcond Doxygen_Suppress diff --git a/include/dmlc/config.h b/include/dmlc/config.h deleted file mode 100644 index a4c5b53d827d..000000000000 --- a/include/dmlc/config.h +++ /dev/null @@ -1,186 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file config.h - * \brief defines config parser class - */ -#ifndef DMLC_CONFIG_H_ -#define DMLC_CONFIG_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -/*! \brief namespace for dmlc */ -namespace dmlc { - -/*! - * \brief class for config parser - * - * Two modes are supported: - * 1. non-multi value mode: if two same keys in the configure file, the later one will replace the - * ealier one; when using iterator, the order will be the "last effective insersion" order - * 2. multi value mode: multiple values with the same key could co-exist; when using iterator, the - * order will be the insersion order. - * - * [Basic usage] - * - * Config cfg(file_input_stream); - * for(Config::ConfigIterator iter = cfg.begin(); iter != cfg.end(); ++iter) { - * ConfigEntry ent = *iter; - * std::string key = ent.first; - * std::string value = ent.second; - * do_something_with(key, value); - * } - */ -class Config { - public: - /*! - * \brief type when extracting from iterator - */ - typedef std::pair ConfigEntry; - - /*! - * \brief iterator class - */ - class ConfigIterator; - - /*! - * \brief create empty config - * \param multi_value whether the config supports multi value - */ - explicit Config(bool multi_value = false); - /*! - * \brief create config and load content from the given stream - * \param is input stream - * \param multi_value whether the config supports multi value - */ - explicit Config(std::istream& is, bool multi_value = false); // NOLINT(*) - /*! - * \brief clear all the values - */ - void Clear(void); - /*! - * \brief load the contents from the stream - * \param is the stream as input - */ - void LoadFromStream(std::istream& is); // NOLINT(*) - /*! - * \brief set a key-value pair into the config; if the key already exists in the configure file, - * it will either replace the old value with the given one (in non-multi value mode) or - * store it directly (in multi-value mode); - * \param key key - * \param value value - * \param is_string whether the value should be wrapped by quotes in proto string - */ - template - void SetParam(const std::string& key, const T& value, bool is_string = false); - - /*! - * \brief get the config under the key; if multiple values exist for the same key, - * return the last inserted one. - * \param key key - * \return config value - */ - const std::string& GetParam(const std::string& key) const; - - /*! - * \brief check whether the configure value given by the key should be wrapped by quotes - * \param key key - * \return whether the configure value is represented by string - */ - bool IsGenuineString(const std::string& key) const; - - /*! - * \brief transform all the configuration into string recognizable to protobuf - * \return string that could be parsed directly by protobuf - */ - std::string ToProtoString(void) const; - - /*! - * \brief get begin iterator - * \return begin iterator - */ - ConfigIterator begin() const; - - /*! - * \brief get end iterator - * \return end iterator - */ - ConfigIterator end() const; - - public: - /*! - * \brief iterator class - */ - class ConfigIterator : public std::iterator< std::input_iterator_tag, ConfigEntry > { - friend class Config; - public: - /*! - * \brief copy constructor - */ - ConfigIterator(const ConfigIterator& other); - /*! - * \brief uni-increment operators - * \return the reference of current config - */ - ConfigIterator& operator++(); - /*! - * \brief uni-increment operators - * \return the reference of current config - */ - ConfigIterator operator++(int); // NOLINT(*) - /*! - * \brief compare operators - * \param rhs the other config to compare against - * \return the compared result - */ - bool operator == (const ConfigIterator& rhs) const; - /*! - * \brief compare operators not equal - * \param rhs the other config to compare against - * \return the compared result - */ - bool operator != (const ConfigIterator& rhs) const; - /*! - * \brief retrieve value from operator - */ - ConfigEntry operator * () const; - - private: - ConfigIterator(size_t index, const Config* config); - void FindNextIndex(); - - private: - size_t index_; - const Config* config_; - }; - - private: - struct ConfigValue { - std::vector val; - std::vector insert_index; - bool is_string; - }; - void Insert(const std::string& key, const std::string& value, bool is_string); - - private: - std::map config_map_; - std::vector > order_; - const bool multi_value_; -}; - -template -void Config::SetParam(const std::string& key, const T& value, bool is_string) { - std::ostringstream oss; - oss << value; - Insert(key, oss.str(), is_string); -} - -} // namespace dmlc - -#endif // DMLC_CONFIG_H_ diff --git a/include/dmlc/data.h b/include/dmlc/data.h deleted file mode 100644 index 16e0667322fb..000000000000 --- a/include/dmlc/data.h +++ /dev/null @@ -1,397 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file data.h - * \brief defines common input data structure, - * and interface for handling the input data - */ -#ifndef DMLC_DATA_H_ -#define DMLC_DATA_H_ - -#include -#include -#include -#include "./base.h" -#include "./io.h" -#include "./logging.h" -#include "./registry.h" - -// To help C Preprocessor with processing c++ templated types -#define __DMLC_COMMA , - -namespace dmlc { -/*! - * \brief this defines the float point - * that will be used to store feature values - */ -typedef float real_t; - -/*! - * \brief this defines the unsigned integer type - * that can normally be used to store feature index - */ -typedef unsigned index_t; - -// This file describes common data structure that can be used -// for large-scale machine learning, this may not be a complete list -// But we will keep the most common and useful ones, and keep adding new ones -/*! - * \brief data iterator interface - * this is not a C++ style iterator, but nice for data pulling:) - * This interface is used to pull in the data - * The system can do some useful tricks for you like pre-fetching - * from disk and pre-computation. - * - * Usage example: - * \code - * - * itr->BeforeFirst(); - * while (itr->Next()) { - * const DType &batch = itr->Value(); - * // some computations - * } - * \endcode - * \tparam DType the data type - */ -template -class DataIter { - public: - /*! \brief destructor */ - virtual ~DataIter(void) {} - /*! \brief set before first of the item */ - virtual void BeforeFirst(void) = 0; - /*! \brief move to next item */ - virtual bool Next(void) = 0; - /*! \brief get current data */ - virtual const DType &Value(void) const = 0; -}; - -/*! - * \brief one row of training instance - * \tparam IndexType type of index - * \tparam DType type of data (both label and value will be of DType - */ -template -class Row { - public: - /*! \brief label of the instance */ - const DType *label; - /*! \brief weight of the instance */ - const real_t *weight; - /*! \brief session-id of the instance */ - const uint64_t *qid; - /*! \brief length of the sparse vector */ - size_t length; - /*! - * \brief field of each instance - */ - const IndexType *field; - /*! - * \brief index of each instance - */ - const IndexType *index; - /*! - * \brief array value of each instance, this can be NULL - * indicating every value is set to be 1 - */ - const DType *value; - /*! - * \param i the input index - * \return field for i-th feature - */ - inline IndexType get_field(size_t i) const { - return field[i]; - } - /*! - * \param i the input index - * \return i-th feature - */ - inline IndexType get_index(size_t i) const { - return index[i]; - } - /*! - * \param i the input index - * \return i-th feature value, this function is always - * safe even when value == NULL - */ - inline DType get_value(size_t i) const { - return value == NULL ? DType(1.0f) : value[i]; - } - /*! - * \return the label of the instance - */ - inline DType get_label() const { - return *label; - } - /*! - * \return the weight of the instance, this function is always - * safe even when weight == NULL - */ - inline real_t get_weight() const { - return weight == NULL ? 1.0f : *weight; - } - /*! - * \return the qid of the instance, this function is always - * safe even when qid == NULL - */ - inline uint64_t get_qid() const { - return qid == NULL ? 0 : *qid; - } - /*! - * \brief helper function to compute dot product of current - * \param weight the dense array of weight we want to product - * \param size the size of the weight vector - * \tparam V type of the weight vector - * \return the result of dot product - */ - template - inline V SDot(const V *weight, size_t size) const { - V sum = static_cast(0); - if (value == NULL) { - for (size_t i = 0; i < length; ++i) { - CHECK(index[i] < size) << "feature index exceed bound"; - sum += weight[index[i]]; - } - } else { - for (size_t i = 0; i < length; ++i) { - CHECK(index[i] < size) << "feature index exceed bound"; - sum += weight[index[i]] * value[i]; - } - } - return sum; - } -}; - -/*! - * \brief a block of data, containing several rows in sparse matrix - * This is useful for (streaming-sxtyle) algorithms that scans through rows of data - * examples include: SGD, GD, L-BFGS, kmeans - * - * The size of batch is usually large enough so that parallelizing over the rows - * can give significant speedup - * \tparam IndexType type to store the index used in row batch - * \tparam DType type to store the label and value used in row batch - */ -template -struct RowBlock { - /*! \brief batch size */ - size_t size; - /*! \brief array[size+1], row pointer to beginning of each rows */ - const size_t *offset; - /*! \brief array[size] label of each instance */ - const DType *label; - /*! \brief With weight: array[size] label of each instance, otherwise nullptr */ - const real_t *weight; - /*! \brief With qid: array[size] session id of each instance, otherwise nullptr */ - const uint64_t *qid; - /*! \brief field id*/ - const IndexType *field; - /*! \brief feature index */ - const IndexType *index; - /*! \brief feature value, can be NULL, indicating all values are 1 */ - const DType *value; - /*! - * \brief get specific rows in the batch - * \param rowid the rowid in that row - * \return the instance corresponding to the row - */ - inline Row operator[](size_t rowid) const; - /*! \return memory cost of the block in bytes */ - inline size_t MemCostBytes(void) const { - size_t cost = size * (sizeof(size_t) + sizeof(DType)); - if (weight != NULL) cost += size * sizeof(real_t); - if (qid != NULL) cost += size * sizeof(size_t); - size_t ndata = offset[size] - offset[0]; - if (field != NULL) cost += ndata * sizeof(IndexType); - if (index != NULL) cost += ndata * sizeof(IndexType); - if (value != NULL) cost += ndata * sizeof(DType); - return cost; - } - /*! - * \brief slice a RowBlock to get rows in [begin, end) - * \param begin the begin row index - * \param end the end row index - * \return the sliced RowBlock - */ - inline RowBlock Slice(size_t begin, size_t end) const { - CHECK(begin <= end && end <= size); - RowBlock ret; - ret.size = end - begin; - ret.label = label + begin; - if (weight != NULL) { - ret.weight = weight + begin; - } else { - ret.weight = NULL; - } - if (qid != NULL) { - ret.qid = qid + begin; - } else { - ret.qid = NULL; - } - ret.offset = offset + begin; - ret.field = field; - ret.index = index; - ret.value = value; - return ret; - } -}; - -/*! - * \brief Data structure that holds the data - * Row block iterator interface that gets RowBlocks - * Difference between RowBlockIter and Parser: - * RowBlockIter caches the data internally that can be used - * to iterate the dataset multiple times, - * Parser holds very limited internal state and was usually - * used to read data only once - * - * \sa Parser - * \tparam IndexType type of index in RowBlock - * \tparam DType type of label and value in RowBlock - * Create function was only implemented for IndexType uint64_t and uint32_t - * and DType real_t and int - */ -template -class RowBlockIter : public DataIter > { - public: - /*! - * \brief create a new instance of iterator that returns rowbatch - * by default, a in-memory based iterator will be returned - * - * \param uri the uri of the input, can contain hdfs prefix - * \param part_index the part id of current input - * \param num_parts total number of splits - * \param type type of dataset can be: "libsvm", ... - * - * \return the created data iterator - */ - static RowBlockIter * - Create(const char *uri, - unsigned part_index, - unsigned num_parts, - const char *type); - /*! \return maximum feature dimension in the dataset */ - virtual size_t NumCol() const = 0; -}; - -/*! - * \brief parser interface that parses input data - * used to load dmlc data format into your own data format - * Difference between RowBlockIter and Parser: - * RowBlockIter caches the data internally that can be used - * to iterate the dataset multiple times, - * Parser holds very limited internal state and was usually - * used to read data only once - * - * - * \sa RowBlockIter - * \tparam IndexType type of index in RowBlock - * \tparam DType type of label and value in RowBlock - * Create function was only implemented for IndexType uint64_t and uint32_t - * and DType real_t and int - */ -template -class Parser : public DataIter > { - public: - /*! - * \brief create a new instance of parser based on the "type" - * - * \param uri_ the uri of the input, can contain hdfs prefix - * \param part_index the part id of current input - * \param num_parts total number of splits - * \param type type of dataset can be: "libsvm", "auto", ... - * - * When "auto" is passed, the type is decided by format argument string in URI. - * - * \return the created parser - */ - static Parser * - Create(const char *uri_, - unsigned part_index, - unsigned num_parts, - const char *type); - /*! \return size of bytes read so far */ - virtual size_t BytesRead(void) const = 0; - /*! \brief Factory type of the parser*/ - typedef Parser* (*Factory) - (const std::string& path, - const std::map& args, - unsigned part_index, - unsigned num_parts); -}; - -/*! - * \brief registry entry of parser factory - * \tparam IndexType The type of index - * \tparam DType The type of label and value - */ -template -struct ParserFactoryReg - : public FunctionRegEntryBase, - typename Parser::Factory> {}; - -/*! - * \brief Register a new distributed parser to dmlc-core. - * - * \param IndexType The type of Batch index, can be uint32_t or uint64_t - * \param DataType The type of Batch label and value, can be real_t or int - * \param TypeName The typename of of the data. - * \param FactoryFunction The factory function that creates the parser. - * - * \begincode - * - * // define the factory function - * template - * Parser* - * CreateLibSVMParser(const char* uri, unsigned part_index, unsigned num_parts) { - * return new LibSVMParser(uri, part_index, num_parts); - * } - * - * // Register it to DMLC - * // Then we can use Parser::Create(uri, part_index, num_parts, "libsvm"); - * // to create the parser - * - * DMLC_REGISTER_DATA_PARSER(uint32_t, real_t, libsvm, CreateLibSVMParser); - * DMLC_REGISTER_DATA_PARSER(uint64_t, real_t, libsvm, CreateLibSVMParser); - * - * \endcode - */ -#define DMLC_REGISTER_DATA_PARSER(IndexType, DataType, TypeName, FactoryFunction) \ - DMLC_REGISTRY_REGISTER(ParserFactoryReg, \ - ParserFactoryReg ## _ ## IndexType ## _ ## DataType, TypeName) \ - .set_body(FactoryFunction) - - -// implementation of operator[] -template -inline Row -RowBlock::operator[](size_t rowid) const { - CHECK(rowid < size); - Row inst; - inst.label = label + rowid; - if (weight != NULL) { - inst.weight = weight + rowid; - } else { - inst.weight = NULL; - } - if (qid != NULL) { - inst.qid = qid + rowid; - } else { - inst.qid = NULL; - } - inst.length = offset[rowid + 1] - offset[rowid]; - if (field != NULL) { - inst.field = field + offset[rowid]; - } else { - inst.field = NULL; - } - inst.index = index + offset[rowid]; - if (value == NULL) { - inst.value = NULL; - } else { - inst.value = value + offset[rowid]; - } - return inst; -} - -} // namespace dmlc -#endif // DMLC_DATA_H_ diff --git a/include/dmlc/endian.h b/include/dmlc/endian.h deleted file mode 100644 index e7deeaa49034..000000000000 --- a/include/dmlc/endian.h +++ /dev/null @@ -1,44 +0,0 @@ -/*! - * Copyright (c) 2017 by Contributors - * \file endian.h - * \brief Endian testing, need c++11 - */ -#ifndef DMLC_ENDIAN_H_ -#define DMLC_ENDIAN_H_ - -#include "./base.h" - -#if defined(__APPLE__) || defined(_WIN32) -#define DMLC_LITTLE_ENDIAN 1 -#else -#include -#define DMLC_LITTLE_ENDIAN (__BYTE_ORDER == __LITTLE_ENDIAN) -#endif - -/*! \brief whether serialize using little endian */ -#define DMLC_IO_NO_ENDIAN_SWAP (DMLC_LITTLE_ENDIAN == DMLC_IO_USE_LITTLE_ENDIAN) - -namespace dmlc { - -/*! - * \brief A generic inplace byte swapping function. - * \param data The data pointer. - * \param elem_bytes The number of bytes of the data elements - * \param num_elems Number of elements in the data. - * \note Always try pass in constant elem_bytes to enable - * compiler optimization - */ -inline void ByteSwap(void* data, size_t elem_bytes, size_t num_elems) { - for (size_t i = 0; i < num_elems; ++i) { - uint8_t* bptr = reinterpret_cast(data) + elem_bytes * i; - for (size_t j = 0; j < elem_bytes / 2; ++j) { - uint8_t v = bptr[elem_bytes - 1 - j]; - bptr[elem_bytes - 1 - j] = bptr[j]; - bptr[j] = v; - } - } -} - -} // namespace dmlc -#endif // DMLC_ENDIAN_H_ - diff --git a/include/dmlc/input_split_shuffle.h b/include/dmlc/input_split_shuffle.h deleted file mode 100644 index fc2c65e0a91e..000000000000 --- a/include/dmlc/input_split_shuffle.h +++ /dev/null @@ -1,168 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file input_split_shuffle.h - * \brief base class to construct input split with global shuffling - * \author Yifeng Geng - */ -#ifndef DMLC_INPUT_SPLIT_SHUFFLE_H_ -#define DMLC_INPUT_SPLIT_SHUFFLE_H_ - -#include -#include -#include -#include -#include -#include - -namespace dmlc { -/*! \brief class to construct input split with global shuffling */ -class InputSplitShuffle : public InputSplit { - public: - // destructor - virtual ~InputSplitShuffle(void) { source_.reset(); } - // implement BeforeFirst - virtual void BeforeFirst(void) { - if (num_shuffle_parts_ > 1) { - std::shuffle(shuffle_indexes_.begin(), shuffle_indexes_.end(), trnd_); - int idx = shuffle_indexes_[0] + part_index_ * num_shuffle_parts_; - source_->ResetPartition(idx, num_parts_ * num_shuffle_parts_); - cur_shuffle_idx_ = 0; - } else { - source_->BeforeFirst(); - } - } - virtual void HintChunkSize(size_t chunk_size) { - source_->HintChunkSize(chunk_size); - } - virtual size_t GetTotalSize(void) { - return source_->GetTotalSize(); - } - // implement next record - virtual bool NextRecord(Blob *out_rec) { - if (num_shuffle_parts_ > 1) { - if (!source_->NextRecord(out_rec)) { - if (cur_shuffle_idx_ == num_shuffle_parts_ - 1) { - return false; - } - ++cur_shuffle_idx_; - int idx = - shuffle_indexes_[cur_shuffle_idx_] + part_index_ * num_shuffle_parts_; - source_->ResetPartition(idx, num_parts_ * num_shuffle_parts_); - return NextRecord(out_rec); - } else { - return true; - } - } else { - return source_->NextRecord(out_rec); - } - } - // implement next chunk - virtual bool NextChunk(Blob* out_chunk) { - if (num_shuffle_parts_ > 1) { - if (!source_->NextChunk(out_chunk)) { - if (cur_shuffle_idx_ == num_shuffle_parts_ - 1) { - return false; - } - ++cur_shuffle_idx_; - int idx = - shuffle_indexes_[cur_shuffle_idx_] + part_index_ * num_shuffle_parts_; - source_->ResetPartition(idx, num_parts_ * num_shuffle_parts_); - return NextChunk(out_chunk); - } else { - return true; - } - } else { - return source_->NextChunk(out_chunk); - } - } - // implement ResetPartition. - virtual void ResetPartition(unsigned rank, unsigned nsplit) { - CHECK(nsplit == num_parts_) << "num_parts is not consistent!"; - int idx = shuffle_indexes_[0] + rank * num_shuffle_parts_; - source_->ResetPartition(idx, nsplit * num_shuffle_parts_); - cur_shuffle_idx_ = 0; - } - /*! - * \brief constructor - * \param uri the uri of the input, can contain hdfs prefix - * \param part_index the part id of current input - * \param num_parts total number of splits - * \param type type of record - * List of possible types: "text", "recordio" - * - "text": - * text file, each line is treated as a record - * input split will split on '\\n' or '\\r' - * - "recordio": - * binary recordio file, see recordio.h - * \param num_shuffle_parts number of shuffle chunks for each split - * \param shuffle_seed shuffle seed for chunk shuffling - */ - InputSplitShuffle(const char* uri, - unsigned part_index, - unsigned num_parts, - const char* type, - unsigned num_shuffle_parts, - int shuffle_seed) - : part_index_(part_index), - num_parts_(num_parts), - num_shuffle_parts_(num_shuffle_parts), - cur_shuffle_idx_(0) { - for (unsigned i = 0; i < num_shuffle_parts_; i++) { - shuffle_indexes_.push_back(i); - } - trnd_.seed(kRandMagic_ + part_index_ + num_parts_ + num_shuffle_parts_ + - shuffle_seed); - std::shuffle(shuffle_indexes_.begin(), shuffle_indexes_.end(), trnd_); - int idx = shuffle_indexes_[cur_shuffle_idx_] + part_index_ * num_shuffle_parts_; - source_.reset( - InputSplit::Create(uri, idx , num_parts_ * num_shuffle_parts_, type)); - } - /*! - * \brief factory function: - * create input split with chunk shuffling given a uri - * \param uri the uri of the input, can contain hdfs prefix - * \param part_index the part id of current input - * \param num_parts total number of splits - * \param type type of record - * List of possible types: "text", "recordio" - * - "text": - * text file, each line is treated as a record - * input split will split on '\\n' or '\\r' - * - "recordio": - * binary recordio file, see recordio.h - * \param num_shuffle_parts number of shuffle chunks for each split - * \param shuffle_seed shuffle seed for chunk shuffling - * \return a new input split - * \sa InputSplit::Type - */ - static InputSplit* Create(const char* uri, - unsigned part_index, - unsigned num_parts, - const char* type, - unsigned num_shuffle_parts, - int shuffle_seed) { - CHECK(num_shuffle_parts > 0) << "number of shuffle parts should be greater than zero!"; - return new InputSplitShuffle( - uri, part_index, num_parts, type, num_shuffle_parts, shuffle_seed); - } - - private: - // magic nyumber for seed - static const int kRandMagic_ = 666; - /*! \brief random engine */ - std::mt19937 trnd_; - /*! \brief inner inputsplit */ - std::unique_ptr source_; - /*! \brief part index */ - unsigned part_index_; - /*! \brief number of parts */ - unsigned num_parts_; - /*! \brief the number of block for shuffling*/ - unsigned num_shuffle_parts_; - /*! \brief current shuffle block index */ - unsigned cur_shuffle_idx_; - /*! \brief shuffled indexes */ - std::vector shuffle_indexes_; -}; -} // namespace dmlc -#endif // DMLC_INPUT_SPLIT_SHUFFLE_H_ diff --git a/include/dmlc/io.h b/include/dmlc/io.h deleted file mode 100644 index 5e76e4c6e24c..000000000000 --- a/include/dmlc/io.h +++ /dev/null @@ -1,522 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file io.h - * \brief defines serializable interface of dmlc - */ -#ifndef DMLC_IO_H_ -#define DMLC_IO_H_ -#include -#include -#include -#include -#include -#include -#include "./logging.h" - -// include uint64_t only to make io standalone -#ifdef _MSC_VER -/*! \brief uint64 */ -typedef unsigned __int64 uint64_t; -#else -#include -#endif - -/*! \brief namespace for dmlc */ -namespace dmlc { -/*! - * \brief interface of stream I/O for serialization - */ -class Stream { // NOLINT(*) - public: - /*! - * \brief reads data from a stream - * \param ptr pointer to a memory buffer - * \param size block size - * \return the size of data read - */ - virtual size_t Read(void *ptr, size_t size) = 0; - /*! - * \brief writes data to a stream - * \param ptr pointer to a memory buffer - * \param size block size - */ - virtual void Write(const void *ptr, size_t size) = 0; - /*! \brief virtual destructor */ - virtual ~Stream(void) {} - /*! - * \brief generic factory function - * create an stream, the stream will close the underlying files upon deletion - * - * \param uri the uri of the input currently we support - * hdfs://, s3://, and file:// by default file:// will be used - * \param flag can be "w", "r", "a" - * \param allow_null whether NULL can be returned, or directly report error - * \return the created stream, can be NULL when allow_null == true and file do not exist - */ - static Stream *Create(const char *uri, - const char* const flag, - bool allow_null = false); - // helper functions to write/read different data structures - /*! - * \brief writes a data to stream. - * - * dmlc::Stream support Write/Read of most STL composites and base types. - * If the data type is not supported, a compile time error will be issued. - * - * This function is endian-aware, - * the output endian defined by DMLC_IO_USE_LITTLE_ENDIAN - * - * \param data data to be written - * \tparam T the data type to be written - */ - template - inline void Write(const T &data); - /*! - * \brief loads a data from stream. - * - * dmlc::Stream support Write/Read of most STL composites and base types. - * If the data type is not supported, a compile time error will be issued. - * - * This function is endian-aware, - * the input endian defined by DMLC_IO_USE_LITTLE_ENDIAN - * - * \param out_data place holder of data to be deserialized - * \return whether the load was successful - */ - template - inline bool Read(T *out_data); - /*! - * \brief Endian aware write array of data. - * \param data The data pointer - * \param num_elems Number of elements - * \tparam T the data type. - */ - template - inline void WriteArray(const T* data, size_t num_elems); - /*! - * \brief Endian aware read array of data. - * \param data The data pointer - * \param num_elems Number of elements - * \tparam T the data type. - * \return whether the load was successful - */ - template - inline bool ReadArray(T* data, size_t num_elems); -}; - -/*! \brief interface of i/o stream that support seek */ -class SeekStream: public Stream { - public: - // virtual destructor - virtual ~SeekStream(void) {} - /*! \brief seek to certain position of the file */ - virtual void Seek(size_t pos) = 0; - /*! \brief tell the position of the stream */ - virtual size_t Tell(void) = 0; - /*! - * \brief generic factory function - * create an SeekStream for read only, - * the stream will close the underlying files upon deletion - * error will be reported and the system will exit when create failed - * \param uri the uri of the input currently we support - * hdfs://, s3://, and file:// by default file:// will be used - * \param allow_null whether NULL can be returned, or directly report error - * \return the created stream, can be NULL when allow_null == true and file do not exist - */ - static SeekStream *CreateForRead(const char *uri, - bool allow_null = false); -}; - -/*! \brief interface for serializable objects */ -class Serializable { - public: - /*! \brief virtual destructor */ - virtual ~Serializable() {} - /*! - * \brief load the model from a stream - * \param fi stream where to load the model from - */ - virtual void Load(Stream *fi) = 0; - /*! - * \brief saves the model to a stream - * \param fo stream where to save the model to - */ - virtual void Save(Stream *fo) const = 0; -}; - -/*! - * \brief input split creates that allows reading - * of records from split of data, - * independent part that covers all the dataset - * - * see InputSplit::Create for definition of record - */ -class InputSplit { - public: - /*! \brief a blob of memory region */ - struct Blob { - /*! \brief points to start of the memory region */ - void *dptr; - /*! \brief size of the memory region */ - size_t size; - }; - /*! - * \brief hint the inputsplit how large the chunk size - * it should return when implementing NextChunk - * this is a hint so may not be enforced, - * but InputSplit will try adjust its internal buffer - * size to the hinted value - * \param chunk_size the chunk size - */ - virtual void HintChunkSize(size_t chunk_size) {} - /*! \brief get the total size of the InputSplit */ - virtual size_t GetTotalSize(void) = 0; - /*! \brief reset the position of InputSplit to beginning */ - virtual void BeforeFirst(void) = 0; - /*! - * \brief get the next record, the returning value - * is valid until next call to NextRecord, NextChunk or NextBatch - * caller can modify the memory content of out_rec - * - * For text, out_rec contains a single line - * For recordio, out_rec contains one record content(with header striped) - * - * \param out_rec used to store the result - * \return true if we can successfully get next record - * false if we reached end of split - * \sa InputSplit::Create for definition of record - */ - virtual bool NextRecord(Blob *out_rec) = 0; - /*! - * \brief get a chunk of memory that can contain multiple records, - * the caller needs to parse the content of the resulting chunk, - * for text file, out_chunk can contain data of multiple lines - * for recordio, out_chunk can contain multiple records(including headers) - * - * This function ensures there won't be partial record in the chunk - * caller can modify the memory content of out_chunk, - * the memory is valid until next call to NextRecord, NextChunk or NextBatch - * - * Usually NextRecord is sufficient, NextChunk can be used by some - * multi-threaded parsers to parse the input content - * - * \param out_chunk used to store the result - * \return true if we can successfully get next record - * false if we reached end of split - * \sa InputSplit::Create for definition of record - * \sa RecordIOChunkReader to parse recordio content from out_chunk - */ - virtual bool NextChunk(Blob *out_chunk) = 0; - /*! - * \brief get a chunk of memory that can contain multiple records, - * with hint for how many records is needed, - * the caller needs to parse the content of the resulting chunk, - * for text file, out_chunk can contain data of multiple lines - * for recordio, out_chunk can contain multiple records(including headers) - * - * This function ensures there won't be partial record in the chunk - * caller can modify the memory content of out_chunk, - * the memory is valid until next call to NextRecord, NextChunk or NextBatch - * - * - * \param out_chunk used to store the result - * \param n_records used as a hint for how many records should be returned, may be ignored - * \return true if we can successfully get next record - * false if we reached end of split - * \sa InputSplit::Create for definition of record - * \sa RecordIOChunkReader to parse recordio content from out_chunk - */ - virtual bool NextBatch(Blob *out_chunk, size_t n_records) { - return NextChunk(out_chunk); - } - /*! \brief destructor*/ - virtual ~InputSplit(void) {} - /*! - * \brief reset the Input split to a certain part id, - * The InputSplit will be pointed to the head of the new specified segment. - * This feature may not be supported by every implementation of InputSplit. - * \param part_index The part id of the new input. - * \param num_parts The total number of parts. - */ - virtual void ResetPartition(unsigned part_index, unsigned num_parts) = 0; - /*! - * \brief factory function: - * create input split given a uri - * \param uri the uri of the input, can contain hdfs prefix - * \param part_index the part id of current input - * \param num_parts total number of splits - * \param type type of record - * List of possible types: "text", "recordio", "indexed_recordio" - * - "text": - * text file, each line is treated as a record - * input split will split on '\\n' or '\\r' - * - "recordio": - * binary recordio file, see recordio.h - * - "indexed_recordio": - * binary recordio file with index, see recordio.h - * \return a new input split - * \sa InputSplit::Type - */ - static InputSplit* Create(const char *uri, - unsigned part_index, - unsigned num_parts, - const char *type); - /*! - * \brief factory function: - * create input split given a uri for input and index - * \param uri the uri of the input, can contain hdfs prefix - * \param index_uri the uri of the index, can contain hdfs prefix - * \param part_index the part id of current input - * \param num_parts total number of splits - * \param type type of record - * List of possible types: "text", "recordio", "indexed_recordio" - * - "text": - * text file, each line is treated as a record - * input split will split on '\\n' or '\\r' - * - "recordio": - * binary recordio file, see recordio.h - * - "indexed_recordio": - * binary recordio file with index, see recordio.h - * \param shuffle whether to shuffle the output from the InputSplit, - * supported only by "indexed_recordio" type. - * Defaults to "false" - * \param seed random seed to use in conjunction with the "shuffle" - * option. Defaults to 0 - * \param batch_size a hint to InputSplit what is the intended number - * of examples return per batch. Used only by - * "indexed_recordio" type - * \param recurse_directories whether to recursively traverse directories - * \return a new input split - * \sa InputSplit::Type - */ - static InputSplit* Create(const char *uri, - const char *index_uri, - unsigned part_index, - unsigned num_parts, - const char *type, - const bool shuffle = false, - const int seed = 0, - const size_t batch_size = 256, - const bool recurse_directories = false); -}; - -#ifndef _LIBCPP_SGX_NO_IOSTREAMS -/*! - * \brief a std::ostream class that can can wrap Stream objects, - * can use ostream with that output to underlying Stream - * - * Usage example: - * \code - * - * Stream *fs = Stream::Create("hdfs:///test.txt", "w"); - * dmlc::ostream os(fs); - * os << "hello world" << std::endl; - * delete fs; - * \endcode - */ -class ostream : public std::basic_ostream { - public: - /*! - * \brief construct std::ostream type - * \param stream the Stream output to be used - * \param buffer_size internal streambuf size - */ - explicit ostream(Stream *stream, - size_t buffer_size = (1 << 10)) - : std::basic_ostream(NULL), buf_(buffer_size) { - this->set_stream(stream); - } - // explictly synchronize the buffer - virtual ~ostream() DMLC_NO_EXCEPTION { - buf_.pubsync(); - } - /*! - * \brief set internal stream to be stream, reset states - * \param stream new stream as output - */ - inline void set_stream(Stream *stream) { - buf_.set_stream(stream); - this->rdbuf(&buf_); - } - - /*! \return how many bytes we written so far */ - inline size_t bytes_written(void) const { - return buf_.bytes_out(); - } - - private: - // internal streambuf - class OutBuf : public std::streambuf { - public: - explicit OutBuf(size_t buffer_size) - : stream_(NULL), buffer_(buffer_size), bytes_out_(0) { - if (buffer_size == 0) buffer_.resize(2); - } - // set stream to the buffer - inline void set_stream(Stream *stream); - - inline size_t bytes_out() const { return bytes_out_; } - private: - /*! \brief internal stream by StreamBuf */ - Stream *stream_; - /*! \brief internal buffer */ - std::vector buffer_; - /*! \brief number of bytes written so far */ - size_t bytes_out_; - // override sync - inline int_type sync(void); - // override overflow - inline int_type overflow(int c); - }; - /*! \brief buffer of the stream */ - OutBuf buf_; -}; - -/*! - * \brief a std::istream class that can can wrap Stream objects, - * can use istream with that output to underlying Stream - * - * Usage example: - * \code - * - * Stream *fs = Stream::Create("hdfs:///test.txt", "r"); - * dmlc::istream is(fs); - * is >> mydata; - * delete fs; - * \endcode - */ -class istream : public std::basic_istream { - public: - /*! - * \brief construct std::ostream type - * \param stream the Stream output to be used - * \param buffer_size internal buffer size - */ - explicit istream(Stream *stream, - size_t buffer_size = (1 << 10)) - : std::basic_istream(NULL), buf_(buffer_size) { - this->set_stream(stream); - } - virtual ~istream() DMLC_NO_EXCEPTION {} - /*! - * \brief set internal stream to be stream, reset states - * \param stream new stream as output - */ - inline void set_stream(Stream *stream) { - buf_.set_stream(stream); - this->rdbuf(&buf_); - } - /*! \return how many bytes we read so far */ - inline size_t bytes_read(void) const { - return buf_.bytes_read(); - } - - private: - // internal streambuf - class InBuf : public std::streambuf { - public: - explicit InBuf(size_t buffer_size) - : stream_(NULL), bytes_read_(0), - buffer_(buffer_size) { - if (buffer_size == 0) buffer_.resize(2); - } - // set stream to the buffer - inline void set_stream(Stream *stream); - // return how many bytes read so far - inline size_t bytes_read(void) const { - return bytes_read_; - } - private: - /*! \brief internal stream by StreamBuf */ - Stream *stream_; - /*! \brief how many bytes we read so far */ - size_t bytes_read_; - /*! \brief internal buffer */ - std::vector buffer_; - // override underflow - inline int_type underflow(); - }; - /*! \brief input buffer */ - InBuf buf_; -}; -#endif -} // namespace dmlc - -#include "./serializer.h" - -namespace dmlc { -// implementations of inline functions -template -inline void Stream::Write(const T &data) { - serializer::Handler::Write(this, data); -} -template -inline bool Stream::Read(T *out_data) { - return serializer::Handler::Read(this, out_data); -} - -template -inline void Stream::WriteArray(const T* data, size_t num_elems) { - for (size_t i = 0; i < num_elems; ++i) { - this->Write(data[i]); - } -} - -template -inline bool Stream::ReadArray(T* data, size_t num_elems) { - for (size_t i = 0; i < num_elems; ++i) { - if (!this->Read(data + i)) return false; - } - return true; -} - -#ifndef _LIBCPP_SGX_NO_IOSTREAMS -// implementations for ostream -inline void ostream::OutBuf::set_stream(Stream *stream) { - if (stream_ != NULL) this->pubsync(); - this->stream_ = stream; - this->setp(&buffer_[0], &buffer_[0] + buffer_.size() - 1); -} -inline int ostream::OutBuf::sync(void) { - if (stream_ == NULL) return -1; - std::ptrdiff_t n = pptr() - pbase(); - stream_->Write(pbase(), n); - this->pbump(-static_cast(n)); - bytes_out_ += n; - return 0; -} -inline int ostream::OutBuf::overflow(int c) { - *(this->pptr()) = c; - std::ptrdiff_t n = pptr() - pbase(); - this->pbump(-static_cast(n)); - if (c == EOF) { - stream_->Write(pbase(), n); - bytes_out_ += n; - } else { - stream_->Write(pbase(), n + 1); - bytes_out_ += n + 1; - } - return c; -} - -// implementations for istream -inline void istream::InBuf::set_stream(Stream *stream) { - stream_ = stream; - this->setg(&buffer_[0], &buffer_[0], &buffer_[0]); -} -inline int istream::InBuf::underflow() { - char *bhead = &buffer_[0]; - if (this->gptr() == this->egptr()) { - size_t sz = stream_->Read(bhead, buffer_.size()); - this->setg(bhead, bhead, bhead + sz); - bytes_read_ += sz; - } - if (this->gptr() == this->egptr()) { - return traits_type::eof(); - } else { - return traits_type::to_int_type(*gptr()); - } -} -#endif -} // namespace dmlc -#endif // DMLC_IO_H_ diff --git a/include/dmlc/json.h b/include/dmlc/json.h deleted file mode 100644 index ef82dfb57aa7..000000000000 --- a/include/dmlc/json.h +++ /dev/null @@ -1,981 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file json.h - * \brief Lightweight JSON Reader/Writer that read save into C++ data structs. - * This includes STL composites and structures. - */ -#ifndef DMLC_JSON_H_ -#define DMLC_JSON_H_ - -// This code requires C++11 to compile -#include -#ifndef _LIBCPP_SGX_NO_IOSTREAMS -#include -#endif -#include -#include -#include -#include -#include -#include - -#include "./base.h" -#include "./logging.h" -#include "./type_traits.h" - -#if DMLC_USE_CXX11 -#include -#include -#include -#if DMLC_STRICT_CXX11 -#if DMLC_ENABLE_RTTI -#include "./any.h" -#endif // DMLC_ENABLE_RTTI -#endif // DMLC_STRICT_CXX11 -#endif // DMLC_USE_CXX11 - -namespace dmlc { -/*! - * \brief Lightweight JSON Reader to read any STL compositions and structs. - * The user need to know the schema of the - * - */ -class JSONReader { - public: - /*! - * \brief Constructor. - * \param is the input source. - */ -#ifndef _LIBCPP_SGX_NO_IOSTREAMS - explicit JSONReader(std::istream *is) -#else - explicit JSONReader(std::string *is) -#endif - : is_(is), - line_count_r_(0), - line_count_n_(0) {} - /*! - * \brief Parse next JSON string. - * \param out_str the output string. - * \throw dmlc::Error when next token is not string - */ - inline void ReadString(std::string *out_str); - /*! - * \brief Read Number. - * \param out_value output value; - * \throw dmlc::Error when next token is not number of ValueType. - * \tparam ValueType type of the number - */ - template - inline void ReadNumber(ValueType *out_value); - /*! - * \brief Begin parsing an object. - * \code - * std::string key; - * // value can be any type that is json serializable. - * std::string value; - * reader->BeginObject(); - * while (reader->NextObjectItem(&key)) { - * // do somthing to key value - * reader->Read(&value); - * } - * \endcode - */ - inline void BeginObject(); - /*! - * \brief Begin parsing an array. - * \code - * // value can be any type that is json serializable. - * std::string value; - * reader->BeginArray(); - * while (reader->NextObjectArrayItem(&value)) { - * // do somthing to value - * } - * \endcode - */ - inline void BeginArray(); - /*! - * \brief Try to move to next object item. - * If this call is successful, user can proceed to call - * reader->Read to read in the value. - * \param out_key the key to the next object. - * \return true if the read is successful, false if we are at end of the object. - */ - inline bool NextObjectItem(std::string *out_key); - /*! - * \brief Try to read the next element in the array. - * If this call is successful, user can proceed to call - * reader->Read to read in the value. - * \return true if the read is successful, false if we are at end of the array. - */ - inline bool NextArrayItem(); - /*! - * \brief Read next ValueType. - * \param out_value any STL or json readable type to be read - * \throw dmlc::Error when the read of ValueType is not successful. - * \tparam ValueType the data type to be read. - */ - template - inline void Read(ValueType *out_value); - - /*! \return current line count */ - inline std::string line_info() const { -#ifndef _LIBCPP_SGX_NO_IOSTREAMS - char temp[64]; - std::ostringstream os; - os << " Line " << std::max(line_count_r_, line_count_n_); - is_->getline(temp, 64); - os << ", around ^`" << temp << "`"; - return os.str(); -#else - std::string info = " Line "; - info += std::to_string(std::max(line_count_r_, line_count_n_)); - - // string getline - size_t end_pos = is_->find('\n'); - end_pos = std::min((size_t)64, - end_pos == std::string::npos ? is_->size() : end_pos); - std::string line = is_->substr(0, end_pos); - is_->erase(0, line.size() + 1); // +1 for \n - - info += ", around ^`" + line + "`"; - return info; -#endif - } - - private: -#ifndef _LIBCPP_SGX_NO_IOSTREAMS - /*! \brief internal reader stream */ - std::istream *is_; -#else - /*! \brief internal reader string */ - std::string *is_; -#endif - /*! \brief "\\r" counter */ - size_t line_count_r_; - /*! \brief "\\n" counter */ - size_t line_count_n_; - /*! - * \brief record how many element processed in - * current array/object scope. - */ - std::vector scope_counter_; - /*! - * \brief Read next nonspace character. - * \return the next nonspace character. - */ - inline int NextNonSpace(); - /*! - * \brief Read just before next nonspace but not read that. - * \return the next nonspace character. - */ - inline int PeekNextNonSpace(); - /*! - * \brief Takes the next char from the input source. - * \return the next character. - */ - inline int NextChar(); - /*! - * \brief Returns the next char from the input source. - * \return the next character. - */ - inline int PeekNextChar(); -}; - -/*! - * \brief Lightweight json to write any STL compositions. - */ -class JSONWriter { - public: - /*! - * \brief Constructor. - * \param os the output reciever. - */ -#ifndef _LIBCPP_SGX_NO_IOSTREAMS - explicit JSONWriter(std::ostream *os) -#else - explicit JSONWriter(std::string *os) -#endif - : os_(os) {} - /*! - * \brief Write a string that do not contain escape characters. - * \param s the string to be written. - */ - inline void WriteNoEscape(const std::string &s); - /*! - * \brief Write a string that can contain escape characters. - * \param s the string to be written. - */ - inline void WriteString(const std::string &s); - /*! - * \brief Write a string that can contain escape characters. - * \param v the value to be written. - * \tparam ValueType The value type to be written. - */ - template - inline void WriteNumber(const ValueType &v); - /*! - * \brief Start beginning of array. - * \param multi_line whether to start an multi_line array. - * \code - * writer->BeginArray(); - * for (auto& v : vdata) { - * writer->WriteArrayItem(v); - * } - * writer->EndArray(); - * \endcode - */ - inline void BeginArray(bool multi_line = true); - /*! \brief Finish writing an array. */ - inline void EndArray(); - /*! - * \brief Start beginning of array. - * \param multi_line whether to start an multi_line array. - * \code - * writer->BeginObject(); - * for (auto& kv : vmap) { - * writer->WriteObjectKeyValue(kv.first, kv.second); - * } - * writer->EndObject(); - * \endcode - */ - inline void BeginObject(bool multi_line = true); - /*! \brief Finish writing object. */ - inline void EndObject(); - /*! - * \brief Write key value pair in the object. - * \param key the key of the object. - * \param value the value of to be written. - * \tparam ValueType The value type to be written. - */ - template - inline void WriteObjectKeyValue(const std::string &key, - const ValueType &value); - /*! - * \brief Write seperator of array, before writing next element. - * User can proceed to call writer->Write to write next item - */ - inline void WriteArraySeperator(); - /*! - * \brief Write value into array. - * \param value The value of to be written. - * \tparam ValueType The value type to be written. - */ - template - inline void WriteArrayItem(const ValueType &value); - /*! - * \brief Write value to json. - * \param value any STL or json readable that can be written. - * \tparam ValueType the data type to be write. - */ - template - inline void Write(const ValueType &value); - - private: -#ifndef _LIBCPP_SGX_NO_IOSTREAMS - /*! \brief Output stream */ - std::ostream *os_; -#else - std::string *os_; -#endif - /*! - * \brief record how many element processed in - * current array/object scope. - */ - std::vector scope_counter_; - /*! \brief Record whether current is a multiline scope */ - std::vector scope_multi_line_; - /*! - * \brief Write seperating space and newlines - */ - inline void WriteSeperator(); -}; - -/*! - * \brief Helper class to read JSON into a class or struct object. - * \code - * struct Param { - * std::string name; - * int value; - * // define load function from JSON - * inline void Load(dmlc::JSONReader *reader) { - * dmlc::JSONStructReadHelper helper; - * helper.DeclareField("name", &name); - * helper.DeclareField("value", &value); - * helper.ReadAllFields(reader); - * } - * }; - * \endcode - */ -class JSONObjectReadHelper { - public: - /*! - * \brief Declare field of type T - * \param key the key of the of field. - * \param addr address of the data type. - * \tparam T the data type to be read, must be STL composition of JSON serializable. - */ - template - inline void DeclareField(const std::string &key, T *addr) { - DeclareFieldInternal(key, addr, false); - } - /*! - * \brief Declare optional field of type T - * \param key the key of the of field. - * \param addr address of the data type. - * \tparam T the data type to be read, must be STL composition of JSON serializable. - */ - template - inline void DeclareOptionalField(const std::string &key, T *addr) { - DeclareFieldInternal(key, addr, true); - } - /*! - * \brief Read in all the declared fields. - * \param reader the JSONReader to read the json. - */ - inline void ReadAllFields(JSONReader *reader); - - private: - /*! - * \brief Internal function to declare field. - * \param key the key of the of field. - * \param addr address of the data type. - * \param optional if set to true, no error will be reported if the key is not presented. - * \tparam T the data type to be read, must be STL composition of JSON serializable. - */ - template - inline void DeclareFieldInternal(const std::string &key, T *addr, bool optional); - /*! - * \brief The internal reader function. - * \param reader The reader to read. - * \param addr The memory address to read. - */ - template - inline static void ReaderFunction(JSONReader *reader, void *addr); - /*! \brief callback type to reader function */ - typedef void (*ReadFunction)(JSONReader *reader, void *addr); - /*! \brief internal data entry */ - struct Entry { - /*! \brief the reader function */ - ReadFunction func; - /*! \brief the address to read */ - void *addr; - /*! \brief whether it is optional */ - bool optional; - }; - /*! \brief the internal map of reader callbacks */ - std::map map_; -}; - -#define DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName) \ - static DMLC_ATTRIBUTE_UNUSED ::dmlc::json::AnyJSONManager& \ - __make_AnyJSONType ## _ ## KeyName ## __ - -/*! - * \def DMLC_JSON_ENABLE_ANY - * \brief Macro to enable save/load JSON of dmlc:: whose actual type is Type. - * Any type will be saved as json array [KeyName, content] - * - * \param Type The type to be registered. - * \param KeyName The Type key assigned to the type, must be same during load. - */ -#define DMLC_JSON_ENABLE_ANY(Type, KeyName) \ - DMLC_STR_CONCAT(DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName), __COUNTER__) = \ - ::dmlc::json::AnyJSONManager::Global()->EnableType(#KeyName) \ - -//! \cond Doxygen_Suppress -namespace json { - -/*! - * \brief generic serialization handler - * \tparam T the type to be serialized - */ -template -struct Handler; - -template -struct NumericHandler { - inline static void Write(JSONWriter *writer, const ValueType &value) { - writer->WriteNumber(value); - } - inline static void Read(JSONReader *reader, ValueType *value) { - reader->ReadNumber(value); - } -}; - -template -struct ArrayHandler { - inline static void Write(JSONWriter *writer, const ContainerType &array) { - typedef typename ContainerType::value_type ElemType; - writer->BeginArray(array.size() > 10 || !dmlc::is_pod::value); - for (typename ContainerType::const_iterator it = array.begin(); - it != array.end(); ++it) { - writer->WriteArrayItem(*it); - } - writer->EndArray(); - } - inline static void Read(JSONReader *reader, ContainerType *array) { - typedef typename ContainerType::value_type ElemType; - array->clear(); - reader->BeginArray(); - while (reader->NextArrayItem()) { - ElemType value; - Handler::Read(reader, &value); - array->insert(array->end(), value); - } - } -}; - -template -struct MapHandler{ - inline static void Write(JSONWriter *writer, const ContainerType &map) { - writer->BeginObject(map.size() > 1); - for (typename ContainerType::const_iterator it = map.begin(); it != map.end(); ++it) { - writer->WriteObjectKeyValue(it->first, it->second); - } - writer->EndObject(); - } - inline static void Read(JSONReader *reader, ContainerType *map) { - typedef typename ContainerType::mapped_type ElemType; - map->clear(); - reader->BeginObject(); - std::string key; - while (reader->NextObjectItem(&key)) { - ElemType value; - reader->Read(&value); - (*map)[key] = value; - } - } -}; - -template -struct CommonJSONSerializer { - inline static void Write(JSONWriter *writer, const T &value) { - value.Save(writer); - } - inline static void Read(JSONReader *reader, T *value) { - value->Load(reader); - } -}; - -template<> -struct Handler { - inline static void Write(JSONWriter *writer, const std::string &value) { - writer->WriteString(value); - } - inline static void Read(JSONReader *reader, std::string *str) { - reader->ReadString(str); - } -}; - -template -struct Handler > : public ArrayHandler > { -}; - -template -struct Handler > { - inline static void Write(JSONWriter *writer, const std::pair &kv) { - writer->BeginArray(); - writer->WriteArrayItem(kv.first); - writer->WriteArrayItem(kv.second); - writer->EndArray(); - } - inline static void Read(JSONReader *reader, std::pair *kv) { - reader->BeginArray(); - CHECK(reader->NextArrayItem()) - << "Expect array of length 2"; - Handler::Read(reader, &(kv->first)); - CHECK(reader->NextArrayItem()) - << "Expect array of length 2"; - Handler::Read(reader, &(kv->second)); - CHECK(!reader->NextArrayItem()) - << "Expect array of length 2"; - } -}; - -template -struct Handler > : public ArrayHandler > { -}; - -template -struct Handler > : public MapHandler > { -}; - -#if DMLC_USE_CXX11 -template -struct Handler > - : public MapHandler > { -}; -#endif // DMLC_USE_CXX11 - -template -struct Handler { - inline static void Write(JSONWriter *writer, const T &data) { - typedef typename dmlc::IfThenElseType::value, - NumericHandler, - CommonJSONSerializer >::Type THandler; - THandler::Write(writer, data); - } - inline static void Read(JSONReader *reader, T *data) { - typedef typename dmlc::IfThenElseType::value, - NumericHandler, - CommonJSONSerializer >::Type THandler; - THandler::Read(reader, data); - } -}; - -#if DMLC_STRICT_CXX11 -#if DMLC_ENABLE_RTTI -// Manager to store json serialization strategy. -class AnyJSONManager { - public: - template - inline AnyJSONManager& EnableType(const std::string& type_name) { // NOLINT(*) - std::type_index tp = std::type_index(typeid(T)); - if (type_name_.count(tp) != 0) { - CHECK(type_name_.at(tp) == type_name) - << "Type has already been registered as another typename " << type_name_.at(tp); - return *this; - } - CHECK(type_map_.count(type_name) == 0) - << "Type name " << type_name << " already registered in registry"; - Entry e; - e.read = ReadAny; - e.write = WriteAny; - type_name_[tp] = type_name; - type_map_[type_name] = e; - return *this; - } - // return global singleton - inline static AnyJSONManager* Global() { - static AnyJSONManager inst; - return &inst; - } - - private: - AnyJSONManager() {} - - template - inline static void WriteAny(JSONWriter *writer, const any &data) { - writer->Write(dmlc::get(data)); - } - template - inline static void ReadAny(JSONReader *reader, any* data) { - T temp; - reader->Read(&temp); - *data = std::move(temp); - } - // data entry to store vtable for any type - struct Entry { - void (*read)(JSONReader* reader, any *data); - void (*write)(JSONWriter* reader, const any& data); - }; - - template - friend struct Handler; - - std::unordered_map type_name_; - std::unordered_map type_map_; -}; - -template<> -struct Handler { - inline static void Write(JSONWriter *writer, const any &data) { - std::unordered_map& - nmap = AnyJSONManager::Global()->type_name_; - std::type_index id = std::type_index(data.type()); - auto it = nmap.find(id); - CHECK(it != nmap.end() && it->first == id) - << "Type " << id.name() << " has not been registered via DMLC_JSON_ENABLE_ANY"; - std::string type_name = it->second; - AnyJSONManager::Entry e = AnyJSONManager::Global()->type_map_.at(type_name); - writer->BeginArray(false); - writer->WriteArrayItem(type_name); - writer->WriteArraySeperator(); - e.write(writer, data); - writer->EndArray(); - } - inline static void Read(JSONReader *reader, any *data) { - std::string type_name; - reader->BeginArray(); - CHECK(reader->NextArrayItem()) << "invalid any json format"; - Handler::Read(reader, &type_name); - std::unordered_map& - tmap = AnyJSONManager::Global()->type_map_; - auto it = tmap.find(type_name); - CHECK(it != tmap.end() && it->first == type_name) - << "Typename " << type_name << " has not been registered via DMLC_JSON_ENABLE_ANY"; - AnyJSONManager::Entry e = it->second; - CHECK(reader->NextArrayItem()) << "invalid any json format"; - e.read(reader, data); - CHECK(!reader->NextArrayItem()) << "invalid any json format"; - } -}; -#endif // DMLC_ENABLE_RTTI -#endif // DMLC_STRICT_CXX11 - -} // namespace json - -// implementations of JSONReader/Writer -inline int JSONReader::NextChar() { -#ifndef _LIBCPP_SGX_NO_IOSTREAMS - return is_->get(); -#else - int ch = is_->at(0); - is_->erase(0, 1); - return ch; -#endif -} - -inline int JSONReader::PeekNextChar() { -#ifndef _LIBCPP_SGX_NO_IOSTREAMS - return is_->peek(); -#else - return is_->at(0); -#endif -} - -inline int JSONReader::NextNonSpace() { - int ch; - do { - ch = NextChar(); - if (ch == '\n') ++line_count_n_; - if (ch == '\r') ++line_count_r_; - } while (isspace(ch)); - return ch; -} - -inline int JSONReader::PeekNextNonSpace() { - int ch; - while (true) { - ch = PeekNextChar(); - if (ch == '\n') ++line_count_n_; - if (ch == '\r') ++line_count_r_; - if (!isspace(ch)) break; - NextChar(); - } - return ch; -} - -namespace { - template -#ifndef _LIBCPP_SGX_NO_IOSTREAMS - void Extend(std::ostream *os, T item) { - *os << item; - } -#else - void Extend(std::string *ostr, T item) { - *ostr += item; - } -#endif -} // namespace - -inline void JSONReader::ReadString(std::string *out_str) { - int ch = NextNonSpace(); - CHECK_EQ(ch, '\"') - << "Error at" << line_info() - << ", Expect \'\"\' but get \'" << static_cast(ch) << '\''; -#ifndef _LIBCPP_SGX_NO_IOSTREAMS - std::ostringstream output; -#else - std::string output = ""; -#endif - while (true) { - ch = NextChar(); - if (ch == '\\') { - char sch = static_cast(NextChar()); - switch (sch) { - case 'r': Extend(&output, "\r"); break; - case 'n': Extend(&output, "\n"); break; - case '\\': Extend(&output, "\\"); break; - case 't': Extend(&output, "\t"); break; - case '\"': Extend(&output, "\""); break; - default: LOG(FATAL) << "unknown string escape \\" << sch; - } - } else { - if (ch == '\"') break; - Extend(&output, static_cast(ch)); - } - if (ch == EOF || ch == '\r' || ch == '\n') { - LOG(FATAL) - << "Error at" << line_info() - << ", Expect \'\"\' but reach end of line "; - } - } -#ifndef _LIBCPP_SGX_NO_IOSTREAMS - *out_str = output.str(); -#else - *out_str = output; -#endif -} - -template -inline void JSONReader::ReadNumber(ValueType *out_value) { -#ifndef _LIBCPP_SGX_NO_IOSTREAMS - *is_ >> *out_value; - CHECK(!is_->fail()) - << "Error at" << line_info() - << ", Expect number"; -#else - char* endptr; - const char* icstr = is_->c_str(); - unsigned number = strtol(icstr, &endptr, 10); - is_->erase(0, endptr - icstr); - *out_value = static_cast(number); -#endif -} - -inline void JSONReader::BeginObject() { - int ch = NextNonSpace(); - CHECK_EQ(ch, '{') - << "Error at" << line_info() - << ", Expect \'{\' but get \'" << static_cast(ch) << '\''; - scope_counter_.push_back(0); -} - -inline void JSONReader::BeginArray() { - int ch = NextNonSpace(); - CHECK_EQ(ch, '[') - << "Error at" << line_info() - << ", Expect \'{\' but get \'" << static_cast(ch) << '\''; - scope_counter_.push_back(0); -} - -inline bool JSONReader::NextObjectItem(std::string *out_key) { - bool next = true; - if (scope_counter_.back() != 0) { - int ch = NextNonSpace(); - if (ch == EOF) { - next = false; - } else if (ch == '}') { - next = false; - } else { - CHECK_EQ(ch, ',') - << "Error at" << line_info() - << ", JSON object expect \'}\' or \',\' \'" << static_cast(ch) << '\''; - } - } else { - int ch = PeekNextNonSpace(); - if (ch == '}') { - NextChar(); - next = false; - } - } - if (!next) { - scope_counter_.pop_back(); - return false; - } else { - scope_counter_.back() += 1; - ReadString(out_key); - int ch = NextNonSpace(); - CHECK_EQ(ch, ':') - << "Error at" << line_info() - << ", Expect \':\' but get \'" << static_cast(ch) << '\''; - return true; - } -} - -inline bool JSONReader::NextArrayItem() { - bool next = true; - if (scope_counter_.back() != 0) { - int ch = NextNonSpace(); - if (ch == EOF) { - next = false; - } else if (ch == ']') { - next = false; - } else { - CHECK_EQ(ch, ',') - << "Error at" << line_info() - << ", JSON array expect \']\' or \',\'. Get \'" << static_cast(ch) << "\' instead"; - } - } else { - int ch = PeekNextNonSpace(); - if (ch == ']') { - NextChar(); - next = false; - } - } - if (!next) { - scope_counter_.pop_back(); - return false; - } else { - scope_counter_.back() += 1; - return true; - } -} - -template -inline void JSONReader::Read(ValueType *out_value) { - json::Handler::Read(this, out_value); -} - -inline void JSONWriter::WriteNoEscape(const std::string &s) { - Extend(os_, '\"'); - Extend(os_, s); - Extend(os_, '\"'); -} - -inline void JSONWriter::WriteString(const std::string &s) { - Extend(os_, '\"'); - for (size_t i = 0; i < s.length(); ++i) { - char ch = s[i]; - switch (ch) { - case '\r': Extend(os_, "\\r"); break; - case '\n': Extend(os_, "\\n"); break; - case '\\': Extend(os_, "\\\\"); break; - case '\t': Extend(os_, "\\t"); break; - case '\"': Extend(os_, "\\\""); break; - default: Extend(os_, ch); - } - } - Extend(os_, '\"'); -} - -template -inline void JSONWriter::WriteNumber(const ValueType &v) { -#ifndef _LIBCPP_SGX_NO_IOSTREAMS - Extend(os_, v); -#else - Extend(os_, std::to_string(v)); -#endif -} - -inline void JSONWriter::BeginArray(bool multi_line) { - Extend(os_, '['); - scope_multi_line_.push_back(multi_line); - scope_counter_.push_back(0); -} - -inline void JSONWriter::EndArray() { - CHECK_NE(scope_multi_line_.size(), 0U); - CHECK_NE(scope_counter_.size(), 0U); - bool newline = scope_multi_line_.back(); - size_t nelem = scope_counter_.back(); - scope_multi_line_.pop_back(); - scope_counter_.pop_back(); - if (newline && nelem != 0) WriteSeperator(); - Extend(os_, ']'); -} - -inline void JSONWriter::BeginObject(bool multi_line) { - Extend(os_, '{'); - scope_multi_line_.push_back(multi_line); - scope_counter_.push_back(0); -} - -inline void JSONWriter::EndObject() { - CHECK_NE(scope_multi_line_.size(), 0U); - CHECK_NE(scope_counter_.size(), 0U); - bool newline = scope_multi_line_.back(); - size_t nelem = scope_counter_.back(); - scope_multi_line_.pop_back(); - scope_counter_.pop_back(); - if (newline && nelem != 0) WriteSeperator(); - Extend(os_, '}'); -} - -template -inline void JSONWriter::WriteObjectKeyValue(const std::string &key, - const ValueType &value) { - if (scope_counter_.back() > 0) { - Extend(os_, ", "); - } - WriteSeperator(); - Extend(os_, '\"'); - Extend(os_, key); - Extend(os_, "\": "); - scope_counter_.back() += 1; - json::Handler::Write(this, value); -} - -inline void JSONWriter::WriteArraySeperator() { - if (scope_counter_.back() != 0) { - Extend(os_, ", "); - } - scope_counter_.back() += 1; - WriteSeperator(); -} - -template -inline void JSONWriter::WriteArrayItem(const ValueType &value) { - this->WriteArraySeperator(); - json::Handler::Write(this, value); -} - -template -inline void JSONWriter::Write(const ValueType &value) { - size_t nscope = scope_multi_line_.size(); - json::Handler::Write(this, value); - CHECK_EQ(nscope, scope_multi_line_.size()) - << "Uneven scope, did you call EndArray/EndObject after each BeginObject/Array?"; -} - -inline void JSONWriter::WriteSeperator() { - if (scope_multi_line_.size() == 0 || scope_multi_line_.back()) { - Extend(os_, '\n'); - Extend(os_, std::string(scope_multi_line_.size() * 2, ' ')); - } -} - -inline void JSONObjectReadHelper::ReadAllFields(JSONReader *reader) { - reader->BeginObject(); - std::map visited; - std::string key; - while (reader->NextObjectItem(&key)) { - if (map_.count(key) != 0) { - Entry e = map_[key]; - (*e.func)(reader, e.addr); - visited[key] = 0; - } else { -#ifndef _LIBCPP_SGX_NO_IOSTREAMS - std::ostringstream err; -#else - std::string err(""); -#endif - Extend(&err, "JSONReader: Unknown field "); - Extend(&err, key); - Extend(&err, ", candidates are: \n"); - for (std::map::iterator - it = map_.begin(); it != map_.end(); ++it) { - Extend(&err, '\"'); - Extend(&err, it->first); - Extend(&err, "\"\n"); - } -#ifndef _LIBCPP_SGX_NO_IOSTREAMS - LOG(FATAL) << err.str(); -#else - LOG(FATAL) << err; -#endif - } - } - if (visited.size() != map_.size()) { - for (std::map::iterator - it = map_.begin(); it != map_.end(); ++it) { - if (it->second.optional) continue; - CHECK_NE(visited.count(it->first), 0U) - << "JSONReader: Missing field \"" << it->first << "\"\n At " - << reader->line_info(); - } - } -} - -template -inline void JSONObjectReadHelper::ReaderFunction(JSONReader *reader, void *addr) { - json::Handler::Read(reader, static_cast(addr)); -} - -template -inline void JSONObjectReadHelper:: -DeclareFieldInternal(const std::string &key, T *addr, bool optional) { - CHECK_EQ(map_.count(key), 0U) - << "Adding duplicate field " << key; - Entry e; - e.func = ReaderFunction; - e.addr = static_cast(addr); - e.optional = optional; - map_[key] = e; -} - -//! \endcond -} // namespace dmlc -#endif // DMLC_JSON_H_ diff --git a/include/dmlc/logging.h b/include/dmlc/logging.h deleted file mode 100644 index 8e7878bd41d3..000000000000 --- a/include/dmlc/logging.h +++ /dev/null @@ -1,424 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file logging.h - * \brief defines logging macros of dmlc - * allows use of GLOG, fall back to internal - * implementation when disabled - */ -#ifndef DMLC_LOGGING_H_ -#define DMLC_LOGGING_H_ -#include -#include -#include -#include -#include -#include -#include "./base.h" - -#if DMLC_LOG_STACK_TRACE -#include -#endif - -#if DMLC_LOG_STACK_TRACE -#include -#endif - -namespace dmlc { -/*! - * \brief exception class that will be thrown by - * default logger if DMLC_LOG_FATAL_THROW == 1 - */ -struct Error : public std::runtime_error { - /*! - * \brief constructor - * \param s the error message - */ - explicit Error(const std::string &s) : std::runtime_error(s) {} -}; -} // namespace dmlc - -#if DMLC_USE_GLOG -#include - -namespace dmlc { -/*! - * \brief optionally redirect to google's init log - * \param argv0 The arguments. - */ -inline void InitLogging(const char* argv0) { - google::InitGoogleLogging(argv0); -} -} // namespace dmlc - -#else -// use a light version of glog -#include -#include -#include -#include - -#if defined(_MSC_VER) -#pragma warning(disable : 4722) -#pragma warning(disable : 4068) -#endif - -namespace dmlc { -inline void InitLogging(const char*) { - // DO NOTHING -} - -class LogCheckError { - public: - LogCheckError() : str(nullptr) {} - explicit LogCheckError(const std::string& str_) : str(new std::string(str_)) {} - ~LogCheckError() { if (str != nullptr) delete str; } - operator bool() {return str != nullptr; } - std::string* str; -}; - -#ifndef DMLC_GLOG_DEFINED - -#ifndef _LIBCPP_SGX_NO_IOSTREAMS -#define DEFINE_CHECK_FUNC(name, op) \ - template \ - inline LogCheckError LogCheck##name(const X& x, const Y& y) { \ - if (x op y) return LogCheckError(); \ - std::ostringstream os; \ - os << " (" << x << " vs. " << y << ") "; /* CHECK_XX(x, y) requires x and y can be serialized to string. Use CHECK(x OP y) otherwise. NOLINT(*) */ \ - return LogCheckError(os.str()); \ - } \ - inline LogCheckError LogCheck##name(int x, int y) { \ - return LogCheck##name(x, y); \ - } -#else -#define DEFINE_CHECK_FUNC(name, op) \ - template \ - inline LogCheckError LogCheck##name(const X& x, const Y& y) { \ - if (x op y) return LogCheckError(); \ - return LogCheckError("Error."); \ - } \ - inline LogCheckError LogCheck##name(int x, int y) { \ - return LogCheck##name(x, y); \ - } -#endif - -#define CHECK_BINARY_OP(name, op, x, y) \ - if (dmlc::LogCheckError _check_err = dmlc::LogCheck##name(x, y)) \ - dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \ - << "Check failed: " << #x " " #op " " #y << *(_check_err.str) - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wsign-compare" -DEFINE_CHECK_FUNC(_LT, <) -DEFINE_CHECK_FUNC(_GT, >) -DEFINE_CHECK_FUNC(_LE, <=) -DEFINE_CHECK_FUNC(_GE, >=) -DEFINE_CHECK_FUNC(_EQ, ==) -DEFINE_CHECK_FUNC(_NE, !=) -#pragma GCC diagnostic pop - -// Always-on checking -#define CHECK(x) \ - if (!(x)) \ - dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \ - << "Check failed: " #x << ' ' -#define CHECK_LT(x, y) CHECK_BINARY_OP(_LT, <, x, y) -#define CHECK_GT(x, y) CHECK_BINARY_OP(_GT, >, x, y) -#define CHECK_LE(x, y) CHECK_BINARY_OP(_LE, <=, x, y) -#define CHECK_GE(x, y) CHECK_BINARY_OP(_GE, >=, x, y) -#define CHECK_EQ(x, y) CHECK_BINARY_OP(_EQ, ==, x, y) -#define CHECK_NE(x, y) CHECK_BINARY_OP(_NE, !=, x, y) -#define CHECK_NOTNULL(x) \ - ((x) == NULL ? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check notnull: " #x << ' ', (x) : (x)) // NOLINT(*) -// Debug-only checking. -#ifdef NDEBUG -#define DCHECK(x) \ - while (false) CHECK(x) -#define DCHECK_LT(x, y) \ - while (false) CHECK((x) < (y)) -#define DCHECK_GT(x, y) \ - while (false) CHECK((x) > (y)) -#define DCHECK_LE(x, y) \ - while (false) CHECK((x) <= (y)) -#define DCHECK_GE(x, y) \ - while (false) CHECK((x) >= (y)) -#define DCHECK_EQ(x, y) \ - while (false) CHECK((x) == (y)) -#define DCHECK_NE(x, y) \ - while (false) CHECK((x) != (y)) -#else -#define DCHECK(x) CHECK(x) -#define DCHECK_LT(x, y) CHECK((x) < (y)) -#define DCHECK_GT(x, y) CHECK((x) > (y)) -#define DCHECK_LE(x, y) CHECK((x) <= (y)) -#define DCHECK_GE(x, y) CHECK((x) >= (y)) -#define DCHECK_EQ(x, y) CHECK((x) == (y)) -#define DCHECK_NE(x, y) CHECK((x) != (y)) -#endif // NDEBUG - -#if DMLC_LOG_CUSTOMIZE -#define LOG_INFO dmlc::CustomLogMessage(__FILE__, __LINE__) -#else -#define LOG_INFO dmlc::LogMessage(__FILE__, __LINE__) -#endif -#define LOG_ERROR LOG_INFO -#define LOG_WARNING LOG_INFO -#define LOG_FATAL dmlc::LogMessageFatal(__FILE__, __LINE__) -#define LOG_QFATAL LOG_FATAL - -// Poor man version of VLOG -#define VLOG(x) LOG_INFO.stream() - -#define LOG(severity) LOG_##severity.stream() -#define LG LOG_INFO.stream() -#define LOG_IF(severity, condition) \ - !(condition) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) - -#ifdef NDEBUG -#define LOG_DFATAL LOG_ERROR -#define DFATAL ERROR -#define DLOG(severity) true ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) -#define DLOG_IF(severity, condition) \ - (true || !(condition)) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) -#else -#define LOG_DFATAL LOG_FATAL -#define DFATAL FATAL -#define DLOG(severity) LOG(severity) -#define DLOG_IF(severity, condition) LOG_IF(severity, condition) -#endif - -// Poor man version of LOG_EVERY_N -#define LOG_EVERY_N(severity, n) LOG(severity) - -#endif // DMLC_GLOG_DEFINED - -class DateLogger { - public: - DateLogger() { -#if defined(_MSC_VER) - _tzset(); -#endif - } - const char* HumanDate() { -#ifndef _LIBCPP_SGX_CONFIG -#if defined(_MSC_VER) - _strtime_s(buffer_, sizeof(buffer_)); -#else - time_t time_value = time(NULL); - struct tm *pnow; -#if !defined(_WIN32) - struct tm now; - pnow = localtime_r(&time_value, &now); -#else - pnow = localtime(&time_value); // NOLINT(*) -#endif - snprintf(buffer_, sizeof(buffer_), "%02d:%02d:%02d", - pnow->tm_hour, pnow->tm_min, pnow->tm_sec); -#endif -#endif // _LIBCPP_SGX_CONFIG - return buffer_; - } - - private: - char buffer_[9]; -}; - -#ifndef _LIBCPP_SGX_NO_IOSTREAMS -class LogMessage { - public: - LogMessage(const char* file, int line) - : -#ifdef __ANDROID__ - log_stream_(std::cout) -#else - log_stream_(std::cerr) -#endif - { - log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":" - << line << ": "; - } - ~LogMessage() { log_stream_ << '\n'; } - std::ostream& stream() { return log_stream_; } - - protected: - std::ostream& log_stream_; - - private: - DateLogger pretty_date_; - LogMessage(const LogMessage&); - void operator=(const LogMessage&); -}; - -// customized logger that can allow user to define where to log the message. -class CustomLogMessage { - public: - CustomLogMessage(const char* file, int line) { - log_stream_ << "[" << DateLogger().HumanDate() << "] " << file << ":" - << line << ": "; - } - ~CustomLogMessage() { - Log(log_stream_.str()); - } - std::ostream& stream() { return log_stream_; } - /*! - * \brief customized logging of the message. - * This function won't be implemented by libdmlc - * \param msg The message to be logged. - */ - static void Log(const std::string& msg); - - private: - std::ostringstream log_stream_; -}; -#else -class DummyOStream { - public: - template - DummyOStream& operator<<(T _) { return *this; } - inline std::string str() { return ""; } -}; -class LogMessage { - public: - LogMessage(const char* file, int line) : log_stream_() {} - DummyOStream& stream() { return log_stream_; } - - protected: - DummyOStream log_stream_; - - private: - LogMessage(const LogMessage&); - void operator=(const LogMessage&); -}; -#endif - - - -#if DMLC_LOG_STACK_TRACE -inline std::string Demangle(char const *msg_str) { - using std::string; - string msg(msg_str); - size_t symbol_start = string::npos; - size_t symbol_end = string::npos; - if ( ((symbol_start = msg.find("_Z")) != string::npos) - && (symbol_end = msg.find_first_of(" +", symbol_start)) ) { - string left_of_symbol(msg, 0, symbol_start); - string symbol(msg, symbol_start, symbol_end - symbol_start); - string right_of_symbol(msg, symbol_end); - - int status = 0; - size_t length = string::npos; - std::unique_ptr demangled_symbol = - {abi::__cxa_demangle(symbol.c_str(), 0, &length, &status), &std::free}; - if (demangled_symbol && status == 0 && length > 0) { - string symbol_str(demangled_symbol.get()); - std::ostringstream os; - os << left_of_symbol << symbol_str << right_of_symbol; - return os.str(); - } - } - return string(msg_str); -} - -inline std::string StackTrace() { - using std::string; - std::ostringstream stacktrace_os; - const int MAX_STACK_SIZE = DMLC_LOG_STACK_TRACE_SIZE; - void *stack[MAX_STACK_SIZE]; - int nframes = backtrace(stack, MAX_STACK_SIZE); - stacktrace_os << "Stack trace returned " << nframes << " entries:" << std::endl; - char **msgs = backtrace_symbols(stack, nframes); - if (msgs != nullptr) { - for (int frameno = 0; frameno < nframes; ++frameno) { - string msg = dmlc::Demangle(msgs[frameno]); - stacktrace_os << "[bt] (" << frameno << ") " << msg << "\n"; - } - } - free(msgs); - string stack_trace = stacktrace_os.str(); - return stack_trace; -} - -#else // DMLC_LOG_STACK_TRACE is off - -inline std::string demangle(char const* msg_str) { - return std::string(); -} - -inline std::string StackTrace() { - return std::string("stack traces not available when " - "DMLC_LOG_STACK_TRACE is disabled at compile time."); -} - -#endif // DMLC_LOG_STACK_TRACE - -#if defined(_LIBCPP_SGX_NO_IOSTREAMS) -class LogMessageFatal : public LogMessage { - public: - LogMessageFatal(const char* file, int line) : LogMessage(file, line) {} - ~LogMessageFatal() { - abort(); - } - private: - LogMessageFatal(const LogMessageFatal&); - void operator=(const LogMessageFatal&); -}; -#elif DMLC_LOG_FATAL_THROW == 0 -class LogMessageFatal : public LogMessage { - public: - LogMessageFatal(const char* file, int line) : LogMessage(file, line) {} - ~LogMessageFatal() { - log_stream_ << "\n\n" << StackTrace() << "\n"; - abort(); - } - - private: - LogMessageFatal(const LogMessageFatal&); - void operator=(const LogMessageFatal&); -}; -#else -class LogMessageFatal { - public: - LogMessageFatal(const char* file, int line) { - log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":" - << line << ": "; - } - std::ostringstream &stream() { return log_stream_; } - ~LogMessageFatal() DMLC_THROW_EXCEPTION { -#if DMLC_LOG_STACK_TRACE - log_stream_ << "\n\n" << StackTrace() << "\n"; -#endif - - // throwing out of destructor is evil - // hopefully we can do it here - // also log the message before throw -#if DMLC_LOG_BEFORE_THROW - LOG(ERROR) << log_stream_.str(); -#endif - throw Error(log_stream_.str()); - } - - private: - std::ostringstream log_stream_; - DateLogger pretty_date_; - LogMessageFatal(const LogMessageFatal&); - void operator=(const LogMessageFatal&); -}; -#endif - -// This class is used to explicitly ignore values in the conditional -// logging macros. This avoids compiler warnings like "value computed -// is not used" and "statement has no effect". -class LogMessageVoidify { - public: - LogMessageVoidify() {} - // This has to be an operator with a precedence lower than << but - // higher than "?:". See its usage. -#if !defined(_LIBCPP_SGX_NO_IOSTREAMS) - void operator&(std::ostream&) {} -#endif -}; - -} // namespace dmlc - -#endif -#endif // DMLC_LOGGING_H_ diff --git a/include/dmlc/lua.h b/include/dmlc/lua.h deleted file mode 100644 index 13aa7b73d269..000000000000 --- a/include/dmlc/lua.h +++ /dev/null @@ -1,739 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file lua.h - * \brief C++11 header only interface to easily interact with Lua and Torch. - * This code is evolved from torch plugin code for MXNet. - * - * This header will require Torch and Lua to be presented, do not include. - * - * \author Junyuan Xie, Min Lin, Tianqi Chen - * - * \code - * - * // Example code to use the lua module. - * dmlc::LuaState* lua = dmlc::LuaState::ThreadLocalState(); - * // vectors converts automatically to lua table. - * auto tbl = lua->Convert(std::vector{1,2,3}); - * // use eval to get lua reference, this is a function - * auto print = lua->Eval("return function(x) print(x) end"); - * // lua function can be directly called from c++, arguments are converted. - * print(100); - * - * // set field in the table. - * tbl.SetField("square", lua->Eval("return function(x) x*x end")); - * // call the function, covert back to C++ values. - * int x = tbl["square"](100).Get(); - * - * \endcode - */ -#ifndef DMLC_LUA_H_ -#define DMLC_LUA_H_ - -extern "C" { -#include -#include -#include -} - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "./base.h" -#include "./logging.h" -#include "./thread_local.h" - -namespace dmlc { - -// forward declare torch state -class LuaState; - -namespace lua_stack { -template -struct Handler; -}; - -/*! \brief an reference to lua object */ -class LuaRef { - public: - /*! \brief construct an nil ref */ - LuaRef() = default; - /*! - * \brief move constructor from another LuaRef - * \param other The other LuaRef to be moved - */ - inline LuaRef(LuaRef&& other); // NOLINT(*) - /*! - * \brief copy constructor - * \param other The other LuaRef to be copied - */ - inline LuaRef(const LuaRef& other); // NOLINT(*) - /*! - * \brief assign operator from other - * \param other The other LuaRef to be copy or moved. - * \return self - */ - inline LuaRef& operator=(LuaRef&& other); - /*! - * \brief assign operator from other - * \param other The other LuaRef to be copy or moved. - * \return self - */ - inline LuaRef& operator=(const LuaRef& other); - /*! \brief destructor */ - inline ~LuaRef(); - /*! - * \brief swap content with another ref - * \param other another LuaRef to be swaped. - */ - inline void swap(LuaRef& other); // NOLINT(*) - /*! - * \brief Get content out as type T. - * - * \tparam T the type to be fetched. - * \return the corresponding c type. - */ - template - inline T Get() const; - /*! - * \brief Get user data pointer from LuaRef - * - * CAREFUL when getting userdata(e.g. pointer to Tensor's storage) from LuaRef. - * Remember they are managed by Lua, and can get deleted when all the - * LuaRef to the userdata destructs. A good practice is always use a LuaRef to keep - * the userdata alive when you need them from C++ side. - * - * \tparam T the type of pointer to be fetched. - * \return the corresponding c type. - */ - template - inline T* GetUDataPtr() const; - /*! \return whether the value is nil */ - inline bool is_nil() const; - /*! - * \brief invoke the LuaRef as function - * \param args Arguments to be passed. - * \tparam Args arguments to be passed. - * \return The first return value. - */ - template - inline LuaRef operator()(Args&& ...args) const; - /*! - * \brief Get field from the lua table. - * The reference must be a table - * \param key The key to the table - * \return a new ref to the corresponding field. - */ - inline LuaRef operator[](const std::string& key) const; - /*! - * \brief Get field from the lua array - * The reference must be a array - * \param index The index to the array, - * Note: the index convention follows lua table, starts from 1 - * \return a new ref to the corresponding field. - */ - inline LuaRef operator[](size_t index) const; - /*! - * \brief Set field of lua table. - * The reference must be a table - * \param key The key to the table - * \param value Lua convertable value to be setted. - * \return self. - */ - template - inline LuaRef& SetField(const std::string& key, const T& value); // NOLINT(*) - /*! - * \brief Set LuaRef to the value on top of the stack. - * This state must be nil. - * This is API used by developer. - * - * \param s the corresponding lua state. - */ - inline void SetByPopStack_(LuaState* s); - - private: - // friend with luastate - friend struct lua_stack::Handler; - friend class LuaState; - friend std::ostream &operator<<(std::ostream &os, const LuaRef &r); - /*! \brief pointer to the state */ - LuaState* state_{nullptr}; - /*! \brief reference index */ - int ref_; -}; - -/*! \brief A Lua state */ -class LuaState { - public: - /*! \brief options to be provided in lua state */ - enum Option { - kNoThreadProtect, - kThreadLocal, - kLocking, - }; - /*! \brief destructor */ - inline ~LuaState(); - /*! - * \brief evaluate a piece of lua code, return the first result. - * \param lua_code Lua code - * \return A LuaRef object of the first returned result, - * Can be nil if the code did not return LuaRefthing. - */ - inline LuaRef Eval(const char* lua_code); - /*! - * \brief evaluate a piece of lua code, return the first result. - * \param lua_code Lua code - * \return A LuaRef object of the first returned result, - * Can be nil if the code did not return anything. - */ - inline LuaRef Eval(const std::string& lua_code) { - return this->Eval(lua_code.c_str()); - } - /*! - * \brief convert a C++ type to lua type - * \param value The data to be converted. - * vector, map will be converted to table. - * \return a converted value. - * \tparam T the type to be converted. - */ - template - inline LuaRef Convert(const T& value); - /*! - * \brief get global field from the state - * \param key The key to the global field. - * \return The global field value. - */ - inline LuaRef operator[](const std::string& key); - /*! - * \brief Set the value to the global table. - * \param key The key of the global field. - * \param value The value to the set. - */ - inline void SetGlobalField(const std::string& key, const LuaRef& value); - /*! - * Get a thread local version of lua state. - * The LuaState runs in thread local mode, - * all the LuaRef can only be run on the current thread. - * This is the recommended behavior when invoking Lua. - * - * \return a threadlocal version of lua state. - */ - static inline LuaState* ThreadLocalState(); - /*! - * Create a new lua state. - * \note It is highly recommended to use ThreadLocalState instead. - * - * Most Lua program assumes it only runs from the same thread. - * Some Lua code that wraps C library(e.g. Torch) could rely - * on thread_local storage to store global state such as random number generator. - * This means if the code is invoked by another thread, the thread_local - * might become inavailable, depending on the implementation. - * - * If the global state is stored only in Lua's global table, then - * it is safe to use kLocking mode and call the code from multiple thread. - * Never-the-less, using ThreadLocalState removes the need to lock, - * and is the desirable usecase in most times. - * - * \sa ThreadLocalState - * \param option The option to use the state. - * \return a newly created lua state - */ - static inline LuaState* Create_(Option option); - - /*! - * \brief protected run f, this is used by API developers. - * always call this to access lua state - * f must not destruct LuaRef, or access the mutex - * - * \param f the function to be called. - * \tparam F the function to be called, signiture (lua_State *L) - */ - template - inline void PRun_(F f); - /*! - * \param L the other lua state. - * \return if the internal lua state is same as L - */ - inline bool SameLuaState(lua_State *L) const { - return L_ == L; - } - - protected: - struct StackReset; - friend class LuaRef; - friend struct ThreadLocalStore; - /*! - * \brief constructor - */ - inline LuaState(); - - /*! \brief internal option, default to thread local */ - Option option_{kThreadLocal}; - /*! \brief internal lua state */ - lua_State* L_; - /*! \brief internal lock about the state */ - std::mutex mutex_; -}; - -// implementations after this line -//! \cond Doxygen_Suppress -/*! \brief macro to check error during lua call */ -#define LUA_CALL(x) \ - if ((x)) { \ - LOG(FATAL) << "Lua Call Error:" << lua_tostring(L, -1); \ - } - -/*! - * \brief namespace to handle conversions between lua and c++ - * User can provide an specialization of dmlc::lua_stack::Handler - * to allow customized c++ data types to interact with Lua. - * - * By default basic data types, composition of vector, and unordered_map is supported. - * The conversion rules - * - basic types(string, int, float) to corresponding lua types. - * - unordered_map to Lua table. - * - vector to lua indexed table. - */ -namespace lua_stack { -inline int lua_abs_index(lua_State* L, int index) { - if (index > 0 || index <= LUA_REGISTRYINDEX) return index; - return lua_gettop(L) + index + 1; -} - -template -struct Handler; - -template -struct NumberHandler { - static inline T Get(lua_State* L, int index, LuaState* s) { - CHECK_EQ(lua_type(L, index), LUA_TNUMBER) - << "Attempt to get number but type is \'" - << lua_typename(L, lua_type(L, index)) << '\''; - if (std::is_integral::value) { - return static_cast(lua_tointeger(L, index)); - } else { - return static_cast(lua_tonumber(L, index)); - } - } - static inline void Push(lua_State* L, const T& v) { - if (std::is_integral::value) { - lua_pushinteger(L, static_cast(v)); - } else { - lua_pushnumber(L, static_cast(v)); - } - } -}; - -template -struct MapHandler { - using K = typename ContainerType::key_type; - using V = typename ContainerType::mapped_type; - static inline ContainerType Get(lua_State* L, int index, LuaState* s) { - ContainerType ret; - CHECK(lua_istable(L, index)) - << "Expected a table but get " - << lua_typename(L, lua_type(L, index)) << '\''; - int tid = lua_abs_index(L, index); - lua_pushnil(L); - while (lua_next(L, -2)) { - ret[Handler::Get(L, -2, s)] = Handler::Pop(L, -1, s); - lua_pop(L, 1); - } - lua_settop(L, tid); - return ret; - } - static inline void Push(lua_State* L, const ContainerType& v) { - lua_createtable(L, v.size(), 0); - for (const auto& kv : v) { - Handler::Push(L, kv.first); - Handler::Push(L, kv.second); - lua_settable(L, -3); - } - } -}; - -struct UndefinedHandler { -}; - -template -struct Handler - : public std::conditional::value, - NumberHandler, - UndefinedHandler>::type { -}; - -template<> -struct Handler { - static inline std::string Get(lua_State* L, int index, LuaState* s) { - CHECK_EQ(lua_type(L, index), LUA_TSTRING); - return std::string(lua_tostring(L, index)); - } - static inline void Push(lua_State* L, const std::string& v) { - lua_pushstring(L, v.c_str()); - } -}; - -template -struct Handler > { - static inline std::vector Get(lua_State* L, int index, LuaState* s) { - std::vector ret; - CHECK(lua_istable(L, index)) - << "Expected a table but get " - << lua_typename(L, lua_type(L, index)) << '\''; - int tid = lua_abs_index(L, index); - lua_pushnil(L); - while (lua_next(L, tid)) { - CHECK_EQ(Handler::Get(L, -2, s), ret.size() + 1) - << "Target table is not an array"; - ret.push_back(Handler::Get(L, -1, s)); - lua_pop(L, 1); - } - lua_settop(L, tid); - return ret; - } - static inline void Push(lua_State* L, const std::vector& v) { - lua_createtable(L, v.size(), 0); - for (size_t i = 0; i < v.size(); ++i) { - Handler::Push(L, v[i]); - lua_rawseti(L, -2, i + 1); - } - } -}; - -template -struct Handler > - : public MapHandler > { -}; - -template<> -struct Handler { - static inline LuaRef Get(lua_State* L, int index, LuaState* s) { - LuaRef ret; - lua_pushvalue(L, index); - ret.SetByPopStack_(s); - return ret; - } - - static inline void Push(lua_State* L, const LuaRef& v) { - if (v.is_nil()) { - lua_pushnil(L); - } else { - CHECK(v.state_->SameLuaState(L)) - << "Cannot pass LuaRef on a different LuaState's function"; - lua_rawgeti(L, LUA_REGISTRYINDEX, v.ref_); - } - } -}; - -template<> -struct Handler { - static inline LuaRef Get(lua_State* L, int index, LuaState* s) { - LOG(FATAL) << "not supported"; - return LuaRef(); - } - static inline void Push(lua_State* L, const std::nullptr_t& v) { - lua_pushnil(L); - } -}; - -// generic functor to call push the arguments. -struct PushArg { - lua_State* L; - template - inline void operator()(const T& v) const { - Handler::Push(L, v); - } -}; - -} // namespace lua_stack - -inline LuaState::LuaState() { - L_ = luaL_newstate(); - CHECK(L_ != nullptr) - << "Failed to create new lua state"; - luaL_openlibs(L_); -} - -inline LuaState::~LuaState() { - if (option_ != kThreadLocal && L_ != nullptr) { - // never close threadlocal, for save destruction. - lua_close(L_); - } -} - -inline LuaState* LuaState::Create_(Option opt) { - LuaState* s = new LuaState(); - s->option_ = opt; - CHECK_NE(opt, kThreadLocal) - << "use LuaState::ThreadLocalState() to get the thread local state"; - return s; -} - -inline void LuaRef::SetByPopStack_(LuaState* s) { - CHECK(state_ == nullptr); - lua_State* L = s->L_; - if (!lua_isnil(L, -1)) { - ref_ = lua_ref(L, LUA_REGISTRYINDEX); - state_ = s; - } else { - lua_pop(L, 1); - } -} - -// RAII guard to reset stack -struct LuaState::StackReset { - lua_State* L; - int top; - ~StackReset() { - lua_settop(L, top); - } -}; - -template -inline void LuaState::PRun_(F f) { - if (option_ != kLocking) { - StackReset reset{L_, lua_gettop(L_)}; - if (option_ == kThreadLocal) { - CHECK_EQ(ThreadLocalState(), this) - << "Invoke lua from a different thread in ThreadLocal mode."; - } - f(L_); - CHECK_EQ(reset.top, lua_gettop(L_)); - } else { - std::lock_guard lock(mutex_); - StackReset reset{L_, lua_gettop(L_)}; - f(L_); - CHECK_EQ(reset.top, lua_gettop(L_)); - } -} - -inline LuaState* LuaState::ThreadLocalState() { - return ThreadLocalStore::Get(); -} - -inline LuaRef LuaState::Eval(const char* lua_code) { - LuaRef ret; - this->PRun_([this, lua_code, &ret](lua_State* L) { - luaL_loadstring(L, lua_code); - CHECK_EQ(lua_pcall(L, 0, 1, 0), 0) - << "Lua call error: " << lua_tostring(L, -1) << '\n' - << "---------\n" - << lua_code - << "\n----------"; - ret.SetByPopStack_(this); - }); - return ret; -} - -template -inline LuaRef LuaState::Convert(const T& value) { - LuaRef ret; - this->PRun_([this, &value, &ret](lua_State* L) { - lua_stack::Handler::Push(L, value); - ret.SetByPopStack_(this); - }); - return ret; -} - -inline LuaRef LuaState::operator[](const std::string& key) { - LuaRef ret; - this->PRun_([this, &key, &ret](lua_State* L) { - lua_getglobal(L, key.c_str()); - ret.SetByPopStack_(this); - }); - return ret; -} - -inline void LuaState::SetGlobalField( - const std::string& key, const LuaRef& value) { - this->PRun_([this, &key, &value](lua_State* L) { - lua_rawgeti(L, LUA_REGISTRYINDEX, value.ref_); - lua_setglobal(L, key.c_str()); - }); -} - -inline LuaRef::LuaRef(const LuaRef& other) { - if (other.state_ != nullptr) { - state_ = other.state_; - state_->PRun_([this, &other](lua_State* L) { - lua_rawgeti(L, LUA_REGISTRYINDEX, other.ref_); - ref_ = luaL_ref(L, LUA_REGISTRYINDEX); - }); - } -} - -inline LuaRef::LuaRef(LuaRef&& other) { - ref_ = other.ref_; - state_ = other.state_; - other.state_ = nullptr; -} - -inline LuaRef& LuaRef::operator=(LuaRef&& other) { - LuaRef(std::move(other)).swap(*this); - return *this; -} - -inline LuaRef& LuaRef::operator=(const LuaRef& other) { - LuaRef(other).swap(*this); - return *this; -} - -inline void LuaRef::swap(LuaRef& other) { // NOLINT(*) - std::swap(state_, other.state_); - std::swap(ref_, other.ref_); -} - -inline LuaRef::~LuaRef() { - if (state_ != nullptr) { - state_->PRun_([this](lua_State* L) { - luaL_unref(L, LUA_REGISTRYINDEX, ref_); - }); - } -} - -inline bool LuaRef::is_nil() const { - return state_ == nullptr; -} - -std::ostream &operator<<(std::ostream &os, const LuaRef &r) { - if (!r.is_nil()) { - r.state_->PRun_([&os, &r](lua_State* L) { - lua_rawgeti(L, LUA_REGISTRYINDEX, r.ref_); - int type = lua_type(L, -1); - switch (type) { - case LUA_TSTRING: - os << "lua_string:'" << lua_tostring(L, -1) << "'"; break; - case LUA_TBOOLEAN: - os << "lua_bool:" << (lua_toboolean(L, -1) ? "true" : "false"); break; - case LUA_TNUMBER: - os << "lua_number:" << lua_tonumber(L, -1); break; - default: - os << "lua[ref=" << r.ref_ << ']' << lua_typename(L, type); break; - } - lua_pop(L, 1); - }); - } else { - os << "lua_nil"; - } - return os; -} - -template -inline T LuaRef::Get() const { - CHECK(state_ != nullptr) << "Get:: LuaRef is nil"; - T ret; - state_->PRun_([&ret, this](lua_State* L) { - lua_rawgeti(L, LUA_REGISTRYINDEX, ref_); - ret = lua_stack::Handler::Get(L, -1, state_); - lua_pop(L, 1); - }); - return ret; -} - -template -inline T* LuaRef::GetUDataPtr() const { - CHECK(state_ != nullptr) << "Get:: LuaRef is nil"; - T* ret; - state_->PRun_([&ret, this](lua_State* L) { - lua_rawgeti(L, LUA_REGISTRYINDEX, ref_); - ret = reinterpret_cast(lua_touserdata(L, -1)); - lua_pop(L, 1); - }); - return ret; -} - -// helper function to dispatch varg foreach -template -struct for_each_dispatcher_ { - static inline void run(const std::tuple& args, F f) { - f(std::get(args)); - for_each_dispatcher_<(I + 1) == sizeof...(Args), (I+1), F, Args...>::run(args, f); - } -}; -// helper function to run foreach -template -struct for_each_dispatcher_ { - static inline void run(const std::tuple& args, F f) { - } -}; - -// template function to iterate over tuples -template -inline void for_each(const std::tuple& args, F f) { - for_each_dispatcher_::run(args, f); -} - -template -inline LuaRef LuaRef::operator()(Args&& ...args) const { - CHECK(state_ != nullptr) << "LuaRef is nil"; - auto targ = std::make_tuple(std::forward(args)...); - size_t nargs = sizeof...(Args); - LuaRef ret; - state_->PRun_([this, nargs, &targ, &ret](lua_State* L) { - lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_); - CHECK(lua_isfunction(L, -1)) - << "Expect to invoke a function but type='" - << lua_typename(L, lua_type(L, -1)) << '\''; - for_each(targ, lua_stack::PushArg{L}); - LUA_CALL(lua_pcall(L, nargs, 1, 0)); - ret.SetByPopStack_(state_); - }); - return ret; -} - -template -inline LuaRef& LuaRef::SetField(const std::string& key, const T& value) { // NOLINT(*) - CHECK(state_ != nullptr) << "LuaRef is nil"; - state_->PRun_([this, &key, &value](lua_State* L) { - lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_); - CHECK(lua_istable(L, -1)) - << "Expect a table but type='" - << lua_typename(L, lua_type(L, -1)) << '\''; - lua_stack::Handler::Push(L, value); - lua_setfield(L, -2, key.c_str()); - lua_pop(L, 1); - }); - return *this; -} - -inline LuaRef LuaRef::operator[](const std::string& key) const { - CHECK(state_ != nullptr) << "LuaRef is nil"; - LuaRef ret; - state_->PRun_([this, &key, &ret](lua_State* L) { - lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_); - CHECK(lua_istable(L, -1)) - << "Expect a table but type='" - << lua_typename(L, lua_type(L, -1)) << '\''; - lua_getfield(L, -1, key.c_str()); - ret.SetByPopStack_(state_); - lua_pop(L, 1); - }); - return ret; -} - -inline LuaRef LuaRef::operator[](size_t index) const { - CHECK(state_ != nullptr) << "LuaRef is nil"; - LuaRef ret; - state_->PRun_([this, index, &ret](lua_State* L) { - lua_rawgeti(L, LUA_REGISTRYINDEX, this->ref_); - CHECK(lua_istable(L, -1)) - << "Expect a table but type='" - << lua_typename(L, lua_type(L, -1)) << '\''; - lua_rawgeti(L, -1, index); - ret.SetByPopStack_(state_); - lua_pop(L, 1); - }); - return ret; -} - -//! \endcond -} // namespace dmlc - -#endif // DMLC_LUA_H_ diff --git a/include/dmlc/memory.h b/include/dmlc/memory.h deleted file mode 100644 index 3a2b9b07988f..000000000000 --- a/include/dmlc/memory.h +++ /dev/null @@ -1,261 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file memory.h - * \brief Additional memory hanlding utilities. - */ -#ifndef DMLC_MEMORY_H_ -#define DMLC_MEMORY_H_ - -#include -#include "./base.h" -#include "./logging.h" -#include "./thread_local.h" - -namespace dmlc { - -/*! - * \brief A memory pool that allocate memory of fixed size and alignment. - * \tparam size The size of each piece. - * \tparam align The alignment requirement of the memory. - */ -template -class MemoryPool { - public: - /*! \brief constructor */ - MemoryPool() { - static_assert(align % alignof(LinkedList) == 0, - "alignment requirement failed."); - curr_page_.reset(new Page()); - } - /*! \brief allocate a new memory of size */ - inline void* allocate() { - if (head_ != nullptr) { - LinkedList* ret = head_; - head_ = head_->next; - return ret; - } else { - if (page_ptr_ < kPageSize) { - return &(curr_page_->data[page_ptr_++]); - } else { - allocated_.push_back(std::move(curr_page_)); - curr_page_.reset(new Page()); - page_ptr_ = 1; - return &(curr_page_->data[0]); - } - } - } - /*! - * \brief deallocate a piece of memory - * \param p The pointer to the memory to be de-allocated. - */ - inline void deallocate(void* p) { - LinkedList* ptr = static_cast(p); - ptr->next = head_; - head_ = ptr; - } - - private: - // page size of each member - static const int kPageSize = ((1 << 22) / size); - // page to be requested. - struct Page { - typename std::aligned_storage::type data[kPageSize]; - }; - // internal linked list structure. - struct LinkedList { - LinkedList* next{nullptr}; - }; - // head of free list - LinkedList* head_{nullptr}; - // current free page - std::unique_ptr curr_page_; - // pointer to the current free page position. - size_t page_ptr_{0}; - // allocated pages. - std::vector > allocated_; -}; - - -/*! - * \brief A thread local allocator that get memory from a threadlocal memory pool. - * This is suitable to allocate objects that do not cross thread. - * \tparam T the type of the data to be allocated. - */ -template -class ThreadlocalAllocator { - public: - /*! \brief pointer type */ - typedef T* pointer; - /*! \brief const pointer type */ - typedef const T* const_ptr; - /*! \brief value type */ - typedef T value_type; - /*! \brief default constructor */ - ThreadlocalAllocator() {} - /*! - * \brief constructor from another allocator - * \param other another allocator - * \tparam U another type - */ - template - ThreadlocalAllocator(const ThreadlocalAllocator& other) {} - /*! - * \brief allocate memory - * \param n number of blocks - * \return an uninitialized memory of type T. - */ - inline T* allocate(size_t n) { - CHECK_EQ(n, 1); - typedef ThreadLocalStore > Store; - return static_cast(Store::Get()->allocate()); - } - /*! - * \brief deallocate memory - * \param p a memory to be returned. - * \param n number of blocks - */ - inline void deallocate(T* p, size_t n) { - CHECK_EQ(n, 1); - typedef ThreadLocalStore > Store; - Store::Get()->deallocate(p); - } -}; - - -/*! - * \brief a shared pointer like type that allocate object - * from a threadlocal object pool. This object is not thread-safe - * but can be faster than shared_ptr in certain usecases. - * \tparam T the data type. - */ -template -struct ThreadlocalSharedPtr { - public: - /*! \brief default constructor */ - ThreadlocalSharedPtr() : block_(nullptr) {} - /*! - * \brief constructor from nullptr - * \param other the nullptr type - */ - ThreadlocalSharedPtr(std::nullptr_t other) : block_(nullptr) {} // NOLINT(*) - /*! - * \brief copy constructor - * \param other another pointer. - */ - ThreadlocalSharedPtr(const ThreadlocalSharedPtr& other) - : block_(other.block_) { - IncRef(block_); - } - /*! - * \brief move constructor - * \param other another pointer. - */ - ThreadlocalSharedPtr(ThreadlocalSharedPtr&& other) - : block_(other.block_) { - other.block_ = nullptr; - } - /*! - * \brief destructor - */ - ~ThreadlocalSharedPtr() { - DecRef(block_); - } - /*! - * \brief move assignment - * \param other another object to be assigned. - * \return self. - */ - inline ThreadlocalSharedPtr& operator=(ThreadlocalSharedPtr&& other) { - DecRef(block_); - block_ = other.block_; - other.block_ = nullptr; - return *this; - } - /*! - * \brief copy assignment - * \param other another object to be assigned. - * \return self. - */ - inline ThreadlocalSharedPtr &operator=(const ThreadlocalSharedPtr& other) { - DecRef(block_); - block_ = other.block_; - IncRef(block_); - return *this; - } - /*! \brief check if nullptr */ - inline bool operator==(std::nullptr_t other) const { - return block_ == nullptr; - } - /*! - * \return get the pointer content. - */ - inline T* get() const { - if (block_ == nullptr) return nullptr; - return reinterpret_cast(&(block_->data)); - } - /*! - * \brief reset the pointer to nullptr. - */ - inline void reset() { - DecRef(block_); - block_ = nullptr; - } - /*! \return if use_count == 1*/ - inline bool unique() const { - if (block_ == nullptr) return false; - return block_->use_count_ == 1; - } - /*! \return dereference pointer */ - inline T* operator*() const { - return reinterpret_cast(&(block_->data)); - } - /*! \return dereference pointer */ - inline T* operator->() const { - return reinterpret_cast(&(block_->data)); - } - /*! - * \brief create a new space from threadlocal storage and return it. - * \tparam Args the arguments. - * \param args The input argument - * \return the allocated pointer. - */ - template - inline static ThreadlocalSharedPtr Create(Args&&... args) { - ThreadlocalAllocator arena; - ThreadlocalSharedPtr p; - p.block_ = arena.allocate(1); - p.block_->use_count_ = 1; - new (&(p.block_->data)) T(std::forward(args)...); - return p; - } - - private: - // internal reference block - struct RefBlock { - typename std::aligned_storage::type data; - unsigned use_count_; - }; - // decrease ref counter - inline static void DecRef(RefBlock* block) { - if (block != nullptr) { - if (--block->use_count_ == 0) { - ThreadlocalAllocator arena; - T* dptr = reinterpret_cast(&(block->data)); - dptr->~T(); - arena.deallocate(block, 1); - } - } - } - // increase ref counter - inline static void IncRef(RefBlock* block) { - if (block != nullptr) { - ++block->use_count_; - } - } - // internal block - RefBlock *block_; -}; - -} // namespace dmlc - -#endif // DMLC_MEMORY_H_ diff --git a/include/dmlc/memory_io.h b/include/dmlc/memory_io.h deleted file mode 100644 index 4e807585cc31..000000000000 --- a/include/dmlc/memory_io.h +++ /dev/null @@ -1,105 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file memory_io.h - * \brief defines binary serialization class to serialize things into/from memory region. - */ -#ifndef DMLC_MEMORY_IO_H_ -#define DMLC_MEMORY_IO_H_ - -#include -#include -#include -#include "./base.h" -#include "./io.h" -#include "./logging.h" - -namespace dmlc { -/*! - * \brief A Stream that operates on fixed region of memory - * This class allows us to read/write from/to a fixed memory region. - */ -struct MemoryFixedSizeStream : public SeekStream { - public: - /*! - * \brief constructor - * \param p_buffer the head pointer of the memory region. - * \param buffer_size the size of the memorybuffer - */ - MemoryFixedSizeStream(void *p_buffer, size_t buffer_size) - : p_buffer_(reinterpret_cast(p_buffer)), - buffer_size_(buffer_size) { - curr_ptr_ = 0; - } - virtual size_t Read(void *ptr, size_t size) { - CHECK(curr_ptr_ + size <= buffer_size_); - size_t nread = std::min(buffer_size_ - curr_ptr_, size); - if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread); - curr_ptr_ += nread; - return nread; - } - virtual void Write(const void *ptr, size_t size) { - if (size == 0) return; - CHECK(curr_ptr_ + size <= buffer_size_); - std::memcpy(p_buffer_ + curr_ptr_, ptr, size); - curr_ptr_ += size; - } - virtual void Seek(size_t pos) { - curr_ptr_ = static_cast(pos); - } - virtual size_t Tell(void) { - return curr_ptr_; - } - - private: - /*! \brief in memory buffer */ - char *p_buffer_; - /*! \brief current pointer */ - size_t buffer_size_; - /*! \brief current pointer */ - size_t curr_ptr_; -}; // class MemoryFixedSizeStream - -/*! - * \brief A in memory stream that is backed by std::string. - * This class allows us to read/write from/to a std::string. - */ -struct MemoryStringStream : public dmlc::SeekStream { - public: - /*! - * \brief constructor - * \param p_buffer the pointer to the string. - */ - explicit MemoryStringStream(std::string *p_buffer) - : p_buffer_(p_buffer) { - curr_ptr_ = 0; - } - virtual size_t Read(void *ptr, size_t size) { - CHECK(curr_ptr_ <= p_buffer_->length()); - size_t nread = std::min(p_buffer_->length() - curr_ptr_, size); - if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread); - curr_ptr_ += nread; - return nread; - } - virtual void Write(const void *ptr, size_t size) { - if (size == 0) return; - if (curr_ptr_ + size > p_buffer_->length()) { - p_buffer_->resize(curr_ptr_+size); - } - std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size); - curr_ptr_ += size; - } - virtual void Seek(size_t pos) { - curr_ptr_ = static_cast(pos); - } - virtual size_t Tell(void) { - return curr_ptr_; - } - - private: - /*! \brief in memory buffer */ - std::string *p_buffer_; - /*! \brief current pointer */ - size_t curr_ptr_; -}; // class MemoryStringStream -} // namespace dmlc -#endif // DMLC_MEMORY_IO_H_ diff --git a/include/dmlc/omp.h b/include/dmlc/omp.h deleted file mode 100644 index 8b8e506b5430..000000000000 --- a/include/dmlc/omp.h +++ /dev/null @@ -1,47 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file omp.h - * \brief header to handle OpenMP compatibility issues - */ -#ifndef DMLC_OMP_H_ -#define DMLC_OMP_H_ - - -#if defined(_OPENMP) -#include -#else - -#if defined(__ANDROID__) -#define __GOMP_NOTHROW -#elif defined(__cplusplus) -#define __GOMP_NOTHROW throw() -#else -#define __GOMP_NOTHROW __attribute__((__nothrow__)) -#endif - -//! \cond Doxygen_Suppress -#ifdef __cplusplus -extern "C" { -#endif -inline int omp_get_thread_num() __GOMP_NOTHROW { return 0; } -inline int omp_get_num_threads() __GOMP_NOTHROW { return 1; } -inline int omp_get_max_threads() __GOMP_NOTHROW { return 1; } -inline int omp_get_num_procs() __GOMP_NOTHROW { return 1; } -inline void omp_set_num_threads(int nthread) __GOMP_NOTHROW {} -#ifdef __cplusplus -} -#endif // __cplusplus -#endif // _OPENMP - -// loop variable used in openmp -namespace dmlc { -#ifdef _MSC_VER -typedef int omp_uint; -typedef long omp_ulong; // NOLINT(*) -#else -typedef unsigned omp_uint; -typedef unsigned long omp_ulong; // NOLINT(*) -#endif -//! \endcond -} // namespace dmlc -#endif // DMLC_OMP_H_ diff --git a/include/dmlc/optional.h b/include/dmlc/optional.h deleted file mode 100644 index dedbc7478102..000000000000 --- a/include/dmlc/optional.h +++ /dev/null @@ -1,261 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file optional.h - * \brief Container to hold optional data. - */ -#ifndef DMLC_OPTIONAL_H_ -#define DMLC_OPTIONAL_H_ - -#include -#include -#include -#include - -#include "./base.h" -#include "./common.h" -#include "./logging.h" -#include "./type_traits.h" - -namespace dmlc { - -/*! \brief dummy type for assign null to optional */ -struct nullopt_t { -#if defined(_MSC_VER) && _MSC_VER < 1900 - /*! \brief dummy constructor */ - explicit nullopt_t(int a) {} -#else - /*! \brief dummy constructor */ - constexpr nullopt_t(int a) {} -#endif -}; - -/*! Assign null to optional: optional x = nullopt; */ -constexpr const nullopt_t nullopt = nullopt_t(0); - -/*! - * \brief c++17 compatible optional class. - * - * At any time an optional instance either - * hold no value (string representation "None") - * or hold a value of type T. - */ -template -class optional { - public: - /*! \brief construct an optional object that contains no value */ - optional() : is_none(true) {} - /*! \brief construct an optional object with value */ - explicit optional(const T& value) { - is_none = false; - new (&val) T(value); - } - /*! \brief construct an optional object with another optional object */ - optional(const optional& other) { - is_none = other.is_none; - if (!is_none) { - new (&val) T(other.value()); - } - } - /*! \brief deconstructor */ - ~optional() { - if (!is_none) { - reinterpret_cast(&val)->~T(); - } - } - /*! \brief swap two optional */ - void swap(optional& other) { - std::swap(val, other.val); - std::swap(is_none, other.is_none); - } - /*! \brief set this object to hold value - * \param value the value to hold - * \return return self to support chain assignment - */ - optional& operator=(const T& value) { - (optional(value)).swap(*this); - return *this; - } - /*! \brief set this object to hold the same value with other - * \param other the other object - * \return return self to support chain assignment - */ - optional& operator=(const optional &other) { - (optional(other)).swap(*this); - return *this; - } - /*! \brief clear the value this object is holding. - * optional x = nullopt; - */ - optional& operator=(nullopt_t) { - (optional()).swap(*this); - return *this; - } - /*! \brief non-const dereference operator */ - T& operator*() { // NOLINT(*) - return *reinterpret_cast(&val); - } - /*! \brief const dereference operator */ - const T& operator*() const { - return *reinterpret_cast(&val); - } - /*! \brief equal comparison */ - bool operator==(const optional& other) const { - return this->is_none == other.is_none && - (this->is_none == true || this->value() == other.value()); - } - /*! \brief return the holded value. - * throws std::logic_error if holding no value - */ - const T& value() const { - if (is_none) { - throw std::logic_error("bad optional access"); - } - return *reinterpret_cast(&val); - } - /*! \brief whether this object is holding a value */ - explicit operator bool() const { return !is_none; } - /*! \brief whether this object is holding a value (alternate form). */ - bool has_value() const { return operator bool(); } - - private: - // whether this is none - bool is_none; - // on stack storage of value - typename std::aligned_storage::type val; -}; - -/*! \brief serialize an optional object to string. - * - * \code - * dmlc::optional x; - * std::cout << x; // None - * x = 0; - * std::cout << x; // 0 - * \endcode - * - * \param os output stream - * \param t source optional object - * \return output stream - */ -template -std::ostream &operator<<(std::ostream &os, const optional &t) { - if (t) { - os << *t; - } else { - os << "None"; - } - return os; -} - -/*! \brief parse a string object into optional - * - * \code - * dmlc::optional x; - * std::string s1 = "1"; - * std::istringstream is1(s1); - * s1 >> x; // x == optional(1) - * - * std::string s2 = "None"; - * std::istringstream is2(s2); - * s2 >> x; // x == optional() - * \endcode - * - * \param is input stream - * \param t target optional object - * \return input stream - */ -template -std::istream &operator>>(std::istream &is, optional &t) { - char buf[4]; - std::streampos origin = is.tellg(); - is.read(buf, 4); - if (is.fail() || buf[0] != 'N' || buf[1] != 'o' || - buf[2] != 'n' || buf[3] != 'e') { - is.clear(); - is.seekg(origin); - T x; - is >> x; - t = x; - if (std::is_integral::value && !is.eof() && is.peek() == 'L') is.get(); - } else { - t = nullopt; - } - return is; -} -/*! \brief specialization of '>>' istream parsing for optional - * - * Permits use of generic parameter FieldEntry class to create - * FieldEntry> without explicit specialization. - * - * \code - * dmlc::optional x; - * std::string s1 = "true"; - * std::istringstream is1(s1); - * s1 >> x; // x == optional(true) - * - * std::string s2 = "None"; - * std::istringstream is2(s2); - * s2 >> x; // x == optional() - * \endcode - * - * \param is input stream - * \param t target optional object - * \return input stream - */ -inline std::istream &operator>>(std::istream &is, optional &t) { - // Discard initial whitespace - while (isspace(is.peek())) - is.get(); - // Extract chars that might be valid into a separate string, stopping - // on whitespace or other non-alphanumerics such as ",)]". - std::string s; - while (isalnum(is.peek())) - s.push_back(is.get()); - - if (!is.fail()) { - std::transform(s.begin(), s.end(), s.begin(), ::tolower); - if (s == "1" || s == "true") - t = true; - else if (s == "0" || s == "false") - t = false; - else if (s == "none") - t = nullopt; - else - is.setstate(std::ios::failbit); - } - - return is; -} - -/*! \brief description for optional int */ -DMLC_DECLARE_TYPE_NAME(optional, "int or None"); -/*! \brief description for optional bool */ -DMLC_DECLARE_TYPE_NAME(optional, "boolean or None"); -/*! \brief description for optional float */ -DMLC_DECLARE_TYPE_NAME(optional, "float or None"); -/*! \brief description for optional double */ -DMLC_DECLARE_TYPE_NAME(optional, "double or None"); - -} // namespace dmlc - -namespace std { -/*! \brief std hash function for optional */ -template -struct hash > { - /*! - * \brief returns hash of the optional value. - * \param val value. - * \return hash code. - */ - size_t operator()(const dmlc::optional& val) const { - std::hash hash_bool; - size_t res = hash_bool(val.has_value()); - if (val.has_value()) { - res = dmlc::HashCombine(res, val.value()); - } - return res; - } -}; -} // namespace std - -#endif // DMLC_OPTIONAL_H_ diff --git a/include/dmlc/parameter.h b/include/dmlc/parameter.h deleted file mode 100644 index 0830cb99cd19..000000000000 --- a/include/dmlc/parameter.h +++ /dev/null @@ -1,1065 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file parameter.h - * \brief Provide lightweight util to do parameter setup and checking. - */ -#ifndef DMLC_PARAMETER_H_ -#define DMLC_PARAMETER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "./base.h" -#include "./json.h" -#include "./logging.h" -#include "./type_traits.h" -#include "./optional.h" - -namespace dmlc { -// this file is backward compatible with non-c++11 -/*! \brief Error throwed by parameter checking */ -struct ParamError : public dmlc::Error { - /*! - * \brief constructor - * \param msg error message - */ - explicit ParamError(const std::string &msg) - : dmlc::Error(msg) {} -}; - -/*! - * \brief Get environment variable with default. - * \param key the name of environment variable. - * \param default_value the default value of environment vriable. - * \return The value received - */ -template -inline ValueType GetEnv(const char *key, - ValueType default_value); -/*! - * \brief Set environment variable. - * \param key the name of environment variable. - * \param value the new value for key. - * \return The value received - */ -template -inline void SetEnv(const char *key, - ValueType value); - -/*! \brief internal namespace for parameter manangement */ -namespace parameter { -// forward declare ParamManager -class ParamManager; -// forward declare FieldAccessEntry -class FieldAccessEntry; -// forward declare FieldEntry -template -class FieldEntry; -// forward declare ParamManagerSingleton -template -struct ParamManagerSingleton; - -/*! \brief option in parameter initialization */ -enum ParamInitOption { - /*! \brief allow unknown parameters */ - kAllowUnknown, - /*! \brief need to match exact parameters */ - kAllMatch, - /*! \brief allow unmatched hidden field with format __*__ */ - kAllowHidden -}; -} // namespace parameter -/*! - * \brief Information about a parameter field in string representations. - */ -struct ParamFieldInfo { - /*! \brief name of the field */ - std::string name; - /*! \brief type of the field in string format */ - std::string type; - /*! - * \brief detailed type information string - * This include the default value, enum constran and typename. - */ - std::string type_info_str; - /*! \brief detailed description of the type */ - std::string description; -}; - -/*! - * \brief Parameter is the base type every parameter struct should inheritate from - * The following code is a complete example to setup parameters. - * \code - * struct Param : public dmlc::Parameter { - * float learning_rate; - * int num_hidden; - * std::string name; - * // declare parameters in header file - * DMLC_DECLARE_PARAMETER(Param) { - * DMLC_DECLARE_FIELD(num_hidden).set_range(0, 1000); - * DMLC_DECLARE_FIELD(learning_rate).set_default(0.01f); - * DMLC_DECLARE_FIELD(name).set_default("hello"); - * } - * }; - * // register it in cc file - * DMLC_REGISTER_PARAMETER(Param); - * \endcode - * - * After that, the Param struct will get all the functions defined in Parameter. - * \tparam PType the type of parameter struct - * - * \sa DMLC_DECLARE_FIELD, DMLC_REGISTER_PARAMETER, DMLC_DECLARE_PARAMETER - */ -template -struct Parameter { - public: - /*! - * \brief initialize the parameter by keyword arguments. - * This function will initialize the parameter struct, check consistency - * and throw error if something wrong happens. - * - * \param kwargs map of keyword arguments, or vector of pairs - * \parma option The option on initialization. - * \tparam Container container type - * \throw ParamError when something go wrong. - */ - template - inline void Init(const Container &kwargs, - parameter::ParamInitOption option = parameter::kAllowHidden) { - PType::__MANAGER__()->RunInit(static_cast(this), - kwargs.begin(), kwargs.end(), - NULL, - option); - } - /*! - * \brief initialize the parameter by keyword arguments. - * This is same as Init, but allow unknown arguments. - * - * \param kwargs map of keyword arguments, or vector of pairs - * \tparam Container container type - * \throw ParamError when something go wrong. - * \return vector of pairs of unknown arguments. - */ - template - inline std::vector > - InitAllowUnknown(const Container &kwargs) { - std::vector > unknown; - PType::__MANAGER__()->RunInit(static_cast(this), - kwargs.begin(), kwargs.end(), - &unknown, parameter::kAllowUnknown); - return unknown; - } - - /*! - * \brief Update the dict with values stored in parameter. - * - * \param dict The dictionary to be updated. - * \tparam Container container type - */ - template - inline void UpdateDict(Container *dict) const { - PType::__MANAGER__()->UpdateDict(this->head(), dict); - } - /*! - * \brief Return a dictionary representation of the parameters - * \return A dictionary that maps key -> value - */ - inline std::map __DICT__() const { - std::vector > vec - = PType::__MANAGER__()->GetDict(this->head()); - return std::map(vec.begin(), vec.end()); - } - /*! - * \brief Write the parameters in JSON format. - * \param writer JSONWriter used for writing. - */ - inline void Save(dmlc::JSONWriter *writer) const { - writer->Write(this->__DICT__()); - } - /*! - * \brief Load the parameters from JSON. - * \param reader JSONReader used for loading. - * \throw ParamError when something go wrong. - */ - inline void Load(dmlc::JSONReader *reader) { - std::map kwargs; - reader->Read(&kwargs); - this->Init(kwargs); - } - /*! - * \brief Get the fields of the parameters. - * \return List of ParamFieldInfo of each field. - */ - inline static std::vector __FIELDS__() { - return PType::__MANAGER__()->GetFieldInfo(); - } - /*! - * \brief Print docstring of the parameter - * \return the printed docstring - */ - inline static std::string __DOC__() { - std::ostringstream os; - PType::__MANAGER__()->PrintDocString(os); - return os.str(); - } - - protected: - /*! - * \brief internal function to allow declare of a parameter memember - * \param manager the parameter manager - * \param key the key name of the parameter - * \param ref the reference to the parameter in the struct. - */ - template - inline parameter::FieldEntry& DECLARE( - parameter::ParamManagerSingleton *manager, - const std::string &key, DType &ref) { // NOLINT(*) - parameter::FieldEntry *e = - new parameter::FieldEntry(); - e->Init(key, this->head(), ref); - manager->manager.AddEntry(key, e); - return *e; - } - - private: - /*! \return Get head pointer of child structure */ - inline PType *head() const { - return static_cast(const_cast*>(this)); - } -}; - -//! \cond Doxygen_Suppress -/*! - * \brief macro used to declare parameter - * - * Example: - * \code - * struct Param : public dmlc::Parameter { - * // declare parameters in header file - * DMLC_DECLARE_PARAMETER(Param) { - * // details of declarations - * } - * }; - * \endcode - * - * This macro need to be put in a source file so that registeration only happens once. - * Refer to example code in Parameter for details - * - * \param PType the name of parameter struct. - * \sa Parameter - */ -#define DMLC_DECLARE_PARAMETER(PType) \ - static ::dmlc::parameter::ParamManager *__MANAGER__(); \ - inline void __DECLARE__(::dmlc::parameter::ParamManagerSingleton *manager) \ - -/*! - * \brief macro to declare fields - * \param FieldName the name of the field. - */ -#define DMLC_DECLARE_FIELD(FieldName) this->DECLARE(manager, #FieldName, FieldName) - -/*! - * \brief macro to declare alias of a fields - * \param FieldName the name of the field. - * \param AliasName the name of the alias, must be declared after the field is declared. - */ -#define DMLC_DECLARE_ALIAS(FieldName, AliasName) manager->manager.AddAlias(#FieldName, #AliasName) - -/*! - * \brief Macro used to register parameter. - * - * This macro need to be put in a source file so that registeration only happens once. - * Refer to example code in Parameter for details - * \param PType the type of parameter struct. - * \sa Parameter - */ -#define DMLC_REGISTER_PARAMETER(PType) \ - ::dmlc::parameter::ParamManager *PType::__MANAGER__() { \ - static ::dmlc::parameter::ParamManagerSingleton inst(#PType); \ - return &inst.manager; \ - } \ - static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \ - __make__ ## PType ## ParamManager__ = \ - (*PType::__MANAGER__()) \ - -//! \endcond -/*! - * \brief internal namespace for parameter manangement - * There is no need to use it directly in normal case - */ -namespace parameter { -/*! - * \brief FieldAccessEntry interface to help manage the parameters - * Each entry can be used to access one parameter in the Parameter struct. - * - * This is an internal interface used that is used to manage parameters - */ -class FieldAccessEntry { - public: - FieldAccessEntry() - : has_default_(false) {} - /*! \brief destructor */ - virtual ~FieldAccessEntry() {} - /*! - * \brief set the default value. - * \param head the pointer to the head of the struct - * \throw error if no default is presented - */ - virtual void SetDefault(void *head) const = 0; - /*! - * \brief set the parameter by string value - * \param head the pointer to the head of the struct - * \param value the value to be set - */ - virtual void Set(void *head, const std::string &value) const = 0; - // check if value is OK - virtual void Check(void *head) const {} - /*! - * \brief get the string representation of value. - * \param head the pointer to the head of the struct - */ - virtual std::string GetStringValue(void *head) const = 0; - /*! - * \brief Get field information - * \return the corresponding field information - */ - virtual ParamFieldInfo GetFieldInfo() const = 0; - - protected: - /*! \brief whether this parameter have default value */ - bool has_default_; - /*! \brief positional index of parameter in struct */ - size_t index_; - /*! \brief parameter key name */ - std::string key_; - /*! \brief parameter type */ - std::string type_; - /*! \brief description of the parameter */ - std::string description_; - /*! - * \brief print string representation of default value - * \parma os the stream to print the docstring to. - */ - virtual void PrintDefaultValueString(std::ostream &os) const = 0; // NOLINT(*) - // allow ParamManager to modify self - friend class ParamManager; -}; - -/*! - * \brief manager class to handle parameter structure for each type - * An manager will be created for each parameter structure. - */ -class ParamManager { - public: - /*! \brief destructor */ - ~ParamManager() { - for (size_t i = 0; i < entry_.size(); ++i) { - delete entry_[i]; - } - } - /*! - * \brief find the access entry by parameter key - * \param key the key of the parameter. - * \return pointer to FieldAccessEntry, NULL if nothing is found. - */ - inline FieldAccessEntry *Find(const std::string &key) const { - std::map::const_iterator it = - entry_map_.find(key); - if (it == entry_map_.end()) return NULL; - return it->second; - } - /*! - * \brief set parameter by keyword arguments. - * \param head head to the parameter field. - * \param begin begin iterator of original kwargs - * \param end end iterator of original kwargs - * \param unknown_args optional, used to hold unknown arguments - * When it is specified, unknown arguments will be stored into here, instead of raise an error - * \tparam RandomAccessIterator iterator type - * \throw ParamError when there is unknown argument and unknown_args == NULL, or required argument is missing. - */ - template - inline void RunInit(void *head, - RandomAccessIterator begin, - RandomAccessIterator end, - std::vector > *unknown_args, - parameter::ParamInitOption option) const { - std::set selected_args; - for (RandomAccessIterator it = begin; it != end; ++it) { - FieldAccessEntry *e = Find(it->first); - if (e != NULL) { - e->Set(head, it->second); - e->Check(head); - selected_args.insert(e); - } else { - if (unknown_args != NULL) { - unknown_args->push_back(*it); - } else { - if (option != parameter::kAllowUnknown) { - if (option == parameter::kAllowHidden && - it->first.length() > 4 && - it->first.find("__") == 0 && - it->first.rfind("__") == it->first.length()-2) { - continue; - } - std::ostringstream os; - os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n"; - os << "----------------\n"; - PrintDocString(os); - throw dmlc::ParamError(os.str()); - } - } - } - } - - for (std::map::const_iterator it = entry_map_.begin(); - it != entry_map_.end(); ++it) { - if (selected_args.count(it->second) == 0) { - it->second->SetDefault(head); - } - } - } - /*! - * \brief internal function to add entry to manager, - * The manager will take ownership of the entry. - * \param key the key to the parameters - * \param e the pointer to the new entry. - */ - inline void AddEntry(const std::string &key, FieldAccessEntry *e) { - e->index_ = entry_.size(); - // TODO(bing) better error message - if (entry_map_.count(key) != 0) { - LOG(FATAL) << "key " << key << " has already been registered in " << name_; - } - entry_.push_back(e); - entry_map_[key] = e; - } - /*! - * \brief internal function to add entry to manager, - * The manager will take ownership of the entry. - * \param key the key to the parameters - * \param e the pointer to the new entry. - */ - inline void AddAlias(const std::string& field, const std::string& alias) { - if (entry_map_.count(field) == 0) { - LOG(FATAL) << "key " << field << " has not been registered in " << name_; - } - if (entry_map_.count(alias) != 0) { - LOG(FATAL) << "Alias " << alias << " has already been registered in " << name_; - } - entry_map_[alias] = entry_map_[field]; - } - /*! - * \brief set the name of parameter manager - * \param name the name to set - */ - inline void set_name(const std::string &name) { - name_ = name; - } - /*! - * \brief get field information of each field. - * \return field information - */ - inline std::vector GetFieldInfo() const { - std::vector ret(entry_.size()); - for (size_t i = 0; i < entry_.size(); ++i) { - ret[i] = entry_[i]->GetFieldInfo(); - } - return ret; - } - /*! - * \brief Print readible docstring to ostream, add newline. - * \parma os the stream to print the docstring to. - */ - inline void PrintDocString(std::ostream &os) const { // NOLINT(*) - for (size_t i = 0; i < entry_.size(); ++i) { - ParamFieldInfo info = entry_[i]->GetFieldInfo(); - os << info.name << " : " << info.type_info_str << '\n'; - if (info.description.length() != 0) { - os << " " << info.description << '\n'; - } - } - } - /*! - * \brief Get internal parameters in vector of pairs. - * \param head the head of the struct. - * \param skip_default skip the values that equals default value. - * \return the parameter dictionary. - */ - inline std::vector > GetDict(void * head) const { - std::vector > ret; - for (std::map::const_iterator - it = entry_map_.begin(); it != entry_map_.end(); ++it) { - ret.push_back(std::make_pair(it->first, it->second->GetStringValue(head))); - } - return ret; - } - /*! - * \brief Update the dictionary with values in parameter. - * \param head the head of the struct. - * \tparam Container The container type - * \return the parameter dictionary. - */ - template - inline void UpdateDict(void * head, Container* dict) const { - for (std::map::const_iterator - it = entry_map_.begin(); it != entry_map_.end(); ++it) { - (*dict)[it->first] = it->second->GetStringValue(head); - } - } - - private: - /*! \brief parameter struct name */ - std::string name_; - /*! \brief positional list of entries */ - std::vector entry_; - /*! \brief map from key to entry */ - std::map entry_map_; -}; - -//! \cond Doxygen_Suppress - -// The following piece of code will be template heavy and less documented -// singleton parameter manager for certain type, used for initialization -template -struct ParamManagerSingleton { - ParamManager manager; - explicit ParamManagerSingleton(const std::string ¶m_name) { - PType param; - manager.set_name(param_name); - param.__DECLARE__(this); - } -}; - -// Base class of FieldEntry -// implement set_default -template -class FieldEntryBase : public FieldAccessEntry { - public: - // entry type - typedef TEntry EntryType; - // implement set value - virtual void Set(void *head, const std::string &value) const { - std::istringstream is(value); - is >> this->Get(head); - if (!is.fail()) { - while (!is.eof()) { - int ch = is.get(); - if (ch == EOF) { - is.clear(); break; - } - if (!isspace(ch)) { - is.setstate(std::ios::failbit); break; - } - } - } - - if (is.fail()) { - std::ostringstream os; - os << "Invalid Parameter format for " << key_ - << " expect " << type_ << " but value=\'" << value<< '\''; - throw dmlc::ParamError(os.str()); - } - } - virtual std::string GetStringValue(void *head) const { - std::ostringstream os; - PrintValue(os, this->Get(head)); - return os.str(); - } - virtual ParamFieldInfo GetFieldInfo() const { - ParamFieldInfo info; - std::ostringstream os; - info.name = key_; - info.type = type_; - os << type_; - if (has_default_) { - os << ',' << " optional, default="; - PrintDefaultValueString(os); - } else { - os << ", required"; - } - info.type_info_str = os.str(); - info.description = description_; - return info; - } - // implement set head to default value - virtual void SetDefault(void *head) const { - if (!has_default_) { - std::ostringstream os; - os << "Required parameter " << key_ - << " of " << type_ << " is not presented"; - throw dmlc::ParamError(os.str()); - } else { - this->Get(head) = default_value_; - } - } - // return reference of self as derived type - inline TEntry &self() { - return *(static_cast(this)); - } - // implement set_default - inline TEntry &set_default(const DType &default_value) { - default_value_ = default_value; - has_default_ = true; - // return self to allow chaining - return this->self(); - } - // implement describe - inline TEntry &describe(const std::string &description) { - description_ = description; - // return self to allow chaining - return this->self(); - } - // initialization function - inline void Init(const std::string &key, - void *head, DType &ref) { // NOLINT(*) - this->key_ = key; - if (this->type_.length() == 0) { - this->type_ = dmlc::type_name(); - } - this->offset_ = ((char*)&ref) - ((char*)head); // NOLINT(*) - } - - protected: - // print the value - virtual void PrintValue(std::ostream &os, DType value) const { // NOLINT(*) - os << value; - } - virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*) - PrintValue(os, default_value_); - } - // get the internal representation of parameter - // for example if this entry corresponds field param.learning_rate - // then Get(¶m) will return reference to param.learning_rate - inline DType &Get(void *head) const { - return *(DType*)((char*)(head) + offset_); // NOLINT(*) - } - // internal offset of the field - ptrdiff_t offset_; - // default value of field - DType default_value_; -}; - -// parameter base for numeric types that have range -template -class FieldEntryNumeric - : public FieldEntryBase { - public: - FieldEntryNumeric() - : has_begin_(false), has_end_(false) {} - // implement set_range - virtual TEntry &set_range(DType begin, DType end) { - begin_ = begin; end_ = end; - has_begin_ = true; has_end_ = true; - return this->self(); - } - // implement set_range - virtual TEntry &set_lower_bound(DType begin) { - begin_ = begin; has_begin_ = true; - return this->self(); - } - // consistency check for numeric ranges - virtual void Check(void *head) const { - FieldEntryBase::Check(head); - DType v = this->Get(head); - if (has_begin_ && has_end_) { - if (v < begin_ || v > end_) { - std::ostringstream os; - os << "value " << v << " for Parameter " << this->key_ - << " exceed bound [" << begin_ << ',' << end_ <<']'; - throw dmlc::ParamError(os.str()); - } - } else if (has_begin_ && v < begin_) { - std::ostringstream os; - os << "value " << v << " for Parameter " << this->key_ - << " should be greater equal to " << begin_; - throw dmlc::ParamError(os.str()); - } else if (has_end_ && v > end_) { - std::ostringstream os; - os << "value " << v << " for Parameter " << this->key_ - << " should be smaller equal to " << end_; - throw dmlc::ParamError(os.str()); - } - } - - protected: - // whether it have begin and end range - bool has_begin_, has_end_; - // data bound - DType begin_, end_; -}; - -/*! - * \brief FieldEntry defines parsing and checking behavior of DType. - * This class can be specialized to implement specific behavior of more settings. - * \tparam DType the data type of the entry. - */ -template -class FieldEntry : - public IfThenElseType::value, - FieldEntryNumeric, DType>, - FieldEntryBase, DType> >::Type { -}; - -// specialize define for int(enum) -template<> -class FieldEntry - : public FieldEntryNumeric, int> { - public: - // construct - FieldEntry() : is_enum_(false) {} - // parent - typedef FieldEntryNumeric, int> Parent; - // override set - virtual void Set(void *head, const std::string &value) const { - if (is_enum_) { - std::map::const_iterator it = enum_map_.find(value); - std::ostringstream os; - if (it == enum_map_.end()) { - os << "Invalid Input: \'" << value; - os << "\', valid values are: "; - PrintEnums(os); - throw dmlc::ParamError(os.str()); - } else { - os << it->second; - Parent::Set(head, os.str()); - } - } else { - Parent::Set(head, value); - } - } - virtual ParamFieldInfo GetFieldInfo() const { - if (is_enum_) { - ParamFieldInfo info; - std::ostringstream os; - info.name = key_; - info.type = type_; - PrintEnums(os); - if (has_default_) { - os << ',' << "optional, default="; - PrintDefaultValueString(os); - } else { - os << ", required"; - } - info.type_info_str = os.str(); - info.description = description_; - return info; - } else { - return Parent::GetFieldInfo(); - } - } - // add enum - inline FieldEntry &add_enum(const std::string &key, int value) { - if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \ - enum_back_map_.count(value) != 0) { - std::ostringstream os; - os << "Enum " << "(" << key << ": " << value << " exisit!" << ")\n"; - os << "Enums: "; - for (std::map::const_iterator it = enum_map_.begin(); - it != enum_map_.end(); ++it) { - os << "(" << it->first << ": " << it->second << "), "; - } - throw dmlc::ParamError(os.str()); - } - enum_map_[key] = value; - enum_back_map_[value] = key; - is_enum_ = true; - return this->self(); - } - - protected: - // enum flag - bool is_enum_; - // enum map - std::map enum_map_; - // enum map - std::map enum_back_map_; - // override print behavior - virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*) - os << '\''; - PrintValue(os, default_value_); - os << '\''; - } - // override print default - virtual void PrintValue(std::ostream &os, int value) const { // NOLINT(*) - if (is_enum_) { - CHECK_NE(enum_back_map_.count(value), 0U) - << "Value not found in enum declared"; - os << enum_back_map_.at(value); - } else { - os << value; - } - } - - - private: - inline void PrintEnums(std::ostream &os) const { // NOLINT(*) - os << '{'; - for (std::map::const_iterator - it = enum_map_.begin(); it != enum_map_.end(); ++it) { - if (it != enum_map_.begin()) { - os << ", "; - } - os << "\'" << it->first << '\''; - } - os << '}'; - } -}; - - -// specialize define for optional(enum) -template<> -class FieldEntry > - : public FieldEntryBase >, optional > { - public: - // construct - FieldEntry >() : is_enum_(false) {} - // parent - typedef FieldEntryBase >, optional > Parent; - // override set - virtual void Set(void *head, const std::string &value) const { - if (is_enum_ && value != "None") { - std::map::const_iterator it = enum_map_.find(value); - std::ostringstream os; - if (it == enum_map_.end()) { - os << "Invalid Input: \'" << value; - os << "\', valid values are: "; - PrintEnums(os); - throw dmlc::ParamError(os.str()); - } else { - os << it->second; - Parent::Set(head, os.str()); - } - } else { - Parent::Set(head, value); - } - } - virtual ParamFieldInfo GetFieldInfo() const { - if (is_enum_) { - ParamFieldInfo info; - std::ostringstream os; - info.name = key_; - info.type = type_; - PrintEnums(os); - if (has_default_) { - os << ',' << "optional, default="; - PrintDefaultValueString(os); - } else { - os << ", required"; - } - info.type_info_str = os.str(); - info.description = description_; - return info; - } else { - return Parent::GetFieldInfo(); - } - } - // add enum - inline FieldEntry > &add_enum(const std::string &key, int value) { - CHECK_NE(key, "None") << "None is reserved for empty optional"; - if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \ - enum_back_map_.count(value) != 0) { - std::ostringstream os; - os << "Enum " << "(" << key << ": " << value << " exisit!" << ")\n"; - os << "Enums: "; - for (std::map::const_iterator it = enum_map_.begin(); - it != enum_map_.end(); ++it) { - os << "(" << it->first << ": " << it->second << "), "; - } - throw dmlc::ParamError(os.str()); - } - enum_map_[key] = value; - enum_back_map_[value] = key; - is_enum_ = true; - return this->self(); - } - - protected: - // enum flag - bool is_enum_; - // enum map - std::map enum_map_; - // enum map - std::map enum_back_map_; - // override print behavior - virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*) - os << '\''; - PrintValue(os, default_value_); - os << '\''; - } - // override print default - virtual void PrintValue(std::ostream &os, optional value) const { // NOLINT(*) - if (is_enum_) { - if (!value) { - os << "None"; - } else { - CHECK_NE(enum_back_map_.count(value.value()), 0U) - << "Value not found in enum declared"; - os << enum_back_map_.at(value.value()); - } - } else { - os << value; - } - } - - - private: - inline void PrintEnums(std::ostream &os) const { // NOLINT(*) - os << "{None"; - for (std::map::const_iterator - it = enum_map_.begin(); it != enum_map_.end(); ++it) { - os << ", "; - os << "\'" << it->first << '\''; - } - os << '}'; - } -}; - -// specialize define for string -template<> -class FieldEntry - : public FieldEntryBase, std::string> { - public: - // parent class - typedef FieldEntryBase, std::string> Parent; - // override set - virtual void Set(void *head, const std::string &value) const { - this->Get(head) = value; - } - // override print default - virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*) - os << '\'' << default_value_ << '\''; - } -}; - -// specialize define for bool -template<> -class FieldEntry - : public FieldEntryBase, bool> { - public: - // parent class - typedef FieldEntryBase, bool> Parent; - // override set - virtual void Set(void *head, const std::string &value) const { - std::string lower_case; lower_case.resize(value.length()); - std::transform(value.begin(), value.end(), lower_case.begin(), ::tolower); - bool &ref = this->Get(head); - if (lower_case == "true") { - ref = true; - } else if (lower_case == "false") { - ref = false; - } else if (lower_case == "1") { - ref = true; - } else if (lower_case == "0") { - ref = false; - } else { - std::ostringstream os; - os << "Invalid Parameter format for " << key_ - << " expect " << type_ << " but value=\'" << value<< '\''; - throw dmlc::ParamError(os.str()); - } - } - - protected: - // print default string - virtual void PrintValue(std::ostream &os, bool value) const { // NOLINT(*) - os << static_cast(value); - } -}; - - -// specialize define for float. Uses stof for platform independent handling of -// INF, -INF, NAN, etc. -#if DMLC_USE_CXX11 -template <> -class FieldEntry : public FieldEntryNumeric, float> { - public: - // parent - typedef FieldEntryNumeric, float> Parent; - // override set - virtual void Set(void *head, const std::string &value) const { - try { - this->Get(head) = std::stof(value); - } catch (const std::invalid_argument &) { - std::ostringstream os; - os << "Invalid Parameter format for " << key_ << " expect " << type_ - << " but value=\'" << value << '\''; - throw dmlc::ParamError(os.str()); - } catch (const std::out_of_range&) { - std::ostringstream os; - os << "Out of range value for " << key_ << ", value=\'" << value << '\''; - throw dmlc::ParamError(os.str()); - } - } -}; - -// specialize define for double. Uses stod for platform independent handling of -// INF, -INF, NAN, etc. -template <> -class FieldEntry - : public FieldEntryNumeric, double> { - public: - // parent - typedef FieldEntryNumeric, double> Parent; - // override set - virtual void Set(void *head, const std::string &value) const { - try { - this->Get(head) = std::stod(value); - } catch (const std::invalid_argument &) { - std::ostringstream os; - os << "Invalid Parameter format for " << key_ << " expect " << type_ - << " but value=\'" << value << '\''; - throw dmlc::ParamError(os.str()); - } catch (const std::out_of_range&) { - std::ostringstream os; - os << "Out of range value for " << key_ << ", value=\'" << value << '\''; - throw dmlc::ParamError(os.str()); - } - } -}; -#endif // DMLC_USE_CXX11 - -} // namespace parameter -//! \endcond - -// implement GetEnv -template -inline ValueType GetEnv(const char *key, - ValueType default_value) { - const char *val = getenv(key); - // On some implementations, if the var is set to a blank string (i.e. "FOO="), then - // a blank string will be returned instead of NULL. In order to be consistent, if - // the environment var is a blank string, then also behave as if a null was returned. - if (val == nullptr || !*val) { - return default_value; - } - ValueType ret; - parameter::FieldEntry e; - e.Init(key, &ret, ret); - e.Set(&ret, val); - return ret; -} - -// implement SetEnv -template -inline void SetEnv(const char *key, - ValueType value) { - parameter::FieldEntry e; - e.Init(key, &value, value); -#ifdef _WIN32 - _putenv(key, e.GetStringValue(&value).c_str()); -#else - setenv(key, e.GetStringValue(&value).c_str(), 1); -#endif // _WIN32 -} -} // namespace dmlc -#endif // DMLC_PARAMETER_H_ diff --git a/include/dmlc/recordio.h b/include/dmlc/recordio.h deleted file mode 100644 index 6220780acadc..000000000000 --- a/include/dmlc/recordio.h +++ /dev/null @@ -1,196 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file recordio.h - * \brief recordio that is able to pack binary data into a splittable - * format, useful to exchange data in binary serialization, - * such as binary raw data or protobuf - */ -#ifndef DMLC_RECORDIO_H_ -#define DMLC_RECORDIO_H_ -#include -#include -#include "./io.h" -#include "./logging.h" - -namespace dmlc { -/*! - * \brief writer of binary recordio - * binary format for recordio - * recordio format: magic lrecord data pad - * - * - magic is magic number - * - pad is simply a padding space to make record align to 4 bytes - * - lrecord encodes length and continue bit - * - data.length() = (lrecord & (1U<<29U - 1)); - * - cflag == (lrecord >> 29U) & 7; - * - * cflag was used to handle (rare) special case when magic number - * occured in the data sequence. - * - * In such case, the data is splitted into multiple records by - * the cells of magic number - * - * (1) cflag == 0: this is a complete record; - * (2) cflag == 1: start of a multiple-rec; - * cflag == 2: middle of multiple-rec; - * cflag == 3: end of multiple-rec - */ -class RecordIOWriter { - public: - /*! - * \brief magic number of recordio - * note: (kMagic >> 29U) & 7 > 3 - * this ensures lrec will not be kMagic - */ - static const uint32_t kMagic = 0xced7230a; - /*! - * \brief encode the lrecord - * \param cflag cflag part of the lrecord - * \param length length part of lrecord - * \return the encoded data - */ - inline static uint32_t EncodeLRec(uint32_t cflag, uint32_t length) { - return (cflag << 29U) | length; - } - /*! - * \brief decode the flag part of lrecord - * \param rec the lrecord - * \return the flag - */ - inline static uint32_t DecodeFlag(uint32_t rec) { - return (rec >> 29U) & 7U; - } - /*! - * \brief decode the length part of lrecord - * \param rec the lrecord - * \return the length - */ - inline static uint32_t DecodeLength(uint32_t rec) { - return rec & ((1U << 29U) - 1U); - } - /*! - * \brief constructor - * \param stream the stream to be constructed - */ - explicit RecordIOWriter(Stream *stream) - : stream_(stream), seek_stream_(dynamic_cast(stream)), - except_counter_(0) { - CHECK(sizeof(uint32_t) == 4) << "uint32_t needs to be 4 bytes"; - } - /*! - * \brief write record to the stream - * \param buf the buffer of memory region - * \param size the size of record to write out - */ - void WriteRecord(const void *buf, size_t size); - /*! - * \brief write record to the stream - * \param data the data to write out - */ - inline void WriteRecord(const std::string &data) { - this->WriteRecord(data.c_str(), data.length()); - } - /*! - * \return number of exceptions(occurance of magic number) - * during the writing process - */ - inline size_t except_counter(void) const { - return except_counter_; - } - - /*! \brief tell the current position of the input stream */ - inline size_t Tell(void) { - CHECK(seek_stream_ != NULL) << "The input stream is not seekable"; - return seek_stream_->Tell(); - } - - private: - /*! \brief output stream */ - Stream *stream_; - /*! \brief seekable stream */ - SeekStream *seek_stream_; - /*! \brief counts the number of exceptions */ - size_t except_counter_; -}; -/*! - * \brief reader of binary recordio to reads in record from stream - * \sa RecordIOWriter - */ -class RecordIOReader { - public: - /*! - * \brief constructor - * \param stream the stream to be constructed - */ - explicit RecordIOReader(Stream *stream) - : stream_(stream), seek_stream_(dynamic_cast(stream)), - end_of_stream_(false) { - CHECK(sizeof(uint32_t) == 4) << "uint32_t needs to be 4 bytes"; - } - /*! - * \brief read next complete record from stream - * \param out_rec used to store output record in string - * \return true of read was successful, false if end of stream was reached - */ - bool NextRecord(std::string *out_rec); - - /*! \brief seek to certain position of the input stream */ - inline void Seek(size_t pos) { - CHECK(seek_stream_ != NULL) << "The input stream is not seekable"; - seek_stream_->Seek(pos); - } - - /*! \brief tell the current position of the input stream */ - inline size_t Tell(void) { - CHECK(seek_stream_ != NULL) << "The input stream is not seekable"; - return seek_stream_->Tell(); - } - - private: - /*! \brief output stream */ - Stream *stream_; - SeekStream *seek_stream_; - /*! \brief whether we are at end of stream */ - bool end_of_stream_; -}; - -/*! - * \brief reader of binary recordio from Blob returned by InputSplit - * This class divides the blob into several independent parts specified by caller, - * and read from one segment. - * The part reading can be used together with InputSplit::NextChunk for - * multi-threaded parsing(each thread take a RecordIOChunkReader) - * - * \sa RecordIOWriter, InputSplit - */ -class RecordIOChunkReader { - public: - /*! - * \brief constructor - * \param chunk source data returned by InputSplit - * \param part_index which part we want to reado - * \param num_parts number of total segments - */ - explicit RecordIOChunkReader(InputSplit::Blob chunk, - unsigned part_index = 0, - unsigned num_parts = 1); - /*! - * \brief read next complete record from stream - * the blob contains the memory content - * NOTE: this function is not threadsafe, use one - * RecordIOChunkReader per thread - * \param out_rec used to store output blob, the header is already - * removed and out_rec only contains the memory content - * \return true of read was successful, false if end was reached - */ - bool NextRecord(InputSplit::Blob *out_rec); - - private: - /*! \brief internal temporal data */ - std::string temp_; - /*! \brief internal data pointer */ - char *pbegin_, *pend_; -}; - -} // namespace dmlc -#endif // DMLC_RECORDIO_H_ diff --git a/include/dmlc/registry.h b/include/dmlc/registry.h deleted file mode 100644 index d68b57597250..000000000000 --- a/include/dmlc/registry.h +++ /dev/null @@ -1,306 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file registry.h - * \brief Registry utility that helps to build registry singletons. - */ -#ifndef DMLC_REGISTRY_H_ -#define DMLC_REGISTRY_H_ - -#include -#include -#include -#include "./base.h" -#include "./logging.h" -#include "./parameter.h" -#include "./type_traits.h" - -namespace dmlc { -/*! - * \brief Registry class. - * Registry can be used to register global singletons. - * The most commonly use case are factory functions. - * - * \tparam EntryType Type of Registry entries, - * EntryType need to name a name field. - */ -template -class Registry { - public: - /*! \return list of entries in the registry(excluding alias) */ - inline static const std::vector& List() { - return Get()->const_list_; - } - /*! \return list all names registered in the registry, including alias */ - inline static std::vector ListAllNames() { - const std::map &fmap = Get()->fmap_; - typename std::map::const_iterator p; - std::vector names; - for (p = fmap.begin(); p !=fmap.end(); ++p) { - names.push_back(p->first); - } - return names; - } - /*! - * \brief Find the entry with corresponding name. - * \param name name of the function - * \return the corresponding function, can be NULL - */ - inline static const EntryType *Find(const std::string &name) { - const std::map &fmap = Get()->fmap_; - typename std::map::const_iterator p = fmap.find(name); - if (p != fmap.end()) { - return p->second; - } else { - return NULL; - } - } - /*! - * \brief Add alias to the key_name - * \param key_name The original entry key - * \param alias The alias key. - */ - inline void AddAlias(const std::string& key_name, - const std::string& alias) { - EntryType* e = fmap_.at(key_name); - if (fmap_.count(alias)) { - CHECK_EQ(e, fmap_.at(alias)) - << "Trying to register alias " << alias << " for key " << key_name - << " but " << alias << " is already taken"; - } else { - fmap_[alias] = e; - } - } - /*! - * \brief Internal function to register a name function under name. - * \param name name of the function - * \return ref to the registered entry, used to set properties - */ - inline EntryType &__REGISTER__(const std::string& name) { - CHECK_EQ(fmap_.count(name), 0U) - << name << " already registered"; - EntryType *e = new EntryType(); - e->name = name; - fmap_[name] = e; - const_list_.push_back(e); - entry_list_.push_back(e); - return *e; - } - /*! - * \brief Internal function to either register or get registered entry - * \param name name of the function - * \return ref to the registered entry, used to set properties - */ - inline EntryType &__REGISTER_OR_GET__(const std::string& name) { - if (fmap_.count(name) == 0) { - return __REGISTER__(name); - } else { - return *fmap_.at(name); - } - } - /*! - * \brief get a singleton of the Registry. - * This function can be defined by DMLC_REGISTRY_ENABLE. - * \return get a singleton - */ - static Registry *Get(); - - private: - /*! \brief list of entry types */ - std::vector entry_list_; - /*! \brief list of entry types */ - std::vector const_list_; - /*! \brief map of name->function */ - std::map fmap_; - /*! \brief constructor */ - Registry() {} - /*! \brief destructor */ - ~Registry() { - for (size_t i = 0; i < entry_list_.size(); ++i) { - delete entry_list_[i]; - } - } -}; - -/*! - * \brief Common base class for function registry. - * - * \code - * // This example demonstrates how to use Registry to create a factory of trees. - * struct TreeFactory : - * public FunctionRegEntryBase > { - * }; - * - * // in a independent cc file - * namespace dmlc { - * DMLC_REGISTRY_ENABLE(TreeFactory); - * } - * // register binary tree constructor into the registry. - * DMLC_REGISTRY_REGISTER(TreeFactory, TreeFactory, BinaryTree) - * .describe("Constructor of BinaryTree") - * .set_body([]() { return new BinaryTree(); }); - * \endcode - * - * \tparam EntryType The type of subclass that inheritate the base. - * \tparam FunctionType The function type this registry is registerd. - */ -template -class FunctionRegEntryBase { - public: - /*! \brief name of the entry */ - std::string name; - /*! \brief description of the entry */ - std::string description; - /*! \brief additional arguments to the factory function */ - std::vector arguments; - /*! \brief Function body to create ProductType */ - FunctionType body; - /*! \brief Return type of the function */ - std::string return_type; - - /*! - * \brief Set the function body. - * \param body Function body to set. - * \return reference to self. - */ - inline EntryType &set_body(FunctionType body) { - this->body = body; - return this->self(); - } - /*! - * \brief Describe the function. - * \param description The description of the factory function. - * \return reference to self. - */ - inline EntryType &describe(const std::string &description) { - this->description = description; - return this->self(); - } - /*! - * \brief Add argument information to the function. - * \param name Name of the argument. - * \param type Type of the argument. - * \param description Description of the argument. - * \return reference to self. - */ - inline EntryType &add_argument(const std::string &name, - const std::string &type, - const std::string &description) { - ParamFieldInfo info; - info.name = name; - info.type = type; - info.type_info_str = info.type; - info.description = description; - arguments.push_back(info); - return this->self(); - } - /*! - * \brief Append list if arguments to the end. - * \param args Additional list of arguments. - * \return reference to self. - */ - inline EntryType &add_arguments(const std::vector &args) { - arguments.insert(arguments.end(), args.begin(), args.end()); - return this->self(); - } - /*! - * \brief Set the return type. - * \param type Return type of the function, could be Symbol or Symbol[] - * \return reference to self. - */ - inline EntryType &set_return_type(const std::string &type) { - return_type = type; - return this->self(); - } - - protected: - /*! - * \return reference of self as derived type - */ - inline EntryType &self() { - return *(static_cast(this)); - } -}; - -/*! - * \def DMLC_REGISTRY_ENABLE - * \brief Macro to enable the registry of EntryType. - * This macro must be used under namespace dmlc, and only used once in cc file. - * \param EntryType Type of registry entry - */ -#define DMLC_REGISTRY_ENABLE(EntryType) \ - template<> \ - Registry *Registry::Get() { \ - static Registry inst; \ - return &inst; \ - } \ - -/*! - * \brief Generic macro to register an EntryType - * There is a complete example in FactoryRegistryEntryBase. - * - * \param EntryType The type of registry entry. - * \param EntryTypeName The typename of EntryType, must do not contain namespace :: . - * \param Name The name to be registered. - * \sa FactoryRegistryEntryBase - */ -#define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \ - static DMLC_ATTRIBUTE_UNUSED EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \ - ::dmlc::Registry::Get()->__REGISTER__(#Name) \ - -/*! - * \brief (Optional) Declare a file tag to current file that contains object registrations. - * - * This will declare a dummy function that will be called by register file to - * incur a link dependency. - * - * \param UniqueTag The unique tag used to represent. - * \sa DMLC_REGISTRY_LINK_TAG - */ -#define DMLC_REGISTRY_FILE_TAG(UniqueTag) \ - int __dmlc_registry_file_tag_ ## UniqueTag ## __() { return 0; } - -/*! - * \brief (Optional) Force link to all the objects registered in file tag. - * - * This macro must be used in the same file as DMLC_REGISTRY_ENABLE and - * in the same namespace as DMLC_REGISTRY_FILE_TAG - * - * DMLC_REGISTRY_FILE_TAG and DMLC_REGISTRY_LINK_TAG are optional macros for registration. - * They are used to encforce link of certain file into during static linking. - * - * This is mainly used to solve problem during statically link a library which contains backward registration. - * Specifically, this avoids the objects in these file tags to be ignored by compiler. - * - * For dynamic linking, this problem won't occur as everything is loaded by default. - * - * Use of this is optional as it will create an error when a file tag do not exist. - * An alternative solution is always ask user to enable --whole-archieve during static link. - * - * \begincode - * // in file objective_registry.cc - * DMLC_REGISTRY_ENABLE(MyObjective); - * DMLC_REGISTRY_LINK_TAG(regression_op); - * DMLC_REGISTRY_LINK_TAG(rank_op); - * - * // in file regression_op.cc - * // declare tag of this file. - * DMLC_REGISTRY_FILE_TAG(regression_op); - * DMLC_REGISTRY_REGISTER(MyObjective, logistic_reg, logistic_reg); - * // ... - * - * // in file rank_op.cc - * // declare tag of this file. - * DMLC_REGISTRY_FILE_TAG(rank_op); - * DMLC_REGISTRY_REGISTER(MyObjective, pairwiserank, pairwiserank); - * - * \endcode - * - * \param UniqueTag The unique tag used to represent. - * \sa DMLC_REGISTRY_ENABLE, DMLC_REGISTRY_FILE_TAG - */ -#define DMLC_REGISTRY_LINK_TAG(UniqueTag) \ - int __dmlc_registry_file_tag_ ## UniqueTag ## __(); \ - static int DMLC_ATTRIBUTE_UNUSED __reg_file_tag_ ## UniqueTag ## __ = \ - __dmlc_registry_file_tag_ ## UniqueTag ## __(); -} // namespace dmlc -#endif // DMLC_REGISTRY_H_ diff --git a/include/dmlc/serializer.h b/include/dmlc/serializer.h deleted file mode 100644 index 4bede4a3b416..000000000000 --- a/include/dmlc/serializer.h +++ /dev/null @@ -1,410 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file serializer.h - * \brief serializer template class that helps serialization. - * This file do not need to be directly used by most user. - */ -#ifndef DMLC_SERIALIZER_H_ -#define DMLC_SERIALIZER_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "./base.h" -#include "./io.h" -#include "./logging.h" -#include "./type_traits.h" -#include "./endian.h" - -#if DMLC_USE_CXX11 -#include -#include -#endif - -namespace dmlc { -/*! \brief internal namespace for serializers */ -namespace serializer { -/*! - * \brief generic serialization handler - * \tparam T the type to be serialized - * \tparam need_endian_swap Whether use little endian - */ -template -struct Handler; - -//! \cond Doxygen_Suppress -/*! - * \brief Serializer that redirect calls by condition - * \tparam cond the condition - * \tparam Then the serializer used for then condition - * \tparam Else the serializer used for else condition - * \tparam Return the type of data the serializer handles - */ -template -struct IfThenElse; - -template -struct IfThenElse { - inline static void Write(Stream *strm, const T &data) { - Then::Write(strm, data); - } - inline static bool Read(Stream *strm, T *data) { - return Then::Read(strm, data); - } -}; -template -struct IfThenElse { - inline static void Write(Stream *strm, const T &data) { - Else::Write(strm, data); - } - inline static bool Read(Stream *strm, T *data) { - return Else::Read(strm, data); - } -}; - -/*! \brief Serializer for POD(plain-old-data) data */ -template -struct NativePODHandler { - inline static void Write(Stream *strm, const T &data) { - strm->Write(&data, sizeof(T)); - } - inline static bool Read(Stream *strm, T *dptr) { - return strm->Read((void*)dptr, sizeof(T)) == sizeof(T); // NOLINT(*) - } -}; - -/*! \brief Serializer for arithmetic data, handle endianness */ -template -struct ArithmeticHandler { - inline static void Write(Stream *strm, const T &data) { - if (DMLC_IO_NO_ENDIAN_SWAP) { - strm->Write(&data, sizeof(T)); - } else { - T copy = data; - ByteSwap(©, sizeof(T), 1); - strm->Write(©, sizeof(T)); - } - } - inline static bool Read(Stream *strm, T *dptr) { - bool ret = strm->Read((void*)dptr, sizeof(T)) == sizeof(T); // NOLINT(*) - if (!DMLC_IO_NO_ENDIAN_SWAP) { - ByteSwap(dptr, sizeof(T), 1); - } - return ret; - } -}; - -// serializer for class that have save/load function -template -struct SaveLoadClassHandler { - inline static void Write(Stream *strm, const T &data) { - data.Save(strm); - } - inline static bool Read(Stream *strm, T *data) { - return data->Load(strm); - } -}; - -/*! - * \brief dummy class for undefined serialization. - * This is used to generate error message when user tries to - * serialize something that is not supported. - * \tparam T the type to be serialized - */ -template -struct UndefinedSerializerFor { -}; - -/*! - * \brief Serializer handler for std::vector where T is POD type. - * \tparam T element type - */ -template -struct NativePODVectorHandler { - inline static void Write(Stream *strm, const std::vector &vec) { - uint64_t sz = static_cast(vec.size()); - strm->Write(sz); - if (sz != 0) { - strm->Write(&vec[0], sizeof(T) * vec.size()); - } - } - inline static bool Read(Stream *strm, std::vector *out_vec) { - uint64_t sz; - if (!strm->Read(&sz)) return false; - size_t size = static_cast(sz); - out_vec->resize(size); - if (sz != 0) { - size_t nbytes = sizeof(T) * size; - return strm->Read(&(*out_vec)[0], nbytes) == nbytes; - } - return true; - } -}; - -/*! - * \brief Serializer handler for std::vector where T can be composed type - * \tparam T element type - */ -template -struct ComposeVectorHandler { - inline static void Write(Stream *strm, const std::vector &vec) { - uint64_t sz = static_cast(vec.size()); - strm->Write(sz); - strm->WriteArray(dmlc::BeginPtr(vec), vec.size()); - } - inline static bool Read(Stream *strm, std::vector *out_vec) { - uint64_t sz; - if (!strm->Read(&sz)) return false; - size_t size = static_cast(sz); - out_vec->resize(size); - return strm->ReadArray(dmlc::BeginPtr(*out_vec), size); - } -}; - -/*! - * \brief Serializer handler for std::basic_string where T is POD type. - * \tparam T element type - */ -template -struct NativePODStringHandler { - inline static void Write(Stream *strm, const std::basic_string &vec) { - uint64_t sz = static_cast(vec.length()); - strm->Write(sz); - if (sz != 0) { - strm->Write(&vec[0], sizeof(T) * vec.length()); - } - } - inline static bool Read(Stream *strm, std::basic_string *out_vec) { - uint64_t sz; - if (!strm->Read(&sz)) return false; - size_t size = static_cast(sz); - out_vec->resize(size); - if (sz != 0) { - size_t nbytes = sizeof(T) * size; - return strm->Read(&(*out_vec)[0], nbytes) == nbytes; - } - return true; - } -}; - -/*! \brief Serializer for std::pair */ -template -struct PairHandler { - inline static void Write(Stream *strm, const std::pair &data) { - Handler::Write(strm, data.first); - Handler::Write(strm, data.second); - } - inline static bool Read(Stream *strm, std::pair *data) { - return Handler::Read(strm, &(data->first)) && - Handler::Read(strm, &(data->second)); - } -}; - -// set type handler that can handle most collection type case -template -struct CollectionHandler { - inline static void Write(Stream *strm, const ContainerType &data) { - // dump data to vector - std::vector vdata(data.begin(), data.end()); - // serialize the vector - Handler >::Write(strm, vdata); - } - inline static bool Read(Stream *strm, ContainerType *data) { - std::vector vdata; - if (!Handler >::Read(strm, &vdata)) return false; - data->clear(); - data->insert(vdata.begin(), vdata.end()); - return true; - } -}; - - -// handler that can handle most list type case -// this type insert function takes additional iterator -template -struct ListHandler { - inline static void Write(Stream *strm, const ListType &data) { - typedef typename ListType::value_type ElemType; - // dump data to vector - std::vector vdata(data.begin(), data.end()); - // serialize the vector - Handler >::Write(strm, vdata); - } - inline static bool Read(Stream *strm, ListType *data) { - typedef typename ListType::value_type ElemType; - std::vector vdata; - if (!Handler >::Read(strm, &vdata)) return false; - data->clear(); - data->insert(data->begin(), vdata.begin(), vdata.end()); - return true; - } -}; - -//! \endcond - -/*! - * \brief generic serialization handler for type T - * - * User can define specialization of this class to support - * composite serialization of their own class. - * - * \tparam T the type to be serialized - */ -template -struct Handler { - /*! - * \brief write data to stream - * \param strm the stream we write the data. - * \param data the data obeject to be serialized - */ - inline static void Write(Stream *strm, const T &data) { - IfThenElse::value, - ArithmeticHandler, - IfThenElse::value && DMLC_IO_NO_ENDIAN_SWAP, - NativePODHandler, - IfThenElse::value, - SaveLoadClassHandler, - UndefinedSerializerFor, T>, - T>, - T> - ::Write(strm, data); - } - /*! - * \brief read data to stream - * \param strm the stream to read the data. - * \param data the pointer to the data obeject to read - * \return whether the read is successful - */ - inline static bool Read(Stream *strm, T *data) { - return - IfThenElse::value, - ArithmeticHandler, - IfThenElse::value && DMLC_IO_NO_ENDIAN_SWAP, - NativePODHandler, - IfThenElse::value, - SaveLoadClassHandler, - UndefinedSerializerFor, T>, - T>, - T> - ::Read(strm, data); - } -}; - -//! \cond Doxygen_Suppress -template -struct Handler > { - inline static void Write(Stream *strm, const std::vector &data) { - IfThenElse::value && DMLC_IO_NO_ENDIAN_SWAP, - NativePODVectorHandler, - ComposeVectorHandler, std::vector > - ::Write(strm, data); - } - inline static bool Read(Stream *strm, std::vector *data) { - return IfThenElse::value && DMLC_IO_NO_ENDIAN_SWAP, - NativePODVectorHandler, - ComposeVectorHandler, - std::vector > - ::Read(strm, data); - } -}; - -template -struct Handler > { - inline static void Write(Stream *strm, const std::basic_string &data) { - IfThenElse::value && (DMLC_IO_NO_ENDIAN_SWAP || sizeof(T) == 1), - NativePODStringHandler, - UndefinedSerializerFor, - std::basic_string > - ::Write(strm, data); - } - inline static bool Read(Stream *strm, std::basic_string *data) { - return IfThenElse::value && (DMLC_IO_NO_ENDIAN_SWAP || sizeof(T) == 1), - NativePODStringHandler, - UndefinedSerializerFor, - std::basic_string > - ::Read(strm, data); - } -}; - -template -struct Handler > { - inline static void Write(Stream *strm, const std::pair &data) { - IfThenElse::value && - dmlc::is_pod::value && - DMLC_IO_NO_ENDIAN_SWAP, - NativePODHandler >, - PairHandler, - std::pair > - ::Write(strm, data); - } - inline static bool Read(Stream *strm, std::pair *data) { - return IfThenElse::value && - dmlc::is_pod::value && - DMLC_IO_NO_ENDIAN_SWAP, - NativePODHandler >, - PairHandler, - std::pair > - ::Read(strm, data); - } -}; - -template -struct Handler > - : public CollectionHandler, std::pair > { -}; - -template -struct Handler > - : public CollectionHandler, std::pair > { -}; - -template -struct Handler > - : public CollectionHandler, T> { -}; - -template -struct Handler > - : public CollectionHandler, T> { -}; - -template -struct Handler > - : public ListHandler > { -}; - -template -struct Handler > - : public ListHandler > { -}; - -#if DMLC_USE_CXX11 -template -struct Handler > - : public CollectionHandler, std::pair > { -}; - -template -struct Handler > - : public CollectionHandler, std::pair > { -}; - -template -struct Handler > - : public CollectionHandler, T> { -}; - -template -struct Handler > - : public CollectionHandler, T> { -}; -#endif -//! \endcond -} // namespace serializer -} // namespace dmlc -#endif // DMLC_SERIALIZER_H_ diff --git a/include/dmlc/thread_group.h b/include/dmlc/thread_group.h deleted file mode 100644 index 626142f30284..000000000000 --- a/include/dmlc/thread_group.h +++ /dev/null @@ -1,808 +0,0 @@ -/*! - * Copyright (c) 2017 by Contributors - * \file thread_group.h - * \brief Thread and synchronization primitives and lifecycle management - */ -#ifndef DMLC_THREAD_GROUP_H_ -#define DMLC_THREAD_GROUP_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#if defined(DMLC_USE_CXX14) || __cplusplus > 201103L /* C++14 */ -#include -#endif -#include -#ifdef __linux__ -#include -#include -#endif - -namespace dmlc { - -/*! - * \brief Simple manual-reset event gate which remains open after signalled - */ -class ManualEvent { - public: - ManualEvent() : signaled_(false) {} - - /*! - * \brief Wait for the object to become signaled. If the object - * is already in the signaled state and reset() has not been called, then no wait will occur - */ - void wait() { - std::unique_lock lock(mutex_); - if (!signaled_) { - condition_variable_.wait(lock); - } - } - - /*! - * \brief Set this object's state to signaled (wait() will release or pass through) - */ - void signal() { - signaled_ = true; - std::unique_lock lk(mutex_); - condition_variable_.notify_all(); - } - - /*! - * \brief Manually reset this object's state to unsignaled (wait() will block) - */ - void reset() { - std::unique_lock lk(mutex_); - signaled_ = false; - } - - private: - /*! \brief Internal mutex to protect condition variable and signaled_ variable */ - std::mutex mutex_; - /*! \brief Internal condition variable */ - std::condition_variable condition_variable_; - /*! \brief lockfree signal state check */ - std::atomic signaled_; -}; - -#if defined(DMLC_USE_CXX14) || __cplusplus > 201103L /* C++14 */ -/*! \brief Mutex which can be read-locked and write-locked */ -using SharedMutex = std::shared_timed_mutex; -/*! \brief Write lock, disallows both reads and writes */ -using WriteLock = std::unique_lock; -/*! \brief Read lock, allows concurrent data reads */ -using ReadLock = std::shared_lock; -#else -/*! \brief Standard mutex for C++ < 14 */ -using SharedMutex = std::recursive_mutex; -/*! \brief Standard unique lock for C++ < 14 */ -using WriteLock = std::unique_lock; -/*! \brief Standard unique lock for C++ < 14 */ -using ReadLock = std::unique_lock; -#endif - -/*! - * \brief Thread lifecycle management group - * \note See gtest unit tests Syc.* for a usage examples - */ -class ThreadGroup { - public: - /*! - * \brief Lifecycle-managed thread (used by ThreadGroup) - * \note See gtest unit tests Syc.* for a usage examples - */ - class Thread { - public: - /*! \brief Shared pointer type for readability */ - using SharedPtr = std::shared_ptr; - - /*! - * \brief Constructor - * \param threadName User-defined name of the thread. must be unique per ThreadGroup - * \param owner The ThreadGroup object managing the lifecycle of this thread - * \param thrd Optionally-assigned std::thread object associated with this Thread class - */ - Thread(std::string threadName, ThreadGroup *owner, std::thread *thrd = nullptr) - : name_(std::move(threadName)) - , thread_(thrd) - , ready_event_(std::make_shared()) - , start_event_(std::make_shared()) - , owner_(owner) - , shutdown_requested_(false) - , auto_remove_(false) { - CHECK_NOTNULL(owner); - } - - /*! - * \brief Destructor with cleanup - */ - virtual ~Thread() { - const bool self_delete = is_current_thread(); - if (!self_delete) { - request_shutdown(); - internal_join(true); - } - WriteLock guard(thread_mutex_); - if (thread_.load()) { - std::thread *thrd = thread_.load(); - thread_ = nullptr; - if (self_delete) { - thrd->detach(); - } - delete thrd; - } - } - - /*! - * \brief Name of the thread - * \return Pointer to the thread name's string - * \note This shoul ndly be used as immediate for the sacope of the - * shared pointer pointing to this object - */ - const char *name() const { - return name_.c_str(); - } - - /*! - * \brief Launch the given Thread object - * \tparam StartFunction Function type for the thread 'main' function - * \tparam Args Arguments to pass to the thread 'main' function - * \param pThis Shared pointer for the managed thread to launch - * \param autoRemove if true, automatically remove this Thread object from the - * ThreadGroup owner upon exit - * \param start_function The Thread's 'main' function - * \param args Arguments to pass to the Thread's 'main' function - * \return true if the thread was successfully created and added to the ThreadGroup - * If false is returned, the thread may have already been started, but if something - * went wrong (ie duplicte thread name for the ThreadGroup), then request_shutdown() - * will have been been called on the running thread - */ - template - static bool launch(std::shared_ptr pThis, - bool autoRemove, - StartFunction start_function, - Args ...args); - - /*! - * \brief Check if this class represents the currently running thread (self) - * \return true if the current running thread belongs to this class - */ - bool is_current_thread() const { - ReadLock guard(thread_mutex_); - return thread_.load() ? (thread_.load()->get_id() == std::this_thread::get_id()) : false; - } - - /*! - * \brief Signal to this thread that a thread shutdown/exit is requested. - * \note This is a candidate for overrise in a derived class which may trigger shutdown - * by means other than a boolean (ie condition variable, SimpleManualkEvent, etc). - */ - virtual void request_shutdown() { - shutdown_requested_ = true; - } - - /*! - * \brief Check whether shutdown has been requested (request_shutdown() was called) - * \return true if shutdown was requested. - * \note This may be overriden to match an overriden to match an overriden 'request_shutdown()', - * for instance. - */ - virtual bool is_shutdown_requested() const { - return shutdown_requested_.load(); - } - - /*! - * \brief Check whether the thread is set to auto-remove itself from the ThreadGroup owner - * when exiting - * \return true if the thread will auto-remove itself from the ThreadGroup owner - * when exiting - */ - bool is_auto_remove() const { - return auto_remove_; - } - - /*! - * \brief Make the thread joinable (by removing the auto_remove flag) - * \warning Care should be taken not to cause a race condition between this call - * and parallel execution of this thread auto-removing itself - */ - void make_joinable() { - auto_remove_ = false; - } - - /*! - * \brief Check whether the thread is joinable - * \return true if the thread is joinable - */ - bool joinable() const { - ReadLock guard(thread_mutex_); - if (thread_.load()) { - CHECK_EQ(auto_remove_, false); - // be checked by searching the group or exit event. - return thread_.load()->joinable(); - } - return false; - } - - /*! - * \brief Thread join - * \note join() may not be called on auto-remove threads - */ - void join() { - internal_join(false); - } - - /*! - * \brief Get this thread's id - * \return this thread's id - */ - std::thread::id get_id() const { - ReadLock guard(thread_mutex_); - return thread_.load()->get_id(); - } - - private: - /*! - * \brief Internal join function - * \param auto_remove_ok Whether to allow join on an auto-remove thread - */ - void internal_join(bool auto_remove_ok) { - ReadLock guard(thread_mutex_); - // should be careful calling (or any function externally) this when in - // auto-remove mode - if (thread_.load() && thread_.load()->get_id() != std::thread::id()) { - std::thread::id someId; - if (!auto_remove_ok) { - CHECK_EQ(auto_remove_, false); - } - CHECK_NOTNULL(thread_.load()); - if (thread_.load()->joinable()) { - thread_.load()->join(); - } else { - LOG(WARNING) << "Thread " << name_ << " ( " - << thread_.load()->get_id() << " ) not joinable"; - } - } - } - - /*! - * \brief Thread bootstrapping and teardown wrapper - * \tparam StartFunction Thread's "main" function - * \tparam Args Argument types to be passed to the start_function - * \param pThis Shared pointer to the Thread object to operate upon - * \param start_function Thread's "main" function (i.e. passed to launch()) - * \param args Arguments to be passed to the start_function - * \return The thread's return code - */ - template - static int entry_and_exit_f(std::shared_ptr pThis, - StartFunction start_function, - Args... args); - /*! \brief Thread name */ - std::string name_; - /*! \brief Shared mutex for some thread operations */ - mutable SharedMutex thread_mutex_; - /*! \brief Pointer to the stl thread object */ - std::atomic thread_; - /*! \brief Signaled when the thread is started and ready to execute user code */ - std::shared_ptr ready_event_; - /*! \brief Thread will block after setting ready_event_ until start_event_ is signaled */ - std::shared_ptr start_event_; - /*! \brief The ThreadGroup ownber managing this thread's lifecycle */ - ThreadGroup *owner_; - /*! \brief Flag to determine if shutdown was requested. */ - std::atomic shutdown_requested_; - /*! - * \brief Whether to automatically remove this thread's object from the ThreadGroup when the - * thread exists (perform its own cleanup) - */ - volatile bool auto_remove_; - }; - - /*! - * \brief Constructor - */ - inline ThreadGroup() - : evEmpty_(std::make_shared()) { - evEmpty_->signal(); // Starts out empty - } - - /*! - * \brief Destructor, perform cleanup. All child threads will be exited when this - * destructor completes - */ - virtual ~ThreadGroup() { - request_shutdown_all(); - join_all(); - } - - /*! - * \brief Check if the current thread a member if this ThreadGroup - * \return true if the current thread is a member of this thread group - * \note This lookup involved a linear search, so for a large number of threads, - * is it not advised to call this function in a performance-sensitive area - */ - inline bool is_this_thread_in() const { - std::thread::id id = std::this_thread::get_id(); - ReadLock guard(m_); - for (auto it = threads_.begin(), end = threads_.end(); it != end; ++it) { - std::shared_ptr thrd = *it; - if (thrd->get_id() == id) - return true; - } - return false; - } - - /*! - * \brief Check if the current thread is a member of this ThreadGroup - * \param thrd The thread to search for - * \return true if the given thread is a member of this ThreadGroup - */ - inline bool is_thread_in(std::shared_ptr thrd) const { - if (thrd) { - std::thread::id id = thrd->get_id(); - ReadLock guard(m_); - for (auto it = threads_.begin(), end = threads_.end(); it != end; ++it) { - std::shared_ptr thrd = *it; - if (thrd->get_id() == id) - return true; - } - return false; - } else { - return false; - } - } - - /*! - * \brief Add a Thread object to this thread group - * \param thrd The thread to add to this ThreadGroup object - * \return true if the given thread was added to this ThreadGroup - */ - inline bool add_thread(std::shared_ptr thrd) { - if (thrd) { - WriteLock guard(m_); - auto iter = name_to_thread_.find(thrd->name()); - if (iter == name_to_thread_.end()) { - name_to_thread_.emplace(std::make_pair(thrd->name(), thrd)); - CHECK_EQ(threads_.insert(thrd).second, true); - evEmpty_->reset(); - return true; - } - } - return false; - } - - /*! - * \brief Remove a Thread object from this thread group - * \param thrd The thread to remove from this ThreadGroup object - * \return true if the given thread was removed from this ThreadGroup - */ - inline bool remove_thread(std::shared_ptr thrd) { - if (thrd) { - WriteLock guard(m_); - auto iter = threads_.find(thrd); - if (iter != threads_.end()) { - name_to_thread_.erase(thrd->name()); - threads_.erase(iter); - if (threads_.empty()) { - evEmpty_->signal(); - } - return true; - } - } - return false; - } - - /*! - * \brief Join all threads in this ThreadGroup - * \note While it is not valid to call 'join' on an auto-remove thread, this function will - * wait for auto-remove threads to exit (waits for the ThreadGroup to become empty) - */ - inline void join_all() { - CHECK_EQ(!is_this_thread_in(), true); - do { - std::unique_lock lk(join_all_mtx_); - std::unordered_set> working_set; - { - ReadLock guard(m_); - for (auto iter = threads_.begin(), e_iter = threads_.end(); iter != e_iter; ++iter) { - if (!(*iter)->is_auto_remove()) { - working_set.emplace(*iter); - } - } - } - // Where possible, prefer to do a proper join rather than simply waiting for empty - // (easier to troubleshoot) - while (!working_set.empty()) { - std::shared_ptr thrd; - thrd = *working_set.begin(); - if (thrd->joinable()) { - thrd->join(); - } - remove_thread(thrd); - working_set.erase(working_set.begin()); - thrd.reset(); - } - // Wait for auto-remove threads (if any) to complete - } while (0); - evEmpty_->wait(); - CHECK_EQ(threads_.size(), 0); - } - - /*! - * \brief Call request_shutdown() on all threads in this ThreadGroup - * \param make_all_joinable If true, remove all auto_remove flags from child threads - */ - inline void request_shutdown_all(const bool make_all_joinable = true) { - std::unique_lock lk(join_all_mtx_); - ReadLock guard(m_); - for (auto &thread : threads_) { - if (make_all_joinable) { - thread->make_joinable(); - } - thread->request_shutdown(); - } - } - - /*! - * \brief Return the number of threads in this thread group - * \return Number of threads in this thread group - */ - inline size_t size() const { - ReadLock guard(m_); - return threads_.size(); - } - - /*! - * \brief Check if the ThreadGroup is empty - * \return true if the ThreadGroup is empty - */ - inline bool empty() const { - ReadLock guard(m_); - return threads_.size() == 0; - } - - /*! - * \brief Create and launch a new Thread object which will be owned by this ThreadGroup - * \tparam StartFunction Function type for the thread 'main' function - * \tparam ThreadType managedThreadclass type (in case it's derived, for instance) - * \tparam Args Arguments to pass to the thread 'main' function - * \param threadName Name if the thread. Must be unique for a ThreadGroup object - * \param auto_remove If true, automatically remove this Thread object from the - * ThreadGroup owner upon exit - * \param start_function The Thread's 'main' function - * \param args Arguments to pass to the Thread's 'main' function - * \return true if the thread was successfully created and added to the ThreadGroup - * If false is returned, the thread may have already been started, but if something - * went wrong (ie duplicte thread name for the ThreadGroup), then request_shutdown() - * will have been been called on the running thread - */ - template - inline bool create(const std::string &threadName, - bool auto_remove, - StartFunction start_function, - Args... args) { - typename ThreadType::SharedPtr newThread(new ThreadType(threadName, this)); - return Thread::launch(newThread, auto_remove, start_function, args...); - } - - /*! - * \brief Lookup Thread object by name - * \param name Name of the thread to look up - * \return A shared pointer to the Thread object - */ - inline std::shared_ptr thread_by_name(const std::string& name) { - ReadLock guard(m_); - auto iter = name_to_thread_.find(name); - if (iter != name_to_thread_.end()) { - return iter->second; - } - return nullptr; - } - - private: - /*! \brief ThreadGroup synchronization mutex */ - mutable SharedMutex m_; - /*! \brief join_all/auto_remove synchronization mutex */ - mutable std::mutex join_all_mtx_; - /*! \brief Set of threads owned and managed by this ThreadGroup object */ - std::unordered_set> threads_; - /*! \brief Manual event which is signaled when the thread group is empty */ - std::shared_ptr evEmpty_; - /*! \brief name->thread mapping */ - std::unordered_map> name_to_thread_; -}; - -/*! - * \brief Blocking queue thread class - * \tparam ObjectType Object type to queue - * \tparam quit_item Object value to signify queue shutdown (ie nullptr for pointer type is common) - * \note See gtest unit test Syc.ManagedThreadLaunchQueueThread for a usage example - */ -template -class BlockingQueueThread : public ThreadGroup::Thread { - using BQT = BlockingQueueThread; - - public: - /*! - * \brief Constructor - * \param name Name for the blockin g queue thread. Must be unique for a specific ThreadGroup - * \param owner ThreadGroup lifecycle manafger/owner - * \param thrd Optionally attach an existing stl thread object - */ - BlockingQueueThread(const std::string& name, - dmlc::ThreadGroup *owner, - std::thread *thrd = nullptr) - : ThreadGroup::Thread(std::move(name), owner, thrd) - , shutdown_in_progress_(false) { - } - - - /*! - * \brief Destructor - */ - ~BlockingQueueThread() override { - // Call to parent first because we don't want to wait for the queue to empty - ThreadGroup::Thread::request_shutdown(); - request_shutdown(); - } - - /*! - * \brief Signal the thread that a shutdown is desired - * \note Since consumer doesn't necessarily get items in order, we must wait for - * the queue to empty. - * This is generally a shutdown procedure and should not be called from - * a performance-sensitive area - */ - void request_shutdown() override { - shutdown_in_progress_ = true; - while (queue_->size_approx() > 0 && !ThreadGroup::Thread::is_shutdown_requested()) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } - ThreadGroup::Thread::request_shutdown(); - queue_->enqueue(quit_item); - } - - /*! - * \brief Enqueue and item - * \param item The item to enqueue - */ - void enqueue(const ObjectType& item) { - if (!shutdown_in_progress_) { - queue_->enqueue(item); - } - } - - /*! - * \brief Get the approximate size of the queue - * \return The approximate size of the queue - */ - size_t size_approx() const { return queue_->size_approx(); } - - /*! - * \brief Launch to the 'run' function which will, in turn, call the class' - * 'run' function, passing it the given 'secondary_function' - * for it to call as needed - * \tparam SecondaryFunction Type of the secondary function for 'run' override - * to call as needed - * \param pThis Pointer to the managed thread to launch - * \param secondary_function secondary function for 'run' override to call as needed - * \return true if thread is launched successfully and added to the ThreadGroup - */ - template - static bool launch_run(std::shared_ptr pThis, - SecondaryFunction secondary_function) { - return ThreadGroup::Thread::launch(pThis, true, [](std::shared_ptr pThis, - SecondaryFunction secondary_function) { - return pThis->run(secondary_function); - }, - pThis, secondary_function); - } - - /*! - * \brief Thread's main queue processing function - * \tparam OnItemFunction Function type to call when an item is dequeued - * \param on_item_function Function to call when an item is dequeued - * \return 0 if completed through a `quit_item`, nonzero if on_item_function requested an exit - */ - template - inline int run(OnItemFunction on_item_function) { - int rc = 0; - do { - ObjectType item; - queue_->wait_dequeue(item); - if (item == quit_item) { - break; - } - rc = on_item_function(item); - if (rc) { - break; - } - } while (true); - return rc; - } - - private: - /*! \brief The blocking queue associated with this thread */ - std::shared_ptr> queue_ = - std::make_shared>(); - /*! \brief Whether shutdown request is in progress */ - std::atomic shutdown_in_progress_; -}; - -/*! - * \brief Managed timer thread - * \tparam Duration Duration type (ie seconds, microseconds, etc) - */ -template -class TimerThread : public ThreadGroup::Thread { - using ThreadGroup::Thread::is_shutdown_requested; - - public: - /*! - * \brief Constructor - * \param name Name of the timer thread - * \param owner ThreadGroup owner if the timer thread - */ - TimerThread(const std::string& name, ThreadGroup *owner) - : Thread(name, owner) { - } - - /*! - * \brief Destructor - */ - ~TimerThread() override { - request_shutdown(); - } - - /*! - * \brief Launch to the 'run' function which will, in turn, call the class' - * 'run' function, passing it the given 'secondary_function' - * for it to call as needed - * \tparam SecondaryFunction Type of the secondary function for 'run' override - * to call as needed - * \param pThis Pointer to the managed thread to launch - * \param secondary_function secondary function for 'run' override to call as needed - * \return true if thread is launched successfully and added to the ThreadGroup - */ - template - static bool launch_run(std::shared_ptr> pThis, - SecondaryFunction secondary_function) { - return ThreadGroup::Thread::launch(pThis, true, [](std::shared_ptr> pThis, - SecondaryFunction secondary_function) { - return pThis->run(secondary_function); - }, - pThis, secondary_function); - } - - /*! - * \brief Start a given timer thread - * \tparam Function Type of the timer function - * \param timer_thread Thread object to perform the timer events - * \param duration Duration between the end end of the timer function and the next timer event - * \param function Function to call when the timer expires - * \note Calling shutdown_requested() will cause the thread to exit the next time that the timer - * expires. - */ - template - static void start(std::shared_ptr timer_thread, - Duration duration, - Function function) { - timer_thread->duration_ = duration; - launch_run(timer_thread, function); - } - - /*! - * \brief Internal timer execution function - * \tparam OnTimerFunction Type of function to call each time the timer expires - * \param on_timer_function Function to call each time the timer expires - * \return Exit code of the thread - */ - template - inline int run(OnTimerFunction on_timer_function) { - int rc = 0; - while (!is_shutdown_requested()) { - std::this_thread::sleep_for(duration_); - if (!is_shutdown_requested()) { - rc = on_timer_function(); - } - } - return rc; - } - - private: - Duration duration_; -}; - -/* - * Inline functions - see declarations for usage - */ -template -inline int ThreadGroup::Thread::entry_and_exit_f(std::shared_ptr pThis, - StartFunction start_function, - Args... args) { - int rc; - if (pThis) { - // Signal launcher that we're up and running - pThis->ready_event_->signal(); - // Wait for launcher to be ready for us to start - pThis->start_event_->wait(); - // Reset start_event_ for possible reuse - pThis->start_event_->reset(); // Reset in case it needs to be reused - // If we haven't been requested to shut down prematurely, then run the desired function - if (!pThis->is_shutdown_requested()) { - rc = start_function(args...); - } else { - rc = -1; - } - // If we're set up as auto-remove, then remove this thread from the thread group - if (pThis->is_auto_remove()) { - pThis->owner_->remove_thread(pThis); - } - // Release this thread shared pinter. May or may not be the last reference. - pThis.reset(); - } else { - LOG(ERROR) << "Null pThis thread pointer"; - rc = EINVAL; - } - return rc; -} - -template -inline bool ThreadGroup::Thread::launch(std::shared_ptr pThis, - bool autoRemove, - StartFunction start_function, - Args ...args) { - WriteLock guard(pThis->thread_mutex_); - CHECK_EQ(!pThis->thread_.load(), true); - CHECK_NOTNULL(pThis->owner_); - // Set auto remove - pThis->auto_remove_ = autoRemove; - // Create the actual stl thread object - pThis->thread_ = new std::thread(Thread::template entry_and_exit_f< - StartFunction, Args...>, - pThis, - start_function, - args...); - // Attempt to add the thread to the thread group (after started, since in case - // something goes wrong, there's not a zombie thread in the thread group) - if (!pThis->owner_->add_thread(pThis)) { - pThis->request_shutdown(); - LOG(ERROR) << "Duplicate thread name within the same thread group is not allowed"; - } - // Wait for the thread to spin up - pThis->ready_event_->wait(); - // Signal the thgread to continue (it will check its shutdown status) - pThis->start_event_->signal(); - // Return if successful - return pThis->thread_.load() != nullptr; -} - -/*! - * \brief Utility function to easily create a timer - * \tparam Duration Duration type (i.e. std::chrono::milliseconds) - * \tparam TimerFunction Function to call each time the timer expires - * \param timer_name Name of the timer. Must be unique per ThreadGroup object - * \param duration Duration of the timer between calls to timer_function - * \param owner ThreadGroup owner of the timer - * \param timer_function Function to call each time the timer expires - * \return true if the timer was successfully created - */ -template -inline bool CreateTimer(const std::string& timer_name, - const Duration& duration, - ThreadGroup *owner, - TimerFunction timer_function) { - std::shared_ptr> timer_thread = - std::make_shared>(timer_name, owner); - dmlc::TimerThread::start(timer_thread, duration, timer_function); - return timer_thread != nullptr; -} -} // namespace dmlc - -#endif // DMLC_THREAD_GROUP_H_ diff --git a/include/dmlc/thread_local.h b/include/dmlc/thread_local.h deleted file mode 100644 index fecaef8686de..000000000000 --- a/include/dmlc/thread_local.h +++ /dev/null @@ -1,83 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file thread_local.h - * \brief Portable thread local storage. - */ -#ifndef DMLC_THREAD_LOCAL_H_ -#define DMLC_THREAD_LOCAL_H_ - -#include -#include -#include -#include "./base.h" - -namespace dmlc { - -// macro hanlding for threadlocal variables -#ifdef __GNUC__ - #define MX_THREAD_LOCAL __thread -#elif __STDC_VERSION__ >= 201112L - #define MX_THREAD_LOCAL _Thread_local -#elif defined(_MSC_VER) - #define MX_THREAD_LOCAL __declspec(thread) -#endif - -#if DMLC_CXX11_THREAD_LOCAL == 0 -#pragma message("Warning: CXX11 thread_local is not formally supported") -#endif - -/*! - * \brief A threadlocal store to store threadlocal variables. - * Will return a thread local singleton of type T - * \tparam T the type we like to store - */ -template -class ThreadLocalStore { - public: - /*! \return get a thread local singleton */ - static T* Get() { -#if DMLC_CXX11_THREAD_LOCAL - static thread_local T inst; - return &inst; -#else - static MX_THREAD_LOCAL T* ptr = nullptr; - if (ptr == nullptr) { - ptr = new T(); - Singleton()->RegisterDelete(ptr); - } - return ptr; -#endif - } - - private: - /*! \brief constructor */ - ThreadLocalStore() {} - /*! \brief destructor */ - ~ThreadLocalStore() { - for (size_t i = 0; i < data_.size(); ++i) { - delete data_[i]; - } - } - /*! \return singleton of the store */ - static ThreadLocalStore *Singleton() { - static ThreadLocalStore inst; - return &inst; - } - /*! - * \brief register str for internal deletion - * \param str the string pointer - */ - void RegisterDelete(T *str) { - std::unique_lock lock(mutex_); - data_.push_back(str); - lock.unlock(); - } - /*! \brief internal mutex */ - std::mutex mutex_; - /*!\brief internal data */ - std::vector data_; -}; - -} // namespace dmlc - -#endif // DMLC_THREAD_LOCAL_H_ diff --git a/include/dmlc/threadediter.h b/include/dmlc/threadediter.h deleted file mode 100644 index c920156b2331..000000000000 --- a/include/dmlc/threadediter.h +++ /dev/null @@ -1,475 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file threadediter.h - * \brief thread backed iterator that can be used to implement - * general thread-based pipeline such as prefetch and pre-computation - * To use the functions in this header, C++11 is required - * \author Tianqi Chen - */ -#ifndef DMLC_THREADEDITER_H_ -#define DMLC_THREADEDITER_H_ -// defines DMLC_USE_CXX11 -#include "./base.h" -// this code depends on c++11 -#if DMLC_ENABLE_STD_THREAD -#include -#include -#include -#include -#include -#include "./data.h" -#include "./logging.h" - -namespace dmlc { -/*! - * \brief a iterator that was backed by a thread - * to pull data eagerly from a single producer into a bounded buffer - * the consumer can pull the data at its own rate - * - * NOTE: thread concurrency cost time, make sure to store big blob of data in DType - * - * Usage example: - * \code - * ThreadedIter iter; - * iter.Init(&producer); - * // the following code can be in parallel - * DType *dptr; - * while (iter.Next(&dptr)) { - * // do something on dptr - * // recycle the space - * iter.Recycle(&dptr); - * } - * \endcode - * \tparam DType the type of data blob we support - */ -template -class ThreadedIter : public DataIter { - public: - /*! - * \brief producer class interface - * that threaditer used as source to - * preduce the content - */ - class Producer { - public: - // virtual destructor - virtual ~Producer() {} - /*! \brief reset the producer to beginning */ - virtual void BeforeFirst(void) { - NotImplemented(); - } - /*! - * \brief load the data content into DType, - * the caller can pass in NULL or an existing address - * when inout_dptr is NULL: - * producer need to allocate a DType and fill the content - * when inout_dptr is specified - * producer takes need to fill the content into address - * specified inout_dptr, or delete the one and create a new one - * - * \param inout_dptr used to pass in the data holder cell - * and return the address of the cell filled - * \return true if there is next record, false if we reach the end - */ - virtual bool Next(DType **inout_dptr) = 0; - }; - /*! - * \brief constructor - * \param max_capacity maximum capacity of the queue - */ - explicit ThreadedIter(size_t max_capacity = 8) - : producer_owned_(NULL), - producer_thread_(NULL), - max_capacity_(max_capacity), - nwait_consumer_(0), - nwait_producer_(0), - out_data_(NULL) {} - /*! \brief destructor */ - virtual ~ThreadedIter(void) { - this->Destroy(); - } - /*! - * \brief destroy all the related resources - * this is equivalent to destructor, can be used - * to destroy the threaditer when user think it is - * appropriate, it is safe to call this multiple times - */ - inline void Destroy(void); - /*! - * \brief set maximum capacity of the queue - * \param max_capacity maximum capacity of the queue - */ - inline void set_max_capacity(size_t max_capacity) { - max_capacity_ = max_capacity; - } - /*! - * \brief initialize the producer and start the thread - * can only be called once - * \param producer pointer to the producer - * \param pass_ownership whether pass the ownership to the iter - * if this is true, the threaditer will delete the producer - * when destructed - */ - inline void Init(Producer *producer, bool pass_ownership = false); - /*! - * \brief initialize the producer and start the thread - * pass in two function(closure) of producer to represent the producer - * the beforefirst function is optional, and defaults to not implemented - * NOTE: the closure must remain valid until the ThreadedIter destructs - * \param next the function called to get next element, see Producer.Next - * \param beforefirst the function to call to reset the producer, see Producer.BeforeFirst - */ - inline void Init(std::function next, - std::function beforefirst = NotImplemented); - /*! - * \brief get the next data, this function is threadsafe - * \param out_dptr used to hold the pointer to the record - * after the function call, the caller takes ownership of the pointer - * the caller can call recycle to return ownership back to the threaditer - * so that the pointer can be re-used - * \return true if there is next record, false if we reach the end - * \sa Recycle - */ - inline bool Next(DType **out_dptr); - /*! - * \brief recycle the data cell, this function is threadsafe - * the threaditer can reuse the data cell for future data loading - * \param inout_dptr pointer to the dptr to recycle, after the function call - * the content of inout_dptr will be set to NULL - */ - inline void Recycle(DType **inout_dptr); - - /*! - * \brief Rethrows exception which is set by the producer - */ - inline void ThrowExceptionIfSet(void); - - /*! - * \brief clears exception_ptr, called from Init - */ - inline void ClearException(void); - - /*! - * \brief adapt the iterator interface's Next - * NOTE: the call to this function is not threadsafe - * use the other Next instead - * \return true if there is next record, false if we reach the end - */ - virtual bool Next(void) { - if (out_data_ != NULL) { - this->Recycle(&out_data_); - } - if (Next(&out_data_)) { - return true; - } else { - return false; - } - } - /*! - * \brief adapt the iterator interface's Value - * NOTE: the call to this function is not threadsafe - * use the other Next instead - */ - virtual const DType &Value(void) const { - CHECK(out_data_ != NULL) << "Calling Value at beginning or end?"; - return *out_data_; - } - /*! \brief set the iterator before first location */ - virtual void BeforeFirst(void) { - ThrowExceptionIfSet(); - std::unique_lock lock(mutex_); - if (out_data_ != NULL) { - free_cells_.push(out_data_); - out_data_ = NULL; - } - if (producer_sig_ == kDestroy) return; - - producer_sig_ = kBeforeFirst; - CHECK(!producer_sig_processed_); - if (nwait_producer_ != 0) { - producer_cond_.notify_one(); - } - CHECK(!producer_sig_processed_); - // wait until the request has been processed - consumer_cond_.wait(lock, [this]() { - return producer_sig_processed_; - }); - producer_sig_processed_ = false; - bool notify = nwait_producer_ != 0 && !produce_end_; - lock.unlock(); - // notify producer, in case they are waiting for the condition. - if (notify) producer_cond_.notify_one(); - ThrowExceptionIfSet(); - } - - private: - /*! \brief not support BeforeFirst */ - inline static void NotImplemented(void) { - LOG(FATAL) << "BeforeFirst is not supported"; - } - /*! \brief signals send to producer */ - enum Signal { - kProduce, - kBeforeFirst, - kDestroy - }; - /*! \brief producer class */ - Producer *producer_owned_; - /*! \brief signal to producer */ - Signal producer_sig_; - /*! \brief whether the special signal other than kProduce is procssed */ - bool producer_sig_processed_; - /*! \brief thread that runs the producer */ - std::thread *producer_thread_; - /*! \brief whether produce ends */ - bool produce_end_; - /*! \brief maximum queue size */ - size_t max_capacity_; - /*! \brief internal mutex */ - std::mutex mutex_; - /*! brief internal mutex for exceptions */ - std::mutex mutex_exception_; - /*! \brief number of consumer waiting */ - unsigned nwait_consumer_; - /*! \brief number of consumer waiting */ - unsigned nwait_producer_; - /*! \brief conditional variable for producer thread */ - std::condition_variable producer_cond_; - /*! \brief conditional variable for consumer threads */ - std::condition_variable consumer_cond_; - /*! \brief the current output cell */ - DType *out_data_; - /*! \brief internal queue of producer */ - std::queue queue_; - /*! \brief free cells that can be used */ - std::queue free_cells_; - /*! \brief holds a reference to iterator exception thrown in spawned threads */ - std::exception_ptr iter_exception_{nullptr}; -}; - -// implementation of functions -template inline void ThreadedIter::Destroy(void) { - if (producer_thread_ != NULL) { - { - // lock the mutex - std::lock_guard lock(mutex_); - // send destroy signal - producer_sig_ = kDestroy; - if (nwait_producer_ != 0) { - producer_cond_.notify_one(); - } - } - producer_thread_->join(); - delete producer_thread_; - producer_thread_ = NULL; - } - // end of critical region - // now the slave thread should exit - while (free_cells_.size() != 0) { - delete free_cells_.front(); - free_cells_.pop(); - } - while (queue_.size() != 0) { - delete queue_.front(); - queue_.pop(); - } - if (producer_owned_ != NULL) { - delete producer_owned_; - } - if (out_data_ != NULL) { - delete out_data_; - out_data_ = NULL; - } -} - -template -inline void ThreadedIter:: -Init(Producer *producer, bool pass_ownership) { - CHECK(producer_owned_ == NULL) << "can only call Init once"; - if (pass_ownership) producer_owned_ = producer; - auto next = [producer](DType **dptr) { - return producer->Next(dptr); - }; - auto beforefirst = [producer]() { - producer->BeforeFirst(); - }; - this->Init(next, beforefirst); -} - -template -inline void ThreadedIter::Init(std::function next, - std::function beforefirst) { - producer_sig_ = kProduce; - producer_sig_processed_ = false; - produce_end_ = false; - ClearException(); - // procedure running in prodcuer - // run producer thread - auto producer_fun = [this, next, beforefirst]() { - while (true) { - try { - DType *cell = NULL; - { - // lockscope - std::unique_lock lock(mutex_); - ++this->nwait_producer_; - producer_cond_.wait(lock, [this]() { - if (producer_sig_ == kProduce) { - bool ret = !produce_end_ && (queue_.size() < max_capacity_ || - free_cells_.size() != 0); - return ret; - } else { - return true; - } - }); - --this->nwait_producer_; - if (producer_sig_ == kProduce) { - if (free_cells_.size() != 0) { - cell = free_cells_.front(); - free_cells_.pop(); - } - } else if (producer_sig_ == kBeforeFirst) { - // reset the producer - beforefirst(); - // cleanup the queue - while (queue_.size() != 0) { - free_cells_.push(queue_.front()); - queue_.pop(); - } - // reset the state - produce_end_ = false; - producer_sig_processed_ = true; - producer_sig_ = kProduce; - // notify consumer that all the process as been done. - lock.unlock(); - consumer_cond_.notify_all(); - continue; - } else { - // destroy the thread - DCHECK(producer_sig_ == kDestroy); - producer_sig_processed_ = true; - produce_end_ = true; - consumer_cond_.notify_all(); - return; - } - } // end of lock scope - // now without lock - produce_end_ = !next(&cell); - DCHECK(cell != NULL || produce_end_); - bool notify; - { - // lockscope - std::lock_guard lock(mutex_); - if (!produce_end_) { - queue_.push(cell); - } else { - if (cell != NULL) - free_cells_.push(cell); - } - // put things into queue - notify = nwait_consumer_ != 0; - } - if (notify) - consumer_cond_.notify_all(); - } catch (dmlc::Error &e) { - // Shouldn't throw exception in destructor - DCHECK(producer_sig_ != kDestroy); - { - std::lock_guard lock(mutex_exception_); - if (!iter_exception_) { - iter_exception_ = std::current_exception(); - } - } - bool next_notify = false; - { - std::unique_lock lock(mutex_); - if (producer_sig_ == kBeforeFirst) { - while (queue_.size() != 0) { - free_cells_.push(queue_.front()); - queue_.pop(); - } - produce_end_ = true; - producer_sig_processed_ = true; - lock.unlock(); - consumer_cond_.notify_all(); - } else if (producer_sig_ == kProduce) { - produce_end_ = true; - next_notify = nwait_consumer_ != 0; - lock.unlock(); - if (next_notify) - consumer_cond_.notify_all(); - } - } - return; - } - } - }; - producer_thread_ = new std::thread(producer_fun); -} - -template -inline bool ThreadedIter::Next(DType **out_dptr) { - if (producer_sig_ == kDestroy) - return false; - ThrowExceptionIfSet(); - std::unique_lock lock(mutex_); - CHECK(producer_sig_ == kProduce) - << "Make sure you call BeforeFirst not inconcurrent with Next!"; - ++nwait_consumer_; - consumer_cond_.wait(lock, - [this]() { return queue_.size() != 0 || produce_end_; }); - --nwait_consumer_; - if (queue_.size() != 0) { - *out_dptr = queue_.front(); - queue_.pop(); - bool notify = nwait_producer_ != 0 && !produce_end_; - lock.unlock(); - if (notify) - producer_cond_.notify_one(); - - ThrowExceptionIfSet(); - return true; - } else { - CHECK(produce_end_); - lock.unlock(); - - ThrowExceptionIfSet(); - return false; - } -} - -template -inline void ThreadedIter::Recycle(DType **inout_dptr) { - bool notify; - ThrowExceptionIfSet(); - { - std::lock_guard lock(mutex_); - free_cells_.push(*inout_dptr); - *inout_dptr = NULL; - notify = nwait_producer_ != 0 && !produce_end_; - } - if (notify) - producer_cond_.notify_one(); - ThrowExceptionIfSet(); -} - -template inline void ThreadedIter::ThrowExceptionIfSet(void) { - std::exception_ptr tmp_exception{nullptr}; - { - std::lock_guard lock(mutex_exception_); - if (iter_exception_) { - tmp_exception = iter_exception_; - } - } - if (tmp_exception) - std::rethrow_exception(tmp_exception); -} - -template inline void ThreadedIter::ClearException(void) { - std::lock_guard lock(mutex_exception_); - iter_exception_ = nullptr; -} - -} // namespace dmlc -#endif // DMLC_USE_CXX11 -#endif // DMLC_THREADEDITER_H_ diff --git a/include/dmlc/timer.h b/include/dmlc/timer.h deleted file mode 100644 index c97059f97812..000000000000 --- a/include/dmlc/timer.h +++ /dev/null @@ -1,49 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file timer.h - * \brief cross platform timer for timing - * \author Tianqi Chen - */ -#ifndef DMLC_TIMER_H_ -#define DMLC_TIMER_H_ - -#include "base.h" - -#if DMLC_USE_CXX11 -#include -#endif - -#include -#ifdef __MACH__ -#include -#include -#endif -#include "./logging.h" - -namespace dmlc { -/*! - * \brief return time in seconds - */ -inline double GetTime(void) { - #if DMLC_USE_CXX11 - return std::chrono::duration( - std::chrono::high_resolution_clock::now().time_since_epoch()).count(); - #elif defined __MACH__ - clock_serv_t cclock; - mach_timespec_t mts; - host_get_clock_service(mach_host_self(), CALENDAR_CLOCK, &cclock); - CHECK(clock_get_time(cclock, &mts) == 0) << "failed to get time"; - mach_port_deallocate(mach_task_self(), cclock); - return static_cast(mts.tv_sec) + static_cast(mts.tv_nsec) * 1e-9; - #else - #if defined(__unix__) || defined(__linux__) - timespec ts; - CHECK(clock_gettime(CLOCK_REALTIME, &ts) == 0) << "failed to get time"; - return static_cast(ts.tv_sec) + static_cast(ts.tv_nsec) * 1e-9; - #else - return static_cast(time(NULL)); - #endif - #endif -} -} // namespace dmlc -#endif // DMLC_TIMER_H_ diff --git a/include/dmlc/type_traits.h b/include/dmlc/type_traits.h deleted file mode 100644 index c528903499e3..000000000000 --- a/include/dmlc/type_traits.h +++ /dev/null @@ -1,191 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file type_traits.h - * \brief type traits information header - */ -#ifndef DMLC_TYPE_TRAITS_H_ -#define DMLC_TYPE_TRAITS_H_ - -#include "./base.h" -#if DMLC_USE_CXX11 -#include -#endif -#include - -namespace dmlc { -/*! - * \brief whether a type is pod type - * \tparam T the type to query - */ -template -struct is_pod { -#if DMLC_USE_CXX11 - /*! \brief the value of the traits */ - static const bool value = std::is_pod::value; -#else - /*! \brief the value of the traits */ - static const bool value = false; -#endif -}; - - -/*! - * \brief whether a type is integer type - * \tparam T the type to query - */ -template -struct is_integral { -#if DMLC_USE_CXX11 - /*! \brief the value of the traits */ - static const bool value = std::is_integral::value; -#else - /*! \brief the value of the traits */ - static const bool value = false; -#endif -}; - -/*! - * \brief whether a type is floating point type - * \tparam T the type to query - */ -template -struct is_floating_point { -#if DMLC_USE_CXX11 - /*! \brief the value of the traits */ - static const bool value = std::is_floating_point::value; -#else - /*! \brief the value of the traits */ - static const bool value = false; -#endif -}; - -/*! - * \brief whether a type is arithemetic type - * \tparam T the type to query - */ -template -struct is_arithmetic { -#if DMLC_USE_CXX11 - /*! \brief the value of the traits */ - static const bool value = std::is_arithmetic::value; -#else - /*! \brief the value of the traits */ - static const bool value = (dmlc::is_integral::value || - dmlc::is_floating_point::value); -#endif -}; - -/*! - * \brief helper class to construct a string that represents type name - * - * Specialized this class to defined type name of custom types - * - * \tparam T the type to query - */ -template -struct type_name_helper { - /*! - * \return a string of typename. - */ - static inline std::string value() { - return ""; - } -}; - -/*! - * \brief the string representation of type name - * \tparam T the type to query - * \return a const string of typename. - */ -template -inline std::string type_name() { - return type_name_helper::value(); -} - -/*! - * \brief whether a type have save/load function - * \tparam T the type to query - */ -template -struct has_saveload { - /*! \brief the value of the traits */ - static const bool value = false; -}; - -/*! - * \brief template to select type based on condition - * For example, IfThenElseType::Type will give int - * \tparam cond the condition - * \tparam Then the typename to be returned if cond is true - * \tparam Else typename to be returned if cond is false -*/ -template -struct IfThenElseType; - -/*! \brief macro to quickly declare traits information */ -#define DMLC_DECLARE_TRAITS(Trait, Type, Value) \ - template<> \ - struct Trait { \ - static const bool value = Value; \ - } - -/*! \brief macro to quickly declare traits information */ -#define DMLC_DECLARE_TYPE_NAME(Type, Name) \ - template<> \ - struct type_name_helper { \ - static inline std::string value() { \ - return Name; \ - } \ - } - -//! \cond Doxygen_Suppress -// declare special traits when C++11 is not available -#if DMLC_USE_CXX11 == 0 -DMLC_DECLARE_TRAITS(is_pod, char, true); -DMLC_DECLARE_TRAITS(is_pod, int8_t, true); -DMLC_DECLARE_TRAITS(is_pod, int16_t, true); -DMLC_DECLARE_TRAITS(is_pod, int32_t, true); -DMLC_DECLARE_TRAITS(is_pod, int64_t, true); -DMLC_DECLARE_TRAITS(is_pod, uint8_t, true); -DMLC_DECLARE_TRAITS(is_pod, uint16_t, true); -DMLC_DECLARE_TRAITS(is_pod, uint32_t, true); -DMLC_DECLARE_TRAITS(is_pod, uint64_t, true); -DMLC_DECLARE_TRAITS(is_pod, float, true); -DMLC_DECLARE_TRAITS(is_pod, double, true); - -DMLC_DECLARE_TRAITS(is_integral, char, true); -DMLC_DECLARE_TRAITS(is_integral, int8_t, true); -DMLC_DECLARE_TRAITS(is_integral, int16_t, true); -DMLC_DECLARE_TRAITS(is_integral, int32_t, true); -DMLC_DECLARE_TRAITS(is_integral, int64_t, true); -DMLC_DECLARE_TRAITS(is_integral, uint8_t, true); -DMLC_DECLARE_TRAITS(is_integral, uint16_t, true); -DMLC_DECLARE_TRAITS(is_integral, uint32_t, true); -DMLC_DECLARE_TRAITS(is_integral, uint64_t, true); - -DMLC_DECLARE_TRAITS(is_floating_point, float, true); -DMLC_DECLARE_TRAITS(is_floating_point, double, true); - -#endif - -DMLC_DECLARE_TYPE_NAME(float, "float"); -DMLC_DECLARE_TYPE_NAME(double, "double"); -DMLC_DECLARE_TYPE_NAME(int, "int"); -DMLC_DECLARE_TYPE_NAME(uint32_t, "int (non-negative)"); -DMLC_DECLARE_TYPE_NAME(uint64_t, "long (non-negative)"); -DMLC_DECLARE_TYPE_NAME(std::string, "string"); -DMLC_DECLARE_TYPE_NAME(bool, "boolean"); -DMLC_DECLARE_TYPE_NAME(void*, "ptr"); - -template -struct IfThenElseType { - typedef Then Type; -}; - -template -struct IfThenElseType { - typedef Else Type; -}; -//! \endcond -} // namespace dmlc -#endif // DMLC_TYPE_TRAITS_H_ diff --git a/include/mshadow b/include/mshadow new file mode 120000 index 000000000000..0ff1a4b9e3b4 --- /dev/null +++ b/include/mshadow @@ -0,0 +1 @@ +../3rdparty/mshadow/mshadow \ No newline at end of file diff --git a/include/mshadow/README.md b/include/mshadow/README.md deleted file mode 100644 index 86276af013e2..000000000000 --- a/include/mshadow/README.md +++ /dev/null @@ -1,8 +0,0 @@ -Code Guide -==== -This readme contains notes about code in mshadow. MShadow generally follows Google's C++ Style. - -Convention -==== -* Basically, all the files ends in ```-inl.h, -inl.cuh``` are implementations, and can be ignored if only using mshadow -* The files ends in ```.h``` are heavily commented with [doxyen format](http://www.doxygen.org/), and can be used to generate the corresponding document. diff --git a/include/mshadow/base.h b/include/mshadow/base.h deleted file mode 100755 index 4cdab74d6a74..000000000000 --- a/include/mshadow/base.h +++ /dev/null @@ -1,1106 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file base.h - * \brief definitions of base types, operators, macros functions - * - * \author Bing Xu, Tianqi Chen - */ -#ifndef MSHADOW_BASE_H_ -#define MSHADOW_BASE_H_ -#ifdef _MSC_VER -#ifndef _CRT_SECURE_NO_WARNINGS -#define _CRT_SECURE_NO_WARNINGS -#endif -#ifndef _CRT_SECURE_NO_DEPRECATE -#define _CRT_SECURE_NO_DEPRECATE -#endif -#ifndef NOMINMAX -#define NOMINMAX -#endif -#endif -#include -#include -#include -#include -#include -#include -#include -#include - -#ifdef _MSC_VER -//! \cond Doxygen_Suppress -typedef signed char int8_t; -typedef __int16 int16_t; -typedef __int32 int32_t; -typedef __int64 int64_t; -typedef unsigned char uint8_t; -typedef unsigned __int16 uint16_t; -typedef unsigned __int32 uint32_t; -typedef unsigned __int64 uint64_t; -//! \endcond -#else -#include -#endif -// macro defintiions -/*! - * \brief if this macro is define to be 1, - * mshadow should compile without any of other libs - */ -#ifndef MSHADOW_STAND_ALONE -#define MSHADOW_STAND_ALONE 0 -#endif -/*! \brief whether do padding during allocation */ -#ifndef MSHADOW_ALLOC_PAD -#define MSHADOW_ALLOC_PAD true -#endif -/*! - * \brief - * x dimension of data must be bigger pad_size * ratio to be alloced padded memory, - * otherwise use tide allocation - * for example, if pad_ratio=2, GPU memory alignement size is 32, - * then we will only allocate padded memory if x dimension > 64 - * set it to 0 then we will always allocate padded memory - */ -#ifndef MSHADOW_MIN_PAD_RATIO - #define MSHADOW_MIN_PAD_RATIO 2 -#endif - -#if MSHADOW_STAND_ALONE - #define MSHADOW_USE_CBLAS 0 - #define MSHADOW_USE_MKL 0 - #define MSHADOW_USE_CUDA 0 -#endif - -/*! - * \brief force user to use GPU stream during computation - * error will be shot when default stream NULL is used - */ -#ifndef MSHADOW_FORCE_STREAM -#define MSHADOW_FORCE_STREAM 1 -#endif - -/*! \brief use CBLAS for CBLAS */ -#ifndef MSHADOW_USE_CBLAS - #define MSHADOW_USE_CBLAS 0 -#endif -/*! \brief use MKL for BLAS */ -#ifndef MSHADOW_USE_MKL - #define MSHADOW_USE_MKL 1 -#endif - -/*! - * \brief use CUDA support, must ensure that the cuda include path is correct, - * or directly compile using nvcc - */ -#ifndef MSHADOW_USE_CUDA - #define MSHADOW_USE_CUDA 1 -#endif - -/*! - * \brief use CUDNN support, must ensure that the cudnn include path is correct - */ -#ifndef MSHADOW_USE_CUDNN - #define MSHADOW_USE_CUDNN 0 -#endif - -/*! - * \brief use CUSOLVER support - */ -#ifndef MSHADOW_USE_CUSOLVER - #define MSHADOW_USE_CUSOLVER MSHADOW_USE_CUDA -#endif - -/*! - * \brief seems CUDAARCH is deprecated in future NVCC - * set this to 1 if you want to use CUDA version smaller than 2.0 - */ -#ifndef MSHADOW_OLD_CUDA -#define MSHADOW_OLD_CUDA 0 -#endif - -/*! - * \brief macro to decide existence of c++11 compiler - */ -#ifndef MSHADOW_IN_CXX11 - #if (defined(__GXX_EXPERIMENTAL_CXX0X__) ||\ - __cplusplus >= 201103L || defined(_MSC_VER)) - #define MSHADOW_IN_CXX11 1 - #else - #define MSHADOW_IN_CXX11 0 - #endif -#endif - -/*! \brief whether use SSE */ -#ifndef MSHADOW_USE_SSE - #define MSHADOW_USE_SSE 1 -#endif - -/*! \brief whether use F16C instruction set architecture extension */ -#ifndef MSHADOW_USE_F16C - #if defined(_MSC_VER) || defined(__CUDACC__) - #define MSHADOW_USE_F16C 0 - #elif defined(__clang__) && \ - ((__clang_major__ < 8) || ((__clang_major__ == 8) && (__clang_minor__ < 1))) - #define MSHADOW_USE_F16C 0 - #else - #define MSHADOW_USE_F16C 1 - #endif -#endif - -/*! \brief whether use NVML to get dynamic info */ -#ifndef MSHADOW_USE_NVML - #define MSHADOW_USE_NVML 0 -#endif -// SSE is conflict with cudacc -#ifdef __CUDACC__ - #undef MSHADOW_USE_SSE - #define MSHADOW_USE_SSE 0 -#endif - -#if MSHADOW_USE_CBLAS -extern "C" { - #include -} -#elif MSHADOW_USE_MKL - #include - #include - #include - #include - #include -#endif - -#if MSHADOW_USE_CUDA - #include - #include - #include -#endif - -#if MSHADOW_USE_CUDNN == 1 - #include -#endif - -#if MSHADOW_USE_CUSOLVER == 1 - #include -#endif - -#if MSHADOW_USE_NVML - #include -#endif - -// -------------------------------- -// MSHADOW_XINLINE is used for inlining template code for both CUDA and CPU code -#ifdef MSHADOW_XINLINE - #error "MSHADOW_XINLINE must not be defined" -#endif -#ifdef _MSC_VER -#define MSHADOW_FORCE_INLINE __forceinline -#pragma warning(disable : 4068) -#else -#define MSHADOW_FORCE_INLINE inline __attribute__((always_inline)) -#endif -#ifdef __CUDACC__ - #define MSHADOW_XINLINE MSHADOW_FORCE_INLINE __device__ __host__ -#else - #define MSHADOW_XINLINE MSHADOW_FORCE_INLINE -#endif -/*! \brief cpu force inline */ -#define MSHADOW_CINLINE MSHADOW_FORCE_INLINE - -#if defined(__GXX_EXPERIMENTAL_CXX0X) ||\ - defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L - #define MSHADOW_CONSTEXPR constexpr -#else - #define MSHADOW_CONSTEXPR const -#endif - -/*! - * \brief default data type for tensor string - * in code release, change it to default_real_t - * during development, change it to empty string so that missing - * template arguments can be detected - */ -#ifndef MSHADOW_DEFAULT_DTYPE -#define MSHADOW_DEFAULT_DTYPE = ::mshadow::default_real_t -#endif - -/*! - * \brief DMLC marco for logging - */ -#ifndef MSHADOW_USE_GLOG -#define MSHADOW_USE_GLOG DMLC_USE_GLOG -#endif // MSHADOW_USE_GLOG - -#if DMLC_USE_CXX11 -#define MSHADOW_THROW_EXCEPTION noexcept(false) -#define MSHADOW_NO_EXCEPTION noexcept(true) -#else -#define MSHADOW_THROW_EXCEPTION -#define MSHADOW_NO_EXCEPTION -#endif - -#if defined(_MSC_VER) -#define MSHADOW_ALIGNED(x) __declspec(align(x)) -#else -#define MSHADOW_ALIGNED(x) __attribute__ ((aligned(x))) -#endif - -/*! - * \brief Protected cuda call in mshadow - * \param func Expression to call. - * It checks for CUDA errors after invocation of the expression. - */ -#define MSHADOW_CUDA_CALL(func) \ - { \ - cudaError_t e = (func); \ - if (e == cudaErrorCudartUnloading) { \ - throw dmlc::Error(cudaGetErrorString(e)); \ - } \ - CHECK(e == cudaSuccess) \ - << "CUDA: " << cudaGetErrorString(e); \ - } - -/*! - * \brief Run function and catch error, log unknown error. - * \param func Expression to call. - */ -#define MSHADOW_CATCH_ERROR(func) \ - { \ - try { \ - (func); \ - } catch (const dmlc::Error &e) { \ - std::string what = e.what(); \ - if (what.find("driver shutting down") == std::string::npos) { \ - LOG(ERROR) << "Ignore CUDA Error " << what; \ - } \ - } \ - } - -#include "./half.h" -#include "./half2.h" -#include "./logging.h" -/*! \brief namespace for mshadow */ -namespace mshadow { -/*! \brief buffer size for each random number generator */ -const unsigned kRandBufferSize = 1000000; -/*! \brief pi */ -const float kPi = 3.1415926f; -/*! \brief type that will be used for index */ -typedef int64_t index_t; - -#ifdef _WIN32 - /*! \brief openmp index for windows */ - typedef int64_t openmp_index_t; -#else - /*! \brief openmp index for linux */ - typedef index_t openmp_index_t; -#endif - -/*! \brief float point type that will be used in default by mshadow */ -typedef float default_real_t; - -/*! \brief data type flag */ -enum TypeFlag { - kFloat32 = 0, - kFloat64 = 1, - kFloat16 = 2, - kUint8 = 3, - kInt32 = 4, - kInt8 = 5, - kInt64 = 6, -}; - -template -struct DataType; - -template<> -struct DataType { - static const int kFlag = kFloat32; - static const int kLanes = 1; -#if MSHADOW_USE_CUDA -#if (CUDA_VERSION >= 8000) - static const cudaDataType_t kCudaFlag = CUDA_R_32F; -#endif -#if MSHADOW_USE_CUDNN - static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_FLOAT; - typedef float ScaleType; -#endif -#endif -}; -template<> -struct DataType { - static const int kFlag = kFloat64; - static const int kLanes = 1; -#if MSHADOW_USE_CUDA -#if (CUDA_VERSION >= 8000) - static const cudaDataType_t kCudaFlag = CUDA_R_64F; -#endif -#if MSHADOW_USE_CUDNN - static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_DOUBLE; - typedef double ScaleType; -#endif -#endif -}; -template<> -struct DataType { - static const int kFlag = kFloat16; - static const int kLanes = 1; -#if MSHADOW_USE_CUDA -#if (CUDA_VERSION >= 8000) - static const cudaDataType_t kCudaFlag = CUDA_R_16F; -#endif -#if MSHADOW_USE_CUDNN - static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_HALF; - typedef float ScaleType; -#endif -#endif -}; -template<> -struct DataType { - static const int kFlag = kFloat16; - static const int kLanes = 2; -}; -template<> -struct DataType { - static const int kFlag = kUint8; - static const int kLanes = 1; -#if MSHADOW_USE_CUDA -#if (CUDA_VERSION >= 8000) - static const cudaDataType_t kCudaFlag = CUDA_R_8U; -#endif -#if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6) - // no uint8 in cudnn for now - static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT8; - typedef uint8_t ScaleType; -#endif -#endif -}; -template<> -struct DataType { - static const int kFlag = kInt8; - static const int kLanes = 1; -#if MSHADOW_USE_CUDA -#if (CUDA_VERSION >= 8000) - static const cudaDataType_t kCudaFlag = CUDA_R_8I; -#endif -#if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6) - static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT8; - typedef int8_t ScaleType; -#endif -#endif -}; -template<> -struct DataType { - static const int kFlag = kInt32; - static const int kLanes = 1; -#if MSHADOW_USE_CUDA -#if (CUDA_VERSION >= 8000) - static const cudaDataType_t kCudaFlag = CUDA_R_32I; -#endif -#if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6) - static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT32; - typedef int32_t ScaleType; -#endif -#endif -}; -template<> -struct DataType { - static const int kFlag = kInt64; - static const int kLanes = 1; -}; - -/*! \brief type enum value for default real type */ -const int default_type_flag = DataType::kFlag; - -/*! layout flag */ -enum LayoutFlag { - kNCHW = 0, - kNHWC, - kCHWN, - - kNCW = 1 << 3, - kNWC, - kCWN, - - kNCDHW = 1 << 5, - kNDHWC, - kCDHWN -}; - -template -struct LayoutType; - -template<> -struct LayoutType { - static const index_t kNdim = 4; -#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4) - static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW; -#else - static const int kCudnnFlag = -1; -#endif -}; - -template<> -struct LayoutType { - static const index_t kNdim = 4; -#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4) - static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC; -#else - static const int kCudnnFlag = -1; -#endif -}; - -/*! \brief default layout for 4d tensor */ -const int default_layout = kNCHW; - -template<> -struct LayoutType { - static const index_t kNdim = 5; -#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4) - static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW; -#else - static const int kCudnnFlag = -1; -#endif -}; - -template<> -struct LayoutType { - static const index_t kNdim = 5; -#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4) - static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC; -#else - static const int kCudnnFlag = -1; -#endif -}; - -/*! \brief default layout for 5d tensor */ -const int default_layout_5d = kNCDHW; - -/*! \brief namespace for operators */ -namespace op { -// binary operator -/*! \brief mul operator */ -struct mul{ - /*! \brief map a, b to result using defined operation */ - template - MSHADOW_XINLINE static DType Map(DType a, DType b) { - return a * b; - } -}; -/*! \brief plus operator */ -struct plus { - /*! \brief map a, b to result using defined operation */ - template - MSHADOW_XINLINE static DType Map(DType a, DType b) { - return a + b; - } -}; -/*! \brief minus operator */ -struct minus { - /*! \brief map a, b to result using defined operation */ - template - MSHADOW_XINLINE static DType Map(DType a, DType b) { - return a - b; - } -}; -/*! \brief divide operator */ -struct div { - /*! \brief map a, b to result using defined operation */ - template - MSHADOW_XINLINE static DType Map(DType a, DType b) { - return a / b; - } -}; -/*! \brief get rhs */ -struct right { - /*! \brief map a, b to result using defined operation */ - template - MSHADOW_XINLINE static DType Map(DType a, DType b) { - return b; - } -}; -// unary operator/ function: example -// these operators can be defined by user, -// in the same style as binary and unary operator -// to use, simply write F( src ) -/*! \brief identity function that maps a real number to it self */ -struct identity{ - /*! \brief map a to result using defined operation */ - template - MSHADOW_XINLINE static DType Map(DType a) { - return a; - } -}; -} // namespace op -/*! \brief namespace for savers */ -namespace sv { -/*! \brief save to saver: = */ -struct saveto { - /*! \brief save b to a using save method */ - template - MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*) - a = b; - } - /*! \brief helper constant to use BLAS, alpha */ - inline static default_real_t AlphaBLAS(void) { return 1.0f; } - /*! \brief helper constant to use BLAS, beta */ - inline static default_real_t BetaBLAS(void) { return 0.0f; } - /*! \brief corresponding binary operator type */ - typedef op::right OPType; -}; -/*! \brief save to saver: += */ -struct plusto { - /*! \brief save b to a using save method */ - template - MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*) - a += b; - } - /*! \brief helper constant to use BLAS, alpha */ - inline static default_real_t AlphaBLAS(void) { return 1.0f; } - /*! \brief helper constant to use BLAS, beta */ - inline static default_real_t BetaBLAS(void) { return 1.0f; } - /*! \brief corresponding binary operator type */ - typedef op::plus OPType; -}; -/*! \brief minus to saver: -= */ -struct minusto { - /*! \brief save b to a using save method */ - template - MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*) - a -= b; - } - /*! \brief helper constant to use BLAS, alpha */ - inline static default_real_t AlphaBLAS(void) { return -1.0f; } - /*! \brief helper constant to use BLAS, beta */ - inline static default_real_t BetaBLAS(void) { return 1.0f; } - /*! \brief corresponding binary operator type */ - typedef op::minus OPType; -}; -/*! \brief multiply to saver: *= */ -struct multo { - /*! \brief save b to a using save method */ - template - MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*) - a *= b; - } - /*! \brief corresponding binary operator type */ - typedef op::mul OPType; -}; -/*! \brief divide to saver: /= */ -struct divto { - /*! \brief save b to a using save method */ - template - MSHADOW_XINLINE static void Save(DType& a, DType b) { // NOLINT(*) - a /= b; - } - /*! \brief corresponding binary operator type */ - typedef op::div OPType; -}; -} // namespace sv -/*! \brief namespace for potential reducer operations */ -namespace red { -namespace limits { -/*! - * \brief minimum value of certain types - * \tparam DType data type - */ -template -MSHADOW_XINLINE DType MinValue(void); -/*! \brief minimum value of float */ -template<> -MSHADOW_XINLINE float MinValue(void) { - return -FLT_MAX; -} -/*! \brief minimum value of double */ -template<> -MSHADOW_XINLINE double MinValue(void) { - return -DBL_MAX; -} -/*! \brief minimum value of half */ -template<> -MSHADOW_XINLINE half::half_t MinValue(void) { - return MSHADOW_HALF_MIN; -} -/*! \brief minimum value of uint8_t */ -template<> -MSHADOW_XINLINE uint8_t MinValue(void) { - return 0; -} -/*! \brief minimum value of int8_t */ -template<> -MSHADOW_XINLINE int8_t MinValue(void) { - return SCHAR_MIN; -} -/*! \brief minimum value of int32_t */ -template<> -MSHADOW_XINLINE int MinValue(void) { - return INT_MIN; -} -/*! \brief minimum value of int64_t */ -template<> -MSHADOW_XINLINE int64_t MinValue(void) { - return LLONG_MIN; -} - -/*! - * \brief maximum value of certain types - * \tparam DType data type - */ -template -MSHADOW_XINLINE DType MaxValue(void); -/*! \brief maximum value of float */ -template<> -MSHADOW_XINLINE float MaxValue(void) { - return FLT_MAX; -} -/*! \brief maximum value of double */ -template<> -MSHADOW_XINLINE double MaxValue(void) { - return DBL_MAX; -} -/*! \brief maximum value of half */ -template<> -MSHADOW_XINLINE half::half_t MaxValue(void) { - return MSHADOW_HALF_MAX; -} -/*! \brief maximum value of uint8_t */ -template<> -MSHADOW_XINLINE uint8_t MaxValue(void) { - return UCHAR_MAX; -} -/*! \brief maximum value of int8_t */ -template<> -MSHADOW_XINLINE int8_t MaxValue(void) { - return SCHAR_MAX; -} -/*! \brief maximum value of int32_t */ -template<> -MSHADOW_XINLINE int MaxValue(void) { - return INT_MAX; -} -/*! \brief maximum value of int64_t */ -template<> -MSHADOW_XINLINE int64_t MaxValue(void) { - return LLONG_MAX; -} -} // namespace limits - -/*! \brief sum reducer */ -struct sum { - /*! \brief do reduction into dst */ - template - MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*) - dst += src; - } - /*! \brief do stable reduction into dst */ - template - MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& residual) { // NOLINT(*) - DType y = src - residual; - DType t = dst + y; - residual = (t - dst) - y; - dst = t; - } - /*! \brief combine the results of two reducers */ - template - MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*) - Reduce(dst_val, src_val); - } - /*! \brief combine the results of two reducers */ - template - MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*) - DType t1 = dst_val + src_val; - DType e = t1 - dst_val; - DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual; - dst_val = t1 + t2; - dst_residual = t2 - (dst_val - t1); - } - /*! \brief finalize reduction */ - template - MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*) - /*! \brief finalize reduction */ - template - MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*) - /*! - *\brief calculate gradient of redres with respect to redsrc, - * redres: reduced result, redsrc: one of reduction element - */ - template - MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) { - return 1; - } - /*! - *\brief set the initial value during reduction - */ - template - MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*) - initv = 0; - } - /*! - *\brief set the initial value during reduction - */ - template - MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &residual) { // NOLINT(*) - SetInitValue(initv); - residual = 0; - } -}; -/*! \brief maximum reducer */ -struct maximum { - /*! \brief do reduction into dst */ - template - MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*) - using namespace std; -#ifdef __CUDACC__ - dst = ::max(dst, src); -#else - dst = max(dst, src); -#endif // __CUDACC__ - } - /*! \brief do reduction into dst */ - template - MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType &none) { // NOLINT(*) - Reduce(dst, src); - } - /*! \brief combine the results of two reducers */ - template - MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*) - Reduce(dst_val, src_val); - } - /*! \brief combine the results of two reducers */ - template - MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*) - Reduce(dst_val, src_val); - } - /*! \brief finalize reduction */ - template - MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*) - /*! \brief finalize reduction */ - template - MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*) - /*! - * \brief calculate gradient of redres with respect to redsrc, - * redres: reduced result, redsrc: one of reduction element - */ - template - MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) { - return redres == redsrc ? 1: 0; - } - /*! - *\brief set the initial value during reduction - */ - template - MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*) - initv = limits::MinValue(); - } - /*! - *\brief set the initial value during reduction - */ - template - MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &none) { // NOLINT(*) - SetInitValue(initv); - } -}; -/*! \brief minimum reducer */ -struct minimum { - /*! \brief do reduction into dst */ - template - MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*) - using namespace std; -#ifdef __CUDACC__ - dst = ::min(dst, src); -#else - dst = min(dst, src); -#endif // __CUDACC__ - } - /*! \brief do reduction into dst */ - template - MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType &none) { // NOLINT(*) - Reduce(dst, src); - } - /*! \brief combine the results of two reducers */ - template - MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*) - Reduce(dst_val, src_val); - } - /*! \brief combine the results of two reducers */ - template - MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*) - Reduce(dst_val, src_val); - } - /*! \brief finalize reduction */ - template - MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*) - /*! \brief finalize reduction */ - template - MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*) - /*! - * \brief calculate gradient of redres with respect to redsrc, - * redres: reduced result, redsrc: one of reduction element - */ - template - MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) { - return redres == redsrc ? 1: 0; - } - /*! - *\brief set the initial value during reduction - */ - template - MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*) - initv = limits::MaxValue(); - } - /*! - *\brief set the initial value during reduction - */ - template - MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &none) { // NOLINT(*) - SetInitValue(initv); - } -}; -} // namespace red - -#define MSHADOW_TYPE_SWITCH(type, DType, ...) \ - switch (type) { \ - case mshadow::kFloat32: \ - { \ - typedef float DType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kFloat64: \ - { \ - typedef double DType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kFloat16: \ - { \ - typedef mshadow::half::half_t DType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kUint8: \ - { \ - typedef uint8_t DType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kInt8: \ - { \ - typedef int8_t DType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kInt32: \ - { \ - typedef int32_t DType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kInt64: \ - { \ - typedef int64_t DType; \ - {__VA_ARGS__} \ - } \ - break; \ - default: \ - LOG(FATAL) << "Unknown type enum " << type; \ - } - -#define MSHADOW_TYPE_SWITCH_WITH_HALF2(type, DType, ...) \ - switch (type) { \ - case mshadow::kFloat32: \ - { \ - typedef float DType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kFloat64: \ - { \ - typedef double DType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kFloat16: \ - { \ - typedef mshadow::half::half2_t DType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kUint8: \ - { \ - typedef uint8_t DType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kInt32: \ - { \ - typedef int32_t DType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kInt64: \ - { \ - typedef int64_t DType; \ - {__VA_ARGS__} \ - } \ - break; \ - default: \ - LOG(FATAL) << "Unknown type enum " << type; \ - } - -#define MSHADOW_SGL_DBL_TYPE_SWITCH(type, DType, ...) \ - switch (type) { \ - case mshadow::kFloat32: \ - { \ - typedef float DType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kFloat64: \ - { \ - typedef double DType; \ - {__VA_ARGS__} \ - } \ - break; \ - default: \ - LOG(FATAL) << "This operation only supports " \ - "32-bit and 64-bit floating point"; \ - } - -#define MSHADOW_REAL_TYPE_SWITCH(type, DType, ...) \ - switch (type) { \ - case mshadow::kFloat32: \ - { \ - typedef float DType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kFloat64: \ - { \ - typedef double DType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kFloat16: \ - { \ - typedef mshadow::half::half_t DType; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kUint8: \ - LOG(FATAL) << "This operation only support " \ - "floating point types not uint8"; \ - break; \ - case mshadow::kInt8: \ - LOG(FATAL) << "This operation only support " \ - "floating point types not int8"; \ - break; \ - case mshadow::kInt32: \ - LOG(FATAL) << "This operation only support " \ - "floating point types, not int32";\ - break; \ - case mshadow::kInt64: \ - LOG(FATAL) << "This operation only support " \ - "floating point types, not int64";\ - break; \ - default: \ - LOG(FATAL) << "Unknown type enum " << type; \ - } - -#define MSHADOW_REAL_TYPE_SWITCH_EX(type$, DType$, DLargeType$, ...) \ - switch (type$) { \ - case mshadow::kFloat32: \ - { \ - typedef float DType$; \ - typedef float DLargeType$; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kFloat64: \ - { \ - typedef double DType$; \ - typedef double DLargeType$; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kFloat16: \ - { \ - typedef mshadow::half::half_t DType$; \ - typedef float DLargeType$; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kUint8: \ - LOG(FATAL) << "This operation only support " \ - "floating point types not uint8"; \ - break; \ - case mshadow::kInt8: \ - LOG(FATAL) << "This operation only support " \ - "floating point types not int8"; \ - break; \ - case mshadow::kInt32: \ - LOG(FATAL) << "This operation only support " \ - "floating point types, not int32";\ - break; \ - case mshadow::kInt64: \ - LOG(FATAL) << "This operation only support " \ - "floating point types, not int64";\ - break; \ - default: \ - LOG(FATAL) << "Unknown type enum " << type$; \ - } - -#define MSHADOW_LAYOUT_SWITCH(layout, Layout, ...) \ - switch (layout) { \ - case mshadow::kNCHW: \ - { \ - const int Layout = kNCHW; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kNHWC: \ - { \ - const int Layout = kNHWC; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kNCDHW: \ - { \ - const int Layout = kNCDHW; \ - {__VA_ARGS__} \ - } \ - break; \ - case mshadow::kNDHWC: \ - { \ - const int Layout = kNDHWC; \ - {__VA_ARGS__} \ - } \ - break; \ - default: \ - LOG(FATAL) << "Unknown layout enum " << layout; \ - } - -/*! - * \brief Only supports int64 index type for aux_data - * in NDArray class fow now. - */ -#define MSHADOW_IDX_TYPE_SWITCH(type, DType, ...) \ - switch (type) { \ - case mshadow::kInt64: \ - { \ - typedef int64_t DType; \ - {__VA_ARGS__} \ - } \ - break; \ - default: \ - LOG(FATAL) << "Unknown type enum " << type; \ - } - -/*! \brief get data type size from type enum */ -inline size_t mshadow_sizeof(int type) { - int size = 0; - MSHADOW_TYPE_SWITCH(type, DType, size = sizeof(DType);); - return size; -} - -} // namespace mshadow -#endif // MSHADOW_BASE_H_ diff --git a/include/mshadow/cuda/reduce.cuh b/include/mshadow/cuda/reduce.cuh deleted file mode 100644 index 921d5ad5e0c0..000000000000 --- a/include/mshadow/cuda/reduce.cuh +++ /dev/null @@ -1,120 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file reduce.cuh - * \brief helper functions to do reduction - * \author Tianqi Chen - */ -#ifndef MSHADOW_CUDA_REDUCE_CUH_ -#define MSHADOW_CUDA_REDUCE_CUH_ - -namespace mshadow { -namespace cuda { -/* - * \brief reduce over the dimension x - * \tparam Reducer reducer - * \tparam x_bits dimension = 1< -inline __device__ void Reduce1D(volatile DType buf[1 << x_bits]); -/* - * \brief reduce over the dimension x - * \tparam Reducer reducer - * \tparam xmax_bits maximum size of buffer - * \tparam DType content data type - * \param xsize size of x dimension, not sure if aligned - */ -template -inline __device__ void -Reduce1DNotAlign(volatile DType buf[1 << xmax_bits], int xsize); -// ===============================================x=== -// implementations afterwards, -// no need to read if only use the functions -// -------------------------------------------------- -#ifdef __DEVICE_EMULATION__ -#define __syncwarp() __syncthreads() -#else -#if CUDA_VERSION < 9000 -#define __syncwarp() -#endif -#endif - -template -inline __device__ void ReduceX(volatile DType buf[], int tid) { - if (x_bits >= 10) { - if (tid < 512) Reducer::Reduce(buf[tid] , buf[tid + 512]); - __syncthreads(); - } - if (x_bits >= 9) { - if (tid < 256) Reducer::Reduce(buf[tid] , buf[tid + 256]); - __syncthreads(); - } - if (x_bits >= 8) { - if (tid < 128) Reducer::Reduce(buf[tid] , buf[tid + 128]); - __syncthreads(); - } - if (x_bits >= 7) { - if (tid < 64) Reducer::Reduce(buf[tid] , buf[tid + 64]); - __syncthreads(); - } - if (x_bits >= 6) { - if (tid < 32) Reducer::Reduce(buf[tid] , buf[tid + 32]); - __syncthreads(); - } - // in warp optimization - if (x_bits >= 5) { - if (tid < 16) Reducer::Reduce(buf[tid] , buf[tid + 16]); -#if MSHADOW_OLD_CUDA - __syncthreads(); -#else - __syncwarp(); -#endif - } - if (x_bits >= 4) { - if (tid < 8) Reducer::Reduce(buf[tid] , buf[tid + 8]); - __syncwarp(); - } - if (x_bits >= 3) { - if (tid < 4) Reducer::Reduce(buf[tid] , buf[tid + 4]); - __syncwarp(); - } - if (x_bits >= 2) { - if (tid < 2) Reducer::Reduce(buf[tid] , buf[tid + 2]); - __syncwarp(); - } - if (x_bits >= 1) { - if (tid < 1) Reducer::Reduce(buf[tid] , buf[tid + 1]); - __syncwarp(); - } -} -template -inline __device__ void Reduce1D(volatile DType buf[1 << x_bits]) { - ReduceX(buf, threadIdx.x); -} -// reduce with a upper bound -#define __RD_NON_ALIGN(els, x_bits) \ - els \ - if (xmax_bits >= x_bits && x_size >= (1 << x_bits)) { \ - if (tid < (1 << x_bits) && tid + (1 << x_bits) < x_size) { \ - Reducer::Reduce(buf[tid] , buf[tid + (1 << x_bits)]); \ - } \ - __syncthreads(); \ - ReduceX(buf, tid); \ - } \ - -template -inline __device__ void Reduce1DNotAlign(volatile DType buf[], int x_size) { - int tid = threadIdx.x; - __RD_NON_ALIGN(, 8) - __RD_NON_ALIGN(else, 7) - __RD_NON_ALIGN(else, 6) - __RD_NON_ALIGN(else, 5) - __RD_NON_ALIGN(else, 4) - __RD_NON_ALIGN(else, 3) - __RD_NON_ALIGN(else, 2) - __RD_NON_ALIGN(else, 1) -} -} // namespace cuda -} // namespace mshadow -#endif // MSHADOW_CUDA_REDUCE_CUH_ - diff --git a/include/mshadow/cuda/tensor_gpu-inl.cuh b/include/mshadow/cuda/tensor_gpu-inl.cuh deleted file mode 100755 index 72e4b7eb9ee9..000000000000 --- a/include/mshadow/cuda/tensor_gpu-inl.cuh +++ /dev/null @@ -1,828 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file tensor_gpu-inl.cuh - * \brief implementation of GPU code using CUDA - * \author Bing Xu, Tianqi Chen - */ -#ifndef MSHADOW_CUDA_TENSOR_GPU_INL_CUH_ -#define MSHADOW_CUDA_TENSOR_GPU_INL_CUH_ -#include -#include -#if CUDA_VERSION >= 7000 -#include -#endif -#include "../tensor.h" -#include "./reduce.cuh" -#define MSHADOW_CUDA_POST_KERNEL_CHECK(x) \ - /* Code block avoids redefinition of cudaError_t err */ \ - do { \ - cudaError err = cudaPeekAtLastError(); \ - CHECK_EQ(err, cudaSuccess) << "Name: " << #x << " ErrStr:" << cudaGetErrorString(err); \ - } while (0) -namespace mshadow { -namespace cuda { -/* load unit for memory access, if CUDAARCH not defined, this is advanced nvcc */ -#if MSHADOW_OLD_CUDA -const int kMemUnitBits = 4; -const int kMaxThreadsPerBlock = 512; -#else -const int kMemUnitBits = 5; -const int kMaxThreadsPerBlock = 1024; -#endif -/*! \brief number of units that can do synchronized update, half warp size */ -const int kMemUnit = 1 << kMemUnitBits; -/*! \brief mask that could be helpful sometime */ -const int kMemUnitMask = kMemUnit - 1; -/*! \brief suggested thread number(logscale) for mapping kernel */ -const int kBaseThreadBits = 8; -/*! \brief suggested thread number for mapping kernel */ -const int kBaseThreadNum = 1 << kBaseThreadBits; -/*! \brief maximum value of grid */ -const int kMaxGridNum = 65535; -/*! \brief maximum value of grid within each dimension */ -const int kMaxGridDim = 65535; -/*! \brief suggested grid number for mapping kernel */ -const int kBaseGridNum = 1024; -/*! \brief get align stride for given size in x dimension */ -inline index_t GetAlignStride(index_t xsize) { - if (xsize >= MSHADOW_MIN_PAD_RATIO * 32) { - return ((xsize + kMemUnit - 1) >> kMemUnitBits) << kMemUnitBits; - } else { - // if originally space is not aligned, no necessary to to alligned thread allocation - return xsize; - } -} -inline void CheckLaunchParam(dim3 dimGrid, dim3 dimBlock, const char *estr = "") { - if (dimBlock.x * dimBlock.y * dimBlock.z > static_cast(kMaxThreadsPerBlock) || - dimGrid.x > kMaxGridDim || dimGrid.y > kMaxGridDim) { - LOG(FATAL) << "too large launch parameter: " - << estr << "[" - << dimGrid.x << "," - << dimGrid.y << "], [" - << dimBlock.x << "," - << dimBlock.y << "," - << dimBlock.z << "]"; - } -} -template -__device__ void MapPlanProc(DstPlan dst, index_t xstride, - Shape<2> dshape, const Plan plan, int block_idx) { - const index_t tid = (block_idx << block_dim_bits) + threadIdx.x; - const int y = tid / xstride; - const int x = tid % xstride; - if (y < dshape[0] && x < dshape[1]) { - Saver::Save(dst.REval(y, x), plan.Eval(y, x)); - } -} -template -__global__ void MapPlanKernel(DstPlan dst, index_t xstride, - Shape<2> dshape, const Plan plan) { - MapPlanProc - (dst, xstride, dshape, plan, blockIdx.x); -} -template -__global__ void MapPlanLargeKernel(DstPlan dst, index_t xstride, - Shape<2> dshape, const Plan plan, int repeat) { - for (int i = 0; i < repeat; ++i) { - MapPlanProc - (dst, xstride, dshape, plan, blockIdx.x + i * grid_size); - } -} - -template -inline void MapPlan(expr::Plan dst, - const expr::Plan &plan, - Shape<2> dshape, - cudaStream_t stream) { - const index_t xstride = GetAlignStride(dshape[1]); - const int num_block = (dshape[0] * xstride + kBaseThreadNum-1) / kBaseThreadNum; - dim3 dimBlock(kBaseThreadNum, 1, 1); - - if (num_block < kMaxGridNum) { - dim3 dimGrid(num_block, 1, 1); - MapPlanKernel, - expr::Plan > - <<>>(dst, xstride, dshape, plan); - MSHADOW_CUDA_POST_KERNEL_CHECK(MapPlanKernel); - } else { - int repeat = (num_block + kBaseGridNum-1) / kBaseGridNum; - dim3 dimGrid(kBaseGridNum, 1 , 1); - MapPlanLargeKernel, - expr::Plan > - <<>>(dst, xstride, dshape, plan, repeat); - MSHADOW_CUDA_POST_KERNEL_CHECK(MapPlanLargeKernel); - } -} - -template -__global__ void -__launch_bounds__(kMemUnit*kMemUnit, 1) -MapRedKeepLowestKernel(DstPlan dst, Plan plan, - DType scale, Shape<2> eshape) { - const unsigned warp_size = 1 << warp_bits; - const unsigned x = (blockIdx.x << warp_bits) + threadIdx.x; - // to avoid bank conflict - __shared__ DType s_res[warp_size][warp_size + 1]; - // note: reverse store [y][x], so that we can reduce over threadIdx.x, use warp optimization - if (threadIdx.y < eshape[0] && x < eshape[1]) { - s_res[threadIdx.x][threadIdx.y] = plan.Eval(threadIdx.y, x); - } - for (unsigned y = warp_size; y < eshape[0]; y += warp_size) { - if (threadIdx.y + y < eshape[0] && x < eshape[1]) { - Reducer::Reduce(s_res[threadIdx.x][threadIdx.y], plan.Eval(threadIdx.y + y, x)); - } - } - __syncthreads(); - if (eshape[0] >= warp_size) { - Reduce1D(s_res[threadIdx.y]); - } else { - Reduce1DNotAlign(s_res[threadIdx.y], eshape[0]); - } - __syncthreads(); - - if (threadIdx.y == 0 && x < eshape[1]) { - Saver::Save(dst.REval(0, x), DType(s_res[threadIdx.x][0] * scale)); - } -} - -template -inline void MapReduceKeepLowest(expr::Plan dst, - const expr::Plan &plan, - DType scale, Shape<2> eshape, - cudaStream_t stream) { - dim3 dimBlock(kMemUnit, kMemUnit); - dim3 dimGrid((eshape[1] + kMemUnit - 1) >> kMemUnitBits); - CheckLaunchParam(dimGrid, dimBlock, "MapRedKeepLowestKernel"); - MapRedKeepLowestKernel, - expr::Plan > - <<>>(dst, plan, scale, eshape); - MSHADOW_CUDA_POST_KERNEL_CHECK(MapRedKeepLowestKernel); -} - -template -__global__ void MapReduceKeepDim1Kernel(DstPlan dst, Plan plan, DType scale, Shape<4> pshape) { - const int block_size = 1 << block_dim_bits; - __shared__ DType s_rec[block_size]; - const int c = blockIdx.x + blockIdx.y * gridDim.x; - const index_t tot = pshape[3] * pshape[2] * pshape[0]; - - if (c < pshape[1]) { - DType res; Reducer::SetInitValue(res); - for (index_t i_offset = 0; i_offset < tot; i_offset += block_size) { - index_t i = i_offset + threadIdx.x; - if (i< tot) { - const index_t x = i % pshape[3]; - i /= pshape[3]; - const index_t y = i % pshape[2]; - const index_t n = i / pshape[2]; - Reducer::Reduce(res, plan.Eval((n * pshape[1] + c) * pshape[2] + y, x)); - } - } - s_rec[threadIdx.x] = res; - __syncthreads(); - Reduce1D(s_rec); - if (threadIdx.x == 0) { - Saver::Save(dst.REval(0, c), DType(s_rec[0] * scale)); - } - } -} - -template -inline void MapReduceKeepDim1(expr::Plan dst, - const expr::Plan &plan, - DType scale, Shape<4> pshape, - cudaStream_t stream) { - dim3 dimBlock(kBaseThreadNum); - const int grid_dim_x = (pshape[1] > kMaxGridNum) ? kMaxGridNum : pshape[1]; - const int grid_dim_y = (pshape[1] > kMaxGridNum) ? (pshape[1] + kMaxGridNum - 1) / kMaxGridNum - : 1; - dim3 dimGrid(grid_dim_x, grid_dim_y); - CheckLaunchParam(dimGrid, dimBlock, "MapReduceKeepDim1"); - MapReduceKeepDim1Kernel, - expr::Plan > - <<>>(dst, plan, scale, pshape); - MSHADOW_CUDA_POST_KERNEL_CHECK(MapReduceKeepDim1Kernel); -} - -template -__global__ void GetBatchedViewKernel(DType **dst, DType *src, int num, int stride) { - const int x_size = 1 << x_bits; - const int start = threadIdx.x; - // Copy the addresses of src to dst every stride steps - for (int i = start; i < num; i += x_size) { - dst[i] = src + i * stride; - } -} - -template -inline void GetBatchedView(DType **dst, DType *src, int num, int stride, - Stream *stream) { - cudaStream_t stream_ = Stream::GetStream(stream); - dim3 dimBlock(kBaseThreadNum); - dim3 dimGrid(1); - CheckLaunchParam(dimGrid, dimBlock, "GetBatchedView"); - GetBatchedViewKernel - <<>> (dst, src, num, stride); - MSHADOW_CUDA_POST_KERNEL_CHECK(GetBatchedViewKernel); -} - -template -__global__ void SoftmaxGradKernel(DstPlan dst, SrcPlan1 src, SrcPlan2 label, index_t xmax) { - const unsigned x_size = 1 << x_bits; - const int y = blockIdx.x; - const int k = static_cast(label.Eval(0, y)); - - // calculate normalizer, with writeback - for (unsigned x = 0; x < xmax; x += x_size) { - const unsigned xindex = x + threadIdx.x; - if (xindex < xmax) { - if (xindex == k) { - dst.REval(y, xindex) = src.Eval(y, xindex) - 1.0f; - } else { - dst.REval(y, xindex) = src.Eval(y, xindex); - } - } - } -} - -template -__global__ void SmoothSoftmaxGradKernel(DstPlan dst, SrcPlan1 src, SrcPlan2 label, index_t xmax, - float alpha) { - const unsigned x_size = 1 << x_bits; - const int y = blockIdx.x; - const int k = static_cast(label.Eval(0, y)); - // xmax is the number of classes in our distribution - const float smooth_grad = (alpha / (xmax - 1)); - - // calculate normalizer, with writeback - for (unsigned x = 0; x < xmax; x += x_size) { - const unsigned xindex = x + threadIdx.x; - if (xindex < xmax) { - if (xindex == k) { - dst.REval(y, xindex) = src.Eval(y, xindex) - 1.0f + alpha; - } else { - dst.REval(y, xindex) = src.Eval(y, xindex) - smooth_grad; - } - } - } -} - -template -__global__ void SoftmaxGradKernel(DstPlan dst, SrcPlan1 src, SrcPlan2 label, index_t xmax, - DType ignore_label) { - const unsigned x_size = 1 << x_bits; - const int y = blockIdx.x; - const int k = static_cast(label.Eval(0, y)); - - // calculate normalizer, with writeback - for (unsigned x = 0; x < xmax; x += x_size) { - const unsigned xindex = x + threadIdx.x; - if (xindex < xmax) { - if (static_cast(ignore_label) == k) { - dst.REval(y, xindex) = 0.0f; - } else { - if (xindex == k) { - dst.REval(y, xindex) = src.Eval(y, xindex) - 1.0f; - } else { - dst.REval(y, xindex) = src.Eval(y, xindex); - } - } - } - } -} - -template -__global__ void SmoothSoftmaxGradKernel(DstPlan dst, SrcPlan1 src, SrcPlan2 label, index_t xmax, - DType ignore_label, float alpha) { - const unsigned x_size = 1 << x_bits; - const int y = blockIdx.x; - const int k = static_cast(label.Eval(0, y)); - // xmax is the number of classes in our distribution - const float smooth_grad = (alpha / (xmax - 1)); - - // calculate normalizer, with writeback - for (unsigned x = 0; x < xmax; x += x_size) { - const unsigned xindex = x + threadIdx.x; - if (xindex < xmax) { - if (static_cast(ignore_label) == k) { - dst.REval(y, xindex) = 0.0f; - } else { - if (xindex == k) { - dst.REval(y, xindex) = src.Eval(y, xindex) - 1.0f + alpha; - } else { - dst.REval(y, xindex) = src.Eval(y, xindex) - smooth_grad; - } - } - } - } -} - -template -__global__ void SoftmaxKernel(DstPlan dst, SrcPlan src, index_t xmax) { - const unsigned x_size = 1 << x_bits; - const int y = blockIdx.x; - __shared__ DType s_rec[x_size]; - // step 1: get max - if (threadIdx.x < xmax) { - s_rec[threadIdx.x] = src.Eval(y, threadIdx.x); - } - for (unsigned x = x_size; x < xmax; x += x_size) { - if (x + threadIdx.x < xmax) { - DType a = src.Eval(y, x + threadIdx.x); - s_rec[threadIdx.x] = max(a, s_rec[threadIdx.x]); - } - } - __syncthreads(); - if (threadIdx.x >= xmax) { - s_rec[threadIdx.x] = s_rec[0]; - } - __syncthreads(); - Reduce1D(s_rec); - __syncthreads(); - DType smax = s_rec[0]; - __syncthreads(); - s_rec[threadIdx.x] = 0.0f; - __syncthreads(); - - // calculate normalizer, with writeback - for (unsigned x = 0; x < xmax; x += x_size) { - if (x + threadIdx.x < xmax) { - DType p = expf(src.Eval(y, x + threadIdx.x) - smax); - s_rec[threadIdx.x] += p; - // write back first, will fetch later - dst.REval(y, x + threadIdx.x) = p; - } - } - // calculate normalizer - __syncthreads(); - Reduce1D(s_rec); - __syncthreads(); - DType ssum = s_rec[0]; - - for (unsigned x = 0; x < xmax; x += x_size) { - if (x + threadIdx.x < xmax) { - dst.REval(y, x + threadIdx.x) /= ssum; - } - } -} - -template -inline void Softmax(const Tensor &dst, - const Tensor &src) { - dim3 dimBlock(kBaseThreadNum); - dim3 dimGrid(dst.size(0)); - CHECK_EQ(dst.shape_, src.shape_) << "Softmax: shape mismatch"; - CheckLaunchParam(dimGrid, dimBlock, "Softmax"); - cudaStream_t stream = Stream::GetStream(dst.stream_); - SoftmaxKernel - <<>> - (expr::MakePlan(dst), - expr::MakePlan(src), - dst.size(1)); - MSHADOW_CUDA_POST_KERNEL_CHECK(SoftmaxKernel); -} - -template -inline void SoftmaxGrad(const Tensor &dst, - const Tensor &src, - const Tensor &label) { - dim3 dimBlock(kBaseThreadNum); - dim3 dimGrid(dst.size(0)); - CHECK_EQ(dst.shape_, src.shape_) << "SoftmaxGrad: shape mismatch"; - CHECK_EQ(dst.size(0), label.size(0)) << "SoftmaxGrad: label shape mismatch"; - CheckLaunchParam(dimGrid, dimBlock, "SoftmaxGrad"); - cudaStream_t stream = Stream::GetStream(dst.stream_); - SoftmaxGradKernel - <<>> - (expr::MakePlan(dst), - expr::MakePlan(src), - expr::MakePlan(label), - dst.size(1)); - MSHADOW_CUDA_POST_KERNEL_CHECK(SoftmaxGradKernel); -} - -template -inline void SmoothSoftmaxGrad(const Tensor &dst, - const Tensor &src, - const Tensor &label, - const float alpha) { - dim3 dimBlock(kBaseThreadNum); - dim3 dimGrid(dst.size(0)); - CHECK_EQ(dst.shape_, src.shape_) << "SoftmaxGrad: shape mismatch"; - CHECK_EQ(dst.size(0), label.size(0)) << "SoftmaxGrad: label shape mismatch"; - CheckLaunchParam(dimGrid, dimBlock, "SoftmaxGrad"); - cudaStream_t stream = Stream::GetStream(dst.stream_); - SmoothSoftmaxGradKernel - <<>> - (expr::MakePlan(dst), - expr::MakePlan(src), - expr::MakePlan(label), - dst.size(1), - alpha); - MSHADOW_CUDA_POST_KERNEL_CHECK(SoftmaxGradKernel); -} - -template -inline void SoftmaxGrad(const Tensor &dst, - const Tensor &src, - const Tensor &label, - const DType &ignore_label) { - dim3 dimBlock(kBaseThreadNum); - dim3 dimGrid(dst.size(0)); - CHECK_EQ(dst.shape_, src.shape_) << "SoftmaxGrad: shape mismatch"; - CHECK_EQ(dst.size(0), label.size(0)) << "SoftmaxGrad: label shape mismatch"; - CheckLaunchParam(dimGrid, dimBlock, "SoftmaxGrad"); - cudaStream_t stream = Stream::GetStream(dst.stream_); - SoftmaxGradKernel - <<>> - (expr::MakePlan(dst), - expr::MakePlan(src), - expr::MakePlan(label), - dst.size(1), - ignore_label); - MSHADOW_CUDA_POST_KERNEL_CHECK(SoftmaxGradKernel); -} - -template -inline void SmoothSoftmaxGrad(const Tensor &dst, - const Tensor &src, - const Tensor &label, - const DType &ignore_label, - const float alpha) { - dim3 dimBlock(kBaseThreadNum); - dim3 dimGrid(dst.size(0)); - CHECK_EQ(dst.shape_, src.shape_) << "SoftmaxGrad: shape mismatch"; - CHECK_EQ(dst.size(0), label.size(0)) << "SoftmaxGrad: label shape mismatch"; - CheckLaunchParam(dimGrid, dimBlock, "SoftmaxGrad"); - cudaStream_t stream = Stream::GetStream(dst.stream_); - SmoothSoftmaxGradKernel - <<>> - (expr::MakePlan(dst), - expr::MakePlan(src), - expr::MakePlan(label), - dst.size(1), - ignore_label, - alpha); - MSHADOW_CUDA_POST_KERNEL_CHECK(SoftmaxGradKernel); -} - -template -__global__ void Softmax3DGradKernel(Tensor dst, - const Tensor src, - const Tensor label) { - const index_t xmax = dst.size(1); - const index_t nmax = dst.size(2); - const unsigned n_size = 1 << n_bits; - const int y = blockIdx.x; - const int n = threadIdx.x; - - for (index_t n_index = n; n_index < nmax; n_index += n_size) { - const int k = static_cast(label[y][n_index]); - for (index_t i = 0; i < xmax; ++i) { - if (i == k) { - dst[y][i][n_index] = src[y][i][n_index] - 1.0f; - } else { - dst[y][i][n_index] = src[y][i][n_index]; - } - } - } -} - -template -__global__ void Softmax3DGradKernel(Tensor dst, - const Tensor src, - const Tensor label, - DType ignore_label) { - const index_t xmax = dst.size(1); - const index_t nmax = dst.size(2); - const unsigned n_size = 1 << n_bits; - const int y = blockIdx.x; - const int n = threadIdx.x; - for (index_t n_index = n; n_index < nmax; n_index += n_size) { - int k = static_cast(label[y][n_index]); - if (k == static_cast(ignore_label)) { - for (index_t i = 0; i < xmax; ++i) { - dst[y][i][n_index] = 0.0f; - } - } else { - for (index_t i = 0; i < xmax; ++i) { - if (i == k) { - dst[y][i][n_index] = src[y][i][n_index] - 1.0f; - } else { - dst[y][i][n_index] = src[y][i][n_index]; - } - } - } - } -} - -template -__global__ void Softmax3DKernel(Tensor dst, - const Tensor src) { - const index_t xmax = dst.size(1); - const index_t nmax = dst.size(2); - const unsigned n_size = 1 << n_bits; - const int y = blockIdx.x; - const int n = threadIdx.x; - - for (index_t n_index = n; n_index < nmax; n_index += n_size) { - DType smax = src[y][0][n_index]; - for (index_t i = 1; i < xmax; ++i) { - smax = max(smax, src[y][i][n_index]); // NOLINT(*) - } - DType ssum = 0.0f; - for (index_t i = 0; i < xmax; ++i) { - DType p = expf(src[y][i][n_index] - smax); - ssum += p; - dst[y][i][n_index] = p; - } - for (index_t i = 0; i < xmax; ++i) { - dst[y][i][n_index] /= ssum; - } - } -} - -template -inline void Softmax(const Tensor &dst, - const Tensor &src) { - dim3 dimBlock(kBaseThreadNum); - dim3 dimGrid(dst.size(0)); - CHECK_EQ(dst.shape_, src.shape_) << "Softmax: shape mismatch"; - CheckLaunchParam(dimGrid, dimBlock, "Softmax"); - cudaStream_t stream = Stream::GetStream(dst.stream_); - Softmax3DKernel<<>>(dst, src); - MSHADOW_CUDA_POST_KERNEL_CHECK(Softmax3DKernel); -} - -template -inline void SoftmaxGrad(const Tensor &dst, - const Tensor &src, - const Tensor &label) { - dim3 dimBlock(kBaseThreadNum); - dim3 dimGrid(dst.size(0)); - CHECK_EQ(dst.shape_, src.shape_) << "SoftmaxGrad: shape mismatch"; - CHECK_EQ(dst.size(0), label.size(0)) << "SoftmaxGrad: label shape mismatch"; - CHECK_EQ(dst.size(2), label.size(1)) << "SoftmaxGrad: label shape mismatch"; - CheckLaunchParam(dimGrid, dimBlock, "SoftmaxGrad"); - cudaStream_t stream = Stream::GetStream(dst.stream_); - Softmax3DGradKernel<<>>(dst, src, label); - MSHADOW_CUDA_POST_KERNEL_CHECK(Softmax3DGradKernel); -} - -template -inline void SoftmaxGrad(const Tensor &dst, - const Tensor &src, - const Tensor &label, - const DType &ignore_label) { - dim3 dimBlock(kBaseThreadNum); - dim3 dimGrid(dst.size(0)); - CHECK_EQ(dst.shape_, src.shape_) << "SoftmaxGrad: shape mismatch"; - CHECK_EQ(dst.size(0), label.size(0)) << "SoftmaxGrad: label shape mismatch"; - CHECK_EQ(dst.size(2), label.size(1)) << "SoftmaxGrad: label shape mismatch"; - CheckLaunchParam(dimGrid, dimBlock, "SoftmaxGrad"); - cudaStream_t stream = Stream::GetStream(dst.stream_); - Softmax3DGradKernel<<>>( - dst, src, label, ignore_label); - MSHADOW_CUDA_POST_KERNEL_CHECK(Softmax3DGradKernel); -} - -template -__global__ void AddTakeGradKernel(DstPlan dst, - SrcPlan1 index, SrcPlan2 src, - index_t ymax, index_t xmax, const int K) { - const unsigned x_size = 1 << x_bits; - const int xindex = blockIdx.x * x_size + threadIdx.x; - __shared__ int ptr; - for (unsigned y = 0; y < ymax; ++y) { - if (threadIdx.x == 0) { - ptr = index.Eval(0, y); - if (ptr <= 0) ptr = 0; - else if (ptr >= K) ptr = K - 1; - } - __syncthreads(); - if (xindex < xmax) { - dst.REval(ptr, xindex) += src.Eval(y, xindex); - } - } -} - -template -__global__ void AddTakeGradLargeBatchKernel(DType* dst, - const IdxType *sorted, const IdxType *index, - const DType *src, - int ymax, int xmax) { - // Based on Torch's Version /~https://github.com/torch/cunn/blob/master/lib/THCUNN/LookupTable.cu - // Each warp is responsible for an input into the LookupTable. - // If the preceeding input has the same as this input, then the warp - // exits immediately. The warp also processes subsequent inputs with the - // same value. - // - // Input Warp - // 1 - // 1 ( exits without doing any work) - // 5 - // 8 - // Also, all warp will loop for SZ times to increase the throughput. - - const int warp_size = 1 << warp_bits; - int idx = blockIdx.x * blockDim.y + threadIdx.y; - - if (idx < ymax - && (idx == 0 || sorted[idx] != sorted[idx - 1])) { - do { - const int start_feature = threadIdx.x + blockIdx.y * blockDim.x * SZ; - const int dst_row = static_cast(sorted[idx]) * xmax; - const int src_row = static_cast(index[idx]) * xmax; - float grad_out[SZ]; - float grad_weight[SZ]; - #pragma unroll - for (int ii = 0; ii < SZ; ii++) { - int feature_dim = start_feature + ii * warp_size; - if (feature_dim < xmax) { - grad_out[ii] = src[src_row + feature_dim]; - grad_weight[ii] = dst[dst_row + feature_dim]; - } - } - - #pragma unroll - for (int ii = 0; ii < SZ; ii++) { - grad_weight[ii] += grad_out[ii]; - } - - #pragma unroll - for (int ii = 0; ii < SZ; ii++) { - int feature_dim = start_feature + ii * warp_size; - if (feature_dim < xmax) { - dst[dst_row + feature_dim] = grad_weight[ii]; - } - } - idx++; - } while (idx < ymax && (sorted[idx] == sorted[idx - 1])); - } -} - -template -inline void AddTakeGrad(Tensor dst, - const Tensor& index, - const Tensor &src) { - CHECK_EQ(dst.CheckContiguous(), true); - CHECK_EQ(index.CheckContiguous(), true); - CHECK_EQ(src.CheckContiguous(), true); - const int kUnitBits = kMemUnitBits + 1; - dim3 dimBlock(1 << kUnitBits); - dim3 dimGrid((dst.size(1) + (1 << kUnitBits) - 1) >> kUnitBits); - - CHECK_EQ(dst.size(1), src.size(1)) << "AddTakeGrad: shape mismatch"; - CHECK_EQ(index.size(0), src.size(0)) << "AddTakeGrad: shape mismatch"; - CheckLaunchParam(dimGrid, dimBlock, "AddTakeGrad"); - cudaStream_t stream = Stream::GetStream(dst.stream_); - const int K = dst.shape_[0]; - - AddTakeGradKernel - <<>> - (expr::MakePlan(dst), - expr::MakePlan(index), - expr::MakePlan(src), - src.size(0), - src.size(1), K); - MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradKernel); -} - -template -inline void AddTakeGradLargeBatch(Tensor dst, - const Tensor& sorted, - const Tensor& index, - const Tensor &src) { - CHECK_EQ(dst.CheckContiguous(), true); - CHECK_EQ(sorted.CheckContiguous(), true); - CHECK_EQ(index.CheckContiguous(), true); - CHECK_EQ(src.CheckContiguous(), true); - const int kWarpBits = kMemUnitBits; - const int SZ = 4; - const int block_dim_x = 1 << kWarpBits; - const int block_dim_y = 4; - const int grid_dim_x = (src.size(0) + block_dim_y - 1) / block_dim_y; - const int grid_dim_y = (src.size(1) + block_dim_x * SZ - 1) / (block_dim_x * SZ); - dim3 dimBlock(block_dim_x, block_dim_y); - dim3 dimGrid(grid_dim_x, grid_dim_y); - - CHECK_EQ(dst.size(1), src.size(1)) << "AddTakeGradLargeBatch: shape mismatch"; - CHECK_EQ(index.size(0), src.size(0)) << "AddTakeGradLargeBatch: shape mismatch"; - CheckLaunchParam(dimGrid, dimBlock, "AddTakeGradLargeBatch"); - cudaStream_t stream = Stream::GetStream(dst.stream_); - - AddTakeGradLargeBatchKernel - <<>> - (dst.dptr_, - sorted.dptr_, - index.dptr_, - src.dptr_, - static_cast(src.size(0)), - static_cast(src.size(1))); - MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradLargeBatchKernel); -} - -template -__global__ void IndexFillKernel(DstPlan dst, - const IndexPlan index, - const SrcPlan src, - const int ymax, - const int xmax) { - int bid = blockIdx.y * blockDim.x + blockIdx.x; - int tid = bid * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x; - if (tid < ymax * xmax) { - int i = tid / xmax; - int j = tid % xmax; - int k = static_cast(index.Eval(0, i)); - dst.REval(k, j) = src.Eval(i, j); - } -} - -template -inline void IndexFill(Tensor dst, - const Tensor& index, - const Tensor &src) { - CHECK_EQ(dst.CheckContiguous(), true); - CHECK_EQ(index.CheckContiguous(), true); - CHECK_EQ(src.CheckContiguous(), true); - CHECK_EQ(dst.size(1), src.size(1)) << "IndexFill: shape mismatch"; - CHECK_EQ(index.size(0), src.size(0)) << "IndexFill: shape mismatch"; - const int block_dim_x = 1 << kMemUnitBits; - const int block_dim_y = 1 << kMemUnitBits; - const int block_size = block_dim_x * block_dim_y; - int grid_dim_x = (src.size(0) * src.size(1) + block_size - 1) / block_size; - int grid_dim_y = 1; - while (grid_dim_x > kMaxGridDim) { - grid_dim_x = (grid_dim_x + 1) / 2; - grid_dim_y *= 2; - } - dim3 dimBlock(block_dim_x, block_dim_y); - dim3 dimGrid(grid_dim_x, grid_dim_y); - CheckLaunchParam(dimGrid, dimBlock, "IndexFill"); - cudaStream_t stream = Stream::GetStream(dst.stream_); - - IndexFillKernel - <<>> - (expr::MakePlan(dst), - expr::MakePlan(index), - expr::MakePlan(src), - src.size(0), - src.size(1)); - MSHADOW_CUDA_POST_KERNEL_CHECK(IndexFillKernel); -} - -template -inline void SortByKey(Tensor keys, Tensor values, - bool is_ascend) { - CHECK_EQ(keys.CheckContiguous(), true); - CHECK_EQ(values.CheckContiguous(), true); -#if CUDA_VERSION >= 7000 - cudaStream_t stream = Stream::GetStream(keys.stream_); - thrust::device_ptr key_iter = thrust::device_pointer_cast(keys.dptr_); - thrust::device_ptr value_iter = thrust::device_pointer_cast(values.dptr_); - if (is_ascend) { - thrust::stable_sort_by_key( - thrust::cuda::par.on(stream), - key_iter, key_iter + keys.size(0), value_iter, thrust::less()); // NOLINT(*) - } else { - thrust::stable_sort_by_key( - thrust::cuda::par.on(stream), - key_iter, key_iter + keys.size(0), value_iter, thrust::greater()); // NOLINT(*) - } - MSHADOW_CUDA_POST_KERNEL_CHECK(SortByKey); -#else - LOG(FATAL) << "SortByKey is only supported for CUDA version >=7.0!"; -#endif -} - -template -inline void SortByKey(Tensor keys, Tensor values, - bool is_ascend) { - LOG(FATAL) << "SortByKey for half_t is not implemented!"; -} - -template -inline void SortByKey(Tensor keys, Tensor values, - bool is_ascend) { - LOG(FATAL) << "SortByKey for half_t is not implemented!"; -} - -// break ambiguous template deduction for -inline void SortByKey(Tensor keys, - Tensor values, - bool is_ascend) { - LOG(FATAL) << "SortByKey for half_t is not implemented!"; -} -} // namespace cuda -} // namespace mshadow -#endif // MSHADOW_CUDA_TENSOR_GPU_INL_CUH_ diff --git a/include/mshadow/dot_engine-inl.h b/include/mshadow/dot_engine-inl.h deleted file mode 100644 index 5363974fc941..000000000000 --- a/include/mshadow/dot_engine-inl.h +++ /dev/null @@ -1,906 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file dot_engine-inl.h - * \brief definitions of how Matrix Multiplications can be evaluated - * \author Tianqi Chen - */ -#ifndef MSHADOW_DOT_ENGINE_INL_H_ -#define MSHADOW_DOT_ENGINE_INL_H_ - -#include -#include "./base.h" -#include "./extension/implicit_gemm.h" - -#ifdef __CUDACC__ -#include "./cuda/tensor_gpu-inl.cuh" -#endif // #ifdef __CUDACC__ - -namespace mshadow { - /*! -* \brief CPU/GPU: Get a batched view of the src array. dst[i] = src + i * stride -* \param dst 2D pointer -* \param src 1D pointer -* \param num number of batches -* \param stride size of each batch -* \param stream -*/ -template -inline void GetBatchedView(DType **dst, DType *src, int num, int stride, - Stream *stream); -template -inline void GetBatchedView(DType **dst, DType *src, int num, int stride, - Stream *stream) { - for (int i = 0; i < num; i++) { - dst[i] = src + i * stride; - } -} -#ifdef __CUDACC__ -namespace cuda {}; -template -inline void GetBatchedView(DType **dst, DType *src, int num, int stride, - Stream *stream) { - cuda::GetBatchedView(dst, src, num, stride, stream); -} -#endif // #ifdef __CUDACC__ - -namespace expr { -//--------------------------------------------------------------------- -// Matrix Multiplications, depends on BLAS Engine -//--------------------------------------------------------------------- -template -struct DotEngine { - inline static void Eval(Tensor *p_dst, - const Tensor &lhs, - const Tensor &rhs, - DType scale); -}; -// handles the dot, use CblasColMajor -template -struct BLASEngine { - inline static bool GetT(bool t) { - return t ? true : false; - } - inline static void SetStream(Stream *stream) { - } - inline static void gemm(Stream *stream, - bool transa, bool transb, - int m, int n, int k, DType alpha, - const DType *A, int lda, const DType *B, int ldb, - DType beta, DType *C, int ldc) { - LOG(FATAL) << "Not implmented!"; - } - inline static void batched_gemm(Stream *stream, - bool transa, bool transb, - int m, int n, int k, DType alpha, - const DType *A, int lda, const DType *B, int ldb, - DType beta, DType *C, int ldc, int batch_count, - DType **workspace) { - LOG(FATAL) << "Not implmented!"; - } - inline static void gemv(Stream *stream, - bool trans, int m, int n, - DType alpha, const DType *A, int lda, - const DType *X, int incX, - DType beta, DType *Y, int incY) { - LOG(FATAL) << "Not implmented!"; - } - inline static void batched_gemv(Stream *stream, - bool trans, int m, int n, - DType alpha, const DType *A, int lda, - const DType *X, int incX, - DType beta, DType *Y, int incY, int batch_count) { - LOG(FATAL) << "Not implmented!"; - } - inline static void ger(Stream *stream, - int m, int n, DType alpha, - const DType *X, int incX, - const DType *Y, int incY, DType *A, int lda) { - LOG(FATAL) << "Not implmented!"; - } - inline static void batched_ger(Stream *stream, - int m, int n, DType alpha, - const DType *X, int incX, - const DType *Y, int incY, DType *A, int lda, int batch_count) { - LOG(FATAL) << "Not implmented!"; - } - inline static void dot(Stream *stream, - int n, - const DType* X, int incX, - const DType* Y, int incY, - DType* ret) { - LOG(FATAL) << "Not implmented!"; - } -}; - -#if MSHADOW_STAND_ALONE -template<> -struct BLASEngine { - inline static bool GetT(bool t) { - return t ? true : false; - } - inline static void SetStream(Stream *stream) { - } - inline static void gemm(Stream *stream, - bool transa, bool transb, - int m, int n, int k, float alpha, - const float *A, int lda, const float *B, int ldb, - float beta, float *C, int ldc) { - if (alpha == 1.0f && beta == 0.0f) { - bool transpose_left = transb; - bool transpose_right = transa; - Tensor lhs((float*)B, Shape2(transpose_left ? k : n, transpose_left ? n : k)); // NOLINT(*) - Tensor rhs((float*)A, Shape2(transpose_right ? m : k, transpose_right ? k : m)); // NOLINT(*) - Tensor dst(C, Shape2(m, n)); - if (!transpose_left && !transpose_right) { - dst = expr::implicit_dot(lhs, rhs); return; - } else if (!transpose_left && transpose_right) { - dst = expr::implicit_dot(lhs, rhs.T()); return; - } else if (transpose_left && !transpose_right) { - dst = expr::implicit_dot(lhs.T(), rhs); return; - } else { - LOG(FATAL) << "Not implmented!"; - } - } else { - LOG(FATAL) << "Not implmented!"; - } - } - inline static void batched_gemm(Stream *stream, - bool transa, bool transb, - int m, int n, int k, float alpha, - const float *A, int lda, const float *B, int ldb, - float beta, float *C, int ldc, int batch_count, - float **workspace) { - for (int i = 0; i < batch_count; ++i) { - gemm(stream, transa, transb, m, n, k, alpha, - A + i * m * k, lda, B + i * k * n, ldb, - beta, C + i * m * n, ldc); - } - } - inline static void gemv(Stream *stream, - bool trans, int m, int n, - float alpha, const float *A, int lda, - const float *X, int incX, - float beta, float *Y, int incY) { - LOG(FATAL) << "Not implmented!"; - } - inline static void batched_gemv(Stream *stream, - bool trans, int m, int n, - float alpha, const float *A, int lda, - const float *X, int incX, - float beta, float *Y, int incY, int batch_count) { - LOG(FATAL) << "Not implmented!"; - } - inline static void ger(Stream *stream, - int m, int n, float alpha, - const float *X, int incX, - const float *Y, int incY, float *A, int lda) { - LOG(FATAL) << "Not implmented!"; - } - inline static void batched_ger(Stream *stream, - int m, int n, float alpha, - const float *X, int incX, - const float *Y, int incY, float *A, int lda, int batch_count) { - LOG(FATAL) << "Not implmented!"; - } - inline static void dot(Stream *stream, - int n, - const float* X, int incX, - const float* Y, int incY, - float* ret) { - LOG(FATAL) << "Not implmented!"; - } -}; - -template<> -struct BLASEngine { - inline static bool GetT(bool t) { - return t ? true : false; - } - inline static void SetStream(Stream *stream) { - } - inline static void gemm(Stream *stream, - bool transa, bool transb, - int m, int n, int k, double alpha, - const double *A, int lda, const double *B, int ldb, - double beta, double *C, int ldc) { - if (alpha == 1.0f && beta == 0.0f) { - bool transpose_left = transb; - bool transpose_right = transa; - Tensor lhs((double*)B, Shape2(transpose_left ? k : n, transpose_left ? n : k)); // NOLINT(*) - Tensor rhs((double*)A, Shape2(transpose_right ? m : k, transpose_right ? k : m)); // NOLINT(*) - Tensor dst(C, Shape2(m, n)); - if (!transpose_left && !transpose_right) { - dst = expr::implicit_dot(lhs, rhs); return; - } else if (!transpose_left && transpose_right) { - dst = expr::implicit_dot(lhs, rhs.T()); return; - } else if (transpose_left && !transpose_right) { - dst = expr::implicit_dot(lhs.T(), rhs); return; - } else { - LOG(FATAL) << "Not implmented!"; - } - } else { - LOG(FATAL) << "Not implmented!"; - } - } - inline static void batched_gemm(Stream *stream, - bool transa, bool transb, - int m, int n, int k, double alpha, - const double *A, int lda, const double *B, int ldb, - double beta, double *C, int ldc, int batch_count, - double **workspace) { - for (int i = 0; i < batch_count; ++i) { - gemm(stream, transa, transb, m, n, k, alpha, - A + i * m * k, lda, B + i * k * n, ldb, - beta, C + i * m * n, ldc); - } - } - inline static void gemv(Stream *stream, - bool trans, int m, int n, - double alpha, const double *A, int lda, - const double *X, int incX, - double beta, double *Y, int incY) { - LOG(FATAL) << "Not implmented!"; - } - inline static void batched_gemv(Stream *stream, - bool trans, int m, int n, - double alpha, const double *A, int lda, - const double *X, int incX, - double beta, double *Y, int incY, int batch_count) { - LOG(FATAL) << "Not implmented!"; - } - inline static void ger(Stream *stream, - int m, int n, double alpha, - const double *X, int incX, - const double *Y, int incY, double *A, int lda) { - LOG(FATAL) << "Not implmented!"; - } - inline static void batched_ger(Stream *stream, - int m, int n, double alpha, - const double *X, int incX, - const double *Y, int incY, double *A, int lda, int batch_count) { - LOG(FATAL) << "Not implmented!"; - } - inline static void dot(Stream *stream, - int n, - const double* X, int incX, - const double* Y, int incY, - double* ret) { - LOG(FATAL) << "Not implmented!"; - } -}; - -#elif (MSHADOW_USE_MKL || MSHADOW_USE_CBLAS) // NOLINT(*) -template<> -struct BLASEngine { - inline static CBLAS_TRANSPOSE GetT(bool t) { - return t ? CblasTrans : CblasNoTrans; - } - inline static void SetStream(Stream *stream) { - } - inline static void gemm(Stream *stream, - bool transa, bool transb, - int m, int n, int k, float alpha, - const float *A, int lda, const float *B, int ldb, - float beta, float *C, int ldc) { - cblas_sgemm(CblasColMajor, GetT(transa), GetT(transb), - m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); - } - inline static void batched_gemm(Stream *stream, - bool transa, bool transb, - int m, int n, int k, float alpha, - const float *A, int lda, const float *B, int ldb, - float beta, float *C, int ldc, int batch_count, - float **workspace) { -#if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000) - std::vector p_m(batch_count, m); - std::vector p_n(batch_count, n); - std::vector p_k(batch_count, k); - std::vector p_lda(batch_count, lda); - std::vector p_ldb(batch_count, ldb); - std::vector p_ldc(batch_count, ldc); - std::vector p_alpha(batch_count, alpha); - std::vector p_beta(batch_count, beta); - std::vector pp_A; - std::vector pp_B; - std::vector pp_C; - - CBLAS_TRANSPOSE cblas_a_trans = GetT(transa); - CBLAS_TRANSPOSE cblas_b_trans = GetT(transb); - - std::vector p_group_sizeb(batch_count, batch_count); - std::vector p_transa(batch_count, cblas_a_trans); - std::vector p_transb(batch_count, cblas_b_trans); - - auto m_k = m * k; - auto k_n = k * n; - auto m_n = m * n; - - for (int i = 0; i < batch_count; i++) { - pp_A.push_back(A + i * m_k); - pp_B.push_back(B + i * k_n); - pp_C.push_back(C + i * m_n); - } - - cblas_sgemm_batch(CblasColMajor, p_transa.data(), p_transb.data(), - p_m.data(), p_n.data(), p_k.data(), - p_alpha.data(), pp_A.data(), p_lda.data(), pp_B.data(), - p_ldb.data(), p_beta.data(), pp_C.data(), p_ldc.data(), - 1, p_group_sizeb.data()); -#else - for (int i = 0; i < batch_count; ++i) { - gemm(stream, transa, transb, m, n, k, alpha, - A + i * m * k, lda, B + i * k * n, ldb, - beta, C + i * m * n, ldc); - } -#endif - } - inline static void gemv(Stream *stream, - bool trans, int m, int n, - float alpha, const float *A, int lda, - const float *X, int incX, - float beta, float *Y, int incY) { - cblas_sgemv(CblasColMajor, GetT(trans), m, n, alpha, - A, lda, X, incX, beta, Y, incY); - } - inline static void batched_gemv(Stream *stream, - bool trans, int m, int n, - float alpha, const float *A, int lda, - const float *X, int incX, - float beta, float *Y, int incY, int batch_count) { - for (int i = 0; i < batch_count; ++i) { - gemv(stream, trans, m, n, alpha, A + i * m * n, lda, - X + i * (trans ? m : n) * incX, incX, - beta, Y + i * (trans ? n : m) * incY, incY); - } - } - inline static void ger(Stream *stream, - int m, int n, float alpha, - const float *X, int incX, - const float *Y, int incY, float *A, int lda) { - cblas_sger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda); - } - inline static void batched_ger(Stream *stream, - int m, int n, float alpha, - const float *X, int incX, - const float *Y, int incY, float *A, int lda, int batch_count) { - for (int i = 0; i < batch_count; ++i) { - ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, - A + i * lda * n, lda); - } - } - inline static void dot(Stream *stream, - int n, - const float* X, int incX, - const float* Y, int incY, - float* ret) { - *ret = cblas_sdot(n, X, incX, Y, incY); - } -}; - -template<> -struct BLASEngine { - inline static CBLAS_TRANSPOSE GetT(bool t) { - return t ? CblasTrans : CblasNoTrans; - } - inline static void SetStream(Stream *stream) { - } - inline static void gemm(Stream *stream, - bool transa, bool transb, - int m, int n, int k, double alpha, - const double *A, int lda, const double *B, int ldb, - double beta, double *C, int ldc) { - cblas_dgemm(CblasColMajor, GetT(transa), GetT(transb), - m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); - } - inline static void batched_gemm(Stream *stream, - bool transa, bool transb, - int m, int n, int k, double alpha, - const double *A, int lda, const double *B, int ldb, - double beta, double *C, int ldc, int batch_count, - double **workspace) { -#if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000) - std::vector p_m(batch_count, m); - std::vector p_n(batch_count, n); - std::vector p_k(batch_count, k); - std::vector p_lda(batch_count, lda); - std::vector p_ldb(batch_count, ldb); - std::vector p_ldc(batch_count, ldc); - std::vector p_alpha(batch_count, alpha); - std::vector p_beta(batch_count, beta); - std::vector pp_A; - std::vector pp_B; - std::vector pp_C; - - CBLAS_TRANSPOSE cblas_a_trans = GetT(transa); - CBLAS_TRANSPOSE cblas_b_trans = GetT(transb); - - std::vector p_group_sizeb(batch_count, batch_count); - std::vector p_transa(batch_count, cblas_a_trans); - std::vector p_transb(batch_count, cblas_b_trans); - - auto m_k = m * k; - auto k_n = k * n; - auto m_n = m * n; - - for (int i = 0; i < batch_count; i++) { - pp_A.push_back(A + i * m_k); - pp_B.push_back(B + i * k_n); - pp_C.push_back(C + i * m_n); - } - - cblas_dgemm_batch(CblasColMajor, p_transa.data(), p_transb.data(), - p_m.data(), p_n.data(), p_k.data(), - p_alpha.data(), pp_A.data(), p_lda.data(), pp_B.data(), - p_ldb.data(), p_beta.data(), pp_C.data(), p_ldc.data(), - 1, p_group_sizeb.data()); -#else - for (int i = 0; i < batch_count; ++i) { - gemm(stream, transa, transb, m, n, k, alpha, - A + i * m * k, lda, B + i * k * n, ldb, - beta, C + i * m * n, ldc); - } -#endif - } - inline static void gemv(Stream *stream, - bool trans, int m, int n, double alpha, - const double *A, int lda, - const double *X, int incX, - double beta, double *Y, int incY) { - cblas_dgemv(CblasColMajor, GetT(trans), m, n, alpha, - A, lda, X, incX, beta, Y, incY); - } - inline static void batched_gemv(Stream *stream, - bool trans, int m, int n, - double alpha, const double *A, int lda, - const double *X, int incX, - double beta, double *Y, int incY, int batch_count) { - for (int i = 0; i < batch_count; ++i) { - gemv(stream, trans, m, n, alpha, A + i * m * n, lda, - X + i * (trans ? m : n) * incX, incX, - beta, Y + i * (trans ? n : m) * incY, incY); - } - } - inline static void ger(Stream *stream, - int m, int n, double alpha, - const double *X, int incX, - const double *Y, int incY, double *A, int lda) { - cblas_dger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda); - } - inline static void batched_ger(Stream *stream, - int m, int n, double alpha, - const double *X, int incX, - const double *Y, int incY, double *A, int lda, int batch_count) { - for (int i = 0; i < batch_count; ++i) { - ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, - A + i * lda * n, lda); - } - } - inline static void dot(Stream *stream, - int n, - const double* X, int incX, - const double* Y, int incY, - double* ret) { - *ret = cblas_ddot(n, X, incX, Y, incY); - } -}; -#endif // MSHADOW_USE_CBLAS || MSHADOW_USE_MKL || MSHADOW_STAND_ALONE -// CuBLAS redirect code -#if MSHADOW_USE_CUDA -// All CuBLAS goes to here, use legacy API: not threadsafe -template<> -struct BLASEngine { - inline static cublasOperation_t GetT(bool t) { - return t ? CUBLAS_OP_T : CUBLAS_OP_N; - } - inline static void SetStream(Stream *stream) { - cublasStatus_t err = cublasSetStream(Stream::GetBlasHandle(stream), - Stream::GetStream(stream)); - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas set stream fail"; - } - inline static void gemm(Stream *stream, - bool transa, bool transb, - int m, int n, int k, half::half_t alpha, - const half::half_t *A, int lda, - const half::half_t *B, int ldb, half::half_t beta, - half::half_t *C, int ldc) { -#if defined(CUDA_VERSION) && CUDA_VERSION >= 7050 - // Always use pseudo-fp16: fp32 compute with fp16 I/O. - float alpha_f = float(alpha); // NOLINT(*) - float beta_f = float(beta); // NOLINT(*) - #if CUDA_VERSION >= 8000 - cublasStatus_t err = cublasSgemmEx(Stream::GetBlasHandle(stream), - GetT(transa), GetT(transb), m, n, k, &alpha_f, - A, CUDA_R_16F, lda, B, CUDA_R_16F, - ldb, &beta_f, C, CUDA_R_16F, ldc); - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas SgemmEx fail"; - #else - cublasStatus_t err = cublasSgemmEx(Stream::GetBlasHandle(stream), - GetT(transa), GetT(transb), m, n, k, &alpha_f, - A, CUBLAS_DATA_HALF, lda, B, CUBLAS_DATA_HALF, - ldb, &beta_f, C, CUBLAS_DATA_HALF, ldc); - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas SgemmEx fail"; - #endif // CUDA_VERSION >= 8000 -#else - LOG(FATAL) << "Require CUDA version >= 7.5!"; -#endif // defined(CUDA_VERSION) && CUDA_VERSION >= 7050 - } - inline static void batched_gemm(Stream *stream, - bool transa, bool transb, - int m, int n, int k, half::half_t alpha, - const half::half_t *A, int lda, const half::half_t *B, int ldb, - half::half_t beta, half::half_t *C, int ldc, int batch_count, - half::half_t **workspace) { - for (int i = 0; i < batch_count; ++i) { - gemm(stream, transa, transb, m, n, k, alpha, - A + i * m * k, lda, B + i * k * n, ldb, - beta, C + i * m * n, ldc); - } - } - inline static void gemv(Stream *stream, - bool trans, int m, int n, half::half_t alpha, - const half::half_t *A, int lda, - const half::half_t *X, int incX, half::half_t beta, - half::half_t *Y, int incY) { - LOG(FATAL) << "Not implmented!"; - } - inline static void batched_gemv(Stream *stream, - bool trans, int m, int n, - half::half_t alpha, const half::half_t *A, int lda, - const half::half_t *X, int incX, - half::half_t beta, half::half_t *Y, int incY, int batch_count) { - LOG(FATAL) << "Not implmented!"; - } - inline static void ger(Stream *stream, - int m, int n, half::half_t alpha, - const half::half_t *X, int incX, - const half::half_t *Y, int incY, half::half_t *A, int lda) { - LOG(FATAL) << "Not implmented!"; - } - inline static void batched_ger(Stream *stream, - int m, int n, half::half_t alpha, - const half::half_t *X, int incX, const half::half_t *Y, int incY, - half::half_t *A, int lda, int batch_count) { - LOG(FATAL) << "Not implmented!"; - } - inline static void dot(Stream *stream, - int n, - const half::half_t* X, int incX, - const half::half_t* Y, int incY, - half::half_t *ret) { - LOG(FATAL) << "Not implmented!"; - } -}; - -template<> -struct BLASEngine { - inline static cublasOperation_t GetT(bool t) { - return t ? CUBLAS_OP_T : CUBLAS_OP_N; - } - inline static void SetStream(Stream *stream) { - cublasStatus_t err = cublasSetStream(Stream::GetBlasHandle(stream), - Stream::GetStream(stream)); - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: set stream fail"; - } - inline static void gemm(Stream *stream, - bool transa, bool transb, - int m, int n, int k, float alpha, - const float *A, int lda, - const float *B, int ldb, float beta, - float *C, int ldc) { - cublasStatus_t err = cublasSgemm(Stream::GetBlasHandle(stream), - GetT(transa), GetT(transb), m, n, k, &alpha, - A, lda, B, ldb, &beta, C, ldc); - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sgemm fail"; - } - inline static void batched_gemm(Stream *stream, - bool transa, bool transb, - int m, int n, int k, float alpha, - const float *A, int lda, const float *B, int ldb, - float beta, float *C, int ldc, int batch_count, - float **workspace) { -#if defined(__CUDACC__) && CUDA_VERSION >= 4010 && CUDA_VERSION < 8000 - // Cast DType* to DType** using workspace as a buffer - bool alloc_workspace = false; - if (workspace == NULL) { - // Allocate the workspace if it's NULL. - // TODO(sxjscience) Try to move the allocation inside Tensor, which is thread-safe. - cudaMalloc(reinterpret_cast(&workspace), 3 * batch_count * sizeof(float*)); - alloc_workspace = true; - } - GetBatchedView(workspace, const_cast(A), batch_count, m * k, stream); - GetBatchedView(workspace + batch_count, - const_cast(B), batch_count, k * n, stream); - GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream); - cublasStatus_t err = cublasSgemmBatched(Stream::GetBlasHandle(stream), - GetT(transa), GetT(transb), m, n, k, &alpha, - (const float**)workspace, lda, - (const float**)(workspace + batch_count), ldb, - &beta, workspace + 2 * batch_count, ldc, batch_count); - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: SgemmBatched fail"; - if (alloc_workspace) { - cudaFree(workspace); - } -#elif defined(__CUDACC__) && CUDA_VERSION >= 8000 - cublasStatus_t err = cublasSgemmStridedBatched(Stream::GetBlasHandle(stream), - GetT(transa), GetT(transb), m, n, k, &alpha, - A, lda, m * k, - B, ldb, k * n, - &beta, C, ldc, m * n, - batch_count); - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: SgemmStridedBatched fail"; -#else - for (int i = 0; i < batch_count; ++i) { - gemm(stream, transa, transb, m, n, k, alpha, - A + i * m * k, lda, B + i * k * n, ldb, - beta, C + i * m * n, ldc); - } -#endif // defined(__CUDACC__) && CUDA_VERSION >= 4010 - } - inline static void gemv(Stream *stream, - bool trans, int m, int n, float alpha, - const float *A, int lda, - const float *X, int incX, float beta, - float *Y, int incY) { - cublasStatus_t err = cublasSgemv(Stream::GetBlasHandle(stream), - GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY); - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sgemv fail"; - } - inline static void batched_gemv(Stream *stream, - bool trans, int m, int n, - float alpha, const float *A, int lda, - const float *X, int incX, - float beta, float *Y, int incY, int batch_count) { - for (int i = 0; i < batch_count; ++i) { - gemv(stream, trans, m, n, alpha, A + i * m * n, lda, - X + i * (trans ? m : n) * incX, incX, - beta, Y + i * (trans ? n : m) * incY, incY); - } - } - inline static void ger(Stream *stream, - int m, int n, float alpha, - const float *X, int incX, - const float *Y, int incY, float *A, int lda) { - cublasStatus_t err = cublasSger(Stream::GetBlasHandle(stream), - m, n, &alpha, X, incX, Y, incY, A, lda); - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sger fail"; - } - inline static void batched_ger(Stream *stream, - int m, int n, float alpha, - const float *X, int incX, - const float *Y, int incY, float *A, int lda, int batch_count) { - for (int i = 0; i < batch_count; ++i) { - ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, - A + i * lda * n, lda); - } - } - inline static void dot(Stream *stream, - int n, - const float* X, int incX, - const float* Y, int incY, - float *ret) { - cublasSetPointerMode(Stream::GetBlasHandle(stream), - CUBLAS_POINTER_MODE_DEVICE); - cublasStatus_t err = cublasSdot(Stream::GetBlasHandle(stream), - n, X, incX, Y, incY, ret); - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dot fail"; - cublasSetPointerMode(Stream::GetBlasHandle(stream), - CUBLAS_POINTER_MODE_HOST); - } -}; - -template<> -struct BLASEngine { - inline static cublasOperation_t GetT(bool t) { - return t ? CUBLAS_OP_T : CUBLAS_OP_N; - } - inline static void SetStream(Stream *stream) { - cublasStatus_t err = cublasSetStream(Stream::GetBlasHandle(stream), - Stream::GetStream(stream)); - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: set stream fail"; - } - inline static void gemm(Stream *stream, - bool transa, bool transb, - int m, int n, int k, double alpha, - const double *A, int lda, - const double *B, int ldb, - double beta, double *C, int ldc) { - cublasStatus_t err = cublasDgemm(Stream::GetBlasHandle(stream), - GetT(transa), GetT(transb), m, n, k, &alpha, - A, lda, B, ldb, &beta, C, ldc); - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dgemm fail"; - } - inline static void batched_gemm(Stream *stream, - bool transa, bool transb, - int m, int n, int k, double alpha, - const double *A, int lda, const double *B, int ldb, - double beta, double *C, int ldc, int batch_count, - double **workspace) { -#if defined(__CUDACC__) && CUDA_VERSION >= 4010 && CUDA_VERSION < 8000 - // Cast DType* to DType** using workspace as a buffer - bool alloc_workspace = false; - if (workspace == NULL) { - // Allocate the workspace if it's NULL. - // TODO(sxjscience) Try to move the allocation inside Tensor, which is thread-safe. - cudaMalloc(reinterpret_cast(&workspace), 3 * batch_count * sizeof(double*)); - alloc_workspace = true; - } - GetBatchedView(workspace, const_cast(A), batch_count, m * k, stream); - GetBatchedView(workspace + batch_count, - const_cast(B), batch_count, k * n, stream); - GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream); - cublasStatus_t err = cublasDgemmBatched(Stream::GetBlasHandle(stream), - GetT(transa), GetT(transb), m, n, k, &alpha, - (const double**)workspace, lda, - (const double**)(workspace + batch_count), ldb, - &beta, workspace + 2 * batch_count, ldc, batch_count); - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: DgemmBatched fail"; - if (alloc_workspace) { - cudaFree(workspace); - } -#elif defined(__CUDACC__) && CUDA_VERSION >= 8000 - cublasStatus_t err = cublasDgemmStridedBatched(Stream::GetBlasHandle(stream), - GetT(transa), GetT(transb), m, n, k, &alpha, - A, lda, m * k, - B, ldb, k * n, - &beta, C, ldc, m * n, - batch_count); - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: DgemmStridedBatched fail"; -#else - for (int i = 0; i < batch_count; ++i) { - gemm(stream, transa, transb, m, n, k, alpha, - A + i * m * k, lda, B + i * k * n, ldb, - beta, C + i * m * n, ldc); - } -#endif // defined(__CUDACC__) && CUDA_VERSION >= 4010 - } - inline static void gemv(Stream *stream, - bool trans, int m, int n, double alpha, - const double *A, int lda, - const double *X, int incX, - double beta, double *Y, int incY) { - cublasStatus_t err = cublasDgemv(Stream::GetBlasHandle(stream), - GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY); - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dgemv fail"; - } - inline static void batched_gemv(Stream *stream, - bool trans, int m, int n, - double alpha, const double *A, int lda, - const double *X, int incX, - double beta, double *Y, int incY, int batch_count) { - for (int i = 0; i < batch_count; ++i) { - gemv(stream, trans, m, n, alpha, A + i * m * n, lda, - X + i * (trans ? m : n) * incX, incX, - beta, Y + i * (trans ? n : m) * incY, incY); - } - } - inline static void ger(Stream *stream, - int m, int n, double alpha, - const double *X, int incX, - const double *Y, int incY, double *A, int lda) { - cublasStatus_t err = cublasDger(Stream::GetBlasHandle(stream), - m, n, &alpha, X, incX, Y, incY, A, lda); - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dger fail"; - } - inline static void batched_ger(Stream *stream, - int m, int n, double alpha, - const double *X, int incX, - const double *Y, int incY, double *A, int lda, int batch_count) { - for (int i = 0; i < batch_count; ++i) { - ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, - A + i * lda * n, lda); - } - } - inline static void dot(Stream *stream, - int n, - const double* X, int incX, - const double* Y, int incY, - double *ret) { - cublasSetPointerMode(Stream::GetBlasHandle(stream), - CUBLAS_POINTER_MODE_DEVICE); - cublasStatus_t err = cublasDdot(Stream::GetBlasHandle(stream), - n, X, incX, Y, incY, ret); - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dot fail"; - cublasSetPointerMode(Stream::GetBlasHandle(stream), - CUBLAS_POINTER_MODE_HOST); - } -}; -#endif // MSHADOW_USE_CUDA -// helper function to decide which shape we are in -inline Shape<2> GetShape(const Shape<2> &shape, bool transpose) { - return transpose ? Shape2(shape[1], shape[0]) : shape; -} -// dst = dot(lhs[.T], rhs[.T]) -template -struct DotEngine { - inline static void Eval(Tensor *p_dst, - const Tensor &lhs, - const Tensor &rhs, - DType scale) { - Tensor &dst = *p_dst; -#if MSHADOW_STAND_ALONE - if (xpu::kDevMask == cpu::kDevMask && scale == 1.0f) { - if (!transpose_left && !transpose_right) { - dst = expr::implicit_dot(lhs, rhs); return; - } else if (!transpose_left && transpose_right) { - dst = expr::implicit_dot(lhs, rhs.T()); return; - } else if (transpose_left && !transpose_right) { - dst = expr::implicit_dot(lhs.T(), rhs); return; - } - } -#endif - // set kernel stream - // if there is no stream, crush - BLASEngine::SetStream(dst.stream_); - Shape<2> sleft = GetShape(lhs.shape_, transpose_left); - Shape<2> sright = GetShape(rhs.shape_, transpose_right); - CHECK(dst.size(0) == sleft[0] && dst.size(1) == sright[1] && sleft[1] == sright[0]) - << "dot-gemm: matrix shape mismatch"; - // use column major argument to compatible with most BLAS - BLASEngine::gemm - (dst.stream_, - transpose_right , transpose_left, - transpose_right ? rhs.size(0) : rhs.size(1), - transpose_left ? lhs.size(1) : lhs.size(0), - transpose_right ? rhs.size(1) : rhs.size(0), - DType(scale * SV::AlphaBLAS()), - rhs.dptr_, rhs.stride_, - lhs.dptr_, lhs.stride_, - DType(SV::BetaBLAS()), - dst.dptr_, dst.stride_); - } -}; -template -struct DotEngine { - inline static void Eval(Tensor *p_dst, - const Tensor &lhs, - const Tensor &rhs, - DType scale) { - Tensor &dst = *p_dst; - // set kernel stream - // if there is no stream, crush - BLASEngine::SetStream(dst.stream_); - Shape<2> sright = GetShape(rhs.shape_, transpose_right); - CHECK(dst.size(0) == sright[1] && lhs.size(0) == sright[0]) - << "dot-gemv: matrix shape mismatch" - << "dst: " << dst.shape_ << "\n" - << "lhs: " << lhs.shape_ << "\n" - << "rhs: " << sright << "\n"; - BLASEngine::gemv - (dst.stream_, - transpose_right, - rhs.size(1), rhs.size(0), scale * SV::AlphaBLAS(), - rhs.dptr_, rhs.stride_, - lhs.dptr_, 1, SV::BetaBLAS(), - dst.dptr_, 1); - } -}; -template -struct DotEngine { - inline static void Eval(Tensor *p_dst, - const Tensor &lhs, - const Tensor &rhs, - DType scale) { - Tensor &dst = *p_dst; - // set kernel stream - // if there is no stream, crush - BLASEngine::SetStream(dst.stream_); - CHECK(dst.size(0) == lhs.size(0) && dst.size(1) == rhs.size(0)) - << "dot-ger: matrix shape mismatch" - << "dst: " << dst.shape_ << "\n" - << "lhs: " << lhs.shape_ << "\n" - << "rhs: " << rhs.shape_; - if (SV::BetaBLAS() == 0.0f) { - BLASEngine::ger - (dst.stream_, rhs.size(0), lhs.size(0), scale * SV::AlphaBLAS(), - rhs.dptr_, 1, lhs.dptr_, 1, dst.dptr_, dst.stride_); - } else { - DotEngine::Eval(p_dst, lhs.FlatTo2D(), rhs.FlatTo2D(), scale); - } - } -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_DOT_ENGINE_INL_H_ diff --git a/include/mshadow/expr_engine-inl.h b/include/mshadow/expr_engine-inl.h deleted file mode 100644 index 6421ebcff812..000000000000 --- a/include/mshadow/expr_engine-inl.h +++ /dev/null @@ -1,482 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file expr_engine-inl.h - * \brief definitions of how expressions should be evaluated - * \author Tianqi Chen, Bing Xu - */ -#ifndef MSHADOW_EXPR_ENGINE_INL_H_ -#define MSHADOW_EXPR_ENGINE_INL_H_ -#include -#include -#include "./logging.h" -#include "./expression.h" -#include "./tensor.h" - -namespace mshadow { -namespace expr { -/*! - * \brief a general class that allows extension that makes tensors of some shape - * \tparam SubType type of subclass - * \tparam SrcExp source expression of the MakeTensorExp, the source of operation - * \tparam dim dimension of the expression - * \tparam DType the type of elements - */ -template -struct MakeTensorExp - : public Exp, - DType, type::kChainer> { - /*! \brief the shape of this expression */ - Shape shape_; - /*! \brief true self of subtype */ - inline const SubType& real_self(void) const{ - return *static_cast(this); - } -}; -//---------------------------------------------------------------------- -// This part of code gives plan that can be used to carry out execution -//--------------------------------------------------------------------- -// Declarations of plans -template -class Plan { - public: - /*! - * \brief evaluate the expression at index [y][x] - * to be implemented by SubType, for RValue, the return type will be DType & - */ - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const; -}; -// tensor plan -template -class Plan, DType> { - public: - explicit Plan(const Tensor &t) - : dptr_(t.dptr_), stride_(t.stride_) {} - // for RValue, the return type should be reference - MSHADOW_XINLINE DType &REval(index_t y, index_t x) { - return dptr_[y * stride_ + x]; - } - // const evaluation - MSHADOW_XINLINE const DType &Eval(index_t y, index_t x) const { - return dptr_[y * stride_ + x]; - } - - private: - DType *dptr_; - index_t stride_; -}; -// special evaluation case for 1d tensor, no stride -template -class Plan, DType> { - public: - explicit Plan(const Tensor &t) : dptr_(t.dptr_) {} - MSHADOW_XINLINE DType &REval(index_t y, index_t x) { - return dptr_[x]; - } - MSHADOW_XINLINE const DType &Eval(index_t y, index_t x) const { - return dptr_[x]; - } - - private: - DType *dptr_; -}; -// scalar -template -class Plan, DType> { - public: - explicit Plan(DType scalar) : scalar_(scalar) {} - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - return scalar_; - } - - private: - DType scalar_; -}; -// unary expression -template -class Plan, DstDType> { - public: - explicit Plan(const Plan &src) : src_(src) {} - MSHADOW_XINLINE DstDType Eval(index_t y, index_t x) const { - return DstDType(src_.Eval(y, x)); // NOLINT(*) - } - - private: - Plan src_; -}; - -// ternary expression -template -class Plan, DType> { - public: - explicit Plan(const Plan &item1, const Plan &item2, - const Plan &item3) - : item1_(item1), item2_(item2), item3_(item3) {} - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - return OP::Map(item1_.Eval(y, x), item2_.Eval(y, x), item3_.Eval(y, x)); - } - - private: - Plan item1_; - Plan item2_; - Plan item3_; -}; -// binary expression -template -class Plan, DType> { - public: - explicit Plan(const Plan &lhs, const Plan &rhs) - : lhs_(lhs), rhs_(rhs) {} - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - return OP::Map(lhs_.Eval(y, x), rhs_.Eval(y, x)); - } - - private: - Plan lhs_; - Plan rhs_; -}; -// unary expression -template -class Plan, DType> { - public: - explicit Plan(const Plan &src) : src_(src) {} - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - return OP::Map(src_.Eval(y, x)); - } - - private: - Plan src_; -}; -// remaps map tensor expression to subtype's plan -template -struct Plan, DType> { - public: - Plan(const Plan &src) : src_(src) {} - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - return src_.Eval(y, x); - } - - private: - Plan src_; -}; -// tranpsoe -template -class Plan, DType> { - public: - explicit Plan(const Plan &src) : src_(src) {} - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - return src_.Eval(x, y); - } - - private: - Plan src_; -}; -//---------------------------------------------------------------------- -// Mappings from expression to plans -//--------------------------------------------------------------------- -template -inline Plan, DType> -MakePlan(const BinaryMapExp &e); - -template -inline Plan, DType> -MakePlan(const TernaryMapExp &e); - -template -inline Plan, DType> MakePlan(const ScalarExp &e) { - return Plan, DType>(e.scalar_); -} - -template -inline Plan, DstDType> -MakePlan(const TypecastExp &e) { - return Plan, DstDType>(MakePlan(e.exp)); -} - -template -inline Plan MakePlan(const RValueExp &e) { - return Plan(e.self()); -} - -template -inline Plan, DType> -MakePlan(const TransposeExp &e) { - return Plan, DType>(MakePlan(e.exp)); -} - -template -inline Plan -MakePlan(const MakeTensorExp &e) { - return Plan(e.real_self()); -} - -template -inline Plan, DType> -MakePlan(const UnaryMapExp &e) { - return Plan, DType>(MakePlan(e.src_)); -} - -template -inline Plan, DType> -MakePlan(const BinaryMapExp &e) { - return Plan, - DType>(MakePlan(e.lhs_), MakePlan(e.rhs_)); -} - -// Ternary -template -inline Plan, DType> -MakePlan(const TernaryMapExp &e) { - return Plan, - DType>(MakePlan(e.item1_), MakePlan(e.item2_), MakePlan(e.item3_)); -} -//---------------------------------------------------------------- -// Static Type inference and Type Checking -//---------------------------------------------------------------- -/*! - * \brief static type inference template, - * used to get the dimension of each expression, - * if ExpInfo::kDim == -1, this means here are mismatch in expression - * if (ExpInfo::kDevMask & cpu::kDevMask) != 0, this means this expression can be assigned to cpu - * \tparam E expression - */ -template -struct ExpInfo { - static const int kDim = -1; - static const int kDevMask = 0; -}; -template -struct ExpInfo< ScalarExp > { - static const int kDim = 0; - static const int kDevMask = 0xffff; -}; -template -struct ExpInfo > { - static const int kDim = ExpInfo::kDim; - static const int kDevMask = ExpInfo::kDevMask; -}; -template -struct ExpInfo > { - static const int kDim = ExpInfo::kDim; - static const int kDevMask = ExpInfo::kDevMask; -}; -template -struct ExpInfo > { - static const int kDim = dim; - static const int kDevMask = Device::kDevMask; -}; -template -struct ExpInfo > { - static const int kDimSrc = ExpInfo::kDim; - static const int kDim = kDimSrc >= 0 ? dim : -1; - static const int kDevMask = ExpInfo::kDevMask; -}; -template -struct ExpInfo > { - static const int kDim = ExpInfo::kDim; - static const int kDevMask = ExpInfo::kDevMask; -}; -template -struct ExpInfo > { - static const int kDimLhs = ExpInfo::kDim; - static const int kDimRhs = ExpInfo::kDim; - static const int kDim = (kDimLhs >= 0 && kDimRhs >= 0) ?\ - (kDimLhs == 0 ?\ - kDimRhs :\ - ((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1; - static const int kDevMask = ExpInfo::kDevMask & ExpInfo::kDevMask; -}; -template -struct ExpInfo > { - static const int kDimItem1 = ExpInfo::kDim; - static const int kDimItem2 = ExpInfo::kDim; - static const int kDimItem3 = ExpInfo::kDim; - static const int kDim = kDimItem1; - static const int kDevMask = ExpInfo::kDevMask & ExpInfo::kDevMask & ExpInfo::kDevMask; -}; - -/*! \brief template to do type check */ -template -struct TypeCheck { - /*! \brief dimension of expression*/ - static const int kExpDim = ExpInfo::kDim; - /*! \brief whether the expression device type matches */ - static const bool kDevPass = (ExpInfo::kDevMask & Device::kDevMask) != 0; - /*! \brief whether the expression can be mapped to expression of dim */ - static const bool kMapPass = (kExpDim == 0 || kExpDim == dim) && kDevPass; - /*! \brief whether the expression can be reduced to expression of dim */ - static const bool kRedPass = (kExpDim > dim) && kDevPass; -}; -/*! \brief used to help static type check*/ -template -struct TypeCheckPass; -// Todo : add static assert using C++11 -template<> -struct TypeCheckPass {}; -template<> -struct TypeCheckPass { - inline static void Error_All_Tensor_in_Exp_Must_Have_Same_Type(void) {} - inline static void Error_TypeCheck_Not_Pass_For_Reduce_Exp(void) {} - inline static void Error_Expression_Does_Not_Meet_Dimension_Req(void) {} -}; - -//---------------------------------------------------------------- -// Runtime Stream Getting -//---------------------------------------------------------------- -template -struct StreamInfo { - inline static Stream *Get(const E &t); -}; -template -struct StreamInfo > { - inline static Stream *Get(const Tensor &t) { - return t.stream_; - } -}; -//---------------------------------------------------------------- -// Runtime Shape Checking -//---------------------------------------------------------------- -/*! - * \brief runtime shape checking template - * get the shape of an expression, report error if shape mismatch - * \tparam dim the dimension of the shape - * \tparam E expression - */ -template -struct ShapeCheck { - inline static Shape Check(const E &t); -}; -template -struct ShapeCheck > { - inline static Shape Check(const ScalarExp &exp) { - // use lowest dimension to mark scalar exp - Shape shape; - for (int i = 0; i < dim; ++i) { - shape[i] = 0; - } - return shape; - } -}; -template -struct ShapeCheck > { - inline static Shape - Check(const TypecastExp &exp) { - return ShapeCheck::Check(exp.exp); - } -}; -template -struct ShapeCheck > { - inline static Shape Check(const TransposeExp &e) { - // swap the lowest two dimensions - Shape s = ShapeCheck::Check(e.exp); - std::swap(s[0], s[1]); - return s; - } -}; -template -struct ShapeCheck > { - inline static Shape Check(const Tensor &t) { - return t.shape_; - } -}; -template -struct ShapeCheck > { - inline static Shape - Check(const MakeTensorExp &t) { - return t.shape_; - } -}; -template -struct ShapeCheck > { - inline static Shape Check(const UnaryMapExp &t) { - Shape s = ShapeCheck::Check(t.src_); - return s; - } -}; - -template -struct ShapeCheck > { - inline static Shape - Check(const BinaryMapExp &t) { - Shape shape1 = ShapeCheck::Check(t.lhs_); - Shape shape2 = ShapeCheck::Check(t.rhs_); - if (shape1[0] == 0) return shape2; - if (shape2[0] == 0) return shape1; - CHECK_EQ(shape1, shape2) << "BinaryMapExp: Shapes of operands are not the same, " << - "Shape1=" << shape1 << ", Shape2=" << shape2; - return shape1; - } -}; - -template -struct ShapeCheck > { - inline static Shape - Check(const TernaryMapExp &t) { - Shape shape1 = ShapeCheck::Check(t.item1_); - Shape shape2 = ShapeCheck::Check(t.item2_); - Shape shape3 = ShapeCheck::Check(t.item3_); - bool same = (shape1 == shape2) && (shape2 == shape3); - CHECK(same) << "TernaryMapExp: Shapes of operands are not the same, " << - "Shape1=" << shape1 << ", Shape2=" << shape2 << ", Shape3=" << shape3; - - return shape1; - } -}; -} // namespace expr - -} // namespace mshadow -// include definition of dot engine -#include "./dot_engine-inl.h" - -namespace mshadow { -namespace expr { -/*! \brief some engine that evaluate complex expression */ -template -struct ExpComplexEngine { - inline static void Eval(RV *dst, const E &exp); -}; -/*! \brief the engine that dispatches simple operations*/ -template -struct ExpEngine { - template - inline static void Eval(RV *dst, - const Exp &exp) { - MapExp(dst, exp); - } - template - inline static void Eval(RV *dst, - const Exp &exp) { - MapExp(dst, exp); - } - template - inline static void Eval(RV *dst, - const Exp &exp) { - MapExp(dst, exp); - } - template - inline static void Eval(RV *dst, - const Exp &exp) { - ExpComplexEngine::Eval(dst->ptrself(), exp.self()); - } -}; -template -struct ExpComplexEngine, - DotExp, - Tensor, - ltrans, rtrans, DType>, - DType> { - inline static void Eval(Tensor *dst, - const DotExp, - Tensor, - ltrans, rtrans, DType> &exp) { - DotEngine::Eval(dst, exp.lhs_, exp.rhs_, exp.scale_); - } -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXPR_ENGINE_INL_H_ diff --git a/include/mshadow/expr_scalar-inl.h b/include/mshadow/expr_scalar-inl.h deleted file mode 100644 index 1ddaba412543..000000000000 --- a/include/mshadow/expr_scalar-inl.h +++ /dev/null @@ -1,165 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file expr_scalar-inl.h - * \brief definitions of operators in expression with respect to scalar - * this file will be included several times, each time with MACRO MSHADOW_SCALAR_ to be different types - * - * DO NOT add pragma once or macro guard - * \author Tianqi Chen, Bing Xu - */ -// macro guard is harmful, used to pass the cpplint -#ifndef MSHADOW_EXPR_SCALAR_INL_H_ -#define MSHADOW_EXPR_SCALAR_INL_H_ -// undef the guard so it can be included multiple times -#undef MSHADOW_EXPR_SCALAR_INL_H_ - -namespace mshadow { -namespace expr { -// DotExp -/*! \brief dot operator def */ -template -inline DotExp -operator*(const DotExp &lhs, - MSHADOW_SCALAR_ rhs) { - return DotExp(lhs.lhs_, lhs.rhs_, lhs.scale_ * rhs); -} -/*! \brief scale of dot operation */ -template -inline DotExp -operator*(MSHADOW_SCALAR_ lhs, - const DotExp &rhs) { - return DotExp(rhs.lhs_, rhs.rhs_, rhs.scale_ * lhs); -} - -/*! \brief operator overload */ -template -inline ReduceTo1DExp -operator*(const ReduceTo1DExp &e, MSHADOW_SCALAR_ scale) { - return ReduceTo1DExp(e.src_, e.scale_ * scale); -} -/*! \brief operator overload */ -template -inline ReduceTo1DExp -operator*(MSHADOW_SCALAR_ scale, const ReduceTo1DExp &e) { - return ReduceTo1DExp(e.src_, e.scale_ * scale); -} - -/*! \brief operator overload for const */ -template -inline BinaryMapExp, - MSHADOW_SCALAR_, (ta|type::kMapper)> -F(const Exp &lhs, const ScalarExp &rhs) { - return MakeExp(lhs, rhs); -} -/*! \brief operator overload for const */ -template -inline BinaryMapExp, TB, - MSHADOW_SCALAR_, (tb|type::kMapper)> -F(const ScalarExp &lhs, const Exp &rhs) { - return MakeExp(lhs, rhs); -} -/*! \brief operator overload for const */ -template -inline BinaryMapExp, ScalarExp, - MSHADOW_SCALAR_, (1|type::kMapper)> -F(const ScalarExp &lhs, const ScalarExp &rhs) { - return MakeExp(lhs, rhs); -} -// constant operators -/*! \brief operator overload */ -template -inline BinaryMapExp, - MSHADOW_SCALAR_, (ta|type::kMapper)> -operator+(const Exp &lhs, - const ScalarExp &rhs) { - return MakeExp(lhs, rhs); -} -/*! \brief operator overload */ -template -inline BinaryMapExp, - MSHADOW_SCALAR_, (ta|type::kMapper)> -operator-(const Exp &lhs, - const ScalarExp &rhs) { - return MakeExp(lhs, rhs); -} -/*! \brief operator overload */ -template -inline BinaryMapExp, - MSHADOW_SCALAR_, (ta|type::kMapper)> -operator*(const Exp &lhs, - const ScalarExp &rhs) { - return MakeExp(lhs, rhs); -} -/*! \brief operator overload */ -template -inline BinaryMapExp, - MSHADOW_SCALAR_, (ta|type::kMapper)> -operator/(const Exp &lhs, - const ScalarExp &rhs) { - return MakeExp(lhs, rhs); -} -// constant operators 2 -/*! \brief operator overload */ -template -inline BinaryMapExp, TB, - MSHADOW_SCALAR_, (tb|type::kMapper)> -operator+(const ScalarExp &lhs, - const Exp &rhs) { - return MakeExp(lhs, rhs); -} -/*! \brief operator overload */ -template -inline BinaryMapExp, TB, - MSHADOW_SCALAR_, (tb|type::kMapper)> -operator-(const ScalarExp &lhs, - const Exp &rhs) { - return MakeExp(lhs, rhs); -} -/*! \brief operator overload */ -template -inline BinaryMapExp, TB, - MSHADOW_SCALAR_, (tb|type::kMapper)> -operator*(const ScalarExp &lhs, - const Exp &rhs) { - return MakeExp(lhs, rhs); -} -/*! \brief operator overload */ -template -inline BinaryMapExp, TB, - MSHADOW_SCALAR_, (tb|type::kMapper)> -operator/(const ScalarExp &lhs, const Exp &rhs) { - return MakeExp(lhs, rhs); -} -// constant operators 3 -/*! \brief operator overload */ -inline BinaryMapExp, ScalarExp, - MSHADOW_SCALAR_, (1|type::kMapper)> -operator+(const ScalarExp &lhs, - const ScalarExp &rhs) { - return MakeExp(lhs, rhs); -} -/*! \brief operator overload */ -inline BinaryMapExp, ScalarExp, - MSHADOW_SCALAR_, (1|type::kMapper)> -operator-(const ScalarExp &lhs, - const ScalarExp &rhs) { - return MakeExp(lhs, rhs); -} -/*! \brief operator overload */ -inline BinaryMapExp, ScalarExp, - MSHADOW_SCALAR_, (1|type::kMapper)> -operator*(const ScalarExp &lhs, - const ScalarExp &rhs) { - return MakeExp(lhs, rhs); -} -/*! \brief operator overload */ -inline BinaryMapExp, ScalarExp, - MSHADOW_SCALAR_, (1|type::kMapper)> -operator/(const ScalarExp &lhs, const ScalarExp &rhs) { - return MakeExp(lhs, rhs); -} -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXPR_SCALAR_INL_H_ diff --git a/include/mshadow/expression.h b/include/mshadow/expression.h deleted file mode 100644 index 77f943165088..000000000000 --- a/include/mshadow/expression.h +++ /dev/null @@ -1,416 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file expression.h - * \brief definitions of abstract expressions and expressions template - * \author Tianqi Chen, Bing Xu - */ -#ifndef MSHADOW_EXPRESSION_H_ -#define MSHADOW_EXPRESSION_H_ -#include "./base.h" - -namespace mshadow { -/*! - * \brief namespace for abstract expressions and expressions template, - * have no dependency on tensor.h, - * These data structure takes no charge in computations, - * they are only used to define operations and represent expression in a symbolic way - */ -namespace expr { -/*! \brief type of expressions */ -namespace type { -// type expression type are defined as bitmask -// subtype relationshop kRValue < kMapper < kPull < kComplex -/*! - * \brief this expression directly correspnds to a data class, - * can be used to assign data - */ -const int kRValue = 0; -/*! - * \brief expression contains element-wise tensor operations, - * map a expression to same shape - */ -const int kMapper = 1; -/*! - * \brief expression that can be chained with other expressiones - * Usually it have function Eval(i,j) defined, which pulls the result (i, j) from input - * expression and output the result at certain position. - */ -const int kChainer = 3; -/*! \brief othercase: e.g dot product */ -const int kComplex = 7; -} // namespace type -/*! - * \brief expression engine that actually interprets these expressions - * this is a function template that needed to be implemented for specific expressions - * \tparam Saver the save method - * \tparam RValue the type of RValue to be saved - * \sa namespace sv - */ -template -struct ExpEngine; -/*! \brief defines how expression exp can be evaluated and stored into dst */ -// template -// inline static void Eval(RValue *dst, const EType &exp); -/*! - * \brief base class for expression - * \tparam SubType inheritated class must put their type into this parameter - * \tparam DType the data type of each element in the expression - * \tparam exp_type expression type, see namespace type - */ -template -struct Exp { - public: - /*! \return subtype instance of current class */ - inline const SubType& self(void) const { - return *static_cast(this); - } - /*! \return reference of subtype instance of current class */ - inline SubType* ptrself(void) { - return static_cast(this); - } -}; -/*! - * \brief scalar expression - * \tparam DType the data type of the scalar - */ -template -struct ScalarExp: public Exp, DType, type::kMapper> { - /*! \brief scalar value */ - DType scalar_; - /*! \brief implicit constructor, MUST NOT BE explicit */ - ScalarExp(DType scalar) : scalar_(scalar) {} // NOLINT(*) -}; -/*! \brief create an scalar expression */ -template -inline ScalarExp scalar(DType s) { - return ScalarExp(s); -} -/*! - * \brief typecast expression, cast the type of elements - * \tparam DstDType the target type we want to cast into - * \tparam SrcDType the target type we want to cast from - * \tparam EType the type of the source expression - * \tparam etype the type of expression after cast - */ -template -struct TypecastExp: - public Exp, - DstDType, etype> { - /*! \brief expression to be typecasted */ - const EType &exp; - /*! \brief constructor */ - explicit TypecastExp(const EType &e) : exp(e) {} -}; -/*! \brief create an scalar expression */ -template -inline TypecastExp -tcast(const Exp &exp) { - return TypecastExp(exp.self()); -} -/*! \brief represent a transpose expression of a container */ -template -struct TransposeExp: public Exp, - DType, type::kChainer> { - /*! \brief expression to be transposed */ - const EType &exp; - /*! \brief constructor */ - explicit TransposeExp(const EType &e) : exp(e) {} - /*! \brief transpose expression */ - inline const EType &T(void) const { - return exp; - } -}; -/*! - * \brief base class of all rvalues - * \tparam Container the actually class of data container, e.g. Tensor1D - * \tparam DataType the element data type of each element in the container - */ -template -class RValueExp: public Exp { - public: - /*! - *\brief transpose of a matrix - *\return transpose of current expression - */ - inline const TransposeExp T(void) const { - return TransposeExp(this->self()); - } - /*! \brief operator overload */ - inline Container &operator+=(DType s) { - ExpEngine::Eval(this->ptrself(), scalar(s)); - return *(this->ptrself()); - } - /*! \brief operator overload */ - inline Container &operator-=(DType s) { - ExpEngine::Eval(this->ptrself(), scalar(s)); - return *(this->ptrself()); - } - /*! \brief operator overload */ - inline Container &operator*=(DType s) { - ExpEngine::Eval(this->ptrself(), scalar(s)); - return *(this->ptrself()); - } - /*! \brief operator overload */ - inline Container &operator/=(DType s) { - ExpEngine::Eval(this->ptrself(), scalar(s)); - return *(this->ptrself()); - } - /*! \brief operator overload */ - inline Container &__assign(DType s) { - ExpEngine::Eval(this->ptrself(), scalar(s)); - return *(this->ptrself()); - } - /*! \brief we can not define container = container */ - template - inline Container &__assign(const Exp &exp) { - ExpEngine::Eval(this->ptrself(), exp.self()); - return *(this->ptrself()); - } - /*! \brief operator overload, assign */ - inline Container &__assign(const Exp &exp); - /*! \brief implementation of operator+= */ - template - inline Container &operator+=(const Exp &exp) { - ExpEngine::Eval(this->ptrself(), exp.self()); - return *(this->ptrself()); - } - /*! \brief implementation of operator-= */ - template - inline Container &operator-=(const Exp &exp) { - ExpEngine::Eval(this->ptrself(), exp.self()); - return *(this->ptrself()); - } - /*! \brief implementation of operator*= */ - template - inline Container &operator*=(const Exp &exp) { - ExpEngine::Eval(this->ptrself(), exp.self()); - return *(this->ptrself()); - } - /*! \brief implementation of operator/= */ - template - inline Container &operator/=(const Exp &exp) { - ExpEngine::Eval(this->ptrself(), exp.self()); - return *(this->ptrself()); - } -}; -/*! - * \brief matrix multiplication expression dot(lhs[.T], rhs[.T]) - * \tparam TA type of lhs - * \tparam TB type of rhs - * \tparam ltrans whether lhs is transposed - * \tparam rtrans whether rhs is transposed - * \tparam DType the data type of the scalar - */ -template -struct DotExp: public Exp, - DType, type::kComplex> { - /*! \brief left operand */ - const TA &lhs_; - /*! \brief right operand */ - const TB &rhs_; - /*! \brief scale over result */ - DType scale_; - /*! \brief constructor */ - explicit DotExp(const TA &lhs, const TB &rhs, DType scale) - : lhs_(lhs), rhs_(rhs), scale_(scale) {} -}; -// definition of dot expression -/*! \brief dot operator def */ -template -inline DotExp -dot(const RValueExp &lhs, const RValueExp &rhs) { - return DotExp(lhs.self(), rhs.self(), DType(1.0f)); -} -/*! \brief dot operator def */ -template -inline DotExp -dot(const TransposeExp &lhs, const RValueExp &rhs) { - return DotExp(lhs.exp, rhs.self(), DType(1.0f)); -} -/*! \brief dot operator def */ -template -inline DotExp -dot(const RValueExp &lhs, const TransposeExp &rhs) { - return DotExp(lhs.self(), rhs.exp, DType(1.0f)); -} -/*! \brief dot operator def */ -template -inline DotExp -dot(const TransposeExp &lhs, const TransposeExp &rhs) { - return DotExp(lhs.exp, rhs.exp, DType(1.0f)); -} -/*! \brief batch_dot operator def */ -template -inline DotExp -batch_dot(const RValueExp &lhs, const RValueExp &rhs) { - return DotExp( - lhs.self(), rhs.self(), DType(1.0f)); -} -//--------------- -// TernaryMapExp -// -------------- -/*! - * \brief ternary map expression - * \tparam OP operator - * \tparam TA type of item1 - * \tparam TB type of item2 - * \tparam etype expression type, sa namespace::type - */ -template -struct TernaryMapExp: public Exp, - DType, etype> { - /*! \brief first operand */ - const TA &item1_; - /*! \brief second operand */ - const TB &item2_; - /*! \brief third operand */ - const TC &item3_; - /*! \brief constructor */ - explicit TernaryMapExp(const TA &item1, const TB &item2, const TC &item3) - :item1_(item1), item2_(item2), item3_(item3) {} -}; - -/*! \brief make expression */ -template -inline TernaryMapExp -MakeExp(const Exp &item1, const Exp &item2, - const Exp &item3) { - return TernaryMapExp(item1.self(), item2.self(), item3.self()); -} -/*! - * \brief short hand for MakeExp, usage F(item1,item2,item3). create a ternary operation expression - * \param item1 first operand - * \param item2 second operand - * \param item3 third operand - * \return the result expression - * \tparam ternary operator - * \tparam TA item1 expression - * \tparam ta item1 expression type - * \tparam TB item2 expression - * \tparam tb item2 expression type - * \tparam TC item3 expression - * \tparam tc item3 expression type - * \sa mshadow::op - */ - -// Ternary -template -inline TernaryMapExp -F(const Exp &item1, const Exp &item2, - const Exp &item3) { - return MakeExp(item1, item2, item3); -} -//--------------- -// BinaryMapExp -// -------------- -/*! - * \brief binary map expression lhs [op] rhs - * \tparam OP operator - * \tparam TA type of lhs - * \tparam TB type of rhs - * \tparam etype expression type, sa namespace::type - */ -template -struct BinaryMapExp: public Exp, - DType, etype> { - /*! \brief left operand */ - const TA &lhs_; - /*! \brief right operand */ - const TB &rhs_; - /*! \brief constructor */ - explicit BinaryMapExp(const TA &lhs, const TB &rhs) - :lhs_(lhs), rhs_(rhs) {} -}; - -/*! \brief make expression */ -template -inline BinaryMapExp -MakeExp(const Exp &lhs, const Exp &rhs) { - return BinaryMapExp(lhs.self(), rhs.self()); -} -/*! - * \brief short hand for MakeExp, usage F(lhs, rhs). create a binary operation expression - * \param lhs left operand - * \param rhs right operand - * \return the result expression - * \tparam binary operator - * \tparam TA lhs expression - * \tparam ta lhs expression type - * \tparam TB rhs expression - * \tparam tb rhs expression type - * \sa mshadow::op - */ -template -inline BinaryMapExp -F(const Exp &lhs, const Exp &rhs) { - return MakeExp(lhs, rhs); -} -// operator rules -/*! \brief operator overload */ -template -inline BinaryMapExp -operator+(const Exp &lhs, const Exp &rhs) { - return MakeExp(lhs, rhs); -} -/*! \brief operator overload */ -template -inline BinaryMapExp -operator-(const Exp &lhs, const Exp &rhs) { - return MakeExp(lhs, rhs); -} -/*! \brief operator overload */ -template -inline BinaryMapExp -operator*(const Exp &lhs, const Exp &rhs) { - return MakeExp(lhs, rhs); -} -/*! \brief operator overload */ -template -inline BinaryMapExp -operator/(const Exp &lhs, const Exp &rhs) { - return MakeExp(lhs, rhs); -} -//--------------- -// UnaryMapExp -// -------------- -/*! - * \brief unary map expression op(src) - * \tparam OP operator - * \tparam TA type of src - * \tparam etype expression type, sa namespace::type - */ -template -struct UnaryMapExp: public Exp, - DType, etype> { - /*! \brief source expression */ - const TA &src_; - /*! \brief constructor */ - explicit UnaryMapExp(const TA &src) : src_(src) {} -}; - -/*! \brief make expression */ -template -inline UnaryMapExp -MakeExp(const Exp &src) { - return UnaryMapExp(src.self()); -} -/*! - * \brief short hand for MakeExp, usage F(src), create a unary operation expression - * \param src source expression - * \return the result expression - * \tparam operator - * \tparam TA source expression - * \tparam ta source expression type - * \sa mshadow::op - */ -template -inline UnaryMapExp -F(const Exp &src) { - return MakeExp(src); -} -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXPRESSION_H_ diff --git a/include/mshadow/extension.h b/include/mshadow/extension.h deleted file mode 100644 index 7af0f56f7699..000000000000 --- a/include/mshadow/extension.h +++ /dev/null @@ -1,41 +0,0 @@ -/*! - * Copyright by Contributors - * \file extension.h - * \brief some extension of expressions, - * used to support something beyond elementwise op - * \author Tianqi Chen, Bing Xu - */ -#ifndef MSHADOW_EXTENSION_H_ -#define MSHADOW_EXTENSION_H_ -#include "./expr_engine-inl.h" -#include "./extension/broadcast.h" -#include "./extension/unpack_patch2col.h" -#include "./extension/pack_col2patch.h" -#include "./extension/reshape.h" -#include "./extension/swapaxis.h" -#include "./extension/reduceto1d.h" -#include "./extension/spatial_pool.h" -#include "./extension/spatial_unpool.h" -#include "./extension/channel_pool.h" -#include "./extension/channel_unpool.h" -#include "./extension/pad.h" -#include "./extension/crop.h" -#include "./extension/mirror.h" -#include "./extension/concat.h" -#include "./extension/implicit_gemm.h" -#include "./extension/choose.h" -#include "./extension/fill.h" -#include "./extension/one_hot.h" -#include "./extension/slice.h" -#include "./extension/slice_ex.h" -#include "./extension/take.h" -#include "./extension/take_grad.h" -#include "./extension/reduce_with_axis.h" -#include "./extension/broadcast_with_axis.h" -#include "./extension/spatial_upsampling_nearest.h" -#include "./extension/transpose.h" -#include "./extension/flip.h" -#include "./extension/complex.h" -#include "./extension/range.h" -#include "./extension/mask.h" -#endif // MSHADOW_EXTENSION_H_ diff --git a/include/mshadow/extension/broadcast.h b/include/mshadow/extension/broadcast.h deleted file mode 100644 index ea138ccd9e4d..000000000000 --- a/include/mshadow/extension/broadcast.h +++ /dev/null @@ -1,165 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file broadcast.h - * \brief support for broadcast and repmat - * \author Tianqi Chen - */ -#ifndef MSHADOW_EXTENSION_BROADCAST_H_ -#define MSHADOW_EXTENSION_BROADCAST_H_ -#include "../extension.h" -namespace mshadow { -namespace expr { -/*! - * \brief broadcast Tensor1D into a higher dimension Tensor - * input: Tensor: ishape[0] - * output: Tensor : oshape[dimcast] = ishape[0] - * \tparam SrcExp type of input expression - * \tparam DType the type of elements - * \tparam dimdst target tensor dimension - * \tparam dimcast_m_dst dimdst - dimcast - */ -template -struct Broadcast1DExp: - public MakeTensorExp, - SrcExp, dimdst, DType> { - /*! \brief source operand */ - const SrcExp &src_; - /*! \brief constructor */ - Broadcast1DExp(const SrcExp &src, Shape shape) - : src_(src) { - this->shape_ = shape; - } -}; - -/*! - * \brief broadcast scalar into a higher dimension Tensor - * input: Tensor: ishape = {1} - * output: Tensor : oshape[dimcast] = ishape[0] - * \tparam SrcExp type of input expression - * \tparam DType the type of elements - * \tparam dimdst target tensor dimension - */ -template -struct BroadcastScalarExp: - public MakeTensorExp, - SrcExp, dimdst, DType> { - /*! \brief source operand */ - const SrcExp &src_; - /*! \brief constructor */ - BroadcastScalarExp(const SrcExp &src, Shape shape) - : src_(src) { - this->shape_ = shape; - } -}; - -/*! - * \brief a expression that replicate a 1 dimension tensor in dimension dimcast - * \param src Tensor: shape[0] - * \param shape shape of output - * \return a expresion with type Tensor - * \tparam dimcast target dimension where the 1D tensor will be broadcasted - * \tparam SrcExp type of input expression - * \tparam DType the type of elements - * \tparam dimdst dimension of destination tensor - * \tparam dimcast_lowest the dimension we want to cast the data into - */ -template -inline Broadcast1DExp -broadcast(const expr::Exp &src, Shape shape) { - TypeCheckPass::kDim == 1> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - typedef ShapeCheck<1, SrcExp> ShapeCheckDim1SrcExp; - CHECK_EQ(ShapeCheckDim1SrcExp::Check(src.self())[0], shape[dimcast]) - << "broadcast, shape mismatch"; - return Broadcast1DExp(src.self(), shape); -} - -/*! - * \brief a expression that replicate a scalar tensor to target dimension. - * \param src Tensor: shape[0] == 1 - * \param shape shape of output - * \return a expresion with type Tensor - * \tparam dimcast target dimension where the 1D tensor will be broadcasted - * \tparam SrcExp type of input expression - * \tparam DType the type of elements - * \tparam dimdst dimension of destination tensor - */ -template -inline BroadcastScalarExp -broadcast_scalar(const expr::Exp &src, Shape shape) { - TypeCheckPass::kDim == 1> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - typedef ShapeCheck<1, SrcExp> ShapeCheckDim1SrcExp; - CHECK_EQ(ShapeCheckDim1SrcExp::Check(src.self())[0], 1U) - << "broadcast_scalar, source need to be scalar expression"; - return BroadcastScalarExp(src.self(), shape); -} -// short cut functions -/*! - * \brief a expression that replicate a 1 dimension tensor for nrow times - * \param src Tensor: shape[0] - * \param nrow number of rows to replicate - * \return a expresion with type Tensor size(1), size(0) = nrow - * \tparam Device which device it lies - */ -template -inline Broadcast1DExp -repmat(const expr::Exp &src, index_t nrow) { - return broadcast<1> - (src, Shape2(nrow, ShapeCheck<1, SrcExp>::Check(src.self())[0])); -} -//---------------------- -// Execution plan -//---------------------- -template -struct Plan, DType> { - public: - static const int dimcast = dimdst - dimdst_m_cast; - explicit Plan(const Broadcast1DExp &e) - : src_(MakePlan(e.src_)), - ystride_(e.shape_.ProdShape(dimcast + 1, dimdst - 1)), - length_(e.shape_[dimcast]) { - TypeCheckPass - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - } - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - return src_.Eval(0, (y / ystride_) % length_); - } - - private: - expr::Plan src_; - const index_t ystride_, length_; -}; - -/*! \brief execution plan of Broadcast1DExp */ -template -struct Plan, DType>{ - public: - explicit Plan(const Broadcast1DExp &e) - : src_(MakePlan(e.src_)) {} - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - return src_.Eval(0, x); - } - - private: - expr::Plan src_; -}; - -/*! \brief execution plan of Broadcast1DExp */ -template -struct Plan, DType>{ - public: - explicit Plan(const BroadcastScalarExp &e) - : src_(MakePlan(e.src_)) {} - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - return src_.Eval(0, 0); - } - - private: - expr::Plan src_; -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_BROADCAST_H_ diff --git a/include/mshadow/extension/broadcast_with_axis.h b/include/mshadow/extension/broadcast_with_axis.h deleted file mode 100644 index 49605af67d32..000000000000 --- a/include/mshadow/extension/broadcast_with_axis.h +++ /dev/null @@ -1,258 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file broadcast_with_axis.h - * \brief - * \author Junyuan Xie, Xingjian Shi -*/ -#ifndef MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_ -#define MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_ - -#include -#include "../extension.h" - -namespace mshadow { -namespace expr { - - /*! - * \brief Broadcasting the tensor in the given axis. If keepdim is off, insert the broadcasting dim after axis. Otherwise broadcasting axis. - * \tparam SrcExp source expression - * \tparam DType data type - * \tparam dimsrc source dimension - * \tparam dimdst destination dimension - */ -template -struct BroadcastWithAxisExp: - public MakeTensorExp, - SrcExp, dimdst, DType> { - /*! \brief data oprand */ - const SrcExp &src_; - /*! \brief size of the last dimension of dst */ - index_t dst_last_; - /*! \brief product of the dimensions after the broadcasting axis */ - index_t trailing_; - /*! \brief new dimension of the broadcasting axis*/ - index_t size_; - /*! \brief size of the last dimension of src*/ - index_t last_; - /*! constructor */ - BroadcastWithAxisExp(const SrcExp &src, const int axis, const index_t size) - : src_(src), size_(size) { - bool keepdim = (dimsrc == dimdst); - Shape src_shape = ShapeCheck::Check(src_); - this->trailing_ = 1; - - if (!keepdim) { - CHECK(dimsrc > axis && axis >= -1) << "broadcast axis (no keepdim) out of bound, " << - "axis must be between -1 and" << dimsrc - 1 << ", given=" << axis << "."; - for (int i = 0; i <= axis; ++i) { - this->shape_[i] = src_shape[i]; - } - this->shape_[axis + 1] = size_; - for (int i = axis + 1; i < dimsrc; ++i) { - this->trailing_ *= src_shape[i]; - this->shape_[i + 1] = src_shape[i]; - } - } else { - CHECK(dimdst > axis && axis >= 0) << "broadcast axis (keepdim) out of bound, " << - "axis must be between 0 and" << dimdst - 1 << ", given=" << axis << "."; - CHECK_EQ(src_shape[axis], 1U) << "Size of the dimension of the broadcasting axis must be 1" << - " when keepdim is on, src_shape[" << axis << "]=" << src_shape[axis] << "."; - for (int i = 0; i <= axis - 1; ++i) { - this->shape_[i] = src_shape[i]; - } - this->shape_[axis] = size_; - for (int i = axis + 1; i < dimdst; ++i) { - this->trailing_ *= src_shape[i]; - this->shape_[i] = src_shape[i]; - } - } - - this->last_ = src_shape[dimsrc - 1]; - this->dst_last_ = this->shape_[dimdst - 1]; - } -}; // struct BroadcastWithAxisExp - -/*! - * \brief Broadcasting the tensor after given axis. - * \tparam SrcExp source expression - * \tparam DType data type - * \tparam etype type of the expression - */ -template -inline BroadcastWithAxisExp::kDim, - ExpInfo::kDim + 1> -broadcast_with_axis(const Exp &src, const int axis, const index_t size) { - return BroadcastWithAxisExp::kDim, - ExpInfo::kDim + 1>(src.self(), axis, size); -} - -/*! -* \brief Broadcasting the tensor in the given axis (keepdim turned on) -* \tparam SrcExp source expression -* \tparam DType data type -* \tparam etype type of the expression -*/ -template -inline BroadcastWithAxisExp::kDim, - ExpInfo::kDim> - broadcast_keepdim(const Exp &src, const int axis, const index_t size) { - return BroadcastWithAxisExp::kDim, - ExpInfo::kDim>(src.self(), axis, size); -} - -/*! -* \brief Broadcasting the tensor in multiple axes. The dimension of the source tensor - in the given axes must be 1. -* \tparam SrcExp source expression -* \tparam DType data type -* \tparam dimsrc source dimension -* \tparam axesnum number of broadcasting dimensions -*/ -template -struct BroadcastWithMultiAxesExp : - public MakeTensorExp, - SrcExp, dimsrc, DType> { - /*! \brief data oprand */ - const SrcExp &src_; - /*! \brief size of the last dimension of dst */ - index_t dst_last_; - /*! \brief number of broadcasting axes*/ - index_t axesnum_; - /*! \brief product of the dimensions after the broadcasting axses */ - Shape trailings_; - /*! \brief new dimension of the broadcasting axes*/ - Shape sizes_; - /*! \brief size of the last dimension of src*/ - index_t last_; - /*! constructor */ - template - BroadcastWithMultiAxesExp(const SrcExp &src, const TShape& axes, const TShape& sizes) - : src_(src) { - Shape src_shape = ShapeCheck::Check(src_); - CHECK(axes.ndim() == sizes.ndim()) << "ndim of axes and sizes must be equal."; - this->axesnum_ = axes.ndim(); - CHECK(this->axesnum_ <= dimsrc) << "Number of broadcasting axes must be smaller than" - "the source ndim, number of axes=" << this->axesnum_ << " dimsrc=" << dimsrc; - for (index_t i = 0; i < this->axesnum_; i++) { - CHECK(dimsrc > axes[i]) << "broadcast axis (keepdim) out of bound, " << - "all axes must be between 0 and" << dimsrc - 1 << ", given axes[" << i << "] = " << axes[i] - << "."; - CHECK_EQ(src_shape[axes[i]], 1U) << "Size of the dimension of the broadcasting axis must be 1" - << ", src_shape[" << axes[i] << "]=" << src_shape[axes[i]] << "."; - if (i < this->axesnum_ - 1) { - CHECK(axes[i] < axes[i + 1]) << "The given axes must be in increasing order."; - } - } - for (index_t i = 0; i < dimsrc; i++) { - this->shape_[i] = src_shape[i]; - this->sizes_[i] = 1; - this->trailings_[i] = 1; - } - for (index_t i = 0; i < this->axesnum_; i++) { - this->shape_[axes[i]] = sizes[i]; - this->sizes_[i] = sizes[i]; - } - for (index_t i = 0; i < this->axesnum_; i++) { - this->trailings_[i] = 1; - for (index_t j = axes[i] + 1; j < dimsrc; ++j) { - this->trailings_[i] *= this->shape_[j]; - } - } - this->last_ = src_shape[dimsrc - 1]; - this->dst_last_ = this->shape_[dimsrc - 1]; - } -}; // struct BroadcastWithMultiAxesExp - -/*! -* \brief Broadcasting the tensor in the given axis (keepdim turned on) -* \param src source -* \param axes broadcasting axes -* \param sizes sizes of the broadcasting axes -* \tparam SrcExp source expression -* \tparam DType data type -* \tparam etype type of the expression -* \tparam TShape the flexible shape type -*/ -template -inline BroadcastWithMultiAxesExp::kDim> -broadcast_multi_axes(const Exp &src, -const TShape &axes, const TShape &sizes) { - return BroadcastWithMultiAxesExp::kDim>(src.self(), axes, sizes); -} - -/*! -* \brief Broadcasting the tensor to the target shape, - dimension of different sizes must be 1 in the original tensor. -* \param src source -* \param target_shape shape of the target broadcasting tensor -* \tparam SrcExp source expression -* \tparam DType data type -* \tparam etype type of the expression -* \tparam TShape the flexible shape type -*/ -template -inline BroadcastWithMultiAxesExp::kDim> -broadcast_to(const Exp &src, const TShape &target_shape) { - static const size_t dimsrc = ExpInfo::kDim; - CHECK_EQ(target_shape.ndim(), dimsrc); - std::vector axes_vec, sizes_vec; - Shape src_shape = ShapeCheck::Check(src.self()); - for (size_t i = 0; i < dimsrc; ++i) { - if (src_shape[i] != target_shape[i]) { - CHECK_EQ(src_shape[i], 1U) << "broadcasting axis must have size 1, received shape=" - << src_shape << " target_shape=" << target_shape; - axes_vec.push_back(i); - sizes_vec.push_back(target_shape[i]); - } - } - TShape axes = TShape(axes_vec.begin(), axes_vec.end()); - TShape sizes = TShape(sizes_vec.begin(), sizes_vec.end()); - return BroadcastWithMultiAxesExp::kDim>(src.self(), axes, sizes); -} - -//---------------------- -// Execution plan -//---------------------- -template -struct Plan, DType> { - public: - explicit Plan(const BroadcastWithAxisExp &e) - : src_(MakePlan(e.src_)), dst_last_(e.dst_last_), - trailing_(e.trailing_), size_(e.size_), last_(e.last_) {} - MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { - index_t x = (i * dst_last_ + j) / trailing_ / size_; - index_t y = (i * dst_last_ + j) % trailing_; - index_t z = x * trailing_ + y; - return src_.Eval(z / last_, z % last_); - } - - private: - Plan src_; - const index_t dst_last_, trailing_, size_, last_; -}; - -template -struct Plan, DType> { - public: - explicit Plan(const BroadcastWithMultiAxesExp &e) - : src_(MakePlan(e.src_)), dst_last_(e.dst_last_), last_(e.last_), axesnum_(e.axesnum_), - trailings_(e.trailings_), sizes_(e.sizes_) {} - MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { - index_t indx = i * dst_last_ + j; - for (index_t p = 0; p < dimsrc; ++p) { - if (p >= axesnum_) { - break; - } - indx = (indx / trailings_[p] / sizes_[p]) * trailings_[p] + (indx % trailings_[p]); - } - return src_.Eval(indx / last_, indx % last_); - } - - private: - Plan src_; - const index_t dst_last_, last_, axesnum_; - const Shape trailings_, sizes_; -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_ diff --git a/include/mshadow/extension/channel_pool.h b/include/mshadow/extension/channel_pool.h deleted file mode 100644 index 60d1112f4a61..000000000000 --- a/include/mshadow/extension/channel_pool.h +++ /dev/null @@ -1,108 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file channel_pool.h - * \brief support for chpool - * \author Tianqi Chen - */ -#ifndef MSHADOW_EXTENSION_CHANNEL_POOL_H_ -#define MSHADOW_EXTENSION_CHANNEL_POOL_H_ -#include -#include "../extension.h" -namespace mshadow { -namespace expr { -/*! - * \brief channel pooling expression, do reduction over (local nearby) channels, - * used to implement local response normalization - * \tparam Reducer reduction method during pooling - * \tparam SrcExp source expression to be pooled from - * \tparam DType the type of elements - * \tparam srcdim dimension of src - */ -template -struct ChannelPoolingExp: - public MakeTensorExp, - SrcExp, srcdim, DType> { - /*! \brief source operand */ - const SrcExp &src_; - /*! \brief neighbor size */ - index_t nsize_; - /*! \brief stride of pooling */ - index_t stride_; - /*! \brief pad of pooling of each side */ - index_t pad_; - index_t src_channel_; - /*! \brief constructor */ - ChannelPoolingExp(const SrcExp &src, index_t nsize, index_t stride, index_t pad) - : src_(src), nsize_(nsize), stride_(stride), pad_(pad) { - this->shape_ = ShapeCheck::Check(src_); - this->src_channel_ = this->shape_[srcdim - 3]; - CHECK_GE(this->shape_[srcdim - 3], nsize_) - << "chpool: local size must be smaller than nchannels"; - this->shape_[srcdim - 3] = (this->src_channel_ - nsize + pad * 2 + 1) / stride; - } -}; -/*! - * \brief channel pooling, do reduction over (local nearby) channels, - * used to implement local response normalization - * \param src source data - * \param nsize neighbor size - * \return expression of pooled result - * \tparam Reducer reducer type - * \tparam SrcExp source expression - * \tparam DType the type of elements - * \tparam etype type of expression - */ -template -inline ChannelPoolingExp::kDim> -chpool(const Exp &src, index_t nsize) { - TypeCheckPass::kDim >= 3> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - CHECK_EQ(nsize % 2, 1U) << "chpool: if no pad is specified, local size must be odd"; - return ChannelPoolingExp::kDim>(src.self(), nsize, 1, nsize / 2); -} - -template -inline ChannelPoolingExp::kDim> -chpool(const Exp &src, index_t nsize, index_t stride, index_t pad) { - TypeCheckPass::kDim >= 3> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - return ChannelPoolingExp::kDim>(src.self(), nsize, stride, pad); -} - -//---------------------- -// Execution plan -//---------------------- -template -struct Plan, DType> { - public: - explicit Plan(const ChannelPoolingExp &e) - : src_(MakePlan(e.src_)), channel_(e.shape_[srcdim - 3]), - height_(e.shape_[srcdim - 2]), width_(e.shape_[srcdim - 1]), - hnsize_(e.nsize_), stride_(e.stride_), pad_(e.pad_), - src_channel_(e.src_channel_) {} - MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { - using namespace std; - const index_t y = i % height_; - i /= height_; - const index_t c = i % channel_; - const index_t n = i / channel_; - const index_t x = j; - const index_t cstart = c * stride_ < pad_ ? 0 : c * stride_ - pad_; - const index_t cend = min(c * stride_ - pad_ + hnsize_, channel_); - DType res; Reducer::SetInitValue(res); - for (index_t cc = cstart; cc < cend; ++cc) { - Reducer::Reduce(res, src_.Eval((n * src_channel_ + cc) * height_ + y, x)); - } - return res; - } - - private: - Plan src_; - const index_t channel_, height_, width_, hnsize_, stride_, pad_, src_channel_; -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_CHANNEL_POOL_H_ - diff --git a/include/mshadow/extension/channel_unpool.h b/include/mshadow/extension/channel_unpool.h deleted file mode 100644 index 00ba279c1760..000000000000 --- a/include/mshadow/extension/channel_unpool.h +++ /dev/null @@ -1,137 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file channel_pool.h - * \brief support for chpool - * \author Tianqi Chen - */ -#ifndef MSHADOW_EXTENSION_CHANNEL_UNPOOL_H_ -#define MSHADOW_EXTENSION_CHANNEL_UNPOOL_H_ -#include -#include "../extension.h" -namespace mshadow { -namespace expr { -/*! - * \brief channel pooling expression, do reduction over (local nearby) channels, - * used to implement local response normalization - * \tparam Reducer reduction method during pooling - * \tparam SrcExp source expression to be pooled from - * \tparam DType the type of elements - * \tparam srcdim dimension of src - */ -template -struct ChannelUnpoolingExp: - public MakeTensorExp, - SrcExp, srcdim, DType> { - /*! \brief source input, corresponds to src in pooling */ - const SrcExp &data_src_; - /*! \brief result of pooled data, corresponds to result of pooling */ - const SrcExp &data_pooled_; - /*! \brief gradient data of pooled part, to be propgate down */ - const SrcExp &grad_pooled_; - /*! \brief channel of pooled expression */ - index_t pchannel_; - /*! \brief kernel size in height */ - index_t nsize_; - /*! \brief kernel size in width */ - index_t kstride_; - /*! \brief pad */ - index_t pad_; - /*! \brief constructor */ - ChannelUnpoolingExp(const SrcExp &data_src, - const SrcExp &data_pooled, - const SrcExp &grad_pooled, - index_t nsize, index_t kstride, index_t pad) - : data_src_(data_src), data_pooled_(data_pooled), - grad_pooled_(grad_pooled), - nsize_(nsize), kstride_(kstride), pad_(pad) { - Shape pshape = ShapeCheck::Check(grad_pooled); - typedef ShapeCheck ShapeCheckSrcDimSrcExp; - CHECK_EQ(pshape, ShapeCheckSrcDimSrcExp::Check(data_pooled)) - << "ChannelUnPoolingExp: data and grad shape mismatch"; - Shape sshape = ShapeCheck::Check(data_src); - for (int k = 0; k < srcdim; ++k) { - if (k == 1) { - continue; - } - CHECK_EQ(pshape[k], sshape[k]) - << "ChannelUnPoolingExp: pooled tensor and src tensor shape mismatch" - << pshape[k] - << " vs " - << sshape[k]; - } - pchannel_ = pshape[1]; - this->shape_ = sshape; - } -}; -/*! - * \brief channel unpooling, do unroll over (local nearby) channels - * \param src source data - * \param nsize neighbor size - * \param stride stride of the pooling - * \param pad number of padding at each side - * \return expression of pooled result - * \tparam Reducer reducer type - * \tparam SrcExp source expression - * \tparam DType the type of elements - * \tparam etype type of expression - */ -template -inline ChannelUnpoolingExp::kDim> -ch_unpool(const Exp &data_src, - const Exp &data_pooled, - const Exp &grad_pooled, - index_t nsize, index_t stride, index_t pad) { - TypeCheckPass::kDim >= 3> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - return ChannelUnpoolingExp::kDim> - (data_src.self(), data_pooled.self(), grad_pooled.self(), nsize, stride, pad); -} - -template -inline ChannelUnpoolingExp::kDim> -ch_unpool(const Exp &data_src, - const Exp &data_pooled, - const Exp &grad_pooled, index_t nsize) { - return ch_unpool(data_src, data_pooled, grad_pooled, nsize, 1, nsize / 2); -} - - -//---------------------- -// Execution plan -//---------------------- -template -struct Plan, DType> { - public: - explicit Plan(const ChannelUnpoolingExp &e) - : data_src_(e.data_src_), data_pooled_(e.data_pooled_), - grad_pooled_(e.grad_pooled_), channel_(e.shape_[srcdim - 3]), - height_(e.shape_[srcdim - 2]), pchannel_(e.pchannel_), - hnsize_(e.nsize_), stride_(e.kstride_), pad_(e.pad_) {} - MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { - using namespace std; - const DType vsrc = data_src_.Eval(i, j); - const index_t y = i % height_; - i /= height_; - const index_t c = i % channel_; - const index_t n = i / channel_; - const index_t x = j; - const index_t cstart = c < hnsize_ - pad_ ? 0 - : (c - (hnsize_ - pad_) + stride_) / stride_; - const index_t cend = min((c + pad_ + stride_) / stride_, channel_); - DType val = static_cast(0); - for (index_t cc = cstart; cc < cend; ++cc) { - val += Reducer::PartialGrad(vsrc, - data_pooled_.Eval((n * pchannel_ + cc) * height_ + y, x)) * - grad_pooled_.Eval((n * pchannel_ + cc) * height_ + y, x); - } - return val; - } - - private: - Plan data_src_, data_pooled_, grad_pooled_; - const index_t channel_, height_, pchannel_, hnsize_, stride_, pad_; -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_CHANNEL_UNPOOL_H_ - diff --git a/include/mshadow/extension/choose.h b/include/mshadow/extension/choose.h deleted file mode 100644 index b1391724d400..000000000000 --- a/include/mshadow/extension/choose.h +++ /dev/null @@ -1,90 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file choose.h - * \brief support for implicit array selection operation - * \author Tianqi Chen - */ -#ifndef MSHADOW_EXTENSION_CHOOSE_H_ -#define MSHADOW_EXTENSION_CHOOSE_H_ - -#include "../extension.h" - -namespace mshadow { -namespace expr { -/*! - * \brief Make a choice of index in the lowest changing dimension. - * \tparam SrcExp type of lhs expression - * \tparam IndexExp type of index expression - * \tparam DType the type of elements - */ -template -struct MatChooseRowElementExp: - public Exp, - DType, type::kChainer> { - /*! \brief source operand */ - const SrcExp &src_; - /*! \brief index operand */ - const IndexExp &index_; - /*! \brief constructor */ - MatChooseRowElementExp(const SrcExp &src, const IndexExp &index) - : src_(src), index_(index) {} -}; - -template -inline MatChooseRowElementExp -mat_choose_row_element(const Exp &src, - const Exp &index) { - TypeCheckPass::kDim == 2 && ExpInfo::kDim == 1> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - return MatChooseRowElementExp(src.self(), index.self()); -} - -//---------------------- -// Execution plan -//---------------------- -template -struct Plan, DType> { - public: - explicit Plan(const MatChooseRowElementExp &e) - : src_(MakePlan(e.src_)), - index_(MakePlan(e.index_)) { - } - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - index_t idx = static_cast(index_.Eval(0, x)); - return src_.Eval(x, idx); - } - - private: - expr::Plan src_; - expr::Plan index_; -}; - -template -inline Plan, DType> -MakePlan(const MatChooseRowElementExp &exp) { - return Plan, DType>(exp); -} - -template -struct ShapeCheck > { - inline static Shape - Check(const MatChooseRowElementExp &t) { - CHECK(dim == 1) - << "MatChooseRowElementExp only support 1 dimension output"; - Shape<2> shape1 = ShapeCheck<2, SrcExp>::Check(t.src_); - Shape shape2 = ShapeCheck::Check(t.index_); - CHECK_EQ(shape1[0], shape2[0]) - << "mat_choose_row_element index length and number of rows in matrix"; - return shape2; - } -}; - -template -struct ExpInfo > { - static const int kDim = 1; - static const int kDevMask = ExpInfo::kDevMask & ExpInfo::kDevMask; -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_CHOOSE_H_ diff --git a/include/mshadow/extension/complex.h b/include/mshadow/extension/complex.h deleted file mode 100644 index 8e79b7eb819c..000000000000 --- a/include/mshadow/extension/complex.h +++ /dev/null @@ -1,525 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file complex.h - * \brief support for complex operations - * \author Xingjian Shi - */ -#ifndef MSHADOW_EXTENSION_COMPLEX_H_ -#define MSHADOW_EXTENSION_COMPLEX_H_ -#include -#include "../extension.h" - -namespace mshadow { -namespace op { -namespace complex { -enum BinaryCalculationType { kBinaryCC, kBinaryCR, kBinaryRC}; -enum UnitaryCalculationType { kUnitaryC2R, kUnitaryC2C, kUnitaryR2C }; -struct mul { - /*! \brief map a_real, a_imag, b_real, b_imag to result using defined operation */ - template - MSHADOW_XINLINE static DType RealMap(DType a_real, DType a_imag, - DType b_real, DType b_imag) { - return a_real * b_real - a_imag * b_imag; - } - template - MSHADOW_XINLINE static DType ImagMap(DType a_real, DType a_imag, - DType b_real, DType b_imag) { - return a_real * b_imag + b_real * a_imag; - } -}; - -struct div { - /*! \brief map a_real, a_imag, b_real, b_imag to result using defined operation */ - template - MSHADOW_XINLINE static DType RealMap(DType a_real, DType a_imag, - DType b_real, DType b_imag) { - return (a_real * b_real + a_imag * b_imag) / (b_real * b_real + b_imag * b_imag); - } - template - MSHADOW_XINLINE static DType ImagMap(DType a_real, DType a_imag, - DType b_real, DType b_imag) { - return (b_real * a_imag - a_real * b_imag) / (b_real * b_real + b_imag * b_imag); - } -}; - -struct conjugate { - template - MSHADOW_XINLINE static DType RealMap(const expr::Plan &src_, - index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) { - return src_.Eval(real_i, real_j); - } - template - MSHADOW_XINLINE static DType ImagMap(const expr::Plan &src_, - index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) { - return -src_.Eval(imag_i, imag_j); - } -}; - -struct exchange { - template - MSHADOW_XINLINE static DType RealMap(const expr::Plan &src_, - index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) { - return src_.Eval(imag_i, imag_j); - } - template - MSHADOW_XINLINE static DType ImagMap(const expr::Plan &src_, - index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) { - return src_.Eval(real_i, real_j); - } -}; - -// r2c operator -struct pad_imag { - template - MSHADOW_XINLINE static DType RealMap(const expr::Plan &src_, - index_t real_i, index_t real_j) { - return src_.Eval(real_i, real_j); - } - template - MSHADOW_XINLINE static DType ImagMap(const expr::Plan &src_, - index_t real_i, index_t real_j) { - return 0; - } -}; - -// c2r operator -struct toreal { - template - MSHADOW_XINLINE static DType RealMap(const expr::Plan &src_, - index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) { - DType real_val = src_.Eval(real_i, real_j); - return real_val; - } -}; - -struct abs_square { - template - MSHADOW_XINLINE static DType RealMap(const expr::Plan &src_, - index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) { - DType real_val = src_.Eval(real_i, real_j); - DType image_val = src_.Eval(imag_i, imag_j); - return real_val * real_val + image_val * image_val; - } -}; - -struct sum_real_imag { - template - MSHADOW_XINLINE static DType RealMap(const expr::Plan &src_, - index_t real_i, index_t real_j, index_t imag_i, index_t imag_j) { - DType real_val = src_.Eval(real_i, real_j); - DType image_val = src_.Eval(imag_i, imag_j); - return real_val + image_val; - } -}; -} // namespace complex -} // namespace op - -namespace expr { -//-------------------- -// ComplexBinaryMapExp -//-------------------- - /*! -* \brief binary map expression lhs [op] rhs where lhs and rhs are complex tensors -* \tparam OP operator -* \tparam calctype type of the calculation -* \tparam TA type of lhs -* \tparam TB type of rhs -* \tparam etype expression type, sa namespace::type -*/ -template -struct ComplexBinaryMapExp : public Exp, - DType, etype> { - /*! \brief left operand */ - const TA &lhs_; - /*! \brief right operand */ - const TB &rhs_; - /*! \brief constructor */ - explicit ComplexBinaryMapExp(const TA &lhs, const TB &rhs) - :lhs_(lhs), rhs_(rhs) {} -}; - -//------------------- -// ComplexConjExp -//------------------- -/*! -* \brief compute conj(src) where src is a complex tensor -* \tparam TA type of src -* \tparam etype expression type, sa namespace::type -*/ -template -struct ComplexUnitaryExp : public Exp, - DType, etype> { - /*! \brief source expression */ - const TA &src_; - /*! \brief constructor */ - explicit ComplexUnitaryExp(const TA &src) : src_(src) {} -}; - - - -template -inline ComplexBinaryMapExp -ComplexF(const Exp &lhs, const Exp &rhs) { - return ComplexBinaryMapExp(lhs.self(), rhs.self()); -} - -/*! -* \brief conj Negation the imaginary part of A where A is a complex tensor -* \param src source tensor -* \tparam e1 type of source expression -*/ -template -inline ComplexUnitaryExp -ComplexF(const Exp &src) { - return ComplexUnitaryExp(src.self()); -} - -/*! -* \brief complex_mul_cc Complex multipilication two complex tensors, A * B -*/ -template -inline ComplexBinaryMapExp -complex_mul_cc(const Exp &lhs, const Exp &rhs) { - return ComplexF(lhs, rhs); -} - -/*! -* \brief complex_mul_cr Complex multipilication a complex tensor A and a real tensor B -*/ -template -inline ComplexBinaryMapExp -complex_mul_cr(const Exp &lhs, const Exp &rhs) { - return ComplexF(lhs, rhs); -} - -/*! -* \brief complex_mul_rc Complex multipilication of a real tensor B and a complex tensor A -*/ -template -inline ComplexBinaryMapExp -complex_mul_rc(const Exp &lhs, const Exp &rhs) { - return ComplexF(lhs, rhs); -} - -/*! -* \brief complex_mul_cc Complex multipilication two complex tensors, A * B -*/ -template -inline ComplexBinaryMapExp -complex_div_cc(const Exp &lhs, const Exp &rhs) { - return ComplexF(lhs, rhs); -} - -/*! -* \brief complex_mul_cr Complex multipilication a complex tensor A and a real tensor B -*/ -template -inline ComplexBinaryMapExp -complex_div_cr(const Exp &lhs, const Exp &rhs) { - return ComplexF(lhs, rhs); -} - -/*! -* \brief complex_mul_rc Complex multipilication of a real tensor A and a complex tensor B -*/ -template -inline ComplexBinaryMapExp -complex_div_rc(const Exp &lhs, const Exp &rhs) { - return ComplexF(lhs, rhs); -} - -/*! -* \brief conj Negation the imaginary part of A where A is a complex tensor -* \param src source tensor -* \tparam e1 type of source expression -*/ -template -inline ComplexUnitaryExp -conj(const Exp &src) { - return ComplexF(src); -} - -/*! -* \brief complex_exchange Exchange the real and imaginary part of A where A is a complex tensor -* \param src source tensor -* \tparam e1 type of source expression -*/ -template -inline ComplexUnitaryExp -complex_exchange(const Exp &src) { - return ComplexF(src); -} - -/*! -* \brief complex_pad_imag Transform real matrix into complex matrix -* \param src source tensor -* \tparam e1 type of source expression -*/ -template -inline ComplexUnitaryExp -complex_pad_imag(const Exp &src) { - return ComplexF(src); -} - -/*! -* \brief complex_toreal convert complex matrix to real matrix, keep only real part -* \param src source tensor -* \tparam e1 type of source expression -*/ -template -inline ComplexUnitaryExp -complex_toreal(const Exp &src) { - return ComplexF(src); -} - -/*! -* \brief complex_abs_square calculate the square of the modulus of A where A is a complex tensor -* \param src source tensor -* \tparam e1 type of source expression -*/ -template -inline ComplexUnitaryExp -complex_abs_square(const Exp &src) { - return ComplexF(src); -} - -template -inline ComplexUnitaryExp -complex_sum_real_imag(const Exp &src) { - return ComplexF(src); -} - -template -struct ShapeCheck > { - inline static Shape - Check(const ComplexBinaryMapExp &t) { - Shape shape1 = ShapeCheck::Check(t.lhs_); - Shape shape2 = ShapeCheck::Check(t.rhs_); - if (shape1[0] == 0) return shape2; - if (shape2[0] == 0) return shape1; - if (calctype == op::complex::kBinaryCC) { - CHECK_EQ(shape1, shape2) << "ComplexBinaryMapExp (CC): Shapes of operands are not the same."; - CHECK_EQ(shape1[dim - 1] % 2, 0) << - "ComplexBinaryMapExp (CC): Shape of the last dimension is not even. " - "We must have real part + imaginary part."; - return shape1; - } else if (calctype == op::complex::kBinaryCR) { - for (int i = 0; i < dim - 1; ++i) { - CHECK_EQ(shape1.shape_[i], shape2.shape_[i]) << - "ComplexBinaryMapExp (CR): Shapes of operands are not the same."; - } - CHECK_EQ(shape1[dim - 1], shape2[dim - 1] * 2) << - "ComplexBinaryMapExp (CR): Shapes of operands do not match."; - return shape1; - } else if (calctype == op::complex::kBinaryRC) { - for (int i = 0; i < dim - 1; ++i) { - CHECK_EQ(shape1.shape_[i], shape2.shape_[i]) << - "ComplexBinaryMapExp (RC): Shapes of operands are not the same."; - } - CHECK_EQ(shape2[dim - 1], shape1[dim - 1] * 2) << - "ComplexBinaryMapExp (RC): Shapes of operands do not match."; - return shape2; - } else { - LOG(FATAL) << "ComplexBinaryMapExp: Unexpected Calculation Type!"; - return shape1; - } - } -}; - -template -struct ShapeCheck > { - inline static Shape Check(const ComplexUnitaryExp &t) { - Shape s = ShapeCheck::Check(t.src_); - CHECK_EQ(s[dim - 1] % 2, 0) << "ComplexUnitaryExp: Shape of the last dimension is not even. " - "We must have real + imaginary."; - if (calctype == op::complex::kUnitaryC2C) { - return s; - } else if (calctype == op::complex::kUnitaryC2R) { - Shape s_ret = s; - s_ret[dim - 1] /= 2; - return s_ret; - } else if (calctype == op::complex::kUnitaryR2C) { - Shape s_ret = s; - s_ret[dim-1] *= 2; - return s_ret; - } else { - LOG(FATAL) << "ComplexUnitaryExp: Unexpected Calculation Type!"; - return s; - } - } -}; - - - -// complex binary expression (cc) -template -class Plan, DType> { - public: - explicit Plan(const Plan &lhs, const Plan &rhs) - : lhs_(lhs), rhs_(rhs) {} - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - const index_t base_x = static_cast(x / 2) * 2; - if (x % 2 == 0) { - return OP::RealMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1), - rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1)); - } else { - return OP::ImagMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1), - rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1)); - } - } - - private: - Plan lhs_; - Plan rhs_; -}; - -// complex binary expression (cr) -template -class Plan, DType> { - public: - explicit Plan(const Plan &lhs, const Plan &rhs) - : lhs_(lhs), rhs_(rhs) {} - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - const index_t base_x = static_cast(x / 2) * 2; - if (x % 2 == 0) { - return OP::RealMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1), - rhs_.Eval(y, base_x / 2), static_cast(0)); - } else { - return OP::ImagMap(lhs_.Eval(y, base_x), lhs_.Eval(y, base_x + 1), - rhs_.Eval(y, base_x / 2), static_cast(0)); - } - } - - private: - Plan lhs_; - Plan rhs_; -}; - - -// complex binary expression (rc) -template -class Plan, DType> { - public: - explicit Plan(const Plan &lhs, const Plan &rhs) - : lhs_(lhs), rhs_(rhs) {} - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - const index_t base_x = static_cast(x / 2) * 2; - if (x % 2 == 0) { - return OP::RealMap(lhs_.Eval(y, base_x / 2), static_cast(0), - rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1)); - } else { - return OP::ImagMap(lhs_.Eval(y, base_x / 2), static_cast(0), - rhs_.Eval(y, base_x), rhs_.Eval(y, base_x + 1)); - } - } - - private: - Plan lhs_; - Plan rhs_; -}; - - -// complex unitary expression (c2c) -template -class Plan, DType> { - public: - explicit Plan(const Plan &src) : src_(src) {} - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - const index_t base_x = static_cast(x / 2) * 2; - if (0 == x % 2) { - return OP::RealMap(src_, y, base_x, y, base_x + 1); - } else { - return OP::ImagMap(src_, y, base_x, y, base_x + 1); - } - } - - private: - Plan src_; -}; - -// complex unitary expression (r2c) -template -class Plan, DType> { - public: - explicit Plan(const Plan &src) : src_(src) {} - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - const index_t real_x = static_cast(x / 2); - if (0 == x%2) { - // x,y should be coordinates in the complex matrix - // this defines how we will give value to the real part from the real matrix src_, - // thus the index has only 2 dimensions - return OP::RealMap(src_, y, real_x); - } else { - return OP::ImagMap(src_, y, real_x); - } - } - - private: - Plan src_; -}; - -// complex unitary expression (c2r) -template -class Plan, DType> { - public: - explicit Plan(const Plan &src) : src_(src) {} - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - return OP::RealMap(src_, y, x * 2, y, x * 2 + 1); - } - - private: - Plan src_; -}; - - - -template -inline Plan, DType> -MakePlan(const ComplexBinaryMapExp &e) { - return Plan, - DType>(MakePlan(e.lhs_), MakePlan(e.rhs_)); -} - -template -inline Plan, DType> -MakePlan(const ComplexUnitaryExp &e) { - return Plan, - DType>(MakePlan(e.src_)); -} - - - -template -struct ExpInfo > { - static const int kDimLhs = ExpInfo::kDim; - static const int kDimRhs = ExpInfo::kDim; - static const int kDim = (kDimLhs >= 0 && kDimRhs >= 0) ? \ - (kDimLhs == 0 ? \ - kDimRhs : \ - ((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1; - static const int kDevMask = ExpInfo::kDevMask & ExpInfo::kDevMask; -}; - -template -struct ExpInfo > { - static const int kDim = ExpInfo::kDim; - static const int kDevMask = ExpInfo::kDevMask; -}; - -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_COMPLEX_H_ diff --git a/include/mshadow/extension/concat.h b/include/mshadow/extension/concat.h deleted file mode 100644 index c51b1dcb0a26..000000000000 --- a/include/mshadow/extension/concat.h +++ /dev/null @@ -1,194 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file concat.h - * \brief support for concatenation - */ -#ifndef MSHADOW_EXTENSION_CONCAT_H_ -#define MSHADOW_EXTENSION_CONCAT_H_ - -#include "../extension.h" - -namespace mshadow { -namespace expr { -/*! - * \brief concat expression, concat two tensor's channel - * \tparam LhsExp left expression - * \tparam RhsExp right expression - * \tparam DType the type of elements - * \tparam srcdim dimension of src - * \tparam dimsrc_m_cat dimsrc - dimcat - */ -template -struct ConcatExp : public TRValue, - Device, srcdim, DType> { - static const int dimcat = srcdim - dimsrc_m_cat; - const LhsExp &src1_; - const RhsExp &src2_; - index_t dcat_src1_; - index_t dcat_src2_; - Shape<4> shape_; - ConcatExp(const LhsExp &src1, const RhsExp &src2) : src1_(src1), src2_(src2) { - Shape sshape1 = ShapeCheck::Check(src1_); - Shape sshape2 = ShapeCheck::Check(src2_); - #pragma unroll - for (int i = 0; i < srcdim; ++i) { - if (i != dimcat) { - CHECK_EQ(sshape1[i], sshape2[i]) << "ConcatExp: shape mismatch"; - } - } - this->shape_ = sshape1; - this->shape_[dimcat] = sshape1[dimcat] + sshape2[dimcat]; - this->dcat_src1_ = sshape1[dimcat]; - this->dcat_src2_ = sshape2[dimcat]; - } - template - inline void - operator=(const expr::Exp &exp) { - this->__assign(exp); - } - inline void - operator=(const DType &exp) { - this->__assign(exp); - } -}; // struct ConcatExp -/*! - * \brief concat two 4D tensor - * \param src1 source tensor1 - * \param src2 source tensor2 - * \return concated 4D tensor - * \tparam cdim the dimension to concatnate on - * \tparam SrcExp source expression - * \tparam DType the type of elements - * \tparam etype type of expression - */ -template -inline ConcatExp -concat(const TRValue &src1, - const TRValue &src2) { - TypeCheckPass::kDim == ExpInfo::kDim> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - TypeCheckPass::kDim == srcdim> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - return ConcatExp - (src1.self(), src2.self()); -} -//------------------------ -// engine plugin -//------------------------ -// runtime shapecheck -template -struct ShapeCheck >{ - inline static Shape Check(const ConcatExp &t) { - return t.shape_; - } -}; -template -struct StreamInfo >{ - inline static Stream * - Get(const ConcatExp &t) { - Stream *lhs = StreamInfo::Get(t.src1_); - Stream *rhs = StreamInfo::Get(t.src2_); - if (lhs != rhs) return NULL; - return lhs; - } -}; -// static typecheck -template -struct ExpInfo >{ - static const int kDimLhs = ExpInfo::kDim; - static const int kDimRhs = ExpInfo::kDim; - // copy from binarymap - static const int kDim = (kDimLhs >= 0 && kDimRhs >= 0) ?\ - (kDimLhs == 0 ?\ - kDimRhs :\ - ((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1; - static const int kDevMask = ExpInfo::kDevMask & ExpInfo::kDevMask; -}; -//---------------------- -// Execution plan -//--------------------- -template -struct Plan, DType> { - public: - static const int dimcat = srcdim - dimsrc_m_cat; - explicit Plan(const ConcatExp &e) - : src1_(MakePlan(e.src1_)), src2_(MakePlan(e.src2_)), - height_(e.shape_.ProdShape(dimcat + 1, srcdim - 1)), - ch_src1_(e.dcat_src1_), ch_src2_(e.dcat_src2_), ch_(e.shape_[dimcat]) {} - MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { - const index_t y = i % height_; - i /= height_; - const index_t c = i % ch_; - const index_t b = i / ch_; - const index_t x = j; - if (c < ch_src1_) { - return src1_.Eval((b * ch_src1_ + c) * height_ + y, x); - } else { - return src2_.Eval((b * ch_src2_ + c - ch_src1_) * height_ + y, x); - } - } - MSHADOW_XINLINE DType &REval(index_t i, index_t j) { - const index_t y = i % height_; - i /= height_; - const index_t c = i % ch_; - const index_t b = i / ch_; - const index_t x = j; - if (c < ch_src1_) { - return src1_.REval((b * ch_src1_ + c) * height_ + y, x); - } else { - return src2_.REval((b * ch_src2_ + c - ch_src1_) * height_ + y, x); - } - } - - private: - Plan src1_; - Plan src2_; - const index_t height_, ch_src1_, ch_src2_, ch_; -}; // struct Plan - -// specialize for concat in x -template -struct Plan, DType> { - public: - explicit Plan(const ConcatExp &e) - : src1_(MakePlan(e.src1_)), src2_(MakePlan(e.src2_)), - width_src1_(e.dcat_src1_) {} - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - if (x < width_src1_) { - return src1_.Eval(y, x); - } else { - return src2_.Eval(y, x - width_src1_); - } - } - MSHADOW_XINLINE DType &REval(index_t y, index_t x) { - if (x < width_src1_) { - return src1_.REval(y, x); - } else { - return src2_.REval(y, x - width_src1_); - } - } - - private: - Plan src1_; - Plan src2_; - const index_t width_src1_; -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_CONCAT_H_ diff --git a/include/mshadow/extension/crop.h b/include/mshadow/extension/crop.h deleted file mode 100644 index 80096a2d22d3..000000000000 --- a/include/mshadow/extension/crop.h +++ /dev/null @@ -1,119 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file crop.h - * \brief support for crop - * \author Tianqi Chen - */ -#ifndef MSHADOW_EXTENSION_CROP_H_ -#define MSHADOW_EXTENSION_CROP_H_ -#include "../extension.h" -namespace mshadow { -namespace expr { -/*! - * \brief crop expression, cut off the boundary region, reverse operation of padding - * \tparam SrcExp source expression to be pooled from - * \tparam DType the type of elements - * \tparam srcdim dimension of src - */ -template -struct CroppingExp: - public MakeTensorExp, - SrcExp, srcdim, DType> { - /*! \brief source operand */ - const SrcExp &src_; - /*! \brief pad height */ - index_t pad_height_; - /*! \brief pad height */ - index_t pad_width_; - /*! \brief src height */ - index_t src_height_; - /*! \brief constructor */ - explicit CroppingExp(const SrcExp &src, Shape<2> cshape) - : src_(src) { - this->shape_ = ShapeCheck::Check(src_); - CHECK_GE(this->shape_[srcdim - 2], cshape[0]) << "CroppingExp: height requirement not met"; - CHECK_GE(this->shape_[srcdim - 1], cshape[1]) << "CroppingExp: width requirement not met"; - pad_height_ = (this->shape_[srcdim - 2] - cshape[0]) / 2; - pad_width_ = (this->shape_[srcdim - 1] - cshape[1]) / 2; - src_height_ = this->shape_[srcdim - 2]; - this->shape_[srcdim - 2] = cshape[0]; // height - this->shape_[srcdim - 1] = cshape[1]; // width - } - /*! \brief constructor */ - explicit CroppingExp(const SrcExp &src, Shape<2> cshape, - index_t start_height, index_t start_width) - : src_(src), pad_height_(start_height), pad_width_(start_width) { - this->shape_ = ShapeCheck::Check(src_); - CHECK_GE(this->shape_[srcdim - 2], cshape[0] + start_height) - << "CroppingExp: height requirement not met"; - CHECK_GE(this->shape_[srcdim - 1], cshape[1] + start_width) - << "CroppingExp: width requirement not met"; - src_height_ = this->shape_[srcdim - 2]; - this->shape_[srcdim - 2] = cshape[0]; // height - this->shape_[srcdim - 1] = cshape[1]; // width - } -}; // struct CroppingExp -/*! - * \brief revserse operationg of padding, cut off boundaries, - * crop output from center of input - * \param src original image batches - * \param oshape output shape to be cropped - * \return expression corresponding to padded result - * \tparam SrcExp source expression - * \tparam DType the type of elements - * \tparam etype type of expression - */ -template -inline CroppingExp::kDim> -crop(const Exp &src, Shape<2> oshape) { - TypeCheckPass::kDim >= 2> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - return CroppingExp::kDim>(src.self(), oshape); -} -/*! - * \brief same as crop, but can specify starting position to do cropping - * \param src original image batches - * \param oshape output shape to be cropped - * \param start_height start height position to do cropping - * \param start_width start width position to do cropping - * \return expression corresponding to padded result - * \tparam SrcExp source expression - * \tparam DType the type of elements - * \tparam etype type of expression - */ -template -inline CroppingExp::kDim> -crop(const Exp &src, Shape<2> oshape, - index_t start_height, index_t start_width) { - TypeCheckPass::kDim >= 2> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - return CroppingExp::kDim> - (src.self(), oshape, start_height, start_width); -} -//---------------------- -// Execution plan -//---------------------- -template -struct Plan, DType> { - public: - explicit Plan(const CroppingExp &e) - : src_(MakePlan(e.src_)), - pad_height_(e.pad_height_), pad_width_(e.pad_width_), - new_height_(e.shape_[srcdim - 2]), src_height_(e.src_height_) {} - MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { - const index_t x = j; - const index_t y = i % new_height_; - const index_t c = i / new_height_; - const index_t h = y + pad_height_; - const index_t w = x + pad_width_; - return src_.Eval(c * src_height_ + h, w); - } - private: - Plan src_; - const index_t pad_height_, pad_width_; - const index_t new_height_; - const index_t src_height_; -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_CROP_H_ diff --git a/include/mshadow/extension/fill.h b/include/mshadow/extension/fill.h deleted file mode 100644 index 4ac62c1673e5..000000000000 --- a/include/mshadow/extension/fill.h +++ /dev/null @@ -1,103 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file fill.h - * \brief support for implicit array filling operation - * \author Xingjian Shi - */ -#ifndef MSHADOW_EXTENSION_FILL_H_ -#define MSHADOW_EXTENSION_FILL_H_ - -#include "../extension.h" - - -namespace mshadow { -namespace expr { -/*! - * \brief Set value of a specific element in each line of the data matrix. - * \tparam SrcExp type of src expression - * \tparam ValExp type of val expression - * \tparam IndexExp type of index expression - * \tparam DType the type of ret expression - */ -template -struct MatFillRowElementExp: - public Exp, - DType, type::kChainer> { - /*! \brief src operand */ - const SrcExp &src_; - const ValExp &val_; - /*! \brief index operand */ - const IndexExp &index_; - /*! \brief constructor */ - MatFillRowElementExp(const SrcExp &src, const ValExp &val, const IndexExp &index) - : src_(src), val_(val), index_(index) {} -}; - -template -inline MatFillRowElementExp -mat_fill_row_element(const Exp &src, - const Exp &val, - const Exp &index) { - TypeCheckPass::kDim == 2 && ExpInfo::kDim == 1 - && ExpInfo::kDim == 1>::Error_Expression_Does_Not_Meet_Dimension_Req(); - return MatFillRowElementExp(src.self(), - val.self(), index.self()); -} - -//---------------------- -// Execution plan -//---------------------- -template -struct Plan, DType> { - public: - explicit Plan(const MatFillRowElementExp &e) - : src_(MakePlan(e.src_)), - val_(MakePlan(e.val_)), - index_(MakePlan(e.index_)) { - } - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - index_t idx = static_cast(index_.Eval(0, y)); - if (idx == x) { - return static_cast(val_.Eval(0, y)); - } else { - return static_cast(src_.Eval(y, x)); - } - } - - private: - expr::Plan src_; - expr::Plan val_; - expr::Plan index_; -}; - -template -inline Plan, DType> -MakePlan(const MatFillRowElementExp &exp) { - return Plan, DType>(exp); -} - -template -struct ShapeCheck > { - inline static Shape - Check(const MatFillRowElementExp &t) { - CHECK(dim == 2) - << "MatFillRowElementExp only support 2 dimension output"; - Shape<2> shape_src = ShapeCheck<2, SrcExp>::Check(t.src_); - Shape<1> shape_val = ShapeCheck<1, ValExp>::Check(t.val_); - Shape<1> shape_index = ShapeCheck<1, IndexExp>::Check(t.index_); - CHECK((shape_src[0] == shape_index[0]) && (shape_index[0] == shape_val[0])) - << "mat_fill_row_element index length, val length and number of rows in matrix"; - return shape_src; - } -}; - -template -struct ExpInfo > { - static const int kDim = 2; - static const int kDevMask = - ExpInfo::kDevMask & ExpInfo::kDevMask & ExpInfo::kDevMask; -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_FILL_H_ diff --git a/include/mshadow/extension/flip.h b/include/mshadow/extension/flip.h deleted file mode 100644 index 17d1894530fc..000000000000 --- a/include/mshadow/extension/flip.h +++ /dev/null @@ -1,132 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file flip.h - * \brief support for flip a certain dimension. - * \author Junyuan Xie - */ -#ifndef MSHADOW_EXTENSION_FLIP_H_ -#define MSHADOW_EXTENSION_FLIP_H_ - -#include "../extension.h" - -namespace mshadow { -namespace expr { -/*! - * \brief slice expression, slice a tensor's channel - * \tparam SrcExp left expression - * \tparam DType the type of elements - * \tparam srcdim dimension of src - * \tparam dimsrc_m_cat dimsrc - dimcat - */ -template -struct FlipExp : public TRValue, - Device, srcdim, DType> { - const SrcExp &src_; - index_t trailing_; - index_t stride_; - index_t stride_j_; - Shape shape_; - FlipExp(const SrcExp &src, int dim) - : src_(src) { - shape_ = ShapeCheck::Check(src_); - stride_ = shape_[dim]; - stride_j_ = shape_[srcdim-1]; - trailing_ = 1; - for (int i = dim + 1; i < srcdim; ++i) { - trailing_ *= shape_[i]; - } - } - template - inline void - operator=(const expr::Exp &exp) { - this->__assign(exp); - } - inline void - operator=(const DType &exp) { - this->__assign(exp); - } -}; // struct Flip - -/*! - * \brief Flip a Tensor - * \param src source tensor - * \param begin The beginning slice. - * \param end The end slice. - * \return sliced tensor - * \tparam sdim the dimension to slice on - * \tparam SrcExp source expression - * \tparam DType the type of elements - * \tparam etype type of expression - */ -template -inline FlipExp -flip(const TRValue &src, int dim) { - return FlipExp(src.self(), dim); -} -//------------------------ -// engine plugin -//------------------------ -// runtime shapecheck -template -struct ShapeCheck >{ - inline static Shape Check(const FlipExp &t) { - return t.shape_; - } -}; -template -struct StreamInfo >{ - inline static Stream * - Get(const FlipExp &t) { - return StreamInfo::Get(t.src_); - } -}; -// static typecheck -template -struct ExpInfo >{ - static const int kDim = ExpInfo::kDim; - static const int kDevMask = ExpInfo::kDevMask; -}; -//---------------------- -// Execution plan -//--------------------- -template -struct Plan, DType> { - public: - explicit Plan(const FlipExp &e) - : src_(MakePlan(e.src_)), stride_j_(e.stride_j_), - trailing_(e.trailing_), stride_(e.stride_) {} - MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { - index_t idx = i*stride_j_+j; - const index_t low = idx%trailing_; - index_t high = idx/trailing_; - const index_t x = high%stride_; - high /= stride_; - idx = (high*stride_+stride_-1-x)*trailing_+low; - return src_.Eval(idx/stride_j_, idx%stride_j_); - } - MSHADOW_XINLINE DType &REval(index_t i, index_t j) const { - index_t idx = i*stride_j_+j; - const index_t low = idx%trailing_; - index_t high = idx/trailing_; - const index_t x = high%stride_; - high /= stride_; - idx = (high*stride_+stride_-1-x)*trailing_+low; - return src_.REval(idx/stride_j_, idx%stride_j_); - } - - private: - Plan src_; - const index_t stride_j_, trailing_, stride_; -}; // struct Plan -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_FLIP_H_ diff --git a/include/mshadow/extension/implicit_gemm.h b/include/mshadow/extension/implicit_gemm.h deleted file mode 100644 index b4b88ea326c8..000000000000 --- a/include/mshadow/extension/implicit_gemm.h +++ /dev/null @@ -1,128 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file implicit_gemm.h - * \brief support for implicit GEMM operation - * \author Tianqi Chen - */ -#ifndef MSHADOW_EXTENSION_IMPLICIT_GEMM_H_ -#define MSHADOW_EXTENSION_IMPLICIT_GEMM_H_ - -#include "../extension.h" -#include "../packet-inl.h" - -namespace mshadow { -namespace expr { -/*! - * \brief Matrix multiplication. - * \tparam LhsExp type of lhs expression - * \tparam LhsExp type of rhs expression - * \tparam DType the type of elements - */ -template -struct ImplicitGEMMExp: - public Exp, - DType, type::kChainer> { - /*! \brief lhs operand */ - const LhsExp &lhs_; - /*! \brief rhs operand */ - const RhsExp &rhs_; - /*! \brief internal production size*/ - index_t prod_size_; - /*! \brief the shape of this expression */ - Shape<2> shape_; - /*! \brief constructor */ - ImplicitGEMMExp(const LhsExp &lhs, const RhsExp &rhs) - : lhs_(lhs), rhs_(rhs) { - Shape<2> slhs = ShapeCheck<2, LhsExp>::Check(lhs_); - Shape<2> srhs = ShapeCheck<2, RhsExp>::Check(rhs_); - this->shape_ = mshadow::Shape2(slhs[0], srhs[1]); - prod_size_ = slhs[1]; - } -}; - - -template -inline ImplicitGEMMExp -implicit_dot(const Exp &lhs, - const Exp &rhs) { - TypeCheckPass::kDim == 2 && ExpInfo::kDim == 2> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - return ImplicitGEMMExp(lhs.self(), rhs.self()); -} - -//---------------------- -// Execution plan -//---------------------- -template -struct Plan, DType> { - public: - explicit Plan(const ImplicitGEMMExp &e) - : lhs_(MakePlan(e.lhs_)), - rhs_(MakePlan(e.rhs_)), - prod_size_(e.prod_size_), - prod_size_lower_align_(packet::LowerAlign(e.prod_size_)) { - } - - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - typedef packet::Packet Packet; - Packet sum = Packet::Fill(0); - - const size_t packetSize = Packet::size; - DType lhs_temp[packetSize], rhs_temp[packetSize]; - - for (index_t i = 0; i < prod_size_lower_align_; i += packetSize) { - // unroll - for (index_t j = 0; j < packetSize; ++j) { - lhs_temp[j] = lhs_.Eval(y, i + j); - } - for (index_t j = 0; j < packetSize; ++j) { - rhs_temp[j] = rhs_.Eval(i + j, x); - } - sum = sum + Packet::LoadUnAligned(lhs_temp) * Packet::LoadUnAligned(rhs_temp); - } - DType ret_result = sum.Sum(); - - for (index_t i = prod_size_lower_align_; i < prod_size_; ++i) { - ret_result += lhs_.Eval(y, i) * rhs_.Eval(i, x); - } - return ret_result; - } - - private: - expr::Plan lhs_; - expr::Plan rhs_; - const index_t prod_size_; - const index_t prod_size_lower_align_; -}; - -template -inline Plan, DType> -MakePlan(const ImplicitGEMMExp &exp) { - return Plan, DType>(exp); -} - - -template -struct ShapeCheck > { - inline static Shape - Check(const ImplicitGEMMExp &t) { - CHECK(dim == 2) - << "ImplicitGEMMExp only support 2 dimension"; - Shape shape1 = ShapeCheck::Check(t.lhs_); - Shape shape2 = ShapeCheck::Check(t.rhs_); - CHECK_EQ(shape1[1], shape2[0]) - << "implicit_dot The matrix shape do not match"; - return t.shape_; - } -}; - -template -struct ExpInfo > { - static const int kDim = 2; - static const int kDevMask = ExpInfo::kDevMask & ExpInfo::kDevMask; -}; - -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_IMPLICIT_GEMM_H_ - diff --git a/include/mshadow/extension/mask.h b/include/mshadow/extension/mask.h deleted file mode 100644 index 0fd4cc6db72e..000000000000 --- a/include/mshadow/extension/mask.h +++ /dev/null @@ -1,97 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file mask.h - * \brief - * \author Bing Xu -*/ -#ifndef MSHADOW_EXTENSION_MASK_H_ -#define MSHADOW_EXTENSION_MASK_H_ - -#include "../extension.h" - -namespace mshadow { -namespace expr { - -/*! \brief Broadcast a mask and do element-wise multiplication - * \tparam IndexExp type of index expression - * \tparam SrcExp type of src expression - * \tparam DType data type - */ -template -struct MaskExp: public Exp, - DType, type::kChainer> { - /*! \brief index oprand */ - const IndexExp &index_; - /*! \brief matrix oprand */ - const SrcExp &src_; - /*! constructor */ - MaskExp(const IndexExp &index, const SrcExp &src) - : index_(index), src_(src) {} -}; // struct MaskExp - - - -template -inline MaskExp -mask(const Exp &index, - const Exp &src) { - return MaskExp(index.self(), src.self()); -} - - -//---------------------- -// Execution plan -//---------------------- - -template -struct Plan, DType> { - public: - explicit Plan(const MaskExp &e) - : index_(MakePlan(e.index_)), src_(MakePlan(e.src_)) { - } - - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - return static_cast(src_.Eval(y, x) * index_.Eval(0, y)); - } - - private: - expr::Plan index_; - expr::Plan src_; -}; // struct Plan - -template -inline Plan, DType> -MakePlan(const MaskExp &exp) { - return Plan, DType>(exp); -} - -template -struct ShapeCheck > { - inline static Shape - Check(const MaskExp &t) { - CHECK(dim == 2) - << "MaskExp only support 2D output"; - Shape<1> dshape = ShapeCheck<1, IndexExp>::Check(t.index_); - Shape<2> wshape = ShapeCheck<2, SrcExp>::Check(t.src_); - CHECK_EQ(dshape[0], wshape[0]) << "MaskExp require inputs in same first dimention"; - Shape ret; - ret[0] = wshape[0]; - ret[1] = wshape[1]; - return ret; - } -}; - - -template -struct ExpInfo > { - static const int kDim = 2; - static const int kDevMask = ExpInfo::kDevMask; -}; - -} // namespace expr -} // namespace mshadow - -#endif // MSHADOW_EXTENSION_MASK_H_ diff --git a/include/mshadow/extension/mirror.h b/include/mshadow/extension/mirror.h deleted file mode 100644 index 9e9edc9b6f70..000000000000 --- a/include/mshadow/extension/mirror.h +++ /dev/null @@ -1,62 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file mirror.h - * \brief support for mirror - * \author Tianqi Chen - */ -#ifndef MSHADOW_EXTENSION_MIRROR_H_ -#define MSHADOW_EXTENSION_MIRROR_H_ -#include "../extension.h" -namespace mshadow { -namespace expr { -/*! - * \brief mirror expression, mirror a image in width - * \tparam SrcExp source expression to be mirrored - * \tparam DType the type of elements - * \tparam srcdim dimension of src - */ -template -struct MirroringExp: - public MakeTensorExp, - SrcExp, srcdim, DType> { - /*! \brief source operand */ - const SrcExp &src_; - /*! \brief constructor */ - explicit MirroringExp(const SrcExp &src) : src_(src) { - this->shape_ = ShapeCheck::Check(src_); - } -}; -/*! - * \brief mirroring expression, mirror images in width - * \param src original image batches - * \return expression corresponding to mirrored result - * \tparam SrcExp source expression - * \tparam DType the type of elements - * \tparam etype type of expression - */ -template -inline MirroringExp::kDim> -mirror(const Exp &src) { - TypeCheckPass::kDim >= 2> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - return MirroringExp::kDim>(src.self()); -} -//---------------------- -// Execution plan -//---------------------- -template -struct Plan, DType> { - public: - explicit Plan(const MirroringExp &e) - : src_(MakePlan(e.src_)), width_(e.shape_[srcdim - 1]) {} - MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { - return src_.Eval(i, width_ - j - 1); - } - - private: - Plan src_; - const index_t width_; -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_MIRROR_H_ diff --git a/include/mshadow/extension/one_hot.h b/include/mshadow/extension/one_hot.h deleted file mode 100644 index 326d4c3560eb..000000000000 --- a/include/mshadow/extension/one_hot.h +++ /dev/null @@ -1,87 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file one_hot.h - * \brief Create one-hot indicator array based on the index. - * \author Tianqi Chen - */ -#ifndef MSHADOW_EXTENSION_ONE_HOT_H_ -#define MSHADOW_EXTENSION_ONE_HOT_H_ - -#include "../extension.h" - - -namespace mshadow { -namespace expr { -/*! - * \brief Create a one-hot indicator array. - * \tparam IndexExp type of index expression - * \tparam DType the type of elements - */ -template -struct OneHotEncodeExp: - public Exp, - DType, type::kChainer> { - /*! \brief index operand */ - const IndexExp &index_; - /*! \brief number of choices we can have. */ - index_t num_choices_; - /*! \brief constructor */ - OneHotEncodeExp(const IndexExp &index, index_t num_choices) - : index_(index), num_choices_(num_choices) {} -}; - -template -inline OneHotEncodeExp -one_hot_encode(const Exp &index, index_t num_choices) { - TypeCheckPass::kDim == 1> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - return OneHotEncodeExp(index.self(), num_choices); -} - -//---------------------- -// Execution plan -//---------------------- -template -struct Plan, DType> { - public: - explicit Plan(const OneHotEncodeExp &e) - : index_(MakePlan(e.index_)) { - } - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - index_t idx = static_cast(index_.Eval(0, y)); - return static_cast(x == idx); - } - - private: - expr::Plan index_; -}; - -template -inline Plan, DType> -MakePlan(const OneHotEncodeExp &exp) { - return Plan, DType>(exp); -} - -template -struct ShapeCheck > { - inline static Shape - Check(const OneHotEncodeExp &t) { - CHECK(dim == 2) - << "OneHotEncodeExp only support 2 dimension output"; - Shape<1> shape = ShapeCheck<1, IndexExp>::Check(t.index_); - Shape ret; - ret[0] = shape[0]; - ret[1] = t.num_choices_; - return ret; - } -}; - -template -struct ExpInfo > { - static const int kDim = 2; - static const int kDevMask = ExpInfo::kDevMask; -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_ONE_HOT_H_ diff --git a/include/mshadow/extension/pack_col2patch.h b/include/mshadow/extension/pack_col2patch.h deleted file mode 100644 index 37f1a699ead5..000000000000 --- a/include/mshadow/extension/pack_col2patch.h +++ /dev/null @@ -1,154 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file pack_col2patch.h - * \brief support for pack - * \author Tianqi Chen - */ -#ifndef MSHADOW_EXTENSION_PACK_COL2PATCH_H_ -#define MSHADOW_EXTENSION_PACK_COL2PATCH_H_ -#include -#include "../extension.h" -namespace mshadow { -namespace expr { -/*! - * \brief reverse operation of UnpackPatchToCol, - * used to backprop gradient back - * this is a version supporting multiple images - * \tparam SrcExp source expression - * \tparam DType the type of elements - * \tparam dstdim destination dimension - */ -template -struct PackColToPatchXExp: - public MakeTensorExp, - SrcExp, dstdim, DType> { - /*! \brief source operand */ - const SrcExp &src_; - /*! \brief patch height */ - index_t psize_y_; - /*! \brief patch height */ - index_t psize_x_; - /*! \brief patch stride */ - index_t pstride_y_; - index_t pstride_x_; - /*! \brief patch dilate */ - index_t pdilate_y_; - index_t pdilate_x_; - /*! \brief constructor */ - PackColToPatchXExp(const SrcExp &src, Shape imshape, - index_t psize_y, index_t psize_x, - index_t pstride_y, index_t pstride_x, - index_t pdilate_y, index_t pdilate_x) - :src_(src), psize_y_(psize_y), psize_x_(psize_x), - pstride_y_(pstride_y), pstride_x_(pstride_x), - pdilate_y_(pdilate_y), pdilate_x_(pdilate_x){ - this->shape_ = imshape; - const index_t o_height = (imshape[dstdim - 2] - - (pdilate_y * (psize_y - 1)+ 1))/pstride_y + 1; - const index_t o_width = (imshape[dstdim - 1] - - (pdilate_x * (psize_x - 1) + 1)) / pstride_x + 1; - Shape<2> sshape = ShapeCheck<2, SrcExp>::Check(src_); - CHECK_EQ(sshape[1], o_height * o_width * imshape.ProdShape(0, dstdim - 3)) - << "PackColToPatchExp: src.size(1) mismatch"; - CHECK_EQ(sshape[0], psize_y * psize_x * imshape[dstdim - 3]) - << "PackColToPatchExp: src.size(0) mismatch"; - } -}; -/*! - * \brief reverse operation of pack_col2patch, can be used to implement deconvolution - * \return packed img expression - * \param mat source matrix - * \param imshape shape of target img - * \param psize_y height of each patch - * \param psize_x height of each patch - * \param pstride stride of each patch - * \tparam SrcExp source expression - * \tparam DType the type of elements - * \tparam dstdim destination dimension - * \tparam etype type of expression - */ -template -inline PackColToPatchXExp -pack_col2patch(const expr::Exp &src, - Shape imshape, index_t psize_y, - index_t psize_x, index_t pstride, index_t pdilate) { - TypeCheckPass::kDim == 2> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y) - << "PackColToPatch:image shape smaller than patch size"; - return PackColToPatchXExp(src.self(), imshape, - psize_y, psize_x, pstride, pstride, - pdilate, pdilate); -} -/*! - *if you want to specify kstride_y and kstride_x - */ -template -inline PackColToPatchXExp -pack_col2patch(const expr::Exp &src, - Shape imshape, index_t psize_y, - index_t psize_x, index_t pstride_y, index_t pstride_x, - index_t pdilate_y, index_t pdilate_x) { - TypeCheckPass::kDim == 2> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y) - << "PackColToPatch:image shape smaller than patch size"; - return PackColToPatchXExp(src.self(), imshape, - psize_y, psize_x, pstride_y, pstride_x, - pdilate_y, pdilate_x); -} - -//---------------------- -// Execution plan -//---------------------- -template -struct Plan, DType> { - public: - explicit Plan(const PackColToPatchXExp &e) - :src_(MakePlan(e.src_)), psize_y_(e.psize_y_), - psize_x_(e.psize_x_), pstride_y_(e.pstride_y_), pstride_x_(e.pstride_x_), - i_channel_(e.shape_[dstdim - 3]), pdilate_y_(e.pdilate_y_), pdilate_x_(e.pdilate_x_), - i_height_(e.shape_[dstdim - 2]), - o_height_((e.shape_[dstdim - 2] - (pdilate_y_ * (psize_y_ - 1) + 1)) / - pstride_y_ + 1), - o_width_((e.shape_[dstdim - 1] - (pdilate_x_ * (psize_x_ - 1) + 1)) / - pstride_x_ + 1) { - // note: i/o convention are same as unpack - } - MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { - using namespace std; - const index_t y = i % i_height_; - const index_t idivh = i / i_height_; - const index_t c = idivh % i_channel_; - const index_t n = idivh / i_channel_; - const index_t x = j; - - const index_t psize_y_dilate = (pdilate_y_ * (psize_y_ - 1) + 1); - const index_t psize_x_dilate = (pdilate_x_ * (psize_x_ - 1) + 1); - - const index_t py_min = - y < psize_y_dilate ? y % pdilate_y_ : (y-psize_y_dilate + pstride_y_) / pstride_y_; - const index_t px_min = - x < psize_x_dilate ? x % pdilate_x_ : (x-psize_x_dilate + pstride_x_) / pstride_x_; - const index_t py_max = min((y + pstride_y_) / pstride_y_, o_height_); - const index_t px_max = min((x + pstride_x_) / pstride_x_, o_width_); - DType res = static_cast(0); - for (index_t py = py_min; py < py_max; py += pdilate_y_) { - for (index_t px = px_min; px < px_max; px += pdilate_x_) { - res += src_.Eval(((c * psize_y_ + (y - py*pstride_y_) / pdilate_y_) * psize_x_ + - (x - px * pstride_x_) / pdilate_x_), - (n * o_height_ + py) * o_width_ + px); - } - } - return res; - } - - private: - Plan src_; - const index_t psize_y_, psize_x_, pstride_y_, pstride_x_, i_channel_; - const index_t pdilate_y_, pdilate_x_; - const index_t i_height_, o_height_, o_width_; -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_PACK_COL2PATCH_H_ diff --git a/include/mshadow/extension/pad.h b/include/mshadow/extension/pad.h deleted file mode 100644 index 6622a022acc8..000000000000 --- a/include/mshadow/extension/pad.h +++ /dev/null @@ -1,111 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file pad.h - * \brief support for pad - * \author Tianqi Chen - */ -#ifndef MSHADOW_EXTENSION_PAD_H_ -#define MSHADOW_EXTENSION_PAD_H_ -#include "../extension.h" -namespace mshadow { -namespace expr { -/*! - * \brief padding expression, pad a image with zeros - * \tparam SrcExp source expression - * \tparam DType the type of elements - * \tparam srcdim dimension of src - */ -template -struct PaddingExp: - public MakeTensorExp, - SrcExp, srcdim, DType> { - /*! \brief source operand */ - const SrcExp &src_; - /*! \brief pad size in y */ - index_t pad_y_; - /*! \brief pad size in x */ - index_t pad_x_; - /*! \brief source tensor height */ - index_t src_height_; - /*! \brief source tensor width */ - index_t src_width_; - /*! \brief constructor */ - PaddingExp(const SrcExp &src, index_t pad_y, index_t pad_x) - : src_(src), pad_y_(pad_y), pad_x_(pad_x) { - this->shape_ = ShapeCheck::Check(src_); - src_height_ = this->shape_[srcdim - 2]; - src_width_ = this->shape_[srcdim - 1]; - this->shape_[srcdim - 2] += pad_y * 2; // height - this->shape_[srcdim - 1] += pad_x * 2; // width - } -}; -/*! - * \brief padding expression, pad a image with zeros on boundaries, padding affects shape[0], and shape[1] - * \param src original image batches - * \param pad padding size - * \return expression corresponding to padded result - * \tparam SrcExp source expression - * \tparam DType the content data type - * \tparam etype type of expression - */ -template -inline PaddingExp::kDim> -pad(const Exp &src, index_t pad) { - TypeCheckPass::kDim >= 2> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - return PaddingExp::kDim>(src.self(), pad, pad); -} -/*! - * \brief padding expression, pad a image with zeros on boundaries, padding affects shape[0], and shape[1] - * \param src original image batches - * \param pad_y padding size in y - * \param pad_x padding size in x - * \return expression corresponding to padded result - * \tparam SrcExp source expression - * \tparam DType the content data type - * \tparam etype type of expression - */ -template -inline PaddingExp::kDim> -pad(const Exp &src, index_t pad_y, index_t pad_x) { - TypeCheckPass::kDim >= 2> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - return PaddingExp::kDim> - (src.self(), pad_y, pad_x); -} -//---------------------- -// Execution plan -//---------------------- -template -struct Plan, DType> { - public: - explicit Plan(const PaddingExp &e) - : src_(MakePlan(e.src_)), - pad_y_(e.pad_y_), pad_x_(e.pad_x_), - new_height_(e.shape_[srcdim - 2]), - src_height_(e.src_height_), src_width_(e.src_width_) {} - MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { - const index_t x = j; - const index_t y = i % new_height_; - const index_t c = i / new_height_; - if (y < pad_y_ || x < pad_x_) return static_cast(0); - const index_t h = y - pad_y_; - const index_t w = x - pad_x_; - if (h < src_height_ && w < src_width_) { - return src_.Eval(c * src_height_ + h, w); - } else { - return static_cast(0); - } - } - - private: - Plan src_; - const index_t pad_y_; - const index_t pad_x_; - const index_t new_height_; - const index_t src_height_; - const index_t src_width_; -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_PAD_H_ diff --git a/include/mshadow/extension/range.h b/include/mshadow/extension/range.h deleted file mode 100644 index ab49b6e3cf18..000000000000 --- a/include/mshadow/extension/range.h +++ /dev/null @@ -1,118 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file range.h - * \brief support generating a range vector - * \author Xingjian Shi - */ -#ifndef MSHADOW_EXTENSION_RANGE_H_ -#define MSHADOW_EXTENSION_RANGE_H_ - -#include "../extension.h" - -namespace mshadow { -namespace expr { -/*! - * \brief Generate a range vector similar to python: range(start, stop[, step][, repeat]). - If step is positive, the last element is the largest start + i * step less than stop - If step is negative, the last element is the smallest start + i * step greater than stop. - All elements are repeated for `repeat` times, e.g range(0, 4, 2, 3) --> 0, 0, 0, 2, 2, 2 - * \tparam SrcExp type of lhs expression - * \tparam IndexExp type of index expression - * \tparam DType the type of elements - */ -template -struct RangeExp: - public Exp, DType, type::kMapper> { - const DType start_; - const DType stop_; - const DType step_; - const int repeat_; - /*! \brief constructor */ - RangeExp(DType start, DType stop, DType step, int repeat) - : start_(start), stop_(stop), step_(step), repeat_(repeat) {} -}; - -template -inline RangeExp -range(DType start, DType stop, DType step = 1, int repeat = 1) { - return RangeExp(start, stop, step, repeat); -} - -//---------------------- -// Execution plan -//---------------------- -template -struct Plan, DType> { - public: - explicit Plan(const RangeExp &e) - : start_(e.start_), - stop_(e.stop_), - step_(e.step_), - repeat_(e.repeat_) { - } - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - return start_ + static_cast((static_cast(x) / repeat_)) * step_; - } - - private: - const DType start_; - const DType stop_; - const DType step_; - const int repeat_; -}; - -template -inline Plan, DType> -MakePlan(const RangeExp &exp) { - return Plan, DType>(exp); -} - - -template -inline int RangeOutSize(DType start, DType stop, DType step, int repeat) { - return repeat * ((stop - start - 1) / step + 1); -} - -template<> -inline int RangeOutSize(float start, float stop, float step, int repeat) { - double d_start = static_cast(start); - double d_stop = static_cast(stop); - double d_step = static_cast(step); - return repeat * static_cast(ceil((d_stop - d_start) / d_step)); -} - -template<> -inline int RangeOutSize(double start, double stop, double step, int repeat) { - return repeat * static_cast(ceil((stop - start) / step)); -} - - -template -struct ShapeCheck > { - inline static Shape - Check(const RangeExp &t) { - CHECK(dim == 1) - << "RangeExp only support 1 dimension output, received " << dim; - CHECK(t.step_ != 0) - << "RangeExp does not support step=0, received " << t.step_; - CHECK(t.repeat_ > 0) - << "RangeExp only supports repeat > 0, received " << t.repeat_; - if (t.step_ > 0) { - CHECK(t.start_ < t.stop_) << "RangeExp does not support (start, stop, step) = " - << "(" << t.start_ << "," << t.stop_ << "," << t.step_ << ")"; - } else { - CHECK(t.start_ > t.stop_) << "RangeExp does not support (start, stop, step)= " - << "(" << t.start_ << "," << t.stop_ << "," << t.step_ << ")"; - } - return Shape1(RangeOutSize(t.start_, t.stop_, t.step_, t.repeat_)); - } -}; - -template -struct ExpInfo > { - static const int kDim = 1; - static const int kDevMask = 0xffff; -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_RANGE_H_ diff --git a/include/mshadow/extension/reduce_with_axis.h b/include/mshadow/extension/reduce_with_axis.h deleted file mode 100644 index 54bcc750cfc5..000000000000 --- a/include/mshadow/extension/reduce_with_axis.h +++ /dev/null @@ -1,136 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file reduce_with_axis.h - * \brief - * \author Junyuan Xie -*/ -#ifndef MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_ -#define MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_ - -#include "../extension.h" - -namespace mshadow { -namespace expr { - -/*! \brief reduce out the dimension of src labeled by axis. - * \tparam Reducer type of reducer - * \tparam SrcExp type of source expression - * \tparam DType data type - */ -template -struct ReduceWithAxisExp: - public MakeTensorExp, - SrcExp, dimdst, DType> { - /*! \brief source oprand */ - const SrcExp &src_; - /*! \brief size of last destination dimension */ - index_t last_dst_dim_; - /*! \brief size of trailing dimensions */ - index_t trailing_; - /*! \brief size of axis dimension */ - index_t size_; - /*! \brief size of last src dimension */ - index_t last_; - /*! constructor */ - explicit ReduceWithAxisExp(const SrcExp &src, int axis) - : src_(src) { - bool keepdim = (dimsrc == dimdst); - CHECK(dimsrc > axis) << "reduce axis out of bound"; - Shape src_shape = ShapeCheck::Check(src_); - for (int i = 0; i < axis; ++i) { - this->shape_[i] = src_shape[i]; - } - this->size_ = src_shape[axis]; - this->trailing_ = 1; - if (!keepdim) { - for (int i = axis + 1; i < dimsrc; ++i) { - this->trailing_ *= src_shape[i]; - this->shape_[i - 1] = src_shape[i]; - } - } else { - this->shape_[axis] = 1; - for (index_t i = axis + 1; i < dimsrc; ++i) { - this->trailing_ *= src_shape[i]; - this->shape_[i] = src_shape[i]; - } - } - - this->last_ = src_shape[dimsrc - 1]; - this->last_dst_dim_ = this->shape_[dimdst - 1]; - } -}; // struct ReduceWithAxisExp - -/*! - * \brief reduce out the dimension of src labeled by axis. - * \param Reducer type of the reducing operation - * \param mask whether to output the unmask indices - * \tparam SrcExp source expression - * \tparam DType data type - * \tparam etype type of the expression - */ -template -inline ReduceWithAxisExp::kDim, mask, - ExpInfo::kDim - 1> -reduce_with_axis(const Exp &src, int axis) { - return ReduceWithAxisExp::kDim, mask, - ExpInfo::kDim- 1>(src.self(), axis); -} - -/*! -* \brief reduce out the dimension of src labeled by axis, keepdim turned on. -* \param Reducer type of the reducing operation -* \param mask whether to output the unmask indices -* \tparam SrcExp source expression -* \tparam DType data type -* \tparam etype type of the expression -*/ -template -inline ReduceWithAxisExp::kDim, mask, - ExpInfo::kDim> - reduce_keepdim(const Exp &src, int axis) { - return ReduceWithAxisExp::kDim, mask, - ExpInfo::kDim>(src.self(), axis); -} - -//---------------------- -// Execution plan -//---------------------- -template -struct Plan, DType> { - public: - explicit Plan(const ReduceWithAxisExp &e) - : src_(MakePlan(e.src_)), last_dst_dim_(e.last_dst_dim_), trailing_(e.trailing_), - size_(e.size_), last_(e.last_) {} - MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { - index_t x = (i*last_dst_dim_ + j)/trailing_; - index_t y = (i*last_dst_dim_ + j)%trailing_; - - if (mask) { - index_t idx = 0; - DType res; Reducer::SetInitValue(res); - for (index_t k = 0; k < size_; ++k) { - index_t z = (x*size_+k)*trailing_+y; - DType tmp = res; - Reducer::Reduce(res, src_.Eval(z/last_, z%last_)); - if (tmp != res) { - idx = k; - } - } - return static_cast(static_cast(idx)); - } else { - DType res; Reducer::SetInitValue(res); - for (index_t k = 0; k < size_; ++k) { - index_t z = (x*size_+k)*trailing_+y; - Reducer::Reduce(res, src_.Eval(z/last_, z%last_)); - } - return res; - } - } - - private: - Plan src_; - const index_t last_dst_dim_, trailing_, size_, last_; -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_ diff --git a/include/mshadow/extension/reduceto1d.h b/include/mshadow/extension/reduceto1d.h deleted file mode 100644 index 09a478ab311e..000000000000 --- a/include/mshadow/extension/reduceto1d.h +++ /dev/null @@ -1,104 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file reduceto1d.h - * \brief support for sum_rows and sumall_except_dim - * \author Tianqi Chen - */ -#ifndef MSHADOW_EXTENSION_REDUCETO1D_H_ -#define MSHADOW_EXTENSION_REDUCETO1D_H_ -#include "../extension.h" -namespace mshadow { -namespace expr { -/*! - * \brief reduction to 1 dimension tensor - * input: Tensor: ishape - * output: Tensor shape[0] = ishape[dimkeep]; - * - * \tparam SrcExp type of expression to be reduced - * \tparam DType the data type of the scalar - * \tparam Reducer which reducer to use - * \tparam m_dimkeep which dimension to be kept, encoded with dimsrc - dimkeep - */ -template -struct ReduceTo1DExp: - public Exp, - DType, type::kComplex> { - /*! \brief source operand */ - const SrcExp &src_; - /*! \brief source operand, scale of the */ - DType scale_; - /*! \brief construct a repmat expression from src and nrow */ - ReduceTo1DExp(const SrcExp& src, DType scale) : src_(src), scale_(scale) {} -}; -/*! - * \brief a sum over all dimensions, except dimkeep - * \param exp input expression that must be a matrix Tensor - * \return a expresion with type Tensor - * \tparam dimkeep the dimension that will be kept - * \tparam SrcExp expression - * \tparam etype type of expression - */ -template -inline ReduceTo1DExp::kDim - dimkeep> -sumall_except_dim(const Exp &exp) { - return ReduceTo1DExp::kDim - dimkeep>(exp.self(), DType(1)); -} -/*! - * \brief reduce over all dimensions, except dimkeep - * \param exp input expression that must be a matrix Tensor - * \return a expresion with type Tensor - * \tparam dimkeep the dimension that will be kept - * \tparam SrcExp expression - * \tparam etype type of expression - */ -template -inline ReduceTo1DExp::kDim - dimkeep> -reduce_except_dim(const Exp &exp) { - return ReduceTo1DExp::kDim - dimkeep>(exp.self(), DType(1)); -} -/*! - * \brief a expression that sum over rows of a matrix - * \param exp input expression that must be a matrix Tensor - * \return a expresion with type Tensor - * \tparam SrcExp expression - * \tparam etype type of expression - */ -template -inline ReduceTo1DExp -sum_rows(const Exp &exp) { - TypeCheckPass::kDim ==2> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - return sumall_except_dim<1>(exp); -} -template -struct ExpComplexEngine, - ReduceTo1DExp, - DType> { - static const int dimkeep = ExpInfo::kDim - m_dimkeep; - inline static void Eval(Tensor *dst, - const ReduceTo1DExp &exp) { - TypeCheckPass - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - MapReduceKeepHighDim(dst, exp.src_, exp.scale_); - } -}; -template -struct ExpComplexEngine, - ReduceTo1DExp, DType> { - inline static void Eval(Tensor *dst, - const ReduceTo1DExp &exp) { - MapReduceKeepLowest(dst, exp.src_, exp.scale_); - } -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_REDUCETO1D_H_ diff --git a/include/mshadow/extension/reshape.h b/include/mshadow/extension/reshape.h deleted file mode 100644 index b310fe69291a..000000000000 --- a/include/mshadow/extension/reshape.h +++ /dev/null @@ -1,87 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file reshape.h - * \brief support for reshape - * \author Tianqi Chen - */ -#ifndef MSHADOW_EXTENSION_RESHAPE_H_ -#define MSHADOW_EXTENSION_RESHAPE_H_ -#include "../extension.h" -namespace mshadow { -namespace expr { -/*! - * \brief reshape the content to another shape - * input: Tensor: ishape - * output: Tensor ishape.Size() == oshape.Size() - * \tparam SrcExp source expression - * \tparam dimdst target dimension - * \tparam dimsrc source dimension - */ -template -struct ReshapeExp: - public MakeTensorExp, - SrcExp, dimdst, DType> { - /*! \brief source expression */ - const SrcExp &src_; - /*! \brief smallest dimension of input */ - index_t ishapex_; - /*! \brief constructor */ - ReshapeExp(const SrcExp &src, Shape shape) - : src_(src) { - Shape ishape = ShapeCheck::Check(src_); - CHECK_EQ(ishape.Size(), shape.Size()) << "reshape size must match"; - ishapex_ = ishape[dimsrc - 1]; - this->shape_ = shape; - } -}; -/*! - * \brief a expression that reshapes a tensor to another shape - * \param src Tensor: - * \param oshape target shape - * \return a expresion with type Tensor - * \tparam SrcExp source expression - * \tparam etype source expression type - * \tparam dimdst target dimension - */ -template -inline ReshapeExp::kDim> -reshape(const Exp &src, Shape oshape) { - return ReshapeExp::kDim> - (src.self(), oshape); -} -//---------------------- -// Execution plan -//---------------------- -template -struct Plan, DType> { - public: - explicit Plan(const ReshapeExp &e) - : src_(MakePlan(e.src_)), - oshapex_(e.shape_[dimdst - 1]), ishapex_(e.ishapex_) {} - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - const index_t idx = y * oshapex_ + x; - return src_.Eval(idx / ishapex_, idx % ishapex_); - } - - private: - Plan src_; - const index_t oshapex_, ishapex_; -}; -// special work plan for 1 dimensional data -template -struct Plan, DType> { - public: - explicit Plan(const ReshapeExp &e) - : src_(MakePlan(e.src_)), oshapex_(e.shape_[dimdst - 1]) { - } - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - return src_.Eval(0, y * oshapex_ + x); - } - - private: - Plan src_; - const index_t oshapex_; -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_RESHAPE_H_ diff --git a/include/mshadow/extension/slice.h b/include/mshadow/extension/slice.h deleted file mode 100644 index cb2eff4548aa..000000000000 --- a/include/mshadow/extension/slice.h +++ /dev/null @@ -1,156 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file slice.h - * \brief support for slice a certain dimension. - */ -#ifndef MSHADOW_EXTENSION_SLICE_H_ -#define MSHADOW_EXTENSION_SLICE_H_ - -#include "../extension.h" - -namespace mshadow { -namespace expr { -/*! - * \brief slice expression, slice a tensor's channel - * \tparam SrcExp left expression - * \tparam DType the type of elements - * \tparam srcdim dimension of src - * \tparam dimsrc_m_cat dimsrc - dimcat - */ -template -struct SliceExp : public TRValue, - Device, srcdim, DType> { - static const int dimslice = srcdim - dimsrc_m_slice; - const SrcExp &src_; - index_t ch_begin_; - index_t ch_old_; - Shape shape_; - SliceExp(const SrcExp &src, index_t begin, index_t end) - : src_(src), ch_begin_(begin) { - shape_ = ShapeCheck::Check(src_); - ch_old_ = shape_[dimslice]; - CHECK(begin < shape_[dimslice] && end <= shape_[dimslice]) - << "The slice went out of range"; - shape_[dimslice] = end - begin; - } - template - inline void - operator=(const expr::Exp &exp) { - this->__assign(exp); - } - inline void - operator=(const DType &exp) { - this->__assign(exp); - } -}; // struct Slice - -/*! - * \brief Slice a Tensor - * \param src source tensor - * \param begin The beginning slice. - * \param end The end slice. - * \return sliced tensor - * \tparam sdim the dimension to slice on - * \tparam SrcExp source expression - * \tparam DType the type of elements - * \tparam etype type of expression - */ -template -inline SliceExp -slice(const TRValue &src, index_t begin, index_t end) { - TypeCheckPass::kDim == srcdim> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - return SliceExp(src.self(), begin, end); -} -//------------------------ -// engine plugin -//------------------------ -// runtime shapecheck -template -struct ShapeCheck >{ - inline static Shape Check(const SliceExp &t) { - return t.shape_; - } -}; -template -struct StreamInfo >{ - inline static Stream * - Get(const SliceExp &t) { - return StreamInfo::Get(t.src_); - } -}; -// static typecheck -template -struct ExpInfo >{ - static const int kDim = ExpInfo::kDim; - static const int kDevMask = ExpInfo::kDevMask; -}; -//---------------------- -// Execution plan -//--------------------- -template -struct Plan, DType> { - public: - static const int dimslice = srcdim - dimsrc_m_slice; - explicit Plan(const SliceExp &e) - : src_(MakePlan(e.src_)), - height_(e.shape_.ProdShape(dimslice + 1, srcdim - 1)), - ch_begin_(e.ch_begin_), ch_old_(e.ch_old_), ch_(e.shape_[dimslice]) {} - MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { - const index_t y = i % height_; - i /= height_; - const index_t c = i % ch_ + ch_begin_; - const index_t b = i / ch_; - const index_t x = j; - return src_.Eval((b * ch_old_ + c) * height_ + y, x); - } - MSHADOW_XINLINE DType &REval(index_t i, index_t j) { - const index_t y = i % height_; - i /= height_; - const index_t c = i % ch_ + ch_begin_; - const index_t b = i / ch_; - const index_t x = j; - return src_.REval((b * ch_old_ + c) * height_ + y, x); - } - - private: - Plan src_; - const index_t height_, ch_begin_, ch_old_, ch_; -}; // struct Plan - -template -struct Plan, DType> { - public: - explicit Plan(const SliceExp &e) - : src_(MakePlan(e.src_)), - ch_begin_(e.ch_begin_) {} - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - return src_.Eval(y, x + ch_begin_); - } - MSHADOW_XINLINE DType &REval(index_t y, index_t x) { - return src_.REval(y, x + ch_begin_); - } - - private: - Plan src_; - const index_t ch_begin_; -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_SLICE_H_ diff --git a/include/mshadow/extension/slice_ex.h b/include/mshadow/extension/slice_ex.h deleted file mode 100644 index 7f464097fb3b..000000000000 --- a/include/mshadow/extension/slice_ex.h +++ /dev/null @@ -1,135 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file slice.h - * \brief support for slice a certain dimension. - */ -#ifndef MSHADOW_EXTENSION_SLICE_EX_H_ -#define MSHADOW_EXTENSION_SLICE_EX_H_ - -#include "../extension.h" - -namespace mshadow { -namespace expr { -/*! - * \brief slice expression, slice a tensor's channel - * \tparam SrcExp left expression - * \tparam DType the type of elements - * \tparam srcdim dimension of src - * \tparam dimsrc_m_cat dimsrc - dimcat - */ -template -struct SliceExExp : public TRValue, - Device, srcdim, DType> { - const SrcExp &src_; - Shape src_shape_; - Shape shape_; - const Shape begin_; - const Shape end_; - SliceExExp(const SrcExp &src, Shape begin, Shape end) - : src_(src), begin_(begin), end_(end) { - src_shape_ = ShapeCheck::Check(src_); - for (int i = 0; i < srcdim; ++i) { - shape_[i] = end_[i] - begin_[i]; - } - } - template - inline void - operator=(const expr::Exp &exp) { - this->__assign(exp); - } - inline void - operator=(const DType &exp) { - this->__assign(exp); - } -}; // struct SliceEx - -/*! - * \brief SliceEx a Tensor - * \param src source tensor - * \param begin The beginning slice. - * \param end The end slice. - * \return sliced tensor - * \tparam sdim the dimension to slice on - * \tparam SrcExp source expression - * \tparam DType the type of elements - * \tparam etype type of expression - */ -template -inline SliceExExp -slice(const TRValue &src, Shape begin, Shape end) { - TypeCheckPass::kDim == srcdim> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - return SliceExExp(src.self(), begin, end); -} -//------------------------ -// engine plugin -//------------------------ -// runtime shapecheck -template -struct ShapeCheck >{ - inline static Shape Check(const SliceExExp &t) { - return t.shape_; - } -}; - -template -struct StreamInfo >{ - inline static Stream * - Get(const SliceExExp &t) { - return StreamInfo::Get(t.src_); - } -}; -// static typecheck -template -struct ExpInfo >{ - static const int kDim = ExpInfo::kDim; - static const int kDevMask = ExpInfo::kDevMask; -}; -//---------------------- -// Execution plan -//--------------------- -template -struct Plan, DType> { - public: - explicit Plan(const SliceExExp &e) - : src_(MakePlan(e.src_)), begin_(e.begin_), - src_shape_(e.src_shape_), shape_(e.shape_) {} - MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { - index_t idx = 0; - index_t stride = 1; - #pragma unroll - for (int k = srcdim-2; k >= 0; --k) { - idx += stride * (i%shape_[k] + begin_[k]); - i /= shape_[k]; - stride *= src_shape_[k]; - } - return src_.Eval(idx, j + begin_[srcdim-1]); - } - MSHADOW_XINLINE DType &REval(index_t i, index_t j) { - index_t idx = 0; - index_t stride = 1; - #pragma unroll - for (int k = srcdim-2; k >= 0; --k) { - idx += stride * (i%shape_[k] + begin_[k]); - i /= shape_[k]; - stride *= src_shape_[k]; - } - return src_.REval(idx, j + begin_[srcdim-1]); - } - - private: - Plan src_; - const Shape begin_, src_shape_, shape_; -}; // struct Plan -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_SLICE_EX_H_ diff --git a/include/mshadow/extension/spatial_pool.h b/include/mshadow/extension/spatial_pool.h deleted file mode 100644 index c833fb40ad58..000000000000 --- a/include/mshadow/extension/spatial_pool.h +++ /dev/null @@ -1,152 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file spatial_pool.h - * \brief support for spatial pooling - * \author Tianqi Chen - */ -#ifndef MSHADOW_EXTENSION_SPATIAL_POOL_H_ -#define MSHADOW_EXTENSION_SPATIAL_POOL_H_ -#include -#include "../extension.h" -namespace mshadow { -namespace expr { -/*! - * \brief pooling expression, do reduction over local patches of a image - * \tparam Reducer reduction method during pooling - * \tparam SrcExp source expression to be pooled from - * \tparam DType the content data type - * \tparam srcdim dimension of src - */ -template -struct PoolingExp: - public MakeTensorExp, - SrcExp, srcdim, DType> { - /*! \brief source operand */ - const SrcExp &src_; - /*! \brief kernel size in height */ - index_t ksize_y_; - /*! \brief kernel size in width */ - index_t ksize_x_; - /*! \brief kernel stride in y directory */ - index_t kstride_y_; - /*! \brief kernel stride in x directory */ - index_t kstride_x_; - /*! \brief source height shape[1] */ - index_t src_height_; - /*! \brief source width shape[0] */ - index_t src_width_; - /*! \brief constructor */ - PoolingExp(const SrcExp &src, - index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x) - : src_(src), ksize_y_(ksize_y), ksize_x_(ksize_x), - kstride_y_(kstride_y), kstride_x_(kstride_x) { - Shape sshape = ShapeCheck::Check(src_); - CHECK(sshape[srcdim - 1] >= ksize_x && sshape[srcdim - 2] >= ksize_y) - << "PoolingExp: kernel must be smaller than image"; - this->src_height_ = sshape[srcdim - 2]; - this->src_width_ = sshape[srcdim - 1]; - this->shape_ = sshape; - this->shape_[srcdim - 2] = (src_height_ - ksize_y) / kstride_y + 1; - this->shape_[srcdim - 1] = (src_width_ - ksize_x) / kstride_x + 1; - } - /*! \brief constructor, specify shape */ - PoolingExp(const SrcExp &src, Shape<2> pshape, - index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x) - : src_(src), ksize_y_(ksize_y), ksize_x_(ksize_x), - kstride_y_(kstride_y), kstride_x_(kstride_x) { - Shape sshape = ShapeCheck::Check(src_); - CHECK(sshape[srcdim - 1] >= ksize_x && sshape[srcdim - 2] >= ksize_y) - << "PoolingExp: kernel must be smaller than image"; - this->src_height_ = sshape[srcdim - 2]; - this->src_width_ = sshape[srcdim - 1]; - this->shape_ = sshape; - this->shape_[srcdim - 2] = pshape[0]; - this->shape_[srcdim - 1] = pshape[1]; - } -}; -/*! - * \brief pooling subregion results together - * \param src source image, shape: (batch, channel, height, width) - * \param ksize_y kernel size in height - * \param ksize_x kernel size in width - * \param kstride_y stride in y directory - * \param kstride_x stride in x directory - * \return expression of pooled result - * \tparam Reducer reducer type - * \tparam SrcExp source expression - * \tparam DType the content data type - * \tparam etype type of expression - */ -template -inline PoolingExp::kDim> -pool(const Exp &src, - index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x) { - TypeCheckPass::kDim >= 2> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - return PoolingExp::kDim> - (src.self(), ksize_y, ksize_x, kstride_y, kstride_x); -} -/*! - * \brief same as pool, except the output shape is specified by pshape - * \param src source image - * \param pshape ouput shape - * \param ksize_y kernel size in y - * \param ksize_x kernel size in x - * \param kstride_y stride in y directory - * \param kstride_x stride in x directory - * \return expression of pooled result - * \tparam Reducer reducer type - * \tparam SrcExp source expression - * \tparam DType the content data type - * \tparam etype type of expression - */ -template -inline PoolingExp::kDim> -pool(const Exp &src, Shape<2> pshape, - index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x) { - TypeCheckPass::kDim >= 2> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - return PoolingExp::kDim> - (src.self(), pshape, ksize_y, ksize_x, kstride_y, kstride_x); -} -//---------------------- -// Execution plan -//---------------------- -template -struct Plan, DType> { - public: - explicit Plan(const PoolingExp &e) - : src_(MakePlan(e.src_)), - ksize_y_(e.ksize_y_), ksize_x_(e.ksize_x_), - kstride_y_(e.kstride_y_), kstride_x_(e.kstride_x_), - src_height_(e.src_height_), src_width_(e.src_width_), - new_height_(e.shape_[srcdim - 2]) {} - MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { - using namespace std; - const index_t py = i % new_height_; - const index_t y_start = py * kstride_y_; - const index_t y_end = min(y_start + ksize_y_, src_height_); - const index_t px = j; - const index_t x_start = px * kstride_x_; - const index_t x_end = min(x_start + ksize_x_, src_width_); - const index_t c = i / new_height_; - - DType res; Reducer::SetInitValue(res); - for (index_t y = y_start; y < y_end; ++y) { - for (index_t x = x_start; x < x_end; ++x) { - Reducer::Reduce(res, src_.Eval(c * src_height_ + y, x)); - } - } - return res; - } - - private: - Plan src_; - const index_t ksize_y_, ksize_x_, kstride_y_, kstride_x_; - const index_t src_height_, src_width_; - const index_t new_height_; -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_SPATIAL_POOL_H_ diff --git a/include/mshadow/extension/spatial_unpool.h b/include/mshadow/extension/spatial_unpool.h deleted file mode 100644 index e9ca2dfd035b..000000000000 --- a/include/mshadow/extension/spatial_unpool.h +++ /dev/null @@ -1,135 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file spatial_unpool.h - * \brief support for unpool - * \author Tianqi Chen - */ -#ifndef MSHADOW_EXTENSION_SPATIAL_UNPOOL_H_ -#define MSHADOW_EXTENSION_SPATIAL_UNPOOL_H_ -#include -#include "../extension.h" -namespace mshadow { -namespace expr { -/*! - * \brief unpooling expr reverse operation of pooling, used to pass gradient back - * \tparam Reducer reduction method during pooling - * \tparam SrcExp source expression to be pooled from - * \tparam DType the content data type - * \tparam srcdim dimension of src - */ -template -struct UnPoolingExp: - public MakeTensorExp, - SrcExp, srcdim, DType> { - /*! \brief source input, corresponds to src in pooling */ - const SrcExp &data_src_; - /*! \brief result of pooled data, corresponds to result of pooling */ - const SrcExp &data_pooled_; - /*! \brief gradient data of pooled part, to be propgate down */ - const SrcExp &grad_pooled_; - /*! \brief shape of pooled expression */ - index_t pshape_y_; - /*! \brief shape of pooled expression */ - index_t pshape_x_; - /*! \brief kernel size in height */ - index_t ksize_y_; - /*! \brief kernel size in width */ - index_t ksize_x_; - /*! \brief kernel stride in y directory */ - index_t kstride_y_; - /*! \brief kernel stride in x directory */ - index_t kstride_x_; - /*! \brief constructor */ - UnPoolingExp(const SrcExp &data_src, - const SrcExp &data_pooled, - const SrcExp &grad_pooled, - index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x) - : data_src_(data_src), data_pooled_(data_pooled), - grad_pooled_(grad_pooled), - ksize_y_(ksize_y), ksize_x_(ksize_x), - kstride_y_(kstride_y), kstride_x_(kstride_x) { - Shape pshape = ShapeCheck::Check(grad_pooled); - typedef ShapeCheck ShapeCheckSrcDimSrcExp; - CHECK_EQ(pshape, ShapeCheckSrcDimSrcExp::Check(data_pooled)) - << "UnPoolingExp: pooled shape mismatch"; - Shape sshape = ShapeCheck::Check(data_src); - for (int k = 0; k < srcdim - 2; ++k) { - CHECK_EQ(pshape[k], sshape[k]) << "UnPoolingExp: pool and src shape mismatch"; - } - pshape_x_ = pshape[srcdim - 1]; - pshape_y_ = pshape[srcdim - 2]; - this->shape_ = sshape; - } -}; -/*! - * \brief unpooling gradient for 4D, backprop gradient value back, revserse operation of pooling, - * same as unpooling, but allows unequal size of kernel - * \param data_src source input, corresponds to src in pooling - * \param data_pooled result of pooled data, corresponds to result of pooling - * \param grad_pooled gradient data of pooled part, to be propgate down - * \param ksize_y kernel height - * \param ksize_x kernel width - * \param kstride_y stride in y directory - * \param kstride_x stride in x directory - * \return expression corresponding to unpooled 4D Tensor, storing backproped gradient - * \tparam Reducer reducer type - * \tparam SrcExp source expression - * \tparam DType the content data type - * \tparam etype type of expression - */ -template -inline UnPoolingExp::kDim> -unpool(const Exp &data_src, - const Exp &data_pooled, - const Exp &grad_pooled, - index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x) { - return UnPoolingExp::kDim> - (data_src.self(), data_pooled.self(), grad_pooled.self(), - ksize_y, ksize_x, kstride_y, kstride_x); -} -//---------------------- -// Execution plan -//---------------------- -template -struct Plan, DType> { - public: - explicit Plan(const UnPoolingExp &e) - : data_src_(MakePlan(e.data_src_)), data_pooled_(MakePlan(e.data_pooled_)), - grad_pooled_(MakePlan(e.grad_pooled_)), sshape_y_(e.shape_[srcdim - 2]), - pshape_y_(e.pshape_y_), pshape_x_(e.pshape_x_), - ksize_y_(e.ksize_y_), ksize_x_(e.ksize_x_), - kstride_y_(e.kstride_y_), kstride_x_(e.kstride_x_) {} - MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { - using namespace std; - const index_t x = j; - const index_t y = i % sshape_y_; - const index_t c = i / sshape_y_; - const DType vsrc = data_src_.Eval(i, j); - const index_t py_min = - y < ksize_y_ ? 0 : (y - ksize_y_ + kstride_y_) / kstride_y_; - const index_t px_min = - x < ksize_x_ ? 0 : (x - ksize_x_ + kstride_x_) / kstride_x_; - const index_t py_max = min((y + kstride_y_) / kstride_y_, pshape_y_); - const index_t px_max = min((x + kstride_x_) / kstride_x_, pshape_x_); - - DType val = static_cast(0); - for (index_t py = py_min; py < py_max; ++py) { - for (index_t px = px_min; px < px_max; ++px) { - val += Reducer::PartialGrad(vsrc, - data_pooled_.Eval(c * pshape_y_ + py, px)) * - grad_pooled_.Eval(c * pshape_y_ + py, px); - } - } - - return val; - } - - private: - Plan data_src_, data_pooled_, grad_pooled_; - const index_t sshape_y_, pshape_y_, pshape_x_; - const index_t ksize_y_, ksize_x_; - const index_t kstride_y_, kstride_x_; -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_SPATIAL_UNPOOL_H_ diff --git a/include/mshadow/extension/spatial_upsampling_nearest.h b/include/mshadow/extension/spatial_upsampling_nearest.h deleted file mode 100644 index 534fbdd9ebe0..000000000000 --- a/include/mshadow/extension/spatial_upsampling_nearest.h +++ /dev/null @@ -1,71 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file spatial_upsampling.h - * \brief - * \author Bing Xu -*/ -#ifndef MSHADOW_EXTENSION_SPATIAL_UPSAMPLING_NEAREST_H_ -#define MSHADOW_EXTENSION_SPATIAL_UPSAMPLING_NEAREST_H_ -#include "../extension.h" - -namespace mshadow { -namespace expr { - -/*! \brief nearest neighboor upsampling - * out(x, y) = in(int(x / scale_x), int(y / scale_y)) - * \tparam SrcExp source expression - * \tparam DType data type - * \tparam srcdim source dimension - */ -template -struct UpSamplingNearestExp : - public MakeTensorExp, - SrcExp, srcdim, DType> { - /*! \brief source oprand */ - const SrcExp &src_; - /*! \brief up sampling scale */ - index_t scale_; - /*! \brief constructor */ - UpSamplingNearestExp(const SrcExp &src, index_t scale) - : src_(src), scale_(scale) { - this->shape_ = ShapeCheck::Check(src_); - this->shape_[srcdim - 2] *= scale_; - this->shape_[srcdim - 1] *= scale_; - } -}; - - -template -inline UpSamplingNearestExp::kDim> -upsampling_nearest(const Exp &src, index_t scale) { - TypeCheckPass::kDim >= 2> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - return UpSamplingNearestExp::kDim>(src.self(), scale); -} - -template -struct Plan, DType> { - public: - explicit Plan(const UpSamplingNearestExp &e) - : src_(MakePlan(e.src_)), - scale_(e.scale_), - new_height_(e.shape_[srcdim - 2]), - src_height_(static_cast(e.shape_[srcdim - 2] / e.scale_)) {} - MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { - const index_t x = j; - const index_t y = i % new_height_; - const index_t c = i / new_height_; - const index_t h = static_cast(y / scale_); - const index_t w = static_cast(x / scale_); - return src_.Eval(c * src_height_ + h, w); - } - - private: - Plan src_; - const index_t scale_; - const index_t new_height_; - const index_t src_height_; -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_SPATIAL_UPSAMPLING_NEAREST_H_ diff --git a/include/mshadow/extension/swapaxis.h b/include/mshadow/extension/swapaxis.h deleted file mode 100644 index b79aba441175..000000000000 --- a/include/mshadow/extension/swapaxis.h +++ /dev/null @@ -1,110 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file swapaxis.h - * \brief support for swapaxis - * \author Tianqi Chen - */ -#ifndef MSHADOW_EXTENSION_SWAPAXIS_H_ -#define MSHADOW_EXTENSION_SWAPAXIS_H_ -#include -#include -#include "../extension.h" -namespace mshadow { -namespace expr { -/*! - * \brief swap two axis of a tensor - * input: Tensor: ishape - * output: Tensor oshape[a1],oshape[a2] = ishape[a2],oshape[a1] - * - * \tparam SrcExp type of source expression - * \tparam DType the type of elements - * \tparam dimsrc source dimension, assert a1 > a2 - * \tparam m_a1 one dimension to be swapped, encoded by dimsrc - a1 - * \tparam a2 second dimension to be swapped, encoded by a2 - */ -template -struct SwapAxisExp: - public MakeTensorExp, - SrcExp, dimsrc, DType> { - // decode the a1, a2 - static const int a1 = dimsrc - m_a1; - /*! \brief source expression */ - const SrcExp &src_; - /*! \brief constructor */ - explicit SwapAxisExp(const SrcExp &src) : src_(src) { - this->shape_ = ShapeCheck::Check(src); - std::swap(this->shape_[a1], this->shape_[a2]); - } -}; -/*! - * \brief a expression that reshapes a tensor to another shape - * \param src Tensor: - * \return a expresion with type Tensor - * \tparam a1 higher dimension to be swapped, assert a1 > a2 - * \tparam a2 lower dimension to be swapped - * \tparam SrcExp source expression - * \tparam DType the type of elements - * \tparam etype source expression type - */ -template -inline SwapAxisExp::kDim, - ExpInfo::kDim - a1, a2> -swapaxis(const Exp &src) { - typedef ExpInfo Info; - TypeCheckPass= a1 + 1 && Info::kDim >= a2 + 1 && - a2 < a1>::Error_Expression_Does_Not_Meet_Dimension_Req(); - return SwapAxisExp::kDim, - ExpInfo::kDim - a1, a2>(src.self()); -} -template -struct Plan, DType> { - public: - // decode the a1 - static const int a1 = dimsrc - m_a1; - explicit Plan(const SwapAxisExp &e) - : src_(MakePlan(e.src_)), - shapey_(e.shape_.ProdShape(a1 + 1, dimsrc - 1)), - shapez_(e.shape_[a1]), - shapec_(e.shape_.ProdShape(a2 + 1, a1)), - shapen_(e.shape_[a2]) {} - MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { - const index_t y = i % shapey_; - i /= shapey_; - const index_t z = i % shapez_; - i /= shapez_; - const index_t c = i % shapec_; - i /= shapec_; - const index_t n = i % shapen_; - // swap z and n - return src_.Eval(((((i / shapen_) * shapez_ + z) * shapec_ + - c) * shapen_ + n) * shapey_ + y, j); - } - - private: - Plan src_; - const index_t shapey_, shapez_, shapec_, shapen_; -}; -template -struct Plan, DType> { - public: - explicit Plan(const SwapAxisExp &e) - : src_(MakePlan(e.src_)), - shapex_(e.shape_[dimsrc - 1]), - shapey_(e.shape_.ProdShape(a2 + 1, dimsrc - 1)), - shapez_(e.shape_[a2]) {} - MSHADOW_XINLINE DType Eval(index_t i, index_t x) const { - // swap x and z - const index_t y = i % shapey_; - i /= shapey_; - const index_t z = i % shapez_; - const index_t n = i / shapez_; - return src_.Eval((n * shapex_ + x) * shapey_ + y , z); - } - - private: - Plan src_; - const index_t shapex_, shapey_, shapez_; -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_SWAPAXIS_H_ diff --git a/include/mshadow/extension/take.h b/include/mshadow/extension/take.h deleted file mode 100644 index 76c4f4729491..000000000000 --- a/include/mshadow/extension/take.h +++ /dev/null @@ -1,99 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file take.h - * \brief - * \author Bing Xu -*/ -#ifndef MSHADOW_EXTENSION_TAKE_H_ -#define MSHADOW_EXTENSION_TAKE_H_ - -#include "../extension.h" - -namespace mshadow { -namespace expr { - -/*! \brief Take a column from a matrix - * \tparam IndexExp type of index expression - * \tparam SrcExp type of src expression - * \tparam DType data type - */ -template -struct TakeExp: public Exp, - DType, type::kChainer> { - /*! \brief index oprand */ - const IndexExp &index_; - /*! \brief embediing oprand */ - const SrcExp &src_; - /*! constructor */ - TakeExp(const IndexExp &index, const SrcExp &src) - : index_(index), src_(src) {} -}; // struct TakeExp - - - -template -inline TakeExp -take(const Exp &index, - const Exp &src) { - return TakeExp(index.self(), src.self()); -} - - -//---------------------- -// Execution plan -//---------------------- - -template -struct Plan, DType> { - public: - explicit Plan(const TakeExp &e) - : index_(MakePlan(e.index_)), src_(MakePlan(e.src_)) { - } - - // TODO(xx): discuss W shape: in * out or out * in - // Now I use in * out - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - index_t idx = static_cast(index_.Eval(0, y)); - return static_cast(src_.Eval(idx, x)); - } - - private: - expr::Plan index_; - expr::Plan src_; -}; // struct Plan - -template -inline Plan, DType> -MakePlan(const TakeExp &exp) { - return Plan, DType>(exp); -} - -template -struct ShapeCheck > { - inline static Shape - Check(const TakeExp &t) { - CHECK(dim == 2) - << "TakeExp only support 2D output"; - Shape<1> dshape = ShapeCheck<1, IndexExp>::Check(t.index_); - Shape<2> wshape = ShapeCheck<2, SrcExp>::Check(t.src_); - Shape ret; - ret[0] = dshape[0]; - ret[1] = wshape[1]; - return ret; - } -}; - - -template -struct ExpInfo > { - static const int kDim = 2; - static const int kDevMask = ExpInfo::kDevMask; -}; - -} // namespace expr -} // namespace mshadow - -#endif // MSHADOW_EXTENSION_TAKE_H_ diff --git a/include/mshadow/extension/take_grad.h b/include/mshadow/extension/take_grad.h deleted file mode 100644 index 4479b3e0cd9d..000000000000 --- a/include/mshadow/extension/take_grad.h +++ /dev/null @@ -1,111 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file take_grad.h - * \brief - * \author Bing Xu -*/ -#ifndef MSHADOW_EXTENSION_TAKE_GRAD_H_ -#define MSHADOW_EXTENSION_TAKE_GRAD_H_ - -#include "../extension.h" - -namespace mshadow { -namespace expr { - -/*! \brief Calculate embedding gradient - * \tparam IndexExp type of index expression - * \tparam SrcExp type of src expression - * \tparam DType data type - */ - -template -struct TakeGradExp : public Exp, - DType, type::kChainer> { - /*! \brief index oprand */ - const IndexExp &index_; - /*! \brief out gradient oprand */ - const SrcExp &src_; - /*! \brief batch size */ - const index_t input_dim_; - /*! \brief constructor */ - TakeGradExp(const IndexExp &index, const SrcExp &src, const index_t input_dim) - : index_(index), src_(src), input_dim_(input_dim) {} -}; // struct TakeGradExp - - -template -inline TakeGradExp -take_grad(const Exp &index, - const Exp &src, - const index_t input_dim) { - return TakeGradExp(index.self(), - src.self(), - input_dim); -} - -//---------------------- -// Execution plan -//---------------------- - -template -struct Plan, DType> { - public: - explicit Plan(const TakeGradExp &e) - : index_(MakePlan(e.index_)), - src_(MakePlan(e.src_)), - batch_size_(ShapeCheck<1, IndexExp>::Check(e.index_)[0]) { - } - - // now return shape: in * out - MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { - DType ret = 0.f; - for (index_t i = 0; i < batch_size_; ++i) { - index_t idx = static_cast(index_.Eval(0, i)); - if (idx == y) { - ret += static_cast(src_.Eval(i, x)); - } - } - return ret; - } - - private: - expr::Plan index_; - expr::Plan src_; - const index_t batch_size_; -}; // struct Plan - - -template -inline Plan, DType> -MakePlan(const TakeGradExp &exp) { - return Plan, DType>(exp); -} - -template -struct ShapeCheck > { - inline static Shape - Check(const TakeGradExp &t) { - CHECK(dim == 2) - << "TakeGradExp only support 2D output"; - // Shape<1> dshape = ShapeCheck<1, IndexExp>::Check(t.index_); - Shape<2> gshape = ShapeCheck<2, SrcExp>::Check(t.src_); - Shape ret; - ret[0] = t.input_dim_; - ret[1] = gshape[1]; - return ret; - } -}; // struct ShapeCheck - -template -struct ExpInfo > { - static const int kDim = 2; - static const int kDevMask = ExpInfo::kDevMask; -}; - -} // namespace expr -} // namespace mshadow - -#endif // MSHADOW_EXTENSION_TAKE_GRAD_H_ diff --git a/include/mshadow/extension/transpose.h b/include/mshadow/extension/transpose.h deleted file mode 100644 index 6640153f2100..000000000000 --- a/include/mshadow/extension/transpose.h +++ /dev/null @@ -1,200 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file transpose.h - * \brief support for transpose - * \author Junyuan Xie - */ -#ifndef MSHADOW_EXTENSION_TRANSPOSE_H_ -#define MSHADOW_EXTENSION_TRANSPOSE_H_ -#include -#include "../extension.h" -namespace mshadow { -namespace expr { -/*! - * \brief transpose axes of a tensor - * input: Tensor: ishape - * output: Tensor oshape[a1],oshape[a2] = ishape[a2],oshape[a1] - * - * \tparam SrcExp type of source expression - * \tparam DType the type of elements - * \tparam dimsrc source dimension, assert a1 > a2 - * \tparam m_a1 one dimension to be swapped, encoded by dimsrc - a1 - * \tparam a2 second dimension to be swapped, encoded by a2 - */ -template -struct TransposeExExp: - public MakeTensorExp, - SrcExp, dimsrc, DType> { - /*! \brief source expression */ - const SrcExp &src_; - const Shape axes_; - Shape dst_in_src_stride_; // Holds the corresponding stride of the dst axes in src - index_t src_stride_; - /*! \brief constructor */ - explicit TransposeExExp(const SrcExp &src, Shape axes) : src_(src), axes_(axes) { - Shape src_shape = ShapeCheck::Check(src); - src_stride_ = src_shape[dimsrc - 1]; - Shape src_stride; - src_stride[dimsrc-1] = 1; - for (int i = dimsrc-2; i >= 0; --i) src_stride[i] = src_shape[i+1]*src_stride[i+1]; - for (int i = 0; i < dimsrc; ++i) { - dst_in_src_stride_[i] = src_stride[axes[i]]; - this->shape_[i] = src_shape[axes[i]]; - } - } -}; -/*! - * \brief a expression that reshapes a tensor to another shape - * \param src Tensor: - * \return a expresion with type Tensor - * \tparam a1 higher dimension to be swapped, assert a1 > a2 - * \tparam a2 lower dimension to be swapped - * \tparam SrcExp source expression - * \tparam DType the type of elements - * \tparam etype source expression type - */ -template -inline TransposeExExp::kDim> -transpose(const Exp &src, Shape::kDim> axes) { - return TransposeExExp::kDim>(src.self(), axes); -} - -template -struct Plan, DType> { - public: - explicit Plan(const TransposeExExp &e) - : src_(MakePlan(e.src_)), - src_stride_(e.src_stride_), - dst_in_src_stride_(e.dst_in_src_stride_), - dst_shape_(e.shape_) {} - MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { - index_t idx = j * dst_in_src_stride_[dimsrc - 1]; - #pragma unroll - for (int k = dimsrc-2; k >= 0; --k) { - idx += (i % dst_shape_[k]) * dst_in_src_stride_[k]; - i /= dst_shape_[k]; - } - return src_.Eval(idx/src_stride_, idx%src_stride_); - } - - private: - Plan src_; - const index_t src_stride_; - const Shape dst_in_src_stride_, dst_shape_; -}; - -/*! - * \brief transform contiguous indices of the source tensor to indices of the transposed tensor. - * input: Tensor: ishape - * output: Tensor: oshape = ishape - * - * \tparam SrcExp type of source expression - * \tparam DType the type of elements - * \tparam dimsrc source dimension - * \tparam etype source type - */ -template -struct TransposeIndicesExp: - public Exp, DType, etype> { - /*! \brief source expression */ - const SrcExp &src_indices_; // Expression of the source indices - Shape src_shape_; // Holds the corresponding stride of the source axes in dst - const Shape axes_; // The transpose axes - Shape src_in_dst_stride_; // Holds the corresponding stride of the source axes in dst - /*! \brief constructor */ - explicit TransposeIndicesExp(const SrcExp &src_indices, - Shape src_shape, - Shape axes) : src_indices_(src_indices), - src_shape_(src_shape), axes_(axes) { - Shape dst_shape_; - Shape dst_stride_; - bool axes_checking_flag[dimsrc] = { 0 }; - for (int i = 0; i < dimsrc; ++i) { - CHECK_LT(static_cast(axes[i]), dimsrc) - << "Invalid axes input! All elements of axes must be between 0 and " << dimsrc - << ", find axes=" << axes; - dst_shape_[i] = src_shape[axes[i]]; - axes_checking_flag[axes[i]] = true; - } - // check if the input axes is valid - for (int i = 0; i < dimsrc; ++i) { - CHECK_EQ(axes_checking_flag[i], true) - << "Invalid axes input! All elements of axes must be between 0 and " << dimsrc - << ", find axes=" << axes; - } - dst_stride_[dimsrc - 1] = 1; - for (int i = dimsrc - 2; i >= 0; --i) dst_stride_[i] = dst_shape_[i+1] * dst_stride_[i+1]; - for (int i = 0; i < dimsrc; ++i) { - src_in_dst_stride_[axes[i]] = dst_stride_[i]; - } - } -}; - -/*! - * \brief a expression that reshapes a tensor to another shape - * \param src Tensor: - * \return a expresion with type Tensor - * \tparam a1 higher dimension to be swapped, assert a1 > a2 - * \tparam a2 lower dimension to be swapped - * \tparam SrcExp source expression - * \tparam DType the type of elements - * \tparam etype source expression type - */ -template -inline TransposeIndicesExp -transpose_indices(const Exp &src_indices, - Shape src_shape, - Shape axes) { - return TransposeIndicesExp(src_indices.self(), src_shape, axes); -} - -template -struct Plan, DType> { - public: - explicit Plan(const TransposeIndicesExp &e) - : src_indices_(MakePlan(e.src_indices_)), - src_in_dst_stride_(e.src_in_dst_stride_), - src_shape_(e.src_shape_) {} - MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { - index_t src_idx = static_cast(src_indices_.Eval(i, j)); - index_t dst_idx = 0; - #pragma unroll - for (int k = dimsrc - 1; k >= 0; --k) { - dst_idx += (src_idx % src_shape_[k]) * src_in_dst_stride_[k]; - src_idx /= src_shape_[k]; - } - return static_cast(dst_idx); - } - - private: - Plan src_indices_; - const Shape src_in_dst_stride_, src_shape_; -}; - -//---------------------- -// Execution plan -//---------------------- -/*! \brief make expression */ -template -inline Plan, DType> -MakePlan(const TransposeIndicesExp &e) { - return Plan, DType>(e); -} - -template -struct ShapeCheck > { - inline static Shape - Check(const TransposeIndicesExp &t) { - Shape s = ShapeCheck::Check(t.src_indices_); - return s; - } -}; - -template -struct ExpInfo > { - static const int kDim = ExpInfo::kDim; - static const int kDevMask = ExpInfo::kDevMask; -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_TRANSPOSE_H_ diff --git a/include/mshadow/extension/unpack_patch2col.h b/include/mshadow/extension/unpack_patch2col.h deleted file mode 100644 index ed473f81d496..000000000000 --- a/include/mshadow/extension/unpack_patch2col.h +++ /dev/null @@ -1,151 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file unpack_patch2col.h - * \brief support for unpack - * \author Tianqi Chen - */ -#ifndef MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_ -#define MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_ -#include "../extension.h" -namespace mshadow { -namespace expr { -/*! - * \brief unpack local (overlap) patches of image to column of mat, - * can be used to implement convolution, this expression allow unpack of a batch - * this is a version support unpacking multiple images - * after getting unpacked mat, we can use: output = dot(weight, mat) to get covolved results, the relations: - * \tparam SrcExp source expression - * \tparam dstdim destination dimension - */ -template -struct UnpackPatchToColXExp: - public MakeTensorExp, - SrcExp, 2, DType>{ - /*! \brief source operand */ - const SrcExp &img_; - /*! \brief patch height */ - index_t psize_y_; - /*! \brief patch width */ - index_t psize_x_; - /*! \brief patch stride */ - index_t pstride_y_; - index_t pstride_x_; - /*! \brief patch dilate */ - index_t pdilate_y_; - index_t pdilate_x_; - /*! \brief number of input channel */ - index_t i_channel_; - /*! \brief height of img */ - index_t i_height_; - /*! \brief width of img */ - index_t i_width_; - /*! \brief constructor */ - UnpackPatchToColXExp(const SrcExp &img, - index_t psize_y, - index_t psize_x, - index_t pstride_y, - index_t pstride_x, - index_t pdilate_y, - index_t pdilate_x) - : img_(img), psize_y_(psize_y), psize_x_(psize_x), - pstride_y_(pstride_y), pstride_x_(pstride_x), - pdilate_y_(pdilate_y), pdilate_x_(pdilate_x){ - Shape imshape = ShapeCheck::Check(img_); - CHECK(imshape[srcdim - 1] >= psize_x && imshape[srcdim - 2] >= psize_y) - << "UnpackPatchToCol:image shape smaller than patch size"; - this->i_channel_ = imshape[srcdim - 3]; - this->i_height_ = imshape[srcdim - 2]; - this->i_width_ = imshape[srcdim - 1]; - // calculate number of batches - const index_t num = imshape.ProdShape(0, srcdim - 3); - const index_t o_height = (i_height_ - - (pdilate_y * (psize_y - 1) + 1)) / pstride_y + 1; - const index_t o_width = (i_width_ - - (pdilate_x * (psize_x - 1) + 1)) / pstride_x + 1; - this->shape_[1] = o_height * o_width * num; - this->shape_[0] = psize_y * psize_x * i_channel_; - } -}; - -/*! - * \brief unpack local (overlap) patches of image to column of mat, can be used to implement convolution - * after getting unpacked mat, we can use: output = dot(weight, mat) to get covolved results, the relations: - * - * weight; shape[0]: out_channel, shape[1]: ichannel * psize_y * psize_x - * output; shape[0]: out_channel, shape[1]: out_height * out_width * num_of_images - * out_height = (in_height - psize_y) / pstride + 1, this means we pad inperfect patch with 0 - * out_width = (in_width - psize_x) / pstride + 1 - * - * \return mat target matrix; shape[0]: in_channel*psize_y*psize_x shape[1]: out_height*out_width * num_of_images - * \param img source image; shape[-3]: in_channels, shape[-2]: in_height, shape[-1]: in_width, can be 3D or 4D tensor(multiple images) - * \param psize_y height of each patch - * \param psize_x width of each patch - * \param pstride stride of each patch - * \param pdilate dilate of each patch - * \tparam SrcExp source expression - * \tparam DType the type of elements - * \tparam etype type of expression - */ -template -inline UnpackPatchToColXExp::kDim> -unpack_patch2col(const Exp &img, - index_t psize_y, index_t psize_x, index_t pstride, index_t pdilate) { - TypeCheckPass::kDim >= 3> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - return UnpackPatchToColXExp::kDim> - (img.self(), psize_y, psize_x, pstride, pstride, pdilate, pdilate); -} - -/*! - *if you want to specify stride_x and stride_y - */ -template -inline UnpackPatchToColXExp::kDim> -unpack_patch2col(const Exp &img, - index_t psize_y, index_t psize_x, index_t pstride_y_, index_t pstride_x_, - index_t pdilate_y_, index_t pdilate_x_) { - TypeCheckPass::kDim >= 3> - ::Error_Expression_Does_Not_Meet_Dimension_Req(); - return UnpackPatchToColXExp::kDim> - (img.self(), psize_y, psize_x, pstride_y_, pstride_x_, pdilate_y_, pdilate_x_); -} -//---------------------- -// Execution plan -//---------------------- -template -struct Plan, DType> { - public: - explicit Plan(const UnpackPatchToColXExp &e) - :src_(MakePlan(e.img_)), - psize_y_(e.psize_y_), psize_x_(e.psize_x_), - pstride_y_(e.pstride_y_), pstride_x_(e.pstride_x_), - i_channel_(e.i_channel_), pdilate_y_(e.pdilate_y_), pdilate_x_(e.pdilate_x_), - i_height_(e.i_height_), i_width_(e.i_width_), - o_height_((i_height_ - (pdilate_y_ * (psize_y_ - 1) + 1)) / pstride_y_ + 1), - o_width_((i_width_ - (pdilate_x_ * (psize_x_ - 1) + 1)) / pstride_x_ + 1) {} - MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { - const index_t x_offset = i % psize_x_ * pdilate_x_; - const index_t idivp = i / psize_x_; - const index_t y_offset = idivp % psize_y_ * pdilate_y_; - const index_t c = idivp / psize_y_; - const index_t x = (j % o_width_) * pstride_x_ + x_offset; - const index_t jdivw = j / o_width_; - const index_t y = (jdivw % o_height_) * pstride_y_ + y_offset; - const index_t n = jdivw / o_height_; - - if (x < i_width_ && y < i_height_) { - return src_.Eval((n * i_channel_ + c) * i_height_ + y, x); - } else { - return DType(0.0f); - } - } - - private: - Plan src_; - const index_t psize_y_, psize_x_, pstride_y_, pstride_x_, i_channel_; - const index_t pdilate_y_, pdilate_x_; - const index_t i_height_, i_width_, o_height_, o_width_; -}; -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_ diff --git a/include/mshadow/half.h b/include/mshadow/half.h deleted file mode 100644 index 75d8e5d09d2f..000000000000 --- a/include/mshadow/half.h +++ /dev/null @@ -1,288 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file half.h - * \brief definition of half (float16) type. - * - * \author Junyuan Xie - */ -#ifndef MSHADOW_HALF_H_ -#define MSHADOW_HALF_H_ -#include "./base.h" - -#if MSHADOW_USE_F16C - #include -#endif // MSHADOW_USE_F16C - -#if (MSHADOW_USE_CUDA && CUDA_VERSION >= 7050) - #define MSHADOW_CUDA_HALF 1 - #include - #if defined(__CUDA_ARCH__) - /*! \brief __half2float_warp */ - __host__ __device__ float __half2float_warp(const volatile __half& h) { /* NOLINT(*) */ - __half val; -#if CUDA_VERSION >= 9000 - val = const_cast<__half&>(h); -#else - val.x = h.x; -#endif - return __half2float(val); - } - #endif -#else - #define MSHADOW_CUDA_HALF 0 -#endif - -/*! \brief namespace for mshadow */ -namespace mshadow { -/* \brief name space for host/device portable half-precision floats */ -namespace half { -#define MSHADOW_HALF_OPERATOR(RTYPE, OP) \ - MSHADOW_XINLINE RTYPE operator OP (half_t a, half_t b) { \ - return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \ - } \ - template \ - MSHADOW_XINLINE RTYPE operator OP (half_t a, T b) { \ - return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \ - } \ - template \ - MSHADOW_XINLINE RTYPE operator OP (T a, half_t b) { \ - return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \ - } - -#define MSHADOW_HALF_ASSIGNOP(AOP, OP) \ - template \ - MSHADOW_XINLINE half_t operator AOP (const T& a) { \ - return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \ - } \ - template \ - MSHADOW_XINLINE half_t operator AOP (const volatile T& a) volatile { \ - return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \ - } - -#if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__)) -#define MSHADOW_HALF_CONVERSIONOP(T) \ - MSHADOW_XINLINE operator T() const { \ - return T(__half2float(cuhalf_)); /* NOLINT(*)*/ \ - } \ - MSHADOW_XINLINE operator T() const volatile { \ - return T(__half2float_warp(cuhalf_)); /* NOLINT(*)*/ \ - } -#elif(MSHADOW_USE_F16C) -#define MSHADOW_HALF_CONVERSIONOP(T) \ - MSHADOW_XINLINE operator T() const { \ - return T(_cvtsh_ss(half_)); /* NOLINT(*)*/ \ - } \ - MSHADOW_XINLINE operator T() const volatile { \ - return T(_cvtsh_ss(half_)); /* NOLINT(*)*/ \ - } -#else -#define MSHADOW_HALF_CONVERSIONOP(T) \ - MSHADOW_XINLINE operator T() const { \ - return T(half2float(half_)); /* NOLINT(*)*/ \ - } \ - MSHADOW_XINLINE operator T() const volatile { \ - return T(half2float(half_)); /* NOLINT(*)*/ \ - } -#endif // (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__)) - -class MSHADOW_ALIGNED(2) half_t { - public: - union { - uint16_t half_; -#if MSHADOW_CUDA_HALF - __half cuhalf_; -#endif // MSHADOW_CUDA_HALF - }; - - static MSHADOW_XINLINE half_t Binary(uint16_t value) { - half_t res; - res.half_ = value; - return res; - } - - MSHADOW_XINLINE half_t() {} - -#if MSHADOW_CUDA_HALF - MSHADOW_XINLINE explicit half_t(const __half& value) { - cuhalf_ = value; - } -#endif // MSHADOW_CUDA_HALF - - MSHADOW_XINLINE half_t(const float& value) { constructor(value); } - MSHADOW_XINLINE explicit half_t(const double& value) { constructor(value); } - MSHADOW_XINLINE explicit half_t(const int8_t& value) { constructor(value); } - MSHADOW_XINLINE explicit half_t(const uint8_t& value) { constructor(value); } - MSHADOW_XINLINE explicit half_t(const int32_t& value) { constructor(value); } - MSHADOW_XINLINE explicit half_t(const uint32_t& value) { constructor(value); } - MSHADOW_XINLINE explicit half_t(const int64_t& value) { constructor(value); } - MSHADOW_XINLINE explicit half_t(const uint64_t& value) { constructor(value); } - - MSHADOW_HALF_CONVERSIONOP(float) - - MSHADOW_HALF_ASSIGNOP(+=, +) - MSHADOW_HALF_ASSIGNOP(-=, -) - MSHADOW_HALF_ASSIGNOP(*=, *) - MSHADOW_HALF_ASSIGNOP(/=, /) - - MSHADOW_XINLINE half_t operator+() { - return *this; - } - - MSHADOW_XINLINE half_t operator-() { - return half_t(-float(*this)); // NOLINT(*) - } - - MSHADOW_XINLINE half_t operator=(const half_t& a) { - half_ = a.half_; - return a; - } - - template - MSHADOW_XINLINE half_t operator=(const T& a) { - return *this = half_t(a); /* NOLINT(*)*/ - } - - MSHADOW_XINLINE half_t operator=(const half_t& a) volatile { - half_ = a.half_; - return a; - } - - template - MSHADOW_XINLINE half_t operator=(const T& a) volatile { - return *this = half_t(a); /* NOLINT(*)*/ - } - - private: - union Bits { - float f; - int32_t si; - uint32_t ui; - }; - - static int const shift = 13; - static int const shiftSign = 16; - - static int32_t const infN = 0x7F800000; // flt32 infinity - static int32_t const maxN = 0x477FE000; // max flt16 normal as a flt32 - static int32_t const minN = 0x38800000; // min flt16 normal as a flt32 - static int32_t const signN = 0x80000000; // flt32 sign bit - - static int32_t const infC = infN >> shift; - static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32 - static int32_t const maxC = maxN >> shift; - static int32_t const minC = minN >> shift; - static int32_t const signC = signN >> shiftSign; // flt16 sign bit - - static int32_t const mulN = 0x52000000; // (1 << 23) / minN - static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift)) - - static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted - static int32_t const norC = 0x00400; // min flt32 normal down shifted - - static int32_t const maxD = infC - maxC - 1; - static int32_t const minD = minC - subC - 1; - - MSHADOW_XINLINE uint16_t float2half(const float& value) const { - Bits v, s; - v.f = value; - uint32_t sign = v.si & signN; - v.si ^= sign; - sign >>= shiftSign; // logical shift - s.si = mulN; - s.si = s.f * v.f; // correct subnormals - v.si ^= (s.si ^ v.si) & -(minN > v.si); - v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN)); - v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN)); - v.ui >>= shift; // logical shift - v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC); - v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC); - return v.ui | sign; - } - - MSHADOW_XINLINE uint16_t float2half(const volatile float& value) const volatile { // NOLINT (*) - Bits v, s; - v.f = value; - uint32_t sign = v.si & signN; - v.si ^= sign; - sign >>= shiftSign; // logical shift - s.si = mulN; - s.si = s.f * v.f; // correct subnormals - v.si ^= (s.si ^ v.si) & -(minN > v.si); - v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN)); - v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN)); - v.ui >>= shift; // logical shift - v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC); - v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC); - return v.ui | sign; - } - - MSHADOW_XINLINE float half2float(const uint16_t& value) const { - Bits v; - v.ui = value; - int32_t sign = v.si & signC; - v.si ^= sign; - sign <<= shiftSign; - v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); - v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); - Bits s; - s.si = mulC; - s.f *= v.si; - int32_t mask = -(norC > v.si); - v.si <<= shift; - v.si ^= (s.si ^ v.si) & mask; - v.si |= sign; - return v.f; - } - - MSHADOW_XINLINE float half2float(const volatile uint16_t& value) const volatile { // NOLINT(*) - Bits v; - v.ui = value; - int32_t sign = v.si & signC; - v.si ^= sign; - sign <<= shiftSign; - v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); - v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); - Bits s; - s.si = mulC; - s.f *= v.si; - int32_t mask = -(norC > v.si); - v.si <<= shift; - v.si ^= (s.si ^ v.si) & mask; - v.si |= sign; - return v.f; - } - - template - MSHADOW_XINLINE void constructor(const T& value) { -#if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__)) - cuhalf_ = __float2half(float(value)); // NOLINT(*) -#elif(MSHADOW_USE_F16C) - half_ = _cvtss_sh(static_cast(value), 0); -#else /* !MSHADOW_CUDA_HALF && !MSHADOW_USE_F16C */ - half_ = float2half(float(value)); // NOLINT(*) -#endif /* !MSHADOW_CUDA_HALF && !MSHADOW_USE_F16C */ - } -}; - -/*! \brief overloaded + operator for half_t */ -MSHADOW_HALF_OPERATOR(half_t, +) -/*! \brief overloaded - operator for half_t */ -MSHADOW_HALF_OPERATOR(half_t, -) -/*! \brief overloaded * operator for half_t */ -MSHADOW_HALF_OPERATOR(half_t, *) -/*! \brief overloaded / operator for half_t */ -MSHADOW_HALF_OPERATOR(half_t, /) -/*! \brief overloaded > operator for half_t */ -MSHADOW_HALF_OPERATOR(bool, >) -/*! \brief overloaded < operator for half_t */ -MSHADOW_HALF_OPERATOR(bool, <) -/*! \brief overloaded >= operator for half_t */ -MSHADOW_HALF_OPERATOR(bool, >=) -/*! \brief overloaded <= operator for half_t */ -MSHADOW_HALF_OPERATOR(bool, <=) - -#define MSHADOW_HALF_MIN mshadow::half::half_t::Binary(0xFBFF); -#define MSHADOW_HALF_MAX mshadow::half::half_t::Binary(0x7BFF); -} // namespace half -} // namespace mshadow -#endif // MSHADOW_HALF_H_ diff --git a/include/mshadow/half2.h b/include/mshadow/half2.h deleted file mode 100755 index 3e130c85ba63..000000000000 --- a/include/mshadow/half2.h +++ /dev/null @@ -1,143 +0,0 @@ -/*! - * Copyright (c) 2017 by Contributors - * \file half2.h - * \brief definition of vector float16, half2 type. - * - * \author Antti-Pekka Hynninen - */ -#ifndef MSHADOW_HALF2_H_ -#define MSHADOW_HALF2_H_ - -#if (defined(__CUDACC__) && __CUDA_ARCH__ >= 530 && MSHADOW_USE_CUDA && CUDA_VERSION >= 7050) - #define MSHADOW_CUDA_HALF2 1 - #include -#else - #define MSHADOW_CUDA_HALF2 0 -#endif - -#include - -/*! \brief namespace for mshadow */ -namespace mshadow { -/* \brief name space for host/device portable half-precision floats */ -namespace half { - -#define MSHADOW_HALF2_ASSIGNOP(AOP, OP) \ - template \ - MSHADOW_XINLINE half2_t operator AOP (const T& a) { \ - return *this = half2_t(*this OP a); /* NOLINT(*)*/ \ - } \ - -class MSHADOW_ALIGNED(4) half2_t { - public: -#if MSHADOW_CUDA_HALF2 - half2 half2_; -#else - half_t half_t2[2]; -#endif - - MSHADOW_XINLINE half2_t() {} - -#if MSHADOW_CUDA_HALF2 - MSHADOW_XINLINE explicit half2_t(half2 a) : half2_(a) {} -#else - MSHADOW_XINLINE explicit half2_t(half_t a, half_t b) { - half_t2[0] = a; - half_t2[1] = b; - } -#endif - - MSHADOW_XINLINE explicit half2_t(int a) { -#if MSHADOW_CUDA_HALF2 - half2_ = __half2half2(__int2half_rz(a)); -#else - half_t2[0] = (half_t)a; - half_t2[1] = (half_t)a; -#endif - } - - MSHADOW_XINLINE half2_t operator+() { - return *this; - } - - MSHADOW_XINLINE half2_t operator-() { -#if MSHADOW_CUDA_HALF2 - return half2_t(__hneg2(half2_)); -#else - return half2_t(-half_t2[0], -half_t2[1]); -#endif - } - - MSHADOW_XINLINE half2_t operator=(const half2_t& a) { -#if MSHADOW_CUDA_HALF2 - half2_ = a.half2_; -#else - half_t2[0] = a.half_t2[0]; - half_t2[1] = a.half_t2[1]; -#endif - return a; - } - - MSHADOW_HALF2_ASSIGNOP(+=, +) - MSHADOW_HALF2_ASSIGNOP(-=, -) - MSHADOW_HALF2_ASSIGNOP(*=, *) - MSHADOW_HALF2_ASSIGNOP(/=, /) -}; - -/*! \brief overloaded + operator for half2_t */ -MSHADOW_XINLINE half2_t operator+(half2_t a, half2_t b) { -#if MSHADOW_CUDA_HALF2 - return half2_t(__floats2half2_rn(__low2float(a.half2_) + __low2float(b.half2_), - __high2float(a.half2_) + __high2float(b.half2_))); -#else - return half2_t(a.half_t2[0] + b.half_t2[0], a.half_t2[1] + b.half_t2[1]); -#endif -} -/*! \brief overloaded - operator for half2_t */ -MSHADOW_XINLINE half2_t operator-(half2_t a, half2_t b) { -#if MSHADOW_CUDA_HALF2 - return half2_t(__floats2half2_rn(__low2float(a.half2_) - __low2float(b.half2_), - __high2float(a.half2_) - __high2float(b.half2_))); -#else - return half2_t(a.half_t2[0] - b.half_t2[0], a.half_t2[1] - b.half_t2[1]); -#endif -} -/*! \brief overloaded * operator for half2_t */ -MSHADOW_XINLINE half2_t operator*(half2_t a, half2_t b) { -#if MSHADOW_CUDA_HALF2 - return half2_t(__floats2half2_rn(__low2float(a.half2_) * __low2float(b.half2_), - __high2float(a.half2_) * __high2float(b.half2_))); -#else - return half2_t(a.half_t2[0] * b.half_t2[0], a.half_t2[1] * b.half_t2[1]); -#endif -} -/*! \brief overloaded / operator for half2_t */ -MSHADOW_XINLINE half2_t operator/(half2_t a, half2_t b) { -#if MSHADOW_CUDA_HALF2 - return half2_t(__floats2half2_rn(__low2float(a.half2_) / __low2float(b.half2_), - __high2float(a.half2_) / __high2float(b.half2_))); -#else - return half2_t(a.half_t2[0] / b.half_t2[0], a.half_t2[1] / b.half_t2[1]); -#endif -} -/*! \brief overloaded % operator for half2_t */ -MSHADOW_XINLINE half2_t operator%(half2_t a, half2_t b) { -#if MSHADOW_CUDA_HALF2 - return half2_t(__floats2half2_rn(::fmod(__low2float(a.half2_), __low2float(b.half2_)), - ::fmod(__high2float(a.half2_), __high2float(b.half2_)))); -#else - return half2_t(::fmod(a.half_t2[0], b.half_t2[0]), ::fmod(a.half_t2[1], b.half_t2[1])); -#endif -} -/*! \brief overloaded == operator for half2_t */ -MSHADOW_XINLINE bool operator==(half2_t a, half2_t b) { -#if MSHADOW_CUDA_HALF2 - return __hbeq2(a.half2_, b.half2_); -#else - return (a.half_t2[0] == b.half_t2[0] && a.half_t2[1] == b.half_t2[1]); -#endif -} - -} // namespace half -} // namespace mshadow -#endif // MSHADOW_HALF2_H_ diff --git a/include/mshadow/io.h b/include/mshadow/io.h deleted file mode 100644 index 2d0efc3aa56b..000000000000 --- a/include/mshadow/io.h +++ /dev/null @@ -1,137 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file io.h - * \brief definitions of I/O functions for mshadow tensor - * \author Tianqi Chen - */ -#ifndef MSHADOW_IO_H_ -#define MSHADOW_IO_H_ -#include "./tensor.h" - -namespace mshadow { -namespace utils { -/*! - * \brief interface of stream I/O, used to serialize data, - * mshadow does not restricted to only this interface in SaveBinary/LoadBinary - * mshadow accept all class that implements Read and Write - */ -class IStream { - public: - /*! - * \brief read data from stream - * \param ptr pointer to memory buffer - * \param size size of block - * \return usually is the size of data readed - */ - virtual size_t Read(void *ptr, size_t size) = 0; - /*! - * \brief write data to stream - * \param ptr pointer to memory buffer - * \param size size of block - */ - virtual void Write(const void *ptr, size_t size) = 0; - /*! \brief virtual destructor */ - virtual ~IStream(void) {} -}; -} // namespace utils -/*! - * \brief CPU/GPU: save a tensor by binary format, for GPU version, a temp Tensor storage will be allocated - * \param fo output binary stream - * \param src source data file - * \tparam dim dimension of tensor - * \tparam DType type of element in tensor - * \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. - */ -template -inline void SaveBinary(TStream &fo, const Tensor &src); // NOLINT(*) -/*! - * \brief CPU/GPU: save a tensor by binary format, for GPU version, a temp Tensor storage will be allocated - * \param fo output binary stream - * \param src source data file - * \tparam dim dimension of tensor - * \tparam DType type of element in tensor - * \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. - */ -template -inline void SaveBinary(TStream &fo, const Tensor &src); // NOLINT(*) -/*! - * \brief CPU/GPU: load a tensor by binary format, for GPU version, a temp Tensor storage will be allocated - * if pre_alloc is true , then space in dst is preallocated, and must have same shape of the tensor loaded - * if pre_alloc is false, then dst originally does not have space allocated, LoadBinary will allocate space for dst - * \param fi output binary stream - * \param dst destination file - * \param pre_alloc whether space is pre-allocated, if false, space allocation will happen - * \tparam dim dimension of tensor - * \tparam DType type of element in tensor - * \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. - */ -template -inline void LoadBinary(TStream &fi, // NOLINT(*) - Tensor *dst, bool pre_alloc); -/*! - * \brief CPU/GPU: load a tensor by binary format, for GPU version, a temp Tensor storage will be allocated - * if pre_alloc is true , then space in dst is preallocated, and must have same shape of the tensor loaded - * if pre_alloc is false, then dst originally does not have space allocated, LoadBinary will allocate space for dst - * \param fi output binary stream - * \param dst destination file - * \param pre_alloc whether space is pre-allocated, if false, space allocation will happen - * \tparam dim dimension of tensor - * \tparam DType type of element in tensor - * \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. - */ - -template -inline void LoadBinary(TStream &fi, // NOLINT(*) - Tensor *dst, bool pre_alloc); - -// implementations -template -inline void SaveBinary(TStream &fo, const Tensor &src_) { // NOLINT(*) - fo.Write(&src_.shape_, sizeof(src_.shape_)); - Tensor src = src_.FlatTo2D(); - for (index_t i = 0; i < src.size(0); ++i) { - fo.Write(src[i].dptr_, sizeof(DType) * src.size(1)); - } -} -template -inline void SaveBinary(TStream &fo, const Tensor &src) { // NOLINT(*) - // copy to CPU, then save - Tensor tmp(src.shape_); - AllocSpace(&tmp); - Stream stream; - Copy(tmp, src, &stream); - SaveBinary(fo, tmp); - FreeSpace(&tmp); -} -template -inline void LoadBinary(TStream &fi, // NOLINT(*) - Tensor *dst_, bool pre_alloc) { - Shape shape; - CHECK_NE(fi.Read(&shape, sizeof(shape)), 0) << "mshadow::LoadBinary"; - if (pre_alloc) { - CHECK_EQ(shape, dst_->shape_) << "LoadBinary, shape do not match pre-allocated shape"; - } else { - dst_->shape_ = shape; AllocSpace(dst_); - } - Tensor dst = dst_->FlatTo2D(); - if (dst.size(0) == 0) return; - for (index_t i = 0; i < dst.size(0); ++i) { - CHECK_NE(fi.Read(dst[i].dptr_, sizeof(DType) * dst.size(1)), 0) << "mshadow::LoadBinary"; - } -} -template -inline void LoadBinary(TStream &fi, // NOLINT(*) - Tensor *dst, bool pre_alloc) { - Tensor tmp; - LoadBinary(fi, &tmp, false); - if (pre_alloc) { - CHECK_EQ(tmp.shape, dst->shape_) << "LoadBinary, shape do not match pre-allocated shape"; - } else { - dst->shape = tmp.shape; AllocSpace(dst); - } - Stream stream; - Copy(*dst, tmp, &stream); - FreeSpace(&tmp); -} -} // namespace mshadow -#endif // MSHADOW_IO_H_ diff --git a/include/mshadow/logging.h b/include/mshadow/logging.h deleted file mode 100644 index 002b90097595..000000000000 --- a/include/mshadow/logging.h +++ /dev/null @@ -1,234 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file logging.h - * \brief defines logging macros of dmlc - * allows use of GLOG, fall back to internal - * implementation when disabled - */ -#ifndef MSHADOW_LOGGING_H_ -#define MSHADOW_LOGGING_H_ -#ifndef DMLC_LOGGING_H_ -#define DMLC_LOGGING_H_ - -#include -#include -#include -#include -#include -#include "./base.h" - -namespace dmlc { -/*! \brief taken from DMLC directly */ - -/*! - * \brief exception class that will be thrown by - * default logger if DMLC_LOG_FATAL_THROW == 1 - */ -struct Error : public std::runtime_error { - /*! - * \brief constructor - * \param s the error message - */ - explicit Error(const std::string &s) : std::runtime_error(s) {} -}; -} // namespace dmlc - -#if defined(_MSC_VER) && _MSC_VER < 1900 -#define noexcept(a) -#endif - -#if DMLC_USE_GLOG -#include - -namespace dmlc { -/*! \brief taken from DMLC directly */ -inline void InitLogging(const char* argv0) { - google::InitGoogleLogging(argv0); -} -} // namespace dmlc - -#else -// use a light version of glog -#include -#include -#include -#include - -#if defined(_MSC_VER) -#pragma warning(disable : 4722) -#endif - -namespace dmlc { -inline void InitLogging(const char* argv0) { - // DO NOTHING -} - -// Always-on checking -#define CHECK(x) \ - if (!(x)) \ - dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check " \ - "failed: " #x << ' ' -#define CHECK_LT(x, y) CHECK((x) < (y)) -#define CHECK_GT(x, y) CHECK((x) > (y)) -#define CHECK_LE(x, y) CHECK((x) <= (y)) -#define CHECK_GE(x, y) CHECK((x) >= (y)) -#define CHECK_EQ(x, y) CHECK((x) == (y)) -#define CHECK_NE(x, y) CHECK((x) != (y)) -#define CHECK_NOTNULL(x) \ - ((x) == NULL ? dmlc::LogMessageFatal(__FILE__, __LINE__).stream() << "Check notnull: " #x << ' ', (x) : (x)) // NOLINT(*) -// Debug-only checking. -#ifdef NDEBUG -#define DCHECK(x) \ - while (false) CHECK(x) -#define DCHECK_LT(x, y) \ - while (false) CHECK((x) < (y)) -#define DCHECK_GT(x, y) \ - while (false) CHECK((x) > (y)) -#define DCHECK_LE(x, y) \ - while (false) CHECK((x) <= (y)) -#define DCHECK_GE(x, y) \ - while (false) CHECK((x) >= (y)) -#define DCHECK_EQ(x, y) \ - while (false) CHECK((x) == (y)) -#define DCHECK_NE(x, y) \ - while (false) CHECK((x) != (y)) -#else -#define DCHECK(x) CHECK(x) -#define DCHECK_LT(x, y) CHECK((x) < (y)) -#define DCHECK_GT(x, y) CHECK((x) > (y)) -#define DCHECK_LE(x, y) CHECK((x) <= (y)) -#define DCHECK_GE(x, y) CHECK((x) >= (y)) -#define DCHECK_EQ(x, y) CHECK((x) == (y)) -#define DCHECK_NE(x, y) CHECK((x) != (y)) -#endif // NDEBUG - -#define LOG_INFO dmlc::LogMessage(__FILE__, __LINE__) -#define LOG_ERROR LOG_INFO -#define LOG_WARNING LOG_INFO -#define LOG_FATAL dmlc::LogMessageFatal(__FILE__, __LINE__) -#define LOG_QFATAL LOG_FATAL - -// Poor man version of VLOG -#define VLOG(x) LOG_INFO.stream() - -#define LOG(severity) LOG_##severity.stream() -#define LG LOG_INFO.stream() -#define LOG_IF(severity, condition) \ - !(condition) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) - -#ifdef NDEBUG -#define LOG_DFATAL LOG_ERROR -#define DFATAL ERROR -#define DLOG(severity) true ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) -#define DLOG_IF(severity, condition) \ - (true || !(condition)) ? (void)0 : dmlc::LogMessageVoidify() & LOG(severity) -#else -#define LOG_DFATAL LOG_FATAL -#define DFATAL FATAL -#define DLOG(severity) LOG(severity) -#define DLOG_IF(severity, condition) LOG_IF(severity, condition) -#endif - -// Poor man version of LOG_EVERY_N -#define LOG_EVERY_N(severity, n) LOG(severity) - -class DateLogger { - public: - DateLogger() { -#if defined(_MSC_VER) - _tzset(); -#endif - } - const char* HumanDate() { -#if defined(_MSC_VER) - _strtime_s(buffer_, sizeof(buffer_)); -#else - time_t time_value = time(NULL); - struct tm now; - localtime_r(&time_value, &now); - snprintf(buffer_, sizeof(buffer_), "%02d:%02d:%02d", now.tm_hour, - now.tm_min, now.tm_sec); -#endif - return buffer_; - } - private: - char buffer_[9]; -}; - -class LogMessage { - public: - LogMessage(const char* file, int line) - : -#ifdef __ANDROID__ - log_stream_(std::cout) -#else - log_stream_(std::cerr) -#endif - { - log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":" - << line << ": "; - } - ~LogMessage() { log_stream_ << "\n"; } - std::ostream& stream() { return log_stream_; } - - protected: - std::ostream& log_stream_; - - private: - DateLogger pretty_date_; - LogMessage(const LogMessage&); - void operator=(const LogMessage&); -}; - -#if DMLC_LOG_FATAL_THROW == 0 -class LogMessageFatal : public LogMessage { - public: - LogMessageFatal(const char* file, int line) : LogMessage(file, line) {} - ~LogMessageFatal() { - log_stream_ << "\n"; - abort(); - } - - private: - LogMessageFatal(const LogMessageFatal&); - void operator=(const LogMessageFatal&); -}; -#else -class LogMessageFatal { - public: - LogMessageFatal(const char* file, int line) { - log_stream_ << "[" << pretty_date_.HumanDate() << "] " << file << ":" - << line << ": "; - } - std::ostringstream &stream() { return log_stream_; } - ~LogMessageFatal() DMLC_THROW_EXCEPTION { - // throwing out of destructor is evil - // hopefully we can do it here - throw Error(log_stream_.str()); - } - - private: - std::ostringstream log_stream_; - DateLogger pretty_date_; - LogMessageFatal(const LogMessageFatal&); - void operator=(const LogMessageFatal&); -}; -#endif - -// This class is used to explicitly ignore values in the conditional -// logging macros. This avoids compiler warnings like "value computed -// is not used" and "statement has no effect". -class LogMessageVoidify { - public: - LogMessageVoidify() {} - // This has to be an operator with a precedence lower than << but - // higher than "?:". See its usage. - void operator&(std::ostream&) {} -}; - -} // namespace dmlc - -#endif -#endif // DMLC_LOGGING_H_ -#endif // MSHADOW_LOGGING_H_ - diff --git a/include/mshadow/packet-inl.h b/include/mshadow/packet-inl.h deleted file mode 100644 index f5a89bfa8421..000000000000 --- a/include/mshadow/packet-inl.h +++ /dev/null @@ -1,413 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file packet-inl.h - * \brief Generic packet vectorization code - */ -#ifndef MSHADOW_PACKET_INL_H_ -#define MSHADOW_PACKET_INL_H_ - -#ifdef __APPLE__ -#include -#else -#include -#endif -#include "./base.h" -#include "./tensor.h" -#include "./expression.h" - - -namespace mshadow { -/*! \brief namespace of packet math*/ -namespace packet { - -enum PacketArch { - kPlain, - kSSE2, -}; - -#if MSHADOW_USE_SSE -#define MSHADOW_DEFAULT_PACKET ::mshadow::packet::kSSE2 -#else -#define MSHADOW_DEFAULT_PACKET ::mshadow::packet::kPlain -#endif - -// whether packet operator is enabled. -/*! - * \brief Generic packet type - * \tparam DType The data type of the packet. - * \tparam Arch the Arch of the packet. - */ -template -struct Packet; - -template -struct AlignBytes { - static const index_t value = 4; -}; - -} // namespace packet -} // namespace mshadow - -namespace mshadow { -namespace packet { -/*! - * \brief analog to cudaMallocPitch, allocate a aligned space with num_line * lspace cells - * \param out_pitch output parameter, the actuall space allocated for each line - * \param lspace number of cells required for each line - * \param num_line number of lines to be allocated - */ -inline void* AlignedMallocPitch(size_t *out_pitch, - size_t lspace, - size_t num_line) { - const index_t bits = AlignBytes::value; - const index_t mask = (1 << bits) - 1; - - size_t pitch = ((lspace + mask) >> bits) << bits; - *out_pitch = pitch; -#ifdef _MSC_VER - void *res = _aligned_malloc(pitch * num_line, 1 << bits); -#else - void *res; - int ret = posix_memalign(&res, 1 << bits, pitch * num_line); - CHECK_EQ(ret, 0) << "AlignedMallocPitch failed"; -#endif - if (res == NULL) { - LOG(FATAL) << "AlignedMallocPitch failed"; - } - return res; -} - -/*! - * \brief free aligned space - * \param ptr pointer to space to be freed - */ -inline void AlignedFree(void *ptr) { -#ifdef _MSC_VER - _aligned_free(ptr); -#else - free(ptr); -#endif -} - -/*! \brief check if a pointer is aligned */ -template -inline bool CheckAlign(size_t pitch) { - const index_t bits = AlignBytes::value; - return !(pitch & ((1 << bits) - 1)); -} - -/*! \brief check if a pointer is aligned */ -template -inline bool CheckAlign(void *ptr) { - return CheckAlign(reinterpret_cast(ptr)); -} - -/*! - * \brief get upper bound of aligned index of size - * \param size size of the array - * \param fsize size of float - */ -template -inline index_t UpperAlign(index_t size) { - const index_t bits = AlignBytes::value; - const index_t mask = (1 << bits) - 1; - const index_t fsize = sizeof(DType); - return (((size * fsize + mask) >> bits) << bits) / fsize; -} - -/*! - * \brief get lower bound of aligned index of size - * \param size size of the array - * \param fsize size of float - */ -template -inline index_t LowerAlign(index_t size) { - const index_t bits = AlignBytes::value; - const index_t fsize = sizeof(DType); - return (((size * fsize) >> bits) << bits) / fsize; -} - -/*! - * \brief generic Packet operator - * \tparam OP The operator - * \tparam DType The data type - * \tparam Arch The architecture. - */ -template -struct PacketOp { - static const bool kEnabled = false; -}; -// specialization of operators -template -struct PacketOp { - static const bool kEnabled = true; - MSHADOW_CINLINE static Packet Map(const Packet& lhs, - const Packet& rhs) { - return lhs + rhs; - } -}; -template -struct PacketOp { - static const bool kEnabled = true; - MSHADOW_CINLINE static Packet Map(const Packet& lhs, - const Packet& rhs) { - return lhs - rhs; - } -}; -template -struct PacketOp { - static const bool kEnabled = true; - MSHADOW_CINLINE static Packet Map(const Packet& lhs, - const Packet& rhs) { - return lhs * rhs; - } -}; -template -struct PacketOp { - static const bool kEnabled = true; - MSHADOW_CINLINE static Packet Map(const Packet& lhs, - const Packet& rhs) { - return lhs / rhs; - } -}; - -template -struct PacketOp { - static const bool kEnabled = true; - MSHADOW_CINLINE static Packet Map(const Packet& src) { - return src; - } -}; - - -// savers to do storage -template -struct Saver{ - MSHADOW_CINLINE static void Save(TFloat *dst, const Packet& src) { - Packet lhs = Packet::Load(dst); - Packet ans = PacketOp::Map(lhs, src); - ans.Store(dst); - } -}; -template -struct Saver { - MSHADOW_CINLINE static void Save(TFloat *dst, const Packet& src) { - src.Store(dst); - } -}; -} // namespace packet -} // namespace mshadow - -#include "packet/plain-inl.h" -#if MSHADOW_USE_SSE && !defined(__CUDACC__) -#include "packet/sse-inl.h" -#endif - -namespace mshadow { -namespace expr { - -typedef packet::PacketArch PacketArch; - -// same as plan, but use packet -template -class PacketPlan { - public: - /*! - * \brief evaluate the expression at index [y][x], - * x will be aligned to Packet::Size() - */ - MSHADOW_CINLINE packet::Packet EvalPacket(index_t y, index_t x) const; - MSHADOW_CINLINE DType Eval(index_t y, index_t x) const; -}; - -template -class PacketPlan, DType, Arch> { - public: - explicit PacketPlan(const Tensor &t) - :dptr_(t.dptr_), stride_(t.stride_) {} - MSHADOW_CINLINE packet::Packet EvalPacket(index_t y, index_t x) const { - return packet::Packet::Load(&dptr_[y * stride_ + x]); - } - MSHADOW_CINLINE DType Eval(index_t y, index_t x) const { - return dptr_[y * stride_ + x]; - } - - private: - const DType *dptr_; - index_t stride_; -}; - -template -class PacketPlan, DType, Arch> { - public: - explicit PacketPlan(DType scalar) : scalar_(scalar) {} - MSHADOW_CINLINE packet::Packet EvalPacket(index_t y, index_t x) const { - return packet::Packet::Fill(scalar_); - } - MSHADOW_CINLINE DType Eval(index_t y, index_t x) const { - return scalar_; - } - - private: - DType scalar_; -}; - -template -class PacketPlan, DType, Arch> { - public: - PacketPlan(const PacketPlan &lhs, const PacketPlan &rhs) - : lhs_(lhs), rhs_(rhs) {} - MSHADOW_CINLINE packet::Packet EvalPacket(index_t y, index_t x) const { - return packet::PacketOp::Map(lhs_.EvalPacket(y, x), rhs_.EvalPacket(y, x)); - } - MSHADOW_CINLINE DType Eval(index_t y, index_t x) const { - return OP::Map(lhs_.Eval(y, x), rhs_.Eval(y, x)); - } - - private: - PacketPlan lhs_; - PacketPlan rhs_; -}; - -template -class PacketPlan, DType, Arch> { - public: - PacketPlan(const PacketPlan &src) : src_(src) {} - MSHADOW_CINLINE packet::Packet EvalPacket(index_t y, index_t x) const { - return packet::PacketOp::Map(src_.EvalPacket(y, x)); - } - MSHADOW_CINLINE DType Eval(index_t y, index_t x) const { - return OP::Map(src_.Eval(y, x)); - } - - private: - PacketPlan src_; -}; - -template -inline PacketPlan, DType, Arch> -MakePacketPlan(const BinaryMapExp &e); - -template -inline PacketPlan, DType, Arch> MakePacketPlan(const ScalarExp &e) { - return PacketPlan, DType, Arch>(e.scalar_); -} -template -inline PacketPlan MakePacketPlan(const RValueExp &e) { - return PacketPlan(e.self()); -} -template -inline PacketPlan -MakePacketPlan(const MakeTensorExp &e) { - return PacketPlan(e.real_self()); -} -template -inline PacketPlan, DType, Arch> -MakePacketPlan(const UnaryMapExp &e) { - return PacketPlan, DType, Arch>(MakePacketPlan(e.src_)); -} -template -inline PacketPlan, DType, Arch> -MakePacketPlan(const BinaryMapExp &e) { - return PacketPlan, - DType, Arch>(MakePacketPlan(e.lhs_), MakePacketPlan(e.rhs_)); -} - -/*! - * \brief static check packet enable - * - * \tparam Device the type of Device - * \tparam dim dimension of the tensor - * \tparam E expression - */ -template -struct PacketCheck{ - static const bool kPass = false; -}; -template -struct PacketCheck { - static const bool kPass = true; -}; -template -struct PacketCheck { - static const bool kPass = true; -}; -template -struct PacketCheck, Arch> { - static const bool kPass = PacketCheck::kPass; -}; -template -struct PacketCheck, Arch> { - static const bool kPass = PacketCheck::kPass; -}; -template -struct PacketCheck, Arch> { - static const bool kPass = PacketCheck::kPass && - packet::PacketOp::kEnabled; -}; -template -struct PacketCheck< BinaryMapExp, Arch> { - static const bool kPass = packet::PacketOp::kEnabled && - PacketCheck::kPass && PacketCheck::kPass; -}; -//---------------------------------------------------- -// Check if data is aligned and allow packet operation -//---------------------------------------------------- -template -struct PacketAlignCheck { - inline static bool Check(const E &exp) { - return false; - } -}; -template -struct PacketAlignCheck, Arch> { - inline static bool Check(const ScalarExp &exp) { - return true; - } -}; -template -struct PacketAlignCheck, Arch> { - inline static bool Check(const Tensor &t) { - return packet::CheckAlign(t.dptr_) && - packet::CheckAlign(t.stride_ * sizeof(DType)); - } -}; -template -struct PacketAlignCheck, Arch> { - inline static bool Check(const UnaryMapExp &t) { - return PacketAlignCheck::Check(t.src_); - } -}; -template -struct PacketAlignCheck, Arch> { - inline static bool Check(const BinaryMapExp &t) { - return PacketAlignCheck::Check(t.lhs_) && - PacketAlignCheck::Check(t.rhs_); - } -}; - -/*! - * \brief use PacketPlan to compute result - */ -template -inline void MapPacketPlan(Tensor _dst, - const expr::PacketPlan& plan) { - Tensor dst = _dst.FlatTo2D(); - const index_t xlen = packet::LowerAlign(dst.size(1)); - const size_t packetSize = packet::Packet::size; -#ifndef __CUDACC__ - #pragma omp parallel for -#endif - for (openmp_index_t y = 0; y < dst.size(0); ++y) { - for (index_t x = 0; x < xlen; x += packetSize) { - packet::Saver::Save(&dst[y][x], plan.EvalPacket(y, x)); - } - for (index_t x = xlen; x < dst.size(1); ++x) { - SV::Save(dst[y][x], plan.Eval(y, x)); - } - } -} -} // namespace expr -} // namespace mshadow -#endif // MSHADOW_PACKET_INL_H_ diff --git a/include/mshadow/packet/plain-inl.h b/include/mshadow/packet/plain-inl.h deleted file mode 100644 index de28ad7b4894..000000000000 --- a/include/mshadow/packet/plain-inl.h +++ /dev/null @@ -1,76 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file plain-inl.h - * \brief support of plain packet that use the plain datatype. - */ -#ifndef MSHADOW_PACKET_PLAIN_INL_H_ -#define MSHADOW_PACKET_PLAIN_INL_H_ - -#include "../base.h" -#include "../packet-inl.h" - -namespace mshadow { -namespace packet { -template -struct Packet { - public: - /*! \brief number of float in vector */ - static constexpr index_t size = 1; - /*! \brief The internal data */ - DType data_; - // enable default copy constructor - Packet(void) {} - // constructor from the intrinsic type - explicit Packet(DType data) : data_(data) {} - // create a fill with the target value s - MSHADOW_CINLINE static Packet Fill(DType s) { - return Packet(s); - } - // load from address - MSHADOW_CINLINE static Packet Load(const DType* src) { - return Packet(*src); - } - // load from address - MSHADOW_CINLINE static Packet LoadUnAligned(const DType* src) { - return Packet(*src); - } - // fill it with value s - MSHADOW_CINLINE Packet& operator=(DType s) { - data_ = s; - return *this; - } - // store data into dst - MSHADOW_CINLINE void Store(DType* dst) const { - *dst = data_; - } - // get the sum of all contents - MSHADOW_CINLINE DType Sum() const { - return data_; - } -}; - -template -MSHADOW_CINLINE Packet operator+(const Packet& lhs, - const Packet& rhs) { - return Packet(lhs.data_ + rhs.data_); -} - -template -MSHADOW_CINLINE Packet operator-(const Packet& lhs, - const Packet& rhs) { - return Packet(lhs.data_ - rhs.data_); -} -template -MSHADOW_CINLINE Packet operator*(const Packet& lhs, - const Packet& rhs) { - return Packet(lhs.data_ * rhs.data_); -} - -template -MSHADOW_CINLINE Packet operator/(const Packet& lhs, - const Packet& rhs) { - return Packet(lhs.data_ / rhs.data_); -} -} // namespace packet -} // namespace mshadow -#endif // MSHADOW_PACKET_PLAIN_INL_H_ diff --git a/include/mshadow/packet/sse-inl.h b/include/mshadow/packet/sse-inl.h deleted file mode 100644 index 923a5f60de38..000000000000 --- a/include/mshadow/packet/sse-inl.h +++ /dev/null @@ -1,147 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file sse-inl.h - * \brief support of sse2 packet optimization of some operations - * \author Tianqi Chen - */ -#ifndef MSHADOW_PACKET_SSE_INL_H_ -#define MSHADOW_PACKET_SSE_INL_H_ - -#include -#include "../base.h" -#include "../packet-inl.h" - -namespace mshadow { -namespace packet { -template<> -struct Packet { - public: - /*! \brief number of float in vector */ - static constexpr index_t size = 4; - /*! \brief The internal data */ - __m128 data_; - // enable default copy constructor - Packet(void) {} - // constructor from the intrinsic type - explicit Packet(__m128 data) : data_(data) {} - // create a fill with the target value s - MSHADOW_CINLINE static Packet Fill(float s) { - return Packet(_mm_set1_ps(s)); - } - // load from address - MSHADOW_CINLINE static Packet Load(const float* src) { - return Packet(_mm_load_ps(src)); - } - // load from address - MSHADOW_CINLINE static Packet LoadUnAligned(const float* src) { - return Packet(_mm_loadu_ps(src)); - } - // fill it with value s - MSHADOW_CINLINE Packet& operator=(float s) { - data_ = _mm_set1_ps(s); - return *this; - } - // store data into dst - MSHADOW_CINLINE void Store(float* dst) const { - _mm_store_ps(dst, data_); - } - // get the sum of all contents - MSHADOW_CINLINE float Sum() const { - __m128 ans = _mm_add_ps(data_, _mm_movehl_ps(data_, data_)); - __m128 rst = _mm_add_ss(ans, _mm_shuffle_ps(ans, ans, 1)); -#if defined(_MSC_VER) && (_MSC_VER <= 1500) && defined(_WIN64) - return rst.m128_f32[0]; -#else - float rr = _mm_cvtss_f32(rst); - return rr; -#endif - } -}; - - -/*! \brief vector real type for float */ -template<> -struct Packet { - /*! \brief number of float in vector */ - static constexpr index_t size = 2; - // internal data - __m128d data_; - // constructor - Packet(void) {} - explicit Packet(__m128d data) : data_(data) {} - // create a fill with the target value s - MSHADOW_CINLINE static Packet Fill(double s) { - return Packet(_mm_set1_pd(s)); - } - // load from address - MSHADOW_CINLINE static Packet Load(const double* src) { - return Packet(_mm_load_pd(src)); - } - MSHADOW_CINLINE static Packet LoadUnAligned(const double* src) { - return Packet(_mm_loadu_pd(src)); - } - // fill it with value s - MSHADOW_CINLINE Packet& operator=(double s) { - data_ = _mm_set1_pd(s); - return *this; - } - // store data into dst - MSHADOW_CINLINE void Store(double* dst) const { - _mm_store_pd(dst, data_); - } - // get sum of all content - inline double Sum(void) const { - __m128d tmp = _mm_add_sd(data_, _mm_unpackhi_pd(data_, data_)); -#if defined(_MSC_VER) && (_MSC_VER <= 1500) && defined(_WIN64) - return tmp.m128d_f64[0]; -#else - double ans = _mm_cvtsd_f64(tmp); - return ans; -#endif - } -}; - -MSHADOW_CINLINE Packet operator+(const Packet& lhs, - const Packet& rhs) { - return Packet(_mm_add_ps(lhs.data_, rhs.data_)); -} - -MSHADOW_CINLINE Packet operator+(const Packet& lhs, - const Packet& rhs) { - return Packet(_mm_add_pd(lhs.data_, rhs.data_)); -} - -MSHADOW_CINLINE Packet operator-(const Packet& lhs, - const Packet& rhs) { - return Packet(_mm_sub_ps(lhs.data_, rhs.data_)); -} - -MSHADOW_CINLINE Packet operator-(const Packet& lhs, - const Packet& rhs) { - return Packet(_mm_sub_pd(lhs.data_, rhs.data_)); -} - -MSHADOW_CINLINE Packet operator*(const Packet& lhs, - const Packet& rhs) { - return Packet(_mm_mul_ps(lhs.data_, rhs.data_)); -} - -MSHADOW_CINLINE Packet operator*(const Packet& lhs, - const Packet& rhs) { - return Packet(_mm_mul_pd(lhs.data_, rhs.data_)); -} - - -MSHADOW_CINLINE Packet operator/(const Packet& lhs, - const Packet& rhs) { - return Packet(_mm_div_ps(lhs.data_, rhs.data_)); -} - -MSHADOW_CINLINE Packet operator/(const Packet& lhs, - const Packet& rhs) { - return Packet(_mm_div_pd(lhs.data_, rhs.data_)); -} - -} // namespace packet -} // namespace mshadow -#endif // MSHADOW_PACKET_SSE_INL_H_ diff --git a/include/mshadow/random.h b/include/mshadow/random.h deleted file mode 100644 index c136f4f67809..000000000000 --- a/include/mshadow/random.h +++ /dev/null @@ -1,570 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file random.h - * \brief Random inline functions for tensor. - * \author Bing Xu, Tianqi Chen - * Based on curand|MKL|stdlib - */ -#ifndef MSHADOW_RANDOM_H_ -#define MSHADOW_RANDOM_H_ - -#include -#include -#include -#include "./base.h" -#include "./tensor.h" -#include "./tensor_container.h" - -#if MSHADOW_IN_CXX11 -#include // use cxx11 random by default -#endif - -#if _MSC_VER -#define rand_r(x) rand() -#endif - - -namespace mshadow { -/*! - * \brief random number generator - * \tparam Device the device of random number generator - * \tparam DType the target data type of random number can be float for double - */ -template -class Random {}; - -/*! \brief CPU random number generator */ -template -class Random { - public: - /*! - * \brief constructor of random engine - * \param seed random number seed - */ - explicit Random(int seed) { - this->Seed(seed); - buffer_.Resize(Shape1(kRandBufferSize)); - } - ~Random(void) { - } - /*! - * \brief seed random number generator using this seed - * \param seed seed of prng - */ - inline void Seed(int seed) { -#if MSHADOW_IN_CXX11 - rnd_engine_.seed(seed); -#endif - this->rseed_ = static_cast(seed); - } - /*! - * \brief get random seed used in random generator - * \return seed in unsigned - */ - inline unsigned GetSeed() const { - return rseed_; - } - /*! - * \brief set the stream of computation - * \param stream computation stream - */ - inline void set_stream(Stream *stream) { - } - -// These samplers are only avail in C++11. -#if MSHADOW_IN_CXX11 - - /*! - * \brief get some random integer - * \return integer as unsigned - */ - inline unsigned GetRandInt() { - return rnd_engine_(); - } - - /*! - * \brief get a set of random integers - */ - inline void GetRandInt(const Tensor& dst) { - std::generate_n(dst.dptr_, dst.size(0), [&](){ return rnd_engine_(); }); - } - - /*! - * \brief generate data from a distribution - * \param dst destination - * \tparam dim dimension of tensor - * \param sampler sampler of the distribution - */ - template - inline void SampleDistribution(Tensor *dst, Sampler sampler) { - if (dst->CheckContiguous()) { - std::generate_n(dst->dptr_, dst->shape_.Size(), sampler); - } else { - Tensor mat = dst->FlatTo2D(); - for (index_t i = 0; i < mat.size(0); ++i) { - std::generate_n(mat[i].dptr_, mat.size(1), sampler); - } - } - } - - /*! - * \brief generate data from uniform [a,b) - * \param dst destination - * \param a lower bound of uniform - * \param b upper bound of uniform - * \tparam dim dimension of tensor - */ - template - inline void SampleUniform(Tensor *dst, - PType a = 0.0f , PType b = 1.0f ) { - // Ensure that half_t is handled correctly. - typedef typename std::conditional::value, - DType, double>::type FType; - typedef typename std::conditional::value, - std::uniform_int_distribution, - std::uniform_real_distribution>::type GType; - GType dist_uniform(a, b); - SampleDistribution(dst, [&](){ return dist_uniform(rnd_engine_);}); - } - - /*! - * \brief generate data from standard gaussian - * \param dst destination - * \param mu mean variable - * \param sigma standard deviation - * \tparam dim dimension of tensor - */ - template - inline void SampleGaussian(Tensor *dst, - PType mu = 0.0f, PType sigma = 1.0f ) { - if (sigma <= 0) { - *dst = mu; return; - } - typedef typename std::conditional::value, - DType, double>::type GType; - std::normal_distribution dist_normal(mu, sigma); - SampleDistribution(dst, [&](){ return dist_normal(rnd_engine_);}); - } - - /*! - * \brief generate data from a gamma distribution - * \param dst destination - * \param alpha (shape) parameter - * \param beta (scale) parameter - * \tparam dim dimension of tensor - */ - template - inline void SampleGamma(Tensor *dst, - PType alpha, PType beta) { - typedef typename std::conditional::value, - DType, double>::type GType; - std::gamma_distribution dist_gamma(alpha, beta); - SampleDistribution(dst, [&](){ return dist_gamma(rnd_engine_);}); - } - - /*! - * \brief generate data from an exponential distribution - * \param dst destination - * \param lambda parameter (rate) of the distribution - * \tparam dim dimension of tensor - */ - template - inline void SampleExponential(Tensor *dst, PType lambda ) { - typedef typename std::conditional::value, - DType, double>::type GType; - std::exponential_distribution dist_exp(lambda); - SampleDistribution(dst, [&](){ return dist_exp(rnd_engine_);}); - } - - /*! - * \brief generate data from a poisson distribution - * \param dst destination - * \param lambda parameter (rate) of the distribution - * \tparam dim dimension of tensor - */ - template - inline void SamplePoisson(Tensor *dst, PType lambda) { - typedef typename std::conditional::value, DType, int>::type GType; - std::poisson_distribution dist_poisson(lambda); - SampleDistribution(dst, [&](){ return static_cast(dist_poisson(rnd_engine_));}); - } - - /*! - * \brief generate data from a negative binomial distribution - * \param dst destination - * \param k limit on number of failures - * \param p success probability - * \tparam dim dimension of tensor - */ - template - inline void SampleNegativeBinomial(Tensor *dst, PType1 k, PType2 p) { - typedef typename std::conditional::value, DType, int>::type GType; - std::negative_binomial_distribution dist_negbinomial(k, p); - SampleDistribution(dst, [&](){ return static_cast(dist_negbinomial(rnd_engine_));}); - } - - /*! - * \brief generate data from a generalized negative binomial distribution - * \param dst destination - * \param mu parameter (mean) of the distribution - * \param alpha parameter (over dispersion) of the distribution - * (for alpha=0 this gives a Poisson) - * \tparam dim dimension of tensor - */ - template - inline void SampleGeneralizedNegativeBinomial(Tensor *dst, - PType mu, PType alpha) { - if (alpha == PType(0)) { - SamplePoisson(dst, mu); // limit of Poisson - } else { - PType r(PType(1) / alpha); - PType beta = mu * alpha; - std::gamma_distribution<> dist_gamma(r, beta); - typedef typename std::conditional::value, DType, int>::type GType; - SampleDistribution(dst, - [&](){ std::poisson_distribution dist_poisson(dist_gamma(rnd_engine_)); - return static_cast(dist_poisson(rnd_engine_));}); - } - } -#endif - - /*! - * \brief return a temporal expression storing standard gaussian random variables - * the temporal tensor is only valid before next call of gaussian or uniform - * can be used as part of expression - * Caution: this means expression such as A = gaussian(s1) * gaussian(s2) will give invalid result, - * since second call of gaussian(s2) makes gaussian(s1) invalid - * A = gaussian(s1)*B+C; is correct; use one gaussian/uniform in each expression - * \param shape shape of the tensor - * \return a temporal expression storing standard gaussian random variables - * \tparam dim dimension of tensor - */ - template - inline expr::ReshapeExp, DType, dim, 1> - gaussian(Shape shape) { - buffer_.Resize(Shape1(shape.Size())); - this->SampleGaussian(&buffer_, 0.0f, 1.0f); - return expr::reshape(buffer_, shape); - } - /*! - * \brief return a temporal expression storing standard uniform [0,1) - * the temporal tensor is only valid before next call of gaussian or uniform - * can be used as part of expression - * Caution: this means expression such as A = uniform(s1) * uniform(s2) will give invalid result, - * since second call of gaussian(s2) makes gaussian(s1) invalid - * A = gaussian(s1)*B+C; is correct; use one gaussian/uniform in each expression - * \param shape shape of the tensor - * \return a temporal expression storing standard uniform [0,1) - * \tparam dim dimension of tensor - */ - template - inline expr::ReshapeExp, DType, dim, 1> - uniform(Shape shape) { - buffer_.Resize(Shape1(shape.Size())); - this->SampleUniform(&buffer_, 0.0f, 1.0f); - return expr::reshape(buffer_, shape); - } - - std::mt19937 &GetRndEngine() { - return rnd_engine_; - } - - private: -#if MSHADOW_IN_CXX11 - /*! \brief use c++11 random engine. */ - std::mt19937 rnd_engine_; - /*! \brief random number seed used in random engine */ - unsigned rseed_; - -#else - - /*! \brief random number seed used by PRNG */ - unsigned rseed_; - // functions - template - inline void SampleUniform(Tensor *dst, - DType a = 0.0f, DType b = 1.0f) { - if (dst->CheckContiguous()) { - this->GenUniform(dst->dptr_, dst->shape_.Size(), a, b); - } else { - Tensor mat = dst->FlatTo2D(); - for (index_t i = 0; i < mat.size(0); ++i) { - this->GenUniform(mat[i].dptr_, mat.size(1), a, b); - } - } - } - template - inline void SampleGaussian(Tensor *dst, - DType mu = 0.0f, DType sigma = 1.0f) { - if (sigma <= 0.0f) { - *dst = mu; return; - } - if (dst->CheckContiguous()) { - this->GenGaussian(dst->dptr_, dst->shape_.Size(), mu, sigma); - } else { - Tensor mat = dst->FlatTo2D(); - for (index_t i = 0; i < mat.size(0); ++i) { - this->GenGaussian(mat[i].dptr_, mat.size(1), mu, sigma); - } - } - } - inline void GenUniform(float *dptr, index_t size, float a, float b) { - for (index_t j = 0; j < size; ++j) { - dptr[j] = static_cast(RandNext()) * (b - a) + a; - } - } - inline void GenUniform(double *dptr, index_t size, double a, double b) { - for (index_t j = 0; j < size; ++j) { - dptr[j] = static_cast(RandNext()) * (b - a) + a; - } - } - inline void GenGaussian(float *dptr, index_t size, float mu, float sigma) { - this->GenGaussianX(dptr, size, mu, sigma); - } - inline void GenGaussian(double *dptr, index_t size, double mu, double sigma) { - this->GenGaussianX(dptr, size, mu, sigma); - } - inline void GenGaussianX(DType *dptr, index_t size, DType mu, DType sigma) { - DType g1 = 0.0f, g2 = 0.0f; - for (index_t j = 0; j < size; ++j) { - if ((j & 1) == 0) { - this->SampleNormal2D(&g1, &g2); - dptr[j] = mu + g1 * sigma; - } else { - dptr[j] = mu + g2 * sigma; - } - } - } - /*! \brief get next random number from rand */ - inline DType RandNext(void) { - return static_cast(rand_r(&rseed_)) / - (static_cast(RAND_MAX) + 1.0f); - } - /*! \brief return a real numer uniform in (0,1) */ - inline DType RandNext2(void) { - return (static_cast(rand_r(&rseed_)) + 1.0f) / - (static_cast(RAND_MAX) + 2.0f); - } - /*! - * \brief sample iid xx,yy ~N(0,1) - * \param xx first gaussian output - * \param yy second gaussian output - */ - inline void SampleNormal2D(DType *xx_, DType *yy_) { - DType &xx = *xx_, &yy = *yy_; - DType x, y, s; - do { - x = 2.0f * RandNext2() - 1.0f; - y = 2.0f * RandNext2() - 1.0f; - s = x * x + y * y; - } while (s >= 1.0f || s == 0.0f); - DType t = std::sqrt(-2.0f * std::log(s) / s); - xx = x * t; yy = y * t; - } -#endif - /*! \brief temporal space used to store random numbers */ - TensorContainer buffer_; -}; // class Random - -// only allow GPU PRNG when cuda is enabled -#if MSHADOW_USE_CUDA -/*! \brief GPU random number generator */ -template -class Random { - public: - /*! - * \brief constructor of random engine - * \param seed random number seed - */ - explicit Random(int seed) : gen_(NULL) { - this->Seed(seed); - buffer_.Resize(Shape1(kRandBufferSize)); - } - ~Random(void) MSHADOW_THROW_EXCEPTION { - DeleteGenerator(); - } - /*! - * \brief set the stream of computation - * \param stream computation stream - */ - inline void set_stream(Stream *stream) { - curandStatus_t status; - status = curandSetStream(gen_, Stream::GetStream(stream)); - - CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "set_stream CURAND failed"; - } - /*! - * \brief seed random number generator using this seed - * \param seed seed of prng - */ - inline void Seed(int seed) { - // Create a new rng, either initially or if the RNG type can't reset its offset. - if (gen_ == NULL || (curandSetGeneratorOffset(gen_, 0ULL) != CURAND_STATUS_SUCCESS)) - CreateGenerator(); - // Now set the seed. - curandStatus_t status; - status = curandSetPseudoRandomGeneratorSeed(gen_, static_cast(seed)); - CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Set CURAND seed failed."; - } - /*! - * \brief get a set of random integers - */ - inline void GetRandInt(const Tensor& dst) { - curandStatus_t status = curandGenerate(gen_, dst.dptr_, dst.size(0)); - CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen rand ints failed."; - } - /*! - * \brief generate data from uniform [a,b) - * \param dst destination - * \param a lower bound of uniform - * \param b upper bound of uniform - * \tparam dim dimension of tensor - */ - template - inline void SampleUniform(Tensor *dst, - DType a = 0.0f, DType b = 1.0f); - - /*! - * \brief generate data from standard gaussian - * \param dst destination - * \param mu mean variable - * \param sigma standard deviation - * \tparam dim dimension of tensor - */ - template - inline void SampleGaussian(Tensor *dst, - DType mu = 0.0f, DType sigma = 1.0f); - /*! - * \brief return a temporal expression storing standard gaussian random variables - * the temporal tensor is only valid before next call of gaussian or uniform - * can be used as part of expression - * Caution: this means expression such as A = gaussian(s1) * gaussian(s2) will give invalid result, - * since second call of gaussian(s2) makes gaussian(s1) invalid - * A = gaussian(s1)*B+C; is correct; use one gaussian/uniform in each expression - * \param shape shape of the tensor - * \param mu mean - * \param sigma variance - * \return a temporal expression storing standard gaussian random variables - * \tparam dim dimension of tensor - */ - template - inline expr::ReshapeExp, DType, dim, 1> - gaussian(Shape shape, DType mu = 0.0f, DType sigma = 1.0f); - /*! - * \brief return a temporal expression storing standard uniform [0,1) - * the temporal tensor is only valid before next call of gaussian or uniform - * can be used as part of expression - * Caution: this means expression such as A = gaussian(s1) * gaussian(s2) will give invalid result, - * since second call of gaussian(s2) makes gaussian(s1) invalid - * A = gaussian(s1)*B+C; is correct; use one gaussian/uniform in each expression - * \param shape shape of the tensor - * \return a temporal expression storing standard uniform [0,1) - * \tparam dim dimension of tensor - */ - template - inline expr::ReshapeExp, DType, dim, 1> - uniform(Shape shape); - - private: - inline void GenGaussian(float *dptr, size_t size, float mu, float sigma) { - curandStatus_t status; - status = curandGenerateNormal(gen_, dptr, size, mu, sigma); - CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Normal float failed." - << " size = " << size - << ",mu = " << mu - << ",sigma = " << sigma; - } - inline void GenGaussian(double *dptr, size_t size, double mu, double sigma) { - curandStatus_t status; - status = curandGenerateNormalDouble(gen_, dptr, size, mu, sigma); - CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Normal double failed." - << " size = " << size - << ",mu = " << mu - << ",sigma = " << sigma; - } - inline void GenUniform(float *dptr, size_t size) { - curandStatus_t status; - status = curandGenerateUniform(gen_, dptr, size); - CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Uniform float failed." - << " size = " << size; - } - inline void GenUniform(double *dptr, size_t size) { - curandStatus_t status; - status = curandGenerateUniformDouble(gen_, dptr, size); - CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Uniform double failed." - << " size = " << size; - } - inline void CreateGenerator() { - if (gen_ != NULL) - DeleteGenerator(); - curandStatus_t status; - status = curandCreateGenerator(&gen_, CURAND_RNG_PSEUDO_DEFAULT); - CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Cannot create CURAND Generator"; - } - inline void DeleteGenerator() { - if (gen_ != NULL) { - curandStatus_t status; - status = curandDestroyGenerator(gen_); - CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Destory CURAND Gen failed"; - gen_ = NULL; - } - } - /*! \brief random number generator */ - curandGenerator_t gen_; - /*! \brief templ buffer */ - TensorContainer buffer_; -}; // class Random -#endif // MSHADOW_USE_CUDA - -#ifdef __CUDACC__ -// implementations that depends on cuda kernels -template -template -inline void Random::SampleUniform( - Tensor *dst, DType a, DType b) { - if (a == 0.0f && b == 1.0f) { - if (dst->CheckContiguous()) { - this->GenUniform(dst->dptr_, dst->shape_.Size()); - } else { - *dst = this->uniform(dst->shape_); - } - } else { - *dst = this->uniform(dst->shape_) * (b - a) + a; - } -} -template -template -inline void Random::SampleGaussian( - Tensor *dst, DType mu, DType sigma) { - // We need to check whether the shape size is even since CuRand supports only normal distribution - // generation of even number of elements. - if (dst->CheckContiguous() && (dst->shape_.Size() % 2 == 0)) { - this->GenGaussian(dst->dptr_, dst->shape_.Size(), mu, sigma); - } else { - *dst = this->gaussian(dst->shape_, mu, sigma); - } -} - -template -template -inline expr::ReshapeExp, DType, dim, 1> -Random::gaussian(Shape shape, DType mu, DType sigma) { - size_t aligned_sz = ((shape.Size() + 1UL) >> 1) << 1; - // allocate alligned size - buffer_.Resize(Shape1(aligned_sz)); - buffer_.Resize(Shape1(shape.Size())); - this->GenGaussian(buffer_.dptr_, aligned_sz, mu, sigma); - return expr::reshape(buffer_, shape); -} - -template -template -inline expr::ReshapeExp, DType, dim, 1> -Random::uniform(Shape shape) { - buffer_.Resize(Shape1(shape.Size())); - this->GenUniform(buffer_.dptr_, buffer_.size(0)); - return expr::reshape(buffer_, shape); -} -#endif // __CUDACC__ -} // namespace mshadow -#endif // MSHADOW_RANDOM_H_ diff --git a/include/mshadow/stream_gpu-inl.h b/include/mshadow/stream_gpu-inl.h deleted file mode 100644 index d20d2d788526..000000000000 --- a/include/mshadow/stream_gpu-inl.h +++ /dev/null @@ -1,212 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file stream_gpu-inl.h - * \brief implementation of GPU code - * \author Bing Xu, Tianqi Chen - */ -#ifndef MSHADOW_STREAM_GPU_INL_H_ -#define MSHADOW_STREAM_GPU_INL_H_ -#include -#include "./base.h" -#include "./tensor.h" -#include "./logging.h" - -namespace mshadow { -#if MSHADOW_USE_CUDA == 1 -// Stream alocation -// actual implementation of GPU stream in CUDA -template<> -struct Stream { - /*! \brief handle state */ - enum HandleState { - NoHandle = 0, - OwnHandle = 1, - }; - /*! \brief cudaStream */ - cudaStream_t stream_; - /*! \brief cublas handle */ - cublasHandle_t blas_handle_; - /*! \brief cusolver handle */ - #if MSHADOW_USE_CUSOLVER == 1 - cusolverDnHandle_t solver_handle_; - #endif - /*! \brief cudnn handle */ - #if MSHADOW_USE_CUDNN == 1 - cudnnHandle_t dnn_handle_; - #endif - /*! \brief cublas handle ownership */ - HandleState blas_handle_ownership_; - /*! \brief cusolver handle ownership */ - HandleState solver_handle_ownership_; - /*! \brief cudnn handle ownership */ - HandleState dnn_handle_ownership_; - /*! \brief cudaDeviceProp */ - cudaDeviceProp prop; - /*! \brief dev id */ - int dev_id; - - Stream(void) - : stream_(0) - , blas_handle_(0) -#if MSHADOW_USE_CUDNN == 1 - , dnn_handle_(0) -#endif - , blas_handle_ownership_(NoHandle) - , solver_handle_ownership_(NoHandle) - , dnn_handle_ownership_(NoHandle) {} - /*! - * \brief wait for all the computation associated - * with this stream to complete - */ - inline void Wait(void) { - MSHADOW_CUDA_CALL(cudaStreamSynchronize(stream_)); - } - /*! - * \brief query whether the the stream is idle - * \return true if the stream is idle and all the job have been completed - */ - inline bool CheckIdle(void) { - cudaError_t err = cudaStreamQuery(stream_); - if (err == cudaSuccess) return true; - if (err == cudaErrorNotReady) return false; - LOG(FATAL) << cudaGetErrorString(err); - return false; - } - /*! - * \brief returns actual cudaStream_t given an input GPU stream pointer - * \param stream pointer to GPU stream - */ - inline static cudaStream_t GetStream(Stream *stream) { - if (stream == NULL) { -#if MSHADOW_FORCE_STREAM - LOG(FATAL) << "Default GPU stream was used when MSHADOW_FORCE_STREAM was on"; -#endif - return 0; - } else { - return stream->stream_; - } - } - /*! - * \brief return actual cublasHandle - * \param pointer to GPU stream - */ - inline static cublasHandle_t GetBlasHandle(Stream *stream) { - if (stream == NULL) { - return 0; - } else { - CHECK_NE(stream->blas_handle_ownership_, NoHandle) - << "No handle exist in source stream"; - return stream->blas_handle_; - } - } - /*! \brief Destory cublas handle if own it */ - inline void DestroyBlasHandle() { - if (blas_handle_ownership_ == OwnHandle) { - cublasStatus_t err = cublasDestroy(blas_handle_); - blas_handle_ownership_ = NoHandle; - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Destory cublas handle failed"; - } - } - /*! \brief Destory original blas handle and create a new one */ - inline void CreateBlasHandle() { - this->DestroyBlasHandle(); - cublasStatus_t err = cublasCreate(&blas_handle_); - blas_handle_ownership_ = OwnHandle; - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Create cublas handle failed"; - } -#if MSHADOW_USE_CUSOLVER == 1 - inline static cusolverDnHandle_t GetSolverHandle(Stream *stream) { - if (stream == NULL) { - return 0; - } else { - CHECK_NE(stream->solver_handle_ownership_, NoHandle) << "No handle exist in source stream"; - return stream->solver_handle_; - } - } -#endif - inline void DestroySolverHandle() { -#if MSHADOW_USE_CUSOLVER == 1 - if (solver_handle_ownership_ == OwnHandle) { - cusolverStatus_t err = cusolverDnDestroy(solver_handle_); - CHECK_EQ(err, CUSOLVER_STATUS_SUCCESS) << "Destory cusolver handle failed"; - } -#endif - } - inline void CreateSolverHandle() { -#if MSHADOW_USE_CUSOLVER == 1 - this->DestroySolverHandle(); - cusolverStatus_t err = cusolverDnCreate(&solver_handle_); - CHECK_EQ(err, CUSOLVER_STATUS_SUCCESS) << "Create cusolver handle failed"; - err = cusolverDnSetStream(solver_handle_, stream_); - CHECK_EQ(err, CUSOLVER_STATUS_SUCCESS) << "Setting cusolver stream failed"; - this->solver_handle_ownership_ = OwnHandle; -#endif - } -// #if MSHADOW_USE_CUDNN && defined(__CUDACC__) -#if MSHADOW_USE_CUDNN == 1 - inline static cudnnHandle_t GetDnnHandle(Stream *stream) { - if (stream == NULL) { - return 0; - } else { - CHECK_NE(stream->dnn_handle_ownership_, NoHandle) << "No handle exist in source stream"; - return stream->dnn_handle_; - } - } -#endif - inline void DestroyDnnHandle() { -// #if MSHADOW_USE_CUDNN && defined(__CUDACC__) -#if MSHADOW_USE_CUDNN == 1 - if (dnn_handle_ownership_ == OwnHandle) { - cudnnStatus_t err = cudnnDestroy(dnn_handle_); - this->dnn_handle_ownership_ = NoHandle; - CHECK_EQ(err, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(err); - } -#endif - } - inline void CreateDnnHandle() { -// #if MSHADOW_USE_CUDNN == 1 && defined(__CUDACC__) -#if MSHADOW_USE_CUDNN == 1 - this->DestroyDnnHandle(); - cudnnStatus_t err = cudnnCreate(&dnn_handle_); - CHECK_EQ(err, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(err); - // At this point, we have the resource which may need to be freed - this->dnn_handle_ownership_ = OwnHandle; - err = cudnnSetStream(dnn_handle_, stream_); - CHECK_EQ(err, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(err); -#endif - } -}; -template<> -inline void DeleteStream(Stream *stream) { - if (stream) { - MSHADOW_CUDA_CALL(cudaStreamDestroy(stream->stream_)); - stream->DestroyBlasHandle(); - stream->DestroySolverHandle(); - stream->DestroyDnnHandle(); - delete stream; - } -} -template<> -inline Stream *NewStream(bool create_blas_handle, - bool create_dnn_handle, - int dev_id) { - // RAII on Cuda exception - struct StreamDeleter { void operator()(Stream *ptr) const { DeleteStream(ptr); } }; - std::unique_ptr, StreamDeleter> st(new Stream()); - MSHADOW_CUDA_CALL(cudaStreamCreate(&st->stream_)); - if (create_blas_handle) { - st->CreateBlasHandle(); - st->CreateSolverHandle(); - } - if (create_dnn_handle) { - st->CreateDnnHandle(); - } - st->dev_id = dev_id; - if (dev_id != -1) { - MSHADOW_CUDA_CALL(cudaGetDeviceProperties(&st->prop, dev_id)); - } - return st.release(); -} -#endif -} // namespace mshadow -#endif // MSHADOW_STREAM_GPU_INL_H_ diff --git a/include/mshadow/tensor.h b/include/mshadow/tensor.h deleted file mode 100755 index f74281d36693..000000000000 --- a/include/mshadow/tensor.h +++ /dev/null @@ -1,1078 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file tensor.h - * \brief header file of tensor data structure and functions - * This lib requires explicit memory allocation and de-allocation - * all the data structure Tensor, Tensor are like handles(pointers), - * no memory allocation is happening during calculation - * - * For STL style tensor, see tensor_container.h - * \author Bing Xu, Tianqi Chen - */ -#ifndef MSHADOW_TENSOR_H_ -#define MSHADOW_TENSOR_H_ -#include -#include -#include "./base.h" -#include "./expression.h" - -namespace mshadow { -/*! \brief device name CPU */ -struct cpu { - /*! \brief whether this device is CPU or not */ - static const bool kDevCPU = true; - /*! \brief device flag number, identifies this device */ - static const int kDevMask = 1 << 0; -}; -/*! \brief device name GPU */ -struct gpu { - /*! \brief whether this device is CPU or not */ - static const bool kDevCPU = false; - /*! \brief device flag number, identifies this device */ - static const int kDevMask = 1 << 1; -}; -template -struct Shape; - -/*! - * \brief allow string printing of the shape - * \param os the output stream - * \param shape the shape - * \return the ostream - */ -template -inline std::ostream &operator<<(std::ostream &os, const Shape &shape); // NOLINT(*) - -/*! - * \brief shape of a tensor - * \tparam dimension dimension of tensor - */ -template -struct Shape { - /*! \brief dimension of current shape */ - static const int kDimension = dimension; - /*! \brief dimension of current shape minus one */ - static const int kSubdim = dimension - 1; - /*! \brief storing the dimension information */ - index_t shape_[kDimension]; - /*! \brief default constructor, do nothing */ - MSHADOW_XINLINE Shape(void) {} - /*! \brief constuctor */ - MSHADOW_XINLINE Shape(const Shape &s) { - #pragma unroll - for (int i = 0; i < kDimension; ++i) { - this->shape_[i] = s[i]; - } - } - /*! - * \brief get corresponding index - * \param idx dimension index - * \return the corresponding dimension size - */ - MSHADOW_XINLINE index_t &operator[](index_t idx) { - return shape_[idx]; - } - /*! - * \brief get corresponding index - * \param idx dimension index - * \return the corresponding dimension size - */ - MSHADOW_XINLINE const index_t &operator[](index_t idx) const { - return shape_[idx]; - } - /*! - * \return whether two shape equals - * \param s the shape to compare against - */ - MSHADOW_XINLINE bool operator==(const Shape &s) const { - #pragma unroll - for (int i = 0; i < kDimension; ++i) { - if (s.shape_[i] != this->shape_[i]) return false; - } - return true; - } - /*! - * \return whether two shape not equal - * \param s the shape to compare against - */ - MSHADOW_XINLINE bool operator!=(const Shape &s) const { - return !(*this == s); - } - /*! - * flatten the tensor, return a 1D shape - * \return the flat 1d shape - */ - MSHADOW_XINLINE Shape<1> FlatTo1D(void) const { - Shape<1> s; - s[0] = this->Size(); - return s; - } - /*! - * flatten the higher dimension to second dimension, return a 2D shape - * \return the flat 2d shape - */ - MSHADOW_XINLINE Shape<2> FlatTo2D(void) const { - Shape<2> s; - s.shape_[1] = this->shape_[kDimension - 1]; - index_t ymax = 1; - #pragma unroll - for (int i = 0; i < kDimension - 1; ++i) { - ymax *= this->shape_[i]; - } - s.shape_[0] = ymax; - return s; - } - /*! \return number of valid elements */ - MSHADOW_XINLINE index_t Size(void) const { - index_t size = this->shape_[0]; - #pragma unroll - for (int i = 1; i < kDimension; ++i) { - size *= this->shape_[i]; - } - return size; - } - /*! - * \return product shape in [dimstart,dimend) - * \param dimstart start dimension - * \param dimend end dimension - */ - MSHADOW_XINLINE index_t ProdShape(int dimstart, int dimend) const { - index_t num = 1; - #pragma unroll - for (int i = dimstart; i < dimend; ++i) { - num *= this->shape_[i]; - } - return num; - } - /*! - * \brief get subshape that takes off largest dimension -v * \return subshape - */ - MSHADOW_XINLINE Shape SubShape(void) const { - Shape s; - // for cuda - #pragma unroll - for (int i = 0; i < kSubdim; ++i) { - s.shape_[i] = this->shape_[i + 1]; - } - return s; - } - /*! - * \brief slice the shape from start to end - * \tparam dimstart start dimension - * \tparam dimend end dimension - * \return the sliced shape - */ - template - MSHADOW_XINLINE Shape Slice(void) const { - Shape s; - #pragma unroll - for (int i = dimstart; i < dimend; ++i) { - s[i - dimstart] = this->shape_[i]; - } - return s; - } - //! \cond Doxygen_Suppress - template - friend std::ostream &operator<<(std::ostream &os, const Shape &shape); // NOLINT(*) - //! \endcond -}; // Shape -//------------------------------------------------ -// useful construction functions to generate shape -//------------------------------------------------- -/*! - * \brief construct a one dimension shape, stride will equal s0 - * \param s0 size of dimension 0 - * \return the shape construction - */ -MSHADOW_XINLINE Shape<1> Shape1(index_t s0) { - Shape<1> s; s[0] = s0; - return s; -} -/*! - * \brief construct a two dimension shape, stride will equal s0 - * \param s0 size of dimension 0 - * \param s1 size of dimension 1 - * \return the shape construction - */ -MSHADOW_XINLINE Shape<2> Shape2(index_t s0, index_t s1) { - Shape<2> s; s[0] = s0; s[1] = s1; - return s; -} -/*! - * \brief construct a three dimension shape, stride will equal s0 - * \param s0 size of dimension 0 - * \param s1 size of dimension 1 - * \param s2 size of dimension 2 - * \return the shape construction - */ -MSHADOW_XINLINE Shape<3> Shape3(index_t s0, index_t s1, index_t s2) { - Shape<3> s; - s[0] = s0; s[1] = s1; s[2] = s2; - return s; -} -/*! - * \brief construct a four dimension shape, stride will equal s0 - * \param s0 size of dimension 0 - * \param s1 size of dimension 1 - * \param s2 size of dimension 2 - * \param s3 size of dimension 3 - * \return the shape construction - */ -MSHADOW_XINLINE Shape<4> Shape4(index_t s0, index_t s1, - index_t s2, index_t s3) { - Shape<4> s; - s[0] = s0; s[1] = s1; s[2] = s2; s[3] = s3; - return s; -} -/*! -* \brief construct a five dimension shape, stride will equal s0 -* \param s0 size of dimension 0 -* \param s1 size of dimension 1 -* \param s2 size of dimension 2 -* \param s3 size of dimension 3 -* \param s4 size of dimension 4 -* \return the shape construction -*/ -MSHADOW_XINLINE Shape<5> Shape5(index_t s0, index_t s1, index_t s2, - index_t s3, index_t s4) { - Shape<5> s; - s[0] = s0; s[1] = s1; s[2] = s2; s[3] = s3; s[4] = s4; - return s; -} - -/*! -* \brief Convert shape in src_layout to shape in dst_layout -* \param src original shape -* \param src_layout layout of original shape -* \param dst_layout target layout -* \return shape in target layout -*/ -inline Shape<3> ConvertLayout(const Shape<3>& src, int src_layout, int dst_layout) { - Shape<3> dst; - switch (src_layout) { - case kNCW: - dst = src; - break; - case kNWC: - dst[0] = src[0]; - dst[1] = src[2]; - dst[2] = src[1]; - break; - default: - LOG(FATAL) << "Invalid layout for 3d shape " << src_layout; - } - switch (dst_layout) { - case kNCW: - return dst; - case kNWC: - { - index_t tmp = dst[1]; - dst[1] = dst[2]; - dst[2] = tmp; - } - break; - default: - LOG(FATAL) << "Invalid layout for 3d shape " << src_layout; - } - return dst; -} - -/*! -* \brief Convert shape in src_layout to shape in dst_layout -* \param src original shape -* \param src_layout layout of original shape -* \param dst_layout target layout -* \return shape in target layout -*/ -inline Shape<4> ConvertLayout(const Shape<4>& src, int src_layout, int dst_layout) { - Shape<4> dst; - switch (src_layout) { - case kNCHW: - dst = src; - break; - case kNHWC: - dst[0] = src[0]; - dst[2] = src[1]; - dst[3] = src[2]; - dst[1] = src[3]; - break; - default: - LOG(FATAL) << "Invalid layout for 4d shape " << src_layout; - dst = src; // fixes compiler warning - } - Shape<4> dst2; - switch (dst_layout) { - case kNCHW: - return dst; - case kNHWC: - dst2[0] = dst[0]; - dst2[1] = dst[2]; - dst2[2] = dst[3]; - dst2[3] = dst[1]; - break; - default: - LOG(FATAL) << "Invalid layout for 4d shape " << src_layout; - dst2 = src; // fixes compiler warning - } - return dst2; -} - -/*! -* \brief Convert shape in src_layout to shape in dst_layout -* \param src original shape -* \param src_layout layout of original shape -* \param dst_layout target layout -* \return shape in target layout -*/ -inline Shape<5> ConvertLayout(const Shape<5>& src, int src_layout, int dst_layout) { - Shape<5> dst; - switch (src_layout) { - case kNCDHW: - dst = src; - break; - case kNDHWC: - dst[0] = src[0]; - dst[2] = src[1]; - dst[3] = src[2]; - dst[4] = src[3]; - dst[1] = src[4]; - break; - default: - LOG(FATAL) << "Invalid layout for 5d shape " << src_layout; - } - Shape<5> dst2; - switch (dst_layout) { - case kNCDHW: - return dst; - case kNDHWC: - dst2[0] = dst[0]; - dst2[1] = dst[2]; - dst2[2] = dst[3]; - dst2[3] = dst[4]; - dst2[4] = dst[1]; - break; - default: - LOG(FATAL) << "Invalid layout for 5d shape " << src_layout; - } - return dst2; -} - -/*! - * \brief computaion stream structure, used for asynchronous computations - */ -template -struct Stream { - // this is only a dummy implementation for CPU - // for GPU, the actual implementation will be specialized in tensor_gpu-inl.h - /*! - * \brief wait for all the computations associated - * with this stream to complete - */ - inline void Wait(void) {} - /*! - * \brief query whether the the stream is idle - * \return true if the stream is idle and all the jobs have been completed - */ - inline bool CheckIdle(void) { - return true; - } - /*! \brief create a blas handle */ - inline void CreateBlasHandle() {} -}; -/*! - * \brief Tensor RValue, this is the super type of all kinds of possible tensors - * \tparam Container the tensor type - * \tparam Device which device the tensor is on - * \tparam dimension dimension of the tensor - * \tparam DType the type of elements in the tensor - */ -template -struct TRValue: public expr::RValueExp { -}; -// more compact template -/*! - * \brief general tensor - * \tparam Device which device the tensor is on - * \tparam dimension dimension of the tensor - * \tparam DType the type of elements in the tensor - */ -template -struct Tensor: public TRValue, - Device, dimension, DType> { - public: - //-------------------------------- - // struct memembers - //-------------------------------- - /*! \brief whether current type lies in cpu */ - static const bool kDevCPU = Device::kDevCPU; - /*! \brief dimension of subtype */ - static const int kSubdim = dimension - 1; - //-------------------------------- - // struct memembers - //-------------------------------- - /*! \brief pointer to the data */ - DType *dptr_; - /*! \brief shape of the tensor */ - Shape shape_; - /*! - * \brief storing the stride information in x dimension - * this is used to deal with pitch allocation in gpu or sse(align x dimension to 64bit) for efficiency - */ - index_t stride_; - /*! - * \brief stream where the computation lies - * stream is a device dependency concept where each computation - */ - Stream *stream_; - //-------------------------------- - // functions - //-------------------------------- - /*! \brief default constructor */ - MSHADOW_XINLINE Tensor(void) : stream_(NULL) {} - /*! \brief constructor from shape */ - MSHADOW_XINLINE Tensor(const Shape &shape) - : shape_(shape), stream_(NULL) {} - /*! \brief constructor from data pointer and shape, without stride */ - MSHADOW_XINLINE Tensor(DType *dptr, const Shape &shape) - : dptr_(dptr), shape_(shape), stride_(shape[kSubdim]), stream_(NULL) {} - /*! \brief constructor from data pointer and shape, without stride */ - MSHADOW_XINLINE Tensor(DType *dptr, const Shape &shape, - Stream *stream) - : dptr_(dptr), shape_(shape), stride_(shape[kSubdim]), stream_(stream) {} - /*! \brief constructor from data pointer and shape */ - MSHADOW_XINLINE Tensor(DType *dptr, - const Shape &shape, - index_t stride, Stream *stream) - : dptr_(dptr), shape_(shape), stride_(stride), stream_(stream) {} - /*! - * \brief set the stream to do computation of current tensor - * \param stream the computation stream - */ - inline void set_stream(Stream *stream) { - this->stream_ = stream; - } - /*! - * \return memory cost of the tensor, including the aligned x dimension - * \tparam startdim the starting dimension - */ - template - MSHADOW_XINLINE index_t MemSize(void) const { - index_t memsz = this->stride_; - #pragma unroll - for (int i = startdim; i < kSubdim; ++i) { - memsz *= this->shape_[i]; - } - return memsz; - } - /*! - * \return whether the tensor's memory is continuous - * x dimension same as stride - */ - MSHADOW_XINLINE bool CheckContiguous(void) const { - return this->shape_[dimension - 1] == stride_; - } - /*! - * \return memory cost of the tensor, including the aligned x dimension - */ - MSHADOW_XINLINE index_t MSize(void) const { - return this->MemSize<0>(); - } - /*! - * \brief return size of i-th dimension, start counting from highest dimension - * \param idx the dimension count from the highest dimensin - * \return the size - */ - MSHADOW_XINLINE index_t size(index_t idx) const { - return shape_[idx]; - } - /*! - * \brief flatten the tensor to 1 dimension - * \return tensor after flatten - */ - MSHADOW_XINLINE Tensor FlatTo1D(void) const { - return Tensor(dptr_, shape_.FlatTo1D(), stride_, stream_); - } - /*! - * \brief flatten the tensor to 2 dimension, collapse the higher dimensions together - * \return tensor after flatten - */ - MSHADOW_XINLINE Tensor FlatTo2D(void) const { - return Tensor(dptr_, shape_.FlatTo2D(), stride_, stream_); - } - /*! - * \brief get a element of dimension - 1 - * \param idx index - * \return the result tensor - */ - MSHADOW_XINLINE Tensor operator[](index_t idx) const { - return Tensor(dptr_ + this->MemSize<1>() * idx, - shape_.SubShape(), stride_, stream_); - } - /*! - * \brief slice the tensor in highest dimension [begin,end) - * \param begin begin position of slice - * \param end end position of slice - * \return tensor after slice - */ - MSHADOW_XINLINE Tensor - Slice(index_t begin, index_t end) const { - Shape s = this->shape_; - s[0] = end - begin; - return Tensor(dptr_ + this->MemSize<1>() * begin, - s, stride_, stream_); - } - /*!\brief implement the assignment of same type */ - inline Tensor & - operator=(const Tensor &exp) { - dptr_ = exp.dptr_; - shape_ = exp.shape_; - stride_ = exp.stride_; - stream_ = exp.stream_; - return *this; - } - /*!\brief functions to fit expression template */ - template - inline Tensor & - operator=(const expr::Exp &exp) { - return this->__assign(exp); - } - /*!\brief functions to fit expression template */ - inline Tensor &operator=(const DType &exp) { - return this->__assign(exp); - } -}; -/* - * respecialized class Tensor1D, thei is due to different implementation in operator[] - */ -template -struct Tensor: - public TRValue, Device, 1, DType> { - public: - DType *dptr_; - Shape<1> shape_; - index_t stride_; - Stream *stream_; - // constructor - MSHADOW_XINLINE Tensor(void) : stream_(NULL) {} - MSHADOW_XINLINE Tensor(const Shape<1> &shape) - : shape_(shape), stream_(NULL) {} - MSHADOW_XINLINE Tensor(DType *dptr, Shape<1> shape) - : dptr_(dptr), shape_(shape), stride_(shape[0]), stream_(NULL) {} - MSHADOW_XINLINE Tensor(DType *dptr, Shape<1> shape, Stream *stream) - : dptr_(dptr), shape_(shape), stride_(shape[0]), stream_(stream) {} - MSHADOW_XINLINE Tensor(DType *dptr, Shape<1> shape, - index_t stride, Stream *stream) - : dptr_(dptr), shape_(shape), stride_(stride), stream_(stream) {} - inline void set_stream(Stream *stream) { - this->stream_ = stream; - } - MSHADOW_XINLINE Tensor FlatTo1D(void) const { - return *this; - } - MSHADOW_XINLINE Tensor FlatTo2D(void) const { - return Tensor(dptr_, shape_.FlatTo2D(), stride_, stream_); - } - MSHADOW_XINLINE Tensor Slice(index_t begin, index_t end) const { - Shape<1> s; - s[0] = end - begin; - return Tensor(dptr_ + begin, s, s[0], stream_); - } - MSHADOW_XINLINE bool CheckContiguous(void) const { - return true; - } - MSHADOW_XINLINE index_t MSize(void) const { - return shape_[0]; - } - MSHADOW_XINLINE index_t size(index_t i) const { - return shape_[0]; - } - MSHADOW_XINLINE DType &operator[](index_t idx) { - return dptr_[idx]; - } - MSHADOW_XINLINE const DType &operator[](index_t idx) const { - return dptr_[idx]; - } - /*!\brief implement the assignment of same type */ - inline Tensor & - operator=(const Tensor &exp) { - dptr_ = exp.dptr_; - shape_ = exp.shape_; - stride_ = exp.stride_; - stream_ = exp.stream_; - return *this; - } - template - inline Tensor & - operator=(const expr::Exp &exp) { - return this->__assign(exp); - } - inline Tensor &operator=(const DType &exp) { - return this->__assign(exp); - } -}; -//------------------------ -// Function Declarations -//----------------------- -/*! - * \brief initialize tensor engine, used to call intialization functions of dependent libs - * this function should be called before all GPU tensor operations, - * for using tensors in CPU, this call is actually not needed - * \param device_id GPU device id to be choosed - * \tparam Device the device type - */ -template -inline void InitTensorEngine(int device_id = 0); -/*! - * \brief Shutdown tensor engine on current device - * this function should be called after all GPU tensor operations, - * for using tensors in CPU, this call is actually not needed - * \tparam Device the device type - */ -template -inline void ShutdownTensorEngine(void); -/*! - * \brief set the device of current thread to work on - * \param devid the device id - * \tparam Device the device type - */ -template -inline void SetDevice(int devid); -/*! - * \brief create a new stream from system - * \param create_blas_handle whether create blas & cusolver handle in stream - * \param create_dnn_handle whether create cudnn handle in stream - * \param dev_id device id - * \return a pointer to the created stream - * \tparam Device the device type - */ -template -inline Stream *NewStream(bool create_blas_handle, - bool create_dnn_handle, - int dev_id = -1); -/*! \brief default behavior: create cublas handle - * \param dev_id device id - * \return a pointer to the created stream - */ -template -inline Stream *NewStream(int dev_id) { - return NewStream(true, false, dev_id); -} -/*! - * \brief delete the computing stream - * \param stream the stream parameter to be deleted - */ -template -inline void DeleteStream(Stream *stream); -/*! - * \brief CPU/CPU: allocate space for CTensor, according to the shape in the obj - * this function is responsible to set the stride_ in each obj.shape - * \param obj the tensor object, with shape specified - * \param pad whether padding dimension 0, to make last dimension aligned, - * padding may help improve efficiency of matrix multiplications - * if true, will allocate space with stride_ that may not equals shape[0] - * if false, will allocate continuous space - * \tparam dim specify the dim of tensor - * \tparam DType type of element in tensor - */ -template -inline void AllocSpace(Tensor *obj, - bool pad = MSHADOW_ALLOC_PAD); -/*! - * \brief CPU/CPU: allocate space for CTensor, according to the shape in the obj - * this function is responsible to set the stride_ in each obj.shape - * \param obj the tensor object, with shape specified - * \param pad whether padding dimension 0, to make last dimension aligned, - * padding may help improve efficiency of matrix multiplications - * if true, will allocate space with stride_ that may not equals shape[0] - * if false, will allocate continuous space - * \tparam dim specify the dim of tensor - * \tparam DType type of element in tensor - */ -template -inline void AllocSpace(Tensor *obj, - bool pad = MSHADOW_ALLOC_PAD); -/*! - * \brief CPU/GPU: free the space of tensor, will set obj.dptr to NULL - * \param obj the tensor object - * \tparam dim specify the dim of tensor - * \tparam DType type of element in tensor - */ -template -inline void FreeSpace(Tensor *obj); -/*! - * \brief CPU/GPU: free the space of tensor, will set obj.dptr to NULL - * \param obj the tensor object - * \tparam dim specify the dim of tensor - * \tparam DType type of element in tensor - */ -template -inline void FreeSpace(Tensor *obj); -/*! - * \brief CPU/GPU: short cut to allocate and initialize a Tensor - * \param shape: shape of tensor - * \param initv: initialization value - * \param pad : padding option - * \param stream : stream of tensor - * \tparam Device device of tensor - * \tparam DType type of element in tensor - * \tparam dim dimention of tensor - * \return a new allocated tensor - * \sa AllocSpace - */ -template -inline Tensor NewTensor(const Shape &shape, - DType initv, - bool pad = MSHADOW_ALLOC_PAD, - Stream *stream = NULL); -/*! - * \brief copy data from one tensor to another, with same shape - * \param dst target tensor - * \param src source tensor - * \param stream the stream, when specified, the copy can exhibit asynchronize behavior - * \tparam dim specify the dim of tensor - * \tparam DType type of element in tensor - */ -template -inline void Copy(Tensor dst, - const Tensor &src, - Stream *stream = NULL); -/*! - * \brief copy data from one tensor to another, with same shape - * \param dst target tensor - * \param src source tensor - * \param stream the stream, when specified, the copy can exhibit asynchronize behavior - * \tparam dim specify the dim of tensor - * \tparam DType type of element in tensor - */ -template -inline void Copy(Tensor dst, - const Tensor &src, - Stream *stream = NULL); -/*! - * \brief copy data from one tensor to another, with same shape - * \param dst target tensor - * \param src source tensor - * \param stream the stream, when specified, the copy can exhibit asynchronize behavior - * \tparam dim specify the dim of tensor - * \tparam DType type of element in tensor - */ -template -inline void Copy(Tensor dst, - const Tensor &src, - Stream *stream = NULL); -/*! - * \brief copy data from one tensor to another, with same shape - * \param dst target tensor - * \param src source tensor - * \param stream the stream, when specified, the copy can exhibit asynchronize behavior - * \tparam dim specify the dim of tensor - * \tparam DType type of element in tensor - */ -template -inline void Copy(Tensor dst, - const Tensor &src, - Stream *stream = NULL); -/*! - * \brief CPU/GPU: normalize softmax: dst[i][j] = exp(energy[i][j]) /(sum_j exp(energy[i][j])) - * \param dst destination - * \param energy input energy - */ -template -inline void Softmax(Tensor dst, const Tensor &energy); -/*! - * \brief CPU/GPU: normalize softmax: dst[i][j] = exp(energy[i][j]) /(sum_j exp(energy[i][j])) - * \param dst destination - * \param energy input energy - */ -template -inline void Softmax(Tensor dst, const Tensor &energy); - -/*! - * \brief CPU/GPU: softmax gradient - * \param dst destination - * \param src source output - * \param label label info - */ -template -inline void SoftmaxGrad(Tensor dst, - const Tensor &src, - const Tensor &label); -/*! - * \brief CPU/GPU: softmax gradient - * \param dst destination - * \param src source output - * \param label label info - */ -template -inline void SoftmaxGrad(const Tensor &dst, - const Tensor &src, - const Tensor &label); -/*! - * \brief CPU/GPU: Gradient accumulate of embedding matrix. - dst[index[i]] += src[i] - Called when the featuredim of src is much larger than the batchsize - * \param dst destination - * \param index index to take - * \param src source output - */ -template -inline void AddTakeGrad(Tensor dst, - const Tensor& index, - const Tensor &src); -/*! - * \brief CPU/GPU: Gradient accumulate of embedding matrix. - dst[index[i]] += src[i] - Called when the featuredim of src is much larger than the batchsize - * \param dst destination - * \param index index to take - * \param src source output - */ -template -inline void AddTakeGrad(Tensor dst, - const Tensor& index, - const Tensor &src); -/*! - * \brief CPU/GPU: Gradient accumulate of embedding matrix. - dst[sorted[i]] += src[index[i]] - Called when the batchsize of src is larger than the featuredim - * \param dst destination - * \param sorted the sorted indices - * \param index original index of the sorted indices - * \param src source output - */ -template -inline void AddTakeGradLargeBatch(Tensor dst, - const Tensor& sorted, - const Tensor& index, - const Tensor &src); -/*! - * \brief CPU/GPU: Gradient accumulate of embedding matrix. - dst[sorted[i]] += src[index[i]] - Called when the batchsize of src is larger than the featuredim - * \param dst destination - * \param sorted the sorted indices - * \param index original index of the sorted indices - * \param src source output - */ -template -inline void AddTakeGradLargeBatch(Tensor dst, - const Tensor& sorted, - const Tensor& index, - const Tensor &src); -/*! - * \brief CPU/GPU: Fill the values of the destination matrix to specific rows in the source matrix. - dst[index[i]] = src[i] - Will use atomicAdd in the inner implementation and the result may not be deterministic. - * \param dst destination - * \param index the index to accumulate value - * \param src source output - */ -template -inline void IndexFill(Tensor dst, - const Tensor& index, - const Tensor &src); -/*! - * \brief CPU/GPU: Fill the values of the destination matrix to specific rows in the source matrix. - dst[index[i]] = src[i] - Will use atomicAdd in the inner implementation and the result may not be deterministic. - * \param dst destination - * \param index the index to accumulate value - * \param src source output - */ -template -inline void IndexFill(Tensor dst, - const Tensor& index, - const Tensor &src); -/*! - * \brief CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!) - * \param keys the keys to sort - * \param values the values that sorts w.r.t the key - * \param is_ascend whether to sort key in ascending order - */ -template -inline void SortByKey(Tensor keys, Tensor values, - bool is_ascend = true); -/*! - * \brief CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!) - * \param keys the keys to sort - * \param values the values that sorts w.r.t the key - * \param is_ascend whether to sort key in ascending order - */ -template -inline void SortByKey(Tensor keys, Tensor values, - bool is_ascend = true); -/*! - * \brief CPU/GPU: Sort the keys within each segment. (Stable sort is performed!) - Segments is defined as an ascending ordered vector like [0, 0, 0, 1, 1, 2, 3, 3, 3,...] - We sort separately the keys labeled by 0 and 1, 2, 3, etc. - Currently only supports sorting in ascending order !! - * \param values the data to sort - * \param segments segment indicator - */ -template -inline void VectorizedSort(Tensor values, Tensor segments); - -// function declarations to support expression, no need to understand them -// these functions do not need to be directly used -/*! - * \brief CPU/GPU: map a expression to a tensor, this function calls MapPlan - * \tparam Saver specify storage method - * \tparam R specifies the storage type of the tensor - * \tparam dim dim of the tensor, during usage, there is no need to specify this parameter - * \tparam DType the type of elements in the tensor - * \tparam E specifies the expression type, not need to specify this parameter during usage - * \tparam etype expression type - * \param dst destination - * \param exp expression - * \sa namespace mshadow:sv, mshadow::op, mshadow::expr - */ -template -inline void MapExp(TRValue *dst, - const expr::Exp &exp); -/*! - * \brief CPU/GPU: map a expression to a tensor, this function calls MapPlan - * \tparam Saver specify storage method - * \tparam R specifies the storage type of the tensor - * \tparam dim dim of the tensor, during usage, there is no need to specify this parameter - * \tparam DType the type of elements in the tensor - * \tparam E specifies the expression type, not need to specify this parameter during usage - * \tparam etype expression type - * \param dst destination - * \param exp expression - * \sa namespace mshadow:sv, mshadow::op, mshadow::expr - */ -template -inline void MapExp(TRValue *dst, - const expr::Exp &exp); -/*! - * \brief CPU/GPU: map a expression, do reduction to 1D Tensor in lowest dimension (dimension 0) - * \tparam Saver specify storage method - * \tparam Reducer specify a reducer method - * \tparam R specifies the storage type of the tensor - * \tparam DType the type of elements in the tensor - * \tparam E specifies the expression type, not need to specify this parameter during usage - * \tparam etype expression type - * \param dst destination - * \param exp expression - * \param scale scale the result before save - * \sa namespace mshadow:sv, mshadow::op, mshadow::red, mshadow::expr - */ -template -inline void MapReduceKeepLowest(TRValue *dst, - const expr::Exp &exp, - DType scale = 1); -/*! - * \brief CPU/GPU: map a expression, do reduction to 1D Tensor in lowest dimension (dimension 0) - * \tparam Saver specify storage method - * \tparam Reducer specify a reducer method - * \tparam R specifies the storage type of the tensor - * \tparam DType the type of elements in the tensor - * \tparam E specifies the expression type, not need to specify this parameter during usage - * \tparam etype expression type - * \param dst destination - * \param exp expression - * \param scale scale the result before save - * \sa namespace mshadow:sv, mshadow::op, mshadow::red, mshadow::expr - */ -template -inline void MapReduceKeepLowest(TRValue *dst, - const expr::Exp &exp, - DType scale = 1); -/*! - * \brief CPU/GPU: map a expression, do reduction to 1D Tensor in third dimension (dimension 2) - * \tparam Saver specify storage method - * \tparam Reducer specify a reducer method - * \tparam R specifies the storage type of the tensor - * \tparam DType the type of elements in the tensor - * \tparam dimkeep the target dimension to be kept, should be larger than 0, for 0, use MapReduceKeepLowest - * \tparam E specifies the expression type, not need to specify this parameter during usage - * \tparam etype expression type - * \param dst destination - * \param exp expression - * \param scale scale the result before save - * \sa namespace mshadow:sv, mshadow::op, mshadow::red, mshadow::expr - */ -template -inline void MapReduceKeepHighDim(TRValue *dst, - const expr::Exp &exp, - DType scale = 1); -/*! - * \brief CPU/GPU: map a expression, do reduction to 1D Tensor in third dimension (dimension 2) - * \tparam Saver specify storage method - * \tparam Reducer specify a reducer method - * \tparam R specifies the storage type of the tensor - * \tparam DType the type of elements in the tensor - * \tparam dimkeep the target dimension to be kept, should be larger than 0, for 0, use MapReduceKeepLowest - * \tparam E specifies the expression type, not need to specify this parameter during usage - * \tparam etype expression type - * \param dst destination - * \param exp expression - * \param scale scale the result before save - * \sa namespace mshadow:sv, mshadow::op, mshadow::red, mshadow::expr - */ -template -inline void MapReduceKeepHighDim(TRValue *dst, - const expr::Exp &exp, - DType scale = 1); -/*! - * \brief CPU/GPU: 1 dimension vector dot - * \param dst Length 1 vector, used to hold the result. - * \param lhs Left operand vector - * \param rhs Right operand vector - */ -template -inline void VectorDot(Tensor dst, - const Tensor &lhs, - const Tensor &rhs); -/*! - * \brief CPU/GPU: dst = alpha * op(lhs) op(rhs) + beta * dst - * \param dst Length 3 tensor, used to hold the result - * \param lhs Left operand vector - * \param rhs Right operand vector - * \param alpha multiplier of op(lhs)op(rhs) - * \param beta multiplier of dst - * \param workspace Workspace for casting DType* to DType** (batched-view), must have size >= 3 * batch_size - */ -template -inline void BatchGEMM(Tensor dst, - const Tensor &lhs, - const Tensor &rhs, - DType alpha, - DType beta, - Tensor workspace); -} // namespace mshadow -// include headers -#include "./stream_gpu-inl.h" -#include "./extension.h" -#include "./expr_engine-inl.h" -#include "./tensor_cpu-inl.h" -#include "./tensor_gpu-inl.h" -#include "./io.h" -#include "./tensor_container.h" -#include "./random.h" -// add definition of scalar related operators -#ifdef MSHADOW_SCALAR_ - #error "MSHADOW_SCALAR_ must not be defined" -#endif -// enumerate all the scalar data type we aim to be good at -#define MSHADOW_SCALAR_ float -#include "./expr_scalar-inl.h" -#undef MSHADOW_SCALAR_ -#define MSHADOW_SCALAR_ double -#include "./expr_scalar-inl.h" -#undef MSHADOW_SCALAR_ -#define MSHADOW_SCALAR_ int -#include "./expr_scalar-inl.h" -#undef MSHADOW_SCALAR_ -#define MSHADOW_SCALAR_ mshadow::half::half_t -#include "./expr_scalar-inl.h" -#undef MSHADOW_SCALAR_ -#endif // MSHADOW_TENSOR_H_ diff --git a/include/mshadow/tensor_container.h b/include/mshadow/tensor_container.h deleted file mode 100644 index b4df68e8e3a5..000000000000 --- a/include/mshadow/tensor_container.h +++ /dev/null @@ -1,208 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file tensor_container.h - * \brief tensor container that does memory allocation and resize like STL - * \author Tianqi Chen - */ -#ifndef MSHADOW_TENSOR_CONTAINER_H_ -#define MSHADOW_TENSOR_CONTAINER_H_ -#include "./tensor.h" -#include "./io.h" - -namespace mshadow { -/*! - * \brief tensor container that does memory allocation and resize like STL, - * use it to save the lines of FreeSpace in class. - * Do not abuse it, efficiency can come from pre-allocation and no re-allocation - * - * \tparam Device which device the tensor is on - * \tparam dimension dimension of the tensor - */ -template -class TensorContainer: public Tensor { - public: - /*! - * \brief constructor - * \param pad whether use padding alignment in space allocation - */ - explicit TensorContainer(bool pad = MSHADOW_ALLOC_PAD) { - this->pad_ = pad; - this->dptr_ = data_.dptr_ = NULL; - this->shape_[0] = 0; - this->stride_ = 0; - this->data_.stride_ = 0; - this->data_.shape_[0] = 0; - } - /*! - * \brief constructor - * \param shape intial shape - */ - explicit TensorContainer(const Shape &shape) { - this->pad_ = MSHADOW_ALLOC_PAD; - data_.dptr_ = NULL; - this->AllocByShape(shape); - } - /*! - * \brief constructor - * \param shape intial shape - * \param initv intial value - */ - explicit TensorContainer(const Shape &shape, DType initv) { - this->pad_ = MSHADOW_ALLOC_PAD; - data_.dptr_ = NULL; - this->AllocByShape(shape); - (*this) = initv; - } - /*! - * \brief copy constructor - * \param src source value - */ - TensorContainer - (const TensorContainer &src) - : pad_(src.pad_) { - this->dptr_ = data_.dptr_ = NULL; - this->shape_[0] = 0; - this->stride_ = 0; - this->data_.stride_ = 0; - this->data_.shape_[0] = 0; - this->stream_ = src.stream_; - if (src.dptr_ != NULL) { - this->AllocByShape(src.shape_); - mshadow::Copy(*this, src, this->stream_); - } - } - ~TensorContainer(void) { - this->Release(); - } - /*! - * \brief resize the container to given shape, content is NOT preserved - * \param shape target shape - */ - inline void Resize(const Shape &shape) { - Shape<2> s2 = shape.FlatTo2D(); - if (s2.shape_[1] > data_.stride_ || s2.shape_[0] > data_.size(0)) { - this->AllocByShape(shape); - } else { - this->shape_ = shape; - if (this->pad_) { - this->stride_ = data_.stride_; - } else { - this->stride_ = s2.shape_[1]; - } - } - } - /*! - * \brief resize the container to given shape, and initialize, content is NOT preserved - * \param shape target shape - * \param initv initialization value - */ - inline void Resize(const Shape &shape, DType initv) { - this->Resize(shape); - (*this) = initv; - } - /*! \brief set whether padding is allowed in tensor */ - inline void set_pad(bool pad) { - this->pad_ = pad; - } - /*! - * \brief save by binary format - * \param fo output binary stream - * \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. - */ - template - inline void SaveBinary(TStream &fo) const { // NOLINT(*) - mshadow::SaveBinary(fo, *this); - } - /*! - * \brief load by binary format, a temp Tensor storage will be allocated - * \param fi input binary stream - * \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream. - */ - template - inline void LoadBinary(TStream &fi) { // NOLINT(*) - Tensor tmp; - mshadow::LoadBinary(fi, &tmp, false); - this->Resize(tmp.shape_); - Stream stream; - Copy(*this, tmp, &stream); - mshadow::FreeSpace(&tmp); - } - /*! - * \brief assign operator from TensorContainer - * \param src source value - * \return reference of self - */ - inline TensorContainer &operator= - (const TensorContainer &src) { - this->pad_ = src.pad_; - this->stream_ = src.stream_; - if (src.dptr_ != NULL) { - this->Resize(src.shape_); - mshadow::Copy(*this, src, this->stream_); - } - return *this; - } - /*!\brief functions to fit expression template */ - inline Tensor &operator=(DType s) { - return this->__assign(s); - } - /*!\brief functions to fit expression template */ - template - inline Tensor & - operator=(const expr::Exp &exp) { - return this->__assign(exp); - } - /*!\brief functions to fit expression template */ - template - inline Tensor & - operator=(const expr::Exp &exp) { - return this->__assign(exp); - } - /*!\brief functions to fit expression template */ - template - inline Tensor & - operator=(const expr::Exp &exp) { - return this->__assign(exp); - } - /*! - * \brief Release the llocated space, - * The TensorContainer is still functionable, - * but will restart allocating space when Resize is called. - */ - inline void Release(void) { - if (data_.dptr_ != NULL) { - this->shape_[0] = 0; - this->stride_ = 0; - this->data_.stride_ = 0; - this->data_.shape_[0] = 0; - try { - mshadow::FreeSpace(&data_); - } catch (const dmlc::Error &e) { - this->dptr_ = data_.dptr_ = NULL; - throw e; - } - this->dptr_ = data_.dptr_ = NULL; - } - } - - private: - /*! \brief whether we do padding in the space */ - bool pad_; - /*! \brief the shape of data_ is actually current data space */ - Tensor data_; - - inline void AllocByShape(const Shape& shape) { - if (data_.dptr_ != NULL) this->Release(); - data_.shape_ = shape.FlatTo2D(); - mshadow::AllocSpace(&data_, pad_); - this->dptr_ = data_.dptr_; - this->shape_ = shape; - if (this->pad_) { - this->stride_ = data_.stride_; - } else { - this->stride_ = data_.size(1); - } - } -}; -} // namespace mshadow -#endif // MSHADOW_TENSOR_CONTAINER_H_ diff --git a/include/mshadow/tensor_cpu-inl.h b/include/mshadow/tensor_cpu-inl.h deleted file mode 100755 index ab5f9a68df14..000000000000 --- a/include/mshadow/tensor_cpu-inl.h +++ /dev/null @@ -1,627 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file tensor_cpu-inl.h - * \brief implementation of CPU host code - * \author Bing Xu, Tianqi Chen - */ -#ifndef MSHADOW_TENSOR_CPU_INL_H_ -#define MSHADOW_TENSOR_CPU_INL_H_ -#include -#include -#include -#include -#include "./base.h" -#include "./tensor.h" -#include "./packet-inl.h" -#include "./dot_engine-inl.h" - -namespace mshadow { -template<> -inline void InitTensorEngine(int dev_id) { -} -template<> -inline void ShutdownTensorEngine(void) { -} - -template<> -inline void SetDevice(int devid) { -} -template<> -inline Stream *NewStream(bool create_blas_handle, - bool create_dnn_handle, - int dev_id) { - return new Stream(); -} -template<> -inline void DeleteStream(Stream *stream) { - delete stream; -} - -template -inline std::ostream &operator<<(std::ostream &os, const Shape &shape) { // NOLINT(*) - os << '('; - for (int i = 0; i < ndim; ++i) { - if (i != 0) os << ','; - os << shape[i]; - } - // python style tuple - if (ndim == 1) os << ','; - os << ')'; - return os; -} - -template -inline void *AllocHost_(size_t size); -template -inline void FreeHost_(void * dptr); - -#ifdef __CUDACC__ -template<> -inline void *AllocHost_(size_t size) { - void *dptr; - MSHADOW_CUDA_CALL(cudaMallocHost(&dptr, size, cudaHostAllocPortable)); - return dptr; -} -template<> -inline void FreeHost_(void *dptr) { - MSHADOW_CUDA_CALL(cudaFreeHost(dptr)); -} -#endif - -template<> -inline void *AllocHost_(size_t size) { - size_t pitch; - return packet::AlignedMallocPitch(&pitch, size, 1); -} -template<> -inline void FreeHost_(void *dptr) { - packet::AlignedFree(dptr); -} - -template -inline void AllocHost(Tensor *obj) { - obj->stride_ = obj->size(dim - 1); - CHECK_EQ(obj->CheckContiguous(), true) << "AllocHost"; - void *dptr = AllocHost_(obj->MSize() * sizeof(DType)); - obj->dptr_ = reinterpret_cast(dptr); -} -template -inline void FreeHost(Tensor *obj) { - if (obj->dptr_ == NULL) { - LOG(FATAL) << "FreeHost:: double free"; - } - FreeHost_(obj->dptr_); - obj->dptr_ = NULL; -} - -template -inline void AllocSpace(Tensor *obj, bool pad) { - size_t pitch; - void *dptr; - if (pad) { - dptr = packet::AlignedMallocPitch - (&pitch, obj->size(dim - 1) * sizeof(DType), obj->shape_.FlatTo2D()[0]); - obj->stride_ = static_cast(pitch / sizeof(DType)); - } else { - obj->stride_ = obj->size(dim - 1); - dptr = packet::AlignedMallocPitch - (&pitch, obj->shape_.Size() * sizeof(DType), 1); - } - obj->dptr_ = reinterpret_cast(dptr); -} -template -inline Tensor -NewTensor(const Shape &shape, DType initv, bool pad, Stream *stream_) { - Tensor obj(shape); - obj.stream_ = stream_; - AllocSpace(&obj, pad); - MapExp(&obj, expr::ScalarExp(initv)); - return obj; -} -template -inline void FreeSpace(Tensor *obj) { - packet::AlignedFree(obj->dptr_); - obj->dptr_ = NULL; -} -template -inline void Copy(Tensor _dst, - const Tensor &_src, - Stream *stream) { - CHECK_EQ(_dst.shape_, _src.shape_) - << "Copy:shape mismatch:" << _dst.shape_ << " vs " << _src.shape_; - if (_dst.CheckContiguous() && _src.CheckContiguous()) { - memcpy(_dst.dptr_, _src.dptr_, sizeof(DType) * _dst.shape_.Size()); - } else { - Tensor dst = _dst.FlatTo2D(); - Tensor src = _src.FlatTo2D(); - for (index_t y = 0; y < dst.size(0); ++y) { - memcpy(dst[y].dptr_, src[y].dptr_, sizeof(DType) * dst.size(1)); - } - } -} - -template -inline void MapPlan(TRValue *dst, - const expr::Plan &plan) { - Shape<2> shape = expr::ShapeCheck::Check(dst->self()).FlatTo2D(); - expr::Plan dplan = expr::MakePlan(dst->self()); -#ifndef __CUDACC__ - #pragma omp parallel for -#endif - // temp remove openmp, as default setting throttles CPU - for (openmp_index_t y = 0; y < shape[0]; ++y) { - for (index_t x = 0; x < shape[1]; ++x) { - // trust your compiler! -_- they will optimize it - Saver::template Save(dplan.REval(y, x), plan.Eval(y, x)); - } - } -} -// code to handle SSE optimization -template -struct MapExpCPUEngine { - inline static void Map(TRValue *dst, - const expr::Exp &exp) { - MapPlan(dst, MakePlan(exp.self())); - } -}; - -template -struct MapExpCPUEngine, - dim, DType, E, etype> { - inline static void Map(Tensor *dst, - const expr::Exp &exp) { - if (expr::PacketAlignCheck::Check(exp.self()) && - expr::PacketAlignCheck, MSHADOW_DEFAULT_PACKET>::Check(*dst)) { - expr::MapPacketPlan(dst->self(), - expr::MakePacketPlan(exp.self())); - } else { - MapPlan(dst, MakePlan(exp.self())); - } - } -}; - - -template -inline void MapExp(TRValue *dst, - const expr::Exp &exp) { - expr::TypeCheckPass::kMapPass> - ::Error_All_Tensor_in_Exp_Must_Have_Same_Type(); - Shape eshape = expr::ShapeCheck::Check(exp.self()); - Shape dshape = expr::ShapeCheck::Check(dst->self()); - CHECK(eshape[0] == 0 || eshape == dshape) - << "Assignment: Shape of Tensors are not consistent with target, " - << "eshape: " << eshape << " dshape:" << dshape; - MapExpCPUEngine::kPass, - Saver, R, dim, DType, E, etype> - ::Map(dst->ptrself(), exp); -} - -template -inline void MapReduceKeepLowest(TRValue *dst, - const expr::Exp &exp, - DType scale) { - expr::TypeCheckPass::kRedPass> - ::Error_TypeCheck_Not_Pass_For_Reduce_Exp(); - Shape<2> eshape = expr::ShapeCheck::kDim, E> - ::Check(exp.self()).FlatTo2D(); - Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self()); - CHECK_EQ(eshape[1], dshape[0]) << "MapReduceKeepLowest::reduction dimension do not match"; - CHECK_NE(eshape[0], 0U) << "can not reduce over empty tensor"; - // execution - expr::Plan dplan = MakePlan(dst->self()); - expr::Plan splan = MakePlan(exp.self()); -#ifndef __CUDACC__ - #pragma omp parallel for -#endif - for (openmp_index_t x = 0; x < eshape[1]; ++x) { - DType res = splan.Eval(0, x); - for (index_t y = 1; y < eshape[0]; ++y) { - Reducer::Reduce(res, splan.Eval(y, x)); - } - Saver::template Save(dplan.REval(0, x), res * scale); - } -} - -template -inline void MapReduceKeepHighDim(TRValue *dst, - const expr::Exp &exp, - DType scale) { - expr::TypeCheckPass::kRedPass> - ::Error_TypeCheck_Not_Pass_For_Reduce_Exp(); - typedef Shape::kDim> EShape; - EShape eshape = expr::ShapeCheck::kDim, E> - ::Check(exp.self()); - Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self()); - CHECK_EQ(eshape[dimkeep], dshape[0]) - << "MapReduceKeepHighDim::reduction dimension do not match"; - // use equvalent form - Shape<4> pshape = Shape4(eshape.ProdShape(0, dimkeep), - eshape[dimkeep], - eshape.ProdShape(dimkeep + 1, EShape::kSubdim), - eshape[EShape::kSubdim]); - // execution - expr::Plan dplan = MakePlan(dst->self()); - expr::Plan splan = MakePlan(exp.self()); -#ifndef __CUDACC__ - #pragma omp parallel for -#endif - for (openmp_index_t c = 0; c < pshape[1]; ++c) { - DType res; Reducer::SetInitValue(res); - for (index_t n = 0; n < pshape[0]; ++n) { - DType tres; Reducer::SetInitValue(tres); - for (index_t y = 0; y < pshape[2]; ++y) { - for (index_t x = 0; x < pshape[3]; ++x) { - Reducer::Reduce(tres, - splan.Eval((n * pshape[1] + c) * pshape[2] + y, x)); - } - } - Reducer::Reduce(res, tres); - } - Saver::template Save(dplan.REval(0, c), DType(res * scale)); - } -} - -template -inline void Softmax(Tensor dst, - const Tensor &energy) { - DType mmax = energy[0]; - for (index_t x = 1; x < dst.size(0); ++x) { - if (mmax < energy[x]) mmax = energy[x]; - } - DType sum = DType(0.0f); - for (index_t x = 0; x < dst.size(0); ++x) { - dst[x] = std::exp(energy[x] - mmax); - sum += dst[x]; - } - for (index_t x = 0; x < dst.size(0); ++x) { - dst[x] /= sum; - } -} - -template -inline void SoftmaxGrad(Tensor dst, - const Tensor &src, - const Tensor &label) { -#pragma omp parallel for - for (openmp_index_t y = 0; y < dst.size(0); ++y) { - const index_t k = static_cast(label[y]); - for (index_t x = 0; x < dst.size(1); ++x) { - if (x == k) { - dst[y][k] = src[y][k] - 1.0f; - } else { - dst[y][x] = src[y][x]; - } - } - } -} - -template -inline void SmoothSoftmaxGrad(Tensor dst, - const Tensor &src, - const Tensor &label, - const float alpha) { - const float smooth_grad = (alpha / (dst.size(1) - 1)); -#pragma omp parallel for - for (openmp_index_t y = 0; y < dst.size(0); ++y) { - const index_t k = static_cast(label[y]); - for (index_t x = 0; x < dst.size(1); ++x) { - if (x == k) { - dst[y][k] = src[y][k] - 1.0f + alpha; - } else { - dst[y][x] = src[y][x] - smooth_grad; - } - } - } -} - - -template -inline void SoftmaxGrad(Tensor dst, - const Tensor &src, - const Tensor &label, - const DType &ignore_label) { -#pragma omp parallel for - for (openmp_index_t y = 0; y < dst.size(0); ++y) { - const int k = static_cast(label[y]); - for (int x = 0; x < static_cast(dst.size(1)); ++x) { - if (static_cast(ignore_label) == k) { - dst[y][x] = 0.0f; - } else { - if (x == k) { - dst[y][k] = src[y][k] - 1.0f; - } else { - dst[y][x] = src[y][x]; - } - } - } - } -} - -template -inline void SmoothSoftmaxGrad(Tensor dst, - const Tensor &src, - const Tensor &label, - const DType &ignore_label, - const float alpha) { - const float smooth_grad = (alpha / (dst.size(1) - 1)); -#pragma omp parallel for - for (openmp_index_t y = 0; y < dst.size(0); ++y) { - const int k = static_cast(label[y]); - for (int x = 0; x < static_cast(dst.size(1)); ++x) { - if (static_cast(ignore_label) == k) { - dst[y][x] = 0.0f; - } else { - if (x == k) { - dst[y][k] = src[y][k] - 1.0f + alpha; - } else { - dst[y][x] = src[y][x] - smooth_grad; - } - } - } - } -} - -template -inline void SoftmaxGrad(Tensor dst, - const Tensor &src, - const Tensor &label) { -#pragma omp parallel for - for (openmp_index_t n = 0; n < dst.size(2); ++n) { - for (index_t y = 0; y < dst.size(0); ++y) { - const int k = static_cast(label[y][n]); - for (int x = 0; x < static_cast(dst.size(1)); ++x) { - if (x == k) { - dst[y][k][n] = src[y][k][n] - 1.0f; - } else { - dst[y][x][n] = src[y][x][n]; - } - } - } - } -} - -template -inline void SmoothSoftmaxGrad(Tensor dst, - const Tensor &src, - const Tensor &label, - const float alpha) { - const float smooth_grad = (alpha / (dst.size(1) - 1)); -#pragma omp parallel for - for (openmp_index_t n = 0; n < dst.size(2); ++n) { - for (index_t y = 0; y < dst.size(0); ++y) { - const int k = static_cast(label[y][n]); - for (int x = 0; x < static_cast(dst.size(1)); ++x) { - if (x == k) { - dst[y][k][n] = src[y][k][n] - 1.0f + alpha; - } else { - dst[y][x][n] = src[y][x][n] - smooth_grad; - } - } - } - } -} - -template -inline void SoftmaxGrad(Tensor dst, - const Tensor &src, - const Tensor &label, - const DType &ignore_label) { -#pragma omp parallel for - for (openmp_index_t n = 0; n < dst.size(2); ++n) { - for (index_t y = 0; y < dst.size(0); ++y) { - const int k = static_cast(label[y][n]); - if (k == static_cast(ignore_label)) { - for (int x = 0; x < static_cast(dst.size(1)); ++x) { - dst[y][x][n] = DType(0.0f); - } - } else { - for (int x = 0; x < static_cast(dst.size(1)); ++x) { - if (x == k) { - dst[y][k][n] = src[y][k][n] - 1.0f; - } else { - dst[y][x][n] = src[y][x][n]; - } - } - } - } - } -} - -template -inline void SmoothSoftmaxGrad(Tensor dst, - const Tensor &src, - const Tensor &label, - const DType &ignore_label, - const float alpha) { - const float smooth_grad = (alpha / (dst.size(1) - 1)); -#pragma omp parallel for - for (openmp_index_t n = 0; n < dst.size(2); ++n) { - for (index_t y = 0; y < dst.size(0); ++y) { - const int k = static_cast(label[y][n]); - if (k == static_cast(ignore_label)) { - for (int x = 0; x < static_cast(dst.size(1)); ++x) { - dst[y][x][n] = DType(0.0f); - } - } else { - for (int x = 0; x < static_cast(dst.size(1)); ++x) { - if (x == k) { - dst[y][k][n] = src[y][k][n] - 1.0f + alpha; - } else { - dst[y][x][n] = src[y][x][n] - smooth_grad; - } - } - } - } - } -} - -template -inline void Softmax(Tensor dst, - const Tensor &energy) { - CHECK_EQ(dst.shape_, energy.shape_) << "Softmax: shape mismatch"; -#pragma omp parallel for - for (openmp_index_t y = 0; y < dst.size(0); ++y) { - Softmax(dst[y], energy[y]); - } -} - -template -inline void Softmax(Tensor dst, - const Tensor &energy) { - CHECK_EQ(dst.shape_, energy.shape_) << "Softmax: shape mismatch"; -#pragma omp parallel for - for (openmp_index_t y = 0; y < dst.size(0); ++y) { - for (index_t n = 0; n < dst.size(2); ++n) { - DType mmax = energy[y][0][n]; - for (index_t x = 1; x < dst.size(1); ++x) { - if (mmax < energy[y][x][n]) mmax = energy[y][x][n]; - } - DType sum = DType(0.0f); - for (index_t x = 0; x < dst.size(1); ++x) { - dst[y][x][n] = std::exp(energy[y][x][n] - mmax); - sum += dst[y][x][n]; - } - for (index_t x = 0; x < dst.size(1); ++x) { - dst[y][x][n] /= sum; - } - } - } -} - -template -inline void AddTakeGrad(Tensor dst, - const Tensor& index, - const Tensor &src) { - const int K = dst.shape_[0]; - for (index_t y = 0; y < index.size(0); ++y) { - int j = index[y]; - if (j <= 0) j = 0; - else if (j >= K) j = K - 1; - dst[j] += src[y]; - } -} - -template -inline void AddTakeGradLargeBatch(Tensor dst, - const Tensor& sorted, - const Tensor& index, - const Tensor &src) { - for (index_t y = 0; y < sorted.size(0); ++y) { - dst[sorted[y]] += src[index[y]]; - } -} - -template -inline void IndexFill(Tensor dst, - const Tensor& index, - const Tensor &src) { - for (index_t y = 0; y < index.size(0); ++y) { - for (index_t j = 0; j < src.size(1); j++) { - dst[index[y]][j] = src[y][j]; - } - } -} - -template -inline void SortByKey(Tensor keys, Tensor values, - bool is_ascend) { - CHECK_EQ(keys.CheckContiguous(), true); - CHECK_EQ(values.CheckContiguous(), true); - CHECK_EQ(keys.size(0), values.size(0)) - << "The sizes of key/value are not equal! keys_size: " << keys.size(0) - << "values_size: " << values.size(0); - std::vector idx(keys.size(0)); - std::vector keys_vec(keys.size(0)); - std::vector values_vec(values.size(0)); - for (int i = 0; i < keys.size(0); i++) { - idx[i] = i; - keys_vec[i] = keys[i]; - values_vec[i] = values[i]; - } - if (is_ascend) { - std::stable_sort(idx.begin(), idx.end(), - [&keys_vec](size_t i1, size_t i2) - {return keys_vec[i1] < keys_vec[i2]; }); - } else { - std::stable_sort(idx.begin(), idx.end(), - [&keys_vec](size_t i1, size_t i2) - {return keys_vec[i1] > keys_vec[i2]; }); - } - for (index_t i = 0; i < values.size(0); i++) { - keys[i] = keys_vec[idx[i]]; - values[i] = values_vec[idx[i]]; - } -} - -template -inline void VectorizedSort(Tensor values, Tensor segments) { - // We can sort each segments using two stable sorts - SortByKey(values, segments, true); - SortByKey(segments, values, true); -} - -// blas related -template -inline void VectorDot(Tensor dst, - const Tensor &lhs, - const Tensor &rhs) { - CHECK_EQ(lhs.size(0), rhs.size(0)) - << "VectorDot: Shape mismatch"; - CHECK_EQ(dst.size(0), 1U) - << "VectorDot: expect dst to be scalar"; - expr::BLASEngine::SetStream(lhs.stream_); - mshadow::expr::BLASEngine::dot( - lhs.stream_, lhs.size(0), lhs.dptr_, 1, rhs.dptr_, 1, dst.dptr_); -} - -template -inline void BatchGEMM(Tensor dst, - const Tensor &lhs, - const Tensor &rhs, - DType alpha, - DType beta, - Tensor workspace) { - index_t batch_size = dst.shape_[0]; - expr::BLASEngine::SetStream(dst.stream_); - Shape<3> sleft = transpose_left ? Shape3(lhs.shape_[0], lhs.shape_[2], lhs.shape_[1]) - : lhs.shape_; - Shape<3> sright = transpose_right ? Shape3(rhs.shape_[0], rhs.shape_[2], rhs.shape_[1]) - : rhs.shape_; - CHECK_EQ(dst.CheckContiguous(), true); - CHECK_EQ(lhs.CheckContiguous(), true); - CHECK_EQ(rhs.CheckContiguous(), true); - CHECK(sleft[0] == batch_size && sright[0] == batch_size) - << "BatchGEMM: batchsize must be equal." - << "dst: " << dst.shape_ << "\n" - << "lhs: " << sleft << "\n" - << "rhs: " << sright << "\n"; - CHECK(dst.size(1) == sleft[1] && dst.size(2) == sright[2] && sleft[2] == sright[1]) - << "BatchGEMM: matrix shape mismatch" - << "dst: " << dst.shape_ << "\n" - << "lhs: " << sleft << "\n" - << "rhs: " << sright << "\n"; - CHECK(workspace.size(0) >= 3 * batch_size) - << "Workspace Size must be bigger than " << 3 * batch_size; - CHECK_EQ(workspace.CheckContiguous(), true); - // use column major argument to compatible with most BLAS - expr::BLASEngine::batched_gemm - (dst.stream_, - transpose_right, transpose_left, - transpose_right ? rhs.size(1) : rhs.size(2), - transpose_left ? lhs.size(2) : lhs.size(1), - transpose_right ? rhs.size(2) : rhs.size(1), - alpha, - rhs.dptr_, rhs.stride_, - lhs.dptr_, lhs.stride_, - beta, - dst.dptr_, dst.stride_, batch_size, - workspace.dptr_); -} -} // namespace mshadow -#endif // MSHADOW_TENSOR_CPU_INL_H_ diff --git a/include/mshadow/tensor_gpu-inl.h b/include/mshadow/tensor_gpu-inl.h deleted file mode 100755 index 94fdb0527e72..000000000000 --- a/include/mshadow/tensor_gpu-inl.h +++ /dev/null @@ -1,245 +0,0 @@ -/*! - * Copyright (c) 2014 by Contributors - * \file tensor_gpu-inl.h - * \brief implementation of GPU host code - * \author Bing Xu, Tianqi Chen - */ -#ifndef MSHADOW_TENSOR_GPU_INL_H_ -#define MSHADOW_TENSOR_GPU_INL_H_ -#include "./base.h" -#include "./tensor.h" - -namespace mshadow { -#if MSHADOW_USE_CUDA -template<> -inline void InitTensorEngine(int dev_id) { - cudaDeviceProp prop; - int device_id = 0; - int device_count = 0; - cudaGetDeviceCount(&device_count); - CHECK_GT(device_count, 0) << "Cannot find CUDA device. Please check CUDA-Configuration"; - if (dev_id < 0) { - device_id = 0; - } else { - device_id = dev_id; - } - CHECK_LT(device_id, device_count) << "Incorrect Device ID"; - MSHADOW_CUDA_CALL(cudaSetDevice(device_id)); - MSHADOW_CUDA_CALL(cudaGetDeviceProperties(&prop, device_id)); -} -template<> -inline void ShutdownTensorEngine(void) { -} -template<> -inline void SetDevice(int devid) { - MSHADOW_CUDA_CALL(cudaSetDevice(devid)); -} -template -inline void AllocSpace(Tensor *obj, bool pad) { - size_t pitch; - // common choice for cuda mem align unit is 32 - if (pad && obj->size(dim - 1) >= MSHADOW_MIN_PAD_RATIO * 32) { - MSHADOW_CUDA_CALL(cudaMallocPitch(reinterpret_cast(&(obj->dptr_)), &pitch, - obj->size(dim - 1) * sizeof(DType), - obj->shape_.FlatTo2D()[0])); - obj->stride_ = static_cast(pitch / sizeof(DType)); - } else { - obj->stride_ = obj->size(dim - 1); - MSHADOW_CUDA_CALL(cudaMallocPitch(reinterpret_cast(&(obj->dptr_)), &pitch, - obj->shape_.Size() * sizeof(DType), 1)); - } -} -template -inline void FreeSpace(Tensor *obj) { - MSHADOW_CUDA_CALL(cudaFree(obj->dptr_)); - obj->dptr_ = NULL; -} -template -inline void Copy(Tensor _dst, - Tensor _src, - cudaMemcpyKind kind, - Stream *stream) { - CHECK_EQ(_dst.shape_, _src.shape_) << "Copy:shape mismatch"; - Tensor dst = _dst.FlatTo2D(); - Tensor src = _src.FlatTo2D(); - MSHADOW_CUDA_CALL(cudaMemcpy2DAsync(dst.dptr_, dst.stride_ * sizeof(DType), - src.dptr_, src.stride_ * sizeof(DType), - dst.size(1) * sizeof(DType), - dst.size(0), kind, - Stream::GetStream(stream))); - // use synchronize call behavior for zero stream - if (stream == NULL) { - MSHADOW_CUDA_CALL(cudaStreamSynchronize(0)); - } -} -template -inline void Copy(Tensor dst, - const Tensor &src, - Stream *stream) { - Copy(dst, src, cudaMemcpyDeviceToHost, stream); -} -template -inline void Copy(Tensor dst, - const Tensor &src, - Stream *stream) { - Copy(dst, src, cudaMemcpyDeviceToDevice, stream); -} -template -inline void Copy(Tensor dst, - const Tensor &src, - Stream *stream) { - Copy(dst, src, cudaMemcpyHostToDevice, stream); -} -#endif // MSHADOW_USE_CUDA -} // namespace mshadow - -// the following part is included only if compiler is nvcc -#ifdef __CUDACC__ -#include "./cuda/tensor_gpu-inl.cuh" - -namespace mshadow { -template -inline void MapExp(TRValue *dst, - const expr::Exp &exp) { - expr::TypeCheckPass::kMapPass> - ::Error_All_Tensor_in_Exp_Must_Have_Same_Type(); - Shape eshape = expr::ShapeCheck::Check(exp.self()); - Shape dshape = expr::ShapeCheck::Check(dst->self()); - CHECK(eshape[0] == 0 || eshape == dshape) - << "Assignment: Shape of Tensors are not consistent with target, " - << "eshape: " << eshape << " dshape:" << dshape; - cuda::MapPlan(MakePlan(dst->self()), - MakePlan(exp.self()), - dshape.FlatTo2D(), - Stream::GetStream(expr::StreamInfo::Get(dst->self()))); -} - -template -inline void MapReduceKeepLowest(TRValue *dst, - const expr::Exp &exp, - DType scale) { - expr::TypeCheckPass::kRedPass> - ::Error_TypeCheck_Not_Pass_For_Reduce_Exp(); - Shape<2> eshape = expr::ShapeCheck::kDim, E> - ::Check(exp.self()).FlatTo2D(); - Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self()); - CHECK_EQ(eshape[1], dshape[0]) << "MapReduceKeepLowest::reduction dimension do not match"; - CHECK_NE(eshape[0], 0U) << "can not reduce over empty tensor"; - cuda::MapReduceKeepLowest - (MakePlan(dst->self()), MakePlan(exp.self()), scale, eshape, - Stream::GetStream(expr::StreamInfo::Get(dst->self()))); -} - -template -inline void MapReduceKeepHighDim(TRValue *dst, - const expr::Exp &exp, - DType scale) { - expr::TypeCheckPass::kRedPass> - ::Error_TypeCheck_Not_Pass_For_Reduce_Exp(); - typedef Shape::kDim> EShape; - EShape eshape = expr::ShapeCheck::kDim, E> - ::Check(exp.self()); - Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self()); - CHECK_EQ(eshape[dimkeep], dshape[0]) << "MapReduceKeepHighDim::reduction dimension do not match"; - // use equvalent form - Shape<4> pshape = Shape4(eshape.ProdShape(0, dimkeep), - eshape[dimkeep], - eshape.ProdShape(dimkeep + 1, EShape::kSubdim), - eshape[EShape::kSubdim]); - // call equavalent map red dim 2 - cuda::MapReduceKeepDim1 - (MakePlan(dst->self()), MakePlan(exp.self()), scale, pshape, - Stream::GetStream(expr::StreamInfo::Get(dst->self()))); -} -template -inline void Softmax(Tensor dst, - const Tensor& src) { - cuda::Softmax(dst, src); -} - -template -inline void Softmax(Tensor dst, - const Tensor& src) { - cuda::Softmax(dst, src); -} - -template -inline void SoftmaxGrad(const Tensor &dst, - const Tensor &src, - const Tensor &label) { - cuda::SoftmaxGrad(dst, src, label); -} - -template -inline void SmoothSoftmaxGrad(const Tensor &dst, - const Tensor &src, - const Tensor &label, - const float alpha) { - cuda::SmoothSoftmaxGrad(dst, src, label, alpha); -} - -template -inline void SoftmaxGrad(const Tensor &dst, - const Tensor &src, - const Tensor &label, - const DType &ignore_label) { - cuda::SoftmaxGrad(dst, src, label, ignore_label); -} - -template -inline void SmoothSoftmaxGrad(const Tensor &dst, - const Tensor &src, - const Tensor &label, - const DType &ignore_label, - const float alpha) { - cuda::SmoothSoftmaxGrad(dst, src, label, ignore_label, alpha); -} - -template -inline void SoftmaxGrad(const Tensor &dst, - const Tensor &src, - const Tensor &label) { - cuda::SoftmaxGrad(dst, src, label); -} - -template -inline void SoftmaxGrad(const Tensor &dst, - const Tensor &src, - const Tensor &label, - const DType &ignore_label) { - cuda::SoftmaxGrad(dst, src, label, ignore_label); -} - -template -inline void AddTakeGrad(Tensor dst, - const Tensor& index, - const Tensor &src) { - cuda::AddTakeGrad(dst, index, src); -} - -template -inline void AddTakeGradLargeBatch(Tensor dst, - const Tensor& sorted, - const Tensor& index, - const Tensor &src) { - cuda::AddTakeGradLargeBatch(dst, sorted, index, src); -} - -template -inline void SortByKey(Tensor keys, Tensor values, - bool is_ascend) { - cuda::SortByKey(keys, values, is_ascend); -} - -template -inline void IndexFill(Tensor dst, - const Tensor& index, - const Tensor &src) { - cuda::IndexFill(dst, index, src); -} -} // namespace mshadow -#endif // __CUDACC__ -#endif // MSHADOW_TENSOR_GPU_INL_H_ diff --git a/include/nnvm b/include/nnvm new file mode 120000 index 000000000000..779dd4459a3c --- /dev/null +++ b/include/nnvm @@ -0,0 +1 @@ +../3rdparty/tvm/nnvm/include/nnvm \ No newline at end of file diff --git a/include/nnvm/base.h b/include/nnvm/base.h deleted file mode 100644 index 449bd2f4626e..000000000000 --- a/include/nnvm/base.h +++ /dev/null @@ -1,35 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file nnvm/base.h - * \brief Configuration of nnvm as well as basic data structure. - */ -#ifndef NNVM_BASE_H_ -#define NNVM_BASE_H_ - -#include -#include -#include -#include -#include -#include -#include - -namespace nnvm { - -/*! \brief any type */ -using dmlc::any; - -/*! \brief array_veiw type */ -using dmlc::array_view; - -/*!\brief getter function of any type */ -using dmlc::get; - -} // namespace nnvm - -// describe op registration point -#define NNVM_STRINGIZE_DETAIL(x) #x -#define NNVM_STRINGIZE(x) NNVM_STRINGIZE_DETAIL(x) -#define NNVM_DESCRIBE(...) describe(__VA_ARGS__ "\n\nFrom:" __FILE__ ":" NNVM_STRINGIZE(__LINE__)) -#define NNVM_ADD_FILELINE "\n\nDefined in " __FILE__ ":L" NNVM_STRINGIZE(__LINE__) -#endif // NNVM_BASE_H_ diff --git a/include/nnvm/c_api.h b/include/nnvm/c_api.h deleted file mode 100644 index daf9b564f3fa..000000000000 --- a/include/nnvm/c_api.h +++ /dev/null @@ -1,388 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file nnvm/c_api.h - * \brief C API of NNVM symbolic construction and pass. - * Enables construction and transformation of Graph - * in any other host languages. - */ -#ifndef NNVM_C_API_H_ -#define NNVM_C_API_H_ - -/*! \brief NNVM_DLL prefix for windows */ -#ifdef _WIN32 -#ifdef NNVM_EXPORTS -#define NNVM_DLL __declspec(dllexport) -#else -#define NNVM_DLL __declspec(dllimport) -#endif -#else -#define NNVM_DLL -#endif - -/*! \brief manually define unsigned int */ -typedef unsigned int nn_uint; - -/*! \brief handle to a function that takes param and creates symbol */ -typedef void *OpHandle; -/*! \brief handle to a symbol that can be bind as operator */ -typedef void *SymbolHandle; -/*! \brief handle to Graph */ -typedef void *GraphHandle; - -#ifdef __cplusplus -extern "C" { -#endif -/*! - * \brief Set the last error message needed by C API - * \param msg The error message to set. - */ -NNVM_DLL void NNAPISetLastError(const char* msg); - -/*! - * \brief return str message of the last error - * all function in this file will return 0 when success - * and -1 when an error occured, - * NNGetLastError can be called to retrieve the error - * - * this function is threadsafe and can be called by different thread - * \return error info - */ -NNVM_DLL const char *NNGetLastError(void); - -/*! - * \brief list all the available operator names, include entries. - * \param out_size the size of returned array - * \param out_array the output operator name array. - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNListAllOpNames(nn_uint *out_size, - const char*** out_array); - -/*! - * \brief Get operator handle given name. - * \param op_name The name of the operator. - * \param op_out The returnning op handle. - */ -NNVM_DLL int NNGetOpHandle(const char* op_name, - OpHandle* op_out); - -/*! - * \brief list all the available operators. - * This won't include the alias, use ListAllNames - * instead to get all alias names. - * - * \param out_size the size of returned array - * \param out_array the output AtomicSymbolCreator array - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNListUniqueOps(nn_uint *out_size, - OpHandle **out_array); - -/*! - * \brief Get the detailed information about atomic symbol. - * \param op The operator handle. - * \param real_name The returned name of the creator. - * This name is not the alias name of the atomic symbol. - * \param description The returned description of the symbol. - * \param num_doc_args Number of arguments that contain documents. - * \param arg_names Name of the arguments of doc args - * \param arg_type_infos Type informations about the arguments. - * \param arg_descriptions Description information about the arguments. - * \param return_type Return type of the function, if any. - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNGetOpInfo(OpHandle op, - const char **real_name, - const char **description, - nn_uint *num_doc_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions, - const char **return_type); -/*! - * \brief Create an AtomicSymbol functor. - * \param op The operator handle - * \param num_param the number of parameters - * \param keys the keys to the params - * \param vals the vals of the params - * \param out pointer to the created symbol handle - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNSymbolCreateAtomicSymbol(OpHandle op, - nn_uint num_param, - const char **keys, - const char **vals, - SymbolHandle *out); -/*! - * \brief Create a Variable Symbol. - * \param name name of the variable - * \param out pointer to the created symbol handle - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNSymbolCreateVariable(const char *name, SymbolHandle *out); -/*! - * \brief Create a Symbol by grouping list of symbols together - * \param num_symbols number of symbols to be grouped - * \param symbols array of symbol handles - * \param out pointer to the created symbol handle - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNSymbolCreateGroup(nn_uint num_symbols, - SymbolHandle *symbols, - SymbolHandle *out); -/*! - * \brief Add src_dep to the handle as control dep. - * \param handle The symbol to add dependency edges on. - * \param src_dep the source handles. - */ -NNVM_DLL int NNAddControlDeps(SymbolHandle handle, - SymbolHandle src_dep); -/*! - * \brief Free the symbol handle. - * \param symbol the symbol - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNSymbolFree(SymbolHandle symbol); -/*! - * \brief Copy the symbol to another handle - * \param symbol the source symbol - * \param out used to hold the result of copy - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out); -/*! - * \brief Print the content of symbol, used for debug. - * \param symbol the symbol - * \param out_str pointer to hold the output string of the printing. - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNSymbolPrint(SymbolHandle symbol, const char **out_str); -/*! - * \brief Get string attribute from symbol - * \param symbol the source symbol - * \param key The key of the symbol. - * \param out The result attribute, can be NULL if the attribute do not exist. - * \param success Whether the result is contained in out. - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNSymbolGetAttr(SymbolHandle symbol, - const char* key, - const char** out, - int *success); -/*! - * \brief Set string attribute from symbol. - * NOTE: Setting attribute to a symbol can affect the semantics(mutable/immutable) of symbolic graph. - * - * Safe recommendaton: use immutable graph - * - Only allow set attributes during creation of new symbol as optional parameter - * - * Mutable graph (be careful about the semantics): - * - Allow set attr at any point. - * - Mutating an attribute of some common node of two graphs can cause confusion from user. - * - * \param symbol the source symbol - * \param num_param Number of parameters to set. - * \param keys The keys of the attribute - * \param values The value to be set - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNSymbolSetAttrs(SymbolHandle symbol, - nn_uint num_param, - const char** keys, - const char** values); -/*! - * \brief Get all attributes from symbol, including all descendents. - * \param symbol the source symbol - * \param recursive_option 0 for recursive, 1 for shallow. - * \param out_size The number of output attributes - * \param out 2*out_size strings representing key value pairs. - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol, - int recursive_option, - nn_uint *out_size, - const char*** out); - -/*! - * \brief List inputs variables in the symbol. - * \param symbol the symbol - * \param option The option to list the inputs - * option=0 means list all arguments. - * option=1 means list arguments that are readed only by the graph. - * option=2 means list arguments that are mutated by the graph. - * \param out_size output size - * \param out_sym_array the output array. - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNSymbolListInputVariables(SymbolHandle symbol, - int option, - nn_uint *out_size, - SymbolHandle** out_sym_array); - -/*! - * \brief List input names in the symbol. - * \param symbol the symbol - * \param option The option to list the inputs - * option=0 means list all arguments. - * option=1 means list arguments that are readed only by the graph. - * option=2 means list arguments that are mutated by the graph. - * \param out_size output size - * \param out_str_array pointer to hold the output string array - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNSymbolListInputNames(SymbolHandle symbol, - int option, - nn_uint *out_size, - const char ***out_str_array); -/*! - * \brief List returns names in the symbol. - * \param symbol the symbol - * \param out_size output size - * \param out_str_array pointer to hold the output string array - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNSymbolListOutputNames(SymbolHandle symbol, - nn_uint *out_size, - const char ***out_str_array); - - -/*! - * \brief Supply number of outputs of the symbol. - * \param symbol the symbol - * \param output_count number of outputs - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNSymbolGetNumOutputs(SymbolHandle symbol, - nn_uint *output_count); - -/*! - * \brief Get a symbol that contains all the internals. - * \param symbol The symbol - * \param out The output symbol whose outputs are all the internals. - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNSymbolGetInternals(SymbolHandle symbol, - SymbolHandle *out); -/*! - * \brief Get a symbol that contains only direct children. - * \param symbol The symbol - * \param out The output symbol whose outputs are the direct children. - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNSymbolGetChildren(SymbolHandle symbol, - SymbolHandle *out); -/*! - * \brief Get index-th outputs of the symbol. - * \param symbol The symbol - * \param index the Index of the output. - * \param out The output symbol whose outputs are the index-th symbol. - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNSymbolGetOutput(SymbolHandle symbol, - nn_uint index, - SymbolHandle *out); - -/*! - * \brief Compose the symbol on other symbols. - * - * This function will change the sym hanlde. - * To achieve function apply behavior, copy the symbol first - * before apply. - * - * \param sym the symbol to apply - * \param name the name of symbol - * \param num_args number of arguments - * \param keys the key of keyword args (optional) - * \param args arguments to sym - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNSymbolCompose(SymbolHandle sym, - const char* name, - nn_uint num_args, - const char** keys, - SymbolHandle* args); - -// Graph IR API -/*! - * \brief create a graph handle from symbol - * \param symbol The symbol representing the graph. - * \param graph The graph handle created. - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNGraphCreate(SymbolHandle symbol, GraphHandle *graph); -/*! - * \brief free the graph handle - * \param handle The handle to be freed. - */ -NNVM_DLL int NNGraphFree(GraphHandle handle); -/*! - * \brief Get a new symbol from the graph. - * \param graph The graph handle. - * \param symbol The corresponding symbol - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol); - -/*! - * \brief Get Set a attribute in json format. - * This feature allows pass graph attributes back and forth in reasonable speed. - * - * \param handle The graph handle. - * \param key The key to the attribute. - * \param json_value The value need to be in format [type_name, value], - * Where type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY. - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle, - const char* key, - const char* json_value); - -/*! - * \brief Get a serialized attrirbute from graph. - * This feature allows pass graph attributes back and forth in reasonable speed. - * - * \param handle The graph handle. - * \param key The key to the attribute. - * \param json_out The result attribute, can be NULL if the attribute do not exist. - * The json_out is an array of [type_name, value]. - * Where the type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY. - * \param success Whether the result is contained in out. - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNGraphGetJSONAttr(GraphHandle handle, - const char* key, - const char** json_out, - int *success); - -/*! - * \brief Set a attribute whose type is std::vector in c++ - * This feature allows pass List of symbolic variables for gradient request. - * - * \note This is beta feature only used for test purpos - * - * \param handle The graph handle. - * \param key The key to the attribute. - * \param list The symbol whose outputs represents the list of NodeEntry to be passed. - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle, - const char* key, - SymbolHandle list); -/*! - * \brief Apply passes on the src graph. - * \param src The source graph handle. - * \param num_pass The number of pass to be applied. - * \param pass_names The names of the pass. - * \param dst The result graph. - * \return 0 when success, -1 when failure happens - */ -NNVM_DLL int NNGraphApplyPasses(GraphHandle src, - nn_uint num_pass, - const char** pass_names, - GraphHandle *dst); - -#ifdef __cplusplus -} /* end extern "C" */ -#endif - -#endif // NNVM_C_API_H_ diff --git a/include/nnvm/compiler/op_attr_types.h b/include/nnvm/compiler/op_attr_types.h deleted file mode 100644 index 497a520db78e..000000000000 --- a/include/nnvm/compiler/op_attr_types.h +++ /dev/null @@ -1,101 +0,0 @@ -/*! - * Copyright (c) 2017 by Contributors - * \file nnvm/compiler/op_attr_types.h - * \brief The Expr and related elements in DataFlow construction. - */ -#ifndef NNVM_COMPILER_OP_ATTR_TYPES_H_ -#define NNVM_COMPILER_OP_ATTR_TYPES_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "packed_func_ext.h" - -namespace nnvm { -namespace compiler { - -using ::tvm::Array; -using ::tvm::Tensor; -using ::tvm::Schedule; - -/*! \brief operator pattern used in graph fusion */ -enum OpPatternKind { - // Elementwise operation - kElemWise = 0, - // Broadcasting operator, can always map output axis to the input in order. - // for example :code:`out[i, ax1, j, ax2] = input[i, j]`. - // Note that the axis need to be in order so transpose is not a bcast operator. - kBroadcast = 1, - // Injective operator, can always injectively map output axis to a single input axis. - // All injective operator can still be safely fused to injective and reduction. - kInjective = 2, - // Communicative reduction operator. - kCommReduce = 3, - // Complex operation, can still fuse elemwise operations into its output. - // but cannot chain another complex op - kOutEWiseFusable = 4, - // Opaque operation, cannot fuse anything. - kOpaque = 8 -}; - -/*! \brief the operator pattern */ -using TOpPattern = int; - -/*! - * \brief Computation description interface - * \param attrs The attribute of the node. - * \param inputs The input tensors(placeholders) - * \param out_info Tensors holding shape/type information about output, - & these are always placeholders. - * \return The output description of the tensor. - */ -using FTVMCompute = std::function< - Array(const NodeAttrs& attrs, - const Array& inputs, - const Array& out_info)>; - -/*! - * \brief Build the computation schedule for - * op whose root is at current op. - * \param attrs The attribute of the node. - * \param outs The output tensors. - * \param target The build target. - * \return schedule The computation schedule. - */ -using FTVMSchedule = std::function< - Schedule(const NodeAttrs& attrs, - const Array& outs, - const std::string& target)>; - -/*! - * \brief Modify the op node to alter its input layout. - * it is invoked in AlterOpLayout pass. - * \param attrs The attribute of the original node. - * \param inputs The input symbols of the original node. - * \param tinfos The inferred shape and dtype of the inputs. - * \param ret The replaced operator. - * \return Whether to replace current operator. - */ -using FTVMAlterOpLayout = std::function< - bool(const NodeAttrs& attrs, - const Symbol& inputs, - const Array& tinfos, - Symbol* ret)>; - -/*! - * \brief Transform from normal operator to vectorized operator - * \param node The source node. - * \return Transformed vectorized op. - */ -using FTVMVectorizedOp = std::function; - -} // namespace compiler -} // namespace nnvm -#endif // NNVM_COMPILER_OP_ATTR_TYPES_H_ diff --git a/include/nnvm/compiler/packed_func_ext.h b/include/nnvm/compiler/packed_func_ext.h deleted file mode 100644 index e289fd4efa59..000000000000 --- a/include/nnvm/compiler/packed_func_ext.h +++ /dev/null @@ -1,59 +0,0 @@ -/*! - * Copyright (c) 2017 by Contributors - * \file nnvm/compiler/packed_func_ext.h - * \brief Extension to enable packed functionn for nnvm types - */ -#ifndef NNVM_COMPILER_PACKED_FUNC_EXT_H_ -#define NNVM_COMPILER_PACKED_FUNC_EXT_H_ - -#include -#include -#include -#include -#include -#include -#include - -namespace nnvm { -namespace compiler { - -using tvm::runtime::PackedFunc; - -using AttrDict = std::unordered_map; - -/*! - * \brief Get PackedFunction from global registry and - * report error if it does not exist - * \param name The name of the function. - * \return The created PackedFunc. - */ -inline const PackedFunc& GetPackedFunc(const std::string& name) { - const PackedFunc* pf = tvm::runtime::Registry::Get(name); - CHECK(pf != nullptr) << "Cannot find function " << name << " in registry"; - return *pf; -} -} // namespace compiler -} // namespace nnvm - -// Enable the graph and symbol object exchange. -namespace tvm { -namespace runtime { - -template<> -struct extension_class_info { - static const int code = 16; -}; - -template<> -struct extension_class_info { - static const int code = 17; -}; - -template<> -struct extension_class_info { - static const int code = 18; -}; - -} // namespace runtime -} // namespace tvm -#endif // NNVM_COMPILER_PACKED_FUNC_EXT_H_ diff --git a/include/nnvm/compiler/util.h b/include/nnvm/compiler/util.h deleted file mode 100644 index 5d5bc4478530..000000000000 --- a/include/nnvm/compiler/util.h +++ /dev/null @@ -1,33 +0,0 @@ -/*! -* Copyright (c) 2016 by Contributors -* \file nnvm/compiler/util.h -* \brief Utility functions for nnvm compiler -*/ -#ifndef NNVM_COMPILER_UTIL_H_ -#define NNVM_COMPILER_UTIL_H_ - -#include -#include - -namespace nnvm { -namespace compiler { - -/* - * \brief Helper function to convert TShape to TVM array. Useful for - * passing data from NNVM param structures to TOPI ops. - * - * \param shape The shape to convert - * - * \return An Array of Expr, where each element is a constant int32 - */ -inline tvm::Array ShapeToArray(TShape shape) { - tvm::Array result; - for (auto i : shape) { - result.push_back(tvm::make_const(tvm::Int(32), i)); - } - return result; -} - -} // namespace compiler -} // namespace nnvm -#endif // NNVM_COMPILER_UTIL_H_ diff --git a/include/nnvm/graph.h b/include/nnvm/graph.h deleted file mode 100644 index 3f8a2a3642b1..000000000000 --- a/include/nnvm/graph.h +++ /dev/null @@ -1,315 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file nnvm/graph.h - * \brief Configuation of nnvm as well as basic data structure. - */ -#ifndef NNVM_GRAPH_H_ -#define NNVM_GRAPH_H_ - -#include -#include -#include -#include -#include -#include -#include "base.h" -#include "node.h" -#include "symbolic.h" - -namespace nnvm { - -class IndexedGraph; - -/*! - * \brief Symbolic computation graph. - * This is the intermediate representation for optimization pass. - */ -class Graph { - public: - /*! \brief outputs of the computation graph. */ - std::vector outputs; - /*! - * \brief attributes of a graph - * Note that attribute is shared pointer and can be shared across graphs. - * - * It is highly recommended to keep each attribute immutable. - * It is also safe to implement an copy-on-write semnatics. - * - * Copy when shared_ptr.unique is not true, while reuse original space - * when shared_ptr.unique is true. - */ - std::unordered_map > attrs; - /*! - * \brief Get the immutable attribute from attrs. - * \param attr_name the name of the attribute - * \return the reference to corresponding attribute - * \tparam T the type of the attribute. - */ - template - inline const T& GetAttr(const std::string& attr_name) const; - /*! - * \brief Check whether has a specific attribute. - * \param attr_name the name of the attribute - * \return a boolean result - */ - inline bool HasAttr(const std::string& attr_name) const; - /*! - * \brief Get a move copy of the attribute, implement copy on write semantics. - * The content is moved if the reference counter of shared_ptr is 1. - * The attribute is erased from attrs after the call. - * - * \param attr_name the name of the attribute - * \return a new copy of the corresponding attribute. - * \tparam T the type of the attribute. - */ - template - inline T MoveCopyAttr(const std::string& attr_name); - /*! - * \brief get a indexed graph of current graph, if not exist, create it on demand - * \return The indexed graph. - * \sa IndexedGraph - */ - const IndexedGraph& indexed_graph() const; - - private: - // internal structure of indexed graph - mutable std::shared_ptr indexed_graph_; -}; - -/*! - * \brief Auxiliary data structure to index a graph. - * It maps Nodes in the graph to consecutive integers node_id. - * It also maps IndexedGraph::NodeEntry to consecutive integer entry_id. - * This allows storing properties of Node and NodeEntry into - * compact vector and quickly access them without resorting to hashmap. - * - * The node_id and entry_rptr are the same as the JSON graph produced by SaveJSON Pass. - */ -class IndexedGraph { - public: - /*! \brief represents a data in the graph */ - struct NodeEntry { - /*! \brief the source node id in the computation graph */ - uint32_t node_id; - /*! \brief index of output from the source. */ - uint32_t index; - /*! \brief version of the node */ - uint32_t version; - }; - /*! \brief Node data structure in IndexedGraph */ - struct Node { - /*! \brief pointer to the source node */ - const nnvm::Node* source; - /*! \brief inputs to the node */ - array_view inputs; - /*! \brief control flow dependencies to the node */ - array_view control_deps; - /*! \brief weak reference to node */ - std::weak_ptr weak_ref; - }; - /*! \return number of nodes in the graph */ - inline size_t num_nodes() const { - return nodes_.size(); - } - /*! \return total number of NodeEntry in the graph */ - inline size_t num_node_entries() const { - return entry_rptr_.back(); - } - /*! - * \brief Get a unique entry id between 0 to num_node_entries() - * for a given IndexedGraph::NodeEntry - * \param node_id The node index - * \param index the output index - * \return the unique index. - */ - inline uint32_t entry_id(uint32_t node_id, uint32_t index) const { - return entry_rptr_[node_id] + index; - } - /*! - * \brief Get a unique entry id between 0 to num_node_entries() - * for a given IndexedGraph::NodeEntry - * \param e The entry to query for index. - * \return the unique index. - */ - inline uint32_t entry_id(const NodeEntry& e) const { - return entry_rptr_[e.node_id] + e.index; - } - /*! - * \brief Get a unique entry id between 0 to num_node_entries() - * for a given NodeEntry. - * \param e The entry to query for index. - * \return the unique index. - */ - inline uint32_t entry_id(const nnvm::NodeEntry& e) const { - return entry_rptr_[node_id(e.node.get())] + e.index; - } - /*! - * \brief Get the corresponding node id for a given Node in the IndexedGraph. - * \param node The Node to query for index. - * \return the node index. - */ - inline uint32_t node_id(const nnvm::Node* node) const { - return node2index_.at(node); - } - /*! - * \brief Get the corresponding Node structure for a given node_id. - * \param node_id The node id - * \return const reference to the corresponding IndexedGraph::Node - */ - inline const Node& operator[](uint32_t node_id) const { - return nodes_[node_id]; - } - /*! - * \brief Get the corresponding Node structure - * \param node The pointer to the Node structure - * \return const reference to the corresponding IndexedGraph::Node - */ - inline const Node& operator[](const nnvm::Node* node) const { - return nodes_[node_id(node)]; - } - /*! \return list of argument nodes */ - inline const std::vector& input_nodes() const { - return input_nodes_; - } - /*! \return list of mutable nodes */ - inline const std::unordered_set& mutable_input_nodes() const { - return mutable_input_nodes_; - } - /*! \return list of output entries */ - inline const std::vector& outputs() const { - return outputs_; - } - - /*! \return whether a node is existed in the indexed graph */ - inline bool exist(const nnvm::Node* node) const { - return node2index_.count(node); - } - - // disalllow copy assign - IndexedGraph(const IndexedGraph&) = delete; - - private: - friend class Graph; - /*! - * \brief Constructor an IndexedGraph from normal Graph - * \param other The source graph. - */ - explicit IndexedGraph(const Graph& other); - // Node pointers in CSR structure. - std::vector nodes_; - // Index to all input nodes. - std::vector input_nodes_; - // Index to all mutable input nodes. - std::unordered_set mutable_input_nodes_; - // space to store the outputs entries - std::vector outputs_; - // mapping from node to index. - std::unordered_map node2index_; - // CSR pointer of node entries - std::vector entry_rptr_; - // space to store input entries of each - std::vector input_entries_; - // control flow dependencies - std::vector control_deps_; -}; - -/*! - * \brief perform a Post Order DFS visit to each node in the graph. - * This order is deterministic and is also topoligical sorted. - * \param heads The heads in the graph. - * \param fvisit a function of type std::function&)> - * \tparam FVisit The function type to perform the visit. - */ -template -inline void DFSVisit(const std::vector& heads, FVisit fvisit); - -// inline function implementations -template -inline const T& Graph::GetAttr(const std::string& attr_name) const { - auto it = attrs.find(attr_name); - CHECK(it != attrs.end()) - << "Cannot find attribute " << attr_name << " in the graph"; - return nnvm::get(*it->second); -} - -inline bool Graph::HasAttr(const std::string& attr_name) const { - auto it = attrs.find(attr_name); - return it != attrs.end(); -} - -template -inline T Graph::MoveCopyAttr(const std::string& attr_name) { - auto it = attrs.find(attr_name); - CHECK(it != attrs.end()) - << "Cannot find attribute " << attr_name << " in the graph"; - std::shared_ptr sptr = it->second; - attrs.erase(it); - if (sptr.unique()) { - return std::move(nnvm::get(*sptr)); - } else { - return nnvm::get(*sptr); - } -} - -template -void PostOrderDFSVisit(const std::vector& heads, - FVisit fvisit, - HashFunc hash, - InDegree indegree, - GetInput getinput) { - std::vector > stack; - std::unordered_set visited; - for (auto& head : heads) { - HashType head_hash = hash(head); - if (visited.count(head_hash) == 0) { - stack.push_back(std::make_pair(head, 0)); - visited.insert(head_hash); - } - while (!stack.empty()) { - std::pair& back = stack.back(); - if (back.second == indegree(back.first)) { - fvisit(back.first); - stack.pop_back(); - } else { - const GNode& input = getinput(back.first, back.second++); - HashType input_hash = hash(input); - if (visited.count(input_hash) == 0) { - stack.push_back(std::make_pair(input, 0)); - visited.insert(input_hash); - } - } - } - } -} - -template -inline void DFSVisit(const std::vector& heads, - FVisit fvisit) { - typedef const NodePtr* GNode; - std::vector head_nodes(heads.size()); - std::transform(heads.begin(), heads.end(), head_nodes.begin(), - [](const NodeEntry& e)->GNode { - return &e.node; - }); - PostOrderDFSVisit( - head_nodes, - [fvisit](GNode n) { fvisit(*n); }, // FVisit - [](GNode n)->Node* { return n->get(); }, // HashFunc - [](GNode n)->uint32_t { // InDegree - if (!(*n)) return 0; - return (*n)->inputs.size() + (*n)->control_deps.size(); - }, - [](GNode n, uint32_t index)->GNode { // GetInput - if (index < (*n)->inputs.size()) { - return &(*n)->inputs.at(index).node; - } else { - return &(*n)->control_deps.at(index - (*n)->inputs.size()); - } - }); -} - -} // namespace nnvm - -#endif // NNVM_GRAPH_H_ diff --git a/include/nnvm/graph_attr_types.h b/include/nnvm/graph_attr_types.h deleted file mode 100644 index 2fe82c9a7de0..000000000000 --- a/include/nnvm/graph_attr_types.h +++ /dev/null @@ -1,112 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file nnvm/graph_attr_types.h - * \brief Data structures that can appear in graph attributes. - */ -#ifndef NNVM_GRAPH_ATTR_TYPES_H_ -#define NNVM_GRAPH_ATTR_TYPES_H_ - -#include -#include -#include "tuple.h" -#include "layout.h" - -namespace nnvm { - -/*! - * \brief The result holder of JSON serializer - * - * \note Stored under ret.attrs["json"], provided by Pass "SaveJSON" - - * \code - * Graph ret = ApplyPass(src_graph, "SaveJSON"); - * const JSONString& json = ret.GetAttr("shape"); - * \endcode - */ -using JSONString = std::string; - -/*! - * \brief The result holder of shape of each NodeEntry in the graph. - * \note Stored under graph.attrs["shape"], provided by Pass "InferShape" - * - * \code - * Graph g = ApplyPass(src_graph, "InferShape"); - * const ShapeVector& shapes = g.GetAttr("shape"); - * // get shape by entry id - * TShape entry_shape = shapes[g.indexed_graph().entry_id(my_entry)]; - * \endcode - * - * \sa FInferShape - */ -using ShapeVector = std::vector; - -/*! - * \brief The result holder of type of each NodeEntry in the graph. - * \note Stored under graph.attrs["dtype"], provided by Pass "InferType" - * - * \code - * Graph g = ApplyPass(src_graph, "InferType"); - * const DTypeVector& types = g.GetAttr("dtype"); - * // get type by entry id - * int entry_type = dtypes[g.indexed_graph().entry_id(my_entry)]; - * \endcode - * - * \sa FInferType - */ -using DTypeVector = std::vector; - -/*! - * \brief The result holder of layout of each NodeEntry in the graph. - * \note Stored under graph.attrs["layout"], provided by Pass "InferType" - * - * \code - * Graph g = ApplyPass(src_graph, "LayoutTransform"); - * const LayoutVector& layouts = g.GetAttr("layout"); - * // get layout by entry id - * int entry_layout = layouts[g.indexed_graph().entry_id(my_entry)]; - * \endcode - * - * \sa FCorrectLayout - */ -using LayoutVector = std::vector; - -/*! - * \brief The result holder of device of each operator in the graph. - * \note Stored under graph.attrs["device"], provided by Pass "PlaceDevice" - * - * \code - * Graph g = ApplyPass(src_graph, "PlaceDevice"); - * const &device = g.GetAttr("device"); - * // get device by node_id - * int device_type = device[g.indexed_graph().node_id(my_node)]; - * \endcode - */ -using DeviceVector = std::vector; - -/*! - * \brief The result holder of device of each operator in the graph. - * - * \note Stored under graph.attrs["device_assign_map"], needed by Pass "PlaceDevice" - * -1 means unknown device - */ -using DeviceAssignMap = std::unordered_map; - -/*! - * \brief The result holder of storage id of each NodeEntry in the graph. - * - * \note Stored under graph.attrs["storage"], provided by Pass "PlanMemory" - * Storage id is a continuous integer. - * If the storage id is -1 then the storage is not assigned. - * - * \code - * Graph g = ApplyPass(src_graph, "PlanMemory"); - * const &storage = g.GetAttr("storage"); - * // get storage id by entry - * int storage_id = storage[g.indexed_graph().entry_id(my_entry)]; - * \endcode - */ -using StorageVector = std::vector; - -} // namespace nnvm - -#endif // NNVM_GRAPH_ATTR_TYPES_H_ diff --git a/include/nnvm/layout.h b/include/nnvm/layout.h deleted file mode 100644 index 94813f5323f8..000000000000 --- a/include/nnvm/layout.h +++ /dev/null @@ -1,455 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file nnvm/layout.h - * \brief Layout expression. - * The layout is composed of upper cases, lower cases and numbers, - * where upper case indicates a (super-)dimension and - * the corresponding lower case with factor size indicates the split (sub-)dimension. - * For example, NCHW16c can describe a 5-D tensor of - * [batch_size, channel, height, width, channel_block]. - * Here sub-dimension channel_block=16 is the split of super-dimension C (channel). - */ -#ifndef NNVM_LAYOUT_H_ -#define NNVM_LAYOUT_H_ - -#include -#include -#include -#include -#include -#include - -namespace nnvm { - -class Layout { - public: - using LayoutDim = char; - - /*! \brief default constructor */ - Layout() : name_("__undef__") {} // NOLINT(*) - - /*! - * \brief construct from a string. - * \param layout input in layout convention: - * upper case indicates a dimension and - * the corresponding lower case with factor size - * indicates the split dimension. - * return undefined layout if "__undef__" is passed. - */ - inline Layout(const std::string& layout) { // NOLINT(*) - parse(layout); - } - /*! - * \brief copy constructor from another layout - * \param s the source layout - */ - inline Layout(const Layout& s) { // NOLINT(*) - this->parse(s.name_); - } - /*! - * \brief move constructor from Layout - * \param src the source layout - */ - inline Layout(Layout&& src) { // NOLINT(*) - this->swap(src); - } - /*! - * \brief assignment from another layout. - * \param src source layout - * \return reference of self - */ - inline Layout& operator=(const Layout& src) { - this->parse(src.name_); - return *this; - } - /*! - * \brief assignment from rvalue of another layout. - * \param src source layout - * \return reference of self - */ - inline Layout& operator=(Layout&& src) { - Layout(std::move(src)).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief assignment from string. - * \param src source layout - * \return reference of self - */ - inline Layout& operator=(const std::string& src) { - this->parse(src); - return *this; - } - /*! - * \return whether two layout equals - * \param s the layout to compare against - */ - inline bool operator==(const Layout& s) const { - return name_ == s.name_; - } - /*! - * \return whether two layout not equal - * \param s the layout to compare against - */ - inline bool operator!=(const Layout& s) const { - return !(*this == s); - } - - /*! - * \brief Append the current layout by another. - * @param other the layout to be appended - * @return a new layout - */ - inline Layout operator+(const Layout& other) const { - if (!this->defined() && !other.defined()) { - return Layout::Undef(); - } else if (!this->defined()) { - return other; - } else if (!other.defined()) { - return *this; - } - return Layout(this->name_ + other.name_); - } - - /*! - * \brief Check whether a given dimension is a super-dimension. - * \param dim input dimension - * \return Whether a given dimension is a super-dimension. - */ - static inline bool is_superdim(LayoutDim dim) { - return dim >= 'A' && dim <= 'Z'; - } - - /*! - * \brief Check whether a given dimension is a sub-dimension. - * \param dim input dimension - * \return Whether a given dimension is a sub-dimension. - */ - static inline bool is_subdim(LayoutDim dim) { - return dim >= 'a' && dim <= 'z'; - } - - /*! - * \brief Convert a given dimension to super-dimension. - * \param dim input dimension - * \return The converted description. - */ - static inline LayoutDim to_superdim(LayoutDim dim) { - if (is_subdim(dim)) { - return dim - 'a' + 'A'; - } - return dim; - } - - /*! - * \brief Convert a given dimension to sub-dimension. - * \param dim input dimension - * \return The converted description. - */ - static inline LayoutDim to_subdim(LayoutDim dim) { - if (is_superdim(dim)) { - return dim - 'A' + 'a'; - } - return dim; - } - - /*! - * \brief Return an undefined layout. - * \return a (global) undefined layout. - */ - static inline const Layout& Undef() { - static Layout undef; - return undef; - } - - /*! - * \brief Swap current object with other - * \param other another object to be swapped. - */ - inline void swap(Layout& other) { // NOLINT(*) - std::swap(name_, other.name_); - std::swap(superdim_pos_, other.superdim_pos_); - std::swap(subdim_pos_, other.subdim_pos_); - std::swap(subdim_size_, other.subdim_size_); - std::swap(layout_simplified_, other.layout_simplified_); - } - - /*! - * \brief Two layouts are convertible only if - * they have same set of super-dimensions. - * e.g., NCHW, NCHW16c, NHWC are convertible between each other, - * but NCHW, CHW, OIHW are not. - * \param dst the target layout - * \return Whether can be converted to dst layout. - */ - inline bool convertible(const Layout &dst) const { - if (!this->defined() || !dst.defined()) return false; - for (size_t i = 0; i < kUniqueDim; ++i) { - if ((superdim_pos_[i] >= 0 && dst.superdim_pos_[i] < 0) || - (superdim_pos_[i] < 0 && dst.superdim_pos_[i] >= 0)) { - return false; - } - } - return true; - } - - /*! - * \brief Returns a sublayout which is the portion of the object - * that starts at dimension \p pos and spans \p len dimensions - * (or until the end of the layout, whichever comes first). - * \param pos The start position. - * \param len The length of the sub-layout. - * \return A newly constructed Layout object. - */ - inline Layout sublayout(size_t pos, size_t len) const { - if (pos > ndim()) return Layout::Undef(); - if (pos + len > ndim()) len = ndim() - pos; - if (len == 0) return Layout::Undef(); - std::ostringstream new_layout; - for (size_t i = pos; i < pos + len; ++i) { - if (is_subdim(layout_simplified_[i])) { - auto block_size = this->subsizeof(layout_simplified_[i]); - CHECK_GT(block_size, 0); - new_layout << block_size; - } - new_layout << layout_simplified_[i]; - } - return Layout(new_layout.str()); - } - - /*! \return A newly constructed reversed Layout object. */ - inline Layout reverse() const { - if (!this->defined()) return Layout::Undef(); - std::ostringstream new_layout; - for (int64_t i = this->ndim() - 1; i >= 0; --i) { - if (is_subdim(layout_simplified_[i])) { - auto block_size = this->subsizeof(layout_simplified_[i]); - CHECK_GT(block_size, 0); - new_layout << block_size; - } - new_layout << layout_simplified_[i]; - } - return Layout(new_layout.str()); - } - - /*! - * \brief Split \p dim by \p size and put the sub-dimension to position \p target_pos. - * \param dim The source dimension to be split. It must be a super-dimension. - * \param target_pos The target position of the newly split sub-dimension. - * \param size size of the sub-dimension. - * \return A newly constructed Layout object. - */ - inline Layout split(LayoutDim dim, size_t target_pos, uint32_t size) const { - CHECK(target_pos <= this->ndim()) << "Invalid split position " - << target_pos << " for layout " << name_; - CHECK(is_superdim(dim)) << "Cannot split a sub-dimension " << dim; - CHECK(this->contains(dim)) << "Axis " << dim << " does not exist in " << name_; - CHECK(!this->contains(to_subdim(dim))) << "Dimension " << dim - << " has already been split in " - << name_; - CHECK(size > 0) << "Invalid split size " << size; - std::ostringstream new_layout; - for (size_t i = 0; i <= this->ndim(); ++i) { - if (i == target_pos) { - new_layout << size << Layout::to_subdim(dim); - } - if (i == this->ndim()) break; - new_layout << this->at(i); - } - Layout x(new_layout.str()); - return x; - } - - using iterator = std::vector::const_iterator; - using reverse_iterator = std::vector::const_reverse_iterator; - - /*! \return begin iterator */ - inline iterator begin() const { - return layout_simplified_.begin(); - } - /*! \return end iterator */ - inline iterator end() const { - return layout_simplified_.end(); - } - /*! \return rbegin iterator */ - inline reverse_iterator rbegin() const { - return layout_simplified_.rbegin(); - } - /*! \return rend iterator */ - inline reverse_iterator rend() const { - return layout_simplified_.rend(); - } - - /*! \return number of dimensions */ - inline size_t ndim() const { - return layout_simplified_.size(); - } - - /*! - * \brief The description of the \p i-th dimension. - * If it is a sub-dimension, the size will be returned as well, - * e.g., 16c. Otherwise a single character is returned, e.g., C. - * \param i The position - * \return the description of the dimension. - */ - inline std::string at(size_t i) const { - CHECK_LT(i, this->ndim()) << "position " << i - << " exceeds ndim=" << this->ndim(); - std::ostringstream repr; - if (is_subdim(layout_simplified_[i])) { - auto factor = subsizeof(layout_simplified_[i]); - CHECK_GT(factor, 0); - repr << factor; - } - repr << layout_simplified_[i]; - return repr.str(); - } - - /*! - * \brief return the index of the input dimension. - * If it is not found in the layout or the layout is undefined, - * return -1. - * \param dim the input dimension. - * \return the index or -1 if not found. - */ - inline int32_t indexof(LayoutDim dim) const { - if (!this->defined()) return -1; - else if (is_superdim(dim)) return superdim_pos_[dim - 'A']; - else if (is_subdim(dim)) return subdim_pos_[dim - 'a']; - return -1; - } - - /*! - * \param dim the input super-dimension or sub-dimension. - * \return the size of the sub-dimension of \p dim (if \p dim is a super-dimension), - * or the size of \p dim itself (if \p dim is a sub-dimension). - * Return -1 if \p dim is not in the layout or the layout is undefined. - */ - inline int64_t subsizeof(LayoutDim dim) const { - CHECK(is_superdim(dim) || is_subdim(dim)) << "Invalid dim " << dim; - if (!this->defined() || !this->contains(to_subdim(dim))) { - return -1; - } - int idx = to_subdim(dim) - 'a'; - return subdim_size_[idx]; - } - - /*! - * \brief Whether the layout contains a dimension. - * \param dim dimension to be checked. - * \return Whether the layout contains the dimension. - */ - inline bool contains(LayoutDim dim) const { - if (is_superdim(dim)) { - return superdim_pos_[dim-'A'] >= 0; - } else if (is_subdim(dim)) { - return subdim_pos_[dim-'a'] >= 0; - } - return false; - } - - inline LayoutDim operator[](size_t i) const { - return layout_simplified_[i]; - } - - /*! \return whether the layout is defined */ - inline bool defined() const { - return name_ != "__undef__"; - } - - /*! \return the string description of the layout */ - inline const std::string& name() const { - return name_; - } - - /*! - * \brief Write layout in JSON format. - * \param writer JSONWriter - */ - inline void Save(dmlc::JSONWriter* writer) const { - writer->Write(name_); - } - - /*! - * \brief Load layout from JSON. - * \param reader JSONReader - */ - inline void Load(dmlc::JSONReader* reader) { - std::string tmp; - reader->Read(&tmp); - this->parse(tmp); - } - - /*! - * \brief allow output string of layout to ostream - * \param os the output stream - * \param l the layout - * \return the ostream - */ - friend std::ostream& operator<<(std::ostream& os, const Layout& l) { - os << l.name_; - return os; - } - - private: - static const uint32_t kUniqueDim = 26; - - std::string name_; - int32_t superdim_pos_[kUniqueDim]; - int32_t subdim_pos_[kUniqueDim]; - int64_t subdim_size_[kUniqueDim]; - std::vector layout_simplified_; - - void parse(const std::string& layout) { - name_ = layout; - std::fill_n(superdim_pos_, kUniqueDim, -1); - std::fill_n(subdim_pos_, kUniqueDim, -1); - std::fill_n(subdim_size_, kUniqueDim, -1); - layout_simplified_.clear(); - - if (layout == "__undef__") return; - - int32_t factor = 0; - uint32_t curr = 0; - for (size_t i = 0; i < layout.size(); ++i) { - const LayoutDim c = layout.at(i); - if (is_superdim(c)) { - int pos = c - 'A'; - CHECK_EQ(factor, 0) << "Invalid layout " << layout - << ": invalid factor size " << factor - << " before dimension " << c; - CHECK_EQ(superdim_pos_[pos], -1) << "Invalid layout " << layout - << ": duplicate dimension " << c; - superdim_pos_[pos] = curr++; - layout_simplified_.push_back(c); - } else if (is_subdim(c)) { - int pos = c - 'a'; - CHECK_GT(factor, 0) << "Invalid layout " << layout << ": invalid factor size " - << factor << " for dimension " << c; - CHECK_EQ(subdim_pos_[pos], -1) << "Invalid layout " << layout - << ": duplicate dimension " << c; - CHECK_EQ(subdim_size_[pos], -1) << "Invalid layout " << layout - << ": duplicate dimension " << c; - subdim_pos_[pos] = curr++; - subdim_size_[pos] = factor; - layout_simplified_.push_back(c); - factor = 0; - } else if (c >= '0' && c <= '9') { - CHECK(factor >= 0) << "Invalid layout " << layout << ": _ is adjacent to a number."; - factor = factor * 10 + c - '0'; - } else { - LOG(FATAL) << "Invalid layout " << layout; - } - } - CHECK(!layout_simplified_.empty()) << "Invalid layout " << layout; - for (LayoutDim dim : layout_simplified_) { - CHECK(is_superdim(dim) || superdim_pos_[dim-'a'] >= 0) - << "Invalid layout " << layout << ": missing axis " - << static_cast(dim - 'a' + 'A'); - } - } -}; - -} // namespace nnvm - -#endif // NNVM_LAYOUT_H_ diff --git a/include/nnvm/node.h b/include/nnvm/node.h deleted file mode 100644 index ae782f04965e..000000000000 --- a/include/nnvm/node.h +++ /dev/null @@ -1,201 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file nnvm/node.h - * \brief Graph node data structure. - */ -#ifndef NNVM_NODE_H_ -#define NNVM_NODE_H_ - -#include -#include -#include -#include -#include "base.h" -#include "op.h" -#include "c_api.h" - -namespace nnvm { - -// Forward declare node. -class Node; -class Symbol; - -/*! - * \brief we always used NodePtr for a reference pointer - * to the node, so this alias can be changed in case. - * - * By default, NodePtr is a std::shared_ptr of node - */ -using NodePtr = std::shared_ptr; - -/*! \brief an entry that represents output data from a node */ -struct NodeEntry { - /*! \brief the source node of this data */ - NodePtr node; - /*! \brief index of output from the source. */ - uint32_t index; - /*! - * \brief version of input Variable. - * This field can only be nonzero when this->node is a Variable node. - * version is increased by one each time a Variable get composed to a mutation Op. - * This information can be helpful to decide order of operations when sequence of mutation happens. - */ - uint32_t version; -}; - -/*! - * \brief This lets you use a NodeEntry as a key in a unordered_map of the form - * unordered_map - */ -struct NodeEntryHash { - size_t operator()(const NodeEntry& e) const { - return std::hash()(e.node.get()) ^ - (std::hash()(e.index) << 1 >> 1) ^ - (std::hash()(e.version) << 1); - } -}; - -/*! - * \brief This lets you use a NodeEntry as a key in a unordered_map of the form - * unordered_map - */ -struct NodeEntryEqual { - size_t operator()(const NodeEntry& a, const NodeEntry& b) const { - return (a.node.get() == b.node.get()) && - (a.index == b.index) && - (a.version == b.version); - } -}; - -/*! use NodeEntry as key in unordered_map */ -template -using NodeEntryMap = std::unordered_map; - -/*! - * \brief The attributes of the current operation node. - * Usually are additional parameters like axis, - */ -struct NodeAttrs { - /*! - * \brief The operator this node uses. - * For place holder variable, op == nullptr. - */ - const Op *op{nullptr}; - /*! \brief name of the node */ - std::string name; - /*! \brief The dictionary representation of attributes */ - std::unordered_map dict; - /*! - * \brief A parsed version of attributes, - * This is generated if OpProperty.attr_parser is registered. - * The object can be used to quickly access attributes. - */ - any parsed; - /*! - * \brief Some operators take graphs as input. These operators include - * control flow operators and high-order functions. - * These graphs don't change when the operators are invoked for different - * mini-batches. In this sense, the subgraphs are kind of similar to - * the parameters and show be kept as node attributes. - * - * Users need to make sure the subgraphs are disjoint with the main graph. - * If a graph shares nodes with subgraphs, loading the graph from LoadJSON - * may generate a graph that has a different structure from the original graph - * (some of the nodes are duplicated). If nodes are shared between two graphs, - * shared nodes might be executed multiple times, which can be a problem for - * stateful operators. - */ - std::vector > subgraphs; -}; - -/*! - * \brief Node represents an operation in a computation graph. - */ -class NNVM_DLL Node { - public: - /*! \brief The attributes in the node. */ - NodeAttrs attrs; - /*! \brief inputs to this node */ - std::vector inputs; - /*! - * \brief Optional control flow dependencies - * Gives operation must be performed before this operation. - */ - std::vector control_deps; - /*! \brief additional fields for this node */ - any info; - /*! \brief destructor of node */ - ~Node(); - /*! \return operator in this node */ - inline const Op* op() const; - /*! - * \brief return whether node is placeholder variable. - * This is equivalent to op == nullptr - * \return whether node is placeholder input variable - */ - inline bool is_variable() const; - /*! \return number of outputs from this node */ - inline uint32_t num_outputs() const; - /*! \return number of inputs from this node */ - inline uint32_t num_inputs() const; - /*! - * \brief create a new empty shared_ptr of Node. - * \return a created empty node. - */ - static NodePtr Create(); -}; - -/*! - * \brief Quick utilities make node. - * \param op_name The name of operator - * \param node_name The name of the node - * \param inputs The input entries - * \param attrs The attributes - * \return The created node entry. - */ -inline NodeEntry MakeNode( - const char* op_name, - std::string node_name, - std::vector inputs, - std::unordered_map attrs = - std::unordered_map()) { - NodePtr p = Node::Create(); - p->attrs.op = nnvm::Op::Get(op_name); - p->attrs.name = std::move(node_name); - p->attrs.dict = attrs; - if (p->attrs.op->attr_parser) { - p->attrs.op->attr_parser(&(p->attrs)); - } - p->inputs = std::move(inputs); - return NodeEntry{p, 0, 0}; -} - -// implementation of functions. -inline const Op* Node::op() const { - return this->attrs.op; -} -inline bool Node::is_variable() const { - return this->op() == nullptr; -} - -inline uint32_t Node::num_outputs() const { - if (is_variable()) return 1; - if (this->op()->get_num_outputs == nullptr) { - return this->op()->num_outputs; - } else { - return this->op()->get_num_outputs(this->attrs); - } -} - -inline uint32_t Node::num_inputs() const { - if (is_variable()) return 1; - if (this->op()->get_num_inputs == nullptr) { - return this->op()->num_inputs; - } else { - return this->op()->get_num_inputs(this->attrs); - } -} - -} // namespace nnvm - -#endif // NNVM_NODE_H_ diff --git a/include/nnvm/op.h b/include/nnvm/op.h deleted file mode 100644 index 9d171bbdb2bc..000000000000 --- a/include/nnvm/op.h +++ /dev/null @@ -1,562 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file nnvm/op.h - * \brief Operator information structor. - */ -#ifndef NNVM_OP_H_ -#define NNVM_OP_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "base.h" -#include "c_api.h" - -namespace nnvm { - -// forward declarations -class Node; -struct NodeAttrs; -template -class OpMap; -class OpGroup; -class OpRegistryEntry; -using dmlc::ParamFieldInfo; - -/*! \brief constant to indicate it take any length of positional inputs */ -static const uint32_t kVarg = std::numeric_limits::max(); - -/*! - * \brief Operator structure. - * - * Besides the fields in the structure, - * arbitary additional information can be associated with each op. - * See function GetAttr for details. - * - * \code - * // Example usage of Op - * - * // registeration of oeprators - * // NOTE that the attr function can register any - * // additional attributes to the operator - * NNVM_REGISTER_OP(add) - * .describe("add two inputs together") - * .set_num_inputs(2) - * .set_attr("OpKernel", AddKernel) - * .include("ElementwiseOpAttr"); - * - * // can register attribute by group - * // all the ops that include the group get the attribute. - * NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr) - * .set_attr("FInferShape", ElementwiseInferShape); - * - * NNVM_REGISTER_OP(sub) - * .describe("substract one tensor from another") - * .set_num_inputs(2); - * - * // Can call regster multiple times in different files - * // to register different part of information - * NNVM_REGISTER_OP(sub) - * .set_attr("OpKernel", SubKernel); - * .include("ElementwiseOpAttr"); - * - * // get operators from registry. - * void my_function() { - * const Op* add = Op::Get("add"); - * const Op* sub = Op::Get("sub"); - * // query basic information about each operator. - * assert(op->name == "plus"); - * assert(op->num_inputs == 2); - * - * // get additional registered information, - * // Assume user registered a OpKernel type attribute as gpu_kernel on each operator. - * const OpMap& kernel = Op::GetAttr("OpKernel"); - * // we can get the kernel functions by using operator as key. - * auto add_kernel = kernel[add]; - * auto sub_kernel = kernel[sub]; - * // subsequent code can make use of the queried kernel functions. - * } - * \endcode - */ -class NNVM_DLL Op { - public: - /*! \brief name of the operator */ - std::string name; - /*! - * \brief detailed description of the operator - * This can be used to generate docstring automatically for the operator. - */ - std::string description; - /* \brief description of inputs and keyword arguments*/ - std::vector arguments; - /*! - * \brief number of inputs to the operator, - * -1 means it is variable length - * When get_num_inputs is presented, - * the number will be decided by get_num_inputs instead. - * \sa get_num_inputs - */ - uint32_t num_inputs = 1; - /*! - * \brief number of outputs of the operator - * When get_num_outputs is presented. - * The number of outputs will be decided by - * get_num_outputs function - * \sa get_num_outputs - */ - uint32_t num_outputs = 1; - /*! - * \brief support level of the operator, - * The lower the more priority it contains. - * This is in analogies to BLAS levels. - */ - uint32_t support_level = 10; - /*! - * \brief get number of outputs given information about the node. - * \param attrs The attribute of the node - * \return number of outputs. - */ - std::function get_num_outputs = nullptr; - /*! - * \brief get number of inputs given information about the node. - * \param attrs The attribute of the node - * \return number of inputs - */ - std::function get_num_inputs = nullptr; - /*! - * \brief Attribute parser to parse the NodeAttrs information. - * - * This can help to get quick access to a parsed attribute - * object - * - * \code - * // Example usage of attr_parser. - * - * // Suppose we want to register operator sum. - * // The parameters about sum operator - * struct SumParam { - * int axis; - * }; - * // The parser function - * void SumAttrParser(NodeAttrs* attrs) { - * // This will be invoked during node construction. - * SumParam param; - * // parse axis string to integer - * param.axis = atoi(attrs->dict["axis"].c_str()); - * // set the parsed parameter - * attrs->parsed = std::move(param); - * } - * // The other function that can utilize the parsed result. - * TShape SumInferShape(const NodeAttrs& attrs, - * const std::vector& ishapes) { - * // we can use the parsed version of param - * // without repeatively parsing the parameter - * const SumParam& param = nnvm::get(attrs.parsed); - * } - * \endcode - */ - std::function attr_parser = nullptr; - // function fields. - /*! - * \brief setter function during registration - * Set the description of operator - * \param descr the description string. - * \return reference to self. - */ - inline Op& describe(const std::string& descr); // NOLINT(*) - /*! - * \brief Add argument information to the function. - * \param name Name of the argument. - * \param type Type of the argument. - * \param description Description of the argument. - * \return reference to self. - */ - inline Op& add_argument(const std::string &name, - const std::string &type, - const std::string &description); - /*! - * \brief Append list if arguments to the end. - * \param args Additional list of arguments. - * \return reference to self. - */ - inline Op& add_arguments(const std::vector &args); - /*! - * \brief Set the num_inputs - * \param n The number of inputs to be set. - * \return reference to self. - */ - inline Op& set_num_inputs(uint32_t n); // NOLINT(*) - /*! - * \brief Set the support level of op. - * \param level The support level. - * \return reference to self. - */ - inline Op& set_support_level(uint32_t level); // NOLINT(*) - /*! - * \brief Set the get_num_outputs function. - * \param fn The function to be set. - * \return reference to self. - */ - inline Op& set_num_inputs(std::function fn); // NOLINT(*) - /*! - * \brief Set the num_outputs - * \param n The number of outputs to be set. - * \return reference to self. - */ - inline Op& set_num_outputs(uint32_t n); // NOLINT(*) - /*! - * \brief Set the get_num_outputs function. - * \param fn The function to be set. - * \return reference to self. - */ - inline Op& set_num_outputs(std::function fn); // NOLINT(*) - /*! - * \brief Set the attr_parser function. - * \param fn The number of outputs to be set. - * \return reference to self. - */ - inline Op& set_attr_parser(std::function fn); // NOLINT(*) - /*! - * \brief Register additional attributes to operator. - * \param attr_name The name of the attribute. - * \param value The value to be set. - * \param plevel The priority level of this set, - * an higher priority level attribute - * will replace lower priority level attribute. - * Must be bigger than 0. - * - * Cannot set with same plevel twice in the code. - * - * \tparam ValueType The type of the value to be set. - */ - template - inline Op& set_attr(const std::string& attr_name, // NOLINT(*) - const ValueType& value, - int plevel = 10); - /*! - * \brief Add another alias to this operator. - * The same Op can be queried with Op::Get(alias) - * \param alias The alias of the operator. - * \return reference to self. - */ - Op& add_alias(const std::string& alias); // NOLINT(*) - /*! - * \brief Include all the attributes from an registered op group. - * \param group_name The name of the group. - * \return reference to self. - * - * \sa NNVM_REGISTER_OP_GROUP - */ - Op& include(const std::string& group_name); - /*! - * \brief Get an Op for a given operator name. - * Will raise an error if the op has not been registered. - * \param op_name Name of the operator. - * \return Pointer to a Op, valid throughout program lifetime. - */ - static const Op* Get(const std::string& op_name); - /*! - * \brief Get additional registered attribute about operators. - * If nothing has been registered, an empty OpMap will be returned. - * \param attr_name The name of the attribute. - * \return An OpMap of specified attr_name. - * \tparam ValueType The type of the attribute. - */ - template - static const OpMap& GetAttr(const std::string& attr_name); - - private: - template - friend class OpMap; - friend class OpGroup; - friend class dmlc::Registry; - // Program internal unique index of operator. - // Used to help index the program. - uint32_t index_{0}; - // internal constructor - Op(); - // get const reference to certain attribute - static const any* GetAttrMap(const std::string& key); - // update the attribute OpMap - static void UpdateAttrMap(const std::string& key, - std::function updater); - // add a trigger based on tag matching on certain tag attribute - // This will apply trigger on all the op such that - // include the corresponding group. - // The trigger will also be applied to all future registrations - // that calls include - static void AddGroupTrigger(const std::string& group_name, - std::function trigger); -}; - -/*! - * \brief A map data structure that takes Op* as key - * and returns ValueType - * \tparam ValueType The type of the value stored in map. - */ -template -class OpMap { - public: - /*! - * \brief get the corresponding value element at op - * \param op The key to the map - * \return the const reference to the content value. - */ - inline const ValueType& operator[](const Op* op) const; - /*! - * \brief get the corresponding value element at op with default value. - * \param op The key to the map - * \param def_value The default value when the key does not exist. - * \return the const reference to the content value. - */ - inline const ValueType& get(const Op* op, const ValueType& def_value) const; - /*! - * \brief Check if the map has op as key. - * \param op The key to the map - * \return 1 if op is contained in map, 0 otherwise. - */ - inline int count(const Op* op) const; - - private: - friend class Op; - // internal attribute name - std::string attr_name_; - // internal data - std::vector > data_; - OpMap() = default; -}; - -/*! - * \brief auxiliary data structure used to - * set attributes to a group of operators - */ -class OpGroup { - public: - /*! \brief the tag key to be matched */ - std::string group_name; - /*! - * \brief Register additional attributes to operator group. - * \param attr_name The name of the attribute. - * \param value The value to be set. - * \param plevel The priority level of this set, - * an higher priority level attribute - * will replace lower priority level attribute. - * Must be bigger than 0. - * - * Cannot set with same plevel twice in the code. - * - * \tparam ValueType The type of the value to be set. - */ - template - inline OpGroup& set_attr(const std::string& attr_name, // NOLINT(*) - const ValueType& value, - int plevel = 1); -}; - -// internal macros to make -#define NNVM_REGISTER_VAR_DEF(OpName) \ - static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName - -#define NNVM_REGISTER_GVAR_DEF(TagName) \ - static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_ ## NnvmOpGroup ## _ ## TagName - -/*! - * \def NNVM_REGISTER_OP - * \brief Register a new operator, or set attribute of the corresponding op. - * - * \param OpName The name of registry - * - * \code - * - * NNVM_REGISTER_OP(add) - * .describe("add two inputs together") - * .set_num_inputs(2) - * .set_attr("gpu_kernel", AddKernel); - * - * \endcode - */ -#define NNVM_REGISTER_OP(OpName) \ - DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \ - ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName) - -/*! - * \def NNVM_REGISTER_OP_GROUP - * \brief Register attribute to a group of operators. - * These attributes will be registered to Op that include the group. - * - * \param GroupName The name of the group. - * - * \code - * - * NNVM_REGISTER_OP(add) - * .include("ElementwiseOpAttr"); - * - * // register same attributes to all the ops that include the group - * NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr) - * .set_attr("FInferShape", ElementwiseInferShape); - * - * NNVM_REGISTER_OP(mul) - * .include("ElementwiseOpAttr"); - * - * \endcode - */ -#define NNVM_REGISTER_OP_GROUP(GroupName) \ - DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = \ - ::nnvm::OpGroup {#GroupName} - -// implementations of template functions after this. -// member function of Op -template -inline const OpMap& Op::GetAttr(const std::string& key) { - const any* ref = GetAttrMap(key); - if (ref == nullptr) { - // update the attribute map of the key by creating new empty OpMap - UpdateAttrMap(key, [key](any* pmap) { - // use callback so it is in lockscope - if (pmap->empty()) { - OpMap pm; - pm.attr_name_ = key; - *pmap = std::move(pm); - } - }); - ref = GetAttrMap(key); - } - return nnvm::get >(*ref); -} - -template -inline Op& Op::set_attr( // NOLINT(*) - const std::string& attr_name, - const ValueType& value, - int plevel) { - CHECK_GT(plevel, 0) - << "plevel in set_attr must be greater than 0"; - // update the attribute map of the key by creating new empty if needed. - UpdateAttrMap(attr_name, - [this, attr_name, value, plevel](any* pmap) { - // the callback is in lockscope so is threadsafe. - if (pmap->empty()) { - OpMap pm; - pm.attr_name_ = attr_name; - *pmap = std::move(pm); - } - CHECK(pmap->type() == typeid(OpMap)) - << "Attribute " << attr_name - << " of operator " << this->name - << " is registered as inconsistent types" - << " previously " << pmap->type().name() - << " current " << typeid(OpMap).name(); - std::vector >& vec = - nnvm::get >(*pmap).data_; - // resize the value type. - if (vec.size() <= index_) { - vec.resize(index_ + 1, - std::make_pair(ValueType(), 0)); - } - std::pair& p = vec[index_]; - CHECK(p.second != plevel) - << "Attribute " << attr_name - << " of operator " << this->name - << " is already registered with same plevel=" << plevel; - if (p.second < plevel) { - vec[index_] = std::make_pair(value, plevel); - } - }); - return *this; -} - - -inline Op& Op::describe(const std::string& descr) { // NOLINT(*) - this->description = descr; - return *this; -} - -inline Op& Op::add_argument(const std::string &name, - const std::string &type, - const std::string &description) { - arguments.push_back({name, type, type, description}); - return *this; -} - -inline Op& Op::add_arguments(const std::vector &args) { - this->arguments.insert(arguments.end(), args.begin(), args.end()); - return *this; -} - -inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*) - this->num_inputs = n; - return *this; -} - -inline Op& Op::set_support_level(uint32_t n) { // NOLINT(*) - this->support_level = n; - return *this; -} - -inline Op& Op::set_num_inputs(std::function fn) { // NOLINT(*) - this->get_num_inputs = fn; - return *this; -} - -inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*) - this->num_outputs = n; - return *this; -} - -inline Op& Op::set_num_outputs(std::function fn) { // NOLINT(*) - this->get_num_outputs = fn; - return *this; -} - -inline Op& Op::set_attr_parser(std::function fn) { // NOLINT(*) - this->attr_parser = fn; - return *this; -} - -// member functions of OpMap -template -inline int OpMap::count(const Op* op) const { - if (op == nullptr) return 0; - const uint32_t idx = op->index_; - return idx < data_.size() ? (data_[idx].second != 0) : 0; -} - -template -inline const ValueType& OpMap::operator[](const Op* op) const { - CHECK(op != nullptr); - const uint32_t idx = op->index_; - CHECK(idx < data_.size() && data_[idx].second) - << "Attribute " << attr_name_ - << " has not been registered for Operator " << op->name; - return data_[idx].first; -} - -template -inline const ValueType& OpMap::get(const Op* op, const ValueType& def_value) const { - if (op == nullptr) return def_value; - const uint32_t idx = op->index_; - if (idx < data_.size() && data_[idx].second) { - return data_[idx].first; - } else { - return def_value; - } -} - -template -inline OpGroup& OpGroup::set_attr(const std::string& attr_name, - const ValueType& value, - int plevel) { - auto trigger = [attr_name, value, plevel](Op* op) { - op->set_attr(attr_name, value, plevel); - }; - Op::AddGroupTrigger(group_name, trigger); - return *this; -} - -} // namespace nnvm - -#endif // NNVM_OP_H_ diff --git a/include/nnvm/op_attr_types.h b/include/nnvm/op_attr_types.h deleted file mode 100644 index abed19f9bc7d..000000000000 --- a/include/nnvm/op_attr_types.h +++ /dev/null @@ -1,219 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file nnvm/op_attr_types.h - * \brief Data structures that can appear in operator attributes. - */ -#ifndef NNVM_OP_ATTR_TYPES_H_ -#define NNVM_OP_ATTR_TYPES_H_ - -#include -#include -#include -#include -#include "base.h" -#include "node.h" -#include "tuple.h" -#include "layout.h" - -namespace nnvm { - -// These types are optional attributes in each operator. -// Each attribute can be required by some passes. - -/*! - * \brief Return list of input arguments names of each operator. - * - * \param attrs The attributes of the node. - * \return list of inputs - * \note Register under "FListInputNames", default return {"data"}. - * - * FListInputNames enables automatic variable creation for missing arguments. - */ -using FListInputNames = std::function (const NodeAttrs& attrs)>; - -/*! - * \brief Return number of visible outputs by the user. - * - * \param attrs The attributes of the node. - * - * \note Register under "FNumVisibleOutputs", default not registered. - * This can be used to hide certain output from the user, - * but the additional outputs can be used to pass information from - * forward to gradient pass. - */ -using FNumVisibleOutputs = std::function; - -/*! - * \brief Return list of output arguments names of each operator. - * - * \param attrs The attributes of the node. - * \return list of inputs - * \note Register under "FListOutputNames", default return {"outputs"}. - * - * FListOutputNames customized naming for operator outputs. - */ -using FListOutputNames = std::function (const NodeAttrs& attrs)>; - -/*! - * \brief Check whether operator will mutate k-th input. - * \param attrs The attributes of the node. - * \return list of input indices it mutates. - * - * \note Register under "FMutateInputs", default return false - * FMutateInputs enables mutation order handling correctly. - */ -using FMutateInputs = std::function (const NodeAttrs& attrs)>; - -/*! - * \brief Inference function of certain type. - * \tparam AttrType The type of the attribute to be infered. - * \return whether all attributes are inferred. - */ -template -using FInferNodeEntryAttr = std::function *in_attrs, - std::vector *out_attrs)>; - -/*! - * \brief Get attribute dictionary from node. - * - * \param attrs The attributes of the node. - * \return The attribute dict. - * \note Register under "FUpdateAttrDict" - */ -using FGetAttrDict = std::function< - std::unordered_map - (const NodeAttrs& attrs)>; - -/*! - * \brief Shape inference function. - * Update the shapes given the input shape information. - * TShape.ndim() == 0 means the shape is still unknown. - * - * \note Register under "FInferShape", - * by default do not update any shapes. - * - * FInferShape is needed by shape inference - */ -using FInferShape = FInferNodeEntryAttr; - -/*! - * \brief Type inference function. - * Update the type given the known type information. - * - * \note Register under "FInferType", - * by default set all the output types to 0. - */ -using FInferType = FInferNodeEntryAttr; - -/*! - * \brief Whether this op is an explicit backward operator, - * If TIsBackward is true: - * - The first control_deps of the node points to the corresponding forward operator. - * - * \note Register under "TIsBackward" - * This enables easier shape/type inference for backward operators. - */ -using TIsBackward = bool; - -/*! - * \brief Get possible inplace options. - * This function enables optimization to reuse memory of inputs in output. - * \param attrs The attributes of the node - * \return list of pair of that maps input->output, - * indicating possible in place operations. - * - * \note Register under "FInplaceOption", by default no inplace can happen. - */ -using FInplaceOption = std::function< - std::vector > (const NodeAttrs& attrs)>; - -/*! - * \brief Get if the inplace option is an identity - * This function enables inplace optimization even when input reference count - * is greater than one. - * \param attrs The attributes of the node - * \return list of bool indicating whether corresponding pair from FInplaceOption - * is an identity - * - * \note Register under "FInplaceIdentity", by default no identities. - */ -using FInplaceIdentity = std::function (const NodeAttrs& attrs)>; - -/*! - * \brief Get list of inputs in the op whose content are actually not used by the operator - * These are dummy input that can be used for example in zeros_like, ones_like. - * - * \param attrs The attributes of the node - * \return list input index that are not used by the operator. - * - * \note Register under "FIgnoreInputs". - */ -using FIgnoreInputs = std::function< - std::vector (const NodeAttrs& attrs)>; - -/*! - * \brief Get the gradient node of the op node - * This function generates the backward graph of the node - * \param nodeptr The node to take gradient - * \param out_grads Gradient of current node's outputs - * \return gradients of the inputs - * - * \note Register under "FGradient" - */ -using FGradient = std::function( - const NodePtr& nodeptr, - const std::vector& out_grads)>; - -/*! - * \brief Set the attributes of input variable. - * Usually used for setting initialization or weight decay. - * \param attrs The attributes of this node. - * \param var the input variable - * \param index index of var in all inputs - */ -using FSetInputVarAttrOnCompose = std::function; - -/*! - * \brief Infer & correct function of node layout. See \p Layout for layout convention - * \param attrs The attribute of the node. - * \param ilayouts Given the input layouts produced by ancestor nodes, - * it should be filled by layouts that the node requests. - * If the requested layout is different from what ancestor produces, - * a __layout_transform__ operator will be inserted automatically. - * \param last_ilayouts The input layouts requested by the node - * at the last infer pass (if any). - * This can be useful when an operator wants to keep - * the input layout the same as the original one. - * For example, after the pass of AlterOpLayout, - * transpose(input, axis=[1, 2, 3, 0]) may receive an input of NCHW16c layout, - * with which it cannot calculate with axis=[1, 2, 3, 0]. - * Last input layouts allow it to know what the layout it originally inferred, - * i.e., the layout in the imported model. - * \param olayouts Inferred output layouts. - * \return success flag. - */ -using FCorrectLayout = std::function *ilayouts, - const std::vector *last_ilayouts, - std::vector *olayouts)>; - -/*! - * \brief Get a list of inputs that represent graphs instead of data. - * Normally, input symbols are considered as data to the operator. However, - * control flow operators and high-order functions need to interpret symbols - * as graphs. - * \param attrs The attributes of this node. - * \return a list of input index that are interpreted as symbols by the operator. - * - * \note Register under "FInputGraph". - */ -using FInputGraph = std::function(const NodeAttrs& attrs)>; - -} // namespace nnvm - -#endif // NNVM_OP_ATTR_TYPES_H_ diff --git a/include/nnvm/pass.h b/include/nnvm/pass.h deleted file mode 100644 index 2e8db6111887..000000000000 --- a/include/nnvm/pass.h +++ /dev/null @@ -1,128 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file nnvm/pass.h - * \brief Pass that can be applied to a graph. - */ -#ifndef NNVM_PASS_H_ -#define NNVM_PASS_H_ - -#include -#include -#include "base.h" -#include "graph.h" - -namespace nnvm { - -/*! - * \brief A PassFunction is an "Operator on Graph". - * It takes a source graph and return a graph that may or may - * not be the same as the input one. - * - * A pass function can either change the graph structure (thus, - * generating a new Graph), or add new attributes to the graph. - * - * \param src The graph to be transformed. - * \return The generated graph. - */ -typedef std::function PassFunction; - -/*! - * \brief Apply a series of pass transformations on the input graph. - * \param src The graph to be transformed. - * \param passes A list of pass names to be applied. - * \return The transformed graph - */ -Graph ApplyPasses(Graph src, - const std::vector& passes); - -/*! - * \brief Apply one pass to the graph. - * \param src The graph to be transformed. - * \param pass The name of pass to be applied. - * \return The transformed graph. - */ -inline Graph ApplyPass(Graph src, const std::string& pass) { - return ApplyPasses(src, {pass}); -} - - -/*! - * \brief Registry entry for pass functions. - */ -struct PassFunctionReg - : public dmlc::FunctionRegEntryBase { - /*! - * \brief Whether the pass will change graph structure - * If this is false, the pass will only change attributes. - */ - bool change_graph{false}; - /*! \brief dependencies on operator attributes */ - std::vector op_attr_dependency; - /*! \brief dependencies on attributes in the graph */ - std::vector graph_attr_dependency; - /*! \brief generated targets of graph attributes */ - std::vector graph_attr_targets; - /*! - * \brief Set whether this pass will change graph structure. - * \param v If true, the pass will change graph structure. - * \return Reference to self. - */ - PassFunctionReg& set_change_graph(bool v) { // NOLINT(*) - change_graph = v; - return *this; - } - /*! - * \brief Declare that this pass will generate the given graph attribute name - * once it is applied on the graph. - * \param attr_name Name of the graph attribute. - * \return Reference to self. - */ - PassFunctionReg& provide_graph_attr(const std::string& attr_name) { // NOLINT(*) - graph_attr_targets.push_back(attr_name); - return *this; - } - /*! - * \brief Declare this pass requires the given operator attribute to be - * available before being applied on the graph. - * \param attr_name Name of the attribute. - * \return Reference to self. - */ - PassFunctionReg& depend_op_attr(const std::string& attr_name) { // NOLINT(*) - op_attr_dependency.push_back(attr_name); - return *this; - } - /*! - * \brief Declare this pass requires the given graph attribute to be - * available before being applied on the graph. - * \param attr_name Name of the attribute. - * \return Reference to self. - */ - PassFunctionReg& depend_graph_attr(const std::string& attr_name) { // NOLINT(*) - graph_attr_dependency.push_back(attr_name); - return *this; - } -}; - -/*! - * \def NNVM_REGISTER_PASS - * \brief Macro to register pass fuctions. - * - * \code - * // example of registering a shape inference pass - * NNVM_REGISTER_PASS(InferShape) - * .describe("Shape Inference function, generate graph attributes") - * .provide_graph_attr("data_shape") - * .depend_graph_attr("indexed_graph") - * .depend_op_attr("infer_shape") - * .set_body([](const Graph& g) { - * // shape inference logic - * }); - * \endcode - */ -#define NNVM_REGISTER_PASS(name) \ - DMLC_REGISTRY_REGISTER(::nnvm::PassFunctionReg, PassFunctionReg, name) - -} // namespace nnvm - -#endif // NNVM_PASS_H_ diff --git a/include/nnvm/pass_functions.h b/include/nnvm/pass_functions.h deleted file mode 100644 index 5a98dd456fb2..000000000000 --- a/include/nnvm/pass_functions.h +++ /dev/null @@ -1,190 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file nnvm/pass_functions.h - * \brief Pass functions that simply redirect the calls to ApplyPass - * - * This file serves as documentation on how to use functions implemented in "src/pass". - * It is totally optional to add these functions when you add a new pass, since - * ApplyPass can be directly called. - */ -#ifndef NNVM_PASS_FUNCTIONS_H_ -#define NNVM_PASS_FUNCTIONS_H_ - -#include -#include -#include -#include "base.h" -#include "pass.h" -#include "graph_attr_types.h" - -namespace nnvm { -namespace pass { - -/*! - * \brief Load a graph from JSON string, redirects to "LoadJSON" pass. - * \param json_str The json string. - * \return Loaded graph. - */ -inline Graph LoadJSON(const std::string& json_str) { - Graph ret; - ret.attrs["json"] = std::make_shared(json_str); - return ApplyPass(ret, "LoadJSON"); -} - -/*! - * \brief Save a graph to json, redirects to "SaveJSON" pass. - * \param graph The graph to be saved as json format. - * \return The json string. - */ -inline std::string SaveJSON(Graph graph) { - Graph ret = ApplyPass(std::move(graph), "SaveJSON"); - return ret.GetAttr("json"); -} - - -/*! - * \brief Print graph ir - * \param graph The graph to be printed - * \return The graph ir string. - */ -inline std::string PrintGraphIR(Graph graph) { - Graph ret = ApplyPass(std::move(graph), "PrintGraphIR"); - return ret.GetAttr("graphir"); -} - -/*! - * \brief Add control flow dependencies between nodes. - * - * This function will enforce the correct order between - * write (mutable operators) and read (immutable operators) - * to sovle write-after-read and read-after-write problems. - * - * \param src The input graph. - * \return A graph with proper control flow dependencies added. - */ -inline Graph OrderMutation(Graph src) { - return ApplyPass(std::move(src), "OrderMutation"); -} - -/*! - * \brief Infer shapes in the graph given the information. - * \param graph The input graph. - * \param shape_inputs The shapes of input symbols to the graph. - * \param shape_attr_key The key to the node attribute that can indicate shape. This is - * the place where manual hint for shapes could be injected. - * \return A graph with new attribute "shape" containing inferred shape of each NodeEntry. - * The index of ShapeVector is given by graph.indexed_graph().entry_id. - */ -inline Graph InferShape(Graph graph, - ShapeVector shape_inputs, - std::string shape_attr_key = "") { - if (shape_inputs.size() != 0) { - graph.attrs["shape_inputs"] = std::make_shared(std::move(shape_inputs)); - } - if (shape_attr_key.length() != 0) { - graph.attrs["shape_attr_key"] = std::make_shared(std::move(shape_attr_key)); - } - return ApplyPass(std::move(graph), "InferShape"); -} - -/*! - * \brief Infer types in the graph given the information. - * \param graph The input graph. - * \param dtype_inputs The types of input symbols to the graph. - * \param dtype_attr_key The key to the node attribute that can indicate types. This is - * the place where manual hint for types could be injected. - * \return A graph with new attribute "dtype" containing inferred type of each NodeEntry. - * The index of ShapeVector is given by graph.indexed_graph().entry_id. - */ -inline Graph InferType(Graph graph, - DTypeVector dtype_inputs, - std::string dtype_attr_key = "") { - if (dtype_inputs.size() != 0) { - graph.attrs["dtype_inputs"] = std::make_shared(std::move(dtype_inputs)); - } - if (dtype_attr_key.length() != 0) { - graph.attrs["dtype_attr_key"] = std::make_shared(std::move(dtype_attr_key)); - } - return ApplyPass(std::move(graph), "InferType"); -} - -/*! - * \brief Place the devices for each operator in the graph. - * - * Current device placement is quite simple. Each operator is assigned to a "group" (stored - * in `device_group_attr_key` attribute). Each group is assigned to a device (stored in - * `device_assign_map` attribute). Operators will be placed to the device assigned to its - * group. Copy operators will be injected if cross device reference happens. - * - * \param graph The input graph. - * \param device_group_attr_key The attribute name for hints of device group. - * \param device_assign_map The assignment map of device. - * \param device_copy_op The name of copy op to be inserted when cross device copy happened. - * \return A graph with new attribute "device", cotaining device information of each node. - */ -inline Graph PlaceDevice(Graph graph, - std::string device_group_attr_key, - DeviceAssignMap device_assign_map, - std::string device_copy_op) { - graph.attrs["device_group_attr_key"] = std::make_shared(std::move(device_group_attr_key)); - graph.attrs["device_assign_map"] = std::make_shared(std::move(device_assign_map)); - graph.attrs["device_copy_op"] = std::make_shared(std::move(device_copy_op)); - return ApplyPass(std::move(graph), "PlaceDevice"); -} - -/*! - * \brief Get the gradient graph whose outputs are gradients of xs wrt to ys. - * \param graph The input graph. - * \param ys The entries we want to take gradient from. - * \param xs The input to take gradient with respect to. - * \param ys_out_grad The symbol for additional gradient to be propagate back to y. - * \param aggregate_fun Aggregation function applied to aggregate the inputs. - * \param mirror_fun Optional mirror function to do mirror optimization and save memory. - * \param attr_hint_fun Optional, hint function to output a node that like src, but its attr is same as like. - * \param zero_ops Optional, list of operators that outputs a single zero array. The first one - * must be zeros_like. - * \param copy_op_str Optional, name of the copy operation required to handle duplicates - * on the edge of the graph - * \return A new graph, whose outputs correspond to inputs of xs. - */ -inline Graph Gradient( - Graph graph, - std::vector ys, - std::vector xs, - std::vector ys_out_grad, - std::function&& inputs)> aggregate_fun = nullptr, - std::function mirror_fun = nullptr, - std::function - attr_hint_fun = nullptr, - std::vector zero_ops = std::vector(), - std::string copy_op_str = std::string()) { - graph.attrs["grad_ys"] = std::make_shared(std::move(ys)); - - graph.attrs["grad_xs"] = std::make_shared(std::move(xs)); - graph.attrs["grad_ys_out_grad"] = std::make_shared(std::move(ys_out_grad)); - if (aggregate_fun != nullptr) { - graph.attrs["grad_aggregate_fun"] = std::make_shared(aggregate_fun); - } - - if (mirror_fun != nullptr) { - graph.attrs["grad_mirror_fun"] = std::make_shared(mirror_fun); - } - - if (attr_hint_fun != nullptr) { - graph.attrs["attr_hint_fun"] = std::make_shared(attr_hint_fun); - } - - if (zero_ops.size()) { - graph.attrs["zero_ops"] = std::make_shared(std::move(zero_ops)); - } - - if (copy_op_str != std::string()) { - graph.attrs["copy_op"] = std::make_shared(std::move(copy_op_str)); - } - - return ApplyPass(std::move(graph), "Gradient"); -} - -} // namespace pass -} // namespace nnvm -#endif // NNVM_PASS_FUNCTIONS_H_ diff --git a/include/nnvm/symbolic.h b/include/nnvm/symbolic.h deleted file mode 100644 index 42cf5dd775c2..000000000000 --- a/include/nnvm/symbolic.h +++ /dev/null @@ -1,217 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file nnvm/symbolic.h - * \brief Symbolic graph construction API - * - * This API is optional, but useful to allow user - * to construct NNVM Graph easily, and quickly create - * front-end host languages. - */ -#ifndef NNVM_SYMBOLIC_H_ -#define NNVM_SYMBOLIC_H_ - -#include -#include -#include -#include - -#include "base.h" -#include "node.h" - -namespace nnvm { -/*! - * \brief Symbol is help class used to represent the operator node in Graph. - * - * Symbol acts as an interface for building graphs from different components - * like Variable, Functor and Group. Symbol is also exported to python front-end - * (while Graph is not) to enable quick test and deployment. Conceptually, - * symbol is the final operation of a graph and thus including all the information - * required (the graph) to evaluate its output value. - */ -class NNVM_DLL Symbol { - public: - /*! \brief option passed to ListAttr */ - enum ListAttrOption { - /*! \brief recursively list all attributes */ - kRecursive = 0, - /*! \brief only list attributes in current node */ - kShallow = 1 - }; - /*! \brief option passed to ListInputNames */ - enum ListInputOption { - /*! \brief list all the arguments */ - kAll = 0, - /*! \brief list only read only arguments */ - kReadOnlyArgs = 1, - /*! - * \brief List auxiliary states that can be mutated by the graph. - * This excludes the ReadOnly arguments - */ - kAuxiliaryStates = 2 - }; - - /*! \brief output entries contained in the symbol */ - std::vector outputs; - - /*! - * \brief Copy the symbol. - * \return A deep copy of this symbol. - */ - Symbol Copy() const; - /*! - * \brief Print the symbol info to output stream. - * \param os The output stream to print to. - */ - void Print(std::ostream &os) const; // NOLINT(*) - /*! - * \brief Get the index-th element from the returned tuple. - * \param index Index of multi output. - * \return The symbol corresponds to the indexed element. - */ - Symbol operator[] (size_t index) const; - /*! - * \brief List the input variable nodes. - * - * The order of the returned list is the same as the order of the input list to `operator()`. - * - * \param option The options to list the arguments. - * \return The arguments list of this symbol, they can be either named or unnamed (empty string). - * \sa ListInputOption - */ - std::vector ListInputs(ListInputOption option) const; - /*! - * \brief List the input names. - * - * The order of the returned list is the same as the order of the input list to `operator()`. - * - * \param option The options to list the arguments. - * \return The arguments list of this symbol, they can be either named or unnamed (empty string). - * \sa ListInputOption - */ - std::vector ListInputNames(ListInputOption option) const; - /*! - * \brief List the names of outputs for this symbol. - * - * For normal operators, it is usually symbol node name + "_output". - * - * \return get the descriptions of outputs for this symbol. - */ - std::vector ListOutputNames() const; - /*! - * \brief Compose the symbol with arguments, this changes the current symbol. - * The kwargs passed in can be in-complete, - * - * The rest of the symbols will remain the same name. - * - * \param args Positional arguments. - * \param kwargs Keyword arguments for the symbol. - * \param name Name of returned symbol. - */ - void Compose(const array_view& args, - const std::unordered_map& kwargs, - const std::string& name); - /*! - * \brief Apply the symbol as a function, compose with arguments - * - * This is equivalent to Copy then Compose. - * - * \param args Positional arguments for the symbol. - * \param kwargs Keyword arguments for the symbol. - * \param name Name of returned symbol. - * \return A new Symbol which is the composition of current symbol with its arguments. - */ - Symbol operator () (const array_view& args, - const std::unordered_map& kwargs, - const std::string& name) const; - /*! - * \brief Add control flow dependencies to the operators in symbols. - * - * For grouped symbol, an error will be raised. This mutates current symbolic Node. - * - * \param src The symbols to depend on. - */ - void AddControlDeps(const Symbol& src); - /* - * \brief Get all the internal nodes of the symbol. - * \return symbol A new symbol whose output contains all the outputs of the symbols - * including input variables and intermediate outputs. - */ - Symbol GetInternals() const; - /* - * \brief Get the direct inputs of the head node(s) of this symbol. - * \return symbol A new symbol whose output contains all the inputs of the head - * node(s). - */ - Symbol GetChildren() const; - /*! - * \brief Set additional attributes to current node. - * - * This only works for symbol with outputs from single operators. - * For grouped symbol, an error will be raised. - * - * This function mutates the node's symbol and is not recommended. - * - * \param attrs The attributes to set. - */ - void SetAttrs(const std::vector >& attrs); - /*! - * \brief Get attributes from the symbol. - * - * This only works for symbol with outputs from single operators. - * For grouped symbol, an error will be raised. - * - * \param key Key of the attribute. When key == "name", it returns the name attirbute. - * \param out The output value of the attribute. - * \return true If the attribute exists, false if the attribute does not exist. - */ - bool GetAttr(const std::string& key, std::string* out) const; - /*! - * \brief Get attribute dictionary from the symbol. - * - * For grouped symbol, an error will be raised. - * - * \param option If recursive flag is set, the attributes of all children are retrieved. - * The name of symbol will be pre-pended to each key. - * \return The created attribute. - */ - std::unordered_map ListAttrs(ListAttrOption option) const; - /*! - * \brief Get attribute dictionary from the symbol and all children. - * - * For grouped symbol, an error will be raised. - * - * \return The created attribute in format . - */ - std::vector > - ListAttrsRecursive() const; - /*! - * \brief Create symbolic functor(AtomicSymbol) by given operator and attributes. - * \param op The operator. - * \param attrs The additional attributes. - * \return Symbol that can be used to call compose further. - */ - static Symbol CreateFunctor(const Op* op, - std::unordered_map attrs); - /*! - * \brief Create symbolic functor(AtomicSymbol) by given node attributes. - * \param attrs pre-initialized Node attributes. - * \return Symbol that can be used to call compose further. - */ - static Symbol CreateFunctor(const NodeAttrs& attrs); - /*! - * \brief Create symbol node representing variable. - * \param name Name of the variable. - * \return The symbol. - */ - static Symbol CreateVariable(const std::string& name); - /*! - * \brief Create equivalence of symbol by grouping the symbols together. - * \param symbols A list of symbols to be grouped. - * \return The grouped symbol. - */ - static Symbol CreateGroup(const std::vector& symbols); -}; - -} // namespace nnvm - -#endif // NNVM_SYMBOLIC_H_ diff --git a/include/nnvm/top/README b/include/nnvm/top/README deleted file mode 100644 index 09a4d6fc387f..000000000000 --- a/include/nnvm/top/README +++ /dev/null @@ -1 +0,0 @@ -NNVM Core Operator and Compiler diff --git a/include/nnvm/top/nn.h b/include/nnvm/top/nn.h deleted file mode 100644 index 143a9548f18a..000000000000 --- a/include/nnvm/top/nn.h +++ /dev/null @@ -1,498 +0,0 @@ -/*! - * Copyright (c) 2017 by Contributors - * \file nnvm/top/nn.h - * \brief Auxiliary param for tensor primitive. - */ -#ifndef NNVM_TOP_NN_H_ -#define NNVM_TOP_NN_H_ - -#include -#include -#include -#include -#include -#include "tensor.h" - -namespace nnvm { -namespace top { - -struct DenseParam : public dmlc::Parameter { - int units; - bool use_bias; - - DMLC_DECLARE_PARAMETER(DenseParam) { - DMLC_DECLARE_FIELD(units).set_lower_bound(1) - .describe("Number of hidden units of the dense transformation."); - DMLC_DECLARE_FIELD(use_bias).set_default(true) - .describe("Whether to use bias parameter"); - } - // constants - static const constexpr int kData = 0; - static const constexpr int kWeight = 1; - static const constexpr int kBias = 2; -}; - -struct DropoutParam : public dmlc::Parameter { - float rate; - - DMLC_DECLARE_PARAMETER(DropoutParam) { - DMLC_DECLARE_FIELD(rate).set_default(0.5) - .set_range(0, 1) - .describe("Fraction of the input that gets dropped out during training time."); - } -}; - -struct BatchNormParam : public dmlc::Parameter { - int axis; - double epsilon; - double momentum; - bool center; - bool scale; - - DMLC_DECLARE_PARAMETER(BatchNormParam) { - DMLC_DECLARE_FIELD(axis).set_default(1) - .describe("Specify which shape axis the channel is specified."); - DMLC_DECLARE_FIELD(epsilon).set_default(1e-5) - .describe("Small float added to variance to avoid dividing by zero."); - DMLC_DECLARE_FIELD(center).set_default(true) - .describe("If True, add offset of `beta` to normalized tensor." - "If False, `beta` is ignored."); - DMLC_DECLARE_FIELD(scale).set_default(true) - .describe("If True, multiply by `gamma`. If False, `gamma` is not used." - "When the next layer is piecewise linear (also e.g. `nn.relu`)," - "this can be disabled since the scaling" - "will be done by the next layer."); - } - // constants - static const constexpr int kData = 0; - static const constexpr int kGamma = 1; - static const constexpr int kBeta = 2; - static const constexpr int kMovingMean = 3; - static const constexpr int kMovingVariance = 4; -}; - - -// Shared by softmax and log_softmax -struct SoftmaxParam : public dmlc::Parameter { - int axis; - - DMLC_DECLARE_PARAMETER(SoftmaxParam) { - DMLC_DECLARE_FIELD(axis).set_default(-1) - .describe("The axis to sum over when computing softmax."); - } -}; - -struct LeakyReLUParam : public dmlc::Parameter { - double alpha; - - DMLC_DECLARE_PARAMETER(LeakyReLUParam) { - DMLC_DECLARE_FIELD(alpha).set_lower_bound(0.0).set_default(0.25) - .describe("slope coefficient for the negative half axis."); - } -}; - -struct PReLUParam : public dmlc::Parameter { - int axis; - DMLC_DECLARE_PARAMETER(PReLUParam) { - DMLC_DECLARE_FIELD(axis).set_default(1) - .describe("Specify which shape axis the channel is specified."); - } -}; - -struct PadParam : public dmlc::Parameter { - float pad_value; - Tuple > pad_width; - - DMLC_DECLARE_PARAMETER(PadParam) { - DMLC_DECLARE_FIELD(pad_value).set_default(0.0) - .describe("The value to be padded."); - DMLC_DECLARE_FIELD(pad_width) - .describe("Number of values padded to the edges of each axis, " - "in the format of ((before_1, after_1), ... (before_N, after_N))"); - } -}; - - -struct Conv2DParam : public dmlc::Parameter { - int channels; - TShape kernel_size; - TShape strides; - TShape padding; - TShape dilation; - int groups; - std::string layout; - std::string kernel_layout; - std::string out_layout; - int out_dtype; - bool use_bias; - - DMLC_DECLARE_PARAMETER(Conv2DParam) { - DMLC_DECLARE_FIELD(channels) - .describe("The dimensionality of the output space" - "i.e. the number of output channels in the convolution."); - DMLC_DECLARE_FIELD(kernel_size) - .describe("Specifies the dimensions of the convolution window."); - DMLC_DECLARE_FIELD(strides).set_default(TShape({1, 1})) - .describe("Specifies the strides of the convolution."); - DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "on both sides for padding number of points"); - DMLC_DECLARE_FIELD(dilation).set_default(TShape({1, 1})) - .describe("Specifies the dilation rate to use for dilated convolution."); - DMLC_DECLARE_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); - DMLC_DECLARE_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Convolution is applied on the 'H' and" - "'W' dimensions."); - DMLC_DECLARE_FIELD(out_layout).set_default("__undef__") - .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Default to be same as input layout."); - DMLC_DECLARE_FIELD(kernel_layout).set_default("OIHW") - .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." - "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" - "dimensions respectively."); - DMLC_DECLARE_DTYPE_FIELD(out_dtype) - .add_enum("same", -1) - .set_default(-1) - .describe("Output data type, set to explicit type under mixed precision setting"); - - DMLC_DECLARE_FIELD(use_bias).set_default(true) - .describe("Whether the layer uses a bias vector."); - } - // constants - static const constexpr int kData = 0; - static const constexpr int kWeight = 1; - static const constexpr int kBias = 2; -}; - -struct WinogradWeightTransformParam : public dmlc::Parameter { - int tile_size; - - DMLC_DECLARE_PARAMETER(WinogradWeightTransformParam) { - DMLC_DECLARE_FIELD(tile_size) - .describe("Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)"); - } - - static const constexpr int kWeight = 0; -}; - -struct WinogradConv2DParam : public dmlc::Parameter { - int channels; - TShape kernel_size; - TShape strides; - TShape padding; - TShape dilation; - int groups; - std::string layout; - std::string kernel_layout; - std::string out_layout; - int out_dtype; - bool use_bias; - int tile_size; - - DMLC_DECLARE_PARAMETER(WinogradConv2DParam) { - DMLC_DECLARE_FIELD(channels) - .describe("The dimensionality of the output space" - "i.e. the number of output channels in the convolution."); - DMLC_DECLARE_FIELD(kernel_size) - .describe("Specifies the dimensions of the convolution window."); - DMLC_DECLARE_FIELD(strides).set_default(TShape({1, 1})) - .describe("Specifies the strides of the convolution."); - DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "on both sides for padding number of points"); - DMLC_DECLARE_FIELD(dilation).set_default(TShape({1, 1})) - .describe("Specifies the dilation rate to use for dilated convolution."); - DMLC_DECLARE_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); - DMLC_DECLARE_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Convolution is applied on the 'H' and" - "'W' dimensions."); - DMLC_DECLARE_FIELD(out_layout).set_default("__undef__") - .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Default to be same as input layout."); - DMLC_DECLARE_FIELD(kernel_layout).set_default("OIHW") - .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." - "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" - "dimensions respectively."); - DMLC_DECLARE_DTYPE_FIELD(out_dtype) - .add_enum("same", -1) - .set_default(-1) - .describe("Output data type, set to explicit type under mixed precision setting"); - DMLC_DECLARE_FIELD(use_bias).set_default(true) - .describe("Whether the layer uses a bias vector."); - DMLC_DECLARE_FIELD(tile_size) - .describe("Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)"); - } - // constants - static const constexpr int kData = 0; - static const constexpr int kWeight = 1; - static const constexpr int kBias = 2; -}; - -struct Conv2DTransposeParam : public dmlc::Parameter { - int channels; - TShape kernel_size; - TShape strides; - TShape padding; - TShape output_padding; - TShape dilation; - int groups; - std::string layout; - std::string kernel_layout; - int out_dtype; - bool use_bias; - - DMLC_DECLARE_PARAMETER(Conv2DTransposeParam) { - DMLC_DECLARE_FIELD(channels) - .describe("The dimensionality of the output space" - "i.e. the number of output channels in the convolution."); - DMLC_DECLARE_FIELD(kernel_size) - .describe("Specifies the dimensions of the convolution window."); - DMLC_DECLARE_FIELD(strides).set_default(TShape({1, 1})) - .describe("Specifies the strides of the convolution."); - DMLC_DECLARE_FIELD(output_padding).set_default(TShape({0, 0})) - .describe("Zero-padding added to one side of the output."); - DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "on both sides for padding number of points"); - DMLC_DECLARE_FIELD(dilation).set_default(TShape({1, 1})) - .describe("Specifies the dilation rate to use for dilated convolution."); - DMLC_DECLARE_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); - DMLC_DECLARE_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Convolution is applied on the 'H' and" - "'W' dimensions."); - DMLC_DECLARE_FIELD(kernel_layout).set_default("OIHW") - .describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc." - "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" - "dimensions respectively."); - DMLC_DECLARE_DTYPE_FIELD(out_dtype) - .add_enum("same", -1) - .set_default(-1) - .describe("Output data type, set to explicit type under mixed precision setting"); - DMLC_DECLARE_FIELD(use_bias).set_default(true) - .describe("Whether the layer uses a bias vector."); - } - // constants - static const constexpr int kData = 0; - static const constexpr int kWeight = 1; - static const constexpr int kBias = 2; -}; - - -struct MaxPool2DParam : public dmlc::Parameter { - TShape pool_size; - TShape strides; - TShape padding; - std::string layout; - bool ceil_mode; - - DMLC_DECLARE_PARAMETER(MaxPool2DParam) { - DMLC_DECLARE_FIELD(pool_size) - .describe("Size of the pooling windows.."); - DMLC_DECLARE_FIELD(strides).set_default(TShape({1, 1})) - .describe("Specifies the strides of the convolution."); - DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - DMLC_DECLARE_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Convolution is applied on the 'H' and" - "'W' dimensions."); - DMLC_DECLARE_FIELD(ceil_mode).set_default(false) - .describe("When true, will use ceil instead of floor to compute the output shape."); - } -}; - - -struct AvgPool2DParam : public dmlc::Parameter { - TShape pool_size; - TShape strides; - TShape padding; - std::string layout; - bool ceil_mode; - bool count_include_pad; - - DMLC_DECLARE_PARAMETER(AvgPool2DParam) { - DMLC_DECLARE_FIELD(pool_size) - .describe("Size of the pooling windows.."); - DMLC_DECLARE_FIELD(strides).set_default(TShape({1, 1})) - .describe("Specifies the strides of the convolution."); - DMLC_DECLARE_FIELD(padding).set_default(TShape({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - DMLC_DECLARE_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Convolution is applied on the 'H' and" - "'W' dimensions."); - DMLC_DECLARE_FIELD(ceil_mode).set_default(false) - .describe("When true, will use ceil instead of floor to compute the output shape."); - DMLC_DECLARE_FIELD(count_include_pad).set_default(false) - .describe("When true, will include padding to compute the average"); - } -}; - - -struct GlobalPool2DParam : public dmlc::Parameter { - std::string layout; - - DMLC_DECLARE_PARAMETER(GlobalPool2DParam) { - DMLC_DECLARE_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Convolution is applied on the 'H' and" - "'W' dimensions."); - } -}; - -struct UpSamplingParam : public dmlc::Parameter { - int scale; - std::string layout; - std::string method; - - DMLC_DECLARE_PARAMETER(UpSamplingParam) { - DMLC_DECLARE_FIELD(scale) - .describe("upsampling scaling factor"); - DMLC_DECLARE_FIELD(layout) - .set_default("NCHW") - .describe("Dimension ordering of data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Upsampling is applied on the 'H' and" - "'W' dimensions."); - DMLC_DECLARE_FIELD(method) - .set_default("NEAREST_NEIGHBOR") - .describe("Specify the mode to use for scaling." - "NEAREST_NEIGHBOR - Nearest Neighbor" - "BILINEAR - Bilinear Interpolation"); - } -}; - -struct LayoutTransformParam : public dmlc::Parameter { - std::string src_layout; - std::string dst_layout; - - DMLC_DECLARE_PARAMETER(LayoutTransformParam) { - DMLC_DECLARE_FIELD(src_layout).set_default("__undef__") - .describe("Dimension ordering of data"); - DMLC_DECLARE_FIELD(dst_layout).set_default("__undef__") - .describe("Dimension ordering of data."); - } -}; - -struct MultiBoxPriorParam : public dmlc::Parameter { - Tuple sizes; - Tuple ratios; - Tuple steps; - Tuple offsets; - bool clip; - - DMLC_DECLARE_PARAMETER(MultiBoxPriorParam) { - DMLC_DECLARE_FIELD(sizes).set_default(Tuple({1.0})) - .describe("List of sizes of generated MultiBoxPriores."); - DMLC_DECLARE_FIELD(ratios).set_default(Tuple({1.0})) - .describe("List of aspect ratios of generated MultiBoxPriores."); - DMLC_DECLARE_FIELD(steps).set_default(Tuple({-1.0, -1.0})) - .describe("Priorbox step across y and x, -1 for auto calculation."); - DMLC_DECLARE_FIELD(offsets).set_default(Tuple({0.5, 0.5})) - .describe("Priorbox center offsets, y and x respectively."); - DMLC_DECLARE_FIELD(clip).set_default(false) - .describe("Whether to clip out-of-boundary boxes."); - } -}; - -struct MultiBoxTransformLocParam : public dmlc::Parameter { - bool clip; - float threshold; - Tuple variances; - DMLC_DECLARE_PARAMETER(MultiBoxTransformLocParam) { - DMLC_DECLARE_FIELD(clip).set_default(true) - .describe("Clip out-of-boundary boxes."); - DMLC_DECLARE_FIELD(threshold).set_default(0.01) - .describe("Threshold to be a positive prediction."); - DMLC_DECLARE_FIELD(variances).set_default(Tuple({0.1f, 0.1f, 0.2f, 0.2f})) - .describe("Variances to be decoded from box regression output."); - } -}; - -struct NMSParam : public dmlc::Parameter { - float nms_threshold; - bool force_suppress; - int nms_topk; - DMLC_DECLARE_PARAMETER(NMSParam) { - DMLC_DECLARE_FIELD(nms_threshold).set_default(0.5) - .describe("Non-maximum suppression threshold."); - DMLC_DECLARE_FIELD(force_suppress).set_default(false) - .describe("Suppress all detections regardless of class_id."); - DMLC_DECLARE_FIELD(nms_topk).set_default(-1) - .describe("Keep maximum top k detections before nms, -1 for no limit."); - } -}; - -struct LRNParam : public dmlc::Parameter { - int size; - int axis; - float alpha; - float beta; - float bias; - - DMLC_DECLARE_PARAMETER(LRNParam) { - DMLC_DECLARE_FIELD(size) - .describe("The size of the local region to be considered for normalization."); - DMLC_DECLARE_FIELD(axis) - .describe("input data layout channel axis"); - DMLC_DECLARE_FIELD(alpha) - .describe("The scaling parameter."); - DMLC_DECLARE_FIELD(beta) - .describe("The exponent parameter."); - DMLC_DECLARE_FIELD(bias) - .describe("The offset parameter."); - } - // constants - static const constexpr int kData = 0; -}; - -struct L2NormalizeParam : public dmlc::Parameter { - float eps; - Tuple axis; - - DMLC_DECLARE_PARAMETER(L2NormalizeParam) { - DMLC_DECLARE_FIELD(eps) - .describe("float type epsilon value."); - DMLC_DECLARE_FIELD(axis) - .describe("axis over the normalization applied"); - } -}; - -} // namespace top -} // namespace nnvm - -#endif // NNVM_TOP_NN_H_ diff --git a/include/nnvm/top/tensor.h b/include/nnvm/top/tensor.h deleted file mode 100644 index 53ed5b3b0a22..000000000000 --- a/include/nnvm/top/tensor.h +++ /dev/null @@ -1,301 +0,0 @@ -/*! - * Copyright (c) 2017 by Contributors - * \file nnvm/top/tensor.h - * \brief Auxiliary param for tensor primitive. - */ -#ifndef NNVM_TOP_TENSOR_H_ -#define NNVM_TOP_TENSOR_H_ - -#include -#include -#include - -namespace nnvm { -namespace top { - -struct ConcatenateParam : public dmlc::Parameter { - int axis; - DMLC_DECLARE_PARAMETER(ConcatenateParam) { - DMLC_DECLARE_FIELD(axis).set_default(1) - .describe("the axis to be concated."); - } -}; - -struct ExpandDimsParam : public dmlc::Parameter { - int axis; - int num_newaxis; - DMLC_DECLARE_PARAMETER(ExpandDimsParam) { - DMLC_DECLARE_FIELD(axis) - .describe("the axis to be expanded."); - DMLC_DECLARE_FIELD(num_newaxis).set_lower_bound(1).set_default(1) - .describe("Number of new axis to be inserted."); - } -}; - -struct SplitParam : public dmlc::Parameter { - // numpy convention, only support indices, not support list. - Tuple indices_or_sections; - int axis; - // additional hint whether it is equal_split mode - // deduced from indices_or_sections - bool equal_split; - - DMLC_DECLARE_PARAMETER(SplitParam) { - DMLC_DECLARE_FIELD(indices_or_sections) - .describe("Number of outputs to be splitted"); - DMLC_DECLARE_FIELD(axis).set_lower_bound(0).set_default(1) - .describe("the axis to be splitted."); - } -}; - - -struct TakeParam : public dmlc::Parameter { - dmlc::optional axis; - - DMLC_DECLARE_PARAMETER(TakeParam) { - DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional()) - .describe("the axis over which to select values."); - } -}; - -struct StridedSliceParam : public dmlc::Parameter { - // numpy convention, only support indices, not support list. - Tuple begin; - Tuple end; - Tuple stride; - - DMLC_DECLARE_PARAMETER(StridedSliceParam) { - DMLC_DECLARE_FIELD(begin) - .describe("Indices for begin of slice"); - DMLC_DECLARE_FIELD(end) - .describe("Indices for end of the slice"); - DMLC_DECLARE_FIELD(stride).set_default(Tuple()) - .describe("Stride values of the slice"); - } -}; - -enum TypeFlag { - kFloat32 = 0, - kFloat64 = 1, - kFloat16 = 2, - kUint8 = 3, - kInt32 = 4, - kInt8 = 5, - kInt64 = 6, - kInt16 = 7, - kUint16 = 8, - kUint32 = 9, - kUint64 = 10, -}; - -enum IndicatorRuleFlag { - kGT0 = 0, - kLT0 = 1, - kMax = 2, - kMin = 3, -}; - -#define DMLC_DECLARE_DTYPE_FIELD(name) \ - DMLC_DECLARE_FIELD(name) \ - .add_enum("float16", kFloat16) \ - .add_enum("float32", kFloat32) \ - .add_enum("float64", kFloat64) \ - .add_enum("uint8", kUint8) \ - .add_enum("uint16", kUint16) \ - .add_enum("uint32", kUint32) \ - .add_enum("uint64", kUint64) \ - .add_enum("int8", kInt8) \ - .add_enum("int16", kInt16) \ - .add_enum("int32", kInt32) \ - .add_enum("int64", kInt64) - -struct CastParam : public dmlc::Parameter { - int dtype; - DMLC_DECLARE_PARAMETER(CastParam) { - DMLC_DECLARE_DTYPE_FIELD(dtype) - .describe("Output data type."); - } -}; - -struct IndicatorParam : public dmlc::Parameter { - TShape axis; - bool exclude; - DMLC_DECLARE_PARAMETER(IndicatorParam) { - DMLC_DECLARE_FIELD(axis).set_default(TShape()) - .describe(R"code(The axis or axes along which to perform the indicator rule. - - The default, `axis=()`, will compute over all elements into a - scalar array with shape `(1,)`. - - If `axis` is int, rule is applied on a particular axis. - - If `axis` is a tuple of ints, rule is applied on all the axes - specified in the tuple. - - If `exclude` is true, rule will be applied on the axes that are - NOT in axis instead.)code"); - DMLC_DECLARE_FIELD(exclude).set_default(false) - .describe("Whether to apply rule on axis that are NOT in axis instead."); - } -}; - -struct ReshapeParam : public dmlc::Parameter { - Tuple shape; - - DMLC_DECLARE_PARAMETER(ReshapeParam) { - DMLC_DECLARE_FIELD(shape); - } -}; - -struct SqueezeParam : public dmlc::Parameter { - TShape axis; - - DMLC_DECLARE_PARAMETER(SqueezeParam) { - DMLC_DECLARE_FIELD(axis).set_default(TShape()) - .describe("The axis to squeeze in the input tensor."); - } -}; - -struct ScalarParam : public dmlc::Parameter { - double scalar; - - DMLC_DECLARE_PARAMETER(ScalarParam) { - DMLC_DECLARE_FIELD(scalar); - } -}; - -struct FillValueParam : public dmlc::Parameter { - double fill_value; - - DMLC_DECLARE_PARAMETER(FillValueParam) { - DMLC_DECLARE_FIELD(fill_value) - .describe("Scalar value to be filled"); - } -}; - -struct TransposeParam : public dmlc::Parameter { - TShape axes; - - DMLC_DECLARE_PARAMETER(TransposeParam) { - DMLC_DECLARE_FIELD(axes).set_default(TShape()) - .describe("Target axis order. By default the axes will be inverted."); - } -}; - -struct FlipParam : public dmlc::Parameter { - int axis; - DMLC_DECLARE_PARAMETER(FlipParam) { - DMLC_DECLARE_FIELD(axis).set_default(0) - .describe("the axis to be reveresed."); - } -}; - -struct BroadcastToParam : public dmlc::Parameter { - TShape shape; - - DMLC_DECLARE_PARAMETER(BroadcastToParam) { - DMLC_DECLARE_FIELD(shape).set_default(TShape()) - .describe("The shape of the desired array." - " We can set the dim to zero if it's same as the original." - " E.g `A = broadcast_to(B, shape=(10, 0, 0))` "); - } -}; - -struct ReduceParam : public dmlc::Parameter { - TShape axis; - bool keepdims; - bool exclude; - - DMLC_DECLARE_PARAMETER(ReduceParam) { - DMLC_DECLARE_FIELD(axis).set_default(TShape()) - .describe(R"code(The axis or axes along which to perform the reduction. - - The default, `axis=()`, will compute over all elements into a - scalar array with shape `(1,)`. - - If `axis` is int, a reduction is performed on a particular axis. - - If `axis` is a tuple of ints, a reduction is performed on all the axes - specified in the tuple. - - If `exclude` is true, reduction will be performed on the axes that are - NOT in axis instead.)code"); - - DMLC_DECLARE_FIELD(keepdims).set_default(false) - .describe("If this is set to `True`, the reduced axes are left " - "in the result as dimension with size one."); - DMLC_DECLARE_FIELD(exclude).set_default(false) - .describe("Whether to perform reduction on axis that are NOT in axis instead."); - } -}; - -struct InitOpWithScalarParam : public dmlc::Parameter { - TShape shape; - int dtype; - double fill_value; - - DMLC_DECLARE_PARAMETER(InitOpWithScalarParam) { - DMLC_DECLARE_FIELD(shape).set_default(TShape()); - DMLC_DECLARE_DTYPE_FIELD(dtype).set_default(kFloat32) - .describe("Target data type."); - DMLC_DECLARE_FIELD(fill_value).describe("Scalar value to fill"); - } -}; - -struct InitOpParam : public dmlc::Parameter { - TShape shape; - int dtype; - - DMLC_DECLARE_PARAMETER(InitOpParam) { - DMLC_DECLARE_FIELD(shape).set_default(TShape()); - DMLC_DECLARE_DTYPE_FIELD(dtype).set_default(kFloat32) - .describe("Target data type."); - } -}; - -struct ElementWiseReduceParam : public dmlc::Parameter { - int num_args; - DMLC_DECLARE_PARAMETER(ElementWiseReduceParam) { - DMLC_DECLARE_FIELD(num_args).set_lower_bound(1) - .describe("Number of inputs to be reduced."); - } -}; - -struct MatMulParam : public dmlc::Parameter { - bool transpose_a; - bool transpose_b; - - DMLC_DECLARE_PARAMETER(MatMulParam) { - DMLC_DECLARE_FIELD(transpose_a) - .describe("If true then transpose the first input before dot.") - .set_default(false); - DMLC_DECLARE_FIELD(transpose_b) - .describe("If true then transpose the second input before dot.") - .set_default(false); - } -}; - -struct ClipParam : public dmlc::Parameter { - double a_min, a_max; - DMLC_DECLARE_PARAMETER(ClipParam) { - DMLC_DECLARE_FIELD(a_min) - .describe("Minimum value such that value smaller then this will be clipped."); - DMLC_DECLARE_FIELD(a_max) - .describe("Maximum value such that value larger then this will be clipped."); - } -}; - -struct SliceLikeParam : public dmlc::Parameter { - Tuple axis; - DMLC_DECLARE_PARAMETER(SliceLikeParam) { - DMLC_DECLARE_FIELD(axis).set_default(Tuple()) - .describe("List of axes on which input data will be sliced according to the " - "corresponding size of the second input. By default will slice " - "on all axes. Negative axes are supported."); - } -}; - -} // namespace top -} // namespace nnvm - -#endif // NNVM_TOP_TENSOR_H_ diff --git a/include/nnvm/tuple.h b/include/nnvm/tuple.h deleted file mode 100644 index 36b8ef13c74a..000000000000 --- a/include/nnvm/tuple.h +++ /dev/null @@ -1,633 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file nnvm/tuple.h - * \brief Data structure Tuple and TShape to store dynamic sized shapes. - */ -#ifndef NNVM_TUPLE_H_ -#define NNVM_TUPLE_H_ - -#include -#include -#include -#include -#include -#include -#include "base.h" - -namespace nnvm { - -/*! \brief data type to store dim size */ -typedef int64_t dim_t; - -/*! - * \brief A dynamic sized array data structure that is optimized for storing - * small number of elements with same type. - * - * Data will be stored in stack when number of elements is small. - * It is suitable to hold shape of Tensor. - * - * \tparam ValueType The type of data stored inside tuple. - * \sa TShape - */ -template -class Tuple { - public: - /*! \brief default constructor */ - Tuple() = default; - /*! \brief destructor */ - inline ~Tuple() { - delete [] data_heap_; - } - /*! - * \brief copy constructor from another tuple - * \param s the source tuple - */ - inline Tuple(const Tuple& s) { - this->assign(s.begin(), s.end()); - } - /*! - * \brief constructor from initializer list - * \param init the initializer_list - */ - inline Tuple(std::initializer_list init) { - this->assign(init.begin(), init.end()); - } - /*! - * \brief constructor from vector - * \param init the vector - */ - inline Tuple(std::vector init) { // NOLINT(runtime/explicit) - this->assign(init.begin(), init.end()); - } - /*! - * \brief move constructor from Tuple - * \param src the source shape - */ - - inline Tuple(Tuple&& src) { // NOLINT(runtime/explicit) - this->swap(src); - } - /*! - * \brief construct the Tuple from content of iterator - * \param begin the beginning of iterator - * \param end end the end of the iterator - * \tparam RandomAccessIterator iterator type - */ - template - inline Tuple(RandomAccessIterator begin, - RandomAccessIterator end) { - this->assign(begin, end); - } - /*! - * \brief Assign content to tuple from iterator. - * \param begin the beginning of iterator - * \param end end the end of the iterator - * \tparam RandomAccessIterator iterator type - */ - template - inline void assign(RandomAccessIterator begin, - RandomAccessIterator end) { - this->SetDim(end - begin); - std::copy(begin, end, this->begin()); - } - /*! - * \brief Swap current object with other - * \param other another object to be swapped. - */ - inline void swap(Tuple& other) { // NOLINT(*) - std::swap(ndim_, other.ndim_); - std::swap(num_heap_allocated_, other.num_heap_allocated_); - std::swap(data_stack_, other.data_stack_); - std::swap(data_heap_, other.data_heap_); - } - /*! - * \brief assignment from another tuple. - * \param src source tuple - * \return reference of self - */ - inline Tuple& operator=(const Tuple& src) { - this->assign(src.begin(), src.end()); - return *this; - } - /*! - * \brief assignment from rvalue of another tuple. - * \param src source tuple - * \return reference of self - */ - inline Tuple& operator=(Tuple&& src) { - Tuple(std::move(src)).swap(*this); - return *this; - } - /*! - * \brief assignment from initializer list - * \param init the source initializer list - * \return reference of self - */ - inline Tuple &operator=(std::initializer_list init) { - this->assign(init.begin(), init.end()); - return *this; - } - /*! - * \return whether two tuple equals - * \param s the tuple to compare against - */ - inline bool operator==(const Tuple &s) const { - if (ndim_ != s.ndim_) return false; - return std::equal(begin(), end(), s.begin()); - } - /*! - * \return whether two tuple not equal - * \param s the tuple to compare against - */ - inline bool operator!=(const Tuple &s) const { - return !(*this == s); - } - /*! \return the begin data pointer to content of the tuple */ - inline const ValueType *begin() const { - return ndim_ <= kStackCache ? data_stack_ : data_heap_; - } - /*! \return the begin data pointer to content of the tuple */ - inline ValueType *begin() { - return ndim_ <= kStackCache ? data_stack_ : data_heap_; - } - /*! \return the data pointer to end of the tuple */ - inline const ValueType* end() const { - return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_); - } - /*! \return the data pointer to end the tuple */ - inline ValueType* end() { - return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_); - } - /*! \return number of dimension of the tuple */ - inline uint32_t ndim() const { - return ndim_; - } - /*! - * \brief get corresponding index - * \param i dimension index - * \return the corresponding dimension size - */ - inline ValueType& operator[](size_t i) { - return begin()[i]; - } - /*! - * \brief get corresponding index - * \param i dimension index - * \return the corresponding dimension size - */ - inline const ValueType& operator[](size_t i) const { - return begin()[i]; - } - /*! - * \brief Save Tuple to JSON. - * \param writer JSONWriter - */ - inline void Save(dmlc::JSONWriter* writer) const { - std::vector tmp(begin(), end()); - writer->Write(tmp); - } - /*! - * \brief Load Tuple from JSON. - * \param reader JSONReader - */ - inline void Load(dmlc::JSONReader* reader) { - std::vector tmp; - reader->Read(&tmp); - this->assign(tmp.begin(), tmp.end()); - } - /*! - * \brief allow output string of tuple to ostream - * \param os the output stream - * \param t the tuple - * \return the ostream - */ - friend std::ostream &operator<<(std::ostream &os, const Tuple &t) { - os << '['; - const ValueType* begin = t.begin(); - const ValueType* end = t.end(); - for (const ValueType* it = begin; it != end; ++it) { - if (it != begin) os << ','; - os << *it; - } - os << ']'; - return os; - } - /*! - * \brief read tuple from the istream - * \param is the input stream - * \param t The tuple - * \return the istream - */ - friend std::istream &operator>>(std::istream &is, Tuple &t) { - // get ( - while (true) { - char ch = is.peek(); - if (isdigit(ch) || ch == '-') { - ValueType idx; - if (is >> idx) { - t.assign(&idx, &idx + 1); - } - return is; - } - is.get(); - if (ch == '(' || ch == '[') break; - if (!isspace(ch)) { - is.setstate(std::ios::failbit); - return is; - } - } - // Handle empty tuple - while (isspace(is.peek())) { - is.get(); - } - if (is.peek() == ')' || is.peek() == ']') { - is.get(); - return is; - } - // Handle non-empty tuple - ValueType idx; - std::vector tmp; - while (is >> idx) { - tmp.push_back(idx); - char ch; - do { - ch = is.get(); - } while (isspace(ch)); - if (std::is_integral::value && ch == 'L') { - ch = is.get(); - } - if (ch == ',') { - while (true) { - ch = is.peek(); - if (isspace(ch)) { - is.get(); continue; - } - if (ch == ')' || ch == ']') { - is.get(); break; - } - break; - } - if (ch == ')' || ch == ']') break; - } else if (ch == ')' || ch == ']') { - break; - } else { - is.setstate(std::ios::failbit); - return is; - } - } - t.assign(tmp.begin(), tmp.end()); - return is; - } - /*! - * \brief save the content into binary stream - * \param strm the output stream - * \tparam DType data type that save to - * \tparam TStream any stream type that have write - */ - template - inline void Save(TStream *strm) const; - /*! - * \brief load the content from binary stream - * \param strm the output stream - * \tparam DType data type that load from - * \tparam TStream any stream type that have write - * \return whether the load is successful - */ - template - inline bool Load(TStream *strm); - - protected: - // stack cache size - static const uint32_t kStackCache = 4; - /*! \brief number of dimension of the tuple */ - uint32_t ndim_{0}; - /*! \brief number of cells allocated in data_heap_ */ - uint32_t num_heap_allocated_{0}; - /*! \brief in stack space used to store shape when it is small */ - ValueType data_stack_[kStackCache]; - /*! \brief space to store shape when dimension is big*/ - ValueType* data_heap_{nullptr}; - // internal function to change the dimension - inline void SetDim(uint32_t ndim) { - if (ndim > kStackCache && - ndim > num_heap_allocated_) { - delete [] data_heap_; - data_heap_ = new ValueType[ndim]; - num_heap_allocated_ = ndim; - } - ndim_ = ndim; - } -}; - -/*! - * \brief A Shape class that is used to represent shape of each tensor. - */ -class TShape : public Tuple { - public: - /*! \brief default constructor */ - TShape() = default; - /*! - * constructor to construct a shape with all 1. - * \param ndim the number of dimension - */ - inline TShape(uint32_t ndim) { // NOLINT(*) - this->SetDim(ndim); - std::fill_n(begin(), ndim, 1); - } - /*! - * \brief copy constructor of TShape - * \param s source shape. - */ - inline TShape(const Tuple& s) { // NOLINT(*) - this->assign(s.begin(), s.end()); - } - /*! - * \brief constructor from initializer list - * \param init the initializer_list - */ - inline TShape(std::initializer_list init) { - this->assign(init.begin(), init.end()); - } - /*! - * \brief move constructor. - * \param s source shape. - */ - inline TShape(Tuple&& s) { // NOLINT(*) - this->swap(s); - } - /*! - * \brief construct the Tuple from content of iterator - * \param begin the beginning of iterator - * \param end end the end of the iterator - * \tparam RandomAccessIterator iterator type - */ - template - inline TShape(RandomAccessIterator begin, - RandomAccessIterator end) { - this->assign(begin, end); - } - /*! - * \brief assignment function from tshape - * \param src source shape. - * \return self. - */ - inline TShape& operator=(const Tuple& src) { - this->assign(src.begin(), src.end()); - return *this; - } - /*! - * \brief move assignment function from tshape - * \param src source shape. - * \return self. - */ - inline TShape& operator=(Tuple&& src) { // NOLINT(*) - TShape(std::move(src)).swap(*this); // NOLINT(*) - return *this; - } - /*! \return total number of elements in the shape */ - inline size_t Size() const { - dim_t size = 1; - const dim_t* start = begin(), *fin = end(); - for (const dim_t* it = start; it != fin; ++it) { - size *= *it; - } - return size; - } - /*! - * \return product shape in [dimstart,dimend) - * \param dimstart start dimension - * \param dimend end dimension - */ - inline size_t ProdShape(int dimstart, int dimend) const { - dim_t num = 1; - const dim_t *d = this->data(); - for (int i = dimstart; i < dimend; ++i) { - num *= d[i]; - } - return num; - } - /*! \return the begin data pointer to content of the tuple */ - inline const dim_t *data() const { - return begin(); - } - /*! \return the begin data pointer to content of the tuple */ - inline dim_t *data() { - return begin(); - } -#ifdef MSHADOW_XINLINE - template - inline TShape(const mshadow::Shape &s) {// NOLINT(*) - this->assign(s.shape_, s.shape_ + dim); - } - - template - inline TShape(mshadow::Shape &&s) {// NOLINT(*) - this->assign(s.shape_, s.shape_ + dim); - } - /*! - * \brief assignment from shape - * \param shape source shape - * \tparam dim shape dimension - * \return reference of self - */ - template - inline TShape &operator=(const mshadow::Shape &shape) { - this->assign(shape.shape_, shape.shape_ + dim); - return *this; - } - /*! - * \brief get the shape of tensor specifying dim - * \return the shape requested - * \tparam dim dimension of the tensor - */ - template - inline mshadow::Shape get() const { - CHECK_EQ(dim, static_cast(ndim())) - << "dimension do not match target dimension " << dim << " vs " << ndim(); - const dim_t *d = this->data(); - mshadow::Shape s; - for (int i = 0; i < dim; ++i) { - s[i] = d[i]; - } - return s; - } - /*! - * flatten the higher dimension to second dimension, return a 2D shape - * \return the flat 2d shape - */ - inline mshadow::Shape<2> FlatTo2D(void) const { - mshadow::Shape<2> s; - if (ndim() == 0) return mshadow::Shape2(0, 0); - const dim_t *d = this->data(); - s.shape_[1] = d[ndim() - 1]; - dim_t ymax = 1; - for (size_t i = 1; i < ndim(); ++i) { - ymax *= d[i - 1]; - } - s.shape_[0] = ymax; - return s; - } - /*! - * flatten the shape into three parts: [0, axis_begin), [axis_begin, axis_end], (axis_end, ndim) - * \param axis_begin The beginning axis specified. - * \param axis_end The ending axis specified. - * \return the flat 3d shape - */ - inline mshadow::Shape<3> FlatTo3D(size_t axis_begin, size_t axis_end) const { - CHECK(axis_end >= axis_begin); - mshadow::Shape<3> s; - if (ndim() == 0) return mshadow::Shape3(0, 0, 0); - const dim_t *d = this->data(); - s.shape_[0] = 1; - s.shape_[1] = 1; - s.shape_[2] = 1; - - for (size_t i = 0; i < axis_begin; ++i) { - s.shape_[0] *= d[i]; - } - for (size_t i = axis_begin; i <= axis_end; ++i) { - s.shape_[1] *= d[i]; - } - for (size_t i = axis_end + 1; i < ndim(); ++i) { - s.shape_[2] *= d[i]; - } - return s; - } - /*! - * flatten the axis before and after the specified axis, so it becomes 3D tensor - * \param axis The axis specified. - * \return the flat 3d shape - */ - inline mshadow::Shape<3> FlatTo3D(size_t axis) const { - return FlatTo3D(axis, axis); - } - inline bool operator==(const TShape &s) const { - if (ndim() != s.ndim()) return false; - return std::equal(begin(), end(), s.begin()); - } - inline bool operator!=(const TShape &s) const { - return !(*this == s); - } - /*! - * \return whether two shape equals - * \param s the shape to compare against - * \tparam dim dimension of the shape - */ - template - inline bool operator==(const mshadow::Shape &s) const { - if (ndim_ != dim) return false; - const dim_t *d = dim <= kStackCache ? data_stack_ : data_heap_; - for (size_t i = 0; i < dim; ++i) { - if (d[i] != s.shape_[i]) return false; - } - return true; - } - /*! - * \return whether two shape not equals - * \param s the shape to compare against - * \tparam dim dimension of the shape - */ - template - inline bool operator!=(const mshadow::Shape &s) const { - return !(*this == s); - } -#endif -}; - -/*! \brief helper function to cast type of container elements */ -template -inline DstIter ShapeTypeCast(const SrcIter begin, - const SrcIter end, - DstIter dst_begin) { - typedef typename std::iterator_traits::value_type SrcDType; - typedef typename std::iterator_traits::value_type DstDType; - auto cast = [](const SrcDType& dim) { return static_cast(dim); }; - return std::transform(begin, end, dst_begin, cast); -} - -/*! \brief helper function to transform a container to TShape with type cast */ -template -inline TShape ShapeTypeCast(const SrcIter begin, const SrcIter end) { - size_t ndim = std::distance(begin, end); - TShape res(ndim); - ShapeTypeCast(begin, end, res.begin()); - return res; -} - -/*! \tparam ValueType The type of data stored inside tuple. */ -template -template -inline void Tuple::Save(TStream *strm) const { - strm->Write(&ndim_, sizeof(ndim_)); - if (typeid(DType) == typeid(ValueType)) { - strm->Write(begin(), sizeof(ValueType) * ndim_); - } else { - std::vector buffer(ndim_); - ShapeTypeCast(begin(), end(), buffer.data()); - strm->Write(buffer.data(), sizeof(DType) * ndim_); - } -} - -/*! \tparam ValueType The type of data stored inside tuple. */ -template -template -inline bool Tuple::Load(TStream *strm) { - if (strm->Read(&ndim_, sizeof(ndim_)) != sizeof(ndim_)) return false; - this->SetDim(ndim_); - size_t nread = sizeof(DType) * ndim_; - if (typeid(DType) == typeid(ValueType)) { - if (strm->Read(begin(), nread) != nread) return false; - } else { - std::vector buffer(ndim_); - if (strm->Read(buffer.data(), nread) != nread) return false; - ShapeTypeCast(buffer.begin(), buffer.end(), begin()); - } - return true; -} - -} // namespace nnvm - -namespace std { -/*! \brief hash function for Tuple. */ -template -struct hash > { - /*! \brief hash a Tuple into unsigned int */ - size_t operator()(const nnvm::Tuple& val) const { - std::hash hash_uint; - size_t res = hash_uint(val.ndim()); - for (uint32_t i = 0; i < val.ndim(); ++i) { - res = dmlc::HashCombine(res, val[i]); - } - return res; - } -}; - -/*! \brief hash function for TShape. */ -template<> -struct hash { - /*! \brief hash a TShape into unsigned int */ - size_t operator()(const nnvm::TShape& val) const { - std::hash hash_uint; - size_t res = hash_uint(val.ndim()); - for (uint32_t i = 0; i < val.ndim(); ++i) { - res = dmlc::HashCombine(res, val[i]); - } - return res; - } -}; -} // namespace std - -namespace dmlc { -/*! \brief description for optional TShape */ -DMLC_DECLARE_TYPE_NAME(optional, "Shape or None"); -// avoid low version of MSVC -#if !defined(_MSC_VER) -template -struct type_name_helper > { - static inline std::string value() { - return "tuple of <" + type_name() + ">"; - } -}; -#endif -} // namespace dmlc -#endif // NNVM_TUPLE_H_ From 1a182cc66e553de129557fe7c9ff731a5ba67b6e Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Mon, 19 Nov 2018 16:56:36 -0800 Subject: [PATCH 03/12] add python API to return include path --- python/mxnet/libinfo.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/python/mxnet/libinfo.py b/python/mxnet/libinfo.py index b4450510a4c4..4d7a8e71b0fb 100644 --- a/python/mxnet/libinfo.py +++ b/python/mxnet/libinfo.py @@ -77,5 +77,36 @@ def find_lib_path(): return lib_path +def find_include_path(): + """Find MXNet dynamic library files. + + Returns + ------- + incl_path : string + Path to the header files. + """ + incl_from_env = os.environ.get('MXNET_INCLUDE_PATH') + if incl_from_env: + if os.path.isfile(incl_from_env): + if not os.path.isabs(incl_from_env): + logging.warning("MXNET_INCLUDE_PATH should be an absolute path, instead of: %s", + incl_from_env) + else: + if os.name == 'nt': + os.environ['PATH'] = os.environ['PATH'] + ';' + os.path.dirname(incl_from_env) + return [incl_from_env] + else: + logging.warning("MXNET_INCLUDE_PATH '%s' doesn't exist", incl_from_env) + + curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) + incl_path = os.path.join(curr_path, '../../include/') + if len(incl_path) == 0: + raise RuntimeError('Cannot find the MXNet include path.\n') + + if os.name == 'nt': + os.environ['PATH'] = os.environ['PATH'] + ';' + os.path.dirname(incl_path) + return incl_path + + # current version __version__ = "1.3.1" From ae6ba032ce8fc32cf424ba8ce9308c0a013931bd Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Tue, 20 Nov 2018 12:32:16 -0800 Subject: [PATCH 04/12] update link --- include/dlpack | 1 + include/dlpack/dlpack.h | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) create mode 120000 include/dlpack delete mode 120000 include/dlpack/dlpack.h diff --git a/include/dlpack b/include/dlpack new file mode 120000 index 000000000000..4e14a36ed7fd --- /dev/null +++ b/include/dlpack @@ -0,0 +1 @@ +../../3rdparty/dlpack/include/dlpack \ No newline at end of file diff --git a/include/dlpack/dlpack.h b/include/dlpack/dlpack.h deleted file mode 120000 index 119855e7cd94..000000000000 --- a/include/dlpack/dlpack.h +++ /dev/null @@ -1 +0,0 @@ -../../3rdparty/dlpack/include/dlpack/dlpack.h \ No newline at end of file From 9bc7ea2af8e4c86aea0c9efa3226c5ef792b5e49 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Tue, 20 Nov 2018 14:40:57 -0800 Subject: [PATCH 05/12] fix windows CI --- ci/build_windows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/build_windows.py b/ci/build_windows.py index 56769f7cdaf0..b060dfc1a091 100755 --- a/ci/build_windows.py +++ b/ci/build_windows.py @@ -160,7 +160,7 @@ def windows_package(args): copy_tree('python', j(pkgdir, 'python')) logging.info('packing headers') copy_tree('include', j(pkgdir, 'include')) - copy_tree(j('3rdparty','dmlc-core','include'), j(pkgdir, 'include')) + copy_tree(j('3rdparty','dmlc-core','include'), j(pkgdir, 'include'), update=1) copy_tree(j('3rdparty','mshadow', 'mshadow'), j(pkgdir, 'include', 'mshadow')) copy_tree(j('3rdparty','tvm','nnvm', 'include'), j(pkgdir,'include', 'nnvm', 'include')) logging.info("Compressing package: %s", pkgfile) From c4a1a0ab264882ac3aafdaf0ab6e4812f3ac7a32 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Tue, 20 Nov 2018 15:32:33 -0800 Subject: [PATCH 06/12] fix windows build --- ci/build_windows.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ci/build_windows.py b/ci/build_windows.py index b060dfc1a091..c4b59762fbc2 100755 --- a/ci/build_windows.py +++ b/ci/build_windows.py @@ -160,9 +160,9 @@ def windows_package(args): copy_tree('python', j(pkgdir, 'python')) logging.info('packing headers') copy_tree('include', j(pkgdir, 'include')) - copy_tree(j('3rdparty','dmlc-core','include'), j(pkgdir, 'include'), update=1) - copy_tree(j('3rdparty','mshadow', 'mshadow'), j(pkgdir, 'include', 'mshadow')) - copy_tree(j('3rdparty','tvm','nnvm', 'include'), j(pkgdir,'include', 'nnvm', 'include')) + #copy_tree(j('3rdparty','dmlc-core','include'), j(pkgdir, 'include')) + #copy_tree(j('3rdparty','mshadow', 'mshadow'), j(pkgdir, 'include', 'mshadow')) + #copy_tree(j('3rdparty','tvm','nnvm', 'include'), j(pkgdir,'include', 'nnvm', 'include')) logging.info("Compressing package: %s", pkgfile) check_call(['7z', 'a', pkgfile, pkgdir]) From ae2a138577e6533a8f3637c375b1811fd7b2141f Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Wed, 21 Nov 2018 06:14:33 +0000 Subject: [PATCH 07/12] fix dlpack link --- include/dlpack | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/dlpack b/include/dlpack index 4e14a36ed7fd..e19164b88516 120000 --- a/include/dlpack +++ b/include/dlpack @@ -1 +1 @@ -../../3rdparty/dlpack/include/dlpack \ No newline at end of file +../3rdparty/dlpack/include/dlpack \ No newline at end of file From 7506fc22acba906ca50f9bac9b1e76c887b3b1f4 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Mon, 3 Dec 2018 22:51:02 -0800 Subject: [PATCH 08/12] merge with master --- python/mxnet/libinfo.py | 31 ------------------------------- 1 file changed, 31 deletions(-) diff --git a/python/mxnet/libinfo.py b/python/mxnet/libinfo.py index 31537c5da4f6..57c73e5943af 100644 --- a/python/mxnet/libinfo.py +++ b/python/mxnet/libinfo.py @@ -110,36 +110,5 @@ def find_include_path(): ' or ' + src_incl_path + '\n') -def find_include_path(): - """Find MXNet dynamic library files. - - Returns - ------- - incl_path : string - Path to the header files. - """ - incl_from_env = os.environ.get('MXNET_INCLUDE_PATH') - if incl_from_env: - if os.path.isfile(incl_from_env): - if not os.path.isabs(incl_from_env): - logging.warning("MXNET_INCLUDE_PATH should be an absolute path, instead of: %s", - incl_from_env) - else: - if os.name == 'nt': - os.environ['PATH'] = os.environ['PATH'] + ';' + os.path.dirname(incl_from_env) - return [incl_from_env] - else: - logging.warning("MXNET_INCLUDE_PATH '%s' doesn't exist", incl_from_env) - - curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - incl_path = os.path.join(curr_path, '../../include/') - if len(incl_path) == 0: - raise RuntimeError('Cannot find the MXNet include path.\n') - - if os.name == 'nt': - os.environ['PATH'] = os.environ['PATH'] + ';' + os.path.dirname(incl_path) - return incl_path - - # current version __version__ = "1.4.0" From f9d4fb02d7089f51342512e9e456bb57f577e113 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Tue, 4 Dec 2018 11:08:49 -0800 Subject: [PATCH 09/12] exclude 3rd party header files from license check --- tests/nightly/apache_rat_license_check/rat-excludes | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/nightly/apache_rat_license_check/rat-excludes b/tests/nightly/apache_rat_license_check/rat-excludes index 0c305f498b34..c88dcae6a589 100755 --- a/tests/nightly/apache_rat_license_check/rat-excludes +++ b/tests/nightly/apache_rat_license_check/rat-excludes @@ -58,3 +58,7 @@ moderngpu/* deformable_im2col.cuh deformable_im2col.h REQUIRE +include/dlpack +include/dmlc +include/mshadow +include/nnvm \ No newline at end of file From 5b8870729ac98f548ab30353aab846ac8480ee77 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Tue, 4 Dec 2018 11:16:32 -0800 Subject: [PATCH 10/12] exclude license check --- tests/nightly/apache_rat_license_check/rat-excludes | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/nightly/apache_rat_license_check/rat-excludes b/tests/nightly/apache_rat_license_check/rat-excludes index c88dcae6a589..1b889782e80f 100755 --- a/tests/nightly/apache_rat_license_check/rat-excludes +++ b/tests/nightly/apache_rat_license_check/rat-excludes @@ -58,7 +58,7 @@ moderngpu/* deformable_im2col.cuh deformable_im2col.h REQUIRE -include/dlpack -include/dmlc -include/mshadow -include/nnvm \ No newline at end of file +include/dlpack/* +include/dmlc/* +include/mshadow/* +include/nnvm/* \ No newline at end of file From 77835c1809ab94f8ee109545f4cc8103b21d1b13 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Tue, 4 Dec 2018 11:44:15 -0800 Subject: [PATCH 11/12] exclude include directory --- tests/nightly/apache_rat_license_check/rat-excludes | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/nightly/apache_rat_license_check/rat-excludes b/tests/nightly/apache_rat_license_check/rat-excludes index 1b889782e80f..0d95792efc15 100755 --- a/tests/nightly/apache_rat_license_check/rat-excludes +++ b/tests/nightly/apache_rat_license_check/rat-excludes @@ -58,7 +58,4 @@ moderngpu/* deformable_im2col.cuh deformable_im2col.h REQUIRE -include/dlpack/* -include/dmlc/* -include/mshadow/* -include/nnvm/* \ No newline at end of file +include/* \ No newline at end of file From 11d36ef747a96cf03932d7d2f331814b7ca5ac95 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Tue, 4 Dec 2018 13:16:01 -0800 Subject: [PATCH 12/12] remove commented lines --- ci/build_windows.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/ci/build_windows.py b/ci/build_windows.py index c4b59762fbc2..b7d47fb1fde1 100755 --- a/ci/build_windows.py +++ b/ci/build_windows.py @@ -160,9 +160,6 @@ def windows_package(args): copy_tree('python', j(pkgdir, 'python')) logging.info('packing headers') copy_tree('include', j(pkgdir, 'include')) - #copy_tree(j('3rdparty','dmlc-core','include'), j(pkgdir, 'include')) - #copy_tree(j('3rdparty','mshadow', 'mshadow'), j(pkgdir, 'include', 'mshadow')) - #copy_tree(j('3rdparty','tvm','nnvm', 'include'), j(pkgdir,'include', 'nnvm', 'include')) logging.info("Compressing package: %s", pkgfile) check_call(['7z', 'a', pkgfile, pkgdir])