diff --git a/web/migrations/versions/add_user_id_to_debugger_func_args_.py b/web/migrations/versions/add_user_id_to_debugger_func_args_.py index d8cd046da5e..07f74a9f4c9 100644 --- a/web/migrations/versions/add_user_id_to_debugger_func_args_.py +++ b/web/migrations/versions/add_user_id_to_debugger_func_args_.py @@ -120,10 +120,18 @@ def upgrade(): if inspector.has_table(table_name): op.execute(stmt) - # --- Unique constraint on SharedServer(osid, user_id) --- - # Prevents duplicate SharedServer records from TOCTOU race. - # First remove duplicates (keep lowest id per osid+user_id). + # --- SharedServer cleanup and constraints --- if inspector.has_table('sharedserver'): + # Clean up orphaned SharedServer records whose osid + # references a Server that no longer exists. + conn.execute(sa.text( + "DELETE FROM sharedserver WHERE osid NOT IN " + "(SELECT id FROM server)" + )) + + # Deduplicate SharedServer records that would violate + # the unique constraint. Keep the record with the + # lowest id (oldest). if dialect == 'sqlite': op.execute( 'DELETE FROM sharedserver WHERE id NOT IN ' @@ -137,11 +145,51 @@ def upgrade(): 'AND s1.user_id = s2.user_id ' 'AND s1.id > s2.id' ) - with op.batch_alter_table('sharedserver') as batch: - batch.create_unique_constraint( - 'uq_sharedserver_osid_user', - ['osid', 'user_id'] - ) + + # Add missing columns to sharedserver (guard against + # re-runs where columns may already exist). + existing_cols = { + c['name'] for c in inspector.get_columns('sharedserver') + } + new_columns = [ + ('passexec_cmd', + sa.Column('passexec_cmd', sa.Text(), + nullable=True)), + ('passexec_expiration', + sa.Column('passexec_expiration', sa.Integer(), + nullable=True)), + ('kerberos_conn', + sa.Column('kerberos_conn', sa.Boolean(), + nullable=False, + server_default='false')), + ('tags', + sa.Column('tags', sa.JSON(), nullable=True)), + ('post_connection_sql', + sa.Column('post_connection_sql', sa.String(), + nullable=True)), + ] + cols_to_add = [ + col for name, col in new_columns + if name not in existing_cols + ] + if cols_to_add: + with op.batch_alter_table('sharedserver') as batch_op: + for col in cols_to_add: + batch_op.add_column(col) + + # Unique constraint prevents duplicate SharedServer + # records from TOCTOU race conditions. + existing_ucs = { + uc['name'] + for uc in inspector.get_unique_constraints( + 'sharedserver') + } + if 'uq_sharedserver_osid_user' not in existing_ucs: + with op.batch_alter_table('sharedserver') as batch: + batch.create_unique_constraint( + 'uq_sharedserver_osid_user', + ['osid', 'user_id'] + ) def downgrade(): diff --git a/web/pgadmin/browser/server_groups/servers/__init__.py b/web/pgadmin/browser/server_groups/servers/__init__.py index cbcf79a3cea..ffda5915195 100644 --- a/web/pgadmin/browser/server_groups/servers/__init__.py +++ b/web/pgadmin/browser/server_groups/servers/__init__.py @@ -172,9 +172,7 @@ def get_shared_server_properties(server, sharedserver): Return shared server properties. Overlays per-user SharedServer values onto the owner's Server - object. Security-sensitive fields that are absent from the - SharedServer model (passexec_cmd, post_connection_sql) are - suppressed for non-owners. + object so each non-owner sees their own customizations. The server is expunged from the SQLAlchemy session before mutation so that the owner's record is never dirtied. @@ -182,6 +180,9 @@ def get_shared_server_properties(server, sharedserver): :param sharedserver: :return: shared server (detached) """ + if sharedserver is None: + return server + # Detach from session so in-place mutations are never # flushed back to the owner's Server row. sess = object_session(server) @@ -224,13 +225,11 @@ def get_shared_server_properties(server, sharedserver): server.server_owner = sharedserver.server_owner server.password = sharedserver.password server.prepare_threshold = sharedserver.prepare_threshold - - # Suppress owner-only fields that are absent from SharedServer - # and dangerous when inherited (privilege escalation / code - # execution). - server.passexec_cmd = None - server.passexec_expiration = None - server.post_connection_sql = None + server.passexec_cmd = sharedserver.passexec_cmd + server.passexec_expiration = sharedserver.passexec_expiration + server.kerberos_conn = sharedserver.kerberos_conn + server.tags = sharedserver.tags + server.post_connection_sql = sharedserver.post_connection_sql return server @@ -477,7 +476,12 @@ def create_shared_server(data, gid): tunnel_prompt_password=0, shared=True, connection_params=safe_conn_params, - prepare_threshold=data.prepare_threshold + prepare_threshold=data.prepare_threshold, + passexec_cmd=None, + passexec_expiration=None, + kerberos_conn=False, + tags=None, + post_connection_sql=None ) db.session.add(shared_server) db.session.commit() @@ -998,8 +1002,21 @@ def _set_valid_attr_value(self, gid, data, config_param_map, server, if not crypt_key_present: raise CryptKeyMissing + # Fields that non-owners must never set on their + # SharedServer — they enable command/SQL execution + # or are owner-level concepts not on SharedServer. + _owner_only_fields = frozenset({ + 'passexec_cmd', 'passexec_expiration', + 'post_connection_sql', + 'db_res', 'db_res_type', + }) + for arg in config_param_map: if arg in data: + # Non-owners cannot set dangerous fields. + if _is_non_owner(server) and \ + arg in _owner_only_fields: + continue value = data[arg] if arg == 'password': value = encrypt(data[arg], crypt_key) @@ -1161,12 +1178,10 @@ def properties(self, gid, sid): 'db_res_type': server.db_res_type, 'passexec_cmd': server.passexec_cmd - if server.passexec_cmd and - not _is_non_owner(server) else None, + if server.passexec_cmd else None, 'passexec_expiration': server.passexec_expiration - if server.passexec_expiration and - not _is_non_owner(server) else None, + if server.passexec_expiration else None, 'service': server.service if server.service else None, 'use_ssh_tunnel': use_ssh_tunnel, 'tunnel_host': tunnel_host, @@ -1186,8 +1201,7 @@ def properties(self, gid, sid): 'connection_string': display_connection_str, 'prepare_threshold': server.prepare_threshold, 'tags': tags, - 'post_connection_sql': server.post_connection_sql - if not _is_non_owner(server) else None, + 'post_connection_sql': server.post_connection_sql, } return ajax_response(response) @@ -1605,6 +1619,13 @@ def connect(self, gid, sid, is_qt=False, server=None): # the API call is not made from SQL Editor or View/Edit Data tool if not manager.connection().connected() and not is_qt: manager.update(server) + # Re-suppress owner-only fields after update() which + # rebuilds them from the (overlaid) server object. + # Belt-and-suspenders: the overlay already defaults + # these to None, but this guards against direct DB edits. + if _is_non_owner(server): + manager.passexec = None + manager.post_connection_sql = None conn = manager.connection() # Get enc key diff --git a/web/pgadmin/browser/server_groups/servers/tests/test_shared_server_unit.py b/web/pgadmin/browser/server_groups/servers/tests/test_shared_server_unit.py index 7b41af9055c..4ec3e594aec 100644 --- a/web/pgadmin/browser/server_groups/servers/tests/test_shared_server_unit.py +++ b/web/pgadmin/browser/server_groups/servers/tests/test_shared_server_unit.py @@ -84,6 +84,11 @@ def _make_shared_server(**overrides): 'sslcert': '/home/nonowner/.ssl/cert.pem', 'connect_timeout': '10', }, + passexec_cmd=None, + passexec_expiration=None, + kerberos_conn=False, + tags=None, + post_connection_sql=None, ) defaults.update(overrides) ss = MagicMock() @@ -97,10 +102,12 @@ class TestGetSharedServerProperties(BaseTestGenerator): using mock objects.""" scenarios = [ - ('Merge suppresses passexec_cmd', - dict(test_method='test_suppresses_passexec')), - ('Merge suppresses post_connection_sql', - dict(test_method='test_suppresses_post_sql')), + ('Merge overlays passexec_cmd from SharedServer', + dict(test_method='test_overlays_passexec')), + ('Merge overlays post_connection_sql from SharedServer', + dict(test_method='test_overlays_post_sql')), + ('Merge overlays kerberos_conn and tags', + dict(test_method='test_overlays_kerberos_tags')), ('Merge strips owner SSL paths not in SharedServer', dict(test_method='test_strips_owner_ssl_paths')), ('Merge applies SharedServer SSL paths', @@ -111,6 +118,8 @@ class TestGetSharedServerProperties(BaseTestGenerator): dict(test_method='test_overrides_tunnel')), ('Merge handles None connection_params', dict(test_method='test_none_conn_params')), + ('Merge returns server unchanged when sharedserver is None', + dict(test_method='test_null_guard')), ] @patch('pgadmin.browser.server_groups.servers.' @@ -128,14 +137,41 @@ def _merge(self, server=None, ss=None): return ServerModule.get_shared_server_properties( server, ss) - def test_suppresses_passexec(self): + def test_overlays_passexec(self): + # SharedServer defaults have None - overlay copies that. result = self._merge() self.assertIsNone(result.passexec_cmd) self.assertIsNone(result.passexec_expiration) - - def test_suppresses_post_sql(self): + # If SharedServer has a value, it should appear. + ss = _make_shared_server( + passexec_cmd='/usr/bin/get-pw', + passexec_expiration=120) + result = self._merge(ss=ss) + self.assertEqual(result.passexec_cmd, '/usr/bin/get-pw') + self.assertEqual(result.passexec_expiration, 120) + + def test_overlays_post_sql(self): + # SharedServer defaults have None - overlay copies that. result = self._merge() self.assertIsNone(result.post_connection_sql) + # If SharedServer has a value, it should appear. + ss = _make_shared_server( + post_connection_sql='SET role reader;') + result = self._merge(ss=ss) + self.assertEqual( + result.post_connection_sql, 'SET role reader;') + + def test_overlays_kerberos_tags(self): + result = self._merge() + self.assertFalse(result.kerberos_conn) + self.assertIsNone(result.tags) + # With values set on SharedServer + ss = _make_shared_server( + kerberos_conn=True, + tags=[{'text': 'prod', 'color': '#f00'}]) + result = self._merge(ss=ss) + self.assertTrue(result.kerberos_conn) + self.assertEqual(len(result.tags), 1) def test_strips_owner_ssl_paths(self): result = self._merge() @@ -180,6 +216,18 @@ def test_none_conn_params(self): # Should not crash; connection_params becomes {} self.assertEqual(result.connection_params, {}) + def test_null_guard(self): + from pgadmin.browser.server_groups.servers import \ + ServerModule + server = _make_server() + # Call directly to bypass _merge's None replacement + result = ServerModule.get_shared_server_properties( + server, None) + # Should return server unchanged + self.assertEqual(result.name, 'OwnerServer') + self.assertEqual(result.passexec_cmd, + '/usr/bin/vault-get-secret') + class TestCreateSharedServerSanitization(BaseTestGenerator): """Verify create_shared_server() strips sensitive @@ -291,6 +339,7 @@ def test_no_session(self): # Should not crash result = ServerModule.get_shared_server_properties( server, ss) + # SharedServer defaults passexec_cmd to None self.assertIsNone(result.passexec_cmd) @@ -457,6 +506,79 @@ def test_owner_deletes(self, mock_cu, mock_ck): 1, server.id) +class TestOwnerOnlyFieldsGuard(BaseTestGenerator): + """Verify _set_valid_attr_value skips owner-only fields + for non-owners.""" + + scenarios = [ + ('Non-owner cannot set passexec_cmd', + dict(test_method='test_nonowner_passexec_blocked')), + ('Owner can set passexec_cmd', + dict(test_method='test_owner_passexec_allowed')), + ] + + def runTest(self): + getattr(self, self.test_method)() + + @patch(SRV_MODULE + '.get_crypt_key', + return_value=(True, b'key')) + @patch(SRV_MODULE + '.current_user') + def test_nonowner_passexec_blocked(self, mock_cu, mock_ck): + mock_cu.id = 200 # Non-owner + from pgadmin.browser.server_groups.servers import \ + ServerNode + + server = _make_server() + ss = _make_shared_server() + node = ServerNode.__new__(ServerNode) + node.delete_shared_server = MagicMock() + + data = { + 'passexec_cmd': '/evil/cmd', + 'post_connection_sql': 'DROP TABLE x;', + } + config_map = { + 'passexec_cmd': 'passexec_cmd', + 'post_connection_sql': 'post_connection_sql', + } + + node._set_valid_attr_value( + 1, data, config_map, server, ss) + + # SharedServer should NOT have these set + self.assertIsNone(ss.passexec_cmd) + self.assertIsNone(ss.post_connection_sql) + + @patch(SRV_MODULE + '.get_crypt_key', + return_value=(True, b'key')) + @patch(SRV_MODULE + '.current_user') + def test_owner_passexec_allowed(self, mock_cu, mock_ck): + mock_cu.id = 100 # Owner + from pgadmin.browser.server_groups.servers import \ + ServerNode + + server = _make_server() + node = ServerNode.__new__(ServerNode) + node.delete_shared_server = MagicMock() + + data = { + 'passexec_cmd': '/usr/bin/new-cmd', + 'post_connection_sql': 'SET role dba;', + } + config_map = { + 'passexec_cmd': 'passexec_cmd', + 'post_connection_sql': 'post_connection_sql', + } + + node._set_valid_attr_value( + 1, data, config_map, server, None) + + # Owner should have these set + self.assertEqual(server.passexec_cmd, '/usr/bin/new-cmd') + self.assertEqual( + server.post_connection_sql, 'SET role dba;') + + class TestGetSharedServerRaisesOnNone(BaseTestGenerator): """Verify get_shared_server() raises if SharedServer cannot be created.""" diff --git a/web/pgadmin/model/__init__.py b/web/pgadmin/model/__init__.py index 62d89ca9412..633b433af41 100644 --- a/web/pgadmin/model/__init__.py +++ b/web/pgadmin/model/__init__.py @@ -560,6 +560,13 @@ class SharedServer(db.Model, UserScopedMixin): shared = db.Column(db.Boolean(), nullable=False) connection_params = db.Column(MutableDict.as_mutable(types.JSON)) prepare_threshold = db.Column(db.Integer(), nullable=True) + passexec_cmd = db.Column(db.Text(), nullable=True) + passexec_expiration = db.Column(db.Integer(), nullable=True) + kerberos_conn = db.Column( + db.Boolean(), nullable=False, default=0 + ) + tags = db.Column(types.JSON) + post_connection_sql = db.Column(db.String(), nullable=True) class Macros(db.Model): diff --git a/web/pgadmin/utils/driver/psycopg3/__init__.py b/web/pgadmin/utils/driver/psycopg3/__init__.py index 0695e83f2a7..e65ae070394 100644 --- a/web/pgadmin/utils/driver/psycopg3/__init__.py +++ b/web/pgadmin/utils/driver/psycopg3/__init__.py @@ -85,8 +85,9 @@ def _restore_connections_from_session(self): manager = managers[str(server.id)] = \ ServerManager(server) # Suppress owner-only fields for non-owners - # of shared servers so passexec_cmd and - # post_connection_sql don't leak. + # of shared servers at the connection layer. + # The UI overlay handles the API layer; this + # handles ServerManager used by tools directly. if server.shared and \ server.user_id != current_user.id: manager.passexec = None @@ -154,7 +155,7 @@ def connection_manager(self, sid=None): # it cannot be None at this point. manager = ServerManager(server_data) # Suppress owner-only fields for non-owners of - # shared servers. + # shared servers at the connection layer. if config.SERVER_MODE and server_data.shared and \ server_data.user_id != current_user.id: manager.passexec = None