#include "stdafx.h"
#include "ttd.h"
#include "command.h"

#if defined(WIN32)
#	include <windows.h>
#	include <winsock.h>

# pragma comment (lib, "ws2_32.lib")
# define ENABLE_NETWORK
#endif

#if defined(UNIX)
// Make compatible with WIN32 names
#	define ioctlsocket ioctl
#	define SOCKET int
#	define INVALID_SOCKET -1

// Need this for FIONREAD on solaris
#	define BSD_COMP
#	include <unistd.h>
#	include <sys/ioctl.h>

// Socket stuff
#	include <sys/socket.h>
#	include <netinet/in.h>
#	include <arpa/inet.h>

# ifndef TCP_NODELAY
#  define TCP_NODELAY 0x0001
# endif

#endif

#define SEND_MTU 1500

#if defined(ENABLE_NETWORK)

// sent from client -> server whenever the client wants to exec a command.
// send from server -> client when another player execs a command.
typedef struct CommandPacket {
	byte packet_length;
	byte packet_type;
	uint16 cmd;
	uint32 p1,p2;
	TileIndex tile;
	byte player;
	uint32 dp[8];
} CommandPacket;

#define COMMAND_PACKET_BASE_SIZE (sizeof(CommandPacket) - 8 * sizeof(uint32))

// sent from server -> client periodically to tell the client about the current tick in the server
// and how far the client may progress.
typedef struct SyncPacket {
	byte packet_length;
	byte packet_type;
	byte frames; // how many more frames may the client execute? this is relative to the old value of max.
	byte server; // where is the server currently executing? this is negatively relative to the old value of max.
	uint32 random_seed_1; // current random state at server. used to detect out of sync.
	uint32 random_seed_2;
} SyncPacket;

// sent from server -> client as an acknowledgement that the server received the command.
// the command will be executed at the current value of "max".
typedef struct AckPacket {
	byte packet_length;
	byte packet_type;
} AckPacket;

typedef struct Packet Packet;
struct Packet {
	Packet *next; // this one has to be the first element.
	uint siz;
	byte buf[SEND_MTU]; // packet payload
};

typedef struct ClientState {
	int socket;
	
	uint buflen;											// receive buffer len
	byte buf[sizeof(CommandPacket)];	// receive buffer
	
	uint eaten;
	Packet *head, **last;
} ClientState;


static uint _sync_frame_count;
static uint _not_packet;

typedef struct QueuedCommand QueuedCommand;
struct QueuedCommand {
	QueuedCommand *next;	
	CommandPacket cp;
	CommandCallback *callback;
	uint32 cmd;
	uint32 frame;
};

typedef struct CommandQueue CommandQueue;
struct CommandQueue {
	QueuedCommand *head, **last;
};

#define MAX_PLAYERS 4

// packets waiting to be executed, for each of the players.
// this list is sorted in frame order, so the item on the front will be executed first.
static CommandQueue _command_queues[MAX_PLAYERS]; 

// in the client, this is the list of commands that have not yet been acked.
// when it is acked, it will be moved to the appropriate position at the end of the player queue.
static CommandQueue _ack_queue;

static ClientState _clients[MAX_PLAYERS];

// keep a history of the 16 most recent seeds to be able to capture out of sync errors.
static uint32 _my_seed_list[16][2];

typedef struct FutureSeeds {
	uint32 frame;
	uint32 seed[2];
} FutureSeeds;

// remember some future seeds that the server sent to us.
static FutureSeeds _future_seed[8];
static int _num_future_seed;

//////////////////////////////////////////////////////////////////////

static QueuedCommand *AllocQueuedCommand(CommandQueue *nq)
{
	QueuedCommand *qp = (QueuedCommand*)calloc(sizeof(QueuedCommand), 1);
	assert(qp);
	*nq->last = qp;
	nq->last = &qp->next;
	return qp;
}


// go through the player queues for each player and see if there are any pending commands
// that should be executed this frame. if there are, execute them.
void NetworkProcessCommands()
{
	int i;
	CommandQueue *nq;
	QueuedCommand *qp;

	for(i=0,nq = _command_queues; i!=MAX_PLAYERS; i++,nq++) {
		while ( (qp=nq->head) && qp->frame <= _frame_counter) {
			// unlink it.
			if (!(nq->head = qp->next)) nq->last = &nq->head;

			if (qp->frame < _frame_counter) {
				error("qp->cp.frame < _frame_counter, %d < %d\n", qp->frame, _frame_counter);
			}

			// run the command
			_current_player = i;
			memcpy(_decode_parameters, qp->cp.dp, (qp->cp.packet_length - COMMAND_PACKET_BASE_SIZE));
			DoCommandP(qp->cp.tile, qp->cp.p1, qp->cp.p2, qp->callback, qp->cmd | CMD_DONT_NETWORK);
			free(qp);
		}
	}

	if (!_networking_server) {
		// remember the random seed so we can check if we're out of sync.
		_my_seed_list[_frame_counter & 15][0] = _sync_seed_1;
		_my_seed_list[_frame_counter & 15][1] = _sync_seed_2;

		while (_num_future_seed) {
			assert(_future_seed[0].frame >= _frame_counter); 
			if (_future_seed[0].frame != _frame_counter) break;
			if (_future_seed[0].seed[0] != _sync_seed_1 ||_future_seed[0].seed[1] != _sync_seed_2) error("network sync error");
			memcpy(_future_seed, _future_seed + 1, --_num_future_seed * sizeof(FutureSeeds));
		}
	}
}

// send a packet to a client
static void SendPacket(ClientState *cs, void *bytes)
{
	byte *b = (byte*)bytes;
	uint len = b[0], n;
	Packet *p;

	// see if there's space in the last packet?
	if (!cs->head || (p = (Packet*)cs->last, p->siz == sizeof(p->buf)))
		p = NULL;

	do {
		if (!p) {
			// need to allocate a new packet buffer.
			p = (Packet*)malloc(sizeof(Packet));
			
			// insert at the end of the linked list.
			*cs->last = p;
			cs->last = &p->next;
			p->next = NULL;
			p->siz = 0;
		}

		// copy bytes to packet.
		n = minu(sizeof(p->buf) - p->siz, len);
		memcpy(p->buf + p->siz, b, n);
		p->siz += n;
		b += n;
		p = NULL;
	} while (len -= n);
}

// client:
//   add it to the client's ack queue, and send the command to the server
// server:
//   add it to the server's player queue, and send it to all clients.
void NetworkSendCommand(TileIndex tile, uint32 p1, uint32 p2, uint32 cmd, CommandCallback *callback)
{
	int nump;
	QueuedCommand *qp;
	ClientState *cs;

	qp = AllocQueuedCommand(_networking_server ? &_command_queues[0] : &_ack_queue);
	qp->cp.packet_type = 0;
	qp->cp.tile = tile;
	qp->cp.p1 = p1;
	qp->cp.p2 = p2;
	qp->cp.cmd = (uint16)cmd;
	qp->cp.player = _local_player;
	qp->cmd = cmd;
	qp->callback = callback;

	// for server.
	qp->frame = _frame_counter_max;

	// calculate the amount of extra bytes.
	nump = 8;
	while ( nump != 0 && ((uint32*)_decode_parameters)[nump-1] == 0) nump--;
	qp->cp.packet_length = COMMAND_PACKET_BASE_SIZE + nump * sizeof(uint32);
	if (nump != 0) memcpy(qp->cp.dp, _decode_parameters, nump * sizeof(uint32));

#if defined(TTD_BIG_ENDIAN)
	// need to convert the command to little endian before sending it.
	{
		CommandPacket cp;
		cp = qp->cp;
		cp.cmd = TO_LE16(cp.cmd);
		cp.tile = TO_LE16(cp.tile);
		cp.p1 = TO_LE32(cp.p1);
		cp.p2 = TO_LE32(cp.p2);
		for(cs=_clients; cs->socket; cs++) SendPacket(cs, &cp);
	}
#else
	// send it to the peers
	for(cs=_clients; cs->socket; cs++) SendPacket(cs, &qp->cp);

#endif
}

// client:
//   server sends a command from another player that we should execute.
//   put it in the appropriate player queue.
// server:
//   client sends a command that it wants to execute. 
//   fill the when field so the client knows when to execute it.
//   put it in the appropriate player queue.
//   send it to all other clients.
//   send an ack packet to the actual client.

static void HandleCommandPacket(ClientState *cs, CommandPacket *np)
{
	QueuedCommand *qp;
	ClientState *c;
	AckPacket ap;
	
	printf("net: cmd size %d\n", np->packet_length);
	
	assert(np->packet_length >= COMMAND_PACKET_BASE_SIZE);

	// put it into the packet queue for the right player.
	qp = AllocQueuedCommand(&_command_queues[np->player]);
	qp->cp = *np;

	qp->frame = _frame_counter_max;
	qp->cmd = qp->cp.cmd;
	qp->callback = NULL;
	
	// extra params
	memcpy(&qp->cp.dp, np->dp, np->packet_length - COMMAND_PACKET_BASE_SIZE);

	ap.packet_type = 2;
	ap.packet_length = 2;

	// send it to the peers
	if (_networking_server) {
		for(c=_clients; c->socket; c++) {
			if (c == cs) {
				SendPacket(c, &ap);
			} else {
				SendPacket(c, &qp->cp);
			}
		}
	}

// convert from little endian to big endian?
#if defined(TTD_BIG_ENDIAN)
	qp->cp.cmd = TO_LE16(qp->cp.cmd);
	qp->cp.tile = TO_LE16(qp->cp.tile);
	qp->cp.p1 = TO_LE32(qp->cp.p1);
	qp->cp.p2 = TO_LE32(qp->cp.p2);
#endif
}

// sent from server -> client periodically to tell the client about the current tick in the server
// and how far the client may progress.
static void HandleSyncPacket(SyncPacket *sp)
{
	uint32 s1,s2;

	_frame_counter_srv = _frame_counter_max - sp->server;
	_frame_counter_max += sp->frames;
	printf("net: sync max=%d  cur=%d  server=%d\n", _frame_counter_max, _frame_counter, _frame_counter_srv);

	s1 = TO_LE32(sp->random_seed_1);
	s2 = TO_LE32(sp->random_seed_2);

	if (_frame_counter_srv <= _frame_counter) {
		// we are ahead of the server check if the seed is in our list. 
		if (_frame_counter_srv + 16 > _frame_counter) {
			// the random seed exists in our array check it.
			if (s1 != _my_seed_list[_frame_counter_srv & 0xF][0] || s2 != _my_seed_list[_frame_counter_srv & 0xF][1])
				error("network is desynched");
		}
	} else {
		// the server's frame has not been executed yet. store the server's seed in a list.
		if (_num_future_seed < lengthof(_future_seed)) {
			_future_seed[_num_future_seed].frame = _frame_counter_srv;
			_future_seed[_num_future_seed].seed[0] = s1;
			_future_seed[_num_future_seed].seed[1] = s2;
			_num_future_seed++;
		}
	}
}

// sent from server -> client as an acknowledgement that the server received the command.
// the command will be executed at the current value of "max".
static void HandleAckPacket()
{
	QueuedCommand *q;
	// move a packet from the ack queue to the end of this player's queue.
	q = _ack_queue.head;
	assert(q);
	if (!(_ack_queue.head = q->next)) _ack_queue.last = &_ack_queue.head;
	q->next = NULL;
	q->frame = _frame_counter_max;
	*_command_queues[_local_player].last = q;
	_command_queues[_local_player].last = &q->next;

	printf("net: ack\n");
}

#define NETWORK_BUFFER_SIZE 2048
static void ReadPackets(ClientState *cs)
{
	byte network_buffer[NETWORK_BUFFER_SIZE];
	uint pos,size;
	unsigned long read_count, recv_bytes;

	if (ioctlsocket(cs->socket, FIONREAD, &read_count) != 0) error("ioctlsocket failed.");
	if (read_count == 0) return;
	
	if ((size=cs->buflen) != 0) memcpy(network_buffer, cs->buf, size);

	do {
		recv_bytes = recv(cs->socket, (char*)network_buffer + size, min(read_count, sizeof(network_buffer) - size) , 0);
		if ( recv_bytes == (unsigned long)-1) error("recv() failed");
		read_count -= recv_bytes; // bytes left to read

		size += recv_bytes; // number of bytes read.
		pos = 0;
		while (size >= 2) {
			byte *packet = network_buffer + pos;
			// whole packet not there yet?
			if (size < packet[0]) break;
			size -= packet[0];
			pos += packet[0];
			switch(packet[1]) {
			case 0:
				HandleCommandPacket(cs, (CommandPacket*)packet);
				break;
			case 1:
				assert(!_networking_server);
				HandleSyncPacket((SyncPacket*)packet);
				break;
			case 2:
				assert(!_networking_server);
				HandleAckPacket();
				break;
			default:
				error("unknown packet type");
			}
		}
	} while(read_count);

	assert(size>=0 && size < sizeof(cs->buf));
	cs->buflen = size;
	memcpy(cs->buf, network_buffer + pos, size);
}


static void SendPackets(ClientState *cs)
{
	Packet *p;
	int n;
	uint nskip = cs->eaten, nsent = nskip;

	// try sending as much as possible.
	for(p=cs->head; p ;p = p->next) {
		if (p->siz) {
			assert(nskip < p->siz);

			n = send(cs->socket, p->buf + nskip, p->siz - nskip, 0);
			if (n == -1) error("send() failed");
			nsent += n;
			// send was not able to send it all? then we assume that the os buffer is full and break.
			if (nskip + n != p->siz)
				break;
			nskip = 0;
		}
	}

	// nsent bytes in the linked list are not invalid. free as many buffers as possible.
	// don't actually free the last buffer.
	
	while (nsent) {
		p = cs->head;
		assert(p->siz != 0);
		
		// some bytes of the packet are still unsent.
		if ( (int)(nsent - p->siz) < 0)
			break;
		nsent -= p->siz;
		p->siz = 0;
		if (p->next) {
			cs->head = p->next;
			free(p);
		}
	}

	cs->eaten = nsent;
}

void NetworkReceive()
{
	ClientState *cs;		
	// get stuff from all the clients
	for(cs=_clients;cs->socket; cs++) ReadPackets(cs);
}

void NetworkSend()
{
	ClientState *cs;

	// send sync packets?
	if (_networking_server) {
		if (++_not_packet >= _network_sync_freq) {
			SyncPacket sp;
			uint new_max;
			
			_not_packet = 0;

			new_max = max(_frame_counter + _network_ahead_frames, _frame_counter_max);

			sp.packet_length = sizeof(sp);
			sp.packet_type = 1;
			sp.frames = new_max - _frame_counter_max;
			sp.server = _frame_counter_max - _frame_counter;
			sp.random_seed_1 = TO_LE32(_sync_seed_1);
			sp.random_seed_2 = TO_LE32(_sync_seed_2);
			_frame_counter_max = new_max;
			
			// send it to all the clients
			for(cs=_clients;cs->socket; cs++) {
				printf("net: sending sync\n");
				SendPacket(cs, &sp);
			}
		}
	}

	// send stuff to all clients
	for(cs=_clients;cs->socket; cs++) SendPackets(cs);
}

void NetworkConnect(const char *hostname, int port)
{
	SOCKET s;
	struct sockaddr_in sin;
	int b;

	s = socket(AF_INET, SOCK_STREAM, 0);
	if (s == -1) error("socket() failed");
	
	b = 1;
	setsockopt(s, IPPROTO_TCP, TCP_NODELAY, (const char*)&b, sizeof(b));
		
	sin.sin_family = AF_INET;
	sin.sin_addr.s_addr = inet_addr(hostname);
	sin.sin_port = htons(port);

	if (connect(s, (struct sockaddr*) &sin, sizeof(sin)) != 0)
		error("connect() failed");

	// in client mode, only the first client field is used. it's pointing to the server.
	_clients[0].socket = s;
}

void NetworkListen(int port, int n)
{
	SOCKET ls, s;
	struct sockaddr_in sin;
	int i, sin_len;
	int b;

	ls = socket(AF_INET, SOCK_STREAM, 0);
	if (ls == -1)
		error("socket() on listen socket failed");

	sin.sin_family = AF_INET;
	sin.sin_addr.s_addr = 0;
	sin.sin_port = htons(port);

	if (bind(ls, (struct sockaddr*)&sin, sizeof(sin)) != 0)
		error("bind() failed");

	if (listen(ls, 1) != 0)
		error("listen() failed");

	for(i=0; i!=n; i++) {
		sin_len = sizeof(sin);
		s = accept(ls, (struct sockaddr*)&sin, &sin_len);
		if (s == INVALID_SOCKET) error("accept() failed");

		b = 1;
		setsockopt(s, IPPROTO_TCP, TCP_NODELAY, (const char*)&b, sizeof(b));

		_clients[i].socket = s;
	}
}


void NetworkInitialize()
{
	int i;

#if defined(WIN32)
	WSADATA wsa;
	if (WSAStartup(MAKEWORD(2,0), &wsa) != 0)
		error("WSAStartup failed");
#endif
	
	// initialize queues
	for(i=0; i!=MAX_PLAYERS; i++) {
		_command_queues[i].last = &_command_queues[i].head;
		_clients[i].last = &_clients[i].head;
	}

	_ack_queue.last = &_ack_queue.head;
}


#else // ENABLE_NETWORK


// stubs
void NetworkInitialize() {}
void NetworkListen(int port, int n) {}
void NetworkConnect(const char *hostname, int port) {}
void NetworkReceive() {}
void NetworkSend() {}
void NetworkSendCommand(TileIndex tile, uint32 p1, uint32 p2, uint32 cmd, CommandCallback *callback) {}
void NetworkProcessCommands() {}
#endif // ENABLE_NETWORK