// The MIT License (MIT)

// Copyright (c) 2020 наб <nabijaczleweli@nabijaczleweli.xyz>

// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:

// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.

// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


#include "config.hpp"
#include "context.hpp"
#include "context_detail.hpp"
#include "quickscope_wrapper.hpp"
#include <algorithm>
#include <fcntl.h>
#include <openssl/sha.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <ucs2.h>
#include <unistd.h>
extern "C" {
#include <efivar/efiboot.h>
}


#define TRY_OPT(...)              \
	if(auto err = __VA_ARGS__; err) \
		return err;


using sha_t = std::uint8_t[20];


static constexpr bool isslash(char c) {
	return c == '\\' || c == '/';
}


void klapki::context::detail::print_devpath(const efidp_data * dp, ssize_t dp_len) {
	const auto size = efidp_format_device_path(nullptr, 0, dp, dp_len);
	if(size < 0)
		fmt::print("couldn't format?\n");
	else {
		std::string path(size, '\0');
		efidp_format_device_path(path.data(), path.size(), dp, dp_len);
		fmt::print("{}\n", path);
	}
}


std::optional<std::string> klapki::context::context::save(const config & cfg, state::state & state) {
	std::vector<std::uint8_t> esp_devpath_raw;
	efidp_data * esp_devpath{};
	if(!this->our_kernels.empty()) {
		do {
			esp_devpath_raw.resize(esp_devpath_raw.size() + 128);

			// extern ssize_t efi_generate_file_device_path(uint8_t *buf, ssize_t size,
			// 	      const char * const filepath,
			// 	      uint32_t options, ...)
			// EFIBOOT_ABBREV_HD matches what's produced by bootctl(1) install, and produces just HD()/File(),
			// however, this funxion requires the File() to exist, so by passing just the ESP root, we can append our potentially-not-yet-existent paths later on.
			if(auto size = efi_generate_file_device_path(esp_devpath_raw.data(), esp_devpath_raw.size(),  //
			                                             fmt::format("{}/", cfg.esp).c_str(),             //
			                                             EFIBOOT_ABBREV_HD);
			   size >= 0)
				esp_devpath_raw.resize(size);
			else if(errno != ENOSPC)
				return fmt::format("Making device path for ESP: {}", strerror(errno));
		} while(errno == ENOSPC);


		esp_devpath = reinterpret_cast<efidp_data *>(esp_devpath_raw.data());
		{  // esp_devpath is currently HD(some path)/File("\"). Trim it to just HD() for appending later
			efidp_data * fnode{};
			if(efidp_next_node(esp_devpath, const_cast<const efidp_data **>(&fnode)) != 1)
				throw __func__;

			fnode->type    = EFIDP_END_TYPE;
			fnode->subtype = EFIDP_END_ENTIRE;
		}


		if(cfg.verbose) {
			fmt::print("ESP devpath: ");
			detail::print_devpath(reinterpret_cast<const efidp_data *>(esp_devpath_raw.data()), esp_devpath_raw.size());
		}
	}


	for(auto && [bootnum, kern] : this->our_kernels) {
		auto skern = std::find_if(std::begin(state.statecfg.wanted_entries), std::end(state.statecfg.wanted_entries),
		                          [bn = bootnum](auto && skern) { return skern.bootnum_hint == bn; });
		if(skern == std::end(state.statecfg.wanted_entries))
			throw __func__;
		auto bent = state.entries.find(bootnum);
		if(bent == std::end(state.entries))
			throw __func__;

		auto image_path = fmt::format("{}{}{}{}", isslash(kern.image_path.first.front()) ? "" : "\\", kern.image_path.first,
		                              isslash(kern.image_path.first.back()) ? "" : "\\", kern.image_path.second);
		image_path.erase(std::remove_if(std::begin(image_path), std::end(image_path),
		                                [prev = false](auto c) mutable {
			                                auto cur = isslash(c);
			                                if(prev && cur)
				                                return true;
			                                else {
				                                prev = cur;
				                                return false;
			                                }
		                                }),
		                 std::end(image_path));
		std::replace(std::begin(image_path), std::end(image_path), '/', '\\');

		std::vector<std::uint8_t> devpath_file_node(efidp_make_file(nullptr, 0, image_path.data()));
		if(efidp_make_file(devpath_file_node.data(), devpath_file_node.size(), image_path.data()) < 0)
			return fmt::format("Entry {:04X}: creating devpath File(): {}", bootnum, strerror(errno));

		efidp_data * devpath;
		if(efidp_append_node(esp_devpath, reinterpret_cast<const efidp_data *>(devpath_file_node.data()), &devpath) < 0)
			return fmt::format("Entry {:04X}: creating appending File(): {}", bootnum, strerror(errno));
		quickscope_wrapper devpath_deleter{[&] { std::free(devpath); }};
		const auto devpath_len = efidp_size(devpath);

		if(cfg.verbose) {
			fmt::print("Entry {:04X} devpath: ", bootnum);
			detail::print_devpath(devpath, devpath_len);
		}

		// Must be at start, we use position in derive() to match extraneous ones from cmdline
		std::string templine{};
		std::string_view prev = kern.image_path.first;
		for(auto && ipath : kern.initrd_paths) {
			if(ipath.first)
				prev = *ipath.first;
			fmt::format_to(std::back_inserter(templine), "initrd={}{}{} ", prev, (prev.back() == '\\' || ipath.second.front() == '\\') ? "" : "\\", ipath.second);
		}
		templine += kern.cmdline;
		if(cfg.verbose)
			fmt::print("Entry {:04X} cmdline: {}\n", bootnum, templine);

		std::vector<std::uint16_t> cmdline(utf8len(reinterpret_cast<std::uint8_t *>(templine.data()), templine.size()));
		if(utf8_to_ucs2(cmdline.data(), cmdline.size() * sizeof(std::uint16_t), false, reinterpret_cast<std::uint8_t *>(templine.data())) < 0)
			return fmt::format("Entry {:04X}: UCS-2ing cmdline: {}", bootnum, strerror(errno));

		// extern ssize_t efi_loadopt_create(uint8_t *buf, ssize_t size,
		//				  uint32_t attributes, efidp dp,
		//				  ssize_t dp_size, unsigned char *description,
		//				  uint8_t *optional_data,
		//				  size_t optional_data_size)
		bent->second.load_option_len = efi_loadopt_create(nullptr, 0,                                                  //
		                                                  bent->second.attributes,                                     //
		                                                  devpath, devpath_len,                                        //
		                                                  reinterpret_cast<unsigned char *>(kern.description.data()),  //
		                                                  reinterpret_cast<std::uint8_t *>(cmdline.data()), cmdline.size() * sizeof(std::uint16_t));

		bent->second.load_option = std::shared_ptr<std::uint8_t[]>{new std::uint8_t[bent->second.load_option_len]};
		if(efi_loadopt_create(bent->second.load_option.get(), bent->second.load_option_len,  //
		                      bent->second.attributes,                                       //
		                      devpath, devpath_len,                                          //
		                      reinterpret_cast<unsigned char *>(kern.description.data()),    //
		                      reinterpret_cast<std::uint8_t *>(cmdline.data()), cmdline.size() * sizeof(std::uint16_t)) < 0)
			return fmt::format("Making load option for {:04X}: {}", bootnum, strerror(errno));

		SHA1(bent->second.load_option.get(), bent->second.load_option_len, bent->second.load_option_sha);
		if(std::memcmp(bent->second.load_option_sha, skern->load_option_sha, sizeof(sha_t)))
			fmt::print("Entry {:04X} changed\n", bootnum);
		std::memcpy(skern->load_option_sha, bent->second.load_option_sha, sizeof(sha_t));
	}


	if(cfg.verbose)
		fmt::print("Bootorder pre : {}\n", state.order);
	state.order = std::visit(
	    klapki::overload{
	        [&](klapki::state::boot_order_flat && bof) {
		        fmt::print(stderr, "wisen(): flat bootorder?\n");  // Weird, but that's what we want anyway
		        return std::move(bof);
	        },
	        [&](klapki::state::boot_order_structured && bos) {
		        std::vector<std::uint16_t> bents[2];
		        for(auto && [cluster, ours] : std::move(bos.order))
			        if(bents[ours].empty())
				        bents[ours] = std::move(cluster);
			        else
				        bents[ours].insert(std::end(bents[ours]), std::begin(cluster), std::end(cluster));

		        const auto target_pos = std::min(bents[false].size(), static_cast<std::size_t>(state.statecfg.boot_position));
		        if(bents[false].size() < state.statecfg.boot_position)
			        fmt::print(stderr, "Not enough entries to be at position {}. Being at {} instead.\n", state.statecfg.boot_position, target_pos);

		        const auto size = bents[false].size() + bents[true].size();
		        const std::shared_ptr<std::uint16_t[]> flat{new std::uint16_t[size]};

		        // By biggest version, then variant index.
		        std::sort(std::begin(bents[true]), std::end(bents[true]), [&](auto && lhs, auto && rhs) {
			        auto lskern = std::find_if(std::begin(state.statecfg.wanted_entries), std::end(state.statecfg.wanted_entries),
			                                   [&](auto && skern) { return skern.bootnum_hint == lhs; });
			        if(lskern == std::end(state.statecfg.wanted_entries))
				        throw __func__;
			        auto rskern = std::find_if(std::begin(state.statecfg.wanted_entries), std::end(state.statecfg.wanted_entries),
			                                   [&](auto && skern) { return skern.bootnum_hint == rhs; });
			        if(rskern == std::end(state.statecfg.wanted_entries))
				        throw __func__;

			        if(lskern->version != rskern->version)
				        return lskern->version > rskern->version;

			        auto lvar = lskern->variant == "" ? -1 : std::find_if(std::begin(state.statecfg.variants), std::end(state.statecfg.variants), [&](auto && var) {
				                                                 return var == lskern->variant;
			                                                 }) - std::begin(state.statecfg.variants);
			        auto rvar = rskern->variant == "" ? -1 : std::find_if(std::begin(state.statecfg.variants), std::end(state.statecfg.variants), [&](auto && var) {
				                                                 return var == rskern->variant;
			                                                 }) - std::begin(state.statecfg.variants);

			        return lvar < rvar;
		        });

		        auto curs = std::copy_n(bents[false].data(), target_pos, flat.get());
		        curs      = std::copy_n(bents[true].data(), bents[true].size(), curs);
		        curs      = std::copy_n(bents[false].data() + target_pos, bents[false].size() - target_pos, curs);
		        if(curs != flat.get() + size)  // This is an assert() but asserts blow ass, so it's a throw instead
			        throw __func__;

		        return state::boot_order_flat{flat, size};
	        },
	    },
	    std::move(state.order));
	if(cfg.verbose)
		fmt::print("Bootorder post: {}\n", state.order);

	return {};
}
