// SPDX-License-Identifier: MIT


#include "state.hpp"
#include <algorithm>
#include <endian.h>
#include <fmt/format.h>
#if !KLAPKI_NO_ZLIB
#include <zlib.h>
#endif

using namespace std::literals;


#define SUB "\x1A"sv


void klapki::state::stated_config::parse(klapki::state::stated_config & into, bool may_decompress, const void * _data, std::size_t size) {
	if(size < 2) {
		fmt::print(stderr, fgettext("Parsing state: cut off before boot position\n"));
		return;
	}

	auto data = reinterpret_cast<const char *>(_data);
	std::memcpy(&into.boot_position, data, sizeof(into.boot_position));
	into.boot_position = be16toh(into.boot_position);
	data += sizeof(into.boot_position);
	size -= sizeof(into.boot_position);

#if !KLAPKI_NO_ZLIB
	if(may_decompress && into.boot_position == 0xFFFF && size > sizeof(std::uint16_t)) {
		std::uint16_t uncompressed_size;
		std::memcpy(&uncompressed_size, data, sizeof(uncompressed_size));
		uncompressed_size = be16toh(uncompressed_size);
		data += sizeof(uncompressed_size);
		size -= sizeof(uncompressed_size);

		std::uint8_t uncompressed[0xFFFF];
		uLongf uncompressed_len_out = uncompressed_size;
		uLong compressed_len_out    = size;
		if(uncompress2(uncompressed, &uncompressed_len_out, reinterpret_cast<const Bytef *>(data), &compressed_len_out) == Z_OK &&
		   uncompressed_len_out == uncompressed_size && compressed_len_out == size) {
			parse(into, false, uncompressed, uncompressed_size);
			return;
		}

		data -= sizeof(uncompressed_size);
		size += sizeof(uncompressed_size);
	}
#endif

	for(;;) {
		const auto variant_end = std::find(data, data + size, '\0');
		if(variant_end == data + size) {
			fmt::print(stderr, fngettext("Parsing state: cut off after {} variant\n", "Parsing state: cut off after {} variants\n", into.variants.size()),
			           into.variants.size());
			data = data + size;
			size = 0;
			break;
		}

		std::string variant{data, static_cast<std::size_t>(variant_end - data)};
		data += variant.size() + 1;  // NUL
		size -= variant.size() + 1;  // NUL

		if(std::find(std::begin(into.variants), std::end(into.variants), variant) != std::end(into.variants)) {
			fmt::print(stderr, fgettext("Duplicate variant {}?"), variant);
			continue;
		}

		if(variant.empty())
			break;
		else
			into.variants.emplace_back(std::move(variant));
	}

	while(size != 0) {
		stated_config_entry new_entry{};
		if(size <= sizeof(new_entry.bootnum_hint) + sizeof(new_entry.load_option_sha)) {
			fmt::print(stderr, fngettext("Parsing state: cut off after {} entry\n", "Parsing state: cut off after {} entries\n", into.wanted_entries.size()),
			           into.wanted_entries.size());
			break;
		}

		std::memcpy(&new_entry.bootnum_hint, data, sizeof(new_entry.bootnum_hint));
		new_entry.bootnum_hint = be16toh(new_entry.bootnum_hint);
		data += sizeof(new_entry.bootnum_hint);
		size -= sizeof(new_entry.bootnum_hint);

		std::memcpy(&new_entry.load_option_sha, data, sizeof(new_entry.load_option_sha));
		data += sizeof(new_entry.load_option_sha);
		size -= sizeof(new_entry.load_option_sha);

		const auto kver_end = std::find(data, data + size, '\0');
		if(kver_end == data + size) {
			fmt::print(stderr, fgettext("Parsing state: cut off reading kernel version for entry {:04X}\n"), new_entry.bootnum_hint);
			break;
		}
		new_entry.version.assign(data, kver_end - data);
		data += new_entry.version.size() + 1;  // NUL
		size -= new_entry.version.size() + 1;  // NUL

		const auto kvar_end = std::find(data, data + size, '\0');
		if(kvar_end == data + size) {
			fmt::print(stderr, fgettext("Parsing state: cut off reading kernel variant for entry {:04X}\n"), new_entry.bootnum_hint);
			break;
		}
		new_entry.variant.assign(data, kvar_end - data);
		data += new_entry.variant.size() + 1;  // NUL
		size -= new_entry.variant.size() + 1;  // NUL

		const auto kdir_end = std::find(data, data + size, '\0');
		if(kdir_end == data + size) {
			fmt::print(stderr, fgettext("Parsing state: cut off reading kernel image directory for entry {:04X}\n"), new_entry.bootnum_hint);
			break;
		}
		new_entry.kernel_dirname.assign(data, kdir_end - data);
		data += new_entry.kernel_dirname.size() + 1;  // NUL
		size -= new_entry.kernel_dirname.size() + 1;  // NUL

		if(size <= sizeof(new_entry.kernel_image_sha)) {
			fmt::print(stderr, fgettext("Parsing state: cut off reading kernel image SHA for entry {:04X}\n"), new_entry.bootnum_hint);
			break;
		}
		std::memcpy(&new_entry.kernel_image_sha, data, sizeof(new_entry.kernel_image_sha));
		data += sizeof(new_entry.kernel_image_sha);
		size -= sizeof(new_entry.kernel_image_sha);

		for(;;) {
			const auto idir_end = std::find(data, data + size, '\0');
			if(idir_end == data + size) {
				fmt::print(stderr,
				           fngettext("Parsing state: cut off after {} initrd for entry {:04X}\n", "Parsing state: cut off after {} initrds for entry {:04X}\n",
				                     new_entry.initrd_dirnames.size()),
				           new_entry.initrd_dirnames.size(), new_entry.bootnum_hint);
				goto end;  // break outer loop; 2020 and C++ does not have this
			}

			std::string idir{data, static_cast<std::size_t>(idir_end - data)};
			data += idir.size() + 1;  // NUL
			size -= idir.size() + 1;  // NUL

			if(idir.empty())
				break;

			shaa_t idir_sha;
			if(size <= idir_sha.size()) {
				fmt::print(stderr, fgettext("Parsing state: cut off reading SHA for initrd {} for entry {:04X}\n"), new_entry.initrd_dirnames.size(),
				           new_entry.bootnum_hint);
				break;
			}
			std::memcpy(&idir_sha[0], data, idir_sha.size());
			data += idir_sha.size();
			size -= idir_sha.size();

			if(idir == SUB)
				new_entry.initrd_dirnames.emplace_back(nonbase_dirname_t{}, idir_sha);
			else
				new_entry.initrd_dirnames.emplace_back(std::move(idir), idir_sha);
		}

		into.wanted_entries.emplace_back(std::move(new_entry));
	}
end:

	return;
}


std::vector<std::uint8_t> klapki::state::stated_config::serialise(bool may_compress) const {
	std::vector<std::uint8_t> ret;

	auto bp  = htobe16(this->boot_position);
	auto cur = std::copy(reinterpret_cast<std::uint8_t *>(&bp), reinterpret_cast<std::uint8_t *>(&bp) + sizeof(bp), std::back_inserter(ret));

	for(auto && var : this->variants)
		cur = std::copy(var.data(), var.data() + var.size() + 1, cur);
	*cur++ = '\0';

	for(auto && went : this->wanted_entries) {
		auto bh = htobe16(went.bootnum_hint);

		cur    = std::copy(reinterpret_cast<std::uint8_t *>(&bh), reinterpret_cast<std::uint8_t *>(&bh) + sizeof(bh), cur);
		cur    = std::copy(went.load_option_sha, went.load_option_sha + sizeof(sha_t), cur);
		cur    = std::copy(std::begin(went.version), std::end(went.version), cur);
		*cur++ = '\0';
		cur    = std::copy(std::begin(went.variant), std::end(went.variant), cur);
		*cur++ = '\0';
		cur    = std::copy(std::begin(went.kernel_dirname), std::end(went.kernel_dirname), cur);
		*cur++ = '\0';
		cur    = std::copy(std::begin(went.kernel_image_sha), std::end(went.kernel_image_sha), cur);

		for(auto && [idir, isha] : went.initrd_dirnames) {
			if(idir)
				cur = std::copy(std::begin(*idir), std::end(*idir), cur);
			else
				cur = std::copy(std::begin(SUB), std::end(SUB), cur);
			*cur++ = '\0';

			cur = std::copy(std::begin(isha), std::end(isha), cur);
		}
		*cur++ = '\0';
	}

#if !KLAPKI_NO_ZLIB
	// must always compress with boot_position=0xFFFF since we overload it to be an "is-gzipped" flag
	if(may_compress && (this->boot_position == 0xFFFF || ret.size() > 2 + 1) && ret.size() <= 0xFFFF) {
		std::vector<std::uint8_t> compret(this->boot_position == 0xFFFF ? compressBound(2 + sizeof(std::uint16_t) + ret.size()) : (ret.size() - 1));
		compret[0] = compret[1] = 0xFF;  // boot_position = 0xFFFF

		auto orig_len = htobe16(ret.size());
		std::memcpy(compret.data() + 2, &orig_len, sizeof(orig_len));

		uLongf compret_len = compret.size() - 2 - sizeof(orig_len);
		if(compress2(compret.data() + 2 + sizeof(orig_len), &compret_len, ret.data(), ret.size(), Z_BEST_COMPRESSION) == Z_OK) {
			compret.resize(2 + sizeof(orig_len) + compret_len);
			return compret;
		}
	}
#endif

	return ret;
}
