Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ set(PYBIND11_HEADERS
include/pybind11/detail/dynamic_raw_ptr_cast_if_possible.h
include/pybind11/detail/exception_translation.h
include/pybind11/detail/function_record_pyobject.h
include/pybind11/detail/function_ref.h
include/pybind11/detail/holder_caster_foreign_helpers.h
include/pybind11/detail/init.h
include/pybind11/detail/internals.h
Expand Down
8 changes: 8 additions & 0 deletions include/pybind11/detail/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,14 @@
# define PYBIND11_NOINLINE __attribute__((noinline)) inline
#endif

#if defined(_MSC_VER)
# define PYBIND11_ALWAYS_INLINE __forceinline
#elif defined(__GNUC__)
# define PYBIND11_ALWAYS_INLINE __attribute__((__always_inline__)) inline
#else
# define PYBIND11_ALWAYS_INLINE inline
#endif

#if defined(__MINGW32__)
// For unknown reasons all PYBIND11_DEPRECATED member trigger a warning when declared
// whether it is used or not
Expand Down
101 changes: 101 additions & 0 deletions include/pybind11/detail/function_ref.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

//===- llvm/ADT/STLFunctionalExtras.h - Extras for <functional> -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains a header-only class template that provides functionality
// similar to std::function but with non-owning semantics. It is a template-only
// implementation that requires no additional library linking.
//
//===----------------------------------------------------------------------===//

/// An efficient, type-erasing, non-owning reference to a callable. This is
/// intended for use as the type of a function parameter that is not used
/// after the function in question returns.
///
/// This class does not own the callable, so it is not in general safe to store
/// a FunctionRef.

// pybind11: modified again from executorch::runtime::FunctionRef
// - renamed back to function_ref
// - use pybind11 enable_if_t, remove_cvref_t, and remove_reference_t
// - lint suppressions

// torch::executor: modified from llvm::function_ref
// - renamed to FunctionRef
// - removed LLVM_GSL_POINTER and LLVM_LIFETIME_BOUND macro uses
// - use namespaced internal::remove_cvref_t

#pragma once

#include <pybind11/detail/common.h>

#include <cstdint>
#include <type_traits>
#include <utility>

PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
PYBIND11_NAMESPACE_BEGIN(detail)

//===----------------------------------------------------------------------===//
// Features from C++20
//===----------------------------------------------------------------------===//

template <typename Fn>
class function_ref;

template <typename Ret, typename... Params>
class function_ref<Ret(Params...)> {
Ret (*callback)(intptr_t callable, Params... params) = nullptr;
intptr_t callable;

template <typename Callable>
// NOLINTNEXTLINE(performance-unnecessary-value-param)
static Ret callback_fn(intptr_t callable, Params... params) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return (*reinterpret_cast<Callable *>(callable))(std::forward<Params>(params)...);
}

public:
function_ref() = default;
// NOLINTNEXTLINE(google-explicit-constructor)
function_ref(std::nullptr_t) {}

template <typename Callable>
// NOLINTNEXTLINE(google-explicit-constructor)
function_ref(
Callable &&callable,
// This is not the copy-constructor.
enable_if_t<!std::is_same<remove_cvref_t<Callable>, function_ref>::value> * = nullptr,
// Functor must be callable and return a suitable type.
enable_if_t<
std::is_void<Ret>::value
|| std::is_convertible<decltype(std::declval<Callable>()(std::declval<Params>()...)),
Ret>::value> * = nullptr)
: callback(callback_fn<remove_reference_t<Callable>>),
callable(reinterpret_cast<intptr_t>(&callable)) {}

// NOLINTNEXTLINE(performance-unnecessary-value-param)
Ret operator()(Params... params) const {
return callback(callable, std::forward<Params>(params)...);
}

explicit operator bool() const { return callback; }

bool operator==(const function_ref<Ret(Params...)> &Other) const {
return callable == Other.callable;
}
};
PYBIND11_NAMESPACE_END(detail)
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
161 changes: 110 additions & 51 deletions include/pybind11/pybind11.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "detail/dynamic_raw_ptr_cast_if_possible.h"
#include "detail/exception_translation.h"
#include "detail/function_record_pyobject.h"
#include "detail/function_ref.h"
#include "detail/init.h"
#include "detail/native_enum_data.h"
#include "detail/using_smart_holder.h"
Expand Down Expand Up @@ -379,6 +380,46 @@ class cpp_function : public function {
return unique_function_record(new detail::function_record());
}

private:
// This is outlined from the dispatch lambda in initialize to save
// on code size. Crucially, we use function_ref to type-erase the
// actual function lambda so that we can get code reuse for
// functions with the same Return, Args, and Guard.
template <typename Return, typename Guard, typename ArgsConverter, typename... Args>
static handle call_impl(detail::function_call &call, detail::function_ref<Return(Args...)> f) {
using namespace detail;
// Static assertion: function_ref must be trivially copyable to ensure safe pass-by-value.
// Lifetime safety: The function_ref is created from cap->f which lives in the capture
// object stored in the function record, and is only used synchronously within this
// function call. It is never stored beyond the scope of call_impl.
static_assert(std::is_trivially_copyable<detail::function_ref<Return(Args...)>>::value,
"function_ref must be trivially copyable for safe pass-by-value usage");
using cast_out
= make_caster<conditional_t<std::is_void<Return>::value, void_type, Return>>;

ArgsConverter args_converter;
if (!args_converter.load_args(call)) {
return PYBIND11_TRY_NEXT_OVERLOAD;
}

/* Override policy for rvalues -- usually to enforce rvp::move on an rvalue */
return_value_policy policy
= return_value_policy_override<Return>::policy(call.func.policy);

/* Perform the function call */
handle result;
if (call.func.is_setter) {
(void) std::move(args_converter).template call<Return, Guard>(f);
result = none().release();
} else {
result = cast_out::cast(
std::move(args_converter).template call<Return, Guard>(f), policy, call.parent);
}

return result;
}

protected:
/// Special internal constructor for functors, lambda functions, etc.
template <typename Func, typename Return, typename... Args, typename... Extra>
void initialize(Func &&f, Return (*)(Args...), const Extra &...extra) {
Expand Down Expand Up @@ -441,13 +482,6 @@ class cpp_function : public function {

/* Dispatch code which converts function arguments and performs the actual function call */
rec->impl = [](function_call &call) -> handle {
cast_in args_converter;

/* Try to cast the function arguments into the C++ domain */
if (!args_converter.load_args(call)) {
return PYBIND11_TRY_NEXT_OVERLOAD;
}

/* Invoke call policy pre-call hook */
process_attributes<Extra...>::precall(call);

Expand All @@ -456,24 +490,11 @@ class cpp_function : public function {
: call.func.data[0]);
auto *cap = const_cast<capture *>(reinterpret_cast<const capture *>(data));

/* Override policy for rvalues -- usually to enforce rvp::move on an rvalue */
return_value_policy policy
= return_value_policy_override<Return>::policy(call.func.policy);

/* Function scope guard -- defaults to the compile-to-nothing `void_type` */
using Guard = extract_guard_t<Extra...>;

/* Perform the function call */
handle result;
if (call.func.is_setter) {
(void) std::move(args_converter).template call<Return, Guard>(cap->f);
result = none().release();
} else {
result = cast_out::cast(
std::move(args_converter).template call<Return, Guard>(cap->f),
policy,
call.parent);
}
auto result = call_impl<Return,
/* Function scope guard -- defaults to the compile-to-nothing
`void_type` */
extract_guard_t<Extra...>,
cast_in>(call, detail::function_ref<Return(Args...)>(cap->f));

/* Invoke call policy post-call hook */
process_attributes<Extra...>::postcall(call, result);
Expand Down Expand Up @@ -2218,7 +2239,7 @@ class class_ : public detail::generic_type {
static void add_base(detail::type_record &) {}

template <typename Func, typename... Extra>
class_ &def(const char *name_, Func &&f, const Extra &...extra) {
PYBIND11_ALWAYS_INLINE class_ &def(const char *name_, Func &&f, const Extra &...extra) {
cpp_function cf(method_adaptor<type>(std::forward<Func>(f)),
name(name_),
is_method(*this),
Expand Down Expand Up @@ -2797,38 +2818,13 @@ struct enum_base {
pos_only())

if (is_convertible) {
PYBIND11_ENUM_OP_CONV_LHS("__eq__", !b.is_none() && a.equal(b));
PYBIND11_ENUM_OP_CONV_LHS("__ne__", b.is_none() || !a.equal(b));

if (is_arithmetic) {
PYBIND11_ENUM_OP_CONV("__lt__", a < b);
PYBIND11_ENUM_OP_CONV("__gt__", a > b);
PYBIND11_ENUM_OP_CONV("__le__", a <= b);
PYBIND11_ENUM_OP_CONV("__ge__", a >= b);
PYBIND11_ENUM_OP_CONV("__and__", a & b);
PYBIND11_ENUM_OP_CONV("__rand__", a & b);
PYBIND11_ENUM_OP_CONV("__or__", a | b);
PYBIND11_ENUM_OP_CONV("__ror__", a | b);
PYBIND11_ENUM_OP_CONV("__xor__", a ^ b);
PYBIND11_ENUM_OP_CONV("__rxor__", a ^ b);
m_base.attr("__invert__")
= cpp_function([](const object &arg) { return ~(int_(arg)); },
name("__invert__"),
is_method(m_base),
pos_only());
}
} else {
PYBIND11_ENUM_OP_STRICT("__eq__", int_(a).equal(int_(b)), return false);
PYBIND11_ENUM_OP_STRICT("__ne__", !int_(a).equal(int_(b)), return true);

if (is_arithmetic) {
#define PYBIND11_THROW throw type_error("Expected an enumeration of matching type!");
PYBIND11_ENUM_OP_STRICT("__lt__", int_(a) < int_(b), PYBIND11_THROW);
PYBIND11_ENUM_OP_STRICT("__gt__", int_(a) > int_(b), PYBIND11_THROW);
PYBIND11_ENUM_OP_STRICT("__le__", int_(a) <= int_(b), PYBIND11_THROW);
PYBIND11_ENUM_OP_STRICT("__ge__", int_(a) >= int_(b), PYBIND11_THROW);
#undef PYBIND11_THROW
}
}

#undef PYBIND11_ENUM_OP_CONV_LHS
Expand Down Expand Up @@ -2944,6 +2940,69 @@ class enum_ : public class_<Type> {

def(init([](Scalar i) { return static_cast<Type>(i); }), arg("value"));
def_property_readonly("value", [](Type value) { return (Scalar) value; }, pos_only());
#define PYBIND11_ENUM_OP_SAME_TYPE(op, expr) \
def(op, [](Type a, Type b) { return expr; }, pybind11::name(op), arg("other"), pos_only())
#define PYBIND11_ENUM_OP_SAME_TYPE_RHS_MAY_BE_NONE(op, expr) \
def(op, [](Type a, Type *b_ptr) { return expr; }, pybind11::name(op), arg("other"), pos_only())
#define PYBIND11_ENUM_OP_SCALAR(op, op_expr) \
def( \
op, \
[](Type a, Scalar b) { return static_cast<Scalar>(a) op_expr b; }, \
pybind11::name(op), \
arg("other"), \
pos_only())
#define PYBIND11_ENUM_OP_CONV_ARITHMETIC(op, op_expr) \
/* NOLINTNEXTLINE(bugprone-macro-parentheses) */ \
PYBIND11_ENUM_OP_SAME_TYPE(op, static_cast<Scalar>(a) op_expr static_cast<Scalar>(b)); \
PYBIND11_ENUM_OP_SCALAR(op, op_expr)
#define PYBIND11_ENUM_OP_REJECT_UNRELATED_TYPE(op, strict_behavior) \
def( \
op, \
[](Type, const object &) { strict_behavior; }, \
pybind11::name(op), \
arg("other"), \
pos_only())
#define PYBIND11_ENUM_OP_STRICT_ARITHMETIC(op, op_expr, strict_behavior) \
/* NOLINTNEXTLINE(bugprone-macro-parentheses) */ \
PYBIND11_ENUM_OP_SAME_TYPE(op, static_cast<Scalar>(a) op_expr static_cast<Scalar>(b)); \
PYBIND11_ENUM_OP_REJECT_UNRELATED_TYPE(op, strict_behavior);

PYBIND11_ENUM_OP_SAME_TYPE_RHS_MAY_BE_NONE("__eq__", b_ptr && a == *b_ptr);
PYBIND11_ENUM_OP_SAME_TYPE_RHS_MAY_BE_NONE("__ne__", !b_ptr || a != *b_ptr);
if (std::is_convertible<Type, Scalar>::value) {
PYBIND11_ENUM_OP_SCALAR("__eq__", ==);
PYBIND11_ENUM_OP_SCALAR("__ne__", !=);
if (is_arithmetic) {
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__lt__", <);
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__gt__", >);
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__le__", <=);
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__ge__", >=);
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__and__", &);
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__rand__", &);
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__or__", |);
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__ror__", |);
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__xor__", ^);
PYBIND11_ENUM_OP_CONV_ARITHMETIC("__rxor__", ^);
}
} else if (is_arithmetic) {
#define PYBIND11_ENUM_OP_THROW_TYPE_ERROR \
throw type_error("Expected an enumeration of matching type!");
PYBIND11_ENUM_OP_STRICT_ARITHMETIC("__lt__", <, PYBIND11_ENUM_OP_THROW_TYPE_ERROR);
PYBIND11_ENUM_OP_STRICT_ARITHMETIC("__gt__", >, PYBIND11_ENUM_OP_THROW_TYPE_ERROR);
PYBIND11_ENUM_OP_STRICT_ARITHMETIC("__le__", <=, PYBIND11_ENUM_OP_THROW_TYPE_ERROR);
PYBIND11_ENUM_OP_STRICT_ARITHMETIC("__ge__", >=, PYBIND11_ENUM_OP_THROW_TYPE_ERROR);
#undef PYBIND11_ENUM_OP_THROW_TYPE_ERROR
}
PYBIND11_ENUM_OP_REJECT_UNRELATED_TYPE("__eq__", return false);
PYBIND11_ENUM_OP_REJECT_UNRELATED_TYPE("__ne__", return true);

#undef PYBIND11_ENUM_OP_SAME_TYPE
#undef PYBIND11_ENUM_OP_SAME_TYPE_RHS_MAY_BE_NONE
#undef PYBIND11_ENUM_OP_SCALAR
#undef PYBIND11_ENUM_OP_CONV_ARITHMETIC
#undef PYBIND11_ENUM_OP_REJECT_UNRELATED_TYPE
#undef PYBIND11_ENUM_OP_STRICT_ARITHMETIC

def("__int__", [](Type value) { return (Scalar) value; }, pos_only());
def("__index__", [](Type value) { return (Scalar) value; }, pos_only());
attr("__setstate__") = cpp_function(
Expand Down
1 change: 1 addition & 0 deletions tests/extra_python_package/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
"include/pybind11/detail/descr.h",
"include/pybind11/detail/dynamic_raw_ptr_cast_if_possible.h",
"include/pybind11/detail/function_record_pyobject.h",
"include/pybind11/detail/function_ref.h",
"include/pybind11/detail/holder_caster_foreign_helpers.h",
"include/pybind11/detail/init.h",
"include/pybind11/detail/internals.h",
Expand Down
12 changes: 6 additions & 6 deletions tests/test_copy_move.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ def test_move_and_copy_loads():

assert c_m.copy_assignments + c_m.copy_constructions == 0
assert c_m.move_assignments == 6
assert c_m.move_constructions == 9
assert c_m.move_constructions == 21
assert c_mc.copy_assignments + c_mc.copy_constructions == 0
assert c_mc.move_assignments == 5
assert c_mc.move_constructions == 8
assert c_mc.move_constructions == 18
assert c_c.copy_assignments == 4
assert c_c.copy_constructions == 6
assert c_c.copy_constructions == 14
assert c_m.alive() + c_mc.alive() + c_c.alive() == 0


Expand Down Expand Up @@ -103,12 +103,12 @@ def test_move_and_copy_load_optional():

assert c_m.copy_assignments + c_m.copy_constructions == 0
assert c_m.move_assignments == 2
assert c_m.move_constructions == 5
assert c_m.move_constructions == 9
assert c_mc.copy_assignments + c_mc.copy_constructions == 0
assert c_mc.move_assignments == 2
assert c_mc.move_constructions == 5
assert c_mc.move_constructions == 9
assert c_c.copy_assignments == 2
assert c_c.copy_constructions == 5
assert c_c.copy_constructions == 9
assert c_m.alive() + c_mc.alive() + c_c.alive() == 0


Expand Down
Loading
Loading