What determines the number of threads a Java ForkJoinPool creates?
Asked Answered
L

5

57

As far as I had understood ForkJoinPool, that pool creates a fixed number of threads (default: number of cores) and will never create more threads (unless the application indicates a need for those by using managedBlock).

However, using ForkJoinPool.getPoolSize() I discovered that in a program that creates 30,000 tasks (RecursiveAction), the ForkJoinPool executing those tasks uses 700 threads on average (threads counted each time a task is created). The tasks don't do I/O, but pure computation; the only inter-task synchronization is calling ForkJoinTask.join() and accessing AtomicBooleans, i.e. there are no thread-blocking operations.

Since join() does not block the calling thread as I understand it, there is no reason why any thread in the pool should ever block, and so (I had assumed) there should be no reason to create any further threads (which is obviously happening nevertheless).

So, why does ForkJoinPool create so many threads? What factors determine the number of threads created?

I had hoped that this question could be answered without posting code, but here it comes upon request. This code is an excerpt from a program of four times the size, reduced to the essential parts; it does not compile as it is. If desired, I can of course post the full program, too.

The program searches a maze for a path from a given start point to a given end point using depth-first search. A solution is guaranteed to exist. The main logic is in the compute() method of SolverTask: A RecursiveAction that starts at some given point and continues with all neighbor points reachable from the current point. Rather than creating a new SolverTask at each branching point (which would create far too many tasks), it pushes all neighbors except one onto a backtracking stack to be processed later and continues with only the one neighbor not pushed to the stack. Once it reaches a dead end that way, the point most recently pushed to the backtracking stack is popped, and the search continues from there (cutting back the path built from the taks's starting point accordingly). A new task is created once a task finds its backtracking stack larger than a certain threshold; from that time, the task, while continuing to pop from its backtracking stack until that is exhausted, will not push any further points to its stack when reaching a branching point, but create a new task for each such point. Thus, the size of the tasks can be adjusted using the stack limit threshold.

The numbers I quoted above ("30,000 tasks, 700 threads on average") are from searching a maze of 5000x5000 cells. So, here is the essential code:

class SolverTask extends RecursiveTask<ArrayDeque<Point>> {
// Once the backtrack stack has reached this size, the current task
// will never add another cell to it, but create a new task for each
// newly discovered branch:
private static final int MAX_BACKTRACK_CELLS = 100*1000;

/**
 * @return Tries to compute a path through the maze from local start to end
 * and returns that (or null if no such path found)
 */
@Override
public ArrayDeque<Point>  compute() {
    // Is this task still accepting new branches for processing on its own,
    // or will it create new tasks to handle those?
    boolean stillAcceptingNewBranches = true;
    Point current = localStart;
    ArrayDeque<Point> pathFromLocalStart = new ArrayDeque<Point>();  // Path from localStart to (including) current
    ArrayDeque<PointAndDirection> backtrackStack = new ArrayDeque<PointAndDirection>();
    // Used as a stack: Branches not yet taken; solver will backtrack to these branching points later

    Direction[] allDirections = Direction.values();

    while (!current.equals(end)) {
        pathFromLocalStart.addLast(current);
        // Collect current's unvisited neighbors in random order: 
        ArrayDeque<PointAndDirection> neighborsToVisit = new ArrayDeque<PointAndDirection>(allDirections.length);  
        for (Direction directionToNeighbor: allDirections) {
            Point neighbor = current.getNeighbor(directionToNeighbor);

            // contains() and hasPassage() are read-only methods and thus need no synchronization
            if (maze.contains(neighbor) && maze.hasPassage(current, neighbor) && maze.visit(neighbor))
                neighborsToVisit.add(new PointAndDirection(neighbor, directionToNeighbor.opposite));
        }
        // Process unvisited neighbors
        if (neighborsToVisit.size() == 1) {
            // Current node is no branch: Continue with that neighbor
            current = neighborsToVisit.getFirst().getPoint();
            continue;
        }
        if (neighborsToVisit.size() >= 2) {
            // Current node is a branch
            if (stillAcceptingNewBranches) {
                current = neighborsToVisit.removeLast().getPoint();
                // Push all neighbors except one on the backtrack stack for later processing
                for(PointAndDirection neighborAndDirection: neighborsToVisit) 
                    backtrackStack.push(neighborAndDirection);
                if (backtrackStack.size() > MAX_BACKTRACK_CELLS)
                    stillAcceptingNewBranches = false;
                // Continue with the one neighbor that was not pushed onto the backtrack stack
                continue;
            } else {
                // Current node is a branch point, but this task does not accept new branches any more: 
                // Create new task for each neighbor to visit and wait for the end of those tasks
                SolverTask[] subTasks = new SolverTask[neighborsToVisit.size()];
                int t = 0;
                for(PointAndDirection neighborAndDirection: neighborsToVisit)  {
                    SolverTask task = new SolverTask(neighborAndDirection.getPoint(), end, maze);
                    task.fork();
                    subTasks[t++] = task;
                }
                for (SolverTask task: subTasks) {
                    ArrayDeque<Point> subTaskResult = null;
                    try {
                        subTaskResult = task.join();
                    } catch (CancellationException e) {
                        // Nothing to do here: Another task has found the solution and cancelled all other tasks
                    }
                    catch (Exception e) {
                        e.printStackTrace();
                    }
                    if (subTaskResult != null) { // subtask found solution
                        pathFromLocalStart.addAll(subTaskResult);
                        // No need to wait for the other subtasks once a solution has been found
                        return pathFromLocalStart;
                    }
                } // for subTasks
            } // else (not accepting any more branches) 
        } // if (current node is a branch)
        // Current node is dead end or all its neighbors lead to dead ends:
        // Continue with a node from the backtracking stack, if any is left:
        if (backtrackStack.isEmpty()) {
            return null; // No more backtracking avaible: No solution exists => end of this task
        }
        // Backtrack: Continue with cell saved at latest branching point:
        PointAndDirection pd = backtrackStack.pop();
        current = pd.getPoint();
        Point branchingPoint = current.getNeighbor(pd.getDirectionToBranchingPoint());
        // DEBUG System.out.println("Backtracking to " +  branchingPoint);
        // Remove the dead end from the top of pathSoFar, i.e. all cells after branchingPoint:
        while (!pathFromLocalStart.peekLast().equals(branchingPoint)) {
            // DEBUG System.out.println("    Going back before " + pathSoFar.peekLast());
            pathFromLocalStart.removeLast();
        }
        // continue while loop with newly popped current
    } // while (current ...
    if (!current.equals(end)) {         
        // this task was interrupted by another one that already found the solution 
        // and should end now therefore:
        return null;
    } else {
        // Found the solution path:
        pathFromLocalStart.addLast(current);
        return pathFromLocalStart;
    }
} // compute()
} // class SolverTask

@SuppressWarnings("serial")
public class ParallelMaze  {

// for each cell in the maze: Has the solver visited it yet?
private final AtomicBoolean[][] visited;

/**
 * Atomically marks this point as visited unless visited before
 * @return whether the point was visited for the first time, i.e. whether it could be marked
 */
boolean visit(Point p) {
    return  visited[p.getX()][p.getY()].compareAndSet(false, true);
}

public static void main(String[] args) {
    ForkJoinPool pool = new ForkJoinPool();
    ParallelMaze maze = new ParallelMaze(width, height, new Point(width-1, 0), new Point(0, height-1));
    // Start initial task
    long startTime = System.currentTimeMillis();
     // since SolverTask.compute() expects its starting point already visited, 
    // must do that explicitly for the global starting point:
    maze.visit(maze.start);
    maze.solution = pool.invoke(new SolverTask(maze.start, maze.end, maze));
    // One solution is enough: Stop all tasks that are still running
    pool.shutdownNow();
    pool.awaitTermination(Integer.MAX_VALUE, TimeUnit.DAYS);
    long endTime = System.currentTimeMillis();
    System.out.println("Computed solution of length " + maze.solution.size() + " to maze of size " + 
            width + "x" + height + " in " + ((float)(endTime - startTime))/1000 + "s.");
}
Lorenelorens answered 29/5, 2012 at 10:44 Comment(0)
D
22

There're related questions on stackoverflow:

ForkJoinPool stalls during invokeAll/join

ForkJoinPool seems to waste a thread

I made a runnable stripped down version of what is happening (jvm arguments i used: -Xms256m -Xmx1024m -Xss8m):

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.TimeUnit;

public class Test1 {

    private static ForkJoinPool pool = new ForkJoinPool(2);

    private static class SomeAction extends RecursiveAction {

        private int counter;         //recursive counter
        private int childrenCount=80;//amount of children to spawn
        private int idx;             // just for displaying

        private SomeAction(int counter, int idx) {
            this.counter = counter;
            this.idx = idx;
        }

        @Override
        protected void compute() {

            System.out.println(
                "counter=" + counter + "." + idx +
                " activeThreads=" + pool.getActiveThreadCount() +
                " runningThreads=" + pool.getRunningThreadCount() +
                " poolSize=" + pool.getPoolSize() +
                " queuedTasks=" + pool.getQueuedTaskCount() +
                " queuedSubmissions=" + pool.getQueuedSubmissionCount() +
                " parallelism=" + pool.getParallelism() +
                " stealCount=" + pool.getStealCount());
            if (counter <= 0) return;

            List<SomeAction> list = new ArrayList<>(childrenCount);
            for (int i=0;i<childrenCount;i++){
                SomeAction next = new SomeAction(counter-1,i);
                list.add(next);
                next.fork();
            }


            for (SomeAction action:list){
                action.join();
            }
        }
    }

    public static void main(String[] args) throws Exception{
        pool.invoke(new SomeAction(2,0));
    }
}

Apparently when you perform a join, current thread sees that required task is not yet completed and takes another task for himself to do.

It happens in java.util.concurrent.ForkJoinWorkerThread#joinTask.

However this new task spawns more of the same tasks, but they can not find threads in the pool, because threads are locked in join. And since it has no way to know how much time it will require for them to be released (thread could be in infinite loop or deadlocked forever), new thread(s) is(are) spawned (Compensating for joined threads as Louis Wasserman mentioned): java.util.concurrent.ForkJoinPool#signalWork

So to prevent such scenario you need to avoid recursive spawning of tasks.

For example if in above code you set initial parameter to 1, active thread amount will be 2, even if you increase childrenCount tenfold.

Also note that, while amount of active threads increases, amount of running threads is less or equal to parallelism.

Divulgate answered 21/11, 2013 at 17:54 Comment(0)
B
15

From the source comments:

Compensating: Unless there are already enough live threads, method tryPreBlock() may create or re-activate a spare thread to compensate for blocked joiners until they unblock.

I think what's happening is that you're not finishing any of the tasks very quickly, and since there aren't available worker threads when you submit a new task, a new thread gets created.

Bullion answered 29/5, 2012 at 10:59 Comment(0)
T
9

strict, full-strict, and terminally-strict have to do with processing a directed acyclic graph (DAG). You can google those terms to get a full understanding of them. That is the type of processing the framework was designed to process. Look at the code in the API for Recursive..., the framework relies on your compute() code to do other compute() links and then do a join(). Each Task does a single join() just like processing a DAG.

You are not doing DAG processing. You are forking many new Tasks and waiting (join()) on each. Have a read in the source code. It's horrendously complex but you may be able to figure it out. The framework does not do proper Task Management. Where is it going to put the waiting Task when it does a join()? There is no suspended queue, that would require a monitor thread to constantly look at the queue to see what is finished. This is why the framework uses "continuation threads". When one task does join() the framework is assuming it is waiting for a single lower Task to finish. When many join() methods are present the thread cannot continue so a helper or continuation thread needs to exist.

As noted above, you need a scatter-gather type fork-join process. There you can fork as many Tasks

Trebuchet answered 1/6, 2012 at 14:9 Comment(2)
Where would one find the scatter-gather type fork-join process, have you had a chance to look at Scala and if it does any better? Just about to dig into this article, perhaps I'll find my answer there: A Java™ Fork-Join ConquerorGudgeon
Question remains on Scala ... on other question, just downloaded TymeacSE for Java and will give it a try. Thanks, interesting article in my previous comment.Gudgeon
T
5

The both code snippets posted by Holger Peine and elusive-code doesn't actually follow recommended practice which appeared in javadoc for 1.8 version:

In the most typical usages, a fork-join pair act like a call (fork) and return (join) from a parallel recursive function. As is the case with other forms of recursive calls, returns (joins) should be performed innermost-first. For example, a.fork(); b.fork(); b.join(); a.join(); is likely to be substantially more efficient than joining code a before code b.

In both cases FJPool was instantiated via default constructor. This leads to construction of the pool with asyncMode=false, which is default:

@param asyncMode if true,
establishes local first-in-first-out scheduling mode for forked tasks that are never joined. This mode may be more appropriate than default locally stack-based mode in applications in which worker threads only process event-style asynchronous tasks. For default value, use false.

that way working queue is actually lifo:
head -> | t4 | t3 | t2 | t1 | ... | <- tail

So in snippets they fork() all task pushing them on stack and than join() in same order, that is from deepest task (t1) to topmost (t4) effectively blocking until some other thread will steal (t1), then (t2) and so on. Since there is enouth tasks to block all pool threads (task_count >> pool.getParallelism()) compensation kicks in as Louis Wasserman described.

Thermosetting answered 12/7, 2017 at 21:14 Comment(0)
P
4

It is worth noting that the output of the code posted by elusive-code depends on the version of java. Running the code in the java 8 I see the output:

...
counter=0.73 activeThreads=45 runningThreads=5 poolSize=49 queuedTasks=105 queuedSubmissions=0 parallelism=2 stealCount=3056
counter=0.75 activeThreads=46 runningThreads=1 poolSize=51 queuedTasks=0 queuedSubmissions=0 parallelism=2 stealCount=3158
counter=0.77 activeThreads=47 runningThreads=3 poolSize=51 queuedTasks=0 queuedSubmissions=0 parallelism=2 stealCount=3157
counter=0.74 activeThreads=45 runningThreads=3 poolSize=51 queuedTasks=5 queuedSubmissions=0 parallelism=2 stealCount=3153

But running the same code in the java 11 the output is different:

...
counter=0.75 activeThreads=1 runningThreads=1 poolSize=2 queuedTasks=4 queuedSubmissions=0 parallelism=2 stealCount=0
counter=0.76 activeThreads=1 runningThreads=1 poolSize=2 queuedTasks=3 queuedSubmissions=0 parallelism=2 stealCount=0
counter=0.77 activeThreads=1 runningThreads=1 poolSize=2 queuedTasks=2 queuedSubmissions=0 parallelism=2 stealCount=0
counter=0.78 activeThreads=1 runningThreads=1 poolSize=2 queuedTasks=1 queuedSubmissions=0 parallelism=2 stealCount=0
counter=0.79 activeThreads=1 runningThreads=1 poolSize=2 queuedTasks=0 queuedSubmissions=0 parallelism=2 stealCount=0
Philosophical answered 6/12, 2019 at 11:59 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.