-
Notifications
You must be signed in to change notification settings - Fork 3
/
overload_connectivity.jl
160 lines (139 loc) · 4.74 KB
/
overload_connectivity.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
## 1-to-1
function connectivity_tracer_1_to_1(
t::T, is_influence_zero::Bool
) where {T<:ConnectivityTracer}
if is_influence_zero
return empty(T)
else
return t
end
end
function overload_connectivity_1_to_1(m::Module, fn::Function)
ms, fns = nameof(m), nameof(fn)
@eval function $ms.$fns(t::T) where {T<:ConnectivityTracer}
return connectivity_tracer_1_to_1(t, is_influence_zero_global($ms.$fns))
end
@eval function $ms.$fns(d::D) where {P,T<:ConnectivityTracer,D<:Dual{P,T}}
x = primal(d)
p_out = $ms.$fns(x)
t_out = connectivity_tracer_1_to_1(tracer(d), is_influence_zero_local($ms.$fns, x))
return Dual(p_out, t_out)
end
end
## 2-to-1
function connectivity_tracer_2_to_1(
tx::T, ty::T, is_influence_arg1_zero::Bool, is_influence_arg2_zero::Bool
) where {T<:ConnectivityTracer}
if is_influence_arg1_zero
if is_influence_arg2_zero
return empty(T)
else
return ty
end
else # x -> f ≠ 0
if is_influence_arg2_zero
return tx
else
return T(inputs(tx) ∪ inputs(ty))
end
end
end
function overload_connectivity_2_to_1(m::Module, fn::Function)
ms, fns = nameof(m), nameof(fn)
@eval function $ms.$fns(tx::T, ty::T) where {T<:ConnectivityTracer}
return connectivity_tracer_2_to_1(
tx,
ty,
is_influence_arg1_zero_global($ms.$fns),
is_influence_arg2_zero_global($ms.$fns),
)
end
@eval function $ms.$fns(dx::D, dy::D) where {P,T<:ConnectivityTracer,D<:Dual{P,T}}
x = primal(dx)
y = primal(dy)
p_out = $ms.$fns(x, y)
t_out = connectivity_tracer_2_to_1(
tracer(dx),
tracer(dy),
is_influence_arg1_zero_local($ms.$fns, x, y),
is_influence_arg2_zero_local($ms.$fns, x, y),
)
return Dual(p_out, t_out)
end
@eval function $ms.$fns(tx::ConnectivityTracer, ::Number)
return connectivity_tracer_1_to_1(tx, is_influence_arg1_zero_global($fns))
end
@eval function $ms.$fns(dx::D, y::Number) where {P,T<:ConnectivityTracer,D<:Dual{P,T}}
x = primal(dx)
p_out = $ms.$fns(x, y)
t_out = connectivity_tracer_1_to_1(
tracer(dx), is_influence_arg1_zero_local($ms.$fns, x, y)
)
return Dual(p_out, t_out)
end
@eval function $ms.$fns(::Number, ty::ConnectivityTracer)
return connectivity_tracer_1_to_1(ty, is_influence_arg2_zero_global($fns))
end
@eval function $ms.$fns(x::Number, dy::D) where {P,T<:ConnectivityTracer,D<:Dual{P,T}}
y = primal(dy)
p_out = $ms.$fns(x, y)
t_out = connectivity_tracer_1_to_1(
tracer(dy), is_influence_arg2_zero_local($ms.$fns, x, y)
)
return Dual(p_out, t_out)
end
end
## 1-to-2
function connectivity_tracer_1_to_2(
t::T, is_influence_out1_zero::Bool, is_influence_out2_zero::Bool
) where {T<:ConnectivityTracer}
t1 = connectivity_tracer_1_to_1(t, is_influence_out1_zero)
t2 = connectivity_tracer_1_to_1(t, is_influence_out2_zero)
return (t1, t2)
end
function overload_connectivity_1_to_2(m::Module, fn::Function)
ms, fns = nameof(m), nameof(fn)
@eval function $ms.$fns(t::ConnectivityTracer)
return connectivity_tracer_1_to_2(
t,
is_influence_out1_zero_global($ms.$fns),
is_influence_out2_zero_global($ms.$fns),
)
end
@eval function $ms.$fns(d::D) where {P,T<:ConnectivityTracer,D<:Dual{P,T}}
x = primal(d)
p1_out, p2_out = $ms.$fns(x)
t1_out, t2_out = connectivity_tracer_1_to_2(
t,
is_influence_out1_zero_local($ms.$fns, x),
is_influence_out2_zero_local($ms.$fns, x),
)
return (Dual(p1_out, t1_out), Dual(p2_out, t2_out))
end
end
## Actual overloads
for op in ops_1_to_1
overload_connectivity_1_to_1(Base, op)
end
for op in ops_2_to_1
overload_connectivity_2_to_1(Base, op)
end
for op in ops_1_to_2
overload_connectivity_1_to_2(Base, op)
end
## Special cases
## Exponent (requires extra types)
for S in (Real, Integer, Rational, Irrational{:ℯ})
Base.:^(t::ConnectivityTracer, ::S) = t
function Base.:^(dx::D, y::S) where {P,T<:ConnectivityTracer,D<:Dual{P,T}}
return Dual(primal(dx)^y, tracer(dx))
end
Base.:^(::S, t::ConnectivityTracer) = t
function Base.:^(x::S, dy::D) where {P,T<:ConnectivityTracer,D<:Dual{P,T}}
return Dual(x^primal(dy), tracer(dy))
end
end
## Rounding
Base.round(t::ConnectivityTracer, ::RoundingMode; kwargs...) = t
## Random numbers
Base.rand(::AbstractRNG, ::SamplerType{T}) where {T<:ConnectivityTracer} = empty(T) # TODO: was missing Base, add tests