diff --git a/Lib/profiling/sampling/_control.py b/Lib/profiling/sampling/_control.py new file mode 100644 index 00000000000000..7d8738b1d67b1b --- /dev/null +++ b/Lib/profiling/sampling/_control.py @@ -0,0 +1,238 @@ +"""Control runtime for the sampling profiler.""" + +import os +import selectors +import socket +import warnings + +from .errors import ControlError, ControlURIError + + +class ProfilerControl: + def __init__(self): + self.enabled = True + self.running = True + + +def parse_control_uri(uri, *, allowed_schemes=("unix",)): + if ":" not in uri: + raise ControlURIError("control URI must include a scheme") + + scheme, path = uri.split(":", 1) + if scheme not in allowed_schemes: + expected = ", ".join(allowed_schemes) + raise ControlURIError( + f"unsupported control URI scheme {scheme!r}; " + f"expected one of: {expected}" + ) + if not path: + raise ControlURIError("control URI path must not be empty") + return scheme, path + + +_MAX_OUTBUF_BYTES = 64 * 1024 +_MAX_INBUF_BYTES = 4 * 1024 +_MAX_CONNECTIONS = 8 +_SOCKET_PERMISSIONS = 0o600 + + +class _Connection: + def __init__(self, sock): + self.sock = sock + self.inbuf = bytearray() + self.outbuf = bytearray() + self.close_after_write = False + + +class ControlServer: + def __init__(self, uri): + self.uri = uri + self.control = ProfilerControl() + _, self._path = parse_control_uri(uri) + self.selector = selectors.DefaultSelector() + self._connections = set() + self._listener = None + self._created_stat = None + + def start(self): + self._listener = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + self._listener.bind(self._path) + os.chmod(self._path, _SOCKET_PERMISSIONS) + self._created_stat = os.lstat(self._path) + self._listener.listen(socket.SOMAXCONN) + self._listener.setblocking(False) + self.selector.register(self._listener, selectors.EVENT_READ, None) + except OSError as exc: + self._close_listener() + raise ControlError( + f"failed to start control socket {self._path!r}: {exc}" + ) from exc + + def stop(self): + for conn in list(self._connections): + self._close_connection(conn) + self.selector.close() + self._close_listener() + + def _close_listener(self): + listener, self._listener = self._listener, None + if listener is not None: + listener.close() + + created_stat, self._created_stat = self._created_stat, None + if created_stat is None: + return + try: + current_stat = os.lstat(self._path) + if (current_stat.st_ino, current_stat.st_dev) == ( + created_stat.st_ino, + created_stat.st_dev, + ): + os.unlink(self._path) + except OSError: + pass + + def poll(self, timeout): + try: + ready = self.selector.select(timeout) + except OSError as exc: + warnings.warn( + f"control selector.select() failed: {exc}", + RuntimeWarning, + stacklevel=2, + ) + return + + for key, events in ready: + if key.fileobj is self._listener: + self._accept_connection() + else: + self._handle_connection(key.data, events) + + def _accept_connection(self): + try: + sock, _addr = self._listener.accept() + except BlockingIOError: + return + except OSError as exc: + warnings.warn( + f"control accept failed: {exc}", + RuntimeWarning, + stacklevel=2, + ) + return + + if len(self._connections) >= _MAX_CONNECTIONS: + sock.close() + return + + try: + sock.setblocking(False) + conn = _Connection(sock) + self.selector.register(sock, selectors.EVENT_READ, conn) + except OSError: + sock.close() + return + + self._connections.add(conn) + + def _handle_connection(self, conn, events): + if events & selectors.EVENT_READ: + self._read_connection(conn) + if conn in self._connections and events & selectors.EVENT_WRITE: + self._flush_connection(conn) + + def _read_connection(self, conn): + try: + chunk = conn.sock.recv(_MAX_INBUF_BYTES) + except (BlockingIOError, InterruptedError): + return + except OSError: + self._close_connection(conn) + return + + if not chunk: + self._close_connection(conn) + return + + conn.inbuf.extend(chunk) + if len(conn.inbuf) > _MAX_INBUF_BYTES: + self._close_connection(conn) + return + + while True: + newline = conn.inbuf.find(b"\n") + if newline == -1: + break + raw = conn.inbuf.take_bytes(newline + 1) + line = raw[:-1].decode("utf-8", "replace").strip() + self._dispatch(conn, line) + if conn not in self._connections or conn.close_after_write: + break + + if conn in self._connections: + self._flush_connection(conn) + + def _dispatch(self, conn, command): + match command: + case "enable": + self.control.enabled = True + reply = "ok\n" + case "disable": + self.control.enabled = False + reply = "ok\n" + case "ping": + reply = "ok\n" + case "status": + reply = f"ok enabled={self.control.enabled}\n" + case "quit": + self.control.running = False + conn.close_after_write = True + reply = "ok\n" + case _: + reply = "err unknown_command\n" + + conn.outbuf.extend(reply.encode("ascii")) + if len(conn.outbuf) > _MAX_OUTBUF_BYTES: + self._close_connection(conn) + + def _flush_connection(self, conn): + while conn.outbuf: + try: + sent = conn.sock.send(conn.outbuf) + except (BlockingIOError, InterruptedError): + break + except OSError: + self._close_connection(conn) + return + + if sent == 0: + self._close_connection(conn) + return + + del conn.outbuf[:sent] + + if not conn.outbuf and conn.close_after_write: + self._close_connection(conn) + return + + events = selectors.EVENT_READ + if conn.outbuf: + events |= selectors.EVENT_WRITE + try: + self.selector.modify(conn.sock, events, conn) + except (KeyError, OSError): + self._close_connection(conn) + + def _close_connection(self, conn): + if conn not in self._connections: + return + self._connections.discard(conn) + + try: + self.selector.unregister(conn.sock) + except (KeyError, OSError): + pass + + conn.sock.close() diff --git a/Lib/profiling/sampling/cli.py b/Lib/profiling/sampling/cli.py index a5d9573ae6b6dd..0b7f24ebbe113b 100644 --- a/Lib/profiling/sampling/cli.py +++ b/Lib/profiling/sampling/cli.py @@ -11,9 +11,15 @@ import sys import time import webbrowser -from contextlib import nullcontext - -from .errors import SamplingUnknownProcessError, SamplingModuleNotFoundError, SamplingScriptNotFoundError +from contextlib import ExitStack, nullcontext + +from .errors import ( + SamplingUnknownProcessError, + SamplingModuleNotFoundError, + SamplingScriptNotFoundError, + ControlError, + ControlURIError, +) from .sample import sample, sample_live, dump_stack, _is_process_running from .dump import print_stack_dump from .pstats_collector import PstatsCollector @@ -23,6 +29,10 @@ from .jsonl_collector import JsonlCollector from .binary_collector import BinaryCollector from .binary_reader import BinaryReader +from ._control import ( + ControlServer, + parse_control_uri, +) from .constants import ( MICROSECONDS_PER_SECOND, PROFILING_MODE_ALL, @@ -427,6 +437,16 @@ def _add_sampling_options(parser): ) +def _add_control_options(parser): + control_group = parser.add_argument_group("Control socket") + control_group.add_argument( + "--control", + default=None, + metavar="URI", + help="control socket URI (unix:)", + ) + + def _add_mode_options(parser): """Add mode options to a parser.""" mode_group = parser.add_argument_group("Mode options") @@ -859,6 +879,16 @@ def _validate_args(args, parser): if command == "replay": return + if getattr(args, 'control', None) is not None: + if args.subprocesses: + parser.error("--control is incompatible with --subprocesses.") + if os.name == "nt": + parser.error("--control is not supported on Windows.") + try: + parse_control_uri(args.control) + except ControlURIError as exc: + parser.error(str(exc)) + # Check if live mode is available if hasattr(args, 'live') and args.live and LiveStatsCollector is None: parser.error( @@ -1035,6 +1065,7 @@ def _main(): help="Interactive TUI profiler (top-like interface, press 'q' to quit, 's' to cycle sort)", ) _add_sampling_options(run_parser) + _add_control_options(run_parser) _add_mode_options(run_parser) _add_format_options(run_parser) _add_pstats_options(run_parser) @@ -1064,6 +1095,7 @@ def _main(): help="Interactive TUI profiler (top-like interface, press 'q' to quit, 's' to cycle sort)", ) _add_sampling_options(attach_parser) + _add_control_options(attach_parser) _add_mode_options(attach_parser) _add_format_options(attach_parser) _add_pstats_options(attach_parser) @@ -1168,7 +1200,16 @@ def _handle_attach(args): diff_baseline=args.diff_baseline ) - with _get_child_monitor_context(args, args.pid): + server = None + with ExitStack() as stack: + if args.control: + server = ControlServer(args.control) + try: + server.start() + except ControlError as exc: + sys.exit(f"Error: {exc}") + stack.callback(server.stop) + stack.enter_context(_get_child_monitor_context(args, args.pid)) collector = sample( args.pid, collector, @@ -1181,6 +1222,7 @@ def _handle_attach(args): gc=args.gc, opcodes=args.opcodes, blocking=args.blocking, + control_server=server, ) _handle_output(collector, args, args.pid, mode) @@ -1275,23 +1317,11 @@ def _handle_run(args): diff_baseline=args.diff_baseline ) - with _get_child_monitor_context(args, process.pid): - try: - collector = sample( - process.pid, - collector, - duration_sec=args.duration, - all_threads=args.all_threads, - realtime_stats=args.realtime_stats, - mode=mode, - async_aware=args.async_mode if args.async_aware else None, - native=args.native, - gc=args.gc, - opcodes=args.opcodes, - blocking=args.blocking, - ) - _handle_output(collector, args, process.pid, mode) - finally: + server = None + with ExitStack() as stack: + stack.enter_context(_get_child_monitor_context(args, process.pid)) + + def _terminate_main_subprocess(): # Terminate the main subprocess - child profilers finish when their # target processes exit if process.poll() is None: @@ -1301,6 +1331,31 @@ def _handle_run(args): except subprocess.TimeoutExpired: process.kill() process.wait() + stack.callback(_terminate_main_subprocess) + + if args.control: + server = ControlServer(args.control) + try: + server.start() + except ControlError as exc: + sys.exit(f"Error: {exc}") + stack.callback(server.stop) + + collector = sample( + process.pid, + collector, + duration_sec=args.duration, + all_threads=args.all_threads, + realtime_stats=args.realtime_stats, + mode=mode, + async_aware=args.async_mode if args.async_aware else None, + native=args.native, + gc=args.gc, + opcodes=args.opcodes, + blocking=args.blocking, + control_server=server, + ) + _handle_output(collector, args, process.pid, mode) def _handle_live_attach(args, pid): @@ -1323,19 +1378,29 @@ def _handle_live_attach(args, pid): ) # Sample in live mode - sample_live( - pid, - collector, - duration_sec=args.duration, - all_threads=args.all_threads, - realtime_stats=args.realtime_stats, - mode=mode, - async_aware=args.async_mode if args.async_aware else None, - native=args.native, - gc=args.gc, - opcodes=args.opcodes, - blocking=args.blocking, - ) + server = None + with ExitStack() as stack: + if args.control: + server = ControlServer(args.control) + try: + server.start() + except ControlError as exc: + sys.exit(f"Error: {exc}") + stack.callback(server.stop) + sample_live( + pid, + collector, + duration_sec=args.duration, + all_threads=args.all_threads, + realtime_stats=args.realtime_stats, + mode=mode, + async_aware=args.async_mode if args.async_aware else None, + native=args.native, + gc=args.gc, + opcodes=args.opcodes, + blocking=args.blocking, + control_server=server, + ) def _handle_live_run(args): @@ -1370,20 +1435,30 @@ def _handle_live_run(args): ) # Profile the subprocess in live mode + server = None try: - sample_live( - process.pid, - collector, - duration_sec=args.duration, - all_threads=args.all_threads, - realtime_stats=args.realtime_stats, - mode=mode, - async_aware=args.async_mode if args.async_aware else None, - native=args.native, - gc=args.gc, - opcodes=args.opcodes, - blocking=args.blocking, - ) + with ExitStack() as stack: + if args.control: + server = ControlServer(args.control) + try: + server.start() + except ControlError as exc: + sys.exit(f"Error: {exc}") + stack.callback(server.stop) + sample_live( + process.pid, + collector, + duration_sec=args.duration, + all_threads=args.all_threads, + realtime_stats=args.realtime_stats, + mode=mode, + async_aware=args.async_mode if args.async_aware else None, + native=args.native, + gc=args.gc, + opcodes=args.opcodes, + blocking=args.blocking, + control_server=server, + ) finally: # Clean up the subprocess and get any error output returncode = process.poll() diff --git a/Lib/profiling/sampling/errors.py b/Lib/profiling/sampling/errors.py index 0832ad2d4381e0..6d6fc4c432ce13 100644 --- a/Lib/profiling/sampling/errors.py +++ b/Lib/profiling/sampling/errors.py @@ -17,3 +17,9 @@ class SamplingModuleNotFoundError(SamplingProfilerError): def __init__(self, module_name): self.module_name = module_name super().__init__(f"Module '{module_name}' not found.") + +class ControlError(SamplingProfilerError): + """Base exception for profiler control channel errors.""" + +class ControlURIError(ControlError): + """Raised when a control URI is malformed or has an unsupported scheme.""" diff --git a/Lib/profiling/sampling/sample.py b/Lib/profiling/sampling/sample.py index b9e7e2625d09e4..840ea97aeee031 100644 --- a/Lib/profiling/sampling/sample.py +++ b/Lib/profiling/sampling/sample.py @@ -102,21 +102,30 @@ def dump_stack(self, *, async_aware=None): """Return a single stack snapshot from the target process.""" return self._get_stack_trace(async_aware=async_aware) - def sample(self, collector, duration_sec=None, *, async_aware=False): + def sample( + self, + collector, + duration_sec=None, + *, + async_aware=False, + control_server=None, + ): sample_interval_sec = self.sample_interval_usec / 1_000_000 num_samples = 0 errors = 0 interrupted = False running_time_sec = 0 - start_time = next_time = time.perf_counter() - last_sample_time = start_time + enabled_time_sec = 0 realtime_update_interval = 1.0 # Update every second - last_realtime_update = start_time + start_time = next_time = time.perf_counter() + enabled_since = last_sample_time = last_realtime_update = next_control_poll = start_time aggregating = getattr(collector, 'aggregating', False) is True prev_stack = None pending_count = 0 pending_timestamps = [] if aggregating else None + control = control_server.control if control_server is not None else None + def flush_pending(): nonlocal pending_count, pending_timestamps if pending_count == 0: @@ -128,11 +137,29 @@ def flush_pending(): try: while duration_sec is None or running_time_sec < duration_sec: - # Check if live collector wants to stop - if hasattr(collector, 'running') and not collector.running: + current_time = time.perf_counter() + if control_server is not None and current_time >= next_control_poll: + control_server.poll(timeout=0) + next_control_poll = current_time + 0.001 + # Check if live collector or runtime control wants to stop + if not getattr(control, 'running', True) or not getattr(collector, 'running', True): break + enabled = getattr(control, 'enabled', True) + + if not enabled: + if enabled_since is not None: + enabled_time_sec += current_time - enabled_since + enabled_since = None + time.sleep(sample_interval_sec) + running_time_sec = time.perf_counter() - start_time + continue + + if enabled_since is None: + enabled_since = current_time + next_time = current_time + sample_interval_sec + last_sample_time = current_time + last_realtime_update = current_time - current_time = time.perf_counter() current_time_us = int(current_time * 1_000_000) if next_time > current_time: sleep_time = (next_time - current_time) * 0.9 @@ -190,18 +217,27 @@ def flush_pending(): running_time_sec = time.perf_counter() - start_time except KeyboardInterrupt: interrupted = True - running_time_sec = time.perf_counter() - start_time + now = time.perf_counter() + running_time_sec = now - start_time + if enabled_since is not None: + enabled_time_sec += now - enabled_since print("Interrupted by user.") finally: flush_pending() + if not interrupted: + final_time = time.perf_counter() + if enabled_since is not None: + enabled_time_sec += final_time - enabled_since + running_time_sec = final_time - start_time + # Clear real-time stats line if it was being displayed if self.realtime_stats and len(self.sample_intervals) > 0: print() # Add newline after real-time stats - sample_rate = num_samples / running_time_sec if running_time_sec > 0 else 0 + sample_rate = num_samples / enabled_time_sec if enabled_time_sec > 0 else 0 error_rate = (errors / num_samples) * 100 if num_samples > 0 else 0 - expected_samples = int(running_time_sec / sample_interval_sec) + expected_samples = int(enabled_time_sec / sample_interval_sec) missed_samples = (expected_samples - num_samples) / expected_samples * 100 if expected_samples > 0 else 0 # Don't print stats for live mode (curses is handling display) @@ -427,6 +463,7 @@ def sample( gc=True, opcodes=False, blocking=False, + control_server=None, ): """Sample a process using the provided collector. @@ -474,7 +511,12 @@ def sample( profiler.realtime_stats = realtime_stats # Run the sampling - profiler.sample(collector, duration_sec, async_aware=async_aware) + profiler.sample( + collector, + duration_sec, + async_aware=async_aware, + control_server=control_server, + ) return collector @@ -523,6 +565,7 @@ def sample_live( gc=True, opcodes=False, blocking=False, + control_server=None, ): """Sample a process in live/interactive mode with curses TUI. @@ -545,6 +588,8 @@ def sample_live( """ import curses + control = control_server.control if control_server is not None else None + # Check if process is alive before doing any heavy initialization if not _is_process_running(pid): print(f"No samples collected - process {pid} exited before profiling could begin.", file=sys.stderr) @@ -576,7 +621,12 @@ def sample_live( def curses_wrapper_func(stdscr): collector.init_curses(stdscr) try: - profiler.sample(collector, duration_sec, async_aware=async_aware) + profiler.sample( + collector, + duration_sec, + async_aware=async_aware, + control_server=control_server, + ) # If too few samples were collected, exit cleanly without showing TUI if collector.successful_samples < MIN_SAMPLES_FOR_TUI: # Clear screen before exiting to avoid visual artifacts @@ -586,7 +636,7 @@ def curses_wrapper_func(stdscr): # Mark as finished and keep the TUI running until user presses 'q' collector.mark_finished() # Keep processing input until user quits - while collector.running: + while collector.running and (control is None or control.running): collector._handle_input() time.sleep(0.05) # Small sleep to avoid busy waiting finally: diff --git a/Lib/test/test_profiling/test_sampling_profiler/test_control.py b/Lib/test/test_profiling/test_sampling_profiler/test_control.py new file mode 100644 index 00000000000000..5ca2715a313b30 --- /dev/null +++ b/Lib/test/test_profiling/test_sampling_profiler/test_control.py @@ -0,0 +1,230 @@ +"""Tests for the sampling profiler control socket.""" + +import io +import os +import socket +import time +import unittest +from unittest import mock + +from test.support import SHORT_TIMEOUT, os_helper, socket_helper + +try: + from profiling.sampling._control import ( + ControlServer, + _MAX_INBUF_BYTES, + parse_control_uri, + ) + from profiling.sampling.cli import LiveStatsCollector, main + from profiling.sampling.errors import ControlError, ControlURIError +except ImportError: + raise unittest.SkipTest( + "Test only runs when profiling.sampling is available" + ) + + +@socket_helper.skip_unless_bind_unix_socket +class ControlServerTests(unittest.TestCase): + """Tests for ControlServer protocol, lifecycle and CLI integration.""" + + def setUp(self): + self.path = socket_helper.create_unix_domain_name() + self.addCleanup(os_helper.unlink, self.path) + + def start_server(self): + server = ControlServer(f"unix:{self.path}") + server.start() + self.addCleanup(server.stop) + return server + + def connect(self): + client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + client.connect(self.path) + client.setblocking(False) + self.addCleanup(client.close) + return client + + def request(self, server, client, command): + client.sendall(command) + deadline = time.monotonic() + SHORT_TIMEOUT + while time.monotonic() < deadline: + server.poll(timeout=0.05) + try: + return client.recv(4096) + except BlockingIOError: + pass + self.fail("timed out waiting for control reply") + + def test_parse_control_uri_valid_unix(self): + """parse_control_uri returns (scheme, path) for unix: URIs.""" + self.assertEqual( + parse_control_uri("unix:/tmp/tachyon.sock"), + ("unix", "/tmp/tachyon.sock"), + ) + + def test_parse_control_uri_rejects_invalids(self): + """parse_control_uri raises ControlURIError on malformed URIs.""" + cases = [ + ("/tmp/x", "must include a scheme"), + ("unix:", "path must not be empty"), + ("fifo:/tmp/x", "unsupported control URI scheme"), + ("", "must include a scheme"), + ] + for uri, expected in cases: + with self.subTest(uri=uri): + with self.assertRaisesRegex(ControlURIError, expected): + parse_control_uri(uri) + + def test_start_creates_and_stop_unlinks_socket(self): + """start() binds the socket on disk; stop() removes it.""" + server = ControlServer(f"unix:{self.path}") + server.start() + try: + self.assertTrue(os.path.exists(self.path)) + finally: + server.stop() + self.assertFalse(os.path.exists(self.path)) + + def test_start_fails_on_occupied_path(self): + """start() raises ControlError when the path is already bound.""" + squatter = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + socket_helper.bind_unix_socket(squatter, self.path) + self.addCleanup(squatter.close) + with self.assertRaisesRegex(ControlError, "failed to start"): + ControlServer(f"unix:{self.path}").start() + + def test_dispatch_basic_commands(self): + """enable/disable/ping/unknown produce the documented replies.""" + server = self.start_server() + client = self.connect() + cases = [ + (b"enable\n", b"ok\n", True), + (b"disable\n", b"ok\n", False), + (b"ping\n", b"ok\n", False), + (b"pingu nut nut\n", b"err unknown_command\n", False), + ] + for command, reply, expected_enabled in cases: + with self.subTest(command=command): + self.assertEqual(self.request(server, client, command), reply) + if command in (b"enable\n", b"disable\n"): + self.assertEqual(server.control.enabled, expected_enabled) + + def test_dispatch_status_format(self): + """status reply exposes the enabled flag.""" + server = self.start_server() + client = self.connect() + self.assertEqual( + self.request(server, client, b"status\n"), + b"ok enabled=True\n", + ) + + def test_dispatch_quit_sets_running_and_closes(self): + """quit replies ok, sets running=False, and closes the connection.""" + server = self.start_server() + client = self.connect() + self.assertEqual(self.request(server, client, b"quit\n"), b"ok\n") + self.assertFalse(server.control.running) + deadline = time.monotonic() + SHORT_TIMEOUT + while time.monotonic() < deadline: + server.poll(timeout=0.05) + try: + chunk = client.recv(4096) + except BlockingIOError: + continue + self.assertEqual(chunk, b"") + return + self.fail("server did not close the connection after quit") + + def test_inbuf_overflow_drops_connection(self): + """A client filling the input buffer without a newline is dropped.""" + server = self.start_server() + client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + client.connect(self.path) + self.addCleanup(client.close) + client.sendall(b"x" * (_MAX_INBUF_BYTES + 1)) + deadline = time.monotonic() + SHORT_TIMEOUT + while time.monotonic() < deadline: + server.poll(timeout=0.05) + if not server._connections: + return + self.fail("server did not drop client over the inbuf cap") + + def test_close_listener_preserves_replaced_path(self): + """stop() refuses to unlink a different file at the same path.""" + server = self.start_server() + os.unlink(self.path) + replacement = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + socket_helper.bind_unix_socket(replacement, self.path) + self.addCleanup(replacement.close) + server.stop() + self.assertTrue(os.path.exists(self.path)) + + def test_cli_rejects_control_with_subprocesses(self): + """--control and --subprocesses are mutually exclusive.""" + argv = [ + "profiling.sampling.cli", + "attach", + "--subprocesses", + "--control", + "unix:/tmp/x.sock", + "123", + ] + with ( + mock.patch("sys.argv", argv), + mock.patch("sys.stderr", io.StringIO()) as stderr, + self.assertRaises(SystemExit) as cm, + ): + main() + self.assertEqual(cm.exception.code, 2) + self.assertIn( + "--control is incompatible with --subprocesses", stderr.getvalue() + ) + + def test_cli_rejects_bad_uri(self): + """An unsupported control URI scheme is rejected during validation.""" + argv = [ + "profiling.sampling.cli", + "attach", + "--control", + "fifo:/tmp/x", + "123", + ] + with ( + mock.patch("sys.argv", argv), + mock.patch("sys.stderr", io.StringIO()) as stderr, + self.assertRaises(SystemExit) as cm, + ): + main() + self.assertEqual(cm.exception.code, 2) + self.assertIn("unsupported control URI scheme", stderr.getvalue()) + + @unittest.skipUnless(LiveStatsCollector is not None, + "requires curses for --live") + def test_cli_accepts_control_with_live(self): + """--control and --live coexist after the mutex was dropped.""" + argv = [ + "profiling.sampling.cli", + "attach", + "--live", + "--control", + "unix:/tmp/ignored.sock", + "123", + ] + with ( + mock.patch("sys.argv", argv), + mock.patch("sys.stderr", io.StringIO()) as stderr, + mock.patch( + "profiling.sampling.cli._is_process_running", return_value=True + ), + mock.patch("profiling.sampling.cli._handle_live_attach") as live, + ): + main() + self.assertNotIn("incompatible with --live", stderr.getvalue()) + live.assert_called_once() + (forwarded_args, forwarded_pid), _ = live.call_args + self.assertEqual(forwarded_args.control, "unix:/tmp/ignored.sock") + self.assertTrue(forwarded_args.live) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_profiling/test_sampling_profiler/test_integration.py b/Lib/test/test_profiling/test_sampling_profiler/test_integration.py index c6731e956391a9..52dd9a1f54fde5 100644 --- a/Lib/test/test_profiling/test_sampling_profiler/test_integration.py +++ b/Lib/test/test_profiling/test_sampling_profiler/test_integration.py @@ -5,9 +5,12 @@ import marshal import os import shutil +import socket import subprocess import sys import tempfile +import threading +import time import unittest from unittest import mock @@ -15,6 +18,7 @@ import _remote_debugging import profiling.sampling import profiling.sampling.sample + from profiling.sampling._control import ControlServer from profiling.sampling.pstats_collector import PstatsCollector from profiling.sampling.stack_collector import CollapsedStackCollector from profiling.sampling.sample import SampleProfiler, _is_process_running @@ -27,6 +31,8 @@ from test.support import ( requires_remote_subprocess_debugging, SHORT_TIMEOUT, + os_helper, + socket_helper, ) from .helpers import ( @@ -953,3 +959,66 @@ def test_all_stacks_share_same_base_frame(self): f"missing the entry point function 'run_forever'. This indicates " f"incomplete stacks are being returned, likely due to frame cache " f"storing partial stack traces.") + + +@requires_remote_subprocess_debugging() +@socket_helper.skip_unless_bind_unix_socket +@unittest.skipIf( + sys.platform == "darwin" and os.geteuid() != 0, + "macOS profiling requires elevated permissions", +) +class TestControlSocketIntegration(unittest.TestCase): + """End-to-end tests for the --control socket via in-process sample().""" + + def _send_recv(self, client, request, expected_reply): + client.sendall(request) + self.assertEqual(client.recv(4096), expected_reply) + + def test_control_socket_disable_enable_quit_cycle(self): + """Drive disable/enable/quit through a real ControlServer.""" + script = ''' +import time +_test_sock.sendall(b"working") +for _ in range(200): + sum(i * i for i in range(50_000)) + time.sleep(0.05) +''' + with test_subprocess(script, wait_for_working=True) as target: + socket_path = socket_helper.create_unix_domain_name() + self.addCleanup(os_helper.unlink, socket_path) + + server = ControlServer(f"unix:{socket_path}") + server.start() + self.addCleanup(server.stop) + + exception = [None] + + collector = PstatsCollector(sample_interval_usec=1000, skip_idle=False) + + def sample_worker(): + try: + with contextlib.redirect_stdout(io.StringIO()): + profiling.sampling.sample.sample( + target.process.pid, + collector, + duration_sec=SHORT_TIMEOUT, + control_server=server, + ) + except Exception as exc: + exception[0] = exc + + thread = threading.Thread(target=sample_worker, daemon=True) + thread.start() + + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as client: + client.settimeout(SHORT_TIMEOUT) + client.connect(socket_path) + self._send_recv(client, b"disable\n", b"ok\n") + time.sleep(0.2) + self._send_recv(client, b"enable\n", b"ok\n") + self._send_recv(client, b"quit\n", b"ok\n") + + thread.join(timeout=SHORT_TIMEOUT) + self.assertFalse(thread.is_alive(), "sample() did not exit on quit") + if exception[0] is not None: + raise exception[0] diff --git a/Lib/test/test_profiling/test_sampling_profiler/test_profiler.py b/Lib/test/test_profiling/test_sampling_profiler/test_profiler.py index 2f5a5e27328659..c2536558599cd6 100644 --- a/Lib/test/test_profiling/test_sampling_profiler/test_profiler.py +++ b/Lib/test/test_profiling/test_sampling_profiler/test_profiler.py @@ -210,7 +210,6 @@ def test_sample_profiler_does_not_buffer_non_aggregating_collectors(self): stack_frames = [mock.sentinel.stack_frames] mock_collector = mock.MagicMock() - mock_collector.aggregating = False with self._patched_unwinder() as u: u.instance.get_stack_trace.return_value = stack_frames @@ -223,7 +222,7 @@ def test_sample_profiler_does_not_buffer_non_aggregating_collectors(self): pid=12345, sample_interval_usec=10000, all_threads=False ) - times = [0.0, 0.01, 0.011, 0.02, 0.03] + times = [0.0, 0.01, 0.011, 0.02, 0.03, 0.04, 0.05] with mock.patch("time.perf_counter", side_effect=times): with io.StringIO() as output: with mock.patch("sys.stdout", output): @@ -260,6 +259,7 @@ def test_sample_profiler_flushes_aggregated_batches_at_limit(self): 0.03, 0.031, 0.04, 0.041, 0.05, 0.051, + 0.06, 0.061, ] with mock.patch("profiling.sampling.sample.MAX_PENDING_SAMPLES", 2): with mock.patch("time.perf_counter", side_effect=times): @@ -322,7 +322,7 @@ def test_sample_profiler_error_handling(self): mock_collector = mock.MagicMock() # Control timing to run exactly 5 samples - times = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06] + times = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08] with mock.patch("time.perf_counter", side_effect=times): with io.StringIO() as output: @@ -376,6 +376,9 @@ def test_sample_profiler_missed_samples_warning(self): 0.5, 0.6, 0.7, + 0.8, + 0.9, + 1.0, ] # Extra time points to avoid StopIteration with mock.patch("time.perf_counter", side_effect=times): @@ -413,7 +416,7 @@ def test_sample_profiler_keyboard_interrupt(self): pid=12345, sample_interval_usec=10000, all_threads=False ) mock_collector = mock.MagicMock() - times = [0.0, 0.01, 0.02, 0.03, 0.04] + times = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06] with mock.patch("time.perf_counter", side_effect=times): with io.StringIO() as output: with mock.patch("sys.stdout", output): @@ -429,6 +432,136 @@ def test_sample_profiler_keyboard_interrupt(self): self.assertIn("samples", result) self.assertNotIn("Warning: missed", result) + @staticmethod + def _fake_control_server(*, enabled=True, running=True): + """Build a minimal control_server stub for sample() integration.""" + control = types.SimpleNamespace(enabled=enabled, running=running) + server = mock.MagicMock() + server.control = control + return server + + def test_sample_polls_control_server_periodically(self): + """Test that sample() drives control_server.poll() during the loop.""" + control_server = self._fake_control_server() + mock_collector = mock.MagicMock() + times = [1000.0 + i * 0.01 for i in range(60)] + with self._patched_unwinder() as u: + u.instance.get_stack_trace.return_value = [] + profiler = SampleProfiler( + pid=12345, sample_interval_usec=10000, all_threads=False + ) + with mock.patch("time.perf_counter", side_effect=times): + with io.StringIO() as output: + with mock.patch("sys.stdout", output): + profiler.sample( + mock_collector, + duration_sec=0.2, + control_server=control_server, + ) + self.assertTrue(control_server.poll.called) + + def test_sample_breaks_when_control_running_false(self): + """Test that sample() exits immediately when control.running is False.""" + control_server = self._fake_control_server(running=False) + mock_collector = mock.MagicMock() + with self._patched_unwinder() as u: + u.instance.get_stack_trace.return_value = [] + profiler = SampleProfiler( + pid=12345, sample_interval_usec=10000, all_threads=False + ) + with io.StringIO() as output: + with mock.patch("sys.stdout", output): + profiler.sample( + mock_collector, + duration_sec=60, + control_server=control_server, + ) + mock_collector.collect.assert_not_called() + + def test_sample_pauses_when_control_disabled(self): + """Test that disabled state stops sample collection.""" + control_server = self._fake_control_server(enabled=False) + mock_collector = mock.MagicMock() + times = [1000.0 + i * 0.01 for i in range(60)] + with self._patched_unwinder() as u: + u.instance.get_stack_trace.return_value = [] + profiler = SampleProfiler( + pid=12345, sample_interval_usec=10000, all_threads=False + ) + with mock.patch("time.perf_counter", side_effect=times): + with io.StringIO() as output: + with mock.patch("sys.stdout", output): + profiler.sample( + mock_collector, + duration_sec=0.2, + control_server=control_server, + ) + mock_collector.collect.assert_not_called() + + def test_sample_resumes_after_re_enable(self): + """Test that sampling resumes when control flips from disabled to enabled.""" + control = types.SimpleNamespace(enabled=False, running=True) + control_server = mock.MagicMock() + control_server.control = control + + def poll_side_effect(timeout=0): + if control_server.poll.call_count >= 3: + control.enabled = True + control_server.poll.side_effect = poll_side_effect + + mock_collector = mock.MagicMock() + times = [1000.0 + i * 0.01 for i in range(80)] + with self._patched_unwinder() as u: + u.instance.get_stack_trace.return_value = [ + (1, [mock.MagicMock(filename="t.py", lineno=1, funcname="f")]) + ] + profiler = SampleProfiler( + pid=12345, sample_interval_usec=10000, all_threads=False + ) + with mock.patch("time.perf_counter", side_effect=times): + with io.StringIO() as output: + with mock.patch("sys.stdout", output): + profiler.sample( + mock_collector, + duration_sec=0.3, + control_server=control_server, + ) + self.assertGreater(mock_collector.collect.call_count, 0) + + def test_sample_rate_reflects_enabled_time(self): + """Test that Sample rate divides by enabled time, not wall time.""" + control = types.SimpleNamespace(enabled=False, running=True) + control_server = mock.MagicMock() + control_server.control = control + + def poll_side_effect(timeout=0): + if control_server.poll.call_count >= 10: + control.enabled = True + control_server.poll.side_effect = poll_side_effect + + mock_collector = mock.MagicMock() + times = [1000.0 + i * 0.01 for i in range(120)] + with self._patched_unwinder() as u: + u.instance.get_stack_trace.return_value = [ + (1, [mock.MagicMock(filename="t.py", lineno=1, funcname="f")]) + ] + profiler = SampleProfiler( + pid=12345, sample_interval_usec=10000, all_threads=False + ) + with mock.patch("time.perf_counter", side_effect=times): + with io.StringIO() as output: + with mock.patch("sys.stdout", output): + profiler.sample( + mock_collector, + duration_sec=0.5, + control_server=control_server, + ) + mock_collector.set_stats.assert_called_once() + args = mock_collector.set_stats.call_args.args + running_time_sec, sample_rate = args[1], args[2] + samples = mock_collector.collect.call_count + self.assertGreater(sample_rate, samples / running_time_sec) + @force_not_colorized_test_class class TestPrintSampledStats(unittest.TestCase): diff --git a/Misc/NEWS.d/next/Library/2026-05-17-21-48-02.gh-issue-149958.6rG3lT.rst b/Misc/NEWS.d/next/Library/2026-05-17-21-48-02.gh-issue-149958.6rG3lT.rst new file mode 100644 index 00000000000000..2e1fa4d49b8e63 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2026-05-17-21-48-02.gh-issue-149958.6rG3lT.rst @@ -0,0 +1,3 @@ +The Tachyon module now supports runtime control over a unix socket, via +``--control unix:``. Supported commands are ``enable``, ``disable``, +``ping``, ``status`` and ``quit``. Patch by Maurycy Pawłowski-Wieroński.