// SPDX-License-Identifier: 0BSD


#pragma once

#include "vore-file"
#include <cstring>
#include <err.h>
#include <errno.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <utility>


#if __linux__
#define UNIX_OR_ANON "unix-socket|@anonymous-socket"
#else
#define UNIX_OR_ANON "unix-socket"
#endif


#if __linux__
using cmsgcred = ucred;
#define SCM_CREDS SCM_CREDENTIALS
#define cmcred_pid pid
#define cmcred_uid uid
#define cmcred_gid gid

#define SOL_CREDS_LEVEL SOL_SOCKET
#elif __NetBSD__
// unix(4) recommends SOCKCREDSIZE(ngroups), but we use none, so we're fine with the default (1)
using cmsgcred = sockcred;
#define cmcred_pid sc_pid
#define cmcred_uid sc_uid
#define cmcred_gid sc_gid

#define SOL_CREDS_LEVEL 0
#define SO_PASSCRED LOCAL_CREDS
#elif __OpenBSD__
using cmsgcred = sockpeercred;
#define cmcred_pid pid
#define cmcred_uid uid
#define cmcred_gid gid
#elif __APPLE__
struct cmsgcred {
	uid_t cmcred_uid;
	gid_t cmcred_gid;
	pid_t cmcred_pid;

	cmsgcred & operator=(const xucred & ucred) {
		this->cmcred_uid = ucred.cr_uid;
		this->cmcred_gid = ucred.cr_gid;
		return *this;
	}
	operator bool() { return true; }
};
#endif


namespace bag {
	namespace {
		template <class = void>
		std::pair<struct sockaddr_un, socklen_t> parse_unix(const char * unix) {
			struct sockaddr_un addr{.sun_family = AF_UNIX};

			auto pathlen = std::strlen(unix);
			if(pathlen > sizeof(addr.sun_path))
				return errno = ENAMETOOLONG, std::pair{addr, 0};

			std::memcpy(addr.sun_path, unix, pathlen);
#if __linux__
			if(addr.sun_path[0] == '@')  // linux abstract socket
				addr.sun_path[0] = '\0';
#endif

			return {addr, sizeof(addr.sun_family) + pathlen};
		}


		template <class = void>
		int consume_cmsg(msghdr & msg, cmsgcred * process_creds, vore::fd & recvdfd) {
			int err{};
			for(auto cmsg = CMSG_FIRSTHDR(&msg); cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg))
				if(cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
					auto intsize = cmsg->cmsg_len - sizeof(struct cmsghdr);

					auto cur = reinterpret_cast<int *>(CMSG_DATA(cmsg));
					for(auto ints = intsize / sizeof(int); ints; --ints, ++cur) {
						int fd;
						std::memcpy(&fd, cur, sizeof(fd));
						if(recvdfd != -1)
							err = EOVERFLOW;
						recvdfd = {fd};
					}
				} else if(cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_CREDS && process_creds) {
					if(process_creds->cmcred_pid != -1) {
						warnx("cmsg: double SCM_CREDS");
						err = EOVERFLOW;
					}
					std::memcpy(process_creds, CMSG_DATA(cmsg), sizeof(*process_creds));
				}
			return err;
		}

		template <std::size_t N>
		void cmsg_maybe_send_fd(msghdr & msg, std::uint8_t (&cmsgbuf)[N], int sendfd) {
			if(sendfd != -1) {
				msg.msg_control    = cmsgbuf;
				msg.msg_controllen = sizeof(cmsgbuf);

				auto cmsg        = CMSG_FIRSTHDR(&msg);
				cmsg->cmsg_level = SOL_SOCKET;
				cmsg->cmsg_type  = SCM_RIGHTS;
				cmsg->cmsg_len   = CMSG_LEN(1 * sizeof(int));
				std::memcpy(CMSG_DATA(cmsg), &sendfd, sizeof(int));

				msg.msg_controllen = cmsg->cmsg_len;  // total size of all control blocks
			} else {
				msg.msg_control    = nullptr;
				msg.msg_controllen = 0;
			}
		}
	}
}
