Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 85 additions & 69 deletions src/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,12 @@ export interface ServerOptions {
* execute the subscription operation
* upon.
*/
execute: (args: ExecutionArgs) => Promise<ExecutionResult> | ExecutionResult;
execute: (
args: ExecutionArgs,
) =>
| Promise<ExecutionResult>
| ExecutionResult
| AsyncIterableIterator<ExecutionResult>;
/**
* Is the `subscribe` function
* from GraphQL which is used to
Expand Down Expand Up @@ -457,59 +462,65 @@ export function createServer(
});
}

// perform
if (operationAST.operation === 'subscription') {
const subscriptionOrResult = await subscribe(execArgs);
if (isAsyncIterable(subscriptionOrResult)) {
// iterable subscriptions are distinct on ID
if (ctx.subscriptions[message.id]) {
return ctx.socket.close(
4409,
`Subscriber for ${message.id} already exists`,
);
}
ctx.subscriptions[message.id] = subscriptionOrResult;

try {
for await (let result of subscriptionOrResult) {
// use the root formater first
if (formatExecutionResult) {
result = await formatExecutionResult(ctx, result);
}
// then use the subscription specific formatter
if (onSubscribeFormatter) {
result = await onSubscribeFormatter(ctx, result);
}
await sendMessage<MessageType.Next>(ctx, {
id: message.id,
type: MessageType.Next,
payload: result,
});
}
const asyncIterableHandler = async (
asyncIterable: AsyncIterableIterator<ExecutionResult>,
) => {
// iterable subscriptions are distinct on ID
if (ctx.subscriptions[message.id]) {
return ctx.socket.close(
4409,
`Subscriber for ${message.id} already exists`,
);
}
ctx.subscriptions[message.id] = asyncIterable;

const completeMessage: CompleteMessage = {
id: message.id,
type: MessageType.Complete,
};
await sendMessage<MessageType.Complete>(ctx, completeMessage);
if (onComplete) {
onComplete(ctx, completeMessage);
try {
for await (let result of asyncIterable) {
// use the root formater first
if (formatExecutionResult) {
result = await formatExecutionResult(ctx, result);
}
} catch (err) {
await sendMessage<MessageType.Error>(ctx, {
// then use the subscription specific formatter
if (onSubscribeFormatter) {
result = await onSubscribeFormatter(ctx, result);
}
await sendMessage<MessageType.Next>(ctx, {
id: message.id,
type: MessageType.Error,
payload: [
new GraphQLError(
err instanceof Error
? err.message
: new Error(err).message,
),
],
type: MessageType.Next,
payload: result,
});
} finally {
delete ctx.subscriptions[message.id];
}

const completeMessage: CompleteMessage = {
id: message.id,
type: MessageType.Complete,
};
await sendMessage<MessageType.Complete>(ctx, completeMessage);
if (onComplete) {
onComplete(ctx, completeMessage);
}
} catch (err) {
await sendMessage<MessageType.Error>(ctx, {
id: message.id,
type: MessageType.Error,
payload: [
new GraphQLError(
err instanceof Error
? err.message
: new Error(err).message,
),
],
});
} finally {
delete ctx.subscriptions[message.id];
}
};

// perform
if (operationAST.operation === 'subscription') {
const subscriptionOrResult = await subscribe(execArgs);
if (isAsyncIterable(subscriptionOrResult)) {
await asyncIterableHandler(subscriptionOrResult);
} else {
let result = subscriptionOrResult;
// use the root formater first
Expand Down Expand Up @@ -539,27 +550,32 @@ export function createServer(
// operationAST.operation === 'query' || 'mutation'

let result = await execute(execArgs);
// use the root formater first
if (formatExecutionResult) {
result = await formatExecutionResult(ctx, result);
}
// then use the subscription specific formatter
if (onSubscribeFormatter) {
result = await onSubscribeFormatter(ctx, result);
}
await sendMessage<MessageType.Next>(ctx, {
id: message.id,
type: MessageType.Next,
payload: result,
});

const completeMessage: CompleteMessage = {
id: message.id,
type: MessageType.Complete,
};
await sendMessage<MessageType.Complete>(ctx, completeMessage);
if (onComplete) {
onComplete(ctx, completeMessage);
if (isAsyncIterable(result)) {
await asyncIterableHandler(result);
} else {
// use the root formater first
if (formatExecutionResult) {
result = await formatExecutionResult(ctx, result);
}
// then use the subscription specific formatter
if (onSubscribeFormatter) {
result = await onSubscribeFormatter(ctx, result);
}
await sendMessage<MessageType.Next>(ctx, {
id: message.id,
type: MessageType.Next,
payload: result,
});

const completeMessage: CompleteMessage = {
id: message.id,
type: MessageType.Complete,
};
await sendMessage<MessageType.Complete>(ctx, completeMessage);
if (onComplete) {
onComplete(ctx, completeMessage);
}
}
}
break;
Expand Down
81 changes: 81 additions & 0 deletions src/tests/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,87 @@ describe('Subscribe', () => {
await wait(20);
});

it('should execute a query operation with custom execute that returns a AsyncIterableIterator, "next" the results and then "complete"', async () => {
expect.assertions(5);

await makeServer({
schema,
execute: async function* () {
for (const value of ['Hi', 'Hello', 'Sup']) {
yield {
data: {
getValue: value,
},
};
}
},
});

const client = new WebSocket(url, GRAPHQL_TRANSPORT_WS_PROTOCOL);
client.onopen = () => {
client.send(
stringifyMessage<MessageType.ConnectionInit>({
type: MessageType.ConnectionInit,
}),
);
};

let receivedNextCount = 0;
client.onmessage = ({ data }) => {
const message = parseMessage(data);
switch (message.type) {
case MessageType.ConnectionAck:
client.send(
stringifyMessage<MessageType.Subscribe>({
id: '1',
type: MessageType.Subscribe,
payload: {
operationName: 'TestString',
query: `query TestString {
getValue
}`,
variables: {},
},
}),
);
break;
case MessageType.Next:
receivedNextCount++;
if (receivedNextCount === 1) {
expect(message).toEqual({
id: '1',
type: MessageType.Next,
payload: { data: { getValue: 'Hi' } },
});
} else if (receivedNextCount === 2) {
expect(message).toEqual({
id: '1',
type: MessageType.Next,
payload: { data: { getValue: 'Hello' } },
});
} else if (receivedNextCount === 3) {
expect(message).toEqual({
id: '1',
type: MessageType.Next,
payload: { data: { getValue: 'Sup' } },
});
}
break;
case MessageType.Complete:
expect(receivedNextCount).toEqual(3);
expect(message).toEqual({
id: '1',
type: MessageType.Complete,
});
break;
default:
fail(`Not supposed to receive a message of type ${message.type}`);
}
};

await wait(20);
});

it('should execute the query of `DocumentNode` type, "next" the result and then "complete"', async () => {
expect.assertions(3);

Expand Down