Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions invokeai/backend/model_manager/configs/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Self,
)

from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, model_validator

from invokeai.backend.model_manager.configs.base import (
Config_Base,
Expand Down Expand Up @@ -39,9 +39,17 @@


class LoraModelDefaultSettings(BaseModel):
weight: float | None = Field(default=None, ge=-1, le=2, description="Default weight for this model")
weight: float | None = Field(default=None, description="Default weight for this model")
weight_min: float | None = Field(default=None, description="Minimum weight slider value for this model")
weight_max: float | None = Field(default=None, description="Maximum weight slider value for this model")
model_config = ConfigDict(extra="forbid")

@model_validator(mode="after")
def _validate_weight_bounds(self) -> "LoraModelDefaultSettings":
if self.weight_min is not None and self.weight_max is not None and self.weight_min >= self.weight_max:
raise ValueError("weight_min must be less than weight_max")
return self


class LoRA_Config_Base(ABC, BaseModel):
"""Base class for LoRA models."""
Expand Down
4 changes: 4 additions & 0 deletions invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,10 @@
},
"lora": {
"weight": "Weight",
"startingWeight": "Starting Weight",
"weightMin": "Min Weight",
"weightMax": "Max Weight",
"weightMinMustBeLessThanMax": "Min weight must be less than max weight",
"removeLoRA": "Remove LoRA"
},
"metadata": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ export const zLoRA = z.object({
id: z.string(),
isEnabled: z.boolean(),
model: zModelIdentifierField,
weight: z.number().gte(-10).lte(10),
weight: z.number(),
});
export type LoRA = z.infer<typeof zLoRA>;

Expand Down
27 changes: 20 additions & 7 deletions invokeai/frontend/web/src/features/lora/components/LoRACard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';

const MARKS = [-1, 0, 1, 2];
import { isLoRAModelConfig } from 'services/api/types';

export const LoRACard = memo((props: { id: string }) => {
const selectLoRA = useMemo(() => buildSelectLoRA(props.id), [props.id]);
Expand Down Expand Up @@ -58,6 +57,20 @@ const LoRAContent = memo(({ lora }: { lora: LoRA }) => {
dispatch(loraDeleted({ id: lora.id }));
}, [dispatch, lora.id]);

const loraDefaults = loraConfig && isLoRAModelConfig(loraConfig) ? loraConfig.default_settings : null;
const sliderMin = loraDefaults?.weight_min ?? DEFAULT_LORA_WEIGHT_CONFIG.sliderMin;
const sliderMax = loraDefaults?.weight_max ?? DEFAULT_LORA_WEIGHT_CONFIG.sliderMax;
const numberInputMin = Math.min(sliderMin, DEFAULT_LORA_WEIGHT_CONFIG.numberInputMin);
const numberInputMax = Math.max(sliderMax, DEFAULT_LORA_WEIGHT_CONFIG.numberInputMax);

const marks = useMemo(() => {
if (sliderMin >= sliderMax) {
return [sliderMin, sliderMax];
}
const mid = (sliderMin + sliderMax) / 2;
return [sliderMin, mid, sliderMax];
}, [sliderMin, sliderMax]);

return (
<Card variant="lora">
<CardHeader>
Expand All @@ -82,19 +95,19 @@ const LoRAContent = memo(({ lora }: { lora: LoRA }) => {
<CompositeSlider
value={lora.weight}
onChange={handleChange}
min={DEFAULT_LORA_WEIGHT_CONFIG.sliderMin}
max={DEFAULT_LORA_WEIGHT_CONFIG.sliderMax}
min={sliderMin}
max={sliderMax}
step={DEFAULT_LORA_WEIGHT_CONFIG.coarseStep}
fineStep={DEFAULT_LORA_WEIGHT_CONFIG.fineStep}
marks={MARKS}
marks={marks}
defaultValue={DEFAULT_LORA_WEIGHT_CONFIG.initial}
isDisabled={!lora.isEnabled}
/>
<CompositeNumberInput
value={lora.weight}
onChange={handleChange}
min={DEFAULT_LORA_WEIGHT_CONFIG.numberInputMin}
max={DEFAULT_LORA_WEIGHT_CONFIG.numberInputMax}
min={numberInputMin}
max={numberInputMax}
step={DEFAULT_LORA_WEIGHT_CONFIG.coarseStep}
fineStep={DEFAULT_LORA_WEIGHT_CONFIG.fineStep}
w={20}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ export const useLoRAModelDefaultSettings = (modelConfig: LoRAModelConfig) => {
isEnabled: !isNil(modelConfig?.default_settings?.weight),
value: modelConfig?.default_settings?.weight ?? DEFAULT_LORA_WEIGHT_CONFIG.initial,
},
weightMin: {
isEnabled: !isNil(modelConfig?.default_settings?.weight_min),
value: modelConfig?.default_settings?.weight_min ?? DEFAULT_LORA_WEIGHT_CONFIG.sliderMin,
},
weightMax: {
isEnabled: !isNil(modelConfig?.default_settings?.weight_max),
value: modelConfig?.default_settings?.weight_max ?? DEFAULT_LORA_WEIGHT_CONFIG.sliderMax,
},
};
}, [modelConfig?.default_settings]);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { DEFAULT_LORA_WEIGHT_CONFIG } from 'features/controlLayers/store/lorasSlice';
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
import { memo, useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useController, useWatch } from 'react-hook-form';
import { useTranslation } from 'react-i18next';

import type { LoRAModelDefaultSettingsFormData } from './LoRAModelDefaultSettings';

const MARKS = [-1, 0, 1, 2];

type DefaultWeight = LoRAModelDefaultSettingsFormData['weight'];

export const DefaultWeight = memo((props: UseControllerProps<LoRAModelDefaultSettingsFormData, 'weight'>) => {
const { field } = useController(props);
const { t } = useTranslation();

const weightMin = useWatch({ control: props.control, name: 'weightMin' });
const weightMax = useWatch({ control: props.control, name: 'weightMax' });

const sliderMin = weightMin?.isEnabled ? weightMin.value : DEFAULT_LORA_WEIGHT_CONFIG.sliderMin;
const sliderMax = weightMax?.isEnabled ? weightMax.value : DEFAULT_LORA_WEIGHT_CONFIG.sliderMax;
const numberInputMin = Math.min(sliderMin, DEFAULT_LORA_WEIGHT_CONFIG.numberInputMin);
const numberInputMax = Math.max(sliderMax, DEFAULT_LORA_WEIGHT_CONFIG.numberInputMax);

const onChange = useCallback(
(v: number) => {
const updatedValue = {
Expand All @@ -40,26 +44,16 @@ export const DefaultWeight = memo((props: UseControllerProps<LoRAModelDefaultSet
<FormControl flexDir="column" gap={2} alignItems="flex-start">
<Flex justifyContent="space-between" w="full">
<InformationalPopover feature="loraWeight">
<FormLabel>{t('lora.weight')}</FormLabel>
<FormLabel>{t('lora.startingWeight')}</FormLabel>
</InformationalPopover>
<SettingToggle control={props.control} name="weight" />
</Flex>

<Flex w="full" gap={4}>
<CompositeSlider
value={value}
min={DEFAULT_LORA_WEIGHT_CONFIG.sliderMin}
max={DEFAULT_LORA_WEIGHT_CONFIG.sliderMax}
step={DEFAULT_LORA_WEIGHT_CONFIG.coarseStep}
fineStep={DEFAULT_LORA_WEIGHT_CONFIG.fineStep}
onChange={onChange}
marks={MARKS}
isDisabled={isDisabled}
/>
<CompositeNumberInput
value={value}
min={DEFAULT_LORA_WEIGHT_CONFIG.numberInputMin}
max={DEFAULT_LORA_WEIGHT_CONFIG.numberInputMax}
min={numberInputMin}
max={numberInputMax}
step={DEFAULT_LORA_WEIGHT_CONFIG.coarseStep}
fineStep={DEFAULT_LORA_WEIGHT_CONFIG.fineStep}
onChange={onChange}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { DEFAULT_LORA_WEIGHT_CONFIG } from 'features/controlLayers/store/lorasSlice';
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
import { memo, useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';

import type { LoRAModelDefaultSettingsFormData } from './LoRAModelDefaultSettings';

export const DefaultWeightMax = memo((props: UseControllerProps<LoRAModelDefaultSettingsFormData, 'weightMax'>) => {
const { field } = useController(props);
const { t } = useTranslation();

const onChange = useCallback(
(v: number) => {
field.onChange({ ...field.value, value: v });
},
[field]
);

const value = useMemo(() => field.value.value, [field.value]);
const isDisabled = useMemo(() => !field.value.isEnabled, [field.value]);

return (
<FormControl flexDir="column" gap={2} alignItems="flex-start">
<Flex justifyContent="space-between" w="full">
<FormLabel>{t('lora.weightMax')}</FormLabel>
<SettingToggle control={props.control} name="weightMax" />
</Flex>
<Flex w="full" gap={4}>
<CompositeNumberInput
value={value}
min={DEFAULT_LORA_WEIGHT_CONFIG.numberInputMin}
max={DEFAULT_LORA_WEIGHT_CONFIG.numberInputMax}
step={DEFAULT_LORA_WEIGHT_CONFIG.coarseStep}
fineStep={DEFAULT_LORA_WEIGHT_CONFIG.fineStep}
onChange={onChange}
isDisabled={isDisabled}
/>
</Flex>
</FormControl>
);
});

DefaultWeightMax.displayName = 'DefaultWeightMax';
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import { CompositeNumberInput, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { DEFAULT_LORA_WEIGHT_CONFIG } from 'features/controlLayers/store/lorasSlice';
import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle';
import { memo, useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useTranslation } from 'react-i18next';

import type { LoRAModelDefaultSettingsFormData } from './LoRAModelDefaultSettings';

export const DefaultWeightMin = memo((props: UseControllerProps<LoRAModelDefaultSettingsFormData, 'weightMin'>) => {
const { field } = useController(props);
const { t } = useTranslation();

const onChange = useCallback(
(v: number) => {
field.onChange({ ...field.value, value: v });
},
[field]
);

const value = useMemo(() => field.value.value, [field.value]);
const isDisabled = useMemo(() => !field.value.isEnabled, [field.value]);

return (
<FormControl flexDir="column" gap={2} alignItems="flex-start">
<Flex justifyContent="space-between" w="full">
<FormLabel>{t('lora.weightMin')}</FormLabel>
<SettingToggle control={props.control} name="weightMin" />
</Flex>
<Flex w="full" gap={4}>
<CompositeNumberInput
value={value}
min={DEFAULT_LORA_WEIGHT_CONFIG.numberInputMin}
max={DEFAULT_LORA_WEIGHT_CONFIG.numberInputMax}
step={DEFAULT_LORA_WEIGHT_CONFIG.coarseStep}
fineStep={DEFAULT_LORA_WEIGHT_CONFIG.fineStep}
onChange={onChange}
isDisabled={isDisabled}
/>
</Flex>
</FormControl>
);
});

DefaultWeightMin.displayName = 'DefaultWeightMin';
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import { Button, Flex, Heading, SimpleGrid } from '@invoke-ai/ui-library';
import { useIsModelManagerEnabled } from 'features/modelManagerV2/hooks/useIsModelManagerEnabled';
import { useLoRAModelDefaultSettings } from 'features/modelManagerV2/hooks/useLoRAModelDefaultSettings';
import { DefaultWeight } from 'features/modelManagerV2/subpanels/ModelPanel/LoRAModelDefaultSettings/DefaultWeight';
import { DefaultWeightMax } from 'features/modelManagerV2/subpanels/ModelPanel/LoRAModelDefaultSettings/DefaultWeightMax';
import { DefaultWeightMin } from 'features/modelManagerV2/subpanels/ModelPanel/LoRAModelDefaultSettings/DefaultWeightMin';
import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings';
import { toast } from 'features/toast/toast';
import { memo, useCallback, useEffect } from 'react';
Expand All @@ -14,6 +16,8 @@ import type { LoRAModelConfig } from 'services/api/types';

export type LoRAModelDefaultSettingsFormData = {
weight: FormField<number>;
weightMin: FormField<number>;
weightMax: FormField<number>;
};

type Props = {
Expand All @@ -28,7 +32,7 @@ export const LoRAModelDefaultSettings = memo(({ modelConfig }: Props) => {

const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation();

const { handleSubmit, control, formState, reset } = useForm<LoRAModelDefaultSettingsFormData>({
const { handleSubmit, control, formState, reset, setError } = useForm<LoRAModelDefaultSettingsFormData>({
defaultValues: defaultSettingsDefaults,
});

Expand All @@ -38,8 +42,23 @@ export const LoRAModelDefaultSettings = memo(({ modelConfig }: Props) => {

const onSubmit = useCallback<SubmitHandler<LoRAModelDefaultSettingsFormData>>(
(data) => {
const weightMin = data.weightMin.isEnabled ? data.weightMin.value : null;
const weightMax = data.weightMax.isEnabled ? data.weightMax.value : null;

if (weightMin !== null && weightMax !== null && weightMin >= weightMax) {
setError('weightMin', { type: 'manual', message: t('lora.weightMinMustBeLessThanMax') });
toast({
id: 'DEFAULT_SETTINGS_SAVE_FAILED',
title: t('lora.weightMinMustBeLessThanMax'),
status: 'error',
});
return;
}

const body = {
weight: data.weight.isEnabled ? data.weight.value : null,
weight_min: weightMin,
weight_max: weightMax,
};

updateModel({
Expand All @@ -65,7 +84,7 @@ export const LoRAModelDefaultSettings = memo(({ modelConfig }: Props) => {
}
});
},
[updateModel, modelConfig.key, t, reset]
[updateModel, modelConfig.key, t, reset, setError]
);

return (
Expand All @@ -86,8 +105,10 @@ export const LoRAModelDefaultSettings = memo(({ modelConfig }: Props) => {
)}
</Flex>

<SimpleGrid columns={2} gap={8}>
<SimpleGrid columns={3} gap={8}>
<DefaultWeight control={control} name="weight" />
<DefaultWeightMin control={control} name="weightMin" />
<DefaultWeightMax control={control} name="weightMax" />
</SimpleGrid>
</>
);
Expand Down
10 changes: 10 additions & 0 deletions invokeai/frontend/web/src/services/api/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19346,6 +19346,16 @@ export type components = {
* @description Default weight for this model
*/
weight?: number | null;
/**
* Weight Min
* @description Minimum weight slider value for this model
*/
weight_min?: number | null;
/**
* Weight Max
* @description Maximum weight slider value for this model
*/
weight_max?: number | null;
};
/** MDControlListOutput */
MDControlListOutput: {
Expand Down
Loading