-
Notifications
You must be signed in to change notification settings - Fork 918
/
print_trace.py
50 lines (34 loc) · 1.37 KB
/
print_trace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""Demonstrates how to use the tracer module, independent of autodiff, by
creating a trace that prints out functions and their arguments as they're being
evaluated"""
import autograd.numpy as np # autograd has already wrapped numpy for us
from autograd.tracer import Node, trace
class PrintNode(Node):
def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
self.varname_generator = parents[0].varname_generator
self.varname = next(self.varname_generator)
args_or_vars = list(args)
for argnum, parent in zip(parent_argnums, parents):
args_or_vars[argnum] = parent.varname
print("{} = {}({}) = {}".format(self.varname, fun.__name__, ",".join(map(str, args_or_vars)), value))
def initialize_root(self, x):
self.varname_generator = make_varname_generator()
self.varname = next(self.varname_generator)
print(f"{self.varname} = {x}")
def make_varname_generator():
for i in range(65, 91):
yield chr(i)
raise Exception("Ran out of alphabet!")
def print_trace(f, x):
start_node = PrintNode.new_root(x)
trace(start_node, f, x)
print()
def avg(x, y):
return (x + y) / 2
def fun(x):
y = np.sin(x + x)
return avg(y, y)
print_trace(fun, 1.23)
# Traces can be nested, so we can also trace through grad(fun)
from autograd import grad
print_trace(grad(fun), 1.0)