Optimisation of recursive algorithm in Java
Asked Answered
B

3

10

Background

I have an ordered set of data points stored as a TreeSet<DataPoint>. Each data point has a position and a Set of Event objects (HashSet<Event>).

There are 4 possible Event objects A, B, C, and D. Every DataPoint has 2 of these, e.g. A and C, except the first and last DataPoint objects in the set, which have T of size 1.

My algorithm is to find the probability of a new DataPoint Q at position x having Event q in this set.

I do this by calculating a value S for this data set, then adding Q to the set and calculating S again. I then divide the second S by the first to isolate the probability for the new DataPoint Q.

Algorithm

The formula for calculating S is:

http://mathbin.net/equations/105225_0.png

where

http://mathbin.net/equations/105225_1.png

http://mathbin.net/equations/105225_2.png

for http://mathbin.net/equations/105225_3.png

and

http://mathbin.net/equations/105225_4.png

http://mathbin.net/equations/105225_5.png is an expensive probability function that only depends on its arguments and nothing else (and http://mathbin.net/equations/105225_6.png), http://mathbin.net/equations/105225_7.png is the last DataPoint in the set (righthand node), http://mathbin.net/equations/105225_8.png is the first DataPoint (lefthand node), http://mathbin.net/equations/105225_9.png is the rightmost DataPoint that isn't the node, http://mathbin.net/equations/105225_10.png is a DataPoint,http://mathbin.net/equations/105225_12.png is the Set of events for this DataPoint.

So the probability for Q with Event q is:

http://mathbin.net/equations/105225_11.png

Implementation

I implemented this algorithm in Java like so:

public class ProbabilityCalculator {
    private Double p(DataPoint right, Event rightEvent, DataPoint left, Event leftEvent) {
        // do some stuff
    }
    
    private Double f(DataPoint right, Event rightEvent, NavigableSet<DataPoint> points) {
        DataPoint left = points.lower(right);
        
        Double result = 0.0;
        
        if(left.isLefthandNode()) {
            result = 0.25 * p(right, rightEvent, left, null);
        } else if(left.isQ()) {
            result = p(right, rightEvent, left, left.getQEvent()) * f(left, left.getQEvent(), points);
        } else { // if M_k
            for(Event leftEvent : left.getEvents())
                result += p(right, rightEvent, left, leftEvent) * f(left, leftEvent, points);
        }
        
        return result;
    }
    
    public Double S(NavigableSet<DataPoint> points) {
        return f(points.last(), points.last().getRightNodeEvent(), points)
    }
}

So to find the probability of Q at x with q:

Double S1 = S(points);
points.add(Q);
Double S2 = S(points);
Double probability = S2/S1;

Problem

As the implementation stands at the moment it follows the mathematical algorithm closely. However this turns out not to be a particularly good idea in practice, as f calls itself twice for each DataPoint. So for http://mathbin.net/equations/105225_9.png, f is called twice, then for the n-1 f is called twice again for each of the previous calls, and so on and so forth. This leads to a complexity of O(2^n) which is pretty terrible considering there can be over 1000 DataPoints in each Set. Because p() is independent of everything except its parameters I have included a caching function where if p() has already been calculated for these parameters it just returns the previous result, but this doesn't solve the inherent complexity problem. Am I missing something here with regards to repeat computations, or is the complexity unavoidable in this algorithm?

Belly answered 15/8, 2012 at 11:4 Comment(4)
Why not cache f as well? Just move parameter points from function parameter to class member.Tarrance
@Tarrance I think this would work even better if I stored a subset of points to the left of right this would mean the cache would used even after I add Q to the points once Q has been passed in the process.Belly
Yes, so main function would run these operations: clear P cache, clear F cache, get S1, add Q, clear F cache, get S2.Tarrance
I don't think it would need to the clear the whole F cache, just the section to the right of Q.Belly
B
0

Thanks for all your suggestions. I implemented my solution by creating new nested classes for the values of P and F already calculated, then used a HashMap to store the results. The HashMap is then queried for the result before computation takes place; if it is present it just returns the result, if it is not it computes the result and adds it to the HashMap.

The final product looks a bit like this:

public class ProbabilityCalculator {

    private NavigableSet<DataPoint> points;

    private ProbabilityCalculator(NavigableSet<DataPoint> points) {
        this.points = points;
    }

    private static class P {
        public final DataPoint left;
        public final Event leftEvent;
        public final DataPoint right;
        public final Event rightEvent;

        public P(DataPoint left, Event leftEvent, DataPoint right, Event rightEvent) {
            this.left = left;
            this.leftEvent = leftEvent;
            this.right = right;
            this.rightEvent = rightEvent;
        }

        public boolean equals(Object o) {
            if(!(o instanceof P)) return false;
            P p = (P) o;

            if(!(this.leftEvent == null ? p.leftEvent == null : this.leftEvent.equals(p.leftEvent)))
                return false;
            if(!(this.rightEvent == null ? p.rightEvent == null : this.rightEvent.equals(p.rightEvent)))
                return false;

            return this.left.equals(p.left) && this.right.equals(p.right);
        }

        public int hashCode() {
            int result = 93;

            result = 31 * result + this.left.hashCode();
            result = 31 * result + this.right.hashCode();
            result = this.leftEvent != null ? 31 * result + this.leftEvent.hashCode() : 31 * result;
            result = this.rightEvent != null ? 31 * result + this.rightEvent.hashCode() : 31 * result;

            return result;
        }
    }

    private Map<P, Double> usedPs = new HashMap<P, Double>();

    private static class F {
        public final DataPoint left;
        public final Event leftEvent;
        public final NavigableSet<DataPoint> dataPointsToLeft;

        public F(DataPoint dataPoint, Event dataPointEvent, NavigableSet<DataPoint> dataPointsToLeft) {
            this.dataPoint = dataPoint;
            this.dataPointEvent = dataPointEvent;
            this.dataPointsToLeft = dataPointsToLeft;
        }

        public boolean equals(Object o) {
            if(!(o instanceof F)) return false;
            F f = (F) o;
            return this.dataPoint.equals(f.dataPoint) && this.dataPointEvent.equals(f.dataPointEvent) && this.dataPointsToLeft.equals(f.dataPointsToLeft);
        }

        public int hashCode() {
            int result = 7;

            result = 31 * result + this.dataPoint.hashCode();
            result = 31 * result + this.dataPointEvent.hashCode();
            result = 31 * result + this.dataPointsToLeft.hashCode();

            return result;
        }

    }

    private Map<F, Double> usedFs = new HashMap<F, Double>();

    private Double p(DataPoint right, Event rightEvent, DataPoint left, Event leftEvent) {
        P newP = new P(right, rightEvent, left, leftEvent);

        if(this.usedPs.containsKey(newP)) return usedPs.get(newP);


        // do some stuff

        usedPs.put(newP, result);
        return result;

    }

    private Double f(DataPoint right, Event rightEvent) {

        NavigableSet<DataPoint> dataPointsToLeft = dataPoints.headSet(right, false);

        F newF = new F(right, rightEvent, dataPointsToLeft);

        if(usedFs.containsKey(newF)) return usedFs.get(newF);

        DataPoint left = points.lower(right);

        Double result = 0.0;

        if(left.isLefthandNode()) {
            result = 0.25 * p(right, rightEvent, left, null);
        } else if(left.isQ()) {
            result = p(right, rightEvent, left, left.getQEvent()) * f(left, left.getQEvent(), points);
        } else { // if M_k
            for(Event leftEvent : left.getEvents())
                result += p(right, rightEvent, left, leftEvent) * f(left, leftEvent, points);
        }

        usedFs.put(newF, result)

        return result;
    }

    public Double S() {
        return f(points.last(), points.last().getRightNodeEvent(), points)
    }

    public static probabilityOfQ(DataPoint q, NavigableSet<DataPoint> points) {
        ProbabilityCalculator pc = new ProbabilityCalculator(points);

        Double S1 = S();

        points.add(q);

        Double S2 = S();

        return S2/S1;

    }
}
Belly answered 15/8, 2012 at 15:30 Comment(0)
V
2

You also need to memoize f on the first 2 arguments (the 3rd is always passed through, so you don't need to worry about that). This will reduce the time complexity of your code from O(2^n) to O(n).

Velez answered 15/8, 2012 at 15:3 Comment(0)
B
0

UPDATED:

Since as commented below, order can not be used to help optimize another method must be utilized. Since most of the P values will be calculated multiple times (and as noted, this is expensive), one optimization would be to cache them. I am not sure of what the best key would be, but you could imagine changing the code something like:

....
private Map<String, Double> previousResultMap = new ....


private Double p(DataPoint right, Event rightEvent, DataPoint left, Event leftEvent) {
   String key = // calculate unique key from inputs
   Double previousResult = previousResultMap.get(key);
   if (previousResult != null) {
      return previousResult;
   } 

   // do some stuff
   previousResultMap.put(key, result);
   return result;
}

This approach should effectively reduce a lot of the redundant calculations - however, as you know the data much more than I, you will need to determine the best way to set the key (and even if String is the best representation for that).

Bandwagon answered 15/8, 2012 at 11:32 Comment(2)
If I understand you correctly I don't think this will work. As S does not just depend on the number of points. If I put Q at position x which is between point x and y then S without Q will call p(x, xEvent, y, yEvent) however, S with Q will call p(x, xEvent, q, qEvent) then p(q, qEvent, y, yEvent). But I could call both S at the same time and only divert when one of them reaches QBelly
You should look at my own answer below!Belly
B
0

Thanks for all your suggestions. I implemented my solution by creating new nested classes for the values of P and F already calculated, then used a HashMap to store the results. The HashMap is then queried for the result before computation takes place; if it is present it just returns the result, if it is not it computes the result and adds it to the HashMap.

The final product looks a bit like this:

public class ProbabilityCalculator {

    private NavigableSet<DataPoint> points;

    private ProbabilityCalculator(NavigableSet<DataPoint> points) {
        this.points = points;
    }

    private static class P {
        public final DataPoint left;
        public final Event leftEvent;
        public final DataPoint right;
        public final Event rightEvent;

        public P(DataPoint left, Event leftEvent, DataPoint right, Event rightEvent) {
            this.left = left;
            this.leftEvent = leftEvent;
            this.right = right;
            this.rightEvent = rightEvent;
        }

        public boolean equals(Object o) {
            if(!(o instanceof P)) return false;
            P p = (P) o;

            if(!(this.leftEvent == null ? p.leftEvent == null : this.leftEvent.equals(p.leftEvent)))
                return false;
            if(!(this.rightEvent == null ? p.rightEvent == null : this.rightEvent.equals(p.rightEvent)))
                return false;

            return this.left.equals(p.left) && this.right.equals(p.right);
        }

        public int hashCode() {
            int result = 93;

            result = 31 * result + this.left.hashCode();
            result = 31 * result + this.right.hashCode();
            result = this.leftEvent != null ? 31 * result + this.leftEvent.hashCode() : 31 * result;
            result = this.rightEvent != null ? 31 * result + this.rightEvent.hashCode() : 31 * result;

            return result;
        }
    }

    private Map<P, Double> usedPs = new HashMap<P, Double>();

    private static class F {
        public final DataPoint left;
        public final Event leftEvent;
        public final NavigableSet<DataPoint> dataPointsToLeft;

        public F(DataPoint dataPoint, Event dataPointEvent, NavigableSet<DataPoint> dataPointsToLeft) {
            this.dataPoint = dataPoint;
            this.dataPointEvent = dataPointEvent;
            this.dataPointsToLeft = dataPointsToLeft;
        }

        public boolean equals(Object o) {
            if(!(o instanceof F)) return false;
            F f = (F) o;
            return this.dataPoint.equals(f.dataPoint) && this.dataPointEvent.equals(f.dataPointEvent) && this.dataPointsToLeft.equals(f.dataPointsToLeft);
        }

        public int hashCode() {
            int result = 7;

            result = 31 * result + this.dataPoint.hashCode();
            result = 31 * result + this.dataPointEvent.hashCode();
            result = 31 * result + this.dataPointsToLeft.hashCode();

            return result;
        }

    }

    private Map<F, Double> usedFs = new HashMap<F, Double>();

    private Double p(DataPoint right, Event rightEvent, DataPoint left, Event leftEvent) {
        P newP = new P(right, rightEvent, left, leftEvent);

        if(this.usedPs.containsKey(newP)) return usedPs.get(newP);


        // do some stuff

        usedPs.put(newP, result);
        return result;

    }

    private Double f(DataPoint right, Event rightEvent) {

        NavigableSet<DataPoint> dataPointsToLeft = dataPoints.headSet(right, false);

        F newF = new F(right, rightEvent, dataPointsToLeft);

        if(usedFs.containsKey(newF)) return usedFs.get(newF);

        DataPoint left = points.lower(right);

        Double result = 0.0;

        if(left.isLefthandNode()) {
            result = 0.25 * p(right, rightEvent, left, null);
        } else if(left.isQ()) {
            result = p(right, rightEvent, left, left.getQEvent()) * f(left, left.getQEvent(), points);
        } else { // if M_k
            for(Event leftEvent : left.getEvents())
                result += p(right, rightEvent, left, leftEvent) * f(left, leftEvent, points);
        }

        usedFs.put(newF, result)

        return result;
    }

    public Double S() {
        return f(points.last(), points.last().getRightNodeEvent(), points)
    }

    public static probabilityOfQ(DataPoint q, NavigableSet<DataPoint> points) {
        ProbabilityCalculator pc = new ProbabilityCalculator(points);

        Double S1 = S();

        points.add(q);

        Double S2 = S();

        return S2/S1;

    }
}
Belly answered 15/8, 2012 at 15:30 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.