# Written by Bram Cohen
# see LICENSE.txt for license information
# Updated and modified for ABC_OKC : Old King Cole

import socket
from errno import EWOULDBLOCK, EAGAIN, ECONNREFUSED, EHOSTUNREACH
from time import clock, sleep
from sys import version_info
try:
    from select import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUP
    timemult = 1000
except ImportError:
    from selectpoll import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUP
    timemult = 1

from utility import exceptionArgsToString, getABCUtility
from random import shuffle, randrange
from natpunch import UPnP_open_port, UPnP_close_port
from buffer import newReceiveBuffer


ALL = POLLIN | POLLOUT
WOULDBLOCK = (EWOULDBLOCK, EAGAIN)

UPnP_ERROR = "Unable to forward port via UPnP"


class SingleSocket:
    def __init__(self, socket_handler, sock, handler, ip = None, port = None):
        self.utility = getABCUtility()
        self.socket_handler = socket_handler
        self.socket = sock
        self.handler = handler
        self.buffer = []
        self.last_hit = clock()
        self.fileno = sock.fileno()
        self.closed = False
        self.connected = False
        self.skipped = 0
        self.port = port
        try:
            self.ip, self.port = self.socket.getpeername()
        except:
            if ip is None:
                self.ip = 'unknown'
            else:
                self.ip = ip
            if port is None:
                self.port = 'unknown'
            else:
                self.port = port
        else:
            self.ip = self.utility.expandIPv6(self.ip) if ':' in self.ip else self.ip

    def get_ip(self, real = False):
        if real:
            try:
                self.ip = self.socket.getpeername()[0]
            except:
                pass
            else:
                self.ip = self.utility.expandIPv6(self.ip) if ':' in self.ip else self.ip
        return self.ip

    def get_port(self, real = False):
        if real:
            try:
                self.port = self.socket.getpeername()[1]
            except:
                pass
        return self.port

    def get_dns(self, real = False):
        if real:
            try:
                self.ip, self.port = self.socket.getpeername()
            except:
                pass
            else:
                self.ip = self.utility.expandIPv6(self.ip) if ':' in self.ip else self.ip
        return (self.ip, self.port)

    def set_port(self, port):
        self.port = port

    def close(self):
        '''
        for x in xrange(5, 0, -1):
            try:
                f = inspect.currentframe(x).f_code
                print (f.co_filename, f.co_firstlineno, f.co_name)
                del f
            except:
                pass
        print ''
        '''
        if not self.closed:
            self.closed = True
            self.connected = False
            # if self.fileno in self.socket_handler.single_sockets:
            del self.socket_handler.single_sockets[self.fileno]
            self.socket_handler.poll.unregister(self.fileno)
            self.socket.close()
            self.socket = None
            if self.buffer:
                for rec in self.buffer:
                    if rec[2] is None or rec[2] == len(rec[0]):
                        rec[0].release()
                self.buffer[:] = []

    def buffered(self):
        return len(self.buffer)

    def write(self, s):
        #assert self.socket is not None
        self.buffer.append(s)
        if len(self.buffer) == 1:
            self.try_write()

    def try_write(self):
        if self.connected:
            dead = False
            while self.buffer:
                buf, begin, end = self.buffer[0]
                if end is None:
                    end = len(buf)

                try:
                    amount = self.socket.send(buf.viewSlice(begin, end))
                except socket.error, e:
                    try:
                        dead = e[0] not in WOULDBLOCK
                    except:
                        dead = True
                    self.skipped += 1
                    break

                if amount == 0:
                    self.skipped += 1
                    break
                self.skipped = 0
                if amount != end - begin:
                    self.buffer[0] = (buf, begin + amount, end)
                    break
                if end == len(buf):
                    buf.release()
                del self.buffer[0]

            if self.skipped >= 3:
                dead = True
            if dead:
                self.socket_handler.dead_from_write.append(self)
                return

        if self.buffer:
            self.socket_handler.poll.register(self.fileno, ALL)
        else:
            self.socket_handler.poll.register(self.fileno, POLLIN)

    def set_handler(self, handler):
        self.handler = handler

    def set_socket_handler(self, socket_handler):
        socket_handler.single_sockets[self.fileno] = self
        socket_handler.poll.register(self.fileno, POLLIN)
        del self.socket_handler.single_sockets[self.fileno]
        self.socket_handler.poll.unregister(self.fileno)
        self.socket_handler = socket_handler


class SocketHandler:
    def __init__(self, timeout, ipv6_enable, udp = False):
        self.utility = getABCUtility()
        self.timeout = timeout
        self.ipv6_enable = ipv6_enable
        self.udp = udp
        self.poll = poll()
        # {socket: SingleSocket}
        self.single_sockets = {}
        self.dead_from_write = []
        self.max_connects = 1000
        self.port_forwarded = None
        self.servers = {}
        self.interfaces = []
        self.udpserver = None
        self.sndbuf = int(self.utility.abcparams['rawserversocksndbuf'])
        self.rcvbuf = int(self.utility.abcparams['rawserversockrcvbuf'])
        self.rcvbuf_s = int(self.utility.abcparams['rawserversockrcvbufs'])
        if self.udp:
            self.socktype = socket.SOCK_DGRAM
            self.prot = 'UDP'
        else:
            self.socktype = socket.SOCK_STREAM
            self.prot = 'TCP'

    def shrinkRcvBuf(self):
        # When torrent is completed, shrink socket receive buffer for new connections
        self.rcvbuf = self.rcvbuf_s

    def scan_for_timeouts(self):
        t = clock() - self.timeout
        tokill = []
        for s in self.single_sockets.values():
            if s.last_hit < t:
                tokill.append(s)
        for k in tokill:
            if k.socket is not None:
                self._close_socket(k)

    def bind(self, port, bind = '', reuse = False, ipv6_socket_style = 1, upnp = 0):
        port = int(port)
        addrinfos = []
        self.servers = {}
        self.interfaces = []
        # if bind != "" thread it as a comma seperated list and bind to all
        # addresses (can be ips or hostnames) else bind to default ipv6 and
        # ipv4 address
        if bind:
            if self.ipv6_enable:
                sockfamily = socket.AF_UNSPEC
            else:
                sockfamily = socket.AF_INET
            bind = bind.split(',')
            for addr in bind:
                if version_info < (2, 2):
                    addrinfos.append((socket.AF_INET, None, None, None, (addr, port)))
                else:
                    addrinfos.extend(socket.getaddrinfo(addr, port,
                                     sockfamily, self.socktype))
        else:
            if self.ipv6_enable:
                addrinfos.append([socket.AF_INET6, None, None, None, ('', port)])
            if not addrinfos or ipv6_socket_style != 0:
                addrinfos.append([socket.AF_INET, None, None, None, ('', port)])
        for addrinfo in addrinfos:
            try:
                server = socket.socket(addrinfo[0], self.socktype)
                if reuse:
                    server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
                server.setblocking(0)
                server.bind(addrinfo[4])
                if bind:
                    self.interfaces.append(server.getsockname()[0])
                if self.udp:
                    self.udpserver = server
                else:
                    self.servers[server.fileno()] = server
                    server.listen(5)
                self.poll.register(server.fileno(), POLLIN)
            except socket.error, e:
                if self.udp:
                    try:
                        self.udpserver.close()
                    except:
                        pass
                    self.udpserver = None
                else:
                    for server in self.servers.values():
                        try:
                            server.close()
                        except:
                            pass
                    self.servers = None
                self.interfaces = None
                if self.ipv6_enable and ipv6_socket_style == 0 and self.servers:
                    raise socket.error(self.utility.lang.get('blockedport'))
                raise socket.error(exceptionArgsToString(e))
        if (not self.udp and not self.servers) or (self.udp and not self.udpserver):
            raise socket.error(self.utility.lang.get('cantopenserverport'))
        if upnp:
            if not UPnP_open_port(port, self.prot):
                if self.udp:
                    try:
                        self.udpserver.close()
                    except:
                        pass
                    self.udpserver = None
                else:
                    for server in self.servers.values():
                        try:
                            server.close()
                        except:
                            pass
                    self.servers = None
                self.interfaces = None
                raise socket.error(UPnP_ERROR)
            self.port_forwarded = port
        self.port = port
        if self.udp:
            return self.udpserver

    def find_and_bind(self, minport, maxport, bind = '', reuse = False,
                      ipv6_socket_style = 1, upnp = 0, randomizer = False):
        e = 'Maxport less than minport - no ports to check'
        if maxport - minport < 50 or not randomizer:
            portrange = range(minport, maxport + 1)
            if randomizer:
                shuffle(portrange)
                portrange = portrange[:20]  # check a maximum of 20 ports
        else:
            portrange = []
            while len(portrange) < 20:
                listen_port = randrange(minport, maxport + 1)
                if not listen_port in portrange:
                    portrange.append(listen_port)
        for listen_port in portrange:
            try:
                self.bind(listen_port, bind,
                          ipv6_socket_style = ipv6_socket_style, upnp = upnp)
                return listen_port
            except socket.error, e:
                e = exceptionArgsToString(e)
        raise socket.error(e)

    def set_handler(self, handler):
        self.handler = handler
        if self.udp:
            s = SingleSocket(self, self.udpserver, handler)
            self.single_sockets[s.fileno] = s

    def start_connection_raw(self, dns, sockfamily, handler):
        sock = socket.socket(sockfamily, self.socktype)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.sndbuf)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.rcvbuf)
        sock.setblocking(0)
        try:
            sock.connect_ex(dns)
        except socket.error:
            raise
        except Exception, e:
            raise socket.error(exceptionArgsToString(e))
        s = SingleSocket(self, sock, handler, dns[0], dns[1])
        self.single_sockets[s.fileno] = s
        self.poll.register(s.fileno, POLLIN)
        return s

    def start_connection(self, dns, handler = None, randomize = False):
        if handler is None:
            handler = self.handler
        if version_info < (2, 2):
            s = self.start_connection_raw(dns, socket.AF_INET, handler)
        else:
            if self.ipv6_enable:
                sockfamily = socket.AF_UNSPEC
            else:
                sockfamily = socket.AF_INET
            try:
                addrinfos = socket.getaddrinfo(dns[0], int(dns[1]),
                                               sockfamily, self.socktype)
            except socket.error, e:
                raise
            except Exception, e:
                raise socket.error(exceptionArgsToString(e))
            if randomize:
                shuffle(addrinfos)
            for addrinfo in addrinfos:
                try:
                    s = self.start_connection_raw(addrinfo[4], addrinfo[0], handler)
                    break
                except:
                    pass
            else:
                raise socket.error(self.utility.lang.get('cantconnect'))
        return s

    def _sleep(self):
        sleep(1)

    def handle_events(self, events, readsize):
        for sock, event in events:
            s = self.servers.get(sock)
            if s:
                if event & (POLLHUP | POLLERR) != 0:
                    self.poll.unregister(s)
                    s.close()
                    del self.servers[sock]
                elif len(self.single_sockets) < self.max_connects:
                    try:
                        newsock, addr = s.accept()
                        newsock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.sndbuf)
                        newsock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.rcvbuf)
                        newsock.setblocking(0)
                        nss = SingleSocket(self, newsock, self.handler)
                        self.single_sockets[nss.fileno] = nss
                        self.poll.register(nss.fileno, POLLIN)
                        self.handler.external_connection_made(nss)
                    except socket.error:
                        self._sleep()
            else:
                s = self.single_sockets.get(sock)
                if not s:
                    continue
                s.connected = True
                if event & (POLLHUP | POLLERR):
                    if not self.udp:
                        self._close_socket(s)
                    continue
                if event & POLLIN:
                    s.last_hit = clock()
                    buf = newReceiveBuffer(readsize)
                    try:
                        nbrcvd, addr = s.socket.recvfrom_into(buf.buf, readsize)
                    except socket.error, e:
                        if not self.udp and e[0] not in WOULDBLOCK:
                            self._close_socket(s)
                        buf.release()
                        continue
                    if nbrcvd:
                        buf.setLength(nbrcvd)
                        if self.udp:
                            s.handler.data_came_in(buf, addr)
                        else:
                            s.handler.data_came_in(buf)
                    else:
                        if not self.udp:
                            self._close_socket(s)
                        buf.release()
                if event & POLLOUT and s.socket:
                    s.try_write()
                    if not s.buffered():
                        s.handler.connection_flushed()

    def close_dead(self):
         while self.dead_from_write:
             old = self.dead_from_write
             self.dead_from_write = []
             for s in old:
                 if s.socket:
                     self._close_socket(s)

    def _close_socket(self, s):
        s.close()
        s.handler.connection_lost()

    def do_poll(self, t):
        r = self.poll.poll(t * timemult)
        if r is None:
            if not self.udp:
                connects = len(self.single_sockets)
                to_close = int(connects * 0.05) + 1 # close 5% of sockets
                self.max_connects = connects - to_close
                closelist = self.single_sockets.values()
                shuffle(closelist)
                for sock in closelist[:to_close]:
                    self._close_socket(sock)
            return []
        return r     

    def get_stats(self):
        return {'interfaces': self.interfaces,
                'port': self.port,
                'upnp': self.port_forwarded is not None}

    def shutdown(self):
        for ss in self.single_sockets.values():
            ss.close()
        self.single_sockets = None
        if self.udp:
            self.udpserver.close()
            self.udpserver = None
        else:
            for server in self.servers.values():
                server.close()
            self.servers = None
        self.interfaces = None
        if self.port_forwarded is not None:
            UPnP_close_port(self.port_forwarded, self.prot)
