September 27, 2012

Quicksort rewritten in tail-recursive form

An example using the Scala programming language


I am currently learning Scala and I became quite intrigued by the compiler's ability to optimize tail-recursive calls.

For those unfamiliar with the concept, a tail-recursive method can be optimized to be executed in constant stack space. This practically means that the recursion is not really necessary and that the method in question could be implemented as a simple iterative process! Nevertheless, it is nice to have the compiler do the trick of converting the recursion into a loop and let the programmers express their ideas using recursion.

As an exercise, I set out to convert the renown Quicksort algorithm into tail-recursive form. But first, why would this algorithm need any conversion whatsoever?

Let us see an implementation of the algorithm written in Scala. I am using the following piece of code as it appears in “Scala By Example” by Martin Odersky (draft of May 24, 2011). It is a straightforward implementation of Quicksort for arrays of integers:

  def quickSort_scalaByExample(xs: Array[Int]) {
    def swap(i: Int, j: Int) {
      val t = xs(i); xs(i) = xs(j); xs(j) = t
    }
    def sort1(l: Int, r: Int) {
      val pivot = xs((l + r) / 2)
      var i = l
      var j = r
      while (i <= j) {
        while (xs(i) < pivot) i += 1
        while (xs(j) > pivot) j -= 1
        if (i <= j) {
          swap(i, j)
          i += 1
          j -= 1
        }
      }
      if (l < j) sort1(l, j)
      if (i < r) sort1(i, r)
    }
    sort1(0, xs.length - 1)
  }

The recursive method sort1() may call itself up to two times at the end of each execution. This is the reason why the method is not tail-recursive. Indeed, if we use the @tailrec annotation in the definition of sort1(), the compiler will complain with this error:
could not optimize @tailrec annotated method sort1: it contains a recursive call not in tail position
The message refers to the line if (l < j) sort1(l, j).

To enable the optimization, we need to make sure that the recursive call may happen at most once at the end of the execution of sort1().

First, let's rewrite the algorithm so that the “sorting” part of it is factored out of the recursion. This is the part where elements are being compared against each other and swapped inside the array, so that the pivot element gets its “correct” position in the array (for an explanation of Quicksort you may read this Wikipedia article). We introduce a new method sortRange(), which will take a lower and upper bound and perform the “sorting work” on this segment of the array:

  def quickSort_intermediate(xs: Array[Int]) {
    def swap(i: Int, j: Int) {
      val t = xs(i); xs(i) = xs(j); xs(j) = t
    }
    def sortRange(l: Int, r: Int): (Int, Int) = {
      val pivot = xs((l + r) / 2)
      var i = l
      var j = r
      while (i <= j) {
        while (xs(i) < pivot) i += 1
        while (xs(j) > pivot) j -= 1
        if (i <= j) {
          swap(i, j)
          i += 1
          j -= 1
        }
      }
      return (i, j)
    }
    def sort1(l: Int, r: Int) {
      val i_j = sortRange(l, r);
      val i = i_j._1
      val j = i_j._2
      if (l < j) sort1(l, j)
      if (i < r) sort1(i, r)
    }
    sort1(0, xs.length - 1)
  }

This new method does the sorting and then returns the values of “i” and “j” as a tuple, because these are needed to continue the process. The sort1() method still takes the same arguments as before and is in fact, again, not at all tail-recursive! Perhaps annoyingly, it does not even make use of Scala's idiomatic syntax for tuples, so we rewrite it to this more natural Scala form:

    def sort1(l: Int, r: Int) {
      sortRange(l, r) match {
        case (i, j) => {
          if (l < j) sort1(l, j)
          if (i < r) sort1(i, r)
        }
      }
    }

So, up to this point, we have broken down the initial implementation into a method “sortRange()” that can “do the sorting work” on segments of the array, and a recursive method “sort1()” that calls itself up to two times in every execution, using as arguments new segments of the array.

The recursive method sort1() takes as arguments the lower and upper limits of the array where it is going to work on. We are going to replace this method with one that takes as its only argument “a list of array segments” that remain to be sorted. Or, more precisely, a list of tuples containing the lower and upper bounds of the array segments that need to be sorted:

    def sort2(segments: List[(Int, Int)]) {
      segments.head match {
        case (l, r) => {
          var newSegments = segments.tail
          sortRange(l, r) match {
            case (i, j) => {
              if (l < j) newSegments = (l, j) :: newSegments
              if (i < r) newSegments = (i, r) :: newSegments
            }
          }
          if (!newSegments.isEmpty) sort2(newSegments)
        }
      }
    }

What does the new method do?
  • It takes a list of “segments of the array” (i.e. the upper and lower bounds of each segment) and calls sortRange() on the head of the list.
  • Then, according to the results of sortRange(), it may add one or two more segments to the tail of the original list of segments that need to be sorted.
  • Finally, it recursively calls itself at the end, with the new list of segments as an argument.
This version of the code is tail-recursive, and the compiler will certainly not complain if we add the @tailrec annotation. Here's the complete implementation:

  def quickSort_tailRecursive(xs: Array[Int]) {
    def swap(i: Int, j: Int) {
      val t = xs(i); xs(i) = xs(j); xs(j) = t
    }
    def sortRange(l: Int, r: Int): (Int, Int) = {
      val pivot = xs((l + r) / 2)
      var i = l
      var j = r
      while (i <= j) {
        while (xs(i) < pivot) i += 1
        while (xs(j) > pivot) j -= 1
        if (i <= j) {
          swap(i, j)
          i += 1
          j -= 1
        }
      }
      return (i, j)
    }
    @tailrec def sort2(segments: List[(Int, Int)]) {
      segments.head match {
        case (l, r) => {
          var newSegments = segments.tail
          sortRange(l, r) match {
            case (i, j) => {
              if (l < j) newSegments = (l, j) :: newSegments
              if (i < r) newSegments = (i, r) :: newSegments
            }
          }
          if (!newSegments.isEmpty) sort2(newSegments)
        }
      }
    }
    sort2(List((0, xs.length - 1)))
  }

Conclusion


It was not so hard to convert the classic implementation of Quicksort into a from that is eligible for tail-recursion optimization using the Scala language. We could go ahead and convert this form into an iterative process even -but why not let the compiler do this step?

The question that arises naturally is: Is this implementation more efficient or better in any way than the classic implementation? I would answer yes AND no.

In the process of removing the second recursive call from the original code (from method sort1()), we introduced a method that needs an argument that is a list of things, while in the original form we had simple arguments of constant size. This list expands in size over time, using space in the heap. So, in effect, we traded heap space for stack space.

Arguably this is a good thing; as far as I know most machines (and this includes the Java virtual machine on its default settings) have a more limited stack space compared to their (sometimes virtually unlimited) heap space. With this in mind I believe that -under normal circumstances- the tail-recursive implementation would be able to operate on larger arrays than the classic method, without being prone to a stack overflow error.

No comments:

Post a Comment