#include <linux/types.h>
#include <netinet/in.h>
#include <netinet/ip.h>
#include <netinet/ip6.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <linux/netfilter.h>            /* for NF_ACCEPT */
#include <libnetfilter_queue/libnetfilter_queue.h>

/**
 * This file implements the overall program logic by setting up a
 * netfilter queue callback for analysing packets and make an ACCEPT
 * or REJECT verdict for them.
 */

// API for caching of verdicts
unsigned int lookup_cache(unsigned char *domain);
void add_cache(unsigned char *domain,unsigned int ix);
int hash_code(unsigned char *domain);

// API for BAD domains database
unsigned int check_domain(unsigned char *domain);
void load_domains(char *file);
void start_domain_database_loading(void);
void end_domain_database_loading(void);

/**
 * Return packet id, or 0 on error.
 */
static u_int32_t get_packet_id(struct nfq_data *tb) {
    struct nfqnl_msg_packet_hdr *ph = nfq_get_msg_packet_hdr( tb );
    return ( ph )? ntohl( ph->packet_id ) : 0;
}

/**
 * Packet header if ipv4 packet.
 */
struct ipv4_pkt {
    struct ip first;          // .ip_dst[4 bytes]
    struct tcphdr second;
};

/**
 * Packet header if ipv6 packet.
 */
struct ipv6_pkt {
    struct ip6_hdr first;     // .ip6_dst[16 bytes]
    struct tcphdr second;
};

/**
 * Convenience type for network packets of undetermined family.
 */
struct packet {
    union {
	struct ipv4_pkt pkt4;
	struct ipv6_pkt pkt6;
    } p;
};

/**
 * Locates the header of a packet byte blob.
 */
static struct packet *get_headerP(unsigned char *data) {
    return (struct packet *) data;
}

static const char *tell_ip(struct packet *ip) {
    static char THEIP[200];
    switch ( ip->p.pkt4.first.ip_v ) {
    case 4:
	return inet_ntop( AF_INET, &ip->p.pkt4.first.ip_dst, THEIP, 200 );
    case 6:
	return inet_ntop( AF_INET6, &ip->p.pkt6.first.ip6_dst, THEIP, 200 );
    }
    snprintf( THEIP, 200, "%d ???", ip->p.pkt4.first.ip_v );
    return THEIP;
}

/**
 * Review payload packet payload
 */
static void view_payload(unsigned char *data,int length) {
    struct packet *header = get_headerP( data );
    u_int16_t port = 0;
    u_int8_t syn = 0;
    unsigned char *body = data ;//+ sizeof( struct packet );
    switch ( header->p.pkt4.first.ip_v ) {
    case 4:
	port = ntohs( ((struct ipv4_pkt *) data )->second.th_dport );
	syn = sizeof( struct ipv4_pkt );
	break;
    case 6:
	port = ntohs( ((struct ipv6_pkt *) data )->second.th_dport );
	syn = sizeof( struct ipv6_pkt );
	break;
    }
#define END 400
    unsigned char * end = body + ( ( length > END )? END : length );
    fprintf( stderr, "%s %d %d %d ", tell_ip( header ), syn, port, length );
    while ( body < end ) {
	unsigned char c = *body++;
	if ( c < ' ' || c >= 127 || 1 ) {
	    fprintf( stderr, "%02x ", c );
	} else {
	    fprintf( stderr, "%c", c );
	}
    }
    fprintf( stderr, "\n" );
}

//////////////////
static unsigned char buffer[1000];

/**
 * SSL traffic includes a data packet with a clear text host name.
 * This is knwon as the SNI extension.
 */
static unsigned char *ssl_host(unsigned char *data,int length) {
    // Check that it's a "Client Hello" message
    unsigned char *p;
    switch ( ((struct packet *) data)->p.pkt4.first.ip_v ) {
    case 4:
	p = data + sizeof( struct ipv4_pkt ) + 12; //??
	break;
    case 6:
	p = data + sizeof( struct ipv6_pkt ) + 0; //??
	break;
    default:
	return 0;
    }
    if ( p[0] != 0x16 || p[1] != 0x03 || p[5] != 0x01 || p[6] != 0x00 ) {
	return 0;
    }
    fprintf( stderr, "Client Hello\n" );
    // Note minor version p[2] is not checked
    // record_length = 256 * p[3] + p[4]
    // handshake_message_length = 256 * p[7] + p[8]
    if ( p[9] != 0x03 || p[10] != 0x03 ) { // TLS 1.2 (?ralph?)
	return 0;
    }
    fprintf( stderr, "TLS 1.2\n" );
    unsigned int i = 46 + ( 256 * p[44] ) + p[45];
    i += p[i] + 1;
    unsigned int extensions_length = ( 256 * p[i] ) + p[i+1];
    i += 2;
    int k = 0;
    fprintf( stderr, "TLS 1.2 %d %d\n", i, extensions_length );
    while ( k < extensions_length ) {
	unsigned int type = ( 256 * p[i+k] ) + p[i+k+1];
	k += 2;
	unsigned int length = ( 256 * p[i+k] ) + p[i+k+1];
	k += 2;
	fprintf( stderr, "Extension %d %d\n", k-4, type );
	if ( type == 0 ) { // Server Name
	    if ( p[i+k+2] ) {
		break; // Name badness
	    }
	    unsigned int name_length = ( 256 * p[i+k+3] ) + p[i+k+4];
	    unsigned char *path = &p[i+k+5];
	    memcpy( buffer, path, name_length );
	    buffer[ name_length ] = '\0';
	    return buffer;
	}
	k += length;
    }
    // This point is only reached on "missing or bad SNI".
    view_payload( data, length );
    return 0;
}

/**
 * HTTP traffic includes a data packet with the host name as a
 * "Host:" attribute.
 */
static unsigned char *http_host(unsigned char *data,int length) {
    unsigned char *body = data + sizeof( struct packet );
    switch ( ((struct packet *) data)->p.pkt4.first.ip_v ) {
    case 4:
	body = data + sizeof( struct ipv4_pkt );
	break;
    case 6:
	body = data + sizeof( struct ipv6_pkt );
	break;
    default:
	return 0;
    }
    if ( ( strncmp( (char*) body, "GET ", 4 ) != 0 ) &&
	 ( strncmp( (char*) body, "POST ", 5 ) != 0 ) ) {
	return 0;
    }
    unsigned char *end = data + length - 6;
    int check = 0;
    for ( ; body < end; body++ ) {
	if ( check ) {
	    if ( strncmp( (char*) body, "Host:", 5 ) == 0 ) {
		body += 5;
		for( ; body < end; body++ ) if ( *body != ' ' ) break;
		unsigned char *start = body;
		int n = 0;
		for( ; body < end; n++, body++ ) if ( *body <= ' ' ) break;
		if ( n < 5 ) {
		    return 0;
		}
		memcpy( buffer, start, n );
		buffer[ n ] = '\0';
		return buffer;
	    }
	    if ( strncmp( (char*) body, "\r\n", 2 ) == 0 ) {
		return 0;
	    }
	    for( ; body < end; body++ ) if ( *body == '\n' ) break;
	    if ( body >= end ) {
		return 0;
	    }
	}
	check = ( *body == '\n' );
    }
    return 0;
}

/**
 * Callback function to handle a packet.
 */
static int cb(
    struct nfq_q_handle *qh,
    struct nfgenmsg *nfmsg,
    struct nfq_data *nfa, void *code )
{
    u_int32_t id = get_packet_id( nfa );
    unsigned char *data;
    int length = nfq_get_payload( nfa, &data);
    int verdict = NF_ACCEPT;
    struct packet *header = get_headerP( data );
    if ( length >= 100 ) {
	unsigned char *host = http_host( data, length );
	if ( host == 0 ) {
	    host = ssl_host( data, length );
	}
	if ( host ) {
	    int i = lookup_cache( host );
	    if ( i < 0 ) {
		unsigned int ix = check_domain( host );
		add_cache( host, ix );
		if ( ix > 0 ) {
		    // Notify "new" domain blocking
		    fprintf( stderr, "%d: block %s at %s by %d\n",
			     hash_code( host ), host, tell_ip( header ), ix );
		    verdict = NF_DROP;
		}
	    } else if ( i > 0 ) {
		verdict = NF_DROP;
	    }
	}
    }
    return nfq_set_verdict(qh, id, verdict, 0, NULL);
}

/**
 * Program main function. Load block lists, register netfilter
 * calllback and go for it.
 */
int main(int argc, char **argv) {
    // Load the database
    start_domain_database_loading();
    int n = 1;
    for ( ; n < argc; n++ ) {
	fprintf( stderr, "blockdomains loads block list %s\n", argv[ n ] );
	load_domains( argv[ n ] );
    }
    end_domain_database_loading();
    
    struct nfq_handle *h;
    struct nfq_q_handle *qh;
    //struct nfnl_handle *nh;
    int fd;
    int rv;
    char buf[4096] __attribute__ ((aligned));
    
    fprintf( stderr, "blockdomains opens library handle\n");
    h = nfq_open();
    if ( !h ) {
	fprintf(stderr, "blockdomains error during nfq_open()\n");
	exit(1);
    }
    
    fprintf( stderr, "blockdomains unbinds any existing nf_queue handler\n" );
    if ( nfq_unbind_pf(h, AF_INET) < 0 ) {
	fprintf(stderr, "error during nfq_unbind_pf()\n");
	exit(1);
    }
    
    fprintf( stderr, "blockdomains binds as nf_queue handler\n" );
    if ( nfq_bind_pf(h, AF_INET) < 0 ) {
	fprintf(stderr, "error during nfq_bind_pf()\n");
	exit(1);
    }

#define THEQUEUE 99
    fprintf( stderr, "blockdomains registers to queue '%d'\n", THEQUEUE );
    qh = nfq_create_queue( h,  THEQUEUE, &cb, NULL );
    if ( !qh ) {
	fprintf(stderr, "blockdomains error during nfq_create_queue()\n");
	exit(1);
    }
    
    fprintf( stderr, "blockdomains setting copy_packet mode\n" );
    if ( nfq_set_mode(qh, NFQNL_COPY_PACKET, 0xffff ) < 0) {
	fprintf(stderr, "blockdomains can't set packet_copy mode\n");
	exit(1);
    }
    
    fd = nfq_fd( h );
    
    while ( ( rv = recv(fd, buf, sizeof(buf), 0) ) && rv >= 0 ) {
	//printf( "pkt received\n" );
	nfq_handle_packet(h, buf, rv);
    }
    
    fprintf( stderr, "blockdomains unbinding from queue %d\n", THEQUEUE);
    nfq_destroy_queue(qh);
    
#ifdef INSANE
    /* normally, applications SHOULD NOT issue this command, since it
       detaches other programs/sockets from AF_INET, too ! */
    fprintf( stderr, "blockdomains unbinding from AF_INET\n");
    nfq_unbind_pf(h, AF_INET);
#endif
    
    fprintf( stderr, "blockdomains closing library handle\n");
    nfq_close( h );
    
    exit( 0 );
}
