diff --git a/src/hooks.client.ts b/src/hooks.client.ts index 4b8886a9e8..d9bc940a2e 100644 --- a/src/hooks.client.ts +++ b/src/hooks.client.ts @@ -1,3 +1,12 @@ +import { getAuthUser } from '$lib/stores/auth-user'; +import { consumeAuthCookies } from '$lib/utilities/auth-user-cookie'; +import { initCoreProvider } from '$lib/utilities/core-provider'; +import { + ossGetDataEncoderEndpoint, + ossPostResponse, + ossPreRequest, +} from '$lib/utilities/oss-provider.svelte'; + if (typeof crypto !== 'undefined' && !crypto.randomUUID) { crypto.randomUUID = function randomUUID() { return '10000000-1000-4000-8000-100000000000'.replace(/[018]/g, (c) => { @@ -9,3 +18,15 @@ if (typeof crypto !== 'undefined' && !crypto.randomUUID) { }) as `${string}-${string}-${string}-${string}-${string}`; }; } + +consumeAuthCookies(); + +initCoreProvider({ + getAccessToken: async () => getAuthUser().accessToken ?? '', + getIdToken: async () => getAuthUser().idToken, + api: { + preRequest: ossPreRequest, + postResponse: ossPostResponse, + }, + getDataEncoderEndpoint: ossGetDataEncoderEndpoint, +}); diff --git a/src/lib/components/event/event-summary-row.svelte b/src/lib/components/event/event-summary-row.svelte index 7bb6dc5cf1..9c9417b833 100644 --- a/src/lib/components/event/event-summary-row.svelte +++ b/src/lib/components/event/event-summary-row.svelte @@ -19,7 +19,6 @@ eventOrGroupIsTerminated, } from '$lib/models/event-groups/get-event-in-group'; import { isCloud } from '$lib/stores/advanced-visibility'; - import { authUser } from '$lib/stores/auth-user'; import type { IterableEvent, WorkflowEvent } from '$lib/types/events'; import { decodeLocalActivity } from '$lib/utilities/decode-local-activity'; import { spaceBetweenCapitalLetters } from '$lib/utilities/format-camel-case'; @@ -187,7 +186,6 @@ primaryLocalAttribute = await decodeLocalActivity(event, { namespace: page.params.namespace, settings: page.data.settings, - accessToken: $authUser.accessToken, }); } else if ( isEventGroup(event) && @@ -196,7 +194,6 @@ primaryLocalAttribute = await decodeLocalActivity(event.initialEvent, { namespace: page.params.namespace, settings: page.data.settings, - accessToken: $authUser.accessToken, }); } }); diff --git a/src/lib/components/event/payload-decoder.svelte b/src/lib/components/event/payload-decoder.svelte index 2e5aaf19f2..9161469c1a 100644 --- a/src/lib/components/event/payload-decoder.svelte +++ b/src/lib/components/event/payload-decoder.svelte @@ -3,7 +3,6 @@ import { page } from '$app/stores'; - import { authUser } from '$lib/stores/auth-user'; import type { Memo } from '$lib/types'; import type { EventAttribute, WorkflowEvent } from '$lib/types/events'; import { @@ -51,7 +50,6 @@ _value, $page.params.namespace, settings, - $authUser.accessToken, ); const decodedAttributes = decodePayloadAttributes( convertedAttributes, diff --git a/src/lib/components/lines-and-dots/svg/timeline-graph-row.svelte b/src/lib/components/lines-and-dots/svg/timeline-graph-row.svelte index fd8e4f5957..079ab39995 100644 --- a/src/lib/components/lines-and-dots/svg/timeline-graph-row.svelte +++ b/src/lib/components/lines-and-dots/svg/timeline-graph-row.svelte @@ -9,7 +9,6 @@ import { translate } from '$lib/i18n/translate'; import type { EventGroup } from '$lib/models/event-groups/event-groups'; import { setActiveGroup } from '$lib/stores/active-events'; - import { authUser } from '$lib/stores/auth-user'; import { decodeLocalActivity, getLocalActivityMarkerEvent, @@ -68,7 +67,6 @@ decodedLocalActivity = await decodeLocalActivity(localActivityEvent, { namespace: page.params.namespace, settings: page.data.settings, - accessToken: $authUser.accessToken, }); if (decodedLocalActivity) { diff --git a/src/lib/components/workflow/metadata/workflow-current-details.svelte b/src/lib/components/workflow/metadata/workflow-current-details.svelte index 8c574e8e06..bfd7353c0b 100644 --- a/src/lib/components/workflow/metadata/workflow-current-details.svelte +++ b/src/lib/components/workflow/metadata/workflow-current-details.svelte @@ -9,7 +9,6 @@ import Markdown from '$lib/holocene/markdown-editor/preview.svelte'; import { translate } from '$lib/i18n/translate'; import { getWorkflowMetadata } from '$lib/services/query-service'; - import { authUser } from '$lib/stores/auth-user'; import { workflowRun } from '$lib/stores/workflow-run'; const { namespace } = $derived(page.params); @@ -34,7 +33,6 @@ }, }, settings, - $authUser?.accessToken ?? '', ); $workflowRun.metadata = metadata; lastFetched = new Date(); diff --git a/src/lib/layouts/workflow-run-layout.svelte b/src/lib/layouts/workflow-run-layout.svelte index 452745e34b..f29c6cceff 100644 --- a/src/lib/layouts/workflow-run-layout.svelte +++ b/src/lib/layouts/workflow-run-layout.svelte @@ -16,7 +16,6 @@ import { getPollers } from '$lib/services/pollers-service'; import { getWorkflowMetadata } from '$lib/services/query-service'; import { fetchWorkflow } from '$lib/services/workflow-service'; - import { authUser } from '$lib/stores/auth-user'; import { resetLastDataEncoderSuccess } from '$lib/stores/data-encoder-config'; import { eventFilterSort, type EventSortOrder } from '$lib/stores/event-view'; import { @@ -117,7 +116,6 @@ }, }, settings, - $authUser?.accessToken, workflowRunController.signal, ).then((metadata) => { $workflowRun.metadata = metadata; diff --git a/src/lib/models/event-history/get-event-attributes.test.ts b/src/lib/models/event-history/get-event-attributes.test.ts index e22a89cd64..7ab150dc67 100644 --- a/src/lib/models/event-history/get-event-attributes.test.ts +++ b/src/lib/models/event-history/get-event-attributes.test.ts @@ -83,7 +83,6 @@ const historyEvent = { const namespace = 'unit-tests'; const settings = settingsFixture as unknown as Settings; -const accessToken = 'xxx.yyy.zzz'; describe('getEventAttributes', () => { beforeEach(() => { @@ -110,7 +109,6 @@ describe('getEventAttributes', () => { historyEvent, namespace, settings, - accessToken, }); expect(event.type).toBe(eventType); }); @@ -123,7 +121,6 @@ describe('getEventAttributes', () => { historyEvent, namespace, settings, - accessToken, }, { convertWithCodec, @@ -141,7 +138,6 @@ describe('getEventAttributes', () => { historyEvent, namespace, settings: { ...settings, codec: { endpoint: 'https://localhost' } }, - accessToken, }, { convertWithCodec, @@ -159,7 +155,6 @@ describe('getEventAttributes', () => { historyEvent, namespace, settings, - accessToken, }, { convertWithCodec, diff --git a/src/lib/models/event-history/index.ts b/src/lib/models/event-history/index.ts index b8ed88a111..c57be0696d 100644 --- a/src/lib/models/event-history/index.ts +++ b/src/lib/models/event-history/index.ts @@ -27,7 +27,7 @@ import { getEventClassification } from './get-event-classification'; import { simplifyAttributes } from './simplify-attributes'; export async function getEventAttributes( - { historyEvent, namespace, settings, accessToken }: EventWithMetadata, + { historyEvent, namespace, settings }: EventWithMetadata, { convertWithCodec = convertPayloadToJsonWithCodec, decodeAttributes = decodePayloadAttributes, @@ -38,7 +38,6 @@ export async function getEventAttributes( attributes, namespace, settings, - accessToken, }); const decodedAttributes = decodeAttributes(convertedAttributes) as object; diff --git a/src/lib/models/event-history/to-event-history.test.ts b/src/lib/models/event-history/to-event-history.test.ts index 6cbe7e3fb5..f062ab7b58 100644 --- a/src/lib/models/event-history/to-event-history.test.ts +++ b/src/lib/models/event-history/to-event-history.test.ts @@ -83,7 +83,6 @@ const historyEvent = { const namespace = 'unit-tests'; const settings = settingsFixture as unknown as Settings; -const accessToken = 'token-test'; describe('getEventAttributes', () => { beforeEach(() => { @@ -110,7 +109,6 @@ describe('getEventAttributes', () => { historyEvent, namespace, settings, - accessToken, }); expect(event.type).toBe(eventType); }); @@ -123,7 +121,6 @@ describe('getEventAttributes', () => { historyEvent, namespace, settings, - accessToken, }, { convertWithCodec, @@ -141,7 +138,6 @@ describe('getEventAttributes', () => { historyEvent, namespace, settings: { ...settings, codec: { endpoint: 'https://localhost' } }, - accessToken, }, { convertWithCodec, @@ -159,7 +155,6 @@ describe('getEventAttributes', () => { historyEvent, namespace, settings, - accessToken, }, { convertWithCodec, diff --git a/src/lib/models/pending-activities/index.ts b/src/lib/models/pending-activities/index.ts index d6c80aabdd..2d56b5c72c 100644 --- a/src/lib/models/pending-activities/index.ts +++ b/src/lib/models/pending-activities/index.ts @@ -2,7 +2,6 @@ import { get } from 'svelte/store'; import { page } from '$app/stores'; -import { authUser } from '$lib/stores/auth-user'; import type { PendingActivity, PendingActivityWithMetadata, @@ -16,7 +15,7 @@ import { } from '$lib/utilities/decode-payload'; export async function getActivityAttributes( - { activity, namespace, settings, accessToken }: PendingActivityWithMetadata, + { activity, namespace, settings }: PendingActivityWithMetadata, { convertWithCodec = convertPayloadToJsonWithCodec, decodeAttributes = decodePayloadAttributes, @@ -26,7 +25,6 @@ export async function getActivityAttributes( attributes: activity, namespace, settings, - accessToken, }); const decodedAttributes = decodeAttributes( @@ -39,13 +37,11 @@ const decodePendingActivity = async ({ activity, namespace, settings, - accessToken, }: PendingActivityWithMetadata): Promise => { const decodedActivity = await getActivityAttributes({ activity, namespace, settings, - accessToken, }); return decodedActivity; }; @@ -54,7 +50,6 @@ export const toDecodedPendingActivities = async ( workflow: WorkflowExecution, namespace: string = get(page).params.namespace, settings: Settings = get(page).data.settings, - accessToken: string = get(authUser).accessToken, ) => { const pendingActivities = workflow?.pendingActivities ?? []; const decodedActivities: PendingActivity[] = []; @@ -63,7 +58,6 @@ export const toDecodedPendingActivities = async ( activity, namespace, settings, - accessToken, }); decodedActivities.push(decodedActivity); } diff --git a/src/lib/models/pending-activities/to-decoded-pending-activities.test.ts b/src/lib/models/pending-activities/to-decoded-pending-activities.test.ts index 5d9b324e2b..00ebd2dcc0 100644 --- a/src/lib/models/pending-activities/to-decoded-pending-activities.test.ts +++ b/src/lib/models/pending-activities/to-decoded-pending-activities.test.ts @@ -10,7 +10,6 @@ import pendingActivityWorkflow from '$fixtures/workflow.pending-activities.json' const namespace = 'unit-tests'; const settings = settingsFixture as unknown as Settings; -const accessToken = 'access-token'; describe('toDecodedPendingActivities', () => { it('should decode heartbeatDetails', async () => { @@ -19,7 +18,6 @@ describe('toDecodedPendingActivities', () => { workflow, namespace, settings, - accessToken, ); expect(decodedHeartbeatDetails[0].heartbeatDetails.payloads[0]).toBe(2); diff --git a/src/lib/pages/workflow-call-stack.svelte b/src/lib/pages/workflow-call-stack.svelte index 1d48553847..2808428543 100644 --- a/src/lib/pages/workflow-call-stack.svelte +++ b/src/lib/pages/workflow-call-stack.svelte @@ -10,7 +10,6 @@ import { translate } from '$lib/i18n/translate'; import type { ParsedQuery } from '$lib/services/query-service'; import { getWorkflowStackTrace } from '$lib/services/query-service'; - import { authUser } from '$lib/stores/auth-user'; import { refresh, workflowRun } from '$lib/stores/workflow-run'; import type { Eventual } from '$lib/types/global'; @@ -31,7 +30,6 @@ namespace, }, page.data?.settings, - $authUser?.accessToken, ); $effect(() => { diff --git a/src/lib/pages/workflow-query.svelte b/src/lib/pages/workflow-query.svelte index d6742584b4..37611cc11f 100644 --- a/src/lib/pages/workflow-query.svelte +++ b/src/lib/pages/workflow-query.svelte @@ -18,7 +18,6 @@ getWorkflowMetadata, type ParsedQuery, } from '$lib/services/query-service'; - import { authUser } from '$lib/stores/auth-user'; import { workflowRun } from '$lib/stores/workflow-run'; import type { Payloads } from '$lib/types'; import type { WorkflowInteractionDefinition } from '$lib/types/workflows'; @@ -84,7 +83,6 @@ }, }, settings, - $authUser?.accessToken, ); $workflowRun.metadata = metadata; }; @@ -121,7 +119,6 @@ queryArgs: payloads ? { payloads } : null, }, $page.data?.settings, - $authUser?.accessToken, ).finally(() => { reset(); }); diff --git a/src/lib/services/data-encoder.test.ts b/src/lib/services/data-encoder.test.ts index 46b1bdbe4e..08e161283c 100644 --- a/src/lib/services/data-encoder.test.ts +++ b/src/lib/services/data-encoder.test.ts @@ -1,7 +1,17 @@ -import { describe, expect, it, vi } from 'vitest'; +import { afterEach, describe, expect, it, vi } from 'vitest'; + +vi.mock('$lib/utilities/core-provider', () => ({ + getAccessToken: vi.fn().mockResolvedValue(''), + getIdToken: vi.fn().mockResolvedValue(undefined), +})); + +import { getAccessToken, getIdToken } from '$lib/utilities/core-provider'; import { codeServerRequest } from './data-encoder'; +const mockGetAccessToken = vi.mocked(getAccessToken); +const mockGetIdToken = vi.mocked(getIdToken); + const settings = { codec: { endpoint: 'http://localcodecserver.com', @@ -91,3 +101,179 @@ describe('Codec Server Requests for Decode and Encode', () => { ).rejects.toThrow(); }); }); + +describe('codecPassAccessToken', () => { + const payloads = { payloads: [{}] }; + const namespace = 'test-namespace'; + + afterEach(() => { + vi.clearAllMocks(); + }); + + it('should attach Authorization and Authorization-Extras headers when passAccessToken is true and endpoint is HTTPS', async () => { + mockGetAccessToken.mockResolvedValue('test-access-token'); + mockGetIdToken.mockResolvedValue('test-id-token'); + + global.fetch = vi.fn(() => + Promise.resolve({ + ok: true, + json: () => Promise.resolve(payloads), + } as Response), + ); + + const httpsSettings = { + codec: { + endpoint: 'https://codecserver.com', + passAccessToken: true, + includeCredentials: false, + }, + }; + + await codeServerRequest({ + type: 'decode', + payloads, + namespace, + settings: httpsSettings, + }); + + expect(mockGetAccessToken).toHaveBeenCalled(); + expect(mockGetIdToken).toHaveBeenCalled(); + + const fetchCall = vi.mocked(global.fetch).mock.calls[0]; + const requestOptions = fetchCall[1] as RequestInit; + const headers = requestOptions.headers as Record; + expect(headers['Authorization']).toBe('Bearer test-access-token'); + expect(headers['Authorization-Extras']).toBe('test-id-token'); + }); + + it('should not attach Authorization header when accessToken is empty', async () => { + mockGetAccessToken.mockResolvedValue(''); + mockGetIdToken.mockResolvedValue(undefined); + + global.fetch = vi.fn(() => + Promise.resolve({ + ok: true, + json: () => Promise.resolve(payloads), + } as Response), + ); + + const httpsSettings = { + codec: { + endpoint: 'https://codecserver.com', + passAccessToken: true, + includeCredentials: false, + }, + }; + + await codeServerRequest({ + type: 'decode', + payloads, + namespace, + settings: httpsSettings, + }); + + const fetchCall = vi.mocked(global.fetch).mock.calls[0]; + const requestOptions = fetchCall[1] as RequestInit; + const headers = requestOptions.headers as Record; + expect(headers['Authorization']).toBeUndefined(); + expect(headers['Authorization-Extras']).toBeUndefined(); + }); + + it('should not make request and return original payloads when passAccessToken is true but endpoint is HTTP', async () => { + global.fetch = vi.fn(); + + const httpSettings = { + codec: { + endpoint: 'http://codecserver.com', + passAccessToken: true, + includeCredentials: false, + }, + }; + + const result = await codeServerRequest({ + type: 'decode', + payloads, + namespace, + settings: httpSettings, + }); + + expect(global.fetch).not.toHaveBeenCalled(); + expect(result).toEqual(payloads); + }); + + it('should not call getAccessToken when passAccessToken is false', async () => { + global.fetch = vi.fn(() => + Promise.resolve({ + ok: true, + json: () => Promise.resolve(payloads), + } as Response), + ); + + await codeServerRequest({ + type: 'decode', + payloads, + namespace, + settings, + }); + + expect(mockGetAccessToken).not.toHaveBeenCalled(); + expect(mockGetIdToken).not.toHaveBeenCalled(); + }); +}); + +describe('codecIncludeCredentials', () => { + const payloads = { payloads: [{}] }; + const namespace = 'test-namespace'; + + afterEach(() => { + vi.clearAllMocks(); + }); + + it('should include credentials in request when includeCredentials is true', async () => { + global.fetch = vi.fn(() => + Promise.resolve({ + ok: true, + json: () => Promise.resolve(payloads), + } as Response), + ); + + const credSettings = { + codec: { + endpoint: 'http://localcodecserver.com', + passAccessToken: false, + includeCredentials: true, + }, + }; + + await codeServerRequest({ + type: 'decode', + payloads, + namespace, + settings: credSettings, + }); + + const fetchCall = vi.mocked(global.fetch).mock.calls[0]; + const requestOptions = fetchCall[1] as RequestInit; + expect(requestOptions.credentials).toBe('include'); + }); + + it('should not include credentials when includeCredentials is false', async () => { + global.fetch = vi.fn(() => + Promise.resolve({ + ok: true, + json: () => Promise.resolve(payloads), + } as Response), + ); + + await codeServerRequest({ + type: 'decode', + payloads, + namespace, + settings, + }); + + const fetchCall = vi.mocked(global.fetch).mock.calls[0]; + const requestOptions = fetchCall[1] as RequestInit; + expect(requestOptions.credentials).toBeUndefined(); + }); +}); diff --git a/src/lib/services/data-encoder.ts b/src/lib/services/data-encoder.ts index 49d368b0b2..4337429b3b 100644 --- a/src/lib/services/data-encoder.ts +++ b/src/lib/services/data-encoder.ts @@ -3,12 +3,12 @@ import { get } from 'svelte/store'; import { page } from '$app/stores'; import { translate } from '$lib/i18n/translate'; -import { authUser } from '$lib/stores/auth-user'; import { setLastDataEncoderFailure, setLastDataEncoderSuccess, } from '$lib/stores/data-encoder-config'; import type { NetworkError, Settings } from '$lib/types/global'; +import { getAccessToken, getIdToken } from '$lib/utilities/core-provider'; import { getCodecEndpoint, getCodecIncludeCredentials, @@ -41,14 +41,14 @@ export async function codeServerRequest({ if (passAccessToken) { if (validateHttps(endpoint)) { - let accessToken = get(authUser).accessToken; - const accessTokenExtras = get(authUser).idToken; - if (globalThis?.AccessToken) { - accessToken = await globalThis?.AccessToken(); - } else if (accessTokenExtras) { - headers['Authorization-Extras'] = accessTokenExtras; + const accessToken = await getAccessToken(); + const idToken = await getIdToken(); + if (accessToken) { + headers['Authorization'] = `Bearer ${accessToken}`; + } + if (idToken) { + headers['Authorization-Extras'] = idToken; } - headers['Authorization'] = `Bearer ${accessToken}`; } else { setLastDataEncoderFailure(); return payloads; diff --git a/src/lib/services/namespaces-service.test.ts b/src/lib/services/namespaces-service.test.ts index fdccee0e61..9829e4e025 100644 --- a/src/lib/services/namespaces-service.test.ts +++ b/src/lib/services/namespaces-service.test.ts @@ -49,7 +49,6 @@ describe('fetchNamespaces', () => { expect(request).toHaveBeenCalledWith( `${origin}${base}/api/v1/namespaces?`, { - credentials: 'include', headers: { 'Caller-Type': 'operator', }, diff --git a/src/lib/services/query-service.ts b/src/lib/services/query-service.ts index c118f64fc1..d192712cf2 100644 --- a/src/lib/services/query-service.ts +++ b/src/lib/services/query-service.ts @@ -84,14 +84,12 @@ async function fetchQuery( export async function getWorkflowMetadata( options: WorkflowParameters, settings: Settings, - accessToken: string, signal?: AbortSignal, ): Promise { try { const metadata = await getQuery( { ...options, queryType: '__temporal_workflow_metadata' }, settings, - accessToken, signal, ); if (!metadata.currentDetails) { @@ -119,7 +117,6 @@ export async function getWorkflowMetadata( export async function getQuery( options: QueryRequestParameters, settings: Settings, - accessToken: string, signal?: AbortSignal, ): Promise { return fetchQuery(options, signal).then(async (execution) => { @@ -132,7 +129,6 @@ export async function getQuery( attributes: queryResult, namespace: options.namespace, settings, - accessToken, }); if ( @@ -156,11 +152,6 @@ export async function getQuery( export async function getWorkflowStackTrace( options: WorkflowParameters, settings: Settings, - accessToken: string, ): Promise { - return getQuery( - { ...options, queryType: '__stack_trace' }, - settings, - accessToken, - ); + return getQuery({ ...options, queryType: '__stack_trace' }, settings); } diff --git a/src/lib/services/workflow-service.ts b/src/lib/services/workflow-service.ts index faa3889b5c..9f8b375df8 100644 --- a/src/lib/services/workflow-service.ts +++ b/src/lib/services/workflow-service.ts @@ -13,7 +13,6 @@ import { toWorkflowExecutions, } from '$lib/models/workflow-execution'; import { isCloud } from '$lib/stores/advanced-visibility'; -import { authUser } from '$lib/stores/auth-user'; import type { SearchAttributeInput, SearchAttributesSchema, @@ -833,7 +832,6 @@ export const fetchInitialValuesForStartWorkflow = async ({ startEvent?.attributes?.input, namespace, get(page).data.settings, - get(authUser).accessToken, 'readable', false, )) as PotentiallyDecodable; diff --git a/src/lib/stores/data-encoder.test.ts b/src/lib/stores/data-encoder.test.ts index f88f0974ea..89c04e5a92 100644 --- a/src/lib/stores/data-encoder.test.ts +++ b/src/lib/stores/data-encoder.test.ts @@ -2,7 +2,6 @@ import { get } from 'svelte/store'; import { beforeEach, describe, expect, it } from 'vitest'; -import { authUser } from './auth-user'; import { dataEncoder } from './data-encoder'; import { codecEndpoint, @@ -20,41 +19,6 @@ describe('dataEncoder', () => { it('should set default values', () => { expect(get(dataEncoder)).toEqual({ - accessToken: undefined, - endpoint: '', - hasError: false, - hasNotRequested: true, - hasSuccess: false, - namespace: 'default', - settingsEndpoint: '', - settingsIncludeCredentials: false, - settingsPassAccessToken: false, - customErrorLink: '', - customErrorMessage: '', - }); - }); - - it('should set access token from authUser', () => { - authUser.set({ accessToken: 'abc' }); - expect(get(dataEncoder)).toEqual({ - accessToken: 'abc', - endpoint: '', - hasError: false, - hasNotRequested: true, - hasSuccess: false, - namespace: 'default', - settingsEndpoint: '', - settingsIncludeCredentials: false, - settingsPassAccessToken: false, - customErrorLink: '', - customErrorMessage: '', - }); - }); - - it('should set access token from authUser', () => { - authUser.set({ accessToken: 'abc' }); - expect(get(dataEncoder)).toEqual({ - accessToken: 'abc', endpoint: '', hasError: false, hasNotRequested: true, @@ -71,7 +35,6 @@ describe('dataEncoder', () => { it('should set codecEndpoint', () => { codecEndpoint.set('https://localhost:8383'); expect(get(dataEncoder)).toEqual({ - accessToken: 'abc', endpoint: 'https://localhost:8383', hasError: false, hasNotRequested: true, diff --git a/src/lib/stores/data-encoder.ts b/src/lib/stores/data-encoder.ts index 3e77bf1cd5..f2d6f58a0f 100644 --- a/src/lib/stores/data-encoder.ts +++ b/src/lib/stores/data-encoder.ts @@ -2,7 +2,6 @@ import { derived } from 'svelte/store'; import { page } from '$app/stores'; -import { authUser } from './auth-user'; import { lastDataConverterStatus } from './data-converter-config'; import { codecEndpoint, @@ -18,7 +17,6 @@ type DataEncoder = { endpoint: string; customErrorMessage: string; customErrorLink: string; - accessToken?: string; hasNotRequested: boolean; hasError: boolean; hasSuccess: boolean; @@ -31,7 +29,6 @@ export const dataEncoder = derived( overrideRemoteCodecConfiguration, lastDataEncoderStatus, lastDataConverterStatus, - authUser, ], ([ $page, @@ -39,7 +36,6 @@ export const dataEncoder = derived( $overrideRemoteCodecConfiguration, $lastDataEncoderStatus, $lastDataConverterStatus, - $authUser, ]): DataEncoder => { const namespace = $page.params.namespace; const settingsEndpoint = $page?.data?.settings?.codec?.endpoint; @@ -57,7 +53,6 @@ export const dataEncoder = derived( const endpoint = $overrideRemoteCodecConfiguration ? $codecEndpoint : settingsEndpoint || $codecEndpoint; - const accessToken = $authUser?.accessToken; const hasNotRequested = endpoint ? $lastDataEncoderStatus === 'notRequested' : $lastDataConverterStatus === 'notRequested'; @@ -74,7 +69,6 @@ export const dataEncoder = derived( settingsPassAccessToken, settingsIncludeCredentials, endpoint, - accessToken, customErrorMessage, customErrorLink, hasNotRequested, diff --git a/src/lib/svelte-mocks/app/state.ts b/src/lib/svelte-mocks/app/state.ts new file mode 100644 index 0000000000..ddff6ac49f --- /dev/null +++ b/src/lib/svelte-mocks/app/state.ts @@ -0,0 +1,48 @@ +import type { Settings } from '$lib/types/global'; + +const settings: Settings = { + auth: { + enabled: false, + options: null, + }, + baseUrl: 'http://localhost:3000', + codec: { + endpoint: '', + passAccessToken: false, + includeCredentials: false, + }, + defaultNamespace: 'default', + disableWriteActions: false, + showTemporalSystemNamespace: false, + batchActionsDisabled: false, + workflowResetDisabled: false, + workflowPauseDisabled: false, + workflowCancelDisabled: false, + workflowSignalDisabled: false, + workflowUpdateDisabled: false, + workflowTerminateDisabled: false, + hideWorkflowQueryErrors: false, + activityCommandsDisabled: false, + feedbackURL: '', + runtimeEnvironment: { + isCloud: false, + isLocal: true, + envOverride: true, + }, + version: '2.28.0', +}; + +export const page = { + error: null, + params: { + namespace: 'default', + }, + routeId: 'namespaces/[namespace]/workflows@root', + status: 200, + data: { + settings, + } as App.PageData, + url: new URL( + 'http://localhost:3000/namespaces/default/workflows?search=basic&query=WorkflowType%3D%22testing%22', + ), +}; diff --git a/src/lib/types/events.ts b/src/lib/types/events.ts index 62f495d512..ed4a2c7b76 100644 --- a/src/lib/types/events.ts +++ b/src/lib/types/events.ts @@ -66,7 +66,6 @@ export type Callbacks = import('$lib/types').CallbackInfo[]; export type EventRequestMetadata = { namespace: string; settings: Settings; - accessToken: string; }; export type EventWithMetadata = { diff --git a/src/lib/utilities/auth-refresh.test.ts b/src/lib/utilities/auth-refresh.test.ts new file mode 100644 index 0000000000..5a6925e7e9 --- /dev/null +++ b/src/lib/utilities/auth-refresh.test.ts @@ -0,0 +1,117 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +vi.mock('$lib/utilities/auth-user-cookie', () => ({ + consumeAuthCookies: vi.fn(), +})); + +vi.mock('$lib/utilities/get-api-origin', () => ({ + getApiOrigin: vi.fn().mockReturnValue('http://localhost:8080'), +})); + +import { consumeAuthCookies } from '$lib/utilities/auth-user-cookie'; + +import { refreshTokens } from './auth-refresh'; + +const mockConsumeAuthCookies = vi.mocked(consumeAuthCookies); + +describe('refreshTokens', () => { + let fetchSpy: ReturnType; + + beforeEach(() => { + vi.clearAllMocks(); + fetchSpy = vi.fn(); + vi.stubGlobal('fetch', fetchSpy); + }); + + afterEach(() => { + vi.unstubAllGlobals(); + }); + + it('should call GET /auth/refresh with credentials include', async () => { + fetchSpy.mockResolvedValue({ ok: true }); + mockConsumeAuthCookies.mockReturnValue(false); + + await refreshTokens(); + + expect(fetchSpy).toHaveBeenCalledWith( + 'http://localhost:8080/auth/refresh', + { + method: 'GET', + credentials: 'include', + }, + ); + }); + + it('should consume cookies and return true on successful refresh', async () => { + fetchSpy.mockResolvedValue({ ok: true }); + mockConsumeAuthCookies.mockReturnValue(true); + + const result = await refreshTokens(); + + expect(result).toBe(true); + expect(mockConsumeAuthCookies).toHaveBeenCalledWith(true); + }); + + it('should return false when refresh endpoint returns non-ok', async () => { + fetchSpy.mockResolvedValue({ + ok: false, + status: 401, + statusText: 'Unauthorized', + }); + + const result = await refreshTokens(); + + expect(result).toBe(false); + expect(mockConsumeAuthCookies).not.toHaveBeenCalled(); + }); + + it('should return false when cookies have no access token after refresh', async () => { + fetchSpy.mockResolvedValue({ ok: true }); + mockConsumeAuthCookies.mockReturnValue(false); + + const result = await refreshTokens(); + + expect(result).toBe(false); + }); + + it('should return false on network error', async () => { + fetchSpy.mockRejectedValue(new Error('Network error')); + + const result = await refreshTokens(); + + expect(result).toBe(false); + expect(mockConsumeAuthCookies).not.toHaveBeenCalled(); + }); + + it('should deduplicate concurrent refresh calls', async () => { + let resolveRefresh: (value: Response) => void; + fetchSpy.mockImplementation( + () => + new Promise((resolve) => { + resolveRefresh = resolve; + }), + ); + mockConsumeAuthCookies.mockReturnValue(true); + + const promise1 = refreshTokens(); + const promise2 = refreshTokens(); + + resolveRefresh!({ ok: true } as Response); + + const [result1, result2] = await Promise.all([promise1, promise2]); + + expect(fetchSpy).toHaveBeenCalledTimes(1); + expect(result1).toBe(true); + expect(result2).toBe(true); + }); + + it('should allow new refresh after previous one completes', async () => { + fetchSpy.mockResolvedValue({ ok: true }); + mockConsumeAuthCookies.mockReturnValue(true); + + await refreshTokens(); + await refreshTokens(); + + expect(fetchSpy).toHaveBeenCalledTimes(2); + }); +}); diff --git a/src/lib/utilities/auth-refresh.ts b/src/lib/utilities/auth-refresh.ts index f4016844c6..0f2c5fc863 100644 --- a/src/lib/utilities/auth-refresh.ts +++ b/src/lib/utilities/auth-refresh.ts @@ -1,14 +1,20 @@ import { BROWSER } from 'esm-env'; -import { setAuthUser } from '$lib/stores/auth-user'; -import { - cleanAuthUserCookie, - getAuthUserCookie, -} from '$lib/utilities/auth-user-cookie'; +import { consumeAuthCookies } from '$lib/utilities/auth-user-cookie'; import { getApiOrigin } from '$lib/utilities/get-api-origin'; let refreshPromise: Promise | null = null; +/** + * Calls the Go server's `/auth/refresh` endpoint, which uses the HttpOnly + * `refresh` cookie to obtain new tokens from the OIDC provider. + * + * The server responds by setting fresh `user*` transport cookies, which + * are then consumed into the auth store via `consumeAuthCookies()`. + * + * Concurrent calls are deduplicated — only one HTTP request is in flight + * at a time, and all callers share the same promise. + */ export const refreshTokens = async (): Promise => { if (!BROWSER) return false; @@ -35,16 +41,11 @@ export const refreshTokens = async (): Promise => { return false; } - const user = getAuthUserCookie(true); - if (user?.accessToken) { - setAuthUser(user); - cleanAuthUserCookie(true); + const consumed = consumeAuthCookies(true); + if (consumed) { const duration = performance.now() - startTime; - const expiryTime = user.expiresAt - ? new Date(user.expiresAt).toISOString() - : 'unknown'; console.info( - `[Auth] Token refresh successful (duration: ${duration.toFixed(2)}ms, expires: ${expiryTime})`, + `[Auth] Token refresh successful (duration: ${duration.toFixed(2)}ms)`, ); return true; } diff --git a/src/lib/utilities/auth-user-cookie.test.ts b/src/lib/utilities/auth-user-cookie.test.ts new file mode 100644 index 0000000000..4705afbed9 --- /dev/null +++ b/src/lib/utilities/auth-user-cookie.test.ts @@ -0,0 +1,62 @@ +import { afterEach, describe, expect, it } from 'vitest'; + +import { getAuthUserCookie } from './auth-user-cookie'; + +const setCookie = (name: string, value: string) => { + Object.defineProperty(document, 'cookie', { + writable: true, + value: `${name}=${value}`, + }); +}; + +const clearCookie = () => { + Object.defineProperty(document, 'cookie', { + writable: true, + value: '', + }); +}; + +afterEach(() => { + clearCookie(); +}); + +describe('getAuthUserCookie', () => { + it('should parse a user cookie', () => { + const payload = JSON.stringify({ + AccessToken: 'access', + IDToken: 'id', + Name: 'Test', + Email: 'test@test.com', + Picture: '', + }); + const encoded = btoa(payload); + setCookie('user0', encoded); + + const user = getAuthUserCookie(true); + expect(user.accessToken).toBe('access'); + expect(user.idToken).toBe('id'); + }); + + it('should correctly parse a base64 cookie value that contains = padding', () => { + const payload = JSON.stringify({ + AccessToken: 'tk', // btoa of JSON with short values produces = padding + IDToken: 'x', + Name: '', + Email: '', + Picture: '', + }); + const encoded = btoa(payload); + // Verify our test fixture actually has = padding + expect(encoded.endsWith('=')).toBe(true); + setCookie('user0', encoded); + + const user = getAuthUserCookie(true); + expect(user.accessToken).toBe('tk'); + }); + + it('should return empty object when no user cookie exists', () => { + clearCookie(); + const user = getAuthUserCookie(true); + expect(user).toEqual({}); + }); +}); diff --git a/src/lib/utilities/auth-user-cookie.ts b/src/lib/utilities/auth-user-cookie.ts index 651d9952ce..132f8cc6dc 100644 --- a/src/lib/utilities/auth-user-cookie.ts +++ b/src/lib/utilities/auth-user-cookie.ts @@ -1,5 +1,6 @@ import { BROWSER } from 'esm-env'; +import { setAuthUser } from '$lib/stores/auth-user'; import type { User } from '$lib/types/global'; import { atob } from '$lib/utilities/atob'; @@ -24,7 +25,7 @@ export const getAuthUserCookie = (isBrowser = BROWSER): User => { let userBase64 = ''; while (next) { - const [_, value] = next.split('='); + const value = next.slice(next.indexOf('=') + 1); userBase64 += value; i++; @@ -65,3 +66,28 @@ export const cleanAuthUserCookie = (isBrowser = BROWSER) => { next = cookies.find((c) => c.includes(cookieName + i)); } }; + +/** + * Reads the Go server's `user*` transport cookies into the auth store, + * then deletes them. + * + * The Go server sets `user0`, `user1`, ... cookies after the OIDC callback + * because it cannot write to localStorage directly. These cookies are + * chunked to work around the ~4KB per-cookie size limit for large JWTs. + * They are a one-time transport mechanism — not persistent storage. + * + * This function should be called once on app init (before the token provider + * reads from the store) and again after each token refresh. + */ +export const consumeAuthCookies = (isBrowser = BROWSER): boolean => { + if (!isBrowser) return false; + + const user = getAuthUserCookie(isBrowser); + if (user?.accessToken) { + setAuthUser(user); + cleanAuthUserCookie(isBrowser); + return true; + } + + return false; +}; diff --git a/src/lib/utilities/core-provider.test.ts b/src/lib/utilities/core-provider.test.ts new file mode 100644 index 0000000000..ae79cf0c11 --- /dev/null +++ b/src/lib/utilities/core-provider.test.ts @@ -0,0 +1,164 @@ +import { describe, expect, it, vi } from 'vitest'; + +import { + getAccessToken, + getIdToken, + initCoreProvider, + runPostResponse, + runPreRequest, +} from './core-provider'; + +describe('core-provider', () => { + describe('getAccessToken', () => { + it('should return token from provided getAccessToken function', async () => { + initCoreProvider({ + getAccessToken: async () => 'my-token', + }); + + const token = await getAccessToken(); + expect(token).toBe('my-token'); + }); + }); + + describe('getIdToken', () => { + it('should return token from provided getIdToken function', async () => { + initCoreProvider({ + getAccessToken: async () => 'access', + getIdToken: async () => 'id-token', + }); + + const idToken = await getIdToken(); + expect(idToken).toBe('id-token'); + }); + + it('should return undefined when getIdToken is not provided', async () => { + initCoreProvider({ + getAccessToken: async () => 'access', + }); + + const idToken = await getIdToken(); + expect(idToken).toBeUndefined(); + }); + }); + + describe('runPreRequest', () => { + it('should call the provided preRequest hook', async () => { + initCoreProvider({ + getAccessToken: async () => 'token', + api: { + preRequest: async (ctx) => ({ + ...ctx, + options: { + ...ctx.options, + headers: { Authorization: 'Bearer injected' }, + }, + }), + }, + }); + + const result = await runPreRequest({ + url: '/api/test', + options: {}, + }); + + expect( + (result.options.headers as Record)['Authorization'], + ).toBe('Bearer injected'); + }); + + it('should pass through when no preRequest hook is provided', async () => { + initCoreProvider({ + getAccessToken: async () => 'token', + }); + + const context = { url: '/api/test', options: { method: 'GET' } }; + const result = await runPreRequest(context); + + expect(result).toEqual(context); + }); + }); + + describe('runPostResponse', () => { + it('should call the provided postResponse hook', async () => { + const retryResponse = new Response('retried', { status: 200 }); + const retry = vi.fn().mockResolvedValue(retryResponse); + + initCoreProvider({ + getAccessToken: async () => 'token', + api: { + postResponse: async (response, context) => { + if (response.status === 401) { + return context.retry(); + } + return response; + }, + }, + }); + + const unauthorizedResponse = new Response('unauthorized', { + status: 401, + }); + const result = await runPostResponse(unauthorizedResponse, { + url: '/api/test', + options: {}, + retry, + }); + + expect(retry).toHaveBeenCalled(); + expect(result.status).toBe(200); + }); + + it('should pass through when no postResponse hook is provided', async () => { + initCoreProvider({ + getAccessToken: async () => 'token', + }); + + const response = new Response('ok', { status: 200 }); + const result = await runPostResponse(response, { + url: '/api/test', + options: {}, + retry: vi.fn(), + }); + + expect(result).toBe(response); + }); + + it('should not retry on non-401 responses when hook checks status', async () => { + const retry = vi.fn(); + + initCoreProvider({ + getAccessToken: async () => 'token', + api: { + postResponse: async (response, context) => { + if (response.status === 401) return context.retry(); + return response; + }, + }, + }); + + const response = new Response('forbidden', { status: 403 }); + const result = await runPostResponse(response, { + url: '/api/test', + options: {}, + retry, + }); + + expect(retry).not.toHaveBeenCalled(); + expect(result.status).toBe(403); + }); + }); + + describe('dependency injection', () => { + it('should allow swapping providers at runtime', async () => { + initCoreProvider({ + getAccessToken: async () => 'first', + }); + expect(await getAccessToken()).toBe('first'); + + initCoreProvider({ + getAccessToken: async () => 'second', + }); + expect(await getAccessToken()).toBe('second'); + }); + }); +}); diff --git a/src/lib/utilities/core-provider.ts b/src/lib/utilities/core-provider.ts new file mode 100644 index 0000000000..0a23a97f94 --- /dev/null +++ b/src/lib/utilities/core-provider.ts @@ -0,0 +1,92 @@ +import { BROWSER } from 'esm-env'; + +export type RequestContext = { + url: string; + options: RequestInit; +}; + +export type PreRequestHook = ( + context: RequestContext, +) => Promise; + +export type PostResponseHook = ( + response: Response, + context: RequestContext & { retry: () => Promise }, +) => Promise; + +export type CoreProvider = { + getAccessToken: () => Promise; + getIdToken: () => Promise; + api: { + preRequest: PreRequestHook; + postResponse: PostResponseHook; + }; + getDataEncoderEndpoint: (namespace: string) => Promise; + searchNamespaces: (query: string) => Promise; +}; + +export type InitOptions = { + getAccessToken: () => Promise; + getIdToken?: () => Promise; + api?: { + preRequest?: PreRequestHook; + postResponse?: PostResponseHook; + }; + getDataEncoderEndpoint?: (namespace: string) => Promise; + searchNamespaces?: (query: string) => Promise; +}; + +let provider: CoreProvider | null = null; + +const passthrough: PreRequestHook = async (ctx) => ctx; +const passthroughResponse: PostResponseHook = async (res) => res; + +export function initCoreProvider(options: InitOptions): void { + provider = { + getAccessToken: options.getAccessToken, + getIdToken: options.getIdToken ?? (async () => undefined), + api: { + preRequest: options.api?.preRequest ?? passthrough, + postResponse: options.api?.postResponse ?? passthroughResponse, + }, + getDataEncoderEndpoint: options.getDataEncoderEndpoint ?? (async () => ''), + searchNamespaces: options.searchNamespaces ?? (async () => []), + }; +} + +export async function getAccessToken(): Promise { + if (!BROWSER || !provider) return ''; + return provider.getAccessToken(); +} + +export async function getIdToken(): Promise { + if (!BROWSER || !provider) return undefined; + return provider.getIdToken(); +} + +export async function getDataEncoderEndpoint( + namespace: string, +): Promise { + if (!BROWSER || !provider) return ''; + return provider.getDataEncoderEndpoint(namespace); +} + +export async function searchNamespaces(query: string): Promise { + if (!BROWSER || !provider) return []; + return provider.searchNamespaces(query); +} + +export async function runPreRequest( + context: RequestContext, +): Promise { + if (!provider) return context; + return provider.api.preRequest(context); +} + +export async function runPostResponse( + response: Response, + context: RequestContext & { retry: () => Promise }, +): Promise { + if (!provider) return response; + return provider.api.postResponse(response, context); +} diff --git a/src/lib/utilities/decode-local-activity.ts b/src/lib/utilities/decode-local-activity.ts index 027d3a18f1..02f641dbdd 100644 --- a/src/lib/utilities/decode-local-activity.ts +++ b/src/lib/utilities/decode-local-activity.ts @@ -32,7 +32,6 @@ export type DecodedLocalActivity = { export type LocalActivityDecodeOptions = { namespace: string; settings: Settings; - accessToken?: string; }; export const decodeLocalActivity = async ( @@ -43,7 +42,7 @@ export const decodeLocalActivity = async ( return undefined; } - const { namespace, settings, accessToken } = options; + const { namespace, settings } = options; const codecSettings = { ...settings, @@ -60,7 +59,6 @@ export const decodeLocalActivity = async ( event.attributes, namespace, codecSettings, - accessToken, ); const payloads = (event.markerRecordedEventAttributes?.details?.data diff --git a/src/lib/utilities/decode-payload.ts b/src/lib/utilities/decode-payload.ts index 53b92caca2..7a6efb125f 100644 --- a/src/lib/utilities/decode-payload.ts +++ b/src/lib/utilities/decode-payload.ts @@ -3,7 +3,6 @@ import { get } from 'svelte/store'; import { page } from '$app/stores'; import { decodePayloadsWithCodec } from '$lib/services/data-encoder'; -import { authUser } from '$lib/stores/auth-user'; import type { codecEndpoint, includeCredentials, @@ -218,7 +217,6 @@ export const decodeAllPotentialPayloadsWithCodec = async ( anyAttributes: EventAttribute | PotentiallyDecodable | Failure, namespace: string = get(page).params.namespace, settings: Settings = get(page).data.settings, - accessToken: string = get(authUser).accessToken, ): Promise => { const decode = decodeReadablePayloads(settings); @@ -237,7 +235,6 @@ export const decodeAllPotentialPayloadsWithCodec = async ( next, namespace, settings, - accessToken, ); } } @@ -264,7 +261,6 @@ export const cloneAllPotentialPayloadsWithCodec = async ( | null, namespace: string, settings: Settings, - accessToken: string, decodeSetting: DownloadEventHistorySetting = 'readable', returnDataOnly: boolean = true, ): Promise< @@ -297,7 +293,6 @@ export const cloneAllPotentialPayloadsWithCodec = async ( next, namespace, settings, - accessToken, decodeSetting, returnDataOnly, ); @@ -313,7 +308,6 @@ export const convertPayloadToJsonWithCodec = async ({ attributes, namespace, settings, - accessToken, }: { attributes: EventAttribute | PotentiallyDecodable | Failure; } & EventRequestMetadata): Promise< @@ -323,7 +317,6 @@ export const convertPayloadToJsonWithCodec = async ({ attributes, namespace, settings, - accessToken, ); return decodedAttributes; }; diff --git a/src/lib/utilities/export-history.ts b/src/lib/utilities/export-history.ts index 12d8993e23..6843bd6d10 100644 --- a/src/lib/utilities/export-history.ts +++ b/src/lib/utilities/export-history.ts @@ -3,7 +3,6 @@ import { get } from 'svelte/store'; import { page } from '$app/stores'; import { fetchRawEvents } from '$lib/services/events-service'; -import { authUser } from '$lib/stores/auth-user'; import type { DownloadEventHistorySetting } from '$lib/stores/events'; import type { HistoryEvent } from '$lib/types/events'; import type { Settings } from '$lib/types/global'; @@ -44,7 +43,6 @@ const decodePayloads = async ( event, get(page).params.namespace, settingsWithLocalConfig, - get(authUser).accessToken, decodeSetting, returnDataOnly, ); diff --git a/src/lib/utilities/oss-provider.svelte.ts b/src/lib/utilities/oss-provider.svelte.ts new file mode 100644 index 0000000000..cab69255c2 --- /dev/null +++ b/src/lib/utilities/oss-provider.svelte.ts @@ -0,0 +1,78 @@ +import { BROWSER } from 'esm-env'; + +import { page } from '$app/state'; + +import { getAuthUser } from '$lib/stores/auth-user'; +import type { + PostResponseHook, + PreRequestHook, +} from '$lib/utilities/core-provider'; + +import { refreshTokens } from './auth-refresh'; +import { getCodecEndpoint } from './get-codec'; +import { routeForLoginPage } from './route-for'; + +export function getCsrfToken(): string | undefined { + try { + const csrfCookie = '_csrf='; + const cookies = document.cookie.split(';'); + const csrf = cookies.find((c) => c.includes(csrfCookie)); + if (csrf) { + return csrf.trim().slice(csrfCookie.length); + } + } catch (error) { + console.error(error); + } + return undefined; +} + +export const ossPreRequest: PreRequestHook = async (context) => { + const headers: Record = + (context.options.headers as Record) ?? {}; + + const user = getAuthUser(); + + if (user.accessToken) { + headers['Authorization'] = `Bearer ${user.accessToken}`; + } + + if (user.idToken) { + headers['Authorization-Extras'] = user.idToken; + } + + const csrf = getCsrfToken(); + if (csrf && !headers['X-CSRF-TOKEN']) { + headers['X-CSRF-TOKEN'] = csrf; + } + + return { + ...context, + options: { + ...context.options, + credentials: 'include' as RequestCredentials, + headers, + }, + }; +}; + +export const ossPostResponse: PostResponseHook = async (response, context) => { + if (response.status !== 401) return response; + + const refreshed = await refreshTokens(); + if (refreshed) { + return context.retry(); + } + + if (BROWSER) { + window.location.assign(routeForLoginPage()); + } + + return response; +}; + +export async function ossGetDataEncoderEndpoint( + _namespace: string, +): Promise { + const settings = page.data?.settings; + return getCodecEndpoint(settings); +} diff --git a/src/lib/utilities/oss-provider.test.ts b/src/lib/utilities/oss-provider.test.ts new file mode 100644 index 0000000000..e382d4a41b --- /dev/null +++ b/src/lib/utilities/oss-provider.test.ts @@ -0,0 +1,229 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +vi.mock('$lib/stores/auth-user', () => ({ + getAuthUser: vi.fn(), +})); + +vi.mock('./auth-refresh', () => ({ + refreshTokens: vi.fn(), +})); + +import { getAuthUser } from '$lib/stores/auth-user'; + +import { refreshTokens } from './auth-refresh'; + +import { + getCsrfToken, + ossPostResponse, + ossPreRequest, +} from './oss-provider.svelte'; + +const mockGetAuthUser = vi.mocked(getAuthUser); +const mockRefreshTokens = vi.mocked(refreshTokens); + +const withCookie = async (cookie: string, fn: () => void | Promise) => { + const currentCookie = document.cookie; + + Object.defineProperty(document, 'cookie', { + writable: true, + value: cookie, + }); + + await fn(); + + Object.defineProperty(document, 'cookie', { + writable: true, + value: currentCookie, + }); +}; + +describe('ossPreRequest', () => { + beforeEach(() => { + vi.clearAllMocks(); + mockGetAuthUser.mockReturnValue({}); + }); + + it('should add credentials include to request', async () => { + const result = await ossPreRequest({ url: '/api/test', options: {} }); + + expect(result.options.credentials).toBe('include'); + }); + + it('should add Authorization header when user has accessToken', async () => { + mockGetAuthUser.mockReturnValue({ accessToken: 'my-token' }); + + const result = await ossPreRequest({ url: '/api/test', options: {} }); + + const headers = result.options.headers as Record; + expect(headers['Authorization']).toBe('Bearer my-token'); + }); + + it('should not add Authorization header when user has no accessToken', async () => { + mockGetAuthUser.mockReturnValue({}); + + const result = await ossPreRequest({ url: '/api/test', options: {} }); + + const headers = result.options.headers as Record; + expect(headers['Authorization']).toBeUndefined(); + }); + + it('should add Authorization-Extras header when user has idToken', async () => { + mockGetAuthUser.mockReturnValue({ + accessToken: 'token', + idToken: 'id-token', + }); + + const result = await ossPreRequest({ url: '/api/test', options: {} }); + + const headers = result.options.headers as Record; + expect(headers['Authorization-Extras']).toBe('id-token'); + }); + + it('should add csrf cookie to headers', async () => { + await withCookie('_csrf=csrf-token-value', async () => { + const result = await ossPreRequest({ url: '/api/test', options: {} }); + + const headers = result.options.headers as Record; + expect(headers['X-CSRF-TOKEN']).toBe('csrf-token-value'); + }); + }); + + it('should not add csrf cookie to headers if not present', async () => { + await withCookie('_nope=token', async () => { + const result = await ossPreRequest({ url: '/api/test', options: {} }); + + const headers = result.options.headers as Record; + expect(headers['X-CSRF-TOKEN']).toBeUndefined(); + }); + }); + + it('should not overwrite existing X-CSRF-TOKEN header', async () => { + await withCookie('_csrf=new-token', async () => { + const result = await ossPreRequest({ + url: '/api/test', + options: { + headers: { 'X-CSRF-TOKEN': 'pre-existing' } as Record, + }, + }); + + const headers = result.options.headers as Record; + expect(headers['X-CSRF-TOKEN']).toBe('pre-existing'); + }); + }); + + it('should preserve existing options while adding credentials', async () => { + const result = await ossPreRequest({ + url: '/api/test', + options: { method: 'POST', body: '{}' }, + }); + + expect(result.options.method).toBe('POST'); + expect(result.options.body).toBe('{}'); + expect(result.options.credentials).toBe('include'); + }); +}); + +describe('getCsrfToken', () => { + it('should return csrf token from cookie', async () => { + await withCookie('_csrf=my-csrf', async () => { + expect(getCsrfToken()).toBe('my-csrf'); + }); + }); + + it('should return undefined when no csrf cookie exists', async () => { + await withCookie('other=value', async () => { + expect(getCsrfToken()).toBeUndefined(); + }); + }); +}); + +describe('ossPostResponse', () => { + let assignSpy: ReturnType; + + beforeEach(() => { + vi.clearAllMocks(); + assignSpy = vi.fn(); + Object.defineProperty(window, 'location', { + value: { + assign: assignSpy, + origin: 'http://localhost', + href: 'http://localhost/', + }, + configurable: true, + writable: true, + }); + }); + + it('should pass through non-401 responses', async () => { + const response = new Response('ok', { status: 200 }); + const retry = vi.fn(); + + const result = await ossPostResponse(response, { + url: '/api/test', + options: {}, + retry, + }); + + expect(result).toBe(response); + expect(mockRefreshTokens).not.toHaveBeenCalled(); + expect(retry).not.toHaveBeenCalled(); + }); + + it('should call refreshTokens and retry on 401', async () => { + mockRefreshTokens.mockResolvedValue(true); + const retryResponse = new Response('retried', { status: 200 }); + const retry = vi.fn().mockResolvedValue(retryResponse); + + const response = new Response('unauthorized', { status: 401 }); + const result = await ossPostResponse(response, { + url: '/api/test', + options: {}, + retry, + }); + + expect(mockRefreshTokens).toHaveBeenCalledTimes(1); + expect(retry).toHaveBeenCalledTimes(1); + expect(result).toBe(retryResponse); + }); + + it('should not retry when refreshTokens returns false', async () => { + mockRefreshTokens.mockResolvedValue(false); + const retry = vi.fn(); + + const response = new Response('unauthorized', { status: 401 }); + const result = await ossPostResponse(response, { + url: '/api/test', + options: {}, + retry, + }); + + expect(mockRefreshTokens).toHaveBeenCalledTimes(1); + expect(retry).not.toHaveBeenCalled(); + expect(result).toBe(response); + }); + + it('should redirect to login when refreshTokens returns false', async () => { + mockRefreshTokens.mockResolvedValue(false); + + const response = new Response('unauthorized', { status: 401 }); + await ossPostResponse(response, { + url: '/api/test', + options: {}, + retry: vi.fn(), + }); + + expect(assignSpy).toHaveBeenCalledWith(expect.stringContaining('login')); + }); + + it('should not redirect for non-401 responses', async () => { + const response = new Response('forbidden', { status: 403 }); + await ossPostResponse(response, { + url: '/api/test', + options: {}, + retry: vi.fn(), + }); + + expect(assignSpy).not.toHaveBeenCalled(); + expect(mockRefreshTokens).not.toHaveBeenCalled(); + }); +}); diff --git a/src/lib/utilities/request-from-api.401-retry.test.ts b/src/lib/utilities/request-from-api.401-retry.test.ts new file mode 100644 index 0000000000..4639beed1e --- /dev/null +++ b/src/lib/utilities/request-from-api.401-retry.test.ts @@ -0,0 +1,355 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +vi.mock('./handle-error', () => ({ + handleError: vi.fn(), +})); + +vi.mock('./core-provider', () => ({ + runPreRequest: vi.fn(), + runPostResponse: vi.fn(), +})); + +import { runPostResponse, runPreRequest } from './core-provider'; +import { handleError } from './handle-error'; +import { requestFromAPI } from './request-from-api'; + +const mockRunPreRequest = vi.mocked(runPreRequest); +const mockRunPostResponse = vi.mocked(runPostResponse); + +type MockResponseConfig = { + body?: unknown; + ok?: boolean; + status?: number; + statusText?: string; +}; + +const createMockFetch = (...responses: MockResponseConfig[]) => { + let callIndex = 0; + return vi.fn(async () => { + const config = responses[Math.min(callIndex++, responses.length - 1)]; + return { + json: () => Promise.resolve(config.body ?? {}), + status: config.status ?? 200, + statusText: config.statusText ?? 'OK', + ok: config.ok ?? true, + }; + }) as unknown as typeof fetch; +}; + +describe('requestFromAPI with hooks', () => { + const endpoint = '/api/endpoint'; + + beforeEach(() => { + vi.clearAllMocks(); + mockRunPreRequest.mockImplementation(async (ctx) => ctx); + mockRunPostResponse.mockImplementation(async (res) => res); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('preRequest hook', () => { + it('should call runPreRequest before making the fetch', async () => { + mockRunPreRequest.mockImplementation(async (ctx) => ({ + ...ctx, + options: { + ...ctx.options, + headers: { + ...(ctx.options.headers as Record), + Authorization: 'Bearer injected-token', + }, + }, + })); + + const request = createMockFetch({ + status: 200, + ok: true, + body: { ok: true }, + }); + await requestFromAPI(endpoint, { request }); + + expect(mockRunPreRequest).toHaveBeenCalledTimes(1); + const fetchCallOptions = (request as ReturnType).mock + .calls[0][1] as RequestInit; + const headers = fetchCallOptions.headers as Record; + expect(headers['Authorization']).toBe('Bearer injected-token'); + }); + + it('should not call runPreRequest when not in browser', async () => { + const request = createMockFetch({ status: 200, ok: true, body: {} }); + await requestFromAPI(endpoint, { request, isBrowser: false }); + + expect(mockRunPreRequest).not.toHaveBeenCalled(); + }); + + it('should allow preRequest to modify the URL', async () => { + mockRunPreRequest.mockImplementation(async (ctx) => ({ + ...ctx, + url: ctx.url + '&injected=true', + })); + + const request = createMockFetch({ + status: 200, + ok: true, + body: { ok: true }, + }); + await requestFromAPI(endpoint, { request }); + + const fetchCallUrl = (request as ReturnType).mock + .calls[0][0] as string; + expect(fetchCallUrl).toContain('&injected=true'); + }); + }); + + describe('postResponse hook', () => { + it('should call runPostResponse after fetch', async () => { + const request = createMockFetch({ + status: 200, + ok: true, + body: { data: 'test' }, + }); + await requestFromAPI(endpoint, { request }); + + expect(mockRunPostResponse).toHaveBeenCalledTimes(1); + }); + + it('should pass retry callback that re-runs preRequest pipeline', async () => { + let retryCallbackCaptured: (() => Promise) | undefined; + + mockRunPostResponse.mockImplementation(async (res, ctx) => { + retryCallbackCaptured = ctx.retry; + return res; + }); + + const request = createMockFetch({ status: 200, ok: true, body: {} }); + await requestFromAPI(endpoint, { request }); + + expect(retryCallbackCaptured).toBeDefined(); + expect(mockRunPreRequest).toHaveBeenCalledTimes(1); + + await retryCallbackCaptured!(); + expect(mockRunPreRequest).toHaveBeenCalledTimes(2); + }); + + it('should use retried response when postResponse calls retry', async () => { + const request = createMockFetch( + { status: 401, ok: false, body: { message: 'unauthorized' } }, + { status: 200, ok: true, body: { data: 'refreshed' } }, + ); + + mockRunPostResponse.mockImplementation(async (res, ctx) => { + if (res.status === 401) { + return ctx.retry(); + } + return res; + }); + + const result = await requestFromAPI(endpoint, { request }); + + expect(request).toHaveBeenCalledTimes(2); + expect(result).toEqual({ data: 'refreshed' }); + }); + + it('should not call runPostResponse when not in browser', async () => { + const request = createMockFetch({ + status: 401, + ok: false, + body: { message: 'unauthorized' }, + }); + + await requestFromAPI(endpoint, { request, isBrowser: false }); + + expect(mockRunPostResponse).not.toHaveBeenCalled(); + }); + + it('should pass correct context to postResponse including url and options', async () => { + const request = createMockFetch({ + status: 200, + ok: true, + body: {}, + }); + + mockRunPreRequest.mockImplementation(async (ctx) => ({ + ...ctx, + options: { + ...ctx.options, + headers: { + ...(ctx.options.headers as Record), + Authorization: 'Bearer test', + }, + }, + })); + + await requestFromAPI(endpoint, { request }); + + const postResponseCall = mockRunPostResponse.mock.calls[0]; + const context = postResponseCall[1]; + expect(context.url).toContain(endpoint); + const headers = context.options.headers as Record; + expect(headers['Authorization']).toBe('Bearer test'); + }); + }); + + describe('retry builds fresh headers', () => { + it('should rebuild options from init.options on retry, not reuse stale headers', async () => { + let callCount = 0; + mockRunPreRequest.mockImplementation(async (ctx) => { + callCount++; + const headers = (ctx.options.headers as Record) ?? {}; + headers['Authorization'] = `Bearer token-${callCount}`; + return { + ...ctx, + options: { ...ctx.options, headers }, + }; + }); + + mockRunPostResponse.mockImplementation(async (res, ctx) => { + if (res.status === 401) { + return ctx.retry(); + } + return res; + }); + + const request = createMockFetch( + { status: 401, ok: false, body: { message: 'unauthorized' } }, + { status: 200, ok: true, body: { data: 'ok' } }, + ); + + await requestFromAPI(endpoint, { request }); + + const firstCallHeaders = ( + (request as ReturnType).mock.calls[0][1] as RequestInit + ).headers as Record; + const retryCallHeaders = ( + (request as ReturnType).mock.calls[1][1] as RequestInit + ).headers as Record; + + expect(firstCallHeaders['Authorization']).toBe('Bearer token-1'); + expect(retryCallHeaders['Authorization']).toBe('Bearer token-2'); + }); + }); + + describe('no infinite retry loop', () => { + it('should not call postResponse on the retried response', async () => { + const request = createMockFetch( + { status: 401, ok: false, body: { message: 'unauthorized' } }, + { status: 401, ok: false, body: { message: 'still unauthorized' } }, + ); + + mockRunPostResponse.mockImplementation(async (res, ctx) => { + if (res.status === 401) { + return ctx.retry(); + } + return res; + }); + + await requestFromAPI(endpoint, { request }); + + expect(mockRunPostResponse).toHaveBeenCalledTimes(1); + expect(request).toHaveBeenCalledTimes(2); + }); + }); + + describe('error in hooks', () => { + it('should propagate preRequest errors to handleError', async () => { + mockRunPreRequest.mockRejectedValue(new Error('preRequest exploded')); + + const request = createMockFetch({ status: 200, ok: true, body: {} }); + await requestFromAPI(endpoint, { request }); + + expect(handleError).toHaveBeenCalled(); + expect(request).not.toHaveBeenCalled(); + }); + + it('should propagate postResponse errors to handleError', async () => { + mockRunPostResponse.mockRejectedValue(new Error('postResponse exploded')); + + const request = createMockFetch({ status: 200, ok: true, body: {} }); + await requestFromAPI(endpoint, { request }); + + expect(handleError).toHaveBeenCalled(); + }); + + it('should throw preRequest error when notifyOnError is false', async () => { + mockRunPreRequest.mockRejectedValue(new Error('preRequest exploded')); + + const request = createMockFetch({ status: 200, ok: true, body: {} }); + const error = await requestFromAPI(endpoint, { + request, + notifyOnError: false, + }).catch((e) => e); + + expect(error).toBeInstanceOf(Error); + expect((error as Error).message).toBe('preRequest exploded'); + }); + }); + + describe('long query bypass', () => { + it('should still call postResponse on synthetic 414 response', async () => { + const longQuery = 'x'.repeat(20000); + const request = createMockFetch({ status: 200, ok: true, body: {} }); + + await requestFromAPI(endpoint, { + request, + params: { query: longQuery }, + }); + + expect(request).not.toHaveBeenCalled(); + expect(mockRunPostResponse).toHaveBeenCalledTimes(1); + + const postResponseCall = mockRunPostResponse.mock.calls[0]; + const response = postResponseCall[0]; + expect(response.status).toBe(414); + }); + }); + + describe('settings endpoint', () => { + it('should still run preRequest for settings endpoint', async () => { + const settingsEndpoint = '/api/v1/settings'; + const request = createMockFetch({ + status: 200, + ok: true, + body: { Auth: { Enabled: true } }, + }); + + await requestFromAPI(settingsEndpoint, { request }); + + expect(mockRunPreRequest).toHaveBeenCalledTimes(1); + }); + }); + + describe('error handling unchanged', () => { + it('should call onError when response is not ok', async () => { + const onError = vi.fn(); + const errorBody = { message: 'forbidden' }; + const request = createMockFetch({ + status: 403, + ok: false, + statusText: 'Forbidden', + body: errorBody, + }); + + await requestFromAPI(endpoint, { request, onError }); + + expect(onError).toHaveBeenCalledWith({ + body: errorBody, + status: 403, + statusText: 'Forbidden', + }); + }); + + it('should call handleError when no onError provided', async () => { + const request = createMockFetch({ + status: 500, + ok: false, + statusText: 'Internal Server Error', + body: { message: 'error' }, + }); + + await requestFromAPI(endpoint, { request }); + + expect(handleError).toHaveBeenCalled(); + }); + }); +}); diff --git a/src/lib/utilities/request-from-api.integration.test.ts b/src/lib/utilities/request-from-api.integration.test.ts new file mode 100644 index 0000000000..4f0f6b54d4 --- /dev/null +++ b/src/lib/utilities/request-from-api.integration.test.ts @@ -0,0 +1,180 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; + +vi.mock('./handle-error', () => ({ + handleError: vi.fn(), +})); + +vi.mock('$lib/stores/auth-user', () => ({ + getAuthUser: vi.fn(), +})); + +vi.mock('./auth-refresh', () => ({ + refreshTokens: vi.fn(), +})); + +import { getAuthUser } from '$lib/stores/auth-user'; + +import { refreshTokens } from './auth-refresh'; +import { initCoreProvider } from './core-provider'; +import { requestFromAPI } from './request-from-api'; + +import { ossPostResponse, ossPreRequest } from './oss-provider.svelte'; + +const mockGetAuthUser = vi.mocked(getAuthUser); +const mockRefreshTokens = vi.mocked(refreshTokens); + +type MockResponseConfig = { + body?: unknown; + ok?: boolean; + status?: number; + statusText?: string; +}; + +const createMockFetch = (...responses: MockResponseConfig[]) => { + let callIndex = 0; + return vi.fn(async () => { + const config = responses[Math.min(callIndex++, responses.length - 1)]; + return { + json: () => Promise.resolve(config.body ?? {}), + status: config.status ?? 200, + statusText: config.statusText ?? 'OK', + ok: config.ok ?? true, + }; + }) as unknown as typeof fetch; +}; + +describe('request-from-api integration with OSS provider', () => { + beforeEach(() => { + vi.clearAllMocks(); + vi.stubGlobal('window', { + location: { + assign: vi.fn(), + origin: 'http://localhost', + href: 'http://localhost/', + }, + }); + mockGetAuthUser.mockReturnValue({ + accessToken: 'test-access-token', + idToken: 'test-id-token', + }); + + initCoreProvider({ + getAccessToken: async () => mockGetAuthUser().accessToken ?? '', + getIdToken: async () => mockGetAuthUser().idToken, + api: { + preRequest: ossPreRequest, + postResponse: ossPostResponse, + }, + }); + }); + + afterEach(() => { + vi.restoreAllMocks(); + vi.unstubAllGlobals(); + }); + + it('should attach auth headers and credentials on a successful request', async () => { + const request = createMockFetch({ + status: 200, + ok: true, + body: { workflows: [] }, + }); + + const result = await requestFromAPI('/api/v1/workflows', { request }); + + expect(result).toEqual({ workflows: [] }); + + const fetchOptions = (request as ReturnType).mock + .calls[0][1] as RequestInit; + const headers = fetchOptions.headers as Record; + expect(headers['Authorization']).toBe('Bearer test-access-token'); + expect(headers['Authorization-Extras']).toBe('test-id-token'); + expect(headers['Caller-Type']).toBe('operator'); + expect(fetchOptions.credentials).toBe('include'); + }); + + it('should refresh tokens and retry on 401', async () => { + mockGetAuthUser.mockReturnValue({ accessToken: 'stale-token' }); + + mockRefreshTokens.mockImplementation(async () => { + mockGetAuthUser.mockReturnValue({ accessToken: 'fresh-token' }); + return true; + }); + + const request = createMockFetch( + { status: 401, ok: false, body: { message: 'unauthorized' } }, + { status: 200, ok: true, body: { data: 'success' } }, + ); + + const result = await requestFromAPI('/api/v1/workflows', { request }); + + expect(result).toEqual({ data: 'success' }); + expect(request).toHaveBeenCalledTimes(2); + expect(mockRefreshTokens).toHaveBeenCalledTimes(1); + + const retryHeaders = ( + (request as ReturnType).mock.calls[1][1] as RequestInit + ).headers as Record; + expect(retryHeaders['Authorization']).toBe('Bearer fresh-token'); + }); + + it('should not retry when refresh fails', async () => { + mockRefreshTokens.mockResolvedValue(false); + + const request = createMockFetch({ + status: 401, + ok: false, + body: { message: 'unauthorized' }, + }); + + const onError = vi.fn(); + await requestFromAPI('/api/v1/workflows', { request, onError }); + + expect(request).toHaveBeenCalledTimes(1); + expect(mockRefreshTokens).toHaveBeenCalledTimes(1); + expect(onError).toHaveBeenCalledWith( + expect.objectContaining({ status: 401 }), + ); + }); + + it('should not add auth headers when user has no tokens', async () => { + mockGetAuthUser.mockReturnValue({}); + + const request = createMockFetch({ + status: 200, + ok: true, + body: {}, + }); + + await requestFromAPI('/api/v1/workflows', { request }); + + const headers = ( + (request as ReturnType).mock.calls[0][1] as RequestInit + ).headers as Record; + expect(headers['Authorization']).toBeUndefined(); + expect(headers['Authorization-Extras']).toBeUndefined(); + expect(headers['Caller-Type']).toBe('operator'); + }); + + it('should handle concurrent requests independently', async () => { + let callCount = 0; + const request = vi.fn(async () => { + callCount++; + return { + json: () => Promise.resolve({ call: callCount }), + status: 200, + statusText: 'OK', + ok: true, + }; + }) as unknown as typeof fetch; + + const [result1, result2] = await Promise.all([ + requestFromAPI('/api/v1/endpoint-a', { request }), + requestFromAPI('/api/v1/endpoint-b', { request }), + ]); + + expect(request).toHaveBeenCalledTimes(2); + expect(result1).toBeDefined(); + expect(result2).toBeDefined(); + }); +}); diff --git a/src/lib/utilities/request-from-api.test.ts b/src/lib/utilities/request-from-api.test.ts index 3f7c3fcf40..a0bccc3a0f 100644 --- a/src/lib/utilities/request-from-api.test.ts +++ b/src/lib/utilities/request-from-api.test.ts @@ -26,22 +26,6 @@ vi.mock('./handle-error', () => { return { handleError: vi.fn() }; }); -const withCookie = async (cookie: string, fn: () => void) => { - const currentCookie = document.cookie; - - Object.defineProperty(document, 'cookie', { - writable: true, - value: cookie, - }); - - fn(); - - Object.defineProperty(document, 'cookie', { - writable: true, - value: currentCookie, - }); -}; - describe('isTemporalAPIError', () => { it('should return false if undefined', () => { expect(isTemporalAPIError(undefined)).toBe(false); @@ -61,7 +45,6 @@ describe('requestFromAPI', () => { const responseBody = listWorkflowResponse; const options = { - credentials: 'include', headers: { 'Caller-Type': 'operator', }, @@ -87,66 +70,12 @@ describe('requestFromAPI', () => { expect(request).toHaveBeenCalledWith(endpoint + '?', options); }); - it('should add credentials to options', async () => { + it('should add Caller-Type header to options', async () => { const request = fetchMock(); await requestFromAPI(endpoint, { request, options: {} }); expect(request).toHaveBeenCalledWith(endpoint + '?', options); }); - it('should add csrf cookie to headers', async () => { - const token = 'token'; - - const request = fetchMock(); - await withCookie(`_csrf=${token}`, async () => { - await requestFromAPI(endpoint, { request }); - }); - - expect(request).toHaveBeenCalledWith(endpoint + '?', { - ...options, - headers: { - 'X-CSRF-TOKEN': token, - 'Caller-Type': 'operator', - }, - }); - }); - - it('should not add csrf cookie to headers if not presdent', async () => { - const token = 'token'; - - const request = fetchMock(); - await withCookie(`_nope=${token}`, async () => { - await requestFromAPI(endpoint, { request }); - }); - - expect(request).toHaveBeenCalledWith(endpoint + '?', options); - }); - - it('should not add csrf cookie to headers if not running in the browser', async () => { - const token = 'token'; - - const request = fetchMock(); - await withCookie(`_csrf=${token}`, async () => { - await requestFromAPI(endpoint, { request, isBrowser: false }); - }); - - expect(request).toHaveBeenCalledWith(endpoint + '?', options); - }); - - it('should not add csrf cookie to headers it already exists', async () => { - const token = 'token'; - const headers = { - 'X-CSRF-TOKEN': 'pre-existing', - }; - const opts = { ...options, headers }; - - const request = fetchMock(); - await withCookie(`_csrf=${token}`, async () => { - await requestFromAPI(endpoint, { request, options: opts as RequestInit }); - }); - - expect(request).toHaveBeenCalledWith(endpoint + '?', opts); - }); - it('should create an empty array of headers if not provided', async () => { const request = fetchMock(); await requestFromAPI(endpoint, { request, options: undefined }); diff --git a/src/lib/utilities/request-from-api.ts b/src/lib/utilities/request-from-api.ts index 9c9194b173..ca5eb574b6 100644 --- a/src/lib/utilities/request-from-api.ts +++ b/src/lib/utilities/request-from-api.ts @@ -1,9 +1,12 @@ import { BROWSER } from 'esm-env'; -import { getAuthUser } from '$lib/stores/auth-user'; import type { NetworkError } from '$lib/types/global'; -import { refreshTokens } from './auth-refresh'; +import { + type RequestContext, + runPostResponse, + runPreRequest, +} from './core-provider'; import { handleError as handleRequestError } from './handle-error'; import { isFunction } from './is-function'; import { toURL } from './to-url'; @@ -70,7 +73,6 @@ export const requestFromAPI = async ( onError, isBrowser = BROWSER, } = init; - let { options } = init; let query = new URLSearchParams(); if (params?.entries) { @@ -78,7 +80,6 @@ export const requestFromAPI = async ( if (token) query.set('nextPageToken', token); } else { const nextPageToken = token ? { next_page_token: token } : {}; - // Filter out undefined values before passing to URLSearchParams const paramsWithoutUndefined = Object.fromEntries( Object.entries({ ...params, ...nextPageToken }).filter( ([_, v]) => v !== undefined, @@ -89,46 +90,52 @@ export const requestFromAPI = async ( const url = toURL(endpoint, query); try { - options = withSecurityOptions(options, isBrowser); - if (!endpoint.endsWith('api/v1/settings')) { - options = await withAuth(options, isBrowser); - } + const baseOptions: RequestInit = { + ...init.options, + headers: withCallerType(init.options?.headers), + }; const queryIsTooLong = [...query.values()].some( (value) => value.length > MAX_QUERY_LENGTH, ); - const makeRequest = async () => + const executeRequest = async (ctx: { + url: string; + options: RequestInit; + }) => queryIsTooLong ? new Response( JSON.stringify({ message: 'Query string is too long' }), - { - status: 414, - statusText: 'URI Too Long', - }, + { status: 414, statusText: 'URI Too Long' }, ) - : await request(url, options); - - let response = await makeRequest(); - let { status, statusText } = response; - - // Shouldn't this check the expiry on the jwt and refresh before we make a request instead of - // doing a 401? If we get a 401 and we have done all of our refreshes shouldn't we send the user to the login - // page? Asking for a friend (claude) - if (isBrowser && status === 401) { - const refreshed = await refreshTokens(); - if (refreshed) { - options = withSecurityOptions(init.options, isBrowser); - if (!endpoint.endsWith('api/v1/settings')) { - options = await withAuth(options, isBrowser); - } - response = await makeRequest(); - status = response.status; - statusText = response.statusText; - } - // If refresh failed, let the error flow to handleError() which will redirect to login + : await request(ctx.url, ctx.options); + + let context = { url, options: baseOptions }; + + if (isBrowser) { + context = await runPreRequest(context); + } + + let response = await executeRequest(context); + + if (isBrowser) { + response = await runPostResponse(response, { + ...context, + retry: async () => { + let retryContext: RequestContext = { + url, + options: { + ...init.options, + headers: withCallerType(init.options?.headers), + }, + }; + retryContext = await runPreRequest(retryContext); + return executeRequest(retryContext); + }, + }); } + const { status, statusText } = response; const body = await response.json(); if (!response.ok) { @@ -154,92 +161,10 @@ export const requestFromAPI = async ( } }; -const withSecurityOptions = ( - options: RequestInit | undefined, - isBrowser = BROWSER, -): RequestInit => { - const opts: RequestInit = { credentials: 'include', ...options }; - opts.headers = withCsrf(opts.headers, isBrowser); - return opts; -}; - -const withAuth = async ( - options: RequestInit, - isBrowser = BROWSER, -): Promise => { - const headers: Record = - (options.headers as Record) ?? {}; - - if ((globalThis as Record)?.AccessToken) { - const accessToken = (globalThis as Record) - .AccessToken as () => Promise; - options.headers = await withBearerToken(headers, accessToken, isBrowser); - } else if (getAuthUser().accessToken) { - options.headers = await withBearerToken( - headers, - async () => getAuthUser().accessToken ?? '', - isBrowser, - ); - options.headers = withIdToken( - options.headers as Record, - getAuthUser().idToken ?? '', - isBrowser, - ); - } - - return options; -}; - -const withBearerToken = async ( - headers: Record, - accessToken: () => Promise, - isBrowser = BROWSER, -): Promise> => { - if (!isBrowser) return headers; - - const token = await accessToken(); - if (token) { - headers['Authorization'] = `Bearer ${token}`; - } - - return headers; -}; - -const withIdToken = ( - headers: Record, - idToken: string, - isBrowser = BROWSER, -): Record => { - if (!isBrowser) return headers; - - if (idToken) { - headers['Authorization-Extras'] = idToken; - } - - return headers; -}; - -const withCsrf = ( +const withCallerType = ( headers: HeadersInit | undefined, - isBrowser = BROWSER, ): Record => { const h: Record = (headers as Record) ?? {}; h['Caller-Type'] = 'operator'; - if (!isBrowser) return h; - - const csrfCookie = '_csrf='; - const csrfHeader = 'X-CSRF-TOKEN'; - try { - const cookies = document.cookie.split(';'); - let csrf = cookies.find((c) => c.includes(csrfCookie)); - if (csrf && !h[csrfHeader]) { - csrf = csrf.trim().slice(csrfCookie.length); - h[csrfHeader] = csrf; - } - /* c8 ignore next 4 */ - } catch (error) { - console.error(error); - } - return h; }; diff --git a/src/lib/utilities/request-from-api.with-access-token.test.ts b/src/lib/utilities/request-from-api.with-access-token.test.ts index 8c8dc73105..ac33b69c90 100644 --- a/src/lib/utilities/request-from-api.with-access-token.test.ts +++ b/src/lib/utilities/request-from-api.with-access-token.test.ts @@ -12,7 +12,6 @@ type MockResponse = { }; const options = { - credentials: 'include', headers: { 'Caller-Type': 'operator', }, diff --git a/src/routes/(app)/+layout.ts b/src/routes/(app)/+layout.ts index 8552c57a0b..be6ed173a3 100644 --- a/src/routes/(app)/+layout.ts +++ b/src/routes/(app)/+layout.ts @@ -5,13 +5,9 @@ import type { LayoutData, LayoutLoad } from './$types'; import { fetchCluster, fetchSystemInfo } from '$lib/services/cluster-service'; import { fetchNamespaces } from '$lib/services/namespaces-service'; import { fetchSettings } from '$lib/services/settings-service'; -import { clearAuthUser, getAuthUser, setAuthUser } from '$lib/stores/auth-user'; +import { clearAuthUser, getAuthUser } from '$lib/stores/auth-user'; import type { GetClusterInfoResponse, GetSystemInfoResponse } from '$lib/types'; import type { Settings } from '$lib/types/global'; -import { - cleanAuthUserCookie, - getAuthUserCookie, -} from '$lib/utilities/auth-user-cookie'; import { isAuthorized } from '$lib/utilities/is-authorized'; import { routeForLoginPage } from '$lib/utilities/route-for'; @@ -23,16 +19,9 @@ export const load: LayoutLoad = async function ({ const settings: Settings = await fetchSettings(fetch); if (!settings.auth.enabled) { - cleanAuthUserCookie(); clearAuthUser(); } - const authUser = getAuthUserCookie(); - if (authUser?.accessToken) { - setAuthUser(authUser); - cleanAuthUserCookie(); - } - const user = getAuthUser(); if (!isAuthorized(settings, user)) { diff --git a/tests/integration/oauth-flow.spec.ts b/tests/integration/oauth-flow.spec.ts index 7ef8a9e8b9..23495f64fe 100644 --- a/tests/integration/oauth-flow.spec.ts +++ b/tests/integration/oauth-flow.spec.ts @@ -123,3 +123,33 @@ test('refreshes token and retries request on 401', async ({ expect(authHeaders[0]).toBe('Bearer expired-token'); expect(authHeaders[1]).toBe('Bearer refreshed-token-456'); }); + +test('redirects to login when session is expired and refresh fails', async ({ + page, + baseURL, +}) => { + await mockAuthApis(page); + + await page.route(WORKFLOWS_API, async (route) => { + await route.fulfill({ status: 401, body: 'Unauthorized' }); + }); + + await page.route('**/auth/refresh**', async (route) => { + await route.fulfill({ status: 401, body: 'session expired' }); + }); + + await page.context().addCookies([ + { + name: 'user0', + value: makeUserCookie('expired-token'), + domain: 'localhost', + path: '/', + }, + ]); + + await page.goto(baseURL); + + await page.waitForURL('**/login**'); + + expect(page.url()).toContain('login'); +});