utils: add rolling_avg

Signed-off-by: Yuxuan Shui <yshuiv7@gmail.com>
This commit is contained in:
Yuxuan Shui 2022-12-13 08:39:11 +00:00
parent 91e023971d
commit 4594168946
No known key found for this signature in database
GPG Key ID: D3A4405BE6CC17F4
2 changed files with 86 additions and 0 deletions

View File

@ -164,4 +164,83 @@ TEST_CASE(rolling_max_test) {
#undef NELEM
}
/// A rolling average of a stream of integers.
struct rolling_avg {
/// The sum of the elements in the window.
int64_t sum;
/// The elements in the window.
int *elem;
int head, nelem;
int window_size;
};
struct rolling_avg *rolling_avg_new(int size) {
auto rm = ccalloc(1, struct rolling_avg);
if (!rm) {
return NULL;
}
rm->elem = ccalloc(size, int);
rm->window_size = size;
if (!rm->elem) {
free(rm);
return NULL;
}
return rm;
}
void rolling_avg_destroy(struct rolling_avg *rm) {
free(rm->elem);
free(rm);
}
void rolling_avg_reset(struct rolling_avg *ra) {
ra->sum = 0;
ra->nelem = 0;
ra->head = 0;
}
void rolling_avg_push(struct rolling_avg *ra, int val) {
if (ra->nelem == ra->window_size) {
// Discard the oldest element.
// rm->elem.pop_front();
ra->sum -= ra->elem[ra->head % ra->window_size];
ra->nelem--;
ra->head = (ra->head + 1) % ra->window_size;
}
// Add the new element to the queue.
// rm->elem.push_back(val)
ra->elem[(ra->head + ra->nelem) % ra->window_size] = val;
ra->sum += val;
ra->nelem++;
}
double rolling_avg_get_avg(struct rolling_avg *ra) {
if (ra->nelem == 0) {
return 0;
}
return (double)ra->sum / (double)ra->nelem;
}
TEST_CASE(rolling_avg_test) {
#define NELEM 15
auto rm = rolling_avg_new(3);
const int data[NELEM] = {1, 2, 3, 1, 4, 5, 2, 3, 6, 5, 4, 3, 2, 0, 0};
const double expected_avg[NELEM] = {
1, 1.5, 2, 2, 8.0 / 3.0, 10.0 / 3.0, 11.0 / 3.0, 10.0 / 3.0,
11.0 / 3.0, 14.0 / 3.0, 5, 4, 3, 5.0 / 3.0, 2.0 / 3.0};
double avg[NELEM] = {0};
for (int i = 0; i < NELEM; i++) {
rolling_avg_push(rm, data[i]);
avg[i] = rolling_avg_get_avg(rm);
}
for (int i = 0; i < NELEM; i++) {
TEST_EQUAL(avg[i], expected_avg[i]);
}
}
// vim: set noet sw=8 ts=8 :

View File

@ -297,6 +297,13 @@ void rolling_max_reset(struct rolling_max *rm);
void rolling_max_push(struct rolling_max *rm, int val);
int rolling_max_get_max(struct rolling_max *rm);
struct rolling_avg;
struct rolling_avg *rolling_avg_new(int window_size);
void rolling_avg_free(struct rolling_avg *ra);
void rolling_avg_reset(struct rolling_avg *ra);
void rolling_avg_push(struct rolling_avg *ra, int val);
double rolling_avg_get_avg(struct rolling_avg *ra);
// Some versions of the Android libc do not have timespec_get(), use
// clock_gettime() instead.
#ifdef __ANDROID__