nip44: make the api less classy.

This commit is contained in:
fiatjaf 2024-05-19 14:40:23 -03:00
parent 44efd49bc0
commit 5876acd67a
2 changed files with 113 additions and 116 deletions

View File

@ -7,7 +7,7 @@ const v2vec = vec.v2
test('get_conversation_key', () => { test('get_conversation_key', () => {
for (const v of v2vec.valid.get_conversation_key) { for (const v of v2vec.valid.get_conversation_key) {
const key = v2.utils.getConversationKey(v.sec1, v.pub2) const key = v2.utils.getConversationKey(hexToBytes(v.sec1), v.pub2)
expect(bytesToHex(key)).toEqual(v.conversation_key) expect(bytesToHex(key)).toEqual(v.conversation_key)
} }
}) })
@ -15,7 +15,7 @@ test('get_conversation_key', () => {
test('encrypt_decrypt', () => { test('encrypt_decrypt', () => {
for (const v of v2vec.valid.encrypt_decrypt) { for (const v of v2vec.valid.encrypt_decrypt) {
const pub2 = bytesToHex(schnorr.getPublicKey(v.sec2)) const pub2 = bytesToHex(schnorr.getPublicKey(v.sec2))
const key = v2.utils.getConversationKey(v.sec1, pub2) const key = v2.utils.getConversationKey(hexToBytes(v.sec1), pub2)
expect(bytesToHex(key)).toEqual(v.conversation_key) expect(bytesToHex(key)).toEqual(v.conversation_key)
const ciphertext = v2.encrypt(v.plaintext, key, hexToBytes(v.nonce)) const ciphertext = v2.encrypt(v.plaintext, key, hexToBytes(v.nonce))
expect(ciphertext).toEqual(v.payload) expect(ciphertext).toEqual(v.payload)
@ -39,6 +39,8 @@ test('decrypt', async () => {
test('get_conversation_key', async () => { test('get_conversation_key', async () => {
for (const v of v2vec.invalid.get_conversation_key) { for (const v of v2vec.invalid.get_conversation_key) {
expect(() => v2.utils.getConversationKey(v.sec1, v.pub2)).toThrow(/(Point is not on curve|Cannot find square root)/) expect(() => v2.utils.getConversationKey(hexToBytes(v.sec1), v.pub2)).toThrow(
/(Point is not on curve|Cannot find square root)/,
)
} }
}) })

109
nip44.ts
View File

@ -4,88 +4,81 @@ import { secp256k1 } from '@noble/curves/secp256k1'
import { extract as hkdf_extract, expand as hkdf_expand } from '@noble/hashes/hkdf' import { extract as hkdf_extract, expand as hkdf_expand } from '@noble/hashes/hkdf'
import { hmac } from '@noble/hashes/hmac' import { hmac } from '@noble/hashes/hmac'
import { sha256 } from '@noble/hashes/sha256' import { sha256 } from '@noble/hashes/sha256'
import { concatBytes, randomBytes, utf8ToBytes } from '@noble/hashes/utils' import { concatBytes, randomBytes } from '@noble/hashes/utils'
import { base64 } from '@scure/base' import { base64 } from '@scure/base'
const decoder = new TextDecoder() import { utf8Decoder, utf8Encoder } from './utils.ts'
class u { const minPlaintextSize = 0x0001 // 1b msg => padded to 32b
static minPlaintextSize = 0x0001 // 1b msg => padded to 32b const maxPlaintextSize = 0xffff // 65535 (64kb-1) => padded to 64kb
static maxPlaintextSize = 0xffff // 65535 (64kb-1) => padded to 64kb
static utf8Encode = utf8ToBytes export function getConversationKey(privkeyA: Uint8Array, pubkeyB: string): Uint8Array {
static utf8Decode(bytes: Uint8Array): string {
return decoder.decode(bytes)
}
static getConversationKey(privkeyA: string, pubkeyB: string): Uint8Array {
const sharedX = secp256k1.getSharedSecret(privkeyA, '02' + pubkeyB).subarray(1, 33) const sharedX = secp256k1.getSharedSecret(privkeyA, '02' + pubkeyB).subarray(1, 33)
return hkdf_extract(sha256, sharedX, 'nip44-v2') return hkdf_extract(sha256, sharedX, 'nip44-v2')
} }
static getMessageKeys( function getMessageKeys(
conversationKey: Uint8Array, conversationKey: Uint8Array,
nonce: Uint8Array, nonce: Uint8Array,
): { chacha_key: Uint8Array; chacha_nonce: Uint8Array; hmac_key: Uint8Array } { ): { chacha_key: Uint8Array; chacha_nonce: Uint8Array; hmac_key: Uint8Array } {
const keys = hkdf_expand(sha256, conversationKey, nonce, 76) const keys = hkdf_expand(sha256, conversationKey, nonce, 76)
return { return {
chacha_key: keys.subarray(0, 32), chacha_key: keys.subarray(0, 32),
chacha_nonce: keys.subarray(32, 44), chacha_nonce: keys.subarray(32, 44),
hmac_key: keys.subarray(44, 76), hmac_key: keys.subarray(44, 76),
} }
} }
static calcPaddedLen(len: number): number { function calcPaddedLen(len: number): number {
if (!Number.isSafeInteger(len) || len < 1) throw new Error('expected positive integer') if (!Number.isSafeInteger(len) || len < 1) throw new Error('expected positive integer')
if (len <= 32) return 32 if (len <= 32) return 32
const nextPower = 1 << (Math.floor(Math.log2(len - 1)) + 1) const nextPower = 1 << (Math.floor(Math.log2(len - 1)) + 1)
const chunk = nextPower <= 256 ? 32 : nextPower / 8 const chunk = nextPower <= 256 ? 32 : nextPower / 8
return chunk * (Math.floor((len - 1) / chunk) + 1) return chunk * (Math.floor((len - 1) / chunk) + 1)
} }
static writeU16BE(num: number): Uint8Array { function writeU16BE(num: number): Uint8Array {
if (!Number.isSafeInteger(num) || num < u.minPlaintextSize || num > u.maxPlaintextSize) if (!Number.isSafeInteger(num) || num < minPlaintextSize || num > maxPlaintextSize)
throw new Error('invalid plaintext size: must be between 1 and 65535 bytes') throw new Error('invalid plaintext size: must be between 1 and 65535 bytes')
const arr = new Uint8Array(2) const arr = new Uint8Array(2)
new DataView(arr.buffer).setUint16(0, num, false) new DataView(arr.buffer).setUint16(0, num, false)
return arr return arr
} }
static pad(plaintext: string): Uint8Array { function pad(plaintext: string): Uint8Array {
const unpadded = u.utf8Encode(plaintext) const unpadded = utf8Encoder.encode(plaintext)
const unpaddedLen = unpadded.length const unpaddedLen = unpadded.length
const prefix = u.writeU16BE(unpaddedLen) const prefix = writeU16BE(unpaddedLen)
const suffix = new Uint8Array(u.calcPaddedLen(unpaddedLen) - unpaddedLen) const suffix = new Uint8Array(calcPaddedLen(unpaddedLen) - unpaddedLen)
return concatBytes(prefix, unpadded, suffix) return concatBytes(prefix, unpadded, suffix)
} }
static unpad(padded: Uint8Array): string { function unpad(padded: Uint8Array): string {
const unpaddedLen = new DataView(padded.buffer).getUint16(0) const unpaddedLen = new DataView(padded.buffer).getUint16(0)
const unpadded = padded.subarray(2, 2 + unpaddedLen) const unpadded = padded.subarray(2, 2 + unpaddedLen)
if ( if (
unpaddedLen < u.minPlaintextSize || unpaddedLen < minPlaintextSize ||
unpaddedLen > u.maxPlaintextSize || unpaddedLen > maxPlaintextSize ||
unpadded.length !== unpaddedLen || unpadded.length !== unpaddedLen ||
padded.length !== 2 + u.calcPaddedLen(unpaddedLen) padded.length !== 2 + calcPaddedLen(unpaddedLen)
) )
throw new Error('invalid padding') throw new Error('invalid padding')
return u.utf8Decode(unpadded) return utf8Decoder.decode(unpadded)
} }
static hmacAad(key: Uint8Array, message: Uint8Array, aad: Uint8Array): Uint8Array { function hmacAad(key: Uint8Array, message: Uint8Array, aad: Uint8Array): Uint8Array {
if (aad.length !== 32) throw new Error('AAD associated data must be 32 bytes') if (aad.length !== 32) throw new Error('AAD associated data must be 32 bytes')
const combined = concatBytes(aad, message) const combined = concatBytes(aad, message)
return hmac(sha256, key, combined) return hmac(sha256, key, combined)
} }
// metadata: always 65b (version: 1b, nonce: 32b, max: 32b) // metadata: always 65b (version: 1b, nonce: 32b, max: 32b)
// plaintext: 1b to 0xffff // plaintext: 1b to 0xffff
// padded plaintext: 32b to 0xffff // padded plaintext: 32b to 0xffff
// ciphertext: 32b+2 to 0xffff+2 // ciphertext: 32b+2 to 0xffff+2
// raw payload: 99 (65+32+2) to 65603 (65+0xffff+2) // raw payload: 99 (65+32+2) to 65603 (65+0xffff+2)
// compressed payload (base64): 132b to 87472b // compressed payload (base64): 132b to 87472b
static decodePayload(payload: string): { nonce: Uint8Array; ciphertext: Uint8Array; mac: Uint8Array } { function decodePayload(payload: string): { nonce: Uint8Array; ciphertext: Uint8Array; mac: Uint8Array } {
if (typeof payload !== 'string') throw new Error('payload must be a valid string') if (typeof payload !== 'string') throw new Error('payload must be a valid string')
const plen = payload.length const plen = payload.length
if (plen < 132 || plen > 87472) throw new Error('invalid payload length: ' + plen) if (plen < 132 || plen > 87472) throw new Error('invalid payload length: ' + plen)
@ -105,28 +98,30 @@ class u {
ciphertext: data.subarray(33, -32), ciphertext: data.subarray(33, -32),
mac: data.subarray(-32), mac: data.subarray(-32),
} }
}
} }
export class v2 { export function encrypt(plaintext: string, conversationKey: Uint8Array, nonce: Uint8Array = randomBytes(32)): string {
static utils = u const { chacha_key, chacha_nonce, hmac_key } = getMessageKeys(conversationKey, nonce)
const padded = pad(plaintext)
static encrypt(plaintext: string, conversationKey: Uint8Array, nonce: Uint8Array = randomBytes(32)): string {
const { chacha_key, chacha_nonce, hmac_key } = u.getMessageKeys(conversationKey, nonce)
const padded = u.pad(plaintext)
const ciphertext = chacha20(chacha_key, chacha_nonce, padded) const ciphertext = chacha20(chacha_key, chacha_nonce, padded)
const mac = u.hmacAad(hmac_key, ciphertext, nonce) const mac = hmacAad(hmac_key, ciphertext, nonce)
return base64.encode(concatBytes(new Uint8Array([2]), nonce, ciphertext, mac)) return base64.encode(concatBytes(new Uint8Array([2]), nonce, ciphertext, mac))
} }
static decrypt(payload: string, conversationKey: Uint8Array): string { export function decrypt(payload: string, conversationKey: Uint8Array): string {
const { nonce, ciphertext, mac } = u.decodePayload(payload) const { nonce, ciphertext, mac } = decodePayload(payload)
const { chacha_key, chacha_nonce, hmac_key } = u.getMessageKeys(conversationKey, nonce) const { chacha_key, chacha_nonce, hmac_key } = getMessageKeys(conversationKey, nonce)
const calculatedMac = u.hmacAad(hmac_key, ciphertext, nonce) const calculatedMac = hmacAad(hmac_key, ciphertext, nonce)
if (!equalBytes(calculatedMac, mac)) throw new Error('invalid MAC') if (!equalBytes(calculatedMac, mac)) throw new Error('invalid MAC')
const padded = chacha20(chacha_key, chacha_nonce, ciphertext) const padded = chacha20(chacha_key, chacha_nonce, ciphertext)
return u.unpad(padded) return unpad(padded)
}
} }
export default { v2 } export const v2 = {
utils: {
getConversationKey,
calcPaddedLen,
},
encrypt,
decrypt,
}