diff --git a/kombu/transport/gcpubsub.py b/kombu/transport/gcpubsub.py index 82453bd443..418aaf8f3a 100644 --- a/kombu/transport/gcpubsub.py +++ b/kombu/transport/gcpubsub.py @@ -63,7 +63,8 @@ from _socket import gethostname from _socket import timeout as socket_timeout from google.api_core.exceptions import (AlreadyExists, DeadlineExceeded, - NotFound, PermissionDenied) + GoogleAPICallError, NotFound, + PermissionDenied) from google.api_core.retry import Retry from google.cloud import monitoring_v3 from google.cloud.monitoring_v3 import query @@ -72,6 +73,8 @@ from google.cloud.pubsub_v1.publisher import exceptions as publisher_exceptions from google.cloud.pubsub_v1.subscriber import \ exceptions as subscriber_exceptions +from google.cloud.pubsub_v1.types import Subscription +from google.protobuf.field_mask_pb2 import FieldMask from google.pubsub_v1 import gapic_version as package_version from kombu.entity import TRANSIENT_DELIVERY_MODE @@ -316,27 +319,51 @@ def _create_subscription( topic_path = topic_path or self.publisher.topic_path( project_id, topic_id ) + msg_retention = msg_retention or self.expiration_seconds + subscription_config = { + "name": subscription_path, + "topic": topic_path, + "ack_deadline_seconds": self.ack_deadline_seconds, + "expiration_policy": {"ttl": f"{self.expiration_seconds}s"}, + "message_retention_duration": f"{msg_retention}s", + "enable_exactly_once_delivery": self.enable_exactly_once_delivery, + **(filter_args or {}), + } try: logger.debug( - 'creating subscription: %s, topic: %s, filter: %s', + "creating subscription: %s, topic: %s, filter: %s", subscription_path, topic_path, filter_args, ) - msg_retention = msg_retention or self.expiration_seconds - self.subscriber.create_subscription( - request={ - "name": subscription_path, - "topic": topic_path, - "ack_deadline_seconds": self.ack_deadline_seconds, - "expiration_policy": {"ttl": f"{self.expiration_seconds}s"}, - "message_retention_duration": f"{msg_retention}s", - "enable_exactly_once_delivery": self.enable_exactly_once_delivery, - **(filter_args or {}), - } - ) + self.subscriber.create_subscription(request=subscription_config) except AlreadyExists: - pass + logger.debug( + "subscription exists, updating: %s", subscription_path + ) + try: + subscription = Subscription(subscription_config) + update_mask = FieldMask(paths=[ + "ack_deadline_seconds", + "expiration_policy.ttl", + "message_retention_duration", + "enable_exactly_once_delivery", + ]) + if filter_args: + update_mask.paths.append("filter") + self.subscriber.update_subscription( + request={ + "subscription": subscription, + "update_mask": update_mask, + } + ) + logger.info("subscription updated: %s", subscription_path) + except GoogleAPICallError as e: + logger.warning( + "failed to update subscription: %s, error: %s", + subscription_path, + e, + ) return subscription_path def _delete(self, queue, *args, **kwargs): diff --git a/t/unit/transport/test_gcpubsub.py b/t/unit/transport/test_gcpubsub.py index 7839e83db7..e6557f89f9 100644 --- a/t/unit/transport/test_gcpubsub.py +++ b/t/unit/transport/test_gcpubsub.py @@ -9,7 +9,8 @@ import pytest from _socket import timeout as socket_timeout from google.api_core.exceptions import (AlreadyExists, DeadlineExceeded, - NotFound, PermissionDenied) + GoogleAPICallError, NotFound, + PermissionDenied) from google.pubsub_v1.types.pubsub import Subscription from kombu.transport.gcpubsub import (_ACK_MODIFY_BATCH_SIZE_DEFAULT, @@ -81,6 +82,8 @@ def channel(): channel.subscriber = MagicMock() channel.publisher = MagicMock() channel.closed = False + channel.ack_deadline_seconds = 240 + channel.expiration_seconds = 86400 with patch.object( Channel, 'conninfo', new_callable=MagicMock ), patch.object( @@ -294,6 +297,113 @@ def test_create_subscription_protobuf_compat(self): } Subscription(request) + def test_create_subscription_updates_when_exists(self, channel): + """Subscription settings are updated when the subscription exists.""" + channel.project_id = "project_id" + topic_id = "topic_id" + subscription_path = "subscription_path" + topic_path = "topic_path" + channel.ack_deadline_seconds = 60 + channel.expiration_seconds = 86400 + channel.enable_exactly_once_delivery = True + + channel.subscriber.subscription_path = MagicMock( + return_value=subscription_path + ) + channel.publisher.topic_path = MagicMock(return_value=topic_path) + channel.subscriber.create_subscription = MagicMock( + side_effect=AlreadyExists("Subscription exists") + ) + channel.subscriber.update_subscription = MagicMock() + + result = channel._create_subscription( + project_id=channel.project_id, + topic_id=topic_id, + subscription_path=subscription_path, + topic_path=topic_path, + ) + + assert result == subscription_path + channel.subscriber.create_subscription.assert_called_once() + channel.subscriber.update_subscription.assert_called_once() + update_call = channel.subscriber.update_subscription.call_args[1] + assert 'subscription' in update_call['request'] + assert 'update_mask' in update_call['request'] + + subscription = update_call['request']['subscription'] + assert subscription.name == subscription_path + assert subscription.topic == topic_path + assert subscription.ack_deadline_seconds == 60 + assert subscription.expiration_policy.ttl.total_seconds() == 86400 + assert subscription.message_retention_duration.total_seconds() == 86400 + assert subscription.enable_exactly_once_delivery is True + + update_mask_paths = update_call['request']['update_mask'].paths + assert 'ack_deadline_seconds' in update_mask_paths + assert 'expiration_policy.ttl' in update_mask_paths + assert 'message_retention_duration' in update_mask_paths + assert 'enable_exactly_once_delivery' in update_mask_paths + + def test_create_subscription_with_filter(self, channel): + """Filter is included in update mask when present.""" + channel.project_id = "project_id" + topic_id = "topic_id" + subscription_path = "subscription_path" + topic_path = "topic_path" + channel.enable_exactly_once_delivery = False + filter_args = {'filter': 'attributes.routing_key="test"'} + + channel.subscriber.subscription_path = MagicMock( + return_value=subscription_path + ) + channel.publisher.topic_path = MagicMock(return_value=topic_path) + channel.subscriber.create_subscription = MagicMock( + side_effect=AlreadyExists("Subscription exists") + ) + channel.subscriber.update_subscription = MagicMock() + + channel._create_subscription( + project_id=channel.project_id, + topic_id=topic_id, + subscription_path=subscription_path, + topic_path=topic_path, + filter_args=filter_args, + ) + + channel.subscriber.update_subscription.assert_called_once() + update_call = channel.subscriber.update_subscription.call_args[1] + update_mask_paths = update_call['request']['update_mask'].paths + assert 'filter' in update_mask_paths + + def test_create_subscription_handles_update_failure_gracefully(self, channel): + """Update failures must not crash the worker.""" + channel.project_id = "project_id" + topic_id = "topic_id" + subscription_path = "subscription_path" + topic_path = "topic_path" + channel.enable_exactly_once_delivery = False + + channel.subscriber.subscription_path = MagicMock( + return_value=subscription_path + ) + channel.publisher.topic_path = MagicMock(return_value=topic_path) + channel.subscriber.create_subscription = MagicMock( + side_effect=AlreadyExists("Subscription exists") + ) + channel.subscriber.update_subscription = MagicMock( + side_effect=GoogleAPICallError("API Error") + ) + + result = channel._create_subscription( + project_id=channel.project_id, + topic_id=topic_id, + subscription_path=subscription_path, + topic_path=topic_path, + ) + + assert result == subscription_path + channel.subscriber.update_subscription.assert_called_once() + def test_delete(self, channel): queue = "test_queue" subscription_path = "projects/project-id/subscriptions/test_queue"