Skip to content

Commit

Permalink
Merge pull request #29 from SciML/deps
Browse files Browse the repository at this point in the history
Fix symbolic deprecations
  • Loading branch information
ChrisRackauckas authored Jun 7, 2024
2 parents 0e41233 + 8040245 commit 7b0e3d4
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 22 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PDEBase"
uuid = "a7812802-0625-4b9e-961c-d332478797e5"
authors = ["xtalax <[email protected]>"]
version = "0.1.11"
version = "0.1.12"

[deps]
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
Expand Down
8 changes: 4 additions & 4 deletions src/PDEBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ using SciMLBase: AbstractDiscretization, AbstractDiscretizationMetadata

using ModelingToolkit

using ModelingToolkit: operation, istree, arguments, variable, get_metadata, get_unknowns,
using ModelingToolkit: get_metadata, get_unknowns,
parameters, defaults, varmap_to_vars, get_eqs, get_iv

using Symbolics, SymbolicUtils
using Symbolics: unwrap, solve_for, expand_derivatives, diff2term, setname, rename,
similarterm
using SymbolicUtils: operation, arguments, Chain, Prewalk, Postwalk
using Symbolics: unwrap, solve_for, expand_derivatives, diff2term, setname, rename
using SymbolicUtils: operation, arguments, Chain, Prewalk, Postwalk, maketerm, metadata,
symtype, operation, iscall, arguments, variable
using DomainSets

abstract type AbstractEquationSystemDiscretization <: AbstractDiscretization end
Expand Down
34 changes: 17 additions & 17 deletions src/symbolic_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ the order of a PDE.
function count_differentials(term, x::Symbolics.Symbolic)
S = Symbolics
SU = SymbolicUtils
if !S.istree(term)
if !S.iscall(term)
return 0
else
op = SU.operation(term)
Expand All @@ -24,7 +24,7 @@ function differential_order(eq, x::Symbolics.Symbolic)
S = Symbolics
SU = SymbolicUtils
orders = Set{Int}()
if S.istree(eq)
if S.iscall(eq)
op = SU.operation(eq)
if op isa Differential
push!(orders, count_differentials(eq, x))
Expand All @@ -41,7 +41,7 @@ end
Determine whether a term has a derivative anywhere in it.
"""
function has_derivatives(term)
if istree(term)
if iscall(term)
op = operation(term)
if op isa Differential
return true
Expand All @@ -60,7 +60,7 @@ function find_derivative(term, depvar_op)
S = Symbolics
SU = SymbolicUtils
orders = Set{Int}()
if S.istree(eq)
if S.iscall(eq)
op = SU.operation(term)
if (op isa Differential) | isequal(op, depvar_op)
return term
Expand Down Expand Up @@ -92,7 +92,7 @@ find all the dependent variables given by depvar_ops in an expression
function get_depvars(eq, depvar_ops)
depvars = Set()
eq = safe_unwrap(eq)
if istree(eq)
if iscall(eq)
if any(u -> isequal(operation(eq), u), depvar_ops)
push!(depvars, eq)
else
Expand All @@ -107,7 +107,7 @@ end
function get_indvars(eq, v)
ivs = Set()
eq = safe_unwrap(eq)
if istree(eq)
if iscall(eq)
for o in map(x -> get_indvars(x, v), arguments(eq))
union!(ivs, o)
end
Expand Down Expand Up @@ -140,7 +140,7 @@ function _split_terms(term)
S = Symbolics
SU = SymbolicUtils
# TODO: Update this to be exclusive of derivatives and depvars rather than inclusive of +-/*
if S.istree(term) && ((operation(term) == +) | (operation(term) == -) | (operation(term) == *) | (operation(term) == /))
if S.iscall(term) && ((operation(term) == +) | (operation(term) == -) | (operation(term) == *) | (operation(term) == /))
return mapreduce(_split_terms, vcat, SU.arguments(term))
else
return [term]
Expand All @@ -155,13 +155,13 @@ function _split_terms(term, x̄)
st(t) = _split_terms(t, x̄)
# TODO: Update this to handle more ops e.g. exp sin tanh etc.
# TODO: Handle cases where two nonlinear laplacians are multiplied together
if S.istree(term)
if S.iscall(term)
# Additional handling for upwinding
if (operation(term) == *)
args = SU.arguments(term)
for (i, arg) in enumerate(args)
# Incase of upwinding, we need to keep the original term
if S.istree(arg) && operation(arg) isa Differential
if S.iscall(arg) && operation(arg) isa Differential
# Flatten the arguments of the differential to make nonlinear laplacian work in more cases
try
args[i] = operation(arg)(flatten_division.(SU.arguments(arg))...)
Expand All @@ -176,7 +176,7 @@ function _split_terms(term, x̄)
elseif (operation(term) == /)
args = SU.arguments(term)
# Incase of upwinding or spherical, we need to keep the original term
if S.istree(args[1])
if S.iscall(args[1])
if args[1] isa Differential
try
args[1] = operation(arg)(flatten_division.(SU.arguments(arg))...)
Expand All @@ -190,7 +190,7 @@ function _split_terms(term, x̄)
subargs = SU.arguments(args[1])
# look for a differential in the arguments
for (i, arg) in enumerate(subargs)
if S.istree(arg) && operation(arg) isa Differential
if S.iscall(arg) && operation(arg) isa Differential
# Flatten the arguments of the differential to make nonlinear laplacian/spherical work in more cases
try
subargs[i] = operation(arg)(flatten_division.(SU.arguments(arg))...)
Expand Down Expand Up @@ -226,8 +226,8 @@ end

function split_additive_terms(eq)
# Calling the methods from symbolicutils matches the expressions
rhs_arg = istree(eq.rhs) && (SymbolicUtils.operation(eq.rhs) == +) ? SymbolicUtils.arguments(eq.rhs) : [eq.rhs]
lhs_arg = istree(eq.lhs) && (SymbolicUtils.operation(eq.lhs) == +) ? SymbolicUtils.arguments(eq.lhs) : [eq.lhs]
rhs_arg = iscall(eq.rhs) && (SymbolicUtils.operation(eq.rhs) == +) ? SymbolicUtils.arguments(eq.rhs) : [eq.rhs]
lhs_arg = iscall(eq.lhs) && (SymbolicUtils.operation(eq.lhs) == +) ? SymbolicUtils.arguments(eq.lhs) : [eq.lhs]

return vcat(lhs_arg, rhs_arg)
end
Expand All @@ -252,7 +252,7 @@ function subsmatch(expr, rule)
if isequal(expr, rule.first)
return true
end
if istree(expr)
if iscall(expr)
return any(ex -> subsmatch(ex, rule), arguments(expr))
end
return false
Expand All @@ -268,18 +268,18 @@ not a `Num`.
```
"""
function ex2term(term, v)
istree(term) || return term
iscall(term) || return term
termdvs = collect(get_depvars(term, v.depvar_ops))
symdvs = filter(u -> all(x -> !(safe_unwrap(x) isa Number), arguments(u)), termdvs)
exdv = last(sort(symdvs, by=u -> length(arguments(u))))
name = Symbol("" * string(term) * "")
return setname(similarterm(exdv, rename(operation(exdv), name), arguments(exdv)), name)
return setname(similarterm(typeof(exdv), rename(operation(exdv), name), arguments(exdv), symtype(exdv), metadata(exdv)), name)
end

safe_unwrap(x) = x isa Num ? unwrap(x) : x

function recursive_unwrap(ex)
if !istree(ex)
if !iscall(ex)
return safe_unwrap(ex)
end

Expand Down

0 comments on commit 7b0e3d4

Please sign in to comment.