Skip to content

Commit

Permalink
feat: Add/remove HPs when creating experiment through HP search (#9610)
Browse files Browse the repository at this point in the history
  • Loading branch information
gt2345 authored Jul 12, 2024
1 parent e9e4458 commit 8379b13
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
width: 100%;
}
.hyperparameterContainer {
align-items: center;
display: grid;
gap: 8px;
grid-auto-rows: max-content;
Expand All @@ -87,6 +88,9 @@
font-weight: normal;
margin: 0;
}
.delete {
cursor: pointer;
}
}
p {
margin: 0;
Expand Down
82 changes: 65 additions & 17 deletions webui/react/src/components/HyperparameterSearchModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import { Modal, ModalCloseReason } from 'hew/Modal';
import RadioGroup from 'hew/RadioGroup';
import Row from 'hew/Row';
import Select, { Option, RefSelectProps, SelectValue } from 'hew/Select';
import { Label, TypographySize } from 'hew/Typography';
import { Loadable } from 'hew/utils/loadable';
import yaml from 'js-yaml';
import React, { useCallback, useEffect, useId, useMemo, useRef, useState } from 'react';
Expand Down Expand Up @@ -80,6 +79,7 @@ interface HyperparameterRowValues {
min?: number;
type: HyperparameterType;
value?: number | string;
name: string;
}

const HyperparameterSearchModal = ({ closeModal, experiment, trial }: Props): JSX.Element => {
Expand Down Expand Up @@ -107,7 +107,7 @@ const HyperparameterSearchModal = ({ closeModal, experiment, trial }: Props): JS
}) as unknown as Record<string, Primitive>;
}, [trial]);

const hyperparameters = useMemo(() => {
const calculateInitialHyperparameters = useCallback(() => {
return Object.entries(experiment.hyperparameters).map((hp) => {
const hpObject = { hyperparameter: hp[1], name: hp[0] };
if (trialHyperparameters?.[hp[0]]) {
Expand All @@ -117,6 +117,10 @@ const HyperparameterSearchModal = ({ closeModal, experiment, trial }: Props): JS
});
}, [experiment.hyperparameters, trialHyperparameters]);

const [currentHPs, setCurrentHPs] = useState<{ hyperparameter: Hyperparameter; name: string }[]>(
calculateInitialHyperparameters,
);

const submitExperiment = useCallback(async () => {
const fields: Record<string, Primitive | HyperparameterRowValues> = form.getFieldsValue(true);

Expand Down Expand Up @@ -150,10 +154,13 @@ const HyperparameterSearchModal = ({ closeModal, experiment, trial }: Props): JS
}

// Parsing hyperparameters
baseConfig.hyperparameters = {};
Object.entries(fields)
.filter((field) => typeof field[1] === 'object')
.forEach((hp) => {
// hpName is the name at the time of the form rendering, while the name field in hpInfo is the updated name.
const hpName = hp[0];
if (!currentHPs?.some((h) => h.name === hpName)) return;
const hpInfo = hp[1] as HyperparameterRowValues;
if (hpInfo.type === HyperparameterType.Categorical) return;
else if (hpInfo.type === HyperparameterType.Constant) {
Expand All @@ -169,13 +176,13 @@ const HyperparameterSearchModal = ({ closeModal, experiment, trial }: Props): JS
} catch (e) {
parsedVal = hpInfo.value;
}
baseConfig.hyperparameters[hpName] = {
baseConfig.hyperparameters[hpInfo.name] = {
type: hpInfo.type,
val: parsedVal,
};
} else {
const prevBase: number | undefined = baseConfig.hyperparameters[hpName]?.base;
baseConfig.hyperparameters[hpName] = {
const prevBase: number | undefined = baseConfig.hyperparameters[hpInfo.name]?.base;
baseConfig.hyperparameters[hpInfo.name] = {
base: hpInfo.type === HyperparameterType.Log ? prevBase ?? DEFAULT_LOG_BASE : undefined,
count: fields.searcher === SEARCH_METHODS.Grid.id ? hpInfo.count : undefined,
maxval:
Expand All @@ -193,7 +200,6 @@ const HyperparameterSearchModal = ({ closeModal, experiment, trial }: Props): JS

// Unflatten hyperparameters to deal with nesting
baseConfig.hyperparameters = unflattenObject(baseConfig.hyperparameters);

const newConfig = yaml.dump(baseConfig);

try {
Expand Down Expand Up @@ -234,7 +240,7 @@ const HyperparameterSearchModal = ({ closeModal, experiment, trial }: Props): JS
// We throw an error to prevent the modal from closing.
throw new DetError(errorMessage, { publicMessage: errorMessage, silent: true });
}
}, [experiment.configRaw, experiment.id, experiment.projectId, form]);
}, [experiment.configRaw, experiment.id, experiment.projectId, form, currentHPs]);

const handleOk = useCallback(() => {
if (currentPage === 0) {
Expand Down Expand Up @@ -351,13 +357,23 @@ const HyperparameterSearchModal = ({ closeModal, experiment, trial }: Props): JS
[form],
);

const getNextHPName = useCallback((names: string[]) => {
let counter = names.length;
while (names.includes(`hp_${counter}`)) counter++;
return `hp_${counter}`;
}, []);

const hyperparameterPage = useMemo((): React.ReactNode => {
const emptyHP: Hyperparameter = { type: 'const' };
// We always render the form regardless of mode to provide a reference to it.
return (
<div className={css.base}>
{modalError && <Alert message={modalError} type="error" />}
<div className={css.labelWithLink}>
<p>Select hyperparameters and define the search space.</p>
<p>
Select hyperparameters and define the search space. <br />
The experiment code needs to be able to handle hyperparameters for them to take effect.
</p>
<Link
external
path={paths.docs('/training/hyperparameter/configure-hp-ranges.html')}
Expand All @@ -369,7 +385,7 @@ const HyperparameterSearchModal = ({ closeModal, experiment, trial }: Props): JS
className={css.hyperparameterContainer}
style={{
gridTemplateColumns: `180px minmax(100px, 1.4fr)
repeat(${searcher === SEARCH_METHODS.Grid ? 4 : 3}, minmax(60px, 1fr))`,
repeat(${searcher === SEARCH_METHODS.Grid ? 4 : 3}, minmax(60px, 1fr)) 20px`,
}}>
<label id="hyperparameter">
<h2>Hyperparameter</h2>
Expand All @@ -391,13 +407,32 @@ const HyperparameterSearchModal = ({ closeModal, experiment, trial }: Props): JS
<h2>Grid Count</h2>
</label>
)}
{hyperparameters.map((hp) => (
<HyperparameterRow key={hp.name} searcher={searcher} {...hp} />
<label id="delete" />
{currentHPs?.map((hp, idx) => (
<HyperparameterRow
handleDelete={(name: string) =>
setCurrentHPs((prev) => prev?.filter((hp) => hp.name !== name))
}
key={idx}
searcher={searcher}
{...hp}
/>
))}
<label id="add">
<Button
onClick={() =>
setCurrentHPs((prev) => [
...(prev ?? []),
{ hyperparameter: emptyHP, name: getNextHPName(prev.map((p) => p.name)) },
])
}>
Add Hyperparameter
</Button>
</label>
</div>
</div>
);
}, [hyperparameters, modalError, searcher]);
}, [currentHPs, modalError, searcher, getNextHPName]);

const searcherPage = useMemo((): React.ReactNode => {
// We always render the form regardless of mode to provide a reference to it.
Expand Down Expand Up @@ -597,12 +632,18 @@ interface RowProps {
hyperparameter: Hyperparameter;
name: string;
searcher: SearchMethod;
handleDelete: (name: string) => void;
}

const HyperparameterRow: React.FC<RowProps> = ({ hyperparameter, name, searcher }: RowProps) => {
const HyperparameterRow: React.FC<RowProps> = ({
hyperparameter,
name,
searcher,
handleDelete,
}: RowProps) => {
const type: HyperparameterType | undefined = Form.useWatch([name, 'type']);
const typeRef = useRef<RefSelectProps>(null);
const [active, setActive] = useState(hyperparameter.type !== HyperparameterType.Constant);
const [active, setActive] = useState<boolean>(false);
const min: number | undefined = Form.useWatch([name, 'min']);
const max: number | undefined = Form.useWatch([name, 'max']);
const [valError, setValError] = useState<string>();
Expand All @@ -611,6 +652,10 @@ const HyperparameterRow: React.FC<RowProps> = ({ hyperparameter, name, searcher
const [rangeError, setRangeError] = useState<string>();
const [countError, setCountError] = useState<string>();

useEffect(() => {
setActive(type !== HyperparameterType.Constant);
}, [type]);

const handleTypeChange = useCallback((value: SelectValue) => {
setActive(value !== HyperparameterType.Constant);
}, []);
Expand Down Expand Up @@ -668,9 +713,9 @@ const HyperparameterRow: React.FC<RowProps> = ({ hyperparameter, name, searcher
return (
<>
<div className={css.hyperparameterName}>
<Label size={TypographySize.L} truncate={{ tooltip: true }}>
{name}
</Label>
<Form.Item initialValue={name} name={[name, 'name']} rules={[{ required: true }]}>
<Input aria-labelledby="name" onChange={validateValue} />
</Form.Item>
</div>
<Form.Item initialValue={hyperparameter.type} name={[name, 'type']} noStyle>
<Select aria-labelledby="type" ref={typeRef} width={'100%'} onChange={handleTypeChange}>
Expand Down Expand Up @@ -777,6 +822,9 @@ const HyperparameterRow: React.FC<RowProps> = ({ hyperparameter, name, searcher
</Form.Item>
</>
)}
<div className={css.delete} onClick={() => handleDelete(name)}>
<Icon name="close" title="delete" />
</div>
{type === HyperparameterType.Categorical && (
<p className={css.warning}>Categorical hyperparameters are not currently supported.</p>
)}
Expand Down

0 comments on commit 8379b13

Please sign in to comment.