Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make derivative 5-8 times faster #3322

Merged
merged 9 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -257,5 +257,6 @@ gauravchawhan <[email protected]>
Akki <[email protected]>
Neeraj Kumawat <[email protected]>
Emmanuel Ferdman <[email protected]>
Paul K <[email protected]>

# Generated by tools/update-authors.js
148 changes: 72 additions & 76 deletions src/function/algebra/derivative.js
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,19 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
* @return {ConstantNode | SymbolNode | ParenthesisNode | FunctionNode | OperatorNode} The derivative of `expr`
*/
function plainDerivative (expr, variable, options = { simplify: true }) {
const constNodes = {}
constTag(constNodes, expr, variable.name)
const res = _derivative(expr, constNodes)
const cache = new Map()
const variableName = variable.name
function isConstCached (node) {
const cached = cache.get(node)
if (cached !== undefined) {
return cached
}
const res = _isConst(isConstCached, node, variableName)
cache.set(node, res)
return res
}

const res = _derivative(expr, isConstCached)
return options.simplify ? simplify(res) : res
}

Expand All @@ -96,9 +106,8 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
'Node, SymbolNode, ConstantNode': function (expr, variable, {order}) {
let res = expr
for (let i = 0; i < order; i++) {
let constNodes = {}
constTag(constNodes, expr, variable.name)
res = _derivative(res, constNodes)
<create caching isConst>
res = _derivative(res, isConst)
}
return res
}
Expand Down Expand Up @@ -143,60 +152,51 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
})

/**
* Does a depth-first search on the expression tree to identify what Nodes
* are constants (e.g. 2 + 2), and stores the ones that are constants in
* constNodes. Classification is done as follows:
* Checks if a node is constants (e.g. 2 + 2).
* Accepts (usually memoized) version of self as the first parameter for recursive calls.
* Classification is done as follows:
*
* 1. ConstantNodes are constants.
* 2. If there exists a SymbolNode, of which we are differentiating over,
* in the subtree it is not constant.
*
* @param {Object} constNodes Holds the nodes that are constant
* @param {function} isConst Function that tells whether sub-expression is a constant
* @param {ConstantNode | SymbolNode | ParenthesisNode | FunctionNode | OperatorNode} node
* @param {string} varName Variable that we are differentiating
* @return {boolean} if node is constant
*/
// TODO: can we rewrite constTag into a pure function?
const constTag = typed('constTag', {
'Object, ConstantNode, string': function (constNodes, node) {
constNodes[node] = true
const _isConst = typed('_isConst', {
'function, ConstantNode, string': function () {
return true
},

'Object, SymbolNode, string': function (constNodes, node, varName) {
'function, SymbolNode, string': function (isConst, node, varName) {
// Treat other variables like constants. For reasoning, see:
// https://en.wikipedia.org/wiki/Partial_derivative
if (node.name !== varName) {
constNodes[node] = true
return true
}
return false
return node.name !== varName
},

'Object, ParenthesisNode, string': function (constNodes, node, varName) {
return constTag(constNodes, node.content, varName)
'function, ParenthesisNode, string': function (isConst, node, varName) {
return isConst(node.content, varName)
},

'Object, FunctionAssignmentNode, string': function (constNodes, node, varName) {
'function, FunctionAssignmentNode, string': function (isConst, node, varName) {
if (!node.params.includes(varName)) {
constNodes[node] = true
return true
}
return constTag(constNodes, node.expr, varName)
return isConst(node.expr, varName)
},

'Object, FunctionNode | OperatorNode, string': function (constNodes, node, varName) {
'function, FunctionNode | OperatorNode, string': function (isConst, node, varName) {
if (node.args.length > 0) {
josdejong marked this conversation as resolved.
Show resolved Hide resolved
let isConst = constTag(constNodes, node.args[0], varName)
for (let i = 1; i < node.args.length; ++i) {
isConst = constTag(constNodes, node.args[i], varName) && isConst
let allConst = true
for (let i = 0; i < node.args.length && allConst; ++i) {
allConst = isConst(node.args[i], varName) && allConst
}

if (isConst) {
constNodes[node] = true
return true
}
return allConst
}
// TODO: add a comment explaining why false (and why is this reachable)
paulftw marked this conversation as resolved.
Show resolved Hide resolved
return false
}
})
Expand All @@ -205,34 +205,34 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
* Applies differentiation rules.
*
* @param {ConstantNode | SymbolNode | ParenthesisNode | FunctionNode | OperatorNode} node
* @param {Object} constNodes Holds the nodes that are constant
* @param {function} isConst Function that tells if a node is constant
* @return {ConstantNode | SymbolNode | ParenthesisNode | FunctionNode | OperatorNode} The derivative of `expr`
*/
const _derivative = typed('_derivative', {
'ConstantNode, Object': function (node) {
'ConstantNode, function': function () {
return createConstantNode(0)
},

'SymbolNode, Object': function (node, constNodes) {
if (constNodes[node] !== undefined) {
'SymbolNode, function': function (node, isConst) {
if (isConst(node)) {
return createConstantNode(0)
}
return createConstantNode(1)
},

'ParenthesisNode, Object': function (node, constNodes) {
return new ParenthesisNode(_derivative(node.content, constNodes))
'ParenthesisNode, function': function (node, isConst) {
return new ParenthesisNode(_derivative(node.content, isConst))
},

'FunctionAssignmentNode, Object': function (node, constNodes) {
if (constNodes[node] !== undefined) {
'FunctionAssignmentNode, function': function (node, isConst) {
if (isConst(node)) {
return createConstantNode(0)
}
return _derivative(node.expr, constNodes)
return _derivative(node.expr, isConst)
},

'FunctionNode, Object': function (node, constNodes) {
if (constNodes[node] !== undefined) {
'FunctionNode, function': function (node, isConst) {
if (isConst(node)) {
return createConstantNode(0)
}

Expand Down Expand Up @@ -274,10 +274,7 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
node.args[1]
])

// Is a variable?
constNodes[arg1] = constNodes[node.args[1]]

return _derivative(new OperatorNode('^', 'pow', [arg0, arg1]), constNodes)
return _derivative(new OperatorNode('^', 'pow', [arg0, arg1]), isConst)
}
break
case 'log10':
Expand All @@ -289,7 +286,7 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
funcDerivative = arg0.clone()
div = true
} else if ((node.args.length === 1 && arg1) ||
(node.args.length === 2 && constNodes[node.args[1]] !== undefined)) {
(node.args.length === 2 && isConst(node.args[1]))) {
// d/dx(log(x, c)) = 1 / (x*ln(c))
funcDerivative = new OperatorNode('*', 'multiply', [
arg0.clone(),
Expand All @@ -301,14 +298,13 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
return _derivative(new OperatorNode('/', 'divide', [
new FunctionNode('log', [arg0]),
new FunctionNode('log', [node.args[1]])
]), constNodes)
]), isConst)
}
break
case 'pow':
if (node.args.length === 2) {
constNodes[arg1] = constNodes[node.args[1]]
// Pass to pow operator node parser
return _derivative(new OperatorNode('^', 'pow', [arg0, node.args[1]]), constNodes)
return _derivative(new OperatorNode('^', 'pow', [arg0, node.args[1]]), isConst)
}
break
case 'exp':
Expand Down Expand Up @@ -585,58 +581,58 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
/* Apply chain rule to all functions:
F(x) = f(g(x))
F'(x) = g'(x)*f'(g(x)) */
let chainDerivative = _derivative(arg0, constNodes)
let chainDerivative = _derivative(arg0, isConst)
if (negative) {
chainDerivative = new OperatorNode('-', 'unaryMinus', [chainDerivative])
}
return new OperatorNode(op, func, [chainDerivative, funcDerivative])
},

'OperatorNode, Object': function (node, constNodes) {
if (constNodes[node] !== undefined) {
'OperatorNode, function': function (node, isConst) {
if (isConst(node)) {
return createConstantNode(0)
}

if (node.op === '+') {
// d/dx(sum(f(x)) = sum(f'(x))
return new OperatorNode(node.op, node.fn, node.args.map(function (arg) {
return _derivative(arg, constNodes)
return _derivative(arg, isConst)
}))
}

if (node.op === '-') {
// d/dx(+/-f(x)) = +/-f'(x)
if (node.isUnary()) {
return new OperatorNode(node.op, node.fn, [
_derivative(node.args[0], constNodes)
_derivative(node.args[0], isConst)
])
}

// Linearity of differentiation, d/dx(f(x) +/- g(x)) = f'(x) +/- g'(x)
if (node.isBinary()) {
return new OperatorNode(node.op, node.fn, [
_derivative(node.args[0], constNodes),
_derivative(node.args[1], constNodes)
_derivative(node.args[0], isConst),
_derivative(node.args[1], isConst)
])
}
}

if (node.op === '*') {
// d/dx(c*f(x)) = c*f'(x)
const constantTerms = node.args.filter(function (arg) {
return constNodes[arg] !== undefined
return isConst(arg)
})

if (constantTerms.length > 0) {
const nonConstantTerms = node.args.filter(function (arg) {
return constNodes[arg] === undefined
return !isConst(arg)
})

const nonConstantNode = nonConstantTerms.length === 1
? nonConstantTerms[0]
: new OperatorNode('*', 'multiply', nonConstantTerms)

const newArgs = constantTerms.concat(_derivative(nonConstantNode, constNodes))
const newArgs = constantTerms.concat(_derivative(nonConstantNode, isConst))

return new OperatorNode('*', 'multiply', newArgs)
}
Expand All @@ -645,7 +641,7 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
return new OperatorNode('+', 'add', node.args.map(function (argOuter) {
return new OperatorNode('*', 'multiply', node.args.map(function (argInner) {
return (argInner === argOuter)
? _derivative(argInner, constNodes)
? _derivative(argInner, isConst)
: argInner.clone()
}))
}))
Expand All @@ -656,16 +652,16 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
const arg1 = node.args[1]

// d/dx(f(x) / c) = f'(x) / c
if (constNodes[arg1] !== undefined) {
return new OperatorNode('/', 'divide', [_derivative(arg0, constNodes), arg1])
if (isConst(arg1)) {
return new OperatorNode('/', 'divide', [_derivative(arg0, isConst), arg1])
}

// Reciprocal Rule, d/dx(c / f(x)) = -c(f'(x)/f(x)^2)
if (constNodes[arg0] !== undefined) {
if (isConst(arg0)) {
return new OperatorNode('*', 'multiply', [
new OperatorNode('-', 'unaryMinus', [arg0]),
new OperatorNode('/', 'divide', [
_derivative(arg1, constNodes),
_derivative(arg1, isConst),
new OperatorNode('^', 'pow', [arg1.clone(), createConstantNode(2)])
])
])
Expand All @@ -674,8 +670,8 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
// Quotient rule, d/dx(f(x) / g(x)) = (f'(x)g(x) - f(x)g'(x)) / g(x)^2
return new OperatorNode('/', 'divide', [
new OperatorNode('-', 'subtract', [
new OperatorNode('*', 'multiply', [_derivative(arg0, constNodes), arg1.clone()]),
new OperatorNode('*', 'multiply', [arg0.clone(), _derivative(arg1, constNodes)])
new OperatorNode('*', 'multiply', [_derivative(arg0, isConst), arg1.clone()]),
new OperatorNode('*', 'multiply', [arg0.clone(), _derivative(arg1, isConst)])
]),
new OperatorNode('^', 'pow', [arg1.clone(), createConstantNode(2)])
])
Expand All @@ -685,7 +681,7 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
const arg0 = node.args[0]
const arg1 = node.args[1]

if (constNodes[arg0] !== undefined) {
if (isConst(arg0)) {
// If is secretly constant; 0^f(x) = 1 (in JS), 1^f(x) = 1
if (isConstantNode(arg0) && (isZero(arg0.value) || equal(arg0.value, 1))) {
return createConstantNode(0)
Expand All @@ -696,20 +692,20 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
node,
new OperatorNode('*', 'multiply', [
new FunctionNode('log', [arg0.clone()]),
_derivative(arg1.clone(), constNodes)
_derivative(arg1.clone(), isConst)
])
])
}

if (constNodes[arg1] !== undefined) {
if (isConst(arg1)) {
if (isConstantNode(arg1)) {
// If is secretly constant; f(x)^0 = 1 -> d/dx(1) = 0
if (isZero(arg1.value)) {
return createConstantNode(0)
}
// Ignore exponent; f(x)^1 = f(x)
if (equal(arg1.value, 1)) {
return _derivative(arg0, constNodes)
return _derivative(arg0, isConst)
}
}

Expand All @@ -725,7 +721,7 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
return new OperatorNode('*', 'multiply', [
arg1.clone(),
new OperatorNode('*', 'multiply', [
_derivative(arg0, constNodes),
_derivative(arg0, isConst),
powMinusOne
])
])
Expand All @@ -736,11 +732,11 @@ export const createDerivative = /* #__PURE__ */ factory(name, dependencies, ({
new OperatorNode('^', 'pow', [arg0.clone(), arg1.clone()]),
new OperatorNode('+', 'add', [
new OperatorNode('*', 'multiply', [
_derivative(arg0, constNodes),
_derivative(arg0, isConst),
new OperatorNode('/', 'divide', [arg1.clone(), arg0.clone()])
]),
new OperatorNode('*', 'multiply', [
_derivative(arg1, constNodes),
_derivative(arg1, isConst),
new FunctionNode('log', [arg0.clone()])
])
])
Expand Down
Loading