/* SPDX-License-Identifier: 0BSD */


#include "fido2.hpp"
#include "fd.hpp"
#include "main.hpp"
#include <algorithm>
#include <dirent.h>
#include <iterator>
#include <numeric>
#include <stdio.h>
#include <string.h>
#include <string_view>
#include <sys/mman.h>
#include <sys/types.h>
#include <utility>

using namespace std::literals;


static fido_dev_t * cancel_dev_dev;
struct sigaction cancel_dev_restore;
static void cancel_dev(int sig) {
	fido_dev_cancel(cancel_dev_dev);
	sigaction(SIGINT, &cancel_dev_restore, nullptr);
	raise(sig);
}
template <class F>
static int with_cancelled(fido2_device & dev, F && f) {
	cancel_dev_dev = dev;
	{
		struct sigaction act{};
		act.sa_handler = cancel_dev;
		sigemptyset(&act.sa_mask);
		sigaction(SIGINT, &act, &cancel_dev_restore);
	}
	int ret = f();
	sigaction(SIGINT, &cancel_dev_restore, nullptr);
	cancel_dev_dev = nullptr;
	return ret;
}


struct fido2_find_device_dt {
	fido2_device & ret;
	int (*shop)(fido2_device & ret, bool & ret_ret, void * dt);
	void * dt;
};
int fido2_find_device(fido2_device & ret, int (*shop)(fido2_device & ret, bool & ret_ret, void * dt), void * dt) {
	if(!shop)
		shop = [](fido2_device &, bool &, void *) { return 0; };

	ret.dev = nullptr;
	fido2_find_device_dt our_dt{ret, shop, dt};
	TRY_MAIN(fido2_enumerate_devices(
	    [](fido2_device & device, bool & ret_ret, bool & deldev, void * our_dt_p) {
		    auto our_dt = *reinterpret_cast<fido2_find_device_dt *>(our_dt_p);

		    auto err = our_dt.shop(device, ret_ret, our_dt.dt);
		    if(!err)
			    our_dt.ret = device;
		    else
			    deldev = true;
		    return err;
	    },
	    &our_dt));

	if(!ret.dev)
		return fputs(gettext("No eligible FIDO2 devices!\n"), stderr), __LINE__;
	return 0;
}

static const char * or_unknown(const char * str) {
	// part of device branding
	return str && *str ? str : gettext("unknown");
}
int fido2_enumerate_devices(int (*visit)(fido2_device & device, bool & ret_ret, bool & deldev, void * dt), void * dt) {
	mlockall(MCL_CURRENT | MCL_FUTURE);

	auto devices = TRY_PTR("allocate devices for discovery", fido_dev_info_new(FIDO_MAXDEVS));
	quickscope_wrapper devices_deleter{[&] { fido_dev_info_free(&devices, FIDO_MAXDEVS); }};

	size_t found{};
	fido_dev_info_manifest(devices, FIDO_MAXDEVS, &found);
	if(!found)
		return fputs(gettext("No FIDO2 devices!\n"), stderr), __LINE__;
	for(size_t i = 0; i < found; ++i) {
		auto devinfo = fido_dev_info_ptr(devices, i);

		fido2_device ret{};
		ret.dev = TRY_PTR("allocate device", fido_dev_new_with_info(devinfo));
		if(fido_dev_open_with_info(ret.dev) != FIDO_OK) {
		next:
			fido_dev_close(ret.dev);
			fido_dev_free(&ret.dev);
			continue;
		}

		fido_cbor_info_t * dev_info;
		if(!(dev_info = TRY_PTR("allocate device info", fido_cbor_info_new())))
			goto next;
		quickscope_wrapper dev_info_deleter{[&] { fido_cbor_info_free(&dev_info); }};
		if(fido_dev_get_cbor_info(ret, dev_info) != FIDO_OK)
			goto next;
		auto exts = fido_cbor_info_extensions_ptr(dev_info);
		auto len  = fido_cbor_info_extensions_len(dev_info);
		if(!exts || !len || std::none_of(exts, exts + len, [](auto && e) { return !strcmp(e, "hmac-secret"); }))
			goto next;

		ret.has_del = false;
		{
			auto beg = fido_cbor_info_options_name_ptr(dev_info);
			auto end = beg + fido_cbor_info_options_len(dev_info);
			if(auto itr = std::find_if(beg, end, [](std::string_view n) { return n == "credMgmt"sv || n == "credentialMgmtPreview"sv; }); itr != end)
				ret.has_del = fido_cbor_info_options_value_ptr(dev_info)[itr - beg];
		}

		if(asprintf(&ret.name, "%s %s (%s)", or_unknown(fido_dev_info_manufacturer_string(devinfo)), or_unknown(fido_dev_info_product_string(devinfo)),
		            or_unknown(fido_dev_info_path(devinfo))) == -1)
			ret.name = strdup(gettext("unknown device"));

		bool ret_ret{}, deldev{};
		auto dret = visit(ret, ret_ret, deldev, dt);
		if(deldev) {
			free(ret.name);
			fido_dev_close(ret.dev);
			fido_dev_free(&ret.dev);
		}
		if(!dret || ret_ret)
			return dret;
	}
	return 0;
}


int fido2_genkey(fido2_cred_bundle & ret, const char * dataset, bool backup, fido2_device & dev) {
	auto cred = TRY_PTR("allocate FIDO2 credential", fido_cred_new());

	// https://docs.yubico.com/yesdk/users-manual/application-fido2/hmac-secret.html
	TRY_FIDO2("set credential extensions", fido_cred_set_extensions(cred, FIDO_EXT_HMAC_SECRET));

	struct {
		uint8_t uid[FIDO2_UID_LEN];
		uint8_t cdh[FIDO2_CLIENTDATAHASH_LEN];
	} blob;
	TRY("getentropy", getentropy(&blob, sizeof(blob)));

	// https://developers.yubico.com/libfido2/Manuals/fido_dev_make_cred.html
	TRY_FIDO2("set type", fido_cred_set_type(cred, COSE_ES256));
	TRY_FIDO2("set data hash", fido_cred_set_clientdata_hash(cred, blob.cdh, sizeof(blob.cdh)));
	TRY_FIDO2("set relying party", fido_cred_set_rp(cred, "fzifdso", dataset));
	TRY_FIDO2("set user attributes", fido_cred_set_user(cred, blob.uid, sizeof(blob.uid), dataset, nullptr, nullptr));
	// "list of excluded credential IDs" has no matches on fido_cred_set_authdata(3)
	TRY_FIDO2("set resident key", fido_cred_set_rk(cred, FIDO_OPT_FALSE));


	// %s=dataset, then device name
	fprintf(stderr, backup ? gettext("Confirm backup credential generation for %s on %s\n") : gettext("Confirm credential generation for %s on %s\n"), dataset,
	        dev.name);
	auto credres = with_cancelled(dev, [&] { return fido_dev_make_cred(dev, cred, nullptr); });
	if(credres == FIDO_ERR_PIN_REQUIRED && fido_dev_supports_pin(dev)) {
		char * pin;
		TRY_MAIN(fido2_prompt_pin(pin, dev));
		quickscope_wrapper pin_deleter{[&] { free(pin); }};

		TRY_FIDO2("generate credential", fido_dev_make_cred(dev, cred, pin));
	} else
		TRY_FIDO2("generate credential", credres);
	TRY_FIDO2("verify credential", fido_cred_verify(cred));


	ret.cred_id     = fido_cred_id_ptr(cred);
	ret.pubkey      = fido_cred_pubkey_ptr(cred);
	ret.cred_id_len = fido_cred_id_len(cred);
	ret.pubkey_len  = fido_cred_pubkey_len(cred);
	return 0;
}


int fido2_loadkey(uint8_t (&ret)[FIDO2_HMAC_SECRET_RESULT_LEN], const char * dataset, bool backup, fido2_device & dev, const fido2_cred_bundle & bundle,
                  bool * fatal) {
	auto assertion = TRY_PTR("allocate assertion", fido_assert_new());
	quickscope_wrapper assert_deleter{[&] { fido_assert_free(&assertion); }};
	TRY_FIDO2("set assertion count", fido_assert_set_count(assertion, 1));

	uint8_t cdh[FIDO2_CLIENTDATAHASH_LEN];
	TRY("getentropy", getentropy(cdh, sizeof(cdh)));


	// https://developers.yubico.com/libfido2/Manuals/fido2-cred.html#INPUT_FORMAT
	TRY_FIDO2("set assertion client data hash", fido_assert_set_clientdata_hash(assertion, cdh, sizeof(cdh)));
	TRY_FIDO2("set assertion relying party", fido_assert_set_rp(assertion, "fzifdso"));
	// authenticator data (base64 blob);
	TRY_FIDO2("set assertion credential ID", fido_assert_allow_cred(assertion, bundle.cred_id, bundle.cred_id_len));
	// attestation signature (base64 blob);
	// attestation certificate (optional, base64 blob).

	// https://docs.yubico.com/yesdk/users-manual/application-fido2/hmac-secret.html
	TRY_FIDO2("set assertion extensions", fido_assert_set_extensions(assertion, FIDO_EXT_HMAC_SECRET));
	TRY_FIDO2("set hmac salt", fido_assert_set_hmac_salt(assertion, bundle.salt, sizeof(bundle.salt)));


	// %s=dataset, then device name
	fprintf(stderr, backup ? gettext("Confirm unlock of %s backup on %s\n") : gettext("Confirm unlock of %s on %s\n"), dataset, dev.name);
	auto err = with_cancelled(dev, [&] { return fido_dev_get_assert(dev, assertion, nullptr); });
	if(fatal)
		*fatal = err != FIDO_ERR_NO_CREDENTIALS;
	TRY_FIDO2("get assert", err);


	auto pubkey_parsed = TRY_PTR("allocate pubkey", es256_pk_new());
	quickscope_wrapper pubkey_parsed_deleter{[&] { es256_pk_free(&pubkey_parsed); }};
	TRY_FIDO2("parse pubkey", es256_pk_from_ptr(pubkey_parsed, bundle.pubkey, bundle.pubkey_len));
	TRY_FIDO2("assert", fido_assert_verify(assertion, 0, COSE_ES256, pubkey_parsed));  // credential format (UTF-8 string);


	/*static_*/ assert(fido_assert_hmac_secret_len(assertion, 0) == FIDO2_HMAC_SECRET_RESULT_LEN);
	memcpy(ret, fido_assert_hmac_secret_ptr(assertion, 0), sizeof(ret));
	return 0;
}


static const constexpr char alphabet[]       = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
static const constexpr uint8_t alphabet_bits = 6 /*__builtin_ctz(alphabet.size())*/;

static void b64(char *& cur, const uint8_t * beg, const uint8_t * end) {
	uint16_t acc{};
	uint8_t accsz{};
	for(;;) {
		if(accsz < alphabet_bits) {
			if(beg == end)
				break;
			acc |= static_cast<uint16_t>(*beg++) << ((16 - 8) - accsz);
			accsz += 8;
		}

		*cur++ = alphabet[acc >> (16 - alphabet_bits)];
		acc <<= alphabet_bits;
		accsz -= alphabet_bits;
	}

	if(accsz)
		*cur++ = alphabet[acc >> (16 - alphabet_bits)];
}

static ssize_t unb64(uint8_t * out, const char * cur) {
	auto iout = out;
	uint16_t acc{};
	uint8_t accsz{};
	for(;;) {
		while(accsz < 8) {
			if(!*cur)
				return out - iout;

			auto val = std::find(std::begin(alphabet), std::end(alphabet) - 1, *cur++);
			if(val == std::end(alphabet) - 1)
				return errno = EINVAL, -1;

			acc |= static_cast<uint16_t>(val - std::begin(alphabet)) << ((16 - alphabet_bits) - accsz);
			accsz += alphabet_bits;
		}

		*out++ = acc >> (16 - 8);
		acc <<= 8;
		accsz -= 8;
	}
}

// all encoded in base64url but with trailing =s removed
// delimited with :
//
// backup bundles enumerated by .
// each one like main but with with iv and encrypted data
static constexpr size_t len_for_blob(size_t bytes) {
	return ((4 * bytes / 3) + 3) & ~0b11;
}
static size_t len_for_bundle(const fido2_cred_bundle & bundle) {
	return len_for_blob(sizeof(bundle.salt)) + 1 /*:*/ +  //
	       len_for_blob(bundle.cred_id_len) + 1 /*:*/ +   //
	       len_for_blob(bundle.pubkey_len);
}
static void unparse_bundle(char *& cur, const fido2_cred_bundle & bundle) {
	b64(cur, std::begin(bundle.salt), std::end(bundle.salt));
	*cur++ = ':';
	b64(cur, bundle.cred_id, bundle.cred_id + bundle.cred_id_len);
	*cur++ = ':';
	b64(cur, bundle.pubkey, bundle.pubkey + bundle.pubkey_len);
}
int fido2_unparse_prop(char *& prop, const fido2_cred_bundle & bundle, const fido2_backup_cred_bundle * backups, size_t backups_len) {
	prop = TRY_PTR("allocate property value", reinterpret_cast<char *>(malloc(std::accumulate(backups, backups + backups_len, len_for_bundle(bundle),
	                                                                                          [](auto acc, auto && b) {
		                                                                                          return acc + 1 /*.*/ +  //
		                                                                                                 len_for_bundle(b.cred) + 1 /*:*/ +
		                                                                                                 len_for_blob(sizeof(b.iv)) + 1 /*:*/ +
		                                                                                                 len_for_blob(sizeof(b.encrypted));
	                                                                                          }) +
	                                                                          1)));

	auto cur = prop;
	unparse_bundle(cur, bundle);
	for(size_t i = 0; i < backups_len; ++i) {
		*cur++ = '.';
		unparse_bundle(cur, backups[i].cred);
		*cur++ = ':';
		b64(cur, std::begin(backups[i].iv), std::end(backups[i].iv));
		*cur++ = ':';
		b64(cur, std::begin(backups[i].encrypted), std::end(backups[i].encrypted));
	}
	*cur = '\0';
	return 0;
}

static int parse_baseline(fido2_cred_bundle & bundle, char *& rest, const char * dataset_name, char * handle_s, size_t backup_num) {
	auto salt_s    = strtok_r(handle_s, ":", &rest);
	auto cred_id_s = strtok_r(nullptr, ":", &rest);
	auto pubkey_s  = strtok_r(nullptr, ":", &rest);

	if(!salt_s || !cred_id_s || !pubkey_s)
		return backup_num ? fprintf(stderr, gettext("Dataset %s backup %zu's handle: %s.\n"), dataset_name, backup_num, strerror(ENODATA))
		                  : fprintf(stderr, gettext("Dataset %s's handle: %s.\n"), dataset_name, strerror(ENODATA)),
		       __LINE__;

	if(cred_id_s - salt_s != ((sizeof(bundle.salt) * 4 + 2) / 3 + 1))
		return backup_num ? fprintf(stderr, gettext("Dataset %s backup %zu's salt %s: %s.\n"), dataset_name, backup_num, salt_s, strerror(ERANGE))
		                  : fprintf(stderr, gettext("Dataset %s's salt %s: %s.\n"), dataset_name, salt_s, strerror(ERANGE)),
		       __LINE__;

	// verb for "Couldn't format"
	auto salt_len = TRY(gettext("base64-decode salt"), unb64(bundle.salt, salt_s));
	if(salt_len != sizeof(bundle.salt))
		return errno = E2BIG, TRY(gettext("base64-decode salt"), __LINE__);

	auto cred_id       = reinterpret_cast<uint8_t *>(malloc(((pubkey_s - cred_id_s) * 4 + 2) / 3));
	bundle.cred_id     = cred_id;  // verb for "Couldn't format"
	bundle.cred_id_len = TRY(gettext("base64-decode credential ID"), unb64(cred_id, cred_id_s));

	auto pubkey       = reinterpret_cast<uint8_t *>(malloc((strlen(pubkey_s) * 4 + 2) / 3));
	bundle.pubkey     = pubkey;  // verb for "Couldn't format"
	bundle.pubkey_len = TRY(gettext("base64-decode credential public key"), unb64(pubkey, pubkey_s));

	return 0;
}
int fido2_parse_prop(fido2_cred_bundle & bundle, fido2_backup_cred_bundle *& backups, size_t & backups_len, const char * dataset_name, char * handle_s) {
	char *backup_sv, *rest;
	TRY_MAIN(parse_baseline(bundle, rest, dataset_name, strtok_r(handle_s, ".", &backup_sv), 0));
	if(auto s = strtok_r(nullptr, ":", &rest))
		return fprintf(stderr, gettext("Dataset %s's handle %s: %s.\n"), dataset_name, s, strerror(E2BIG)), __LINE__;

	backups     = nullptr;
	backups_len = 0;
	for(char * backup; (backup = strtok_r(nullptr, ".", &backup_sv));) {
		backups  = TRY_PTR("allocate backup", reinterpret_cast<fido2_backup_cred_bundle *>(reallocarray(backups, ++backups_len, sizeof(*backups))));
		auto & b = *new(&backups[backups_len - 1]) fido2_backup_cred_bundle;

		TRY_MAIN(parse_baseline(b.cred, rest, dataset_name, backup, backups_len));
		auto iv_s        = strtok_r(nullptr, ":", &rest);
		auto encrypted_s = strtok_r(nullptr, ":", &rest);
		if(!iv_s || !encrypted_s)
			return fprintf(stderr, gettext("Dataset %s backup %zu's handle: %s.\n"), dataset_name, backups_len, strerror(ENODATA)), __LINE__;
		if(auto s = strtok_r(nullptr, ":", &rest))
			return fprintf(stderr, gettext("Dataset %s backup %zu's handle %s: %s.\n"), dataset_name, backups_len, s, strerror(E2BIG)), __LINE__;

		if(encrypted_s - iv_s != ((sizeof(b.iv) * 4 + 2) / 3 + 1))
			return fprintf(stderr, gettext("Dataset %s backup %zu's salt %s: %s.\n"), dataset_name, backups_len, iv_s, strerror(ERANGE)), __LINE__;
		if(strlen(encrypted_s) != (sizeof(b.encrypted) * 4 + 2) / 3)
			return fprintf(stderr, gettext("Dataset %s backup %zu's encrypted %s: %s.\n"), dataset_name, backups_len, encrypted_s, strerror(ERANGE)), __LINE__;

		// verb for "Couldn't" format
		auto iv_len = TRY(gettext("base64-decode IV"), unb64(b.iv, iv_s));
		if(iv_len != sizeof(b.iv))
			return errno = E2BIG, TRY(gettext("base64-decode IV"), __LINE__);

		// verb for "Couldn't" format
		auto encrypted_len = TRY(gettext("base64-decode encrypted"), unb64(b.encrypted, encrypted_s));
		if(encrypted_len != sizeof(b.encrypted))
			return errno = E2BIG, TRY(gettext("base64-decode encrypted"), __LINE__);
	}

	return 0;
}


// -_ -> +/ (base64url -> base64)
// and =-padding
char * fido2_cred_bundle::cred_id_str() const noexcept {
	size_t padlen = (this->cred_id_len % 3) ? (3 - (this->cred_id_len % 3)) : 0;
	size_t olen   = (this->cred_id_len * 4 + 2) / 3;
	auto obuf     = reinterpret_cast<char *>(malloc(olen + padlen + 1));
	if(!obuf)
		return nullptr;

	auto cur = obuf;
	b64(cur, this->cred_id, this->cred_id + this->cred_id_len);
	std::replace(obuf, cur, '-', '+');
	std::replace(obuf, cur, '_', '/');

	memset(cur, '=', padlen), cur += padlen;
	*cur = '\0';
	return obuf;
}


int fido2ify_passphrase(char *& pin, const uint8_t * buf, size_t len) {
	if(!len) {
		pin = nullptr;
		return 0;
	}

	if(memchr(buf, '\0', len))
		return fprintf(stderr, "PIN can't contain NULs!\n"), __LINE__;

	pin = TRY_PTR("allocate PIN", reinterpret_cast<char *>(malloc(len + 1)));
	memcpy(pin, buf, len);
	pin[len] = '\0';

	return 0;
}

int fido2_prompt_pin(char *& pin, const fido2_device & dev) {
	char what_for[512];
	// %s=device name. noun for "Enter passphrase for" prompt
	snprintf(what_for, sizeof(what_for), gettext("%s PIN"), dev.name);

	uint8_t * buf;
	size_t len;
	TRY_MAIN(read_known_passphrase(what_for, buf, len));
	quickscope_wrapper buf_deleter{[&] { free(buf); }};

	return fido2ify_passphrase(pin, buf, len);
}
