import stainless.annotation.*
import stainless.lang.{ghost => ghostExpr, *}
object ConstFold:

  sealed abstract class Expr
  case class Number(value: BigInt) extends Expr
  case class Var(name: String) extends Expr
  case class Add(e1: Expr, e2: Expr) extends Expr
  case class Minus(e1: Expr, e2: Expr) extends Expr

  type Env = String => BigInt
  val zeroEnv: Env = (_:String) => BigInt(0)

  def evaluate(e: Expr)(using ctx: Env): BigInt = 
    e match
      case Number(value) => value
      case Var(name) => ctx(name)
      case Add(e1, e2) => evaluate(e1) + evaluate(e2)
      case Minus(e1, e2) => evaluate(e1) - evaluate(e2)

  def zeroExpr(e: Expr): Boolean = 
    e match
      case Number(value) => value == 0
      case Var(_) => false
      case Add(e1, e2) => zeroExpr(e1) && zeroExpr(e2)
      case Minus(e1, e2) => zeroExpr(e1) && zeroExpr(e2)

  def lemma(ctx: Env, @induct e: Expr): Unit = {
    require(zeroExpr(e))
    ()
  }.ensuring(_ => evaluate(e)(using ctx) == 0)

  def mirror(e: Expr)(using anyCtx: Env): Expr = {
    e match
      case Number(value) => e
      case Var(name) => e
      case Add(e1, e2) => Add(mirror(e2), mirror(e1))
      case Minus(e1, e2) => Minus(mirror(e1), mirror(e2))
  }.ensuring(evaluate(_) == evaluate(e))

  abstract class SoundSimplifier:
    def apply(e: Expr, anyCtx: Env): Expr = {
      (??? : Expr)
    }.ensuring(evaluate(_)(using anyCtx) == evaluate(e)(using anyCtx))
  
  val mirSimp = new SoundSimplifier:
    override def apply(e: Expr, anyCtx: Env) = mirror(e)(using anyCtx)
  
  def constfold1(e: Expr)(using anyCtx: Env) = {
    e match
      case Add(Number(n1), Number(n2))   => Number(n1 + n2)
      case Minus(Number(n1), Number(n2)) => Number(n1 - n2)
      case e                             => e
  }.ensuring(evaluate(_) == evaluate(e))

  val constFold1Simp = new SoundSimplifier:
    override def apply(e: Expr, anyCtx: Env) = constfold1(e)(using anyCtx)

  def mapExpr(e: Expr, f: SoundSimplifier)(using anyCtx: Env): Expr = {
    val mapped: Expr = e match
      case Number(_)     => e
      case Var(_)        => e
      case Add(e1, e2)   => Add(mapExpr(e1, f), mapExpr(e2, f))
      case Minus(e1, e2) => Minus(mapExpr(e1, f), mapExpr(e2, f))
    f(mapped, anyCtx)
  }.ensuring(evaluate(_) == evaluate(e))

  def constfold(e: Expr)(using anyCtx: Env): Expr = {
    mapExpr(e, constFold1Simp)
  }.ensuring(evaluate(_) == evaluate(e))

end ConstFold
