/*
 *    poor man's windows sniffer w/ raw sockets (v0.1)
 *	    by nad (nad@somethinginteresting.org)
 *		(all code public domain)
 */

#include "poorsniff.h"

/* guh */
int main(int argc, char **argv) {

	char pak[PAKSIZE];
	DWORD bytes;

	init_opt(argc, argv);
	init_net();

	while(1) {

		memset(pak, 0, sizeof(pak));
		if ((bytes = recv(s0k, pak, sizeof(pak), 0)) == SOCKET_ERROR)
			die("socket error on recv\n");
		else
			process_pak(pak, bytes);

	}

	return 0;

}

/* parse args, init format strings */
void init_opt(int argc, char **argv) {

	int i, v;
	char *c, *var, *val;

	/* setup default format strings */
	strncpy(fmttcp, "%s:%S -> %d:%D (tcp) %P[0-256]\0", MAXFMT);
	strncpy(fmtudp, "%s:%S -> %d:%D (udp) %P[0-256]\0", MAXFMT);
	strncpy(fmticmp, "%s -> %d (icmptype: %T icmpcode %C) %P[0-256]\0", MAXFMT);

	for(i=1; i<argc; i++) {

		if (argv[1][0] == '-' && argv[1][1] == 'h')
			usage();

		/* valid args are variable=value */
		if ((c = strchr(argv[i], '=')) == NULL) {
				printf("invalid argument: %s\n", argv[i]);
				exit(-1);
		} else {
			*c = '\0';
			var = argv[i];
			val = ++c;
			/* it would be cooler to do this in an array with a loop, but i'm too lazy (only 6 vars anyway) */
			if (strncmp(var, "tcp", 3) == 0) {
				v = atoi(val);
				if (v == 1 || v == 0)
					scantcp = v;
				else
					die("tcp, udp, and icmp only take 1 or 0 as valid values.\n");
			} else if (strncmp(var, "udp", 3) == 0) {
				v = atoi(val);
				if (v == 1 || v == 0)
					scanudp = v;
				else
					die("tcp, udp, and icmp only take 1 or 0 as valid values.\n");
			} else if (strncmp(var, "icmp", 4) == 0) {
				v = atoi(val);
				if (v == 1 || v == 0)
					scanicmp = v;
				else
					die("tcp, udp, and icmp only take 1 or 0 as valid values.\n");
			} else if (strncmp(var, "fmttcp", 6) == 0)
				strncpy(fmttcp, val, MAXFMT);
			else if (strncmp(var, "fmtudp", 6) == 0)
				strncpy(fmtudp, val, MAXFMT);
			else if (strncmp(var, "fmticmp", 7) == 0)
				strncpy(fmticmp, val, MAXFMT);
			else if (strncmp(var, "prom", 4) == 0)
				promiscuous=0;
			else if (strncmp(var, "int", 3) == 0)
				interface_choice = atoi(val);
			else {
				printf("unknown variable, i'll just DIE until you figure out what you're doing\n");
				exit(-1);
			}

		}

	}

}

/* open raw socket, set promiscuous mode */
void init_net() {

	WSADATA w;
	SOCKADDR_IN sa;
	DWORD bytes;
	char hostname[HOSTNAME_LEN];
	struct hostent *h;
	unsigned int opt = 1;

	if (WSAStartup(MAKEWORD(2,0), &w) != 0)
		die("WSAStartup failed\n");

	if (interface_choice==0) {
		list_interfaces();
		die("");
	}

	if ((s0k = socket(AF_INET, SOCK_RAW, IPPROTO_IP)) == INVALID_SOCKET)
		die("unable to open raw socket\n");

	if (interface_choice >=0) {

		bind_to_interface(interface_choice);

	} else {

		// use default interface

		if ((gethostname(hostname, HOSTNAME_LEN)) == SOCKET_ERROR)
			die("unable to gethostname\n");

		if ((h = gethostbyname(hostname)) == NULL)
			die("unable to gethostbyname\n");

		sa.sin_family = AF_INET;
		sa.sin_port = htons(6000);
		memcpy(&sa.sin_addr.S_un.S_addr, h->h_addr_list[0], h->h_length);

		if ((bind(s0k, (SOCKADDR *)&sa, sizeof(sa))) == SOCKET_ERROR)
			die("unable to bind() socket\n");

	}


	if (promiscuous)	/* -d on the command line to disable promiscuous mode */
		if ((WSAIoctl(s0k, SIO_RCVALL, &opt, sizeof(opt), NULL, 0, &bytes, NULL, NULL)) == SOCKET_ERROR)
			die("failed to set promiscuous mode\n");



}


/* parse pak, print out requested fields */
void process_pak(char *pak, int len) {

	struct iphdr *ip;
	struct tcphdr *tcp;
	struct udphdr *udp;
	struct icmphdr *icmp;
	char *fmt, *data, tmp[6];
	unsigned char proto;	/* to avoid repeated dereferencing */
	int i, start, end, datasize;

	SOCKADDR_IN src, dst;
	char *p;

	ip = (struct iphdr *) pak;
	proto = ip->proto;
	switch(proto) {
	case IPPROTO_TCP:
		if (!scantcp) return;
		tcp = (struct tcphdr *) (pak + (ip->ihl * 4));
		fmt = fmttcp;
		break;
	case IPPROTO_UDP:
		if (!scanudp) return;
		udp = (struct udphdr *) (pak + (ip->ihl * 4));
		fmt = fmtudp;
		break;
	case IPPROTO_ICMP:
		if (!scanicmp) return;
		icmp = (struct icmphdr *) (pak + (ip->ihl * 4));
		fmt = fmticmp;
		break;
	default:
		printf("unknown protocol %d: what the heck is that?\n", proto);
		return;
	}

	for (p=fmt; *p; p++) {

		if (*p != '%') {
			putchar(*p);
			continue;
		}

		switch(*++p) {

			/* IP HEADER FIELDS */

			case 'h':	/* header len	*/
				printf("%d", ip->ihl);
				break;
			case 'v':	/* version		*/
				printf("%d", ip->ver);
				break;
			case 'o':	/* tos			*/
				printf("%d", ip->tos);
				break;
			case 'l':	/* total len	*/
				printf("%d", ntohs(ip->totlen));
				break;
			case 'i':	/* ip id		*/
				printf("%d", ip->id);
				break;
			case 't':	/* ttl			*/
				printf("%d", ip->ttl);
				break;
			case 'p':	/* proto		*/
				printf("%d", ip->proto);
				break;
			case 'c':	/* checksum		*/
				printf("%d", ntohs(ip->checksum));
				break;
			case 's':	/* src ip		*/
				src.sin_addr.s_addr = ip->src;
				printf("%s", inet_ntoa(src.sin_addr));
				break;
			case 'd':	/* dst ip		*/
				dst.sin_addr.s_addr = ip->dst;
				printf("%s", inet_ntoa(dst.sin_addr));
				break;

			/* PROTOCOL SPECIFIC FIELDS */
			case 'S':	/* tcp/udp source port */
				switch(proto) {
				case IPPROTO_TCP:
					printf("%d", ntohs(tcp->sport));
					break;
				case IPPROTO_UDP:
					printf("%d", ntohs(udp->sport));
					break;
				}
				break;
			case 'D':	/* tcp/udp dest port */
				switch(proto) {
				case IPPROTO_TCP:
					printf("%d", ntohs(tcp->dport));
					break;
				case IPPROTO_UDP:
					printf("%d", ntohs(udp->dport));
					break;
				}
				break;
			case 'T':	/* icmp type */
				if (proto == IPPROTO_ICMP) printf("%d", icmp->type);
				break;
			case 'C':	/* icmp code */
				if (proto == IPPROTO_ICMP) printf("%d", icmp->code);
				break;
			case 'I':	/* icmp ID */
				if (proto == IPPROTO_ICMP) printf("%d", ntohs(icmp->id));
				break;
			case 'Q':	/* tcp sequence, icmp sequence */
				switch(proto) {
				case IPPROTO_TCP:
					printf("%d", ntohl(tcp->seq));
					break;
				case IPPROTO_ICMP:
					printf("%d", ntohl(icmp->seq));
					break;
				}
				break;
			case 'L':	/* udp length */
				if (proto == IPPROTO_UDP) printf("%d", ntohs(udp->len));
				break;
			case 'A':	/* tcp ack num */
				if (proto == IPPROTO_TCP) printf("%d", ntohl(tcp->acknum));
				break;
			case 'K':	/* tcp/udp/icmp checksum */
				switch(proto) {
				case IPPROTO_TCP:
					printf("%d", ntohl(tcp->cksum));
					break;
				case IPPROTO_UDP:
					printf("%d", ntohs(udp->cksum));
					break;
				case IPPROTO_ICMP:
					printf("%d", ntohl(icmp->cksum));
					break;
				}
				break;
			case 'P':	/* raw data offset */

				if (*(++p) == '[') {

					p++;
					i=0;
					while (*p != ']') {

						if (isdigit(*p)) {
							tmp[i] = *p;
							i++;
							if (i > 5) die("invalid data range (index)\n");
						} else if (*p == '-') {
							tmp[i] = '\0';
							start = atoi(tmp);
							i=0;
							tmp[i] = '\0';
						} else {
							die("invalid data range (fucked up range)\n");
						}

						p++;

					}
					tmp[i] = '\0';
					end = atoi(tmp);

					switch(proto) {
					case IPPROTO_TCP:
						data = pak + (ip->ihl * 4) + (tcp->tcphl * 4);
						datasize = ntohs(ip->totlen) - (ip->ihl*4) - (tcp->tcphl*4);
						break;
					case IPPROTO_UDP:
						data = pak + (ip->ihl * 4) + (sizeof(struct udphdr));
						datasize = ntohs(ip->totlen) - (ip->ihl*4) - (sizeof(struct udphdr));
						break;
					case IPPROTO_ICMP:
						data = pak + (ip->ihl * 4);
						datasize = ntohs(ip->totlen) - (ip->ihl * 4);
						break;
					}

					printf("\n");
					if (datasize != 0 && start < end && start >= 0) {

						if (end > datasize)
							end = datasize;

						for (i=start; i<end; i++)
							if (isalnum(data[i]))
								printf("%c", data[i]);
							else
								printf(".");

						printf("\n");
					}

				} else
					printf("invalid data range specified\n");

				break;

			default:
				printf("unknown format character!\n");
		}

	}
	printf("\n");

}


void list_interfaces() {

	SOCKET sd;
	sd = WSASocket(AF_INET, SOCK_DGRAM, 0, 0, 0, 0);
	if (sd == SOCKET_ERROR)
		printf("error on WSASocket\n");

	INTERFACE_INFO InterfaceList[20];
	unsigned long nBytesReturned;
	if (WSAIoctl(sd, SIO_GET_INTERFACE_LIST, 0, 0, &InterfaceList, sizeof(InterfaceList), &nBytesReturned, 0, 0) == SOCKET_ERROR) {
		printf("error fetching interface list\n");
	}

	int nNumInterfaces = nBytesReturned / sizeof(INTERFACE_INFO);
	printf("found %d interfaces\n", nNumInterfaces);

    for (int i = 0; i < nNumInterfaces; ++i) {

		printf("%d. ", i+1);

        SOCKADDR_IN *pAddress;
        pAddress = (SOCKADDR_IN *) & (InterfaceList[i].iiAddress);
        printf("%s ", inet_ntoa(pAddress->sin_addr));

        pAddress = (SOCKADDR_IN *) & (InterfaceList[i].iiBroadcastAddress);
        printf("  %s ", inet_ntoa(pAddress->sin_addr));

        pAddress = (SOCKADDR_IN *) & (InterfaceList[i].iiNetmask);
        printf("  %s ", inet_ntoa(pAddress->sin_addr));;

        u_long nFlags = InterfaceList[i].iiFlags;

        if (nFlags & IFF_UP)
			printf("(up");
        else
			printf("(down");

        if (nFlags & IFF_POINTTOPOINT)
			printf(", point-to-point");
        if (nFlags & IFF_LOOPBACK)
			printf(", loopback");

        if (nFlags & IFF_BROADCAST)
			printf(", broadcast");

        if (nFlags & IFF_MULTICAST)
			printf(", multicast");


		printf(")\n");
    }

	return;
}


void bind_to_interface(int choice) {

	SOCKET sd;
	sd = WSASocket(AF_INET, SOCK_DGRAM, 0, 0, 0, 0);
	if (sd == SOCKET_ERROR)
		printf("error on WSASocket\n");

	INTERFACE_INFO InterfaceList[20];
	unsigned long nBytesReturned;
	if (WSAIoctl(sd, SIO_GET_INTERFACE_LIST, 0, 0, &InterfaceList, sizeof(InterfaceList), &nBytesReturned, 0, 0) == SOCKET_ERROR) {
		printf("error fetching interface list\n");
	}

	int nNumInterfaces = nBytesReturned / sizeof(INTERFACE_INFO);
	if (choice > nNumInterfaces) {
		die("invalid interface selection\n");
	}

	if (choice) {
		// bind to the specified interface and return

        SOCKADDR_IN *pAddress;
        pAddress = (SOCKADDR_IN *) & (InterfaceList[choice-1].iiAddress);
        printf("using interface: %s\n", inet_ntoa(pAddress->sin_addr));

		if ((bind(s0k, (SOCKADDR *)&(InterfaceList[choice-1].iiAddress), sizeof(SOCKADDR_IN))) == SOCKET_ERROR)
			die("unable to bind() socket\n");

		return;
	}


}


void usage() {
	printf("poorsniff <variables=values>\n");
	printf(" ignore protocols:\n");
	printf("   tcp=0   udp=0   icmp=0\n");
	printf("\n");
	printf(" promiscuous mode off:\n");
	printf("   prom=0\n");
	printf("\n");
	printf(" select interface (0 for list):\n");
	printf("   int=0   int=1   etc.\n");
	printf("\n");
	printf(" override default print format\n");
	printf("   fmttcp=fmt  fmtudp=fmt  fmticmp=fmt\n");
	printf("\n");
	printf(" ip header:	\n");
	printf("   %%s  src ip\n");
	printf("   %%d  dst ip\n");
	printf("   %%h  header length\n");
	printf("   %%v  ip version\n");
	printf("   %%o  type of service\n");
	printf("   %%l  total length\n");
	printf("   %%i  ip identification\n");
	printf("   %%t  time to live\n");
	printf("   %%p  protocol\n");
	printf("   %%c  ip checksum\n");
	printf("\n");
	printf(" tcp/udp/icmp header:\n");
	printf("   %%S  src port            tcp,udp	\n");
	printf("   %%D  dst port            tcp,udp	\n");
	printf("   %%A  ack number          tcp	\n");
	printf("   %%Q  sequence number     tcp,icmp\n");
	printf("   %%L  length              udp	\n");
	printf("   %%K  checksum            tcp,udp,icmp\n");
	printf("   %%T  type                icmp\n");
	printf("   %%C  code                icmp\n");
	printf("   %%I  id                  icmp\n");
	printf("\n");
	printf(" data\n");
	printf("   %%P[start-end]   (index begins at data)\n");
	printf("\n");
	printf("example: poorsniff udp=0 tcpfmt=\"%%s:%%S -> %%d:%%D %%P[0-256]\"\n");
	exit(0);
}


void die(char *s) {
	WSACleanup();
	printf("%s", s);
	exit(-1);
}


