How to use MDC with ForkJoinPool?
Asked Answered
G

4

15

Following up on How to use MDC with thread pools? how can one use MDC with a ForkJoinPool? Specifically, I how can one wrap a ForkJoinTask so MDC values are set before executing a task?

Glidden answered 16/3, 2016 at 3:42 Comment(0)
G
12

The following seems to work for me:

import java.lang.Thread.UncaughtExceptionHandler;
import java.util.Map;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.MDC;

/**
 * A {@link ForkJoinPool} that inherits MDC contexts from the thread that queues a task.
 *
 * @author Gili Tzabari
 */
public final class MdcForkJoinPool extends ForkJoinPool
{
    /**
     * Creates a new MdcForkJoinPool.
     *
     * @param parallelism the parallelism level. For default value, use {@link java.lang.Runtime#availableProcessors}.
     * @param factory     the factory for creating new threads. For default value, use
     *                    {@link #defaultForkJoinWorkerThreadFactory}.
     * @param handler     the handler for internal worker threads that terminate due to unrecoverable errors encountered
     *                    while executing tasks. For default value, use {@code null}.
     * @param asyncMode   if true, establishes local first-in-first-out scheduling mode for forked tasks that are never
     *                    joined. This mode may be more appropriate than default locally stack-based mode in applications
     *                    in which worker threads only process event-style asynchronous tasks. For default value, use
     *                    {@code false}.
     * @throws IllegalArgumentException if parallelism less than or equal to zero, or greater than implementation limit
     * @throws NullPointerException     if the factory is null
     * @throws SecurityException        if a security manager exists and the caller is not permitted to modify threads
     *                                  because it does not hold
     *                                  {@link java.lang.RuntimePermission}{@code ("modifyThread")}
     */
    public MdcForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, UncaughtExceptionHandler handler,
        boolean asyncMode)
    {
        super(parallelism, factory, handler, asyncMode);
    }

    @Override
    public void execute(ForkJoinTask<?> task)
    {
        // See https://mcmap.net/q/134252/-how-to-use-mdc-with-thread-pools
        super.execute(wrap(task, MDC.getCopyOfContextMap()));
    }

    @Override
    public void execute(Runnable task)
    {
        // See https://mcmap.net/q/134252/-how-to-use-mdc-with-thread-pools
        super.execute(wrap(task, MDC.getCopyOfContextMap()));
    }

    private <T> ForkJoinTask<T> wrap(ForkJoinTask<T> task, Map<String, String> newContext)
    {
        return new ForkJoinTask<T>()
        {
            private static final long serialVersionUID = 1L;
            /**
             * If non-null, overrides the value returned by the underlying task.
             */
            private final AtomicReference<T> override = new AtomicReference<>();

            @Override
            public T getRawResult()
            {
                T result = override.get();
                if (result != null)
                    return result;
                return task.getRawResult();
            }

            @Override
            protected void setRawResult(T value)
            {
                override.set(value);
            }

            @Override
            protected boolean exec()
            {
                // According to ForkJoinTask.fork() "it is a usage error to fork a task more than once unless it has completed
                // and been reinitialized". We therefore assume that this method does not have to be thread-safe.
                Map<String, String> oldContext = beforeExecution(newContext);
                try
                {
                    task.invoke();
                    return true;
                }
                finally
                {
                    afterExecution(oldContext);
                }
            }
        };
    }

    private Runnable wrap(Runnable task, Map<String, String> newContext)
    {
        return () ->
        {
            Map<String, String> oldContext = beforeExecution(newContext);
            try
            {
                task.run();
            }
            finally
            {
                afterExecution(oldContext);
            }
        };
    }

    /**
     * Invoked before running a task.
     *
     * @param newValue the new MDC context
     * @return the old MDC context
     */
    private Map<String, String> beforeExecution(Map<String, String> newValue)
    {
        Map<String, String> previous = MDC.getCopyOfContextMap();
        if (newValue == null)
            MDC.clear();
        else
            MDC.setContextMap(newValue);
        return previous;
    }

    /**
     * Invoked after running a task.
     *
     * @param oldValue the old MDC context
     */
    private void afterExecution(Map<String, String> oldValue)
    {
        if (oldValue == null)
            MDC.clear();
        else
            MDC.setContextMap(oldValue);
    }
}

and

import java.util.Map;
import java.util.concurrent.CountedCompleter;
import org.slf4j.MDC;

/**
 * A {@link CountedCompleter} that inherits MDC contexts from the thread that queues a task.
 *
 * @author Gili Tzabari
 * @param <T> The result type returned by this task's {@code get} method
 */
public abstract class MdcCountedCompleter<T> extends CountedCompleter<T>
{
    private static final long serialVersionUID = 1L;
    private final Map<String, String> newContext;

    /**
     * Creates a new MdcCountedCompleter instance using the MDC context of the current thread.
     */
    protected MdcCountedCompleter()
    {
        this(null);
    }

    /**
     * Creates a new MdcCountedCompleter instance using the MDC context of the current thread.
     *
     * @param completer this task's completer; {@code null} if none
     */
    protected MdcCountedCompleter(CountedCompleter<?> completer)
    {
        super(completer);
        this.newContext = MDC.getCopyOfContextMap();
    }

    /**
     * The main computation performed by this task.
     */
    protected abstract void computeWithContext();

    @Override
    public final void compute()
    {
        Map<String, String> oldContext = beforeExecution(newContext);
        try
        {
            computeWithContext();
        }
        finally
        {
            afterExecution(oldContext);
        }
    }

    /**
     * Invoked before running a task.
     *
     * @param newValue the new MDC context
     * @return the old MDC context
     */
    private Map<String, String> beforeExecution(Map<String, String> newValue)
    {
        Map<String, String> previous = MDC.getCopyOfContextMap();
        if (newValue == null)
            MDC.clear();
        else
            MDC.setContextMap(newValue);
        return previous;
    }

    /**
     * Invoked after running a task.
     *
     * @param oldValue the old MDC context
     */
    private void afterExecution(Map<String, String> oldValue)
    {
        if (oldValue == null)
            MDC.clear();
        else
            MDC.setContextMap(oldValue);
    }
}
  1. Run your tasks against MdcForkJoinPool instead of the common ForkJoinPool.
  2. Extend MdcCountedCompleter instead of CountedCompleter.
Glidden answered 18/3, 2016 at 2:34 Comment(15)
Is there a way to override default ForkJoinPool implementation with the custom one like you posted? I do not want to inject my own executor service into every CompletableFuture async call.Teerell
@IhorM. Not that I'm aware of.Glidden
I can override a ForkJoinWorkerThreadFactory with my own for ForkJoinPool, but apparently it is not enough, b/c I set MDC context on a thread, but it looks like that thread object is not being recycled once new task arrives (tasks are being added to the worker queue and ForkJoinWorkerThread processes one at a time). So, I need to set/unset MDC context for the ForkJoinTask instead of ForkJoinWorkerThreadTeerell
@IhorM. Did you apply all of the pieces in the above answer? Specifically, MdcForkJoinPool wraps your tasks and sets the MDC before/after each execution.Glidden
No, I just wanted to get away with overriding ForkJoinWorkerThreadFactory only, but I do not think my approach will work. I trust that your suggestion will work as you have control over task's life cycle.Teerell
@Gill question: why don't you override submit() methods of ForkJoinPool?Teerell
@Glidden - I added some details (including a test to show that it's working) if you wanted to update your answer. I didn't edit directly because I wanted to make sure you agreed with the changes.Cai
@Glidden I am trying to use the solution in parallelStream using customer ForkjoinPool but requestId is only populated in initial thread. More details here: #52840846Lanneret
How one makes use of MdcCountedCompleter ? Is there a way to tell Java Parallel streams to use it ? E.g. if I use IntStream.range(0, 2).parallel().sum() then it will use ReduceTask that extends the java.util.concurrent.CountedCompleter and I see no way to replace it.Restriction
@Restriction MdcCountedCompleter is used the same way you would use CountedCompleter. If you figure out how to use the latter, you can use the former.Glidden
This answer does not help at all! :-)Restriction
@Restriction Sorry, to clarify: you don't have to use MdcCountedCompleter at all. You should be using MdcForkJoinPool with whatever class you were using to schedule tasks on the pool. In my particular case, I wanted to use CountedCompleter on top of a ForkJoinPool without losing MDC values. MdcCountedCompleter reflects the changes I had to make to accomplish this. If you are using a different class, you will have to modify it in a similar manner: (1) backup MDC before executing your task (2) update MDC to the correct value (3) restores MDC to the original value before returning.Glidden
@Glidden I use new MdcFJPool().submit(() -> IntStream.range(1, 10).parallel().peek(() -> System.out.println(Thread.currentThread().getName() + " mdc: " + MDC.get("blah"))).sum()). The first FJ worker thread has the MDC value but because .sum() uses ReduceTask that extends AbstractTask > CountedCompleter all other threads do NOT have the MDC values. I see no way to plug MdcCountedCompleter here.Restriction
@Restriction My answer assumes that you control the implementation running on the fork-join pool. If you do not, you are out of luck... Sorry I could not be of more help.Glidden
@Restriction I have the same problem. The first thread has the MDC values, but the other threads have an empty MDC context. Have you found a solution for this problem?Villa
E
1

I am unfamiliar with ForkJoinPool but you can pass the MDC key/values of interest to the ForkJoinTask instances that you instantiate before submitting them to the ForkJoinPool.

Given that as of logback version 1.1.5, MDC values are not inherited by child threads, there are not too many options. They are

  1. pass the relevant MDC key/values to ForkJoinTask instances as you instantiate them
  2. extend ForkJoinPool so that MDC key/values are passed to the newly created threads
  3. create your own ThreadFactory which sets MDC key/values to newly created threads

Please note that I have not actually implemented either options 2. or 3.

Exocentric answered 17/3, 2016 at 5:21 Comment(2)
Ceki, I was asking for an automatic way for tasks to inherit the MDC of the thread that queued them. The problem with option 1 is that users often forget to inherit the MDC manually. The problem with options 2 and 3 is that MDC values are expected to come from the queuing thread, not from the executing thread. A single executing thread will run multiple tasks, each with potentially different MDC values. I hope this explains what I had in mind. Thanks anyway.Glidden
Regarding you comment: Item 1: It has to be some custom ForkJoinTask implementation. As you cannot achieve injecting MDC context into, for instance, CompletableFuture$AsyncRun. Item 2: You can extend ForkJoinPool, but MDC context shouldn't be passed once threads are created, but when new tasks are added to the worker queue. As same threads are being reused to handle multiple tasks. Item 3: Not a viable solution, as you will be setting MDC context for a thread that is constructed once, but handles multiple tasks.Teerell
C
1

Here is some additional information to go along with @Gili's answer.

Test that shows that the solution works (note that there will be lines without the Context, but at least they won't be the WRONG context, which is what was happening with a normal ForkJoinPool).

import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.startsWith;
import static org.junit.Assert.assertThat;

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.InputStreamReader;
import java.nio.charset.Charset;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;

import org.junit.Test;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;

import ch.qos.logback.classic.Level;
import ch.qos.logback.classic.Logger;
import ch.qos.logback.classic.LoggerContext;
import ch.qos.logback.classic.spi.ILoggingEvent;
import ch.qos.logback.core.OutputStreamAppender;

public class MDCForkJoinPoolTest {

    private static final Logger log = (Logger) LoggerFactory.getLogger("mdc-test");

    // you can demonstrate the problem I'm trying to fix by changing the below to a normal ForkJoinPool and then running the test
    private ForkJoinPool threads = new MDCForkJoinPool(16);
    private Semaphore threadsRunning = new Semaphore(-99);
    private ByteArrayOutputStream bio = new ByteArrayOutputStream();

    @Test
    public void shouldCopyManagedDiagnosticContextWhenUsingForkJoinPool() throws Exception {
        for (int i = 0 ; i < 100; i++) {
            Thread t = new Thread(simulatedRequest(), "MDC-Test-"+i);
            t.setDaemon(true);
            t.start();
        }

        // set up the appender to grab the output
        LoggerContext lc = (LoggerContext) LoggerFactory.getILoggerFactory();
        OutputStreamAppender<ILoggingEvent> appender = new OutputStreamAppender<>();
        LogbackEncoder encoder = new LogbackEncoder();
        encoder.setPattern("%X{mdc_val:-}=%m%n");
        encoder.setContext(lc);
        encoder.start();
        appender.setEncoder(encoder);
        appender.setImmediateFlush(true);
        appender.setContext(lc);
        appender.setOutputStream(bio);
        appender.start();
        log.addAppender(appender);
        log.setAdditive(false);
        log.setLevel(Level.INFO);

        assertThat("timed out waiting for threads to complete.", threadsRunning.tryAcquire(300, TimeUnit.SECONDS), is(true));

        Set<String> ids = new HashSet<>();
        try (BufferedReader r = new BufferedReader(new InputStreamReader(new ByteArrayInputStream(bio.toByteArray()), Charset.forName("utf8")))) {
            r.lines().forEach(line->{
                System.out.println(line);
               String[] vals = line.split("=");
               if (!vals[0].isEmpty()) {
                   ids.add(vals[0]);
                   assertThat(vals[1], startsWith(vals[0]));
               }
            });
        }

        assertThat(ids.size(), is(100));
    }

    private Runnable simulatedRequest() {
        return () -> {
            String id = UUID.randomUUID().toString();
            MDC.put("mdc_val", id);
            Map<String, String> context = MDC.getCopyOfContextMap();
            threads.submit(()->{
                MDC.setContextMap(context);
                IntStream.range(0, 100).parallel().forEach((i)->{
                   log.info("{} - {}", id, i); 
                });
            }).join();
            threadsRunning.release();
        };
    }
}

Also, here are the additional methods that should be overridden in the original answer.

    @Override
    public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
        return super.submit(wrap(task, MDC.getCopyOfContextMap()));
    }

    @Override
    public <T> ForkJoinTask<T> submit(Callable<T> task) {
        return super.submit(wrap(task, MDC.getCopyOfContextMap()));
    }

    @Override
    public <T> ForkJoinTask<T> submit(Runnable task, T result) {
        return super.submit(wrap(task, MDC.getCopyOfContextMap()), result);
    }

    @Override
    public ForkJoinTask<?> submit(Runnable task) {
        return super.submit(wrap(task, MDC.getCopyOfContextMap()));
    }

    private <T> Callable<T> wrap(Callable<T> task, Map<String, String> newContext)
    {
        return () ->
        {
            Map<String, String> oldContext = beforeExecution(newContext);
            try
            {
                return task.call();
            }
            finally
            {
                afterExecution(oldContext);
            }
        };
    }
Cai answered 12/6, 2018 at 15:1 Comment(1)
@BenL.- I tried your test (thanks) but as you noted there are cases without context. From some basic debugging it appears that the MdcCountedCompleter would need to be used, but I can't see how to use it. Any idea?Carrero
C
0

I'm stuck with the same problem. Obviously, using your custom ForkJoinPool every time you need to run a parallel Java stream is not ideal as it requires a lot of code.

However, I think I found a smaller solution compared to what was proposed by the topic creator:

@Slf4j
public class MdcTest {

    public static void main(String[] args) {
        List<Integer> list = new ArrayList<>();
        for (int i = 0; i < 10; i++) {
            list.add(i);
        }
        
        MDC.put("someKey", "iter");
        
        list.stream()
            .parallel()
            .peek(mdcParallelStreamKeeper())
            .forEach(i -> log.info("List item={} with MDC={}", i, MDC.getCopyOfContextMap()));
    }

    private static Consumer<? super Integer> mdcParallelStreamKeeper() {
        Map<String, String> contextMap = MDC.getCopyOfContextMap();
        return i -> {
            MDC.clear();
            MDC.setContextMap(contextMap);
        };
    }
}

Basically, you just need to have mdcParallelStreamKeeper method somewhere and use only it.

UPDATE #1 There is a problem with MDC cleanup in this approach.

Chastain answered 2/6, 2023 at 11:58 Comment(0)

© 2022 - 2025 — McMap. All rights reserved.