Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions invokeai/backend/model_manager/configs/lora.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,72 @@
from abc import ABC
from pathlib import Path
from typing import (
Any,
Literal,
Self,
)

from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, model_validator

from invokeai.backend.model_manager.configs.base import (
Config_Base,
)
from invokeai.backend.model_manager.configs.controlnet import ControlAdapterDefaultSettings
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
raise_for_override_fields,
raise_if_not_dir,
raise_if_not_file,
state_dict_has_any_keys_ending_with,
state_dict_has_any_keys_starting_with,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.omi import flux_dev_1_lora, stable_diffusion_xl_1_lora
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
Flux2VariantType,
FluxLoRAFormat,
ModelFormat,
ModelType,
ZImageVariantType,
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length
from invokeai.backend.patches.lora_conversions.anima_lora_constants import (
has_cosmos_dit_kohya_keys,
has_cosmos_dit_peft_keys,
)
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control

Check failure on line 38 in invokeai/backend/model_manager/configs/lora.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (I001)

invokeai/backend/model_manager/configs/lora.py:1:1: I001 Import block is un-sorted or un-formatted


# Defaults used to compute the effective slider range when one or both bounds
# are unset. These intentionally mirror the frontend's DEFAULT_LORA_WEIGHT_CONFIG
# in invokeai/frontend/web/src/features/controlLayers/store/lorasSlice.ts so that
# bound/weight validation produces the same result whether it runs in the form
# or in this pydantic model.
_DEFAULT_LORA_WEIGHT_SLIDER_MIN = -1.0
_DEFAULT_LORA_WEIGHT_SLIDER_MAX = 2.0


class LoraModelDefaultSettings(BaseModel):
weight: float | None = Field(default=None, ge=-1, le=2, description="Default weight for this model")
weight: float | None = Field(default=None, description="Default weight for this model")
weight_min: float | None = Field(default=None, description="Minimum weight slider value for this model")
weight_max: float | None = Field(default=None, description="Maximum weight slider value for this model")
model_config = ConfigDict(extra="forbid")

@model_validator(mode="after")
def _validate_weight_bounds(self) -> "LoraModelDefaultSettings":
effective_min = self.weight_min if self.weight_min is not None else _DEFAULT_LORA_WEIGHT_SLIDER_MIN
effective_max = self.weight_max if self.weight_max is not None else _DEFAULT_LORA_WEIGHT_SLIDER_MAX
if effective_min >= effective_max:
raise ValueError(
f"effective weight range is invalid: min ({effective_min}) must be less than max ({effective_max})"
)
if self.weight is not None and not (effective_min <= self.weight <= effective_max):
raise ValueError(
f"weight ({self.weight}) must be within the effective range [{effective_min}, {effective_max}]"
)
return self


class LoRA_Config_Base(ABC, BaseModel):
"""Base class for LoRA models."""
Expand Down
5 changes: 5 additions & 0 deletions invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,11 @@
},
"lora": {
"weight": "Weight",
"startingWeight": "Starting Weight",
"weightMin": "Min Weight",
"weightMax": "Max Weight",
"weightMinMustBeLessThanMax": "Min weight must be less than max weight",
"weightOutOfRange": "Starting weight must be within the min/max range",
"removeLoRA": "Remove LoRA"
},
"metadata": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ export const zLoRA = z.object({
id: z.string(),
isEnabled: z.boolean(),
model: zModelIdentifierField,
weight: z.number().gte(-10).lte(10),
weight: z.number(),
});
export type LoRA = z.infer<typeof zLoRA>;

Expand Down
27 changes: 20 additions & 7 deletions invokeai/frontend/web/src/features/lora/components/LoRACard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';

const MARKS = [-1, 0, 1, 2];
import { isLoRAModelConfig } from 'services/api/types';

export const LoRACard = memo((props: { id: string }) => {
const selectLoRA = useMemo(() => buildSelectLoRA(props.id), [props.id]);
Expand Down Expand Up @@ -58,6 +57,20 @@ const LoRAContent = memo(({ lora }: { lora: LoRA }) => {
dispatch(loraDeleted({ id: lora.id }));
}, [dispatch, lora.id]);

const loraDefaults = loraConfig && isLoRAModelConfig(loraConfig) ? loraConfig.default_settings : null;
const sliderMin = loraDefaults?.weight_min ?? DEFAULT_LORA_WEIGHT_CONFIG.sliderMin;
const sliderMax = loraDefaults?.weight_max ?? DEFAULT_LORA_WEIGHT_CONFIG.sliderMax;
const numberInputMin = Math.min(sliderMin, DEFAULT_LORA_WEIGHT_CONFIG.numberInputMin);
const numberInputMax = Math.max(sliderMax, DEFAULT_LORA_WEIGHT_CONFIG.numberInputMax);

const marks = useMemo(() => {
if (sliderMin >= sliderMax) {
return [sliderMin, sliderMax];
}
const mid = (sliderMin + sliderMax) / 2;
return [sliderMin, mid, sliderMax];
}, [sliderMin, sliderMax]);

return (
<Card variant="lora">
<CardHeader>
Expand All @@ -82,19 +95,19 @@ const LoRAContent = memo(({ lora }: { lora: LoRA }) => {
<CompositeSlider
value={lora.weight}
onChange={handleChange}
min={DEFAULT_LORA_WEIGHT_CONFIG.sliderMin}
max={DEFAULT_LORA_WEIGHT_CONFIG.sliderMax}
min={sliderMin}
max={sliderMax}
step={DEFAULT_LORA_WEIGHT_CONFIG.coarseStep}
fineStep={DEFAULT_LORA_WEIGHT_CONFIG.fineStep}
marks={MARKS}
marks={marks}
defaultValue={DEFAULT_LORA_WEIGHT_CONFIG.initial}
isDisabled={!lora.isEnabled}
/>
<CompositeNumberInput
value={lora.weight}
onChange={handleChange}
min={DEFAULT_LORA_WEIGHT_CONFIG.numberInputMin}
max={DEFAULT_LORA_WEIGHT_CONFIG.numberInputMax}
min={numberInputMin}
max={numberInputMax}
step={DEFAULT_LORA_WEIGHT_CONFIG.coarseStep}
fineStep={DEFAULT_LORA_WEIGHT_CONFIG.fineStep}
w={20}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ export const useLoRAModelDefaultSettings = (modelConfig: LoRAModelConfig) => {
isEnabled: !isNil(modelConfig?.default_settings?.weight),
value: modelConfig?.default_settings?.weight ?? DEFAULT_LORA_WEIGHT_CONFIG.initial,
},
weightMin: {
isEnabled: !isNil(modelConfig?.default_settings?.weight_min),
value: modelConfig?.default_settings?.weight_min ?? DEFAULT_LORA_WEIGHT_CONFIG.sliderMin,
},
weightMax: {
isEnabled: !isNil(modelConfig?.default_settings?.weight_max),
value: modelConfig?.default_settings?.weight_max ?? DEFAULT_LORA_WEIGHT_CONFIG.sliderMax,
},
};
}, [modelConfig?.default_settings]);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { DEFAULT_LORA_WEIGHT_CONFIG } from 'features/controlLayers/store/lorasSlice';
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
import { memo, useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useController, useWatch } from 'react-hook-form';
import { useTranslation } from 'react-i18next';

import type { LoRAModelDefaultSettingsFormData } from './LoRAModelDefaultSettings';

const MARKS = [-1, 0, 1, 2];

type DefaultWeight = LoRAModelDefaultSettingsFormData['weight'];

export const DefaultWeight = memo((props: UseControllerProps<LoRAModelDefaultSettingsFormData, 'weight'>) => {
const { field } = useController(props);
const { t } = useTranslation();

const weightMin = useWatch({ control: props.control, name: 'weightMin' });
const weightMax = useWatch({ control: props.control, name: 'weightMax' });

const sliderMin = weightMin?.isEnabled ? weightMin.value : DEFAULT_LORA_WEIGHT_CONFIG.sliderMin;
const sliderMax = weightMax?.isEnabled ? weightMax.value : DEFAULT_LORA_WEIGHT_CONFIG.sliderMax;
const numberInputMin = Math.min(sliderMin, DEFAULT_LORA_WEIGHT_CONFIG.numberInputMin);
const numberInputMax = Math.max(sliderMax, DEFAULT_LORA_WEIGHT_CONFIG.numberInputMax);

const onChange = useCallback(
(v: number) => {
const updatedValue = {
Expand All @@ -40,26 +44,16 @@ export const DefaultWeight = memo((props: UseControllerProps<LoRAModelDefaultSet
<FormControl flexDir="column" gap={2} alignItems="flex-start">
<Flex justifyContent="space-between" w="full">
<InformationalPopover feature="loraWeight">
<FormLabel>{t('lora.weight')}</FormLabel>
<FormLabel>{t('lora.startingWeight')}</FormLabel>
</InformationalPopover>
<SettingToggle control={props.control} name="weight" />
</Flex>

<Flex w="full" gap={4}>
<CompositeSlider
value={value}
min={DEFAULT_LORA_WEIGHT_CONFIG.sliderMin}
max={DEFAULT_LORA_WEIGHT_CONFIG.sliderMax}
step={DEFAULT_LORA_WEIGHT_CONFIG.coarseStep}
fineStep={DEFAULT_LORA_WEIGHT_CONFIG.fineStep}
onChange={onChange}
marks={MARKS}
isDisabled={isDisabled}
/>
<CompositeNumberInput
value={value}
min={DEFAULT_LORA_WEIGHT_CONFIG.numberInputMin}
max={DEFAULT_LORA_WEIGHT_CONFIG.numberInputMax}
min={numberInputMin}
max={numberInputMax}
step={DEFAULT_LORA_WEIGHT_CONFIG.coarseStep}
fineStep={DEFAULT_LORA_WEIGHT_CONFIG.fineStep}
onChange={onChange}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { DEFAULT_LORA_WEIGHT_CONFIG } from 'features/controlLayers/store/lorasSlice';
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
import { memo, useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';

import type { LoRAModelDefaultSettingsFormData } from './LoRAModelDefaultSettings';

export const DefaultWeightMax = memo((props: UseControllerProps<LoRAModelDefaultSettingsFormData, 'weightMax'>) => {
const { field } = useController(props);
const { t } = useTranslation();

const onChange = useCallback(
(v: number) => {
field.onChange({ ...field.value, value: v });
},
[field]
);

const value = useMemo(() => field.value.value, [field.value]);
const isDisabled = useMemo(() => !field.value.isEnabled, [field.value]);

return (
<FormControl flexDir="column" gap={2} alignItems="flex-start">
<Flex justifyContent="space-between" w="full">
<FormLabel>{t('lora.weightMax')}</FormLabel>
<SettingToggle control={props.control} name="weightMax" />
</Flex>
<Flex w="full" gap={4}>
<CompositeNumberInput
value={value}
min={DEFAULT_LORA_WEIGHT_CONFIG.numberInputMin}
max={DEFAULT_LORA_WEIGHT_CONFIG.numberInputMax}
step={DEFAULT_LORA_WEIGHT_CONFIG.coarseStep}
fineStep={DEFAULT_LORA_WEIGHT_CONFIG.fineStep}
onChange={onChange}
isDisabled={isDisabled}
/>
</Flex>
</FormControl>
);
});

DefaultWeightMax.displayName = 'DefaultWeightMax';
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { DEFAULT_LORA_WEIGHT_CONFIG } from 'features/controlLayers/store/lorasSlice';
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
import { memo, useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';

import type { LoRAModelDefaultSettingsFormData } from './LoRAModelDefaultSettings';

export const DefaultWeightMin = memo((props: UseControllerProps<LoRAModelDefaultSettingsFormData, 'weightMin'>) => {
const { field } = useController(props);
const { t } = useTranslation();

const onChange = useCallback(
(v: number) => {
field.onChange({ ...field.value, value: v });
},
[field]
);

const value = useMemo(() => field.value.value, [field.value]);
const isDisabled = useMemo(() => !field.value.isEnabled, [field.value]);

return (
<FormControl flexDir="column" gap={2} alignItems="flex-start">
<Flex justifyContent="space-between" w="full">
<FormLabel>{t('lora.weightMin')}</FormLabel>
<SettingToggle control={props.control} name="weightMin" />
</Flex>
<Flex w="full" gap={4}>
<CompositeNumberInput
value={value}
min={DEFAULT_LORA_WEIGHT_CONFIG.numberInputMin}
max={DEFAULT_LORA_WEIGHT_CONFIG.numberInputMax}
step={DEFAULT_LORA_WEIGHT_CONFIG.coarseStep}
fineStep={DEFAULT_LORA_WEIGHT_CONFIG.fineStep}
onChange={onChange}
isDisabled={isDisabled}
/>
</Flex>
</FormControl>
);
});

DefaultWeightMin.displayName = 'DefaultWeightMin';
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import { Button, Flex, Heading, SimpleGrid } from '@invoke-ai/ui-library';
import { DEFAULT_LORA_WEIGHT_CONFIG } from 'features/controlLayers/store/lorasSlice';
import { useIsModelManagerEnabled } from 'features/modelManagerV2/hooks/useIsModelManagerEnabled';
import { useLoRAModelDefaultSettings } from 'features/modelManagerV2/hooks/useLoRAModelDefaultSettings';
import { DefaultWeight } from 'features/modelManagerV2/subpanels/ModelPanel/LoRAModelDefaultSettings/DefaultWeight';
import { DefaultWeightMax } from 'features/modelManagerV2/subpanels/ModelPanel/LoRAModelDefaultSettings/DefaultWeightMax';
import { DefaultWeightMin } from 'features/modelManagerV2/subpanels/ModelPanel/LoRAModelDefaultSettings/DefaultWeightMin';
import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings';
import { toast } from 'features/toast/toast';
import { memo, useCallback, useEffect } from 'react';
Expand All @@ -14,6 +17,8 @@ import type { LoRAModelConfig } from 'services/api/types';

export type LoRAModelDefaultSettingsFormData = {
weight: FormField<number>;
weightMin: FormField<number>;
weightMax: FormField<number>;
};

type Props = {
Expand All @@ -28,7 +33,7 @@ export const LoRAModelDefaultSettings = memo(({ modelConfig }: Props) => {

const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation();

const { handleSubmit, control, formState, reset } = useForm<LoRAModelDefaultSettingsFormData>({
const { handleSubmit, control, formState, reset, setError } = useForm<LoRAModelDefaultSettingsFormData>({
defaultValues: defaultSettingsDefaults,
});

Expand All @@ -38,8 +43,40 @@ export const LoRAModelDefaultSettings = memo(({ modelConfig }: Props) => {

const onSubmit = useCallback<SubmitHandler<LoRAModelDefaultSettingsFormData>>(
(data) => {
const weightMin = data.weightMin.isEnabled ? data.weightMin.value : null;
const weightMax = data.weightMax.isEnabled ? data.weightMax.value : null;
const weight = data.weight.isEnabled ? data.weight.value : null;

// Compute effective bounds the same way LoRACard does when rendering the slider,
// so partial bounds (only min or only max) get validated against the fallback
// default for the other side.
const effectiveMin = weightMin ?? DEFAULT_LORA_WEIGHT_CONFIG.sliderMin;
const effectiveMax = weightMax ?? DEFAULT_LORA_WEIGHT_CONFIG.sliderMax;

if (effectiveMin >= effectiveMax) {
setError('weightMin', { type: 'manual', message: t('lora.weightMinMustBeLessThanMax') });
toast({
id: 'DEFAULT_SETTINGS_SAVE_FAILED',
title: t('lora.weightMinMustBeLessThanMax'),
status: 'error',
});
return;
}

if (weight !== null && (weight < effectiveMin || weight > effectiveMax)) {
setError('weight', { type: 'manual', message: t('lora.weightOutOfRange') });
toast({
id: 'DEFAULT_SETTINGS_SAVE_FAILED',
title: t('lora.weightOutOfRange'),
status: 'error',
});
return;
}

const body = {
weight: data.weight.isEnabled ? data.weight.value : null,
weight,
weight_min: weightMin,
weight_max: weightMax,
};

updateModel({
Expand All @@ -65,7 +102,7 @@ export const LoRAModelDefaultSettings = memo(({ modelConfig }: Props) => {
}
});
},
[updateModel, modelConfig.key, t, reset]
[updateModel, modelConfig.key, t, reset, setError]
);

return (
Expand All @@ -86,8 +123,10 @@ export const LoRAModelDefaultSettings = memo(({ modelConfig }: Props) => {
)}
</Flex>

<SimpleGrid columns={2} gap={8}>
<SimpleGrid columns={3} gap={8}>
<DefaultWeight control={control} name="weight" />
<DefaultWeightMin control={control} name="weightMin" />
<DefaultWeightMax control={control} name="weightMax" />
</SimpleGrid>
</>
);
Expand Down
Loading
Loading