カラツバ法 正の整数範囲

#カラツバ法 正の整数範囲


import math


#多倍長整数
class Multiple:
    #数字列の各桁のリストを受け取る
    #1234 → 1 2 3 4
    #12 34 → 12 34
    def __init__(self, values):
        self.values = [v for v in reversed(values)]
        self.N = len(self.values)
        i = 1
        while(i < self.N):
            i = i * 2
        for j in range(i - self.N):
            self.values.append(0)
        self.N = len(self.values)

    def __str__(self):
        s = ''
        f = False
        for d in reversed(self.values):
            if 0 < d or f:
                s += str(d)
                f = True
        return s
    # +
    def __add__(self, a):
        self.setDigits(a)
        add = []
        for i in range(self.N):
            add.append(self.values[i] + a.values[i])
        return Multiple(list(reversed(add)))

    # -
    def __sub__(self, a):
        self.setDigits(a)
        sub = []
        for i in range(self.N):
            sub.append(self.values[i] - a.values[i])
        m = Multiple(list(reversed(sub)))
        #m.caryy()
        return m

    # <<
    def __lshift__(self, shift):
        ls = []
        #シフト後の空きスペースを0で埋める先頭の0は追加しない
        for i in range(shift):
            ls.append(0)

        #シフト後の値を下位桁から追加する
        head = 0
        for v in reversed(self.values):
            if v == 0:
                head += 1
            else:
                break
        for i in range(self.N - head):
            ls.append(self.values[i])

        #桁数を2の累乗に調整して空いたスペースは0で埋める
        size = self.N
        while(size < len(ls)):
            size *= 2
        space = size - len(ls)
        for i in range(space):
            ls.append(0)

        return Multiple(list(reversed(ls)))


    #計算する2つの値に桁数の違いがある場合、大きい桁数に揃える
    def setDigits(self, a):
        p = a if self.N <= a.N else self
        q = self if self.N <= a.N else a
        for i in range(p.N - q.N):
            q.values.append(0)
        q.N = p.N

    #selfとqについて,"桁ごとの加算"を行う
    def add(self, q):
        return self + q

    #"繰り上がり"・"繰り下がり"の処理を行う
    def caryy(self):
        #繰り上がり
        carry = 0
        for i in range(self.N):
            self.values[i] += carry
            if 10 <= self.values[i]:
                carry = int(self.values[i] / 10)
                self.values[i] = self.values[i] - (carry * 10)
            else:
                carry = 0
        if carry:
            self.values.append(carry)

        #繰り下がり
        for i in range(self.N - 1):
            if self.values[i] < 0:
                k = 0
                for j in range(i + 1, self.N):
                    if 0 < self.values[j]:
                        k = j
                        break
                if 0 < k:
                    self.values[k] -= 1
                    for n in range(k - 1, i, -1):
                        self.values[n] += 9
                    self.values[i] += 10

        return self


    #selfについてvaluesの添字が大きい方のk個の要素を返す
    def left(self, k):
        ret = Multiple(list(reversed(self.values[-k:])))
       # print('left:', ret)
        return ret

    #selfについてvaluesの添字が小さい方のk個の要素を返す
    def right(self, k):
        ret = Multiple(list(reversed(self.values[:k])))
       # print('right', ret)
        return ret

    #left(k) + right(k) の結果を返す
    def lradd(self, k):
        return self.left(k) + self.right(k)

    #selfを10^k倍する
    def shift(self, k):
        return self << k

    #selfとqについて,"桁ごとの減算"を行いself-qを返す
    def sub(self, q):
        return self - q


#カラツバ法
class Karatsuba:

    class Node:
        def __init__(self, n, val1, val2):
            self.N = n                     #最大の桁数
            self.val1 = val1               #左辺の値
            self.val2 = val2               #右辺の値
            self.result = Multiple([0,0])  #計算結果

        def printnode(self):
            print(self.N)
            print(self.val1)
            print(self.val2)
            print(self.result)

    def __init__(self, val1, val2):
        val1.setDigits(val2)
        self.root = Karatsuba.Node(val1.N, val1, val2)
        self.t_depth = int(math.log(self.root.N, 2) + 1)
        size = 0
        for i in range(self.t_depth):
            size += int(math.pow(3, i))
        self.elements = [0] * size
        self.layer_top = []

        self.newLayer_Top()
        self.newElements()
        self.bottomLayerResult()
        self.layerResult()
        self.result = self.carry()

     #ツリーの各層のエレメント配列上でのインデックスを算出する
    def newLayer_Top(self):
        self.layer_top.append(1)
        for i in range(self.t_depth - 1):
            self.layer_top.append(self.layer_top[i] + int(math.pow(3, i)))

    #ツリーを構築する
    #ルートノードを用意。桁数はval1の桁数を使う
    #ルートノードの層から、最下層以外の層を順に処理
    #親ノードになる層の要素数だけ繰り返す
    #親ノードの要素を取得
    #子ノードの桁数を算出
    #子ノード①へのインデックス
    #左:A*C
    #中央:B*D
    #右:(A+B)*(C+D)
    def newElements(self):
        self.elements[0] = self.root
        for dp in range(self.t_depth - 1):
            for i in range(int(math.pow(3, dp))):
                pe = self.elements[self.layer_top[dp] - 1 + i]
                cn = int(pe.N / 2)
                tidx = self.layer_top[dp + 1] - 1 + (i * 3)
                self.elements[tidx] = self.Node(cn, pe.val1.left(cn), pe.val2.left(cn))
                self.elements[tidx + 1] = self.Node(cn, pe.val1.right(cn), pe.val2.right(cn))
                self.elements[tidx + 2] = self.Node(cn, pe.val1.lradd(cn), pe.val2.lradd(cn))

    #最下層の計算
    def bottomLayerResult(self):
        for i in range(int(math.pow(3, self.t_depth - 1))):
            el = self.elements[self.layer_top[self.t_depth - 1] - 1 + i]
            mul = el.val1.values[0] * el.val2.values[0]
            el.result.N = 2
            el.result.values[0] = int(mul % 10)
            el.result.values[1] = int(mul / 10)

    #最下層以外の計算
    def layerResult(self):
        for dp in range(self.t_depth - 1, 0, -1):
            for i in range(int(math.pow(3, dp - 1))):
                el = self.elements[self.layer_top[dp - 1] - 1 + i]
                cidx = self.layer_top[dp] - 1 + (i * 3)
                s1 = self.elements[cidx + 2].result.sub(self.elements[cidx].result)
                s2 = s1.sub(self.elements[cidx + 1].result)
                p1 = self.elements[cidx].result.shift(el.N)
                p2 = s2.shift(int(el.N / 2))
                p3 = self.elements[cidx + 1].result
                el.result = (p1.add(p2)).add(p3)


    #キャリー処理
    def carry(self):
        return self.elements[0].result.caryy()

    def getresult(self):
        return self.root.result

    def printkara(self):
        self.root.printnode()



if __name__ == '__main__':
    #a, b = input().split()
    a = list(map(int, '123333333333321'))
    b = list(map(int, '900900900900990990990991'))
    mpl =  Multiple(a)
    mpl2 = Multiple(b)
    kara = Karatsuba(mpl, mpl2)
    kara.printkara()