-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathenn.erl
executable file
·166 lines (148 loc) · 5.97 KB
/
enn.erl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
%%%-------------------------------------------------------------------
%%% @author borja
%%% @doc
%%%
%%% @end
%%%-------------------------------------------------------------------
-module(enn).
-author("borja").
-compile({no_auto_import,[link/1]}).
%% API
-export([start/1, start_link/1, stop/1, predict/2, fit/3]).
-export([run/4, inputs/1, outputs/1, neurons/1]).
-export([status/1, cortex/1]).
-export_types([network/0, neuron/0, model/0]).
-type network() :: nnet:id().
-type neuron() :: nnet:neuron().
-type model() :: nnet:model().
%%====================================================================
%% API
%%====================================================================
%%--------------------------------------------------------------------
%% @doc Returns the number of inputs a network expects.
%% Should run inside mnesia transaction.
%% @end
%%--------------------------------------------------------------------
-spec inputs(model() | network()) -> NumberOfInputs::integer().
inputs(Model) when is_map(Model) ->
#{inputs := #{units := N_Inputs}} = Model,
N_Inputs;
inputs(Network) -> % System inputs are the nnet outputs
length(nnet:out(Network)).
%%--------------------------------------------------------------------
%% @doc Returns the number of outputs a network expects.
%% Should run inside mnesia transaction.
%% @end
%%--------------------------------------------------------------------
-spec outputs(model() | network()) -> NumberOfOtputs::integer().
outputs(Model) when is_map(Model) ->
#{outputs := #{units := N_Outputs}} = Model,
N_Outputs;
outputs(Network) -> % System outputs are the nnet inputs
length(nnet:in(Network)).
%%-------------------------------------------------------------------
%% @doc Returns a list of all neurons of the network.
%% Should run inside mnesia transaction.
%% @end
%%-------------------------------------------------------------------
-spec neurons(Network::network()) -> Neurons::[neuron()].
neurons(Network) ->
maps:keys(nnet:nodes(Network)).
%%--------------------------------------------------------------------
%% @doc Start a neural network, ready to receive inputs or training.
%% @end
%%--------------------------------------------------------------------
-spec start(Model | Network) -> Network when
Model :: model(),
Network :: network().
start(Model) when is_map(Model) ->
{atomic, Network} = mnesia:transaction(fun() -> nnet:compile(Model) end),
start(Network);
start(Network) ->
case enn_sup:start_nn(Network) of
{ok, _Pid} -> Network;
{error,{{_,{_,_, broken_nn}},_}} -> error(broken_nn)
end.
%%--------------------------------------------------------------------
%% @doc Start a neural network, and links the caller.
%% @end
%%--------------------------------------------------------------------
-spec start_link(Model | Network) -> Network when
Model :: model(),
Network :: network().
start_link(NetworkTerm) ->
Network = start(NetworkTerm),
erlang:link(cortex(Network)),
Network.
%%--------------------------------------------------------------------
%% @doc Stops a neural network.
%% @end
%%--------------------------------------------------------------------
-spec stop(Network::network()) -> Result when
Result :: 'ok' | {'error', Error},
Error :: 'not_found'.
stop(Network) ->
erlang:unlink(cortex(Network)),
enn_sup:terminate_nn(Network).
%%--------------------------------------------------------------------
%% @doc Returns the status of the specified network id.
%% @end
%%--------------------------------------------------------------------
-spec status(Network::network()) -> Status when
Status :: enn_pool:info() | not_running.
status(Network) ->
try enn_pool:info(Network) of
Info -> Info
catch error:badarg -> not_running
end.
%%--------------------------------------------------------------------
%% @doc Returns the pid of the cortex.
%% @end
%%--------------------------------------------------------------------
-spec cortex(Network::network()) -> pid().
cortex(Network) ->
#{cortex:=Pid} = enn_pool:info(Network),
Pid.
%%--------------------------------------------------------------------
%% @doc Uses a ANN to create a prediction. The ANN is refered by using
%% the cortex pid.
%% @end
%%--------------------------------------------------------------------
-spec predict(Network, Inputs) -> Predictions when
Network :: network(),
Inputs :: [[float()]],
Predictions :: [[float()]].
predict(Network, InputsList) ->
Options = [{return, [prediction]}],
[Prediction] = run(Network, InputsList, [], Options),
Prediction.
%%--------------------------------------------------------------------
%% @doc Supervised ANN training function. Fits the Predictions to the
%% OptimalOutputs. Returns the errors between prediction and optima.
%% @end
%%--------------------------------------------------------------------
-spec fit(Network, Inputs, Optima) -> Errors when
Network :: network(),
Inputs :: [[float()]],
Optima :: [[float()]],
Errors :: [[float()]].
fit(Network, InputsList, OptimaList) ->
Options = [{return, [loss]}, {print, 10}],
[Loss] = run(Network, InputsList, OptimaList, Options),
Loss.
%%--------------------------------------------------------------------
%% @doc Runs an ANN with the criteria defined at the options.
%% @end
%%--------------------------------------------------------------------
-spec run(Network, Inputs, Optima, Options) -> Results when
Network :: network(),
Inputs :: [[float()]],
Optima :: [[float()]],
Options :: [training:option()],
Results :: [term()].
run(Network, InputsList, OptimaList, Options) ->
Cortex_Pid = cortex(Network),
training:start_link(Cortex_Pid, InputsList, OptimaList, Options).
%%%===================================================================
%%% Internal functions
%%%===================================================================