From af46629cd1303ce757eb8b3c3449aba5b06d3015 Mon Sep 17 00:00:00 2001 From: clowwindy Date: Sun, 4 May 2014 00:16:12 +0800 Subject: [PATCH] fix salsa20 --- shadowsocks/encrypt_salsa20.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/shadowsocks/encrypt_salsa20.py b/shadowsocks/encrypt_salsa20.py index e061c7d..85b30ff 100644 --- a/shadowsocks/encrypt_salsa20.py +++ b/shadowsocks/encrypt_salsa20.py @@ -4,6 +4,7 @@ import time import struct import logging import sys +import encrypt slow_xor = False imported = False @@ -72,14 +73,17 @@ class Salsa20Cipher(object): cur_data = data[:remain] cur_data_len = len(cur_data) cur_stream = self._stream[self._pos:self._pos + cur_data_len] - self._pos = (self._pos + cur_data_len) % BLOCK_SIZE + self._pos = self._pos + cur_data_len data = data[remain:] results.append(numpy_xor(cur_data, cur_stream)) + if self._pos >= BLOCK_SIZE: + self._next_stream() + self._pos -= BLOCK_SIZE + assert self._pos == 0 if not data: break - self._next_stream() return ''.join(results) @@ -87,8 +91,16 @@ def test(): from os import urandom import random - rounds = 1 * 10 + rounds = 1 * 1024 plain = urandom(BLOCK_SIZE * rounds) + import M2Crypto.EVP + cipher = M2Crypto.EVP.Cipher('aes_128_cfb', 'k' * 32, 'i' * 16, 1, + key_as_bytes=0, d='md5', salt=None, i=1, + padding=1) + decipher = M2Crypto.EVP.Cipher('aes_128_cfb', 'k' * 32, 'i' * 16, 0, + key_as_bytes=0, d='md5', salt=None, i=1, + padding=1) + cipher = Salsa20Cipher('salsa20-ctr', 'k' * 32, 'i' * 8, 1) decipher = Salsa20Cipher('salsa20-ctr', 'k' * 32, 'i' * 8, 1) results = [] @@ -96,13 +108,20 @@ def test(): print 'start' start = time.time() while pos < len(plain): - l = random.randint(10000, 32768) + l = random.randint(100, 16384) c = cipher.update(plain[pos:pos + l]) - results.append(decipher.update(c)) + results.append(c) + pos += l + pos = 0 + c = ''.join(results) + results = [] + while pos < len(plain): + l = random.randint(100, 16384) + results.append(decipher.update(c[pos:pos + l])) pos += l - assert ''.join(results) == plain end = time.time() print BLOCK_SIZE * rounds / (end - start) + assert ''.join(results) == plain if __name__ == '__main__':