diff --git a/shadowsocks/encrypt.py b/shadowsocks/encrypt.py index 09a6db6..5523c54 100644 --- a/shadowsocks/encrypt.py +++ b/shadowsocks/encrypt.py @@ -42,11 +42,7 @@ def get_table(key): table.sort(lambda x, y: int(a % (ord(x) + i) - a % (ord(y) + i))) return table -encrypt_table = None -decrypt_table = None - - -def init_table(key, method=None): +def try_encryptor(key, method=None): if method == 'table': method = None if method: @@ -56,7 +52,6 @@ def init_table(key, method=None): logging.error('M2Crypto is required to use encryption other than default method') sys.exit(1) if not method: - global encrypt_table, decrypt_table encrypt_table = ''.join(get_table(key)) decrypt_table = string.maketrans(encrypt_table, string.maketrans('', '')) else: @@ -118,6 +113,9 @@ class Encryptor(object): self.cipher = self.get_cipher(key, method, 1, iv=random_string(32)) else: self.cipher = None + self.encrypt_table = ''.join(get_table(key)) + self.decrypt_table = string.maketrans(self.encrypt_table, string.maketrans('', '')) + def get_cipher_len(self, method): method = method.lower() @@ -147,7 +145,7 @@ class Encryptor(object): if len(buf) == 0: return buf if self.method is None: - return string.translate(buf, encrypt_table) + return string.translate(buf, self.encrypt_table) else: if self.iv_sent: return self.cipher.update(buf) @@ -159,7 +157,7 @@ class Encryptor(object): if len(buf) == 0: return buf if self.method is None: - return string.translate(buf, decrypt_table) + return string.translate(buf, self.decrypt_table) else: if self.decipher is None: decipher_iv_len = self.get_cipher_len(self.method)[1] diff --git a/shadowsocks/local.py b/shadowsocks/local.py index b3061e6..a3c6b4d 100755 --- a/shadowsocks/local.py +++ b/shadowsocks/local.py @@ -235,7 +235,7 @@ def main(): utils.check_config(config) - encrypt.init_table(KEY, METHOD) + encrypt.try_encryptor(KEY, METHOD) try: if IPv6: diff --git a/shadowsocks/server.py b/shadowsocks/server.py index d045d5e..39cc650 100755 --- a/shadowsocks/server.py +++ b/shadowsocks/server.py @@ -193,7 +193,7 @@ def main(): PORTPASSWORD = {} PORTPASSWORD[str(PORT)] = KEY - encrypt.init_table(KEY, METHOD) + encrypt.try_encryptor(KEY, METHOD) if IPv6: ThreadingTCPServer.address_family = socket.AF_INET6 for port, key in PORTPASSWORD.items():