Specialization of generic functions in Scala (or Java)
Asked Answered
V

4

30

Is it possible to specialize generic functions (or class) in Scala? For example, I want to write a generic function that writes data into a ByteBuffer:

def writeData[T](buffer: ByteBuffer, data: T) = buffer.put(data)

But as the put method takes only a byte and put it into the buffer, I need to specialize it for Ints and Longs as follows:

def writeData[Int](buffer: ByteBuffer, data: Int) = buffer.putInt(data)
def writeData[Long](buffer: ByteBuffer, data: Long) = buffer.putLong(data)

and it won't compile. Of course, I could instead write 3 different functions writeByte, writeInt and writeLong respectively, but let's say there is another function for an array:

def writeArray[T](buffer: ByteBuffer, array: Array[T]) {
  for (elem <- array) writeData(buffer, elem)
}

and this wouldn't work without the specialized writeData functions: I'll have to deploy another set of functions writeByteArray, writeIntArray, writeLongArray. Having to deal with the situation this way whenever I need to use type-dependent write functions is not cool. I did some research and one possible workaround is to test the type of the parameter:

def writeArray[T](buffer: ByteBuffer, array: Array[T]) {
  if (array.isInstanceOf[Array[Byte]])
    for (elem <- array) writeByte(buffer, elem)
  else if (array.isInstanceOf[Array[Int]])
    for (elem <- array) writeInt(buffer, elem)
  ...
}

This might work but it's less efficient because type-checking is done in runtime unlike the specialized function version.

So my question is, what is the most desirable and preferred way to solve this kind of problem in Scala or Java? I appreciate your help in advance!

Viburnum answered 1/11, 2012 at 8:4 Comment(4)
Waiting for the type class answer in 3... 2... 1..Pestilent
Added the specialized tag, because it turns out that this means something specific in Scala and it is exactly what you need.Medea
@RexKerr: I don't see how @specialized is relevant here—it's not going to help with picking which of putInt, putLong, etc. is needed.Blab
@TravisBrown - No, but it will keep you from being sorry that you did it this way instead of implementing everything from scratch each time. Adding bytes to a buffer is not something you typically want to invoke boxing.Medea
M
23

Wouldn't it be nice if you could have both a compact and efficient solution? It turns out that you can, given Scala's @specialized feature. First a warning: the feature is somewhat buggy, and may break if you try to use it for something too complicated. But for this case, it's almost perfect.

The @specialized annotation creates separate classes and/or methods for each primitive type, and then calls that instead of the generic version whenever the compiler knows for sure what the primitive type is. The only drawback is that it does all of this completely automatically--you don't get to fill in your own method. That's kind of a shame, but you can overcome the problem using type classes.

Let's look at some code:

import java.nio.ByteBuffer
trait BufferWriter[@specialized(Byte,Int) A]{
  def write(b: ByteBuffer, a: A): Unit
}
class ByteWriter extends BufferWriter[Byte] {
  def write(b: ByteBuffer, a: Byte) { b.put(a) }
}
class IntWriter extends BufferWriter[Int] {
  def write(b: ByteBuffer, a: Int) { b.putInt(a) }
}
object BufferWriters {
  implicit val byteWriter = new ByteWriter
  implicit val intWriter = new IntWriter
}

This gives us a BufferWriter trait which is generic, but we override each of the specific primitive types that we want (in this case Byte and Int) with an appropriate implementation. Specialization is smart enough to link up this explicit version with the hidden one it normally uses for specialization. So you've got your custom code, but how do you use it? This is where the implicit vals come in (I've done it this way for speed and clarity):

import BufferWriters._
def write[@specialized(Byte,Int) A: BufferWriter](b: ByteBuffer, ar: Array[A]) {
  val writer = implicitly[BufferWriter[A]]
  var i = 0
  while (i < ar.length) {
    writer.write(b, ar(i))
    i += 1
  }
}

The A: BufferWriter notation means that in order to call this write method, you need to have an implicit BufferWriter[A] handy. We've supplied them with the vals in BufferWriters, so we should be set. Let's see if this works.

val b = ByteBuffer.allocate(6)
write(b, Array[Byte](1,2))
write(b, Array[Int](0x03040506))
scala> b.array
res3: Array[Byte] = Array(1, 2, 3, 4, 5, 6)

If you put these things in a file and start poking around the classes with javap -c -private you'll see that the appropriate primitive methods are being used.

(Note that if you didn't use specialization, this strategy would still work, but it would have to box values inside the loop to copy the array out.)

Medea answered 1/11, 2012 at 13:24 Comment(6)
being a novice to Scala, I did not know this feature. +1, thanksDinin
Awesome, exactly what I was looking for! It'd be better with less emphasis put on the @specialized feature, though. Because although it makes the code more efficient and thus perfect for the real usage, it's not directly related to the problem so I'm afraid it might make the people seeking for the answer slightly confused. The type class pattern is the key point here. But besides that, the answer is perfect. Thanks, Rex!Viburnum
@KJ - If you're using ByteBuffer at all instead of serializing to a string, you probably care about performance. So I am not convinced that the focus on specialized is misplaced. Type classes are just syntactic sugar; you could always do the same in Scala or Java but have to pass the BufferWriter in manually. specialized is sugar too, but it's a lot of sugar.Medea
Good answer, that fixes the problems with mine. A question though - why not put the specialized write method inside the BufferWriters object?Pestilent
Also, forgive my noobiness, is there anyway to define the write method so it can act on single A's (not Array[A])? The best I could come up with was this: gist.github.com/3997580Pestilent
@DominicBou-Samra - You did it correctly, I think--it's just the obvious. But your test as posted tries to put three ints into 6 bytes, which is not going to work. And the only reason I pulled the write method out was to show that you can (as you may have yet more methods to define later), and so I could break up the code and explanation more.Medea
P
17

Use a typeclass pattern. It has the advantage over the instanceOf checking (or pattern matching) of being typesafe.

import java.nio.ByteBuffer

trait BufferWriter[A] {
  def write(buffer: ByteBuffer, a: A)
}

class BuffPimp(buffer: ByteBuffer) {
  def writeData[A: BufferWriter](data: A) = { 
    implicitly[BufferWriter[A]].write(buffer, data)
  }
}

object BuffPimp {
  implicit def intWriter = new BufferWriter[Int] {
    def write(buffer: ByteBuffer, a: Int) = buffer.putInt(a)
  }
  implicit def doubleWriter = new BufferWriter[Double] {
    def write(buffer: ByteBuffer, a: Double) = buffer.putDouble(a)
  }
  implicit def longWriter = new BufferWriter[Long] {
    def write(buffer: ByteBuffer, a: Long) = buffer.putLong(a)
  }
  implicit def wrap(buffer: ByteBuffer) = new BuffPimp(buffer)
}

object Test {
  import BuffPimp._
  val someByteBuffer: ByteBuffer
  someByteBuffer.writeData(1)
  someByteBuffer.writeData(1.0)
  someByteBuffer.writeData(1L)
}

So this code isn't the best demonstration of typeclasses. I am still very new to them. This video gives a really solid overview of their benefits and how you can use them: http://www.youtube.com/watch?v=sVMES4RZF-8

Pestilent answered 1/11, 2012 at 9:1 Comment(9)
If one of the Scala experts could clean this up for me I'd greatly appreciate it. I am super new to typeclasses.Pestilent
You should add support for arrays in you example. As it stands now, it looks like a needlessly complex way of doing what good old overloading could achieve. As soon as you try to handle composite structures, the type class patterns shows its advantage: with this pattern, no need to define writeIntArray, writeLongArray and so on. Just write one generic writeArray meyhod which takes an implicit BufferWriter instance for the element typeLuthuli
+1, with a very minor observation: you can use implicit object intWriter extends BufferWriter[Int] { ... } for instances that aren't generic—it makes the intent a little clearer in my view, and is potentially a touch more efficient.Blab
+1 - This is a great answer, but there's an extra wrinkle that lets it work at full primitive speed also, even with arrays.Medea
@Travis Brown: funny, I made an edit to do just that, seems like it did not pass review for some reason.Luthuli
@RégisJean-Gilles That's because you are not supposed to do major changes to someone else's answer.Semitics
It didn't strike me as a major change (it's just slightly more idiomatic, and more efficient), especially when he explictly asked for cleanup. His thoughtful answer was intact, just slightly enhanced. Adding support for arrays, as I suggested above, is more like an actual change of the answer, which is why I didn't make it myself and suggested the change as a comment.Luthuli
Thanks and +1 for the amazing answer, Dominic! It was a little bit hard to follow the code at the first look, but you were the first one who gave the hint of the type class pattern. ;)Viburnum
No probs! Typeclasses are amazing, and this is not a very good example of their power. As for the guys touching up the code - go for it, but it seems the limitations of StackOverflow are preventing it?Pestilent
D
3
  1. The declarations

    def writeData[Int](buffer: ByteBuffer, data: Int) 
    def writeData[Long](buffer: ByteBuffer, data: Long)
    

do not compile because they are equivalent, as Int and Long are formal type parameters and not the standard Scala types. To define functions with standard Scala types just write:

def writeData(buffer: ByteBuffer, data: Int) = buffer.putInt(data)
def writeData(buffer: ByteBuffer, data: Long) = buffer.putLong(data)

This way you declare different functions with the same name.

  1. Since they are different functions, you cannot apply them to elements of a List of statically unknown type. You have first to determine the type of the List. Note it can happen the type of the List is AnyRef, then you have dynamically determine the type of each element. The determination can be done with isInstanceOf as in your original code, or with pattern matching, as rolve suggested. I think this would produce the same bytecode.

  2. In sum, you have to choose:

    • fast code with multiple functions like writeByteArray, writeIntArray etc. They all can have the same name writeArray but can be statically distinguished by their actual parameters. The variant suggested by Dominic Bou-Sa is of this kind.

    • concise but slow code with run-time type determination

Unfortunately, you cannot have both fast and concise code.

Dinin answered 1/11, 2012 at 9:53 Comment(7)
The two definitions of writeData are not equivalent (they are just wrong): Int and Long are different types and thus we have two different method signatures, which scala's overloading rules handles just fine. Your explanation about formal parameter is confusing. +1 any away for stressing the difference between the static and dynamic approaches.Luthuli
@Régis Jean-Gilles what "two definitions" are you talking about? There are two pairs of definitions in my answer. The first pair, taken from OP, are equivalent because Int and Long there are names of type parameters, introduced in the definitions. What confuses you in my explanation?Dinin
I am talking about the first pair (from the OP). When you say they are equivalent, It can be read as "same signature", especially given that you talk about "formal parameters" and not "formal type parameters (better to be extra clear, as there is also a formal parameter "data" of type Int/Long here). After rereading it, I think you actually point exactly at the right thing, but it would probably be clearer to point right away that the Int and Long identifiers inside the square brackets are type parameters that could as well be changed to Foo and Bar without any semantic change.Luthuli
Oh and yes, I was wrong when I said in my earlier comment that the two definitions are not equivalent. I was confused by your mention of "formal parameter" as explained above.Luthuli
Saying "formal parameters" I meant "formal type parameters", and yes that was confusing. I corrected the answer.Dinin
@AlexeiKaigorodov - You are incorrect that you cannot have both fast and concise code. The compiler can write the multiple functions for you, as long as you supply the key details. I.e. it does most of (1) for you. See my answer!Medea
Thanks for the great answer, Alexei. Your answer was really helpful in understanding the entire context along with those of Dominic's and RexKerr's. +1!Viburnum
C
2

How about this:

def writeData(buffer: ByteBuffer, data: AnyVal) {
  data match {
    case d: Byte => buffer put d
    case d: Int  => buffer putInt d
    case d: Long => buffer putLong d
    ...
  }
}

Here, you make the case distinction in the writeData method, which makes all further methods very simple:

def writeArray(buffer: ByteBuffer, array: Array[AnyVal]) {
  for (elem <- array) writeData(buffer, elem)
}

Advantages: Simple, short, easy to understand.

Disadvantages: Not completely type-safe if you don't handle all AnyVal types: Someone may call writeData(buffer, ()) (the second argument being of type Unit), which may result in an error at runtime. But you can also make the handling of () a no-op, which solves the problem. The complete method would look like this:

def writeData(buffer: ByteBuffer, data: AnyVal) {
  data match {
    case d: Byte   => buffer put d
    case d: Short  => buffer putShort d
    case d: Int    => buffer putInt d
    case d: Long   => buffer putLong d
    case d: Float  => buffer putFloat d
    case d: Double => buffer putDouble d
    case d: Char   => buffer putChar d
    case true      => buffer put 1.asInstanceOf[Byte]
    case false     => buffer put 0.asInstanceOf[Byte]
    case ()        =>
  }
}

By the way, this only works so easily because of Scala's strict object-oriented nature. In Java, where primitive types are not objects, this would be much more cumbersome. There, you would actually have to create a separate method for each primitive type, unless you want to do some ugly boxing and unboxing.

Chrissie answered 1/11, 2012 at 8:57 Comment(1)
Thanks for the elaborate answer, rolve! Your pattern-matching version looks cleaner than my isInstanceOf version. :) But it seems less efficient than the type class pattern Dominic and Rex Kerr introduced as it still does type-checking dynamically in runtime.Viburnum

© 2022 - 2024 — McMap. All rights reserved.