Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ We recommend using [uv](https://docs.astral.sh/uv/). It's super fast.
```bash
uv python install 3.9.19
uv python pin 3.9.19
uv venv env
uv venv
source env/bin/activate
uv sync --extra dev --extra test
pre-commit install
Expand Down
45 changes: 45 additions & 0 deletions posthog/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
to_values,
)
from posthog.utils import (
FlagCache,
SizeLimitedDict,
clean,
guess_timezone,
Expand Down Expand Up @@ -126,6 +127,9 @@ def __init__(
project_root=None,
privacy_mode=False,
before_send=None,
enable_flag_cache=True,
flag_cache_size=10000,
flag_cache_ttl=300,
):
self.queue = queue.Queue(max_queue_size)

Expand All @@ -151,6 +155,10 @@ def __init__(
)
self.poller = None
self.distinct_ids_feature_flags_reported = SizeLimitedDict(MAX_DICT_SIZE, set)
self.flag_cache = (
FlagCache(flag_cache_size, flag_cache_ttl) if enable_flag_cache else None
)
self.flag_definition_version = 0
self.disabled = disabled
self.disable_geoip = disable_geoip
self.historical_migration = historical_migration
Expand Down Expand Up @@ -707,6 +715,9 @@ def shutdown(self):

def _load_feature_flags(self):
try:
# Store old flags to detect changes
old_flags_by_key = self.feature_flags_by_key or {}

response = get(
self.personal_api_key,
f"/api/feature_flag/local_evaluation/?token={self.api_key}&send_cohorts",
Expand All @@ -718,6 +729,12 @@ def _load_feature_flags(self):
self.group_type_mapping = response["group_type_mapping"] or {}
self.cohorts = response["cohorts"] or {}

# Check if flag definitions changed and update version
if self.flag_cache and old_flags_by_key != self.feature_flags_by_key:
old_version = self.flag_definition_version
self.flag_definition_version += 1
self.flag_cache.invalidate_version(old_version)
Comment thread
dmarticus marked this conversation as resolved.

except APIError as e:
if e.status == 401:
self.log.error(
Expand All @@ -739,6 +756,10 @@ def _load_feature_flags(self):
self.group_type_mapping = {}
self.cohorts = {}

# Clear flag cache when quota limited
if self.flag_cache:
self.flag_cache.clear()

if self.debug:
raise APIError(
status=402,
Expand Down Expand Up @@ -889,6 +910,12 @@ def _get_feature_flag_result(
flag_result = FeatureFlagResult.from_value_and_payload(
key, lookup_match_value, payload
)

# Cache successful local evaluation
if self.flag_cache and flag_result:
self.flag_cache.set_cached_flag(
distinct_id, key, flag_result, self.flag_definition_version
)
elif not only_evaluate_locally:
try:
flag_details, request_id = self._get_feature_flag_details_from_decide(
Expand All @@ -902,12 +929,30 @@ def _get_feature_flag_result(
flag_result = FeatureFlagResult.from_flag_details(
flag_details, override_match_value
)

# Cache successful remote evaluation
if self.flag_cache and flag_result:
self.flag_cache.set_cached_flag(
distinct_id, key, flag_result, self.flag_definition_version
)

self.log.debug(
f"Successfully computed flag remotely: #{key} -> #{flag_result}"
)
except Exception as e:
self.log.exception(f"[FEATURE FLAGS] Unable to get flag remotely: {e}")

# Fallback to cached value if remote evaluation fails
if self.flag_cache:
stale_result = self.flag_cache.get_stale_cached_flag(
distinct_id, key
)
if stale_result:
self.log.info(
f"[FEATURE FLAGS] Using stale cached value for flag {key}"
)
flag_result = stale_result

if send_feature_flag_events:
self._capture_feature_flag_called(
distinct_id,
Expand Down
123 changes: 123 additions & 0 deletions posthog/test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
import unittest
from dataclasses import dataclass
from datetime import date, datetime, timedelta
Expand All @@ -12,6 +13,7 @@
from pydantic.v1 import BaseModel as BaseModelV1

from posthog import utils
from posthog.types import FeatureFlagResult

TEST_API_KEY = "kOOlRy2QlMY9jHZQv0bKz0FZyazBUoY8Arj0lFVNjs4"
FAKE_TEST_API_KEY = "random_key"
Expand Down Expand Up @@ -173,3 +175,124 @@ class TestDataClass:
"inner_optional": None,
},
}


class TestFlagCache(unittest.TestCase):
def setUp(self):
self.cache = utils.FlagCache(max_size=3, default_ttl=1)
self.flag_result = FeatureFlagResult.from_value_and_payload(
"test-flag", True, None
)

def test_cache_basic_operations(self):
distinct_id = "user123"
flag_key = "test-flag"
flag_version = 1

# Test cache miss
result = self.cache.get_cached_flag(distinct_id, flag_key, flag_version)
assert result is None

# Test cache set and hit
self.cache.set_cached_flag(
distinct_id, flag_key, self.flag_result, flag_version
)
result = self.cache.get_cached_flag(distinct_id, flag_key, flag_version)
assert result is not None
assert result.get_value()

def test_cache_ttl_expiration(self):
distinct_id = "user123"
flag_key = "test-flag"
flag_version = 1

# Set flag in cache
self.cache.set_cached_flag(
distinct_id, flag_key, self.flag_result, flag_version
)

# Should be available immediately
result = self.cache.get_cached_flag(distinct_id, flag_key, flag_version)
assert result is not None

# Wait for TTL to expire (1 second + buffer)
time.sleep(1.1)

# Should be expired
result = self.cache.get_cached_flag(distinct_id, flag_key, flag_version)
assert result is None

def test_cache_version_invalidation(self):
distinct_id = "user123"
flag_key = "test-flag"
old_version = 1
new_version = 2

# Set flag with old version
self.cache.set_cached_flag(distinct_id, flag_key, self.flag_result, old_version)

# Should hit with old version
result = self.cache.get_cached_flag(distinct_id, flag_key, old_version)
assert result is not None

# Should miss with new version
result = self.cache.get_cached_flag(distinct_id, flag_key, new_version)
assert result is None

# Invalidate old version
self.cache.invalidate_version(old_version)

# Should miss even with old version after invalidation
result = self.cache.get_cached_flag(distinct_id, flag_key, old_version)
assert result is None

def test_stale_cache_functionality(self):
distinct_id = "user123"
flag_key = "test-flag"
flag_version = 1

# Set flag in cache
self.cache.set_cached_flag(
distinct_id, flag_key, self.flag_result, flag_version
)

# Wait for TTL to expire
time.sleep(1.1)

# Should not get fresh cache
result = self.cache.get_cached_flag(distinct_id, flag_key, flag_version)
assert result is None

# Should get stale cache (within 1 hour default)
stale_result = self.cache.get_stale_cached_flag(distinct_id, flag_key)
assert stale_result is not None
assert stale_result.get_value()

def test_lru_eviction(self):
# Cache has max_size=3, so adding 4 users should evict the LRU one
flag_version = 1

# Add 3 users
for i in range(3):
user_id = f"user{i}"
self.cache.set_cached_flag(
user_id, "test-flag", self.flag_result, flag_version
)

# Access user0 to make it recently used
self.cache.get_cached_flag("user0", "test-flag", flag_version)

# Add 4th user, should evict user1 (least recently used)
self.cache.set_cached_flag("user3", "test-flag", self.flag_result, flag_version)

# user0 should still be there (was recently accessed)
result = self.cache.get_cached_flag("user0", "test-flag", flag_version)
assert result is not None

# user2 should still be there (was recently added)
result = self.cache.get_cached_flag("user2", "test-flag", flag_version)
assert result is not None

# user3 should be there (just added)
result = self.cache.get_cached_flag("user3", "test-flag", flag_version)
assert result is not None
117 changes: 117 additions & 0 deletions posthog/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import numbers
import re
import time
from collections import defaultdict
from dataclasses import asdict, is_dataclass
from datetime import date, datetime, timezone
Expand Down Expand Up @@ -157,6 +158,122 @@ def __setitem__(self, key, value):
super().__setitem__(key, value)


class FlagCacheEntry:
def __init__(self, flag_result, flag_definition_version, timestamp=None):
self.flag_result = flag_result
self.flag_definition_version = flag_definition_version
self.timestamp = timestamp or time.time()

def is_valid(self, current_time, ttl, current_flag_version):
time_valid = (current_time - self.timestamp) < ttl
version_valid = self.flag_definition_version == current_flag_version
return time_valid and version_valid

def is_stale_but_usable(self, current_time, max_stale_age=3600):
return (current_time - self.timestamp) < max_stale_age


class FlagCache:
def __init__(self, max_size=10000, default_ttl=300):
self.cache = {} # distinct_id -> {flag_key: FlagCacheEntry}
self.access_times = {} # distinct_id -> last_access_time
self.max_size = max_size
self.default_ttl = default_ttl

def get_cached_flag(self, distinct_id, flag_key, current_flag_version):
current_time = time.time()

if distinct_id not in self.cache:
return None

user_flags = self.cache[distinct_id]
if flag_key not in user_flags:
return None

entry = user_flags[flag_key]
if entry.is_valid(current_time, self.default_ttl, current_flag_version):
self.access_times[distinct_id] = current_time
return entry.flag_result

return None

def get_stale_cached_flag(self, distinct_id, flag_key, max_stale_age=3600):
current_time = time.time()

if distinct_id not in self.cache:
return None

user_flags = self.cache[distinct_id]
if flag_key not in user_flags:
return None

entry = user_flags[flag_key]
if entry.is_stale_but_usable(current_time, max_stale_age):
return entry.flag_result

return None

def set_cached_flag(
self, distinct_id, flag_key, flag_result, flag_definition_version
):
current_time = time.time()

# Evict LRU users if we're at capacity
if distinct_id not in self.cache and len(self.cache) >= self.max_size:
self._evict_lru()

# Initialize user cache if needed
if distinct_id not in self.cache:
self.cache[distinct_id] = {}

# Store the flag result
self.cache[distinct_id][flag_key] = FlagCacheEntry(
flag_result, flag_definition_version, current_time
)
self.access_times[distinct_id] = current_time

def invalidate_version(self, old_version):
users_to_remove = []

for distinct_id, user_flags in self.cache.items():
flags_to_remove = []
for flag_key, entry in user_flags.items():
if entry.flag_definition_version == old_version:
flags_to_remove.append(flag_key)

# Remove invalidated flags
for flag_key in flags_to_remove:
del user_flags[flag_key]

# Remove user entirely if no flags remain
if not user_flags:
users_to_remove.append(distinct_id)

# Clean up empty users
for distinct_id in users_to_remove:
del self.cache[distinct_id]
if distinct_id in self.access_times:
del self.access_times[distinct_id]

def _evict_lru(self):
if not self.access_times:
return

# Remove 20% of least recently used entries
sorted_users = sorted(self.access_times.items(), key=lambda x: x[1])
to_remove = max(1, len(sorted_users) // 5)

for distinct_id, _ in sorted_users[:to_remove]:
if distinct_id in self.cache:
del self.cache[distinct_id]
if distinct_id in self.access_times:
del self.access_times[distinct_id]

def clear(self):
self.cache.clear()
self.access_times.clear()


def convert_to_datetime_aware(date_obj):
if date_obj.tzinfo is None:
date_obj = date_obj.replace(tzinfo=timezone.utc)
Expand Down
Loading
Loading