From e153be0be2fb8df6656292daab3fa59963c76737 Mon Sep 17 00:00:00 2001
From: 3gg <3gg@shellblade.net>
Date: Tue, 13 Feb 2024 17:51:51 -0800
Subject: Let memory allocators trap by default when attempting to allocate
 beyond capacity.

---
 CMakeLists.txt              |  1 +
 cassert/CMakeLists.txt      |  8 ++++++++
 cassert/include/cassert.h   | 29 +++++++++++++++++++++++++++++
 mem/CMakeLists.txt          |  1 +
 mem/include/mem.h           | 11 +++++++++--
 mem/src/mem.c               | 12 ++++++++++++
 mem/test/mem_test.c         |  1 +
 mempool/CMakeLists.txt      |  3 +++
 mempool/include/mempool.h   |  9 +++++++++
 mempool/src/mempool.c       | 11 +++++++++++
 mempool/test/mempool_test.c |  1 +
 11 files changed, 85 insertions(+), 2 deletions(-)
 create mode 100644 cassert/CMakeLists.txt
 create mode 100644 cassert/include/cassert.h

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 868268d..e2206ac 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -2,6 +2,7 @@ cmake_minimum_required(VERSION 3.0)
 
 project(clib)
 
+add_subdirectory(cassert)
 add_subdirectory(cstring)
 add_subdirectory(error)
 add_subdirectory(filesystem)
diff --git a/cassert/CMakeLists.txt b/cassert/CMakeLists.txt
new file mode 100644
index 0000000..855f261
--- /dev/null
+++ b/cassert/CMakeLists.txt
@@ -0,0 +1,8 @@
+cmake_minimum_required(VERSION 3.0)
+
+project(cassert)
+
+add_library(cassert INTERFACE)
+
+target_include_directories(cassert INTERFACE
+  include)
diff --git a/cassert/include/cassert.h b/cassert/include/cassert.h
new file mode 100644
index 0000000..3349b55
--- /dev/null
+++ b/cassert/include/cassert.h
@@ -0,0 +1,29 @@
+#pragma once
+
+#include <assert.h> // For convenience, bring in soft assertions with assert().
+#include <signal.h>
+
+// Allow the client to define their own LOGE() macro.
+#ifndef LOGE
+#include <stdio.h>
+#define LOGE(...)                                            \
+  {                                                          \
+    fprintf(stderr, "[ASSERT] %s:%d: ", __FILE__, __LINE__); \
+    fprintf(stderr, __VA_ARGS__);                            \
+    fprintf(stderr, "\n");                                   \
+  }
+#endif // LOGE
+
+#define TRAP() raise(SIGTRAP)
+
+/// Unconditional hard assert.
+#define FAIL(message) \
+  LOGE(message);      \
+  TRAP();
+
+/// Conditional hard assert.
+#define ASSERT(condition)                 \
+  if (!condition) {                       \
+    LOGE("Assertion failed: " #condition) \
+    TRAP();                               \
+  }
diff --git a/mem/CMakeLists.txt b/mem/CMakeLists.txt
index 233d2be..e4b28c3 100644
--- a/mem/CMakeLists.txt
+++ b/mem/CMakeLists.txt
@@ -11,6 +11,7 @@ target_include_directories(mem PUBLIC
   include)
 
 target_link_libraries(mem
+  cassert
   list)
 
 target_compile_options(mem PRIVATE -Wall -Wextra)
diff --git a/mem/include/mem.h b/mem/include/mem.h
index bcff39f..892ea4f 100644
--- a/mem/include/mem.h
+++ b/mem/include/mem.h
@@ -66,8 +66,10 @@
 #define mem_clear(MEM) mem_clear_(&(MEM)->mem)
 
 /// Allocate a new chunk of N blocks.
-/// Return a pointer to the first block of the chunk, or 0 if there is no memory
-/// left.
+/// Return a pointer to the first block of the chunk.
+/// When there is no space left in the allocator, allocation can either trap
+/// (default) or gracefully return 0. Call mem_enable_traps() to toggle this
+/// behaviour.
 /// New chunks are conveniently zeroed out.
 #define mem_alloc(MEM, num_blocks) mem_alloc_(&(MEM)->mem, num_blocks)
 
@@ -87,6 +89,9 @@
 /// Return the total capacity of the allocator in bytes.
 #define mem_capacity(MEM) mem_capacity_(&(MEM)->mem)
 
+/// Set whether to trap when attempting to allocate beyond capacity.
+#define mem_enable_traps(MEM, enable) mem_enable_traps_(&(MEM)->mem, enable)
+
 /// Iterate over the used chunks of the allocator.
 ///
 /// The caller can use 'i' as the index of the current chunk.
@@ -134,6 +139,7 @@ typedef struct Memory {
   size_t   num_blocks;
   size_t   next_free_chunk;
   bool     dynamic; /// True if blocks and chunks are dynamically-allocated.
+  bool     trap;    /// Whether to trap when allocating beyond capacity.
   Chunk*   chunks;  /// Array of chunk information.
   uint8_t* blocks;  /// Array of blocks;
 } Memory;
@@ -159,3 +165,4 @@ void   mem_free_(Memory*, void** chunk_ptr);
 void*  mem_get_chunk_(const Memory*, size_t chunk_handle);
 size_t mem_get_chunk_handle_(const Memory*, const void* chunk);
 size_t mem_capacity_(const Memory*);
+void   mem_enable_traps_(Memory*, bool);
diff --git a/mem/src/mem.c b/mem/src/mem.c
index 056d947..2904035 100644
--- a/mem/src/mem.c
+++ b/mem/src/mem.c
@@ -1,5 +1,7 @@
 #include "mem.h"
 
+#include <cassert.h>
+
 #include <stdlib.h>
 #include <string.h>
 
@@ -13,6 +15,7 @@ bool mem_make_(
   mem->block_size_bytes = block_size_bytes;
   mem->num_blocks       = num_blocks;
   mem->next_free_chunk  = 0;
+  mem->trap             = true;
 
   // Allocate chunks and blocks if necessary and zero them out.
   if (!chunks) {
@@ -107,6 +110,10 @@ void* mem_alloc_(Memory* mem, size_t num_blocks) {
     mem->next_free_chunk = mem->chunks[chunk_idx].next;
     return &mem->blocks[chunk_idx * mem->block_size_bytes];
   } else {
+    if (mem->trap) {
+      FAIL("Memory allocation failed, increase the allocator's capacity or "
+           "avoid fragmentation.");
+    }
     return 0; // Large-enough free chunk not found.
   }
 }
@@ -186,3 +193,8 @@ size_t mem_capacity_(const Memory* mem) {
   assert(mem);
   return mem->num_blocks * mem->block_size_bytes;
 }
+
+void mem_enable_traps_(Memory* mem, bool enable) {
+  assert(mem);
+  mem->trap = enable;
+}
diff --git a/mem/test/mem_test.c b/mem/test/mem_test.c
index 2f242c3..d3c43b9 100644
--- a/mem/test/mem_test.c
+++ b/mem/test/mem_test.c
@@ -67,6 +67,7 @@ TEST_CASE(mem_fill_then_free) {
 TEST_CASE(mem_allocate_beyond_max_size) {
   test_mem mem;
   mem_make(&mem);
+  mem_enable_traps(&mem, false);
 
   // Fully allocate the mem.
   for (int i = 0; i < NUM_BLOCKS; ++i) {
diff --git a/mempool/CMakeLists.txt b/mempool/CMakeLists.txt
index fe3e2a5..8c9dd30 100644
--- a/mempool/CMakeLists.txt
+++ b/mempool/CMakeLists.txt
@@ -10,6 +10,9 @@ add_library(mempool
 target_include_directories(mempool PUBLIC
   include)
 
+target_link_libraries(mempool PRIVATE
+  cassert)
+
 target_compile_options(mempool PRIVATE -Wall -Wextra)
 
 # Test
diff --git a/mempool/include/mempool.h b/mempool/include/mempool.h
index bd4d4dd..de9ea4f 100644
--- a/mempool/include/mempool.h
+++ b/mempool/include/mempool.h
@@ -65,6 +65,9 @@
 
 /// Allocate a new block.
 /// Return 0 if there is no memory left.
+/// When there is no space left in the pool, allocation can either trap
+/// (default) or gracefully return 0. Call mem_enable_traps() to toggle this
+/// behaviour.
 /// New blocks are conveniently zeroed out.
 #define mempool_alloc(POOL) mempool_alloc_(&(POOL)->pool)
 
@@ -86,6 +89,10 @@
 /// Return the total capacity of the mempool in bytes.
 #define mempool_capacity(POOL) mempool_capacity_(&(POOL)->pool)
 
+/// Set whether to trap when attempting to allocate beyond capacity.
+#define mempool_enable_traps(POOL, enable) \
+  mempool_enable_traps_(&(POOL)->pool, enable)
+
 /// Iterate over the used blocks of the pool.
 ///
 /// The caller can use 'i' as the index of the current block.
@@ -129,6 +136,7 @@ typedef struct mempool {
   size_t     head;    /// Points to the first block in the free list.
   size_t     used;    /// Points to the first block in the used list.
   bool       dynamic; /// True if blocks and info are dynamically-allocated.
+  bool       trap;    /// Whether to trap when allocating beyond capacity.
   BlockInfo* block_info;
   uint8_t*   blocks;
 } mempool;
@@ -154,3 +162,4 @@ void   mempool_free_(mempool*, void** block_ptr);
 void*  mempool_get_block_(const mempool*, size_t block_index);
 size_t mempool_get_block_index_(const mempool*, const void* block);
 size_t mempool_capacity_(const mempool*);
+void   mempool_enable_traps_(mempool*, bool);
diff --git a/mempool/src/mempool.c b/mempool/src/mempool.c
index 1100dad..b09038b 100644
--- a/mempool/src/mempool.c
+++ b/mempool/src/mempool.c
@@ -1,5 +1,7 @@
 #include "mempool.h"
 
+#include <cassert.h>
+
 #include <stdlib.h>
 #include <string.h>
 
@@ -24,6 +26,7 @@ bool mempool_make_(
   pool->num_blocks       = num_blocks;
   pool->head             = 0;
   pool->used             = 0;
+  pool->trap             = true;
 
   // Initialize blocks and block info.
   if (!block_info) {
@@ -74,6 +77,9 @@ void* mempool_alloc_(mempool* pool) {
 
   BlockInfo* head = &pool->block_info[pool->head];
   if (head->used) {
+    if (pool->trap) {
+      FAIL("mempool allocation failed, increase the pool's capacity.");
+    }
     return 0; // Pool is full.
   }
 
@@ -134,3 +140,8 @@ size_t mempool_capacity_(const mempool* pool) {
   assert(pool);
   return pool->num_blocks * pool->block_size_bytes;
 }
+
+void mempool_enable_traps_(mempool* pool, bool enable) {
+  assert(pool);
+  pool->trap = enable;
+}
diff --git a/mempool/test/mempool_test.c b/mempool/test/mempool_test.c
index d5ed1ea..6c48a2a 100644
--- a/mempool/test/mempool_test.c
+++ b/mempool/test/mempool_test.c
@@ -67,6 +67,7 @@ TEST_CASE(mempool_fill_then_free) {
 TEST_CASE(mempool_allocate_beyond_max_size) {
   test_pool pool;
   mempool_make(&pool);
+  mempool_enable_traps(&pool, false);
 
   // Fully allocate the pool.
   for (int i = 0; i < NUM_BLOCKS; ++i) {
-- 
cgit v1.2.3