// The MIT License (MIT)

// Copyright (c) 2020 наб <nabijaczleweli@nabijaczleweli.xyz>

// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:

// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.

// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


#include "config.hpp"
#include "context.hpp"
#include "context_detail.hpp"
#include "util.hpp"
#include <algorithm>
#include <cstring>
#include <fmt/format.h>
#include <iterator>
#include <numeric>
#include <openssl/sha.h>


using sha_t = std::uint8_t[20];


std::variant<klapki::state::state, std::string> klapki::context::resolve_state_context(const config & cfg, const state::state & input_state) {
	std::map<std::uint16_t, state::boot_entry> entries{input_state.entries};
	// When measured (KVM, amd64 on amd64, Xeon E5645 @ 2.40GHz, 6 vCPUs), this took 23`636ns all told for 14 SHAs (+outlier at 207`817);
	// multithreading is much much more expensive: 2`633`277.86ns on same dataset when spawning a std::thread per.
	std::for_each(std::begin(entries), std::end(entries),
	              [](auto & bent) { SHA1(bent.second.load_option.get(), bent.second.load_option_len, bent.second.load_option_sha); });


	state::stated_config statecfg{input_state.statecfg.boot_position, input_state.statecfg.variants, {}};
	statecfg.wanted_entries.reserve(input_state.statecfg.wanted_entries.size());


	// Uniquify by SHA
	{
		std::vector<std::pair<std::uint16_t, const std::uint8_t *>> shas;
		shas.reserve(statecfg.wanted_entries.size());
		statecfg.wanted_entries.erase(
		    std::remove_if(std::begin(statecfg.wanted_entries), std::end(statecfg.wanted_entries),
		                   [&](auto && went) {
			                   if(auto dupe = std::find_if(std::begin(shas), std::end(shas),
			                                               [&](auto && stored_sha) { return !std::memcmp(went.load_option_sha, stored_sha.second, sizeof(sha_t)); });
			                      dupe != std::end(shas)) {
				                   fmt::print(stderr, "Entry {:04X}: duplicate SHA {} with entry {:04X}. Dropping\n", went.bootnum_hint,
				                              detail::sha_f{went.load_option_sha}, dupe->first);
				                   return true;
			                   }

			                   shas.emplace_back(went.bootnum_hint, went.load_option_sha);
			                   return false;
		                   }),
		    std::end(statecfg.wanted_entries));
	}


	// Match wanted entries to boot entries
	std::vector<std::pair<sha_t, std::uint16_t>> remaps;
	std::copy_if(std::begin(input_state.statecfg.wanted_entries), std::end(input_state.statecfg.wanted_entries), std::back_inserter(statecfg.wanted_entries),
	             [&](auto && went) {
		             if(auto bent = entries.find(went.bootnum_hint); bent != std::end(entries)) {
			             if(!std::memcmp(bent->second.load_option_sha, went.load_option_sha, sizeof(sha_t))) {
				             if(cfg.verbose)
					             fmt::print("Entry {:04X} matches\n", went.bootnum_hint);
				             return true;
			             } else
				             fmt::print(stderr, "Entry {:04X}: mismatched hash, searching elsewhere\n", went.bootnum_hint);
		             } else
			             fmt::print(stderr, "Entry {:04X} doesn't exist; moved?, searching elsewhere\n", went.bootnum_hint);

		             if(auto bent = std::find_if(begin(entries), end(entries),
		                                         [&](const auto & bent) { return !std::memcmp(bent.second.load_option_sha, went.load_option_sha, sizeof(sha_t)); });
		                bent != end(entries)) {
			             fmt::print(stderr, "Found entry formerly {:04X} at {:04X}.\n", went.bootnum_hint, bent->first);

			             std::pair<sha_t, std::uint16_t> remap;
			             memcpy(remap.first, went.load_option_sha, sizeof(sha_t));
			             remap.second = bent->first;
			             remaps.emplace_back(std::move(remap));

			             return true;
		             } else {
			             fmt::print(stderr, "Entry formerly {:04X} ({}) not found. Abandoning.\n", went.bootnum_hint, detail::sha_f{went.load_option_sha});
			             return false;
		             }
	             });

	for(auto && [sha, bootnum] : remaps)
		if(auto went = std::find_if(std::begin(statecfg.wanted_entries), std::end(statecfg.wanted_entries),
		                            [&, &sha = sha](auto && went) { return !std::memcmp(went.load_option_sha, sha, sizeof(sha_t)); });
		   went != std::end(statecfg.wanted_entries))
			went->bootnum_hint = bootnum;


	// Structure boot order
	auto boot_order = std::visit(overload{[&](const state::boot_order_flat & bof) {
		                                      state::boot_order_structured ret{};

		                                      bool ours = false;
		                                      std::vector<std::uint16_t> bootnums;
		                                      for(std::uint16_t * cur = bof.order.get(); cur < bof.order.get() + bof.order_cnt; ++cur) {
			                                      bool our = std::find_if(std::begin(statecfg.wanted_entries), std::end(statecfg.wanted_entries), [&](auto && went) {
				                                                 return went.bootnum_hint == *cur;
			                                                 }) != std::end(statecfg.wanted_entries);

			                                      if(our == ours)
				                                      bootnums.emplace_back(*cur);
			                                      else {
				                                      std::vector<std::uint16_t> bns{*cur};
				                                      bns.swap(bootnums);
				                                      if(!bns.empty())
					                                      ret.order.emplace_back(std::move(bns), ours);
				                                      ours = our;
			                                      }
		                                      }

		                                      if(!bootnums.empty())
			                                      ret.order.emplace_back(std::move(bootnums), ours);

		                                      return ret;
	                                      },
	                                      [](const state::boot_order_structured & bos) { return bos; }},
	                             input_state.order);


	return state::state{std::move(boot_order), std::move(entries), std::move(statecfg)};
}
