Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions lib/trajectory.ex
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ defmodule LangChain.Trajectory do
%{name: "get_forecast", arguments: nil}
])

# Assert relative ordering: search happened before answer
Trajectory.called_before?(trajectory, "search", "answer")

# Filter tool calls by name
Trajectory.calls_by_name(trajectory, "search")

Expand Down Expand Up @@ -339,6 +342,76 @@ defmodule LangChain.Trajectory do
end
end

@doc """
Return `true` when tool `name_a` was called before tool `name_b`.

This asserts *relative ordering* of two tool calls regardless of how many
other calls happen in between — the common middle ground between `:superset`
(containment, order-independent) and `:strict` (whole sequence, exact count).
It answers "did the agent call `search` at some point before it called
`answer`?".

`trajectory` can be a `Trajectory` struct, an `LLMChain`, or a bare list of
`%{name: ..., arguments: ...}` tool call maps.

## Semantics

Ordering is evaluated over the flat, ordered `tool_calls` list (the same
order `calls_by_turn/1` exposes). Calls emitted in the same assistant message
are ordered by their position in that list.

Returns `true` when *any* `name_a` call precedes *any* `name_b` call — i.e.
`min(index of name_a) < max(index of name_b)`. This matches the natural
reading of "A happened before B" and tolerates interleaving.

## Missing tools

By default, if either tool was never called this returns `false` (so
`refute_called_before` passes vacuously). Pass `require_both: true` to instead
raise an `ArgumentError` when either tool is absent — use this when a missing
tool indicates a broken eval and should be surfaced rather than silently
collapsing to `false`.

## Options

* `:require_both` — when `true`, raise `ArgumentError` if either tool was
never called (default `false`)

## Examples

Trajectory.called_before?(trajectory, "search", "answer")
#=> true

# Detect a missing tool instead of passing silently
Trajectory.called_before?(trajectory, "search", "answer", require_both: true)
"""
@spec called_before?(t() | LLMChain.t() | [tool_call_map()], String.t(), String.t(), keyword()) ::
boolean()
def called_before?(trajectory, name_a, name_b, opts \\ [])

def called_before?(%LLMChain{} = chain, name_a, name_b, opts) do
called_before?(from_chain(chain), name_a, name_b, opts)
end

def called_before?(%Trajectory{tool_calls: calls}, name_a, name_b, opts) do
called_before?(calls, name_a, name_b, opts)
end

def called_before?(calls, name_a, name_b, opts) when is_list(calls) do
first_a = Enum.find_index(calls, &(&1.name == name_a))
last_b = find_last_index(calls, &(&1.name == name_b))

if Keyword.get(opts, :require_both, false) do
ensure_both_present!(name_a, first_a, name_b, last_b)
end

case {first_a, last_b} do
{nil, _} -> false
{_, nil} -> false
{ia, ib} -> ia < ib
end
end

@doc """
Return all tool calls matching the given tool `name`.

Expand Down Expand Up @@ -382,6 +455,29 @@ defmodule LangChain.Trajectory do

# --- Private helpers ---

# Last index in `list` for which `fun` returns true, or nil. Mirrors
# Enum.find_index/2 but scanning from the right, used for max(index of B).
defp find_last_index(list, fun) do
list
|> Enum.with_index()
|> Enum.reduce(nil, fn {item, idx}, acc -> if fun.(item), do: idx, else: acc end)
end

defp ensure_both_present!(name_a, first_a, name_b, last_b) do
missing =
cond do
is_nil(first_a) and is_nil(last_b) -> "#{inspect(name_a)} and #{inspect(name_b)} were"
is_nil(first_a) -> "#{inspect(name_a)} was"
is_nil(last_b) -> "#{inspect(name_b)} was"
true -> nil
end

if missing do
raise ArgumentError,
":require_both is set but #{missing} never called in the trajectory"
end
end

defp extract_metadata(llm) when is_struct(llm) do
%{
model: Map.get(llm, :model),
Expand Down
61 changes: 61 additions & 0 deletions lib/trajectory/assertions.ex
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,67 @@ defmodule LangChain.Trajectory.Assertions do
end
end

@doc """
Assert that tool `name_a` was called before tool `name_b`.

Wraps `LangChain.Trajectory.called_before?/4` and accepts the same options:

* `:require_both` — when `true`, a missing tool raises rather than failing
the ordering check (see `LangChain.Trajectory.called_before?/4`)

## Examples

assert_called_before trajectory, "search", "answer"
assert_called_before trajectory, "search", "answer", require_both: true
"""
defmacro assert_called_before(actual, name_a, name_b, opts \\ []) do
quote do
actual_val = unquote(actual)
name_a_val = unquote(name_a)
name_b_val = unquote(name_b)
opts_val = unquote(opts)

unless LangChain.Trajectory.called_before?(actual_val, name_a_val, name_b_val, opts_val) do
actual_calls = LangChain.Trajectory.Assertions.extract_tool_calls(actual_val)

raise ExUnit.AssertionError,
left: actual_calls,
right: [name_a_val, name_b_val],
message: "Expected #{inspect(name_a_val)} to be called before #{inspect(name_b_val)}"
end
end
end

@doc """
Assert that tool `name_a` was NOT called before tool `name_b`.

Accepts the same options as `assert_called_before/4`. Note that with the
default options a missing tool makes this pass vacuously; pass
`require_both: true` to surface a missing tool as an error instead.

## Examples

refute_called_before trajectory, "write_file", "read_file"
"""
defmacro refute_called_before(actual, name_a, name_b, opts \\ []) do
quote do
actual_val = unquote(actual)
name_a_val = unquote(name_a)
name_b_val = unquote(name_b)
opts_val = unquote(opts)

if LangChain.Trajectory.called_before?(actual_val, name_a_val, name_b_val, opts_val) do
actual_calls = LangChain.Trajectory.Assertions.extract_tool_calls(actual_val)

raise ExUnit.AssertionError,
left: actual_calls,
right: [name_a_val, name_b_val],
message:
"Did not expect #{inspect(name_a_val)} to be called before #{inspect(name_b_val)}"
end
end
end

@doc false
def extract_tool_calls(%LangChain.Chains.LLMChain{} = chain) do
chain |> LangChain.Trajectory.from_chain() |> extract_tool_calls()
Expand Down
Loading