diff --git a/doc/usage.rst b/doc/usage.rst index 37527fb40..bbcccfa8b 100644 --- a/doc/usage.rst +++ b/doc/usage.rst @@ -839,3 +839,14 @@ like this: $ labgrid-client -p example allow sirius/john To remove the allow it is currently necessary to unlock and lock the place. + +Internal console +^^^^^^^^^^^^^^^^ + +Labgrid uses microcom as its console by default. For situations where this is +not suitable, an internal console is provided. To use this, provide the +``--internal`` flag to the ``labgrid client`` command. + +When the internal console is used, the console transitions cleanly between use +within a strategy or driver, and interactive use for the user. The console is +not closed and therefore there is no loss of data. diff --git a/labgrid/remote/client.py b/labgrid/remote/client.py index 4d2eb0bfa..e06de92e1 100755 --- a/labgrid/remote/client.py +++ b/labgrid/remote/client.py @@ -14,7 +14,6 @@ import signal import sys import shlex -import shutil import json import itertools import ipaddress @@ -46,7 +45,7 @@ from ..exceptions import NoDriverFoundError, NoResourceFoundError, InvalidConfigError from .generated import labgrid_coordinator_pb2, labgrid_coordinator_pb2_grpc from ..resource.remote import RemotePlaceManager, RemotePlace -from ..util import diff_dict, flat_dict, dump, atomic_replace, labgrid_version, Timeout +from ..util import diff_dict, flat_dict, dump, atomic_replace, labgrid_version, Timeout, term from ..util.proxy import proxymanager from ..util.helper import processwrapper from ..driver import Mode, ExecutionError @@ -468,19 +467,28 @@ def _match_places(self, pattern): result.add(name) return list(result) - def _check_allowed(self, place): + def is_allowed(self, place): + """Check if a place is acquired + + Args: + place (str): Place name to check + + Returns: + str: None if acquired, else error message + """ if not place.acquired: - raise UserError(f"place {place.name} is not acquired") + return f"place {place.name} is not acquired" if f"{self.gethostname()}/{self.getuser()}" not in place.allowed: host, user = place.acquired.split("/") if user != self.getuser(): - raise UserError( - f"place {place.name} is not acquired by your user, acquired by {user}. To work simultaneously, {user} can execute labgrid-client -p {place.name} allow {self.gethostname()}/{self.getuser()}" - ) + return f"place {place.name} is not acquired by your user, acquired by {user}. To work simultaneously, {user} can execute labgrid-client -p {place.name} allow {self.gethostname()}/{self.getuser()}" if host != self.gethostname(): - raise UserError( - f"place {place.name} is not acquired on this computer, acquired on {host}. To allow this host, use labgrid-client -p {place.name} allow {self.gethostname()}/{self.getuser()} on the other host" - ) + return f"place {place.name} is not acquired on this computer, acquired on {host}. To allow this host, use labgrid-client -p {place.name} allow {self.gethostname()}/{self.getuser()} on the other host" + + def _check_allowed(self, place): + err = self.is_allowed(place) + if err: + raise UserError(err) def get_place(self, place=None): pattern = place or self.args.place @@ -890,12 +898,6 @@ def _get_target(self, place): strategy.force(self.args.initial_state) print(f"Transitioning into state {self.args.state}") strategy.transition(self.args.state) - # deactivate console drivers so we are able to connect with microcom later - try: - con = target.get_active_driver("ConsoleProtocol") - target.deactivate(con) - except NoDriverFoundError: - pass else: target = Target(place.name, env=self.env) RemotePlace(target, name=place.name) @@ -1027,78 +1029,53 @@ def digital_io(self): drv.set(False) async def _console(self, place, target, timeout, *, logfile=None, loop=False, listen_only=False): - name = self.args.name - from ..resource import NetworkSerialPort - - resource = target.get_resource(NetworkSerialPort, name=name, wait_avail=False) + from ..protocol import ConsoleProtocol - # async await resources - timeout = Timeout(timeout) - while True: - target.update_resources() - if resource.avail or (not loop and timeout.expired): - break - await asyncio.sleep(0.1) - - # use zero timeout to prevent blocking sleeps - target.await_resources([resource], timeout=0.0) + name = self.args.name if not place.acquired: print("place released") return 255 - host, port = proxymanager.get_host_and_port(resource) - - # check for valid resources - assert port is not None, "Port is not set" - - microcom_bin = shutil.which("microcom") - - if microcom_bin is not None: - call = [microcom_bin, "-s", str(resource.speed), "-t", f"{host}:{port}"] - - if listen_only: - call.append("--listenonly") - - if logfile: - call.append(f"--logfile={logfile}") + if self.args.internal or os.environ.get("LG_CONSOLE") == "internal": + console = target.get_driver(ConsoleProtocol, name=name) + returncode = await term.internal(lambda: self.is_allowed(place), console, logfile, listen_only) else: - call = ["telnet", host, str(port)] + from ..resource import NetworkSerialPort - logging.info("microcom not available, using telnet instead") + # deactivate console drivers so we are able to connect with microcom + try: + con = target.get_active_driver("ConsoleProtocol") + target.deactivate(con) + except NoDriverFoundError: + pass - if listen_only: - logging.warning("--listenonly option not supported by telnet, ignoring") + resource = target.get_resource(NetworkSerialPort, name=name, wait_avail=False) - if logfile: - logging.warning("--logfile option not supported by telnet, ignoring") + # async await resources + timeout = Timeout(timeout) + while True: + target.update_resources() + if resource.avail or (not loop and timeout.expired): + break + await asyncio.sleep(0.1) - print(f"connecting to {resource} calling {' '.join(call)}") - try: - p = await asyncio.create_subprocess_exec(*call) - except FileNotFoundError as e: - raise ServerError(f"failed to execute remote console command: {e}") - while p.returncode is None: - try: - await asyncio.wait_for(p.wait(), 1.0) - except asyncio.TimeoutError: - # subprocess is still running - pass + # use zero timeout to prevent blocking sleeps + target.await_resources([resource], timeout=0.0) + host, port = proxymanager.get_host_and_port(resource) + # check for valid resources + assert port is not None, "Port is not set" try: - self._check_allowed(place) - except UserError: - p.terminate() - try: - await asyncio.wait_for(p.wait(), 1.0) - except asyncio.TimeoutError: - # try harder - p.kill() - await asyncio.wait_for(p.wait(), 1.0) - raise - if p.returncode: - print("connection lost", file=sys.stderr) - return p.returncode + returncode = await term.external( + lambda: self.is_allowed(place), host, port, resource, logfile, listen_only + ) + except FileNotFoundError as e: + raise ServerError(f"failed to execute remote console command: {e}") + + # Raise an exception if the place was released + self._check_allowed(place) + return returncode async def console(self, place, target): while True: @@ -1110,7 +1087,7 @@ async def console(self, place, target): break if not self.args.loop: if res: - raise InteractiveCommandError("microcom error", res) + raise InteractiveCommandError("console error", res) break await asyncio.sleep(1.0) @@ -1995,6 +1972,7 @@ def get_parser(auto_doc_mode=False) -> "argparse.ArgumentParser | AutoProgramArg subparser.set_defaults(func=ClientSession.digital_io) subparser = subparsers.add_parser("console", aliases=("con",), help="connect to the console") + subparser.add_argument("-i", "--internal", action="store_true", help="use an internal console instead of microcom") subparser.add_argument( "-l", "--loop", action="store_true", help="keep trying to connect if the console is unavailable" ) diff --git a/labgrid/util/term.py b/labgrid/util/term.py new file mode 100644 index 000000000..74ed19f65 --- /dev/null +++ b/labgrid/util/term.py @@ -0,0 +1,196 @@ +"""Terminal handling, using microcom, telnet or an internal function""" + +import asyncio +import collections +import logging +import os +import sys +import shutil +import termios +import time + +from pexpect import TIMEOUT +from serial.serialutil import SerialException + +EXIT_CHAR = 0x1d # FS (Ctrl + ]) + +async def external(check_allowed, host, port, resource, logfile, listen_only): + """Start an external terminal sessions + + This uses microcom if available, otherwise falls back to telnet. + + Args: + check_allowed (lambda): Function to call to make sure the terminal is + still accessible. No args. Returns True if allowed, False if not. + host (str): Host name to connect to + port (int): Port number to connect to + resource (str): Serial resource to connect to (used to get speed / name) + logfile (str): Logfile to write output too, or None. This is ignored if + telnet is used + listen_only (bool): True to ignore keyboard input (ignored with telnet) + + Returns: + int: Return code from tool + """ + microcom_bin = shutil.which("microcom") + + if microcom_bin is not None: + call = [microcom_bin, "-s", str(resource.speed), "-t", f"{host}:{port}"] + + if listen_only: + call.append("--listenonly") + + if logfile: + call.append(f"--logfile={logfile}") + else: + call = ["telnet", host, str(port)] + + logging.info("microcom not available, using telnet instead") + + if listen_only: + logging.warning("--listenonly option not supported by telnet, ignoring") + + if logfile: + logging.warning("--logfile option not supported by telnet, ignoring") + + logging.info("connecting to %s calling %s", resource, " ".join(call)) + p = await asyncio.create_subprocess_exec(*call) + while p.returncode is None: + try: + await asyncio.wait_for(p.wait(), 1.0) + except asyncio.TimeoutError: + # subprocess is still running + pass + + if check_allowed(): + p.terminate() + try: + await asyncio.wait_for(p.wait(), 1.0) + except asyncio.TimeoutError: + # try harder + p.kill() + await asyncio.wait_for(p.wait(), 1.0) + break + if p.returncode: + print("connection lost", file=sys.stderr) + return p.returncode + + +BUF_SIZE = 1024 + +async def run(check_allowed, cons, log_fd, listen_only): + prev = collections.deque(maxlen=2) + + deadline = None + to_cons = b'' + next_cons = time.monotonic() + txdelay = cons.txdelay + + # Show a message to indicate we are waiting for output from the board + msg = 'Terminal ready...press Ctrl-] twice to exit' + sys.stdout.write(msg) + sys.stdout.flush() + erase_msg = '\b' * len(msg) + ' ' * len(msg) + '\b' * len(msg) + have_output = False + + while True: + activity = bool(to_cons) + try: + data = cons.read(size=BUF_SIZE, timeout=0.001) + if data: + activity = True + if not have_output: + # Erase our message + sys.stdout.write(erase_msg) + sys.stdout.flush() + have_output = True + sys.stdout.buffer.write(data) + sys.stdout.buffer.flush() + if log_fd: + log_fd.write(data) + log_fd.flush() + + except TIMEOUT: + pass + + except SerialException: + break + + if not listen_only: + data = os.read(sys.stdin.fileno(), BUF_SIZE) + if data: + activity = True + if not deadline: + deadline = time.monotonic() + .5 # seconds + prev.extend(data) + count = prev.count(EXIT_CHAR) + if count == 2: + break + + to_cons += data + + # Drain one byte at a time, honouring txdelay between bytes + if to_cons and time.monotonic() > next_cons: + cons._write(to_cons[:1]) + to_cons = to_cons[1:] + if txdelay: + next_cons += txdelay + + if deadline and time.monotonic() > deadline: + prev.clear() + deadline = None + if check_allowed(): + break + if not activity: + time.sleep(.001) + + # Blank line to move past any partial output + print() + + +async def internal(check_allowed, cons, logfile, listen_only): + """Start an external terminal sessions + + This uses microcom if available, otherwise falls back to telnet. + + Args: + check_allowed (lambda): Function to call to make sure the terminal is + still accessible. No args. Returns True if allowed, False if not. + cons (str): ConsoleProtocol device to read/write + logfile (str): Logfile to write output too, or None + listen_only (bool): True to ignore keyboard input + + Return: + int: Result code + """ + returncode = 0 + old = None + log_fd = None + try: + if not listen_only and os.isatty(sys.stdout.fileno()): + fd = sys.stdin.fileno() + old = termios.tcgetattr(fd) + new = termios.tcgetattr(fd) + new[3] = new[3] & ~(termios.ICANON | termios.ECHO | termios.ISIG) + new[6][termios.VMIN] = 0 + new[6][termios.VTIME] = 0 + termios.tcsetattr(fd, termios.TCSANOW, new) + + log_fd = None + if logfile: + log_fd = open(logfile, 'wb') + + logging.info('Console start:') + await run(check_allowed, cons, log_fd, listen_only) + + except OSError as err: + print('error', err) + returncode = 1 + + finally: + if old: + termios.tcsetattr(fd, termios.TCSAFLUSH, old) + if log_fd: + log_fd.close() + + return returncode diff --git a/man/labgrid-client.1 b/man/labgrid-client.1 index dfc95b2e8..e84495197 100644 --- a/man/labgrid-client.1 +++ b/man/labgrid-client.1 @@ -233,7 +233,7 @@ connect to the console .INDENT 3.5 .sp .EX -usage: labgrid\-client console|con [\-l] [\-o] [\-\-logfile FILE] [name] +usage: labgrid\-client console|con [\-i] [\-l] [\-o] [\-\-logfile FILE] [name] .EE .UNINDENT .UNINDENT @@ -244,6 +244,11 @@ optional resource name .UNINDENT .INDENT 0.0 .TP +.B \-i, \-\-internal +use an internal console instead of microcom +.UNINDENT +.INDENT 0.0 +.TP .B \-l, \-\-loop keep trying to connect if the console is unavailable .UNINDENT diff --git a/tests/test_client_unit.py b/tests/test_client_unit.py new file mode 100644 index 000000000..7f00ce1f2 --- /dev/null +++ b/tests/test_client_unit.py @@ -0,0 +1,94 @@ +"""Unit tests for labgrid.remote.client""" + +import argparse +from unittest.mock import MagicMock, patch + +import pytest + +from labgrid.remote.client import ClientSession, UserError, get_parser + + +# --- is_allowed() tests --- + +@pytest.fixture +def session(): + """Create a minimal ClientSession-like object for testing""" + s = object.__new__(ClientSession) + s.args = argparse.Namespace() + return s + + +@pytest.fixture +def mock_place(): + place = MagicMock() + place.name = "testplace" + place.acquired = "myhost/myuser" + place.allowed = {"myhost/myuser"} + return place + + +class TestIsAllowed: + def test_place_not_acquired(self, session, mock_place): + mock_place.acquired = None + with patch.object(session, "gethostname", return_value="myhost"), \ + patch.object(session, "getuser", return_value="myuser"): + result = session.is_allowed(mock_place) + assert "not acquired" in result + + def test_place_acquired_by_us(self, session, mock_place): + with patch.object(session, "gethostname", return_value="myhost"), \ + patch.object(session, "getuser", return_value="myuser"): + result = session.is_allowed(mock_place) + assert result is None + + def test_place_acquired_by_different_user(self, session, mock_place): + mock_place.acquired = "myhost/otheruser" + mock_place.allowed = {"myhost/otheruser"} + with patch.object(session, "gethostname", return_value="myhost"), \ + patch.object(session, "getuser", return_value="myuser"): + result = session.is_allowed(mock_place) + assert "not acquired by your user" in result + assert "otheruser" in result + + def test_place_acquired_on_different_host(self, session, mock_place): + mock_place.acquired = "otherhost/myuser" + mock_place.allowed = {"otherhost/myuser"} + with patch.object(session, "gethostname", return_value="myhost"), \ + patch.object(session, "getuser", return_value="myuser"): + result = session.is_allowed(mock_place) + assert "not acquired on this computer" in result + assert "otherhost" in result + + def test_place_acquired_elsewhere_but_allowed(self, session, mock_place): + """User is in the allowed set even though place was acquired elsewhere""" + mock_place.acquired = "otherhost/otheruser" + mock_place.allowed = {"otherhost/otheruser", "myhost/myuser"} + with patch.object(session, "gethostname", return_value="myhost"), \ + patch.object(session, "getuser", return_value="myuser"): + result = session.is_allowed(mock_place) + assert result is None + + +# --- _check_allowed() tests --- + +class TestCheckAllowed: + def test_raises_on_not_allowed(self, session, mock_place): + mock_place.acquired = None + with patch.object(session, "gethostname", return_value="myhost"), \ + patch.object(session, "getuser", return_value="myuser"): + with pytest.raises(UserError, match="not acquired"): + session._check_allowed(mock_place) + + def test_no_raise_when_allowed(self, session, mock_place): + with patch.object(session, "gethostname", return_value="myhost"), \ + patch.object(session, "getuser", return_value="myuser"): + session._check_allowed(mock_place) # should not raise + + +# --- get_parser() tests --- + +class TestGetParser: + def test_console_internal_argument(self): + parser = get_parser() + args = parser.parse_args(["console", "--internal"]) + assert args.internal is True diff --git a/tests/test_term.py b/tests/test_term.py new file mode 100644 index 000000000..8b9ae8c17 --- /dev/null +++ b/tests/test_term.py @@ -0,0 +1,504 @@ +"""Tests for labgrid.util.term — terminal handling""" + +import asyncio +import io +import logging +import os +import sys +import termios +import threading +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from labgrid.util.term import external, run, internal, EXIT_CHAR +from pexpect import TIMEOUT +from serial.serialutil import SerialException + + +@pytest.fixture +def event_loop(): + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +@pytest.fixture +def mock_resource(): + resource = MagicMock() + resource.speed = 115200 + return resource + + +@pytest.fixture +def mock_console(): + cons = MagicMock() + cons.txdelay = 0 + cons.read = MagicMock(side_effect=TIMEOUT("timeout")) + cons._write = MagicMock() + return cons + + +class FakeConsole: + """Minimal console for pipe-based tests. + + Args: + txdelay: per-byte transmit delay in seconds + on_write: optional callback invoked with each byte written + """ + def __init__(self, txdelay=0, on_write=None): + self.txdelay = txdelay + self.written = [] + self._on_write = on_write + + def read(self, size=1024, timeout=0.001): + raise TIMEOUT("timeout") + + def _write(self, data): + self.written.append(data) + if self._on_write: + self._on_write(data) + + +@pytest.fixture +def stdin_pipe(): + """Create a pipe and yield (read_file, write_fd). + + The read side is wrapped in a file object suitable for patching + sys.stdin. Both ends are closed on cleanup. + """ + read_fd, write_fd = os.pipe() + read_file = os.fdopen(read_fd, 'rb', 0) + yield read_file, write_fd + read_file.close() + try: + os.close(write_fd) + except OSError: + pass # already closed by the test + + +# --- external() tests --- + +class TestExternal: + def test_microcom_basic(self, event_loop, mock_resource): + """Test that external() launches microcom when available""" + proc = AsyncMock() + proc.returncode = 0 + proc.wait = AsyncMock(return_value=0) + proc.terminate = MagicMock() + + with patch("labgrid.util.term.shutil.which", return_value="/usr/bin/microcom"), \ + patch("labgrid.util.term.asyncio.create_subprocess_exec", new_callable=AsyncMock, return_value=proc) as mock_exec: + result = event_loop.run_until_complete( + external(lambda: None, "host1", 1234, mock_resource, None, False)) + + args = mock_exec.call_args[0] + assert args[0] == "/usr/bin/microcom" + assert "-s" in args + assert "115200" in args + assert "-t" in args + assert "host1:1234" in args + assert result == 0 + + def test_microcom_listen_only(self, event_loop, mock_resource): + """Test that --listenonly is passed to microcom""" + proc = AsyncMock() + proc.returncode = 0 + proc.wait = AsyncMock(return_value=0) + + with patch("labgrid.util.term.shutil.which", return_value="/usr/bin/microcom"), \ + patch("labgrid.util.term.asyncio.create_subprocess_exec", new_callable=AsyncMock, return_value=proc) as mock_exec: + event_loop.run_until_complete( + external(lambda: None, "host1", 1234, mock_resource, None, True)) + + args = mock_exec.call_args[0] + assert "--listenonly" in args + + def test_telnet_fallback(self, event_loop, mock_resource): + """Test fallback to telnet when microcom is not available""" + proc = AsyncMock() + proc.returncode = 0 + proc.wait = AsyncMock(return_value=0) + + with patch("labgrid.util.term.shutil.which", return_value=None), \ + patch("labgrid.util.term.asyncio.create_subprocess_exec", new_callable=AsyncMock, return_value=proc) as mock_exec: + event_loop.run_until_complete( + external(lambda: None, "host1", 1234, mock_resource, None, False)) + + args = mock_exec.call_args[0] + assert args[0] == "telnet" + assert "host1" in args + assert "1234" in args + + def test_telnet_listen_only_warns(self, event_loop, mock_resource, caplog): + """Test that telnet with listen_only logs a warning""" + proc = AsyncMock() + proc.returncode = 0 + proc.wait = AsyncMock(return_value=0) + + with patch("labgrid.util.term.shutil.which", return_value=None), \ + patch("labgrid.util.term.asyncio.create_subprocess_exec", new_callable=AsyncMock, return_value=proc), \ + caplog.at_level(logging.WARNING): + event_loop.run_until_complete( + external(lambda: None, "host1", 1234, mock_resource, None, True)) + + assert "--listenonly option not supported by telnet" in caplog.text + + def test_telnet_logfile_warns(self, event_loop, mock_resource, caplog): + """Test that telnet with logfile logs a warning""" + proc = AsyncMock() + proc.returncode = 0 + proc.wait = AsyncMock(return_value=0) + + with patch("labgrid.util.term.shutil.which", return_value=None), \ + patch("labgrid.util.term.asyncio.create_subprocess_exec", new_callable=AsyncMock, return_value=proc), \ + caplog.at_level(logging.WARNING): + event_loop.run_until_complete( + external(lambda: None, "host1", 1234, mock_resource, "/tmp/log", False)) + + assert "--logfile option not supported by telnet" in caplog.text + + def test_check_allowed_terminates(self, event_loop, mock_resource): + """Test that check_allowed returning truthy terminates the process""" + call_count = [0] + + def check(): + call_count[0] += 1 + return "not allowed" if call_count[0] >= 2 else None + + proc = AsyncMock() + proc.returncode = None + + def do_terminate(): + proc.returncode = -15 + proc.terminate = MagicMock(side_effect=do_terminate) + + wait_count = [0] + async def fake_wait(): + wait_count[0] += 1 + if wait_count[0] == 1: + # First call: simulate poll timeout + await asyncio.sleep(10) + # Subsequent calls: return immediately (process terminated) + proc.wait = fake_wait + + with patch("labgrid.util.term.shutil.which", return_value="/usr/bin/microcom"), \ + patch("labgrid.util.term.asyncio.create_subprocess_exec", new_callable=AsyncMock, return_value=proc): + event_loop.run_until_complete( + external(check, "host1", 1234, mock_resource, None, False)) + + proc.terminate.assert_called_once() + + def test_check_allowed_kills_after_terminate_timeout(self, event_loop, mock_resource): + """Test that kill is used when terminate does not stop the process. + + This test takes ~3s because three asyncio.wait_for(timeout=1.0) + calls must time out (two poll loops + one after terminate). + """ + call_count = [0] + + def check(): + call_count[0] += 1 + return "not allowed" if call_count[0] >= 2 else None + + proc = MagicMock() + proc.returncode = None + proc.terminate = MagicMock() # terminate does NOT set returncode + + def do_kill(): + proc.returncode = -9 + proc.kill = MagicMock(side_effect=do_kill) + + wait_count = [0] + async def fake_wait(): + wait_count[0] += 1 + if wait_count[0] <= 3: + # First three calls hang: two poll loops + after terminate + await asyncio.sleep(10) + # Fourth call (after kill): return immediately + proc.wait = fake_wait + + with patch("labgrid.util.term.shutil.which", return_value="/usr/bin/microcom"), \ + patch("labgrid.util.term.asyncio.create_subprocess_exec", new_callable=AsyncMock, return_value=proc): + event_loop.run_until_complete( + external(check, "host1", 1234, mock_resource, None, False)) + + proc.terminate.assert_called_once() + proc.kill.assert_called_once() + + def test_microcom_logfile_not_duplicated(self, event_loop, mock_resource): + """Test that --logfile is not appended twice when using microcom""" + proc = AsyncMock() + proc.returncode = 0 + proc.wait = AsyncMock(return_value=0) + + with patch("labgrid.util.term.shutil.which", return_value="/usr/bin/microcom"), \ + patch("labgrid.util.term.asyncio.create_subprocess_exec", new_callable=AsyncMock, return_value=proc) as mock_exec: + event_loop.run_until_complete( + external(lambda: None, "host1", 1234, mock_resource, "/tmp/log", False)) + + args = mock_exec.call_args[0] + logfile_args = [a for a in args if "logfile" in str(a)] + assert len(logfile_args) == 1, f"--logfile appended {len(logfile_args)} times: {args}" + + def test_nonzero_return(self, event_loop, mock_resource, capsys): + """Test that non-zero return code prints connection lost""" + proc = AsyncMock() + proc.returncode = 1 + proc.wait = AsyncMock(return_value=1) + + with patch("labgrid.util.term.shutil.which", return_value="/usr/bin/microcom"), \ + patch("labgrid.util.term.asyncio.create_subprocess_exec", new_callable=AsyncMock, return_value=proc): + result = event_loop.run_until_complete( + external(lambda: None, "host1", 1234, mock_resource, None, False)) + assert result == 1 + assert "connection lost" in capsys.readouterr().err + + +# --- run() tests --- + +class TestRun: + def test_exit_on_double_ctrl_bracket(self, event_loop, mock_console): + """Test that double Ctrl+] exits the loop""" + exit_data = bytes([EXIT_CHAR, EXIT_CHAR]) + mock_stdin = MagicMock() + mock_stdin.fileno.return_value = 0 + + with patch("os.read", return_value=exit_data), \ + patch("sys.stdin", mock_stdin), \ + patch("sys.stdout", new_callable=lambda: MagicMock(spec=sys.stdout)): + event_loop.run_until_complete( + run(lambda: None, mock_console, None, False)) + + def test_listen_only_no_stdin_read(self, event_loop, mock_console): + """Test that listen_only mode does not read from stdin""" + call_count = [0] + def check(): + call_count[0] += 1 + return "done" if call_count[0] >= 2 else None + + with patch("os.read") as mock_read, \ + patch("sys.stdout", new_callable=lambda: MagicMock(spec=sys.stdout)): + event_loop.run_until_complete( + run(check, mock_console, None, True)) + mock_read.assert_not_called() + + def test_console_output_written_to_stdout(self, event_loop, mock_console): + """Test that console output is written to stdout""" + read_count = [0] + def mock_read(size=1024, timeout=0.001): + read_count[0] += 1 + if read_count[0] == 1: + return b"Hello from board\n" + raise TIMEOUT("timeout") + + mock_console.read = mock_read + + check_count = [0] + def check(): + check_count[0] += 1 + return "done" if check_count[0] >= 3 else None + + stdout_buffer = io.BytesIO() + mock_stdout = MagicMock() + mock_stdout.buffer = stdout_buffer + mock_stdout.write = MagicMock() + mock_stdout.flush = MagicMock() + + with patch("sys.stdout", mock_stdout): + event_loop.run_until_complete( + run(check, mock_console, None, True)) + + stdout_buffer.seek(0) + assert b"Hello from board\n" in stdout_buffer.getvalue() + + def test_logfile_written(self, event_loop, mock_console): + """Test that console output is written to the logfile""" + read_count = [0] + def mock_read(size=1024, timeout=0.001): + read_count[0] += 1 + if read_count[0] == 1: + return b"log data\n" + raise TIMEOUT("timeout") + + mock_console.read = mock_read + + check_count = [0] + def check(): + check_count[0] += 1 + return "done" if check_count[0] >= 3 else None + + log_fd = io.BytesIO() + + with patch("sys.stdout", new_callable=lambda: MagicMock(spec=sys.stdout)): + event_loop.run_until_complete( + run(check, mock_console, log_fd, True)) + + log_fd.seek(0) + assert b"log data\n" in log_fd.getvalue() + + def test_serial_exception_exits(self, event_loop, mock_console): + """Test that SerialException breaks out of the loop""" + mock_console.read = MagicMock(side_effect=SerialException("disconnected")) + + with patch("sys.stdout", new_callable=lambda: MagicMock(spec=sys.stdout)): + event_loop.run_until_complete( + run(lambda: None, mock_console, None, True)) + + def test_stdin_written_to_console(self, event_loop, stdin_pipe): + """Test that stdin data is written to the console one byte at a time, + using a pipe for stdin rather than mocking os.read""" + read_file, write_fd = stdin_pipe + os.write(write_fd, b"Hi") + os.close(write_fd) + + cons = FakeConsole() + + # os.read on a pipe returns b"" at EOF, which is falsy, so + # the loop will just keep going. Exit once both bytes are written. + def check(): + return "done" if len(cons.written) >= 2 else None + + with patch("sys.stdin", read_file), \ + patch("sys.stdout", new_callable=lambda: MagicMock(spec=sys.stdout)): + event_loop.run_until_complete( + run(check, cons, None, False)) + + assert cons.written == [b"H", b"i"] + + def test_stdin_txdelay(self, event_loop, stdin_pipe): + """Test that txdelay throttles bytes written to the console""" + read_file, write_fd = stdin_pipe + os.write(write_fd, b"AB") + os.close(write_fd) + + timestamps = [] + def record_time(data): + timestamps.append(time.monotonic()) + + cons = FakeConsole(txdelay=0.05, on_write=record_time) + + def check(): + return "done" if len(timestamps) >= 2 else None + + with patch("sys.stdin", read_file), \ + patch("sys.stdout", new_callable=lambda: MagicMock(spec=sys.stdout)): + event_loop.run_until_complete( + run(check, cons, None, False)) + + assert len(timestamps) == 2 + gap = timestamps[1] - timestamps[0] + # Allow 10ms margin below the 50ms txdelay for scheduling jitter + assert gap >= 0.04 + + def test_exit_char_deadline_resets(self, event_loop, stdin_pipe): + """Test that a single Ctrl+] is forgotten after the 0.5s deadline. + + Send Ctrl+] then wait for the deadline to expire, then send + normal data. The normal data should be written to the console + (proving the exit-char was cleared) rather than combined with + the stale Ctrl+] to trigger exit. + + To avoid a brittle fixed sleep, the feeder thread uses a + threading.Event set by _write() when the Ctrl+] byte arrives. + This way the 0.6s deadline-expiry sleep only starts once we + know the loop has processed the keystroke and set its internal + deadline, removing any race between the pipe write and the + main loop. + """ + read_file, write_fd = stdin_pipe + got_exit_char = threading.Event() + + def on_write(data): + if data == bytes([EXIT_CHAR]): + got_exit_char.set() + + cons = FakeConsole(on_write=on_write) + + def feed_stdin(): + os.write(write_fd, bytes([EXIT_CHAR])) + # Wait until the loop has processed the Ctrl+] (deadline is set) + got_exit_char.wait(timeout=5) + time.sleep(0.6) # exceed the 0.5s deadline + os.write(write_fd, b"X") + os.close(write_fd) + + threading.Thread(target=feed_stdin, daemon=True).start() + + # Safety: also exit after 3s in case the feeder thread fails + start = time.monotonic() + def check(): + if time.monotonic() - start > 3: + return "timeout" + return "done" if any(d == b"X" for d in cons.written) else None + + with patch("sys.stdin", read_file), \ + patch("sys.stdout", new_callable=lambda: MagicMock(spec=sys.stdout)): + event_loop.run_until_complete( + run(check, cons, None, False)) + + assert b"X" in cons.written + + def test_check_allowed_exits(self, event_loop, mock_console): + """Test that check_allowed returning truthy exits the loop""" + with patch("sys.stdout", new_callable=lambda: MagicMock(spec=sys.stdout)): + event_loop.run_until_complete( + run(lambda: "not allowed", mock_console, None, True)) + + +# --- internal() tests --- + +class TestInternal: + def test_listen_only_no_termios(self, event_loop, mock_console): + """Test that listen_only mode skips terminal setup""" + with patch("labgrid.util.term.run", new_callable=AsyncMock) as mock_run, \ + patch("labgrid.util.term.termios") as mock_termios: + result = event_loop.run_until_complete( + internal(lambda: None, mock_console, None, True)) + + mock_termios.tcgetattr.assert_not_called() + mock_run.assert_awaited_once() + assert result == 0 + + def test_with_logfile(self, event_loop, mock_console, tmp_path): + """Test that a logfile is opened and closed""" + logfile = str(tmp_path / "test.log") + + with patch("labgrid.util.term.run", new_callable=AsyncMock): + result = event_loop.run_until_complete( + internal(lambda: None, mock_console, logfile, True)) + + assert result == 0 + assert os.path.exists(logfile) + + def test_os_error_returns_1(self, event_loop, mock_console): + """Test that OSError during run returns exitcode 1""" + with patch("labgrid.util.term.run", new_callable=AsyncMock, + side_effect=OSError("test error")): + result = event_loop.run_until_complete( + internal(lambda: None, mock_console, None, True)) + assert result == 1 + + def test_terminal_restored_on_exit(self, event_loop, mock_console): + """Test that terminal attributes are restored after exit""" + old_attrs = [0, 0, 0, 0, 0, 0, [0] * 32] + + mock_stdin = MagicMock() + mock_stdin.fileno.return_value = 0 + + with patch("labgrid.util.term.run", new_callable=AsyncMock), \ + patch("labgrid.util.term.os.isatty", return_value=True), \ + patch("labgrid.util.term.sys.stdin", mock_stdin), \ + patch("labgrid.util.term.termios.tcgetattr", return_value=old_attrs.copy()), \ + patch("labgrid.util.term.termios.tcsetattr") as mock_set: + event_loop.run_until_complete( + internal(lambda: None, mock_console, None, False)) + + assert mock_set.call_count == 2 + # First call: setup (TCSANOW), second call: restore (TCSAFLUSH) + setup_call = mock_set.call_args_list[0] + assert setup_call[0][1] == termios.TCSANOW + restore_call = mock_set.call_args_list[1] + assert restore_call[0][1] == termios.TCSAFLUSH + assert restore_call[0][2] == old_attrs