Skip to content

Commit

Permalink
#1129 use mul! and cache variables
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Oct 29, 2020
1 parent 60db480 commit c7d1753
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 18 deletions.
64 changes: 47 additions & 17 deletions pybamm/expression_tree/operations/evaluate_julia.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import scipy.sparse
import re
from collections import OrderedDict

import numbers
Expand All @@ -19,13 +20,13 @@ def id_to_julia_variable(symbol_id, constant=False):
if constant:
var_format = "const_{:05d}"
else:
var_format = "var_{:05d}"
var_format = "cache_{:05d}"

# Need to replace "-" character to make them valid julia variable names
return var_format.format(symbol_id).replace("-", "m")


def find_symbols(symbol, constant_symbols, variable_symbols):
def find_symbols(symbol, constant_symbols, variable_symbols, variable_symbol_sizes):
"""
This function converts an expression tree to a dictionary of node id's and strings
specifying valid julia code to calculate that nodes value, given y and t.
Expand All @@ -51,6 +52,10 @@ def find_symbols(symbol, constant_symbols, variable_symbols):
variable_symbol : collections.OrderedDict
The output dictionary of variable (with y or t) symbol ids to lines of code
variable_symbol_sizes : collections.OrderedDict
The output dictionary of variable (with y or t) symbol ids to size of that
variable, for caching
"""
if symbol.is_constant():
value = symbol.evaluate()
Expand Down Expand Up @@ -94,7 +99,7 @@ def find_symbols(symbol, constant_symbols, variable_symbols):

# process children recursively
for child in symbol.children:
find_symbols(child, constant_symbols, variable_symbols)
find_symbols(child, constant_symbols, variable_symbols, variable_symbol_sizes)

# calculate the variable names that will hold the result of calculating the
# children variables
Expand Down Expand Up @@ -223,6 +228,12 @@ def find_symbols(symbol, constant_symbols, variable_symbols):

variable_symbols[symbol.id] = symbol_str

# Save the size of the variable
symbol_shape = symbol.shape
if symbol_shape[1] != 1:
raise ValueError("expected column vector")
variable_symbol_sizes[symbol.id] = symbol_shape[0]


def to_julia(symbol, debug=False):
"""
Expand All @@ -246,9 +257,10 @@ def to_julia(symbol, debug=False):

constant_values = OrderedDict()
variable_symbols = OrderedDict()
find_symbols(symbol, constant_values, variable_symbols)
variable_symbol_sizes = OrderedDict()
find_symbols(symbol, constant_values, variable_symbols, variable_symbol_sizes)

line_format = "{} = {}"
line_format = "{} .= {}"

if debug:
variable_lines = [
Expand All @@ -267,7 +279,7 @@ def to_julia(symbol, debug=False):
for symbol_id, symbol_line in variable_symbols.items()
]

return constant_values, "\n".join(variable_lines)
return constant_values, "\n".join(variable_lines), variable_symbol_sizes


def get_julia_function(symbol, funcname="f"):
Expand All @@ -287,26 +299,40 @@ def get_julia_function(symbol, funcname="f"):
"""

constants, var_str = to_julia(symbol, debug=False)
constants, var_str, var_symbol_sizes = to_julia(symbol, debug=False)

# extract constants in generated function
const_str = "const cs=(\n"
const_and_cache_str = "const cs=(\n"
for symbol_id, const_value in constants.items():
const_name = id_to_julia_variable(symbol_id, True)
const_str += " {} = {},\n".format(const_name, const_value)
const_str += ")\n"
const_and_cache_str += " {} = {},\n".format(const_name, const_value)

# add "c." to constant and cache names
var_str = var_str.replace("const", "c.const")
var_str = var_str.replace("cache", "c.cache")
# replace matrix multiplications with mul! (requires LinearAlgebra library)
var_str = re.sub("(.+) .= (.+) \* (.+)", r"mul!(\1, \2, \3)", var_str)
# indent code
var_str = " " + var_str
var_str = var_str.replace("\n", "\n ")
# add "c." to constant names
var_str = var_str.replace("const", "c.const")

# add the cache variables to the cache NamedTuple
for var_symbol_id, var_symbol_size in var_symbol_sizes.items():
# Skip caching the result variable since this is provided as dy
if var_symbol_id != symbol.id:
cache_name = id_to_julia_variable(var_symbol_id, False)
const_and_cache_str += " {} = zeros({}),\n".format(
cache_name, var_symbol_size
)

# close the constants and cache string
const_and_cache_str += ")\n"

# add function def and sparse arrays to first line
imports = "begin\nusing SparseArrays, LinearAlgebra\n"
imports = "begin\nusing SparseArrays, LinearAlgebra\n\n"
julia_str = (
imports
+ const_str
+ const_and_cache_str
+ f"\nfunction {funcname}_with_consts(dy, y, p, t, c)\n"
+ var_str
)
Expand All @@ -316,12 +342,16 @@ def get_julia_function(symbol, funcname="f"):
if symbol.is_constant():
result_value = symbol.evaluate()

# add return line
# assign the return variable
if symbol.is_constant() and isinstance(result_value, numbers.Number):
julia_str = julia_str + "\n return " + str(result_value) + "\nend\n"
julia_str = julia_str + "\n return " + str(result_value)
else:
julia_str = julia_str + "\n return " + result_var + "\nend\n"
julia_str = julia_str.replace("c." + result_var, "dy")

# close the function
julia_str += "\nend\n\n"

# Return the function with the cached variables passed in
julia_str += f"{funcname}(dy, y, p, t) = {funcname}_with_consts(dy, y, p, t, cs)\n"

# close the "begin"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
expr = A @ pybamm.StateVector(slice(0, 2))
evaluator_str = pybamm.get_julia_function(expr)
print(evaluator_str)

Main.eval(evaluator_str)
Main.dy = [0, 0]
Main.y = [2, 3]
print(Main.eval("f(dy,y,0,0)"))
print(Main.dy)
# # test something with a heaviside
# a = pybamm.Vector([1, 2])
# expr = a <= pybamm.StateVector(slice(0, 2))
Expand Down

0 comments on commit c7d1753

Please sign in to comment.