Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions invokeai/app/invocations/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,15 +279,28 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
title="Image Collection Primitive",
tags=["primitives", "image", "collection"],
category="primitives",
version="1.0.1",
version="1.0.2",
)
class ImageCollectionInvocation(BaseInvocation):
"""A collection of image primitive values"""

collection: list[ImageField] = InputField(description="The collection of image values")
collection: Optional[list[ImageField]] = InputField(
default=None,
description="An optional image collection to append to",
input=Input.Connection,
title="Collection",
ui_order=0,
)
images: Optional[list[ImageField]] = InputField(
default=None,
description="The images to append to the collection",
input=Input.Direct,
title="Images",
ui_order=1,
)

def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
return ImageCollectionOutput(collection=self.collection)
return ImageCollectionOutput(collection=[*(self.collection or []), *(self.images or [])])


# endregion
Expand Down
33 changes: 28 additions & 5 deletions invokeai/frontend/web/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -32336,11 +32336,34 @@
}
],
"default": null,
"description": "The collection of image values",
"description": "An optional image collection to append to",
"field_kind": "input",
"input": "any",
"orig_required": true,
"title": "Collection"
"input": "connection",
"orig_default": null,
"orig_required": false,
"title": "Collection",
"ui_order": 0
},
"images": {
"anyOf": [
{
"items": {
"$ref": "#/components/schemas/ImageField"
},
"type": "array"
},
{
"type": "null"
}
],
"default": null,
"description": "The images to append to the collection",
"field_kind": "input",
"input": "direct",
"orig_default": null,
"orig_required": false,
"title": "Images",
"ui_order": 1
},
"type": {
"const": "image_collection",
Expand All @@ -32354,7 +32377,7 @@
"tags": ["primitives", "image", "collection"],
"title": "Image Collection Primitive",
"type": "object",
"version": "1.0.1",
"version": "1.0.2",
"output": {
"$ref": "#/components/schemas/ImageCollectionOutput"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { logger } from 'app/logging/logger';
import { useAppStore } from 'app/store/storeHooks';
import { useGetNodesNeedUpdate } from 'features/nodes/hooks/useGetNodesNeedUpdate';
import { $templates, nodesChanged } from 'features/nodes/store/nodesSlice';
import { selectNodes } from 'features/nodes/store/selectors';
import { selectEdges, selectNodes } from 'features/nodes/store/selectors';
import { NodeUpdateError } from 'features/nodes/types/error';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { getNeedsUpdate, updateNode } from 'features/nodes/util/node/nodeUpdate';
Expand All @@ -20,6 +20,7 @@ const useUpdateNodes = () => {

const updateNodes = useCallback(() => {
const nodes = selectNodes(store.getState());
const edges = selectEdges(store.getState());
const templates = $templates.get();

let unableToUpdateCount = 0;
Expand All @@ -35,7 +36,12 @@ const useUpdateNodes = () => {
return;
}
try {
const updatedNode = updateNode(node, template);
const connectedInputNames = new Set(
edges.flatMap((edge) =>
edge.type === 'default' && edge.target === node.id && edge.targetHandle ? [edge.targetHandle] : []
)
);
const updatedNode = updateNode(node, template, { connectedInputNames });
store.dispatch(
nodesChanged([
{ type: 'remove', id: updatedNode.id },
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import type { InvocationTemplate } from 'features/nodes/types/invocation';
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
import { updateNode } from 'features/nodes/util/node/nodeUpdate';
import { describe, expect, it } from 'vitest';

const imageCollectionOutput = {
collection: {
fieldKind: 'output',
name: 'collection',
title: 'Collection',
description: 'The output images',
type: {
name: 'ImageField',
cardinality: 'COLLECTION',
batch: false,
},
ui_hidden: false,
},
} satisfies InvocationTemplate['outputs'];

const oldImageCollectionTemplate = {
title: 'Image Collection Primitive',
type: 'image_collection',
version: '1.0.1',
tags: ['primitives', 'image', 'collection'],
description: 'A collection of image primitive values',
outputType: 'image_collection_output',
inputs: {
collection: {
name: 'collection',
title: 'Collection',
required: false,
description: 'The collection of image values',
fieldKind: 'input',
input: 'any',
ui_hidden: false,
type: {
name: 'ImageField',
cardinality: 'COLLECTION',
batch: false,
},
default: undefined,
},
},
outputs: imageCollectionOutput,
useCache: true,
nodePack: 'invokeai',
classification: 'stable',
category: 'primitives',
} satisfies InvocationTemplate;

const currentImageCollectionTemplate = {
...oldImageCollectionTemplate,
version: '1.0.2',
inputs: {
collection: {
name: 'collection',
title: 'Collection',
required: false,
description: 'An optional image collection to append to',
fieldKind: 'input',
input: 'connection',
ui_hidden: false,
type: {
name: 'ImageField',
cardinality: 'COLLECTION',
batch: false,
},
default: [],
},
images: {
name: 'images',
title: 'Images',
required: false,
description: 'The images to append to the collection',
fieldKind: 'input',
input: 'direct',
ui_hidden: false,
type: {
name: 'ImageField',
cardinality: 'COLLECTION',
batch: false,
},
default: undefined,
},
},
} satisfies InvocationTemplate;

describe('updateNode', () => {
it('moves old image_collection direct collection values to the new images field', () => {
const node = buildInvocationNode({ x: 0, y: 0 }, oldImageCollectionTemplate);
const images = [{ image_name: 'first' }, { image_name: 'second' }];
const collectionInput = node.data.inputs.collection;
if (!collectionInput) {
throw new Error('Expected collection input');
}
collectionInput.value = images;

const updated = updateNode(node, currentImageCollectionTemplate);

expect(updated.data.version).toBe('1.0.2');
expect(updated.data.inputs.images?.value).toEqual(images);
expect(updated.data.inputs.collection?.value).toEqual([]);
});

it('does not move old image_collection direct collection values when collection is connected', () => {
const node = buildInvocationNode({ x: 0, y: 0 }, oldImageCollectionTemplate);
const images = [{ image_name: 'stale' }];
const collectionInput = node.data.inputs.collection;
if (!collectionInput) {
throw new Error('Expected collection input');
}
collectionInput.value = images;

const updated = updateNode(node, currentImageCollectionTemplate, {
connectedInputNames: new Set(['collection']),
});

expect(updated.data.inputs.images?.value).toBeUndefined();
expect(updated.data.inputs.collection?.value).toEqual([]);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ import { zParsedSemver } from 'features/nodes/types/semver';

import { buildInvocationNode } from './buildInvocationNode';

type UpdateNodeOptions = {
connectedInputNames?: Set<string>;
};

export const getNeedsUpdate = (data: InvocationNodeData, template: InvocationTemplate): boolean => {
if (data.type !== template.type) {
return true;
Expand All @@ -29,6 +33,26 @@ const getMayUpdateNode = (node: InvocationNode, template: InvocationTemplate): b
return satisfies(node.data.version, `^${templateMajor}`);
};

const migrateImageCollectionInputValues = (node: InvocationNode, options?: UpdateNodeOptions) => {
if (node.data.type !== 'image_collection') {
return;
}

const collection = node.data.inputs.collection;
const images = node.data.inputs.images;
if (!collection || !images || !Array.isArray(collection.value)) {
return;
}
if (Array.isArray(images.value) && images.value.length > 0) {
return;
}

if (!options?.connectedInputNames?.has('collection')) {
images.value = collection.value;
}
collection.value = [];
};

/**
* Updates a node to the latest version of its template:
* - Create a new node data object with the latest version of the template.
Expand All @@ -40,7 +64,11 @@ const getMayUpdateNode = (node: InvocationNode, template: InvocationTemplate): b
* @param template The invocation template to update to.
* @throws {NodeUpdateError} If the node is not an invocation node.
*/
export const updateNode = (node: InvocationNode, template: InvocationTemplate): InvocationNode => {
export const updateNode = (
node: InvocationNode,
template: InvocationTemplate,
options?: UpdateNodeOptions
): InvocationNode => {
const mayUpdate = getMayUpdateNode(node, template);

if (!mayUpdate || node.data.type !== template.type) {
Expand All @@ -56,6 +84,7 @@ export const updateNode = (node: InvocationNode, template: InvocationTemplate):
const clone = deepClone(node);
clone.data.version = template.version;
defaultsDeep(clone, defaults); // mutates!
migrateImageCollectionInputValues(clone, options);

// Remove any fields that are not in the template
clone.data.inputs = pick(clone.data.inputs, keys(defaults.data.inputs));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
// This node needs to be updated, based on comparison of its version to the template version
if (getNeedsUpdate(node.data, template)) {
try {
const updatedNode = updateNode(node, template);
const connectedInputNames = new Set(
edges.flatMap((edge) =>
edge.type === 'default' && edge.target === node.id && edge.targetHandle ? [edge.targetHandle] : []
)
);
const updatedNode = updateNode(node, template, { connectedInputNames });
node.data = updatedNode.data;
} catch {
const message = t('nodes.unableToUpdateNode', {
Expand Down
8 changes: 7 additions & 1 deletion invokeai/frontend/web/src/services/api/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13760,10 +13760,16 @@ export type components = {
use_cache?: boolean;
/**
* Collection
* @description The collection of image values
* @description An optional image collection to append to
* @default null
*/
collection?: components["schemas"]["ImageField"][] | null;
/**
* Images
* @description The images to append to the collection
* @default null
*/
images?: components["schemas"]["ImageField"][] | null;
/**
* type
* @default image_collection
Expand Down
32 changes: 32 additions & 0 deletions tests/app/invocations/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from PIL import Image, ImageFilter

from invokeai.app.invocations.image import ImageField, OklabUnsharpMaskInvocation, OklchImageHueAdjustmentInvocation
from invokeai.app.invocations.primitives import ImageCollectionInvocation
from invokeai.backend.image_util.color_conversion import (
linear_srgb_from_oklab,
linear_srgb_from_oklch,
Expand Down Expand Up @@ -47,6 +48,37 @@ def _max_abs_diff_uint8(left: Image.Image, right: Image.Image) -> int:
return int(numpy.abs(left_arr - right_arr).max())


def test_image_collection_invocation_preserves_existing_collection_values() -> None:
images = [ImageField(image_name="first"), ImageField(image_name="second")]

output = ImageCollectionInvocation(collection=images).invoke(MagicMock())

assert output.collection == images


def test_image_collection_invocation_appends_direct_images_after_chained_collection() -> None:
chained_images = [ImageField(image_name="chained")]
direct_images = [ImageField(image_name="direct_1"), ImageField(image_name="direct_2")]

output = ImageCollectionInvocation(collection=chained_images, images=direct_images).invoke(MagicMock())

assert output.collection == [*chained_images, *direct_images]


def test_image_collection_invocation_supports_empty_direct_images() -> None:
chained_images = [ImageField(image_name="chained")]

output = ImageCollectionInvocation(collection=chained_images, images=None).invoke(MagicMock())

assert output.collection == chained_images


def test_image_collection_invocation_outputs_empty_collection_when_inputs_are_empty() -> None:
output = ImageCollectionInvocation(collection=None, images=None).invoke(MagicMock())

assert output.collection == []


def test_oklab_unsharp_mask_invocation_preserves_alpha_and_sharpens_lightness_only() -> None:
input_image = Image.new("RGBA", (3, 1))
input_image.putdata(
Expand Down
Loading