diff --git a/application/ui/src/features/models/train-model/model-architectures-list/model-architectures-list.component.tsx b/application/ui/src/features/models/train-model/model-architectures-list/model-architectures-list.component.tsx index f058ce7f756..065b83c3289 100644 --- a/application/ui/src/features/models/train-model/model-architectures-list/model-architectures-list.component.tsx +++ b/application/ui/src/features/models/train-model/model-architectures-list/model-architectures-list.component.tsx @@ -13,13 +13,22 @@ import { getRecommendedArchitectures } from './utils'; const SHOW_MORE_THRESHOLD = 4; export const ModelArchitecturesList = () => { - const [showMore, setShowMore] = useState(false); const { modelArchitectures, selectedModelArchitectureId, onSelectModelArchitectureId } = useTrainModelState(); const recommendedArchitectures = getRecommendedArchitectures(modelArchitectures); const collapsedArchitectures = recommendedArchitectures.slice(0, SHOW_MORE_THRESHOLD); const canToggleArchitecturesList = modelArchitectures.length > SHOW_MORE_THRESHOLD; + const [showMore, setShowMore] = useState(() => { + if (selectedModelArchitectureId !== null) { + const isSelectedInCollapsed = collapsedArchitectures.some( + (arch) => arch.id === selectedModelArchitectureId + ); + return !isSelectedInCollapsed; + } + return false; + }); + return ( {showMore ? ( diff --git a/application/ui/src/features/models/train-model/model-architectures-list/model-architectures-list.test.tsx b/application/ui/src/features/models/train-model/model-architectures-list/model-architectures-list.test.tsx index eea41e9acd7..86510a70526 100644 --- a/application/ui/src/features/models/train-model/model-architectures-list/model-architectures-list.test.tsx +++ b/application/ui/src/features/models/train-model/model-architectures-list/model-architectures-list.test.tsx @@ -8,6 +8,7 @@ import { render } from 'test-utils/render'; import { ModelArchitecturesList } from './model-architectures-list.component'; const mockOnSelectModelArchitectureId = vi.hoisted(() => vi.fn()); +const mockSelectedModelArchitectureId = vi.hoisted(() => ({ value: null as string | null })); const mockModelArchitectures = [ getMockedModelArchitecture({ id: 'arch-1', name: 'Alpha Model' }), @@ -20,7 +21,7 @@ const mockModelArchitectures = [ vi.mock('../train-model-provider.component', () => ({ useTrainModelState: () => ({ modelArchitectures: mockModelArchitectures, - selectedModelArchitectureId: null, + selectedModelArchitectureId: mockSelectedModelArchitectureId.value, onSelectModelArchitectureId: mockOnSelectModelArchitectureId, }), })); @@ -28,6 +29,7 @@ vi.mock('../train-model-provider.component', () => ({ describe('ModelArchitecturesList', () => { beforeEach(() => { mockOnSelectModelArchitectureId.mockReset(); + mockSelectedModelArchitectureId.value = null; }); it('does not render the sort control in the default (recommended) view', () => { @@ -53,4 +55,12 @@ describe('ModelArchitecturesList', () => { fireEvent.click(screen.getByRole('button', { name: 'Show less' })); expect(screen.queryByRole('button', { name: /Sort Models by:/i })).not.toBeInTheDocument(); }); + + it('renders expanded list by default if a non-default architecture is already selected', () => { + mockSelectedModelArchitectureId.value = mockModelArchitectures[mockModelArchitectures.length - 1].id; + render(); + + expect(screen.getByRole('button', { name: 'Show less' })).toBeVisible(); + expect(screen.getByRole('button', { name: /Sort Models by:/i })).toBeVisible(); + }); }); diff --git a/application/ui/src/features/models/train-model/train-model-provider.component.tsx b/application/ui/src/features/models/train-model/train-model-provider.component.tsx index d393091fce0..942b5180d5d 100644 --- a/application/ui/src/features/models/train-model/train-model-provider.component.tsx +++ b/application/ui/src/features/models/train-model/train-model-provider.component.tsx @@ -42,6 +42,9 @@ export type TrainModelContextProps = { isAdvancedSettingsMode: boolean; onToggleAdvancedSettingsMode: (isAdvancedSettingsMode: boolean) => void; + showMoreModelArchitectures: boolean; + setShowMoreModelArchitectures: (showMore: boolean) => void; + trainingConfiguration: TrainingConfiguration | undefined; defaultTrainingConfiguration: TrainingConfiguration | undefined; onTrainingConfigurationChange: Dispatch>; @@ -127,6 +130,7 @@ export const TrainModelProvider = ({ children }: TrainModelProviderProps) => { ); const [isAdvancedSettingsMode, setIsAdvancedSettingsMode] = useState(false); + const [showMoreModelArchitectures, setShowMoreModelArchitectures] = useState(false); const modelRevisions = useMemo(() => { return getModelRevisionsForArchitecture(allModelRevisions, selectedModelArchitectureId); @@ -167,6 +171,9 @@ export const TrainModelProvider = ({ children }: TrainModelProviderProps) => { isAdvancedSettingsMode, onToggleAdvancedSettingsMode: setIsAdvancedSettingsMode, + showMoreModelArchitectures, + setShowMoreModelArchitectures, + trainingConfiguration, defaultTrainingConfiguration, onTrainingConfigurationChange: setTrainingConfiguration, diff --git a/application/ui/tests/models/model-training.spec.ts b/application/ui/tests/models/model-training.spec.ts index 666a18413e8..1e27a514ad3 100644 --- a/application/ui/tests/models/model-training.spec.ts +++ b/application/ui/tests/models/model-training.spec.ts @@ -23,10 +23,23 @@ import { expect, http, test } from '../fixtures'; import { MOCKED_MODEL_TRAINING_CONFIGURATION, MOCKED_TRAINING_CONFIGURATION } from './mocks'; const mockedModelArchitectures = [ - getMockedModelArchitecture({ id: 'Object_Detection_SSD', name: 'Object_Detection_SSD' }), + getMockedModelArchitecture({ + id: 'Object_Detection_SSD', + name: 'Object_Detection_SSD', + performanceCategory: 'balance', + }), getMockedModelArchitecture({ id: 'Custom_Object_Detection_Gen3_ATSS', name: 'Custom_Object_Detection_Gen3_ATSS', + performanceCategory: 'speed', + }), + getMockedModelArchitecture({ id: 'arch-3', name: 'Gamma Model', performanceCategory: 'accuracy' }), + getMockedModelArchitecture({ id: 'arch-4', name: 'Delta Model', performanceCategory: undefined }), + getMockedModelArchitecture({ id: 'arch-5', name: 'Epsilon Model', performanceCategory: undefined }), + getMockedModelArchitecture({ + id: 'Object_Detection_YOLOX_X', + name: 'Object_Detection_YOLOX_X', + performanceCategory: undefined, }), ]; @@ -265,6 +278,39 @@ test.describe('Model training flow', () => { ); }); + test('Keeps model architectures list expanded when navigating back from advanced settings if non-default model is selected', async ({ + network, + modelsPage, + page, + }) => { + setupNetworkMocks(network); + + await modelsPage.goto(); + await modelsPage.openTrainModelDialog(); + + await test.step('Expand architectures list', async () => { + await page.getByRole('button', { name: 'Show more' }).click(); + }); + + await test.step('Select a non-default model architecture', async () => { + await modelsPage.selectModelArchitecture('Object_Detection_YOLOX_X'); + }); + + await test.step('Open advanced settings', async () => { + await modelsPage.openAdvancedSettings(); + await expect(page.getByLabel('Advanced settings tabs').getByText('Data management')).toBeVisible(); + }); + + await test.step('Navigate back', async () => { + await page.getByRole('button', { name: 'Back' }).click(); + }); + + await test.step('Verify that the list is still expanded', async () => { + await expect(page.getByRole('button', { name: 'Show less' })).toBeVisible(); + await expect(page.getByRole('button', { name: /Sort Models by:/i })).toBeVisible(); + }); + }); + test('Advanced model training flow with training from scratch', async ({ network, modelsPage }) => { const state = setupNetworkMocks(network);