forked from avhz/RustQuant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
automatic_differentiation.rs
111 lines (85 loc) · 3.39 KB
/
automatic_differentiation.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
use RustQuant::autodiff::*;
// The general workflow for using the `autodiff` module is as follows:
//
// 1. Create a new graph.
// 2. Assign variables onto the graph.
// 3. Define an expression using the variables.
// 4. Accumulate (differentiate) the expression.
// 5. Profit.
fn main() {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// SIMPLE EXPRESSIONS
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
let g = Graph::new();
let a = 1.;
let b = 2.;
let x = g.var(69.);
let y = g.var(420.);
// Define a function.
let f = a + b + (x * y).exp();
// Accumulate the gradient.
let gradient = f.accumulate();
println!("z = {}", f.value);
println!("dz/dx = {}", gradient.wrt(&x));
println!("dz/dy = {}", gradient.wrt(&y));
println!("grad = {:?}", gradient.wrt(&[x, y]));
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// BLOCK EXPRESSIONS
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
let g = Graph::new();
let x = g.var(69.);
let y = g.var(420.);
let block = {
let z = x.sin() + y.tan();
z.exp()
};
let grad = block.accumulate();
println!("f = {}", block.value);
println!("df/dx = {}", grad.wrt(&x));
println!("df/dy = {}", grad.wrt(&y));
println!("grad = {:?}", grad.wrt(&[x, y]));
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// CLOSURES
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
let g = Graph::new();
let x = g.var(1.);
let y = g.var(2.);
let closure = || (x * y).cosh() / (x.tanh() * y.sinh());
let grad = closure().accumulate();
println!("z = {}", closure().value);
println!("dz/dx = {}", grad.wrt(&x));
println!("dz/dy = {}", grad.wrt(&y));
println!("grad = {:?}", grad.wrt(&[x, y]));
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// PROPER FUNCTIONS.
//
// Note that you can also add many variables via vectors, slices, arrays, etc.
// This is where the `autodiff` crate really shines, as it allows
// you to differentiate functions of any number of variables and
// computing gradients for large functions using AD rather than
// finite-difference quotients is significantly faster and has no error.
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Function to differentiate:
// f = x^(y + cos(1)) - atanh(z) / 2 + 1
// at x = 3, y = 2, z = 1.
#[rustfmt::skip]
fn function<'v>(variables: &[Variable<'v>], constants: &[f64]) -> Variable<'v> {
variables[0].powf(variables[1] + constants[0].cos()) -
variables[2].atanh() / constants[1] +
constants[0]
}
// New graph.
let graph = Graph::new();
// Variables and constants.
let variables = graph.vars(&[3.0, 2.0, 1.0]);
let constants = [1., 2.];
// Evaluate and differentiate the function.
let result = function(&variables, &constants);
let gradient = result.accumulate();
// Print the graph length.
println!("Graph length: {}", graph.len());
println!("{:?}", gradient.wrt(&variables));
// Print the graphviz output.
// You can copy and paste this into your Graphviz viewer of choice.
println!("{}", graphviz(&graph, &variables));
}