// SPDX-License-Identifier: 0BSD


#pragma once

#include "bag-b64"
#include "bag-messages"
#include "bag-socket"
#include "linkbag-abi.h"
#include "vore-file"
#include "vore-visit"
#include <cwctype>
#include <iterator>
#include <sys/mman.h>


#if __linux__
#define DEFAULT_CLIENT "@linkbag"
#else
#define DEFAULT_CLIENT "/var/run/linkbag.sock"
#endif


namespace bag {
	namespace {
		vore::fd prepsock() {
			vore::fd sock = socket(AF_UNIX, SOCK_DGRAM | SOCK_CLOEXEC, 0);

			sa_family_t family = AF_UNIX;
			if(bind(sock, reinterpret_cast<struct sockaddr *>(&family), sizeof(family)) == -1)  // unix(7) linux autobind
				return {};

			return sock;
		}

		template <class C>
		bool is_all_print(std::basic_string_view<C> str) {
			std::mbstate_t ctx{};
			for(wchar_t c; !str.empty();)
				switch(auto r = std::mbrtowc(&c, reinterpret_cast<const char *>(str.data()), str.size(), &ctx)) {
					case static_cast<std::size_t>(-2):  // incomplete
					case static_cast<std::size_t>(-1):  // EILSEQ
					case 0:
						return false;
					default:
						if(!std::iswprint(c))
							return false;
						str.remove_prefix(r);
						break;
				}
			return true;
		}

		template <class C>
		void dump_name_blob(std::basic_string_view<C> name, bool escape_all, FILE * into) {
			std::uint8_t buf[255 * 4 / 3];
			if(escape_all || name[0] == '\\' || !is_all_print(name)) {
				auto end = buf;
				bag::b64(end, std::begin(name), std::end(name));
				name = {buf, end};
				std::fputc('\\', into);
			}
			std::fwrite(name.data(), 1, name.size(), into);
		}


		struct pascal_strings {
			using iterator_category = std::input_iterator_tag;
			using difference_type   = void;
			using value_type        = std::basic_string_view<std::uint8_t>;
			using pointer           = std::basic_string_view<std::uint8_t> *;
			using reference         = std::basic_string_view<std::uint8_t> &;

			std::basic_string_view<std::uint8_t> remaining;


			pascal_strings & operator++() noexcept {
				remaining.remove_prefix(this->head_len());
				return *this;
			}

			pascal_strings operator++(int) noexcept {
				const auto ret = *this;
				++(*this);
				return ret;
			}

			constexpr auto operator<=>(const pascal_strings & rhs) const noexcept = default;

			constexpr std::basic_string_view<std::uint8_t> operator*() const noexcept {
				if(auto len = this->head_len())
					return this->remaining.substr(1, len - 1);
				else
					return {};
			}

		private:
			constexpr std::size_t head_len() const noexcept {
				if(!this->remaining.empty())
					return std::min(1 + static_cast<std::size_t>(this->remaining[0]), this->remaining.size());
				else
					return 0;
			}
		};

		class ls {
			friend class ls ls(int sock, const struct sockaddr_un & addr, socklen_t addrlen);
			std::basic_string_view<std::uint8_t> map;


		public:
			ls()                       = default;
			ls(const ls &)             = delete;
			ls & operator=(const ls &) = delete;
			ls(ls && oth) : map(std::exchange(oth.map, {})) {}
			ls & operator=(ls && rhs) {
				std::swap(this->map, rhs.map);
				return *this;
			}

			operator bool() const { return this->map.data(); }
			~ls() {
				if(this->map.size())
					munmap(const_cast<std::uint8_t *>(this->map.data()), this->map.size());
			}

			std::basic_string_view<std::uint8_t> take() {
				auto ret  = this->map;
				this->map = {};
				return ret;
			}

			pascal_strings begin() const { return {this->map}; }
			pascal_strings end() const { return {}; }
		};


		template <class T, class F, class G>
		T transact(int sock, const struct sockaddr_un & addr, socklen_t addrlen, int sendfd, F && init, G && visit) {
			alignas(linkbag_req) alignas(linkbag_res) std::uint8_t buf[600];  // header + 255 * 2; arbitrary
			auto & req              = *reinterpret_cast<linkbag_req *>(buf);
			auto & res              = *reinterpret_cast<linkbag_res *>(buf);
			const std::uint64_t seq = reinterpret_cast<std::uintptr_t>(buf);


			vore::fd madesock;
			if(sock == -1)
				sock = madesock = prepsock();
			if(sock == -1)
				return {};

			req.seq  = seq;
			auto end = reinterpret_cast<std::uint8_t *>(init(buf, req));
			iovec buf_i{buf, static_cast<std::size_t>(end - buf)};
			msghdr msg{.msg_name = const_cast<struct sockaddr_un *>(&addr), .msg_namelen = addrlen, .msg_iov = &buf_i, .msg_iovlen = 1};

			// we only care about the int, but SO_PASSCRED (to trigger autobind) may also give us the creds
			std::uint8_t cmsgbuf[CMSG_SPACE(sizeof(cmsgcred)) + CMSG_SPACE(1 * sizeof(int))];
			bag::cmsg_maybe_send_fd(msg, cmsgbuf, sendfd);

			ssize_t rd, wr;
			while((wr = sendmsg(sock, &msg, MSG_NOSIGNAL)) == -1 && errno == EINTR)
				;
			if(wr == -1)
				return {};
			if(static_cast<std::size_t>(wr) != buf_i.iov_len)
				return errno = EIO, T{};


			msg.msg_name       = nullptr;
			msg.msg_namelen    = 0;
			buf_i.iov_len      = sizeof(buf);
			msg.msg_control    = cmsgbuf;
			msg.msg_controllen = sizeof(cmsgbuf);


			while((rd = recvmsg(sock, &msg, MSG_NOSIGNAL | MSG_CMSG_CLOEXEC)) == -1 && errno == EINTR)
				;
			if(rd == -1)
				return {};


			vore::fd recvdfd;
			if(auto err = bag::consume_cmsg(msg, nullptr, recvdfd))
				return errno = err, T{};


			if(msg.msg_flags & MSG_CTRUNC)
				return errno = EMFILE, T{};
			if(msg.msg_flags & MSG_TRUNC)
				return errno = EMSGSIZE, T{};


			auto parsed = parse_response({buf, static_cast<std::size_t>(rd)});
			if(!parsed)
				return {};
			if(responses::need_fd(res) != (recvdfd != -1))
				return errno = EBADSLT, T{};
			if(res.seq != seq)
				return errno = EBADR, T{};

			return visit(std::move(*parsed), std::move(recvdfd));
		}

		[[maybe_unused]] class ls ls(int sock, const struct sockaddr_un & addr, socklen_t addrlen) {
			return transact<class ls>(
			    sock, addr, addrlen, -1,
			    [](auto buf, auto & req) {
				    req.type = LINKBAG_REQ_LS;
				    return buf + sizeof(std::uint64_t) + sizeof(std::uint8_t);
			    },
			    [&](auto && parsed, auto && recvdfd) {
				    class ls ret;
				    std::visit(vore::overload{
				                   [&](bag::responses::error err) { errno = err ?: EBADE; },
				                   [&](bag::responses::got &) { errno = EBADRQC; },
				                   [&](bag::responses::names) {
					                   if(auto flags = fcntl(recvdfd, F_GET_SEALS);
					                      flags == -1 || !(flags & F_SEAL_GROW) || !(flags & F_SEAL_SHRINK) || !(flags & (F_SEAL_WRITE | F_SEAL_FUTURE_WRITE))) {
						                   if(flags != -1)
							                   errno = EBADFD;
						                   return;
					                   }

					                   auto size = lseek(recvdfd, 0, SEEK_END);
					                   if(size == -1)
						                   return;

					                   if(!size)
						                   errno = 0;
					                   else  //
						                   if(auto addr = mmap(nullptr, size, PROT_READ, MAP_PRIVATE, recvdfd, 0); addr != MAP_FAILED)
							                   ret.map = {reinterpret_cast<std::uint8_t *>(addr), static_cast<std::size_t>(size)};
				                   },
				               },
				               parsed);
				    return ret;
			    });
		}

#define VALIDATE_NAME(inval)      \
	static_assert(sizeof(C) == 1);  \
	if(name.empty())                \
		return errno = EINVAL, inval; \
	if(name.size() > 255)           \
	return errno = ENAMETOOLONG, inval
		template <class C>
		bool del(int sock, const struct sockaddr_un & addr, socklen_t addrlen, std::basic_string_view<C> name, std::optional<std::basic_string_view<C>> ifblob) {
			VALIDATE_NAME(false);
			if(ifblob && ifblob->size() > 255)
				return errno = E2BIG, false;

			return transact<bool>(
			    sock, addr, addrlen, -1,
			    [&](auto buf, auto & req) {
				    req.type = ifblob ? LINKBAG_REQ_DELIF : LINKBAG_REQ_DEL;
				    buf += sizeof(std::uint64_t) + sizeof(std::uint8_t);
				    if(ifblob) {
					    *buf++ = ifblob->size();
					    buf    = reinterpret_cast<std::uint8_t *>(mempcpy(buf, ifblob->data(), ifblob->size()));
				    }
				    return mempcpy(buf, name.data(), name.size());
			    },
			    [&](auto && parsed, auto &&) {
				    std::visit(vore::overload{[&](bag::responses::error err) { errno = err; }, [&](auto &) { errno = EBADRQC; }}, parsed);
				    return !errno;
			    });
		}

		template <class C, class F>
		std::optional<int> get(int sock, const struct sockaddr_un & addr, socklen_t addrlen, std::basic_string_view<C> name, F && blobisitor) {
			VALIDATE_NAME(std::nullopt);

			return transact<std::optional<int>>(
			    sock, addr, addrlen, -1,
			    [&](auto buf, auto & req) {
				    req.type = LINKBAG_REQ_GET;
				    return mempcpy(buf + sizeof(std::uint64_t) + sizeof(std::uint8_t), name.data(), name.size());
			    },
			    [&](auto && parsed, auto && recvdfd) {
				    return std::visit(vore::overload{[&](bag::responses::error err) -> std::optional<int> {
					                                     errno = err ?: EBADE;
					                                     return {};
				                                     },
				                                     [&](bag::responses::got & got) -> std::optional<int> {
					                                     blobisitor(got);
					                                     return recvdfd.take();
				                                     },
				                                     [&](bag::responses::names) -> std::optional<int> {
					                                     errno = EBADRQC;
					                                     return {};
				                                     }},
				                      parsed);
			    });
		}

		template <class C>
		bool put(int sock, const struct sockaddr_un & addr, socklen_t addrlen, std::basic_string_view<C> name, std::basic_string_view<C> blob, int fd, bool force) {
			VALIDATE_NAME(false);
			if(blob.size() > 255)
				return errno = E2BIG, false;

			return transact<bool>(
			    sock, addr, addrlen, fd,
			    [&](auto buf, auto & req) {
				    req.type = LINKBAG_REQ_PUT;
				    buf += sizeof(std::uint64_t) + sizeof(std::uint8_t);
				    *buf++ = (force ? LINKBAG_REQ_PUT_REPLACE : 0);
				    *buf++ = blob.size();
				    return mempcpy(mempcpy(buf, blob.data(), blob.size()), name.data(), name.size());
			    },
			    [&](auto && parsed, auto &&) {
				    std::visit(vore::overload{[&](bag::responses::error err) { errno = err; }, [&](auto &) { errno = EBADRQC; }}, parsed);
				    return !errno;
			    });
		}
	}
}
