Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support at in ac_nf and use it in bv_normalize #5618

Merged
merged 24 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions src/Init/Tactics.lean
Original file line number Diff line number Diff line change
Expand Up @@ -399,19 +399,6 @@ example (a b c d : Nat) : a + b + c + d = d + (b + c) + a := by ac_rfl
-/
syntax (name := acRfl) "ac_rfl" : tactic

/--
`ac_nf` normalizes equalities up to application of an associative and commutative operator.
```
instance : Associative (α := Nat) (.+.) := ⟨Nat.add_assoc⟩
instance : Commutative (α := Nat) (.+.) := ⟨Nat.add_comm⟩

example (a b c d : Nat) : a + b + c + d = d + (b + c) + a := by
ac_nf
-- goal: a + (b + (c + d)) = a + (b + (c + d))
```
-/
syntax (name := acNf) "ac_nf" : tactic

/--
The `sorry` tactic closes the goal using `sorryAx`. This is intended for stubbing out incomplete
parts of a proof while still having a syntactically correct proof skeleton. Lean will give
Expand Down Expand Up @@ -1172,6 +1159,9 @@ Currently the preprocessor is implemented as `try simp only [bv_toNat] at *`.
-/
macro "bv_omega" : tactic => `(tactic| (try simp only [bv_toNat] at *) <;> omega)

/-- Implementation of `ac_nf` (the full `ac_nf` calls `trivial` afterwards). -/
syntax (name := acNf0) "ac_nf0" (location)? : tactic

/-- Implementation of `norm_cast` (the full `norm_cast` calls `trivial` afterwards). -/
syntax (name := normCast0) "norm_cast0" (location)? : tactic

Expand Down Expand Up @@ -1222,6 +1212,20 @@ See also `push_cast`, which moves casts inwards rather than lifting them outward
macro "norm_cast" loc:(location)? : tactic =>
`(tactic| norm_cast0 $[$loc]? <;> try trivial)

/--
`ac_nf` normalizes equalities up to application of an associative and commutative operator.
```
instance : Associative (α := Nat) (.+.) := ⟨Nat.add_assoc⟩
instance : Commutative (α := Nat) (.+.) := ⟨Nat.add_comm⟩

example (a b c d : Nat) : a + b + c + d = d + (b + c) + a := by
ac_nf
-- goal: a + (b + (c + d)) = a + (b + (c + d))
```
-/
macro "ac_nf" loc:(location)? : tactic =>
`(tactic| ac_nf0 $[$loc]? <;> try trivial)

/--
`push_cast` rewrites the goal to move certain coercions (*casts*) inward, toward the leaf nodes.
This uses `norm_cast` lemmas in the forward direction.
Expand Down
21 changes: 18 additions & 3 deletions src/Lean/Elab/Tactic/BVDecide/Frontend/Normalize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Authors: Henrik Böving
-/
prelude
import Lean.Meta.AppBuilder
import Lean.Meta.Tactic.AC.Main
import Lean.Elab.Tactic.Simp
import Lean.Elab.Tactic.FalseOrByContra
import Lean.Elab.Tactic.BVDecide.Frontend.Attr
Expand Down Expand Up @@ -112,10 +113,26 @@ def rewriteRulesPass : Pass := fun goal => do
let some (_, newGoal) := result? | return none
return newGoal

/--
Normalize with respect to Associativity and Commutativity.
-/
def acNormalizePass : Pass := fun goal => do
let mut newGoal := goal
for hyp in (← goal.getNondepPropHyps) do
let result ← Lean.Meta.AC.acNfHypMeta newGoal hyp

if let .some x := result then
newGoal := x
continue

return result
bollu marked this conversation as resolved.
Show resolved Hide resolved

return newGoal

/--
The normalization passes used by `bv_normalize` and thus `bv_decide`.
-/
def defaultPipeline : List Pass := [rewriteRulesPass]
def defaultPipeline : List Pass := [rewriteRulesPass, acNormalizePass]

end Pass

Expand All @@ -137,5 +154,3 @@ def evalBVNormalize : Tactic := fun

end Frontend.Normalize
end Lean.Elab.Tactic.BVDecide


68 changes: 47 additions & 21 deletions src/Lean/Meta/Tactic/AC/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -140,18 +140,7 @@ where
| .op l r => mkApp2 preContext.op (convertTarget vars l) (convertTarget vars r)
| .var x => vars[x]!

def rewriteUnnormalized (mvarId : MVarId) : MetaM MVarId := do
let simpCtx :=
{
simpTheorems := {}
congrTheorems := (← getSimpCongrTheorems)
config := Simp.neutralConfig
}
let tgt ← instantiateMVars (← mvarId.getType)
let (res, _) ← Simp.main tgt simpCtx (methods := { post })
applySimpResultToTarget mvarId tgt res
where
post (e : Expr) : SimpM Simp.Step := do
def post (e : Expr) : SimpM Simp.Step := do
tobiasgrosser marked this conversation as resolved.
Show resolved Hide resolved
let ctx ← Simp.getContext
match e, ctx.parent? with
| bin op₁ l r, some (bin op₂ _ _) =>
Expand All @@ -170,21 +159,58 @@ where
| none => return Simp.Step.done { expr := e }
| e, _ => return Simp.Step.done { expr := e }

def rewriteUnnormalizedRefl (goal : MVarId) : MetaM Unit := do
let newGoal ← rewriteUnnormalized goal
newGoal.refl
def rewriteUnnormalized (mvarId : MVarId) : MetaM MVarId := do
let simpCtx :=
{
simpTheorems := {}
congrTheorems := (← getSimpCongrTheorems)
config := Simp.neutralConfig
}
let tgt ← instantiateMVars (← mvarId.getType)
let (res, _) ← Simp.main tgt simpCtx (methods := { post })
applySimpResultToTarget mvarId tgt res

def rewriteUnnormalizedNormalForm (goal : MVarId) : TacticM Unit := do
let newGoal ← rewriteUnnormalized goal
replaceMainGoal [newGoal]
def rewriteUnnormalizedRefl (goal : MVarId) : MetaM Unit := do
(← rewriteUnnormalized goal).refl

@[builtin_tactic acRfl] def acRflTactic : Lean.Elab.Tactic.Tactic := fun _ => do
let goal ← getMainGoal
goal.withContext <| rewriteUnnormalizedRefl goal

@[builtin_tactic acNf] def acNfTactic : Lean.Elab.Tactic.Tactic := fun _ => do
let goal ← getMainGoal
goal.withContext <| rewriteUnnormalizedNormalForm goal
def acNfHypMeta (goal : MVarId) (fvarId : FVarId) : MetaM (Option MVarId) := do
goal.withContext do
let simpCtx :=
{
simpTheorems := {}
congrTheorems := (← getSimpCongrTheorems)
config := Simp.neutralConfig
}
let tgt ← instantiateMVars (← fvarId.getType)
let (res, _) ← Simp.main tgt simpCtx (methods := { post })
return (← applySimpResultToLocalDecl goal fvarId res false).map (·.snd)

/-- Implementation of the `ac_nf` tactic when operating on the main goal. -/
def acNfTargetTactic : TacticM Unit :=
liftMetaTactic1 fun goal => do return (← rewriteUnnormalized goal)
bollu marked this conversation as resolved.
Show resolved Hide resolved

/-- Implementation of the `ac_nf` tactic when operating on a hypothesis. -/
def acNfHypTactic (fvarId : FVarId) : TacticM Unit :=
liftMetaTactic1 fun goal => acNfHypMeta goal fvarId

@[builtin_tactic acNf0]
def evalNf0 : Tactic := fun stx => do
match stx with
| `(tactic| ac_nf0 $[$loc?]?) =>
let loc := if let some loc := loc? then expandLocation loc else Location.targets #[] true
withMainContext do
match loc with
| Location.targets hyps target =>
if target then acNfTargetTactic
(← getFVarIds hyps).forM acNfHypTactic
| Location.wildcard =>
acNfTargetTactic
(← (← getMainGoal).getNondepPropHyps).forM acNfHypTactic
| _ => Lean.Elab.throwUnsupportedSyntax

builtin_initialize
registerTraceClass `Meta.AC
Expand Down
23 changes: 12 additions & 11 deletions tests/lean/run/ac_rfl.lean
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ example : [1, 2] ++ ([] ++ [2+4, 8] ++ [4]) = [1, 2] ++ [4+2, 8] ++ [4] := by ac
example (a b c d : BitVec w) :
a * b * c * d = d * c * b * a := by
ac_nf
rfl

example (a b c d : BitVec w) :
a * b * c * d = d * c * b * a := by
Expand All @@ -52,7 +51,6 @@ example (a b c d : BitVec w) :
example (a b c d : BitVec w) :
a + b + c + d = d + c + b + a := by
ac_nf
rfl

example (a b c d : BitVec w) :
a + b + c + d = d + c + b + a := by
Expand All @@ -63,7 +61,6 @@ example (a b c d : BitVec w) :
example (a b c d : BitVec w) :
a * (b * (c * d)) = ((a * b) * c) * d := by
ac_nf
rfl

example (a b c d : BitVec w) :
a * (b * (c * d)) = ((a * b) * c) * d := by
Expand All @@ -72,7 +69,6 @@ example (a b c d : BitVec w) :
example (a b c d : BitVec w) :
a + (b + (c + d)) = ((a + b) + c) + d := by
ac_nf
rfl

example (a b c d : BitVec w) :
a + (b + (c + d)) = ((a + b) + c) + d := by
Expand All @@ -83,7 +79,6 @@ example (a b c d : BitVec w) :
example (a b c d : BitVec w) :
a ^^^ b ^^^ c ^^^ d = d ^^^ c ^^^ b ^^^ a := by
ac_nf
rfl

example (a b c d : BitVec w) :
a ^^^ b ^^^ c ^^^ d = d ^^^ c ^^^ b ^^^ a := by
Expand All @@ -92,7 +87,6 @@ example (a b c d : BitVec w) :
example (a b c d : BitVec w) :
a &&& b &&& c &&& d = d &&& c &&& b &&& a := by
ac_nf
rfl

example (a b c d : BitVec w) :
a &&& b &&& c &&& d = d &&& c &&& b &&& a := by
Expand All @@ -101,7 +95,6 @@ example (a b c d : BitVec w) :
example (a b c d : BitVec w) :
a ||| b ||| c ||| d = d ||| c ||| b ||| a := by
ac_nf
rfl

example (a b c d : BitVec w) :
a ||| b ||| c ||| d = d ||| c ||| b ||| a := by
Expand All @@ -112,7 +105,6 @@ example (a b c d : BitVec w) :
example (a b c d : BitVec w) :
a &&& (b &&& (c &&& d)) = ((a &&& b) &&& c) &&& d := by
ac_nf
rfl

example (a b c d : BitVec w) :
a &&& (b &&& (c &&& d)) = ((a &&& b) &&& c) &&& d := by
Expand All @@ -121,7 +113,6 @@ example (a b c d : BitVec w) :
example (a b c d : BitVec w) :
a ||| (b ||| (c ||| d)) = ((a ||| b) ||| c) ||| d := by
ac_nf
rfl

example (a b c d : BitVec w) :
a ||| (b ||| (c ||| d)) = ((a ||| b) ||| c) ||| d := by
Expand All @@ -130,12 +121,22 @@ example (a b c d : BitVec w) :
example (a b c d : BitVec w) :
a ^^^ (b ^^^ (c ^^^ d)) = ((a ^^^ b) ^^^ c) ^^^ d := by
ac_nf
rfl

example (a b c d : BitVec w) :
a ^^^ (b ^^^ (c ^^^ d)) = ((a ^^^ b) ^^^ c) ^^^ d := by
ac_rfl

example (a b c d : Nat) : a + b + c + d = d + (b + c) + a := by
ac_nf
rfl

example (a b c d : Nat) (h₁ h₂ : a + b + c + d = d + (b + c) + a) :
a + b + c + d = a + (b + c) + d := by

ac_nf at h₁
guard_hyp h₁ :ₛ a + (b + (c + d)) = a + (b + (c + d))

guard_hyp h₂ :ₛ a + b + c + d = d + (b + c) + a
ac_nf at h₂
guard_hyp h₂ :ₛ a + (b + (c + d)) = a + (b + (c + d))

ac_nf at *
Loading