Java Fork/Join 框架使用指南

更新于 2025-12-29

baeldung 2023-09-28

1. 概述

Java 7 引入了 Fork/Join 框架。该框架提供了一套工具,通过尽可能利用所有可用的处理器核心来加速并行处理。它采用“分而治之”(divide and conquer)的方法实现这一目标。

在实际应用中,这意味着框架首先进行“分叉”(fork),递归地将任务拆分为更小的、相互独立的子任务,直到这些子任务足够简单,可以异步执行。

随后进入“合并”(join)阶段:所有子任务的结果被递归地合并为一个最终结果。对于不返回值的任务(void 类型),程序只需等待所有子任务执行完成即可。

为了实现高效的并行执行,Fork/Join 框架使用了一个名为 ForkJoinPool 的线程池。该线程池管理着类型为 ForkJoinWorkerThread 的工作线程。

2. ForkJoinPool

ForkJoinPool 是该框架的核心。它是 ExecutorService 的一种实现,用于管理工作线程,并提供了获取线程池状态和性能信息的工具。

每个工作线程一次只能执行一个任务,但 ForkJoinPool 并不会为每一个子任务都创建一个单独的线程。相反,池中的每个线程都有自己的双端队列(deque,发音为 “deck”),用于存储任务。

这种架构对负载均衡至关重要,因为它支持工作窃取算法(work-stealing algorithm)。

2.1 工作窃取算法

简而言之,空闲的线程会尝试从繁忙线程的 deque 中“窃取”任务。

默认情况下,工作线程从自己 deque 的头部获取任务。当自己的 deque 为空时,线程会从其他繁忙线程 deque 的尾部或全局任务队列中获取任务——因为较大的任务块通常位于这些位置。

这种策略最大限度地减少了线程之间对任务的竞争,也减少了线程寻找新任务的次数,因为它优先处理最大的可用任务块。

2.2 ForkJoinPool 的实例化

在 Java 8 中,获取 ForkJoinPool 实例最便捷的方式是调用其静态方法 commonPool()。这将返回一个公共线程池(common pool)的引用,该池是所有 ForkJoinTask 的默认线程池。

根据 Oracle 官方文档,使用预定义的公共池有助于减少资源消耗,因为它避免了为每个任务单独创建线程池。

ForkJoinPool commonPool = ForkJoinPool.commonPool();

在 Java 7 中,我们可以通过创建 ForkJoinPool 并将其赋值给某个工具类的公共静态字段来实现类似效果:

public static ForkJoinPool forkJoinPool = new ForkJoinPool(2);

然后即可轻松访问:

ForkJoinPool forkJoinPool = PoolUtil.forkJoinPool;

通过 ForkJoinPool 的构造函数,我们可以创建自定义线程池,指定并行度(parallelism)、线程工厂(thread factory)和异常处理器(exception handler)。例如,上面的池设置了并行度为 2,表示该池将使用两个处理器核心。

3. ForkJoinTask

ForkJoinTask 是在 ForkJoinPool 中执行的任务的基类。在实际开发中,通常应继承它的两个子类之一:

  • RecursiveAction:用于不返回结果的任务(void)。
  • RecursiveTask<V>:用于返回结果的任务。

这两个子类都包含一个抽象方法 compute(),任务逻辑就在该方法中定义。

3.1 RecursiveAction

以下示例使用一个名为 workload 的字符串表示待处理的工作单元。为演示目的,任务逻辑很简单:将输入字符串转为大写并记录日志。

为了展示框架的“分叉”行为,当 workload.length() 超过指定阈值时,任务会通过 createSubtask() 方法进行拆分。

字符串被递归地分割为子串,并基于这些子串创建 CustomRecursiveAction 实例。

最终,这些子任务以 List<CustomRecursiveAction> 的形式返回,并通过 invokeAll() 方法提交到 ForkJoinPool

public class CustomRecursiveAction extends RecursiveAction {

    private String workload = "";
    private static final int THRESHOLD = 4;

    private static Logger logger = Logger.getAnonymousLogger();

    public CustomRecursiveAction(String workload) {
        this.workload = workload;
    }

    @Override
    protected void compute() {
        if (workload.length() > THRESHOLD) {
            ForkJoinTask.invokeAll(createSubtasks());
        } else {
            processing(workload);
        }
    }

    private List<CustomRecursiveAction> createSubtasks() {
        List<CustomRecursiveAction> subtasks = new ArrayList<>();

        String partOne = workload.substring(0, workload.length() / 2);
        String partTwo = workload.substring(workload.length() / 2, workload.length());

        subtasks.add(new CustomRecursiveAction(partOne));
        subtasks.add(new CustomRecursiveAction(partTwo));

        return subtasks;
    }

    private void processing(String work) {
        String result = work.toUpperCase();
        logger.info("This result - (" + result + ") - was processed by " 
          + Thread.currentThread().getName());
    }
}

我们可以使用此模式开发自己的 RecursiveAction 类:创建一个表示总工作量的对象,选择合适的阈值,定义拆分工作的方法,以及定义实际执行工作的逻辑。

3.2 RecursiveTask

对于需要返回值的任务,逻辑类似,但不同之处在于:每个子任务的结果会被合并为一个最终结果。

public class CustomRecursiveTask extends RecursiveTask<Integer> {
    private int[] arr;
    private static final int THRESHOLD = 20;

    public CustomRecursiveTask(int[] arr) {
        this.arr = arr;
    }

    @Override
    protected Integer compute() {
        if (arr.length > THRESHOLD) {
            return ForkJoinTask.invokeAll(createSubtasks())
              .stream()
              .mapToInt(ForkJoinTask::join)
              .sum();
        } else {
            return processing(arr);
        }
    }

    private Collection<CustomRecursiveTask> createSubtasks() {
        List<CustomRecursiveTask> dividedTasks = new ArrayList<>();
        dividedTasks.add(new CustomRecursiveTask(
          Arrays.copyOfRange(arr, 0, arr.length / 2)));
        dividedTasks.add(new CustomRecursiveTask(
          Arrays.copyOfRange(arr, arr.length / 2, arr.length)));
        return dividedTasks;
    }

    private Integer processing(int[] arr) {
        return Arrays.stream(arr)
          .filter(a -> a > 10 && a < 27)
          .map(a -> a * 10)
          .sum();
    }
}

在此示例中,我们使用 arr 字段中的整型数组表示工作单元。createSubtasks() 方法递归地将任务拆分为更小的部分,直到每部分小于阈值。然后 invokeAll() 方法将子任务提交到公共池,并返回一个 Future 列表。

为了触发执行,需对每个子任务调用 join() 方法。

这里我们借助 Java 8 的 Stream API 实现了子结果的合并,使用 sum() 方法作为组合子结果的示例。

4. 向 ForkJoinPool 提交任务

有多种方式可以向线程池提交任务。

首先,可以使用 submit()execute() 方法(它们的使用场景基本相同):

forkJoinPool.execute(customRecursiveTask);
int result = customRecursiveTask.join();

invoke() 方法会分叉任务并等待结果,无需手动调用 join()

int result = forkJoinPool.invoke(customRecursiveTask);

invokeAll() 是向 ForkJoinPool 提交多个 ForkJoinTask 最便捷的方式。它接受两个任务、可变参数或一个集合,并按生成顺序返回一个 Future 对象的集合。

另外,也可以分别使用 fork()join() 方法。fork() 将任务提交到池中,但不会立即触发执行;必须调用 join() 才能真正执行并获取结果。

对于 RecursiveActionjoin() 返回 null;而对于 RecursiveTask<V>,它返回任务执行的结果:

customRecursiveTaskFirst.fork();
result = customRecursiveTaskLast.join();

在上面的例子中,我们使用 invokeAll() 提交了一组子任务。当然也可以用 fork()join() 实现相同功能,但这会影响结果的顺序。

为避免混淆,通常建议在提交多个任务时使用 invokeAll() 方法

5. 结论

使用 Fork/Join 框架可以显著加速大型任务的处理,但要达到理想效果,应遵循以下准则:

  • 尽量少用线程池:大多数情况下,每个应用程序或系统使用一个线程池即可。
  • 若无需特殊调优,优先使用默认的公共线程池(common pool)。
  • 为任务拆分设置合理的阈值
  • 避免在 ForkJoinTask 中执行任何阻塞操作