diff --git a/lib/trajectory.ex b/lib/trajectory.ex index 680e1556..f6117cfe 100644 --- a/lib/trajectory.ex +++ b/lib/trajectory.ex @@ -45,6 +45,9 @@ defmodule LangChain.Trajectory do %{name: "get_forecast", arguments: nil} ]) + # Assert relative ordering: search happened before answer + Trajectory.called_before?(trajectory, "search", "answer") + # Filter tool calls by name Trajectory.calls_by_name(trajectory, "search") @@ -339,6 +342,76 @@ defmodule LangChain.Trajectory do end end + @doc """ + Return `true` when tool `name_a` was called before tool `name_b`. + + This asserts *relative ordering* of two tool calls regardless of how many + other calls happen in between — the common middle ground between `:superset` + (containment, order-independent) and `:strict` (whole sequence, exact count). + It answers "did the agent call `search` at some point before it called + `answer`?". + + `trajectory` can be a `Trajectory` struct, an `LLMChain`, or a bare list of + `%{name: ..., arguments: ...}` tool call maps. + + ## Semantics + + Ordering is evaluated over the flat, ordered `tool_calls` list (the same + order `calls_by_turn/1` exposes). Calls emitted in the same assistant message + are ordered by their position in that list. + + Returns `true` when *any* `name_a` call precedes *any* `name_b` call — i.e. + `min(index of name_a) < max(index of name_b)`. This matches the natural + reading of "A happened before B" and tolerates interleaving. + + ## Missing tools + + By default, if either tool was never called this returns `false` (so + `refute_called_before` passes vacuously). Pass `require_both: true` to instead + raise an `ArgumentError` when either tool is absent — use this when a missing + tool indicates a broken eval and should be surfaced rather than silently + collapsing to `false`. + + ## Options + + * `:require_both` — when `true`, raise `ArgumentError` if either tool was + never called (default `false`) + + ## Examples + + Trajectory.called_before?(trajectory, "search", "answer") + #=> true + + # Detect a missing tool instead of passing silently + Trajectory.called_before?(trajectory, "search", "answer", require_both: true) + """ + @spec called_before?(t() | LLMChain.t() | [tool_call_map()], String.t(), String.t(), keyword()) :: + boolean() + def called_before?(trajectory, name_a, name_b, opts \\ []) + + def called_before?(%LLMChain{} = chain, name_a, name_b, opts) do + called_before?(from_chain(chain), name_a, name_b, opts) + end + + def called_before?(%Trajectory{tool_calls: calls}, name_a, name_b, opts) do + called_before?(calls, name_a, name_b, opts) + end + + def called_before?(calls, name_a, name_b, opts) when is_list(calls) do + first_a = Enum.find_index(calls, &(&1.name == name_a)) + last_b = find_last_index(calls, &(&1.name == name_b)) + + if Keyword.get(opts, :require_both, false) do + ensure_both_present!(name_a, first_a, name_b, last_b) + end + + case {first_a, last_b} do + {nil, _} -> false + {_, nil} -> false + {ia, ib} -> ia < ib + end + end + @doc """ Return all tool calls matching the given tool `name`. @@ -382,6 +455,29 @@ defmodule LangChain.Trajectory do # --- Private helpers --- + # Last index in `list` for which `fun` returns true, or nil. Mirrors + # Enum.find_index/2 but scanning from the right, used for max(index of B). + defp find_last_index(list, fun) do + list + |> Enum.with_index() + |> Enum.reduce(nil, fn {item, idx}, acc -> if fun.(item), do: idx, else: acc end) + end + + defp ensure_both_present!(name_a, first_a, name_b, last_b) do + missing = + cond do + is_nil(first_a) and is_nil(last_b) -> "#{inspect(name_a)} and #{inspect(name_b)} were" + is_nil(first_a) -> "#{inspect(name_a)} was" + is_nil(last_b) -> "#{inspect(name_b)} was" + true -> nil + end + + if missing do + raise ArgumentError, + ":require_both is set but #{missing} never called in the trajectory" + end + end + defp extract_metadata(llm) when is_struct(llm) do %{ model: Map.get(llm, :model), diff --git a/lib/trajectory/assertions.ex b/lib/trajectory/assertions.ex index 0a89d413..7dd690a5 100644 --- a/lib/trajectory/assertions.ex +++ b/lib/trajectory/assertions.ex @@ -83,6 +83,67 @@ defmodule LangChain.Trajectory.Assertions do end end + @doc """ + Assert that tool `name_a` was called before tool `name_b`. + + Wraps `LangChain.Trajectory.called_before?/4` and accepts the same options: + + * `:require_both` — when `true`, a missing tool raises rather than failing + the ordering check (see `LangChain.Trajectory.called_before?/4`) + + ## Examples + + assert_called_before trajectory, "search", "answer" + assert_called_before trajectory, "search", "answer", require_both: true + """ + defmacro assert_called_before(actual, name_a, name_b, opts \\ []) do + quote do + actual_val = unquote(actual) + name_a_val = unquote(name_a) + name_b_val = unquote(name_b) + opts_val = unquote(opts) + + unless LangChain.Trajectory.called_before?(actual_val, name_a_val, name_b_val, opts_val) do + actual_calls = LangChain.Trajectory.Assertions.extract_tool_calls(actual_val) + + raise ExUnit.AssertionError, + left: actual_calls, + right: [name_a_val, name_b_val], + message: "Expected #{inspect(name_a_val)} to be called before #{inspect(name_b_val)}" + end + end + end + + @doc """ + Assert that tool `name_a` was NOT called before tool `name_b`. + + Accepts the same options as `assert_called_before/4`. Note that with the + default options a missing tool makes this pass vacuously; pass + `require_both: true` to surface a missing tool as an error instead. + + ## Examples + + refute_called_before trajectory, "write_file", "read_file" + """ + defmacro refute_called_before(actual, name_a, name_b, opts \\ []) do + quote do + actual_val = unquote(actual) + name_a_val = unquote(name_a) + name_b_val = unquote(name_b) + opts_val = unquote(opts) + + if LangChain.Trajectory.called_before?(actual_val, name_a_val, name_b_val, opts_val) do + actual_calls = LangChain.Trajectory.Assertions.extract_tool_calls(actual_val) + + raise ExUnit.AssertionError, + left: actual_calls, + right: [name_a_val, name_b_val], + message: + "Did not expect #{inspect(name_a_val)} to be called before #{inspect(name_b_val)}" + end + end + end + @doc false def extract_tool_calls(%LangChain.Chains.LLMChain{} = chain) do chain |> LangChain.Trajectory.from_chain() |> extract_tool_calls() diff --git a/test/trajectory_test.exs b/test/trajectory_test.exs index 1620ac09..c74e9fb7 100644 --- a/test/trajectory_test.exs +++ b/test/trajectory_test.exs @@ -670,6 +670,134 @@ defmodule LangChain.TrajectoryTest do end end + describe "called_before?/3,4" do + setup do + trajectory = %Trajectory{ + messages: [], + tool_calls: [ + %{name: "search", arguments: %{"query" => "weather"}}, + %{name: "summarize", arguments: nil}, + %{name: "answer", arguments: nil} + ], + token_usage: nil + } + + %{trajectory: trajectory} + end + + test "true when A precedes B", %{trajectory: trajectory} do + assert Trajectory.called_before?(trajectory, "search", "answer") + end + + test "false when A follows B", %{trajectory: trajectory} do + refute Trajectory.called_before?(trajectory, "answer", "search") + end + + test "false when A is missing", %{trajectory: trajectory} do + refute Trajectory.called_before?(trajectory, "missing", "answer") + end + + test "false when B is missing", %{trajectory: trajectory} do + refute Trajectory.called_before?(trajectory, "search", "missing") + end + + test "false when both are missing", %{trajectory: trajectory} do + refute Trajectory.called_before?(trajectory, "nope", "nada") + end + + test "false when A and B are the same single call", %{trajectory: trajectory} do + # A single call to a tool is not 'before' itself. + refute Trajectory.called_before?(trajectory, "answer", "answer") + end + + test "uses min(index A) < max(index B): any A before any B" do + # Interleaved: search, answer, search, answer. min(search)=0, max(answer)=3. + trajectory = %Trajectory{ + messages: [], + tool_calls: [ + %{name: "search", arguments: nil}, + %{name: "answer", arguments: nil}, + %{name: "search", arguments: nil}, + %{name: "answer", arguments: nil} + ], + token_usage: nil + } + + assert Trajectory.called_before?(trajectory, "search", "answer") + # And the reverse: min(answer)=1 < max(search)=2, so this also holds. + assert Trajectory.called_before?(trajectory, "answer", "search") + end + + test "repeated A and B where every B precedes every A" do + trajectory = %Trajectory{ + messages: [], + tool_calls: [ + %{name: "answer", arguments: nil}, + %{name: "answer", arguments: nil}, + %{name: "search", arguments: nil} + ], + token_usage: nil + } + + # min(search)=2, max(answer)=1 -> 2 < 1 is false + refute Trajectory.called_before?(trajectory, "search", "answer") + end + + test "accepts a bare list of tool calls" do + calls = [ + %{name: "search", arguments: nil}, + %{name: "answer", arguments: nil} + ] + + assert Trajectory.called_before?(calls, "search", "answer") + refute Trajectory.called_before?(calls, "answer", "search") + end + + test "accepts an LLMChain directly" do + tc1 = make_tool_call("search", %{"q" => "x"}, "call_1") + tc2 = make_tool_call("answer", nil, "call_2") + tr1 = make_tool_result("search", "result", "call_1") + + messages = [ + user_msg("go"), + assistant_msg(nil, tool_calls: [tc1]), + tool_msg([tr1]), + assistant_msg(nil, tool_calls: [tc2]) + ] + + chain = chain_with_messages(messages) + + assert Trajectory.called_before?(chain, "search", "answer") + end + + # :require_both — ensures a missing tool is detected rather than silently + # collapsing to false. + test "require_both raises when A was never called", %{trajectory: trajectory} do + assert_raise ArgumentError, ~r/"missing".*never called/, fn -> + Trajectory.called_before?(trajectory, "missing", "answer", require_both: true) + end + end + + test "require_both raises when B was never called", %{trajectory: trajectory} do + assert_raise ArgumentError, ~r/"missing".*never called/, fn -> + Trajectory.called_before?(trajectory, "search", "missing", require_both: true) + end + end + + test "require_both raises when both were never called", %{trajectory: trajectory} do + assert_raise ArgumentError, ~r/never called/, fn -> + Trajectory.called_before?(trajectory, "nope", "nada", require_both: true) + end + end + + test "require_both does not raise when both present, returns ordering", %{ + trajectory: trajectory + } do + assert Trajectory.called_before?(trajectory, "search", "answer", require_both: true) + refute Trajectory.called_before?(trajectory, "answer", "search", require_both: true) + end + end + describe "calls_by_name/2" do test "returns matching tool calls" do trajectory = %Trajectory{ @@ -897,4 +1025,85 @@ defmodule LangChain.TrajectoryTest do end end end + + describe "assert_called_before/3,4" do + setup do + trajectory = %Trajectory{ + messages: [], + tool_calls: [ + %{name: "search", arguments: nil}, + %{name: "answer", arguments: nil} + ], + token_usage: nil + } + + %{trajectory: trajectory} + end + + test "passes when A precedes B", %{trajectory: trajectory} do + assert_called_before(trajectory, "search", "answer") + end + + test "raises ExUnit.AssertionError when ordering does not hold", %{trajectory: trajectory} do + assert_raise ExUnit.AssertionError, ~r/before/, fn -> + assert_called_before(trajectory, "answer", "search") + end + end + + test "accepts an LLMChain directly" do + tc1 = make_tool_call("search", nil, "call_1") + tc2 = make_tool_call("answer", nil, "call_2") + + messages = [ + user_msg("go"), + assistant_msg(nil, tool_calls: [tc1, tc2]) + ] + + chain = chain_with_messages(messages) + assert_called_before(chain, "search", "answer") + end + + test "require_both surfaces a missing tool as a failure", %{trajectory: trajectory} do + assert_raise ArgumentError, ~r/never called/, fn -> + assert_called_before(trajectory, "search", "missing", require_both: true) + end + end + end + + describe "refute_called_before/3,4" do + setup do + trajectory = %Trajectory{ + messages: [], + tool_calls: [ + %{name: "search", arguments: nil}, + %{name: "answer", arguments: nil} + ], + token_usage: nil + } + + %{trajectory: trajectory} + end + + test "passes when ordering does not hold", %{trajectory: trajectory} do + refute_called_before(trajectory, "answer", "search") + end + + test "passes when a tool is missing (default)", %{trajectory: trajectory} do + refute_called_before(trajectory, "search", "missing") + end + + test "raises ExUnit.AssertionError when ordering unexpectedly holds", %{ + trajectory: trajectory + } do + assert_raise ExUnit.AssertionError, ~r/before/, fn -> + refute_called_before(trajectory, "search", "answer") + end + end + + test "require_both surfaces a missing tool instead of passing", %{trajectory: trajectory} do + assert_raise ArgumentError, ~r/never called/, fn -> + refute_called_before(trajectory, "search", "missing", require_both: true) + end + end + end end