diff --git a/build_static.sh b/build_static.sh index 366a69a..1095cc3 100755 --- a/build_static.sh +++ b/build_static.sh @@ -95,16 +95,43 @@ else echo "Installing additional static libraries..." sudo apt install -y libcap-dev libuv1-dev libev-dev + # Build SQLite with JSON1 extension if not available + echo "Building SQLite with JSON1 extension..." + SQLITE_BUILD_DIR="/tmp/sqlite-build-$$" + mkdir -p "$SQLITE_BUILD_DIR" + cd "$SQLITE_BUILD_DIR" + + wget https://www.sqlite.org/2024/sqlite-autoconf-3460000.tar.gz + tar xzf sqlite-autoconf-3460000.tar.gz + cd sqlite-autoconf-3460000 + + ./configure \ + --enable-static \ + --disable-shared \ + --enable-json1 \ + --enable-fts5 \ + --prefix="$SQLITE_BUILD_DIR/install" \ + CFLAGS="-DSQLITE_ENABLE_JSON1=1 -DSQLITE_ENABLE_FTS5=1" + + make && make install + + # Return to project directory + cd "$SCRIPT_DIR" + # Try building with regular gcc and static linking echo "Compiling with gcc -static..." # Use the same approach as the regular Makefile but with static linking gcc -static -O2 -Wall -Wextra -std=c99 -g \ -I. -Inostr_core_lib -Inostr_core_lib/nostr_core -Inostr_core_lib/cjson -Inostr_core_lib/nostr_websocket \ + -I"$SQLITE_BUILD_DIR/install/include" \ src/main.c src/config.c src/dm_admin.c src/request_validator.c src/nip009.c src/nip011.c src/nip013.c src/nip040.c src/nip042.c src/websockets.c src/subscriptions.c src/api.c src/embedded_web_content.c \ -o "$BUILD_DIR/c_relay_static_x86_64" \ nostr_core_lib/libnostr_core_x64.a \ - -lsqlite3 -lwebsockets -lz -ldl -lpthread -lm -L/usr/local/lib -lsecp256k1 -lssl -lcrypto -L/usr/local/lib -lcurl -lcap -luv_a -lev + "$SQLITE_BUILD_DIR/install/lib/libsqlite3.a" -lwebsockets -lz -ldl -lpthread -lm -L/usr/local/lib -lsecp256k1 -lssl -lcrypto -L/usr/local/lib -lcurl -lcap -luv_a -lev + + # Clean up SQLite build directory + rm -rf "$SQLITE_BUILD_DIR" if [ $? -eq 0 ]; then echo "x86_64 static binary created: $BUILD_DIR/c_relay_static_x86_64" diff --git a/deploy_static.sh b/deploy_static.sh new file mode 100755 index 0000000..1e7c3f0 --- /dev/null +++ b/deploy_static.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +# C-Relay Static Binary Deployment Script +# Deploys build/c_relay_static_x86_64 to server via sshlt + +set -e + +# Configuration +LOCAL_BINARY="build/c_relay_static_x86_64" +REMOTE_BINARY_PATH="/usr/local/bin/c_relay/c_relay" +SERVICE_NAME="c-relay" + +# Create backup +ssh ubuntu@laantungir.com "sudo cp '$REMOTE_BINARY_PATH' '${REMOTE_BINARY_PATH}.backup.$(date +%Y%m%d_%H%M%S)'" 2>/dev/null || true + +# Upload binary to temp location +scp "$LOCAL_BINARY" "ubuntu@laantungir.com:/tmp/c_relay.tmp" + +# Install binary +ssh ubuntu@laantungir.com "sudo mv '/tmp/c_relay.tmp' '$REMOTE_BINARY_PATH'" +ssh ubuntu@laantungir.com "sudo chown c-relay:c-relay '$REMOTE_BINARY_PATH'" +ssh ubuntu@laantungir.com "sudo chmod +x '$REMOTE_BINARY_PATH'" + +# Restart service +ssh ubuntu@laantungir.com "sudo systemctl restart '$SERVICE_NAME'" + +echo "Deployment complete!" \ No newline at end of file diff --git a/examples/deployment/static-builder.Dockerfile b/examples/deployment/static-builder.Dockerfile index 9e5f822..ddac50e 100644 --- a/examples/deployment/static-builder.Dockerfile +++ b/examples/deployment/static-builder.Dockerfile @@ -69,6 +69,20 @@ RUN cd /tmp && \ ./Configure linux-x86_64 no-shared --prefix=/usr && \ make && make install_sw +# Build SQLite with JSON1 extension enabled +RUN cd /tmp && \ + wget https://www.sqlite.org/2024/sqlite-autoconf-3460000.tar.gz && \ + tar xzf sqlite-autoconf-3460000.tar.gz && \ + cd sqlite-autoconf-3460000 && \ + ./configure \ + --enable-static \ + --disable-shared \ + --enable-json1 \ + --enable-fts5 \ + --prefix=/usr \ + CFLAGS="-DSQLITE_ENABLE_JSON1=1 -DSQLITE_ENABLE_FTS5=1" && \ + make && make install + # Build libsecp256k1 static RUN cd /tmp && \ git clone https://github.com/bitcoin-core/secp256k1.git && \ diff --git a/make_and_restart_relay.sh b/make_and_restart_relay.sh index f30cf8c..2fb6115 100755 --- a/make_and_restart_relay.sh +++ b/make_and_restart_relay.sh @@ -163,9 +163,15 @@ rm -f db/c_nostr_relay.db* 2>/dev/null echo "Embedding web files..." ./embed_web_files.sh -# Build the project first -echo "Building project..." -make clean all +# Build the project first - use static build by default +echo "Building project (static binary with SQLite JSON1 extension)..." +./build_static.sh + +# Fallback to regular build if static build fails +if [ $? -ne 0 ]; then + echo "Static build failed, falling back to regular build..." + make clean all +fi # Restore database files if preserving if [ "$PRESERVE_DATABASE" = true ] && [ -d "/tmp/relay_backup_$$" ]; then @@ -181,22 +187,34 @@ if [ $? -ne 0 ]; then exit 1 fi -# Check if relay binary exists after build - detect architecture +# Check if relay binary exists after build - prefer static binary, fallback to regular ARCH=$(uname -m) case "$ARCH" in x86_64) - BINARY_PATH="./build/c_relay_x86" + STATIC_BINARY="./build/c_relay_static_x86_64" + REGULAR_BINARY="./build/c_relay_x86" ;; aarch64|arm64) - BINARY_PATH="./build/c_relay_arm64" + STATIC_BINARY="./build/c_relay_static_arm64" + REGULAR_BINARY="./build/c_relay_arm64" ;; *) - BINARY_PATH="./build/c_relay_$ARCH" + STATIC_BINARY="./build/c_relay_static_$ARCH" + REGULAR_BINARY="./build/c_relay_$ARCH" ;; esac -if [ ! -f "$BINARY_PATH" ]; then - echo "ERROR: Relay binary not found at $BINARY_PATH after build. Build may have failed." +# Prefer static binary if available +if [ -f "$STATIC_BINARY" ]; then + BINARY_PATH="$STATIC_BINARY" + echo "Using static binary: $BINARY_PATH" +elif [ -f "$REGULAR_BINARY" ]; then + BINARY_PATH="$REGULAR_BINARY" + echo "Using regular binary: $BINARY_PATH" +else + echo "ERROR: No relay binary found. Checked:" + echo " - $STATIC_BINARY" + echo " - $REGULAR_BINARY" exit 1 fi diff --git a/relay.pid b/relay.pid index c050afe..319b126 100644 --- a/relay.pid +++ b/relay.pid @@ -1 +1 @@ -2442403 +2875464 diff --git a/src/main.c b/src/main.c index b249b5d..2471f82 100644 --- a/src/main.c +++ b/src/main.c @@ -126,6 +126,22 @@ int process_admin_event_in_config(cJSON* event, char* error_message, size_t erro // Forward declaration for NIP-45 COUNT message handling int handle_count_message(const char* sub_id, cJSON* filters, struct lws *wsi, struct per_session_data *pss); +// Parameter binding helpers for SQL queries +static void add_bind_param(char*** params, int* count, int* capacity, const char* value) { + if (*count >= *capacity) { + *capacity = *capacity == 0 ? 16 : *capacity * 2; + *params = realloc(*params, *capacity * sizeof(char*)); + } + (*params)[(*count)++] = strdup(value); +} + +static void free_bind_params(char** params, int count) { + for (int i = 0; i < count; i++) { + free(params[i]); + } + free(params); +} + // Forward declaration for enhanced admin event authorization int is_authorized_admin_event(cJSON* event, char* error_message, size_t error_size); @@ -726,7 +742,95 @@ int handle_req_message(const char* sub_id, cJSON* filters, struct lws *wsi, stru log_error("REQ filters is not an array"); return 0; } - + + // EARLY SUBSCRIPTION LIMIT CHECK - Check limits BEFORE any processing + if (pss) { + time_t current_time = time(NULL); + + // Check if client is currently rate limited due to excessive failed attempts + if (pss->rate_limit_until > current_time) { + char rate_limit_msg[256]; + int remaining_seconds = (int)(pss->rate_limit_until - current_time); + snprintf(rate_limit_msg, sizeof(rate_limit_msg), + "Rate limited due to excessive failed subscription attempts. Try again in %d seconds.", remaining_seconds); + + // Send CLOSED notice for rate limiting + cJSON* closed_msg = cJSON_CreateArray(); + cJSON_AddItemToArray(closed_msg, cJSON_CreateString("CLOSED")); + cJSON_AddItemToArray(closed_msg, cJSON_CreateString(sub_id)); + cJSON_AddItemToArray(closed_msg, cJSON_CreateString("error: rate limited")); + cJSON_AddItemToArray(closed_msg, cJSON_CreateString(rate_limit_msg)); + + char* closed_str = cJSON_Print(closed_msg); + if (closed_str) { + size_t closed_len = strlen(closed_str); + unsigned char* buf = malloc(LWS_PRE + closed_len); + if (buf) { + memcpy(buf + LWS_PRE, closed_str, closed_len); + lws_write(wsi, buf + LWS_PRE, closed_len, LWS_WRITE_TEXT); + free(buf); + } + free(closed_str); + } + cJSON_Delete(closed_msg); + + // Update rate limiting counters + pss->failed_subscription_attempts++; + pss->last_failed_attempt = current_time; + + return 0; + } + + // Check session subscription limits + if (pss->subscription_count >= g_subscription_manager.max_subscriptions_per_client) { + log_error("Maximum subscriptions per client exceeded"); + + // Update rate limiting counters for failed attempt + pss->failed_subscription_attempts++; + pss->last_failed_attempt = current_time; + pss->consecutive_failures++; + + // Implement progressive backoff: 1s, 5s, 30s, 300s (5min) based on consecutive failures + int backoff_seconds = 1; + if (pss->consecutive_failures >= 10) backoff_seconds = 300; // 5 minutes + else if (pss->consecutive_failures >= 5) backoff_seconds = 30; // 30 seconds + else if (pss->consecutive_failures >= 3) backoff_seconds = 5; // 5 seconds + + pss->rate_limit_until = current_time + backoff_seconds; + + // Send CLOSED notice with backoff information + cJSON* closed_msg = cJSON_CreateArray(); + cJSON_AddItemToArray(closed_msg, cJSON_CreateString("CLOSED")); + cJSON_AddItemToArray(closed_msg, cJSON_CreateString(sub_id)); + cJSON_AddItemToArray(closed_msg, cJSON_CreateString("error: too many subscriptions")); + + char backoff_msg[256]; + snprintf(backoff_msg, sizeof(backoff_msg), + "Maximum subscriptions per client exceeded. Backoff for %d seconds.", backoff_seconds); + cJSON_AddItemToArray(closed_msg, cJSON_CreateString(backoff_msg)); + + char* closed_str = cJSON_Print(closed_msg); + if (closed_str) { + size_t closed_len = strlen(closed_str); + unsigned char* buf = malloc(LWS_PRE + closed_len); + if (buf) { + memcpy(buf + LWS_PRE, closed_str, closed_len); + lws_write(wsi, buf + LWS_PRE, closed_len, LWS_WRITE_TEXT); + free(buf); + } + free(closed_str); + } + cJSON_Delete(closed_msg); + + return 0; + } + } + + // Parameter binding helpers + char** bind_params = NULL; + int bind_param_count = 0; + int bind_param_capacity = 0; + // Check for kind 33334 configuration event requests BEFORE creating subscription int config_events_sent = 0; int has_config_request = 0; @@ -770,32 +874,6 @@ int handle_req_message(const char* sub_id, cJSON* filters, struct lws *wsi, stru // If only config events were requested, we can return early after sending EOSE // But still create the subscription for future config updates - // Check session subscription limits - if (pss && pss->subscription_count >= g_subscription_manager.max_subscriptions_per_client) { - log_error("Maximum subscriptions per client exceeded"); - - // Send CLOSED notice - cJSON* closed_msg = cJSON_CreateArray(); - cJSON_AddItemToArray(closed_msg, cJSON_CreateString("CLOSED")); - cJSON_AddItemToArray(closed_msg, cJSON_CreateString(sub_id)); - cJSON_AddItemToArray(closed_msg, cJSON_CreateString("error: too many subscriptions")); - - char* closed_str = cJSON_Print(closed_msg); - if (closed_str) { - size_t closed_len = strlen(closed_str); - unsigned char* buf = malloc(LWS_PRE + closed_len); - if (buf) { - memcpy(buf + LWS_PRE, closed_str, closed_len); - lws_write(wsi, buf + LWS_PRE, closed_len, LWS_WRITE_TEXT); - free(buf); - } - free(closed_str); - } - cJSON_Delete(closed_msg); - - return has_config_request ? config_events_sent : 0; - } - // Create persistent subscription subscription_t* subscription = create_subscription(sub_id, wsi, filters, pss ? pss->client_ip : "unknown"); if (!subscription) { @@ -807,13 +885,13 @@ int handle_req_message(const char* sub_id, cJSON* filters, struct lws *wsi, stru if (add_subscription_to_manager(subscription) != 0) { log_error("Failed to add subscription to global manager"); free_subscription(subscription); - + // Send CLOSED notice cJSON* closed_msg = cJSON_CreateArray(); cJSON_AddItemToArray(closed_msg, cJSON_CreateString("CLOSED")); cJSON_AddItemToArray(closed_msg, cJSON_CreateString(sub_id)); cJSON_AddItemToArray(closed_msg, cJSON_CreateString("error: subscription limit reached")); - + char* closed_str = cJSON_Print(closed_msg); if (closed_str) { size_t closed_len = strlen(closed_str); @@ -826,7 +904,15 @@ int handle_req_message(const char* sub_id, cJSON* filters, struct lws *wsi, stru free(closed_str); } cJSON_Delete(closed_msg); - + + // Update rate limiting counters for failed attempt (global limit reached) + if (pss) { + time_t current_time = time(NULL); + pss->failed_subscription_attempts++; + pss->last_failed_attempt = current_time; + pss->consecutive_failures++; + } + return has_config_request ? config_events_sent : 0; } @@ -848,7 +934,13 @@ int handle_req_message(const char* sub_id, cJSON* filters, struct lws *wsi, stru log_warning("Invalid filter object"); continue; } - + + // Reset bind params for this filter + free_bind_params(bind_params, bind_param_count); + bind_params = NULL; + bind_param_count = 0; + bind_param_capacity = 0; + // Build SQL query based on filter - exclude ephemeral events (kinds 20000-29999) from historical queries char sql[1024] = "SELECT id, pubkey, created_at, kind, content, sig, tags FROM events WHERE 1=1 AND (kind < 20000 OR kind >= 30000)"; char* sql_ptr = sql + strlen(sql); @@ -888,56 +980,80 @@ int handle_req_message(const char* sub_id, cJSON* filters, struct lws *wsi, stru // Handle authors filter cJSON* authors = cJSON_GetObjectItem(filter, "authors"); if (authors && cJSON_IsArray(authors)) { - int author_count = cJSON_GetArraySize(authors); + int author_count = 0; + // Count valid authors + for (int a = 0; a < cJSON_GetArraySize(authors); a++) { + cJSON* author = cJSON_GetArrayItem(authors, a); + if (cJSON_IsString(author)) { + author_count++; + } + } if (author_count > 0) { snprintf(sql_ptr, remaining, " AND pubkey IN ("); sql_ptr += strlen(sql_ptr); remaining = sizeof(sql) - strlen(sql); for (int a = 0; a < author_count; a++) { - cJSON* author = cJSON_GetArrayItem(authors, a); - if (cJSON_IsString(author)) { - if (a > 0) { - snprintf(sql_ptr, remaining, ","); - sql_ptr++; - remaining--; - } - snprintf(sql_ptr, remaining, "'%s'", cJSON_GetStringValue(author)); - sql_ptr += strlen(sql_ptr); - remaining = sizeof(sql) - strlen(sql); + if (a > 0) { + snprintf(sql_ptr, remaining, ","); + sql_ptr++; + remaining--; } + snprintf(sql_ptr, remaining, "?"); + sql_ptr += strlen(sql_ptr); + remaining = sizeof(sql) - strlen(sql); } snprintf(sql_ptr, remaining, ")"); sql_ptr += strlen(sql_ptr); remaining = sizeof(sql) - strlen(sql); + + // Add author values to bind params + for (int a = 0; a < cJSON_GetArraySize(authors); a++) { + cJSON* author = cJSON_GetArrayItem(authors, a); + if (cJSON_IsString(author)) { + add_bind_param(&bind_params, &bind_param_count, &bind_param_capacity, cJSON_GetStringValue(author)); + } + } } } // Handle ids filter cJSON* ids = cJSON_GetObjectItem(filter, "ids"); if (ids && cJSON_IsArray(ids)) { - int id_count = cJSON_GetArraySize(ids); + int id_count = 0; + // Count valid ids + for (int i = 0; i < cJSON_GetArraySize(ids); i++) { + cJSON* id = cJSON_GetArrayItem(ids, i); + if (cJSON_IsString(id)) { + id_count++; + } + } if (id_count > 0) { snprintf(sql_ptr, remaining, " AND id IN ("); sql_ptr += strlen(sql_ptr); remaining = sizeof(sql) - strlen(sql); for (int i = 0; i < id_count; i++) { - cJSON* id = cJSON_GetArrayItem(ids, i); - if (cJSON_IsString(id)) { - if (i > 0) { - snprintf(sql_ptr, remaining, ","); - sql_ptr++; - remaining--; - } - snprintf(sql_ptr, remaining, "'%s'", cJSON_GetStringValue(id)); - sql_ptr += strlen(sql_ptr); - remaining = sizeof(sql) - strlen(sql); + if (i > 0) { + snprintf(sql_ptr, remaining, ","); + sql_ptr++; + remaining--; } + snprintf(sql_ptr, remaining, "?"); + sql_ptr += strlen(sql_ptr); + remaining = sizeof(sql) - strlen(sql); } snprintf(sql_ptr, remaining, ")"); sql_ptr += strlen(sql_ptr); remaining = sizeof(sql) - strlen(sql); + + // Add id values to bind params + for (int i = 0; i < cJSON_GetArraySize(ids); i++) { + cJSON* id = cJSON_GetArrayItem(ids, i); + if (cJSON_IsString(id)) { + add_bind_param(&bind_params, &bind_param_count, &bind_param_capacity, cJSON_GetStringValue(id)); + } + } } } @@ -950,29 +1066,42 @@ int handle_req_message(const char* sub_id, cJSON* filters, struct lws *wsi, stru const char* tag_name = filter_key + 1; // Get the tag name (e, p, t, type, etc.) if (cJSON_IsArray(filter_item)) { - int tag_value_count = cJSON_GetArraySize(filter_item); + int tag_value_count = 0; + // Count valid tag values + for (int i = 0; i < cJSON_GetArraySize(filter_item); i++) { + cJSON* tag_value = cJSON_GetArrayItem(filter_item, i); + if (cJSON_IsString(tag_value)) { + tag_value_count++; + } + } if (tag_value_count > 0) { - // Use EXISTS with LIKE to check for matching tags - snprintf(sql_ptr, remaining, " AND EXISTS (SELECT 1 FROM json_each(json(tags)) WHERE json_extract(value, '$[0]') = '%s' AND json_extract(value, '$[1]') IN (", tag_name); + // Use EXISTS with parameterized query + snprintf(sql_ptr, remaining, " AND EXISTS (SELECT 1 FROM json_each(json(tags)) WHERE json_extract(value, '$[0]') = ? AND json_extract(value, '$[1]') IN ("); sql_ptr += strlen(sql_ptr); remaining = sizeof(sql) - strlen(sql); for (int i = 0; i < tag_value_count; i++) { - cJSON* tag_value = cJSON_GetArrayItem(filter_item, i); - if (cJSON_IsString(tag_value)) { - if (i > 0) { - snprintf(sql_ptr, remaining, ","); - sql_ptr++; - remaining--; - } - snprintf(sql_ptr, remaining, "'%s'", cJSON_GetStringValue(tag_value)); - sql_ptr += strlen(sql_ptr); - remaining = sizeof(sql) - strlen(sql); + if (i > 0) { + snprintf(sql_ptr, remaining, ","); + sql_ptr++; + remaining--; } + snprintf(sql_ptr, remaining, "?"); + sql_ptr += strlen(sql_ptr); + remaining = sizeof(sql) - strlen(sql); } snprintf(sql_ptr, remaining, "))"); sql_ptr += strlen(sql_ptr); remaining = sizeof(sql) - strlen(sql); + + // Add tag name and values to bind params + add_bind_param(&bind_params, &bind_param_count, &bind_param_capacity, tag_name); + for (int i = 0; i < cJSON_GetArraySize(filter_item); i++) { + cJSON* tag_value = cJSON_GetArrayItem(filter_item, i); + if (cJSON_IsString(tag_value)) { + add_bind_param(&bind_params, &bind_param_count, &bind_param_capacity, cJSON_GetStringValue(tag_value)); + } + } } } } @@ -1048,6 +1177,11 @@ int handle_req_message(const char* sub_id, cJSON* filters, struct lws *wsi, stru log_error(error_msg); continue; } + + // Bind parameters + for (int i = 0; i < bind_param_count; i++) { + sqlite3_bind_text(stmt, i + 1, bind_params[i], -1, SQLITE_TRANSIENT); + } int row_count = 0; while (sqlite3_step(stmt) == SQLITE_ROW) { @@ -1112,7 +1246,10 @@ int handle_req_message(const char* sub_id, cJSON* filters, struct lws *wsi, stru sqlite3_finalize(stmt); } - + + // Cleanup bind params + free_bind_params(bind_params, bind_param_count); + return events_sent; } ///////////////////////////////////////////////////////////////////////////////////////// @@ -1614,9 +1751,27 @@ int main(int argc, char* argv[]) { // Initialize NIP-40 expiration configuration init_expiration_config(); - // Update subscription manager configuration update_subscription_manager_config(); + + // Initialize subscription manager mutexes + if (pthread_mutex_init(&g_subscription_manager.subscriptions_lock, NULL) != 0) { + log_error("Failed to initialize subscription manager subscriptions lock"); + cleanup_configuration_system(); + nostr_cleanup(); + close_database(); + return 1; + } + + if (pthread_mutex_init(&g_subscription_manager.ip_tracking_lock, NULL) != 0) { + log_error("Failed to initialize subscription manager IP tracking lock"); + pthread_mutex_destroy(&g_subscription_manager.subscriptions_lock); + cleanup_configuration_system(); + nostr_cleanup(); + close_database(); + return 1; + } + // Start WebSocket Nostr relay server (port from configuration) @@ -1626,6 +1781,11 @@ int main(int argc, char* argv[]) { cleanup_relay_info(); ginxsom_request_validator_cleanup(); cleanup_configuration_system(); + + // Cleanup subscription manager mutexes + pthread_mutex_destroy(&g_subscription_manager.subscriptions_lock); + pthread_mutex_destroy(&g_subscription_manager.ip_tracking_lock); + nostr_cleanup(); close_database(); diff --git a/src/subscriptions.c b/src/subscriptions.c index 1679b11..7bd0c50 100644 --- a/src/subscriptions.c +++ b/src/subscriptions.c @@ -472,52 +472,102 @@ int broadcast_event_to_subscriptions(cJSON* event) { } int broadcasts = 0; + + // Create a temporary list of matching subscriptions to avoid holding lock during I/O + typedef struct temp_sub { + struct lws* wsi; + char id[SUBSCRIPTION_ID_MAX_LENGTH]; + char client_ip[CLIENT_IP_MAX_LENGTH]; + struct temp_sub* next; + } temp_sub_t; + + temp_sub_t* matching_subs = NULL; + int matching_count = 0; + // First pass: collect matching subscriptions while holding lock pthread_mutex_lock(&g_subscription_manager.subscriptions_lock); subscription_t* sub = g_subscription_manager.active_subscriptions; while (sub) { - if (sub->active && event_matches_subscription(event, sub)) { - // Create EVENT message for this subscription - cJSON* event_msg = cJSON_CreateArray(); - cJSON_AddItemToArray(event_msg, cJSON_CreateString("EVENT")); - cJSON_AddItemToArray(event_msg, cJSON_CreateString(sub->id)); - cJSON_AddItemToArray(event_msg, cJSON_Duplicate(event, 1)); - - char* msg_str = cJSON_Print(event_msg); - if (msg_str) { - size_t msg_len = strlen(msg_str); - unsigned char* buf = malloc(LWS_PRE + msg_len); - if (buf) { - memcpy(buf + LWS_PRE, msg_str, msg_len); - - // Send to WebSocket connection - int write_result = lws_write(sub->wsi, buf + LWS_PRE, msg_len, LWS_WRITE_TEXT); - if (write_result >= 0) { - sub->events_sent++; - broadcasts++; - - // Log event broadcast to database (optional - can be disabled for performance) - cJSON* event_id_obj = cJSON_GetObjectItem(event, "id"); - if (event_id_obj && cJSON_IsString(event_id_obj)) { - log_event_broadcast(cJSON_GetStringValue(event_id_obj), sub->id, sub->client_ip); - } - } - - free(buf); - } - free(msg_str); + if (sub->active && sub->wsi && event_matches_subscription(event, sub)) { + temp_sub_t* temp = malloc(sizeof(temp_sub_t)); + if (temp) { + temp->wsi = sub->wsi; + strncpy(temp->id, sub->id, SUBSCRIPTION_ID_MAX_LENGTH - 1); + temp->id[SUBSCRIPTION_ID_MAX_LENGTH - 1] = '\0'; + strncpy(temp->client_ip, sub->client_ip, CLIENT_IP_MAX_LENGTH - 1); + temp->client_ip[CLIENT_IP_MAX_LENGTH - 1] = '\0'; + temp->next = matching_subs; + matching_subs = temp; + matching_count++; } - - cJSON_Delete(event_msg); + } + sub = sub->next; + } + + pthread_mutex_unlock(&g_subscription_manager.subscriptions_lock); + + // Second pass: send messages without holding lock + temp_sub_t* current_temp = matching_subs; + while (current_temp) { + // Create EVENT message for this subscription + cJSON* event_msg = cJSON_CreateArray(); + cJSON_AddItemToArray(event_msg, cJSON_CreateString("EVENT")); + cJSON_AddItemToArray(event_msg, cJSON_CreateString(current_temp->id)); + cJSON_AddItemToArray(event_msg, cJSON_Duplicate(event, 1)); + + char* msg_str = cJSON_Print(event_msg); + if (msg_str) { + size_t msg_len = strlen(msg_str); + unsigned char* buf = malloc(LWS_PRE + msg_len); + if (buf) { + memcpy(buf + LWS_PRE, msg_str, msg_len); + + // Send to WebSocket connection with error checking + // Note: lws_write can fail if connection is closed, but won't crash + int write_result = lws_write(current_temp->wsi, buf + LWS_PRE, msg_len, LWS_WRITE_TEXT); + if (write_result >= 0) { + broadcasts++; + + // Update events sent counter for this subscription + pthread_mutex_lock(&g_subscription_manager.subscriptions_lock); + subscription_t* update_sub = g_subscription_manager.active_subscriptions; + while (update_sub) { + if (update_sub->wsi == current_temp->wsi && + strcmp(update_sub->id, current_temp->id) == 0) { + update_sub->events_sent++; + break; + } + update_sub = update_sub->next; + } + pthread_mutex_unlock(&g_subscription_manager.subscriptions_lock); + + // Log event broadcast to database (optional - can be disabled for performance) + cJSON* event_id_obj = cJSON_GetObjectItem(event, "id"); + if (event_id_obj && cJSON_IsString(event_id_obj)) { + log_event_broadcast(cJSON_GetStringValue(event_id_obj), current_temp->id, current_temp->client_ip); + } + } + + free(buf); + } + free(msg_str); } - sub = sub->next; + cJSON_Delete(event_msg); + current_temp = current_temp->next; + } + + // Clean up temporary subscription list + while (matching_subs) { + temp_sub_t* next = matching_subs->next; + free(matching_subs); + matching_subs = next; } // Update global statistics + pthread_mutex_lock(&g_subscription_manager.subscriptions_lock); g_subscription_manager.total_events_broadcast += broadcasts; - pthread_mutex_unlock(&g_subscription_manager.subscriptions_lock); return broadcasts; @@ -688,19 +738,149 @@ void log_event_broadcast(const char* event_id, const char* sub_id, const char* c // Update events sent counter for a subscription void update_subscription_events_sent(const char* sub_id, int events_sent) { if (!g_db || !sub_id) return; - + const char* sql = "UPDATE subscription_events " "SET events_sent = ? " "WHERE subscription_id = ? AND event_type = 'created'"; - + sqlite3_stmt* stmt; int rc = sqlite3_prepare_v2(g_db, sql, -1, &stmt, NULL); if (rc == SQLITE_OK) { sqlite3_bind_int(stmt, 1, events_sent); sqlite3_bind_text(stmt, 2, sub_id, -1, SQLITE_STATIC); - + sqlite3_step(stmt); sqlite3_finalize(stmt); } } + + +/////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////// +// PER-IP CONNECTION TRACKING +/////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////// + +// Get or create IP connection info (thread-safe) +ip_connection_info_t* get_or_create_ip_connection(const char* client_ip) { + if (!client_ip) return NULL; + + pthread_mutex_lock(&g_subscription_manager.ip_tracking_lock); + + // Look for existing IP connection info + ip_connection_info_t* current = g_subscription_manager.ip_connections; + while (current) { + if (strcmp(current->ip_address, client_ip) == 0) { + // Found existing entry, update activity + current->last_activity = time(NULL); + pthread_mutex_unlock(&g_subscription_manager.ip_tracking_lock); + return current; + } + current = current->next; + } + + // Create new IP connection info + ip_connection_info_t* new_ip = calloc(1, sizeof(ip_connection_info_t)); + if (!new_ip) { + pthread_mutex_unlock(&g_subscription_manager.ip_tracking_lock); + return NULL; + } + + // Copy IP address safely + strncpy(new_ip->ip_address, client_ip, CLIENT_IP_MAX_LENGTH - 1); + new_ip->ip_address[CLIENT_IP_MAX_LENGTH - 1] = '\0'; + + // Initialize tracking data + time_t now = time(NULL); + new_ip->active_connections = 1; + new_ip->total_subscriptions = 0; + new_ip->first_connection = now; + new_ip->last_activity = now; + + // Add to linked list + new_ip->next = g_subscription_manager.ip_connections; + g_subscription_manager.ip_connections = new_ip; + + pthread_mutex_unlock(&g_subscription_manager.ip_tracking_lock); + return new_ip; +} + +// Update IP connection activity timestamp +void update_ip_connection_activity(const char* client_ip) { + if (!client_ip) return; + + pthread_mutex_lock(&g_subscription_manager.ip_tracking_lock); + + ip_connection_info_t* current = g_subscription_manager.ip_connections; + while (current) { + if (strcmp(current->ip_address, client_ip) == 0) { + current->last_activity = time(NULL); + break; + } + current = current->next; + } + + pthread_mutex_unlock(&g_subscription_manager.ip_tracking_lock); +} + +// Remove IP connection (when last connection from IP closes) +void remove_ip_connection(const char* client_ip) { + if (!client_ip) return; + + pthread_mutex_lock(&g_subscription_manager.ip_tracking_lock); + + ip_connection_info_t** current = &g_subscription_manager.ip_connections; + while (*current) { + ip_connection_info_t* entry = *current; + if (strcmp(entry->ip_address, client_ip) == 0) { + // Remove from list + *current = entry->next; + free(entry); + break; + } + current = &((*current)->next); + } + + pthread_mutex_unlock(&g_subscription_manager.ip_tracking_lock); +} + +// Get total subscriptions for an IP address +int get_total_subscriptions_for_ip(const char* client_ip) { + if (!client_ip) return 0; + + pthread_mutex_lock(&g_subscription_manager.ip_tracking_lock); + + ip_connection_info_t* current = g_subscription_manager.ip_connections; + while (current) { + if (strcmp(current->ip_address, client_ip) == 0) { + int total = current->total_subscriptions; + pthread_mutex_unlock(&g_subscription_manager.ip_tracking_lock); + return total; + } + current = current->next; + } + + pthread_mutex_unlock(&g_subscription_manager.ip_tracking_lock); + return 0; +} + +// Get active connections for an IP address +int get_active_connections_for_ip(const char* client_ip) { + if (!client_ip) return 0; + + pthread_mutex_lock(&g_subscription_manager.ip_tracking_lock); + + ip_connection_info_t* current = g_subscription_manager.ip_connections; + while (current) { + if (strcmp(current->ip_address, client_ip) == 0) { + int active = current->active_connections; + pthread_mutex_unlock(&g_subscription_manager.ip_tracking_lock); + return active; + } + current = current->next; + } + + pthread_mutex_unlock(&g_subscription_manager.ip_tracking_lock); + return 0; +} diff --git a/src/subscriptions.h b/src/subscriptions.h index 11a7ba3..2d120fc 100644 --- a/src/subscriptions.h +++ b/src/subscriptions.h @@ -55,6 +55,16 @@ struct subscription { struct subscription* session_next; // Next subscription for this session }; +// Per-IP connection tracking +typedef struct ip_connection_info { + char ip_address[CLIENT_IP_MAX_LENGTH]; // IP address + int active_connections; // Number of active connections from this IP + int total_subscriptions; // Total subscriptions across all connections from this IP + time_t first_connection; // When first connection from this IP was established + time_t last_activity; // Last activity timestamp from this IP + struct ip_connection_info* next; // Next in linked list +} ip_connection_info_t; + // Global subscription manager struct subscription_manager { subscription_t* active_subscriptions; // Head of global subscription list @@ -65,6 +75,10 @@ struct subscription_manager { int max_subscriptions_per_client; // Default: 20 int max_total_subscriptions; // Default: 5000 + // Per-IP connection tracking + ip_connection_info_t* ip_connections; // Head of per-IP connection list + pthread_mutex_t ip_tracking_lock; // Thread safety for IP tracking + // Statistics uint64_t total_created; // Lifetime subscription count uint64_t total_events_broadcast; // Lifetime event broadcast count @@ -81,6 +95,13 @@ int event_matches_filter(cJSON* event, subscription_filter_t* filter); int event_matches_subscription(cJSON* event, subscription_t* subscription); int broadcast_event_to_subscriptions(cJSON* event); +// Per-IP connection tracking functions +ip_connection_info_t* get_or_create_ip_connection(const char* client_ip); +void update_ip_connection_activity(const char* client_ip); +void remove_ip_connection(const char* client_ip); +int get_total_subscriptions_for_ip(const char* client_ip); +int get_active_connections_for_ip(const char* client_ip); + // Database logging functions void log_subscription_created(const subscription_t* sub); void log_subscription_closed(const char* sub_id, const char* client_ip, const char* reason); diff --git a/src/websockets.c b/src/websockets.c index 0ee7a70..a5c4ae3 100644 --- a/src/websockets.c +++ b/src/websockets.c @@ -1200,6 +1200,11 @@ int handle_count_message(const char* sub_id, cJSON* filters, struct lws *wsi, st return 0; } + // Parameter binding helpers + char** bind_params = NULL; + int bind_param_count = 0; + int bind_param_capacity = 0; + int total_count = 0; // Process each filter in the array @@ -1210,6 +1215,15 @@ int handle_count_message(const char* sub_id, cJSON* filters, struct lws *wsi, st continue; } + // Reset bind params for this filter + for (int j = 0; j < bind_param_count; j++) { + free(bind_params[j]); + } + free(bind_params); + bind_params = NULL; + bind_param_count = 0; + bind_param_capacity = 0; + // Build SQL COUNT query based on filter - exclude ephemeral events (kinds 20000-29999) from historical queries char sql[1024] = "SELECT COUNT(*) FROM events WHERE 1=1 AND (kind < 20000 OR kind >= 30000)"; char* sql_ptr = sql + strlen(sql); @@ -1249,56 +1263,88 @@ int handle_count_message(const char* sub_id, cJSON* filters, struct lws *wsi, st // Handle authors filter cJSON* authors = cJSON_GetObjectItem(filter, "authors"); if (authors && cJSON_IsArray(authors)) { - int author_count = cJSON_GetArraySize(authors); + int author_count = 0; + // Count valid authors + for (int a = 0; a < cJSON_GetArraySize(authors); a++) { + cJSON* author = cJSON_GetArrayItem(authors, a); + if (cJSON_IsString(author)) { + author_count++; + } + } if (author_count > 0) { snprintf(sql_ptr, remaining, " AND pubkey IN ("); sql_ptr += strlen(sql_ptr); remaining = sizeof(sql) - strlen(sql); for (int a = 0; a < author_count; a++) { - cJSON* author = cJSON_GetArrayItem(authors, a); - if (cJSON_IsString(author)) { - if (a > 0) { - snprintf(sql_ptr, remaining, ","); - sql_ptr++; - remaining--; - } - snprintf(sql_ptr, remaining, "'%s'", cJSON_GetStringValue(author)); - sql_ptr += strlen(sql_ptr); - remaining = sizeof(sql) - strlen(sql); + if (a > 0) { + snprintf(sql_ptr, remaining, ","); + sql_ptr++; + remaining--; } + snprintf(sql_ptr, remaining, "?"); + sql_ptr += strlen(sql_ptr); + remaining = sizeof(sql) - strlen(sql); } snprintf(sql_ptr, remaining, ")"); sql_ptr += strlen(sql_ptr); remaining = sizeof(sql) - strlen(sql); + + // Add author values to bind params + for (int a = 0; a < cJSON_GetArraySize(authors); a++) { + cJSON* author = cJSON_GetArrayItem(authors, a); + if (cJSON_IsString(author)) { + if (bind_param_count >= bind_param_capacity) { + bind_param_capacity = bind_param_capacity == 0 ? 16 : bind_param_capacity * 2; + bind_params = realloc(bind_params, bind_param_capacity * sizeof(char*)); + } + bind_params[bind_param_count++] = strdup(cJSON_GetStringValue(author)); + } + } } } // Handle ids filter cJSON* ids = cJSON_GetObjectItem(filter, "ids"); if (ids && cJSON_IsArray(ids)) { - int id_count = cJSON_GetArraySize(ids); + int id_count = 0; + // Count valid ids + for (int i = 0; i < cJSON_GetArraySize(ids); i++) { + cJSON* id = cJSON_GetArrayItem(ids, i); + if (cJSON_IsString(id)) { + id_count++; + } + } if (id_count > 0) { snprintf(sql_ptr, remaining, " AND id IN ("); sql_ptr += strlen(sql_ptr); remaining = sizeof(sql) - strlen(sql); for (int i = 0; i < id_count; i++) { - cJSON* id = cJSON_GetArrayItem(ids, i); - if (cJSON_IsString(id)) { - if (i > 0) { - snprintf(sql_ptr, remaining, ","); - sql_ptr++; - remaining--; - } - snprintf(sql_ptr, remaining, "'%s'", cJSON_GetStringValue(id)); - sql_ptr += strlen(sql_ptr); - remaining = sizeof(sql) - strlen(sql); + if (i > 0) { + snprintf(sql_ptr, remaining, ","); + sql_ptr++; + remaining--; } + snprintf(sql_ptr, remaining, "?"); + sql_ptr += strlen(sql_ptr); + remaining = sizeof(sql) - strlen(sql); } snprintf(sql_ptr, remaining, ")"); sql_ptr += strlen(sql_ptr); remaining = sizeof(sql) - strlen(sql); + + // Add id values to bind params + for (int i = 0; i < cJSON_GetArraySize(ids); i++) { + cJSON* id = cJSON_GetArrayItem(ids, i); + if (cJSON_IsString(id)) { + if (bind_param_count >= bind_param_capacity) { + bind_param_capacity = bind_param_capacity == 0 ? 16 : bind_param_capacity * 2; + bind_params = realloc(bind_params, bind_param_capacity * sizeof(char*)); + } + bind_params[bind_param_count++] = strdup(cJSON_GetStringValue(id)); + } + } } } @@ -1311,29 +1357,50 @@ int handle_count_message(const char* sub_id, cJSON* filters, struct lws *wsi, st const char* tag_name = filter_key + 1; // Get the tag name (e, p, t, type, etc.) if (cJSON_IsArray(filter_item)) { - int tag_value_count = cJSON_GetArraySize(filter_item); + int tag_value_count = 0; + // Count valid tag values + for (int i = 0; i < cJSON_GetArraySize(filter_item); i++) { + cJSON* tag_value = cJSON_GetArrayItem(filter_item, i); + if (cJSON_IsString(tag_value)) { + tag_value_count++; + } + } if (tag_value_count > 0) { - // Use EXISTS with JSON extraction to check for matching tags - snprintf(sql_ptr, remaining, " AND EXISTS (SELECT 1 FROM json_each(json(tags)) WHERE json_extract(value, '$[0]') = '%s' AND json_extract(value, '$[1]') IN (", tag_name); + // Use EXISTS with parameterized query + snprintf(sql_ptr, remaining, " AND EXISTS (SELECT 1 FROM json_each(json(tags)) WHERE json_extract(value, '$[0]') = ? AND json_extract(value, '$[1]') IN ("); sql_ptr += strlen(sql_ptr); remaining = sizeof(sql) - strlen(sql); for (int i = 0; i < tag_value_count; i++) { - cJSON* tag_value = cJSON_GetArrayItem(filter_item, i); - if (cJSON_IsString(tag_value)) { - if (i > 0) { - snprintf(sql_ptr, remaining, ","); - sql_ptr++; - remaining--; - } - snprintf(sql_ptr, remaining, "'%s'", cJSON_GetStringValue(tag_value)); - sql_ptr += strlen(sql_ptr); - remaining = sizeof(sql) - strlen(sql); + if (i > 0) { + snprintf(sql_ptr, remaining, ","); + sql_ptr++; + remaining--; } + snprintf(sql_ptr, remaining, "?"); + sql_ptr += strlen(sql_ptr); + remaining = sizeof(sql) - strlen(sql); } snprintf(sql_ptr, remaining, "))"); sql_ptr += strlen(sql_ptr); remaining = sizeof(sql) - strlen(sql); + + // Add tag name and values to bind params + if (bind_param_count >= bind_param_capacity) { + bind_param_capacity = bind_param_capacity == 0 ? 16 : bind_param_capacity * 2; + bind_params = realloc(bind_params, bind_param_capacity * sizeof(char*)); + } + bind_params[bind_param_count++] = strdup(tag_name); + for (int i = 0; i < cJSON_GetArraySize(filter_item); i++) { + cJSON* tag_value = cJSON_GetArrayItem(filter_item, i); + if (cJSON_IsString(tag_value)) { + if (bind_param_count >= bind_param_capacity) { + bind_param_capacity = bind_param_capacity == 0 ? 16 : bind_param_capacity * 2; + bind_params = realloc(bind_params, bind_param_capacity * sizeof(char*)); + } + bind_params[bind_param_count++] = strdup(cJSON_GetStringValue(tag_value)); + } + } } } } @@ -1395,6 +1462,11 @@ int handle_count_message(const char* sub_id, cJSON* filters, struct lws *wsi, st continue; } + // Bind parameters + for (int i = 0; i < bind_param_count; i++) { + sqlite3_bind_text(stmt, i + 1, bind_params[i], -1, SQLITE_TRANSIENT); + } + int filter_count = 0; if (sqlite3_step(stmt) == SQLITE_ROW) { filter_count = sqlite3_column_int(stmt, 0); @@ -1431,5 +1503,11 @@ int handle_count_message(const char* sub_id, cJSON* filters, struct lws *wsi, st } cJSON_Delete(count_response); + // Cleanup bind params + for (int i = 0; i < bind_param_count; i++) { + free(bind_params[i]); + } + free(bind_params); + return total_count; } diff --git a/src/websockets.h b/src/websockets.h index 27b44de..cee4b20 100644 --- a/src/websockets.h +++ b/src/websockets.h @@ -14,7 +14,7 @@ #define CHALLENGE_MAX_LENGTH 128 #define AUTHENTICATED_PUBKEY_MAX_LENGTH 65 // 64 hex + null -// Enhanced per-session data with subscription management and NIP-42 authentication +// Enhanced per-session data with subscription management, NIP-42 authentication, and rate limiting struct per_session_data { int authenticated; struct subscription* subscriptions; // Head of this session's subscription list @@ -30,6 +30,12 @@ struct per_session_data { int nip42_auth_required_events; // Whether NIP-42 auth is required for EVENT submission int nip42_auth_required_subscriptions; // Whether NIP-42 auth is required for REQ operations int auth_challenge_sent; // Whether challenge has been sent (0/1) + + // Rate limiting for subscription attempts + int failed_subscription_attempts; // Count of failed subscription attempts + time_t last_failed_attempt; // Timestamp of last failed attempt + time_t rate_limit_until; // Time until rate limiting expires + int consecutive_failures; // Consecutive failed attempts for backoff }; // NIP-11 HTTP session data structure for managing buffer lifetime diff --git a/tests/subscription_limits.sh b/tests/subscription_limits.sh new file mode 100755 index 0000000..3ef5f53 --- /dev/null +++ b/tests/subscription_limits.sh @@ -0,0 +1,63 @@ +#!/bin/bash + +# Simple test script to verify subscription limit enforcement and rate limiting +# This script tests that subscription limits are enforced early + +set -e + +RELAY_URL="ws://127.0.0.1:8888" + +echo "=== Subscription Limit Test ===" +echo "[INFO] Testing relay at: $RELAY_URL" +echo "[INFO] Note: This test assumes default subscription limits (max 25 per client)" +echo "" + +# Test basic connectivity first +echo "=== Test 1: Basic Connectivity ===" +echo "[INFO] Testing basic WebSocket connection..." + +# Send a simple REQ message +response=$(echo '["REQ","basic_test",{}]' | timeout 5 websocat -n1 "$RELAY_URL" 2>/dev/null || echo "TIMEOUT") + +if echo "$response" | grep -q "EOSE\|EVENT\|NOTICE"; then + echo "[PASS] Basic connectivity works" +else + echo "[FAIL] Basic connectivity failed. Response: $response" + exit 1 +fi +echo "" + +# Test subscription limits +echo "=== Test 2: Subscription Limit Enforcement ===" +echo "[INFO] Testing subscription limits by creating multiple subscriptions..." + +success_count=0 +limit_hit=false + +# Create multiple subscriptions in sequence (each in its own connection) +for i in {1..30}; do + echo "[INFO] Creating subscription $i..." + sub_id="limit_test_$i_$(date +%s%N)" + response=$(echo "[\"REQ\",\"$sub_id\",{}]" | timeout 5 websocat -n1 "$RELAY_URL" 2>/dev/null || echo "TIMEOUT") + + if echo "$response" | grep -q "CLOSED.*$sub_id.*exceeded"; then + echo "[INFO] Hit subscription limit at subscription $i" + limit_hit=true + break + elif echo "$response" | grep -q "EOSE\|EVENT"; then + ((success_count++)) + else + echo "[WARN] Unexpected response for subscription $i: $response" + fi + + sleep 0.1 +done + +if [ "$limit_hit" = true ]; then + echo "[PASS] Subscription limit enforcement working (limit hit after $success_count subscriptions)" +else + echo "[WARN] Subscription limit not hit after 30 attempts" +fi +echo "" + +echo "=== Test Complete ===" \ No newline at end of file