Skip to content

Commit

Permalink
Add Save Checkpoint button, update checkpoint selector
Browse files Browse the repository at this point in the history
  • Loading branch information
vgeorge committed Feb 21, 2024
1 parent 5424b82 commit e7c1216
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import React, { useState } from 'react';
import { Dropdown, DropdownBody } from '../../../../styles/dropdown';
import InfoButton from '../../../common/info-button';
import { Form, FormInput } from '@devseed-ui/form';
import { glsp } from '@devseed-ui/theme-provider';

import { ProjectMachineContext } from '../../../../fsm/project';
import { Subheading } from '../../../../styles/type/heading';
import { LocalButton } from '../../../../styles/local-button';
import styled from 'styled-components';

const SaveCheckpoint = styled(DropdownBody)`
padding: ${glsp()};
`;

const SaveCheckpointButton = () => {
const actorRef = ProjectMachineContext.useActorRef();
const currentCheckpoint = ProjectMachineContext.useSelector(
(s) => s.context.currentCheckpoint
);

const [localCheckpointName, setLocalCheckpointName] = useState('');
return (
<Dropdown
alignment='center'
direction='up'
triggerElement={(triggerProps) => (
<InfoButton
data-cy='save-checkpoint-button'
variation='primary-plain'
size='medium'
useIcon='save-disk'
useLocalButton
style={{
gridColumn: '1 / -1',
}}
id='rename-button-trigger'
{...triggerProps}
>
Save Checkpoint
</InfoButton>
)}
>
<SaveCheckpoint>
<Subheading>Checkpoint name:</Subheading>
<Form
onSubmit={(evt) => {
evt.preventDefault();
const name = evt.target.elements.checkpointName.value;
actorRef.send('Save checkpoint', {
data: {
checkpoint: { ...currentCheckpoint, name, bookmarked: true },
},
});
}}
>
<FormInput
name='checkpointName'
placeholder='Set Checkpoint Name'
value={localCheckpointName}
onKeyDown={(e) => {
e.stopPropagation();
}}
onChange={(e) => setLocalCheckpointName(e.target.value)}
autoFocus
/>
<LocalButton
type='submit'
variation='primary-plain'
useIcon='save-disk'
title='Rename checkpoint'
data-dropdown='click.close'
>
Save
</LocalButton>
</Form>
</SaveCheckpoint>
</Dropdown>
);
};

export default SaveCheckpointButton;
5 changes: 5 additions & 0 deletions app/assets/scripts/components/project/prime-panel/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import RetrainTab from './tabs/retrain';
import { UploadAoiModal } from './upload-aoi-modal';
import { PrimeButton } from './footer/prime-button';
import { BatchPredictionPanel } from './footer/batch-prediction-panel';
import SaveCheckpointButton from './footer/save-checkpoint-button';
import { ProjectMachineContext } from '../../../fsm/project';
import { SESSION_MODES } from '../../../fsm/project/constants';
import selectors from '../../../fsm/project/selectors';
Expand Down Expand Up @@ -58,6 +59,9 @@ export function PrimePanel() {
const retrainModeEnabled = ProjectMachineContext.useSelector(({ context }) =>
guards.retrainModeEnabled(context)
);
const canSaveCheckpoint = ProjectMachineContext.useSelector(
selectors.canSaveCheckpoint
);

return (
<>
Expand Down Expand Up @@ -104,6 +108,7 @@ export function PrimePanel() {
</PanelBlockBody>
<PanelControls>
<PrimeButton />
{canSaveCheckpoint && <SaveCheckpointButton />}
<BatchPredictionPanel />
</PanelControls>
</StyledPanelBlock>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ function CheckpointSelector() {
);

function getCheckpointLabel(c) {
return `${c?.name || currentModel.name} (${c?.id || '(Base Model)'})`;
return c.parent === null
? `${currentModel.name} (Base Model)`
: `${c.name}`;
}

let selectedOptionLabel;
Expand All @@ -97,6 +99,13 @@ function CheckpointSelector() {
selectedOptionLabel = getCheckpointLabel(currentCheckpoint);
}

const selectableCheckpoints = checkpointList?.filter(
(c) =>
(c.parent === null || c.bookmarked) &&
currentCheckpoint &&
currentCheckpoint.id !== c.id
);

return (
<>
<HeadOption hasSubtitle>
Expand All @@ -117,22 +126,19 @@ function CheckpointSelector() {
<CheckpointOption selected data-cy='selected-checkpoint-header'>
<Heading size='xsmall'>{selectedOptionLabel}</Heading>
</CheckpointOption>
{!!checkpointList?.length &&
checkpointList
.filter((c) => c.id != currentCheckpoint?.id)
.map((c) => (
<CheckpointOption
key={c.id}
onClick={async () => {
actorRef.send({
type: 'Apply checkpoint',
data: { checkpoint: { ...c } },
});
}}
>
<Heading size='xsmall'>{getCheckpointLabel(c)}</Heading>
</CheckpointOption>
))}
{selectableCheckpoints?.map((c) => (
<CheckpointOption
key={c.id}
onClick={async () => {
actorRef.send({
type: 'Apply checkpoint',
data: { checkpoint: { ...c } },
});
}}
>
<Heading size='xsmall'>{getCheckpointLabel(c)}</Heading>
</CheckpointOption>
))}
</ShadowScrollbar>
</HeadOption>
</>
Expand Down
Loading

0 comments on commit e7c1216

Please sign in to comment.