#775 fix UDP decrypt_all issue

This commit is contained in:
Zou Yong 2017-03-02 14:38:56 +08:00
parent d5cf37ab67
commit 3b9689aaa0
4 changed files with 35 additions and 31 deletions

View file

@ -22,7 +22,8 @@ from ctypes import c_char_p, c_int, c_long, byref,\
from shadowsocks import common from shadowsocks import common
from shadowsocks.crypto import util from shadowsocks.crypto import util
from shadowsocks.crypto.aead import * from shadowsocks.crypto.aead import AeadCryptoBase, EVP_CTRL_AEAD_SET_IVLEN, \
nonce_increment, EVP_CTRL_AEAD_GET_TAG, EVP_CTRL_AEAD_SET_TAG
__all__ = ['ciphers'] __all__ = ['ciphers']
@ -116,8 +117,9 @@ class OpenSSLCryptoBase(object):
buf_size = l * 2 buf_size = l * 2
buf = create_string_buffer(buf_size) buf = create_string_buffer(buf_size)
libcrypto.EVP_CipherUpdate( libcrypto.EVP_CipherUpdate(
self._ctx, byref(buf), self._ctx, byref(buf),
byref(cipher_out_len), c_char_p(data), l) byref(cipher_out_len), c_char_p(data), l
)
# buf is copied to a str object when we access buf.raw # buf is copied to a str object when we access buf.raw
return buf.raw[:cipher_out_len.value] return buf.raw[:cipher_out_len.value]
@ -139,17 +141,20 @@ class OpenSSLAeadCrypto(OpenSSLCryptoBase, AeadCryptoBase):
AeadCryptoBase.__init__(self, cipher_name, key, iv, op) AeadCryptoBase.__init__(self, cipher_name, key, iv, op)
r = libcrypto.EVP_CipherInit_ex( r = libcrypto.EVP_CipherInit_ex(
self._ctx, self._ctx,
self._cipher, None, self._cipher, None,
None, None, c_int(op)) None, None, c_int(op)
)
if not r: if not r:
self.clean() self.clean()
raise Exception('can not initialize cipher context') raise Exception('can not initialize cipher context')
r = libcrypto.EVP_CIPHER_CTX_ctrl( r = libcrypto.EVP_CIPHER_CTX_ctrl(
self._ctx, self._ctx,
c_int(EVP_CTRL_AEAD_SET_IVLEN), c_int(EVP_CTRL_AEAD_SET_IVLEN),
c_int(self._nlen), None) c_int(self._nlen),
None
)
if not r: if not r:
raise Exception('Set ivlen failed') raise Exception('Set ivlen failed')
@ -164,10 +169,10 @@ class OpenSSLAeadCrypto(OpenSSLCryptoBase, AeadCryptoBase):
iv_ptr = c_char_p(self._nonce.raw) iv_ptr = c_char_p(self._nonce.raw)
r = libcrypto.EVP_CipherInit_ex( r = libcrypto.EVP_CipherInit_ex(
self._ctx, self._ctx,
None, None, None, None,
key_ptr, iv_ptr, key_ptr, iv_ptr,
c_int(CIPHER_ENC_UNCHANGED) c_int(CIPHER_ENC_UNCHANGED)
) )
if not r: if not r:
self.clean() self.clean()
@ -184,9 +189,9 @@ class OpenSSLAeadCrypto(OpenSSLCryptoBase, AeadCryptoBase):
""" """
tag_len = self._tlen tag_len = self._tlen
r = libcrypto.EVP_CIPHER_CTX_ctrl( r = libcrypto.EVP_CIPHER_CTX_ctrl(
self._ctx, self._ctx,
c_int(EVP_CTRL_AEAD_SET_TAG), c_int(EVP_CTRL_AEAD_SET_TAG),
c_int(tag_len), c_char_p(tag) c_int(tag_len), c_char_p(tag)
) )
if not r: if not r:
raise Exception('Set tag failed') raise Exception('Set tag failed')
@ -199,9 +204,9 @@ class OpenSSLAeadCrypto(OpenSSLCryptoBase, AeadCryptoBase):
tag_len = self._tlen tag_len = self._tlen
tag_buf = create_string_buffer(tag_len) tag_buf = create_string_buffer(tag_len)
r = libcrypto.EVP_CIPHER_CTX_ctrl( r = libcrypto.EVP_CIPHER_CTX_ctrl(
self._ctx, self._ctx,
c_int(EVP_CTRL_AEAD_GET_TAG), c_int(EVP_CTRL_AEAD_GET_TAG),
c_int(tag_len), byref(tag_buf) c_int(tag_len), byref(tag_buf)
) )
if not r: if not r:
raise Exception('Get tag failed') raise Exception('Get tag failed')
@ -215,8 +220,8 @@ class OpenSSLAeadCrypto(OpenSSLCryptoBase, AeadCryptoBase):
global buf_size, buf global buf_size, buf
cipher_out_len = c_long(0) cipher_out_len = c_long(0)
r = libcrypto.EVP_CipherFinal_ex( r = libcrypto.EVP_CipherFinal_ex(
self._ctx, self._ctx,
byref(buf), byref(cipher_out_len) byref(buf), byref(cipher_out_len)
) )
if not r: if not r:
# print(self._nonce.raw, r, cipher_out_len) # print(self._nonce.raw, r, cipher_out_len)
@ -321,8 +326,8 @@ def test_aes_128_cfb():
def test_aes_gcm(bits=128): def test_aes_gcm(bits=128):
method = "aes-{0}-gcm".format(bits) method = "aes-{0}-gcm".format(bits)
print(method, int(bits/8)) print(method, int(bits / 8))
run_aead_method(method, bits/8) run_aead_method(method, bits / 8)
def test_aes_256_gcm(): def test_aes_256_gcm():

View file

@ -21,7 +21,7 @@ from ctypes import c_char_p, c_int, c_ulonglong, byref, c_ulong, \
create_string_buffer, c_void_p create_string_buffer, c_void_p
from shadowsocks.crypto import util from shadowsocks.crypto import util
from shadowsocks.crypto.aead import * from shadowsocks.crypto.aead import AeadCryptoBase
__all__ = ['ciphers'] __all__ = ['ciphers']

View file

@ -19,7 +19,6 @@ from __future__ import absolute_import, division, print_function, \
import os import os
import logging import logging
from ctypes import create_string_buffer
def find_library_nt(name): def find_library_nt(name):
@ -34,7 +33,7 @@ def find_library_nt(name):
results.append(fname) results.append(fname)
if fname.lower().endswith(".dll"): if fname.lower().endswith(".dll"):
continue continue
fname = fname + ".dll" fname += ".dll"
if os.path.isfile(fname): if os.path.isfile(fname):
results.append(fname) results.append(fname)
return results return results
@ -111,9 +110,9 @@ def run_cipher(cipher, decipher):
import random import random
import time import time
BLOCK_SIZE = 16384 block_size = 16384
rounds = 1 * 1024 rounds = 1 * 1024
plain = urandom(BLOCK_SIZE * rounds) plain = urandom(block_size * rounds)
results = [] results = []
pos = 0 pos = 0
@ -132,7 +131,7 @@ def run_cipher(cipher, decipher):
results.append(decipher.decrypt(c[pos:pos + l])) results.append(decipher.decrypt(c[pos:pos + l]))
pos += l pos += l
end = time.time() end = time.time()
print('speed: %d bytes/s' % (BLOCK_SIZE * rounds / (end - start))) print('speed: %d bytes/s' % (block_size * rounds / (end - start)))
assert b''.join(results) == plain assert b''.join(results) == plain

View file

@ -87,8 +87,8 @@ class Cryptor(object):
self._method_info = Cryptor.get_method_info(method) self._method_info = Cryptor.get_method_info(method)
if self._method_info: if self._method_info:
self.cipher = self.get_cipher( self.cipher = self.get_cipher(
password, method, CIPHER_ENC_ENCRYPTION, password, method, CIPHER_ENC_ENCRYPTION,
random_string(self._method_info[METHOD_INFO_IV_LEN]) random_string(self._method_info[METHOD_INFO_IV_LEN])
) )
else: else:
logging.error('method %s not supported' % method) logging.error('method %s not supported' % method)