Cargo.toml:
[package]
name = "tokio_usecase"
version = "0.1.0"
edition = "2021"
[dependencies]
tokio = { version = "1.25.0", features = ["full"] }
async-trait = "0.1.64"
thiserror = "1.0.35"
src/main.rs:
xxxxxxxxxx
use std::{collections::HashMap};
/// 任务管理器可能引发的错误
pub enum TaskManagerError {
/// 在等待所有任务完成时,Tokio JoinSet 可能发生的错误
JoinError( tokio::task::JoinError),
/// 在关闭任务管理器时,Tokio Broadcast Channel 可能出现的发送错误
SendError( tokio::sync::broadcast::error::SendError<()>),
}
/// 每个 Task 实例代表一个异步任务
pub trait Task: Send {
/// 关联类型 R 是 start 方法的返回值
type R: Send;
/// 启动任务。在任务中,可以通过 shutdown_rx 监听退出状态
async fn start(mut self, mut shutdown_rx: tokio::sync::broadcast::Receiver<()>) -> Self::R;
}
/// 任务 ID
pub type TaskID = u64;
/// 任务管理器用于同时管理多个异步任务
pub struct TaskManager<T: Send + 'static> {
/// 通过 Tokio JoinSet 启动所有任务
join_set: tokio::task::JoinSet<()>,
/// 用于通知 Tokio JoinSet 中的所有任务,任务管理器正在关闭
shutdown_tx: tokio::sync::broadcast::Sender<()>,
/// 为任务分配的最大 ID
current_id: TaskID,
/// 保存任务 ID 与任务返回值之间的映射。
/// 任务完成时,将返回值发送到无界的 MPSC 管道中。
/// TaskManager 按需将管道中的数据读取到哈希表中。如此做可以避免维护单独的消费任务带来的复杂性。
/// 使用管道的原因是使任务的返回类型不必是 Sync 的。如果使用带锁的哈希表,那么任务的返回类型必须是 Sync 的。
id_to_return_value: HashMap<TaskID, T>,
id_to_return_value_tx: tokio::sync::mpsc::UnboundedSender<(TaskID, T)>,
id_to_return_value_rx: tokio::sync::mpsc::UnboundedReceiver<(TaskID, T)>,
/// 保存任务 ID 与 Tokio AbortHandle 之间的映射。AbortHandle 可以用于判断任务是否已经完成、中止任务等
id_to_abort_handle: HashMap<TaskID, tokio::task::AbortHandle>,
/// 当该字段达到特定的阈值时,将收缩哈希表占用的空间,以防内存溢出,同时将该字段重置为 0
consumed_count: u64,
}
impl<T: Send + 'static> Drop for TaskManager<T> {
fn drop(&mut self) {
self.id_to_return_value.clear();
self.id_to_return_value.shrink_to_fit();
self.id_to_return_value_rx.close();
while let Ok(_) = self.id_to_return_value_rx.try_recv() {}
self.id_to_abort_handle.clear();
self.id_to_abort_handle.shrink_to_fit();
}
}
impl<T: Send + 'static> TaskManager<T> {
/// 创建任务管理器实例
pub fn new() -> Self {
let (shutdown_tx, _) = tokio::sync::broadcast::channel(1);
let (id_to_return_value_tx, id_to_return_value_rx) = tokio::sync::mpsc::unbounded_channel();
Self {
join_set: tokio::task::JoinSet::new(),
shutdown_tx,
current_id: 0,
id_to_return_value: HashMap::new(),
id_to_return_value_tx,
id_to_return_value_rx,
id_to_abort_handle: HashMap::new(),
consumed_count: 0,
}
}
/// 启动任务,返回任务 ID
pub fn spawn<F>(&mut self, task: F) -> TaskID where F: 'static + Task<R=T> {
let shutdown_rx = self.shutdown_tx.subscribe();
let task_id = self.current_id + 1;
let id_to_return_value_tx = self.id_to_return_value_tx.clone();
let abort_handle = self.join_set.spawn(async move {
let return_value = task.start(shutdown_rx).await;
if let Err(_) = id_to_return_value_tx.send((task_id, return_value)) {
// 接收已关闭
}
drop(id_to_return_value_tx);
});
self.current_id = task_id;
self.id_to_abort_handle.insert(task_id, abort_handle);
task_id
}
/// 从管道中读取任务的返回值,存储到内部的哈希表中
fn consume_return_value_channel(&mut self) {
loop {
if let Ok((task_id, return_value)) = self.id_to_return_value_rx.try_recv() {
self.id_to_return_value.insert(task_id, return_value);
} else {
break;
}
}
}
/// 等待所有任务完成
pub async fn join(&mut self) -> Result<(), TaskManagerError> {
loop {
match self.join_set.join_next().await {
// 所有任务都已完成
None => {
self.consume_return_value_channel();
break Ok(());
}
Some(join_result) => {
if let Err(e) = join_result {
self.consume_return_value_channel();
break Err(TaskManagerError::JoinError(e));
}
}
}
}
}
/// 关闭任务管理器。对于任意任务而言,如果达到指定的超时时间,任务仍然没有完成,那么中止它
pub async fn shutdown(&mut self, timeout: u64) -> Result<(), TaskManagerError> {
// 发送关闭信号
if let Err(_) = self.shutdown_tx.send(()) {
// 只有当广播队列没有消费者时,才会报错。这意味着不存在任何任务,所以返回成功
return Ok(());
}
let sleep_fut = tokio::time::sleep(tokio::time::Duration::from_secs(timeout));
let join_fut = self.join();
tokio::select! {
// 如果达到超时时间,那么中止所有未完成的任务
_ = sleep_fut => {
for (_, abort_handle) in self.id_to_abort_handle.iter() {
if !abort_handle.is_finished() {
abort_handle.abort();
}
}
self.consume_return_value_channel();
Ok(())
},
// 在指定的时间内,所有任务已经完成
_ = join_fut => {
Ok(())
},
}
}
/// 使用任务 ID 获取任务的返回值。如果任务不存在,或返回值已经被取走,或尚未完成,那么返回 None。
/// 使用与该方法类似的方式,实现使用任务 ID 取消任务之类的功能。
/// 务必择机收缩内部保存状态的哈希表,以防在启动大量任务时,内存溢出
pub fn get_return_value(&mut self, task_id: &TaskID) -> Option<T> {
if let Some(abort_handle) = self.id_to_abort_handle.get(task_id) {
// 任务已经完成
if abort_handle.is_finished() {
// 保证任务的返回值一定在哈希表中
self.consume_return_value_channel();
self.consumed_count += 1;
self.id_to_abort_handle.remove(task_id);
// 除非两个哈希表的状态不一致,否则 unwrap 不会导致 Panic
let return_value = self.id_to_return_value.remove(task_id).unwrap();
// 判断是否需要收缩哈希表的存储空间。这里暂时硬编码成每被消费 1000 次,收缩 1 次
if self.consumed_count >= 1000 {
self.id_to_abort_handle.shrink_to_fit();
self.id_to_return_value.shrink_to_fit();
self.consumed_count = 0;
}
return Some(return_value);
}
}
None
}
}
// 下面是测试代码
struct AsyncSleepTask {
// 休眠时间
duration: tokio::time::Duration,
// 任务名称
name: String,
}
impl AsyncSleepTask {
pub fn new(sleep_seconds: u64, name: String) -> Self {
Self {
duration: tokio::time::Duration::from_secs(sleep_seconds),
name,
}
}
}
impl Task for AsyncSleepTask {
type R = Option<()>;
async fn start(mut self, mut shutdown_rx: tokio::sync::broadcast::Receiver<()>) -> Self::R {
println!("`{}` begins to do work, it takes about {} seconds", self.name, self.duration.as_secs());
tokio::select! {
// 使用休眠模拟耗时任务
_ = tokio::time::sleep(self.duration) => {
println!("`{}` has finished work", self.name);
}
_ = shutdown_rx.recv() => {
println!("`{}` has got signal", self.name);
println!("`{}` will do cleaning work", self.name);
}
}
// 模拟清理工作
let cleaning_up_seconds = 1u64;
println!(
"`{}` begins to clean up all resources, it takes about {} seconds",
self.name, cleaning_up_seconds
);
tokio::time::sleep(tokio::time::Duration::from_secs(cleaning_up_seconds)).await;
println!("`{}` has cleaned up all resources, exiting", self.name);
Some(())
}
}
async fn main() {
let mut task_manager = TaskManager::new();
let mut task_id_list = Vec::new();
for idx in 0..1000 {
let task = AsyncSleepTask::new(idx % 10, format!("task_{}", idx));
let task_id = task_manager.spawn(task);
task_id_list.push(task_id);
}
task_manager.shutdown(10).await.unwrap();
for task_id in task_id_list.iter() {
task_manager.get_return_value(task_id).unwrap();
}
}
在多个并发分支上等待,当第一个分支完成时返回,取消剩余分支。
必须在 async 函数、闭包、块的内部使用 select! 宏。
select! 宏接受具有如下模式的一或多个分支:
<pattern> = <async expression> (, if <precondition>)? => <handler>,
另外,select! 宏可以包含单个可选的 else 分支,如果其它分支都不匹配它们的模式,那么计算该分支:
else => <expression>
该宏聚合所有 <async expression> 表达式,然后在当前任务中,并发地运行它们。一旦第一个表达式完成,并且其值匹配它的 <pattern>,select! 宏返回计算已完成分支的 <handler> 表达式的结果。
另外,每个分支可以包含可选的 if 前置条件。如果前置条件返回 false,那么禁用分支。仍然计算 <async expression>,但是不轮询结果 Future。当在循环中使用 select! 时,该功能很有用。
select! 表达式的完整生命周期如下所示:
通过在当前任务上运行所有异步表达式的方式,能够并发地,而非并行地运行表达式。这意味着在相同线程上运行所有表达式,如果一个分支阻塞线程,那么所有其它表达式都将不能继续。如果需要并行,那么使用 tokio::spawn
启动每个异步表达式,然后将 JoinHandle 传递给 select!。
默认情况下,select! 首先随机地选择分支检查。当在循环中调用 select!,并且拥有已经就绪的分支时,这样提供一定程度的公平性。
通过将 biased; 添加到宏使用的开头的方式,重写该行为。查看下面的示例,获取细节。这将导致 select 按照 Future 从上到下出现的顺序轮询它们。这样做的原因包括:
但是使用 biased 模式时,有一个重要的注意事项。你需要确保 Future 的轮询顺序是公平的。比如,如果在流和关闭 Future 之间 select,而该流拥有大量消息,并且它们之间的时间几乎为 0,那么应该将关闭 Future 放到 select! 列表的前面,以确保它始终被轮询,不会因为流不断地准备就绪而被忽略。
如果所有分支都被禁用,并且未提供 else 分支,那么 select! 宏将 Panic。当提供的 if 前置条件返回 false ,或者模式不匹配 <async expression> 的结果时,分支被禁用。
当使用 select! 宏循环接收来自多个源的消息时,应该确保接收调用是取消安全的,以避免丢失消息。本节将介绍各种常见的方法,并且描述它们是否具备取消安全性。这里列出的列表并不全面,仅供参考。
如下方法是取消安全的:
以下方法不具备取消安全性,可能导致数据丢失:
以下方法不具备取消安全性,因为它们使用队列保证公平性,而取消操作将使你在队列中丢失位置:
为确定自己的方法是否具备取消安全性,请查看使用 .await 的位置。这是因为当异步方法被取消时,始终发生在 .await。如果你的函数在等待 .await 时,即使被重启,也能正确运行,那么它是取消安全的。
可以用如下方式定义取消安全性:如果你有一个尚未完成的 Future,那么删除并重新创建该 Future 必须是空操作。该定义基于在循环中使用 select! 的情况。如无此保证,当另一个分支完成,并且通过绕过循环的方式,重启 select!,那么将失去进度。
注意,取消不具备取消安全性的操作不一定是错误的。比如,如果取消任务是因为应用程序正在关闭,那么可能不关心部分读取数据丢失。
带两个分支的基础 select:
async fn do_stuff_async() {
// async work
}
async fn more_async_work() {
// more here
}
async fn main() {
tokio::select! {
_ = do_stuff_async() => {
println!("do_stuff_async() completed first")
}
_ = more_async_work() => {
println!("more_async_work() completed first")
}
};
}
基础 Stream select:
use tokio_stream::{self as stream, StreamExt};
async fn main() {
let mut stream1 = stream::iter(vec![1, 2, 3]);
let mut stream2 = stream::iter(vec![4, 5, 6]);
let next = tokio::select! {
v = stream1.next() => v.unwrap(),
v = stream2.next() => v.unwrap(),
};
assert!(next == 1 || next == 4);
}
收集两个流的内容。在该示例中,我们依赖模式匹配,以及 stream::iter 是 "fused" 的事实,即一旦流完成,所有对 next() 的调用返回 None:
use tokio_stream::{self as stream, StreamExt};
async fn main() {
let mut stream1 = stream::iter(vec![1, 2, 3]);
let mut stream2 = stream::iter(vec![4, 5, 6]);
let mut values = vec![];
loop {
tokio::select! {
Some(v) = stream1.next() => values.push(v),
Some(v) = stream2.next() => values.push(v),
else => break,
}
}
values.sort();
assert_eq!(&[1, 2, 3, 4, 5, 6], &values[..]);
}
通过传递 Future 引用的方式,在多个 select! 表达式中使用相同的 Future。这样做需要 Future 是 Unpin 的。通过使用 Box::pin 或栈固定(stack pinning)的方式,使 Future 变成 Unpin 的。
在下面的例子中,流最多被消费 1 秒钟的时间:
use tokio_stream::{self as stream, StreamExt};
use tokio::time::{self, Duration};
async fn main() {
let mut stream = stream::iter(vec![1, 2, 3]);
let sleep = time::sleep(Duration::from_secs(1));
tokio::pin!(sleep);
loop {
tokio::select! {
maybe_v = stream.next() => {
if let Some(v) = maybe_v {
println!("got = {}", v);
} else {
break;
}
}
_ = &mut sleep => {
println!("timeout");
break;
}
}
}
}
使用 select! 连接两个值:
xxxxxxxxxx
use tokio::sync::oneshot;
async fn main() {
let (tx1, mut rx1) = oneshot::channel();
let (tx2, mut rx2) = oneshot::channel();
tokio::spawn(async move {
tx1.send("first").unwrap();
});
tokio::spawn(async move {
tx2.send("second").unwrap();
});
let mut a = None;
let mut b = None;
while a.is_none() || b.is_none() {
tokio::select! {
v1 = (&mut rx1), if a.is_none() => a = Some(v1.unwrap()),
v2 = (&mut rx2), if b.is_none() => b = Some(v2.unwrap()),
}
}
let res = (a.unwrap(), b.unwrap());
assert_eq!(res.0, "first");
assert_eq!(res.1, "second");
}
使用 biased; 模式控制轮询顺序:
xxxxxxxxxx
async fn main() {
let mut count = 0u8;
loop {
tokio::select! {
// If you run this example without `biased;`, the polling order is
// pseudo-random, and the assertions on the value of count will
// (probably) fail.
biased;
_ = async {}, if count < 1 => {
count += 1;
assert_eq!(count, 1);
}
_ = async {}, if count < 2 => {
count += 1;
assert_eq!(count, 2);
}
_ = async {}, if count < 3 => {
count += 1;
assert_eq!(count, 3);
}
_ = async {}, if count < 4 => {
count += 1;
assert_eq!(count, 4);
}
else => {
break;
}
};
}
}
考虑到 if 前置条件用于禁用 select! 分支,必须谨慎使用,以避免丢失值。
比如,下面是带有 if 的 sleep 的不正确用法。目标是重复地运行异步任务,最多运行 50 毫秒。然而,存在错过 sleep 完成的可能性。
xxxxxxxxxx
use tokio::time::{self, Duration};
async fn some_async_work() {
// do work
}
async fn main() {
let sleep = time::sleep(Duration::from_millis(50));
tokio::pin!(sleep);
while !sleep.is_elapsed() {
tokio::select! {
_ = &mut sleep, if !sleep.is_elapsed() => {
println!("operation timed out");
}
_ = some_async_work() => {
println!("operation completed");
}
}
}
panic!("This example shows how not to do it!");
}
在上面的示例中,即使 sleep.poll() 从未返回 Ready,sleep.is_elapsed() 可能返回 true。这将引起潜在的竞态条件,即在 while !sleep.is_elapsed() 检查和调用 select! 之间,sleep 到期,将导致 some_async_work() 调用在睡眠时间已过期的情况下仍然运行,而不被中断。
以下是一种避免竞态条件的重写示例:
xxxxxxxxxx
use tokio::time::{self, Duration};
async fn some_async_work() {
// do work
}
async fn main() {
let sleep = time::sleep(Duration::from_millis(50));
tokio::pin!(sleep);
loop {
tokio::select! {
_ = &mut sleep => {
println!("operation timed out");
break;
}
_ = some_async_work() => {
println!("operation completed");
}
}
}
}