This commit is contained in:
clowwindy 2014-05-04 00:52:08 +08:00
parent af46629cd1
commit 92c9ba0009

View file

@ -4,7 +4,6 @@ import time
import struct import struct
import logging import logging
import sys import sys
import encrypt
slow_xor = False slow_xor = False
imported = False imported = False
@ -32,8 +31,14 @@ def run_imports():
def numpy_xor(a, b): def numpy_xor(a, b):
if slow_xor: if slow_xor:
return py_xor_str(a, b) return py_xor_str(a, b)
ab = numpy.frombuffer(a, dtype=numpy.byte) dtype = numpy.byte
bb = numpy.frombuffer(b, dtype=numpy.byte) if len(a) % 4 == 0:
dtype = numpy.uint32
elif len(a) % 2 == 0:
dtype = numpy.uint16
ab = numpy.frombuffer(a, dtype=dtype)
bb = numpy.frombuffer(b, dtype=dtype)
c = numpy.bitwise_xor(ab, bb) c = numpy.bitwise_xor(ab, bb)
r = c.tostring() r = c.tostring()
return r return r
@ -80,8 +85,7 @@ class Salsa20Cipher(object):
if self._pos >= BLOCK_SIZE: if self._pos >= BLOCK_SIZE:
self._next_stream() self._next_stream()
self._pos -= BLOCK_SIZE self._pos = 0
assert self._pos == 0
if not data: if not data:
break break
return ''.join(results) return ''.join(results)
@ -94,12 +98,12 @@ def test():
rounds = 1 * 1024 rounds = 1 * 1024
plain = urandom(BLOCK_SIZE * rounds) plain = urandom(BLOCK_SIZE * rounds)
import M2Crypto.EVP import M2Crypto.EVP
cipher = M2Crypto.EVP.Cipher('aes_128_cfb', 'k' * 32, 'i' * 16, 1, # cipher = M2Crypto.EVP.Cipher('aes_128_cfb', 'k' * 32, 'i' * 16, 1,
key_as_bytes=0, d='md5', salt=None, i=1, # key_as_bytes=0, d='md5', salt=None, i=1,
padding=1) # padding=1)
decipher = M2Crypto.EVP.Cipher('aes_128_cfb', 'k' * 32, 'i' * 16, 0, # decipher = M2Crypto.EVP.Cipher('aes_128_cfb', 'k' * 32, 'i' * 16, 0,
key_as_bytes=0, d='md5', salt=None, i=1, # key_as_bytes=0, d='md5', salt=None, i=1,
padding=1) # padding=1)
cipher = Salsa20Cipher('salsa20-ctr', 'k' * 32, 'i' * 8, 1) cipher = Salsa20Cipher('salsa20-ctr', 'k' * 32, 'i' * 8, 1)
decipher = Salsa20Cipher('salsa20-ctr', 'k' * 32, 'i' * 8, 1) decipher = Salsa20Cipher('salsa20-ctr', 'k' * 32, 'i' * 8, 1)
@ -108,7 +112,7 @@ def test():
print 'start' print 'start'
start = time.time() start = time.time()
while pos < len(plain): while pos < len(plain):
l = random.randint(100, 16384) l = random.randint(100, 32768)
c = cipher.update(plain[pos:pos + l]) c = cipher.update(plain[pos:pos + l])
results.append(c) results.append(c)
pos += l pos += l
@ -116,7 +120,7 @@ def test():
c = ''.join(results) c = ''.join(results)
results = [] results = []
while pos < len(plain): while pos < len(plain):
l = random.randint(100, 16384) l = random.randint(100, 32768)
results.append(decipher.update(c[pos:pos + l])) results.append(decipher.update(c[pos:pos + l]))
pos += l pos += l
end = time.time() end = time.time()