Merge 710c93291e
into ad39d957d7
This commit is contained in:
commit
0fdd2fec83
12 changed files with 1183 additions and 1385 deletions
|
@ -29,7 +29,7 @@ from shadowsocks import common, lru_cache, eventloop, shell
|
||||||
|
|
||||||
CACHE_SWEEP_INTERVAL = 30
|
CACHE_SWEEP_INTERVAL = 30
|
||||||
|
|
||||||
VALID_HOSTNAME = re.compile(br"(?!-)[A-Z\d\-_]{1,63}(?<!-)$", re.IGNORECASE)
|
VALID_HOSTNAME = re.compile(br"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE)
|
||||||
|
|
||||||
common.patch_socket()
|
common.patch_socket()
|
||||||
|
|
||||||
|
@ -242,13 +242,13 @@ class DNSResponse(object):
|
||||||
return '%s: %s' % (self.hostname, str(self.answers))
|
return '%s: %s' % (self.hostname, str(self.answers))
|
||||||
|
|
||||||
|
|
||||||
STATUS_FIRST = 0
|
STATUS_IPV4 = 0
|
||||||
STATUS_SECOND = 1
|
STATUS_IPV6 = 1
|
||||||
|
|
||||||
|
|
||||||
class DNSResolver(object):
|
class DNSResolver(object):
|
||||||
|
|
||||||
def __init__(self, server_list=None, prefer_ipv6=False):
|
def __init__(self):
|
||||||
self._loop = None
|
self._loop = None
|
||||||
self._hosts = {}
|
self._hosts = {}
|
||||||
self._hostname_status = {}
|
self._hostname_status = {}
|
||||||
|
@ -256,15 +256,8 @@ class DNSResolver(object):
|
||||||
self._cb_to_hostname = {}
|
self._cb_to_hostname = {}
|
||||||
self._cache = lru_cache.LRUCache(timeout=300)
|
self._cache = lru_cache.LRUCache(timeout=300)
|
||||||
self._sock = None
|
self._sock = None
|
||||||
if server_list is None:
|
self._servers = None
|
||||||
self._servers = None
|
self._parse_resolv()
|
||||||
self._parse_resolv()
|
|
||||||
else:
|
|
||||||
self._servers = server_list
|
|
||||||
if prefer_ipv6:
|
|
||||||
self._QTYPES = [QTYPE_AAAA, QTYPE_A]
|
|
||||||
else:
|
|
||||||
self._QTYPES = [QTYPE_A, QTYPE_AAAA]
|
|
||||||
self._parse_hosts()
|
self._parse_hosts()
|
||||||
# TODO monitor hosts change and reload hosts
|
# TODO monitor hosts change and reload hosts
|
||||||
# TODO parse /etc/gai.conf and follow its rules
|
# TODO parse /etc/gai.conf and follow its rules
|
||||||
|
@ -276,18 +269,15 @@ class DNSResolver(object):
|
||||||
content = f.readlines()
|
content = f.readlines()
|
||||||
for line in content:
|
for line in content:
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if not (line and line.startswith(b'nameserver')):
|
if line:
|
||||||
continue
|
if line.startswith(b'nameserver'):
|
||||||
|
parts = line.split()
|
||||||
parts = line.split()
|
if len(parts) >= 2:
|
||||||
if len(parts) < 2:
|
server = parts[1]
|
||||||
continue
|
if common.is_ip(server) == socket.AF_INET:
|
||||||
|
if type(server) != str:
|
||||||
server = parts[1]
|
server = server.decode('utf8')
|
||||||
if common.is_ip(server) == socket.AF_INET:
|
self._servers.append(server)
|
||||||
if type(server) != str:
|
|
||||||
server = server.decode('utf8')
|
|
||||||
self._servers.append(server)
|
|
||||||
except IOError:
|
except IOError:
|
||||||
pass
|
pass
|
||||||
if not self._servers:
|
if not self._servers:
|
||||||
|
@ -302,17 +292,13 @@ class DNSResolver(object):
|
||||||
for line in f.readlines():
|
for line in f.readlines():
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
parts = line.split()
|
parts = line.split()
|
||||||
if len(parts) < 2:
|
if len(parts) >= 2:
|
||||||
continue
|
ip = parts[0]
|
||||||
|
if common.is_ip(ip):
|
||||||
ip = parts[0]
|
for i in range(1, len(parts)):
|
||||||
if not common.is_ip(ip):
|
hostname = parts[i]
|
||||||
continue
|
if hostname:
|
||||||
|
self._hosts[hostname] = ip
|
||||||
for i in range(1, len(parts)):
|
|
||||||
hostname = parts[i]
|
|
||||||
if hostname:
|
|
||||||
self._hosts[hostname] = ip
|
|
||||||
except IOError:
|
except IOError:
|
||||||
self._hosts['localhost'] = '127.0.0.1'
|
self._hosts['localhost'] = '127.0.0.1'
|
||||||
|
|
||||||
|
@ -352,22 +338,21 @@ class DNSResolver(object):
|
||||||
answer[2] == QCLASS_IN:
|
answer[2] == QCLASS_IN:
|
||||||
ip = answer[0]
|
ip = answer[0]
|
||||||
break
|
break
|
||||||
if not ip and self._hostname_status.get(hostname, STATUS_SECOND) \
|
if not ip and self._hostname_status.get(hostname, STATUS_IPV6) \
|
||||||
== STATUS_FIRST:
|
== STATUS_IPV4:
|
||||||
self._hostname_status[hostname] = STATUS_SECOND
|
self._hostname_status[hostname] = STATUS_IPV6
|
||||||
self._send_req(hostname, self._QTYPES[1])
|
self._send_req(hostname, QTYPE_AAAA)
|
||||||
else:
|
else:
|
||||||
if ip:
|
if ip:
|
||||||
self._cache[hostname] = ip
|
self._cache[hostname] = ip
|
||||||
self._call_callback(hostname, ip)
|
self._call_callback(hostname, ip)
|
||||||
elif self._hostname_status.get(hostname, None) \
|
elif self._hostname_status.get(hostname, None) == STATUS_IPV6:
|
||||||
== STATUS_SECOND:
|
|
||||||
for question in response.questions:
|
for question in response.questions:
|
||||||
if question[1] == self._QTYPES[1]:
|
if question[1] == QTYPE_AAAA:
|
||||||
self._call_callback(hostname, None)
|
self._call_callback(hostname, None)
|
||||||
break
|
break
|
||||||
|
|
||||||
def handle_event(self, sock, fd, event):
|
def handle_event(self, sock, event):
|
||||||
if sock != self._sock:
|
if sock != self._sock:
|
||||||
return
|
return
|
||||||
if event & eventloop.POLL_ERR:
|
if event & eventloop.POLL_ERR:
|
||||||
|
@ -429,14 +414,14 @@ class DNSResolver(object):
|
||||||
return
|
return
|
||||||
arr = self._hostname_to_cb.get(hostname, None)
|
arr = self._hostname_to_cb.get(hostname, None)
|
||||||
if not arr:
|
if not arr:
|
||||||
self._hostname_status[hostname] = STATUS_FIRST
|
self._hostname_status[hostname] = STATUS_IPV4
|
||||||
self._send_req(hostname, self._QTYPES[0])
|
self._send_req(hostname, QTYPE_A)
|
||||||
self._hostname_to_cb[hostname] = [callback]
|
self._hostname_to_cb[hostname] = [callback]
|
||||||
self._cb_to_hostname[callback] = hostname
|
self._cb_to_hostname[callback] = hostname
|
||||||
else:
|
else:
|
||||||
arr.append(callback)
|
arr.append(callback)
|
||||||
# TODO send again only if waited too long
|
# TODO send again only if waited too long
|
||||||
self._send_req(hostname, self._QTYPES[0])
|
self._send_req(hostname, QTYPE_A)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if self._sock:
|
if self._sock:
|
||||||
|
|
33
shadowsocks/client.py
Normal file
33
shadowsocks/client.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from __future__ import absolute_import, division, print_function, \
|
||||||
|
with_statement
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
sys.path.insert(0, path)
|
||||||
|
from shadowsocks.eventloop import EventLoop
|
||||||
|
from shadowsocks.tcprelay import TcpRelay, TcpRelayClientHanler
|
||||||
|
|
||||||
|
|
||||||
|
FORMATTER = '%(asctime)s - %(levelname)s - %(message)s'
|
||||||
|
LOGGING_LEVEL = logging.INFO
|
||||||
|
logging.basicConfig(level=LOGGING_LEVEL, format=FORMATTER)
|
||||||
|
|
||||||
|
LISTEN_ADDR = ('127.0.0.1', 7070)
|
||||||
|
REMOTE_ADDR = ('104.129.176.124', 9000)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
loop = EventLoop()
|
||||||
|
relay = TcpRelay(TcpRelayClientHanler, LISTEN_ADDR, REMOTE_ADDR)
|
||||||
|
relay.add_to_loop(loop)
|
||||||
|
loop.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -21,25 +21,6 @@ from __future__ import absolute_import, division, print_function, \
|
||||||
import socket
|
import socket
|
||||||
import struct
|
import struct
|
||||||
import logging
|
import logging
|
||||||
import hashlib
|
|
||||||
import hmac
|
|
||||||
|
|
||||||
|
|
||||||
ONETIMEAUTH_BYTES = 10
|
|
||||||
ONETIMEAUTH_CHUNK_BYTES = 12
|
|
||||||
ONETIMEAUTH_CHUNK_DATA_LEN = 2
|
|
||||||
|
|
||||||
|
|
||||||
def sha1_hmac(secret, data):
|
|
||||||
return hmac.new(secret, data, hashlib.sha1).digest()
|
|
||||||
|
|
||||||
|
|
||||||
def onetimeauth_verify(_hash, data, key):
|
|
||||||
return _hash == sha1_hmac(key, data)[:ONETIMEAUTH_BYTES]
|
|
||||||
|
|
||||||
|
|
||||||
def onetimeauth_gen(data, key):
|
|
||||||
return sha1_hmac(key, data)[:ONETIMEAUTH_BYTES]
|
|
||||||
|
|
||||||
|
|
||||||
def compat_ord(s):
|
def compat_ord(s):
|
||||||
|
@ -137,11 +118,9 @@ def patch_socket():
|
||||||
patch_socket()
|
patch_socket()
|
||||||
|
|
||||||
|
|
||||||
ADDRTYPE_IPV4 = 0x01
|
ADDRTYPE_IPV4 = 1
|
||||||
ADDRTYPE_IPV6 = 0x04
|
ADDRTYPE_IPV6 = 4
|
||||||
ADDRTYPE_HOST = 0x03
|
ADDRTYPE_HOST = 3
|
||||||
ADDRTYPE_AUTH = 0x10
|
|
||||||
ADDRTYPE_MASK = 0xF
|
|
||||||
|
|
||||||
|
|
||||||
def pack_addr(address):
|
def pack_addr(address):
|
||||||
|
@ -165,17 +144,17 @@ def parse_header(data):
|
||||||
dest_addr = None
|
dest_addr = None
|
||||||
dest_port = None
|
dest_port = None
|
||||||
header_length = 0
|
header_length = 0
|
||||||
if addrtype & ADDRTYPE_MASK == ADDRTYPE_IPV4:
|
if addrtype == ADDRTYPE_IPV4:
|
||||||
if len(data) >= 7:
|
if len(data) >= 7:
|
||||||
dest_addr = socket.inet_ntoa(data[1:5])
|
dest_addr = socket.inet_ntoa(data[1:5])
|
||||||
dest_port = struct.unpack('>H', data[5:7])[0]
|
dest_port = struct.unpack('>H', data[5:7])[0]
|
||||||
header_length = 7
|
header_length = 7
|
||||||
else:
|
else:
|
||||||
logging.warn('header is too short')
|
logging.warn('header is too short')
|
||||||
elif addrtype & ADDRTYPE_MASK == ADDRTYPE_HOST:
|
elif addrtype == ADDRTYPE_HOST:
|
||||||
if len(data) > 2:
|
if len(data) > 2:
|
||||||
addrlen = ord(data[1])
|
addrlen = ord(data[1])
|
||||||
if len(data) >= 4 + addrlen:
|
if len(data) >= 2 + addrlen:
|
||||||
dest_addr = data[2:2 + addrlen]
|
dest_addr = data[2:2 + addrlen]
|
||||||
dest_port = struct.unpack('>H', data[2 + addrlen:4 +
|
dest_port = struct.unpack('>H', data[2 + addrlen:4 +
|
||||||
addrlen])[0]
|
addrlen])[0]
|
||||||
|
@ -184,7 +163,7 @@ def parse_header(data):
|
||||||
logging.warn('header is too short')
|
logging.warn('header is too short')
|
||||||
else:
|
else:
|
||||||
logging.warn('header is too short')
|
logging.warn('header is too short')
|
||||||
elif addrtype & ADDRTYPE_MASK == ADDRTYPE_IPV6:
|
elif addrtype == ADDRTYPE_IPV6:
|
||||||
if len(data) >= 19:
|
if len(data) >= 19:
|
||||||
dest_addr = socket.inet_ntop(socket.AF_INET6, data[1:17])
|
dest_addr = socket.inet_ntop(socket.AF_INET6, data[1:17])
|
||||||
dest_port = struct.unpack('>H', data[17:19])[0]
|
dest_port = struct.unpack('>H', data[17:19])[0]
|
||||||
|
|
|
@ -32,7 +32,7 @@ buf_size = 2048
|
||||||
|
|
||||||
|
|
||||||
def load_openssl():
|
def load_openssl():
|
||||||
global loaded, libcrypto, buf, ctx_cleanup
|
global loaded, libcrypto, buf
|
||||||
|
|
||||||
libcrypto = util.find_library(('crypto', 'eay32'),
|
libcrypto = util.find_library(('crypto', 'eay32'),
|
||||||
'EVP_get_cipherbyname',
|
'EVP_get_cipherbyname',
|
||||||
|
@ -49,12 +49,7 @@ def load_openssl():
|
||||||
libcrypto.EVP_CipherUpdate.argtypes = (c_void_p, c_void_p, c_void_p,
|
libcrypto.EVP_CipherUpdate.argtypes = (c_void_p, c_void_p, c_void_p,
|
||||||
c_char_p, c_int)
|
c_char_p, c_int)
|
||||||
|
|
||||||
try:
|
libcrypto.EVP_CIPHER_CTX_cleanup.argtypes = (c_void_p,)
|
||||||
libcrypto.EVP_CIPHER_CTX_cleanup.argtypes = (c_void_p,)
|
|
||||||
ctx_cleanup = libcrypto.EVP_CIPHER_CTX_cleanup
|
|
||||||
except AttributeError:
|
|
||||||
libcrypto.EVP_CIPHER_CTX_reset.argtypes = (c_void_p,)
|
|
||||||
ctx_cleanup = libcrypto.EVP_CIPHER_CTX_reset
|
|
||||||
libcrypto.EVP_CIPHER_CTX_free.argtypes = (c_void_p,)
|
libcrypto.EVP_CIPHER_CTX_free.argtypes = (c_void_p,)
|
||||||
if hasattr(libcrypto, 'OpenSSL_add_all_ciphers'):
|
if hasattr(libcrypto, 'OpenSSL_add_all_ciphers'):
|
||||||
libcrypto.OpenSSL_add_all_ciphers()
|
libcrypto.OpenSSL_add_all_ciphers()
|
||||||
|
@ -113,7 +108,7 @@ class OpenSSLCrypto(object):
|
||||||
|
|
||||||
def clean(self):
|
def clean(self):
|
||||||
if self._ctx:
|
if self._ctx:
|
||||||
ctx_cleanup(self._ctx)
|
libcrypto.EVP_CIPHER_CTX_cleanup(self._ctx)
|
||||||
libcrypto.EVP_CIPHER_CTX_free(self._ctx)
|
libcrypto.EVP_CIPHER_CTX_free(self._ctx)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
from __future__ import absolute_import, division, print_function, \
|
from __future__ import absolute_import, division, print_function, \
|
||||||
with_statement
|
with_statement
|
||||||
|
|
||||||
from ctypes import c_char_p, c_int, c_ulonglong, byref, c_ulong, \
|
from ctypes import c_char_p, c_int, c_ulonglong, byref, \
|
||||||
create_string_buffer, c_void_p
|
create_string_buffer, c_void_p
|
||||||
|
|
||||||
from shadowsocks.crypto import util
|
from shadowsocks.crypto import util
|
||||||
|
@ -29,7 +29,7 @@ loaded = False
|
||||||
|
|
||||||
buf_size = 2048
|
buf_size = 2048
|
||||||
|
|
||||||
# for salsa20 and chacha20 and chacha20-ietf
|
# for salsa20 and chacha20
|
||||||
BLOCK_SIZE = 64
|
BLOCK_SIZE = 64
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,13 +51,6 @@ def load_libsodium():
|
||||||
c_ulonglong,
|
c_ulonglong,
|
||||||
c_char_p, c_ulonglong,
|
c_char_p, c_ulonglong,
|
||||||
c_char_p)
|
c_char_p)
|
||||||
libsodium.crypto_stream_chacha20_ietf_xor_ic.restype = c_int
|
|
||||||
libsodium.crypto_stream_chacha20_ietf_xor_ic.argtypes = (c_void_p,
|
|
||||||
c_char_p,
|
|
||||||
c_ulonglong,
|
|
||||||
c_char_p,
|
|
||||||
c_ulong,
|
|
||||||
c_char_p)
|
|
||||||
|
|
||||||
buf = create_string_buffer(buf_size)
|
buf = create_string_buffer(buf_size)
|
||||||
loaded = True
|
loaded = True
|
||||||
|
@ -75,8 +68,6 @@ class SodiumCrypto(object):
|
||||||
self.cipher = libsodium.crypto_stream_salsa20_xor_ic
|
self.cipher = libsodium.crypto_stream_salsa20_xor_ic
|
||||||
elif cipher_name == 'chacha20':
|
elif cipher_name == 'chacha20':
|
||||||
self.cipher = libsodium.crypto_stream_chacha20_xor_ic
|
self.cipher = libsodium.crypto_stream_chacha20_xor_ic
|
||||||
elif cipher_name == 'chacha20-ietf':
|
|
||||||
self.cipher = libsodium.crypto_stream_chacha20_ietf_xor_ic
|
|
||||||
else:
|
else:
|
||||||
raise Exception('Unknown cipher')
|
raise Exception('Unknown cipher')
|
||||||
# byte counter, not block counter
|
# byte counter, not block counter
|
||||||
|
@ -106,7 +97,6 @@ class SodiumCrypto(object):
|
||||||
ciphers = {
|
ciphers = {
|
||||||
'salsa20': (32, 8, SodiumCrypto),
|
'salsa20': (32, 8, SodiumCrypto),
|
||||||
'chacha20': (32, 8, SodiumCrypto),
|
'chacha20': (32, 8, SodiumCrypto),
|
||||||
'chacha20-ietf': (32, 12, SodiumCrypto),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -125,15 +115,6 @@ def test_chacha20():
|
||||||
util.run_cipher(cipher, decipher)
|
util.run_cipher(cipher, decipher)
|
||||||
|
|
||||||
|
|
||||||
def test_chacha20_ietf():
|
|
||||||
|
|
||||||
cipher = SodiumCrypto('chacha20-ietf', b'k' * 32, b'i' * 16, 1)
|
|
||||||
decipher = SodiumCrypto('chacha20-ietf', b'k' * 32, b'i' * 16, 0)
|
|
||||||
|
|
||||||
util.run_cipher(cipher, decipher)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_chacha20()
|
test_chacha20()
|
||||||
test_salsa20()
|
test_salsa20()
|
||||||
test_chacha20_ietf()
|
|
||||||
|
|
|
@ -69,18 +69,17 @@ def EVP_BytesToKey(password, key_len, iv_len):
|
||||||
|
|
||||||
|
|
||||||
class Encryptor(object):
|
class Encryptor(object):
|
||||||
def __init__(self, password, method):
|
def __init__(self, key, method):
|
||||||
self.password = password
|
self.key = key
|
||||||
self.key = None
|
|
||||||
self.method = method
|
self.method = method
|
||||||
|
self.iv = None
|
||||||
self.iv_sent = False
|
self.iv_sent = False
|
||||||
self.cipher_iv = b''
|
self.cipher_iv = b''
|
||||||
self.decipher = None
|
self.decipher = None
|
||||||
self.decipher_iv = None
|
|
||||||
method = method.lower()
|
method = method.lower()
|
||||||
self._method_info = self.get_method_info(method)
|
self._method_info = self.get_method_info(method)
|
||||||
if self._method_info:
|
if self._method_info:
|
||||||
self.cipher = self.get_cipher(password, method, 1,
|
self.cipher = self.get_cipher(key, method, 1,
|
||||||
random_string(self._method_info[1]))
|
random_string(self._method_info[1]))
|
||||||
else:
|
else:
|
||||||
logging.error('method %s not supported' % method)
|
logging.error('method %s not supported' % method)
|
||||||
|
@ -102,7 +101,7 @@ class Encryptor(object):
|
||||||
else:
|
else:
|
||||||
# key_length == 0 indicates we should use the key directly
|
# key_length == 0 indicates we should use the key directly
|
||||||
key, iv = password, b''
|
key, iv = password, b''
|
||||||
self.key = key
|
|
||||||
iv = iv[:m[1]]
|
iv = iv[:m[1]]
|
||||||
if op == 1:
|
if op == 1:
|
||||||
# this iv is for cipher not decipher
|
# this iv is for cipher not decipher
|
||||||
|
@ -124,8 +123,7 @@ class Encryptor(object):
|
||||||
if self.decipher is None:
|
if self.decipher is None:
|
||||||
decipher_iv_len = self._method_info[1]
|
decipher_iv_len = self._method_info[1]
|
||||||
decipher_iv = buf[:decipher_iv_len]
|
decipher_iv = buf[:decipher_iv_len]
|
||||||
self.decipher_iv = decipher_iv
|
self.decipher = self.get_cipher(self.key, self.method, 0,
|
||||||
self.decipher = self.get_cipher(self.password, self.method, 0,
|
|
||||||
iv=decipher_iv)
|
iv=decipher_iv)
|
||||||
buf = buf[decipher_iv_len:]
|
buf = buf[decipher_iv_len:]
|
||||||
if len(buf) == 0:
|
if len(buf) == 0:
|
||||||
|
@ -133,47 +131,10 @@ class Encryptor(object):
|
||||||
return self.decipher.update(buf)
|
return self.decipher.update(buf)
|
||||||
|
|
||||||
|
|
||||||
def gen_key_iv(password, method):
|
|
||||||
method = method.lower()
|
|
||||||
(key_len, iv_len, m) = method_supported[method]
|
|
||||||
key = None
|
|
||||||
if key_len > 0:
|
|
||||||
key, _ = EVP_BytesToKey(password, key_len, iv_len)
|
|
||||||
else:
|
|
||||||
key = password
|
|
||||||
iv = random_string(iv_len)
|
|
||||||
return key, iv, m
|
|
||||||
|
|
||||||
|
|
||||||
def encrypt_all_m(key, iv, m, method, data):
|
|
||||||
result = []
|
|
||||||
result.append(iv)
|
|
||||||
cipher = m(method, key, iv, 1)
|
|
||||||
result.append(cipher.update(data))
|
|
||||||
return b''.join(result)
|
|
||||||
|
|
||||||
|
|
||||||
def dencrypt_all(password, method, data):
|
|
||||||
result = []
|
|
||||||
method = method.lower()
|
|
||||||
(key_len, iv_len, m) = method_supported[method]
|
|
||||||
key = None
|
|
||||||
if key_len > 0:
|
|
||||||
key, _ = EVP_BytesToKey(password, key_len, iv_len)
|
|
||||||
else:
|
|
||||||
key = password
|
|
||||||
iv = data[:iv_len]
|
|
||||||
data = data[iv_len:]
|
|
||||||
cipher = m(method, key, iv, 0)
|
|
||||||
result.append(cipher.update(data))
|
|
||||||
return b''.join(result), key, iv
|
|
||||||
|
|
||||||
|
|
||||||
def encrypt_all(password, method, op, data):
|
def encrypt_all(password, method, op, data):
|
||||||
result = []
|
result = []
|
||||||
method = method.lower()
|
method = method.lower()
|
||||||
(key_len, iv_len, m) = method_supported[method]
|
(key_len, iv_len, m) = method_supported[method]
|
||||||
key = None
|
|
||||||
if key_len > 0:
|
if key_len > 0:
|
||||||
key, _ = EVP_BytesToKey(password, key_len, iv_len)
|
key, _ = EVP_BytesToKey(password, key_len, iv_len)
|
||||||
else:
|
else:
|
||||||
|
@ -221,18 +182,6 @@ def test_encrypt_all():
|
||||||
assert plain == plain2
|
assert plain == plain2
|
||||||
|
|
||||||
|
|
||||||
def test_encrypt_all_m():
|
|
||||||
from os import urandom
|
|
||||||
plain = urandom(10240)
|
|
||||||
for method in CIPHERS_TO_TEST:
|
|
||||||
logging.warn(method)
|
|
||||||
key, iv, m = gen_key_iv(b'key', method)
|
|
||||||
cipher = encrypt_all_m(key, iv, m, method, plain)
|
|
||||||
plain2, key, iv = dencrypt_all(b'key', method, cipher)
|
|
||||||
assert plain == plain2
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_encrypt_all()
|
test_encrypt_all()
|
||||||
test_encryptor()
|
test_encryptor()
|
||||||
test_encrypt_all_m()
|
|
||||||
|
|
|
@ -1,181 +1,53 @@
|
||||||
#!/usr/bin/python
|
#!/usr/bin/env python
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
#
|
|
||||||
# Copyright 2013-2015 clowwindy
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
|
||||||
# not use this file except in compliance with the License. You may obtain
|
|
||||||
# a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
# License for the specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
|
|
||||||
# from ssloop
|
|
||||||
# https://github.com/clowwindy/ssloop
|
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, \
|
from __future__ import absolute_import, division, print_function, \
|
||||||
with_statement
|
with_statement
|
||||||
|
|
||||||
import os
|
import logging
|
||||||
import time
|
import time
|
||||||
import socket
|
|
||||||
import select
|
|
||||||
import traceback
|
import traceback
|
||||||
import errno
|
import errno
|
||||||
import logging
|
from shadowsocks import selectors
|
||||||
from collections import defaultdict
|
from shadowsocks.selectors import (EVENT_READ, EVENT_WRITE, EVENT_ERROR,
|
||||||
|
errno_from_exception)
|
||||||
from shadowsocks import shell
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['EventLoop', 'POLL_NULL', 'POLL_IN', 'POLL_OUT', 'POLL_ERR',
|
POLL_IN = EVENT_READ
|
||||||
'POLL_HUP', 'POLL_NVAL', 'EVENT_NAMES']
|
POLL_OUT = EVENT_WRITE
|
||||||
|
POLL_ERR = EVENT_ERROR
|
||||||
|
|
||||||
POLL_NULL = 0x00
|
|
||||||
POLL_IN = 0x01
|
|
||||||
POLL_OUT = 0x04
|
|
||||||
POLL_ERR = 0x08
|
|
||||||
POLL_HUP = 0x10
|
|
||||||
POLL_NVAL = 0x20
|
|
||||||
|
|
||||||
|
|
||||||
EVENT_NAMES = {
|
|
||||||
POLL_NULL: 'POLL_NULL',
|
|
||||||
POLL_IN: 'POLL_IN',
|
|
||||||
POLL_OUT: 'POLL_OUT',
|
|
||||||
POLL_ERR: 'POLL_ERR',
|
|
||||||
POLL_HUP: 'POLL_HUP',
|
|
||||||
POLL_NVAL: 'POLL_NVAL',
|
|
||||||
}
|
|
||||||
|
|
||||||
# we check timeouts every TIMEOUT_PRECISION seconds
|
|
||||||
TIMEOUT_PRECISION = 10
|
TIMEOUT_PRECISION = 10
|
||||||
|
|
||||||
|
|
||||||
class KqueueLoop(object):
|
class EventLoop:
|
||||||
|
|
||||||
MAX_EVENTS = 1024
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._kqueue = select.kqueue()
|
self._selector = selectors.DefaultSelector()
|
||||||
self._fds = {}
|
self._stopping = False
|
||||||
|
|
||||||
def _control(self, fd, mode, flags):
|
|
||||||
events = []
|
|
||||||
if mode & POLL_IN:
|
|
||||||
events.append(select.kevent(fd, select.KQ_FILTER_READ, flags))
|
|
||||||
if mode & POLL_OUT:
|
|
||||||
events.append(select.kevent(fd, select.KQ_FILTER_WRITE, flags))
|
|
||||||
for e in events:
|
|
||||||
self._kqueue.control([e], 0)
|
|
||||||
|
|
||||||
def poll(self, timeout):
|
|
||||||
if timeout < 0:
|
|
||||||
timeout = None # kqueue behaviour
|
|
||||||
events = self._kqueue.control(None, KqueueLoop.MAX_EVENTS, timeout)
|
|
||||||
results = defaultdict(lambda: POLL_NULL)
|
|
||||||
for e in events:
|
|
||||||
fd = e.ident
|
|
||||||
if e.filter == select.KQ_FILTER_READ:
|
|
||||||
results[fd] |= POLL_IN
|
|
||||||
elif e.filter == select.KQ_FILTER_WRITE:
|
|
||||||
results[fd] |= POLL_OUT
|
|
||||||
return results.items()
|
|
||||||
|
|
||||||
def register(self, fd, mode):
|
|
||||||
self._fds[fd] = mode
|
|
||||||
self._control(fd, mode, select.KQ_EV_ADD)
|
|
||||||
|
|
||||||
def unregister(self, fd):
|
|
||||||
self._control(fd, self._fds[fd], select.KQ_EV_DELETE)
|
|
||||||
del self._fds[fd]
|
|
||||||
|
|
||||||
def modify(self, fd, mode):
|
|
||||||
self.unregister(fd)
|
|
||||||
self.register(fd, mode)
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self._kqueue.close()
|
|
||||||
|
|
||||||
|
|
||||||
class SelectLoop(object):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._r_list = set()
|
|
||||||
self._w_list = set()
|
|
||||||
self._x_list = set()
|
|
||||||
|
|
||||||
def poll(self, timeout):
|
|
||||||
r, w, x = select.select(self._r_list, self._w_list, self._x_list,
|
|
||||||
timeout)
|
|
||||||
results = defaultdict(lambda: POLL_NULL)
|
|
||||||
for p in [(r, POLL_IN), (w, POLL_OUT), (x, POLL_ERR)]:
|
|
||||||
for fd in p[0]:
|
|
||||||
results[fd] |= p[1]
|
|
||||||
return results.items()
|
|
||||||
|
|
||||||
def register(self, fd, mode):
|
|
||||||
if mode & POLL_IN:
|
|
||||||
self._r_list.add(fd)
|
|
||||||
if mode & POLL_OUT:
|
|
||||||
self._w_list.add(fd)
|
|
||||||
if mode & POLL_ERR:
|
|
||||||
self._x_list.add(fd)
|
|
||||||
|
|
||||||
def unregister(self, fd):
|
|
||||||
if fd in self._r_list:
|
|
||||||
self._r_list.remove(fd)
|
|
||||||
if fd in self._w_list:
|
|
||||||
self._w_list.remove(fd)
|
|
||||||
if fd in self._x_list:
|
|
||||||
self._x_list.remove(fd)
|
|
||||||
|
|
||||||
def modify(self, fd, mode):
|
|
||||||
self.unregister(fd)
|
|
||||||
self.register(fd, mode)
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class EventLoop(object):
|
|
||||||
def __init__(self):
|
|
||||||
if hasattr(select, 'epoll'):
|
|
||||||
self._impl = select.epoll()
|
|
||||||
model = 'epoll'
|
|
||||||
elif hasattr(select, 'kqueue'):
|
|
||||||
self._impl = KqueueLoop()
|
|
||||||
model = 'kqueue'
|
|
||||||
elif hasattr(select, 'select'):
|
|
||||||
self._impl = SelectLoop()
|
|
||||||
model = 'select'
|
|
||||||
else:
|
|
||||||
raise Exception('can not find any available functions in select '
|
|
||||||
'package')
|
|
||||||
self._fdmap = {} # (f, handler)
|
|
||||||
self._last_time = time.time()
|
self._last_time = time.time()
|
||||||
self._periodic_callbacks = []
|
self._periodic_callbacks = []
|
||||||
self._stopping = False
|
|
||||||
logging.debug('using event model: %s', model)
|
|
||||||
|
|
||||||
def poll(self, timeout=None):
|
def poll(self, timeout=None):
|
||||||
events = self._impl.poll(timeout)
|
return self._selector.select(timeout)
|
||||||
return [(self._fdmap[fd][0], fd, event) for fd, event in events]
|
|
||||||
|
|
||||||
def add(self, f, mode, handler):
|
def add(self, sock, events, data):
|
||||||
fd = f.fileno()
|
events |= selectors.EVENT_ERROR
|
||||||
self._fdmap[fd] = (f, handler)
|
return self._selector.register(sock, events, data)
|
||||||
self._impl.register(fd, mode)
|
|
||||||
|
|
||||||
def remove(self, f):
|
def remove(self, sock):
|
||||||
fd = f.fileno()
|
try:
|
||||||
del self._fdmap[fd]
|
return self._selector.unregister(sock)
|
||||||
self._impl.unregister(fd)
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def modify(self, sock, events, data):
|
||||||
|
events |= selectors.EVENT_ERROR
|
||||||
|
try:
|
||||||
|
key = self._selector.modify(sock, events, data)
|
||||||
|
except KeyError:
|
||||||
|
key = self.add(sock, events, data)
|
||||||
|
return key
|
||||||
|
|
||||||
def add_periodic(self, callback):
|
def add_periodic(self, callback):
|
||||||
self._periodic_callbacks.append(callback)
|
self._periodic_callbacks.append(callback)
|
||||||
|
@ -183,69 +55,57 @@ class EventLoop(object):
|
||||||
def remove_periodic(self, callback):
|
def remove_periodic(self, callback):
|
||||||
self._periodic_callbacks.remove(callback)
|
self._periodic_callbacks.remove(callback)
|
||||||
|
|
||||||
def modify(self, f, mode):
|
def fd_count(self):
|
||||||
fd = f.fileno()
|
return len(self._selector.get_map())
|
||||||
self._impl.modify(fd, mode)
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
self._stopping = True
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
events = []
|
logging.debug('Starting event loop')
|
||||||
|
|
||||||
while not self._stopping:
|
while not self._stopping:
|
||||||
asap = False
|
asap = False
|
||||||
try:
|
try:
|
||||||
events = self.poll(TIMEOUT_PRECISION)
|
events = self.poll(timeout=TIMEOUT_PRECISION)
|
||||||
except (OSError, IOError) as e:
|
except (OSError, IOError) as e:
|
||||||
if errno_from_exception(e) in (errno.EPIPE, errno.EINTR):
|
if errno_from_exception(e) in (errno.EPIPE, errno.EINTR):
|
||||||
# EPIPE: Happens when the client closes the connection
|
# EPIPE: Happens when the client closes the connection
|
||||||
# EINTR: Happens when received a signal
|
# EINTR: Happens when received a signal
|
||||||
# handles them as soon as possible
|
# handles them as soon as possible
|
||||||
asap = True
|
asap = True
|
||||||
logging.debug('poll:%s', e)
|
logging.debug('poll: %s', e)
|
||||||
else:
|
else:
|
||||||
logging.error('poll:%s', e)
|
logging.error('poll: %s', e)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for sock, fd, event in events:
|
for key, event in events:
|
||||||
handler = self._fdmap.get(fd, None)
|
if type(key.data) == tuple:
|
||||||
if handler is not None:
|
handler = key.data[0]
|
||||||
handler = handler[1]
|
args = key.data[1:]
|
||||||
try:
|
else:
|
||||||
handler.handle_event(sock, fd, event)
|
handler = key.data
|
||||||
except (OSError, IOError) as e:
|
args = ()
|
||||||
shell.print_exception(e)
|
|
||||||
|
sock = key.fileobj
|
||||||
|
if hasattr(handler, 'handle_event'):
|
||||||
|
handler = handler.handle_event
|
||||||
|
|
||||||
|
try:
|
||||||
|
handler(sock, event, *args)
|
||||||
|
except Exception as e:
|
||||||
|
logging.debug(e)
|
||||||
|
traceback.print_exc()
|
||||||
|
raise
|
||||||
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
if asap or now - self._last_time >= TIMEOUT_PRECISION:
|
if asap or now - self._last_time >= TIMEOUT_PRECISION:
|
||||||
for callback in self._periodic_callbacks:
|
for callback in self._periodic_callbacks:
|
||||||
callback()
|
callback()
|
||||||
self._last_time = now
|
self._last_time = now
|
||||||
|
|
||||||
def __del__(self):
|
logging.debug('Got {} fds registered'.format(self.fd_count()))
|
||||||
self._impl.close()
|
|
||||||
|
|
||||||
|
logging.debug('Stopping event loop')
|
||||||
|
self._selector.close()
|
||||||
|
|
||||||
# from tornado
|
def stop(self):
|
||||||
def errno_from_exception(e):
|
self._stopping = True
|
||||||
"""Provides the errno from an Exception object.
|
|
||||||
|
|
||||||
There are cases that the errno attribute was not set so we pull
|
|
||||||
the errno out of the args but if someone instatiates an Exception
|
|
||||||
without any args you will get a tuple error. So this function
|
|
||||||
abstracts all that behavior to give you a safe way to get the
|
|
||||||
errno.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if hasattr(e, 'errno'):
|
|
||||||
return e.errno
|
|
||||||
elif e.args:
|
|
||||||
return e.args[0]
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# from tornado
|
|
||||||
def get_sock_error(sock):
|
|
||||||
error_number = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
|
|
||||||
return socket.error(error_number, os.strerror(error_number))
|
|
||||||
|
|
|
@ -87,8 +87,8 @@ class LRUCache(collections.MutableMapping):
|
||||||
if value not in self._closed_values:
|
if value not in self._closed_values:
|
||||||
self.close_callback(value)
|
self.close_callback(value)
|
||||||
self._closed_values.add(value)
|
self._closed_values.add(value)
|
||||||
self._last_visits.popleft()
|
|
||||||
for key in self._time_to_keys[least]:
|
for key in self._time_to_keys[least]:
|
||||||
|
self._last_visits.popleft()
|
||||||
if key in self._store:
|
if key in self._store:
|
||||||
if now - self._keys_to_last_time[key] > self.timeout:
|
if now - self._keys_to_last_time[key] > self.timeout:
|
||||||
del self._store[key]
|
del self._store[key]
|
||||||
|
|
652
shadowsocks/selectors.py
Normal file
652
shadowsocks/selectors.py
Normal file
|
@ -0,0 +1,652 @@
|
||||||
|
"""Selectors module.
|
||||||
|
|
||||||
|
This module allows high-level and efficient I/O multiplexing, built upon the
|
||||||
|
`select` module primitives.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
from collections import namedtuple, Mapping
|
||||||
|
import errno
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import select
|
||||||
|
import socket
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
# generic events, that must be mapped to implementation-specific ones
|
||||||
|
EVENT_READ = 0x001
|
||||||
|
EVENT_WRITE = 0x004
|
||||||
|
EVENT_ERROR = 0x018
|
||||||
|
|
||||||
|
|
||||||
|
def errno_from_exception(e):
|
||||||
|
"""Provides the errno from an Exception object.
|
||||||
|
|
||||||
|
There are cases that the errno attribute was not set so we pull
|
||||||
|
the errno out of the args but if someone instatiates an Exception
|
||||||
|
without any args you will get a tuple error. So this function
|
||||||
|
abstracts all that behavior to give you a safe way to get the
|
||||||
|
errno.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if hasattr(e, 'errno'):
|
||||||
|
return e.errno
|
||||||
|
elif e.args:
|
||||||
|
return e.args[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_sock_error(sock):
|
||||||
|
if not sock:
|
||||||
|
return
|
||||||
|
error_number = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
|
||||||
|
return socket.error(error_number, os.strerror(error_number))
|
||||||
|
|
||||||
|
|
||||||
|
def _fileobj_to_fd(fileobj):
|
||||||
|
"""Return a file descriptor from a file object.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
fileobj -- file object or file descriptor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
corresponding file descriptor
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError if the object is invalid
|
||||||
|
"""
|
||||||
|
if isinstance(fileobj, int):
|
||||||
|
fd = fileobj
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
fd = int(fileobj.fileno())
|
||||||
|
except (AttributeError, TypeError, ValueError):
|
||||||
|
raise ValueError("Invalid file object: "
|
||||||
|
"{!r}".format(fileobj))
|
||||||
|
if fd < 0:
|
||||||
|
raise ValueError("Invalid file descriptor: {}".format(fd))
|
||||||
|
return fd
|
||||||
|
|
||||||
|
|
||||||
|
SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data'])
|
||||||
|
"""Object used to associate a file object to its backing file descriptor,
|
||||||
|
selected event mask and attached data."""
|
||||||
|
|
||||||
|
|
||||||
|
class _SelectorMapping(Mapping):
|
||||||
|
"""Mapping of file objects to selector keys."""
|
||||||
|
|
||||||
|
def __init__(self, selector):
|
||||||
|
self._selector = selector
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._selector._fd_to_key)
|
||||||
|
|
||||||
|
def __getitem__(self, fileobj):
|
||||||
|
try:
|
||||||
|
fd = self._selector._fileobj_lookup(fileobj)
|
||||||
|
return self._selector._fd_to_key[fd]
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError("{!r} is not registered".format(fileobj))
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._selector._fd_to_key)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSelector(object):
|
||||||
|
__metaclass__ = ABCMeta
|
||||||
|
"""Selector abstract base class.
|
||||||
|
|
||||||
|
A selector supports registering file objects to be monitored for specific
|
||||||
|
I/O events.
|
||||||
|
|
||||||
|
A file object is a file descriptor or any object with a `fileno()` method.
|
||||||
|
An arbitrary object can be attached to the file object, which can be used
|
||||||
|
for example to store context information, a callback, etc.
|
||||||
|
|
||||||
|
A selector can use various implementations (select(), poll(), epoll()...)
|
||||||
|
depending on the platform. The default `Selector` class uses the most
|
||||||
|
efficient implementation on the current platform.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def register(self, fileobj, events, data=None):
|
||||||
|
"""Register a file object.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
fileobj -- file object or file descriptor
|
||||||
|
events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE)
|
||||||
|
data -- attached data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SelectorKey instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError if events is invalid
|
||||||
|
KeyError if fileobj is already registered
|
||||||
|
OSError if fileobj is closed or otherwise is unacceptable to
|
||||||
|
the underlying system call (if a system call is made)
|
||||||
|
|
||||||
|
Note:
|
||||||
|
OSError may or may not be raised
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def unregister(self, fileobj):
|
||||||
|
"""Unregister a file object.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
fileobj -- file object or file descriptor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SelectorKey instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError if fileobj is not registered
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If fileobj is registered but has since been closed this does
|
||||||
|
*not* raise OSError (even if the wrapped syscall does)
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def modify(self, fileobj, events, data=None):
|
||||||
|
"""Change a registered file object monitored events or attached data.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
fileobj -- file object or file descriptor
|
||||||
|
events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE)
|
||||||
|
data -- attached data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SelectorKey instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Anything that unregister() or register() raises
|
||||||
|
"""
|
||||||
|
self.unregister(fileobj)
|
||||||
|
return self.register(fileobj, events, data)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def select(self, timeout=None):
|
||||||
|
"""Perform the actual selection, until some monitored file objects are
|
||||||
|
ready or a timeout expires.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
timeout -- if timeout > 0, this specifies the maximum wait time, in
|
||||||
|
seconds
|
||||||
|
if timeout <= 0, the select() call won't block, and will
|
||||||
|
report the currently ready file objects
|
||||||
|
if timeout is None, select() will block until a monitored
|
||||||
|
file object becomes ready
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of (key, events) for ready file objects
|
||||||
|
`events` is a bitwise mask of EVENT_READ|EVENT_WRITE
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close the selector.
|
||||||
|
|
||||||
|
This must be called to make sure that any underlying resource is freed.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_key(self, fileobj):
|
||||||
|
"""Return the key associated to a registered file object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SelectorKey for this file object
|
||||||
|
"""
|
||||||
|
mapping = self.get_map()
|
||||||
|
try:
|
||||||
|
if mapping is None:
|
||||||
|
raise KeyError
|
||||||
|
return mapping[fileobj]
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError("{!r} is not registered".format(fileobj))
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_map(self):
|
||||||
|
"""Return a mapping of file objects to selector keys."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
|
class _BaseSelectorImpl(BaseSelector):
|
||||||
|
"""Base selector implementation."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# this maps file descriptors to keys
|
||||||
|
self._fd_to_key = {}
|
||||||
|
# read-only mapping returned by get_map()
|
||||||
|
self._map = _SelectorMapping(self)
|
||||||
|
|
||||||
|
def _fileobj_lookup(self, fileobj):
|
||||||
|
"""Return a file descriptor from a file object.
|
||||||
|
|
||||||
|
This wraps _fileobj_to_fd() to do an exhaustive search in case
|
||||||
|
the object is invalid but we still have it in our map. This
|
||||||
|
is used by unregister() so we can unregister an object that
|
||||||
|
was previously registered even if it is closed. It is also
|
||||||
|
used by _SelectorMapping.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return _fileobj_to_fd(fileobj)
|
||||||
|
except ValueError:
|
||||||
|
# Do an exhaustive search.
|
||||||
|
for key in self._fd_to_key.values():
|
||||||
|
if key.fileobj is fileobj:
|
||||||
|
return key.fd
|
||||||
|
# Raise ValueError after all.
|
||||||
|
raise
|
||||||
|
|
||||||
|
def register(self, fileobj, events, data=None):
|
||||||
|
if (not events) or (events & ~(EVENT_READ | EVENT_WRITE |
|
||||||
|
EVENT_ERROR)):
|
||||||
|
raise ValueError("Invalid events: {!r}".format(events))
|
||||||
|
|
||||||
|
key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data)
|
||||||
|
|
||||||
|
if key.fd in self._fd_to_key:
|
||||||
|
raise KeyError("{!r} (FD {}) is already registered"
|
||||||
|
.format(fileobj, key.fd))
|
||||||
|
|
||||||
|
self._fd_to_key[key.fd] = key
|
||||||
|
return key
|
||||||
|
|
||||||
|
def unregister(self, fileobj):
|
||||||
|
try:
|
||||||
|
key = self._fd_to_key.pop(self._fileobj_lookup(fileobj))
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError("{!r} is not registered".format(fileobj))
|
||||||
|
return key
|
||||||
|
|
||||||
|
def modify(self, fileobj, events, data=None):
|
||||||
|
# TODO: Subclasses can probably optimize this even further.
|
||||||
|
try:
|
||||||
|
key = self._fd_to_key[self._fileobj_lookup(fileobj)]
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError("{!r} is not registered".format(fileobj))
|
||||||
|
if events != key.events:
|
||||||
|
self.unregister(fileobj)
|
||||||
|
key = self.register(fileobj, events, data)
|
||||||
|
elif data != key.data:
|
||||||
|
# Use a shortcut to update the data.
|
||||||
|
key = key._replace(data=data)
|
||||||
|
self._fd_to_key[key.fd] = key
|
||||||
|
return key
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._fd_to_key.clear()
|
||||||
|
self._map = None
|
||||||
|
|
||||||
|
def get_map(self):
|
||||||
|
return self._map
|
||||||
|
|
||||||
|
def _key_from_fd(self, fd):
|
||||||
|
"""Return the key associated to a given file descriptor.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
fd -- file descriptor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
corresponding key, or None if not found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return self._fd_to_key[fd]
|
||||||
|
except KeyError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class SelectSelector(_BaseSelectorImpl):
|
||||||
|
"""Select-based selector."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(self.__class__, self).__init__()
|
||||||
|
self._readers = set()
|
||||||
|
self._writers = set()
|
||||||
|
self._errors = set()
|
||||||
|
|
||||||
|
def register(self, fileobj, events, data=None):
|
||||||
|
key = super(self.__class__, self).register(fileobj, events, data)
|
||||||
|
if events & EVENT_READ:
|
||||||
|
self._readers.add(key.fd)
|
||||||
|
if events & EVENT_WRITE:
|
||||||
|
self._writers.add(key.fd)
|
||||||
|
if events & EVENT_ERROR:
|
||||||
|
self._errors.add(key.fd)
|
||||||
|
return key
|
||||||
|
|
||||||
|
def unregister(self, fileobj):
|
||||||
|
key = super(self.__class__, self).unregister(fileobj)
|
||||||
|
self._readers.discard(key.fd)
|
||||||
|
self._writers.discard(key.fd)
|
||||||
|
self._errors.discard(key.fd)
|
||||||
|
return key
|
||||||
|
|
||||||
|
if sys.platform == 'win32':
|
||||||
|
def _select(self, r, w, x, timeout=None):
|
||||||
|
r, w, x = select.select(r, w, x, timeout)
|
||||||
|
return r, w, x
|
||||||
|
else:
|
||||||
|
_select = select.select
|
||||||
|
|
||||||
|
def select(self, timeout=None):
|
||||||
|
timeout = None if timeout is None else max(timeout, 0)
|
||||||
|
ready = []
|
||||||
|
try:
|
||||||
|
r, w, x = self._select(self._readers, self._writers, self._errors,
|
||||||
|
timeout)
|
||||||
|
except OSError as e:
|
||||||
|
if errno_from_exception(e) == errno.EAGAIN:
|
||||||
|
return ready
|
||||||
|
r = set(r)
|
||||||
|
w = set(w)
|
||||||
|
x = set(x)
|
||||||
|
for fd in r | w | x:
|
||||||
|
events = 0
|
||||||
|
if fd in r:
|
||||||
|
events |= EVENT_READ
|
||||||
|
if fd in w:
|
||||||
|
events |= EVENT_WRITE
|
||||||
|
if fd in x:
|
||||||
|
events |= EVENT_ERROR
|
||||||
|
|
||||||
|
key = self._key_from_fd(fd)
|
||||||
|
if key:
|
||||||
|
ready.append((key, events & key.events))
|
||||||
|
return ready
|
||||||
|
|
||||||
|
|
||||||
|
if hasattr(select, 'poll'):
|
||||||
|
|
||||||
|
class PollSelector(_BaseSelectorImpl):
|
||||||
|
"""Poll-based selector."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(self.__class__, self).__init__()
|
||||||
|
self._poll = select.poll()
|
||||||
|
|
||||||
|
def register(self, fileobj, events, data=None):
|
||||||
|
key = super(self.__class__, self).register(fileobj, events, data)
|
||||||
|
poll_events = 0
|
||||||
|
if events & EVENT_READ:
|
||||||
|
poll_events |= select.POLLIN
|
||||||
|
if events & EVENT_WRITE:
|
||||||
|
poll_events |= select.POLLOUT
|
||||||
|
if events & EVENT_ERROR:
|
||||||
|
poll_events |= select.POLLERR | select.POLLHUP
|
||||||
|
self._poll.register(key.fd, poll_events)
|
||||||
|
return key
|
||||||
|
|
||||||
|
def unregister(self, fileobj):
|
||||||
|
key = super(self.__class__, self).unregister(fileobj)
|
||||||
|
self._poll.unregister(key.fd)
|
||||||
|
return key
|
||||||
|
|
||||||
|
def select(self, timeout=None):
|
||||||
|
if timeout is None:
|
||||||
|
timeout = None
|
||||||
|
elif timeout <= 0:
|
||||||
|
timeout = 0
|
||||||
|
else:
|
||||||
|
# poll() has a resolution of 1 millisecond, round away from
|
||||||
|
# zero to wait *at least* timeout seconds.
|
||||||
|
timeout = math.ceil(timeout * 1e3)
|
||||||
|
ready = []
|
||||||
|
try:
|
||||||
|
fd_event_list = self._poll.poll(timeout)
|
||||||
|
except OSError as e:
|
||||||
|
if errno_from_exception(e) == errno.EAGAIN:
|
||||||
|
return ready
|
||||||
|
for fd, event in fd_event_list:
|
||||||
|
events = 0
|
||||||
|
if event & select.POLLIN:
|
||||||
|
events |= EVENT_READ
|
||||||
|
if event & select.POLLOUT:
|
||||||
|
events |= EVENT_WRITE
|
||||||
|
if event & (select.POLLERR | select.POLLHUP):
|
||||||
|
events |= EVENT_ERROR
|
||||||
|
|
||||||
|
key = self._key_from_fd(fd)
|
||||||
|
if key:
|
||||||
|
ready.append((key, events & key.events))
|
||||||
|
return ready
|
||||||
|
|
||||||
|
|
||||||
|
if hasattr(select, 'epoll'):
|
||||||
|
|
||||||
|
class EpollSelector(_BaseSelectorImpl):
|
||||||
|
"""Epoll-based selector."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(self.__class__, self).__init__()
|
||||||
|
self._epoll = select.epoll()
|
||||||
|
|
||||||
|
def fileno(self):
|
||||||
|
return self._epoll.fileno()
|
||||||
|
|
||||||
|
def register(self, fileobj, events, data=None):
|
||||||
|
key = super(self.__class__, self).register(fileobj, events, data)
|
||||||
|
epoll_events = 0
|
||||||
|
if events & EVENT_READ:
|
||||||
|
epoll_events |= select.EPOLLIN
|
||||||
|
if events & EVENT_WRITE:
|
||||||
|
epoll_events |= select.EPOLLOUT
|
||||||
|
if events & EVENT_ERROR:
|
||||||
|
epoll_events |= select.EPOLLERR | select.EPOLLHUP
|
||||||
|
self._epoll.register(key.fd, epoll_events)
|
||||||
|
return key
|
||||||
|
|
||||||
|
def unregister(self, fileobj):
|
||||||
|
key = super(self.__class__, self).unregister(fileobj)
|
||||||
|
try:
|
||||||
|
self._epoll.unregister(key.fd)
|
||||||
|
except OSError:
|
||||||
|
# This can happen if the FD was closed since it
|
||||||
|
# was registered.
|
||||||
|
pass
|
||||||
|
return key
|
||||||
|
|
||||||
|
def select(self, timeout=None):
|
||||||
|
if timeout is None:
|
||||||
|
timeout = -1
|
||||||
|
elif timeout <= 0:
|
||||||
|
timeout = 0
|
||||||
|
else:
|
||||||
|
# epoll_wait() has a resolution of 1 millisecond, round away
|
||||||
|
# from zero to wait *at least* timeout seconds.
|
||||||
|
timeout = math.ceil(timeout * 1e3) * 1e-3
|
||||||
|
|
||||||
|
# epoll_wait() expects `maxevents` to be greater than zero;
|
||||||
|
# we want to make sure that `select()` can be called when no
|
||||||
|
# FD is registered.
|
||||||
|
max_ev = max(len(self._fd_to_key), 1)
|
||||||
|
|
||||||
|
ready = []
|
||||||
|
try:
|
||||||
|
fd_event_list = self._epoll.poll(timeout, max_ev)
|
||||||
|
except OSError as e:
|
||||||
|
if errno_from_exception(e) == errno.EAGAIN:
|
||||||
|
return ready
|
||||||
|
for fd, event in fd_event_list:
|
||||||
|
events = 0
|
||||||
|
if event & select.EPOLLIN:
|
||||||
|
events |= EVENT_READ
|
||||||
|
if event & select.EPOLLOUT:
|
||||||
|
events |= EVENT_WRITE
|
||||||
|
if event & (select.EPOLLERR | select.EPOLLHUP):
|
||||||
|
events |= EVENT_ERROR
|
||||||
|
|
||||||
|
key = self._key_from_fd(fd)
|
||||||
|
if key:
|
||||||
|
ready.append((key, events & key.events))
|
||||||
|
return ready
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
try:
|
||||||
|
self._epoll.close()
|
||||||
|
finally:
|
||||||
|
super(self.__class__, self).close()
|
||||||
|
|
||||||
|
|
||||||
|
if hasattr(select, 'devpoll'):
|
||||||
|
|
||||||
|
class DevpollSelector(_BaseSelectorImpl):
|
||||||
|
"""Solaris /dev/poll selector."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._devpoll = select.devpoll()
|
||||||
|
|
||||||
|
def fileno(self):
|
||||||
|
return self._devpoll.fileno()
|
||||||
|
|
||||||
|
def register(self, fileobj, events, data=None):
|
||||||
|
key = super().register(fileobj, events, data)
|
||||||
|
poll_events = 0
|
||||||
|
if events & EVENT_READ:
|
||||||
|
poll_events |= select.POLLIN
|
||||||
|
if events & EVENT_WRITE:
|
||||||
|
poll_events |= select.POLLOUT
|
||||||
|
if events & EVENT_ERROR:
|
||||||
|
poll_events |= select.POLLERR | select.POLLHUP
|
||||||
|
self._devpoll.register(key.fd, poll_events)
|
||||||
|
return key
|
||||||
|
|
||||||
|
def unregister(self, fileobj):
|
||||||
|
key = super().unregister(fileobj)
|
||||||
|
self._devpoll.unregister(key.fd)
|
||||||
|
return key
|
||||||
|
|
||||||
|
def select(self, timeout=None):
|
||||||
|
if timeout is None:
|
||||||
|
timeout = None
|
||||||
|
elif timeout <= 0:
|
||||||
|
timeout = 0
|
||||||
|
else:
|
||||||
|
# devpoll() has a resolution of 1 millisecond, round away from
|
||||||
|
# zero to wait *at least* timeout seconds.
|
||||||
|
timeout = math.ceil(timeout * 1e3)
|
||||||
|
ready = []
|
||||||
|
try:
|
||||||
|
fd_event_list = self._devpoll.poll(timeout)
|
||||||
|
except InterruptedError:
|
||||||
|
return ready
|
||||||
|
for fd, event in fd_event_list:
|
||||||
|
events = 0
|
||||||
|
if event & select.POLLIN:
|
||||||
|
events |= EVENT_READ
|
||||||
|
if event & select.POLLOUT:
|
||||||
|
events |= EVENT_WRITE
|
||||||
|
if event & (select.POLLERR | select.POLLHUP):
|
||||||
|
events |= EVENT_ERROR
|
||||||
|
|
||||||
|
key = self._key_from_fd(fd)
|
||||||
|
if key:
|
||||||
|
ready.append((key, events & key.events))
|
||||||
|
return ready
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._devpoll.close()
|
||||||
|
super().close()
|
||||||
|
|
||||||
|
|
||||||
|
if hasattr(select, 'kqueue'):
|
||||||
|
|
||||||
|
class KqueueSelector(_BaseSelectorImpl):
|
||||||
|
"""Kqueue-based selector."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(self.__class__, self).__init__()
|
||||||
|
self._kqueue = select.kqueue()
|
||||||
|
|
||||||
|
def fileno(self):
|
||||||
|
return self._kqueue.fileno()
|
||||||
|
|
||||||
|
def register(self, fileobj, events, data=None):
|
||||||
|
key = super(self.__class__, self).register(fileobj, events, data)
|
||||||
|
if events & EVENT_READ:
|
||||||
|
kev = select.kevent(key.fd, select.KQ_FILTER_READ,
|
||||||
|
select.KQ_EV_ADD)
|
||||||
|
self._kqueue.control([kev], 0, 0)
|
||||||
|
if events & EVENT_WRITE:
|
||||||
|
kev = select.kevent(key.fd, select.KQ_FILTER_WRITE,
|
||||||
|
select.KQ_EV_ADD)
|
||||||
|
self._kqueue.control([kev], 0, 0)
|
||||||
|
return key
|
||||||
|
|
||||||
|
def unregister(self, fileobj):
|
||||||
|
key = super(self.__class__, self).unregister(fileobj)
|
||||||
|
if key.events & EVENT_READ:
|
||||||
|
kev = select.kevent(key.fd, select.KQ_FILTER_READ,
|
||||||
|
select.KQ_EV_DELETE)
|
||||||
|
try:
|
||||||
|
self._kqueue.control([kev], 0, 0)
|
||||||
|
except OSError:
|
||||||
|
# This can happen if the FD was closed since it
|
||||||
|
# was registered.
|
||||||
|
pass
|
||||||
|
if key.events & EVENT_WRITE:
|
||||||
|
kev = select.kevent(key.fd, select.KQ_FILTER_WRITE,
|
||||||
|
select.KQ_EV_DELETE)
|
||||||
|
try:
|
||||||
|
self._kqueue.control([kev], 0, 0)
|
||||||
|
except OSError:
|
||||||
|
# See comment above.
|
||||||
|
pass
|
||||||
|
return key
|
||||||
|
|
||||||
|
def select(self, timeout=None):
|
||||||
|
timeout = None if timeout is None else max(timeout, 0)
|
||||||
|
max_ev = len(self._fd_to_key)
|
||||||
|
ready = []
|
||||||
|
try:
|
||||||
|
kev_list = self._kqueue.control(None, max_ev, timeout)
|
||||||
|
except OSError as e:
|
||||||
|
if errno_from_exception(e) == errno.EAGAIN:
|
||||||
|
return ready
|
||||||
|
for kev in kev_list:
|
||||||
|
fd = kev.ident
|
||||||
|
flag = kev.filter
|
||||||
|
events = 0
|
||||||
|
if flag == select.KQ_FILTER_READ:
|
||||||
|
events |= EVENT_READ
|
||||||
|
if flag == select.KQ_FILTER_WRITE:
|
||||||
|
events |= EVENT_WRITE
|
||||||
|
|
||||||
|
key = self._key_from_fd(fd)
|
||||||
|
if key:
|
||||||
|
ready.append((key, events & key.events))
|
||||||
|
return ready
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
try:
|
||||||
|
self._kqueue.close()
|
||||||
|
finally:
|
||||||
|
super(self.__class__, self).close()
|
||||||
|
|
||||||
|
|
||||||
|
# Choose the best implementation: roughly, epoll|kqueue > poll > select.
|
||||||
|
# select() also can't accept a FD > FD_SETSIZE (usually around 1024)
|
||||||
|
if 'KqueueSelector' in globals():
|
||||||
|
DefaultSelector = KqueueSelector
|
||||||
|
elif 'EpollSelector' in globals():
|
||||||
|
DefaultSelector = EpollSelector
|
||||||
|
elif 'DevpollSelector' in globals():
|
||||||
|
DefaultSelector = DevpollSelector
|
||||||
|
elif 'PollSelector' in globals():
|
||||||
|
DefaultSelector = PollSelector
|
||||||
|
else:
|
||||||
|
DefaultSelector = SelectSelector
|
148
shadowsocks/server.py
Executable file → Normal file
148
shadowsocks/server.py
Executable file → Normal file
|
@ -1,143 +1,35 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
#
|
|
||||||
# Copyright 2015 clowwindy
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
|
||||||
# not use this file except in compliance with the License. You may obtain
|
|
||||||
# a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
# License for the specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, \
|
from __future__ import absolute_import, division, print_function, \
|
||||||
with_statement
|
with_statement
|
||||||
|
|
||||||
import sys
|
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import logging
|
import logging
|
||||||
import signal
|
|
||||||
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../'))
|
path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, \
|
sys.path.insert(0, path)
|
||||||
asyncdns, manager
|
from shadowsocks.eventloop import EventLoop
|
||||||
|
from shadowsocks.tcprelay import TcpRelay, TcpRelayServerHandler
|
||||||
|
from shadowsocks.asyncdns import DNSResolver
|
||||||
|
|
||||||
|
|
||||||
|
FORMATTER = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
LOGGING_LEVEL = logging.INFO
|
||||||
|
logging.basicConfig(level=LOGGING_LEVEL, format=FORMATTER)
|
||||||
|
|
||||||
|
LISTEN_ADDR = ('0.0.0.0', 9000)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
shell.check_python()
|
loop = EventLoop()
|
||||||
|
dns_resolver = DNSResolver()
|
||||||
config = shell.get_config(False)
|
relay = TcpRelay(TcpRelayServerHandler, LISTEN_ADDR,
|
||||||
|
dns_resolver=dns_resolver)
|
||||||
daemon.daemon_exec(config)
|
dns_resolver.add_to_loop(loop)
|
||||||
|
relay.add_to_loop(loop)
|
||||||
if config['port_password']:
|
loop.run()
|
||||||
if config['password']:
|
|
||||||
logging.warn('warning: port_password should not be used with '
|
|
||||||
'server_port and password. server_port and password '
|
|
||||||
'will be ignored')
|
|
||||||
else:
|
|
||||||
config['port_password'] = {}
|
|
||||||
server_port = config['server_port']
|
|
||||||
if type(server_port) == list:
|
|
||||||
for a_server_port in server_port:
|
|
||||||
config['port_password'][a_server_port] = config['password']
|
|
||||||
else:
|
|
||||||
config['port_password'][str(server_port)] = config['password']
|
|
||||||
|
|
||||||
if config.get('manager_address', 0):
|
|
||||||
logging.info('entering manager mode')
|
|
||||||
manager.run(config)
|
|
||||||
return
|
|
||||||
|
|
||||||
tcp_servers = []
|
|
||||||
udp_servers = []
|
|
||||||
|
|
||||||
if 'dns_server' in config: # allow override settings in resolv.conf
|
|
||||||
dns_resolver = asyncdns.DNSResolver(config['dns_server'],
|
|
||||||
config['prefer_ipv6'])
|
|
||||||
else:
|
|
||||||
dns_resolver = asyncdns.DNSResolver(prefer_ipv6=config['prefer_ipv6'])
|
|
||||||
|
|
||||||
port_password = config['port_password']
|
|
||||||
del config['port_password']
|
|
||||||
for port, password in port_password.items():
|
|
||||||
a_config = config.copy()
|
|
||||||
a_config['server_port'] = int(port)
|
|
||||||
a_config['password'] = password
|
|
||||||
logging.info("starting server at %s:%d" %
|
|
||||||
(a_config['server'], int(port)))
|
|
||||||
tcp_servers.append(tcprelay.TCPRelay(a_config, dns_resolver, False))
|
|
||||||
udp_servers.append(udprelay.UDPRelay(a_config, dns_resolver, False))
|
|
||||||
|
|
||||||
def run_server():
|
|
||||||
def child_handler(signum, _):
|
|
||||||
logging.warn('received SIGQUIT, doing graceful shutting down..')
|
|
||||||
list(map(lambda s: s.close(next_tick=True),
|
|
||||||
tcp_servers + udp_servers))
|
|
||||||
signal.signal(getattr(signal, 'SIGQUIT', signal.SIGTERM),
|
|
||||||
child_handler)
|
|
||||||
|
|
||||||
def int_handler(signum, _):
|
|
||||||
sys.exit(1)
|
|
||||||
signal.signal(signal.SIGINT, int_handler)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop = eventloop.EventLoop()
|
|
||||||
dns_resolver.add_to_loop(loop)
|
|
||||||
list(map(lambda s: s.add_to_loop(loop), tcp_servers + udp_servers))
|
|
||||||
|
|
||||||
daemon.set_user(config.get('user', None))
|
|
||||||
loop.run()
|
|
||||||
except Exception as e:
|
|
||||||
shell.print_exception(e)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if int(config['workers']) > 1:
|
|
||||||
if os.name == 'posix':
|
|
||||||
children = []
|
|
||||||
is_child = False
|
|
||||||
for i in range(0, int(config['workers'])):
|
|
||||||
r = os.fork()
|
|
||||||
if r == 0:
|
|
||||||
logging.info('worker started')
|
|
||||||
is_child = True
|
|
||||||
run_server()
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
children.append(r)
|
|
||||||
if not is_child:
|
|
||||||
def handler(signum, _):
|
|
||||||
for pid in children:
|
|
||||||
try:
|
|
||||||
os.kill(pid, signum)
|
|
||||||
os.waitpid(pid, 0)
|
|
||||||
except OSError: # child may already exited
|
|
||||||
pass
|
|
||||||
sys.exit()
|
|
||||||
signal.signal(signal.SIGTERM, handler)
|
|
||||||
signal.signal(signal.SIGQUIT, handler)
|
|
||||||
signal.signal(signal.SIGINT, handler)
|
|
||||||
|
|
||||||
# master
|
|
||||||
for a_tcp_server in tcp_servers:
|
|
||||||
a_tcp_server.close()
|
|
||||||
for a_udp_server in udp_servers:
|
|
||||||
a_udp_server.close()
|
|
||||||
dns_resolver.close()
|
|
||||||
|
|
||||||
for child in children:
|
|
||||||
os.waitpid(child, 0)
|
|
||||||
else:
|
|
||||||
logging.warn('worker is only available on Unix/Linux')
|
|
||||||
run_server()
|
|
||||||
else:
|
|
||||||
run_server()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -23,10 +23,6 @@ import json
|
||||||
import sys
|
import sys
|
||||||
import getopt
|
import getopt
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
|
||||||
|
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
from shadowsocks.common import to_bytes, to_str, IPNetwork
|
from shadowsocks.common import to_bytes, to_str, IPNetwork
|
||||||
from shadowsocks import encrypt
|
from shadowsocks import encrypt
|
||||||
|
|
||||||
|
@ -57,49 +53,6 @@ def print_exception(e):
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
def exception_handle(self_, err_msg=None, exit_code=None,
|
|
||||||
destroy=False, conn_err=False):
|
|
||||||
# self_: if function passes self as first arg
|
|
||||||
|
|
||||||
def process_exception(e, self=None):
|
|
||||||
print_exception(e)
|
|
||||||
if err_msg:
|
|
||||||
logging.error(err_msg)
|
|
||||||
if exit_code:
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if not self_:
|
|
||||||
return
|
|
||||||
|
|
||||||
if conn_err:
|
|
||||||
addr, port = self._client_address[0], self._client_address[1]
|
|
||||||
logging.error('%s when handling connection from %s:%d' %
|
|
||||||
(e, addr, port))
|
|
||||||
if self._config['verbose']:
|
|
||||||
traceback.print_exc()
|
|
||||||
if destroy:
|
|
||||||
self.destroy()
|
|
||||||
|
|
||||||
def decorator(func):
|
|
||||||
if self_:
|
|
||||||
@wraps(func)
|
|
||||||
def wrapper(self, *args, **kwargs):
|
|
||||||
try:
|
|
||||||
func(self, *args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
process_exception(e, self)
|
|
||||||
else:
|
|
||||||
@wraps(func)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
try:
|
|
||||||
func(*args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
process_exception(e)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def print_shadowsocks():
|
def print_shadowsocks():
|
||||||
version = ''
|
version = ''
|
||||||
try:
|
try:
|
||||||
|
@ -125,30 +78,13 @@ def check_config(config, is_local):
|
||||||
# no need to specify configuration for daemon stop
|
# no need to specify configuration for daemon stop
|
||||||
return
|
return
|
||||||
|
|
||||||
if is_local:
|
|
||||||
if config.get('server', None) is None:
|
|
||||||
logging.error('server addr not specified')
|
|
||||||
print_local_help()
|
|
||||||
sys.exit(2)
|
|
||||||
else:
|
|
||||||
config['server'] = to_str(config['server'])
|
|
||||||
else:
|
|
||||||
config['server'] = to_str(config.get('server', '0.0.0.0'))
|
|
||||||
try:
|
|
||||||
config['forbidden_ip'] = \
|
|
||||||
IPNetwork(config.get('forbidden_ip', '127.0.0.0/8,::1/128'))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(e)
|
|
||||||
sys.exit(2)
|
|
||||||
|
|
||||||
if is_local and not config.get('password', None):
|
if is_local and not config.get('password', None):
|
||||||
logging.error('password not specified')
|
logging.error('password not specified')
|
||||||
print_help(is_local)
|
print_help(is_local)
|
||||||
sys.exit(2)
|
sys.exit(2)
|
||||||
|
|
||||||
if not is_local and not config.get('password', None) \
|
if not is_local and not config.get('password', None) \
|
||||||
and not config.get('port_password', None) \
|
and not config.get('port_password', None):
|
||||||
and not config.get('manager_address'):
|
|
||||||
logging.error('password or port_password not specified')
|
logging.error('password or port_password not specified')
|
||||||
print_help(is_local)
|
print_help(is_local)
|
||||||
sys.exit(2)
|
sys.exit(2)
|
||||||
|
@ -194,14 +130,13 @@ def get_config(is_local):
|
||||||
logging.basicConfig(level=logging.INFO,
|
logging.basicConfig(level=logging.INFO,
|
||||||
format='%(levelname)-s: %(message)s')
|
format='%(levelname)-s: %(message)s')
|
||||||
if is_local:
|
if is_local:
|
||||||
shortopts = 'hd:s:b:p:k:l:m:c:t:vqa'
|
shortopts = 'hd:s:b:p:k:l:m:c:t:vq'
|
||||||
longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'user=',
|
longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'user=',
|
||||||
'version']
|
'version']
|
||||||
else:
|
else:
|
||||||
shortopts = 'hd:s:p:k:m:c:t:vqa'
|
shortopts = 'hd:s:p:k:m:c:t:vq'
|
||||||
longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'workers=',
|
longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'workers=',
|
||||||
'forbidden-ip=', 'user=', 'manager-address=', 'version',
|
'forbidden-ip=', 'user=', 'manager-address=', 'version']
|
||||||
'prefer-ipv6']
|
|
||||||
try:
|
try:
|
||||||
config_path = find_config()
|
config_path = find_config()
|
||||||
optlist, args = getopt.getopt(sys.argv[1:], shortopts, longopts)
|
optlist, args = getopt.getopt(sys.argv[1:], shortopts, longopts)
|
||||||
|
@ -239,8 +174,6 @@ def get_config(is_local):
|
||||||
v_count += 1
|
v_count += 1
|
||||||
# '-vv' turns on more verbose mode
|
# '-vv' turns on more verbose mode
|
||||||
config['verbose'] = v_count
|
config['verbose'] = v_count
|
||||||
elif key == '-a':
|
|
||||||
config['one_time_auth'] = True
|
|
||||||
elif key == '-t':
|
elif key == '-t':
|
||||||
config['timeout'] = int(value)
|
config['timeout'] = int(value)
|
||||||
elif key == '--fast-open':
|
elif key == '--fast-open':
|
||||||
|
@ -271,8 +204,6 @@ def get_config(is_local):
|
||||||
elif key == '-q':
|
elif key == '-q':
|
||||||
v_count -= 1
|
v_count -= 1
|
||||||
config['verbose'] = v_count
|
config['verbose'] = v_count
|
||||||
elif key == '--prefer-ipv6':
|
|
||||||
config['prefer_ipv6'] = True
|
|
||||||
except getopt.GetoptError as e:
|
except getopt.GetoptError as e:
|
||||||
print(e, file=sys.stderr)
|
print(e, file=sys.stderr)
|
||||||
print_help(is_local)
|
print_help(is_local)
|
||||||
|
@ -294,8 +225,21 @@ def get_config(is_local):
|
||||||
config['verbose'] = config.get('verbose', False)
|
config['verbose'] = config.get('verbose', False)
|
||||||
config['local_address'] = to_str(config.get('local_address', '127.0.0.1'))
|
config['local_address'] = to_str(config.get('local_address', '127.0.0.1'))
|
||||||
config['local_port'] = config.get('local_port', 1080)
|
config['local_port'] = config.get('local_port', 1080)
|
||||||
config['one_time_auth'] = config.get('one_time_auth', False)
|
if is_local:
|
||||||
config['prefer_ipv6'] = config.get('prefer_ipv6', False)
|
if config.get('server', None) is None:
|
||||||
|
logging.error('server addr not specified')
|
||||||
|
print_local_help()
|
||||||
|
sys.exit(2)
|
||||||
|
else:
|
||||||
|
config['server'] = to_str(config['server'])
|
||||||
|
else:
|
||||||
|
config['server'] = to_str(config.get('server', '0.0.0.0'))
|
||||||
|
try:
|
||||||
|
config['forbidden_ip'] = \
|
||||||
|
IPNetwork(config.get('forbidden_ip', '127.0.0.0/8,::1/128'))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(e)
|
||||||
|
sys.exit(2)
|
||||||
config['server_port'] = config.get('server_port', 8388)
|
config['server_port'] = config.get('server_port', 8388)
|
||||||
|
|
||||||
logging.getLogger('').handlers = []
|
logging.getLogger('').handlers = []
|
||||||
|
@ -342,7 +286,6 @@ Proxy options:
|
||||||
-k PASSWORD password
|
-k PASSWORD password
|
||||||
-m METHOD encryption method, default: aes-256-cfb
|
-m METHOD encryption method, default: aes-256-cfb
|
||||||
-t TIMEOUT timeout in seconds, default: 300
|
-t TIMEOUT timeout in seconds, default: 300
|
||||||
-a ONE_TIME_AUTH one time auth
|
|
||||||
--fast-open use TCP_FASTOPEN, requires Linux 3.7+
|
--fast-open use TCP_FASTOPEN, requires Linux 3.7+
|
||||||
|
|
||||||
General options:
|
General options:
|
||||||
|
@ -372,12 +315,10 @@ Proxy options:
|
||||||
-k PASSWORD password
|
-k PASSWORD password
|
||||||
-m METHOD encryption method, default: aes-256-cfb
|
-m METHOD encryption method, default: aes-256-cfb
|
||||||
-t TIMEOUT timeout in seconds, default: 300
|
-t TIMEOUT timeout in seconds, default: 300
|
||||||
-a ONE_TIME_AUTH one time auth
|
|
||||||
--fast-open use TCP_FASTOPEN, requires Linux 3.7+
|
--fast-open use TCP_FASTOPEN, requires Linux 3.7+
|
||||||
--workers WORKERS number of workers, available on Unix/Linux
|
--workers WORKERS number of workers, available on Unix/Linux
|
||||||
--forbidden-ip IPLIST comma seperated IP list forbidden to connect
|
--forbidden-ip IPLIST comma seperated IP list forbidden to connect
|
||||||
--manager-address ADDR optional server manager UDP address, see wiki
|
--manager-address ADDR optional server manager UDP address, see wiki
|
||||||
--prefer-ipv6 resolve ipv6 address first
|
|
||||||
|
|
||||||
General options:
|
General options:
|
||||||
-h, --help show this help message and exit
|
-h, --help show this help message and exit
|
||||||
|
|
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue