Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[FFI] Add new containers and Implementations #19685

Merged
merged 44 commits into from
Mar 9, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
95d13ad
[FFI] Add new containers and tests
barry-jin Dec 16, 2020
7384831
add license
barry-jin Dec 16, 2020
4c1ff74
fix sanity
barry-jin Dec 17, 2020
db80bf7
sanity
barry-jin Dec 17, 2020
4086c6c
strlen -> stold
barry-jin Dec 17, 2020
76452d3
update base.pxi
barry-jin Dec 17, 2020
1ef24bf
set @use_np in test
barry-jin Dec 19, 2020
7197c4a
fix clang-tidy
leezu Dec 21, 2020
ede0977
fix clang-tidy
barry-jin Dec 21, 2020
83aaff8
Merge branch 'ffi-container' of /~https://github.com/barry-jin/incubato…
barry-jin Dec 21, 2020
a5c5259
make containers support NDArray
barry-jin Jan 12, 2021
f56acbe
fix sanity
barry-jin Jan 12, 2021
bdbee75
Adopt PackedFunc Based FFI on CachedOp
barry-jin Jan 15, 2021
3f78892
fix pylint
barry-jin Jan 15, 2021
c37de95
fix sanity
barry-jin Jan 15, 2021
a37c9a2
update ndarray_handle.h
barry-jin Jan 15, 2021
ee4d2d6
remove convert.pxi
barry-jin Jan 15, 2021
50f19ef
update
barry-jin Jan 15, 2021
b7933f5
update _internal.py
barry-jin Jan 15, 2021
8c961dc
convert ADT to list
barry-jin Jan 15, 2021
595a84e
Merge remote-tracking branch 'upstream/master' into ffi-container
barry-jin Jan 15, 2021
933be36
fix unix test failures
barry-jin Jan 19, 2021
1601f19
update function.pxi
barry-jin Jan 19, 2021
890483d
Merge remote-tracking branch 'upstream/master' into ffi-container
barry-jin Jan 28, 2021
20c0c9a
update
barry-jin Jan 29, 2021
97fa8f0
update cached_op
barry-jin Feb 2, 2021
b3367b4
fix sanity
barry-jin Feb 2, 2021
c47d283
fix
barry-jin Feb 2, 2021
0154837
fix
barry-jin Feb 2, 2021
f83de2f
update container
barry-jin Feb 2, 2021
164eff9
udpate cached_op_create
barry-jin Feb 2, 2021
8ad9c6e
Merge remote-tracking branch 'upstream/master' into ffi-container
barry-jin Feb 2, 2021
50b58f7
clean packed_func
barry-jin Feb 3, 2021
7d4580e
fix
barry-jin Feb 3, 2021
4b5cd7f
improve performance
barry-jin Feb 9, 2021
2ff4f0a
improve perf
barry-jin Feb 9, 2021
f91111a
improve perf
barry-jin Feb 13, 2021
6cbd771
Merge remote-tracking branch 'upstream/master' into ffi-container
barry-jin Feb 15, 2021
7b4af9a
update
barry-jin Feb 16, 2021
8025fc1
fix
barry-jin Feb 16, 2021
a36f61e
update cached_op.py
barry-jin Feb 16, 2021
713f46e
update packed_func.h
barry-jin Feb 23, 2021
47e3f29
update cached_op_api.cc
barry-jin Feb 23, 2021
3c77022
Merge remote-tracking branch 'upstream/master' into ffi-container
barry-jin Feb 26, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
792 changes: 792 additions & 0 deletions include/mxnet/runtime/container_ext.h

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions include/mxnet/runtime/ndarray_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace mxnet {
class NDArrayHandleObj : public Object {
public:
/*! \brief the Internal value. */
NDArray* value;
NDArray value;

static constexpr const char* _type_key = "MXNet.NDArrayHandle";
MXNET_DECLARE_FINAL_OBJECT_INFO(NDArrayHandleObj, Object)
Expand All @@ -41,12 +41,14 @@ class NDArrayHandle : public ObjectRef {
public:
explicit NDArrayHandle(NDArray* value) {
runtime::ObjectPtr<NDArrayHandleObj> node = make_object<NDArrayHandleObj>();
node->value = value;
node->value = *value;
data_ = std::move(node);
}
MXNET_DEFINE_OBJECT_REF_METHODS(NDArrayHandle, ObjectRef, NDArrayHandleObj)
};

inline NDArray* getArray() const {
return static_cast<NDArray*>(&(static_cast<NDArrayHandleObj*>(data_.get())->value));
}
}; // namespace mxnet

#endif // MXNET_RUNTIME_NDARRAY_HANDLE_H_
40 changes: 32 additions & 8 deletions include/mxnet/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,12 @@ enum TypeIndex {
kMXNetTensor = 1,
kMXNetClosure = 2,
kMXNetADT = 3,
kRuntimeModule = 4,
leezu marked this conversation as resolved.
Show resolved Hide resolved
kEllipsis = 5,
kSlice = 6,
kInteger = 7,
kFloat = 8,
kMXNetMap = 4,
kMXNetString = 5,
kEllipsis = 6,
kSlice = 7,
kInteger = 8,
kFloat = 9,
kStaticIndexEnd,
/*! \brief Type index is allocated during runtime. */
kDynamic = kStaticIndexEnd
Expand Down Expand Up @@ -567,6 +568,8 @@ class ObjectRef {

/*! \brief type indicate the container type. */
using ContainerType = Object;
// Default type properties for the reference class.
static constexpr bool _type_is_nullable = true;

protected:
/*! \brief Internal pointer that backs the reference. */
Expand Down Expand Up @@ -681,6 +684,11 @@ struct ObjectEqual {
static DMLC_ATTRIBUTE_UNUSED uint32_t __make_Object_tidx ## _ ## TypeName ## __ = \
TypeName::_GetOrAllocRuntimeTypeIndex()

#define MXNET_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \
TypeName(const TypeName& other) = default; \
TypeName(TypeName&& other) = default; \
TypeName& operator=(const TypeName& other) = default; \
TypeName& operator=(TypeName&& other) = default;

#define MXNET_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
TypeName() {} \
Expand All @@ -704,6 +712,14 @@ struct ObjectEqual {
operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName;

#define MXNET_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
explicit TypeName(::mxnet::runtime::ObjectPtr<::mxnet::runtime::Object> n) : ParentType(n) {} \
MXNET_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \
const ObjectName* get() const { return operator->(); } \
static constexpr bool _type_is_nullable = false; \
using ContainerType = ObjectName;

// Implementations details below
// Object reference counting.
#if MXNET_OBJECT_ATOMIC_REF_COUNTER
Expand Down Expand Up @@ -794,6 +810,9 @@ template <typename RefType, typename ObjType>
inline RefType GetRef(const ObjType* ptr) {
static_assert(std::is_base_of<typename RefType::ContainerType, ObjType>::value,
"Can only cast to the ref of same container type");
if (!RefType::_type_is_nullable) {
CHECK(ptr != nullptr);
}
return RefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));
}

Expand All @@ -806,9 +825,14 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr) {

template <typename SubRef, typename BaseRef>
inline SubRef Downcast(BaseRef ref) {
CHECK(ref->template IsInstance<typename SubRef::ContainerType>())
<< "Downcast from " << ref->GetTypeKey() << " to "
<< SubRef::ContainerType::_type_key << " failed.";
if (ref.defined()) {
CHECK(ref->template IsInstance<typename SubRef::ContainerType>())
<< "Downcast from " << ref->GetTypeKey() << " to "
<< SubRef::ContainerType::_type_key << " failed.";
} else {
CHECK(SubRef::_type_is_nullable) << "Downcast from nullptr to not nullable reference of "
<< SubRef::ContainerType::_type_key;
}
return SubRef(std::move(ref.data_));
}

Expand Down
125 changes: 112 additions & 13 deletions include/mxnet/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <mxnet/runtime/object.h>
#include <mxnet/runtime/ndarray.h>
#include <mxnet/runtime/container.h>
#include <mxnet/runtime/container_ext.h>
#include <mxnet/runtime/ndarray_handle.h>
#include <mxnet/runtime/ffi_helper.h>
#include <mxnet/runtime/data_type.h>
Expand Down Expand Up @@ -382,6 +383,23 @@ struct extension_type_info {
static const int code = 0;
};

/*!
* \brief Type traits for runtime type check during FFI conversion.
* \tparam T the type to be checked.
*/
template <typename T>
struct ObjectTypeChecker {
static bool Check(const Object* ptr) {
using ContainerType = typename T::ContainerType;
if (ptr == nullptr) return T::_type_is_nullable;
return ptr->IsInstance<ContainerType>();
}
static std::string TypeName() {
using ContainerType = typename T::ContainerType;
return ContainerType::_type_key;
}
};

/*!
* \brief Internal base class to
* handle conversion to POD values.
Expand Down Expand Up @@ -433,6 +451,8 @@ class MXNetPODValue_ {
typename = typename std::enable_if<
std::is_class<TObjectRef>::value>::type>
inline bool IsObjectRef() const;
template <typename TObjectRef>
inline TObjectRef AsObjectRef() const;
int type_code() const {
return type_code_;
}
Expand Down Expand Up @@ -487,6 +507,7 @@ class MXNetArgValue : public MXNetPODValue_ {
using MXNetPODValue_::operator void*;
using MXNetPODValue_::operator ObjectRef;
using MXNetPODValue_::IsObjectRef;
using MXNetPODValue_::AsObjectRef;

// conversion operator.
operator std::string() const {
Expand All @@ -497,6 +518,12 @@ class MXNetArgValue : public MXNetPODValue_ {
MXNET_CHECK_TYPE_CODE(type_code_, kStr);
return std::string(value_.v_str);
}
// } else if (type_code_ == kStr) {
// return std::string(value_.v_str);
// } else {
// CHECK(IsObjectRef<tvm::runtime::String>());
// return AsObjectRef<tvm::runtime::String>().operator std::string();
// }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the unused code?

}
operator DLDataType() const {
if (type_code_ == kStr) {
Expand Down Expand Up @@ -528,9 +555,6 @@ class MXNetArgValue : public MXNetPODValue_ {
const MXNetValue& value() const {
return value_;
}
// Deferred extension handler.
template<typename TObjectRef>
inline TObjectRef AsObjectRef() const;
template<typename T,
typename = typename std::enable_if<
std::is_class<T>::value>::type>
Expand Down Expand Up @@ -571,6 +595,7 @@ class MXNetRetValue : public MXNetPODValue_ {
using MXNetPODValue_::operator void*;
using MXNetPODValue_::operator ObjectRef;
using MXNetPODValue_::IsObjectRef;
using MXNetPODValue_::AsObjectRef;

MXNetRetValue(const MXNetRetValue& other) : MXNetPODValue_() {
this->Assign(other);
Expand Down Expand Up @@ -681,7 +706,8 @@ class MXNetRetValue : public MXNetPODValue_ {
}
MXNetRetValue& operator=(NDArrayHandle value) {
this->SwitchToPOD(kNDArrayHandle);
value_.v_handle = reinterpret_cast<void*>(value->value);
NDArray* arr = new NDArray(value->value);
value_.v_handle = reinterpret_cast<void*>(arr);
return *this;
}
MXNetRetValue& operator=(const PythonArg& value) {
Expand Down Expand Up @@ -725,8 +751,6 @@ class MXNetRetValue : public MXNetPODValue_ {
typename = typename std::enable_if<
std::is_class<T>::value>::type>
inline operator T() const;
template<typename TObjectRef>
inline TObjectRef AsObjectRef() const;

private:
template<typename T>
Expand Down Expand Up @@ -1173,13 +1197,88 @@ struct MXNetValueCast {

} // namespace detail

template<typename T, typename>
inline MXNetRetValue::operator T() const {
return detail::
MXNetValueCast<T, MXNetRetValue,
(extension_type_info<T>::code != 0),
(array_type_info<T>::code > 0)>
::Apply(this);
/*!
* \brief Type trait to specify special value conversion rules from
* MXNetArgValue and MXNetRetValue.
*
* The trait can be specialized to add type specific conversion logic
* from the TVMArgvalue and TVMRetValue.
*
* \tparam TObjectRef the specific ObjectRefType.
*/
template <typename TObjectRef>
struct PackedFuncValueConverter {
/*!
* \brief Convert a TObjectRef from an argument value.
* \param val The argument value.
* \return the converted result.
*/
static TObjectRef From(const MXNetArgValue& val) { return val.AsObjectRef<TObjectRef>(); }
/*!
* \brief Convert a TObjectRef from a return value.
* \param val The argument value.
* \return the converted result.
*/
static TObjectRef From(const MXNetRetValue& val) { return val.AsObjectRef<TObjectRef>(); }
};

template <>
struct PackedFuncValueConverter<::mxnet::runtime::String> {
static String From(const MXNetArgValue& val) {
if (val.IsObjectRef<mxnet::runtime::String>()) {
return val.AsObjectRef<mxnet::runtime::String>();
} else {
return mxnet::runtime::String(val.operator std::string());
}
}

static String From(const MXNetRetValue& val) {
if (val.IsObjectRef<mxnet::runtime::String>()) {
return val.AsObjectRef<mxnet::runtime::String>();
} else {
return mxnet::runtime::String(val.operator std::string());
}
}
};

template <typename TObjectRef>
inline TObjectRef MXNetPODValue_::AsObjectRef() const {
static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
"Conversion only works for ObjectRef");
using ContainerType = typename TObjectRef::ContainerType;

if (type_code_ == kNull) {
CHECK(TObjectRef::_type_is_nullable)
<< "Expect a not null value of " << ContainerType::_type_key;
return TObjectRef(ObjectPtr<Object>(nullptr));
}
if (type_code_ == kObjectHandle) {
// normal object type check.
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<TObjectRef>::Check(ptr))
<< "Expect " << ObjectTypeChecker<TObjectRef>::TypeName() << " but get "
<< ptr->GetTypeKey();
return TObjectRef(GetObjectPtr<Object>(ptr));
} else {
MXNET_CHECK_TYPE_CODE(type_code_, kObjectHandle);
return TObjectRef(ObjectPtr<Object>(nullptr));
}
}

template <typename T, typename>
inline MXNetArgValue::operator T() const {
return PackedFuncValueConverter<T>::From(*this);
}

template <typename TObjectRef, typename>
inline bool MXNetPODValue_::IsObjectRef() const {
using ContainerType = typename TObjectRef::ContainerType;
return type_code_ == kObjectHandle &&
ObjectTypeChecker<TObjectRef>::Check(static_cast<Object*>(value_.v_handle));
}

inline bool String::CanConvertFrom(const MXNetArgValue& val) {
return val.type_code() == kStr || val.IsObjectRef<mxnet::runtime::String>();
}

} // namespace runtime
Expand Down
22 changes: 22 additions & 0 deletions python/mxnet/_ctypes/_api_internal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""CachedOp APIs exposed from C++."""
szha marked this conversation as resolved.
Show resolved Hide resolved

import mxnet._ffi

mxnet._ffi._init_api("cached_op", __name__)
Loading