# Copyright 2002-2004 Andrew Loewenstern, All Rights Reserved
# See khashmir.py for license information
# Updated and modified for ABC_OKC : Old King Cole

from time import clock
from hashlib import sha1

from const import K, CONCURRENT_REQS, STORE_REDUNDANCY, SECRET_RENEW
from khash import intify, newID


class ActionBase:
    """
        base class for some long running asynchronous processes like finding nodes or values
    """
    def __init__(self, dht, target, callback, sched):
        self.dht = dht
        self.target = target
        self.callback = callback
        self.sched = sched
        self.num = intify(target)
        self.found = {}
        self.queried = {}
        self.answered = {}
        self.outstanding = 0
        self.finished = 0

        def sort(a, b, num = self.num):
            """
                this function is for sorting nodes relative to the ID we are looking for
            """
            x, y = num ^ a.num, num ^ b.num
            if x > y:
                return 1
            elif x < y:
                return -1
            return 0
        self.sort = sort

    def goWithNodes(self, t):
        pass


class FindNode(ActionBase):
    """
        find node action merits it's own class as it is a long running stateful process
    """
    def handleGotNodes(self, dict):
        _krpc_sender = dict['_krpc_sender']
        rsp = dict['rsp']
        sender = self.dht.Node().init(rsp['id'], _krpc_sender[0], _krpc_sender[1])
        self.dht.insertNode(sender)
        if self.finished:
            # a day late and a dollar short
            return
        self.outstanding -= 1
        if self.answered.has_key(sender.id):
            return
        self.answered[sender.id] = 1
        rspcontent = rsp.get('nodes')
        if rspcontent:
            try:
                nodes = self.dht.decodeNodeInfos(rspcontent)
            except:
                pass
            else:
                for node in nodes:
                    if node[0] != self.dht.node.id and (node[1], node[2]) != (self.dht.host, self.dht.port) \
                       and not self.found.has_key(node[0]) and not self.queried.has_key(node[0]):
                            n = self.dht.Node().init(node[0], node[1], node[2])
                            self.dht.insertNode(n, ipv6 = False, contacted = 0)
                            self.found[node[0]] = n
        if self.dht.ipv6:
            rspcontent = rsp.get('nodes6')
            if rspcontent:
                try:
                    nodes6 = self.dht.decodeNodeInfos6(rspcontent)
                except:
                    pass
                else:
                    for node in nodes6:
                        if node[0] != self.dht.node.id and (node[1], node[2]) != (self.dht.host, self.dht.port) \
                           and not self.found.has_key(node[0]) and not self.queried.has_key(node[0]):
                                n = self.dht.Node().init(node[0], node[1], node[2])
                                self.dht.insertNode(n, ipv6 = True, contacted = 0)
                                self.found[node[0]] = n
        self.schedule()

    def makeMsgFailed(self, node):
        def defaultGotNodes(err, self = self, node = node):
            #print ">>> find_node failed %s/%s" % (node.host, node.port), err
            if node.isIPv6():
                self.dht.table6.nodeFailed(node)
            else:
                self.dht.table.nodeFailed(node)
            self.outstanding -= 1
            self.schedule()
        return defaultGotNodes

    def schedule(self):
        """
            send messages to new peers, if necessary
        """
        if self.finished:
            return
        l = self.found.values()
        l.sort(self.sort)
        for node in l[:K]:
            if node.id == self.target:
                self.finished = 1
                self.found = None
                return self.callback([node])
            if not self.queried.has_key(node.id):
                df = node.findNode(self.target, self.dht.node.id)
                df.addCallbacks(self.handleGotNodes, self.makeMsgFailed(node))
                self.outstanding += 1
                self.queried[node.id] = 1
            if self.outstanding >= CONCURRENT_REQS:
                break
        assert self.outstanding >= 0
        if self.outstanding == 0:
            # all done !!
            self.finished = 1
            self.found = None
            self.sched(self.callback, 0, [l[:K]])

    def goWithNodes(self, nodes):
        """
            this starts the process, our argument is a transaction with t.extras being our list of nodes
            it's a transaction since we got called from the dispatcher
        """
        for node in nodes:
            if node.id != self.dht.node.id:
                self.found[node.id] = node
        self.schedule()


class GetValue(FindNode):
    """
        get value task
    """
    def __init__(self, dht, target, callback, internalcallback, sched):
        ActionBase.__init__(self, dht, target, callback, sched)
        self.internalcallback = internalcallback

    def handleGotNodes(self, dict):
        _krpc_sender = dict['_krpc_sender']
        rsp = dict['rsp']
        sender = self.dht.Node().init(rsp['id'], _krpc_sender[0], _krpc_sender[1])
        self.dht.insertNode(sender)
        if self.finished:
            # a day late and a dollar short
            return
        self.outstanding -= 1
        if self.answered.has_key(sender.id):
            return
        self.answered[sender.id] = sender
        # go through nodes
        # if we have any closer than what we already got, query them
        rspcontent = rsp.get('values')
        if rspcontent:
            for peerinfo in rspcontent[:]:
                if self.results.has_key(peerinfo):
                    rspcontent.remove(peerinfo)
                else:
                    self.results[peerinfo] = 1
            if rspcontent:
                decodedrspcontent = []
                for c in rspcontent:
                    if len(c) == 6:
                        try:
                            dc = self.dht.decodePeerInfo(c)
                        except:
                            continue
                    elif self.dht.ipv6 and len(c) == 18:
                        try:
                            dc = self.dht.decodePeerInfo6(c)
                        except:
                            continue
                    decodedrspcontent.append(dc)
                self.sched(self.callback, 0, [decodedrspcontent])
        else:
            rspcontent = rsp.get('nodes')
            if rspcontent:
                try:
                    nodes = self.dht.decodeNodeInfos(rspcontent)
                except:
                    pass
                else:
                    for node in nodes:
                        if node[0] != self.dht.node.id and not self.found.has_key(node[0]):
                            n = self.dht.Node().init(node[0], node[1], node[2])
                            self.dht.insertNode(n, ipv6 = False)
                            self.found[node[0]] = n
            if self.dht.ipv6:
                rspcontent = rsp.get('nodes6')
                if rspcontent:
                    try:
                        nodes = self.dht.decodeNodeInfos6(rspcontent)
                    except:
                        pass
                    else:
                        for node in nodes:
                            if node[0] != self.dht.node.id and not self.found.has_key(node[0]):
                                n = self.dht.Node().init(node[0], node[1], node[2])
                                self.dht.insertNode(n, ipv6 = True)
                                self.found[node[0]] = n
        self.schedule()

    def makeMsgFailed(self, node):
        def defaultGotNodes(err, self = self, node = node):
            #print ">>> get_peers failed %s/%s" % (node.host, node.port), err
            if node.isIPv6():
                self.dht.table6.nodeFailed(node)
            else:
                self.dht.table.nodeFailed(node)
            self.outstanding -= 1
            self.schedule()
        return defaultGotNodes

    def schedule(self):
        if self.finished:
            return
        l = self.found.values()
        l.sort(self.sort)

        for node in l[:K]:
            if not self.queried.has_key(node.id) and node.id != self.dht.node.id:
                df = node.valueForKey(self.target, self.dht.node.id)
                df.addCallbacks(self.handleGotNodes, self.makeMsgFailed(node))
                self.outstanding += 1
                self.queried[node.id] = 1
            if self.outstanding >= CONCURRENT_REQS:
                break
        assert self.outstanding >= 0
        if self.outstanding == 0:
            # all done
            self.finished = 1
            self.sched(self.callback, 0, [[]])
            if self.internalcallback:
                self.sched(self.internalcallback, 0, [self.answered.values()])
            self.found = None
            self.answered = None

    def goWithNodes(self, nodes, alreadyfound = []):
        self.results = {}
        for peerinfo in alreadyfound:
            self.results[peerinfo] = 1
        for node in nodes:
            if node.id != self.dht.node.id:
                self.found[node.id] = node
        self.schedule()


class StoreValue(ActionBase):
    """
        store value task
    """
    def __init__(self, dht, target, value, callback, sched):
        ActionBase.__init__(self, dht, target, callback, sched)
        self.value = value
        self.stored = []

    def storedValue(self, dict, node):
        self.outstanding -= 1
        self.dht.insertNode(node)
        if self.finished:
            return dict
        self.stored.append(dict['rsp']['id'])
        if len(self.stored) >= STORE_REDUNDANCY:
            self.finished = 1
            self.nodes = None
            if self.callback:
                self.callback(self.stored)
        elif len(self.stored) + self.outstanding < STORE_REDUNDANCY:
            self.schedule()
        return dict

    def storeFailed(self, err, node):
        #print ">>> store failed %s/%s" % (node.host, node.port)
        if node.isIPv6():
            self.dht.table6.nodeFailed(node)
        else:
            self.dht.table.nodeFailed(node)
        self.outstanding -= 1
        if self.finished:
            return err
        self.schedule()
        return err

    def schedule(self):
        if self.finished:
            return
        num = min(CONCURRENT_REQS - self.outstanding, STORE_REDUNDANCY - len(self.stored))
        if num == 0:
            self.finished = 1
            self.nodes = None
            if self.callback:
                self.callback(self.stored)
        while num > 0:
            try:
                node = self.nodes.pop(0)
            except IndexError:
                if self.outstanding == 0:
                    self.finished = 1
                    self.nodes = None
                    if self.callback:
                        self.callback(self.stored)
                return
            if node.id != self.dht.node.id:
                try:
                    token = self.dht.tokens[node.id]
                except:
                    continue
                df = node.storeValue(self.target, self.value, self.dht.node.id, token)
                df.addCallback(self.storedValue, node = node)
                df.addErrback(self.storeFailed, node = node)
                self.outstanding += 1
                num -= 1

    def goWithNodes(self, nodes):
        self.nodes = nodes
        self.nodes.sort(self.sort)
        self.schedule()


class Expirer:
    def __init__(self, store, sched, age, initialdelay, delay):
        self.store = store
        self.sched = sched
        self.age = age
        self.delay = delay
        self.sched(self.doExpire, initialdelay)

    def doExpire(self):
        self.store.expire(clock() - self.age)
        self.sched(self.doExpire, self.delay)


class SecretRenewer:
    def __init__(self, dht, sched):
        self.dht = dht
        self.sched = sched
        self.sched(self.doRenew, SECRET_RENEW)

    def doRenew(self):
        self.dht.prevsecret = self.dht.secret
        self.dht.secret = sha1(newID())
        self.sched(self.doRenew, SECRET_RENEW)
