diff --git a/.gitignore b/.gitignore index 126d134..eda0e13 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,4 @@ /examples/raw_receiver /examples/raw_sender /tests/test_main +/tests/test_prod_reaches_buffer_end diff --git a/Makefile.am b/Makefile.am index ca1c50a..f0c4542 100644 --- a/Makefile.am +++ b/Makefile.am @@ -9,7 +9,8 @@ AM_CFLAGS = \ lib_LIBRARIES = libshmemq.a TESTS = \ - tests/test_main + tests/test_main \ + tests/test_prod_reaches_buffer_end noinst_PROGRAMS = \ $(TESTS) \ @@ -29,3 +30,7 @@ examples_raw_sender_SOURCES = \ tests_test_main_SOURCES = \ $(libshmemq_a_SOURCES) \ tests/test_main.c + +tests_test_prod_reaches_buffer_end_SOURCES = \ + $(libshmemq_a_SOURCES) \ + tests/test_prod_reaches_buffer_end.c diff --git a/include/shmemq.h b/include/shmemq.h index 964f0ee..c9e5b25 100644 --- a/include/shmemq.h +++ b/include/shmemq.h @@ -28,6 +28,7 @@ typedef enum ShmemqError { // Bugs in user code. SHMEMQ_ERROR_BUG_POP_END_ON_EMPTY_QUEUE = 50, + SHMEMQ_ERROR_BUG_PUSH_END_ON_FULL_QUEUE = 51, // Failed system calls. SHMEMQ_ERROR_FAILED_MALLOC = 100, diff --git a/src/main.c b/src/main.c index a89ac7b..6e8be0a 100644 --- a/src/main.c +++ b/src/main.c @@ -144,6 +144,13 @@ void shmemq_init( ShmemqFrame shmemq_push_start(const Shmemq shmemq) { + if ( + shmemq->buffer->header.write_frame_index >= + shmemq->buffer->header.frames_count + ) { + return NULL; + } + const ShmemqFrame low_frame = &shmemq->buffer->frames[0]; const ShmemqFrame high_frame = &shmemq->buffer->frames[ shmemq->buffer->header.write_frame_index @@ -167,6 +174,14 @@ void shmemq_push_end( ) { if (error_ptr) *error_ptr = SHMEMQ_ERROR_NONE; + if ( + shmemq->buffer->header.write_frame_index >= + shmemq->buffer->header.frames_count + ) { + if (error_ptr) *error_ptr = SHMEMQ_ERROR_BUG_PUSH_END_ON_FULL_QUEUE; + return; + } + const ShmemqFrame frame = &shmemq->buffer->frames[shmemq->buffer->header.write_frame_index]; @@ -186,7 +201,10 @@ void shmemq_push_end( shmemq->buffer->header.write_frame_index + frame->header.message_frames_count; - if (new_write_frame_index >= shmemq->buffer->header.frames_count) { + if ( + new_write_frame_index >= shmemq->buffer->header.frames_count && + shmemq->buffer->header.read_frame_index > 0 + ) { shmemq->buffer->header.write_frame_index = 0; } else { diff --git a/tests/test_prod_reaches_buffer_end.c b/tests/test_prod_reaches_buffer_end.c new file mode 100644 index 0000000..c1a883a --- /dev/null +++ b/tests/test_prod_reaches_buffer_end.c @@ -0,0 +1,39 @@ +#include + +#include + +static const char name[] = "/foobar"; + +int main() +{ + ShmemqError error; + + const Shmemq consumer = shmemq_new(name, true, &error); + assert(error == SHMEMQ_ERROR_NONE); + + const Shmemq producer = shmemq_new(name, false, &error); + assert(error == SHMEMQ_ERROR_NONE); + + for (unsigned i = 0; i < 100; ++i) { + const ShmemqFrame frame = shmemq_push_start(producer); + assert(frame != NULL); + + *(unsigned*)frame->data = i; + + shmemq_push_end(producer, sizeof(unsigned), &error); + assert(error == SHMEMQ_ERROR_NONE); + } + + assert(shmemq_push_start(producer) == NULL); + + shmemq_push_end(producer, sizeof(unsigned), &error); + assert(error == SHMEMQ_ERROR_BUG_PUSH_END_ON_FULL_QUEUE); + + shmemq_delete(consumer, &error); + assert(error == SHMEMQ_ERROR_NONE); + + shmemq_delete(producer, &error); + assert(error == SHMEMQ_ERROR_NONE); + + return 0; +}