fix salsa20 on python 3

This commit is contained in:
clowwindy 2014-10-31 18:56:24 +08:00
parent 753e46654d
commit ede2b1b120
2 changed files with 18 additions and 11 deletions

View file

@ -163,8 +163,8 @@ def test():
# decipher = M2Crypto.EVP.Cipher('aes_128_cfb', 'k' * 32, 'i' * 16, 0,
# key_as_bytes=0, d='md5', salt=None, i=1,
# padding=1)
cipher = CtypesCrypto('aes-128-cfb', 'k' * 32, 'i' * 16, 1)
decipher = CtypesCrypto('aes-128-cfb', 'k' * 32, 'i' * 16, 0)
cipher = CtypesCrypto('aes-128-cfb', b'k' * 32, b'i' * 16, 1)
decipher = CtypesCrypto('aes-128-cfb', b'k' * 32, b'i' * 16, 0)
# cipher = Salsa20Cipher('salsa20-ctr', 'k' * 32, 'i' * 8, 1)
# decipher = Salsa20Cipher('salsa20-ctr', 'k' * 32, 'i' * 8, 1)

View file

@ -72,9 +72,16 @@ def numpy_xor(a, b):
def py_xor_str(a, b):
c = []
for i in range(0, len(a)):
c.append(chr(ord(a[i]) ^ ord(b[i])))
return ''.join(c)
if bytes == str:
for i in range(0, len(a)):
c.append(chr(ord(a[i]) ^ ord(b[i])))
else:
for i in range(0, len(a)):
c.append(a[i] ^ b[i])
if bytes == str:
return ''.join(c)
else:
return bytes(c)
class Salsa20Cipher(object):
@ -83,7 +90,7 @@ class Salsa20Cipher(object):
def __init__(self, alg, key, iv, op, key_as_bytes=0, d=None, salt=None,
i=1, padding=1):
run_imports()
if alg != 'salsa20-ctr':
if alg != b'salsa20-ctr':
raise Exception('unknown algorithm')
self._key = key
self._nonce = struct.unpack('<Q', iv)[0]
@ -115,7 +122,7 @@ class Salsa20Cipher(object):
self._pos = 0
if not data:
break
return ''.join(results)
return b''.join(results)
ciphers = {
@ -137,8 +144,8 @@ def test():
# key_as_bytes=0, d='md5', salt=None, i=1,
# padding=1)
cipher = Salsa20Cipher('salsa20-ctr', 'k' * 32, 'i' * 8, 1)
decipher = Salsa20Cipher('salsa20-ctr', 'k' * 32, 'i' * 8, 1)
cipher = Salsa20Cipher(b'salsa20-ctr', b'k' * 32, b'i' * 8, 1)
decipher = Salsa20Cipher(b'salsa20-ctr', b'k' * 32, b'i' * 8, 1)
results = []
pos = 0
print('salsa20 test start')
@ -149,7 +156,7 @@ def test():
results.append(c)
pos += l
pos = 0
c = ''.join(results)
c = b''.join(results)
results = []
while pos < len(plain):
l = random.randint(100, 32768)
@ -157,7 +164,7 @@ def test():
pos += l
end = time.time()
print('speed: %d bytes/s' % (BLOCK_SIZE * rounds / (end - start)))
assert ''.join(results) == plain
assert b''.join(results) == plain
if __name__ == '__main__':