Skip to content

Commit

Permalink
#1219 fix basic tests
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Nov 3, 2020
1 parent 1ec94f5 commit 08a9415
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 139 deletions.
28 changes: 26 additions & 2 deletions pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,20 +378,44 @@ def log10(child):
return log(child, base=10)


class Max(SpecificFunction):
""" Max function """

def __init__(self, child):
super().__init__(np.max, child)

@property
def julia_name(self):
""" See :meth:`pybamm.Function.julia_name` """
return "maximum"


def max(child):
"""
Returns max function of child. Not to be confused with :meth:`pybamm.maximum`, which
returns the larger of two objects.
"""
return pybamm.simplify_if_constant(Function(np.max, child), keep_domains=True)
return pybamm.simplify_if_constant(Max(child), keep_domains=True)


class Min(SpecificFunction):
""" Min function """

def __init__(self, child):
super().__init__(np.min, child)

@property
def julia_name(self):
""" See :meth:`pybamm.Function.julia_name` """
return "minimum"


def min(child):
"""
Returns min function of child. Not to be confused with :meth:`pybamm.minimum`, which
returns the smaller of two objects.
"""
return pybamm.simplify_if_constant(Function(np.min, child), keep_domains=True)
return pybamm.simplify_if_constant(Min(child), keep_domains=True)


def sech(child):
Expand Down
58 changes: 30 additions & 28 deletions pybamm/expression_tree/operations/evaluate_julia.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ def find_symbols(symbol, constant_symbols, variable_symbols, variable_symbol_siz
elif isinstance(symbol, pybamm.Inner):
symbol_str = "{0} .* {1}".format(children_vars[0], children_vars[1])
elif isinstance(symbol, pybamm.Minimum):
symbol_str = "np.minimum({},{})".format(children_vars[0], children_vars[1])
symbol_str = "min.({},{})".format(children_vars[0], children_vars[1])
elif isinstance(symbol, pybamm.Maximum):
symbol_str = "np.maximum({},{})".format(children_vars[0], children_vars[1])
symbol_str = "max.({},{})".format(children_vars[0], children_vars[1])
elif isinstance(symbol, pybamm.Power):
# julia uses ^ instead of ** for power
# include dot for elementwise operations
Expand Down Expand Up @@ -156,7 +156,10 @@ def find_symbols(symbol, constant_symbols, variable_symbols, variable_symbol_siz
# write functions directly
julia_name = symbol.julia_name
# add a . to allow elementwise operations
symbol_str = "{}.({})".format(julia_name, children_str)
if isinstance(symbol, (pybamm.Min, pybamm.Max)):
symbol_str = "{}({})".format(julia_name, children_str)
else:
symbol_str = "{}.({})".format(julia_name, children_str)

elif isinstance(symbol, pybamm.Concatenation):

Expand Down Expand Up @@ -238,11 +241,10 @@ def find_symbols(symbol, constant_symbols, variable_symbols, variable_symbol_siz
variable_symbols[symbol.id] = symbol_str

# Save the size of the variable
symbol_shape = symbol.shape
if symbol_shape == ():
if symbol.shape == ():
variable_symbol_sizes[symbol.id] = 1
elif symbol_shape[1] == 1:
variable_symbol_sizes[symbol.id] = symbol_shape[0]
elif symbol.shape[1] == 1:
variable_symbol_sizes[symbol.id] = symbol.shape[0]
else:
raise ValueError("expected scalar or column vector")

Expand Down Expand Up @@ -314,7 +316,7 @@ def get_julia_function(symbol, funcname="f"):
constants, var_symbols, var_symbol_sizes = to_julia(symbol, debug=False)

# extract constants in generated function
const_and_cache_str = "const cs=(\n"
const_and_cache_str = "cs = (\n"
for symbol_id, const_value in constants.items():
const_name = id_to_julia_variable(symbol_id, True)
const_and_cache_str += " {} = {},\n".format(const_name, const_value)
Expand Down Expand Up @@ -357,7 +359,7 @@ def get_julia_function(symbol, funcname="f"):
# first in that case, unless it is a @view in which case we don't
# need to cache
if julia_var in next_symbol_line and not (
" * " in next_symbol_line
[" * " in next_symbol_line or "mul!" in next_symbol_line]
and not symbol_line.startswith("@view")
):
if symbol_line != "t":
Expand All @@ -377,9 +379,9 @@ def get_julia_function(symbol, funcname="f"):
# otherwise assign
else:
var_str += "{} .= {}\n".format(julia_var, symbol_line)
# add "c." to constant and cache names
var_str = var_str.replace("const", "c.const")
var_str = var_str.replace("cache", "c.cache")
# add "cs." to constant and cache names
var_str = var_str.replace("const", "cs.const")
var_str = var_str.replace("cache", "cs.cache")
# indent code
var_str = " " + var_str
var_str = var_str.replace("\n", "\n ")
Expand All @@ -400,19 +402,16 @@ def get_julia_function(symbol, funcname="f"):
const_and_cache_str += ")\n"

# remove the constant and cache sring if it is empty
const_and_cache_str = const_and_cache_str.replace("const cs=(\n)\n", "")
const_and_cache_str = const_and_cache_str.replace("cs = (\n)\n", "")

# add function def and sparse arrays to first line
imports = "begin\nusing SparseArrays, LinearAlgebra\n\n"
if const_and_cache_str == "":
julia_str = imports + f"\nfunction {funcname}(dy, y, p, t)\n" + var_str
else:
julia_str = (
imports
+ const_and_cache_str
+ f"\nfunction {funcname}_with_consts(dy, y, p, t, c)\n"
+ var_str
)
julia_str = (
imports
+ const_and_cache_str
+ f"\nfunction {funcname}_with_consts(dy, y, p, t)\n"
+ var_str
)

# calculate the final variable that will output the result
result_var = id_to_julia_variable(symbol.id, symbol.is_constant())
Expand All @@ -423,18 +422,21 @@ def get_julia_function(symbol, funcname="f"):
if symbol.is_constant() and isinstance(result_value, numbers.Number):
julia_str = julia_str + "\n dy .= " + str(result_value) + "\n"
else:
julia_str = julia_str.replace("c." + result_var, "dy")
julia_str = julia_str.replace("cs." + result_var, "dy")

# close the function
julia_str += "end\n\n"
julia_str = julia_str.replace("\n end", "\nend")
julia_str = julia_str.replace("\n \n", "\n")

# Return the function with the cached variables passed in
if const_and_cache_str != "":
julia_str += (
f"{funcname}(dy, y, p, t) = {funcname}_with_consts(dy, y, p, t, cs)\n"
)
if const_and_cache_str == "":
julia_str += f"{funcname} = {funcname}_with_consts\n"
else:
# Use a let block for the cached variables
# open the let block
julia_str = julia_str.replace("cs = (", f"{funcname} = let cs = (")
# close the let block
julia_str += "end\n"

# close the "begin"
julia_str += "end"
Expand Down
23 changes: 22 additions & 1 deletion pybamm/expression_tree/operations/spm_julia.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,4 +280,25 @@ solve(prob, CVODE_BDF(), reltol=1e-8, abstol=1e-8, saveat=0.15 / 100);
# Juno.profiler()
# Profile.clear()
rand(10)
@btime rand(10);
@btime rand(10);

hh = let es = (x1 = [1,2,3,4,5], x2 = 1)
function g_inner(dy, y, p, t, c)
dy .= c.x1 .* y
end
g(dy, y, p, t) = g_inner(dy, y, p, t, es)
end
u0 = [1,1,1,1,1]
tspan = (0, 1)
prob = ODEProblem(hh, u0, tspan)
@btime solve(prob, KenCarp47(autodiff=false), reltol=1e-6, abstol=1e-6, saveat=0.1);

function aaa()
return 2
end
bbb = aaa
bbb = let xxx = 2
xxx^2
end
xxx
bbb()
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
# B = pybamm.Matrix([[11, 12, 13], [13, 14, 15], [16, 17, 18]])
# C = pybamm.Vector([[21], [22], [23]])
# expr = A @ (B @ (C * (C + pybamm.StateVector(slice(0, 3)))) + C)
expr = pybamm.NumpyConcatenation(a, b)
expr = pybamm.Vector([1, 2, 3, 4, 5, 6]) * pybamm.NumpyConcatenation(a, b)
evaluator_str = pybamm.get_julia_function(expr)
print(evaluator_str)
Main.eval(evaluator_str)
Expand Down
Loading

0 comments on commit 08a9415

Please sign in to comment.