#include "EXTERN.h"
#include "perl.h"
#define NO_XSLOCKS
#include "XSUB.h"

#include <fcntl.h>
#include <errno.h>
#include <stdbool.h>
#include <stdlib.h>
#include <unistd.h>
#include <stdint.h>
#include <sys/mman.h>
#include <sys/stat.h>

typedef struct hashlookup_mapping {
	void *buf;
	char *filename;
	size_t size;
	size_t mapsize;
	size_t hashlen;
} hashlookup_mapping_t;
static const hashlookup_mapping_t hashlookup_mapping_0 = {.buf = MAP_FAILED};

typedef struct hashmerge_source {
	hashlookup_mapping_t *hm;
	char *buf;
	size_t off;
	size_t end;
} hashmerge_source_t;
static const hashmerge_source_t hashmerge_source_0;

typedef struct hashmerge_state {
	off_t written;
	hashmerge_source_t *sources;
	hashmerge_source_t **queue;
	char *buf;
	const char *filename;
	size_t fill;
	size_t hashlen;
	size_t numsources;
	size_t queuelen;
	int fd;
} hashmerge_state_t;
static const hashmerge_state_t hashmerge_state_0 = {.fd = -1, .buf = MAP_FAILED};

#define MERGEBUFSIZE (1 << 21)

static int hashcmp8(const void *a, const void *b) {
	return memcmp(a, b, 8);
}

static int hashcmp16(const void *a, const void *b) {
	return memcmp(a, b, 16);
}

static int hashcmp32(const void *a, const void *b) {
	return memcmp(a, b, 32);
}

static int hashcmp64(const void *a, const void *b) {
	return memcmp(a, b, 64);
}

static int hashcmp128(const void *a, const void *b) {
	return memcmp(a, b, 128);
}

static int hashcmp256(const void *a, const void *b) {
	return memcmp(a, b, 256);
}

static int hashcmp512(const void *a, const void *b) {
	return memcmp(a, b, 512);
}

static int hashcmp1024(const void *a, const void *b) {
	return memcmp(a, b, 1024);
}

static int hashcmp2048(const void *a, const void *b) {
	return memcmp(a, b, 2048);
}

static int hashcmp4096(const void *a, const void *b) {
	return memcmp(a, b, 4096);
}

static int (*hashcmp(int hashlen))(const void *a, const void *b) {
	if(hashlen < 8)
		croak("hashlen (%d) is too small", hashlen);
	if(hashlen & (hashlen -1))
		croak("hashlen (%d) is not a power of 2", hashlen);
	switch(hashlen) {
		case 8: return hashcmp8;
		case 16: return hashcmp16;
		case 32: return hashcmp32;
		case 64: return hashcmp64;
		case 128: return hashcmp128;
		case 256: return hashcmp256;
		case 512: return hashcmp512;
		case 1024: return hashcmp1024;
		case 2048: return hashcmp2048;
		case 4096: return hashcmp4096;
		default:
			croak("hashlen (%d) is too large", hashlen);
	}
}

static uint64_t hashlookup_msb64(const uint8_t *bytes) {
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
	return *(const uint64_t *)bytes;
#elif __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
#ifdef __GNUC__
	return __builtin_bswap64(*(const uint64_t *)bytes);
#else
	uint64_t r = *(const uint64_t *)bytes;
	r = ((r & UINT64_C(0x00FF00FF00FF00FF)) << 8) | ((r & UINT64_C(0xFF00FF00FF00FF00)) >> 8);
	r = ((r & UINT64_C(0x0000FFFF0000FFFF)) << 16) | ((r & UINT64_C(0xFFFF0000FFFF0000)) >> 16);
	return (r << 32) | (r >> 32);
#endif
#else
	return ((uint64_t)bytes[0] << 56)
		| ((uint64_t)bytes[1] << 48)
		| ((uint64_t)bytes[2] << 40)
		| ((uint64_t)bytes[3] << 32)
		| ((uint64_t)bytes[4] << 24)
		| ((uint64_t)bytes[5] << 16)
		| ((uint64_t)bytes[6] << 8)
		| ((uint64_t)bytes[7]);
#endif
}

static void dedup(hashlookup_mapping_t *hm) {
	size_t hashlen = hm->hashlen;
	uint8_t *buf = hm->buf;
	uint8_t *dst = buf + hashlen;
	const uint8_t *prv = buf;
	const uint8_t *src = buf + hashlen;
	const uint8_t *end = buf + hm->size;

	if(!hashlen)
		croak("internal error: hashlen==0 in dedup()");

	if(!hm->size)
		return;

	while(src < end) {
		if(memcmp(prv, src, hashlen)) {
			if(src != dst)
				memcpy(dst, src, hashlen);
			dst += hashlen;
		}
		prv = src;
		src += hashlen;
	}

	hm->size = dst - buf;
}

static uint64_t hashlookup_guess(uint64_t lower, uint64_t upper, uint64_t lower_hash, uint64_t upper_hash, uint64_t target) {
#ifdef __SIZEOF_INT128__
	unsigned __int128 res;
	uint64_t num, diff, off, ret;
	num = upper - lower;
	diff = upper_hash - lower_hash;
	off = target - lower_hash;
	res = off;
	res *= num;
	res /= diff;
	ret = (uint64_t)res;
	return ret + lower;
#else
	uint64_t num, prec, div, half;
	num = upper - lower;
	prec = UINT64_MAX / num;
	div = UINT64_MAX / prec + 1;
	half = num / prec / 2;
	/* warn("\rdiv=%"PRIu64"\033[K\n", div); */
	return lower + ((target - lower_hash) / div) * num / ((upper_hash - lower_hash) / div + 1) + half;
#endif
}

static bool hashlookup_find(const hashlookup_mapping_t *hm, void *key, size_t len) {
	const uint8_t *buf, *cur_buf;
	uint64_t lower, upper, cur, lower_hash, upper_hash, target;
	int d;

	if(hm->hashlen && len != hm->hashlen)
		croak("File::Hashset::exists: key does not have the configured length (%ld != %ld) ", (long int)len, (long int)hm->hashlen);
	if(len < 8)
		croak("File::Hashset::exists: key too small (%ld < 8) ", (long int)len);
	if(hm->size % len)
		croak("File::Hashset::exists: file size (%ld) is not a multiple of key length (%ld)", (long int)hm->size, (long int)len);
	if(!hm->size)
		return false;

	buf = hm->buf;
	lower = 0;
	upper = hm->size / len;
	lower_hash = 0;
	upper_hash = UINT64_MAX;
	target = hashlookup_msb64(key);
	/* warn("\r------\033[K\n"); */
	for(;;) {
		cur = hashlookup_guess(lower, upper, lower_hash, upper_hash, target);
		/* warn("\rTrying %"PRIu64"\033[K\n", cur); */
		if(cur < lower)
			cur = lower;
		else if(cur >= upper)
			cur = upper - 1;
		/* warn("\rTrying %"PRIu64"\033[K\n", cur); */
		cur_buf = buf + cur * len;
		d = memcmp(cur_buf, key, len);
		if(d < 0) {
			lower = cur + 1;
			lower_hash = hashlookup_msb64(cur_buf);
		} else if(d > 0) {
			upper = cur;
			upper_hash = hashlookup_msb64(cur_buf);
		} else {
			return true;
		}
		if(lower == upper)
			return false;
	}
}

static void queue_update_up(hashmerge_state_t *state, size_t i) {
	size_t i1, i2;
	hashmerge_source_t *s, *s1, *s2;
	const char *a, *a1, *a2;
	hashmerge_source_t **queue = state->queue;
	size_t queuelen = state->queuelen;
	size_t hashlen = state->hashlen;

	s = queue[i];
	a = s->buf + s->off;
	i1 = queuelen;

	/* bubble up */
	for(;;) {
		i1 = i * 2 + 1;
		if(i1 >= queuelen)
			break;
		s1 = queue[i1];
		a1 = s1->buf + s1->off;

		i2 = i1 + 1;
		if(i2 < queuelen) {
			s2 = queue[i2];
			a2 = s2->buf + s2->off;
			if(memcmp(a2, a1, hashlen) < 0) {
				i1 = i2;
				s1 = s2;
				a1 = a2;
			}
		}

		if(memcmp(a1, a, hashlen) < 0) {
			queue[i] = s1;
			queue[i1] = s;
			i = i1;
		} else {
			break;
		}
	}
}

static void queue_init(hashmerge_state_t *state) {
	size_t i = i = state->queuelen / 2;

	do queue_update_up(state, i);
		while(i--);
}

PERL_UNUSED_DECL
static void queue_update(hashmerge_state_t *state, size_t i) {
	size_t i1;
	hashmerge_source_t *s, *s1;
	const char *a, *a1;
	hashmerge_source_t **queue = state->queue;
	size_t queuelen = state->queuelen;
	size_t hashlen = state->hashlen;

	s = queue[i];
	a = s->buf + s->off;
	i1 = queuelen;

	/* bubble down */
	while(i) {
		i1 = (i - 1) / 2;
		s1 = queue[i1];
		a1 = s1->buf + s1->off;
		if(memcmp(a, a1, hashlen) < 0) {
			queue[i] = s1;
			queue[i1] = s;
			i = i1;
		} else {
			break;
		}
	}

	if(i != i1)
		queue_update_up(state, i);
}

static void safewrite(hashmerge_state_t *state) {
	ssize_t r;
	state->written += state->fill;
	const char *buf = state->buf;
	while(state->fill) {
		r = write(state->fd, buf, state->fill);
		switch(r) {
			case -1:
				croak("write(%s): %s\n", state->filename, strerror(errno));
			case 0:
				croak("write(%s): Returned 0\n", state->filename);
		}
		buf += (size_t)r;
		state->fill -= (size_t)r;
	}
}

#define MORTAL_ALLOC_ALIGNMENT (sizeof(size_t) * 2)
#define MORTAL_ALLOC_ALIGNMENT_MASK (MORTAL_ALLOC_ALIGNMENT - 1)
static void *mortal_malloc(size_t len) {
	SV *sv;
	char *buf;
	size_t pad;
	sv = newSV(len + MORTAL_ALLOC_ALIGNMENT_MASK);
	sv_2mortal(sv);
	SvPOK_on(sv);
	buf = SvPV_nolen(sv);
	pad = -(size_t)buf & MORTAL_ALLOC_ALIGNMENT_MASK;
	if(pad) {
		sv_chop(sv, buf + pad);
		buf = SvPV_nolen(sv);
		pad = -(size_t)buf & MORTAL_ALLOC_ALIGNMENT_MASK;
		if(pad)
			croak("internal error: unable to align an allocation\n");
	}
	return buf;
}

static void merge_do(hashmerge_state_t *state, const char *destination, size_t hashlen, hashlookup_mapping_t **sources, size_t numsources) {
	size_t i;
	hashlookup_mapping_t *hm;
	hashmerge_source_t *src;
	char *last;
	int fd;

	state->hashlen = hashlen;

	if(MERGEBUFSIZE % hashlen)
		croak("unsupported hash length (%d)\n", (int)hashlen);

#ifdef MAP_HUGETLB
	state->buf = mmap(NULL, MERGEBUFSIZE, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS|MAP_HUGETLB, -1, 0);
	if(state->buf == MAP_FAILED)
#endif
	state->buf = mmap(NULL, MERGEBUFSIZE, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0);
	if(state->buf == MAP_FAILED)
		croak("mmap(): %s\n", strerror(errno));

	fd = state->fd = open(destination, O_WRONLY|O_CREAT|O_NOCTTY|O_LARGEFILE, 0666);
	if(fd == -1)
		croak("open(%s): %s\n", destination, strerror(errno));
	state->filename = destination;

	state->queue = malloc(numsources * sizeof *state->queue);
	if(!state->queue)
		croak("malloc(): %s\n", strerror(errno));

	state->sources = malloc(numsources * sizeof *state->sources);
	if(!state->sources)
		croak("malloc(): %s\n", strerror(errno));

	for(i = 0; i < numsources; i++)
		state->sources[i] = hashmerge_source_0;
	state->numsources = numsources;

	for(i = 0; i < numsources; i++) {
		hm = sources[i];
		src = state->sources + i;
		src->hm = hm;
		src->buf = hm->buf;
		src->end = hm->size;
		if(hm->hashlen && hm->hashlen != hashlen)
			croak("File::Hashset::merge: string input object has a different hash length\n");
		if(src->end % hashlen)
			croak("File::Hashset::merge: input '%s' is not a multiple of the hash length\n", hm->filename);
		if(src->end)
			state->queue[state->queuelen++] = src;
	}

	if(state->queuelen) {
		queue_init(state);
		src = state->queue[0];
	}

	while(state->queuelen) {
		last = state->buf + state->fill;
		memcpy(last, src->buf + src->off, state->hashlen);
		state->fill += hashlen;
		src->off += hashlen;
		if(src->off == src->end) {
			if(!--state->queuelen)
				break;
			state->queue[0] = state->queue[state->queuelen];
		}
		// skip duplicate hashes
		for(;;) {
			queue_update_up(state, 0);
			src = state->queue[0];
			if(memcmp(last, src->buf + src->off, state->hashlen))
				break;
			src->off += hashlen;
			if(src->off == src->end) {
				if(!--state->queuelen)
					break;
				state->queue[0] = state->queue[state->queuelen];
			}
		}
		if(state->fill == MERGEBUFSIZE)
			safewrite(state);
	}

	if(state->fill)
		safewrite(state);

	if(ftruncate(fd, state->written) == -1)
		croak("truncate(%s): $!\n", state->filename, strerror(errno));

	if(fdatasync(fd) == -1)
		croak("fsync(%s): $!\n", state->filename, strerror(errno));

	state->fd = -1;
	if(close(fd) == -1)
		croak("close(%s): $!\n", state->filename, strerror(errno));
}

static void merge_cleanup(hashmerge_state_t *state) {
	free(state->sources);
	free(state->queue);
	if(state->fd != -1)
		close(state->fd);
	if(state->buf != MAP_FAILED)
		munmap(state->buf, MERGEBUFSIZE);
}

static void merge_wrap(hashmerge_state_t *state, const char *destination, size_t hashlen, hashlookup_mapping_t **sources, size_t numsources) {
	dXCPT;

	XCPT_TRY_START {
		merge_do(state, destination, hashlen, sources, numsources);
	} XCPT_TRY_END

	merge_cleanup(state);

	XCPT_CATCH {
		XCPT_RETHROW;
	}
}

static void *find_magic(SV *sv, MGVTBL *vtable) {
	MAGIC *mg;

	if(!sv || !SvROK(sv))
		return NULL;

	sv = SvRV(sv);
	if(!sv || !SvMAGICAL(sv))
		return NULL;

	mg = mg_findext(sv, PERL_MAGIC_ext, vtable);
	if(!mg)
		return NULL;

	return SvPV_nolen(mg->mg_obj);
}

static void *attach_magic(SV *sv, MGVTBL *vtable, const char *name, void *data, STRLEN len) {
	SV *obj = newSVpvn(data, len);
	sv_magicext(sv, obj, PERL_MAGIC_ext, vtable, name, 0);
	return SvPV_nolen(obj);
}

static int free_magic(pTHX_ SV *sv, MAGIC *mg) {
	PERL_UNUSED_ARG(sv);
	hashlookup_mapping_t *obj = (void *)SvPV_nolen(mg->mg_obj);
	if(obj) {
		if(obj->buf != MAP_FAILED)
			munmap(obj->buf, obj->mapsize);
		free(obj->filename);
		*obj = hashlookup_mapping_0;
	}
	SvREFCNT_dec(mg->mg_obj);
	return 0;
}

STATIC MGVTBL hashlookup_vtable = {
	.svt_free = free_magic
};

MODULE = File::Hashset  PACKAGE = File::Hashset

PROTOTYPES: ENABLE

void
sortfile(const char *class, const char *filename, int hashlen)
PREINIT:
	int (*cmp)(const void *, const void *);
	hashlookup_mapping_t hm = hashlookup_mapping_0;
	int fd;
	int err;
	struct stat st;
PPCODE:
	PERL_UNUSED_ARG(class);
	cmp = hashcmp(hashlen);
	fd = open(filename, O_RDWR|O_NOCTTY|O_LARGEFILE);
	if(fd == -1)
		croak("open(%s): %s\n", filename, strerror(errno));
	if(fstat(fd, &st) == -1) {
		close(fd);
		croak("stat(%s): %s\n", filename, strerror(errno));
	}
	if(st.st_size % hashlen) {
		close(fd);
		croak("File::Hashset::sortfile: string size (%ld) is not a multiple of the key length (%d)", (long int)st.st_size, hashlen);
	}
	if(st.st_size <= (off_t)hashlen) {
		close(fd);
		return;
	}
	hm.buf = mmap(NULL, st.st_size, PROT_READ|PROT_WRITE, MAP_SHARED, fd, 0);
	if(hm.buf == MAP_FAILED) {
		err = errno;
		close(fd);
		croak("mmap(%s): %s\n", filename, strerror(err));
	}

	hm.size = hm.mapsize = st.st_size;
	hm.hashlen = hashlen;
	qsort(hm.buf, hm.size / hashlen, hashlen, cmp);
	dedup(&hm);

	if(msync(hm.buf, hm.mapsize, MS_SYNC) == -1) {
		err = errno;
		munmap(hm.buf, hm.mapsize);
		close(fd);
		croak("msync(%s, MS_SYNC): %s\n", filename, strerror(err));
	}

	if(munmap(hm.buf, hm.mapsize) == -1) {
		err = errno;
		close(fd);
		croak("munmap(%s): %s\n", filename, strerror(err));
	}

	if(hm.size != hm.mapsize && ftruncate(fd, hm.size) == -1) {
		err = errno;
		close(fd);
		croak("ftruncate(%s, %ld): %s\n", filename, (long int)hm.size, strerror(err));
	}
	close(fd);
	XSRETURN_EMPTY;

void
merge(char *class, const char *destination, int hashlen, ...)
PREINIT:
	int i;
	hashlookup_mapping_t **sources;
	int numsources;
	hashmerge_state_t state = hashmerge_state_0;
PPCODE:
	PERL_UNUSED_ARG(class);
	numsources = items - 3;
	sources = mortal_malloc(numsources * sizeof *sources);
	for(i = 0; i < numsources; i++) {
		sources[i] = find_magic(ST(i + 3), &hashlookup_vtable);
		if(!sources[i])
			croak("invalid File::Hashset object");
	}

	merge_wrap(&state, destination, hashlen, sources, numsources);

	XSRETURN_EMPTY;

SV *
new(char *class, SV *string_sv, int hashlen)
PREINIT:
	HV *hash;
	hashlookup_mapping_t hm = hashlookup_mapping_0;
	const char *string;
	STRLEN len;
	int (*cmp)(const void *, const void *);
CODE:
	cmp = hashcmp(hashlen);
	if(SvUTF8(string_sv))
		croak("attempt to use an UTF-8 string as a hash buffer");
	string = SvPV(string_sv, len);
	if(len % hashlen)
		croak("File::Hashset::new: string size (%ld) is not a multiple of the key length (%d)", (long int)len, hashlen);
	if(len) {
		hm.buf = mmap(NULL, (size_t)len, PROT_READ|PROT_WRITE, MAP_PRIVATE|MAP_ANONYMOUS, -1, 0);
		if(hm.buf == MAP_FAILED)
			croak("mmap(): %s\n", strerror(errno));
		memcpy(hm.buf, string, len);
		qsort(hm.buf, len / hashlen, hashlen, cmp);
		hm.size = hm.mapsize = len;
		dedup(&hm);
	}
	hash = newHV();
	attach_magic((SV *)hash, &hashlookup_vtable, "hashlookup", &hm, sizeof hm);
	RETVAL = sv_bless(newRV_noinc((SV *)hash), gv_stashpv(class, 0));
OUTPUT:
	RETVAL

SV *
load(char *class, const char *filename)
PREINIT:
	HV *hash;
	hashlookup_mapping_t hm = hashlookup_mapping_0;
	int fd;
	int err;
	struct stat st;
CODE:
	fd = open(filename, O_RDONLY|O_NOCTTY|O_LARGEFILE);
	if(fd == -1)
		croak("open(%s): %s\n", filename, strerror(errno));
	if(fstat(fd, &st) == -1) {
		close(fd);
		croak("stat(%s): %s\n", filename, strerror(errno));
	}
	if(st.st_size % 8) {
		close(fd);
		croak("File::Hashset::load: file size (%ld) is not a multiple of the minimum key length (8)", (long int)st.st_size);
	}
	if(st.st_size) {
		hm.buf = mmap(NULL, st.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
		err = errno;
		close(fd);
		if(hm.buf == MAP_FAILED)
			croak("mmap(%s): %s\n", filename, strerror(err));
	} else {
		close(fd);
	}
	hm.size = hm.mapsize = st.st_size;

	if(st.st_size) {
		madvise(hm.buf, hm.mapsize, MADV_WILLNEED);
#ifdef MADV_UNMERGEABLE
		madvise(hm.buf, hm.mapsize, MADV_UNMERGEABLE);
#endif
	}
	hm.filename = strdup(filename);

	hash = newHV();
	attach_magic((SV *)hash, &hashlookup_vtable, "hashlookup", &hm, sizeof hm);
	RETVAL = sv_bless(newRV_noinc((SV *)hash), gv_stashpv(class, 0));
OUTPUT:
	RETVAL

void
exists(SV *self, SV *key)
PREINIT:
	hashlookup_mapping_t *hm;
	STRLEN len;
	char *k;
PPCODE:
	hm = find_magic(self, &hashlookup_vtable);
	if(!hm)
		croak("Invalid File::Hashset object");

	k = SvPV(key, len);
	if(hashlookup_find(hm, k, len))
		XSRETURN_YES;
	else
		XSRETURN_NO;
