Saturday, October 17, 2015

Matching objects with patterns in Scala

Some code I wrote while reading this paper.

Expression simplification using object-oriented decomposition

This method suffers from lack of extensibility: whenever you want to add a new class that extends Expr, the trait has to be augmented.

/**
 * Created by IDEA on 27/08/15.
 */

object Dummy {
  trait Expr {
    def isVar: Boolean = false
    def isNum: Boolean = false
    def isMul: Boolean = false
    def value: Int = throw new NoSuchMethodError()
    def name: String = throw new NoSuchMethodError()
    def left: Expr = throw new NoSuchMethodError()
    def right: Expr = throw new NoSuchMethodError()
    override def toString: String = throw new NoSuchElementException()
  }

  class Num(override val value: Int) extends Expr {
    override def isNum = true
    override def toString: String = {
      value.toString
    }
  }

  class Var(override val name: String) extends Expr {
    override def isVar = true
    override def toString = name
  }

  class Mul(override val left: Expr, override val right: Expr) extends Expr {
    override def isMul = true
    override def toString = {
      "(%s) times (%s)".format(left.toString, right.toString)
    }
  }
}

class Main {
  def main(args: Array[String]): Unit = {
    import Dummy._
    def simplify(e: Expr): Expr = {
      if(e.isMul) {
        val r = e.right
        val l = e.left
        if(r.value == 1) {
          return l
        }
        if(l.value == 1) {
          return r
        }
        return e
      } else {
        return e
      }
    }

    val mul = new Mul(new Num(12), new Num(1))
    println(mul)
    println(simplify(mul))
  }
}

Using visitor pattern

This approach is very complex and I think the otherwise part is even questionable: what if the right expression of a Mul instance is not a Num?

/**
 * Created by IDEA on 27/08/15.
 */

object Dummy {
  trait Visitor[T] {
    def caseNum(t: Num): T = otherwise(t)
    def caseVar(t: Var): T = otherwise(t)
    def caseMul(t: Mul): T = otherwise(t)
    def otherwise(t: Expr): T = throw new NoSuchMethodError()
  }

  trait Expr {
    def matchWith[T](v: Visitor[T]): T
  }

  class Num(val value: Int) extends Expr {
    def matchWith[T](v: Visitor[T]): T = v.caseNum(this)
    override def toString = {
      value.toString
    }
  }

  class Var(val name: String, val value: Int) extends Expr {
    def matchWith[T](v: Visitor[T]): T = v.caseVar(this)
    override def toString = value.toString
  }

  class Mul(val left: Expr, val right: Expr) extends Expr {
    def matchWith[T](v: Visitor[T]): T = v.caseMul(this)
    override def toString = "(%s) times (%s)".format(left.toString, right.toString)
  }
}

class Main {
  def main(args: Array[String]): Unit = {
    import Dummy._
    def simplify(e: Expr): Expr = {
      e.matchWith{
        new Visitor[Expr] {
          override def caseMul(m: Mul): Expr = {
            // right value is 1?
            m.right.matchWith {
              new Visitor[Expr] {
                override def caseNum(n: Num): Expr = {
                  if(n.value == 1) {
                    m.left
                  } else {
                    // left value is 1?
                    m.left.matchWith[Expr] {
                      new Visitor[Expr] {
                        override def caseNum(n: Num): Expr = {
                          if(n.value == 1) {
                            m.right
                          } else {
                            m
                          }
                        }
                        // left expr not a Num, return unchanged
                        override def otherwise(v: Expr) = v
                      }
                    }
                  }
                }
                // right expr not a Num, return unchanged
                override def otherwise(e: Expr) = e
              }
            }
          }
          // not a Mul, return v unchanged
          override def otherwise(e: Expr) = e
        }
      }
    }

    val mul = new Mul(new Num(12), new Num(1))
    println(mul)
    println(simplify(mul))
  }
}

Using type-test/type-cast

This is the most direct form of object matching, very easy to understand:

/**
 * Created by IDEA on 27/08/15.
 */

object Dummy {
  trait Expr

  class Num(val value: Int) extends Expr {
    override def toString =  value.toString
  }

  class Var(val name: String, val value: Int) extends Expr {
    override def toString = value.toString
  }

  class Mul(val left: Expr, val right: Expr) extends Expr {
    override def toString = "(%s) times (%s)".format(left.toString, right.toString)
  }
}

class Main {
  def main(args: Array[String]): Unit = {
    import Dummy._
    def simplify(e: Expr): Expr = {
      if(e.isInstanceOf[Mul]) {
        val m = e.asInstanceOf[Mul]
        val r = m.right
        val l = m.left
        if(l.isInstanceOf[Num]) {
          val lNum = l.asInstanceOf[Num]
          if(lNum.value == 1) {
            return r
          }
        }
        if(r.isInstanceOf[Num]) {
          val rNum = r.asInstanceOf[Num]
          if(rNum.value == 1) {
            return l
          }
        }
        return m
      } else {
        return e
      }
    }

    val mul = new Mul(new Num(12), new Num(1))
    println(mul)
    println(simplify(mul))
  }
}

Using typecase

This is similar to type-test/type-case, but more concise:

/**
 * Created by IDEA on 27/08/15.
 */

object Dummy {
  trait Expr

  class Num(val value: Int) extends Expr {
    override def toString =  value.toString
  }

  class Var(val name: String, val value: Int) extends Expr {
    override def toString = value.toString
  }

  class Mul(val left: Expr, val right: Expr) extends Expr {
    override def toString = "(%s) times (%s)".format(left.toString, right.toString)
  }
}

class Main {
  def main(args: Array[String]): Unit = {
    import Dummy._
    def simplify(e: Expr): Expr = {
      e match {
        case m: Mul => {
          m.right match {
            case n: Num => if(n.value == 1) m.left else {
              m.left match {
                case n: Num => if(n.value == 1) m.right else e
                case _ => e
              }
            }
            case _ => e
          }
        }
        case _ => e
      }
    }

    val mul1 = new Mul(new Num(12), new Num(1))
    println(mul1)
    println(simplify(mul1))
    val mul2 = new Mul(new Num(1), new Num(12))
    println(mul2)
    println(simplify(mul2))
  }
}

Using case classes

Case classes allows pattern matching using the constructor:

/**
 * Created by IDEA on 27/08/15.
 */

object Dummy {
  trait Expr

  case class Num(val value: Int) extends Expr {
    override def toString =  value.toString
  }

  case class Var(val name: String, val value: Int) extends Expr {
    override def toString = value.toString
  }

  case class Mul(val left: Expr, val right: Expr) extends Expr {
    override def toString = "(%s) times (%s)".format(left.toString, right.toString)
  }
}

class Main {
  def main(args: Array[String]): Unit = {
    import Dummy._
    def simplify(e: Expr): Expr = {
      e match {
        case Mul(x, Num(1)) => x
        case Mul(Num(1), x) => x
        case _ => e
      }
    }

    val mul1 = new Mul(new Num(12), new Num(1))
    println(mul1)
    println(simplify(mul1))
    val mul2 = new Mul(new Num(1), new Num(12))
    println(mul2)
    println(simplify(mul2))
  }
}

Using extractors

/**
 * Created by IDEA on 27/08/15.
 */

object Dummy {
  trait Expr

  class Num(val value: Int) extends Expr {
    override def toString =  value.toString
    override def equals(other: Any) = other match {
      case o: Num => value.equals(o.value)
      case _ => false
    }
    override def hashCode = 41 * (41 + value)
  }

  object Num {
    def apply(value: Int) = new Num(value)
    def unapply(n: Num) = Some(n.value)
  }

  class Var(val name: String, val value: Int) extends Expr {
    override def toString = value.toString
    override def equals(other: Any) = other match {
      case m: Var => m.name.equals(name) && m.value.equals(value)
      case _ => false
    }
    override def hashCode = 41 * (41 + value)
  }

  object Var {
    def apply(name: String, value: Int) = new Var(name, value)
    def unapply(v: Var) = Some(v.name, v.value)
  }

  class Mul(val left: Expr, val right: Expr) extends Expr {
    override def toString = "Mul(%s, %s)".format(left.toString, right.toString)
    override def equals(other: Any) = {
      other match {
        case m: Mul => left.equals(m.left) && right.equals(m.right)
        case _ => false
      }
    }
    override def hashCode = 41 * (41 + left.hashCode() + right.hashCode())
  }

  object Mul {
    def apply(left: Expr, right: Expr) = new Mul(left, right)
    def unapply(m: Mul) = Some(m.left, m.right)
  }
}


class Main {
  def main(args: Array[String]): Unit = {
    import Dummy._
    def simplify(e: Expr): Expr = {
      e match {
        case Mul(x, Num(1)) => x
        case Mul(Num(1), x) => x
        case Mul(Num(0), _) => Num(0)
        case Mul(_, Num(0)) => Num(0)
        case _ => e
      }
    }



    val mul1 = new Mul(new Num(12), new Num(1))
    println(mul1)
    println(simplify(mul1))
    val mul2 = new Mul(new Num(1), new Num(12))
    println(mul2)
    println(simplify(mul2))
  }
}

0 comments: