#!/usr/bin/env python

import sys
import os
import socket
import subprocess
import nmap
from optparse import OptionParser

ip4, ip6 = 0, 0


ANSI_COLOR_ERR = "\x1b[31m"
ANSI_COLOR_WARN = "\x1b[33m"
ANSI_COLOR_OK = "\x1b[32m"
ANSI_COLOR_RESET = "\x1b[0m"


def error(*arg):
    print(ANSI_COLOR_ERR, *arg, file=sys.stderr,
          end='%s\n' % ANSI_COLOR_RESET)


def warn(*arg):
    print(ANSI_COLOR_WARN, *arg, file=sys.stderr,
          end='%s\n' % ANSI_COLOR_RESET)


def ok(*arg):
    print(ANSI_COLOR_OK, *arg, file=sys.stderr,
          end='%s\n' % ANSI_COLOR_RESET)


def check_host_lookup(hostname, port):
    try:
        return socket.getaddrinfo(hostname, port)
    except Exception:
        error("DNS Lookup for {hostname} failed".format(hostname=hostname))
        return []


def check_icmp_reachability(gai_record):
    host = gai_record[4][0]
    family = gai_record[0]

    if family is socket.AddressFamily.AF_INET:
        process_name = 'ping'
    else:
        process_name = 'ping6'
    child = subprocess.Popen([process_name, host, '-c', '1', '-W', '5'],
                             stdout=subprocess.PIPE)
    child.communicate()
    if child.returncode:
        error("{host} is icmp unreachable".format(host=host))
    return True if child.returncode == 0 else False


def check_udp_reachability(gai_record):
    global ip4, ip6

    host, port = gai_record[4][:2]
    family = gai_record[0]

    if family is socket.AddressFamily.AF_INET:
        arguments = '-sU -PN'
    else:
        arguments = '-sU -PN -6'

    scanner = nmap.PortScanner()
    result = scanner.scan(host, str(port), arguments)  # -sU requires root
    state = result['scan'][host]['udp'][port]['state']

    if state == 'closed':
        error("{host} port {port}/udp is {state}"
              .format(host=host, port=port, state=state))
    else:
        ok("{host} port {port}/udp is {state}"
           .format(host=host, port=port, state=state))

        if family is socket.AddressFamily.AF_INET:
            ip4 += 1
        else:
            ip6 += 1

    return False if state == 'closed' else True


def get_hosts_data(srcdir):
    for fname in sorted(list(set(os.listdir(srcdir)))):
        if fname.startswith("."):
            continue

        fpath = os.path.join(srcdir, fname)
        if os.path.isfile(fpath):
            with open(fpath) as f:
                ignore_key = False
                addresses = []
                port = 655  # tinc default port

                for line in f.readlines():

                    if '-----BEGIN RSA PUBLIC KEY-----' in line:
                        ignore_key = True
                    elif '-----END RSA PUBLIC KEY-----' in line:
                        ignore_key = False

                    if line.startswith("#") or ignore_key:
                        continue

                    chunks = line.split("=")
                    if len(chunks) == 2:
                        (k, v) = (x.strip().lower() for x in chunks)

                        if k == "port":
                            try:
                                port = int(v)
                            except ValueError:
                                error("non-integer default port given")
                        elif k == "address":
                            if " " in v:
                                parts = v.split(' ')
                                if len(parts) != 2:
                                    error("unknown address format")
                                try:
                                    int(parts[1])
                                    addresses.append(parts)
                                except ValueError:
                                    error("non-integer port given")
                            else:
                                addresses.append((v, None))
                        elif k in ('ecdsapublickey', 'ed25519publickey'):
                            continue
                        else:
                            error("unknown key {key} with value {val}"
                                  .format(key=k, val=v))

                # set explicit port for address/port pairs
                for i, addr in enumerate(addresses):
                    if addr[1] is None:
                        item = (addr[0], port)
                        addresses[i] = item

                yield(dict(community=fname, addresses=addresses))


def do_checks(srcdir):
    global ip4, ip6

    errcnt = 0
    warncnt = 0

    for host in get_hosts_data(srcdir):
        print("Checking {community}".format(community=host['community']))
        if not host['addresses']:
            warn("no addresses specified")
            warncnt += 1
        for address in host['addresses']:
            host, port = address

            # dns lookup
            records = check_host_lookup(host, port)
            if not records:
                errcnt += 1
            else:
                for record in records:
                    if record[1] is not socket.SOCK_DGRAM:
                        # vpn connections are udp based, so skip
                        # everything else
                        continue

                    if not check_icmp_reachability(record):
                        errcnt += 1
                    else:
                        port_state = check_udp_reachability(record)
                        if not port_state:
                            errcnt += 1

    print("\nfound {}/{} working ipv4/ipv6 peers".format(ip4, ip6))

    error("{} errors".format(errcnt))
    warn("{} warnings".format(warncnt))

    return 0 if errcnt == 0 else 1


if __name__ == "__main__":
    parser = OptionParser()
    parser.add_option("-s", "--sourcedir", dest="src",
                      help="Location of tinc host files. Default: ../hosts",
                      metavar="DIR",
                      default="../hosts/")

    (options, args) = parser.parse_args()

    ret = do_checks(options.src)

    sys.exit(ret)