Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

import { useState } from 'react';
import { useEffect } from 'react';

import { Button, Flex } from '@geti/ui';

Expand All @@ -13,13 +13,29 @@ import { getRecommendedArchitectures } from './utils';
const SHOW_MORE_THRESHOLD = 4;

export const ModelArchitecturesList = () => {
const [showMore, setShowMore] = useState<boolean>(false);
const { modelArchitectures, selectedModelArchitectureId, onSelectModelArchitectureId } = useTrainModelState();
const {
modelArchitectures,
selectedModelArchitectureId,
onSelectModelArchitectureId,
showMoreModelArchitectures: showMore,
setShowMoreModelArchitectures: setShowMore,
} = useTrainModelState();

const recommendedArchitectures = getRecommendedArchitectures(modelArchitectures);
const collapsedArchitectures = recommendedArchitectures.slice(0, SHOW_MORE_THRESHOLD);
const canToggleArchitecturesList = modelArchitectures.length > SHOW_MORE_THRESHOLD;

useEffect(() => {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think we need this useEffect anymore. This was what i mentioned previously. We dont need to open/close any show more. Now that the state is on the provider, the list should not collapse

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think we need this useEffect anymore. This was what i mentioned previously. We dont need to open/close any show more. Now that the state is on the provider, the list should not collapse

@jpggvilaca I have reverted it

if (selectedModelArchitectureId !== null && !showMore) {
const isSelectedInCollapsed = collapsedArchitectures.some(
(arch) => arch.id === selectedModelArchitectureId
);
if (!isSelectedInCollapsed) {
setShowMore(true);
}
}
Comment on lines +22 to +28
}, [selectedModelArchitectureId, collapsedArchitectures, setShowMore, showMore]);

return (
<Flex direction={'column'} minHeight={0} gap={'size-300'}>
{showMore ? (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ import { render } from 'test-utils/render';
import { ModelArchitecturesList } from './model-architectures-list.component';

const mockOnSelectModelArchitectureId = vi.hoisted(() => vi.fn());
const mockSelectedModelArchitectureId = vi.hoisted(() => ({ value: null as string | null }));
const mockShowMore = vi.hoisted(() => ({ value: false as boolean }));
const mockSetShowMore = vi.hoisted(() =>
vi.fn((val: boolean) => {
mockShowMore.value = val;
})
);

const mockModelArchitectures = [
getMockedModelArchitecture({ id: 'arch-1', name: 'Alpha Model' }),
Expand All @@ -20,14 +27,19 @@ const mockModelArchitectures = [
vi.mock('../train-model-provider.component', () => ({
useTrainModelState: () => ({
modelArchitectures: mockModelArchitectures,
selectedModelArchitectureId: null,
selectedModelArchitectureId: mockSelectedModelArchitectureId.value,
onSelectModelArchitectureId: mockOnSelectModelArchitectureId,
showMoreModelArchitectures: mockShowMore.value,
setShowMoreModelArchitectures: mockSetShowMore,
}),
}));

describe('ModelArchitecturesList', () => {
beforeEach(() => {
mockOnSelectModelArchitectureId.mockReset();
mockSelectedModelArchitectureId.value = null;
mockShowMore.value = false;
mockSetShowMore.mockClear();
});

it('does not render the sort control in the default (recommended) view', () => {
Expand All @@ -36,21 +48,27 @@ describe('ModelArchitecturesList', () => {
expect(screen.queryByText('Sort Models by:')).not.toBeInTheDocument();
});

it('renders the sort control after clicking "Show more"', async () => {
it('calls setShowMore with true when clicking "Show more"', async () => {
render(<ModelArchitecturesList />);

fireEvent.click(screen.getByRole('button', { name: 'Show more' }));

expect(screen.getByRole('button', { name: /Sort Models by:/i })).toBeVisible();
expect(mockSetShowMore).toHaveBeenCalledWith(true);
});

it('hides the sort control again after clicking "Show less"', async () => {
it('calls setShowMore with false when clicking "Show less"', async () => {
mockShowMore.value = true;
render(<ModelArchitecturesList />);

fireEvent.click(screen.getByRole('button', { name: 'Show more' }));
expect(screen.getByRole('button', { name: /Sort Models by:/i })).toBeVisible();

fireEvent.click(screen.getByRole('button', { name: 'Show less' }));
expect(screen.queryByRole('button', { name: /Sort Models by:/i })).not.toBeInTheDocument();

expect(mockSetShowMore).toHaveBeenCalledWith(false);
});

it('calls setShowMore to true on mount if a non-default architecture is already selected', () => {
mockSelectedModelArchitectureId.value = mockModelArchitectures[mockModelArchitectures.length - 1].id;
render(<ModelArchitecturesList />);

expect(mockSetShowMore).toHaveBeenCalledWith(true);
});
Comment on lines +59 to +65
});
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ export type TrainModelContextProps = {
isAdvancedSettingsMode: boolean;
onToggleAdvancedSettingsMode: (isAdvancedSettingsMode: boolean) => void;

showMoreModelArchitectures: boolean;
setShowMoreModelArchitectures: (showMore: boolean) => void;

trainingConfiguration: TrainingConfiguration | undefined;
defaultTrainingConfiguration: TrainingConfiguration | undefined;
onTrainingConfigurationChange: Dispatch<SetStateAction<TrainingConfiguration | undefined>>;
Expand Down Expand Up @@ -127,6 +130,7 @@ export const TrainModelProvider = ({ children }: TrainModelProviderProps) => {
);

const [isAdvancedSettingsMode, setIsAdvancedSettingsMode] = useState<boolean>(false);
const [showMoreModelArchitectures, setShowMoreModelArchitectures] = useState<boolean>(false);

const modelRevisions = useMemo(() => {
return getModelRevisionsForArchitecture(allModelRevisions, selectedModelArchitectureId);
Expand Down Expand Up @@ -167,6 +171,9 @@ export const TrainModelProvider = ({ children }: TrainModelProviderProps) => {
isAdvancedSettingsMode,
onToggleAdvancedSettingsMode: setIsAdvancedSettingsMode,

showMoreModelArchitectures,
setShowMoreModelArchitectures,

trainingConfiguration,
defaultTrainingConfiguration,
onTrainingConfigurationChange: setTrainingConfiguration,
Expand Down
48 changes: 47 additions & 1 deletion application/ui/tests/models/model-training.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,23 @@ import { expect, http, test } from '../fixtures';
import { MOCKED_MODEL_TRAINING_CONFIGURATION, MOCKED_TRAINING_CONFIGURATION } from './mocks';

const mockedModelArchitectures = [
getMockedModelArchitecture({ id: 'Object_Detection_SSD', name: 'Object_Detection_SSD' }),
getMockedModelArchitecture({
id: 'Object_Detection_SSD',
name: 'Object_Detection_SSD',
performanceCategory: 'balance',
}),
getMockedModelArchitecture({
id: 'Custom_Object_Detection_Gen3_ATSS',
name: 'Custom_Object_Detection_Gen3_ATSS',
performanceCategory: 'speed',
}),
getMockedModelArchitecture({ id: 'arch-3', name: 'Gamma Model', performanceCategory: 'accuracy' }),
getMockedModelArchitecture({ id: 'arch-4', name: 'Delta Model', performanceCategory: undefined }),
getMockedModelArchitecture({ id: 'arch-5', name: 'Epsilon Model', performanceCategory: undefined }),
getMockedModelArchitecture({
id: 'Object_Detection_YOLOX_X',
name: 'Object_Detection_YOLOX_X',
performanceCategory: undefined,
}),
];

Expand Down Expand Up @@ -265,6 +278,39 @@ test.describe('Model training flow', () => {
);
});

test('Keeps model architectures list expanded when navigating back from advanced settings if non-default model is selected', async ({
network,
modelsPage,
page,
}) => {
setupNetworkMocks(network);

await modelsPage.goto();
await modelsPage.openTrainModelDialog();

await test.step('Expand architectures list', async () => {
await page.getByRole('button', { name: 'Show more' }).click();
});

await test.step('Select a non-default model architecture', async () => {
await modelsPage.selectModelArchitecture('Object_Detection_YOLOX_X');
});

await test.step('Open advanced settings', async () => {
await modelsPage.openAdvancedSettings();
await expect(page.getByLabel('Advanced settings tabs').getByText('Data management')).toBeVisible();
});

await test.step('Navigate back', async () => {
await page.getByRole('button', { name: 'Back' }).click();
});

await test.step('Verify that the list is still expanded', async () => {
await expect(page.getByRole('button', { name: 'Show less' })).toBeVisible();
await expect(page.getByRole('button', { name: /Sort Models by:/i })).toBeVisible();
});
});

test('Advanced model training flow with training from scratch', async ({ network, modelsPage }) => {
const state = setupNetworkMocks(network);

Expand Down