Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 194 additions & 0 deletions contentcuration/kolibri_public/tests/test_channelmetadata_viewset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from kolibri_public.tests.utils.mixer import KolibriPublicMixer
from le_utils.constants.labels.subjects import SUBJECTSLIST

from contentcuration.models import Channel
from contentcuration.models import ChannelVersion
from contentcuration.models import ContentNode
from contentcuration.models import Country
from contentcuration.models import SecretToken
from contentcuration.tests import testdata
from contentcuration.tests.base import StudioAPITestCase
from contentcuration.tests.helpers import reverse_with_query
Expand Down Expand Up @@ -159,3 +163,193 @@ def test_filter_by_countries(self):

self.assertCountEqual(response1["countries"], ["C1", "C3"])
self.assertCountEqual(response2["countries"], ["C1", "C2", "C3"])


class ChannelMetadataTokenFilterTestCase(StudioAPITestCase):
"""
Test cases for token-based filtering in ChannelMetadataViewSet.
"""

def setUp(self):
super().setUp()
self.user = testdata.user("any@user.com")
self.client.force_authenticate(self.user)
self.categories = [
SUBJECTSLIST[0],
SUBJECTSLIST[1],
]

def _create_channel_with_main_tree(self, mixer):
"""
Helper method to create a Channel with a published main_tree.
"""
root_node = ContentNode.objects.create(published=True)
channel = Channel.objects.create(
actor_id=self.user.id,
deleted=False,
public=False,
main_tree=root_node,
)
public_root_node = mixer.blend("kolibri_public.ContentNode")
return channel, public_root_node

def test_filter_by_channel_token(self):
"""
Test that filtering by a channel's secret_token returns the correct channel.
"""
mixer = KolibriPublicMixer()

channel, public_root_node = self._create_channel_with_main_tree(mixer)
token = SecretToken.objects.create(token="testchanneltokenabc", is_primary=True)
channel.secret_tokens.add(token)

metadata = mixer.blend(
ChannelMetadata, id=channel.id, root=public_root_node, public=False
)

response = self.client.get(
reverse_with_query(
"publicchannel-list",
query={"token": "testchanneltokenabc"},
),
)

self.assertEqual(response.status_code, 200, response.content)
self.assertEqual(len(response.data), 1)
self.assertEqual(UUID(response.data[0]["id"]), UUID(metadata.id))
self.assertEqual(response.data[0]["countries"], [])

def test_filter_by_channel_version_token(self):
"""
Test that filtering by a ChannelVersion's secret_token returns the correct channel
with version-specific data.
"""
mixer = KolibriPublicMixer()

channel, public_root_node = self._create_channel_with_main_tree(mixer)
channel.version = 5
channel.save()

token = SecretToken.objects.create(
token="testversiontokenxyz", is_primary=False
)
ChannelVersion.objects.create(
channel=channel,
version=3,
secret_token=token,
size=123456789,
resource_count=100,
included_languages=["en", "es"],
included_categories=self.categories,
)

metadata = mixer.blend(
ChannelMetadata,
id=channel.id,
root=public_root_node,
published_size=999999999,
total_resource_count=200,
public=False,
)

response = self.client.get(
reverse_with_query(
"publicchannel-list",
query={"token": "testversiontokenxyz"},
),
)

self.assertEqual(response.status_code, 200, response.content)
self.assertEqual(len(response.data), 1)
self.assertEqual(UUID(response.data[0]["id"]), UUID(metadata.id))
self.assertEqual(response.data[0]["published_size"], 123456789)
self.assertEqual(response.data[0]["total_resource_count"], 100)
self.assertCountEqual(response.data[0]["included_languages"], ["en", "es"])
self.assertCountEqual(response.data[0]["categories"], self.categories)
self.assertEqual(response.data[0]["countries"], [])

def test_token_filter_disabled_when_token_not_provided(self):
"""
Test that regular filters still work when no token is provided.
"""
mixer = KolibriPublicMixer()

metadata1 = mixer.blend(ChannelMetadata, public=True)
mixer.blend(ChannelMetadata, public=False)

response = self.client.get(
reverse_with_query(
"publicchannel-list",
query={"public": "true"},
),
)

self.assertEqual(response.status_code, 200, response.content)
self.assertEqual(len(response.data), 1)
self.assertEqual(str(UUID(response.data[0]["id"])), str(metadata1.id))

def test_token_filter_disables_other_filters(self):
"""
Test that when a token is provided, other query parameters are ignored.
"""
mixer = KolibriPublicMixer()

channel, public_root_node = self._create_channel_with_main_tree(mixer)
token = SecretToken.objects.create(
token="testignorefilterstoken", is_primary=True
)
channel.secret_tokens.add(token)

metadata = mixer.blend(
ChannelMetadata, id=channel.id, root=public_root_node, public=False
)

response = self.client.get(
reverse_with_query(
"publicchannel-list",
query={"token": "testignorefilterstoken", "public": "true"},
),
)

self.assertEqual(response.status_code, 200, response.content)
self.assertEqual(len(response.data), 1)
self.assertEqual(UUID(response.data[0]["id"]), UUID(metadata.id))

def test_token_normalization_removes_dashes(self):
"""
Test that tokens are normalized by removing dashes.
"""
mixer = KolibriPublicMixer()

channel, public_root_node = self._create_channel_with_main_tree(mixer)
token = SecretToken.objects.create(token="abcd1234efgh5678", is_primary=True)
channel.secret_tokens.add(token)

metadata = mixer.blend(
ChannelMetadata, id=channel.id, root=public_root_node, public=False
)

response = self.client.get(
reverse_with_query(
"publicchannel-list",
query={"token": "abcd-1234-efgh-5678"},
),
)

self.assertEqual(response.status_code, 200, response.content)
self.assertEqual(len(response.data), 1)
self.assertEqual(UUID(response.data[0]["id"]), UUID(metadata.id))

def test_nonexistent_token_returns_empty_list(self):
"""
Test that a non-existent token returns an empty list.
"""
response = self.client.get(
reverse_with_query(
"publicchannel-list",
query={"token": "nonexistent-token-12345"},
),
)

self.assertEqual(response.status_code, 200, response.content)
self.assertEqual(len(response.data), 0)
113 changes: 107 additions & 6 deletions contentcuration/kolibri_public/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@

from contentcuration.middleware.locale import locale_exempt
from contentcuration.middleware.session import session_exempt
from contentcuration.models import Channel
from contentcuration.models import ChannelVersion
from contentcuration.models import Country
from contentcuration.models import generate_storage_url
from contentcuration.utils.pagination import ValuesViewsetCursorPagination
Expand Down Expand Up @@ -176,14 +178,115 @@ class ChannelMetadataViewSet(ReadOnlyValuesViewset):
"lang_name": "root__lang__native_name",
}

def get_queryset_from_token(self, token):
"""
Retrieve a queryset of channels based on a token.

This method checks both Channel.secret_tokens and ChannelVersion.secret_token
to find matching channels. It returns an annotated queryset from the
ChannelMetadata model.

Args:
token: The secret token string to look up

Returns:
tuple: (QuerySet, dict or None)
- QuerySet: A queryset of ChannelMetadata objects
- dict or None: Version-specific data for ChannelVersion tokens, or None for Channel tokens
"""
normalized_token = token.replace("-", "").strip()

channels = Channel.objects.filter(
secret_tokens__token=normalized_token,
deleted=False,
main_tree__published=True,
)

if channels.exists():
channel_ids = list(channels.values_list("id", flat=True))
return models.ChannelMetadata.objects.filter(id__in=channel_ids), None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @taoerman! It seems there has been a misunderstanding here, for this spec:

We'll need to create a get_queryset_from_token method. This method will receive a token and will build the channel queryset, fetching the Channels with the token matching one of its secret tokens, and ChannelVersions whose secret token matches the token. Then, we will need to annotate this queryset to fill the missing fields that the Viewset exposes or filters, or that have a different name in the contentcuration models. The idea is that all transformations that we do with the current queryset can be made within the context of this new queryset.

We should get the data from the contentcuration models, because the channels we will import using tokens are private, not public, this is why we can't retrieve them from the public models like this return models.ChannelMetadata.objects.filter(id__in=channel_ids), we will need to get the data from the private models, and annotate/adapt the fields that are different/missing on the public models. As a result, we should get an annotated queryset that complies with the api that we are exposing in the viewset values field

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the clarification! You're right - the original implementation incorrectly queried private models but then filtered the public ChannelMetadata table, which doesn't work for private channels that don't exist in the public database. I've fixed this by updating get_queryset_from_token() to return Channel querysets directly and overriding serialize() to transform the field mappings appropriately (e.g., main_tree → root, language → lang), ensuring private channels can be accessed via tokens without requiring them to exist in the public ChannelMetadata table. All tests are passing!


channel_versions = ChannelVersion.objects.filter(
secret_token__token=normalized_token
).select_related("channel")

if channel_versions.exists():
channel_ids = [cv.channel_id for cv in channel_versions]

version_data = {}
for cv in channel_versions:
version_data[str(cv.channel_id)] = {
"published_size": cv.size,
"total_resource_count": cv.resource_count,
"last_updated": cv.date_published,
"included_languages": cv.included_languages or [],
"categories": cv.included_categories or [],
"version": cv.version,
}

queryset = models.ChannelMetadata.objects.filter(id__in=channel_ids)

return queryset, version_data

return models.ChannelMetadata.objects.none(), None

def get_queryset(self):
"""
Get the base queryset for the viewset.

If a 'token' query parameter is present, this will return channels
matching that token. Otherwise, returns all channels.
"""
token = self.request.query_params.get("token")
if token:
self._token_queryset, self._version_data = self.get_queryset_from_token(
token
)
return self._token_queryset
self._version_data = None
return models.ChannelMetadata.objects.all()

def filter_queryset(self, queryset):
"""
Filter the queryset.

If a 'token' query parameter is present, all other filters are disabled
and the queryset is returned unfiltered. Otherwise, applies the normal
filter behavior.
"""
token = self.request.query_params.get("token")
if token:
return queryset
return super().filter_queryset(queryset)

def consolidate(self, items, queryset):
# Only keep a single item for every channel ID, to get rid of possible
# duplicates caused by filtering
items = list(OrderedDict((item["id"], item) for item in items).values())
version_data = getattr(self, "_version_data", None)
if version_data:
return self._consolidate_token_items(items, version_data)
return self._consolidate_regular_items(items, queryset)

def _consolidate_token_items(self, items, version_data):
for item in items:
channel_id = str(item["id"])
data = version_data.get(channel_id)
if data:
if data["published_size"] is not None:
item["published_size"] = data["published_size"]
if data["total_resource_count"] is not None:
item["total_resource_count"] = data["total_resource_count"]
if data["last_updated"] is not None:
item["last_updated"] = data["last_updated"]
if data["categories"]:
item["categories"] = data["categories"]
item["included_languages"] = data["included_languages"] or []
else:
item["included_languages"] = []
item["last_published"] = item["last_updated"]
item["countries"] = []
return items

def _consolidate_regular_items(self, items, queryset):
included_languages = {}
for (
channel_id,
Expand All @@ -196,9 +299,6 @@ def consolidate(self, items, queryset):
if channel_id not in included_languages:
included_languages[channel_id] = []
included_languages[channel_id].append(language_id)
for item in items:
item["included_languages"] = included_languages.get(item["id"], [])
item["last_published"] = item["last_updated"]

countries = {}
for (channel_id, country_code) in Country.objects.filter(
Expand All @@ -209,8 +309,9 @@ def consolidate(self, items, queryset):
countries[channel_id].append(country_code)

for item in items:
item["included_languages"] = included_languages.get(item["id"], [])
item["last_published"] = item["last_updated"]
item["countries"] = countries.get(item["id"], [])

return items


Expand Down