Last active
October 4, 2023 22:56
-
-
Save sergera/d36c9ef42b8a8703a566ef0870612e95 to your computer and use it in GitHub Desktop.
Abstract rate limiter with channel communication, to be used in a separate thread, makes request throttling easy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::collections::VecDeque; | |
use std::future::Future; | |
use std::pin::Pin; | |
use eyre::*; | |
use tokio::sync::mpsc::UnboundedReceiver; | |
use tokio::time::{Duration, Instant}; | |
use tracing::*; | |
pub enum ThrottledQueueCallbackType<T> { | |
Sync(Box<dyn Fn(T) -> Result<()> + Send + Sync>), | |
Async(Box<dyn Fn(T) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync>), | |
} | |
/*********** | |
* ThrottledQueueRunner: provides a controlled, rate-limited runner for managing queued operations. | |
* | |
* This runner provides an interface for external threads to communicate with it using channels. | |
* | |
* WARNING: | |
* - DO NOT access the client concurrently. Only ONE task/thread should run the client at a time. | |
* - Multiple threads can SAFELY send alerts using the provided channel, but the client itself should not be accessed or run concurrently. | |
* | |
*/ | |
pub struct ThrottledQueueRunner<T> { | |
queue_input_rx: UnboundedReceiver<T>, | |
queue: ThrottledQueueProcessor<T>, | |
} | |
impl<T: Send + 'static> ThrottledQueueRunner<T> { | |
pub fn new( | |
queue_input_rx: UnboundedReceiver<T>, | |
limit: u32, | |
period: Duration, | |
process_callback: ThrottledQueueCallbackType<T>, | |
) -> Self { | |
Self { | |
queue_input_rx, | |
queue: ThrottledQueueProcessor::new(process_callback, limit, period), | |
} | |
} | |
pub async fn run(&mut self) -> Result<()> { | |
loop { | |
match self.queue_input_rx.try_recv() { | |
Ok(item) => { | |
/* if we receive an item from the channel, push it in the queue */ | |
self.queue.push(item); | |
} | |
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => { | |
/* the channel is closed, process remaining items in the queue and exit */ | |
self.queue.process_until_empty().await; | |
break; | |
} | |
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => { | |
if !self.queue.empty() { | |
/* if no item is available, process the queue */ | |
self.queue.process_n_periods(1).await; | |
} else { | |
/* if no item available and nothing to process, wait until next item available */ | |
match self.queue_input_rx.recv().await { | |
Some(item) => { | |
/* when an item becomes available again, push into the queue and continue the loop */ | |
self.queue.push(item); | |
} | |
None => { | |
/* if the channel was closed while waiting, process remaining items in the queue and exit */ | |
self.queue.process_until_empty().await; | |
break; | |
} | |
} | |
} | |
} | |
} | |
} | |
Ok(()) | |
} | |
} | |
/*********** | |
* ThrottledQueueProcessor: provides a controlled, rate-limited queue processor. | |
* | |
* WARNING: | |
* - This struct is NOT thread-safe or designed for concurrent task/thread access. | |
* - It maintains internal mutable state, and if accessed concurrently it can lead to race conditions, undefined behavior, or panics. | |
* - Always ensure that you only access and mutate this struct from a single task/thread at a time. | |
* - If concurrent access is desired in the future, it would require a redesign with proper synchronization mechanisms. | |
* | |
*/ | |
pub struct ThrottledQueueProcessor<T> { | |
queue: VecDeque<T>, | |
process: ThrottledQueueCallbackType<T>, | |
limit: u32, | |
period: Duration, | |
last_reset: Instant, | |
processed_in_current_period: u32, | |
} | |
impl<T> ThrottledQueueProcessor<T> | |
where | |
T: Send, | |
{ | |
pub fn new( | |
process_callback: ThrottledQueueCallbackType<T>, | |
limit: u32, | |
period: Duration, | |
) -> Self { | |
Self { | |
queue: VecDeque::new(), | |
process: process_callback, | |
limit, | |
period, | |
last_reset: Instant::now(), | |
processed_in_current_period: 0, | |
} | |
} | |
pub fn new_sync( | |
process_callback: impl Fn(T) -> Result<()> + Send + Sync + 'static, | |
limit: u32, | |
period: Duration, | |
) -> Self { | |
Self { | |
queue: VecDeque::new(), | |
process: ThrottledQueueCallbackType::Sync(Box::new(process_callback)), | |
limit, | |
period, | |
last_reset: Instant::now(), | |
processed_in_current_period: 0, | |
} | |
} | |
pub fn new_async( | |
process_callback: impl Fn(T) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> | |
+ Send | |
+ Sync | |
+ 'static, | |
limit: u32, | |
period: Duration, | |
) -> Self { | |
Self { | |
queue: VecDeque::new(), | |
process: ThrottledQueueCallbackType::Async(Box::new(process_callback)), | |
limit, | |
period, | |
last_reset: Instant::now(), | |
processed_in_current_period: 0, | |
} | |
} | |
pub fn empty(&self) -> bool { | |
self.queue.is_empty() | |
} | |
pub fn push(&mut self, item: T) { | |
self.queue.push_back(item); | |
} | |
pub async fn process_n_periods(&mut self, n: u32) { | |
let mut periods_elapsed = 0; | |
while periods_elapsed < n { | |
self.check_and_reset_period(); | |
/* process as long as it's possible and there are items in the queue */ | |
while self.can_process() && !self.queue.is_empty() { | |
let item = self.queue.pop_front().unwrap(); | |
match &self.process { | |
ThrottledQueueCallbackType::Sync(process) => { | |
match process(item) { | |
Ok(_) => {} | |
Err(e) => { | |
error!("error processing item: {}", e); | |
} | |
}; | |
self.update_processed(); | |
} | |
ThrottledQueueCallbackType::Async(process) => { | |
match process(item).await { | |
Ok(_) => {} | |
Err(e) => { | |
error!("error processing item: {}", e); | |
} | |
}; | |
self.update_processed(); | |
} | |
} | |
} | |
if !self.can_process() { | |
/* if the rate limit is hit, calculate the remaining time in the current period and sleep */ | |
let time_to_wait = self.period.saturating_sub(self.last_reset.elapsed()); | |
tokio::time::sleep(time_to_wait).await; | |
} else { | |
/* if can process and there are no items in the queue, yield control */ | |
break; | |
} | |
periods_elapsed += 1; | |
} | |
} | |
pub async fn process_until_empty(&mut self) { | |
while !self.queue.is_empty() { | |
self.check_and_reset_period(); | |
/* process as long as it's possible and there are items in the queue */ | |
while self.can_process() && !self.queue.is_empty() { | |
let item = self.queue.pop_front().unwrap(); | |
match &self.process { | |
ThrottledQueueCallbackType::Sync(process) => { | |
match process(item) { | |
Ok(_) => {} | |
Err(e) => { | |
error!("error processing item: {}", e); | |
} | |
}; | |
self.update_processed(); | |
} | |
ThrottledQueueCallbackType::Async(process) => { | |
match process(item).await { | |
Ok(_) => {} | |
Err(e) => { | |
error!("error processing item: {}", e); | |
} | |
}; | |
self.update_processed(); | |
} | |
} | |
} | |
if !self.can_process() { | |
/* if the rate limit is hit, calculate the remaining time in the current period and sleep */ | |
let time_to_wait = self.period.saturating_sub(self.last_reset.elapsed()); | |
tokio::time::sleep(time_to_wait).await; | |
} | |
} | |
} | |
fn can_process(&mut self) -> bool { | |
self.processed_in_current_period < self.limit | |
} | |
fn check_and_reset_period(&mut self) -> () { | |
if self.last_reset.elapsed() > self.period { | |
self.reset_period(); | |
} | |
} | |
fn reset_period(&mut self) { | |
self.last_reset = Instant::now(); | |
self.processed_in_current_period = 0; | |
} | |
fn update_processed(&mut self) { | |
self.processed_in_current_period += 1; | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
use std::sync::{Arc, Mutex}; | |
use std::time::Instant; | |
use tokio::time::Duration; | |
#[tokio::test] | |
async fn test_rate_limiting() { | |
/* create ThrottledQueueProcessor with a dummy function 10 items limit and 10 millisecond period */ | |
let processed_count = Arc::new(Mutex::new(0usize)); | |
let count_clone = processed_count.clone(); | |
let mut processor = ThrottledQueueProcessor::new_sync( | |
move |_| { | |
let mut count = count_clone.lock().unwrap(); | |
*count += 1; | |
Ok(()) | |
}, | |
10, | |
Duration::from_millis(10), | |
); | |
/* add 20 items to the queue */ | |
for _ in 0..20 { | |
processor.push(()); | |
} | |
/* process 2 periods */ | |
let start = Instant::now(); | |
processor.process_n_periods(2).await; | |
/* check duration and make sure it runs for 20 milliseconds or more */ | |
let elapsed_time = start.elapsed(); | |
assert!(elapsed_time >= Duration::from_millis(20)); | |
/* check all items were processed */ | |
assert_eq!(*processed_count.lock().unwrap(), 20); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment