diff --git a/doc/context.md b/doc/context.md index 627158d3..ed1462e1 100644 --- a/doc/context.md +++ b/doc/context.md @@ -29,6 +29,8 @@ Since [`bash -l`](https://www.gnu.org/software/bash/manual/bash.html#Invoking-Ba Files will be copied to the remote directory via SSH channels before jobs start and copied back after jobs finish. To use SSH, one needs to provide necessary parameters in {dargs:argument}`remote_profile `, such as {dargs:argument}`username ` and {dargs:argument}`hostname `. +By default, DPDispatcher requires {dargs:argument}`remote_root ` to already exist on the remote machine, which helps catch typos in remote paths. If you want DPDispatcher to recursively create that directory tree for you, set {dargs:argument}`create_remote_root ` to `true`. + It's suggested to generate [SSH keys](https://help.ubuntu.com/community/SSH/OpenSSH/Keys) and transfer the public key to the remote server in advance, which is more secure than password authentication. Note that `SSH` context is [non-login](https://www.gnu.org/software/bash/manual/html_node/Bash-Startup-Files.html), so `bash_profile` files will not be executed outside the submission script. diff --git a/doc/getting-started.md b/doc/getting-started.md index 46dc4eb7..43730d58 100644 --- a/doc/getting-started.md +++ b/doc/getting-started.md @@ -72,6 +72,7 @@ where `machine.json` is "context_type": "SSHContext", "local_root": "/home/user123/workplace/22_new_project/", "remote_root": "/home/user123/dpdispatcher_work_dir/", + "create_remote_root": true, "remote_profile": { "hostname": "39.106.xx.xxx", "username": "user123", diff --git a/dpdispatcher/contexts/ssh_context.py b/dpdispatcher/contexts/ssh_context.py index 69bb46d8..9c35f79a 100644 --- a/dpdispatcher/contexts/ssh_context.py +++ b/dpdispatcher/contexts/ssh_context.py @@ -464,6 +464,7 @@ def __init__( remote_root, remote_profile, clean_asynchronously=False, + create_remote_root=False, *args, **kwargs, ): @@ -480,6 +481,7 @@ def __init__( # self.job_uuid = None self.clean_asynchronously = clean_asynchronously + self.create_remote_root = create_remote_root # self.job_uuid = job_uuid # if job_uuid: # self.job_uuid=job_uuid @@ -488,10 +490,7 @@ def __init__( self.ssh_session = SSHSession(**remote_profile) # self.temp_remote_root = os.path.join(self.ssh_session.get_session_root()) self.ssh_session.ensure_alive() - try: - self.sftp.mkdir(self.temp_remote_root) - except OSError: - pass + self._mkdir(self.temp_remote_root, recursive=self.create_remote_root) @classmethod def load_from_dict(cls, context_dict): @@ -511,12 +510,14 @@ def load_from_dict(cls, context_dict): remote_root = context_dict["remote_root"] remote_profile = context_dict["remote_profile"] clean_asynchronously = context_dict.get("clean_asynchronously", False) + create_remote_root = context_dict.get("create_remote_root", False) ssh_context = cls( local_root=local_root, remote_root=remote_root, remote_profile=remote_profile, clean_asynchronously=clean_asynchronously, + create_remote_root=create_remote_root, ) # local_root = jdata['local_root'] # ssh_session = SSHSession(**input) @@ -541,6 +542,28 @@ def close(self): def get_job_root(self): return self.remote_root + def _mkdir(self, remote_dir, recursive=False): + if not remote_dir: + return + + sftp = self.sftp + if not recursive: + try: + sftp.mkdir(remote_dir) + except OSError: + pass + return + + path = pathlib.PurePosixPath(remote_dir) + current = path.root if path.is_absolute() else "" + parts = path.parts[1:] if path.is_absolute() else path.parts + for part in parts: + current = pathlib.PurePosixPath(current, part).as_posix() + try: + sftp.mkdir(current) + except OSError: + pass + def bind_submission(self, submission): assert self.ssh_session is not None assert self.ssh_session.ssh is not None @@ -572,11 +595,7 @@ def bind_submission(self, submission): # if the new directory exists and the old directory does not contain files, then move the old directory self._rmtree(old_remote_root) - sftp = self.ssh_session.ssh.open_sftp() - try: - sftp.mkdir(self.remote_root) - except OSError: - pass + self._mkdir(self.remote_root, recursive=self.create_remote_root) # self.job_uuid = submission.submission_hash # dlog.debug("debug:SSHContext.bind_submission" @@ -1013,8 +1032,22 @@ def machine_subfields(cls) -> List[Argument]: list[Argument] machine subfields """ + doc_create_remote_root = ( + "Whether DPDispatcher should recursively create the configured SSH remote_root " + "when parent directories do not already exist. Keep this disabled by default " + "to avoid silently creating directories for a mistyped path." + ) doc_remote_profile = "SSH connection settings for the remote machine, including authentication, timeouts, and optional proxy/jump-host behavior." remote_profile_format = SSHSession.arginfo() remote_profile_format.name = "remote_profile" remote_profile_format.doc = doc_remote_profile - return [remote_profile_format] + return [ + Argument( + "create_remote_root", + bool, + optional=True, + default=False, + doc=doc_create_remote_root, + ), + remote_profile_format, + ] diff --git a/dpdispatcher/machine.py b/dpdispatcher/machine.py index 578b97bf..68d5284c 100644 --- a/dpdispatcher/machine.py +++ b/dpdispatcher/machine.py @@ -175,6 +175,10 @@ def serialize(self, if_empty_remote_profile=False): machine_dict["context_type"] = self.context.__class__.__name__ machine_dict["local_root"] = self.context.init_local_root machine_dict["remote_root"] = self.context.init_remote_root + if hasattr(self.context, "clean_asynchronously"): + machine_dict["clean_asynchronously"] = self.context.clean_asynchronously + if hasattr(self.context, "create_remote_root"): + machine_dict["create_remote_root"] = self.context.create_remote_root if not if_empty_remote_profile: machine_dict["remote_profile"] = self.context.remote_profile else: diff --git a/examples/machine/ssh_proxy_command.json b/examples/machine/ssh_proxy_command.json index 8310397a..a3c7996a 100644 --- a/examples/machine/ssh_proxy_command.json +++ b/examples/machine/ssh_proxy_command.json @@ -3,6 +3,7 @@ "context_type": "SSHContext", "local_root": "./", "remote_root": "/home/user/work", + "create_remote_root": true, "remote_profile": { "hostname": "internal-server.company.com", "username": "user", diff --git a/tests/test_argcheck.py b/tests/test_argcheck.py index 637c5254..330c3b05 100644 --- a/tests/test_argcheck.py +++ b/tests/test_argcheck.py @@ -31,6 +31,69 @@ def test_machine_argcheck(self): } self.assertDictEqual(norm_dict, expected_dict) + def test_ssh_machine_argcheck(self): + from .context import SSHContext + + original_init = SSHContext.__init__ + + def fake_init( + self, + local_root, + remote_root, + remote_profile, + clean_asynchronously=False, + create_remote_root=False, + *args, + **kwargs, + ): + self.init_local_root = local_root + self.init_remote_root = remote_root + self.remote_profile = remote_profile + self.clean_asynchronously = clean_asynchronously + self.create_remote_root = create_remote_root + + SSHContext.__init__ = fake_init + try: + norm_dict = Machine.load_from_dict( + { + "batch_type": "slurm", + "context_type": "ssh", + "local_root": "./", + "remote_root": "/some/path", + "remote_profile": { + "hostname": "host", + "username": "user", + }, + "create_remote_root": True, + } + ).serialize() + finally: + SSHContext.__init__ = original_init + + expected_dict = { + "batch_type": "Slurm", + "context_type": "SSHContext", + "local_root": "./", + "remote_root": "/some/path", + "remote_profile": { + "hostname": "host", + "username": "user", + "port": 22, + "key_filename": None, + "passphrase": None, + "timeout": 10, + "totp_secret": None, + "tar_compress": True, + "look_for_keys": True, + "execute_command": None, + "proxy_command": None, + }, + "clean_asynchronously": False, + "create_remote_root": True, + "retry_count": 3, + } + self.assertDictEqual(norm_dict, expected_dict) + def test_resources_argcheck(self): norm_dict = Resources.load_from_dict( { diff --git a/tests/test_ssh_create_remote_root.py b/tests/test_ssh_create_remote_root.py new file mode 100644 index 00000000..8fc9d015 --- /dev/null +++ b/tests/test_ssh_create_remote_root.py @@ -0,0 +1,93 @@ +import os +import sys +import unittest +from unittest.mock import MagicMock + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +__package__ = "tests" + +from .context import SSHContext, setUpModule # noqa: F401 + + +class TestSSHCreateRemoteRoot(unittest.TestCase): + def test_recursive_mkdir_disabled_by_default(self): + calls = [] + context = SSHContext.__new__(SSHContext) + context.ssh_session = MagicMock() + context.ssh_session.sftp = MagicMock() + context.ssh_session.sftp.mkdir.side_effect = lambda path: calls.append(path) + + context._mkdir("/data/home/user/work", recursive=False) + + self.assertEqual(calls, ["/data/home/user/work"]) + + def test_recursive_mkdir_creates_missing_parents(self): + calls = [] + context = SSHContext.__new__(SSHContext) + context.ssh_session = MagicMock() + context.ssh_session.sftp = MagicMock() + + def mkdir(path): + calls.append(path) + if path in {"/data", "/data/home/user/work"}: + raise OSError("already exists") + + context.ssh_session.sftp.mkdir.side_effect = mkdir + + context._mkdir("/data/home/user/work", recursive=True) + + self.assertEqual( + calls, + [ + "/data", + "/data/home", + "/data/home/user", + "/data/home/user/work", + ], + ) + + def test_machine_roundtrip_keeps_create_remote_root(self): + machine_dict = { + "batch_type": "Shell", + "context_type": "SSHContext", + "local_root": "./", + "remote_root": "/some/path", + "clean_asynchronously": False, + "create_remote_root": True, + "remote_profile": { + "hostname": "example.com", + "username": "alice", + }, + } + + from .context import Machine + + original_init = SSHContext.__init__ + + def fake_init( + self, + local_root, + remote_root, + remote_profile, + clean_asynchronously=False, + create_remote_root=False, + *args, + **kwargs, + ): + self.init_local_root = local_root + self.init_remote_root = remote_root + self.remote_profile = remote_profile + self.clean_asynchronously = clean_asynchronously + self.create_remote_root = create_remote_root + + SSHContext.__init__ = fake_init + try: + machine = Machine.load_from_dict(machine_dict) + serialized = machine.serialize() + finally: + SSHContext.__init__ = original_init + + self.assertTrue(serialized["create_remote_root"]) + self.assertFalse(serialized["clean_asynchronously"]) + self.assertEqual(serialized["remote_root"], "/some/path") + self.assertEqual(serialized["remote_profile"]["hostname"], "example.com")