// SPDX-License-Identifier: 0BSD


#pragma once

#include <algorithm>
#include <cstdint>
#include <err.h>
#include <iterator>


// fzifdso src/fido2.cpp; base64 alphabet
namespace bag {
	namespace {
		const constexpr char alphabet[]            = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
		const constexpr std::uint8_t alphabet_bits = 6 /*__builtin_ctz(alphabet.size())*/;

		template <class C>
		void b64(C *& cur, const std::uint8_t * beg, const std::uint8_t * end) {
			std::uint16_t acc{};
			std::uint8_t accsz{};
			for(;;) {
				if(accsz < alphabet_bits) {
					if(beg == end)
						break;
					acc |= static_cast<std::uint16_t>(*beg++) << ((16 - 8) - accsz);
					accsz += 8;
				}

				*cur++ = alphabet[acc >> (16 - alphabet_bits)];
				acc <<= alphabet_bits;
				accsz -= alphabet_bits;
			}

			if(accsz)
				*cur++ = alphabet[acc >> (16 - alphabet_bits)];
		}

		template <class C>
		ssize_t unb64(std::uint8_t * out, std::uint8_t * outend, const C * cur) {
			auto iout = out;
			std::uint16_t acc{};
			std::uint8_t accsz{};
			for(;;) {
				while(accsz < 8) {
					if(*cur == '=') {
						while(*cur == '=')
							++cur;
						if(*cur)
							return errno = EINVAL, -1;
					}
					if(!*cur)
						return out - iout;

					auto val = std::find(std::begin(alphabet), std::end(alphabet) - 1, *cur++);
					if(val == std::end(alphabet) - 1)
						return errno = EINVAL, -1;

					acc |= static_cast<std::uint16_t>(val - std::begin(alphabet)) << ((16 - alphabet_bits) - accsz);
					accsz += alphabet_bits;
				}

				if(out == outend)
					return errno = ENAMETOOLONG, -1;
				*out++ = acc >> (16 - 8);
				acc <<= 8;
				accsz -= 8;
			}
		}


		template <class = void>
		std::basic_string_view<std::uint8_t> parse_arg(const char * arg, std::uint8_t (&decode_buf)[255]) {
			auto arg_a = reinterpret_cast<const std::uint8_t *>(arg);
			if(*arg_a == '\\') {
				if(auto len = bag::unb64(std::begin(decode_buf), std::end(decode_buf), arg_a + 1); len != -1)
					return {decode_buf, static_cast<std::size_t>(len)};
				else
					err(errno, "%s", arg);
			} else
				return arg_a;
		}
	}
}
