Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { createSelector } from '@reduxjs/toolkit';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { selectEdges, selectNodes } from 'features/nodes/store/selectors';
import type { Templates } from 'features/nodes/store/types';
import { resolveConnectorSource, resolveConnectorSourceFieldType } from 'features/nodes/store/util/connectorTopology';
import { resolveConnectorDisplayFieldType, resolveConnectorSource } from 'features/nodes/store/util/connectorTopology';
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { isConnectorNode } from 'features/nodes/types/invocation';

Expand Down Expand Up @@ -35,7 +35,7 @@ export const buildSelectEdgeColor = (
}

const sourceType = isConnectorNode(sourceNode)
? resolveConnectorSourceFieldType(sourceNode.id, nodes, edges, templates)
? resolveConnectorDisplayFieldType(sourceNode.id, nodes, edges, templates)
: templates[sourceNode.data.type]?.outputs[sourceHandleId]?.type;

return sourceType ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
Expand Down
Original file line number Diff line number Diff line change
@@ -1,138 +1,190 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Icon } from '@invoke-ai/ui-library';
import type { Node, NodeProps } from '@xyflow/react';
import { Handle, Position } from '@xyflow/react';
import { Box, Icon, Tooltip } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { Handle, type HandleType, type Node, type NodeProps, Position } from '@xyflow/react';
import { useAppSelector } from 'app/store/storeHooks';
import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor';
import {
NODE_IO_HANDLE_HITBOX_INPUT,
NODE_IO_HANDLE_HITBOX_OUTPUT,
NODE_IO_HANDLE_INNER_SX,
} from 'features/nodes/components/flow/nodes/common/nodeIOHandle';
import NonInvocationNodeWrapper from 'features/nodes/components/flow/nodes/common/NonInvocationNodeWrapper';
import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from 'features/nodes/store/util/connectorTopology';
import { NO_DRAG_CLASS } from 'features/nodes/types/constants';
import {
useConnectionErrorTKey,
useIsConnectionInProgress,
useIsConnectionStartField,
} from 'features/nodes/hooks/useFieldConnectionState';
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
import { $templates } from 'features/nodes/store/nodesSlice';
import { selectEdges, selectNodes } from 'features/nodes/store/selectors';
import {
CONNECTOR_INPUT_HANDLE,
CONNECTOR_OUTPUT_HANDLE,
resolveConnectorDisplayFieldType,
} from 'features/nodes/store/util/connectorTopology';
import { HANDLE_TOOLTIP_OPEN_DELAY, NO_DRAG_CLASS } from 'features/nodes/types/constants';
import type { FieldType } from 'features/nodes/types/field';
import { isModelFieldType } from 'features/nodes/types/field';
import type { ConnectorNodeData } from 'features/nodes/types/invocation';
import type { CSSProperties } from 'react';
import { memo } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiDotOutlineFill } from 'react-icons/pi';

const CONNECTOR_NODE_SIZE = 35;
const CONNECTOR_HANDLE_SIZE = 12;
const CONNECTOR_HANDLE_OFFSET = -CONNECTOR_HANDLE_SIZE / 2;

const handleVisualSx = {
position: 'absolute',
/** AnyField-shaped fallback for tooltips when display type is unknown; same shape as connector stubs in `useConnection`. */
const CONNECTOR_FALLBACK_FIELD_TYPE = {
name: 'AnyField',
cardinality: 'SINGLE',
batch: false,
} as const satisfies FieldType;

const CONNECTOR_HANDLE_VERTICAL_ALIGN: CSSProperties = {
top: '50%',
w: 4,
h: 4,
borderRadius: 'full',
borderWidth: 2,
borderColor: 'base.900',
bg: 'base.100',
pointerEvents: 'none',
} satisfies SystemStyleObject;

const inputHandleVisualSx = {
...handleVisualSx,
left: 0,
transform: 'translate(-50%, -50%)',
} satisfies SystemStyleObject;

const outputHandleVisualSx = {
...handleVisualSx,
right: 0,
transform: 'translate(50%, -50%)',
} satisfies SystemStyleObject;

const connectorSx = {
'& .connector-border': {
pointerEvents: 'none',
position: 'absolute',
inset: 0,
borderRadius: 'inherit',
shadow: '0 0 0 1px var(--invoke-colors-base-500)',
},
_hover: {
'& .connector-border': {
shadow: '0 0 0 1px var(--invoke-colors-blue-300)',
},
'&[data-is-selected="true"] .connector-border': {
shadow: '0 0 0 2px var(--invoke-colors-blue-300)',
},
},
'&[data-is-selected="true"] .connector-border': {
shadow: '0 0 0 2px var(--invoke-colors-blue-300)',
},
} satisfies SystemStyleObject;

const handleStyles = {
position: 'absolute',
width: `${CONNECTOR_HANDLE_SIZE}px`,
height: `${CONNECTOR_HANDLE_SIZE}px`,
top: `calc(50% + ${CONNECTOR_HANDLE_OFFSET}px)`,
zIndex: 1,
background: 'none',
border: 'none',
} satisfies CSSProperties;
transform: 'translateY(-50%)',
};

const inputHandleStyles = {
...handleStyles,
left: 0,
transform: 'none',
const CONNECTOR_HANDLE_INPUT_STYLE = {
...NODE_IO_HANDLE_HITBOX_INPUT,
...CONNECTOR_HANDLE_VERTICAL_ALIGN,
} satisfies CSSProperties;

const outputHandleStyles = {
...handleStyles,
right: 0,
transform: 'none',
const CONNECTOR_HANDLE_OUTPUT_STYLE = {
...NODE_IO_HANDLE_HITBOX_OUTPUT,
...CONNECTOR_HANDLE_VERTICAL_ALIGN,
} satisfies CSSProperties;

type PassthroughHandleProps = {
nodeId: string;
rfHandleType: HandleType;
handleId: string;
position: Position;
hitboxStyle: CSSProperties;
displayFieldType: FieldType | null;
fieldColor: string;
fieldTypeName: string;
};

const ConnectorPassthroughHandle = memo(
({
nodeId,
rfHandleType,
handleId,
position,
hitboxStyle,
displayFieldType,
fieldColor,
fieldTypeName,
}: PassthroughHandleProps) => {
const { t } = useTranslation();
const isConnectionInProgress = useIsConnectionInProgress();
const isConnectionStartField = useIsConnectionStartField(nodeId, handleId, rfHandleType);
const connectionError = useConnectionErrorTKey(nodeId, handleId, rfHandleType);

const tooltipLabel = useMemo(() => {
if (isConnectionInProgress && connectionError !== null) {
return t(connectionError);
}
return fieldTypeName;
}, [connectionError, fieldTypeName, isConnectionInProgress, t]);

const innerProps = useMemo(() => {
const shape =
displayFieldType !== null
? {
'data-cardinality': displayFieldType.cardinality,
'data-is-batch-field': displayFieldType.batch,
'data-is-model-field': isModelFieldType(displayFieldType),
}
: {
'data-cardinality': 'SINGLE' as const,
'data-is-batch-field': false,
'data-is-model-field': false,
};

return {
sx: NODE_IO_HANDLE_INNER_SX,
...shape,
'data-is-connection-in-progress': isConnectionInProgress,
'data-is-connection-start-field': isConnectionInProgress ? isConnectionStartField : false,
'data-is-connection-valid': isConnectionInProgress ? connectionError === null : false,
};
}, [connectionError, displayFieldType, isConnectionInProgress, isConnectionStartField]);

const innerBackgroundColor =
displayFieldType !== null && displayFieldType.cardinality !== 'SINGLE' ? 'base.900' : fieldColor;

return (
<Tooltip label={tooltipLabel} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
<Handle className={NO_DRAG_CLASS} type={rfHandleType} id={handleId} position={position} style={hitboxStyle}>
<Box {...innerProps} backgroundColor={innerBackgroundColor} borderColor={fieldColor} />
</Handle>
</Tooltip>
);
}
);

ConnectorPassthroughHandle.displayName = 'ConnectorPassthroughHandle';

const ConnectorNode = ({ id, selected }: NodeProps<Node<ConnectorNodeData>>) => {
const templates = useStore($templates);
const nodes = useAppSelector(selectNodes);
const edges = useAppSelector(selectEdges);

const displayFieldType = useMemo(
() => resolveConnectorDisplayFieldType(id, nodes, edges, templates),
[id, nodes, edges, templates]
);

const fieldColor = useMemo(() => getFieldColor(displayFieldType), [displayFieldType]);

const fieldTypeLabel = useFieldTypeName(displayFieldType ?? CONNECTOR_FALLBACK_FIELD_TYPE);

return (
<NonInvocationNodeWrapper
nodeId={id}
selected={selected}
width={CONNECTOR_NODE_SIZE}
borderRadius="full"
withChrome={false}
>
<NonInvocationNodeWrapper nodeId={id} selected={selected} width={CONNECTOR_NODE_SIZE} borderRadius="full">
<Box
data-connector-node-context-menu="true"
data-connector-node-id={id}
data-is-selected={selected}
position="relative"
w={CONNECTOR_NODE_SIZE}
h={CONNECTOR_NODE_SIZE}
display="flex"
alignItems="center"
justifyContent="center"
sx={connectorSx}
>
<Handle
className={NO_DRAG_CLASS}
type="target"
id={CONNECTOR_INPUT_HANDLE}
<ConnectorPassthroughHandle
nodeId={id}
rfHandleType="target"
handleId={CONNECTOR_INPUT_HANDLE}
position={Position.Left}
style={inputHandleStyles}
>
<Box sx={inputHandleVisualSx} />
</Handle>
hitboxStyle={CONNECTOR_HANDLE_INPUT_STYLE}
displayFieldType={displayFieldType}
fieldColor={fieldColor}
fieldTypeName={fieldTypeLabel}
/>
<Box
layerStyle="nodeBody"
position="relative"
w={CONNECTOR_NODE_SIZE}
h={CONNECTOR_NODE_SIZE}
display="flex"
alignItems="center"
justifyContent="center"
borderRadius="full"
bg={selected ? 'base.650' : 'base.700'}
>
<Box className="connector-border" />
<Icon as={PiDotOutlineFill} boxSize={5} color={selected ? 'base.50' : 'base.100'} />
<Icon as={PiDotOutlineFill} boxSize={5} color="base.300" />
</Box>
<Handle
className={NO_DRAG_CLASS}
type="source"
id={CONNECTOR_OUTPUT_HANDLE}
<ConnectorPassthroughHandle
nodeId={id}
rfHandleType="source"
handleId={CONNECTOR_OUTPUT_HANDLE}
position={Position.Right}
style={outputHandleStyles}
>
<Box sx={outputHandleVisualSx} />
</Handle>
hitboxStyle={CONNECTOR_HANDLE_OUTPUT_STYLE}
displayFieldType={displayFieldType}
fieldColor={fieldColor}
fieldTypeName={fieldTypeLabel}
/>
</Box>
</NonInvocationNodeWrapper>
);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Tooltip } from '@invoke-ai/ui-library';
import { Handle, Position } from '@xyflow/react';
import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor';
import {
NODE_IO_HANDLE_HITBOX_INPUT,
NODE_IO_HANDLE_INNER_SX,
} from 'features/nodes/components/flow/nodes/common/nodeIOHandle';
import {
useConnectionErrorTKey,
useIsConnectionInProgress,
Expand All @@ -12,7 +15,6 @@ import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import type { FieldInputTemplate } from 'features/nodes/types/field';
import { isModelFieldType } from 'features/nodes/types/field';
import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';

Expand All @@ -21,46 +23,6 @@ type Props = {
fieldName: string;
};

const sx = {
position: 'relative',
width: 'full',
height: 'full',
borderStyle: 'solid',
borderWidth: 4,
pointerEvents: 'none',
'&[data-cardinality="SINGLE"]': {
borderWidth: 0,
},
borderRadius: '100%',
'&[data-is-model-field="true"], &[data-is-batch-field="true"]': {
borderRadius: 4,
},
'&[data-is-batch-field="true"]': {
transform: 'rotate(45deg)',
},
'&[data-is-connection-in-progress="true"][data-is-connection-start-field="false"][data-is-connection-valid="false"]':
{
filter: 'opacity(0.4) grayscale(0.7)',
cursor: 'not-allowed',
},
'&[data-is-connection-in-progress="true"][data-is-connection-start-field="true"][data-is-connection-valid="false"]': {
cursor: 'grab',
},
'&[data-is-connection-in-progress="false"] &[data-is-connection-valid="true"]': {
cursor: 'crosshair',
},
} satisfies SystemStyleObject;

const handleStyles = {
position: 'absolute',
width: '1rem',
height: '1rem',
zIndex: 1,
background: 'none',
border: 'none',
insetInlineStart: '-0.5rem',
} satisfies CSSProperties;

export const InputFieldHandle = memo(({ nodeId, fieldName }: Props) => {
const fieldTemplate = useInputFieldTemplateOrThrow(fieldName);
const fieldTypeName = useFieldTypeName(fieldTemplate.type);
Expand Down Expand Up @@ -107,9 +69,9 @@ type HandleCommonProps = {
const IdleHandle = memo(({ fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => {
return (
<Tooltip label={fieldTypeName} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
<Handle type="target" id={fieldTemplate.name} position={Position.Left} style={handleStyles}>
<Handle type="target" id={fieldTemplate.name} position={Position.Left} style={NODE_IO_HANDLE_HITBOX_INPUT}>
<Box
sx={sx}
sx={NODE_IO_HANDLE_INNER_SX}
data-cardinality={fieldTemplate.type.cardinality}
data-is-batch-field={fieldTemplate.type.batch}
data-is-model-field={isModelField}
Expand Down Expand Up @@ -140,9 +102,9 @@ const ConnectionInProgressHandle = memo(

return (
<Tooltip label={tooltip} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
<Handle type="target" id={fieldTemplate.name} position={Position.Left} style={handleStyles}>
<Handle type="target" id={fieldTemplate.name} position={Position.Left} style={NODE_IO_HANDLE_HITBOX_INPUT}>
<Box
sx={sx}
sx={NODE_IO_HANDLE_INNER_SX}
data-cardinality={fieldTemplate.type.cardinality}
data-is-batch-field={fieldTemplate.type.batch}
data-is-model-field={isModelField}
Expand Down
Loading
Loading