diff --git a/CMakeLists.txt b/CMakeLists.txt index b49d47a..f59f257 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,6 +63,7 @@ endif() # Add test executable add_executable(test_counter tests/test_counter.c src/counter.c) add_executable(test_string_utils tests/test_string_utils.c src/string_utils.c) +add_executable(test_hashtable tests/test_hashtable.c src/core/hashtable.c) add_executable(test_response_writer tests/test_response_writer.c src/client.c src/commands/common/command_registry.c) add_executable(test_server_lifecycle tests/test_server_lifecycle.c src/server_lifecycle.c src/client.c src/core/list.c src/core/hashtable.c) add_executable(test_server_config tests/test_server_config.c src/config.c src/numeric_parse.c) @@ -72,6 +73,7 @@ target_compile_definitions(test_server_config PRIVATE SERVER) target_compile_definitions(test_integration PRIVATE SERVER) fkvs_configure_target(test_counter) fkvs_configure_target(test_string_utils) +fkvs_configure_target(test_hashtable) fkvs_configure_target(test_response_writer) fkvs_configure_target(test_server_lifecycle) fkvs_configure_target(test_server_config) @@ -79,6 +81,7 @@ fkvs_configure_target(test_server_limits) fkvs_configure_target(test_integration) target_compile_options(test_counter PRIVATE -UNDEBUG) target_compile_options(test_string_utils PRIVATE -UNDEBUG) +target_compile_options(test_hashtable PRIVATE -UNDEBUG) target_compile_options(test_response_writer PRIVATE -UNDEBUG) target_compile_options(test_server_lifecycle PRIVATE -UNDEBUG) target_compile_options(test_server_config PRIVATE -UNDEBUG) @@ -86,6 +89,7 @@ target_compile_options(test_server_limits PRIVATE -UNDEBUG) target_compile_options(test_integration PRIVATE -UNDEBUG) target_link_libraries(test_counter) target_link_libraries(test_string_utils) +target_link_libraries(test_hashtable) target_link_libraries(test_response_writer) target_link_libraries(test_server_lifecycle) target_link_libraries(test_server_config) @@ -96,6 +100,7 @@ target_link_libraries(test_integration) enable_testing() add_test(NAME CounterTest COMMAND test_counter) add_test(NAME StringUtilsTest COMMAND test_string_utils) +add_test(NAME HashtableTest COMMAND test_hashtable) add_test(NAME ResponseWriterTest COMMAND test_response_writer) add_test(NAME ServerLifecycleTest COMMAND test_server_lifecycle) add_test(NAME ServerConfigTest COMMAND test_server_config) diff --git a/src/commands/server/server_command_handlers.c b/src/commands/server/server_command_handlers.c index da50011..b38e968 100644 --- a/src/commands/server/server_command_handlers.c +++ b/src/commands/server/server_command_handlers.c @@ -16,15 +16,6 @@ static hashtable_t *table = NULL; static hashtable_t *expires = NULL; -static void free_value_copy(value_entry_t *value) -{ - if (!value) - return; - - free(value->ptr); - free(value); -} - static bool check_and_expire(const unsigned char *key, size_t key_len) { if (is_expired(expires, key, key_len)) { @@ -198,8 +189,8 @@ void handle_set_command(client_t *client, unsigned char *buffer, size_t bytes_re value_len, value_encoding)) { send_error(client); fprintf(stderr, "Unable to store SET value\n"); - free_value_copy(old_value); - free_value_copy(old_expiry); + free_value_entry(old_value); + free_value_entry(old_expiry); free(data); return; } @@ -222,8 +213,8 @@ void handle_set_command(client_t *client, unsigned char *buffer, size_t bytes_re } else { delete_value(expires, &buffer[pos_key], key_len); } - free_value_copy(old_value); - free_value_copy(old_expiry); + free_value_entry(old_value); + free_value_entry(old_expiry); free(data); return; } @@ -233,8 +224,8 @@ void handle_set_command(client_t *client, unsigned char *buffer, size_t bytes_re } send_reply(client, &buffer[pos_value], value_len); - free_value_copy(old_value); - free_value_copy(old_expiry); + free_value_entry(old_value); + free_value_entry(old_expiry); free(data); } @@ -268,7 +259,7 @@ void handle_get_command(client_t *client, unsigned char *buffer, size_t bytes_re if (!resp_buffer) { send_error(client); perror("malloc failed"); - free_value_copy(value); + free_value_entry(value); return; } @@ -277,7 +268,7 @@ void handle_get_command(client_t *client, unsigned char *buffer, size_t bytes_re resp_buffer[value_len] = '\0'; send_reply(client, resp_buffer, value_len); free(resp_buffer); - free_value_copy(value); + free_value_entry(value); } else { send_error(client); } @@ -339,7 +330,7 @@ void handle_incr_command(client_t *client, unsigned char *buffer, if (value->encoding != VALUE_ENTRY_TYPE_INT) { fprintf(stderr, "Stored value is not an integer.\n"); send_error(client); - free_value_copy(value); + free_value_entry(value); return; } @@ -349,7 +340,7 @@ void handle_incr_command(client_t *client, unsigned char *buffer, current == INT64_MAX) { fprintf(stderr, "Stored integer is out of range.\n"); send_error(client); - free_value_copy(value); + free_value_entry(value); return; } const int64_t sum = current + 1; @@ -361,7 +352,7 @@ void handle_incr_command(client_t *client, unsigned char *buffer, char *reply = int64_to_string(sum); if (!reply) { send_error(client); - free_value_copy(value); + free_value_entry(value); return; } const size_t reply_len = strlen(reply); @@ -371,13 +362,13 @@ void handle_incr_command(client_t *client, unsigned char *buffer, fprintf(stderr, "Unable to set incremented value.\n"); send_error(client); free(reply); - free_value_copy(value); + free_value_entry(value); return; } send_reply(client, (const unsigned char *)reply, reply_len); free(reply); - free_value_copy(value); + free_value_entry(value); } void handle_incr_by_command(client_t *client, unsigned char *buffer, @@ -455,7 +446,7 @@ void handle_incr_by_command(client_t *client, unsigned char *buffer, if (old_value->encoding != VALUE_ENTRY_TYPE_INT) { fprintf(stderr, "Stored value is not an integer.\n"); send_error(client); - free_value_copy(old_value); + free_value_entry(old_value); free(incr_str); return; } @@ -470,7 +461,7 @@ void handle_incr_by_command(client_t *client, unsigned char *buffer, (increment < 0 && current < INT64_MIN - increment)) { fprintf(stderr, "Integer increment is out of range.\n"); send_error(client); - free_value_copy(old_value); + free_value_entry(old_value); free(incr_str); return; } @@ -483,7 +474,7 @@ void handle_incr_by_command(client_t *client, unsigned char *buffer, char *result = int64_to_string(sum); if (!result) { send_error(client); - free_value_copy(old_value); + free_value_entry(old_value); free(incr_str); return; } @@ -493,14 +484,14 @@ void handle_incr_by_command(client_t *client, unsigned char *buffer, result_len, VALUE_ENTRY_TYPE_INT)) { fprintf(stderr, "Unable to set incremented value.\n"); send_error(client); - free_value_copy(old_value); + free_value_entry(old_value); free(incr_str); free(result); return; } send_reply(client, (unsigned char *)result, result_len); - free_value_copy(old_value); + free_value_entry(old_value); free(incr_str); free(result); } @@ -581,7 +572,7 @@ void handle_decr_by_command(client_t *client, unsigned char *buffer, if (old_value->encoding != VALUE_ENTRY_TYPE_INT) { fprintf(stderr, "Stored value is not an integer.\n"); send_error(client); - free_value_copy(old_value); + free_value_entry(old_value); free(decr_str); return; } @@ -596,7 +587,7 @@ void handle_decr_by_command(client_t *client, unsigned char *buffer, (decrement < 0 && current > INT64_MAX + decrement)) { fprintf(stderr, "Integer decrement is out of range.\n"); send_error(client); - free_value_copy(old_value); + free_value_entry(old_value); free(decr_str); return; } @@ -609,7 +600,7 @@ void handle_decr_by_command(client_t *client, unsigned char *buffer, char *result = int64_to_string(result_val); if (!result) { send_error(client); - free_value_copy(old_value); + free_value_entry(old_value); free(decr_str); return; } @@ -619,14 +610,14 @@ void handle_decr_by_command(client_t *client, unsigned char *buffer, result_len, VALUE_ENTRY_TYPE_INT)) { fprintf(stderr, "Unable to set decremented value.\n"); send_error(client); - free_value_copy(old_value); + free_value_entry(old_value); free(decr_str); free(result); return; } send_reply(client, (unsigned char *)result, result_len); - free_value_copy(old_value); + free_value_entry(old_value); free(decr_str); free(result); } @@ -762,7 +753,7 @@ void handle_decr_command(client_t *client, unsigned char *buffer, if (value->encoding != VALUE_ENTRY_TYPE_INT) { fprintf(stderr, "Stored value is not an integer.\n"); send_error(client); - free_value_copy(value); + free_value_entry(value); return; } @@ -772,7 +763,7 @@ void handle_decr_command(client_t *client, unsigned char *buffer, current == INT64_MIN) { fprintf(stderr, "Stored integer is out of range.\n"); send_error(client); - free_value_copy(value); + free_value_entry(value); return; } const int64_t decrement = current - 1; @@ -780,7 +771,7 @@ void handle_decr_command(client_t *client, unsigned char *buffer, char *result_str = int64_to_string(decrement); if (!result_str) { send_error(client); - free_value_copy(value); + free_value_entry(value); return; } const size_t result_length = strlen(result_str); @@ -789,13 +780,13 @@ void handle_decr_command(client_t *client, unsigned char *buffer, result_length, VALUE_ENTRY_TYPE_INT)) { fprintf(stderr, "Unable to set decremented value.\n"); send_error(client); - free_value_copy(value); + free_value_entry(value); free(result_str); return; } send_reply(client, (unsigned char *)result_str, result_length); - free_value_copy(value); + free_value_entry(value); free(result_str); } @@ -872,12 +863,12 @@ void handle_expire_command(client_t *client, unsigned char *buffer, send_error(client); return; } - free(val->ptr); - free(val); + free_value_entry(val); // Parse seconds string char sec_buf[32]; - size_t copy_len = ttl_str_len < sizeof(sec_buf) - 1 ? ttl_str_len : sizeof(sec_buf) - 1; + size_t copy_len = + ttl_str_len < sizeof(sec_buf) - 1 ? ttl_str_len : sizeof(sec_buf) - 1; memcpy(sec_buf, &buffer[pos_ttl], copy_len); sec_buf[copy_len] = '\0'; @@ -927,8 +918,7 @@ void handle_ttl_command(client_t *client, unsigned char *buffer, size_t val_len; bool key_exists = get_value(table, &buffer[5], key_len, &val, &val_len); if (key_exists) { - free(val->ptr); - free(val); + free_value_entry(val); } int64_t ttl; @@ -936,7 +926,8 @@ void handle_ttl_command(client_t *client, unsigned char *buffer, ttl = -2; } else { ttl = get_ttl(expires, &buffer[5], key_len); - // get_ttl returns -2 if not in expires table; for existing key with no TTL, return -1 + // get_ttl returns -2 if not in expires table; for existing key with + // no TTL, return -1. if (ttl == -2) ttl = -1; } diff --git a/src/core/hashtable.c b/src/core/hashtable.c index 3b2f9e8..22c1e45 100644 --- a/src/core/hashtable.c +++ b/src/core/hashtable.c @@ -8,6 +8,9 @@ size_t hash_function(const unsigned char *key, const size_t key_len, const size_t table_size) { + if (table_size == 0 || (!key && key_len > 0)) + return 0; + size_t hash = 5381; for (size_t i = 0; i < key_len; i++) { hash = ((hash << 5) + hash) + key[i]; @@ -18,6 +21,9 @@ size_t hash_function(const unsigned char *key, const size_t key_len, // Create a new hash table hashtable_t *create_hash_table(const size_t size) { + if (size == 0) + return NULL; + hashtable_t *table = malloc(sizeof(hashtable_t)); if (!table) return NULL; @@ -30,20 +36,31 @@ hashtable_t *create_hash_table(const size_t size) return table; } +void free_value_entry(value_entry_t *value) +{ + if (!value) + return; + + free(value->ptr); + free(value); +} + void free_hash_table(hashtable_t *table) { if (!table) return; + if (!table->buckets) { + free(table); + return; + } + for (size_t i = 0; i < table->size; i++) { hash_table_entry_t *entry = table->buckets[i]; while (entry) { hash_table_entry_t *next = entry->next; free(entry->key); - if (entry->value) { - free(entry->value->ptr); - free(entry->value); - } + free_value_entry(entry->value); free(entry); entry = next; } @@ -56,6 +73,10 @@ bool set_value(const hashtable_t *table, const unsigned char *key, size_t key_len, const void *value, size_t value_len, int value_type_encoding) { + if (!table || !table->buckets || table->size == 0 || !key || + (!value && value_len > 0)) + return false; + const size_t index = hash_function(key, key_len, table->size); hash_table_entry_t *current = table->buckets[index]; @@ -72,7 +93,7 @@ bool set_value(const hashtable_t *table, const unsigned char *key, if (!current) return false; - current->key = malloc(key_len); + current->key = malloc(key_len == 0 ? 1 : key_len); if (!current->key) { free(current); return false; @@ -122,8 +143,7 @@ bool set_value(const hashtable_t *table, const unsigned char *key, // We are now safe to free old value if (!is_new_entry && current->value) { - free(current->value->ptr); - free(current->value); + free_value_entry(current->value); } current->value = new_val; @@ -132,7 +152,7 @@ bool set_value(const hashtable_t *table, const unsigned char *key, bool delete_value(hashtable_t *table, const unsigned char *key, size_t key_len) { - if (!table || !key) + if (!table || !table->buckets || table->size == 0 || !key) return false; const size_t index = hash_function(key, key_len, table->size); @@ -149,10 +169,7 @@ bool delete_value(hashtable_t *table, const unsigned char *key, size_t key_len) table->buckets[index] = current->next; } free(current->key); - if (current->value) { - free(current->value->ptr); - free(current->value); - } + free_value_entry(current->value); free(current); return true; } @@ -163,10 +180,16 @@ bool delete_value(hashtable_t *table, const unsigned char *key, size_t key_len) return false; } -bool get_value(hashtable_t *table, unsigned char *key, size_t key_len, +bool get_value(hashtable_t *table, const unsigned char *key, size_t key_len, value_entry_t **value, size_t *value_len) { - if (!table || !key || !value || !value_len) + if (value) + *value = NULL; + if (value_len) + *value_len = 0; + + if (!table || !table->buckets || table->size == 0 || !key || !value || + !value_len) return false; const size_t index = hash_function(key, key_len, table->size); @@ -175,6 +198,9 @@ bool get_value(hashtable_t *table, unsigned char *key, size_t key_len, current = current->next) { if (current->key_len == key_len && memcmp(current->key, key, key_len) == 0) { + if (!current->value) + return false; + // allocate full value_entry_t value_entry_t *out = malloc(sizeof(value_entry_t)); if (!out) @@ -188,7 +214,14 @@ bool get_value(hashtable_t *table, unsigned char *key, size_t key_len, return false; } - memcpy(out->ptr, current->value->ptr, current->value->value_len); + if (current->value->value_len > 0) { + if (!current->value->ptr) { + free_value_entry(out); + return false; + } + memcpy(out->ptr, current->value->ptr, + current->value->value_len); + } ((unsigned char *)out->ptr)[current->value->value_len] = '\0'; // copy metadata diff --git a/src/core/hashtable.h b/src/core/hashtable.h index ca27587..82b40f7 100644 --- a/src/core/hashtable.h +++ b/src/core/hashtable.h @@ -29,10 +29,11 @@ typedef struct hashtable { hashtable_t *create_hash_table(size_t size); void free_hash_table(hashtable_t *table); +void free_value_entry(value_entry_t *value); bool set_value(const hashtable_t *table, const unsigned char *key, size_t key_len, const void *value, size_t value_len, int value_type); -bool get_value(hashtable_t *table, unsigned char *key, size_t key_len, +bool get_value(hashtable_t *table, const unsigned char *key, size_t key_len, value_entry_t **value, size_t *value_len); bool delete_value(hashtable_t *table, const unsigned char *key, size_t key_len); size_t hash_function(const unsigned char *key, size_t key_len, diff --git a/src/ttl.c b/src/ttl.c index f261191..3cf98b3 100644 --- a/src/ttl.c +++ b/src/ttl.c @@ -25,12 +25,11 @@ static bool get_deadline(hashtable_t *expires, const unsigned char *key, value_entry_t *val = NULL; size_t val_len = 0; - if (!get_value(expires, (unsigned char *)key, key_len, &val, &val_len)) + if (!get_value(expires, key, key_len, &val, &val_len)) return false; if (val_len != 8) { - free(val->ptr); - free(val); + free_value_entry(val); return false; } @@ -40,8 +39,7 @@ static bool get_deadline(hashtable_t *expires, const unsigned char *key, ((int64_t)b[4] << 24) | ((int64_t)b[5] << 16) | ((int64_t)b[6] << 8) | (int64_t)b[7]; - free(val->ptr); - free(val); + free_value_entry(val); return true; } diff --git a/tests/test_hashtable.c b/tests/test_hashtable.c new file mode 100644 index 0000000..a0df60e --- /dev/null +++ b/tests/test_hashtable.c @@ -0,0 +1,98 @@ +#include "../src/core/hashtable.h" + +#include +#include +#include + +static void test_zero_length_value_roundtrip_is_freeable(void) +{ + hashtable_t *table = create_hash_table(8); + assert(table != NULL); + + const unsigned char key[] = "empty"; + assert(set_value(table, key, strlen((const char *)key), NULL, 0, + VALUE_ENTRY_TYPE_RAW)); + + value_entry_t *value = NULL; + size_t value_len = 99; + assert(get_value(table, key, strlen((const char *)key), &value, + &value_len)); + assert(value != NULL); + assert(value_len == 0); + assert(value->value_len == 0); + assert(value->ptr != NULL); + assert(((unsigned char *)value->ptr)[0] == '\0'); + + free_value_entry(value); + free_hash_table(table); + + printf("test_zero_length_value_roundtrip_is_freeable passed.\n"); +} + +static void test_invalid_inputs_are_rejected(void) +{ + hashtable_t *table = create_hash_table(4); + assert(table != NULL); + + const unsigned char key[] = "key"; + const unsigned char value[] = "value"; + value_entry_t *out = NULL; + size_t out_len = 0; + + assert(create_hash_table(0) == NULL); + assert(!set_value(NULL, key, sizeof(key) - 1, value, sizeof(value) - 1, + VALUE_ENTRY_TYPE_RAW)); + assert(!set_value(table, NULL, sizeof(key) - 1, value, sizeof(value) - 1, + VALUE_ENTRY_TYPE_RAW)); + assert(!set_value(table, key, sizeof(key) - 1, NULL, 1, + VALUE_ENTRY_TYPE_RAW)); + assert(!get_value(NULL, key, sizeof(key) - 1, &out, &out_len)); + assert(!get_value(table, NULL, sizeof(key) - 1, &out, &out_len)); + assert(!get_value(table, key, sizeof(key) - 1, NULL, &out_len)); + assert(!get_value(table, key, sizeof(key) - 1, &out, NULL)); + assert(!delete_value(NULL, key, sizeof(key) - 1)); + assert(!delete_value(table, NULL, sizeof(key) - 1)); + + free_hash_table(table); + + printf("test_invalid_inputs_are_rejected passed.\n"); +} + +static void test_replace_delete_and_free_are_sanitizer_clean(void) +{ + hashtable_t *table = create_hash_table(2); + assert(table != NULL); + + const unsigned char key[] = "name"; + const unsigned char first[] = "alice"; + const unsigned char second[] = "alexandre"; + + assert(set_value(table, key, sizeof(key) - 1, first, sizeof(first) - 1, + VALUE_ENTRY_TYPE_RAW)); + assert(set_value(table, key, sizeof(key) - 1, second, sizeof(second) - 1, + VALUE_ENTRY_TYPE_RAW)); + + value_entry_t *value = NULL; + size_t value_len = 0; + assert(get_value(table, key, sizeof(key) - 1, &value, &value_len)); + assert(value_len == sizeof(second) - 1); + assert(memcmp(value->ptr, second, value_len) == 0); + free_value_entry(value); + + assert(delete_value(table, key, sizeof(key) - 1)); + assert(!get_value(table, key, sizeof(key) - 1, &value, &value_len)); + + assert(set_value(table, key, sizeof(key) - 1, first, sizeof(first) - 1, + VALUE_ENTRY_TYPE_RAW)); + free_hash_table(table); + + printf("test_replace_delete_and_free_are_sanitizer_clean passed.\n"); +} + +int main(void) +{ + test_zero_length_value_roundtrip_is_freeable(); + test_invalid_inputs_are_rejected(); + test_replace_delete_and_free_are_sanitizer_clean(); + return 0; +}