Fast complex number arithmetic in Clojure
Asked Answered
J

2

20

I was implementing some basic complex number arithmetic in Clojure, and noticed that it was about 10 times slower than roughly equivalent Java code, even with type hints.

Compare:

(defn plus [[^double x1 ^double y1] [^double x2 ^double y2]]
    [(+ x1 x2) (+ y1 y2)])

(defn times [[^double x1 ^double y1] [^double x2 ^double y2]]
    [(- (* x1 x2) (* y1 y2)) (+ (* x1 y2) (* y1 x2))])

(time (dorun (repeatedly 100000 #(plus [1 0] [0 1]))))
(time (dorun (repeatedly 100000 #(times [1 0] [0 1])))) 

output:

"Elapsed time: 69.429796 msecs"
"Elapsed time: 72.232479 msecs"

with:

public static void main( String[] args ) {
  double[] z1 = new double[] { 1, 0 };
  double[] z2 = new double[] { 0, 1 };
  double[] z3 = null;

  long l_StartTimeMillis = System.currentTimeMillis();
  for ( int i = 0; i < 100000; i++ ) {
    z3 = plus( z1, z2 ); // assign result to dummy var to stop compiler from optimising the loop away
  }
  long l_EndTimeMillis = System.currentTimeMillis();
  long l_TimeTakenMillis = l_EndTimeMillis - l_StartTimeMillis;
  System.out.format( "Time taken: %d millis\n", l_TimeTakenMillis );


  l_StartTimeMillis = System.currentTimeMillis();
  for ( int i = 0; i < 100000; i++ ) {
    z3 = times( z1, z2 );
  }
  l_EndTimeMillis = System.currentTimeMillis();
  l_TimeTakenMillis = l_EndTimeMillis - l_StartTimeMillis;
  System.out.format( "Time taken: %d millis\n", l_TimeTakenMillis );

  doNothing( z3 );
}

private static void doNothing( double[] z ) {

}

public static double[] plus (double[] z1, double[] z2) {
  return new double[] { z1[0] + z2[0], z1[1] + z2[1] };
}

public static double[] times (double[] z1, double[] z2) {
  return new double[] { z1[0]*z2[0] - z1[1]*z2[1], z1[0]*z2[1] + z1[1]*z2[0] };
}

output:

Time taken: 6 millis
Time taken: 6 millis

In fact, the type hints don't seem to make a difference: if I remove them I get approximately the same result. What's really strange is that if I run the Clojure script without a REPL, I get slower results:

"Elapsed time: 137.337782 msecs"
"Elapsed time: 214.213993 msecs"

So my questions are: how can I get close to the performance of the Java code? And why on Earth do the expressions take longer to evaluate when running clojure without a REPL?

UPDATE ==============

Great, using deftype with type hints in the deftype and in the defns, and using dotimes rather than repeatedly gives performance as good as or better than the Java version. Thanks to both of you.

(deftype complex [^double real ^double imag])

(defn plus [^complex z1 ^complex z2]
  (let [x1 (double (.real z1))
        y1 (double (.imag z1))
        x2 (double (.real z2))
        y2 (double (.imag z2))]
    (complex. (+ x1 x2) (+ y1 y2))))

(defn times [^complex z1 ^complex z2]
  (let [x1 (double (.real z1))
        y1 (double (.imag z1))
        x2 (double (.real z2))
        y2 (double (.imag z2))]
    (complex. (- (* x1 x2) (* y1 y2)) (+ (* x1 y2) (* y1 x2)))))

(println "Warm up")
(time (dorun (repeatedly 100000 #(plus (complex. 1 0) (complex. 0 1)))))
(time (dorun (repeatedly 100000 #(times (complex. 1 0) (complex. 0 1)))))
(time (dorun (repeatedly 100000 #(plus (complex. 1 0) (complex. 0 1)))))
(time (dorun (repeatedly 100000 #(times (complex. 1 0) (complex. 0 1)))))
(time (dorun (repeatedly 100000 #(plus (complex. 1 0) (complex. 0 1)))))
(time (dorun (repeatedly 100000 #(times (complex. 1 0) (complex. 0 1)))))

(println "Try with dorun")
(time (dorun (repeatedly 100000 #(plus (complex. 1 0) (complex. 0 1)))))
(time (dorun (repeatedly 100000 #(times (complex. 1 0) (complex. 0 1)))))

(println "Try with dotimes")
(time (dotimes [_ 100000]
        (plus (complex. 1 0) (complex. 0 1))))

(time (dotimes [_ 100000]
        (times (complex. 1 0) (complex. 0 1))))

Output:

Warm up
"Elapsed time: 92.805664 msecs"
"Elapsed time: 164.929421 msecs"
"Elapsed time: 23.799012 msecs"
"Elapsed time: 32.841624 msecs"
"Elapsed time: 20.886101 msecs"
"Elapsed time: 18.872783 msecs"
Try with dorun
"Elapsed time: 19.238403 msecs"
"Elapsed time: 17.856938 msecs"
Try with dotimes
"Elapsed time: 5.165658 msecs"
"Elapsed time: 5.209027 msecs"
Jezebel answered 6/8, 2012 at 8:27 Comment(2)
Have you tried setting *warn-on-reflection* to see if there's any reflection sneaking in?Trouvaille
@DaoWen: no, I've never used that setting. I've just run the script again with (set! *warn-on-reflection* true) at the top of it, and there are no warnings printed to stdout, so that means there's no reflection being used, right? Just want to make sure I'm using it correctly.Jezebel
C
25

The likely reasons for your slow performance are:

  • Clojure vectors are intrinsically more heavyweight data structures than Java double[] arrays. So you have quite a bit of extra overhead in creating and reading vectors.
  • You are boxing doubles as arguments to your functions and also when they are put into vectors. Boxing / unboxing is relatively expensive in this kind of low-level numerical code.
  • The type hints (^double) are not helping you: while you can have primitive type hints on normal Clojure functions, they won't work on vectors.

See this blog post on accelerating primitive arithmetic for some more details.

If you really want fast complex numbers in Clojure, you will probably need to implement them using deftype, something like:

(deftype Complex [^double real ^double imag])

And then define all your complex functions using this type. This will enable you to use primitive arithmetic throughout, and should be roughly equivalent to the performance of well-written Java code.

Climax answered 6/8, 2012 at 8:49 Comment(6)
I think defrecord is recommended over deftype for simple types like this.Trouvaille
@Trouvaille - I may be wrong but I believe you will get better performance from deftype - it has (slightly) less overhead than defrecord. defrecord implement full map-like behaviour and is more suited for "business object data" whereas deftype is more suited for slightly lower level data types.Climax
Thanks, I wondered about deftype/defrecord but thought they might introduce even more overhead, but I'll give deftype a try (and the stuff in that blog post) and report back.Jezebel
@Climax - You're definitely right about deftype being lower-level, but I don't think that using defrecord in place of deftype will necessarily introduce any extra overhead. Implementing extra interfaces (e.g. IPersistentMap) won't hurt you if you don't call any of the methods. Using deftype in place of defrecord would prevent you from doing keyword lookups and destructuring on instances, which might be useful in less performance-critical parts code.Trouvaille
^:static isn't relevant to typehinting, and hasn't been since at least 1.2. As of 1.3, you can have primitive typehints as function arguments; however, OP doesn't have that because he is accepting vectors, not primitives, and doubles have to be boxed to fit in. All that aside, I agree with your eventual recommendation to use deftype.Anglocatholic
Thanks Alan - my info on ^:static was somewhat out of date, have updated the answer. Do you know where this kind of stuff is documented?Climax
P
4
  • I don't know much about benchmark testing but it seems that you need to warm up jvm when you start test. So when you do it in REPL it's already warmed up. When you run as script it's not yet.

  • In java you run all loops inside 1 method. No other method except plus and times are called. In clojure you create anonymous function and call repeatedly for calling it. It takes some time. You can replace it with dotimes.

My try:

(println "Warm up")
(time (dorun (repeatedly 100000 #(plus [1 0] [0 1]))))
(time (dorun (repeatedly 100000 #(times [1 0] [0 1]))))
(time (dorun (repeatedly 100000 #(plus [1 0] [0 1]))))
(time (dorun (repeatedly 100000 #(times [1 0] [0 1]))))
(time (dorun (repeatedly 100000 #(plus [1 0] [0 1]))))
(time (dorun (repeatedly 100000 #(times [1 0] [0 1]))))

(println "Try with dorun")
(time (dorun (repeatedly 100000 #(plus [1 0] [0 1]))))
(time (dorun (repeatedly 100000 #(times [1 0] [0 1]))))

(println "Try with dotimes")
(time (dotimes [_ 100000]
        (plus [1 0] [0 1])))

(time (dotimes [_ 100000]
        (times [1 0] [0 1])))

Results:

Warm up
"Elapsed time: 367.569195 msecs"
"Elapsed time: 493.547628 msecs"
"Elapsed time: 116.832979 msecs"
"Elapsed time: 46.862176 msecs"
"Elapsed time: 27.805174 msecs"
"Elapsed time: 28.584179 msecs"
Try with dorun
"Elapsed time: 26.540489 msecs"
"Elapsed time: 27.64626 msecs"
Try with dotimes
"Elapsed time: 7.3792 msecs"
"Elapsed time: 5.940705 msecs"
Percheron answered 6/8, 2012 at 8:56 Comment(1)
Thanks, that makes sense. I get similar results.Jezebel

© 2022 - 2024 — McMap. All rights reserved.