Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/khoj/processor/tools/online_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ async def search_with_exa(query: str, location: LocationData) -> Tuple[str, Dict
# Set up API endpoint and headers
exa_api_base = os.getenv("EXA_API_URL", "https://api.exa.ai")
exa_search_api_endpoint = f"{exa_api_base}/search"
headers = {"Content-Type": "application/json", "x-api-key": EXA_API_KEY}
headers = {"Content-Type": "application/json", "x-api-key": EXA_API_KEY, "x-exa-integration": "khoj"}

# Prepare request payload
country_code = location.country_code.upper() if location and location.country_code else "US"
Expand All @@ -210,8 +210,7 @@ async def search_with_exa(query: str, location: LocationData) -> Tuple[str, Dict
"contents": {
"text": False,
"highlights": {
"numSentences": 5,
"highlightsPerUrl": 1,
"maxCharacters": 500,
},
},
}
Expand Down Expand Up @@ -613,11 +612,10 @@ async def read_webpage_with_olostep(web_url: str, api_key: str, api_url: str) ->

async def read_webpage_with_exa(web_url: str, api_key: str, api_url: str) -> str:
exa_api_url = f"{api_url}/contents"
headers = {"Content-Type": "application/json", "x-api-key": api_key}
headers = {"Content-Type": "application/json", "x-api-key": api_key, "x-exa-integration": "khoj"}
params = {
"urls": [web_url],
"text": True,
"livecrawl": "fallback",
"livecrawlTimeout": 15000,
}

Expand Down
284 changes: 284 additions & 0 deletions tests/test_exa_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
import json
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from khoj.processor.tools.online_search import read_webpage_with_exa, search_with_exa
from khoj.utils.rawconfig import LocationData


# ----------------------------------------------------------------------------------------------------
# Fixtures
# ----------------------------------------------------------------------------------------------------
EXA_SEARCH_RESPONSE = {
"results": [
{
"title": "Example Article",
"url": "https://example.com/article",
"highlights": ["This is a highlighted sentence from the article."],
"highlightScores": [0.95],
},
{
"title": "Another Result",
"url": "https://example.com/another",
"highlights": ["Another relevant snippet."],
"highlightScores": [0.88],
"text": "Full text content of the page.",
},
{
"title": "No Highlights Result",
"url": "https://example.com/no-highlights",
},
]
}

EXA_CONTENTS_RESPONSE = {
"results": [
{
"title": "Example Page",
"url": "https://example.com/page",
"text": "Full text content of the web page for reading.",
}
]
}

LOCATION = LocationData(city="San Francisco", region="California", country="United States", country_code="us")


# ----------------------------------------------------------------------------------------------------
# Helpers
# ----------------------------------------------------------------------------------------------------
def _mock_aiohttp_response(status: int, json_data: dict):
"""Create a mock aiohttp response with the given status and JSON data."""
response = AsyncMock()
response.status = status
response.json = AsyncMock(return_value=json_data)
response.text = AsyncMock(return_value=json.dumps(json_data))
response.raise_for_status = MagicMock()
return response


def _mock_aiohttp_session(response):
"""Create a mock aiohttp.ClientSession that returns the given response on post."""
session = AsyncMock()
session.post = MagicMock(return_value=AsyncMock(__aenter__=AsyncMock(return_value=response)))
ctx = AsyncMock()
ctx.__aenter__ = AsyncMock(return_value=session)
ctx.__aexit__ = AsyncMock(return_value=False)
return ctx


# ----------------------------------------------------------------------------------------------------
# Tests: search_with_exa
# ----------------------------------------------------------------------------------------------------
class TestSearchWithExa:
@pytest.mark.anyio
@patch("khoj.processor.tools.online_search.EXA_API_KEY", "test-api-key")
@patch("khoj.processor.tools.online_search.aiohttp.ClientSession")
async def test_search_returns_organic_results(self, mock_session_cls):
"""Exa search should return results in the expected organic format."""
response = _mock_aiohttp_response(200, EXA_SEARCH_RESPONSE)
mock_session_cls.return_value = _mock_aiohttp_session(response)

query, results = await search_with_exa("test query", LOCATION)

assert query == "test query"
assert "organic" in results
assert len(results["organic"]) == 3

@pytest.mark.anyio
@patch("khoj.processor.tools.online_search.EXA_API_KEY", "test-api-key")
@patch("khoj.processor.tools.online_search.aiohttp.ClientSession")
async def test_search_result_fields(self, mock_session_cls):
"""Each organic result should have title, link, snippet, and content fields."""
response = _mock_aiohttp_response(200, EXA_SEARCH_RESPONSE)
mock_session_cls.return_value = _mock_aiohttp_session(response)

_, results = await search_with_exa("test query", LOCATION)
first = results["organic"][0]

assert first["title"] == "Example Article"
assert first["link"] == "https://example.com/article"
assert first["snippet"] == "This is a highlighted sentence from the article."
assert first["content"] is None # text was not requested

@pytest.mark.anyio
@patch("khoj.processor.tools.online_search.EXA_API_KEY", "test-api-key")
@patch("khoj.processor.tools.online_search.aiohttp.ClientSession")
async def test_search_result_with_text_content(self, mock_session_cls):
"""Results that include text should populate the content field."""
response = _mock_aiohttp_response(200, EXA_SEARCH_RESPONSE)
mock_session_cls.return_value = _mock_aiohttp_session(response)

_, results = await search_with_exa("test query", LOCATION)
second = results["organic"][1]

assert second["content"] == "Full text content of the page."

@pytest.mark.anyio
@patch("khoj.processor.tools.online_search.EXA_API_KEY", "test-api-key")
@patch("khoj.processor.tools.online_search.aiohttp.ClientSession")
async def test_search_snippet_fallback_when_no_highlights(self, mock_session_cls):
"""Results without highlights should have None snippet."""
response = _mock_aiohttp_response(200, EXA_SEARCH_RESPONSE)
mock_session_cls.return_value = _mock_aiohttp_session(response)

_, results = await search_with_exa("test query", LOCATION)
third = results["organic"][2]

assert third["snippet"] is None

@pytest.mark.anyio
@patch("khoj.processor.tools.online_search.EXA_API_KEY", "test-api-key")
@patch("khoj.processor.tools.online_search.aiohttp.ClientSession")
async def test_search_empty_results(self, mock_session_cls):
"""Empty results from Exa should return empty dict."""
response = _mock_aiohttp_response(200, {"results": []})
mock_session_cls.return_value = _mock_aiohttp_session(response)

query, results = await search_with_exa("test query", LOCATION)

assert query == "test query"
assert results == {}

@pytest.mark.anyio
@patch("khoj.processor.tools.online_search.EXA_API_KEY", "test-api-key")
@patch("khoj.processor.tools.online_search.aiohttp.ClientSession")
async def test_search_api_error_returns_empty(self, mock_session_cls):
"""Non-200 response should return empty dict."""
response = _mock_aiohttp_response(500, {"error": "Internal server error"})
mock_session_cls.return_value = _mock_aiohttp_session(response)

query, results = await search_with_exa("test query", LOCATION)

assert query == "test query"
assert results == {}

@pytest.mark.anyio
@patch("khoj.processor.tools.online_search.EXA_API_KEY", "test-api-key")
@patch("khoj.processor.tools.online_search.aiohttp.ClientSession")
async def test_search_sends_integration_header(self, mock_session_cls):
"""Request should include the x-exa-integration tracking header."""
response = _mock_aiohttp_response(200, EXA_SEARCH_RESPONSE)
mock_session = AsyncMock()
mock_post_ctx = AsyncMock()
mock_post_ctx.__aenter__ = AsyncMock(return_value=response)
mock_post_ctx.__aexit__ = AsyncMock(return_value=False)
mock_session.post = MagicMock(return_value=mock_post_ctx)
mock_session_ctx = AsyncMock()
mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_ctx.__aexit__ = AsyncMock(return_value=False)
mock_session_cls.return_value = mock_session_ctx

await search_with_exa("test query", LOCATION)

call_kwargs = mock_session.post.call_args
headers = call_kwargs.kwargs.get("headers") or call_kwargs[1].get("headers")
assert headers["x-exa-integration"] == "khoj"

@pytest.mark.anyio
@patch("khoj.processor.tools.online_search.EXA_API_KEY", "test-api-key")
@patch("khoj.processor.tools.online_search.aiohttp.ClientSession")
async def test_search_uses_maxcharacters_not_deprecated_params(self, mock_session_cls):
"""Payload should use maxCharacters for highlights, not deprecated numSentences/highlightsPerUrl."""
response = _mock_aiohttp_response(200, EXA_SEARCH_RESPONSE)
mock_session = AsyncMock()
mock_post_ctx = AsyncMock()
mock_post_ctx.__aenter__ = AsyncMock(return_value=response)
mock_post_ctx.__aexit__ = AsyncMock(return_value=False)
mock_session.post = MagicMock(return_value=mock_post_ctx)
mock_session_ctx = AsyncMock()
mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_ctx.__aexit__ = AsyncMock(return_value=False)
mock_session_cls.return_value = mock_session_ctx

await search_with_exa("test query", LOCATION)

call_kwargs = mock_session.post.call_args
payload = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
highlights_config = payload["contents"]["highlights"]
assert "maxCharacters" in highlights_config
assert "numSentences" not in highlights_config
assert "highlightsPerUrl" not in highlights_config

@pytest.mark.anyio
@patch("khoj.processor.tools.online_search.EXA_API_KEY", "test-api-key")
@patch("khoj.processor.tools.online_search.aiohttp.ClientSession")
async def test_search_location_defaults_to_us(self, mock_session_cls):
"""When location has no country_code, should default to US."""
response = _mock_aiohttp_response(200, EXA_SEARCH_RESPONSE)
mock_session = AsyncMock()
mock_post_ctx = AsyncMock()
mock_post_ctx.__aenter__ = AsyncMock(return_value=response)
mock_post_ctx.__aexit__ = AsyncMock(return_value=False)
mock_session.post = MagicMock(return_value=mock_post_ctx)
mock_session_ctx = AsyncMock()
mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_ctx.__aexit__ = AsyncMock(return_value=False)
mock_session_cls.return_value = mock_session_ctx

empty_location = LocationData(city=None, region=None, country=None, country_code=None)
await search_with_exa("test query", empty_location)

call_kwargs = mock_session.post.call_args
payload = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
assert payload["userLocation"] == "US"


# ----------------------------------------------------------------------------------------------------
# Tests: read_webpage_with_exa
# ----------------------------------------------------------------------------------------------------
class TestReadWebpageWithExa:
@pytest.mark.anyio
@patch("khoj.processor.tools.online_search.aiohttp.ClientSession")
async def test_read_webpage_returns_text(self, mock_session_cls):
"""Should return the text content of the page."""
response = _mock_aiohttp_response(200, EXA_CONTENTS_RESPONSE)
mock_session_cls.return_value = _mock_aiohttp_session(response)

result = await read_webpage_with_exa("https://example.com/page", "test-key", "https://api.exa.ai")

assert result == "Full text content of the web page for reading."

@pytest.mark.anyio
@patch("khoj.processor.tools.online_search.aiohttp.ClientSession")
async def test_read_webpage_sends_integration_header(self, mock_session_cls):
"""Request should include the x-exa-integration tracking header."""
response = _mock_aiohttp_response(200, EXA_CONTENTS_RESPONSE)
mock_session = AsyncMock()
mock_post_ctx = AsyncMock()
mock_post_ctx.__aenter__ = AsyncMock(return_value=response)
mock_post_ctx.__aexit__ = AsyncMock(return_value=False)
mock_session.post = MagicMock(return_value=mock_post_ctx)
mock_session_ctx = AsyncMock()
mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_ctx.__aexit__ = AsyncMock(return_value=False)
mock_session_cls.return_value = mock_session_ctx

await read_webpage_with_exa("https://example.com/page", "test-key", "https://api.exa.ai")

call_kwargs = mock_session.post.call_args
headers = call_kwargs.kwargs.get("headers") or call_kwargs[1].get("headers")
assert headers["x-exa-integration"] == "khoj"

@pytest.mark.anyio
@patch("khoj.processor.tools.online_search.aiohttp.ClientSession")
async def test_read_webpage_does_not_send_deprecated_livecrawl(self, mock_session_cls):
"""Payload should not include the deprecated livecrawl field."""
response = _mock_aiohttp_response(200, EXA_CONTENTS_RESPONSE)
mock_session = AsyncMock()
mock_post_ctx = AsyncMock()
mock_post_ctx.__aenter__ = AsyncMock(return_value=response)
mock_post_ctx.__aexit__ = AsyncMock(return_value=False)
mock_session.post = MagicMock(return_value=mock_post_ctx)
mock_session_ctx = AsyncMock()
mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_ctx.__aexit__ = AsyncMock(return_value=False)
mock_session_cls.return_value = mock_session_ctx

await read_webpage_with_exa("https://example.com/page", "test-key", "https://api.exa.ai")

call_kwargs = mock_session.post.call_args
payload = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
assert "livecrawl" not in payload
assert payload["text"] is True