diff --git a/kombu/transport/gcpubsub.py b/kombu/transport/gcpubsub.py index b47f620837..a5df5a622d 100644 --- a/kombu/transport/gcpubsub.py +++ b/kombu/transport/gcpubsub.py @@ -55,7 +55,7 @@ from _socket import gethostname from _socket import timeout as socket_timeout from google.api_core.exceptions import (AlreadyExists, DeadlineExceeded, - PermissionDenied) + GoogleAPICallError, PermissionDenied) from google.api_core.retry import Retry from google.cloud import monitoring_v3 from google.cloud.monitoring_v3 import query @@ -64,6 +64,8 @@ from google.cloud.pubsub_v1.publisher import exceptions as publisher_exceptions from google.cloud.pubsub_v1.subscriber import \ exceptions as subscriber_exceptions +from google.cloud.pubsub_v1.types import Subscription +from google.protobuf.field_mask_pb2 import FieldMask from google.pubsub_v1 import gapic_version as package_version from kombu.entity import TRANSIENT_DELIVERY_MODE @@ -304,6 +306,18 @@ def _create_subscription( topic_path = topic_path or self.publisher.topic_path( project_id, topic_id ) + msg_retention = msg_retention or self.expiration_seconds + subscription_config = { + "name": subscription_path, + "topic": topic_path, + 'ack_deadline_seconds': self.ack_deadline_seconds, + 'expiration_policy': { + 'ttl': f'{self.expiration_seconds}s' + }, + 'message_retention_duration': f'{msg_retention}s', + **(filter_args or {}), + } + try: logger.debug( 'creating subscription: %s, topic: %s, filter: %s', @@ -311,21 +325,33 @@ def _create_subscription( topic_path, filter_args, ) - msg_retention = msg_retention or self.expiration_seconds - self.subscriber.create_subscription( - request={ - "name": subscription_path, - "topic": topic_path, - 'ack_deadline_seconds': self.ack_deadline_seconds, - 'expiration_policy': { - 'ttl': f'{self.expiration_seconds}s' - }, - 'message_retention_duration': f'{msg_retention}s', - **(filter_args or {}), - } - ) + self.subscriber.create_subscription(request=subscription_config) except AlreadyExists: - pass + # Subscription exists, update with current configuration + logger.debug('subscription exists, updating: %s', subscription_path) + try: + subscription = Subscription(subscription_config) + update_mask = FieldMask(paths=[ + 'ack_deadline_seconds', + 'expiration_policy.ttl', + 'message_retention_duration', + ]) + if filter_args: + update_mask.paths.append('filter') + + self.subscriber.update_subscription( + request={ + 'subscription': subscription, + 'update_mask': update_mask + } + ) + logger.info('subscription updated: %s', subscription_path) + except GoogleAPICallError as e: + logger.warning( + 'failed to update subscription: %s, error: %s', + subscription_path, + e, + ) return subscription_path def _delete(self, queue, *args, **kwargs): diff --git a/t/unit/transport/test_gcpubsub.py b/t/unit/transport/test_gcpubsub.py index 504eb50a4e..3b9a6703a0 100644 --- a/t/unit/transport/test_gcpubsub.py +++ b/t/unit/transport/test_gcpubsub.py @@ -8,7 +8,7 @@ import pytest from _socket import timeout as socket_timeout from google.api_core.exceptions import (AlreadyExists, DeadlineExceeded, - PermissionDenied) + GoogleAPICallError, PermissionDenied) from google.pubsub_v1.types.pubsub import Subscription from kombu.transport.gcpubsub import (AtomicCounter, Channel, QueueDescriptor, @@ -79,6 +79,8 @@ def channel(): channel.subscriber = MagicMock() channel.publisher = MagicMock() channel.closed = False + channel.ack_deadline_seconds = 240 + channel.expiration_seconds = 86400 with patch.object( Channel, 'conninfo', new_callable=MagicMock ), patch.object( @@ -295,6 +297,115 @@ def test_create_subscription_protobuf_compat(self): } Subscription(request) + def test_create_subscription_updates_when_exists(self, channel): + """Test that subscription settings are always updated when subscription exists.""" + channel.project_id = "project_id" + topic_id = "topic_id" + subscription_path = "subscription_path" + topic_path = "topic_path" + channel.ack_deadline_seconds = 60 + channel.expiration_seconds = 86400 + + # Mock subscription already exists + channel.subscriber.subscription_path = MagicMock( + return_value=subscription_path + ) + channel.publisher.topic_path = MagicMock(return_value=topic_path) + channel.subscriber.create_subscription = MagicMock( + side_effect=AlreadyExists("Subscription exists") + ) + channel.subscriber.update_subscription = MagicMock() + + result = channel._create_subscription( + project_id=channel.project_id, + topic_id=topic_id, + subscription_path=subscription_path, + topic_path=topic_path, + ) + + assert result == subscription_path + channel.subscriber.create_subscription.assert_called_once() + # Verify update was called + channel.subscriber.update_subscription.assert_called_once() + update_call = channel.subscriber.update_subscription.call_args[1] + assert 'subscription' in update_call['request'] + assert 'update_mask' in update_call['request'] + + # Verify the subscription object contains correct values + subscription = update_call['request']['subscription'] + assert subscription.name == subscription_path + assert subscription.topic == topic_path + assert subscription.ack_deadline_seconds == 60 + assert subscription.expiration_policy.ttl.total_seconds() == 86400 + assert subscription.message_retention_duration.total_seconds() == 86400 + + # Verify update mask includes all fields + update_mask_paths = update_call['request']['update_mask'].paths + assert 'ack_deadline_seconds' in update_mask_paths + assert 'expiration_policy.ttl' in update_mask_paths + assert 'message_retention_duration' in update_mask_paths + + def test_create_subscription_with_filter(self, channel): + """Test that filter is included in update mask when present.""" + channel.project_id = "project_id" + topic_id = "topic_id" + subscription_path = "subscription_path" + topic_path = "topic_path" + filter_args = {'filter': 'attributes.routing_key="test"'} + + channel.subscriber.subscription_path = MagicMock( + return_value=subscription_path + ) + channel.publisher.topic_path = MagicMock(return_value=topic_path) + channel.subscriber.create_subscription = MagicMock( + side_effect=AlreadyExists("Subscription exists") + ) + channel.subscriber.update_subscription = MagicMock() + + channel._create_subscription( + project_id=channel.project_id, + topic_id=topic_id, + subscription_path=subscription_path, + topic_path=topic_path, + filter_args=filter_args, + ) + + # Verify filter was included in update mask + channel.subscriber.update_subscription.assert_called_once() + update_call = channel.subscriber.update_subscription.call_args[1] + update_mask_paths = update_call['request']['update_mask'].paths + assert 'filter' in update_mask_paths + + def test_create_subscription_handles_update_failure_gracefully(self, channel): + """Test that update failures don't crash the worker.""" + channel.project_id = "project_id" + topic_id = "topic_id" + subscription_path = "subscription_path" + topic_path = "topic_path" + + channel.subscriber.subscription_path = MagicMock( + return_value=subscription_path + ) + channel.publisher.topic_path = MagicMock(return_value=topic_path) + channel.subscriber.create_subscription = MagicMock( + side_effect=AlreadyExists("Subscription exists") + ) + # Mock update fails with Google API error + channel.subscriber.update_subscription = MagicMock( + side_effect=GoogleAPICallError("API Error") + ) + + # Should not raise exception - just log warning + result = channel._create_subscription( + project_id=channel.project_id, + topic_id=topic_id, + subscription_path=subscription_path, + topic_path=topic_path, + ) + + assert result == subscription_path + channel.subscriber.update_subscription.assert_called_once() + def test_delete(self, channel): queue = "test_queue" subscription_path = "projects/project-id/subscriptions/test_queue"