From c7d17531d6faa898b2d7fb71a2f46bff867a728c Mon Sep 17 00:00:00 2001 From: Valentin Sulzer Date: Thu, 29 Oct 2020 13:32:42 -0400 Subject: [PATCH] #1129 use mul! and cache variables --- .../operations/evaluate_julia.py | 64 ++++++++++++++----- .../test_operations/quick_julia_test.py | 6 +- 2 files changed, 52 insertions(+), 18 deletions(-) diff --git a/pybamm/expression_tree/operations/evaluate_julia.py b/pybamm/expression_tree/operations/evaluate_julia.py index d3b2607e91..4e8b19d73f 100644 --- a/pybamm/expression_tree/operations/evaluate_julia.py +++ b/pybamm/expression_tree/operations/evaluate_julia.py @@ -5,6 +5,7 @@ import numpy as np import scipy.sparse +import re from collections import OrderedDict import numbers @@ -19,13 +20,13 @@ def id_to_julia_variable(symbol_id, constant=False): if constant: var_format = "const_{:05d}" else: - var_format = "var_{:05d}" + var_format = "cache_{:05d}" # Need to replace "-" character to make them valid julia variable names return var_format.format(symbol_id).replace("-", "m") -def find_symbols(symbol, constant_symbols, variable_symbols): +def find_symbols(symbol, constant_symbols, variable_symbols, variable_symbol_sizes): """ This function converts an expression tree to a dictionary of node id's and strings specifying valid julia code to calculate that nodes value, given y and t. @@ -51,6 +52,10 @@ def find_symbols(symbol, constant_symbols, variable_symbols): variable_symbol : collections.OrderedDict The output dictionary of variable (with y or t) symbol ids to lines of code + variable_symbol_sizes : collections.OrderedDict + The output dictionary of variable (with y or t) symbol ids to size of that + variable, for caching + """ if symbol.is_constant(): value = symbol.evaluate() @@ -94,7 +99,7 @@ def find_symbols(symbol, constant_symbols, variable_symbols): # process children recursively for child in symbol.children: - find_symbols(child, constant_symbols, variable_symbols) + find_symbols(child, constant_symbols, variable_symbols, variable_symbol_sizes) # calculate the variable names that will hold the result of calculating the # children variables @@ -223,6 +228,12 @@ def find_symbols(symbol, constant_symbols, variable_symbols): variable_symbols[symbol.id] = symbol_str + # Save the size of the variable + symbol_shape = symbol.shape + if symbol_shape[1] != 1: + raise ValueError("expected column vector") + variable_symbol_sizes[symbol.id] = symbol_shape[0] + def to_julia(symbol, debug=False): """ @@ -246,9 +257,10 @@ def to_julia(symbol, debug=False): constant_values = OrderedDict() variable_symbols = OrderedDict() - find_symbols(symbol, constant_values, variable_symbols) + variable_symbol_sizes = OrderedDict() + find_symbols(symbol, constant_values, variable_symbols, variable_symbol_sizes) - line_format = "{} = {}" + line_format = "{} .= {}" if debug: variable_lines = [ @@ -267,7 +279,7 @@ def to_julia(symbol, debug=False): for symbol_id, symbol_line in variable_symbols.items() ] - return constant_values, "\n".join(variable_lines) + return constant_values, "\n".join(variable_lines), variable_symbol_sizes def get_julia_function(symbol, funcname="f"): @@ -287,26 +299,40 @@ def get_julia_function(symbol, funcname="f"): """ - constants, var_str = to_julia(symbol, debug=False) + constants, var_str, var_symbol_sizes = to_julia(symbol, debug=False) # extract constants in generated function - const_str = "const cs=(\n" + const_and_cache_str = "const cs=(\n" for symbol_id, const_value in constants.items(): const_name = id_to_julia_variable(symbol_id, True) - const_str += " {} = {},\n".format(const_name, const_value) - const_str += ")\n" + const_and_cache_str += " {} = {},\n".format(const_name, const_value) + # add "c." to constant and cache names + var_str = var_str.replace("const", "c.const") + var_str = var_str.replace("cache", "c.cache") + # replace matrix multiplications with mul! (requires LinearAlgebra library) + var_str = re.sub("(.+) .= (.+) \* (.+)", r"mul!(\1, \2, \3)", var_str) # indent code var_str = " " + var_str var_str = var_str.replace("\n", "\n ") - # add "c." to constant names - var_str = var_str.replace("const", "c.const") + + # add the cache variables to the cache NamedTuple + for var_symbol_id, var_symbol_size in var_symbol_sizes.items(): + # Skip caching the result variable since this is provided as dy + if var_symbol_id != symbol.id: + cache_name = id_to_julia_variable(var_symbol_id, False) + const_and_cache_str += " {} = zeros({}),\n".format( + cache_name, var_symbol_size + ) + + # close the constants and cache string + const_and_cache_str += ")\n" # add function def and sparse arrays to first line - imports = "begin\nusing SparseArrays, LinearAlgebra\n" + imports = "begin\nusing SparseArrays, LinearAlgebra\n\n" julia_str = ( imports - + const_str + + const_and_cache_str + f"\nfunction {funcname}_with_consts(dy, y, p, t, c)\n" + var_str ) @@ -316,12 +342,16 @@ def get_julia_function(symbol, funcname="f"): if symbol.is_constant(): result_value = symbol.evaluate() - # add return line + # assign the return variable if symbol.is_constant() and isinstance(result_value, numbers.Number): - julia_str = julia_str + "\n return " + str(result_value) + "\nend\n" + julia_str = julia_str + "\n return " + str(result_value) else: - julia_str = julia_str + "\n return " + result_var + "\nend\n" + julia_str = julia_str.replace("c." + result_var, "dy") + + # close the function + julia_str += "\nend\n\n" + # Return the function with the cached variables passed in julia_str += f"{funcname}(dy, y, p, t) = {funcname}_with_consts(dy, y, p, t, cs)\n" # close the "begin" diff --git a/tests/unit/test_expression_tree/test_operations/quick_julia_test.py b/tests/unit/test_expression_tree/test_operations/quick_julia_test.py index 2629f773af..a61a0d453a 100644 --- a/tests/unit/test_expression_tree/test_operations/quick_julia_test.py +++ b/tests/unit/test_expression_tree/test_operations/quick_julia_test.py @@ -33,7 +33,11 @@ expr = A @ pybamm.StateVector(slice(0, 2)) evaluator_str = pybamm.get_julia_function(expr) print(evaluator_str) - +Main.eval(evaluator_str) +Main.dy = [0, 0] +Main.y = [2, 3] +print(Main.eval("f(dy,y,0,0)")) +print(Main.dy) # # test something with a heaviside # a = pybamm.Vector([1, 2]) # expr = a <= pybamm.StateVector(slice(0, 2))