一、前言 本篇的介绍对象是 CountDownLatch ,它同样是基于 AQS 之上扩展的一款多线程场景下的工具类,它可以使一个或多个线程等待其他线程各自执行完毕后再执行。
对于 CountDownLatch 理解,我们可以将单次拆开为 CountDown 和 Latch 。CountDown 表示倒计时,Latch 表示门闩,当倒计时结束后门闩解除,门就开了。
二、使用场景 要完成一项复杂的任务,任务被划分为子任务1和子任务2,3,4...,为了提高执行任务的效率,采用多线程去完成。
由于子任务1的执行条件依赖于 子任务2,3,4...,需要先执行子任务2,3,4...获取到相应的结果才能执行子任务1,这是 CountDownLatch 就派上用场了。
三、工作原理 给定 CountDownLatch 一个倒计时数,每个线程都能访问 CountDownLatch 实例。当一批线程要协作完成任务,线程 A 可以调用 CountDownLatch 的 await()
进行等待阻塞。其他线程则做其他业务,当业务执行完成后调用 CountDownLatch 的 countDown()
减掉倒计时。最后倒计时减到 0 时,阻塞的线程 A 就会被唤醒执行后续的业务。
由于是 CountDownLatch 是基于 AQS 扩展的,因此引用 AQS 模型图可方便我们理解:
图中,state 用于保存倒计时数,Node 节点用于封装等待阻塞的线程。
四、源码解析 我们先通过案例了解 CountDownLatch 基本使用。
我们将 CountDownLatch 当作餐馆服务员,线程比作客人。当客人来到餐馆吃饭时,餐馆服务员负责记录餐桌、客人吃饭的情况。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 public class CountDownLatchTest { public static void main(String[] args) throws InterruptedException { // (1) CountDownLatch countDownLatch = new CountDownLatch(5); System.out.println(LocalDateTime.now() + " -> " + Thread.currentThread().getName() + " 上菜"); for (int i = 1; i <= 5; i++) { new Thread(() -> { try { System.out.println(LocalDateTime.now() + " -> " + Thread.currentThread().getName() + " 开始吃饭"); Double time = Math.random() * 3 + 1; TimeUnit.SECONDS.sleep(time.intValue()); System.out.println(LocalDateTime.now() + " -> " + Thread.currentThread().getName() + " 吃饭结束,走人"); // (2) 减去倒计时 countDownLatch.countDown(); } catch (InterruptedException e) { e.printStackTrace(); } }, "t" + i).start(); } // (3) 等待阻塞,当倒计时为 0 就放行 System.out.println(LocalDateTime.now() + " -> " + Thread.currentThread().getName() + " 等待客人结账"); countDownLatch.await(); System.out.println(LocalDateTime.now() + " -> " + Thread.currentThread().getName() + " 客人都走了,开始收摊"); } }
执行结果:
1 2 3 4 5 6 7 8 9 10 11 12 13 2023-03-15T11:40:08.536 -> main 上菜 2023-03-15T11:40:08.542 -> t1 开始吃饭 2023-03-15T11:40:08.542 -> t2 开始吃饭 2023-03-15T11:40:08.542 -> main 等待客人结账 2023-03-15T11:40:08.542 -> t3 开始吃饭 2023-03-15T11:40:08.542 -> t4 开始吃饭 2023-03-15T11:40:08.542 -> t5 开始吃饭 2023-03-15T11:40:09.543 -> t2 吃饭结束,走人 2023-03-15T11:40:10.543 -> t4 吃饭结束,走人 2023-03-15T11:40:11.542 -> t3 吃饭结束,走人 2023-03-15T11:40:11.542 -> t1 吃饭结束,走人 2023-03-15T11:40:11.542 -> t5 吃饭结束,走人 2023-03-15T11:40:11.542 -> main 客人都走了,开始收摊
当服务员上菜给客人后,需要等待(await()
)所有客人吃完饭结账后才能收摊,客人吃完饭需要通知服务员吃完饭结账(countDown()
)。
我们按照例子中的代码执行顺序分析。
首先查看 (1) 处代码,即创建 CountDownLatch 实例,进入构造方法中:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 public class CountDownLatch { private static final class Sync extends AbstractQueuedSynchronizer { Sync(int count) { setState(count); } int getCount() { return getState(); } // (4) 尝试获取资源 protected int tryAcquireShared(int acquires) { return (getState() == 0) ? 1 : -1; } // (5) 尝试释放资源 protected boolean tryReleaseShared(int releases) { for (;;) { int c = getState(); if (c == 0) return false; int nextc = c-1; if (compareAndSetState(c, nextc)) return nextc == 0; } } } private final Sync sync; public CountDownLatch(int count) { if (count < 0) throw new IllegalArgumentException("count < 0"); this.sync = new Sync(count); } // ...省略... }
在构造方法内部创建了 Sync 实例,而 Sync 是一个静态的内部类, 它继承 AbstractQueuedSynchronizer 类,因此 Sync 拥有了 AQS 的能力,CountDownLatch 的所有操作都是通过 Sync 实例完成的。
调用构造方法传入的 count 值(倒计时数)被传入到 Sync 的构造方法中,其内部调用 setState(count)
方法,该方法来自 AQS ,被保存到 AQS 的 state 中。
此时,AQS 的模型图如下:
回到案例代码中,main 线程创建好 CountDownLatch 实例后, 接着执行 for 循环,其方法体中创建新的线程执行其他业务,都是异步操作。我们顺着当前线程直接来到 (3) 处,即 countDownLatch.await()
,跳进源码:
1 2 3 4 5 6 public class CountDownLatch { public void await() throws InterruptedException { sync.acquireSharedInterruptibly(1); } }
await()
方法底层通过 Sync 实例调用了 acquireSharedInterruptibly(1)
方法,该方法来自 AQS :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable { // ...省略... public final void acquireSharedInterruptibly(int arg) throws InterruptedException { if (Thread.interrupted()) throw new InterruptedException(); // (6) if (tryAcquireShared(arg) < 0) // (7) doAcquireSharedInterruptibly(arg); } }
进入该方法:先判断 main 线程是否被中断,并没有,然后执行 (6) 处代码,即 tryAcquireShared(arg)
,尝试获取资源权限(判断倒计时是否为 0)。该方法是一个抽象方法,最终通过子类来实现,即上文提到的 Sync 类来实现,跳回 (4) 处:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 public class CountDownLatch { private static final class Sync extends AbstractQueuedSynchronizer { Sync(int count) { setState(count); } int getCount() { return getState(); } // (4) protected int tryAcquireShared(int acquires) { return (getState() == 0) ? 1 : -1; } // ...省略... } // ...省略... }
tryAcquireShared()
方法中判断 state 值(倒计时)是否为 0 ,是则返回 1,否则返回 -1。
从上文案例的执行结果可以看出,main 线程在线程阻塞之后,其他线程才陆续执行完毕,因此 state 值不可能为 0,最终方法返回 -1,然后执行 (7) 处代码,即 doAcquireSharedInterruptibly(arg)
方法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable { // ...省略... private void doAcquireSharedInterruptibly(int arg) throws InterruptedException { // (8) 线程被封装到 Node 节点中 final Node node = addWaiter(Node.SHARED); boolean failed = true; try { for (;;) { // (9) 获取前驱节点 final Node p = node.predecessor(); if (p == head) { // (10) 再一次尝试获取资源 int r = tryAcquireShared(arg); if (r >= 0) { // (11) 设置头结点 setHeadAndPropagate(node, r); p.next = null; // help GC failed = false; return; } } // (12) 获取资源失败,修改前驱节点的 state 状态 if (shouldParkAfterFailedAcquire(p, node) && // (13) 底层调用 LockSupport.lock() 挂起当前线程 parkAndCheckInterrupt()) throw new InterruptedException(); } } finally { if (failed) cancelAcquire(node); } } // ...省略... }
该方法在 《AQS 源码详解》 文章中详细解说过,源码上已简单注释说明,此处不多赘述。
最终,main 线程执行到 parkAndCheckInterrupt()
方法中被挂起等待。
此时,AQS 的模型图如下:
我们切换到其他线程视角,案例中 t2 线程先执行完业务调用了 countDown()
方法:
1 2 3 4 5 6 7 8 9 public class CountDownLatch { // ...省略... public void countDown() { sync.releaseShared(1); } }
countDown()
方法底层调用 releaseShared(1)
,该方法来自 AQS :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable { // ...省略... public final boolean releaseShared(int arg) { // (14) if (tryReleaseShared(arg)) { // (15) doReleaseShared(); return true; } return false; } }
线程 t2 来到 releaseShared(1)
方法中先执行 (14) 处代码,即 tryReleaseShared(arg)
代码,该方法是个抽象方法,通过子类 Sync 来实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 public class CountDownLatch { private static final class Sync extends AbstractQueuedSynchronizer { Sync(int count) { setState(count); } int getCount() { return getState(); } // (5) 尝试释放资源 protected boolean tryReleaseShared(int releases) { for (;;) { int c = getState(); if (c == 0) return false; int nextc = c-1; if (compareAndSetState(c, nextc)) return nextc == 0; } } // ...省略... } // ...省略... }
进到 tryReleaseShared(arg)
方法中,开启一个无限循环:
获取 state 值,当前值为 5。 判断 state 值,如果 为 0 返回 false,否则计算 state 新值(state 旧值 -1),此时新值为 4。 通过 CAS 方式将新值赋给 state 。 如果 state 新值为 0 返回 true,否则返回 false。 t2 线程执行方法最终返回值为 false,线程也跟着结束。
此时,AQS 的模型图如下:
其他条线程的执行步骤与 t2 线程都一样,我们直接跳到最后的 t5 线程视角。当 t5 线程执行 tryReleaseShared(arg)
将 state 值改为 0 后,方法返回 true,开始执行 (15) 处代码,即 doReleaseShared()
:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable { // ...省略... private void doReleaseShared() { for (;;) { Node h = head; // (16) if (h != null && h != tail) { int ws = h.waitStatus; // (17) Node.SIGNAL:-1 if (ws == Node.SIGNAL) { if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0)) continue; // (18) unparkSuccessor(h); } else if (ws == 0 && !compareAndSetWaitStatus(h, 0, Node.PROPAGATE)) continue; } // (19) if (h == head) break; } } // ...省略... }
该方法用于修改 CLH 队列中头结点的 waitStatus 值以及唤醒头结点的后继节点中的线程。 开启一个无限循环:
获取 CLH 的头结点 判断头结点(dummy)是否为空,同时头结点是否与尾节点相同。由 AQS 模型图可知,(16) 处的判断是成立的,随后 t5 线程进到 if 方法体中。 判断头结点(dummy)的 waitStatus 状态,当前状态值为 -1,(17) 处判断成立,将头结点的 waitStatus 通过 CAS 方式还原为 0。 修改成功后执行 (18) 处代码,即 unparkSuccessor(h)
,该方法用于查询头结点的后继节点 node1,并通过 LockSupport.unpark(thread)
唤醒节点中的线程(main 线程)。由于该方法在 《AQS 源码详解》 已讲解,此处不多赘述。 t5 线程最后来到 (19) 处,判断成立退出无限循环。 这样 t5 线程释放锁完毕,结束线程,我们转回被唤醒的 main 线程视角:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable { // ...省略... private void doAcquireSharedInterruptibly(int arg) throws InterruptedException { // (8) 线程被封装到 Node 节点中 final Node node = addWaiter(Node.SHARED); boolean failed = true; try { for (;;) { // (9) 获取前驱节点 final Node p = node.predecessor(); if (p == head) { // (10) 再一次尝试获取资源 int r = tryAcquireShared(arg); if (r >= 0) { // (11) 设置头结点 setHeadAndPropagate(node, r); p.next = null; // help GC failed = false; return; } } // (12) 获取资源失败,修改前驱节点的 state 状态 if (shouldParkAfterFailedAcquire(p, node) && // (13) 底层调用 LockSupport.lock() 挂起当前线程 parkAndCheckInterrupt()) throw new InterruptedException(); } } finally { if (failed) cancelAcquire(node); } } // ...省略... }
main 线程在执行 (13) 处代码被挂起等待,该方法是在一个无限循环中进行的,当 main 线程被 t5 线程唤醒后开始执行下一轮循环任务:
获取前驱节点,即 dummy 节点,判断是否头结点,由 AQS 模型图可知,判断成立。 调用 tryAcquireShared(arg)
,上文已介绍,由于 state 值被减为 0, 最终该方法返回值为 1。 之后执行 (11) 处代码,即 setHeadAndPropagate(node, r)
,该方法用于将 node1 节点设置为新的头结点,移除节点中的线程 旧的头结点与当前节点解除关系 最终, AQS 的模型图如下:
五、参考资料 CAS 原理新讲
LockSupport 工具介绍
AQS 源码详解