Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: safe exponentiation #4637

Merged
merged 1 commit into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/Lean/CoreM.lean
Original file line number Diff line number Diff line change
Expand Up @@ -519,4 +519,16 @@ instance : MonadRuntimeException CoreM where
@[inline] def mapCoreM [MonadControlT CoreM m] [Monad m] (f : forall {α}, CoreM α → CoreM α) {α} (x : m α) : m α :=
controlAt CoreM fun runInBase => f <| runInBase x

/--
Returns `true` if the given message kind has not been reported in the message log,
and then mark it as reported. Otherwise, returns `false`.
We use this API to ensure we don't report the same kind of warning multiple times.
-/
def reportMessageKind (kind : Name) : CoreM Bool := do
if (← get).messages.reportedKinds.contains kind then
return false
else
modify fun s => { s with messages.reportedKinds := s.messages.reportedKinds.insert kind }
return true

end Lean
9 changes: 8 additions & 1 deletion src/Lean/Message.lean
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,13 @@ structure MessageLog where
hadErrors : Bool := false
/-- The list of messages not already reported, in insertion order. -/
unreported : PersistentArray Message := {}
/--
Set of message kinds that have been added to the log.
For example, we have the kind `unsafe.exponentiation.warning` for warning messages associated with
the configuration option `exponentiation.threshold`.
We don't produce a warning if the kind is already in the following set.
-/
reportedKinds : NameSet := {}
deriving Inhabited

namespace MessageLog
Expand Down Expand Up @@ -403,7 +410,7 @@ def indentExpr (e : Expr) : MessageData :=
indentD e

class AddMessageContext (m : Type → Type) where
/--
/--
Without context, a `MessageData` object may be be missing information
(e.g. hover info) for pretty printing, or may print an error. Hence,
`addMessageContext` should be called on all constructed `MessageData`
Expand Down
14 changes: 10 additions & 4 deletions src/Lean/Meta/Offset.lean
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ prelude
import Lean.Data.LBool
import Lean.Meta.InferType
import Lean.Meta.NatInstTesters
import Lean.Meta.NatInstTesters
import Lean.Util.SafeExponentiation

namespace Lean.Meta

Expand All @@ -29,6 +31,10 @@ partial def evalNat (e : Expr) : OptionT MetaM Nat := do
| .mvar .. => visit e
| _ => failure
where
evalPow (b n : Expr) : OptionT MetaM Nat := do
let n ← evalNat n
guard (← checkExponent n)
return (← evalNat b) ^ n
visit e := do
match_expr e with
| OfNat.ofNat _ n i => guard (← isInstOfNatNat i); evalNat n
Expand All @@ -48,10 +54,10 @@ where
| Nat.mod a b => return (← evalNat a) % (← evalNat b)
| Mod.mod _ i a b => guard (← isInstModNat i); return (← evalNat a) % (← evalNat b)
| HMod.hMod _ _ _ i a b => guard (← isInstHModNat i); return (← evalNat a) % (← evalNat b)
| Nat.pow a b => return (← evalNat a) ^ (← evalNat b)
| NatPow.pow _ i a b => guard (← isInstNatPowNat i); return (← evalNat a) ^ (← evalNat b)
| Pow.pow _ _ i a b => guard (← isInstPowNat i); return (← evalNat a) ^ (← evalNat b)
| HPow.hPow _ _ _ i a b => guard (← isInstHPowNat i); return (← evalNat a) ^ (← evalNat b)
| Nat.pow a b => evalPow a b
| NatPow.pow _ i a b => guard (← isInstNatPowNat i); evalPow a b
| Pow.pow _ _ i a b => guard (← isInstPowNat i); evalPow a b
| HPow.hPow _ _ _ i a b => guard (← isInstHPowNat i); evalPow a b
| _ => failure

/--
Expand Down
1 change: 1 addition & 0 deletions src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Int.lean
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ builtin_dsimproc [simp, seval] reducePow ((_ : Int) ^ (_ : Nat)) := fun e => do
let_expr HPow.hPow _ _ _ _ a b ← e | return .continue
let some v₁ ← fromExpr? a | return .continue
let some v₂ ← Nat.fromExpr? b | return .continue
unless (← checkExponent v₂) do return .continue
return .done <| toExpr (v₁ ^ v₂)

builtin_simproc [simp, seval] reduceLT (( _ : Int) < _) := reduceBinPred ``LT.lt 4 (. < .)
Expand Down
9 changes: 8 additions & 1 deletion src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Nat.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Authors: Leonardo de Moura
prelude
import Init.Simproc
import Init.Data.Nat.Simproc
import Lean.Util.SafeExponentiation
import Lean.Meta.LitValues
import Lean.Meta.Offset
import Lean.Meta.Tactic.Simp.Simproc
Expand Down Expand Up @@ -52,7 +53,13 @@ builtin_dsimproc [simp, seval] reduceMul ((_ * _ : Nat)) := reduceBin ``HMul.hMu
builtin_dsimproc [simp, seval] reduceSub ((_ - _ : Nat)) := reduceBin ``HSub.hSub 6 (· - ·)
builtin_dsimproc [simp, seval] reduceDiv ((_ / _ : Nat)) := reduceBin ``HDiv.hDiv 6 (· / ·)
builtin_dsimproc [simp, seval] reduceMod ((_ % _ : Nat)) := reduceBin ``HMod.hMod 6 (· % ·)
builtin_dsimproc [simp, seval] reducePow ((_ ^ _ : Nat)) := reduceBin ``HPow.hPow 6 (· ^ ·)

builtin_dsimproc [simp, seval] reducePow ((_ ^ _ : Nat)) := fun e => do
let some n ← fromExpr? e.appFn!.appArg! | return .continue
let some m ← fromExpr? e.appArg! | return .continue
unless (← checkExponent m) do return .continue
return .done <| toExpr (n ^ m)

builtin_dsimproc [simp, seval] reduceGcd (gcd _ _) := reduceBin ``gcd 2 gcd

builtin_simproc [simp, seval] reduceLT (( _ : Nat) < _) := reduceBinPred ``LT.lt 4 (. < .)
Expand Down
10 changes: 9 additions & 1 deletion src/Lean/Meta/WHNF.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Authors: Leonardo de Moura
prelude
import Lean.Structure
import Lean.Util.Recognizers
import Lean.Util.SafeExponentiation
import Lean.Meta.GetUnfoldableConst
import Lean.Meta.FunInfo
import Lean.Meta.Offset
Expand Down Expand Up @@ -885,6 +886,13 @@ def reduceBinNatOp (f : Nat → Nat → Nat) (a b : Expr) : MetaM (Option Expr)
trace[Meta.isDefEq.whnf.reduceBinOp] "{a} op {b}"
return mkRawNatLit <| f a b

def reducePow (a b : Expr) : MetaM (Option Expr) :=
withNatValue a fun a =>
withNatValue b fun b => OptionT.run do
guard (← checkExponent b)
trace[Meta.isDefEq.whnf.reduceBinOp] "{a} ^ {b}"
return mkRawNatLit <| a ^ b

def reduceBinNatPred (f : Nat → Nat → Bool) (a b : Expr) : MetaM (Option Expr) := do
withNatValue a fun a =>
withNatValue b fun b =>
Expand All @@ -904,7 +912,7 @@ def reduceNat? (e : Expr) : MetaM (Option Expr) :=
| ``Nat.mul => reduceBinNatOp Nat.mul a1 a2
| ``Nat.div => reduceBinNatOp Nat.div a1 a2
| ``Nat.mod => reduceBinNatOp Nat.mod a1 a2
| ``Nat.pow => reduceBinNatOp Nat.pow a1 a2
| ``Nat.pow => reducePow a1 a2
| ``Nat.gcd => reduceBinNatOp Nat.gcd a1 a2
| ``Nat.beq => reduceBinNatPred Nat.beq a1 a2
| ``Nat.ble => reduceBinNatPred Nat.ble a1 a2
Expand Down
1 change: 1 addition & 0 deletions src/Lean/Util.lean
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ import Lean.Util.OccursCheck
import Lean.Util.HasConstCache
import Lean.Util.FileSetupInfo
import Lean.Util.Heartbeats
import Lean.Util.SafeExponentiation
34 changes: 34 additions & 0 deletions src/Lean/Util/SafeExponentiation.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.CoreM

namespace Lean

register_builtin_option exponentiation.threshold : Nat := {
defValue := 256
descr := "maximum value for \
which exponentiation operations are safe to evaluate. When an exponent \
is a value greater than this threshold, the exponentiation will not be evaluated, \
and a warning will be logged. This helps to prevent the system from becoming \
unresponsive due to excessively large computations."
}

/--
Returns `true` if `n` is `≤ exponentiation.threshold`. Otherwise,
reports a warning and returns `false`.
This method ensures there is at most one warning message of this kind in the message log.
-/
def checkExponent (n : Nat) : CoreM Bool := do
let threshold := exponentiation.threshold.get (← getOptions)
if n > threshold then
if (← reportMessageKind `unsafe.exponentiation) then
logWarning s!"exponent {n} exceeds the threshold {threshold}, exponentiation operation was not evaluated, use `set_option {exponentiation.threshold.name} <num>` to set a new threshold"
return false
else
return true

end Lean
14 changes: 13 additions & 1 deletion src/kernel/type_checker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,18 @@ template<typename F> optional<expr> type_checker::reduce_bin_nat_op(F const & f,
return some_expr(mk_lit(literal(nat(f(v1.raw(), v2.raw())))));
}

#define ReducePowMaxExp 1<<24 // TODO: make it configurable

optional<expr> type_checker::reduce_pow(expr const & e) {
expr arg1 = whnf(app_arg(app_fn(e)));
expr arg2 = whnf(app_arg(e));
if (!is_nat_lit_ext(arg2)) return none_expr();
nat v1 = get_nat_val(arg1);
nat v2 = get_nat_val(arg2);
if (v2 > nat(ReducePowMaxExp)) return none_expr();
return some_expr(mk_lit(literal(nat(nat_pow(v1.raw(), v2.raw())))));
}

template<typename F> optional<expr> type_checker::reduce_bin_nat_pred(F const & f, expr const & e) {
expr arg1 = whnf(app_arg(app_fn(e)));
if (!is_nat_lit_ext(arg1)) return none_expr();
Expand Down Expand Up @@ -622,7 +634,7 @@ optional<expr> type_checker::reduce_nat(expr const & e) {
if (f == *g_nat_add) return reduce_bin_nat_op(nat_add, e);
if (f == *g_nat_sub) return reduce_bin_nat_op(nat_sub, e);
if (f == *g_nat_mul) return reduce_bin_nat_op(nat_mul, e);
if (f == *g_nat_pow) return reduce_bin_nat_op(nat_pow, e);
if (f == *g_nat_pow) return reduce_pow(e);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little worried that this doesn't actually help; I think when I tried this patch, instead of:

  • Lean asks GMP to compute a very large natural, using lots of CPU or perhaps running out of heap memory

The result is now:

  • Lean refuses to expand a ^ n via nat reduction
  • a ^ n is expanded to a * a ^ (n - 1), allocating small objects and filling up the stack
  • Lean refuses to expand a ^ (n - 1) via nat reduction
  • ...
  • Recursion uses lots of CPU, and hits a stack overflow

Copy link
Contributor

@eric-wieser eric-wieser Aug 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's the smallest test I could construct:

import Lean
import Qq

open Qq

theorem foo : 2 ^ 100000000 = (2 ^ 10000) ^ 10000 := by
  run_tac do
    (← Lean.Elab.Tactic.getMainGoal).assign
      q(Eq.refl ((2 ^ 10000) ^ 10000))

Apologies for the Qq use

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was fixed by #4934!

if (f == *g_nat_gcd) return reduce_bin_nat_op(nat_gcd, e);
if (f == *g_nat_mod) return reduce_bin_nat_op(nat_mod, e);
if (f == *g_nat_div) return reduce_bin_nat_op(nat_div, e);
Expand Down
1 change: 1 addition & 0 deletions src/kernel/type_checker.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class type_checker {

template<typename F> optional<expr> reduce_bin_nat_op(F const & f, expr const & e);
template<typename F> optional<expr> reduce_bin_nat_pred(F const & f, expr const & e);
optional<expr> reduce_pow(expr const & e);
optional<expr> reduce_nat(expr const & e);
public:
type_checker(state & st, local_ctx const & lctx, definition_safety ds = definition_safety::safe);
Expand Down
1 change: 1 addition & 0 deletions tests/lean/run/lean_nat_gcd.lean
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def p_31 := 216091
def p_32 := 756839
def p_33 := 859433

set_option exponentiation.threshold 10000000
/- GCD with large prime factors on one side, and small primes on the other. -/
example : Nat.gcd (p_29 * p_30 * p_31 * p_32 * p_33) 2^(2^20) = 1 := rfl
/- GCD with two prime factors on both sides, including one in common. -/
Expand Down
47 changes: 47 additions & 0 deletions tests/lean/run/safeExp.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/--
warning: exponent 10000000 exceeds the threshold 256, exponentiation operation was not evaluated, use `set_option exponentiation.threshold <num>` to set a new threshold
---
error: maximum recursion depth has been reached
use `set_option maxRecDepth <num>` to increase limit
use `set_option diagnostics true` to get diagnostic information
-/
#guard_msgs in
example : 2^2^8000000 = 3^3^10000000 :=
rfl

/--
-/
#guard_msgs in
set_option exponentiation.threshold 258 in
example : 2^257 = 2*2^256 :=
rfl

/--
warning: exponent 2008 exceeds the threshold 256, exponentiation operation was not evaluated, use `set_option exponentiation.threshold <num>` to set a new threshold
---
warning: declaration uses 'sorry'
---
error: (kernel) deep recursion detected
---
info: k : Nat
h : k = 2008 ^ 2 + 2 ^ 2008
⊢ ((4032064 + 2 ^ 2008) ^ 2 + 2 ^ (4032064 + 2 ^ 2008)) % 10 = 6
-/
#guard_msgs in
example (k : Nat) (h : k = 2008^2 + 2^2008) : (k^2 + 2^k)%10 = 6 := by
simp [h]
trace_state
sorry

/--
warning: declaration uses 'sorry'
---
info: k : Nat
h : k = 2008 ^ 2 + 2 ^ 2008
⊢ ((2008 ^ 2 + 2 ^ 2008) ^ 2 + 2 ^ (2008 ^ 2 + 2 ^ 2008)) % 10 = 6
-/
#guard_msgs in
example (k : Nat) (h : k = 2008^2 + 2^2008) : (k^2 + 2^k)%10 = 6 := by
rw [h]
trace_state
sorry
Loading