# Written by Bram Cohen
# see LICENSE.txt for license information
# Updated and modified for ABC_OKC : Old King Cole

from BitTornado.CurrentRateMeasure import Measure
from BitTornado.bitfield import Bitfield
from random import shuffle
from time import clock


SEEDSEXPIRETIME = 120 * 60


class PerIPStats: 	 
    def __init__(self):
        self.bad = {}
        self.numconnections = 0
        self.lastdownload = None
        self.peerid = None


class BadDataGuard:
    def __init__(self, download):
        self.download = download
        self.ip = download.dns[0]
        self.downloader = download.downloader
        self.config = self.downloader.config
        self.stats = self.downloader.perip[self.ip]

    def failed(self, index):
        self.stats.bad.setdefault(index, 0)
        # self.downloader.gotbaddata[self.ip] = 1
        self.stats.bad[index] += 1
        bantrigger = self.stats.bad[index] >= self.config['ban_err_per_piece'] \
                     or len(self.stats.bad) >= self.config['ban_bad_pieces']
        if self.stats.bad[index] >= self.config['kick_err_per_piece'] \
           or len(self.stats.bad) >= self.config['kick_bad_pieces']:
            if self.download is not None:
                self.downloader.try_kick(self.download.connection, bantrigger)
            elif self.stats.numconnections == 1 and self.stats.lastdownload is not None:
                self.downloader.try_kick(self.stats.lastdownload.connection, bantrigger)
        if bantrigger:
            self.downloader.try_ban(self.ip)
        else:
            self.downloader.picker.bump(index)


class SingleDownload:
    def __init__(self, downloader, connection):
        self.downloader = downloader
        self.connection = connection
        self.choked = True
        self.interested = False
        self.active_requests = []
        self.measure = Measure(downloader.window)
        self.peermeasure = Measure(600)
        self.have = Bitfield(downloader.numpieces)
        self.gothavebitfield = False
        self.gothave = False
        self.last = self.last2 = -1000
        self.example_interest = None
        self.backlog = 4
        self.dns = connection.get_dns()
        self.guard = BadDataGuard(self)
        self.maxmissingpiecesforrequest = (1 - self.downloader.config['min_comp_for_request']) * downloader.numpieces
        self.starttime = clock()

    def updateData(self, bitfield, magnethave):
        if bitfield is None:
            have = Bitfield(self.downloader.numpieces)
        else:
            try:
                have = Bitfield(self.downloader.numpieces, bitfield.tostring())
            except:
                self.connection.close()
                return
        for h in magnethave:
            if h < self.downloader.numpieces:
                have[h] = 1
        if self.got_have_bitfield(have):
            self.connection.upload.got_not_interested()
        # Reset delay to take into account next HAVE for the peer estimated down rate
        self.peermeasure.start = self.starttime = clock()
        if not self.downloader.paused and not self.choked and self.interested:
            self._request_more(new_unchoke = True)

    def _backlog(self, just_unchoked):
        #self.backlog = min(4 + int(self.measure.get_rate() / self.downloader.chunksize),
        #                   4 * just_unchoked + self.downloader.queue_limit())

        queuelimit = self.downloader.queue_limit()
        if queuelimit == 1 and just_unchoked:
            queuelimit = 4

        #self.backlog = min(1 + int(self.measure.get_rate() / self.downloader.chunksize), queuelimit)

        currentchunk = 2 * int(self.measure.get_rate() / self.downloader.chunksize)

        if queuelimit - currentchunk < 1:
            self.backlog = queuelimit
        else:
            self.backlog = currentchunk + min(int(queuelimit - currentchunk), 2)

        if self.backlog > 50:
            self.backlog = max(50, int(self.backlog * 0.075))
        return self.backlog
        #################################################################
        # Old algo
        # self.backlog = min(2 + int(4 * self.measure.get_rate() / self.downloader.chunksize),
                           # (2 * just_unchoked) + self.downloader.queue_limit())
        # if self.backlog > 50:
            # self.backlog = max(50, self.backlog * 0.075)
        # return self.backlog
        #################################################################

    def disconnected(self, shutdown = False):
        self.downloader.lost_peer(self)
        if not self.downloader.config['magnet']:
            if self.have.complete():
                self.downloader.picker.lost_seed()
            else:
                for i in xrange(len(self.have)):
                    if self.have[i]:
                        self.downloader.picker.lost_have(i)
            if self.have.complete() and self.downloader.storage.is_endgame():
                self.downloader.add_disconnected_seed(self.connection.get_dns())
        self._letgo(shutdown)
        self.guard.download = None

    def _letgo(self, shutdown = False):
        if self.downloader.queued_out.has_key(self):
            del self.downloader.queued_out[self]
        if shutdown or not self.active_requests or self.downloader.config['magnet']:
            return
        if self.downloader.endgamemode:
            self.active_requests[:] = []
            return
        lost = []
        for index, begin, length in self.active_requests:
            self.downloader.storage.request_lost(index, begin, length)
            lost.append(index)
        self.active_requests[:] = []
        if self.downloader.paused:
            return
        ds = [d for d in self.downloader.downloads if not d.choked]
        shuffle(ds)
        for d in ds:
            d._request_more()
        for d in self.downloader.downloads:
            if d.choked and not d.interested:
                for l in lost:
                    if d.have[l] and self.downloader.storage.do_I_have_requests(l):
                        d.send_interested()
                        break

    def got_choke(self):
        if not self.choked:
            self.choked = True
            self._letgo()

    def got_unchoke(self):
        if self.choked or self.is_snubbed():
            self.choked = False
            if self.interested and not self.downloader.config['magnet']:
                self._request_more(new_unchoke = True)
            self.last2 = clock()

    def is_choked(self):
        return self.choked

    def is_interested(self):
        return self.interested

    def send_interested(self):
        if not self.interested:
            self.interested = True
            self.connection.send_interested()
            if not self.choked:
                self.last2 = clock()

    def send_not_interested(self):
        if self.interested:
            self.interested = False
            self.connection.send_not_interested()

    def got_piece(self, index, begin, piece):
        self.last = self.last2 = clock()
        length = len(piece)
        self.measure.update_rate(length)
        self.downloader.measurefunc(length)
        try:
            self.active_requests.remove((index, begin, length))
        except ValueError:
            self.downloader.discarded += length
            piece.release()
            return False
        if self.downloader.endgamemode:
            try:
                self.downloader.all_requests.remove((index, begin, length))
            except:
                pass
        if not self.downloader.storage.piece_came_in(index, begin, piece, self.guard):
            self.downloader.piece_flunked(index)
            return False
        if self.downloader.storage.do_I_have(index):
            self.downloader.picker.complete(index)
        if self.downloader.endgamemode:
            for d in self.downloader.downloads:
                if d is not self:
                    if d.interested:
                        if d.choked:
                            assert not d.active_requests
                            d.fix_download_endgame()
                        else:
                            try:
                                d.active_requests.remove((index, begin, length))
                            except ValueError:
                                continue
                            d.connection.send_cancel(index, begin, length)
                            d.fix_download_endgame()
                    else:
                        assert not d.active_requests
        self._request_more()
        self.downloader.check_complete(index)
        return self.downloader.storage.do_I_have(index)

    def _request_more(self, new_unchoke = False):
        assert not self.choked
        if self.downloader.endgamemode:
            self.fix_download_endgame(new_unchoke)
            return
        if self.downloader.paused:
            return
        if len(self.active_requests) >= self._backlog(new_unchoke):
            if not self.active_requests:
                self.downloader.queued_out[self] = 1
            return
        if self.have.complete():
            seeddownloading = None
        else:
            seeddownloading = self.downloader.has_seed_downloading()
        ################################################################
        # CPU hog trick
        # To avoid CPU hog when too many pieces in torrent and peer is incomplete with few pieces
        # while connections have at least one downloading seed
        if self.downloader.numpieces > 5000 and self.have.nbzero > self.maxmissingpiecesforrequest \
           and seeddownloading:
            return
        ################################################################
        lost_interests = []
        while len(self.active_requests) < self.backlog:
            interest = self.downloader.picker.next(self.have,
                                                   self.downloader.storage.do_I_have_requests,
                                                   self.downloader.too_many_partials(),
                                                   self.downloader.downloads,
                                                   seeddownloading)
            if interest is None:
                break
            self.example_interest = interest
            self.send_interested()
            while len(self.active_requests) < self.backlog:
                begin, length = self.downloader.storage.new_request(interest)
                self.downloader.picker.requested(interest)
                if not self.active_requests:
                    self.last = self.last2 = clock()
                self.active_requests.append((interest, begin, length))
                self.connection.send_request(interest, begin, length)
                self.downloader.chunk_requested(length)
                if not self.downloader.storage.do_I_have_requests(interest):
                    lost_interests.append(interest)
                    break
        if not self.active_requests:
            self.send_not_interested()
        if lost_interests:
            for d in self.downloader.downloads:
                if d.active_requests or not d.interested:
                    continue
                if d.example_interest is not None and self.downloader.storage.do_I_have_requests(d.example_interest):
                    continue
                for lost in lost_interests:
                    if d.have[lost]:
                        break
                else:
                    continue
                interest = self.downloader.picker.next(d.have,
                                                       self.downloader.storage.do_I_have_requests,
                                                       self.downloader.too_many_partials(),
                                                       self.downloader.downloads,
                                                       seeddownloading)
                if interest is None:
                    d.send_not_interested()
                else:
                    d.example_interest = interest
        if self.downloader.storage.is_endgame():
            self.downloader.start_endgame()

    def fix_download_endgame(self, new_unchoke = False):
        if self.downloader.paused:
            return
        if len(self.active_requests) >= self._backlog(new_unchoke):
            if not self.active_requests and not self.choked:
                self.downloader.queued_out[self] = 1
            return
        want = [a for a in self.downloader.all_requests if self.have[a[0]] and a not in self.active_requests]
        if not (self.active_requests or want):
            self.send_not_interested()
            return
        if want:
            self.send_interested()
        if self.choked:
            return
        shuffle(want)
        del want[self.backlog - len(self.active_requests):]
        self.active_requests.extend(want)
        for piece, begin, length in want:
            self.connection.send_request(piece, begin, length)
            self.downloader.chunk_requested(length)

    def got_have(self, index):
        if self.gothave or clock() - self.starttime >= 30:
            # The first HAVE messages do not give valuable information about the peer download rate
            # because we don't know when the peer started to download these pieces
            # and also because these messages may be next to an incomplete BITFIELD message
            if index == self.downloader.storage.lastpiece:
                downed = self.downloader.storage.lastpiece_length
            else:
                downed = self.downloader.storage.piece_length
            self.peermeasure.update_rate(downed)
            self.gothave = True

        if not self.have[index]:
            self.have[index] = True
            self.downloader.picker.got_have(index)
            if self.have.complete():
                self.downloader.picker.became_seed()
                if self.downloader.storage.am_I_complete():
                    self.downloader.add_disconnected_seed(self.connection.get_dns())
                    self.connection.close()
                    return False
            if self.downloader.endgamemode:
                self.fix_download_endgame()
            elif (not self.downloader.paused
                  and not self.downloader.picker.is_blocked(index)
                  and self.downloader.storage.do_I_have_requests(index)):
                if not self.choked:
                    self._request_more()
                else:
                    self.send_interested()
        return self.have.complete()

    def _check_interests(self):
        if self.interested or self.downloader.paused:
            return
        for i in xrange(len(self.have)):
            if self.have[i] and not self.downloader.picker.is_blocked(i) \
               and (self.downloader.endgamemode or self.downloader.storage.do_I_have_requests(i)):
                self.send_interested()
                return

    def got_have_bitfield(self, have):
        if self.downloader.storage.am_I_complete() and have.complete():
            if self.downloader.super_seeding:
                self.connection.send_bitfield(have.tostring()) # be nice, show you're a seed too
            self.downloader.add_disconnected_seed(self.connection.get_dns())
            self.connection.close()
            self.gothavebitfield = True
            return False
        self.have = have
        if have.complete():
            self.downloader.picker.got_seed()
        else:
            for i in xrange(len(have)):
                if have[i]:
                    self.downloader.picker.got_have(i)
        if self.downloader.endgamemode and not self.downloader.paused:
            for piece, begin, length in self.downloader.all_requests:
                if self.have[piece]:
                    self.send_interested()
                    break
        else:
            self._check_interests()
        self.gothavebitfield = True
        return have.complete()

    def get_rate(self):
        return self.measure.get_rate()

    def _check_snubbed(self):
        t = clock()
        if self.active_requests and self.interested and not self.choked \
           and t - self.last2 > 2 * self.downloader.snub_time:
            for index, begin, length in self.active_requests:
                self.connection.send_cancel(index, begin, length)
            self._letgo()
        if self.active_requests:
            return (t - self.last > self.downloader.snub_time)
        return False

    def is_snubbed(self):
        if self.active_requests:
            return (clock() - self.last > self.downloader.snub_time)
        return False

    def is_seed_downloading(self):
        return self.have.complete() and self.measure.get_rate()


class Downloader:
    def __init__(self, storage, picker, numpieces, measurefunc, kickfunc,
                 banfunc, sched, config):
        self.config = config
        self.storage = storage
        self.picker = picker
        self.window = self.config['max_rate_period']
        self.measurefunc = measurefunc
        self.numpieces = numpieces
        self.chunksize = self.config['download_slice_size']
        self.snub_time = self.config['snub_time']
        self.kickfunc = kickfunc
        self.banfunc = banfunc
        self.sched = sched
        self.disconnectedseeds = {}
        self.downloads = []
        self.perip = {}
        # self.gotbaddata = {}
        self.kicked = {}
        self.banned = {}
        self.kickbans_ok = self.config['auto_kick']
        self.super_seeding = False
        self.endgamemode = False
        self.endgame_queued_pieces = []
        self.all_requests = []
        self.discarded = 0L
        self.download_rate = 0.
        self.bytes_requested = 0.
        self.last_time = clock()
        self.queued_out = {}
        self.paused = False
        self.sched(self.expire_old_seeds, 900)

    def updateData(self, numpieces):
        self.numpieces = numpieces

    def set_download_rate(self, rate):
        if rate != 0:
            rate = max(int(rate * 1024), 1)
        if rate != self.download_rate:
            if rate == 0 or self.download_rate == 0:
                self.bytes_requested = 0.
            self.download_rate = rate
            ds = [d for d in self.downloads if not d.choked]
            shuffle(ds)
            self.queued_out.clear()
            for d in ds:
                d._request_more()
        elif self.queued_out:
            q = self.queued_out.keys()
            shuffle(q)
            self.queued_out.clear()
            for d in q:
                d._request_more()

    def queue_limit(self):
        if not self.download_rate:
            return 10e10    # that's a big queue !
        t = clock()
        self.bytes_requested -= (t - self.last_time) * self.download_rate
        if -self.bytes_requested > 2 * self.download_rate:
            self.bytes_requested = -2 * self.download_rate
        self.last_time = t
        if self.bytes_requested < -self.chunksize:
            return int(-self.bytes_requested / self.chunksize)
        if self.bytes_requested >= 0:
            return 0
        return 1
        #return max(int(-self.bytes_requested / self.chunksize), 0)

    def chunk_requested(self, size):
        if self.download_rate:
            self.bytes_requested += size

    def make_download(self, connection):
        ip = connection.get_ip()
        perip = self.perip.setdefault(ip, PerIPStats())
        perip.peerid = connection.get_id()
        perip.numconnections += 1
        d = SingleDownload(self, connection)
        perip.lastdownload = d
        self.downloads.append(d)
        return d

    def piece_flunked(self, index):
        if self.paused:
            return
        if self.endgamemode:
            if self.downloads:
                while self.storage.do_I_have_requests(index):
                    nb, nl = self.storage.new_request(index)
                    self.all_requests.append((index, nb, nl))
                for d in self.downloads:
                    d.fix_download_endgame()
                return
            self._reset_endgame()
            return
        ds = [d for d in self.downloads if not d.choked]
        shuffle(ds)
        for d in ds:
            d._request_more()
        for d in self.downloads:
            if not d.interested and d.have[index]:
                d.example_interest = index
                d.send_interested()

    def has_downloaders(self):
        return len(self.downloads)

    def lost_peer(self, download):
        ip = download.dns[0]
        self.perip[ip].numconnections -= 1
        if self.perip[ip].lastdownload == download:
            self.perip[ip].lastdownload = None
        self.downloads.remove(download)
        if self.endgamemode and not self.downloads: # all peers gone
            self._reset_endgame()

    def _reset_endgame(self):            
        self.storage.reset_endgame(self.all_requests)
        self.endgamemode = False
        self.all_requests[:] = []
        self.endgame_queued_pieces[:] = []

    def add_disconnected_seed(self, dns):
        self.disconnectedseeds[dns] = clock()

    def num_disconnected_seeds(self):
        return len(self.disconnectedseeds)

    def expire_old_seeds(self):
        expired = []
        now = clock()
        for id, t in self.disconnectedseeds.items():
            if now - t > SEEDSEXPIRETIME:
                expired.append(id)
        for id in expired:
            del self.disconnectedseeds[id]
        self.sched(self.expire_old_seeds, 900)

    def findConnection(self, dns):
        for d in self.downloads:
            if d.dns == dns:
                return d.connection
        return None

    def _check_kicks_ok(self, bantrigger):
        # if self.kickbans_ok and len(self.gotbaddata) > 10:
            # self.kickbans_ok = False
        # return self.kickbans_ok and len(self.downloads) > 2
        return self.kickbans_ok and (len(self.downloads) > 1 or bantrigger)

    def try_kick(self, connection, bantrigger, force = False, dns = None, pid = None):
        if force or self._check_kicks_ok(bantrigger):
            if connection:
                ip = connection.get_ip()
                pid = connection.get_id()
            else:
                # Call from abc detail frame
                connection = self.findConnection(dns)
                ip = dns[0]
            self.kicked[ip] = self.perip[ip].peerid = pid
            if connection:
                self.kickfunc(connection)

    def try_ban(self, ip):
        if self._check_kicks_ok(True):
            self.banfunc(ip)
            self.banned[ip] = self.perip[ip].peerid
            if self.kicked.has_key(ip):
                del self.kicked[ip]

    def close_ban(self, connection, dns = None, pid = None):
        if connection:
            ip = connection.get_ip()
            pid = connection.get_id()
        else:
            # Call from abc detail frame
            connection = self.findConnection(dns)
            ip = dns[0]
        self.banfunc(ip)
        self.banned[ip] = self.perip[ip].peerid = pid
        if self.kicked.has_key(ip):
            del self.kicked[ip]
        if connection:
            self.kickfunc(connection)

    def unban(self, ip):
        self.banfunc(ip, False)
        if ip in self.banned:
            del self.banned[ip]

    def set_super_seed(self):
        self.super_seeding = True

    def check_complete(self, index):
        if self.config['magnet']:
            return False
        if self.endgamemode and not self.all_requests:
            self.endgamemode = False
        if self.endgame_queued_pieces and not self.endgamemode:
            self.requeue_piece_download()
        if self.storage.am_I_complete():
            assert not self.all_requests
            assert not self.endgamemode
            for d in self.downloads[:]:
                if d.have.complete():
                    d.connection.send_have(index)   # be nice, tell the other seed you completed
                    self.add_disconnected_seed(d.connection.get_dns())
                    d.connection.close()
            return True
        return False

    def too_many_partials(self):
        return len(self.storage.dirty) > len(self.downloads) / 2

    def cancel_piece_download(self, pieces):
        if self.endgamemode:
            if self.endgame_queued_pieces:
                for piece in pieces:
                    try:
                        self.endgame_queued_pieces.remove(piece)
                    except:
                        pass
            new_all_requests = []
            for index, nb, nl in self.all_requests:
                if index in pieces:
                    self.storage.request_lost(index, nb, nl)
                else:
                    new_all_requests.append((index, nb, nl))
            self.all_requests = new_all_requests

        for d in self.downloads:
            hit = False
            for index, nb, nl in d.active_requests:
                if index in pieces:
                    hit = True
                    d.connection.send_cancel(index, nb, nl)
                    if not self.endgamemode:
                        self.storage.request_lost(index, nb, nl)
            if hit:
                d.active_requests = [r for r in d.active_requests
                                     if r[0] not in pieces]
                d._request_more()
            if not self.endgamemode and d.choked:
                d._check_interests()

    def requeue_piece_download(self, pieces = []):
        if self.endgame_queued_pieces:
            for piece in pieces:
                if not piece in self.endgame_queued_pieces:
                    self.endgame_queued_pieces.append(piece)
            pieces = self.endgame_queued_pieces
        if self.endgamemode:
            if self.all_requests:
                self.endgame_queued_pieces = pieces
                return
            self.endgamemode = False
            self.endgame_queued_pieces = None

        ds = self.downloads[:]
        shuffle(ds)
        for d in ds:
            if d.choked:
                d._check_interests()
            else:
                d._request_more()

    def has_seed_downloading(self):
        for d in self.downloads:
            if d.is_seed_downloading():
                return True
        return False

    def start_endgame(self):
        assert not self.endgamemode
        self.endgamemode = True
        assert not self.all_requests
        for d in self.downloads:
            if d.active_requests:
                assert d.interested and not d.choked
            for request in d.active_requests:
                assert not request in self.all_requests
                self.all_requests.append(request)
        for d in self.downloads:
            d.fix_download_endgame()

    def pause(self, flag):
        self.paused = flag
        if self.config['magnet']:
            return
        if flag:
            for d in self.downloads:
                for index, begin, length in d.active_requests:
                    d.connection.send_cancel(index, begin, length)
                d._letgo()
                d.send_not_interested()
            if self.endgamemode:
                self._reset_endgame()
        else:
            shuffle(self.downloads)
            for d in self.downloads:
                d._check_interests()
                if d.interested and not d.choked:
                    d._request_more()

    def shutdown(self):
        self.downloads = None
        self.storage = None
