多线程求和的两种解法

本文讨论一个简单的设计题:

利用多线程求1-10000的和

首先说,求和这是一个非常非常简单的需求,类似于交替打印奇偶数的感觉,该类型的题目的唯一使用场景是面试,不能直接在生产代码使用。但同时,又能够很直接的反应一个人对编程语言里常用工具的掌握情况和基本的代码设计能力,所以,值得仔细思考其场景。

题目分析

对于连续数组求和,如果从数学的角度考虑则是:

求一个等差数列的和

如果按照这个逻辑来说,等差数列求和公式是最好的解决方案。不过,在实际项目中不可能是这种逻辑,比如是在单机上并行的1000个用户的访问数据进行统计然后求和之类的,那么显然,出于性能的角度需要并行的处理。
该类问题需要有如下的特点:

  • 可分解。 也就是要有基础的分治的逻辑
  • 结果可聚合。分解的子问题的结果要可以聚合成最终结果

比如对于上面说的统计每天的用户的访问网页数,系统内部记录了每个用户的访问记录,那么我们先分别统计每一个用户的访问量,最后对所有用户的访问量求和就满足上述规则。
所以说回来,本质上我们需要的是一个 并行的分治算法。在具体的工程实现上,既可以是单机的多线程,也可以是分布式的任务调度,这取决于数据规模和问题的复杂性。本文着重考虑单机的情况下的多线程解法。

一般来说,Java里的线程池我们都会使用ThreadPoolExecutor这个类,它提供最经典的线程池的实现:

  • 核心参数为 常驻线程数、最大线程数
  • 初始化状态可能初始化某些数量的线程也可能延迟初始化,后面则根据任务的情况决定是否增加线程
  • 线程数量在达到最大线程数后,拒绝或是其他策略

ThreadPoolExecutor解法

对于线程池的实现来说,我们重点关注如下两点:

  1. 分治策略或者最小分片单元
    在实际场景中,我们可以用时间、hash、范围等策略来分片,最小分片可以是固定数量,也可以是对最小单元的计算进行性能测算过后的动态选择。对于求和这个例子,我们暂定为 按照线程数量分治,最小单元为求和总数除以线程数,比如对于求10000的和,采用4个线程来计算,分别为[1,2500), [2500, 5000), [5000, 7500), [7500, 10000]。每个线程内直接循环。
  2. 聚合方式
    聚合方式是如何对每个线程计算的结果累计的策略,我们可选的有每个线程返回自己的值或者多个线程在某个变量上累加。 对于后者,累加就总是会涉及到变量被多个线程修改对应的可见性和互斥性的问题,该场景下我们暂不关注。我们已每个线程返回结果然后再聚合的方式来计算

上面两点已经阐述了需要实现的核心逻辑,还有一个实现上的细节需要注意:

  • 怎么判断所有线程都计算结束了?

这里的问题其实是:如何等待多个线程完成。毫无疑问的,CountDownLatch是我们的选择:

A synchronization aid that allows one or more threads to wait until a set of operations being performed in other threads completes.

我们按照线程数量初始化这个等待器,然后在线程内部计算完成后进行递减操作,在外层等待这个等待器。
所以算法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
public class ConcurrentSumCalculator implements SumOfNumber{
@Override
public int sumOfNumber(int num) {
return getSum(0, num, 4);
}

private class Sum implements Callable<Integer> {

private int subMin;
private int subMax;
private CountDownLatch threadLatch;

public Sum(int subMin, int subMax, CountDownLatch latch) {
this.subMin = subMin;
this.subMax = subMax;
this.threadLatch = latch;
}

@Override
public Integer call() throws Exception {
int sum = 0;
try {
for (int i = subMin; i <= subMax; i++) {
sum += i;
//System.out.println(Thread.currentThread().getName() + "-------" + sum);
}
} finally {
threadLatch.countDown();
}
return sum;
}
}

public Integer getSum(int min, int max, int threadNum){
if(min > max){
throw new IllegalArgumentException("非法范围值");
}
int subMin;
int subMax;
List<FutureTask<Integer>> taskList = new ArrayList<>();
int sumCounts = max - min + 1;
int subCounts = sumCounts/threadNum;
int remainder = sumCounts%threadNum;
CountDownLatch lock = new CountDownLatch(threadNum);
int mark = min;
for(int i = 0;i<threadNum;i++){
subMin = mark;
if(remainder!=0 && remainder>i){
subMax = subMin + subCounts;
}else{
subMax = mark + subCounts - 1;
}
mark = subMax + 1;

FutureTask<Integer> task = new FutureTask<Integer>(new Sum(subMin, subMax, lock));
taskList.add(task);
new Thread(task).start();
}
try{
//TODO time out
lock.await();
}catch(InterruptedException e){
System.out.println("wait sub thread time out!");
}
int sum = taskListSum(taskList);
return sum;
}

//考虑AutomicInteger.但是在实际场景中往往需要返回值。
private Integer taskListSum(List<FutureTask<Integer>> taskList){
int sum = 0;
for(FutureTask<Integer> task : taskList){
try {
sum += task.get();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
}
return sum;
}


public static void main(String[] args){
ConcurrentSumCalculator sumCalculator = new ConcurrentSumCalculator();
int sum = sumCalculator.getSum(1, 1000, 5);
System.out.println("多线程结果:"+sum);
int tmp = 0;
for(int i=1;i<=1000;i++){
tmp+=i;
}
System.out.println("单线程结果:"+tmp);
}

}

ForkJoinPoll解法

这个是jdk7新增的线程池,它的核心逻辑有:

  • 每个线程都有自己的任务队列
  • 每个线程首先会从自己的任务队列拉取任务执行,然后在没事情做的时候会从其他线程的队列里窃取任务来执行

具体可以参考网上例子,不能作为参考文档

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
public class ForkJoinSumOfNumber implements SumOfNumber {

@Override
public int sumOfNumber(int num) {
ForkJoinPool forkJoinPool = new ForkJoinPool();
return forkJoinPool.invoke(new CountTaskTmp(0, num));
}

public class CountTaskTmp extends RecursiveTask<Integer> {
private static final int THRESHOLD = 2;
private int start;
private int end;

public CountTaskTmp(int start, int end) {
this.start = start;
this.end = end;
}

//实现compute 方法来实现任务切分和计算
protected Integer compute() {
int sum = 0;
boolean canCompute = (end - start) <= THRESHOLD;
if (canCompute) {
for (int i = start; i <= end; i++)
sum += i;
} else {
//如果任务大于阀值,就分裂成两个子任务计算
int mid = (start + end) / 2;
CountTaskTmp leftTask = new CountTaskTmp(start, mid);
CountTaskTmp rightTask = new CountTaskTmp(mid + 1, end);

//执行子任务
leftTask.fork();
rightTask.fork();

//等待子任务执行完,并得到结果
int leftResult = (int) leftTask.join();
int rightResult = (int) rightTask.join();

sum = leftResult + rightResult;
}

return sum;
}
}
}