"""Implementation of AES (Advanced Encryption Standard).  See FIPS 197."""

class Object(object):
    def __init__(self, *args, **kw):
        self.__dict__ = kw
        for arg in args:
            self.__dict__[arg.__name__] = arg

# S-box array (FIPS 197 page 16).
def unhex(hex):
    return int(hex, 16)

S = map(unhex, '''63 7c 77 7b f2 6b 6f c5 30 01 67 2b fe d7 ab 76 
                  ca 82 c9 7d fa 59 47 f0 ad d4 a2 af 9c a4 72 c0 
                  b7 fd 93 26 36 3f f7 cc 34 a5 e5 f1 71 d8 31 15 
                  04 c7 23 c3 18 96 05 9a 07 12 80 e2 eb 27 b2 75 
                  09 83 2c 1a 1b 6e 5a a0 52 3b d6 b3 29 e3 2f 84 
                  53 d1 00 ed 20 fc b1 5b 6a cb be 39 4a 4c 58 cf 
                  d0 ef aa fb 43 4d 33 85 45 f9 02 7f 50 3c 9f a8 
                  51 a3 40 8f 92 9d 38 f5 bc b6 da 21 10 ff f3 d2 
                  cd 0c 13 ec 5f 97 44 17 c4 a7 7e 3d 64 5d 19 73 
                  60 81 4f dc 22 2a 90 88 46 ee b8 14 de 5e 0b db 
                  e0 32 3a 0a 49 06 24 5c c2 d3 ac 62 91 95 e4 79 
                  e7 c8 37 6d 8d d5 4e a9 6c 56 f4 ea 65 7a ae 08 
                  ba 78 25 2e 1c a6 b4 c6 e8 dd 74 1f 4b bd 8b 8a 
                  70 3e b5 66 48 03 f6 0e 61 35 57 b9 86 c1 1d 9e 
                  e1 f8 98 11 69 d9 8e 94 9b 1e 87 e9 ce 55 28 df 
                  8c a1 89 0d bf e6 42 68 41 99 2d 0f b0 54 bb 16'''.split())

# Inverse S-box array.
InvS = [0]*256
for i in range(256):
    InvS[S[i]] = i

# Conversion among numbers and blocks.
def byte(n):
    """Take the low 8 bits of n."""
    return n % (1<<8)

def word(n):
    """Take the low 32 bits of n."""
    return n % (1<<32)

def toword(a0, a1, a2, a3):
    """Convert four bytes to a word."""
    return (byte(a0)<<24) + (byte(a1)<<16) + (byte(a2)<<8) + byte(a3)

def tobytes(w):
    """Convert a word to four bytes."""
    return byte(w>>24), byte(w>>16), byte(w>>8), byte(w>>0)

def tolong(block):
    """Convert a block (list of integer bytes) to a long integer."""
    return sum([b<<((len(block) - 1 - i)*8) for i, b in enumerate(block)])

def toblock(long, blockbytes):
    """Convert a long integer to a block (list of integer bytes)."""
    return [byte(long>>shift) for shift in range(blockbytes*8 - 8, -8, -8)]

def blockxor(a, b):
    return [x ^ y for x, y in zip(a, b)]

# Word manipulation routines.
def SubWord(w):
    a0, a1, a2, a3 = tobytes(w)
    return toword(S[a0], S[a1], S[a2], S[a3])

def RotWord(w):
    a0, a1, a2, a3 = tobytes(w)
    return toword(a1, a2, a3, a0)

def reduce(n):
    """Reduce a polynomial modulo x^8 + x^4 + x^3 + x + 1."""
    bit = 1<<8
    modulus = (1<<8) + (1<<4) + (1<<3) + (1<<1) + (1<<0)
    while bit <= n:
        bit, modulus = bit<<1, modulus<<1
    while bit >= (1<<8):
        if n & bit:
            n ^= modulus
        bit, modulus = bit>>1, modulus>>1
    return n

def mult(a, b):
    """Multiply two polynomials modulo x^8 + x^4 + x^3 + x + 1."""
    product = 0
    bit = 1
    while bit <= a:
        if a & bit:
            product ^= b*bit
        bit = bit<<1
    return reduce(product)

# Rounding constants.
x = 1
Rcon = [0, toword(x, 0, 0, 0)]
for i in range(2, 256):
    x = reduce(x<<1)
    Rcon.append(toword(x, 0, 0, 0))

# Key expansion according to FIPS 197, page 19.
def KeyExpansion(K, Nk, Nb, Nr):
    # The key schedule will contain Nb*(Nr + 1) words.
    w = [0]*Nb*(Nr + 1)

    # Convert K (a string or long) into the first Nk words of the schedule.
    if isinstance(K, str):
        K = tolong(map(ord, K))
    key = [word(K>>shift) for shift in range(0, 256, 32)]
    w[:Nk] = reversed(key[:Nk])

    for i in range(Nk, Nb*(Nr + 1)):
        temp = w[i - 1]
        ih, il = int(i / Nk), i % Nk
        if il == 0:
            temp = SubWord(RotWord(temp)) ^ Rcon[ih]
        elif Nk > 6 and il == 4:
            temp = SubWord(temp)
        w[i] = w[i - Nk] ^ temp
    return w

# Transformations used in the cipher.
def SubBytes(state):
    return [S[x] for x in state]

def InvSubBytes(state):
    return [InvS[x] for x in state]

def ShiftRows((a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p)):
    return [a, f, k, p, e, j, o, d, i, n, c, h, m, b, g, l]

def InvShiftRows((a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p)):
    return [a, n, k, h, e, b, o, l, i, f, c, p, m, j, g, d]

def MixColumn((c0, c1, c2, c3)):
    m0 = mult(c0, 2) ^ mult(c1, 3) ^ c2 ^ c3
    m1 = c0 ^ mult(c1, 2) ^ mult(c2, 3) ^ c3
    m2 = c0 ^ c1 ^ mult(c2, 2) ^ mult(c3, 3)
    m3 = mult(c0, 3) ^ c1 ^ c2 ^ mult(c3, 2)
    return [m0, m1, m2, m3]

def MixColumns(state):
    return (MixColumn(state[0:4]) + MixColumn(state[4:8]) +
            MixColumn(state[8:12]) + MixColumn(state[12:16]))

def InvMixColumn((c0, c1, c2, c3)):
    m0 = mult(c0, 14) ^ mult(c1, 11) ^ mult(c2, 13) ^ mult(c3, 9)
    m1 = mult(c0, 9) ^ mult(c1, 14) ^ mult(c2, 11) ^ mult(c3, 13)
    m2 = mult(c0, 13) ^ mult(c1, 9) ^ mult(c2, 14) ^ mult(c3, 11)
    m3 = mult(c0, 11) ^ mult(c1, 13) ^ mult(c2, 9) ^ mult(c3, 14)
    return m0, m1, m2, m3

def InvMixColumns(state):
    return (InvMixColumn(state[0:4]) + InvMixColumn(state[4:8]) +
            InvMixColumn(state[8:12]) + InvMixColumn(state[12:16]))

def AddRoundKey(state, key):
    return blockxor(state, sum([tobytes(k) for k in key], ()))

def AES(keybits, K):
    """Create a cipher object with a given key K.  The length of the key,
    given by 'keybits', must be 128, 192, or 256.  The key may be given as
    a long integer or a string of 16, 24, or 32 bytes.  The resulting object
    has two methods, encipher and decipher, that take a block as input and
    give a block as output, where a block is a list of 16 integer bytes."""
    if keybits == 128:
        Nk, Nb, Nr = 4, 4, 10
    elif keybits == 192:
        Nk, Nb, Nr = 6, 4, 12
    elif keybits == 256:
        Nk, Nb, Nr = 8, 4, 14
    else:
        raise ValueError('invalid key length %d' % keybits)

    # Prepare the key schedule.
    w = KeyExpansion(K, Nk, Nb, Nr)

    def encipher(input): # Argument is a string or list of 4*Nb bytes.
        state = isinstance(input, str) and map(ord, input) or input
        state = AddRoundKey(state, w[:Nb])
        for round in range(1, Nr):
            state = SubBytes(state)
            state = ShiftRows(state)
            state = MixColumns(state)
            state = AddRoundKey(state, w[round*Nb:(round + 1)*Nb])
        state = SubBytes(state)
        state = ShiftRows(state)
        state = AddRoundKey(state, w[Nr*Nb:(Nr + 1)*Nb])
        return isinstance(input, str) and ''.join(map(chr, state)) or state

    def decipher(input): # Argument is a string or list of 4*Nb bytes.
        state = isinstance(input, str) and map(ord, input) or input
        state = AddRoundKey(state, w[Nr*Nb:(Nr + 1)*Nb])
        for round in range(Nr-1, 0, -1):
            state = InvShiftRows(state)
            state = InvSubBytes(state)
            state = AddRoundKey(state, w[round*Nb:(round + 1)*Nb])
            state = InvMixColumns(state)
        state = InvShiftRows(state)
        state = InvSubBytes(state)
        state = AddRoundKey(state, w[:Nb])
        return isinstance(input, str) and ''.join(map(chr, state)) or state

    return Object(encipher, decipher, keybits=keybits, blockbytes=4*Nb)

def pad(data, blockbytes):
    """Pad a string to fit in a whole number of blocks."""
    padded = data + '\x80'
    padded += '\x00' * ((-len(padded)) % blockbytes)
    return padded

def toblocks(padded, blockbytes):
    """Convert a string to a list of blocks (a block is a list of bytes)."""
    if len(padded) % blockbytes != 0:
        raise ValueError('incomplete padded data')
    return [map(ord, padded[i:i + blockbytes])
            for i in range(0, len(padded), blockbytes)]

# See http://www.ietf.org/internet-drafts/draft-songlee-aes-cmac-03.txt.
def mac(cipher, data):
    """Use the given cipher to compute AES-CMAC for the given string."""
    if cipher.keybits != 128:
        raise ValueError('AES-CMAC requires 128-bit keys')

    # Generate the two subkeys.
    L = tolong(cipher.encipher([0]*16))
    K1 = L<<1
    if K1 & (1<<128):
        K1 ^= (1<<128) + 0x87
    K2 = K1<<1
    if K2 & (1<<128):
        K2 ^= (1<<128) + 0x87
    K1, K2 = toblock(K1, 16), toblock(K2, 16)

    # Decide which key to use.
    n = int(len(data)/16) + (len(data) % 16 > 0)
    if n == 0:
        n = 1
        flag = False
    else:
        flag = len(data) % 16 == 0
    if flag:
        blocks = toblocks(data, 16)
        M_last = blockxor(blocks[-1], K1)
    else:
        blocks = toblocks(pad(data, 16), 16)
        M_last = blockxor(blocks[-1], K2)

    # Compute the MAC.
    X = [0]*16
    for i in range(n - 1):
        Y = blockxor(X, blocks[i])
        X = cipher.encipher(Y)
    Y = blockxor(X, M_last)
    T = cipher.encipher(Y)
    return tolong(T)

