Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ The frontend fetches `/api/v1/admin-auth/config`, shows a sign-in screen, redire

Public (auth-protected when `[[auth]]` is configured):

- `POST /mcp/{server}` — MCP JSON-RPC. `tools/list` annotates each tool with its policy disposition for the calling agent; a synthetic `atryum.rules.get` tool lets an agent inspect its applicable rules before deciding what to call.
- `POST /mcp/{server}` — MCP JSON-RPC. `tools/list` annotates each tool with its policy disposition for the calling agent; a synthetic `atryum_rules_get` tool lets an agent inspect its applicable rules before deciding what to call. The underscore name is intentional: dotted MCP tool names are spec-compliant, but some common harness implementations reject them.
- `GET /mcp/{server}` — Streamable HTTP / legacy SSE channel for MCP clients that need a long-lived event stream.
- `POST /api/v1/invocations` — direct invocation (Atryum executes).
- `POST /api/v1/external/invocations`, `PATCH /api/v1/external/invocations/{id}` — hook path (harness executes, Atryum gates and records).
Expand Down
80 changes: 77 additions & 3 deletions examples/amp-plugin/atryum.ts
Original file line number Diff line number Diff line change
Expand Up @@ -347,12 +347,86 @@ function describe(input: Record<string, unknown>): string {
return parts.join(" | ") || "(no string params)";
}

const RULES_CACHE_TTL_MS = 5 * 60 * 1000;
const rulesCache = new Map<string, { value: string; expiresAt: number }>();

function formatRulesContext(rules: unknown): string {
if (!rules || typeof rules !== "object") return "";
const record = rules as Record<string, unknown>;
const lines = [
"Atryum advisory rules visible to this harness before the gated call:",
`- server: ${String(record.server || SOURCE)}`,
`- tool: ${String(record.tool || "unknown")}`,
`- effective action: ${String(record.action || record.default_action || "unknown")}`,
];
if (record.matched_rule_id) {
lines.push(`- matched rule: ${String(record.matched_rule_id)}`);
}
if (record.generated_at) {
lines.push(`- as of: ${String(record.generated_at)}`);
}
if (Array.isArray(record.items) && record.items.length > 0) {
lines.push("- visible rules:");
for (const item of record.items.slice(0, 20)) {
const rule = item as Record<string, unknown>;
const guidance = rule.guidance ? ` (${String(rule.guidance)})` : "";
lines.push(
` - ${String(rule.id || "(unnamed)")}: ${String(rule.action)}${guidance}`
);
}
if (record.items.length > 20) {
lines.push(` - ...${record.items.length - 20} more`);
}
}
lines.push("- advisory only; Atryum re-checks policy during the actual gated call.");
return lines.join("\n");
}

async function rulesContext(tool: string): Promise<string> {
const cacheKey = [SOURCE, tool, ACCESS_TOKEN ? "auth" : "no-auth", AGENT_ID].join("\x00");
const cached = rulesCache.get(cacheKey);
if (cached !== undefined && cached.expiresAt > Date.now()) return cached.value;
if (cached !== undefined) rulesCache.delete(cacheKey);
const url = new URL("/api/v1/agent/rules", API);
url.searchParams.set("server", SOURCE);
url.searchParams.set("tool", tool);
if (AGENT_ID && !ACCESS_TOKEN) {
url.searchParams.set("agent_id", AGENT_ID);
}
const controller = new AbortController();
const timer = setTimeout(() => controller.abort(), 3000);
try {
const res = await fetch(url, { headers: atryumHeaders(), signal: controller.signal });
if (!res.ok) return "";
const result = formatRulesContext(await res.json());
rulesCache.set(cacheKey, {
value: result,
expiresAt: Date.now() + RULES_CACHE_TTL_MS,
});
return result;
} catch {
return "";
} finally {
clearTimeout(timer);
}
}

function combineContext(
rules: string,
chat: { context: string; count: number } | undefined
): { context: string; count: number | undefined } | undefined {
const context = [rules, chat?.context].filter(Boolean).join("\n\n");
if (!context) return undefined;
return { context, count: chat?.count };
}

async function submit(
tool: string,
toolUseID: string,
input: Record<string, unknown>,
chat: { context: string; count: number } | undefined
): Promise<InvocationResponse> {
const context = combineContext(await rulesContext(tool), chat);
const res = await fetch(`${API}/api/v1/external/invocations`, {
method: "POST",
headers: atryumHeaders(true),
Expand All @@ -363,9 +437,9 @@ async function submit(
input,
request_id: toolUseID,
thread_id: activeThreadID() || undefined,
chat_context: chat?.context,
chat_context_messages: chat?.count,
context: chat?.context,
chat_context: context?.context,
chat_context_messages: context?.count,
context: context?.context,
client_name: CLIENT_NAME,
client_version: CLIENT_VERSION || undefined,
agent_id: AGENT_ID || undefined,
Expand Down
80 changes: 77 additions & 3 deletions examples/pi-extension/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,79 @@ function describe(input: ToolInput): string {
return parts.join(" | ") || "(no string params)";
}

const RULES_CACHE_TTL_MS = 5 * 60 * 1000;
const rulesCache = new Map<string, { value: string; expiresAt: number }>();

function formatRulesContext(rules: unknown): string {
if (!rules || typeof rules !== "object") return "";
const record = rules as Record<string, unknown>;
const lines = [
"Atryum advisory rules visible to this harness before the gated call:",
`- server: ${String(record.server || SOURCE)}`,
`- tool: ${String(record.tool || "unknown")}`,
`- effective action: ${String(record.action || record.default_action || "unknown")}`,
];
if (record.matched_rule_id) {
lines.push(`- matched rule: ${String(record.matched_rule_id)}`);
}
if (record.generated_at) {
lines.push(`- as of: ${String(record.generated_at)}`);
}
if (Array.isArray(record.items) && record.items.length > 0) {
lines.push("- visible rules:");
for (const item of record.items.slice(0, 20)) {
const rule = item as Record<string, unknown>;
const guidance = rule.guidance ? ` (${String(rule.guidance)})` : "";
lines.push(
` - ${String(rule.id || "(unnamed)")}: ${String(rule.action)}${guidance}`
);
}
if (record.items.length > 20) {
lines.push(` - ...${record.items.length - 20} more`);
}
}
lines.push("- advisory only; Atryum re-checks policy during the actual gated call.");
return lines.join("\n");
}

async function rulesContext(tool: string): Promise<string> {
const cacheKey = [SOURCE, tool, ACCESS_TOKEN ? "auth" : "no-auth", AGENT_ID].join("\x00");
const cached = rulesCache.get(cacheKey);
if (cached !== undefined && cached.expiresAt > Date.now()) return cached.value;
if (cached !== undefined) rulesCache.delete(cacheKey);
const url = new URL("/api/v1/agent/rules", API);
url.searchParams.set("server", SOURCE);
url.searchParams.set("tool", tool);
if (AGENT_ID && !ACCESS_TOKEN) {
url.searchParams.set("agent_id", AGENT_ID);
}
const controller = new AbortController();
const timer = setTimeout(() => controller.abort(), 3000);
try {
const res = await fetch(url, { headers: atryumHeaders(), signal: controller.signal });
if (!res.ok) return "";
const result = formatRulesContext(await res.json());
rulesCache.set(cacheKey, {
value: result,
expiresAt: Date.now() + RULES_CACHE_TTL_MS,
});
return result;
} catch {
return "";
} finally {
clearTimeout(timer);
}
}

function combineContext(
rules: string,
chat: { context: string; count: number } | undefined
): { context: string; count: number | undefined } | undefined {
const context = [rules, chat?.context].filter(Boolean).join("\n\n");
if (!context) return undefined;
return { context, count: chat?.count };
}

function sessionID(ctx: unknown): string | undefined {
const manager = (ctx as { sessionManager?: unknown }).sessionManager as
| { getSessionFile?: () => string; sessionId?: string; id?: string }
Expand Down Expand Up @@ -245,6 +318,7 @@ async function submit(
threadID: string | undefined,
chat: { context: string; count: number } | undefined
): Promise<InvocationResponse> {
const context = combineContext(await rulesContext(tool), chat);
const res = await fetch(`${API}/api/v1/external/invocations`, {
method: "POST",
headers: atryumHeaders(true),
Expand All @@ -255,9 +329,9 @@ async function submit(
input,
request_id: toolCallID,
thread_id: threadID,
chat_context: chat?.context,
chat_context_messages: chat?.count,
context: chat?.context,
chat_context: context?.context,
chat_context_messages: context?.count,
context: context?.context,
agent_id: AGENT_ID || undefined,
client_name: CLIENT_NAME,
client_version: CLIENT_VERSION || undefined,
Expand Down
73 changes: 70 additions & 3 deletions examples/shared-agent-hook/atryum-hook.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,72 @@ function describe(input) {
return parts.join(" | ") || "(no string params)";
}

const RULES_CACHE_TTL_MS = 5 * 60 * 1000;
const rulesCache = new Map();

function formatRulesContext(rules) {
if (!rules || typeof rules !== "object") return "";
const lines = [
"Atryum advisory rules visible to this harness before the gated call:",
`- server: ${rules.server || SOURCE}`,
`- tool: ${rules.tool || "unknown"}`,
`- effective action: ${rules.action || rules.default_action || "unknown"}`,
];
if (rules.matched_rule_id) {
lines.push(`- matched rule: ${rules.matched_rule_id}`);
}
if (rules.generated_at) {
lines.push(`- as of: ${rules.generated_at}`);
}
if (Array.isArray(rules.items) && rules.items.length > 0) {
lines.push("- visible rules:");
for (const rule of rules.items.slice(0, 20)) {
const guidance = rule.guidance ? ` (${rule.guidance})` : "";
lines.push(` - ${rule.id || "(unnamed)"}: ${rule.action}${guidance}`);
}
if (rules.items.length > 20) {
lines.push(` - ...${rules.items.length - 20} more`);
}
}
lines.push("- advisory only; Atryum re-checks policy during the actual gated call.");
return lines.join("\n");
}

async function rulesContext(tool) {
const cacheKey = [SOURCE, tool, ACCESS_TOKEN ? "auth" : "no-auth", AGENT_ID].join("\x00");
const cached = rulesCache.get(cacheKey);
if (cached !== undefined && cached.expiresAt > Date.now()) return cached.value;
if (cached !== undefined) rulesCache.delete(cacheKey);
const url = new URL("/api/v1/agent/rules", API);
url.searchParams.set("server", SOURCE);
url.searchParams.set("tool", tool);
if (AGENT_ID && !ACCESS_TOKEN) {
url.searchParams.set("agent_id", AGENT_ID);
}
const controller = new AbortController();
const timer = setTimeout(() => controller.abort(), 3000);
try {
const res = await fetch(url, { headers: atryumHeaders(), signal: controller.signal });
if (!res.ok) return "";
const result = formatRulesContext(await res.json());
rulesCache.set(cacheKey, {
value: result,
expiresAt: Date.now() + RULES_CACHE_TTL_MS,
});
return result;
} catch {
return "";
} finally {
clearTimeout(timer);
}
}

function combineContext(rules, chat) {
const context = [rules, chat?.context].filter(Boolean).join("\n\n");
if (!context) return undefined;
return { context, count: chat?.count };
}

function normalizeRole(value) {
const role = String(value || "").toLowerCase();
if (role === "human") return "user";
Expand Down Expand Up @@ -462,6 +528,7 @@ async function submit(event) {
const input = toolInput(event);
const id = toolUseID(event);
const chat = await chatContext(event);
const context = combineContext(await rulesContext(name), chat);
const res = await fetch(`${API}/api/v1/external/invocations`, {
method: "POST",
headers: atryumHeaders(true),
Expand All @@ -472,9 +539,9 @@ async function submit(event) {
input,
request_id: id,
thread_id: sessionId(event) || undefined,
chat_context: chat?.context,
chat_context_messages: chat?.count,
context: chat?.context,
chat_context: context?.context,
chat_context_messages: context?.count,
context: context?.context,
client_name: CLIENT_NAME,
client_version: CLIENT_VERSION || undefined,
agent_id: AGENT_ID || undefined,
Expand Down
11 changes: 9 additions & 2 deletions internal/api/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,14 @@ func TestMCPAcceptsValidTokenAndPlumbsAgentID(t *testing.T) {

func TestAgentRulesRequiresAuthAndUsesTokenAgentID(t *testing.T) {
rig := newAuthTestRig(t)
tokenAgent := store.AgentRecord{ID: "agent-cuid-007", AgentIDs: `["agent-007"]`}
otherAgent := store.AgentRecord{ID: "agent-cuid-other", AgentIDs: `["other"]`}
rules := &stubRulesRepo{rules: []store.Rule{
{ID: "auto-rule", Action: invocation.RuleActionAutoApprove, ServerPatterns: []string{"amp"}, ToolPatterns: []string{"Read"}, Enabled: true, Order: 0},
{ID: "other-deny", Action: invocation.RuleActionAutoDeny, ServerPatterns: []string{"amp"}, ToolPatterns: []string{"Read"}, AgentCUIDs: []string{otherAgent.ID}, Enabled: true, Order: 0},
{ID: "auto-rule", Action: invocation.RuleActionAutoApprove, ServerPatterns: []string{"amp"}, ToolPatterns: []string{"Read"}, AgentCUIDs: []string{tokenAgent.ID}, Enabled: true, Order: 1},
}}
h := NewHandler(&stubService{}, stubServerService{}, nil, rules, nil, nil, nil, nil, nil, nil)
agents := &stubAgentsRepo{records: []store.AgentRecord{tokenAgent, otherAgent}}
h := NewHandler(&stubService{}, stubServerService{}, nil, rules, agents, nil, nil, nil, nil, nil)
h.SetAuthValidator(rig.v)
handler := h.Routes()

Expand Down Expand Up @@ -217,6 +221,9 @@ func TestAgentRulesRequiresAuthAndUsesTokenAgentID(t *testing.T) {
if resp.Action != invocation.RuleActionAutoApprove {
t.Fatalf("expected auto_approve action, got %q", resp.Action)
}
if len(resp.Items) != 1 || resp.Items[0].ID != "auto-rule" {
t.Fatalf("expected only token agent rule to be visible, got %#v", resp.Items)
}
}

func TestAgentRuntimeEndpointsRequireAuthWhenValidatorConfigured(t *testing.T) {
Expand Down
Loading
Loading