diff --git a/ddtrace/internal/ipc.py b/ddtrace/internal/ipc.py index 710336ca2df..f17e08788ca 100644 --- a/ddtrace/internal/ipc.py +++ b/ddtrace/internal/ipc.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager import os import secrets import tempfile @@ -99,8 +100,18 @@ def open_file(path, mode): # type: ignore class SharedStringFile: """A simple shared-file implementation for multiprocess communication.""" - def __init__(self) -> None: - self.filename: typing.Optional[str] = str(TMPDIR / secrets.token_hex(8)) if TMPDIR is not None else None + def __init__(self, name: typing.Optional[str] = None) -> None: + self.filename: typing.Optional[str] = ( + str(TMPDIR / (name or secrets.token_hex(8))) if TMPDIR is not None else None + ) + if self.filename is not None: + Path(self.filename).touch(exist_ok=True) + + def put_unlocked(self, f: typing.BinaryIO, data: str) -> None: + f.seek(0, os.SEEK_END) + dt = (data + "\x00").encode() + if f.tell() + len(dt) <= MAX_FILE_SIZE: + f.write(dt) def put(self, data: str) -> None: """Put a string into the file.""" @@ -108,23 +119,23 @@ def put(self, data: str) -> None: return try: - with open_file(self.filename, "ab") as f, WriteLock(f): - f.seek(0, os.SEEK_END) - dt = (data + "\x00").encode() - if f.tell() + len(dt) <= MAX_FILE_SIZE: - f.write(dt) + with self.lock_exclusive() as f: + self.put_unlocked(f, data) except Exception: # nosec pass + def peekall_unlocked(self, f: typing.BinaryIO) -> typing.List[str]: + f.seek(0) + return data.decode().split("\x00") if (data := f.read().strip(b"\x00")) else [] + def peekall(self) -> typing.List[str]: """Peek at all strings from the file.""" if self.filename is None: return [] try: - with open_file(self.filename, "r+b") as f, ReadLock(f): - f.seek(0) - return f.read().strip(b"\x00").decode().split("\x00") + with self.lock_shared() as f: + return self.peekall_unlocked(f) except Exception: # nosec return [] @@ -134,13 +145,39 @@ def snatchall(self) -> typing.List[str]: return [] try: - with open_file(self.filename, "r+b") as f, WriteLock(f): - f.seek(0) - strings = f.read().strip(b"\x00").decode().split("\x00") + with self.lock_exclusive() as f: + try: + return self.peekall_unlocked(f) + finally: + self.clear_unlocked(f) + except Exception: # nosec + return [] - f.seek(0) - f.truncate() + def clear_unlocked(self, f: typing.BinaryIO) -> None: + f.seek(0) + f.truncate() + + def clear(self) -> None: + """Clear all strings from the file.""" + if self.filename is None: + return - return strings + try: + with self.lock_exclusive() as f: + self.clear_unlocked(f) except Exception: # nosec - return [] + pass + + @contextmanager + def lock_shared(self): + """Context manager to acquire a shared/read lock on the file.""" + with open_file(self.filename, "rb") as f, ReadLock(f): + yield f + + @contextmanager + def lock_exclusive(self): + """Context manager to acquire an exclusive/write lock on the file.""" + if self.filename is None: + return + with open_file(self.filename, "r+b") as f, WriteLock(f): + yield f diff --git a/ddtrace/internal/symbol_db/remoteconfig.py b/ddtrace/internal/symbol_db/remoteconfig.py index 78eca1a087d..d753228ec62 100644 --- a/ddtrace/internal/symbol_db/remoteconfig.py +++ b/ddtrace/internal/symbol_db/remoteconfig.py @@ -2,6 +2,7 @@ import typing as t from ddtrace.internal.forksafe import has_forked +from ddtrace.internal.ipc import SharedStringFile from ddtrace.internal.logger import get_logger from ddtrace.internal.products import manager as product_manager from ddtrace.internal.remoteconfig import Payload @@ -18,20 +19,34 @@ log = get_logger(__name__) +# Use a shared file to keep track of which PIDs have Symbol DB enabled. This way +# we can ensure that at most two processes are emitting symbols under a large +# range of scenarios. +shared_pid_file = SharedStringFile(f"{os.getppid()}-symdb-pids") + +MAX_CHILD_UPLOADERS = 1 # max one child + def _rc_callback(data: t.Sequence[Payload]): - if get_ancestor_runtime_id() is not None and has_forked(): - log.debug("[PID %d] SymDB: Disabling Symbol DB in forked process", os.getpid()) - # We assume that forking is being used for spawning child worker - # processes. Therefore, we avoid uploading the same symbols from each - # child process. We restrict the enablement of Symbol DB to just the - # parent process and the first fork child. - remoteconfig_poller.unregister("LIVE_DEBUGGING_SYMBOL_DB") - - if SymbolDatabaseUploader.is_installed(): - SymbolDatabaseUploader.uninstall() - - return + with shared_pid_file.lock_exclusive() as f: + if (get_ancestor_runtime_id() is not None and has_forked()) or len( + set(shared_pid_file.peekall_unlocked(f)) + ) >= MAX_CHILD_UPLOADERS: + log.debug("[PID %d] SymDB: Disabling Symbol DB in child process", os.getpid()) + # We assume that forking is being used for spawning child worker + # processes. Therefore, we avoid uploading the same symbols from each + # child process. We restrict the enablement of Symbol DB to just the + # parent process and the first fork child. + remoteconfig_poller.unregister("LIVE_DEBUGGING_SYMBOL_DB") + + if SymbolDatabaseUploader.is_installed(): + SymbolDatabaseUploader.uninstall() + + return + + # Store the PID of the current process so that we know which processes + # have Symbol DB enabled. + shared_pid_file.put_unlocked(f, str(os.getpid())) for payload in data: if payload.metadata is None: diff --git a/tests/internal/symbol_db/test_symbols.py b/tests/internal/symbol_db/test_symbols.py index 27425debc3f..7d0c68019f1 100644 --- a/tests/internal/symbol_db/test_symbols.py +++ b/tests/internal/symbol_db/test_symbols.py @@ -15,6 +15,15 @@ from ddtrace.internal.symbol_db.symbols import SymbolType +@pytest.fixture(autouse=True, scope="function") +def pid_file_teardown(): + from ddtrace.internal.symbol_db.remoteconfig import shared_pid_file + + yield + + shared_pid_file.clear() + + def test_symbol_from_code(): def foo(a, b, c=None): loc = 42 @@ -320,3 +329,39 @@ def test_symbols_fork_uploads(): for pid in pids: os.waitpid(pid, 0) + + +@pytest.mark.subprocess(run_module=True, err=None) +def test_symbols_spawn_uploads(): + def spawn_target(results): + from ddtrace.internal.remoteconfig import ConfigMetadata + from ddtrace.internal.remoteconfig import Payload + from ddtrace.internal.symbol_db.remoteconfig import _rc_callback + from ddtrace.internal.symbol_db.symbols import SymbolDatabaseUploader + + SymbolDatabaseUploader.install() + + rc_data = [Payload(ConfigMetadata("test", "symdb", "hash", 0, 0), "test", None)] + _rc_callback(rc_data) + results.append(SymbolDatabaseUploader.is_installed()) + + if __name__ == "__main__": + import multiprocessing + + multiprocessing.freeze_support() + + multiprocessing.set_start_method("spawn", force=True) + mc_context = multiprocessing.get_context("spawn") + manager = multiprocessing.Manager() + returns = manager.list() + jobs = [] + + for _ in range(10): + p = mc_context.Process(target=spawn_target, args=(returns,)) + p.start() + jobs.append(p) + + for p in jobs: + p.join() + + assert sum(returns) == 1, returns