Skip to content

Commit e6984c8

Browse files
native_enum: add capsule containing enum information and cleanup logic (#5871)
* native_enum: add capsule containing enum information and cleanup logic * style: pre-commit fixes * Updates from code review --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1594396 commit e6984c8

File tree

8 files changed

+174
-52
lines changed

8 files changed

+174
-52
lines changed

include/pybind11/detail/internals.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,24 @@ struct type_info {
382382
bool module_local : 1;
383383
};
384384

385+
/// Information stored in a capsule on py::native_enum() types. Since we don't
386+
/// create a type_info record for native enums, we must store here any
387+
/// information we will need about the enum at runtime.
388+
///
389+
/// If you make backward-incompatible changes to this structure, you must
390+
/// change the `attribute_name()` so that native enums from older version of
391+
/// pybind11 don't have their records reinterpreted. Better would be to keep
392+
/// the changes backward-compatible (i.e., only add new fields at the end)
393+
/// and detect/indicate their presence using the currently-unused `version`.
394+
struct native_enum_record {
395+
const std::type_info *cpptype;
396+
uint32_t size_bytes;
397+
bool is_signed;
398+
const uint8_t version = 1;
399+
400+
static const char *attribute_name() { return "__pybind11_native_enum__"; }
401+
};
402+
385403
#define PYBIND11_INTERNALS_ID \
386404
"__pybind11_internals_v" PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) \
387405
PYBIND11_COMPILER_TYPE_LEADING_UNDERSCORE PYBIND11_PLATFORM_ABI_ID "__"

include/pybind11/detail/native_enum_data.h

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,36 @@ native_enum_missing_finalize_error_message(const std::string &enum_name_encoded)
2222
return "pybind11::native_enum<...>(\"" + enum_name_encoded + "\", ...): MISSING .finalize()";
2323
}
2424

25+
// Internals for pybind11::native_enum; one native_enum_data object exists
26+
// inside each pybind11::native_enum and lives only for the duration of the
27+
// native_enum binding statement.
2528
class native_enum_data {
2629
public:
27-
native_enum_data(const object &parent_scope,
30+
native_enum_data(handle parent_scope_,
2831
const char *enum_name,
2932
const char *native_type_name,
3033
const char *class_doc,
31-
const std::type_index &enum_type_index)
34+
const native_enum_record &enum_record_)
3235
: enum_name_encoded{enum_name}, native_type_name_encoded{native_type_name},
33-
enum_type_index{enum_type_index}, parent_scope(parent_scope), enum_name{enum_name},
36+
enum_type_index{*enum_record_.cpptype},
37+
parent_scope(reinterpret_borrow<object>(parent_scope_)), enum_name{enum_name},
3438
native_type_name{native_type_name}, class_doc(class_doc), export_values_flag{false},
35-
finalize_needed{false} {}
39+
finalize_needed{false} {
40+
// Create the enum record capsule. It will be installed on the enum
41+
// type object during finalize(). Its destructor removes the enum
42+
// mapping from our internals, so that we won't try to convert to an
43+
// enum type that's been destroyed.
44+
enum_record = capsule(
45+
new native_enum_record{enum_record_},
46+
native_enum_record::attribute_name(),
47+
+[](void *record_) {
48+
auto *record = static_cast<native_enum_record *>(record_);
49+
with_internals([&](internals &internals) {
50+
internals.native_enum_type_map.erase(*record->cpptype);
51+
});
52+
delete record;
53+
});
54+
}
3655

3756
void finalize();
3857

@@ -71,6 +90,7 @@ class native_enum_data {
7190
str enum_name;
7291
str native_type_name;
7392
std::string class_doc;
93+
capsule enum_record;
7494

7595
protected:
7696
list members;
@@ -81,12 +101,6 @@ class native_enum_data {
81101
bool finalize_needed : 1;
82102
};
83103

84-
inline void global_internals_native_enum_type_map_set_item(const std::type_index &enum_type_index,
85-
PyObject *py_enum) {
86-
with_internals(
87-
[&](internals &internals) { internals.native_enum_type_map[enum_type_index] = py_enum; });
88-
}
89-
90104
inline handle
91105
global_internals_native_enum_type_map_get_item(const std::type_index &enum_type_index) {
92106
return with_internals([&](internals &internals) {
@@ -202,7 +216,11 @@ inline void native_enum_data::finalize() {
202216
for (auto doc : member_docs) {
203217
py_enum[doc[int_(0)]].attr("__doc__") = doc[int_(1)];
204218
}
205-
global_internals_native_enum_type_map_set_item(enum_type_index, py_enum.release().ptr());
219+
220+
py_enum.attr(native_enum_record::attribute_name()) = enum_record;
221+
with_internals([&](internals &internals) {
222+
internals.native_enum_type_map[enum_type_index] = py_enum.ptr();
223+
});
206224
}
207225

208226
PYBIND11_NAMESPACE_END(detail)

include/pybind11/native_enum.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ class native_enum : public detail::native_enum_data {
2222
public:
2323
using Underlying = typename std::underlying_type<EnumType>::type;
2424

25-
native_enum(const object &parent_scope,
25+
native_enum(handle parent_scope,
2626
const char *name,
2727
const char *native_type_name,
2828
const char *class_doc = "")
2929
: detail::native_enum_data(
30-
parent_scope, name, native_type_name, class_doc, std::type_index(typeid(EnumType))) {
30+
parent_scope, name, native_type_name, class_doc, make_record()) {
3131
if (detail::get_local_type_info(typeid(EnumType)) != nullptr
3232
|| detail::get_global_type_info(typeid(EnumType)) != nullptr) {
3333
pybind11_fail(
@@ -62,6 +62,15 @@ class native_enum : public detail::native_enum_data {
6262

6363
native_enum(const native_enum &) = delete;
6464
native_enum &operator=(const native_enum &) = delete;
65+
66+
private:
67+
static detail::native_enum_record make_record() {
68+
detail::native_enum_record ret;
69+
ret.cpptype = &typeid(EnumType);
70+
ret.size_bytes = sizeof(EnumType);
71+
ret.is_signed = std::is_signed<Underlying>::value;
72+
return ret;
73+
}
6574
};
6675

6776
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)

tests/conftest.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import sysconfig
1717
import textwrap
1818
import traceback
19+
import weakref
1920
from typing import Callable
2021

2122
import pytest
@@ -208,19 +209,54 @@ def pytest_assertrepr_compare(op, left, right): # noqa: ARG001
208209
return None
209210

210211

212+
# Number of times we think repeatedly collecting garbage might do anything.
213+
# The only reason to do more than once is because finalizers executed during
214+
# one GC pass could create garbage that can't be collected until a future one.
215+
# This quickly produces diminishing returns, and GC passes can be slow, so this
216+
# value is a tradeoff between non-flakiness and fast tests. (It errs on the
217+
# side of non-flakiness; many uses of this idiom only do 3 passes.)
218+
num_gc_collect = 5
219+
220+
211221
def gc_collect():
212-
"""Run the garbage collector three times (needed when running
222+
"""Run the garbage collector several times (needed when running
213223
reference counting tests with PyPy)"""
214-
gc.collect()
215-
gc.collect()
216-
gc.collect()
217-
gc.collect()
218-
gc.collect()
224+
for _ in range(num_gc_collect):
225+
gc.collect()
226+
227+
228+
def delattr_and_ensure_destroyed(*specs):
229+
"""For each of the given *specs* (a tuple of the form ``(scope, name)``),
230+
perform ``delattr(scope, name)``, then do enough GC collections that the
231+
deleted reference has actually caused the target to be destroyed. This is
232+
typically used to test what happens when a type object is destroyed; if you
233+
use it for that, you should be aware that extension types, or all types,
234+
are immortal on some Python versions. See ``env.TYPES_ARE_IMMORTAL``.
235+
"""
236+
wrs = []
237+
for mod, name in specs:
238+
wrs.append(weakref.ref(getattr(mod, name)))
239+
delattr(mod, name)
240+
241+
for _ in range(num_gc_collect):
242+
gc.collect()
243+
if all(wr() is None for wr in wrs):
244+
break
245+
else:
246+
# If this fires, most likely something is still holding a reference
247+
# to the object you tried to destroy - for example, it's a type that
248+
# still has some instances alive. Try setting a breakpoint here and
249+
# examining `gc.get_referrers(wrs[0]())`. It's vaguely possible that
250+
# num_gc_collect needs to be increased also.
251+
pytest.fail(
252+
f"Could not delete bindings such as {next(wr for wr in wrs if wr() is not None)!r}"
253+
)
219254

220255

221256
def pytest_configure():
222257
pytest.suppress = contextlib.suppress
223258
pytest.gc_collect = gc_collect
259+
pytest.delattr_and_ensure_destroyed = delattr_and_ensure_destroyed
224260

225261

226262
def pytest_report_header():

tests/env.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@
2424
# Runtime state (what's actually happening now)
2525
sys_is_gil_enabled = getattr(sys, "_is_gil_enabled", lambda: True)
2626

27+
TYPES_ARE_IMMORTAL = (
28+
PYPY
29+
or GRAALPY
30+
or (CPYTHON and PY_GIL_DISABLED and (3, 13) <= sys.version_info < (3, 14))
31+
)
32+
2733

2834
def deprecated_call():
2935
"""

tests/test_class_cross_module_use_after_one_module_dealloc.py

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,16 @@
11
from __future__ import annotations
22

3-
import gc
4-
import sys
5-
import sysconfig
63
import types
7-
import weakref
84

95
import pytest
106

117
import env
128
from pybind11_tests import class_cross_module_use_after_one_module_dealloc as m
139

14-
is_python_3_13_free_threaded = (
15-
env.CPYTHON
16-
and sysconfig.get_config_var("Py_GIL_DISABLED")
17-
and (3, 13) <= sys.version_info < (3, 14)
18-
)
19-
20-
21-
def delattr_and_ensure_destroyed(*specs):
22-
wrs = []
23-
for mod, name in specs:
24-
wrs.append(weakref.ref(getattr(mod, name)))
25-
delattr(mod, name)
2610

27-
for _ in range(5):
28-
gc.collect()
29-
if all(wr() is None for wr in wrs):
30-
break
31-
else:
32-
pytest.fail(
33-
f"Could not delete bindings such as {next(wr for wr in wrs if wr() is not None)!r}"
34-
)
35-
36-
37-
@pytest.mark.skipif("env.PYPY or env.GRAALPY or is_python_3_13_free_threaded")
11+
@pytest.mark.skipif(
12+
env.TYPES_ARE_IMMORTAL, reason="can't GC type objects on this platform"
13+
)
3814
def test_cross_module_use_after_one_module_dealloc():
3915
# This is a regression test for a bug that occurred during development of
4016
# internals::registered_types_cpp_fast (see #5842). registered_types_cpp_fast maps
@@ -58,7 +34,7 @@ def test_cross_module_use_after_one_module_dealloc():
5834
cm.consume_cross_dso_class(instance)
5935

6036
del instance
61-
delattr_and_ensure_destroyed((module_scope, "CrossDSOClass"))
37+
pytest.delattr_and_ensure_destroyed((module_scope, "CrossDSOClass"))
6238

6339
# Make sure that CrossDSOClass gets allocated at a different address.
6440
m.register_unrelated_class(module_scope)

tests/test_native_enum.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,14 @@ TEST_SUBMODULE(native_enum, m) {
9393
.value("blue", color::blue)
9494
.finalize();
9595

96-
py::native_enum<altitude>(m, "altitude", "enum.Enum")
97-
.value("high", altitude::high)
98-
.value("low", altitude::low)
99-
.finalize();
96+
m.def("bind_altitude", [](const py::module_ &mod) {
97+
py::native_enum<altitude>(mod, "altitude", "enum.Enum")
98+
.value("high", altitude::high)
99+
.value("low", altitude::low)
100+
.finalize();
101+
});
102+
m.def("is_high_altitude", [](altitude alt) { return alt == altitude::high; });
103+
m.def("get_altitude", []() -> altitude { return altitude::high; });
100104

101105
py::native_enum<flags_uchar>(m, "flags_uchar", "enum.Flag")
102106
.value("bit0", flags_uchar::bit0)

tests/test_native_enum.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
ENUM_TYPES_AND_MEMBERS = (
6060
(m.smallenum, SMALLENUM_MEMBERS),
6161
(m.color, COLOR_MEMBERS),
62-
(m.altitude, ALTITUDE_MEMBERS),
6362
(m.flags_uchar, FLAGS_UCHAR_MEMBERS),
6463
(m.flags_uint, FLAGS_UINT_MEMBERS),
6564
(m.export_values, EXPORT_VALUES_MEMBERS),
@@ -320,3 +319,59 @@ def test_native_enum_missing_finalize_failure():
320319
if not isinstance(m.native_enum_missing_finalize_failure, str):
321320
m.native_enum_missing_finalize_failure()
322321
pytest.fail("Process termination expected.")
322+
323+
324+
def test_unregister_native_enum_when_destroyed():
325+
# For stability when running tests in parallel, this test should be the
326+
# only one that touches `m.altitude` or calls `m.bind_altitude`.
327+
328+
def test_altitude_enum():
329+
# Logic copied from test_enum_type / test_enum_members.
330+
# We don't test altitude there to avoid possible clashes if
331+
# parallelizing against other tests in this file, and we also
332+
# don't want to hold any references to the enumerators that
333+
# would prevent GCing the enum type below.
334+
assert isinstance(m.altitude, enum.EnumMeta)
335+
assert m.altitude.__module__ == m.__name__
336+
for name, value in ALTITUDE_MEMBERS:
337+
assert m.altitude[name].value == value
338+
339+
def test_altitude_binding():
340+
assert m.is_high_altitude(m.altitude.high)
341+
assert not m.is_high_altitude(m.altitude.low)
342+
assert m.get_altitude() is m.altitude.high
343+
with pytest.raises(TypeError, match="incompatible function arguments"):
344+
m.is_high_altitude("oops")
345+
346+
m.bind_altitude(m)
347+
test_altitude_enum()
348+
test_altitude_binding()
349+
350+
if env.TYPES_ARE_IMMORTAL:
351+
pytest.skip("can't GC type objects on this platform")
352+
353+
# Delete the enum type. Returning an instance from Python should fail
354+
# rather than accessing a deleted object.
355+
pytest.delattr_and_ensure_destroyed((m, "altitude"))
356+
with pytest.raises(TypeError, match="Unable to convert function return"):
357+
m.get_altitude()
358+
with pytest.raises(TypeError, match="incompatible function arguments"):
359+
m.is_high_altitude("oops")
360+
361+
# Recreate the enum type; should not have any duplicate-binding error
362+
m.bind_altitude(m)
363+
test_altitude_enum()
364+
test_altitude_binding()
365+
366+
# Remove the pybind11 capsule without removing the type; enum is still
367+
# usable but can't be passed to/from bound functions
368+
del m.altitude.__pybind11_native_enum__
369+
pytest.gc_collect()
370+
test_altitude_enum() # enum itself still works
371+
372+
with pytest.raises(TypeError, match="Unable to convert function return"):
373+
m.get_altitude()
374+
with pytest.raises(TypeError, match="incompatible function arguments"):
375+
m.is_high_altitude(m.altitude.high)
376+
377+
del m.altitude

0 commit comments

Comments
 (0)