diff --git a/pyo3-derive-backend/src/defs.rs b/pyo3-derive-backend/src/defs.rs index eda87b5f013..6dcf93e12f1 100644 --- a/pyo3-derive-backend/src/defs.rs +++ b/pyo3-derive-backend/src/defs.rs @@ -21,7 +21,7 @@ impl Proto { where Q: PartialEq<&'static str>, { - self.methods.iter().find(|m| query == m.name()) + self.methods.iter().find(|m| query == m.name) } pub(crate) fn get_method(&self, query: Q) -> Option<&'static PyMethod> where @@ -101,52 +101,26 @@ pub const OBJECT: Proto = Proto { name: "Object", extension_trait: "pyo3::class::basic::PyBasicSlots", methods: &[ - MethodProto::Binary { - name: "__getattr__", - arg: "Name", - proto: "pyo3::class::basic::PyObjectGetAttrProtocol", - }, - MethodProto::Ternary { - name: "__setattr__", - arg1: "Name", - arg2: "Value", - proto: "pyo3::class::basic::PyObjectSetAttrProtocol", - }, - MethodProto::Binary { - name: "__delattr__", - arg: "Name", - proto: "pyo3::class::basic::PyObjectDelAttrProtocol", - }, - MethodProto::Unary { - name: "__str__", - proto: "pyo3::class::basic::PyObjectStrProtocol", - }, - MethodProto::Unary { - name: "__repr__", - proto: "pyo3::class::basic::PyObjectReprProtocol", - }, - MethodProto::Binary { - name: "__format__", - arg: "Format", - proto: "pyo3::class::basic::PyObjectFormatProtocol", - }, - MethodProto::Unary { - name: "__hash__", - proto: "pyo3::class::basic::PyObjectHashProtocol", - }, - MethodProto::Unary { - name: "__bytes__", - proto: "pyo3::class::basic::PyObjectBytesProtocol", - }, - MethodProto::Binary { - name: "__richcmp__", - arg: "Other", - proto: "pyo3::class::basic::PyObjectRichcmpProtocol", - }, - MethodProto::Unary { - name: "__bool__", - proto: "pyo3::class::basic::PyObjectBoolProtocol", - }, + MethodProto::new("__getattr__", "pyo3::class::basic::PyObjectGetAttrProtocol") + .args(&["Name"]) + .has_self(), + MethodProto::new("__setattr__", "pyo3::class::basic::PyObjectSetAttrProtocol") + .args(&["Name", "Value"]) + .has_self(), + MethodProto::new("__delattr__", "pyo3::class::basic::PyObjectDelAttrProtocol") + .args(&["Name"]) + .has_self(), + MethodProto::new("__str__", "pyo3::class::basic::PyObjectStrProtocol").has_self(), + MethodProto::new("__repr__", "pyo3::class::basic::PyObjectReprProtocol").has_self(), + MethodProto::new("__format__", "pyo3::class::basic::PyObjectFormatProtocol") + .args(&["Format"]) + .has_self(), + MethodProto::new("__hash__", "pyo3::class::basic::PyObjectHashProtocol").has_self(), + MethodProto::new("__bytes__", "pyo3::class::basic::PyObjectBytesProtocol").has_self(), + MethodProto::new("__richcmp__", "pyo3::class::basic::PyObjectRichcmpProtocol") + .args(&["Other"]) + .has_self(), + MethodProto::new("__bool__", "pyo3::class::basic::PyObjectBoolProtocol").has_self(), ], py_methods: &[ PyMethod::new("__format__", "pyo3::class::basic::FormatProtocolImpl"), @@ -170,32 +144,16 @@ pub const ASYNC: Proto = Proto { name: "Async", extension_trait: "pyo3::class::pyasync::PyAsyncSlots", methods: &[ - MethodProto::UnaryS { - name: "__await__", - arg: "Receiver", - proto: "pyo3::class::pyasync::PyAsyncAwaitProtocol", - }, - MethodProto::UnaryS { - name: "__aiter__", - arg: "Receiver", - proto: "pyo3::class::pyasync::PyAsyncAiterProtocol", - }, - MethodProto::UnaryS { - name: "__anext__", - arg: "Receiver", - proto: "pyo3::class::pyasync::PyAsyncAnextProtocol", - }, - MethodProto::Unary { - name: "__aenter__", - proto: "pyo3::class::pyasync::PyAsyncAenterProtocol", - }, - MethodProto::Quaternary { - name: "__aexit__", - arg1: "ExcType", - arg2: "ExcValue", - arg3: "Traceback", - proto: "pyo3::class::pyasync::PyAsyncAexitProtocol", - }, + MethodProto::new("__await__", "pyo3::class::pyasync::PyAsyncAwaitProtocol") + .args(&["Receiver"]), + MethodProto::new("__aiter__", "pyo3::class::pyasync::PyAsyncAiterProtocol") + .args(&["Receiver"]), + MethodProto::new("__anext__", "pyo3::class::pyasync::PyAsyncAnextProtocol") + .args(&["Receiver"]), + MethodProto::new("__aenter__", "pyo3::class::pyasync::PyAsyncAenterProtocol").has_self(), + MethodProto::new("__aexit__", "pyo3::class::pyasync::PyAsyncAexitProtocol") + .args(&["ExcType", "ExcValue", "Traceback"]) + .has_self(), ], py_methods: &[ PyMethod::new( @@ -218,14 +176,16 @@ pub const BUFFER: Proto = Proto { name: "Buffer", extension_trait: "pyo3::class::buffer::PyBufferSlots", methods: &[ - MethodProto::Unary { - name: "bf_getbuffer", - proto: "pyo3::class::buffer::PyBufferGetBufferProtocol", - }, - MethodProto::Unary { - name: "bf_releasebuffer", - proto: "pyo3::class::buffer::PyBufferReleaseBufferProtocol", - }, + MethodProto::new( + "bf_getbuffer", + "pyo3::class::buffer::PyBufferGetBufferProtocol", + ) + .has_self(), + MethodProto::new( + "bf_releasebuffer", + "pyo3::class::buffer::PyBufferReleaseBufferProtocol", + ) + .has_self(), ], py_methods: &[], slot_getters: &[ @@ -238,17 +198,10 @@ pub const CONTEXT: Proto = Proto { name: "Context", extension_trait: "", methods: &[ - MethodProto::Unary { - name: "__enter__", - proto: "pyo3::class::context::PyContextEnterProtocol", - }, - MethodProto::Quaternary { - name: "__exit__", - arg1: "ExcType", - arg2: "ExcValue", - arg3: "Traceback", - proto: "pyo3::class::context::PyContextExitProtocol", - }, + MethodProto::new("__enter__", "pyo3::class::context::PyContextEnterProtocol").has_self(), + MethodProto::new("__exit__", "pyo3::class::context::PyContextExitProtocol") + .args(&["ExcType", "ExcValue", "Traceback"]) + .has_self(), ], py_methods: &[ PyMethod::new( @@ -267,14 +220,12 @@ pub const GC: Proto = Proto { name: "GC", extension_trait: "pyo3::class::gc::PyGCSlots", methods: &[ - MethodProto::Free { - name: "__traverse__", - proto: "pyo3::class::gc::PyGCTraverseProtocol", - }, - MethodProto::Free { - name: "__clear__", - proto: "pyo3::class::gc::PyGCClearProtocol", - }, + MethodProto::new("__traverse__", "pyo3::class::gc::PyGCTraverseProtocol") + .has_self() + .no_result(), + MethodProto::new("__clear__", "pyo3::class::gc::PyGCClearProtocol") + .has_self() + .no_result(), ], py_methods: &[], slot_getters: &[ @@ -287,30 +238,16 @@ pub const DESCR: Proto = Proto { name: "Descriptor", extension_trait: "pyo3::class::descr::PyDescrSlots", methods: &[ - MethodProto::TernaryS { - name: "__get__", - arg1: "Receiver", - arg2: "Inst", - arg3: "Owner", - proto: "pyo3::class::descr::PyDescrGetProtocol", - }, - MethodProto::TernaryS { - name: "__set__", - arg1: "Receiver", - arg2: "Inst", - arg3: "Value", - proto: "pyo3::class::descr::PyDescrSetProtocol", - }, - MethodProto::Binary { - name: "__det__", - arg: "Inst", - proto: "pyo3::class::descr::PyDescrDelProtocol", - }, - MethodProto::Binary { - name: "__set_name__", - arg: "Inst", - proto: "pyo3::class::descr::PyDescrSetNameProtocol", - }, + MethodProto::new("__get__", "pyo3::class::descr::PyDescrGetProtocol") + .args(&["Receiver", "Inst", "Owner"]), + MethodProto::new("__set__", "pyo3::class::descr::PyDescrSetProtocol") + .args(&["Receiver", "Inst", "Value"]), + MethodProto::new("__det__", "pyo3::class::descr::PyDescrDelProtocol") + .args(&["Inst"]) + .has_self(), + MethodProto::new("__set_name__", "pyo3::class::descr::PyDescrSetNameProtocol") + .args(&["Inst"]) + .has_self(), ], py_methods: &[ PyMethod::new("__del__", "pyo3::class::context::PyDescrDelProtocolImpl"), @@ -330,16 +267,8 @@ pub const ITER: Proto = Proto { extension_trait: "pyo3::class::iter::PyIterSlots", py_methods: &[], methods: &[ - MethodProto::UnaryS { - name: "__iter__", - arg: "Receiver", - proto: "pyo3::class::iter::PyIterIterProtocol", - }, - MethodProto::UnaryS { - name: "__next__", - arg: "Receiver", - proto: "pyo3::class::iter::PyIterNextProtocol", - }, + MethodProto::new("__iter__", "pyo3::class::iter::PyIterIterProtocol").args(&["Receiver"]), + MethodProto::new("__next__", "pyo3::class::iter::PyIterNextProtocol").args(&["Receiver"]), ], slot_getters: &[ SlotGetter::new(&["__iter__"], "get_iter"), @@ -351,30 +280,30 @@ pub const MAPPING: Proto = Proto { name: "Mapping", extension_trait: "pyo3::class::mapping::PyMappingSlots", methods: &[ - MethodProto::Unary { - name: "__len__", - proto: "pyo3::class::mapping::PyMappingLenProtocol", - }, - MethodProto::Binary { - name: "__getitem__", - arg: "Key", - proto: "pyo3::class::mapping::PyMappingGetItemProtocol", - }, - MethodProto::Ternary { - name: "__setitem__", - arg1: "Key", - arg2: "Value", - proto: "pyo3::class::mapping::PyMappingSetItemProtocol", - }, - MethodProto::Binary { - name: "__delitem__", - arg: "Key", - proto: "pyo3::class::mapping::PyMappingDelItemProtocol", - }, - MethodProto::Unary { - name: "__reversed__", - proto: "pyo3::class::mapping::PyMappingReversedProtocol", - }, + MethodProto::new("__len__", "pyo3::class::mapping::PyMappingLenProtocol").has_self(), + MethodProto::new( + "__getitem__", + "pyo3::class::mapping::PyMappingGetItemProtocol", + ) + .args(&["Key"]) + .has_self(), + MethodProto::new( + "__setitem__", + "pyo3::class::mapping::PyMappingSetItemProtocol", + ) + .args(&["Key", "Value"]) + .has_self(), + MethodProto::new( + "__delitem__", + "pyo3::class::mapping::PyMappingDelItemProtocol", + ) + .args(&["Key"]) + .has_self(), + MethodProto::new( + "__reversed__", + "pyo3::class::mapping::PyMappingReversedProtocol", + ) + .has_self(), ], py_methods: &[PyMethod::new( "__reversed__", @@ -393,51 +322,55 @@ pub const SEQ: Proto = Proto { name: "Sequence", extension_trait: "pyo3::class::sequence::PySequenceSlots", methods: &[ - MethodProto::Unary { - name: "__len__", - proto: "pyo3::class::sequence::PySequenceLenProtocol", - }, - MethodProto::Binary { - name: "__getitem__", - arg: "Index", - proto: "pyo3::class::sequence::PySequenceGetItemProtocol", - }, - MethodProto::Ternary { - name: "__setitem__", - arg1: "Index", - arg2: "Value", - proto: "pyo3::class::sequence::PySequenceSetItemProtocol", - }, - MethodProto::Binary { - name: "__delitem__", - arg: "Index", - proto: "pyo3::class::sequence::PySequenceDelItemProtocol", - }, - MethodProto::Binary { - name: "__contains__", - arg: "Item", - proto: "pyo3::class::sequence::PySequenceContainsProtocol", - }, - MethodProto::Binary { - name: "__concat__", - arg: "Other", - proto: "pyo3::class::sequence::PySequenceConcatProtocol", - }, - MethodProto::Binary { - name: "__repeat__", - arg: "Index", - proto: "pyo3::class::sequence::PySequenceRepeatProtocol", - }, - MethodProto::Binary { - name: "__inplace_concat__", - arg: "Other", - proto: "pyo3::class::sequence::PySequenceInplaceConcatProtocol", - }, - MethodProto::Binary { - name: "__inplace_repeat__", - arg: "Index", - proto: "pyo3::class::sequence::PySequenceInplaceRepeatProtocol", - }, + MethodProto::new("__len__", "pyo3::class::sequence::PySequenceLenProtocol").has_self(), + MethodProto::new( + "__getitem__", + "pyo3::class::sequence::PySequenceGetItemProtocol", + ) + .args(&["Index"]) + .has_self(), + MethodProto::new( + "__setitem__", + "pyo3::class::sequence::PySequenceSetItemProtocol", + ) + .args(&["Index", "Value"]) + .has_self(), + MethodProto::new( + "__delitem__", + "pyo3::class::sequence::PySequenceDelItemProtocol", + ) + .args(&["Index"]) + .has_self(), + MethodProto::new( + "__contains__", + "pyo3::class::sequence::PySequenceContainsProtocol", + ) + .args(&["Item"]) + .has_self(), + MethodProto::new( + "__concat__", + "pyo3::class::sequence::PySequenceConcatProtocol", + ) + .args(&["Other"]) + .has_self(), + MethodProto::new( + "__repeat__", + "pyo3::class::sequence::PySequenceRepeatProtocol", + ) + .args(&["Index"]) + .has_self(), + MethodProto::new( + "__inplace_concat__", + "pyo3::class::sequence::PySequenceInplaceConcatProtocol", + ) + .args(&["Other"]) + .has_self(), + MethodProto::new( + "__inplace_repeat__", + "pyo3::class::sequence::PySequenceInplaceRepeatProtocol", + ) + .args(&["Index"]) + .has_self(), ], py_methods: &[], slot_getters: &[ @@ -458,264 +391,169 @@ pub const NUM: Proto = Proto { name: "Number", extension_trait: "pyo3::class::number::PyNumberSlots", methods: &[ - MethodProto::BinaryS { - name: "__add__", - arg1: "Left", - arg2: "Right", - proto: "pyo3::class::number::PyNumberAddProtocol", - }, - MethodProto::BinaryS { - name: "__sub__", - arg1: "Left", - arg2: "Right", - proto: "pyo3::class::number::PyNumberSubProtocol", - }, - MethodProto::BinaryS { - name: "__mul__", - arg1: "Left", - arg2: "Right", - proto: "pyo3::class::number::PyNumberMulProtocol", - }, - MethodProto::BinaryS { - name: "__matmul__", - arg1: "Left", - arg2: "Right", - proto: "pyo3::class::number::PyNumberMatmulProtocol", - }, - MethodProto::BinaryS { - name: "__truediv__", - arg1: "Left", - arg2: "Right", - proto: "pyo3::class::number::PyNumberTruedivProtocol", - }, - MethodProto::BinaryS { - name: "__floordiv__", - arg1: "Left", - arg2: "Right", - proto: "pyo3::class::number::PyNumberFloordivProtocol", - }, - MethodProto::BinaryS { - name: "__mod__", - arg1: "Left", - arg2: "Right", - proto: "pyo3::class::number::PyNumberModProtocol", - }, - MethodProto::BinaryS { - name: "__divmod__", - arg1: "Left", - arg2: "Right", - proto: "pyo3::class::number::PyNumberDivmodProtocol", - }, - MethodProto::TernaryS { - name: "__pow__", - arg1: "Left", - arg2: "Right", - arg3: "Modulo", - proto: "pyo3::class::number::PyNumberPowProtocol", - }, - MethodProto::BinaryS { - name: "__lshift__", - arg1: "Left", - arg2: "Right", - proto: "pyo3::class::number::PyNumberLShiftProtocol", - }, - MethodProto::BinaryS { - name: "__rshift__", - arg1: "Left", - arg2: "Right", - proto: "pyo3::class::number::PyNumberRShiftProtocol", - }, - MethodProto::BinaryS { - name: "__and__", - arg1: "Left", - arg2: "Right", - proto: "pyo3::class::number::PyNumberAndProtocol", - }, - MethodProto::BinaryS { - name: "__xor__", - arg1: "Left", - arg2: "Right", - proto: "pyo3::class::number::PyNumberXorProtocol", - }, - MethodProto::BinaryS { - name: "__or__", - arg1: "Left", - arg2: "Right", - proto: "pyo3::class::number::PyNumberOrProtocol", - }, - MethodProto::Binary { - name: "__radd__", - arg: "Other", - proto: "pyo3::class::number::PyNumberRAddProtocol", - }, - MethodProto::Binary { - name: "__rsub__", - arg: "Other", - proto: "pyo3::class::number::PyNumberRSubProtocol", - }, - MethodProto::Binary { - name: "__rmul__", - arg: "Other", - proto: "pyo3::class::number::PyNumberRMulProtocol", - }, - MethodProto::Binary { - name: "__rmatmul__", - arg: "Other", - proto: "pyo3::class::number::PyNumberRMatmulProtocol", - }, - MethodProto::Binary { - name: "__rtruediv__", - arg: "Other", - proto: "pyo3::class::number::PyNumberRTruedivProtocol", - }, - MethodProto::Binary { - name: "__rfloordiv__", - arg: "Other", - proto: "pyo3::class::number::PyNumberRFloordivProtocol", - }, - MethodProto::Binary { - name: "__rmod__", - arg: "Other", - proto: "pyo3::class::number::PyNumberRModProtocol", - }, - MethodProto::Binary { - name: "__rdivmod__", - arg: "Other", - proto: "pyo3::class::number::PyNumberRDivmodProtocol", - }, - MethodProto::Ternary { - name: "__rpow__", - arg1: "Other", - arg2: "Modulo", - proto: "pyo3::class::number::PyNumberRPowProtocol", - }, - MethodProto::Binary { - name: "__rlshift__", - arg: "Other", - proto: "pyo3::class::number::PyNumberRLShiftProtocol", - }, - MethodProto::Binary { - name: "__rrshift__", - arg: "Other", - proto: "pyo3::class::number::PyNumberRRShiftProtocol", - }, - MethodProto::Binary { - name: "__rand__", - arg: "Other", - proto: "pyo3::class::number::PyNumberRAndProtocol", - }, - MethodProto::Binary { - name: "__rxor__", - arg: "Other", - proto: "pyo3::class::number::PyNumberRXorProtocol", - }, - MethodProto::Binary { - name: "__ror__", - arg: "Other", - proto: "pyo3::class::number::PyNumberROrProtocol", - }, - MethodProto::Binary { - name: "__iadd__", - arg: "Other", - proto: "pyo3::class::number::PyNumberIAddProtocol", - }, - MethodProto::Binary { - name: "__isub__", - arg: "Other", - proto: "pyo3::class::number::PyNumberISubProtocol", - }, - MethodProto::Binary { - name: "__imul__", - arg: "Other", - proto: "pyo3::class::number::PyNumberIMulProtocol", - }, - MethodProto::Binary { - name: "__imatmul__", - arg: "Other", - proto: "pyo3::class::number::PyNumberIMatmulProtocol", - }, - MethodProto::Binary { - name: "__itruediv__", - arg: "Other", - proto: "pyo3::class::number::PyNumberITruedivProtocol", - }, - MethodProto::Binary { - name: "__ifloordiv__", - arg: "Other", - proto: "pyo3::class::number::PyNumberIFloordivProtocol", - }, - MethodProto::Binary { - name: "__imod__", - arg: "Other", - proto: "pyo3::class::number::PyNumberIModProtocol", - }, - MethodProto::Binary { - name: "__ipow__", - arg: "Other", - proto: "pyo3::class::number::PyNumberIPowProtocol", - }, - MethodProto::Binary { - name: "__ilshift__", - arg: "Other", - proto: "pyo3::class::number::PyNumberILShiftProtocol", - }, - MethodProto::Binary { - name: "__irshift__", - arg: "Other", - proto: "pyo3::class::number::PyNumberIRShiftProtocol", - }, - MethodProto::Binary { - name: "__iand__", - arg: "Other", - proto: "pyo3::class::number::PyNumberIAndProtocol", - }, - MethodProto::Binary { - name: "__ixor__", - arg: "Other", - proto: "pyo3::class::number::PyNumberIXorProtocol", - }, - MethodProto::Binary { - name: "__ior__", - arg: "Other", - proto: "pyo3::class::number::PyNumberIOrProtocol", - }, - MethodProto::Unary { - name: "__neg__", - proto: "pyo3::class::number::PyNumberNegProtocol", - }, - MethodProto::Unary { - name: "__pos__", - proto: "pyo3::class::number::PyNumberPosProtocol", - }, - MethodProto::Unary { - name: "__abs__", - proto: "pyo3::class::number::PyNumberAbsProtocol", - }, - MethodProto::Unary { - name: "__invert__", - proto: "pyo3::class::number::PyNumberInvertProtocol", - }, - MethodProto::Unary { - name: "__complex__", - proto: "pyo3::class::number::PyNumberComplexProtocol", - }, - MethodProto::Unary { - name: "__int__", - proto: "pyo3::class::number::PyNumberIntProtocol", - }, - MethodProto::Unary { - name: "__float__", - proto: "pyo3::class::number::PyNumberFloatProtocol", - }, - MethodProto::Unary { - name: "__index__", - proto: "pyo3::class::number::PyNumberIndexProtocol", - }, - MethodProto::Binary { - name: "__round__", - arg: "NDigits", - proto: "pyo3::class::number::PyNumberRoundProtocol", - }, + MethodProto::new("__add__", "pyo3::class::number::PyNumberAddProtocol") + .args(&["Left", "Right"]), + MethodProto::new("__sub__", "pyo3::class::number::PyNumberSubProtocol") + .args(&["Left", "Right"]), + MethodProto::new("__mul__", "pyo3::class::number::PyNumberMulProtocol") + .args(&["Left", "Right"]), + MethodProto::new("__matmul__", "pyo3::class::number::PyNumberMatmulProtocol") + .args(&["Left", "Right"]), + MethodProto::new( + "__truediv__", + "pyo3::class::number::PyNumberTruedivProtocol", + ) + .args(&["Left", "Right"]), + MethodProto::new( + "__floordiv__", + "pyo3::class::number::PyNumberFloordivProtocol", + ) + .args(&["Left", "Right"]), + MethodProto::new("__mod__", "pyo3::class::number::PyNumberModProtocol") + .args(&["Left", "Right"]), + MethodProto::new("__divmod__", "pyo3::class::number::PyNumberDivmodProtocol") + .args(&["Left", "Right"]), + MethodProto::new("__pow__", "pyo3::class::number::PyNumberPowProtocol") + .args(&["Left", "Right", "Modulo"]), + MethodProto::new("__lshift__", "pyo3::class::number::PyNumberLShiftProtocol") + .args(&["Left", "Right"]), + MethodProto::new("__rshift__", "pyo3::class::number::PyNumberRShiftProtocol") + .args(&["Left", "Right"]), + MethodProto::new("__and__", "pyo3::class::number::PyNumberAndProtocol") + .args(&["Left", "Right"]), + MethodProto::new("__xor__", "pyo3::class::number::PyNumberXorProtocol") + .args(&["Left", "Right"]), + MethodProto::new("__or__", "pyo3::class::number::PyNumberOrProtocol") + .args(&["Left", "Right"]), + MethodProto::new("__radd__", "pyo3::class::number::PyNumberRAddProtocol") + .args(&["Other"]) + .has_self(), + MethodProto::new("__rsub__", "pyo3::class::number::PyNumberRSubProtocol") + .args(&["Other"]) + .has_self(), + MethodProto::new("__rmul__", "pyo3::class::number::PyNumberRMulProtocol") + .args(&["Other"]) + .has_self(), + MethodProto::new( + "__rmatmul__", + "pyo3::class::number::PyNumberRMatmulProtocol", + ) + .args(&["Other"]) + .has_self(), + MethodProto::new( + "__rtruediv__", + "pyo3::class::number::PyNumberRTruedivProtocol", + ) + .args(&["Other"]) + .has_self(), + MethodProto::new( + "__rfloordiv__", + "pyo3::class::number::PyNumberRFloordivProtocol", + ) + .args(&["Other"]) + .has_self(), + MethodProto::new("__rmod__", "pyo3::class::number::PyNumberRModProtocol") + .args(&["Other"]) + .has_self(), + MethodProto::new( + "__rdivmod__", + "pyo3::class::number::PyNumberRDivmodProtocol", + ) + .args(&["Other"]) + .has_self(), + MethodProto::new("__rpow__", "pyo3::class::number::PyNumberRPowProtocol") + .args(&["Other", "Modulo"]) + .has_self(), + MethodProto::new( + "__rlshift__", + "pyo3::class::number::PyNumberRLShiftProtocol", + ) + .args(&["Other"]) + .has_self(), + MethodProto::new( + "__rrshift__", + "pyo3::class::number::PyNumberRRShiftProtocol", + ) + .args(&["Other"]) + .has_self(), + MethodProto::new("__rand__", "pyo3::class::number::PyNumberRAndProtocol") + .args(&["Other"]) + .has_self(), + MethodProto::new("__rxor__", "pyo3::class::number::PyNumberRXorProtocol") + .args(&["Other"]) + .has_self(), + MethodProto::new("__ror__", "pyo3::class::number::PyNumberROrProtocol") + .args(&["Other"]) + .has_self(), + MethodProto::new("__iadd__", "pyo3::class::number::PyNumberIAddProtocol") + .args(&["Other"]) + .has_self(), + MethodProto::new("__isub__", "pyo3::class::number::PyNumberISubProtocol") + .args(&["Other"]) + .has_self(), + MethodProto::new("__imul__", "pyo3::class::number::PyNumberIMulProtocol") + .args(&["Other"]) + .has_self(), + MethodProto::new( + "__imatmul__", + "pyo3::class::number::PyNumberIMatmulProtocol", + ) + .args(&["Other"]) + .has_self(), + MethodProto::new( + "__itruediv__", + "pyo3::class::number::PyNumberITruedivProtocol", + ) + .args(&["Other"]) + .has_self(), + MethodProto::new( + "__ifloordiv__", + "pyo3::class::number::PyNumberIFloordivProtocol", + ) + .args(&["Other"]) + .has_self(), + MethodProto::new("__imod__", "pyo3::class::number::PyNumberIModProtocol") + .args(&["Other"]) + .has_self(), + MethodProto::new("__ipow__", "pyo3::class::number::PyNumberIPowProtocol") + .args(&["Other"]) + .has_self(), + MethodProto::new( + "__ilshift__", + "pyo3::class::number::PyNumberILShiftProtocol", + ) + .args(&["Other"]) + .has_self(), + MethodProto::new( + "__irshift__", + "pyo3::class::number::PyNumberIRShiftProtocol", + ) + .args(&["Other"]) + .has_self(), + MethodProto::new("__iand__", "pyo3::class::number::PyNumberIAndProtocol") + .args(&["Other"]) + .has_self(), + MethodProto::new("__ixor__", "pyo3::class::number::PyNumberIXorProtocol") + .args(&["Other"]) + .has_self(), + MethodProto::new("__ior__", "pyo3::class::number::PyNumberIOrProtocol") + .args(&["Other"]) + .has_self(), + MethodProto::new("__neg__", "pyo3::class::number::PyNumberNegProtocol").has_self(), + MethodProto::new("__pos__", "pyo3::class::number::PyNumberPosProtocol").has_self(), + MethodProto::new("__abs__", "pyo3::class::number::PyNumberAbsProtocol").has_self(), + MethodProto::new("__invert__", "pyo3::class::number::PyNumberInvertProtocol").has_self(), + MethodProto::new( + "__complex__", + "pyo3::class::number::PyNumberComplexProtocol", + ) + .has_self(), + MethodProto::new("__int__", "pyo3::class::number::PyNumberIntProtocol").has_self(), + MethodProto::new("__float__", "pyo3::class::number::PyNumberFloatProtocol").has_self(), + MethodProto::new("__index__", "pyo3::class::number::PyNumberIndexProtocol").has_self(), + MethodProto::new("__round__", "pyo3::class::number::PyNumberRoundProtocol") + .args(&["NDigits"]) + .has_self(), ], py_methods: &[ PyMethod::coexist("__radd__", "pyo3::class::number::PyNumberRAddProtocolImpl"), diff --git a/pyo3-derive-backend/src/proto_method.rs b/pyo3-derive-backend/src/proto_method.rs index 77a1e291a45..662797069fa 100644 --- a/pyo3-derive-backend/src/proto_method.rs +++ b/pyo3-derive-backend/src/proto_method.rs @@ -1,5 +1,4 @@ // Copyright (c) 2017-present PyO3 Project and Contributors -use crate::utils::print_err; use proc_macro2::{Span, TokenStream}; use quote::{quote, ToTokens}; use syn::Token; @@ -7,66 +6,38 @@ use syn::Token; // TODO: // Add lifetime support for args with Rptr #[derive(Debug)] -pub enum MethodProto { - Free { - name: &'static str, - proto: &'static str, - }, - Unary { - name: &'static str, - proto: &'static str, - }, - UnaryS { - name: &'static str, - arg: &'static str, - proto: &'static str, - }, - Binary { - name: &'static str, - arg: &'static str, - proto: &'static str, - }, - BinaryS { - name: &'static str, - arg1: &'static str, - arg2: &'static str, - proto: &'static str, - }, - Ternary { - name: &'static str, - arg1: &'static str, - arg2: &'static str, - proto: &'static str, - }, - TernaryS { - name: &'static str, - arg1: &'static str, - arg2: &'static str, - arg3: &'static str, - proto: &'static str, - }, - Quaternary { - name: &'static str, - arg1: &'static str, - arg2: &'static str, - arg3: &'static str, - proto: &'static str, - }, +pub struct MethodProto { + pub name: &'static str, + pub args: &'static [&'static str], + pub proto: &'static str, + pub with_self: bool, + pub with_result: bool, } impl MethodProto { - pub fn name(&self) -> &str { - match *self { - MethodProto::Free { ref name, .. } => name, - MethodProto::Unary { ref name, .. } => name, - MethodProto::UnaryS { ref name, .. } => name, - MethodProto::Binary { ref name, .. } => name, - MethodProto::BinaryS { ref name, .. } => name, - MethodProto::Ternary { ref name, .. } => name, - MethodProto::TernaryS { ref name, .. } => name, - MethodProto::Quaternary { ref name, .. } => name, + // TODO: workaround for no unsized casts in const fn on Rust 1.45 (stable in 1.46) + const EMPTY_ARGS: &'static [&'static str] = &[]; + pub const fn new(name: &'static str, proto: &'static str) -> Self { + MethodProto { + name, + proto, + args: MethodProto::EMPTY_ARGS, + with_self: false, + with_result: true, } } + pub const fn args(mut self, args: &'static [&'static str]) -> MethodProto { + self.args = args; + self + } + pub const fn has_self(mut self) -> MethodProto { + self.with_self = true; + self + } + pub const fn no_result(mut self) -> MethodProto { + self.with_result = false; + self + } } pub(crate) fn impl_method_proto( @@ -74,255 +45,47 @@ pub(crate) fn impl_method_proto( sig: &mut syn::Signature, meth: &MethodProto, ) -> syn::Result { - let ret_ty = match &sig.output { - syn::ReturnType::Default => quote! { () }, - syn::ReturnType::Type(_, ty) => { - let mut ty = ty.clone(); - insert_lifetime(&mut ty); - ty.to_token_stream() - } - }; - - let toks = match *meth { - MethodProto::Free { proto, .. } => { - let p: syn::Path = syn::parse_str(proto).unwrap(); - quote! { - impl<'p> #p<'p> for #cls {} - } - } - MethodProto::Unary { proto, .. } => { - let p: syn::Path = syn::parse_str(proto).unwrap(); - - let tmp: syn::ItemFn = syn::parse_quote! { - fn test(&self) -> <#cls as #p<'p>>::Result {} - }; - sig.output = tmp.sig.output; - modify_self_ty(sig); - - quote! { - impl<'p> #p<'p> for #cls { - type Result = #ret_ty; - } - } - } - MethodProto::UnaryS { proto, arg, .. } => { - let p: syn::Path = syn::parse_str(proto).unwrap(); - - let slf_name = syn::Ident::new(arg, Span::call_site()); - let slf_ty = get_arg_ty(sig, 0)?; - let tmp: syn::ItemFn = syn::parse_quote! { - fn test(&self) -> <#cls as #p<'p>>::Result {} - }; - sig.output = tmp.sig.output; - modify_self_ty(sig); - - if let syn::FnArg::Typed(ref mut arg) = sig.inputs[0] { - arg.ty = Box::new(syn::parse_quote! { - <#cls as #p<'p>>::#slf_name - }); - } - - quote! { - impl<'p> #p<'p> for #cls { - type #slf_name = #slf_ty; - type Result = #ret_ty; - } - } - } - MethodProto::Binary { name, arg, proto } => { - if sig.inputs.len() <= 1 { - println!("Not enough arguments for {}", name); - return Ok(TokenStream::new()); - } - - let p: syn::Path = syn::parse_str(proto).unwrap(); - let arg_name = syn::Ident::new(arg, Span::call_site()); - let arg_ty = get_arg_ty(sig, 1)?; - - let tmp = extract_decl(syn::parse_quote! { - fn test(&self,arg: <#cls as #p<'p>>::#arg_name)-> <#cls as #p<'p>>::Result {} - }); - - let tmp2 = extract_decl(syn::parse_quote! { - fn test(&self, arg: Option<<#cls as #p<'p>>::#arg_name>) -> <#cls as #p<'p>>::Result {} - }); - - modify_arg_ty(sig, 1, &tmp, &tmp2)?; - modify_self_ty(sig); - - quote! { - impl<'p> #p<'p> for #cls { - type #arg_name = #arg_ty; - type Result = #ret_ty; - } - } - } - MethodProto::BinaryS { - name, - arg1, - arg2, - proto, - } => { - if sig.inputs.len() <= 1 { - print_err(format!("Not enough arguments {}", name), quote!(sig)); - return Ok(TokenStream::new()); - } - let p: syn::Path = syn::parse_str(proto).unwrap(); - let arg1_name = syn::Ident::new(arg1, Span::call_site()); - let arg1_ty = get_arg_ty(sig, 0)?; - let arg2_name = syn::Ident::new(arg2, Span::call_site()); - let arg2_ty = get_arg_ty(sig, 1)?; + let p: syn::Path = syn::parse_str(meth.proto).unwrap(); - // rewrite ty - let tmp = extract_decl(syn::parse_quote! {fn test( - arg1: <#cls as #p<'p>>::#arg1_name, - arg2: <#cls as #p<'p>>::#arg2_name) - -> <#cls as #p<'p>>::Result {}}); - let tmp2 = extract_decl(syn::parse_quote! {fn test( - arg1: Option<<#cls as #p<'p>>::#arg1_name>, - arg2: Option<<#cls as #p<'p>>::#arg2_name>) - -> <#cls as #p<'p>>::Result {}}); - modify_arg_ty(sig, 0, &tmp, &tmp2)?; - modify_arg_ty(sig, 1, &tmp, &tmp2)?; + let mut impl_types = Vec::new(); + for (i, arg) in meth.args.iter().enumerate() { + let idx = if meth.with_self { i + 1 } else { i }; + let arg_name = syn::Ident::new(arg, Span::call_site()); + let arg_ty = get_arg_ty(sig, idx)?; - quote! { - impl<'p> #p<'p> for #cls { - type #arg1_name = #arg1_ty; - type #arg2_name = #arg2_ty; - type Result = #ret_ty; - } - } - } - MethodProto::Ternary { - name, - arg1, - arg2, - proto, - } => { - if sig.inputs.len() <= 2 { - print_err(format!("Not enough arguments {}", name), quote!(sig)); - return Ok(TokenStream::new()); - } - let p: syn::Path = syn::parse_str(proto).unwrap(); - let arg1_name = syn::Ident::new(arg1, Span::call_site()); - let arg1_ty = get_arg_ty(sig, 1)?; - let arg2_name = syn::Ident::new(arg2, Span::call_site()); - let arg2_ty = get_arg_ty(sig, 2)?; - - // rewrite ty - let tmp = extract_decl(syn::parse_quote! {fn test( - &self, - arg1: <#cls as #p<'p>>::#arg1_name, - arg2: <#cls as #p<'p>>::#arg2_name) - -> <#cls as #p<'p>>::Result {}}); - let tmp2 = extract_decl(syn::parse_quote! {fn test( - &self, - arg1: Option<<#cls as #p<'p>>::#arg1_name>, - arg2: Option<<#cls as #p<'p>>::#arg2_name>) - -> <#cls as #p<'p>>::Result {}}); - modify_arg_ty(sig, 1, &tmp, &tmp2)?; - modify_arg_ty(sig, 2, &tmp, &tmp2)?; - modify_self_ty(sig); + impl_types.push(quote! {type #arg_name = #arg_ty;}); - quote! { - impl<'p> #p<'p> for #cls { - type #arg1_name = #arg1_ty; - type #arg2_name = #arg2_ty; - type Result = #ret_ty; - } - } - } - MethodProto::TernaryS { - name, - arg1, - arg2, - arg3, - proto, - } => { - if sig.inputs.len() <= 2 { - print_err(format!("Not enough arguments {}", name), quote!(sig)); - return Ok(TokenStream::new()); - } - let p: syn::Path = syn::parse_str(proto).unwrap(); - let arg1_name = syn::Ident::new(arg1, Span::call_site()); - let arg1_ty = get_arg_ty(sig, 0)?; - let arg2_name = syn::Ident::new(arg2, Span::call_site()); - let arg2_ty = get_arg_ty(sig, 1)?; - let arg3_name = syn::Ident::new(arg3, Span::call_site()); - let arg3_ty = get_arg_ty(sig, 2)?; + let type1 = syn::parse_quote! { arg: <#cls as #p<'p>>::#arg_name}; + let type2 = syn::parse_quote! { arg: Option<<#cls as #p<'p>>::#arg_name>}; + modify_arg_ty(sig, idx, &type1, &type2)?; + } - // rewrite ty - let tmp = extract_decl(syn::parse_quote! {fn test( - arg1: <#cls as #p<'p>>::#arg1_name, - arg2: <#cls as #p<'p>>::#arg2_name, - arg3: <#cls as #p<'p>>::#arg3_name) - -> <#cls as #p<'p>>::Result {}}); - let tmp2 = extract_decl(syn::parse_quote! {fn test( - arg1: Option<<#cls as #p<'p>>::#arg1_name>, - arg2: Option<<#cls as #p<'p>>::#arg2_name>, - arg3: Option<<#cls as #p<'p>>::#arg3_name>) - -> <#cls as #p<'p>>::Result {}}); - modify_arg_ty(sig, 0, &tmp, &tmp2)?; - modify_arg_ty(sig, 1, &tmp, &tmp2)?; - modify_arg_ty(sig, 2, &tmp, &tmp2)?; + if meth.with_self { + modify_self_ty(sig); + } - quote! { - impl<'p> #p<'p> for #cls { - type #arg1_name = #arg1_ty; - type #arg2_name = #arg2_ty; - type #arg3_name = #arg3_ty; - type Result = #ret_ty; - } - } - } - MethodProto::Quaternary { - name, - arg1, - arg2, - arg3, - proto, - } => { - if sig.inputs.len() <= 3 { - print_err(format!("Not enough arguments {}", name), quote!(sig)); - return Ok(TokenStream::new()); + let res_type_def = if meth.with_result { + let ret_ty = match &sig.output { + syn::ReturnType::Default => quote! { () }, + syn::ReturnType::Type(_, ty) => { + let mut ty = ty.clone(); + insert_lifetime(&mut ty); + ty.to_token_stream() } - let p: syn::Path = syn::parse_str(proto).unwrap(); - let arg1_name = syn::Ident::new(arg1, Span::call_site()); - let arg1_ty = get_arg_ty(sig, 1)?; - let arg2_name = syn::Ident::new(arg2, Span::call_site()); - let arg2_ty = get_arg_ty(sig, 2)?; - let arg3_name = syn::Ident::new(arg3, Span::call_site()); - let arg3_ty = get_arg_ty(sig, 3)?; + }; - // rewrite ty - let tmp = extract_decl(syn::parse_quote! {fn test( - &self, - arg1: <#cls as #p<'p>>::#arg1_name, - arg2: <#cls as #p<'p>>::#arg2_name, - arg3: <#cls as #p<'p>>::#arg3_name) - -> <#cls as #p<'p>>::Result {}}); - let tmp2 = extract_decl(syn::parse_quote! {fn test( - &self, - arg1: Option<<#cls as #p<'p>>::#arg1_name>, - arg2: Option<<#cls as #p<'p>>::#arg2_name>, - arg3: Option<<#cls as #p<'p>>::#arg3_name>) - -> <#cls as #p<'p>>::Result {}}); - modify_arg_ty(sig, 1, &tmp, &tmp2)?; - modify_arg_ty(sig, 2, &tmp, &tmp2)?; - modify_arg_ty(sig, 3, &tmp, &tmp2)?; - modify_self_ty(sig); + sig.output = syn::parse_quote! { -> <#cls as #p<'p>>::Result }; + quote! { type Result = #ret_ty; } + } else { + proc_macro2::TokenStream::new() + }; - quote! { - impl<'p> #p<'p> for #cls { - type #arg1_name = #arg1_ty; - type #arg2_name = #arg2_ty; - type #arg3_name = #arg3_ty; - type Result = #ret_ty; - } - } + Ok(quote! { + impl<'p> #p<'p> for #cls { + #(#impl_types)* + #res_type_def } - }; - Ok(toks) + }) } /// Some hacks for arguments: get `T` from `Option` and insert lifetime @@ -388,39 +151,23 @@ fn insert_lifetime(ty: &mut syn::Type) { } } -fn extract_decl(spec: syn::Item) -> syn::Signature { - match spec { - syn::Item::Fn(f) => f.sig, - _ => panic!(), - } -} - -// modify method signature fn modify_arg_ty( sig: &mut syn::Signature, idx: usize, - decl1: &syn::Signature, - decl2: &syn::Signature, + decl1: &syn::FnArg, + decl2: &syn::FnArg, ) -> syn::Result<()> { let arg = sig.inputs[idx].clone(); match arg { - syn::FnArg::Typed(ref cap) => match *cap.ty { - syn::Type::Path(ref typath) => { - let seg = typath.path.segments.last().unwrap().clone(); - if seg.ident == "Option" { - sig.inputs[idx] = fix_name(&cap.pat, &decl2.inputs[idx])?; - } else { - sig.inputs[idx] = fix_name(&cap.pat, &decl1.inputs[idx])?; - } - } - _ => { - sig.inputs[idx] = fix_name(&cap.pat, &decl1.inputs[idx])?; - } - }, + syn::FnArg::Typed(ref cap) if crate::utils::option_type_argument(&*cap.ty).is_some() => { + sig.inputs[idx] = fix_name(&cap.pat, &decl2)?; + } + syn::FnArg::Typed(ref cap) => { + sig.inputs[idx] = fix_name(&cap.pat, &decl1)?; + } _ => return Err(syn::Error::new_spanned(arg, "not supported")), } - sig.output = decl1.output.clone(); Ok(()) } diff --git a/pyo3-derive-backend/src/utils.rs b/pyo3-derive-backend/src/utils.rs index e5080483c8f..8178a629b5d 100644 --- a/pyo3-derive-backend/src/utils.rs +++ b/pyo3-derive-backend/src/utils.rs @@ -1,12 +1,7 @@ // Copyright (c) 2017-present PyO3 Project and Contributors use proc_macro2::Span; -use proc_macro2::TokenStream; use std::fmt::Display; -pub fn print_err(msg: String, t: TokenStream) { - println!("Error: {} in '{}'", msg, t.to_string()); -} - /// Check if the given type `ty` is `pyo3::Python`. pub fn is_python(ty: &syn::Type) -> bool { match ty {