Skip to content

Commit

Permalink
feat(presets): work on future presets feature
Browse files Browse the repository at this point in the history
  • Loading branch information
Bamdad Sabbagh committed Nov 18, 2021
1 parent 6b9764d commit 96020f2
Show file tree
Hide file tree
Showing 10 changed files with 153 additions and 7 deletions.
20 changes: 20 additions & 0 deletions index.html
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,9 @@ <h2>Credits</h2>

<!-- App Buttons -->
<div id="buttons">
<button class="mdl-button mdl-js-button mdl-button--fab mdl-button--mini-fab mdl-js-ripple-effect presets">
<i class="material-icons">bookmarks</i>
</button>
<button class="mdl-button mdl-js-button mdl-button--fab mdl-button--mini-fab mdl-js-ripple-effect imports-exports">
<i class="material-icons">import_export</i>
</button>
Expand Down Expand Up @@ -851,6 +854,23 @@ <h6>Novation Launch Control XL (Controller)</h6>
</div>
</dialog>

<!-- Presets Dialog -->
<dialog class="mdl-dialog" id="presets">
<h4 class="mdl-dialog__title">
Presets
</h4>
<div class="mdl-dialog__content">
<div>
<span>Please choose a preset to load</span>
</div>
</div>
<div class="mdl-dialog__actions">
<button class="mdl-button close-button" type="button">
close
</button>
</div>
</dialog>

<!-- Select Card -->
<div class="cool-card" id="select-card">
<div class="row header">
Expand Down
4 changes: 4 additions & 0 deletions src/app/ui/buttons.ui.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { devicesUi } from './devices.ui';
import { mappingsUi } from './mappings.ui';
import { helpUi } from './help.ui';
import { importsExportsUi } from './imports-exports.ui';
import { ui } from './ui';

export const buttonsUi = Object.create (null);

Expand All @@ -11,6 +12,7 @@ buttonsUi.nodeSelectors = {
devices: '.devices',
mappings: '.mappings',
help: '.help',
presets: '.presets',
};

buttonsUi.init = function () {
Expand All @@ -19,9 +21,11 @@ buttonsUi.init = function () {
this.devices = this.node.querySelector (this.nodeSelectors.devices);
this.mappings = this.node.querySelector (this.nodeSelectors.mappings);
this.help = this.node.querySelector (this.nodeSelectors.help);
this.presets = this.node.querySelector (this.nodeSelectors.presets);

this.importsExports.onclick = () => importsExportsUi.show ();
this.mappings.onclick = () => mappingsUi.show ();
this.devices.onclick = () => devicesUi.show ();
this.help.onclick = () => helpUi.show ();
this.presets.onclick = () => ui.presetsView.show ();
};
44 changes: 44 additions & 0 deletions src/app/ui/presets.view.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
export class PresetsView {
private selectors: { contentNode: string; node: string; closeButton: string; };

private node: HTMLDialogElement;

private content: Element;

private closeButton: HTMLButtonElement;

constructor () {
this.selectors = {
node: '#presets',
closeButton: '.close-button',
contentNode: '.mdl-dialog__content',
};

this.node = document.querySelector (this.selectors.node);
this.content = this.node.querySelector (this.selectors.contentNode);
this.closeButton = this.node.querySelector (this.selectors.closeButton);

this.attachEvents ();
}

attachEvents (): void {
// clicking close button
this.closeButton.onclick = () => this.hide ();

// clicking outside container
this.node.onclick = (e) => {
// MDL adds `open` HTML attribute to the dialog container (outside) only
if (e.target.open) {
this.hide ();
}
};
}

show (): void {
this.node.showModal ();
}

hide (): void {
this.node.close ();
}
}
2 changes: 2 additions & 0 deletions src/app/ui/ui.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { helpUi } from './help.ui';
import { layerCardUi } from './layer-card.ui';
import { importsExportsUi } from './imports-exports.ui';
import { playgroundUi } from './playground.ui';
import { PresetsView } from './presets.view';

export const ui = Object.create (null);

Expand All @@ -25,6 +26,7 @@ ui.init = async function () {
helpUi.init ();
importsExportsUi.init ();
playgroundUi.init ();
this.presetsView = new PresetsView ();
};

/**
Expand Down
5 changes: 5 additions & 0 deletions src/app/utils/get-network-shape.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
export function getNetworkShape (preset: number[]): number[] {
const n = [];
preset.forEach (() => n.push (8));
return n;
}
2 changes: 2 additions & 0 deletions src/coolearning/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,6 @@ interface Navigator {

interface HTMLDialogElement {
close (): void;

showModal (): void;
}
27 changes: 27 additions & 0 deletions src/playground/nn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ export class Link {
accErrorDer = 0;
/** Number of accumulated derivatives since the last update. */
numAccumulatedDers = 0;
savedWeight: number;

/**
* Constructs a link in the neural network initialized with random weight.
Expand Down Expand Up @@ -222,6 +223,7 @@ export class Link {
* @param activation The activation function of every hidden node.
* @param outputActivation The activation function for the output nodes.
* @param inputIds List of ids for the input nodes.
* @param preset The preset to use.
* @param initZero
*/
export function buildNetwork (
Expand All @@ -230,6 +232,7 @@ export function buildNetwork (
activation: ActivationFunction,
outputActivation: ActivationFunction,
inputIds: string[],
preset,
initZero?: boolean,
): Node[][] {
let numLayers = networkShape.length;
Expand All @@ -241,6 +244,7 @@ export function buildNetwork (
let isInputLayer = layerIdx === 0;
let currentLayer: Node[] = [];
network.push (currentLayer);
// let numNodes = networkShape[layerIdx];
let numNodes = networkShape[layerIdx];
for (let i = 0; i < numNodes; i++) {
let nodeId = id.toString ();
Expand Down Expand Up @@ -276,6 +280,29 @@ export function buildNetwork (
}
}
}

// initial neuron state
const layers = network.slice (1, -1);
layers.forEach ((layer, layerIndex) => {
layer.forEach ((neuron, nodeIndex) => {
if (nodeIndex + 1 > preset[layerIndex]) {
neuron.isEnabled = false;

neuron.inputLinks.forEach ((link: Link) => {
link.isDead = true;
link.savedWeight = link.weight;
link.weight = 0;
});

neuron.outputs.forEach ((link: Link) => {
link.isDead = true;
link.savedWeight = link.weight;
link.weight = 0;
});
}
});
});

return network;
}

Expand Down
24 changes: 21 additions & 3 deletions src/playground/playground.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import { HeatMap, reduceMatrix } from './heatmap';
import {
activations,
datasets,
getKeyFromValue,
getKeyFromValue, presets,
Problem,
problems,
regDatasets,
Expand All @@ -32,6 +32,7 @@ import { Coolearning } from '../coolearning/coolearning';
import { networkUi } from '../app/ui/network.ui';
import { playgroundFacade } from '../app/facades/playground.facade';
import { playgroundUi } from '../app/ui/playground.ui';
import { getNetworkShape } from '../app/utils/get-network-shape';

Coolearning ();

Expand Down Expand Up @@ -525,6 +526,10 @@ function drawNode (cx: number, cy: number, nodeId: string, isInput: boolean,
// })
}

const nodeDisabled = typeof _node?.isEnabled === 'undefined'
? false
: !_node.isEnabled;

// Draw the node's canvas.
let div = d3.select ('#network').insert ('div', ':first-child')
.attr ({
Expand All @@ -537,6 +542,7 @@ function drawNode (cx: number, cy: number, nodeId: string, isInput: boolean,
top: `${y + 3}px`,
})
.style ('cursor', 'pointer')
.classed ('disabled', nodeDisabled)
.on ('mousedown', () => {
mouseTimer = setTimeout (() => {
if (isInput) {
Expand Down Expand Up @@ -986,13 +992,25 @@ function reset (onStartup = false) {
// Make a simple network.
iter = 0;
let numInputs = constructInput (0, 0).length;
let shape = [numInputs].concat (state.networkShape).concat ([1]);
const preset = presets['2-2'];
let shape = [numInputs].concat (getNetworkShape (preset)).concat ([1]);
console.log (shape);

let outputActivation = state.problem === Problem.REGRESSION
? nn.Activations.LINEAR
: nn.Activations.TANH;

network = nn.buildNetwork (shape, state, state.activation, outputActivation, constructInputIds (), state.initZero);
network = nn.buildNetwork (
shape,
state,
state.activation,
outputActivation,
constructInputIds (),
preset,
state.initZero,
);

console.log (network);
lossTrain = getLoss (network, trainData);
lossTest = getLoss (network, testData);
drawNetwork (network);
Expand Down
23 changes: 19 additions & 4 deletions src/playground/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

import * as nn from './nn';
import * as dataset from './dataset';
import { getNetworkShape } from '../app/utils/get-network-shape';

/** Suffix added to the state when storing if a control is hidden or not. */
const HIDE_STATE_SUFFIX = '_hide';
Expand Down Expand Up @@ -48,6 +49,16 @@ export let regDatasets: { [key: string]: dataset.DataGenerator } = {
'reg-gauss': dataset.regressGaussian,
};

export let presets: { [key: string]: number[] } = {
'allOn': [8, 8, 8, 8, 8, 8],
'allOff': [0, 0, 0, 0, 0, 0],
'8-6-2-4-6': [8, 6, 4, 2, 4, 6],
'8-6-4-2': [8, 6, 4, 2],
'2-4-6-8': [2, 4, 6, 8],
'2-2': [2, 2],
'2': [2],
};

export function getKeyFromValue (obj: any, value: any): string {
for (let key in obj) {
if (obj[key] === value) {
Expand Down Expand Up @@ -112,7 +123,8 @@ export class State {
{name: 'learningRate', type: Type.NUMBER},
{name: 'regularizationRate', type: Type.NUMBER},
{name: 'noise', type: Type.NUMBER},
{name: 'networkShape', type: Type.ARRAY_NUMBER},
// {name: 'networkShape', type: Type.ARRAY_NUMBER},
{name: 'networkPreset', type: Type.ARRAY_NUMBER},
{name: 'seed', type: Type.STRING},
{name: 'showTestData', type: Type.BOOLEAN},
{name: 'discretize', type: Type.BOOLEAN},
Expand Down Expand Up @@ -149,7 +161,8 @@ export class State {
collectStats = false;
numHiddenLayers = 1;
hiddenLayerControls: any[] = [];
networkShape: number[] = [8, 8, 8, 8, 8, 8];
// networkShape: number[] = [8, 8, 8, 8, 8, 8];
networkPreset: number[];
x = true;
y = true;
xTimesY = true;
Expand Down Expand Up @@ -229,7 +242,8 @@ export class State {
getHideProps (map).forEach (prop => {
state[prop] = (map[prop] === 'true');
});
state.numHiddenLayers = state.networkShape.length;
// state.numHiddenLayers = state.networkShape.length;
state.numHiddenLayers = getNetworkShape (presets['2-2']).length;
if (state.seed == null) {
state.seed = Math.random ().toFixed (5);
}
Expand All @@ -251,7 +265,8 @@ export class State {
}
if (type === Type.OBJECT) {
value = getKeyFromValue (keyMap, value);
} else if (type === Type.ARRAY_NUMBER ||
}
else if (type === Type.ARRAY_NUMBER ||
type === Type.ARRAY_STRING) {
value = value.join (',');
}
Expand Down
9 changes: 9 additions & 0 deletions styles.css
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,8 @@ App Buttons
display: flex;
justify-content: flex-start;
grid-gap: 6px;

z-index: 1000;
}

/**
Expand Down Expand Up @@ -1094,6 +1096,13 @@ Imports/Exports Dialog
margin-top: 1em;
}

/**
Presets dialog
*/
#presets {
width: 50vw;
}

/**
Cool cards
*/
Expand Down

0 comments on commit 96020f2

Please sign in to comment.