# Written by Old King Cole
# see LICENSE.txt for license information

from BitTornado.BTcrypto import Crypto
from BitTornado.buffer import newMsgBuffer
from BitTornado.BT1.Encrypter import toint, tobinary16

protocol_name = 'BitTorrent protocol'


class PeerListener:
    def __init__(self, rawserver, dht, utility):
        self.sched = rawserver.add_task
        self.utility = utility
        self.abcparams = utility.abcparams
        # infohash -> torrent dow
        self.startedtorrentdic = {}
        if dht:
            # extended messages + dht
            self.optionpattern = chr(0) * 5 + chr(16) + chr(0) + chr(1)
        else:
            # extended messages
            self.optionpattern = chr(0) * 5 + chr(16) + chr(0) * 2

    def external_connection_made(self, connection):
        cnx = PeerListenerConnection(self, connection, self.optionpattern)
        connection.set_handler(cnx)

    def regTorrent(self, dow):
        self.startedtorrentdic[dow.infohash] = dow

    def unregTorrent(self, dow):
        del self.startedtorrentdic[dow.infohash]

    def getRawServerParams(self, time, events):
        return 0.01, True


class PeerListenerConnection:
    def __init__(self, encoder, connection, optionpattern):
        self.abcparams = encoder.abcparams
        self.encoder = encoder
        self.connection = connection
        self.option_pattern = optionpattern
        self.closed = False
        self.buffer = ''
        self.read = self._read
        self.write = self._write
        self.cryptmode = 0
        self.encrypter = None
        self.uses_dht = False
        self.uses_extended = False
        self.encrypted = None       # don't know yet
        self.next_len, self.next_func = 1 + len(protocol_name), self.read_header
        encoder.sched(self._auto_close, encoder.utility.maxpeerconnectiontimeout)

    def get_ip(self, real = False):
        return self.connection.get_ip(real)

    def get_port(self, real = False):
        return self.connection.get_port(real)

    def set_port(self, port):
        self.connection.set_port(port)

    def get_dns(self, real = False):
        return self.connection.get_dns(real)

    def get_id(self):
        return self.id

    def is_encrypted(self):
        return bool(self.encrypted)

    def buffered(self):
        return self.connection.buffered()

    def _read_header(self, s):
        if s == chr(len(protocol_name)) + protocol_name:
            return 8, self.read_options
        return None

    def read_header(self, s):
        if self._read_header(s):
            if self.abcparams['crypto'] == '3':   # crypto_stealth
                return None
            return 8, self.read_options
        # Stream is encrypted
        if self.abcparams['crypto'] == '0':
            return None
        if not self.encrypted:
            self.encrypted = True
            self.encrypter = Crypto(False)
        self._write_buffer(s)
        return self.encrypter.keylength, self.read_crypto_header

    ################## ENCRYPTION SUPPORT ######################

    def _start_crypto(self):
        self.encrypter.setrawaccess(self._read, self._write)
        self.write = self.encrypter.write
        self.read = self.encrypter.read
        if self.buffer:
            self.buffer = self.encrypter.decrypt(self.buffer)

    def _end_crypto(self):
        self.read = self._read
        self.write = self._write
        self.encrypter = None

    def read_crypto_header(self, s):
        self.encrypter.received_key(s)
        self.write((newMsgBuffer(self.encrypter.pubkey + self.encrypter.padding()), 0, None))
        self._max_search = 520
        return 0, self.read_crypto_block3a

    def _search_for_pattern(self, s, pat):
        p = s.find(pat)
        if p < 0:
            if len(s) >= len(pat):
                self._max_search -= len(s) + 1 - len(pat)
            if self._max_search < 0:
                self.close()
                return False
            self._write_buffer(s[1 - len(pat):])
            return False
        self._write_buffer(s[p + len(pat):])
        return True

    ### INCOMING CONNECTION ###
    def read_crypto_block3a(self, s):
        if not self._search_for_pattern(s, self.encrypter.block3a):
            return -1, self.read_crypto_block3a     # wait for more data
        return 20, self.read_crypto_block3b

    def read_crypto_block3b(self, s):
        # Search for matching encrypted download_id (infohash) from keys in started torrent dic
        # to set SKEY in BTcrypto
        for infohash in self.encoder.startedtorrentdic.keys():
            if self.encrypter.test_skey(s, infohash):
                break
        else:
            return None
        self._start_crypto()
        return 14, self.read_crypto_block3c

    def read_crypto_block3c(self, s):
        if s[:8] != ('\x00' * 8):           # check VC
            return None
        try:
            self.cryptmode = toint(s[8:12]) % 4
        except:
            return None
        if self.cryptmode == 0:            # no encryption selected
            return None
        if (self.cryptmode == 1
            and int(self.abcparams['crypto'])) >= 2: # only header encryption
            return None
        padlen = (ord(s[12]) << 8) + ord(s[13])
        if padlen > 512:
            return None
        return padlen + 2, self.read_crypto_pad3

    def read_crypto_pad3(self, s):
        s = s[-2:]
        ialen = (ord(s[0]) << 8) + ord(s[1])
        if ialen > 65535:
            return None
        if self.cryptmode == 1:
            cryptmode = '\x00\x00\x00\x01'    # header only encryption
        else:
            cryptmode = '\x00\x00\x00\x02'    # full stream encryption
        padd = self.encrypter.padding()
        self.write((newMsgBuffer(
                   ('\x00' * 8)            # VC
                   + cryptmode             # encryption mode
                   + tobinary16(len(padd))
                   + padd), 0, None))                 # PadD
        if ialen:
            return ialen, self.read_crypto_ia
        return self.read_crypto_block3done()

    def read_crypto_ia(self, s):
        return self.read_crypto_block3done(s)

    def read_crypto_block3done(self, ia = ''):
        if self.cryptmode == 1:     # only handshake encryption
            if not self.buffer:  # oops; check for exceptions to this
                return None
            self._end_crypto()
        if ia:
            self._write_buffer(ia)
        return 1 + len(protocol_name), self.read_encrypted_header

    ### START PROTOCOL OVER ENCRYPTED CONNECTION ###
    def read_encrypted_header(self, s):
        return self._read_header(s)

    ################################################

    def read_options(self, s):
        self.options = s
        if ord(self.options[7]) & 1:
            self.uses_dht = True
        if ord(self.options[5]) & 16:
            self.uses_extended = True
        return 20, self.read_download_id

    def read_download_id(self, s):
        # Is info_hash regged ?
        torrentdow = self.encoder.startedtorrentdic.get(s)
        if not torrentdow:
            return None
        # Switch to torrent and wait for completion
        torrentdow.addPLPeer(self.connection, self.buffer, self.uses_dht, self.uses_extended, self.encrypted,
                             self.encrypter, self.cryptmode)
        return False

    def read_dead(self, s):
        return None

    def _auto_close(self):
        if not self.closed:
            self.close()

    def close(self):
        if self.closed:
            return
        self.connection.close()
        self.sever()

    def sever(self):
        if self.closed:
            return
        self.closed = True
        self.connection = None

    def _write(self, message):
        if self.closed:
            message[0].release()
            return
        self.connection.write(message)

    def data_came_in(self, s):
        self.read(s)

    def _write_buffer(self, s):
        self.buffer = s + self.buffer

    def _read(self, s):
        if self.closed:
            s.release()
            return
        #self.encoder.measurefunc(len(s))  # not used
        self.buffer += s[:]
        s.release()
        while True:
            if self.closed:
                return
            # self.next_len = # of characters function expects
            # or 0 = all characters in the buffer
            # or -1 = wait for next read, then all characters in the buffer
            # not compatible w/ keepalives, switch out after all negotiation complete
            if self.next_len <= 0:
                m = self.buffer
                self.buffer = ''
            elif len(self.buffer) >= self.next_len:
                m = self.buffer[:self.next_len]
                self.buffer = self.buffer[self.next_len:]
            else:
                return
            try:
                x = self.next_func(m)
            except:
                self.next_len, self.next_func = 1, self.read_dead
                raise
            if x is None:
                self.close()
                return
            if not x:
                self.sever()
                return
            self.next_len, self.next_func = x
            if self.next_len < 0:  # already checked buffer
                return             # wait for additional data

    def connection_flushed(self):
        pass

    def connection_lost(self):
        self.sever()
