#!/usr/bin/env python3

# Copyright (C) 2007-2016 Giampaolo Rodola' <g.rodola@gmail.com>.
# Use of this source code is governed by MIT license that can be
# found in the LICENSE file.

"""
FTP server benchmark script.

In order to run this you must have a listening FTP server with a user
with writing permissions configured.
This is a stand-alone script which does not depend from pyftpdlib.
psutil dep (optional) can be installed to keep track of FTP server
memory usage).

Example usages:
  ftpbench -u USER -p PASSWORD
  ftpbench -u USER -p PASSWORD -H ftp.domain.com -P 21   # host / port
  ftpbench -u USER -p PASSWORD -b transfer
  ftpbench -u USER -p PASSWORD -b concurrence
  ftpbench -u USER -p PASSWORD -b all
  ftpbench -u USER -p PASSWORD -b concurrence -n 500     # 500 clients
  ftpbench -u USER -p PASSWORD -b concurrence -s 20M     # file size
  ftpbench -u USER -p PASSWORD -b concurrence -p 3521    # memory usage
"""

# Some benchmarks (Linux 3.0.0, Intel core duo - 3.1 Ghz).
# pyftpdlib 1.0.0:
#
#   (starting with 6.7M of memory being used)
#   STOR (client -> server)                              557.97 MB/sec  6.7M
#   RETR (server -> client)                             1613.82 MB/sec  6.8M
#   300 concurrent clients (connect, login)                1.20 secs    8.8M
#   STOR (1 file with 300 idle clients)                  567.52 MB/sec  8.8M
#   RETR (1 file with 300 idle clients)                 1561.41 MB/sec  8.8M
#   300 concurrent clients (RETR 10.0M file)               3.26 secs    10.8M
#   300 concurrent clients (STOR 10.0M file)               8.46 secs    12.6M
#   300 concurrent clients (QUIT)                          0.07 secs
#
#
# proftpd 1.3.4a:
#
#   (starting with 1.4M of memory being used)
#   STOR (client -> server)                              554.67 MB/sec  3.2M
#   RETR (server -> client)                             1517.12 MB/sec  3.2M
#   300 concurrent clients (connect, login)                9.30 secs    568.6M
#   STOR (1 file with 300 idle clients)                  484.11 MB/sec  570.6M
#   RETR (1 file with 300 idle clients)                 1534.61 MB/sec  570.6M
#   300 concurrent clients (RETR 10.0M file)               3.67 secs    568.6M
#   300 concurrent clients (STOR 10.0M file)              11.21 secs    568.7M
#   300 concurrent clients (QUIT)                          0.43 secs
#
#
# vsftpd 2.3.2
#
#   (starting with 352.0K of memory being used)
#   STOR (client -> server)                              607.23 MB/sec  816.0K
#   RETR (server -> client)                             1506.59 MB/sec  816.0K
#   300 concurrent clients (connect, login)               18.91 secs    140.9M
#   STOR (1 file with 300 idle clients)                  618.99 MB/sec  141.4M
#   RETR (1 file with 300 idle clients)                 1402.48 MB/sec  141.4M
#   300 concurrent clients (RETR 10.0M file)               3.64 secs    140.9M
#   300 concurrent clients (STOR 10.0M file)               9.74 secs    140.9M
#   300 concurrent clients (QUIT)                          0.00 secs

import argparse
import asynchat
import asyncore
import atexit
import contextlib
import ftplib
import os
import ssl
import sys
import time


try:
    import resource
except ImportError:
    resource = None

try:
    import psutil
except ImportError:
    psutil = None


HOST = 'localhost'
PORT = 21
USER = None
PASSWORD = None
TESTFN = "$testfile"
BUFFER_LEN = 8192
SERVER_PROC = None
TIMEOUT = None
FILE_SIZE = "10M"
SSL = False

server_memory = []


if not sys.stdout.isatty() or os.name != 'posix':

    def hilite(s, *args, **kwargs):
        return s

else:
    # https://goo.gl/6V8Rm
    def hilite(string, ok=True, bold=False):
        """Return an highlighted version of 'string'."""
        attr = []
        if ok is None:  # no color
            pass
        elif ok:  # green
            attr.append('32')
        else:  # red
            attr.append('31')
        if bold:
            attr.append('1')
        return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string)


def print_bench(what, value, unit=""):
    s = "%s %s %-8s" % (
        hilite("%-50s" % what, ok=None, bold=0),
        hilite("%8.2f" % value),
        unit,
    )
    if server_memory:
        s += "%s" % hilite(server_memory.pop())
    print(s.strip())


# https://goo.gl/zeJZl
def bytes2human(n, format="%(value).1f%(symbol)s"):
    """
    >>> bytes2human(10000)
    '9K'
    >>> bytes2human(100001221)
    '95M'
    """
    symbols = ('B', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y')
    prefix = {}
    for i, s in enumerate(symbols[1:]):
        prefix[s] = 1 << (i + 1) * 10
    for symbol in reversed(symbols[1:]):
        if n >= prefix[symbol]:
            value = float(n) / prefix[symbol]
            return format % locals()
    return format % dict(symbol=symbols[0], value=n)


# https://goo.gl/zeJZl
def human2bytes(s):
    """
    >>> human2bytes('1M')
    1048576
    >>> human2bytes('1G')
    1073741824
    """
    symbols = ('B', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y')
    letter = s[-1:].strip().upper()
    num = s[:-1]
    assert num.isdigit() and letter in symbols, s
    num = float(num)
    prefix = {symbols[0]: 1}
    for i, s in enumerate(symbols[1:]):
        prefix[s] = 1 << (i + 1) * 10
    return int(num * prefix[letter])


def register_memory():
    """Register an approximation of memory used by FTP server process
    and all of its children.
    """

    # XXX How to get a reliable representation of memory being used is
    # not clear. (rss - shared) seems kind of ok but we might also use
    # the private working set via get_memory_maps().private*.
    def get_mem(proc):
        if os.name == 'posix':
            mem = proc.memory_info_ex()
            counter = mem.rss
            if 'shared' in mem._fields:
                counter -= mem.shared
            return counter
        else:
            # TODO figure out what to do on Windows
            return proc.get_memory_info().rss

    if SERVER_PROC is not None:
        mem = get_mem(SERVER_PROC)
        for child in SERVER_PROC.children():
            mem += get_mem(child)
        server_memory.append(bytes2human(mem))


def timethis(what):
    """Utility function for making simple benchmarks (calculates time calls).
    It can be used either as a context manager or as a decorator.
    """

    @contextlib.contextmanager
    def benchmark():
        timer = time.clock if sys.platform == "win32" else time.time
        start = timer()
        yield
        stop = timer()
        res = stop - start
        print_bench(what, res, "secs")

    if callable(what):

        def timed(*args, **kwargs):
            with benchmark():
                return what(*args, **kwargs)

        return timed
    else:
        return benchmark()


def connect():
    """Connect to FTP server, login and return an ftplib.FTP instance."""
    ftp_class = ftplib.FTP if not SSL else ftplib.FTP_TLS
    ftp = ftp_class(timeout=TIMEOUT)
    ftp.connect(HOST, PORT)
    ftp.login(USER, PASSWORD)
    if SSL:
        ftp.prot_p()  # secure data connection
    return ftp


def retr(ftp):
    """Same as ftplib's retrbinary() but discard the received data."""
    ftp.voidcmd('TYPE I')
    with contextlib.closing(ftp.transfercmd("RETR " + TESTFN)) as conn:
        recv_bytes = 0
        while True:
            data = conn.recv(BUFFER_LEN)
            if not data:
                break
            recv_bytes += len(data)
    ftp.voidresp()


def stor(ftp=None):
    """Same as ftplib's storbinary() but just sends dummy data
    instead of reading it from a real file.
    """
    if ftp is None:
        ftp = connect()
        quit = True
    else:
        quit = False
    ftp.voidcmd('TYPE I')
    with contextlib.closing(ftp.transfercmd("STOR " + TESTFN)) as conn:
        chunk = b'x' * BUFFER_LEN
        total_sent = 0
        while True:
            sent = conn.send(chunk)
            total_sent += sent
            if total_sent >= FILE_SIZE:
                break
    ftp.voidresp()
    if quit:
        ftp.quit()
    return ftp


def bytes_per_second(ftp, retr=True):
    """Return the number of bytes transmitted in 1 second."""
    tot_bytes = 0
    if retr:

        def request_file():
            ftp.voidcmd('TYPE I')
            conn = ftp.transfercmd("retr " + TESTFN)
            return conn

        with contextlib.closing(request_file()) as conn:
            register_memory()
            stop_at = time.time() + 1.0
            while stop_at > time.time():
                chunk = conn.recv(BUFFER_LEN)
                if not chunk:
                    a = time.time()
                    while conn.recv(BUFFER_LEN):
                        break
                    conn.close()
                    ftp.voidresp()
                    conn = request_file()
                    stop_at += time.time() - a
                tot_bytes += len(chunk)

        conn.close()
        try:
            ftp.voidresp()
        except (ftplib.error_temp, ftplib.error_perm):
            pass
    else:
        ftp.voidcmd('TYPE I')
        with contextlib.closing(ftp.transfercmd("STOR " + TESTFN)) as conn:
            register_memory()
            chunk = b'x' * BUFFER_LEN
            stop_at = time.time() + 1
            while stop_at > time.time():
                tot_bytes += conn.send(chunk)
        ftp.voidresp()

    return tot_bytes


def cleanup():
    ftp = connect()
    try:
        if TESTFN in ftp.nlst():
            ftp.delete(TESTFN)
    except (ftplib.error_perm, ftplib.error_temp) as err:
        msg = "could not delete %r test file on cleanup: %r" % (TESTFN, err)
        print(hilite(msg, ok=False), file=sys.stderr)
    ftp.quit()


def bench_stor(ftp=None, title="STOR (client -> server)"):
    if ftp is None:
        ftp = connect()
    tot_bytes = bytes_per_second(ftp, retr=False)
    print_bench(title, round(tot_bytes / 1024.0 / 1024.0, 2), "MB/sec")
    ftp.quit()


def bench_retr(ftp=None, title="RETR (server -> client)"):
    if ftp is None:
        ftp = connect()
    tot_bytes = bytes_per_second(ftp, retr=True)
    print_bench(title, round(tot_bytes / 1024.0 / 1024.0, 2), "MB/sec")
    ftp.quit()


def bench_multi(howmany):
    # The OS usually sets a limit of 1024 as the maximum number of
    # open file descriptors for the current process.
    # Let's set the highest number possible, just to be sure.
    if howmany > 500 and resource is not None:
        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
        resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))

    def bench_multi_connect():
        with timethis("%i concurrent clients (connect, login)" % howmany):
            clients = []
            for _ in range(howmany):
                clients.append(connect())
            register_memory()
        return clients

    def bench_multi_retr(clients):
        stor(clients[0])
        with timethis(
            "%s concurrent clients (RETR %s file)"
            % (howmany, bytes2human(FILE_SIZE))
        ):
            for ftp in clients:
                ftp.voidcmd('TYPE I')
                conn = ftp.transfercmd("RETR " + TESTFN)
                AsyncReader(conn)
            register_memory()
            asyncore.loop(use_poll=True)
        for ftp in clients:
            ftp.voidresp()

    def bench_multi_stor(clients):
        with timethis(
            "%s concurrent clients (STOR %s file)"
            % (howmany, bytes2human(FILE_SIZE))
        ):
            for ftp in clients:
                ftp.voidcmd('TYPE I')
                conn = ftp.transfercmd("STOR " + TESTFN)
                AsyncWriter(conn, FILE_SIZE)
            register_memory()
            asyncore.loop(use_poll=True)
        for ftp in clients:
            ftp.voidresp()

    def bench_multi_quit(clients):
        for ftp in clients:
            AsyncQuit(ftp.sock)
        with timethis("%i concurrent clients (QUIT)" % howmany):
            asyncore.loop(use_poll=True)

    clients = bench_multi_connect()
    bench_stor(title="STOR (1 file with %s idle clients)" % len(clients))
    bench_retr(title="RETR (1 file with %s idle clients)" % len(clients))
    bench_multi_retr(clients)
    bench_multi_stor(clients)
    bench_multi_quit(clients)


@contextlib.contextmanager
def handle_ssl_want_rw_errs():
    try:
        yield
    except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as err:
        if DEBUG:
            print(err)


class AsyncReader(asyncore.dispatcher):
    """Just read data from a connected socket, asynchronously."""

    def __init__(self, sock):
        asyncore.dispatcher.__init__(self, sock)

    def handle_read(self):
        if SSL:
            with handle_ssl_want_rw_errs():
                chunk = self.socket.recv(65536)
        else:
            chunk = self.socket.recv(65536)
        if not chunk:
            self.close()

    def handle_close(self):
        self.close()

    def handle_error(self):
        raise  # noqa


class AsyncWriter(asyncore.dispatcher):
    """Just write dummy data to a connected socket, asynchronously."""

    def __init__(self, sock, size):
        asyncore.dispatcher.__init__(self, sock)
        self.size = size
        self.sent = 0
        self.chunk = b'x' * BUFFER_LEN

    def handle_write(self):
        if SSL:
            with handle_ssl_want_rw_errs():
                self.sent += asyncore.dispatcher.send(self, self.chunk)
        else:
            self.sent += asyncore.dispatcher.send(self, self.chunk)
        if self.sent >= self.size:
            self.handle_close()

    def handle_error(self):
        raise  # noqa


class AsyncQuit(asynchat.async_chat):

    def __init__(self, sock):
        asynchat.async_chat.__init__(self, sock)
        self.in_buffer = []
        self.set_terminator(b'\r\n')
        self.push(b'QUIT\r\n')

    def collect_incoming_data(self, data):
        self.in_buffer.append(data)

    def found_terminator(self):
        self.handle_close()

    def handle_error(self):
        raise


def main():
    global HOST, PORT, USER, PASSWORD, SERVER_PROC, TIMEOUT, SSL, FILE_SIZE, DEBUG

    USAGE = (
        "%s -u USERNAME -p PASSWORD [-H] [-P] [-b] [-n] [-s] [-k] "
        "[-t] [-d] [-S]" % (os.path.basename(__file__))
    )

    parser = argparse.ArgumentParser(
        usage=USAGE,
    )

    parser.add_argument(
        '-u', '--user', dest='user', required=True, help='username'
    )
    parser.add_argument(
        '-p', '--pass', dest='password', required=True, help='password'
    )
    parser.add_argument(
        '-H', '--host', dest='host', default=HOST, help='hostname'
    )
    parser.add_argument(
        '-P', '--port', dest='port', default=PORT, type=int, help='port'
    )
    parser.add_argument(
        '-b',
        '--benchmark',
        dest='benchmark',
        default='transfer',
        help=(
            "benchmark type ('transfer', 'download', 'upload', 'concurrence',"
            " 'all')"
        ),
    )
    parser.add_argument(
        '-n',
        '--clients',
        dest='clients',
        default=200,
        type=int,
        help="number of concurrent clients used by 'concurrence' benchmark",
    )
    parser.add_argument(
        '-s',
        '--filesize',
        dest='filesize',
        default="10M",
        help="file size used by 'concurrence' benchmark (e.g. '10M')",
    )
    parser.add_argument(
        '-k',
        '--pid',
        dest='pid',
        default=None,
        type=int,
        help="the PID of the FTP server process, to track its memory usage",
    )
    parser.add_argument(
        '-t',
        '--timeout',
        dest='timeout',
        default=TIMEOUT,
        type=int,
        help="the socket timeout",
    )
    parser.add_argument(
        '-d',
        '--debug',
        action='store_true',
        dest='debug',
        help="whether to print debugging info",
    )
    parser.add_argument(
        '-S',
        '--ssl',
        action='store_true',
        dest='ssl',
        help="whether to use FTPS",
    )

    options = parser.parse_args()

    USER = options.user
    PASSWORD = options.password
    HOST = options.host
    PORT = options.port
    TIMEOUT = options.timeout
    SSL = options.ssl
    DEBUG = options.debug

    try:
        FILE_SIZE = human2bytes(options.filesize)
    except (ValueError, AssertionError):
        parser.error("invalid file size %r" % options.filesize)

    if options.pid is not None:
        if psutil is None:
            raise ImportError("-k option requires psutil module")
        SERVER_PROC = psutil.Process(options.pid)

    # before starting make sure we have write permissions
    ftp = connect()
    conn = ftp.transfercmd("STOR " + TESTFN)
    conn.close()
    ftp.voidresp()
    ftp.delete(TESTFN)
    ftp.quit()
    atexit.register(cleanup)

    # start benchmark
    if SERVER_PROC is not None:
        register_memory()
        print(
            "(starting with %s of memory being used)"
            % (hilite(SERVER_PROC.memory_info().rss))
        )

    if options.benchmark == 'download':
        stor()
        bench_retr()
    elif options.benchmark == 'upload':
        bench_stor()
    elif options.benchmark == 'transfer':
        bench_stor()
        bench_retr()
    elif options.benchmark == 'concurrence':
        bench_multi(options.clients)
    elif options.benchmark == 'all':
        bench_stor()
        bench_retr()
        bench_multi(options.clients)
    else:
        sys.exit("invalid 'benchmark' parameter %r" % options.benchmark)


if __name__ == '__main__':
    main()
