CountDownLatch源码解读

我会带着你远行 2021-06-10 20:40 643阅读 0赞

CountDownLatch

本意是倒数计时,它是一个线程同步的辅助工具。它允许一个或多个线程阻塞等待直到其他线程执行完一系列的操作之后再继续执行。
几个主要的api:

  • await()方法:阻塞等待,直到state等于0。(它是基于AbstractQueuedSynchronizer实现的)
  • await(long timeout, TimeUnit unit):定时等待,直到state等于0或超时。
  • countDown():将state减1,当state减到0之后,await阻塞的线程可以继续执行。
    说起来还是比较简单的。

使用场景

当一个操作需要依赖于其他多个耗时的条件时(比如调接口获取),可以采用CountDownLatch实现并发

举个例子

  1. import org.junit.Test;
  2. import java.util.Random;
  3. import java.util.concurrent.*;
  4. import java.util.concurrent.atomic.AtomicInteger;
  5. public class CountDownLatchTest {
  6. private final AtomicInteger poolNumber = new AtomicInteger(1);
  7. /** * 定义一个线程池 */
  8. private ThreadPoolExecutor executor = new ThreadPoolExecutor(20, 50, 60, TimeUnit.SECONDS,
  9. new LinkedBlockingDeque<>(2000), r -> new Thread(r,"工人" + poolNumber.incrementAndGet()));
  10. class Worker implements Runnable{
  11. private CountDownLatch countDownLatch;
  12. Worker(CountDownLatch countDownLatch) {
  13. this.countDownLatch = countDownLatch;
  14. }
  15. @Override
  16. public void run() {
  17. System.out.println(Thread.currentThread().getName() + "开始执行任务");
  18. try {
  19. Thread.sleep(new Random().nextInt(10000));
  20. } catch (InterruptedException e) {
  21. e.printStackTrace();
  22. }
  23. System.out.println(Thread.currentThread().getName() + "执行任务结束");
  24. countDownLatch.countDown();
  25. }
  26. }
  27. @Test
  28. public void testCountDownLatch() {
  29. CountDownLatch countDownLatch = new CountDownLatch(5);
  30. for (int i = 0; i < 5; i++) {
  31. executor.execute(new Worker(countDownLatch));
  32. }
  33. try {
  34. countDownLatch.await();
  35. } catch (InterruptedException e) {
  36. e.printStackTrace();
  37. }
  38. System.out.println("任务全部执行结束");
  39. }
  40. }

以上示例将5个任务提交给线程池并发执行,每个子线程持有同一个countDownLatch对象,每个子线程执行结束之后执行countDown,主线程await等待。当子线程全部执行完毕之后,打印“任务全部执行结束”。

" class="reference-link">执行结果:在这里插入图片描述

源码

以上示例比较好理解,主要问题就是countDownLatch.countDown();countDownLatch.await();两个方法,下面我们来看看他们是如何实现的。
countDownLatch.countDown()方法:

  1. public void countDown() {
  2. sync.releaseShared(1);
  3. }
  4. public final boolean releaseShared(int arg) {
  5. // 释放共享信号量(如果减到0,doReleaseShared)
  6. if (tryReleaseShared(arg)) {
  7. // 唤醒所有阻塞的线程
  8. doReleaseShared();
  9. return true;
  10. }
  11. return false;
  12. }

tryReleaseShared是AbstractQueuedSynchronizer的抽象方法,在CountDownLatch的内部类Sync实现:

  1. protected boolean tryReleaseShared(int releases) {
  2. // Decrement count; signal when transition to zero
  3. for (;;) {
  4. // 在自旋内获取信号量
  5. int c = getState();
  6. // 信号量已经为0了,就不操作了,正常代码不会走到这
  7. if (c == 0)
  8. return false;
  9. int nextc = c-1;
  10. // 将减1的信号量更新
  11. if (compareAndSetState(c, nextc))
  12. // 只有当信号量为0,返回true,否则false。
  13. return nextc == 0;
  14. }
  15. }

当state减到0之后,执行doReleaseShared方法,唤醒因为await阻塞的线程。

  1. private void doReleaseShared() {
  2. for (;;) {
  3. Node h = head;
  4. if (h != null && h != tail) {
  5. int ws = h.waitStatus;
  6. if (ws == Node.SIGNAL) {
  7. // 正常挂起的线程状态是SIGNAL,将waitStatus置为0,并唤醒线程
  8. if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
  9. continue; // loop to recheck cases
  10. unparkSuccessor(h);
  11. }
  12. // 本来就是0(无状态)的话,则置为无条件传播PROPAGATE
  13. else if (ws == 0 &&
  14. !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
  15. continue; // loop on failed CAS
  16. }
  17. if (h == head) // loop if head changed
  18. break;
  19. }
  20. }

再看countDownLatch.await():

  1. public void await() throws InterruptedException {
  2. sync.acquireSharedInterruptibly(1);
  3. }
  4. // AbstractQueuedSynchronizer
  5. public final void acquireSharedInterruptibly(int arg)
  6. throws InterruptedException {
  7. if (Thread.interrupted())
  8. throw new InterruptedException();
  9. if (tryAcquireShared(arg) < 0)
  10. doAcquireSharedInterruptibly(arg);
  11. }

tryAcquireShared也是AbstractQueuedSynchronizer的抽象方法,在CountDownLatch的内部类Sync实现:

  1. protected int tryAcquireShared(int acquires) {
  2. // 只要state不等于0,就会阻塞线程
  3. return (getState() == 0) ? 1 : -1;
  4. }

阻塞线程doAcquireSharedInterruptibly:

  1. private void doAcquireSharedInterruptibly(int arg)
  2. throws InterruptedException {
  3. // 创建阻塞线程节点
  4. final Node node = addWaiter(Node.SHARED);
  5. boolean failed = true;
  6. try {
  7. for (;;) {
  8. // 如果node的前驱阶段是阻塞队列的头结点,再尝试一次获取信号量
  9. final Node p = node.predecessor();
  10. if (p == head) {
  11. int r = tryAcquireShared(arg);
  12. // 成功获取就出队了
  13. if (r >= 0) {
  14. setHeadAndPropagate(node, r);
  15. p.next = null; // help GC
  16. failed = false;
  17. return;
  18. }
  19. }
  20. // 线程挂起前判断
  21. // 1、如果线程已经是SIGNAL,但还没拿到信号量,就能够安全挂起
  22. // 2、如果任务被取消CANCELLED,则移除该节点
  23. // 3、否则在挂起前再将状态置为SIGNAL,再尝试一次获取信号量
  24. if (shouldParkAfterFailedAcquire(p, node) &&
  25. // 挂起线程
  26. parkAndCheckInterrupt())
  27. throw new InterruptedException();
  28. }
  29. } finally {
  30. if (failed)
  31. cancelAcquire(node);
  32. }
  33. }

小结

1、await():当state不等于0则尝试挂起线程
2、countDown():当state等于0时,则唤醒阻塞线程

发表评论

表情:
评论列表 (有 0 条评论,643人围观)

还没有评论,来说两句吧...

相关阅读