Sunday, June 24, 2012

Tail recursion

In my previous post I compared two solutions, one imperative, one functional. You might have thought that this proves that functional solutions are inherently slower than imperative. That's not true as such. It is better to say that with excessive use of tail recursion you can be as fast as, but your code won't be any more readable. Here's a pure solution that finishes in about the same time.
import scala.annotation.tailrec
def solve(a:Int, b:Int) = {
  @tailrec def solve(sum:Int, n:Int):Int = {
    @tailrec def tens(n:Int, t:Int):Int =
      if (n < 10) t else tens(n/10, t*10)
    val t = tens(n, 1)
    @tailrec def recycle(sum:Int, m:Int):Int = {
      if (n==m) return sum
      recycle(
        if (n < m && m <= b) sum+1 else sum,
        m/10 + m%10*t)
    }
    val newSum = recycle(0, n/10 + n%10*t) + sum
    if (n==b) {
      newSum
    } else {
      solve(newSum, n+1)
    }
  }
  solve(0, a)
}

This shows the rather trivial way of how you can convert any imperative loop into a tail recursive pure function. Of course, the result is still nothing more than an imperative function, obfuscated to fit into the functional rules - giving you the worst of both worlds.

Note: the @tailrec annotations are not necessary, they are only there to ensure that the code I wrote is actually tail-recursive. If not, the compiler will throw an error. This is similar to the @override annotation in java.

Google Code Jam in Scala : Recycled Numbers

Recycled Numbers is the first problem in the code jam where we have to think about computational complexity. Any solution that uses for(n<-a to b; m<-a to b) is O(n^2), thus will not finish in time for the extreme a=1,b=2000000 case. And believe me, Google stress test your code a testcase with the extreme case.  So first we have to create a function that can recycle 12345, and output 51234. Basic math, the only thing needed is somehow to compute the multiplier for the last digit, 10000.
def tens(n:Int):Int = if (n < 10) 1 else tens(n/10)*10
def recycle(n:Int, t:Int): Int = n/10 + n%10*t

Now we have the two basic functions, but how do we generate all recycles? What you need is the iterate function of the various container classes. It helps you compute multiple applications of a function, and gives you back a series of x,f(x),f(f(x)),f(f(f(x))),.... Since we don't know beforehand how many steps do we want to take, let's use Stream:
for(m<-Stream.iterate(12345)(recycle(_,10000))) {
  println(m)
}
12345
51234
45123
34512
23451
12345
51234
45123
34512
23451
12345
...

This is cool, we have an infinite list of recycles, but we only need them until they start looping, and anyway, the first one is 12345, we don't want that either. Fortunately the API has all the methods we need, so at the end, we arrive at this function:
def recycled(n:Int):Stream[Int] = {
  val t = tens(n)
  Stream.iterate(n)(recycle(_,t)).tail.takeWhile(_!=n)
}

Now, how do we use our cool new method to generate all possible (n,m) pairs for given a and b? You use for loops, of course. And once you've got a list of all possible pairs, the length of the list is the answer you seek.
def solveFunc(a:Int, b:Int):Int = {
  val pairs = for(
      n <- a to b;
      m <- recycled(n);
      if (n < m && m <= b))
    yield (n,m)
  pairs.length
}

Neat, eh? Easy to read and understand, not much you can screw up. As a comparison here's the imperative solution to the problem:
def solveImp(a:Int, b:Int) = {
  var count = 0
  var n = a
  while(n <= b) {
    val tens = {
      var y = n/10
      var m = 1
      while(y>0) { y/=10; m*=10 }
      m
    }
    var temp = n/10 + (n%10)*tens
    while (temp != n) {
      if (temp > n && temp <=b) {
          count+=1
      }
      temp = temp/10 + (temp%10)*tens
    }
    n+=1
  }
  count
}

Much more harder to understand. It has one big advantage, though: its runtime is about 4% of the functional solution. And uses constant memory, while solveFunc uses up O(pairs.length). The question is can we speed up the functional solution? And why does it use up that much memory?

To understand the memory issue, you have to know that for comprehension is nothing more than syntactic sugar over map, flatMap and filter. Thus, when you compute pairs.length, you're forcing the whole dataset to materialise:
def solve(a:Int, b:Int):Int = {
  val pairs = (a to b).flatMap(n=> {
    recycled(n)
    .filter(
        m=>n < m && m <= b
    )
    .map(
        m=>(n,m)
    )
  })
  pairs.length
}
But since you don't actually need all that data, you might as well use a view. In scala you can switch back and forth between lazy and instant evaluation by using view and force.

Generating a the recycled numbers with stream is also very slow. Stream with its lazy evaluation and memoisation takes a lot of CPU cycles. If we use List, we would be much faster, but of course you can't have an infinite list. Fortunately since there is an upper limit b=2000000, we know that we don't need an infinite list of recycles, because we know that we arrive back at the original after 7 iterations. Computing some excess recycles is so much less work than the Stream overhead, that it's worth switching.

Last, there's no need to generate the pairs at all; we're dropping them on the floor anyway, and only count the length.

With these, the optimized-but-still-pure-functional solution looks like this:
def recycled(n:Int) = {
  val t = tens(n)
  List.iterate(n,7)(recycle(_,t)).tail.takeWhile(_!=n)
}

def solve(a:Int, b:Int) = {
  val recycles = for(
      n <- (a to b).view;
      ms = recycled(n))
    yield ms.count(m=> n < m && m <= b)
  
  recycles.sum
}

It uses constant memory, and the run time is 20% of the original. That's still 5 times slower than the imperative solution, but well, you can't have everything.

Friday, June 22, 2012

Google Code Jam in Scala : Dancing With the Googlers

This is a rather trivial task to solve, and shouldn't be more than a few lines in any language. The tricky bits are not in the coding, but in getting the calculations right. So the only things Scala-specific are
  • how to use min and max
  • how to count the number of items that satisfy a condition

import scala.math.{min,max}
def solve(s:Int,p:Int,ts:List[Int]):Int = {
  val normal_limit = p + max(0, p-1) + max(0, p-1)
  val suprising_limit = p + max(0, p-2) + max(0, p-2)
  val normals = ts.count(_>=normal_limit)
  val suprisings = ts.count(t=> t>=suprising_limit && t<normal_limit)
  
  return min(s,suprisings) + normals
}

Sidenote: usually it's bad practice to name your function parameters with one letter, like "s" or "p". But if you're solving a Code Jam problem, you should always use the same naming convention as the problem text. It will save lots of additional thinking cycles.

Thursday, June 21, 2012

Google Code Jam in Scala : Speaking in Tounges

Part One of the series "let's solve Google Code Jam problems in Scala, without ever using the keyword var".

The first problem for the 2012 qualification round gives you the task of "implement a monoalphabetic substitution cipher if you have access to sufficient plaintext/ciphertext". At the bottom of the problem description you have a nice seed data, which you can put in a map:

val stringMap = Map(
  "ejp mysljylc kd kxveddknmc re jsicpdrysi"
         -> "our language is impossible to understand",
  "rbcpc ypc rtcsra dkh wyfrepkym veddknkmkrkcd"
         -> "there are twenty six factorial possibilities",
  "de kr kd eoya kw aej tysr re ujdr lkgc jv"
         -> "so it is okay if you want to just give up"
)
This is a string-to-string map. But you need a character-to-character map. Converting to that is easy, you just have to know some facts about the scala collections:

  1. every map can be seen as a sequence of (key,value) pairs - you can iterate over them
  2. a string is also a Seq[Char], and as such, it has a zip method. So if you zip two strings together, you get a sequence of character pairs
  3. the reverse of (1), thus if you have a sequence of (key,value) pairs, you can create a map
val charMap = for(
    (keyString,valueString) <- stringMap;
    (key,value) <- keyString zip valueString)
  yield (key,value)

Since stringMap is of type Map[String,String], due to the clever scala mapping system, the type of charMap will be Map[Char,Char].
Now we can create our cipher decode function, and test it out. For this, I'm giving you a few different solutions, that are equivalent, as the differences are only syntactic sugar:
def decode(s:String) = for(c<-s) yield charMap(c)
def decode(s:String) = s.map(c=>charMap(c))
def decode(s:String) = s.map(charMap(_))
println(decode("rbcpc ypc rtcsra dkh wyfrepkym veddknkmkrkcd"))
"there are twenty six factorial possibilities"
Before you download the first input file to solve, you might ask one question: "I know map throws an exception for unknown keys. What if the sample we have so far doesn't cover all letters of the alphabet?" Good question. Let's ask our code for the answer! Again the extremely versatile collection classes come to the rescue:

  1. you can use the magic keyword "to" to create a NumericRange, and it works for all numeric types, including Char
  2. basically every collection has a toSet method, that creates an immutable set
  3. sets have a substitution operation "--" that can remove from a set anything that can be iterated over
  4. maps have methods to get the keys and the values as Iterables
val alphabet = ('a' to 'z').toSet
println(alphabet -- charMap.keys)
Set(q, z)
println(alphabet -- charMap.values)
Set(q, z)
Oops, there are two characters that don't have a mapping. Fortunately, if you read the beggining of the problem, there are some hints...