Wednesday, October 21, 2015

Pitfalls in writing an equality method

Pitfall 1: wrong signature

== is final

And note the type of that.

final def == (that: Any): Boolean =
if (null eq this){null eq that} else {this equals that}

The signature of other should be Any, as in ==

Wrong version:

object Example {
  class Point(val x: Int, val y: Int) {
    def equals(other: Point): Boolean = {
      this.x == other.x && this.y == other.y
    }
  }
}

object Test {
  import collection.mutable.HashSet
  import Example._
  val p1 = new Point(2, 3)
  val p2 = new Point(2, 3)
  println(p1.equals(p2)) // true
  println(p1 == p2) // false
  val coll = HashSet(p1)
}

Better version 1:

object Example {

  class Point(val x: Int, val y: Int) {
    override def equals(other: Any): Boolean = {
      other match {
        case that: Point => this.x == that.x && this.y == that.y
        case _ => false
      }
    }
  }

}

object Test {
  import collection.mutable.HashSet
  import Example._
  val p1 = new Point(2, 3)
  val p2 = new Point(2, 3)
  println(p1.equals(p2)) // true
  println(p1 == p2) // true
  // p2a is cast into Any, but still has the Point.equal method
  val p2a: Any = p2
  p2a == p2 // true
  val coll = HashSet(p1)
  coll.contains(p2) // false
}

Pitfall 2: changing equals without also changing hashCode

Better version 2 (hashCode and equals should be overridden together, always):


object Example {
  class Point(val x: Int, val y: Int) {
    override def hashCode = 41 * (41 + x) + 17 * (17 + y)
    override def equals(other: Any): Boolean = {
      other match {
        case that: Point => this.x == that.x && this.y == that.y
        case _ =>  false
      }
    }
  }
}

object Test {
  import collection.mutable.HashSet
  import Example._
  val p1 = new Point(2, 3)
  val p2 = new Point(2, 3)
  println(p1.equals(p2)) // true
  println(p1 == p2) // false
  // p2a is cast into Any, but still has the Point.equal method
  val p2a: Any = p2
  p2a == p2 // true
  val coll = HashSet(p1)
  coll.contains(p2) // true
}

Pitfall 3: define equals in terms of mutable fields

object Example {
  class Point(var x: Int, var y: Int) {
    override def hashCode = 41 * (41 + x) + 17 * (17 + y)
    override def equals(other: Any): Boolean = {
      other match {
        case that: Point => this.x == that.x && this.y == that.y
        case _ =>  false
      }
    }
  }
}

Pitfall 4: failing to define equals as an equivalence relation

Equivalence relation (5 rules):

  1. (reflexive)
  2. (symmetric)
  3. (transitive)
  4. Consistency. Once equal, always equal.
  5. null == null and in all other cases x != null

An example that violates symmetry


object Example {
  class Point(val x: Int, val y: Int) {
    override def hashCode = 41 * (41 + x) + 17 * (17 + y)
    override def equals(other: Any): Boolean = other match {
      case that: Point => this.x == that.x && this.y == that.y
      case _ =>  false
    }
  }

  object Color extends Enumeration {
    val Red, Orange, Yellow, Green = Value
  }

  class ColoredPoint(x: Int, y: Int, val color: Color.Value)
  extends Point(x, y) {
    // hashCode is not changed here, because true from equals guarantees true from hashCode
    // equals is not symmetric
    override def equals(other: Any) = other match {
      case that: ColoredPoint => this.color == that.color && super.equals(that)
      case _ => false
    }
  }

}

object Test {
  import collection.mutable.HashSet
  import Example._
  val p = new Point(1, 2)
  val cp = new ColoredPoint(1, 2, Color.Red)
  p == cp  // true
  cp == p  // false
  HashSet[Point](p) contains cp // true
  HashSet[Point](cp) contains p // false
}

An example that violates transitivity



object Example {
  class Point(val x: Int, val y: Int) {
    override def hashCode = 41 * (41 + x) + 17 * (17 + y)
    override def equals(other: Any): Boolean = other match {
      case that: Point => this.x == that.x && this.y == that.y
      case _ =>  false
    }
  }

  object Color extends Enumeration {
    val Red, Orange, Yellow, Green = Value
  }

  class ColoredPoint(x: Int, y: Int, val color: Color.Value)
  extends Point(x, y) {
    // hashCode is not changed here, because true from equals guarantees true from hashCode
    override def equals(other: Any) = other match {
      case that: ColoredPoint => this.color == that.color && super.equals(that)
      // this line is added to ensure symmetry
      case that: Point => that equals this
      case _ => false
    }
  }

}

object Test {
  import collection.mutable.HashSet
  import Example._
  val p1 = new ColoredPoint(1, 2, Color.Red)
  val p2 = new Point(1, 2)
  val p3 = new ColoredPoint(1, 2, Color.Green)
  p1 == p2 // true
  p2 == p3  // true
  p1 == p3 // false
}

An example that is too restrictive

object Example {
  class Point(val x: Int, val y: Int) {
    override def hashCode = 41 * (41 + x) + 17 * (17 + y)
    override def equals(other: Any): Boolean = other match {
      case that: Point => this.x == that.x && this.y == that.y && this.getClass == that.getClass
      case _ =>  false
    }
  }

  object Color extends Enumeration {
    val Red, Orange, Yellow, Green = Value
  }

  class ColoredPoint(x: Int, y: Int, val color: Color.Value)
  extends Point(x, y) {
    // hashCode is not changed here, because true from equals guarantees true from hashCode
    // equals is not symmetric
    override def equals(other: Any) = other match {
      case that: ColoredPoint => this.color == that.color && super.equals(that)
      case _ => false
    }
  }

}

object Test {
  import collection.mutable.HashSet
  import Example._
  val p, p1 = new Point(1, 2)
  // symmetry holds
  p == p1
  p1 == p
  val cp = new ColoredPoint(1, 2, Color.Red)
  p == cp  // false
  cp == p  // false
  HashSet[Point](p) contains cp // false
  HashSet[Point](cp) contains p // false

  val anonymous = new Point(1, 1) {
    override val y = 3
  }
  anonymous == (new Point(1, 3)) // false
  val anonymous1 = new Point(1, 1) {
      override val y = 3
    }
  anonymous == anonymous1 // false
}

The final (good) version of the Point and ColoredPoint classes


object Example {
  class Point(val x: Int, val y: Int) {
    override def hashCode = 41 * (41 + x) + y
    override def equals(other: Any): Boolean = other match {
      case that: Point => that.canEqual(this) && this.x == that.x && this.y == that.y
      case _ =>  false
    }
    def canEqual(other: Any) = other.isInstanceOf[Point]
  }

  object Color extends Enumeration {
    val Red, Orange, Yellow, Green = Value
  }

  class ColoredPoint(x: Int, y: Int, val color: Color.Value)
    extends Point(x, y) {
    // hashCode is not changed here, because true from equals guarantees true from hashCode
    // equals is not symmetric
    override def equals(other: Any) = other match {
      case that: ColoredPoint =>
        that.canEqual(this) && this.color == that.color && super.equals(that)
      case _ => false
    }

    override def canEqual(that: Any) = that.isInstanceOf[ColoredPoint]
  }

}

object Test {
  import collection.mutable.HashSet
  import Example._
  val p1 = new Point(1, 2)
  val p2 = new Point(1, 1) {
    override val y = 2
  }
  p1 == p2 // true
  val p3 = new ColoredPoint(1, 2, Color.Green) {
    override val color = Color.Red
  }
  val p4 = new ColoredPoint(1, 2, Color.Red)
  p3 == p4 // true

  p1 == p3 // false

  val s = Set(p1, p2)
  s.size // 1
  val s1 = Set(p3, p4)
  s1.size // 1
}

Writing equals method for a parameterized class


object Example {

  trait Tree[+T] {
    def elem: T

    def left: Tree[T]

    def right: Tree[T]
  }

  object EmptyTree extends Tree[Nothing] {
    def elem = throw new NoSuchElementException("EmptyTree.elem")

    def left = throw new NoSuchElementException("EmptyTree.left")

    def right = throw new NoSuchElementException("EmptyTree.right")
  }

  class Branch[+T](val elem: T, val left: Tree[T], val right: Tree[T]) extends Tree[T] {
    def canEqual(branch: Branch[_]) = branch.isInstanceOf[Branch[_]]

    override def equals(other: Any) = other match {
      case that: Branch[_] =>
        that.canEqual(this) &&
          this.elem == that.elem &&
          this.left == that.left &&
          this.right == that.right
      case _ => false
    }

    override def hashCode: Int =
      41 * (
        41 * (
          41 + elem.hashCode()
          ) + left.hashCode()
        ) + right.hashCode()
  }

}

object Test {

  import collection.mutable.HashSet
  import Example._

  val b1 = new Branch[List[String]](Nil, EmptyTree, EmptyTree)
  val b2 = new Branch[List[Int]](Nil, EmptyTree, EmptyTree)
  b1 == b2 // true
}

0 comments: