How can I rewrite this main thread - worker threads synchronization
Asked Answered
A

3

7

I've a program that goes something like this

public class Test implements Runnable
{
    public        int local_counter
    public static int global_counter
    // Barrier waits for as many threads as we launch + main thread
    public static CyclicBarrier thread_barrier = new CyclicBarrier (n_threads + 1);

    /* Constructors etc. */

    public void run()
    {
        for (int i=0; i<100; i++)
        {
            thread_barrier.await();
            local_counter = 0;
            for(int j=0 ; j = 20 ; j++)
                local_counter++;
            thread_barrier.await();
        }
    }

    public void main()
    {
        /* Create and launch some threads, stored on thread_array */
        for(int i=0 ; i<100 ; i++)
        {
            thread_barrier.await();
            thread_barrier.await();

            for (int t=1; t<thread_array.length; t++)
            {
                global_counter += thread_array[t].local_counter;
            }
        }
    }
}

Basically, I've a few threads with their own local counters, and I'm doing this (in a loop)

        |----|           |           |----|
        |main|           |           |pool|
        |----|           |           |----|
                         |

-------------------------------------------------------
barrier (get local counters before they're overwritten)
-------------------------------------------------------
                         |
                         |   1. reset local counter
                         |   2. do some computations
                         |      involving local counter
                         |
-------------------------------------------------------
             barrier (synchronize all threads)
-------------------------------------------------------
                         |
1. update global counter |
   using each thread's   |
   local counter         |

And this should all be fine and dandy, but it turns out this doesn't scale quite well. On a 16 physical nodes cluster, speedup after 6-8 threads is negligible, so I have to get rid of one of the awaits. I've tried with CyclicBarrier, which scales awfully, Semaphores, which do as much, and a custom library (jbarrier) that works great until there's more threads than physical cores, at which point it performs worse than the sequential version. But I just can't come up with a way of doing this without stopping all threads twice.

EDIT: while I appreciate all and any insight you might have concerning any other possible bottlenecks in my program, I'm looking for an answer concerning this particular issue. I can provide a more specific example if needed

Atreus answered 11/4, 2018 at 0:12 Comment(13)
How complex is the problem you are solving? Would something like CountDownLatch be of help as it may reduce the complexity of the solution?Jacksmelt
@HarisNadeem the problem is CountDownLatch is designed to be used once, whereas I use this Barrier continuously in a loop - I guess I could create a new CountDownLatch every loop, I havent tried it, but I dont think it'd be efficientAtreus
yes it wouldn't be efficient to create a new one every time. If you don't mind, I have some questions about the size of the problem and the hardware. It is possible that the problem you are trying to solve is Memory intensive? If so, increasing threads would only increase the memory load and would slow down memory available for other threads. Is it possible IO is involved in the thread? If so, that could be a possible bottleneck that increasing threads past a certain point may not be able to resolve.Jacksmelt
This looks like a standard producer-consumer problem to me, why can't the threads compute their result independently and put them in a queue with their id. Subsequently main can consume them? I am assuming the consumer is much faster than the producers.Electorate
@Electorate It's not a producer-consumer problem. This is an oversimplification for the sake of brevity, but in reality the worker threads are computing a cellular automata, so they all have to be synched up every iterationAtreus
@HarisNadeem those are all understandable questions, but I know for a fact the bottleneck is happening because of the CyclicBarrier. What I'm doing is processing a cellular automaton in parallel, with each worker thread working on a portion of the automataAtreus
@Atreus then you need to reflect that in your code does any information flow back from main thread to the workers or are they required to be in sync so that they can simulate a generation?Electorate
If later then the problem is that each thread is doing lot less for the task to take advantage of concurrency. A loop of 20 is too less. What you can do is batch a few cells automata in one thread process these counters one after another.Electorate
@Kovalainen: I've been pondering over the problem and think I may have pinpointed the speed issue. Is it possible that you are facing a similar problem. (Refer to the first graph). So the idea is that there are just a few threads that are causing the slow down since EVERYONE cannot continue until they are finished. If so, can you use await(timeout) on ur design? Or just benchmark and see what the getNumberWaiting() is across a time interval? That would help a lotJacksmelt
I do realize that since you are working on a cellular automata, you would want to preserve state and not reset before all work is done, but I thought it was worth asking. Just out of curiosity and you dont have to entertain this question, but what rule are you using in your CA?Jacksmelt
Could you share a more detailed example? (It would be the best to see the source or a runnable example.)Mcshane
Having more of your code visible here is key. Also, knowing what your performance is when using 16 (the number of cores) versus one core, and knowing what your memory settings are. Are you sure you aren't hitting GC limits due to the increased memory of additional threads?Pshaw
Share an MCVE reproducing the problem instead of simplified pseudo code, please. This is not a quiz show where people like to guess, this is SO where dedicated developers like to help each other. They are not dedicated wasting their time guessing, though.Brandon
R
3

A few fixes: your iteration over threads should be for(int t=0;...) assuming your thread array[0] should participate in the global counter sum. We can guess it's an array of Test, not threads. local_counter should be volatile, otherwise you may not see the true value across test thread and main thread.

Ok, now, you have a proper 2 phases cycle, afaict. Anything else like a phaser or 1 cycling barrier with a new countdown latch at every loop are just variations of a same theme: getting numerous threads to agree to let the main resume, and getting the main to resume numerous threads in one shot.

A thinner implementation could involve a reentrantlock, a counter of arrived tests threads, a condition to resume test on all test threads, and a condition to resume the main thread. The test thread that arrives when --count==0 should signal the main resume condition. All test threads await the test resume condition. The main should reset the counter to N and signalAll on the test resume condition, then await on the main condition. Threads (test and main) await only once per loop.

Finally, if the end goal is a sum updated by any threads, you should look at LongAdder (if not AtomicLong) to perform addition to a long concurently without having to stop all threads (them them fight and add, not involving the main).

Otherwise you can have the threads deliver their material to a blocking queue read by the main. There is just too many flavors of doing this; I'm having a hard time understanding why you hang all threads to collect data. That's all.The question is oversimplified and we don't have enough constraint to justify what you are doing.

Don't worry about CyclicBarrier, it is implemented with reentrant lock, a counter and a condition to trip the signalAll() to all waiting threads. This is tightly coded, afaict. If you wanted lock-free version, you would be facing too many busy spin loops wasting cpu time, especially when you are concerned of scaling when there is more threads than cores.

Meanwhile, is it possible that you have in fact 8 cores hyperthreaded that look like 16 cpu?

Once sanitized, your code looks like:

package tests;

import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import java.util.stream.Stream;

public class Test implements Runnable {
    static final int n_threads = 8;
    static final long LOOPS = 10000;
    public static int global_counter;
    public static CyclicBarrier thread_barrier = new CyclicBarrier(n_threads + 1);

    public volatile int local_counter;

    @Override
    public void run() {
        try {
            runImpl();
        } catch (InterruptedException | BrokenBarrierException e) {
            //
        }
    }

    void runImpl() throws InterruptedException, BrokenBarrierException {
        for (int i = 0; i < LOOPS; i++) {
            thread_barrier.await();
            local_counter = 0;
            for (int j=0; j<20; j++)
                local_counter++;
            thread_barrier.await();
        }
    }

    public static void main(String[] args) throws InterruptedException, BrokenBarrierException {
        Test[] ra = new Test[n_threads];
        Thread[] ta = new Thread[n_threads];
        for(int i=0; i<n_threads; i++)
            (ta[i] = new Thread(ra[i]=new Test()).start();

        long nanos = System.nanoTime();
        for (int i = 0; i < LOOPS; i++) {
            thread_barrier.await();
            thread_barrier.await();

            for (int t=0; t<ra.length; t++) {
                global_counter += ra[t].local_counter;
            }
        }

        System.out.println(global_counter+", "+1e-6*(System.nanoTime()-nanos)+" ms");

        Stream.of(ta).forEach(t -> t.interrupt());
    }
}

My version with 1 lock looks like this:

package tests;

import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Stream;

public class TwoPhaseCycle implements Runnable {
    static final boolean DEBUG = false;
    static final int N = 8;
    static final int LOOPS = 10000;

    static ReentrantLock lock = new ReentrantLock();
    static Condition testResume = lock.newCondition();
    static volatile long cycle = -1;
    static Condition mainResume = lock.newCondition();
    static volatile int testLeft = 0;

    static void p(Object msg) {
        System.out.println(Thread.currentThread().getName()+"] "+msg);
    }

    //-----
    volatile int local_counter;

    @Override
    public void run() {
        try {
            runImpl();
        } catch (InterruptedException e) {
            p("interrupted; ending.");
        }
    }

    public void runImpl() throws InterruptedException {
        lock.lock();
        try {
            if(DEBUG) p("waiting for 1st testResumed");
            while(cycle<0) {
                testResume.await();
            }
        } finally {
            lock.unlock();
        }

        long localCycle = 0;//for (int i = 0; i < LOOPS; i++) {
        while(true) {
            if(DEBUG) p("working");
            local_counter = 0;
            for (int j = 0; j<20; j++)
                local_counter++;
            localCycle++;

            lock.lock();
            try {
                if(DEBUG) p("done");
                if(--testLeft <=0)
                    mainResume.signalAll(); //could have been just .signal() since only main is waiting, but safety first.

                if(DEBUG) p("waiting for cycle "+localCycle+" testResumed");
                while(cycle < localCycle) {
                    testResume.await();
                }
            } finally {
                lock.unlock();
            }
        }
    }

    public static void main(String[] args) throws InterruptedException {
        TwoPhaseCycle[] ra = new TwoPhaseCycle[N];
        Thread[] ta = new Thread[N];
        for(int i=0; i<N; i++)
            (ta[i] = new Thread(ra[i]=new TwoPhaseCycle(), "\t\t\t\t\t\t\t\t".substring(0, i%8)+"\tT"+i)).start();

        long nanos = System.nanoTime();

        int global_counter = 0;
        for (int i=0; i<LOOPS; i++) {
            lock.lock();
            try {
                if(DEBUG) p("gathering");
                for (int t=0; t<ra.length; t++) {
                    global_counter += ra[t].local_counter;
                }
                testLeft = N;
                cycle = i;
                if(DEBUG) p("resuming cycle "+cycle+" tests");
                testResume.signalAll();

                if(DEBUG) p("waiting for main resume");
                while(testLeft>0) {
                    mainResume.await();
                }
            } finally {
                lock.unlock();
            }
        }

        System.out.println(global_counter+", "+1e-6*(System.nanoTime()-nanos)+" ms");

        p(global_counter);
        Stream.of(ta).forEach(t -> t.interrupt());
    }
}

Of course, this is by no mean a stable microbenchmark, but the trend shows it's faster. Hope you like it. (I dropped a few favorite tricks for debugging, worth turning debug true...)

Resting answered 16/4, 2018 at 23:56 Comment(1)
Sorry for neglecting this question, I found a suitable solution before anyone replied and forgot about it. " if the end goal is a sum updated by any threads, you should look at LongAdder (if not AtomicLong) to perform addition to a long concurently without having to stop all threads (them them fight and add, not involving the main)." -> this is exactly what I ended up doing. For that plus the rest very thorough list of improvements I think it's worth marking this as the best answerAtreus
I
2

Well. I'm not sure to fully understand, but I think your main problem is that you try to re-use a predefined set of threads too much. You should let Java take care of this (that's what executors/fork-join pool are for). To solve your issue, a split/process/merge (or map/reduce) seems appropriate to me. Since java 8, it's a really simple approach to implement (thanks to the stream/fork-join pool/completable future APIs). I propose 2 alternatives here:

Java 8 Stream

For me, your problem looks like it can be resumed to a map/reduce problem. And if you can use Java 8 streams, you can delegate performance issues to it. What I'd do :
1. Create a parallel stream, containing your processing input (you can even use methods to generate inputs on the fly). Note that you can implement your own Spliterator, to fully control the browsing and splitting of your input (cells on a grid ?).
2. Use a map to process the input.
3. Use a reduce method to merge all previously computed results.

Simple example (based on your example):

// Create a pool with wanted number of threads
    final ForkJoinPool pool = new ForkJoinPool(4);
    // We give the entire procedure to the thread pool
    final int result = pool.submit(() -> {
        // Generate a hundred counters, initialized on 0 value
        return IntStream.generate(() -> 0)
                .limit(100)
                // Specify we want it processed in a parallel way
                .parallel()
                // The map will register processing method
                .map(in -> incrementMultipleTimes(in, 20))
                // We ask the merge of processing results
                .reduce((first, second) -> first + second)
                .orElseThrow(() -> new IllegalArgumentException("Empty dataset"));
    })
            // Wait for the overall result
            .get();

    System.out.println("RESULT: " + result);

    pool.shutdown();
    pool.awaitTermination(10, TimeUnit.SECONDS);

Some things to be aware of :
1. By default, parallel streams execute tasks on JVM Common fork-join pool, which could be limited in number of executors. But there's ways to use your own pool : see this answer.
2. If well-configured, I think that's the best method, because parallelism logic has been taken care of by JDK developper themselves.

Phaser

If you cannot use java8 functionality (or I've misunderstood your problem, or you really want to handle low-level management yourself), the last clue I can give you is: Phaser object. As stated by the doc, it's a re-usable mix of cyclic barrier and countdown latch. I've use it multiple times. It's a complex thing to use, but it's also really powerful. It can be used as a cyclic barrier, so I think it fits your case.

Isidor answered 16/4, 2018 at 10:42 Comment(0)
M
0

You could really consider following the 'official' example from its (CyclicBarrier) documentation:

 class Solver {
   final int N;
   final float[][] data;
   final CyclicBarrier barrier;

   class Worker implements Runnable {
     int myRow;
     Worker(int row) { myRow = row; }
     public void run() {
       while (!done()) {
         processRow(myRow);

         try {
           barrier.await();
         } catch (InterruptedException ex) {
           return;
         } catch (BrokenBarrierException ex) {
           return;
         }
       }
     }
   }

   public Solver(float[][] matrix) {
     data = matrix;
     N = matrix.length;
     barrier = new CyclicBarrier(N,
                                 new Runnable() {
                                   public void run() {
                                     mergeRows(...);
                                   }
                                 });
     for (int i = 0; i < N; ++i)
       new Thread(new Worker(i)).start();

     waitUntilDone();
   }
 }

In your case

  • processRow() would generate a partial generation (the task is divided into N pieces, and the workers can get their number on initialization, or just use the number returned by barrier.await() (in this case the workers should start with an await)
  • mergeRows(), in the anonymous Runnable passed to the barrier at construction, is the place where an entire generation is ready, you can print it on the screen or something (and perhaps swap some 'currentGen' and 'nextGen' buffers). When this method returns (or more precisely the run()), the barrier.await() calls in the workers also return and calculation of the next generation starts (or not, see the next bullet point)
  • done() decides when the threads should exit (instead of producing a new generation). It can be a 'real' method, but a static volatile boolean variable would also work
  • waitUntilDone() could be a loop over all the threads, join()-ing them. Or just wait for something what you can trigger (from 'mergeRows') when the program should exit
Morsel answered 19/4, 2018 at 23:4 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.