/* * nostr_chacha20.c - ChaCha20 stream cipher implementation * * Implementation based on RFC 8439 "ChaCha20 and Poly1305 for IETF Protocols" * * This implementation is adapted from the RFC 8439 reference specification. * It prioritizes correctness and clarity over performance optimization. */ #include "nostr_chacha20.h" #include /* * ============================================================================ * UTILITY MACROS AND FUNCTIONS * ============================================================================ */ /* Left rotate a 32-bit value by n bits */ #define ROTLEFT(a, b) (((a) << (b)) | ((a) >> (32 - (b)))) /* Convert 4 bytes to 32-bit little-endian */ static uint32_t bytes_to_u32_le(const uint8_t *bytes) { return ((uint32_t)bytes[0]) | ((uint32_t)bytes[1] << 8) | ((uint32_t)bytes[2] << 16) | ((uint32_t)bytes[3] << 24); } /* Convert 32-bit to 4 bytes little-endian */ static void u32_to_bytes_le(uint32_t val, uint8_t *bytes) { bytes[0] = (uint8_t)(val & 0xff); bytes[1] = (uint8_t)((val >> 8) & 0xff); bytes[2] = (uint8_t)((val >> 16) & 0xff); bytes[3] = (uint8_t)((val >> 24) & 0xff); } /* * ============================================================================ * CHACHA20 CORE FUNCTIONS * ============================================================================ */ void chacha20_quarter_round(uint32_t state[16], int a, int b, int c, int d) { state[a] += state[b]; state[d] ^= state[a]; state[d] = ROTLEFT(state[d], 16); state[c] += state[d]; state[b] ^= state[c]; state[b] = ROTLEFT(state[b], 12); state[a] += state[b]; state[d] ^= state[a]; state[d] = ROTLEFT(state[d], 8); state[c] += state[d]; state[b] ^= state[c]; state[b] = ROTLEFT(state[b], 7); } void chacha20_init_state(uint32_t state[16], const uint8_t key[32], uint32_t counter, const uint8_t nonce[12]) { /* ChaCha20 constants "expand 32-byte k" */ state[0] = 0x61707865; state[1] = 0x3320646e; state[2] = 0x79622d32; state[3] = 0x6b206574; /* Key (8 words) */ state[4] = bytes_to_u32_le(key + 0); state[5] = bytes_to_u32_le(key + 4); state[6] = bytes_to_u32_le(key + 8); state[7] = bytes_to_u32_le(key + 12); state[8] = bytes_to_u32_le(key + 16); state[9] = bytes_to_u32_le(key + 20); state[10] = bytes_to_u32_le(key + 24); state[11] = bytes_to_u32_le(key + 28); /* Counter (1 word) */ state[12] = counter; /* Nonce (3 words) */ state[13] = bytes_to_u32_le(nonce + 0); state[14] = bytes_to_u32_le(nonce + 4); state[15] = bytes_to_u32_le(nonce + 8); } void chacha20_serialize_state(const uint32_t state[16], uint8_t output[64]) { for (int i = 0; i < 16; i++) { u32_to_bytes_le(state[i], output + (i * 4)); } } int chacha20_block(const uint8_t key[32], uint32_t counter, const uint8_t nonce[12], uint8_t output[64]) { uint32_t state[16]; uint32_t initial_state[16]; /* Initialize state */ chacha20_init_state(state, key, counter, nonce); /* Save initial state for later addition */ memcpy(initial_state, state, sizeof(initial_state)); /* Perform 20 rounds (10 iterations of the 8 quarter rounds) */ for (int i = 0; i < 10; i++) { /* Column rounds */ chacha20_quarter_round(state, 0, 4, 8, 12); chacha20_quarter_round(state, 1, 5, 9, 13); chacha20_quarter_round(state, 2, 6, 10, 14); chacha20_quarter_round(state, 3, 7, 11, 15); /* Diagonal rounds */ chacha20_quarter_round(state, 0, 5, 10, 15); chacha20_quarter_round(state, 1, 6, 11, 12); chacha20_quarter_round(state, 2, 7, 8, 13); chacha20_quarter_round(state, 3, 4, 9, 14); } /* Add initial state back (prevents slide attacks) */ for (int i = 0; i < 16; i++) { state[i] += initial_state[i]; } /* Serialize to output bytes */ chacha20_serialize_state(state, output); return 0; } int chacha20_encrypt(const uint8_t key[32], uint32_t counter, const uint8_t nonce[12], const uint8_t* input, uint8_t* output, size_t length) { uint8_t keystream[CHACHA20_BLOCK_SIZE]; size_t offset = 0; while (length > 0) { /* Generate keystream block */ int ret = chacha20_block(key, counter, nonce, keystream); if (ret != 0) { return ret; } /* XOR with input to produce output */ size_t block_len = (length < CHACHA20_BLOCK_SIZE) ? length : CHACHA20_BLOCK_SIZE; for (size_t i = 0; i < block_len; i++) { output[offset + i] = input[offset + i] ^ keystream[i]; } /* Move to next block */ offset += block_len; length -= block_len; counter++; /* Check for counter overflow */ if (counter == 0) { return -1; /* Counter wrapped around */ } } return 0; }