From e65ed8877f05a344002201e4bb6ea55ae7082821 Mon Sep 17 00:00:00 2001 From: skysneakers Date: Tue, 14 Apr 2026 16:55:24 -0400 Subject: [PATCH 01/16] Adding Logit Lens --- package.json | 1 + scripts/api.sh | 2 +- workbench/_web/next.config.js | 2 +- workbench/_web/package.json | 1 + .../src/app/dev/logit-lens-intro/page.tsx | 66 ++++ .../[workspaceId]/components/ChartCard.tsx | 11 +- .../components/ChartCardsSidebar.tsx | 47 ++- .../components/LogitLensIntroArea.tsx | 127 ++++++ .../components/LogitLensIntroControls.tsx | 368 ++++++++++++++++++ .../components/LogitLensIntroDisplay.tsx | 133 +++++++ .../logit-lens-intro/[chartId]/layout.tsx | 10 + .../logit-lens-intro/[chartId]/page.tsx | 52 +++ .../src/app/workbench/[workspaceId]/page.tsx | 2 + workbench/_web/src/lib/api/chartApi.ts | 28 ++ .../_web/src/lib/api/logitLensIntroApi.ts | 91 +++++ .../_web/src/lib/queries/chartQueries.ts | 9 +- workbench/_web/src/types/charts.ts | 9 +- workbench/_web/src/types/logitLensIntro.ts | 16 + workbench/package-lock.json | 6 + 19 files changed, 970 insertions(+), 11 deletions(-) create mode 100644 package.json create mode 100644 workbench/_web/src/app/dev/logit-lens-intro/page.tsx create mode 100644 workbench/_web/src/app/workbench/[workspaceId]/logit-lens-intro/[chartId]/components/LogitLensIntroArea.tsx create mode 100644 workbench/_web/src/app/workbench/[workspaceId]/logit-lens-intro/[chartId]/components/LogitLensIntroControls.tsx create mode 100644 workbench/_web/src/app/workbench/[workspaceId]/logit-lens-intro/[chartId]/components/LogitLensIntroDisplay.tsx create mode 100644 workbench/_web/src/app/workbench/[workspaceId]/logit-lens-intro/[chartId]/layout.tsx create mode 100644 workbench/_web/src/app/workbench/[workspaceId]/logit-lens-intro/[chartId]/page.tsx create mode 100644 workbench/_web/src/lib/api/logitLensIntroApi.ts create mode 100644 workbench/_web/src/types/logitLensIntro.ts create mode 100644 workbench/package-lock.json diff --git a/package.json b/package.json new file mode 100644 index 00000000..668e1889 --- /dev/null +++ b/package.json @@ -0,0 +1 @@ +{"dependencies": {}} \ No newline at end of file diff --git a/scripts/api.sh b/scripts/api.sh index b69f0e34..9b0b0cf4 100644 --- a/scripts/api.sh +++ b/scripts/api.sh @@ -1,4 +1,4 @@ #!/bin/bash cd workbench -uvicorn _api.main:app --host 0.0.0.0 --port 8000 --reload \ No newline at end of file +python -m uvicorn _api.main:app --host 0.0.0.0 --port 8000 --reload \ No newline at end of file diff --git a/workbench/_web/next.config.js b/workbench/_web/next.config.js index ed00fba2..5f7d1d7a 100644 --- a/workbench/_web/next.config.js +++ b/workbench/_web/next.config.js @@ -8,7 +8,7 @@ const __dirname = path.dirname(fileURLToPath(import.meta.url)); const nextConfig = { reactStrictMode: true, - transpilePackages: ["nnsightful"], + transpilePackages: ["nnsightful", "edulogitlens"], turbopack: { // Expand root so Turbopack can resolve the symlinked nnsightful package // which lives at ../../nnsightful (outside the default _web/ root) diff --git a/workbench/_web/package.json b/workbench/_web/package.json index e1e9129f..99a06af7 100644 --- a/workbench/_web/package.json +++ b/workbench/_web/package.json @@ -46,6 +46,7 @@ "d3-delaunay": "^6.0.4", "dotenv": "^17.2.1", "drizzle-orm": "^0.44.4", + "edulogitlens": "github:skysneakers/edulogitlens", "framer-motion": "^12.23.22", "html-to-image": "^1.11.13", "lexical": "^0.34.0", diff --git a/workbench/_web/src/app/dev/logit-lens-intro/page.tsx b/workbench/_web/src/app/dev/logit-lens-intro/page.tsx new file mode 100644 index 00000000..7d45a852 --- /dev/null +++ b/workbench/_web/src/app/dev/logit-lens-intro/page.tsx @@ -0,0 +1,66 @@ +"use client"; + +import { LogitLensGrid } from "edulogitlens"; +import type { LogitLensData, LogitCell } from "edulogitlens"; + +function generateMockData(): LogitLensData { + const tokens = [ + "The", "E", "iff", "el", "Tower", "is", "in", "the", + "city", "of", "Paris", ",", "France", ".", + ]; + + const layers = Array.from({ length: 12 }, (_, i) => i); + + const vocab = [ + "t", "bow", "illi", "Tower", "el", "France", "Paris", + "tower", "city", "of", "the", "in", "is", "a", "and", + "Eiff", "to", "built", "was", "meters", "at", "by", + "from", "with", "on", "for", "an", "stands", "tall", + ]; + + const data: LogitCell[][] = tokens.map((token) => { + return layers.map((_, layerIdx) => { + const convergence = layerIdx / layers.length; + + let primaryToken = token; + let prob: number; + + if (convergence < 0.3) { + primaryToken = vocab[Math.floor(Math.random() * vocab.length)]; + prob = 0.05 + Math.random() * 0.15; + } else if (convergence < 0.6) { + primaryToken = Math.random() > 0.5 ? token : vocab[Math.floor(Math.random() * vocab.length)]; + prob = 0.2 + Math.random() * 0.3; + } else { + primaryToken = token; + prob = 0.5 + convergence * 0.4 + Math.random() * 0.1; + } + prob = Math.min(prob, 0.95); + + const topTokens: { token: string; prob: number }[] = [{ token: primaryToken, prob }]; + let remaining = (1 - prob) * 0.4; + for (let i = 0; i < 14; i++) { + const candidate = vocab[Math.floor(Math.random() * vocab.length)]; + topTokens.push({ token: candidate, prob: remaining }); + remaining *= 0.7; + } + + return { token: primaryToken, probability: prob, topTokens }; + }); + }); + + return { tokens, layers, data }; +} + +const MOCK_DATA = generateMockData(); + +export default function DevLogitLensIntroPage() { + return ( +
+

Logit Lens Intro — Dev Preview

+
+ +
+
+ ); +} diff --git a/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCard.tsx b/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCard.tsx index 7960f394..31e1de26 100644 --- a/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCard.tsx +++ b/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCard.tsx @@ -2,7 +2,7 @@ import React from "react"; import { useParams, useRouter } from "next/navigation"; -import { Grid3X3, ChartLine, Trash2, Copy, MoreVertical, GitBranch } from "lucide-react"; +import { Grid3X3, ChartLine, Trash2, Copy, MoreVertical, GitBranch, GraduationCap } from "lucide-react"; import Image from "next/image"; import { ChartMetadata } from "@/types/charts"; import { cn } from "@/lib/utils"; @@ -37,6 +37,8 @@ export default function ChartCard({ metadata, handleDelete, canDelete }: ChartCa router.push(`/workbench/${workspaceId}/lens2/${chart.id}`); } else if (chart.toolType === "activation-patching" || chart.chartType === "activation-patching") { router.push(`/workbench/${workspaceId}/activation-patching/${chart.id}`); + } else if (chart.toolType === "logit-lens-intro" || chart.chartType === "logit-lens-intro") { + router.push(`/workbench/${workspaceId}/logit-lens-intro/${chart.id}`); } else { router.push(`/workbench/${workspaceId}/${chart.id}`); } @@ -91,6 +93,13 @@ export default function ChartCard({ metadata, handleDelete, canDelete }: ChartCa Act. Patching ); + if (chartType === "logit-lens-intro") + return ( + + + LL Intro + + ); return ( diff --git a/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCardsSidebar.tsx b/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCardsSidebar.tsx index bd5ca8a6..77e56f93 100644 --- a/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCardsSidebar.tsx +++ b/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCardsSidebar.tsx @@ -5,6 +5,7 @@ import { getChartsMetadata } from "@/lib/queries/chartQueries"; import { useParams, useRouter } from "next/navigation"; import { useCreateLens2ChartPair, + useCreateLogitLensIntroChartPair, useCreatePatchChartPair, useCreateActivationPatchingChartPair, useDeleteChart, @@ -21,7 +22,7 @@ import ReportCard from "./ReportCard"; import { SortableEntry, entryKey, type SidebarEntry } from "./SortableEntry"; import { ChartMetadata } from "@/types/charts"; import type { DocumentListItem } from "@/lib/queries/documentQueries"; -import { Loader2, Plus, PanelLeftClose, PanelLeft, FileText, Layers, GitBranch } from "lucide-react"; +import { Loader2, Plus, PanelLeftClose, PanelLeft, Search, FileText, Layers, GitBranch, GraduationCap } from "lucide-react"; import { Button } from "@/components/ui/button"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { @@ -56,6 +57,7 @@ export default function ChartCardsSidebar({ fillWidth = false }: { fillWidth?: b ); const { mutate: createLens2Pair, isPending: isCreatingLens2 } = useCreateLens2ChartPair(); + const { mutate: createLogitLensIntroPair, isPending: isCreatingLogitLensIntro } = useCreateLogitLensIntroChartPair(); const { mutate: createPatchPair, isPending: isCreatingPatch } = useCreatePatchChartPair(); const { mutate: createActivationPatchingPair, isPending: isCreatingActivationPatching } = useCreateActivationPatchingChartPair(); @@ -156,6 +158,8 @@ export default function ChartCardsSidebar({ fillWidth = false }: { fillWidth?: b router.push(`/workbench/${workspaceId}/lens2/${chartId}`); } else if (toolType === "activation-patching") { router.push(`/workbench/${workspaceId}/activation-patching/${chartId}`); + } else if (toolType === "logit-lens-intro") { + router.push(`/workbench/${workspaceId}/logit-lens-intro/${chartId}`); } else { router.push(`/workbench/${workspaceId}/${chartId}`); } @@ -165,7 +169,7 @@ export default function ChartCardsSidebar({ fillWidth = false }: { fillWidth?: b router.push(`/workbench/${workspaceId}/overview/${documentId}`); }; - const handleCreate = (toolType: "lens2" | "patch" | "activation-patching") => { + const handleCreate = (toolType: "lens2" | "patch" | "activation-patching" | "logit-lens-intro") => { if (toolType === "lens2") { createLens2Pair( { workspaceId: workspaceId as string }, @@ -175,6 +179,15 @@ export default function ChartCardsSidebar({ fillWidth = false }: { fillWidth?: b ); return; } + if (toolType === "logit-lens-intro") { + createLogitLensIntroPair( + { workspaceId: workspaceId as string }, + { + onSuccess: ({ chart }) => navigateToChart(chart.id, "logit-lens-intro"), + }, + ); + return; + } if (toolType === "activation-patching") { createActivationPatchingPair( { workspaceId: workspaceId as string }, @@ -239,7 +252,7 @@ export default function ChartCardsSidebar({ fillWidth = false }: { fillWidth?: b ); }; - const isCreatingAny = isCreatingLens2 || isCreatingPatch || isCreatingActivationPatching || isCreatingDocument; + const isCreatingAny = isCreatingLens2 || isCreatingLogitLensIntro || isCreatingPatch || isCreatingActivationPatching || isCreatingDocument; const actionButtons = (
@@ -257,6 +270,20 @@ export default function ChartCardsSidebar({ fillWidth = false }: { fillWidth?: b )} Logit Lens + +