From 8eb259601c0f31b915120eeb0b8b3cf621cd7254 Mon Sep 17 00:00:00 2001 From: njzjz-bot <48687836+njzjz-bot@users.noreply.github.com> Date: Sun, 29 Mar 2026 14:56:30 +0000 Subject: [PATCH] fix(ssh): add optional remote_root auto-creation Add a create_remote_root switch for SSHContext so DPDispatcher can recursively create the configured remote_root when users opt in. This preserves the current safe default while fixing setups where parent directories do not already exist. Also persist the new option through Machine.serialize(), add unit tests for recursive mkdir and config round-tripping, and document the new knob in the SSH examples. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.4) --- doc/context.md | 2 + doc/getting-started.md | 1 + dpdispatcher/contexts/ssh_context.py | 53 +++++++++++--- dpdispatcher/machine.py | 4 ++ examples/machine/ssh_proxy_command.json | 1 + tests/test_argcheck.py | 63 +++++++++++++++++ tests/test_ssh_create_remote_root.py | 93 +++++++++++++++++++++++++ 7 files changed, 207 insertions(+), 10 deletions(-) create mode 100644 tests/test_ssh_create_remote_root.py diff --git a/doc/context.md b/doc/context.md index 627158d3..ed1462e1 100644 --- a/doc/context.md +++ b/doc/context.md @@ -29,6 +29,8 @@ Since [`bash -l`](https://www.gnu.org/software/bash/manual/bash.html#Invoking-Ba Files will be copied to the remote directory via SSH channels before jobs start and copied back after jobs finish. To use SSH, one needs to provide necessary parameters in {dargs:argument}`remote_profile `, such as {dargs:argument}`username ` and {dargs:argument}`hostname `. +By default, DPDispatcher requires {dargs:argument}`remote_root ` to already exist on the remote machine, which helps catch typos in remote paths. If you want DPDispatcher to recursively create that directory tree for you, set {dargs:argument}`create_remote_root ` to `true`. + It's suggested to generate [SSH keys](https://help.ubuntu.com/community/SSH/OpenSSH/Keys) and transfer the public key to the remote server in advance, which is more secure than password authentication. Note that `SSH` context is [non-login](https://www.gnu.org/software/bash/manual/html_node/Bash-Startup-Files.html), so `bash_profile` files will not be executed outside the submission script. diff --git a/doc/getting-started.md b/doc/getting-started.md index 46dc4eb7..43730d58 100644 --- a/doc/getting-started.md +++ b/doc/getting-started.md @@ -72,6 +72,7 @@ where `machine.json` is "context_type": "SSHContext", "local_root": "/home/user123/workplace/22_new_project/", "remote_root": "/home/user123/dpdispatcher_work_dir/", + "create_remote_root": true, "remote_profile": { "hostname": "39.106.xx.xxx", "username": "user123", diff --git a/dpdispatcher/contexts/ssh_context.py b/dpdispatcher/contexts/ssh_context.py index 69bb46d8..9c35f79a 100644 --- a/dpdispatcher/contexts/ssh_context.py +++ b/dpdispatcher/contexts/ssh_context.py @@ -464,6 +464,7 @@ def __init__( remote_root, remote_profile, clean_asynchronously=False, + create_remote_root=False, *args, **kwargs, ): @@ -480,6 +481,7 @@ def __init__( # self.job_uuid = None self.clean_asynchronously = clean_asynchronously + self.create_remote_root = create_remote_root # self.job_uuid = job_uuid # if job_uuid: # self.job_uuid=job_uuid @@ -488,10 +490,7 @@ def __init__( self.ssh_session = SSHSession(**remote_profile) # self.temp_remote_root = os.path.join(self.ssh_session.get_session_root()) self.ssh_session.ensure_alive() - try: - self.sftp.mkdir(self.temp_remote_root) - except OSError: - pass + self._mkdir(self.temp_remote_root, recursive=self.create_remote_root) @classmethod def load_from_dict(cls, context_dict): @@ -511,12 +510,14 @@ def load_from_dict(cls, context_dict): remote_root = context_dict["remote_root"] remote_profile = context_dict["remote_profile"] clean_asynchronously = context_dict.get("clean_asynchronously", False) + create_remote_root = context_dict.get("create_remote_root", False) ssh_context = cls( local_root=local_root, remote_root=remote_root, remote_profile=remote_profile, clean_asynchronously=clean_asynchronously, + create_remote_root=create_remote_root, ) # local_root = jdata['local_root'] # ssh_session = SSHSession(**input) @@ -541,6 +542,28 @@ def close(self): def get_job_root(self): return self.remote_root + def _mkdir(self, remote_dir, recursive=False): + if not remote_dir: + return + + sftp = self.sftp + if not recursive: + try: + sftp.mkdir(remote_dir) + except OSError: + pass + return + + path = pathlib.PurePosixPath(remote_dir) + current = path.root if path.is_absolute() else "" + parts = path.parts[1:] if path.is_absolute() else path.parts + for part in parts: + current = pathlib.PurePosixPath(current, part).as_posix() + try: + sftp.mkdir(current) + except OSError: + pass + def bind_submission(self, submission): assert self.ssh_session is not None assert self.ssh_session.ssh is not None @@ -572,11 +595,7 @@ def bind_submission(self, submission): # if the new directory exists and the old directory does not contain files, then move the old directory self._rmtree(old_remote_root) - sftp = self.ssh_session.ssh.open_sftp() - try: - sftp.mkdir(self.remote_root) - except OSError: - pass + self._mkdir(self.remote_root, recursive=self.create_remote_root) # self.job_uuid = submission.submission_hash # dlog.debug("debug:SSHContext.bind_submission" @@ -1013,8 +1032,22 @@ def machine_subfields(cls) -> List[Argument]: list[Argument] machine subfields """ + doc_create_remote_root = ( + "Whether DPDispatcher should recursively create the configured SSH remote_root " + "when parent directories do not already exist. Keep this disabled by default " + "to avoid silently creating directories for a mistyped path." + ) doc_remote_profile = "SSH connection settings for the remote machine, including authentication, timeouts, and optional proxy/jump-host behavior." remote_profile_format = SSHSession.arginfo() remote_profile_format.name = "remote_profile" remote_profile_format.doc = doc_remote_profile - return [remote_profile_format] + return [ + Argument( + "create_remote_root", + bool, + optional=True, + default=False, + doc=doc_create_remote_root, + ), + remote_profile_format, + ] diff --git a/dpdispatcher/machine.py b/dpdispatcher/machine.py index 578b97bf..68d5284c 100644 --- a/dpdispatcher/machine.py +++ b/dpdispatcher/machine.py @@ -175,6 +175,10 @@ def serialize(self, if_empty_remote_profile=False): machine_dict["context_type"] = self.context.__class__.__name__ machine_dict["local_root"] = self.context.init_local_root machine_dict["remote_root"] = self.context.init_remote_root + if hasattr(self.context, "clean_asynchronously"): + machine_dict["clean_asynchronously"] = self.context.clean_asynchronously + if hasattr(self.context, "create_remote_root"): + machine_dict["create_remote_root"] = self.context.create_remote_root if not if_empty_remote_profile: machine_dict["remote_profile"] = self.context.remote_profile else: diff --git a/examples/machine/ssh_proxy_command.json b/examples/machine/ssh_proxy_command.json index 8310397a..a3c7996a 100644 --- a/examples/machine/ssh_proxy_command.json +++ b/examples/machine/ssh_proxy_command.json @@ -3,6 +3,7 @@ "context_type": "SSHContext", "local_root": "./", "remote_root": "/home/user/work", + "create_remote_root": true, "remote_profile": { "hostname": "internal-server.company.com", "username": "user", diff --git a/tests/test_argcheck.py b/tests/test_argcheck.py index 637c5254..330c3b05 100644 --- a/tests/test_argcheck.py +++ b/tests/test_argcheck.py @@ -31,6 +31,69 @@ def test_machine_argcheck(self): } self.assertDictEqual(norm_dict, expected_dict) + def test_ssh_machine_argcheck(self): + from .context import SSHContext + + original_init = SSHContext.__init__ + + def fake_init( + self, + local_root, + remote_root, + remote_profile, + clean_asynchronously=False, + create_remote_root=False, + *args, + **kwargs, + ): + self.init_local_root = local_root + self.init_remote_root = remote_root + self.remote_profile = remote_profile + self.clean_asynchronously = clean_asynchronously + self.create_remote_root = create_remote_root + + SSHContext.__init__ = fake_init + try: + norm_dict = Machine.load_from_dict( + { + "batch_type": "slurm", + "context_type": "ssh", + "local_root": "./", + "remote_root": "/some/path", + "remote_profile": { + "hostname": "host", + "username": "user", + }, + "create_remote_root": True, + } + ).serialize() + finally: + SSHContext.__init__ = original_init + + expected_dict = { + "batch_type": "Slurm", + "context_type": "SSHContext", + "local_root": "./", + "remote_root": "/some/path", + "remote_profile": { + "hostname": "host", + "username": "user", + "port": 22, + "key_filename": None, + "passphrase": None, + "timeout": 10, + "totp_secret": None, + "tar_compress": True, + "look_for_keys": True, + "execute_command": None, + "proxy_command": None, + }, + "clean_asynchronously": False, + "create_remote_root": True, + "retry_count": 3, + } + self.assertDictEqual(norm_dict, expected_dict) + def test_resources_argcheck(self): norm_dict = Resources.load_from_dict( { diff --git a/tests/test_ssh_create_remote_root.py b/tests/test_ssh_create_remote_root.py new file mode 100644 index 00000000..8fc9d015 --- /dev/null +++ b/tests/test_ssh_create_remote_root.py @@ -0,0 +1,93 @@ +import os +import sys +import unittest +from unittest.mock import MagicMock + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +__package__ = "tests" + +from .context import SSHContext, setUpModule # noqa: F401 + + +class TestSSHCreateRemoteRoot(unittest.TestCase): + def test_recursive_mkdir_disabled_by_default(self): + calls = [] + context = SSHContext.__new__(SSHContext) + context.ssh_session = MagicMock() + context.ssh_session.sftp = MagicMock() + context.ssh_session.sftp.mkdir.side_effect = lambda path: calls.append(path) + + context._mkdir("/data/home/user/work", recursive=False) + + self.assertEqual(calls, ["/data/home/user/work"]) + + def test_recursive_mkdir_creates_missing_parents(self): + calls = [] + context = SSHContext.__new__(SSHContext) + context.ssh_session = MagicMock() + context.ssh_session.sftp = MagicMock() + + def mkdir(path): + calls.append(path) + if path in {"/data", "/data/home/user/work"}: + raise OSError("already exists") + + context.ssh_session.sftp.mkdir.side_effect = mkdir + + context._mkdir("/data/home/user/work", recursive=True) + + self.assertEqual( + calls, + [ + "/data", + "/data/home", + "/data/home/user", + "/data/home/user/work", + ], + ) + + def test_machine_roundtrip_keeps_create_remote_root(self): + machine_dict = { + "batch_type": "Shell", + "context_type": "SSHContext", + "local_root": "./", + "remote_root": "/some/path", + "clean_asynchronously": False, + "create_remote_root": True, + "remote_profile": { + "hostname": "example.com", + "username": "alice", + }, + } + + from .context import Machine + + original_init = SSHContext.__init__ + + def fake_init( + self, + local_root, + remote_root, + remote_profile, + clean_asynchronously=False, + create_remote_root=False, + *args, + **kwargs, + ): + self.init_local_root = local_root + self.init_remote_root = remote_root + self.remote_profile = remote_profile + self.clean_asynchronously = clean_asynchronously + self.create_remote_root = create_remote_root + + SSHContext.__init__ = fake_init + try: + machine = Machine.load_from_dict(machine_dict) + serialized = machine.serialize() + finally: + SSHContext.__init__ = original_init + + self.assertTrue(serialized["create_remote_root"]) + self.assertFalse(serialized["clean_asynchronously"]) + self.assertEqual(serialized["remote_root"], "/some/path") + self.assertEqual(serialized["remote_profile"]["hostname"], "example.com")