かまたま日記3

プログラミングメイン、たまに日常

ForkJoinPoolについて

ForkJoinPoolJava 7から導入された新しいExecutorのフレームワークです。 旧来のExecutorと違うのは、タスクのスケジュールのアルゴリズムとして、work-stealingを採用していることです。これは再帰処理やタスクの中で更に細かな子タスクが生成されるような計算処理に適しています(例えばWebクローラなど)

ForkJoinkPoolに登録されている各ワーカースレッドは、それぞれワーカーキュー(実際はLIFO型のスタック)を持っていて、ForkJoinTaskを積むことができます。ForkJoinTaskは外部からForkJoinPoolの execute, invoke, submit メソッドを使って登録したり、もしくはタスクの中で直接別タスクを生成しその fork メソッドを呼ぶことで登録することもできます。forkされたタスクは join メソッドを使い、計算結果を待ちます。

ということで、WikipediaのWork stealingにあるモデルをForkJoinPoolとForkJoinTaskを使って実装してみます。ForkJoinTaskにはいくつかの抽象サブクラスがあり、今回はその中のRecursiveTaskを使います。

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;

class ForkJoinPoolExample {
    public static void main(String[] args) {
        int poolSize = Integer.parseInt(args[0]);

        ForkJoinPool pool = new ForkJoinPool(poolSize);
        int result = pool.invoke(new F(1, 2));
        log("Result is " + result);
    }

    static class F extends RecursiveTask<Integer> {
        private final int a, b;

        F(int a, int b) {
            this.a = a;
            this.b = b;
        }

        @Override
        protected Integer compute() {
            log(String.format("Start compute of f(%d, %d) = g(%d) + h(%d)", a, b, a, b));
            G g = new G(a);
            g.fork();
            sleep(1000);
            H h = new H(b);
            final int result = h.compute() + g.join();
            log(String.format("f(%d, %d) = %d", a, b, result));
            return result;
        }
    }

    static class G extends RecursiveTask<Integer> {
        private final int a;

        G(int a) {
            this.a = a;
        }

        @Override
        protected Integer compute() {
            log(String.format("Start compute of g(%d) = %<d * 2", a));
            final int result = a * 2;
            log(String.format("g(%d) = %d", a, result));
            return result;
        }
    }

    static class H extends RecursiveTask<Integer> {
        private final int a;

        H(int a) {
            this.a = a;
        }

        @Override
        protected Integer compute() {
            log(String.format("Start compute of h(%d) = g(%<d) + (%<d + 1)", a));
            G g = new G(a);
            g.fork();
            sleep(1000);
            int c = a + 1;
            final int result = c + g.join();
            log(String.format("h(%d) = %d", a, result));
            return result;
        }
    }

    private static void log(String message) {
        System.out.println(String.format("%tT.%<tL [%s] %s", System.currentTimeMillis(), Thread.currentThread().getName(), message));
    }

    private static void sleep(long millis) {
        try {
            Thread.sleep(millis);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
}

最初は処理順を確認するため、プールのサイズを1にしてみます。

$ java ForkJoinPoolExample 1
01:37:23.239 [ForkJoinPool-1-worker-1] Start compute of f(1, 2) = g(1) + h(2)
01:37:24.293 [ForkJoinPool-1-worker-1] Start compute of h(2) = g(2) + (2 + 1)
01:37:25.295 [ForkJoinPool-1-worker-1] Start compute of g(2) = 2 * 2
01:37:25.295 [ForkJoinPool-1-worker-1] g(2) = 4
01:37:25.295 [ForkJoinPool-1-worker-1] h(2) = 7
01:37:25.296 [ForkJoinPool-1-worker-1] Start compute of g(1) = 1 * 2
01:37:25.296 [ForkJoinPool-1-worker-1] g(1) = 2
01:37:25.297 [ForkJoinPool-1-worker-1] f(1, 2) = 9
01:37:25.297 [main] Result is 9

プールサイズが2以上の場合はforkされたタスクは空きスレッドがあれば順次消費されていきます。

$ java ForkJoinPoolExample 2
01:37:43.282 [ForkJoinPool-1-worker-1] Start compute of f(1, 2) = g(1) + h(2)
01:37:43.292 [ForkJoinPool-1-worker-0] Start compute of g(1) = 1 * 2
01:37:43.293 [ForkJoinPool-1-worker-0] g(1) = 2
01:37:44.300 [ForkJoinPool-1-worker-1] Start compute of h(2) = g(2) + (2 + 1)
01:37:44.300 [ForkJoinPool-1-worker-0] Start compute of g(2) = 2 * 2
01:37:44.301 [ForkJoinPool-1-worker-0] g(2) = 4
01:37:45.305 [ForkJoinPool-1-worker-1] h(2) = 7
01:37:45.306 [ForkJoinPool-1-worker-1] f(1, 2) = 9
01:37:45.306 [main] Result is 9