/* SPDX-License-Identifier: GPL-2.0-only OR GPL-3.0-only */
/* Copyright (c) 2025 Brett A C Sheffield <bacs@librecast.net> */

#include "test.h"
#include "testnet.h"
#include <agent.h>
#include <pthread.h>
#include <semaphore.h>
#include <sys/wait.h>
#include <unistd.h>

#define CHANNEL_NAME "Cui bono?"
#define CMDLINE_PAYLOAD "Alonso"
#define WAITS 8
#define SPEEDLIMIT "1048576"
#define PAYLOAD_SZ 32768 + BUFSIZ
#define REDIRECT_BUF f = freopen("/dev/null", "a", stdout); assert(f); setbuffer(stdout, buf, sizeof buf);
#define REDIRECT_OUT f = freopen("/dev/tty", "a", stdout); assert(f); setbuf(stdout, NULL);

static sem_t sem_recv;
static unsigned int ifx;
static char ifname[IFNAMSIZ];
static char payload[PAYLOAD_SZ];
static size_t paylen;

enum {
	PAYLOAD_CMD,
	PAYLOAD_STDIN,
#if 0
	PAYLOAD_FILE,
#endif
	PAYLOAD_TYPES,
};

static void *thread_recv(void *arg)
{
	state_t state = {0};
	char *argv_00[] = { PACKAGE_NAME, "-v", "-i", ifname, "recv", CHANNEL_NAME, NULL };
	FILE *f;
	char buf[PAYLOAD_SZ];
	int argc = sizeof argv_00 / sizeof argv_00[0] - 1;
	int rc;

	memset(buf, 0, sizeof buf);
	rc = agent_load_config(&state, argc, argv_00, NULL);
	sem_post(&sem_recv); /* tell main thread we're ready to recv */
	if (!test_assert(rc == 0, "agent_load_config()")) goto skip_test;

	REDIRECT_BUF
	rc = agent_run(&state);
	free_state(&state);
	fflush(stdout);
	REDIRECT_OUT
	test_assert(rc == EXIT_SUCCESS, "agent() returned %i", rc);

	/* check payload */
	test_assert(!memcmp(payload, buf, paylen), "recv buffer matches (%zu bytes)", paylen);
skip_test:
	sem_post(&sem_recv); /* tell main thread we're done */

	return arg;
}

static int create_keys_and_tokens(void)
{
	state_t state = {0};
	int rc;
	{
		/* generate keys */
		char *argv[] = { PACKAGE_NAME, "whoami", NULL };
		int argc = sizeof argv / sizeof argv[0] - 1;
		rc = agent(&state, argc, argv);
		test_assert(rc == EXIT_SUCCESS, "agent() returned %i", rc);
	}
	{
		/* add trusted key (self-signed) */
		char *signer_key = state.defaults.keyring.phex;
		char *argv[] = { PACKAGE_NAME, "key", "add", signer_key, NULL };
		int argc = sizeof argv / sizeof argv[0] - 1;
		rc = agent(&state, argc, argv);
		test_assert(rc == EXIT_SUCCESS, "agent() returned %i", rc);
	}
	return test_status;
}

static int test_payload_stdin(state_t *state)
{
	char *argv[] = { PACKAGE_NAME, "send", "-v", "-i", ifname, "--loopback", "--bpslimit", SPEEDLIMIT, CHANNEL_NAME, "-", NULL };
	int argc = sizeof argv / sizeof argv[0] - 1;
	int rc;
	int pipefd[2];
	paylen = 32768 + arc4random_uniform(BUFSIZ);
	test_log("generating random payload of %zu bytes\n", paylen);
	arc4random_buf(payload, paylen);
	if (!test_assert(pipe(pipefd) == 0, "pipe()")) return test_status;
	pid_t pid = fork();
	if (!pid) {
		/* child writes back to parent stdin through pipe */
		close(pipefd[0]); /* close read end */
		ssize_t ret = write(pipefd[1], payload, paylen);
		close(pipefd[1]);
		exit((ret == (ssize_t)paylen) ? EXIT_SUCCESS : EXIT_FAILURE);
	}
	close(pipefd[1]);   /* close write end */
	dup2(pipefd[0], 0); /* hook up stdin to read end of pipe */
	rc = agent(state, argc, argv);
	test_assert(rc == EXIT_SUCCESS, "agent() returned %i", rc);
	waitpid(pid, &rc, 0);
	close(pipefd[0]);
	test_assert(rc == EXIT_SUCCESS, "child writer returned %i", rc);
	return test_status;
}

static int test_payload_cmdline(state_t *state)
{
	char *argv[] = { PACKAGE_NAME, "send", "-v", "-i", ifname, "--loopback", "--bpslimit", SPEEDLIMIT, CHANNEL_NAME, CMDLINE_PAYLOAD, NULL };
	int argc = sizeof argv / sizeof argv[0] - 1;
	int rc;
	paylen = strlen(CMDLINE_PAYLOAD);
	sprintf(payload, CMDLINE_PAYLOAD);
	rc = agent(state, argc, argv);
	return test_assert(rc == EXIT_SUCCESS, "agent() returned %i", rc);
}

static int test_payload(int paytype)
{
	char fakehome[] = "0000-0021-XXXXXX";
	state_t state = {0};
	int rc;

	/* create fake home directory */
	if (!test_assert(mkdtemp(fakehome) != NULL, "mkdtemp()")) {
		perror("mkdtemp");
		return test_status;
	}
	setenv("HOME", fakehome, 1);

	/* generate keys */
	if (create_keys_and_tokens()) return test_status;

	/* start recv thread */
	struct timespec timeout;
	pthread_t tid;
	rc = sem_init(&sem_recv, 0, 0);
	if (!test_assert(rc == 0, "sem_init()")) return test_status;
	rc = pthread_create(&tid, NULL, &thread_recv, NULL);
	if (!test_assert(rc == 0, "pthread_create() recv thread")) goto err_sem_recv_destroy;

	sem_wait(&sem_recv); /* wait until recv thread ready */

	switch (paytype) {
		case PAYLOAD_CMD:
			test_payload_cmdline(&state);
			break;
		case PAYLOAD_STDIN:
			test_payload_stdin(&state);
			break;
#if 0
		case PAYLOAD_FILE:
			test_payload_cmdline(&state);
			break;
#endif
	}

	rc = clock_gettime(CLOCK_REALTIME, &timeout);
	if (!test_assert(rc == 0, "clock_gettime()")) goto err_sem_recv_destroy;
	timeout.tv_sec += WAITS;
	rc = sem_timedwait(&sem_recv, &timeout);
	test_assert(rc == 0, "timeout");
	pthread_cancel(tid);
	pthread_join(tid, NULL);
	if (rc != 0) goto err_sem_recv_destroy;

err_sem_recv_destroy:
	sem_destroy(&sem_recv);

	return test_status;
}

int main(void)
{
	char name[] = PACKAGE_NAME " send + recv (large payloads)";

	test_name(name);
	test_require_net(TEST_NET_BASIC);

	ifx = get_multicast_if();
	if (!ifx) return (test_status = TEST_WARN);
	if (!test_assert(if_indextoname(ifx, ifname) != NULL, "if_indextoname()"))
		return test_status;

	int max = (RUNNING_ON_VALGRIND) ? 0 : PAYLOAD_TYPES;
	for (int i = 0; i < max; i++) {
		if (!test_assert(test_payload(i) == 0, "test with payload (%i)", i))
			return test_status;
	}

	return test_status;
}
