参考文献

Fork/Join

  • Fork/Join框架是Java并发工具包中的一种可以将一个大任务拆分为很多小任务来异步执行的工具,自JDK1.7引入。
  • ForkJoinPool是ThreadPoolExecutor线程池的一种补充,是对计算密集型场景的加强

Fork/Join组成

  • 任务对象: ForkJoinTask (包括RecursiveTaskRecursiveActionCountedCompleter)

    • ForkJoinPool只接收ForkJoinTask 任务(在实际使用中,也可以接收 Runnable/Callable 任务,但在真正运行时,也会把这些任务封装成 ForkJoinTask 类型的任务)
    • RecursiveTaskForkJoinTask的子类,是一个可以递归执行的 ForkJoinTask
    • RecursiveAction 是一个无返回值的 RecursiveTask,CountedCompleter 在任务完成执行后会触发执行一个自定义的钩子函数。
  • 执行Fork/Join任务的线程: ForkJoinWorkerThread

  • 线程池: ForkJoinPool

工作窃取的实现原理

  • ForkJoinPool类中的WorkQueue正是实现工作窃取的队列,大多数操作都发生在工作窃取队列中(在嵌套类工作队列中)。这些是特殊形式的Deques,主要有push,pop,poll操作。
  • Deque是双端队列(double ended queue缩写),头部和尾部任何一端都可以进行插入,删除,获取的操作,即支持FIFO(队列)也支持LIFO(栈)顺序。
  • Deque接口的实现最常见的是LinkedList,除此还有ArrayDeque,ConcurrentLinkedDeque

工作窃取模式主要步骤

  1. 每个线程都有自己的双端队列
  2. 当调用fork方法时,将任务放进队列头部,线程以LIFO顺序,使用push/pop方式处理队列中的任务
  3. 如果自己队列里的任务处理完后,会从其他线程维护的队列尾部使用poll的方式窃取任务,以达到充分利用CPU资源的目的
  4. 从尾部窃取可以减少同原线程的竞争
  5. 当队列中剩最后一个任务时,通过cas解决原线程和窃取线程的竞争

使用示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
package cn.holelin.project.task;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.RandomUtil;
import com.google.common.collect.Lists;
import cn.holelin.project.constants.StringConstants;
import cn.holelin.project.domain.DownloadDicom;
import cn.holelin.project.exception.BusinessException;
import cn.holelin.project.service.OssService;
import cn.holelin.project.utils.SnowflakeUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.bouncycastle.util.Arrays;

import java.io.File;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.RecursiveTask;

@Slf4j
public class DownloadDicomTask extends RecursiveTask<List<String>> {

/**
* 最小分割任务
*/
private static final Integer DIRECT_HANDLE_NUM = 20;
private final DownloadDicom downloadDicom;
private final OssService ossService;

public DownloadDicomTask(DownloadDicom downloadDicom, OssService ossService) {
this.downloadDicom = downloadDicom;
this.ossService = ossService;
}

@Override
protected List<String> compute() {
final List<String> objectNameList = downloadDicom.getObjectNameList();
final String downloadFileDir = downloadDicom.getDownloadFileDir();
if (CollUtil.isEmpty(objectNameList)) {
log.error("任务为空,无法下载");
return Collections.emptyList();
}
final ArrayList<String> result = Lists.newArrayList();
final int size = objectNameList.size();
if (size < DIRECT_HANDLE_NUM) {
result.addAll(directHandler(objectNameList, downloadFileDir));
} else {
// 将objectNameList分割
final int middle = size / 2;
final DownloadDicomTask downloadDicomTask1 = new DownloadDicomTask(DownloadDicom.builder()
.downloadFileDir(downloadFileDir)
.objectNameList(objectNameList.subList(0, middle))
.build(), ossService);
final DownloadDicomTask downloadDicomTask2 = new DownloadDicomTask(DownloadDicom.builder()
.downloadFileDir(downloadFileDir)
.objectNameList(objectNameList.subList(middle, size))
.build(), ossService);
downloadDicomTask1.fork();
downloadDicomTask2.fork();
result.addAll(downloadDicomTask1.join());
result.addAll(downloadDicomTask2.join());
}
return result;
}

private List<String> directHandler(List<String> objectNameList, String downloadFileDir) {
final ArrayList<String> filePathList = Lists.newArrayList();
for (String objectName : objectNameList) {
String fileName = getObjectNameFileName(objectName);
if (StringUtils.isEmpty(fileName)) {
fileName = SnowflakeUtil.genId();
}
final File currentDownloadDir = Paths.get(downloadFileDir + File.separator + SnowflakeUtil.genId()).toFile();
if (!currentDownloadDir.exists()) {
currentDownloadDir.mkdir();
}
String filePath = currentDownloadDir.getAbsolutePath() + File.separator + fileName;
final boolean exists = Paths.get(filePath).toFile().exists();
if (exists) {
final String flag = RandomUtil.randomInt(1000) + "_";
filePath = downloadFileDir + File.separator + flag + fileName;
}
try {
ossService.download(objectName, filePath);
} catch (BusinessException e) {
e.printStackTrace();
}
filePathList.add(filePath);
}
return filePathList;
}

public String getObjectNameFileName(String objectName) {
if (StringUtils.isEmpty(objectName)) {
return StringConstants.EMPTY_STRING;
}
final String[] split = objectName.split(StringConstants.SYMBOL_SLASH);
if (Arrays.isNullOrEmpty(split)) {
return StringConstants.EMPTY_STRING;
}
return split[split.length - 1];
}

}