/* SPDX-License-Identifier: 0BSD */


#include "openssl.hpp"
#include "main.hpp"


#define GET_AES                                                                         \
	auto aes = TRY_PTR("AES-256-GCM", EVP_CIPHER_fetch(nullptr, "AES-256-GCM", nullptr)); \
	quickscope_wrapper aes_deleter{[&] { EVP_CIPHER_free(aes); }};                        \
	/*static_*/ assert(EVP_CIPHER_get_key_length(aes) == 256 / 8);                        \
	/*static_*/ assert(static_cast<size_t>(EVP_CIPHER_get_iv_length(aes)) <= sizeof(iv));

#define PREP_CTX(initfn)                                                           \
	auto ctx = TRY_PTR("EVP_CIPHER_CTX_new", EVP_CIPHER_CTX_new());                  \
	quickscope_wrapper ctx_deleter{[&] { EVP_CIPHER_CTX_free(ctx); }};               \
	{                                                                                \
		unsigned int padding = false;                                                  \
		OSSL_PARAM params[]  = {OSSL_PARAM_uint("padding", &padding), OSSL_PARAM_END}; \
		TRY_OSSL(#initfn, initfn(ctx, aes, sym_key, iv, params));                      \
	}


int encrypt_backup(uint8_t (&encrypted)[ENCRYPTED_LEN], const uint8_t (&sym_key)[AES_KEY_LEN], const uint8_t (&iv)[ENCRYPTED_AES_IV_LEN],
                   const uint8_t (&wrap_key)[WRAPPING_KEY_LEN]) {
	GET_AES
	PREP_CTX(EVP_EncryptInit_ex2)

	int encrypted_len;
	TRY_OSSL("EVP_EncryptUpdate", EVP_EncryptUpdate(ctx, encrypted, &encrypted_len, wrap_key, sizeof(wrap_key)));
	/*static_*/ assert(encrypted_len == ENCRYPTED_LEN);

	/*static_*/ assert(EVP_EncryptFinal_ex(ctx, encrypted, &encrypted_len));
	/*static_*/ assert(!encrypted_len);
	return 0;
}


int decrypt_backup(uint8_t (&wrap_key)[WRAPPING_KEY_LEN], const uint8_t (&sym_key)[AES_KEY_LEN], const uint8_t (&iv)[ENCRYPTED_AES_IV_LEN],
                   const uint8_t (&encrypted)[ENCRYPTED_LEN]) {
	GET_AES
	PREP_CTX(EVP_DecryptInit_ex2)

	int wrap_key_len;
	TRY_OSSL("EVP_DecryptUpdate", EVP_DecryptUpdate(ctx, wrap_key, &wrap_key_len, encrypted, sizeof(encrypted)));
	/*static_*/ assert(wrap_key_len == sizeof(wrap_key));

	return 0;
}
