Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 28 additions & 26 deletions api/integrations/launch_darkly/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from projects.tags.models import Tag
from segments.models import Condition, Segment, SegmentRule
from users.models import FFAdminUser
from util.db import closing_stale_connections
from util.util import iter_chunked_concat, truncate

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -1113,34 +1114,35 @@ def process_import_request(

ld_client = LaunchDarklyClient(ld_token)

try:
ld_environments = ld_client.get_environments(project_key=ld_project_key)
ld_flags = ld_client.get_flags_by_envs(
project_key=ld_project_key,
environment_keys=[env["key"] for env in ld_environments],
)
ld_flag_tags = ld_client.get_flag_tags()
# ld_segment_tags = ld_client.get_segment_tags()
# Keyed by (segment, environment)
ld_segments: list[tuple[ld_types.UserSegment, str]] = []
for env in ld_environments:
ld_segments_for_env = ld_client.get_segments(
with closing_stale_connections():
try:
ld_environments = ld_client.get_environments(project_key=ld_project_key)
ld_flags = ld_client.get_flags_by_envs(
project_key=ld_project_key,
environment_key=env["key"],
environment_keys=[env["key"] for env in ld_environments],
)
for segment in ld_segments_for_env:
ld_segments.append((segment, env["key"]))

except RequestException as exc:
_log_error(
import_request=import_request,
error_message=(
f"{exc.__class__.__name__} "
f"{str(exc.response.status_code) + ' ' if exc.response else ''}"
+ f"when requesting {getattr(exc.request, 'path_url', 'unknown')}"
),
)
raise
ld_flag_tags = ld_client.get_flag_tags()
# ld_segment_tags = ld_client.get_segment_tags()
# Keyed by (segment, environment)
ld_segments: list[tuple[ld_types.UserSegment, str]] = []
for env in ld_environments:
ld_segments_for_env = ld_client.get_segments(
project_key=ld_project_key,
environment_key=env["key"],
)
for segment in ld_segments_for_env:
ld_segments.append((segment, env["key"]))

except RequestException as exc:
_log_error(
import_request=import_request,
error_message=(
f"{exc.__class__.__name__} "
f"{str(exc.response.status_code) + ' ' if exc.response else ''}"
+ f"when requesting {getattr(exc.request, 'path_url', 'unknown')}"
),
)
raise

# Create environments
environments_by_ld_environment_key = _create_environments_from_ld(
Expand Down
5 changes: 5 additions & 0 deletions api/tests/unit/integrations/launch_darkly/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def test_create_import_request__valid_project__returns_expected(
(Timeout(), "Timeout when requesting /expected_path"),
],
)
@pytest.mark.django_db(transaction=True)
def test_process_import_request__api_error__expected_status(
ld_client_mock: MagicMock,
ld_client_class_mock: MagicMock,
Expand All @@ -100,6 +101,7 @@ def test_process_import_request__api_error__expected_status(
assert import_request.status["error_messages"] == [expected_error_message]


@pytest.mark.django_db(transaction=True)
def test_process_import_request__success__expected_status( # type: ignore[no-untyped-def]
project: Project,
import_request: LaunchDarklyImportRequest,
Expand Down Expand Up @@ -259,6 +261,7 @@ def test_process_import_request__success__expected_status( # type: ignore[no-un
[tag.label for tag in tagged_feature.tags.all()] == ["testtag", "testtag2"]


@pytest.mark.django_db(transaction=True)
def test_process_import_request__valid_segments__imports_correctly( # type: ignore[no-untyped-def]
project: Project,
import_request: LaunchDarklyImportRequest,
Expand Down Expand Up @@ -459,6 +462,7 @@ def test_process_import_request__valid_segments__imports_correctly( # type: ign
assert trait_value == identity.identifier


@pytest.mark.django_db(transaction=True)
def test_process_import_request__valid_rules__imports_correctly( # type: ignore[no-untyped-def]
project: Project,
import_request: LaunchDarklyImportRequest,
Expand Down Expand Up @@ -555,6 +559,7 @@ def test_process_import_request__valid_rules__imports_correctly( # type: ignore
}


@pytest.mark.django_db(transaction=True)
def test_process_import_request__large_segments__correctly_imported(
request: pytest.FixtureRequest,
ld_client_class_mock: MagicMock,
Expand Down
17 changes: 17 additions & 0 deletions api/tests/unit/util/test_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from pytest_mock import MockerFixture

from util.db import closing_stale_connections


def test_closing_stale_connections__exit__calls_close_old_connections(
mocker: MockerFixture,
) -> None:
# Given
mock_close_old_connections = mocker.patch("util.db.close_old_connections")

# When
with closing_stale_connections():
pass

# Then
mock_close_old_connections.assert_called_once_with()
19 changes: 19 additions & 0 deletions api/util/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from collections.abc import Iterator
from contextlib import contextmanager

from django.db import close_old_connections


@contextmanager
def closing_stale_connections() -> Iterator[None]:
"""
Close any stale DB connections when the wrapped block exits.

Intended for blocks that may hold a DB connection idle for long enough
that the DB server (or an intermediate proxy) terminates it — e.g. an
HTTP call to a slow third-party API preceding a write phase.
"""
try:
yield
finally:
close_old_connections()
Loading