几种生产者-消费者模型

张天宇 on 2020-08-01

几种生产者消费者模型总结,还有一个采用该模式的线程池实现。

生产者消费者模型

生产者生产数据到缓冲区去,消费者从缓冲区中取数据。

如果缓冲区满了,生产者线程阻塞。

如果缓冲区为空,那么消费者线程阻塞。

采用 wait() / notify() 实现

  1. Resource 资源类,类中包含资源池大小和已经有的数量。生产者和消费者通过持有资源类的成员变量,Main 方法通过构造函数传入,线程 run 方法中操作资源类的新加和减少。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    public class Resource {
    private int num = 0;
    private int size = 10;
    public synchronized void remove() {
    if (num > 0) {
    num--;
    System.out.println("消费者" + Thread.currentThread().getName() + "消耗了一个资源,还剩" + num);
    notifyAll();
    } else {
    try {
    wait();
    System.out.println("消费者" + Thread.currentThread().getName() + "等待中“);
    } catch (InterruptedException e) {
    e.printStackTrace();
    }
    }
    }
    public synchronized void add() {
    if (num < size) {
    num++;
    System.out.println("生产者" + Thread.currentThread().getName() + "生产了一个资源,当前有" + num);
    notifyAll();
    } else {
    try {
    wait();
    System.out.println("生产者" + Thread.currentThread().getName() + "等待中“);
    } catch (InterruptedException e) {
    e.printStackTrace();
    }
    }
    }
    }
  2. 消费者线程

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    public class ConsumerThread extends Thread {
    private Resource resource;
    public ConsumerThread(Resource resource) {
    this.resource = resource;
    }
    @Override
    public void run() {
    while (true) {
    try {
    Thread.sleep(1000);
    } catch (InterruptedException e) {
    e.printStackTrace();
    }
    resource.remove();
    }
    }
    }
  3. 生产者线程

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    public class ProducerThread extends Thread {
    private Resource resource;
    public ProducerThread(Resource resource) {
    this.resource = resource;
    }
    @Override
    public void run() {
    while (true) {
    try {
    Thread.sleep(1000);
    } catch (InterruptedException e) {
    e.printStackTrace();
    }
    resource.add();
    }
    }
    }
  4. 测试类

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    public static void main(String[] args) {
    Resource resource = new Resource();
    ProducerThread producerThread = new ProducerThread(resource);
    ProducerThread producerThread1 = new ProducerThread(resource);
    ProducerThread producerThread2 = new ProducerThread(resource);
    ConsumerThread consumerThread = new ConsumerThread(resource);
    ConsumerThread consumerThread1 = new ConsumerThread(resource);
    producerThread.start();
    producerThread1.start();
    producerThread2.start();
    consumerThread.start();
    consumerThread1.start();
    }

采用 Lock 和 Condition

修改 Resource 如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
public class Resource {
private int num = 0;
private int size = 10;
private Lock lock;
private Condition consumerCondition;
private Condition producerCondition;
public Resource(Lock lock, Condition consumerCondition, Condition producerCondition) {
this.lock = lock;
this.consumerCondition = consumerCondition;
this.producerCondition = producerCondition;
}
public void remove() {
lock.lock();
try {
if (num > 0) {
num--;
System.out.println("消费者" + Thread.currentThread().getName() + "消耗了一个资源,现在有" + num);
producerCondition.signalAll();
} else {
consumerCondition.await();
System.out.println("消费者" + Thread.currentThread().getName() + "等待中");
}
} catch (InterruptedException e) {
e.printStackTrace();
} finally {
lock.unlock();
}
}
public void add() {
lock.lock();
try {
if (num < size) {
num++;
System.out.println("生产者" + Thread.currentThread().getName() + "生产了一个资源,现在有" + num);
consumerCondition.signalAll();
} else {
producerCondition.await();
System.out.println("生产者" + Thread.currentThread().getName() + "进入等待");
}
} catch (InterruptedException e) {
e.printStackTrace();
} finally {
lock.unlock();
}
}
}

测试:

1
2
3
4
5
6
7
8
9
10
11
12
Lock lock = new ReentrantLock();
Condition consumerCondition = lock.newCondition();
Condition producerCondition = lock.newCondition();
Resource resource = new Resource(lock, consumerCondition, producerCondition);
ProducerThread producerThread = new ProducerThread(resource);
ProducerThread producerThread1 = new ProducerThread(resource);
ProducerThread producerThread2 = new ProducerThread(resource);
ConsumerThread consumerThread = new ConsumerThread(resource);
producerThread.start();
producerThread1.start();
producerThread2.start();
consumerThread.start();

使用 BlockingQueue

抛出异常 特殊值 阻塞 超时
插入 add(e) offer(e) put(e) offer(e, time, unit)
移除 remove() poll() take() poll(time, unit)
检查 element() peek() 不可用 不可用
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
public class Resource {
private BlockingQueue<Integer> blockingQueue = new LinkedBlockingDeque<>(10);
public void add() {
try {
blockingQueue.put(1);
System.out.println("生产者" + Thread.currentThread().getName() + "生产了一个资源,当前有" + blockingQueue.size());
} catch (InterruptedException e) {
e.printStackTrace();
}
}
public void remove() {
try {
blockingQueue.take();
System.out.println("消费者" + Thread.currentThread().getName() + "消费了一个资源,当前有" + blockingQueue.size());
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}

因为这里使用了泛型,所以可以基于 BlockQueue 手写一个自定义线程池。

手写线程池

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
@Slf4j
class BlockingQueue<T> {
// 任务队列
private Deque<T> queue = new ArrayDeque<>();
// 锁
private ReentrantLock lock = new ReentrantLock();
// 生产者条件变量,添加线程,满的时候等待
private Condition producerCondition = lock.newCondition();
// 消费者条件变量,执行线程,空的时候等待
private Condition consumerCondition = lock.newCondition();
// 容量
private int capcity;

public BlockingQueue(int capcity) {
this.capcity = capcity;
}

// 阻塞获取
public T task() {
lock.lock();
try {
while (queue.isEmpty()) {
try {
consumerCondition.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
T t = queue.removeFirst();
producerCondition.signal();
return t;
} finally {
lock.unlock();
}
}

// 带超时的阻塞获取
public T poll(long timeout, TimeUnit unit) {
lock.lock();
try {
long nanos = unit.toNanos(timeout);
while (queue.isEmpty()) {
try {
if (nanos <= 0) {
return null;
}
nanos = consumerCondition.awaitNanos(nanos);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
T t = queue.removeFirst();
producerCondition.signal();
return t;
} finally {
lock.unlock();
}
}

// 阻塞添加
public void put(T task) {
lock.lock();
try {
while (queue.size() == capcity) {
try {
producerCondition.await();
log.debug("等待任务加入队列{}", task);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
log.debug("加入任务队列{}", task);
queue.addLast(task);
consumerCondition.signal();
} finally {
lock.unlock();
}
}

// 带超时的阻塞添加
public boolean offer(T task, long timeout, TimeUnit timeUnit) {
lock.lock();
try {
long nanos = timeUnit.toNanos(timeout);
while (queue.size() == capcity) {
try {
if (nanos <= 0) {
return false;
}
log.debug("等待加入队列{}", task);
nanos = producerCondition.awaitNanos(nanos);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
log.debug("加入任务队列{}", task);
queue.addLast(task);
consumerCondition.signal();
return true;
} finally {
lock.unlock();
}
}

// 返回队列长度
public int size() {
lock.lock();
try {
return queue.size();
} finally {
lock.unlock();
}
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@Slf4j
class ThreadPool {
// 任务队列
private BlockingQueue<Runnable> taskQueue;
// 线程集合
private HashSet<Worker> workers = new HashSet<>();
// 核心线程数
private int coreSize;
// 获取任务时的超时时间
private long timeout;
private TimeUnit timeUnit;
private RejectPolicy<Runnable> rejectPolicy;

public ThreadPool(int coreSize, long timeout, TimeUnit timeUnit, RejectPolicy<Runnable> rejectPolicy, int queueCapcity) {
this.coreSize = coreSize;
this.timeout = timeout;
this.timeUnit = timeUnit;
this.rejectPolicy = rejectPolicy;
this.taskQueue = new BlockingQueue<>(queueCapcity);
}
// 执行任务
public void execute(Runnable task) {
synchronized (workers) {
if (workers.size() < coreSize) {
Worker worker = new Worker(task);
log.debug("新增worker{}", worker);
workers.add(worker);
worker.start();
} else {
taskQueue.tryPut(rejectPolicy, task);
}
}
}

class Worker extends Thread {
private Runnable task;
public Worker(Runnable task) {
this.task = task;
}
@Override
public void run() {
// 当task不为空,执行任务;当task执行完毕,从任务队列中获取任务并执行
while (task != null || (task = taskQueue.poll(timeout, timeUnit)) != null) {
try {
log.debug("正在执行{}", task);
task.run();
} catch (Exception e) {
e.printStackTrace();
} finally {
task = null;
}
}
synchronized (workers) {
log.debug("worker被移除{}", this);
workers.remove(this);
}
}
}
}