Merge 710c93291e
into ad39d957d7
This commit is contained in:
commit
0fdd2fec83
12 changed files with 1183 additions and 1385 deletions
|
@ -29,7 +29,7 @@ from shadowsocks import common, lru_cache, eventloop, shell
|
|||
|
||||
CACHE_SWEEP_INTERVAL = 30
|
||||
|
||||
VALID_HOSTNAME = re.compile(br"(?!-)[A-Z\d\-_]{1,63}(?<!-)$", re.IGNORECASE)
|
||||
VALID_HOSTNAME = re.compile(br"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE)
|
||||
|
||||
common.patch_socket()
|
||||
|
||||
|
@ -242,13 +242,13 @@ class DNSResponse(object):
|
|||
return '%s: %s' % (self.hostname, str(self.answers))
|
||||
|
||||
|
||||
STATUS_FIRST = 0
|
||||
STATUS_SECOND = 1
|
||||
STATUS_IPV4 = 0
|
||||
STATUS_IPV6 = 1
|
||||
|
||||
|
||||
class DNSResolver(object):
|
||||
|
||||
def __init__(self, server_list=None, prefer_ipv6=False):
|
||||
def __init__(self):
|
||||
self._loop = None
|
||||
self._hosts = {}
|
||||
self._hostname_status = {}
|
||||
|
@ -256,15 +256,8 @@ class DNSResolver(object):
|
|||
self._cb_to_hostname = {}
|
||||
self._cache = lru_cache.LRUCache(timeout=300)
|
||||
self._sock = None
|
||||
if server_list is None:
|
||||
self._servers = None
|
||||
self._parse_resolv()
|
||||
else:
|
||||
self._servers = server_list
|
||||
if prefer_ipv6:
|
||||
self._QTYPES = [QTYPE_AAAA, QTYPE_A]
|
||||
else:
|
||||
self._QTYPES = [QTYPE_A, QTYPE_AAAA]
|
||||
self._servers = None
|
||||
self._parse_resolv()
|
||||
self._parse_hosts()
|
||||
# TODO monitor hosts change and reload hosts
|
||||
# TODO parse /etc/gai.conf and follow its rules
|
||||
|
@ -276,18 +269,15 @@ class DNSResolver(object):
|
|||
content = f.readlines()
|
||||
for line in content:
|
||||
line = line.strip()
|
||||
if not (line and line.startswith(b'nameserver')):
|
||||
continue
|
||||
|
||||
parts = line.split()
|
||||
if len(parts) < 2:
|
||||
continue
|
||||
|
||||
server = parts[1]
|
||||
if common.is_ip(server) == socket.AF_INET:
|
||||
if type(server) != str:
|
||||
server = server.decode('utf8')
|
||||
self._servers.append(server)
|
||||
if line:
|
||||
if line.startswith(b'nameserver'):
|
||||
parts = line.split()
|
||||
if len(parts) >= 2:
|
||||
server = parts[1]
|
||||
if common.is_ip(server) == socket.AF_INET:
|
||||
if type(server) != str:
|
||||
server = server.decode('utf8')
|
||||
self._servers.append(server)
|
||||
except IOError:
|
||||
pass
|
||||
if not self._servers:
|
||||
|
@ -302,17 +292,13 @@ class DNSResolver(object):
|
|||
for line in f.readlines():
|
||||
line = line.strip()
|
||||
parts = line.split()
|
||||
if len(parts) < 2:
|
||||
continue
|
||||
|
||||
ip = parts[0]
|
||||
if not common.is_ip(ip):
|
||||
continue
|
||||
|
||||
for i in range(1, len(parts)):
|
||||
hostname = parts[i]
|
||||
if hostname:
|
||||
self._hosts[hostname] = ip
|
||||
if len(parts) >= 2:
|
||||
ip = parts[0]
|
||||
if common.is_ip(ip):
|
||||
for i in range(1, len(parts)):
|
||||
hostname = parts[i]
|
||||
if hostname:
|
||||
self._hosts[hostname] = ip
|
||||
except IOError:
|
||||
self._hosts['localhost'] = '127.0.0.1'
|
||||
|
||||
|
@ -352,22 +338,21 @@ class DNSResolver(object):
|
|||
answer[2] == QCLASS_IN:
|
||||
ip = answer[0]
|
||||
break
|
||||
if not ip and self._hostname_status.get(hostname, STATUS_SECOND) \
|
||||
== STATUS_FIRST:
|
||||
self._hostname_status[hostname] = STATUS_SECOND
|
||||
self._send_req(hostname, self._QTYPES[1])
|
||||
if not ip and self._hostname_status.get(hostname, STATUS_IPV6) \
|
||||
== STATUS_IPV4:
|
||||
self._hostname_status[hostname] = STATUS_IPV6
|
||||
self._send_req(hostname, QTYPE_AAAA)
|
||||
else:
|
||||
if ip:
|
||||
self._cache[hostname] = ip
|
||||
self._call_callback(hostname, ip)
|
||||
elif self._hostname_status.get(hostname, None) \
|
||||
== STATUS_SECOND:
|
||||
elif self._hostname_status.get(hostname, None) == STATUS_IPV6:
|
||||
for question in response.questions:
|
||||
if question[1] == self._QTYPES[1]:
|
||||
if question[1] == QTYPE_AAAA:
|
||||
self._call_callback(hostname, None)
|
||||
break
|
||||
|
||||
def handle_event(self, sock, fd, event):
|
||||
def handle_event(self, sock, event):
|
||||
if sock != self._sock:
|
||||
return
|
||||
if event & eventloop.POLL_ERR:
|
||||
|
@ -429,14 +414,14 @@ class DNSResolver(object):
|
|||
return
|
||||
arr = self._hostname_to_cb.get(hostname, None)
|
||||
if not arr:
|
||||
self._hostname_status[hostname] = STATUS_FIRST
|
||||
self._send_req(hostname, self._QTYPES[0])
|
||||
self._hostname_status[hostname] = STATUS_IPV4
|
||||
self._send_req(hostname, QTYPE_A)
|
||||
self._hostname_to_cb[hostname] = [callback]
|
||||
self._cb_to_hostname[callback] = hostname
|
||||
else:
|
||||
arr.append(callback)
|
||||
# TODO send again only if waited too long
|
||||
self._send_req(hostname, self._QTYPES[0])
|
||||
self._send_req(hostname, QTYPE_A)
|
||||
|
||||
def close(self):
|
||||
if self._sock:
|
||||
|
|
33
shadowsocks/client.py
Normal file
33
shadowsocks/client.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import absolute_import, division, print_function, \
|
||||
with_statement
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, path)
|
||||
from shadowsocks.eventloop import EventLoop
|
||||
from shadowsocks.tcprelay import TcpRelay, TcpRelayClientHanler
|
||||
|
||||
|
||||
FORMATTER = '%(asctime)s - %(levelname)s - %(message)s'
|
||||
LOGGING_LEVEL = logging.INFO
|
||||
logging.basicConfig(level=LOGGING_LEVEL, format=FORMATTER)
|
||||
|
||||
LISTEN_ADDR = ('127.0.0.1', 7070)
|
||||
REMOTE_ADDR = ('104.129.176.124', 9000)
|
||||
|
||||
|
||||
def main():
|
||||
loop = EventLoop()
|
||||
relay = TcpRelay(TcpRelayClientHanler, LISTEN_ADDR, REMOTE_ADDR)
|
||||
relay.add_to_loop(loop)
|
||||
loop.run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -21,25 +21,6 @@ from __future__ import absolute_import, division, print_function, \
|
|||
import socket
|
||||
import struct
|
||||
import logging
|
||||
import hashlib
|
||||
import hmac
|
||||
|
||||
|
||||
ONETIMEAUTH_BYTES = 10
|
||||
ONETIMEAUTH_CHUNK_BYTES = 12
|
||||
ONETIMEAUTH_CHUNK_DATA_LEN = 2
|
||||
|
||||
|
||||
def sha1_hmac(secret, data):
|
||||
return hmac.new(secret, data, hashlib.sha1).digest()
|
||||
|
||||
|
||||
def onetimeauth_verify(_hash, data, key):
|
||||
return _hash == sha1_hmac(key, data)[:ONETIMEAUTH_BYTES]
|
||||
|
||||
|
||||
def onetimeauth_gen(data, key):
|
||||
return sha1_hmac(key, data)[:ONETIMEAUTH_BYTES]
|
||||
|
||||
|
||||
def compat_ord(s):
|
||||
|
@ -137,11 +118,9 @@ def patch_socket():
|
|||
patch_socket()
|
||||
|
||||
|
||||
ADDRTYPE_IPV4 = 0x01
|
||||
ADDRTYPE_IPV6 = 0x04
|
||||
ADDRTYPE_HOST = 0x03
|
||||
ADDRTYPE_AUTH = 0x10
|
||||
ADDRTYPE_MASK = 0xF
|
||||
ADDRTYPE_IPV4 = 1
|
||||
ADDRTYPE_IPV6 = 4
|
||||
ADDRTYPE_HOST = 3
|
||||
|
||||
|
||||
def pack_addr(address):
|
||||
|
@ -165,17 +144,17 @@ def parse_header(data):
|
|||
dest_addr = None
|
||||
dest_port = None
|
||||
header_length = 0
|
||||
if addrtype & ADDRTYPE_MASK == ADDRTYPE_IPV4:
|
||||
if addrtype == ADDRTYPE_IPV4:
|
||||
if len(data) >= 7:
|
||||
dest_addr = socket.inet_ntoa(data[1:5])
|
||||
dest_port = struct.unpack('>H', data[5:7])[0]
|
||||
header_length = 7
|
||||
else:
|
||||
logging.warn('header is too short')
|
||||
elif addrtype & ADDRTYPE_MASK == ADDRTYPE_HOST:
|
||||
elif addrtype == ADDRTYPE_HOST:
|
||||
if len(data) > 2:
|
||||
addrlen = ord(data[1])
|
||||
if len(data) >= 4 + addrlen:
|
||||
if len(data) >= 2 + addrlen:
|
||||
dest_addr = data[2:2 + addrlen]
|
||||
dest_port = struct.unpack('>H', data[2 + addrlen:4 +
|
||||
addrlen])[0]
|
||||
|
@ -184,7 +163,7 @@ def parse_header(data):
|
|||
logging.warn('header is too short')
|
||||
else:
|
||||
logging.warn('header is too short')
|
||||
elif addrtype & ADDRTYPE_MASK == ADDRTYPE_IPV6:
|
||||
elif addrtype == ADDRTYPE_IPV6:
|
||||
if len(data) >= 19:
|
||||
dest_addr = socket.inet_ntop(socket.AF_INET6, data[1:17])
|
||||
dest_port = struct.unpack('>H', data[17:19])[0]
|
||||
|
|
|
@ -32,7 +32,7 @@ buf_size = 2048
|
|||
|
||||
|
||||
def load_openssl():
|
||||
global loaded, libcrypto, buf, ctx_cleanup
|
||||
global loaded, libcrypto, buf
|
||||
|
||||
libcrypto = util.find_library(('crypto', 'eay32'),
|
||||
'EVP_get_cipherbyname',
|
||||
|
@ -49,12 +49,7 @@ def load_openssl():
|
|||
libcrypto.EVP_CipherUpdate.argtypes = (c_void_p, c_void_p, c_void_p,
|
||||
c_char_p, c_int)
|
||||
|
||||
try:
|
||||
libcrypto.EVP_CIPHER_CTX_cleanup.argtypes = (c_void_p,)
|
||||
ctx_cleanup = libcrypto.EVP_CIPHER_CTX_cleanup
|
||||
except AttributeError:
|
||||
libcrypto.EVP_CIPHER_CTX_reset.argtypes = (c_void_p,)
|
||||
ctx_cleanup = libcrypto.EVP_CIPHER_CTX_reset
|
||||
libcrypto.EVP_CIPHER_CTX_cleanup.argtypes = (c_void_p,)
|
||||
libcrypto.EVP_CIPHER_CTX_free.argtypes = (c_void_p,)
|
||||
if hasattr(libcrypto, 'OpenSSL_add_all_ciphers'):
|
||||
libcrypto.OpenSSL_add_all_ciphers()
|
||||
|
@ -113,7 +108,7 @@ class OpenSSLCrypto(object):
|
|||
|
||||
def clean(self):
|
||||
if self._ctx:
|
||||
ctx_cleanup(self._ctx)
|
||||
libcrypto.EVP_CIPHER_CTX_cleanup(self._ctx)
|
||||
libcrypto.EVP_CIPHER_CTX_free(self._ctx)
|
||||
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
from __future__ import absolute_import, division, print_function, \
|
||||
with_statement
|
||||
|
||||
from ctypes import c_char_p, c_int, c_ulonglong, byref, c_ulong, \
|
||||
from ctypes import c_char_p, c_int, c_ulonglong, byref, \
|
||||
create_string_buffer, c_void_p
|
||||
|
||||
from shadowsocks.crypto import util
|
||||
|
@ -29,7 +29,7 @@ loaded = False
|
|||
|
||||
buf_size = 2048
|
||||
|
||||
# for salsa20 and chacha20 and chacha20-ietf
|
||||
# for salsa20 and chacha20
|
||||
BLOCK_SIZE = 64
|
||||
|
||||
|
||||
|
@ -51,13 +51,6 @@ def load_libsodium():
|
|||
c_ulonglong,
|
||||
c_char_p, c_ulonglong,
|
||||
c_char_p)
|
||||
libsodium.crypto_stream_chacha20_ietf_xor_ic.restype = c_int
|
||||
libsodium.crypto_stream_chacha20_ietf_xor_ic.argtypes = (c_void_p,
|
||||
c_char_p,
|
||||
c_ulonglong,
|
||||
c_char_p,
|
||||
c_ulong,
|
||||
c_char_p)
|
||||
|
||||
buf = create_string_buffer(buf_size)
|
||||
loaded = True
|
||||
|
@ -75,8 +68,6 @@ class SodiumCrypto(object):
|
|||
self.cipher = libsodium.crypto_stream_salsa20_xor_ic
|
||||
elif cipher_name == 'chacha20':
|
||||
self.cipher = libsodium.crypto_stream_chacha20_xor_ic
|
||||
elif cipher_name == 'chacha20-ietf':
|
||||
self.cipher = libsodium.crypto_stream_chacha20_ietf_xor_ic
|
||||
else:
|
||||
raise Exception('Unknown cipher')
|
||||
# byte counter, not block counter
|
||||
|
@ -106,7 +97,6 @@ class SodiumCrypto(object):
|
|||
ciphers = {
|
||||
'salsa20': (32, 8, SodiumCrypto),
|
||||
'chacha20': (32, 8, SodiumCrypto),
|
||||
'chacha20-ietf': (32, 12, SodiumCrypto),
|
||||
}
|
||||
|
||||
|
||||
|
@ -125,15 +115,6 @@ def test_chacha20():
|
|||
util.run_cipher(cipher, decipher)
|
||||
|
||||
|
||||
def test_chacha20_ietf():
|
||||
|
||||
cipher = SodiumCrypto('chacha20-ietf', b'k' * 32, b'i' * 16, 1)
|
||||
decipher = SodiumCrypto('chacha20-ietf', b'k' * 32, b'i' * 16, 0)
|
||||
|
||||
util.run_cipher(cipher, decipher)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_chacha20()
|
||||
test_salsa20()
|
||||
test_chacha20_ietf()
|
||||
|
|
|
@ -69,18 +69,17 @@ def EVP_BytesToKey(password, key_len, iv_len):
|
|||
|
||||
|
||||
class Encryptor(object):
|
||||
def __init__(self, password, method):
|
||||
self.password = password
|
||||
self.key = None
|
||||
def __init__(self, key, method):
|
||||
self.key = key
|
||||
self.method = method
|
||||
self.iv = None
|
||||
self.iv_sent = False
|
||||
self.cipher_iv = b''
|
||||
self.decipher = None
|
||||
self.decipher_iv = None
|
||||
method = method.lower()
|
||||
self._method_info = self.get_method_info(method)
|
||||
if self._method_info:
|
||||
self.cipher = self.get_cipher(password, method, 1,
|
||||
self.cipher = self.get_cipher(key, method, 1,
|
||||
random_string(self._method_info[1]))
|
||||
else:
|
||||
logging.error('method %s not supported' % method)
|
||||
|
@ -102,7 +101,7 @@ class Encryptor(object):
|
|||
else:
|
||||
# key_length == 0 indicates we should use the key directly
|
||||
key, iv = password, b''
|
||||
self.key = key
|
||||
|
||||
iv = iv[:m[1]]
|
||||
if op == 1:
|
||||
# this iv is for cipher not decipher
|
||||
|
@ -124,8 +123,7 @@ class Encryptor(object):
|
|||
if self.decipher is None:
|
||||
decipher_iv_len = self._method_info[1]
|
||||
decipher_iv = buf[:decipher_iv_len]
|
||||
self.decipher_iv = decipher_iv
|
||||
self.decipher = self.get_cipher(self.password, self.method, 0,
|
||||
self.decipher = self.get_cipher(self.key, self.method, 0,
|
||||
iv=decipher_iv)
|
||||
buf = buf[decipher_iv_len:]
|
||||
if len(buf) == 0:
|
||||
|
@ -133,47 +131,10 @@ class Encryptor(object):
|
|||
return self.decipher.update(buf)
|
||||
|
||||
|
||||
def gen_key_iv(password, method):
|
||||
method = method.lower()
|
||||
(key_len, iv_len, m) = method_supported[method]
|
||||
key = None
|
||||
if key_len > 0:
|
||||
key, _ = EVP_BytesToKey(password, key_len, iv_len)
|
||||
else:
|
||||
key = password
|
||||
iv = random_string(iv_len)
|
||||
return key, iv, m
|
||||
|
||||
|
||||
def encrypt_all_m(key, iv, m, method, data):
|
||||
result = []
|
||||
result.append(iv)
|
||||
cipher = m(method, key, iv, 1)
|
||||
result.append(cipher.update(data))
|
||||
return b''.join(result)
|
||||
|
||||
|
||||
def dencrypt_all(password, method, data):
|
||||
result = []
|
||||
method = method.lower()
|
||||
(key_len, iv_len, m) = method_supported[method]
|
||||
key = None
|
||||
if key_len > 0:
|
||||
key, _ = EVP_BytesToKey(password, key_len, iv_len)
|
||||
else:
|
||||
key = password
|
||||
iv = data[:iv_len]
|
||||
data = data[iv_len:]
|
||||
cipher = m(method, key, iv, 0)
|
||||
result.append(cipher.update(data))
|
||||
return b''.join(result), key, iv
|
||||
|
||||
|
||||
def encrypt_all(password, method, op, data):
|
||||
result = []
|
||||
method = method.lower()
|
||||
(key_len, iv_len, m) = method_supported[method]
|
||||
key = None
|
||||
if key_len > 0:
|
||||
key, _ = EVP_BytesToKey(password, key_len, iv_len)
|
||||
else:
|
||||
|
@ -221,18 +182,6 @@ def test_encrypt_all():
|
|||
assert plain == plain2
|
||||
|
||||
|
||||
def test_encrypt_all_m():
|
||||
from os import urandom
|
||||
plain = urandom(10240)
|
||||
for method in CIPHERS_TO_TEST:
|
||||
logging.warn(method)
|
||||
key, iv, m = gen_key_iv(b'key', method)
|
||||
cipher = encrypt_all_m(key, iv, m, method, plain)
|
||||
plain2, key, iv = dencrypt_all(b'key', method, cipher)
|
||||
assert plain == plain2
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_encrypt_all()
|
||||
test_encryptor()
|
||||
test_encrypt_all_m()
|
||||
|
|
|
@ -1,181 +1,53 @@
|
|||
#!/usr/bin/python
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2013-2015 clowwindy
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# from ssloop
|
||||
# https://github.com/clowwindy/ssloop
|
||||
|
||||
from __future__ import absolute_import, division, print_function, \
|
||||
with_statement
|
||||
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
import socket
|
||||
import select
|
||||
import traceback
|
||||
import errno
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
from shadowsocks import shell
|
||||
from shadowsocks import selectors
|
||||
from shadowsocks.selectors import (EVENT_READ, EVENT_WRITE, EVENT_ERROR,
|
||||
errno_from_exception)
|
||||
|
||||
|
||||
__all__ = ['EventLoop', 'POLL_NULL', 'POLL_IN', 'POLL_OUT', 'POLL_ERR',
|
||||
'POLL_HUP', 'POLL_NVAL', 'EVENT_NAMES']
|
||||
POLL_IN = EVENT_READ
|
||||
POLL_OUT = EVENT_WRITE
|
||||
POLL_ERR = EVENT_ERROR
|
||||
|
||||
POLL_NULL = 0x00
|
||||
POLL_IN = 0x01
|
||||
POLL_OUT = 0x04
|
||||
POLL_ERR = 0x08
|
||||
POLL_HUP = 0x10
|
||||
POLL_NVAL = 0x20
|
||||
|
||||
|
||||
EVENT_NAMES = {
|
||||
POLL_NULL: 'POLL_NULL',
|
||||
POLL_IN: 'POLL_IN',
|
||||
POLL_OUT: 'POLL_OUT',
|
||||
POLL_ERR: 'POLL_ERR',
|
||||
POLL_HUP: 'POLL_HUP',
|
||||
POLL_NVAL: 'POLL_NVAL',
|
||||
}
|
||||
|
||||
# we check timeouts every TIMEOUT_PRECISION seconds
|
||||
TIMEOUT_PRECISION = 10
|
||||
|
||||
|
||||
class KqueueLoop(object):
|
||||
|
||||
MAX_EVENTS = 1024
|
||||
class EventLoop:
|
||||
|
||||
def __init__(self):
|
||||
self._kqueue = select.kqueue()
|
||||
self._fds = {}
|
||||
|
||||
def _control(self, fd, mode, flags):
|
||||
events = []
|
||||
if mode & POLL_IN:
|
||||
events.append(select.kevent(fd, select.KQ_FILTER_READ, flags))
|
||||
if mode & POLL_OUT:
|
||||
events.append(select.kevent(fd, select.KQ_FILTER_WRITE, flags))
|
||||
for e in events:
|
||||
self._kqueue.control([e], 0)
|
||||
|
||||
def poll(self, timeout):
|
||||
if timeout < 0:
|
||||
timeout = None # kqueue behaviour
|
||||
events = self._kqueue.control(None, KqueueLoop.MAX_EVENTS, timeout)
|
||||
results = defaultdict(lambda: POLL_NULL)
|
||||
for e in events:
|
||||
fd = e.ident
|
||||
if e.filter == select.KQ_FILTER_READ:
|
||||
results[fd] |= POLL_IN
|
||||
elif e.filter == select.KQ_FILTER_WRITE:
|
||||
results[fd] |= POLL_OUT
|
||||
return results.items()
|
||||
|
||||
def register(self, fd, mode):
|
||||
self._fds[fd] = mode
|
||||
self._control(fd, mode, select.KQ_EV_ADD)
|
||||
|
||||
def unregister(self, fd):
|
||||
self._control(fd, self._fds[fd], select.KQ_EV_DELETE)
|
||||
del self._fds[fd]
|
||||
|
||||
def modify(self, fd, mode):
|
||||
self.unregister(fd)
|
||||
self.register(fd, mode)
|
||||
|
||||
def close(self):
|
||||
self._kqueue.close()
|
||||
|
||||
|
||||
class SelectLoop(object):
|
||||
|
||||
def __init__(self):
|
||||
self._r_list = set()
|
||||
self._w_list = set()
|
||||
self._x_list = set()
|
||||
|
||||
def poll(self, timeout):
|
||||
r, w, x = select.select(self._r_list, self._w_list, self._x_list,
|
||||
timeout)
|
||||
results = defaultdict(lambda: POLL_NULL)
|
||||
for p in [(r, POLL_IN), (w, POLL_OUT), (x, POLL_ERR)]:
|
||||
for fd in p[0]:
|
||||
results[fd] |= p[1]
|
||||
return results.items()
|
||||
|
||||
def register(self, fd, mode):
|
||||
if mode & POLL_IN:
|
||||
self._r_list.add(fd)
|
||||
if mode & POLL_OUT:
|
||||
self._w_list.add(fd)
|
||||
if mode & POLL_ERR:
|
||||
self._x_list.add(fd)
|
||||
|
||||
def unregister(self, fd):
|
||||
if fd in self._r_list:
|
||||
self._r_list.remove(fd)
|
||||
if fd in self._w_list:
|
||||
self._w_list.remove(fd)
|
||||
if fd in self._x_list:
|
||||
self._x_list.remove(fd)
|
||||
|
||||
def modify(self, fd, mode):
|
||||
self.unregister(fd)
|
||||
self.register(fd, mode)
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class EventLoop(object):
|
||||
def __init__(self):
|
||||
if hasattr(select, 'epoll'):
|
||||
self._impl = select.epoll()
|
||||
model = 'epoll'
|
||||
elif hasattr(select, 'kqueue'):
|
||||
self._impl = KqueueLoop()
|
||||
model = 'kqueue'
|
||||
elif hasattr(select, 'select'):
|
||||
self._impl = SelectLoop()
|
||||
model = 'select'
|
||||
else:
|
||||
raise Exception('can not find any available functions in select '
|
||||
'package')
|
||||
self._fdmap = {} # (f, handler)
|
||||
self._selector = selectors.DefaultSelector()
|
||||
self._stopping = False
|
||||
self._last_time = time.time()
|
||||
self._periodic_callbacks = []
|
||||
self._stopping = False
|
||||
logging.debug('using event model: %s', model)
|
||||
|
||||
def poll(self, timeout=None):
|
||||
events = self._impl.poll(timeout)
|
||||
return [(self._fdmap[fd][0], fd, event) for fd, event in events]
|
||||
return self._selector.select(timeout)
|
||||
|
||||
def add(self, f, mode, handler):
|
||||
fd = f.fileno()
|
||||
self._fdmap[fd] = (f, handler)
|
||||
self._impl.register(fd, mode)
|
||||
def add(self, sock, events, data):
|
||||
events |= selectors.EVENT_ERROR
|
||||
return self._selector.register(sock, events, data)
|
||||
|
||||
def remove(self, f):
|
||||
fd = f.fileno()
|
||||
del self._fdmap[fd]
|
||||
self._impl.unregister(fd)
|
||||
def remove(self, sock):
|
||||
try:
|
||||
return self._selector.unregister(sock)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def modify(self, sock, events, data):
|
||||
events |= selectors.EVENT_ERROR
|
||||
try:
|
||||
key = self._selector.modify(sock, events, data)
|
||||
except KeyError:
|
||||
key = self.add(sock, events, data)
|
||||
return key
|
||||
|
||||
def add_periodic(self, callback):
|
||||
self._periodic_callbacks.append(callback)
|
||||
|
@ -183,69 +55,57 @@ class EventLoop(object):
|
|||
def remove_periodic(self, callback):
|
||||
self._periodic_callbacks.remove(callback)
|
||||
|
||||
def modify(self, f, mode):
|
||||
fd = f.fileno()
|
||||
self._impl.modify(fd, mode)
|
||||
|
||||
def stop(self):
|
||||
self._stopping = True
|
||||
def fd_count(self):
|
||||
return len(self._selector.get_map())
|
||||
|
||||
def run(self):
|
||||
events = []
|
||||
logging.debug('Starting event loop')
|
||||
|
||||
while not self._stopping:
|
||||
asap = False
|
||||
try:
|
||||
events = self.poll(TIMEOUT_PRECISION)
|
||||
events = self.poll(timeout=TIMEOUT_PRECISION)
|
||||
except (OSError, IOError) as e:
|
||||
if errno_from_exception(e) in (errno.EPIPE, errno.EINTR):
|
||||
# EPIPE: Happens when the client closes the connection
|
||||
# EINTR: Happens when received a signal
|
||||
# handles them as soon as possible
|
||||
asap = True
|
||||
logging.debug('poll:%s', e)
|
||||
logging.debug('poll: %s', e)
|
||||
else:
|
||||
logging.error('poll:%s', e)
|
||||
logging.error('poll: %s', e)
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
for sock, fd, event in events:
|
||||
handler = self._fdmap.get(fd, None)
|
||||
if handler is not None:
|
||||
handler = handler[1]
|
||||
try:
|
||||
handler.handle_event(sock, fd, event)
|
||||
except (OSError, IOError) as e:
|
||||
shell.print_exception(e)
|
||||
for key, event in events:
|
||||
if type(key.data) == tuple:
|
||||
handler = key.data[0]
|
||||
args = key.data[1:]
|
||||
else:
|
||||
handler = key.data
|
||||
args = ()
|
||||
|
||||
sock = key.fileobj
|
||||
if hasattr(handler, 'handle_event'):
|
||||
handler = handler.handle_event
|
||||
|
||||
try:
|
||||
handler(sock, event, *args)
|
||||
except Exception as e:
|
||||
logging.debug(e)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
now = time.time()
|
||||
if asap or now - self._last_time >= TIMEOUT_PRECISION:
|
||||
for callback in self._periodic_callbacks:
|
||||
callback()
|
||||
self._last_time = now
|
||||
|
||||
def __del__(self):
|
||||
self._impl.close()
|
||||
logging.debug('Got {} fds registered'.format(self.fd_count()))
|
||||
|
||||
logging.debug('Stopping event loop')
|
||||
self._selector.close()
|
||||
|
||||
# from tornado
|
||||
def errno_from_exception(e):
|
||||
"""Provides the errno from an Exception object.
|
||||
|
||||
There are cases that the errno attribute was not set so we pull
|
||||
the errno out of the args but if someone instatiates an Exception
|
||||
without any args you will get a tuple error. So this function
|
||||
abstracts all that behavior to give you a safe way to get the
|
||||
errno.
|
||||
"""
|
||||
|
||||
if hasattr(e, 'errno'):
|
||||
return e.errno
|
||||
elif e.args:
|
||||
return e.args[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
# from tornado
|
||||
def get_sock_error(sock):
|
||||
error_number = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
|
||||
return socket.error(error_number, os.strerror(error_number))
|
||||
def stop(self):
|
||||
self._stopping = True
|
||||
|
|
|
@ -87,8 +87,8 @@ class LRUCache(collections.MutableMapping):
|
|||
if value not in self._closed_values:
|
||||
self.close_callback(value)
|
||||
self._closed_values.add(value)
|
||||
self._last_visits.popleft()
|
||||
for key in self._time_to_keys[least]:
|
||||
self._last_visits.popleft()
|
||||
if key in self._store:
|
||||
if now - self._keys_to_last_time[key] > self.timeout:
|
||||
del self._store[key]
|
||||
|
|
652
shadowsocks/selectors.py
Normal file
652
shadowsocks/selectors.py
Normal file
|
@ -0,0 +1,652 @@
|
|||
"""Selectors module.
|
||||
|
||||
This module allows high-level and efficient I/O multiplexing, built upon the
|
||||
`select` module primitives.
|
||||
"""
|
||||
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections import namedtuple, Mapping
|
||||
import errno
|
||||
import math
|
||||
import os
|
||||
import select
|
||||
import socket
|
||||
import sys
|
||||
|
||||
|
||||
# generic events, that must be mapped to implementation-specific ones
|
||||
EVENT_READ = 0x001
|
||||
EVENT_WRITE = 0x004
|
||||
EVENT_ERROR = 0x018
|
||||
|
||||
|
||||
def errno_from_exception(e):
|
||||
"""Provides the errno from an Exception object.
|
||||
|
||||
There are cases that the errno attribute was not set so we pull
|
||||
the errno out of the args but if someone instatiates an Exception
|
||||
without any args you will get a tuple error. So this function
|
||||
abstracts all that behavior to give you a safe way to get the
|
||||
errno.
|
||||
"""
|
||||
|
||||
if hasattr(e, 'errno'):
|
||||
return e.errno
|
||||
elif e.args:
|
||||
return e.args[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_sock_error(sock):
|
||||
if not sock:
|
||||
return
|
||||
error_number = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
|
||||
return socket.error(error_number, os.strerror(error_number))
|
||||
|
||||
|
||||
def _fileobj_to_fd(fileobj):
|
||||
"""Return a file descriptor from a file object.
|
||||
|
||||
Parameters:
|
||||
fileobj -- file object or file descriptor
|
||||
|
||||
Returns:
|
||||
corresponding file descriptor
|
||||
|
||||
Raises:
|
||||
ValueError if the object is invalid
|
||||
"""
|
||||
if isinstance(fileobj, int):
|
||||
fd = fileobj
|
||||
else:
|
||||
try:
|
||||
fd = int(fileobj.fileno())
|
||||
except (AttributeError, TypeError, ValueError):
|
||||
raise ValueError("Invalid file object: "
|
||||
"{!r}".format(fileobj))
|
||||
if fd < 0:
|
||||
raise ValueError("Invalid file descriptor: {}".format(fd))
|
||||
return fd
|
||||
|
||||
|
||||
SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data'])
|
||||
"""Object used to associate a file object to its backing file descriptor,
|
||||
selected event mask and attached data."""
|
||||
|
||||
|
||||
class _SelectorMapping(Mapping):
|
||||
"""Mapping of file objects to selector keys."""
|
||||
|
||||
def __init__(self, selector):
|
||||
self._selector = selector
|
||||
|
||||
def __len__(self):
|
||||
return len(self._selector._fd_to_key)
|
||||
|
||||
def __getitem__(self, fileobj):
|
||||
try:
|
||||
fd = self._selector._fileobj_lookup(fileobj)
|
||||
return self._selector._fd_to_key[fd]
|
||||
except KeyError:
|
||||
raise KeyError("{!r} is not registered".format(fileobj))
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._selector._fd_to_key)
|
||||
|
||||
|
||||
class BaseSelector(object):
|
||||
__metaclass__ = ABCMeta
|
||||
"""Selector abstract base class.
|
||||
|
||||
A selector supports registering file objects to be monitored for specific
|
||||
I/O events.
|
||||
|
||||
A file object is a file descriptor or any object with a `fileno()` method.
|
||||
An arbitrary object can be attached to the file object, which can be used
|
||||
for example to store context information, a callback, etc.
|
||||
|
||||
A selector can use various implementations (select(), poll(), epoll()...)
|
||||
depending on the platform. The default `Selector` class uses the most
|
||||
efficient implementation on the current platform.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def register(self, fileobj, events, data=None):
|
||||
"""Register a file object.
|
||||
|
||||
Parameters:
|
||||
fileobj -- file object or file descriptor
|
||||
events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE)
|
||||
data -- attached data
|
||||
|
||||
Returns:
|
||||
SelectorKey instance
|
||||
|
||||
Raises:
|
||||
ValueError if events is invalid
|
||||
KeyError if fileobj is already registered
|
||||
OSError if fileobj is closed or otherwise is unacceptable to
|
||||
the underlying system call (if a system call is made)
|
||||
|
||||
Note:
|
||||
OSError may or may not be raised
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def unregister(self, fileobj):
|
||||
"""Unregister a file object.
|
||||
|
||||
Parameters:
|
||||
fileobj -- file object or file descriptor
|
||||
|
||||
Returns:
|
||||
SelectorKey instance
|
||||
|
||||
Raises:
|
||||
KeyError if fileobj is not registered
|
||||
|
||||
Note:
|
||||
If fileobj is registered but has since been closed this does
|
||||
*not* raise OSError (even if the wrapped syscall does)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def modify(self, fileobj, events, data=None):
|
||||
"""Change a registered file object monitored events or attached data.
|
||||
|
||||
Parameters:
|
||||
fileobj -- file object or file descriptor
|
||||
events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE)
|
||||
data -- attached data
|
||||
|
||||
Returns:
|
||||
SelectorKey instance
|
||||
|
||||
Raises:
|
||||
Anything that unregister() or register() raises
|
||||
"""
|
||||
self.unregister(fileobj)
|
||||
return self.register(fileobj, events, data)
|
||||
|
||||
@abstractmethod
|
||||
def select(self, timeout=None):
|
||||
"""Perform the actual selection, until some monitored file objects are
|
||||
ready or a timeout expires.
|
||||
|
||||
Parameters:
|
||||
timeout -- if timeout > 0, this specifies the maximum wait time, in
|
||||
seconds
|
||||
if timeout <= 0, the select() call won't block, and will
|
||||
report the currently ready file objects
|
||||
if timeout is None, select() will block until a monitored
|
||||
file object becomes ready
|
||||
|
||||
Returns:
|
||||
list of (key, events) for ready file objects
|
||||
`events` is a bitwise mask of EVENT_READ|EVENT_WRITE
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self):
|
||||
"""Close the selector.
|
||||
|
||||
This must be called to make sure that any underlying resource is freed.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_key(self, fileobj):
|
||||
"""Return the key associated to a registered file object.
|
||||
|
||||
Returns:
|
||||
SelectorKey for this file object
|
||||
"""
|
||||
mapping = self.get_map()
|
||||
try:
|
||||
if mapping is None:
|
||||
raise KeyError
|
||||
return mapping[fileobj]
|
||||
except KeyError:
|
||||
raise KeyError("{!r} is not registered".format(fileobj))
|
||||
|
||||
@abstractmethod
|
||||
def get_map(self):
|
||||
"""Return a mapping of file objects to selector keys."""
|
||||
raise NotImplementedError
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.close()
|
||||
|
||||
|
||||
class _BaseSelectorImpl(BaseSelector):
|
||||
"""Base selector implementation."""
|
||||
|
||||
def __init__(self):
|
||||
# this maps file descriptors to keys
|
||||
self._fd_to_key = {}
|
||||
# read-only mapping returned by get_map()
|
||||
self._map = _SelectorMapping(self)
|
||||
|
||||
def _fileobj_lookup(self, fileobj):
|
||||
"""Return a file descriptor from a file object.
|
||||
|
||||
This wraps _fileobj_to_fd() to do an exhaustive search in case
|
||||
the object is invalid but we still have it in our map. This
|
||||
is used by unregister() so we can unregister an object that
|
||||
was previously registered even if it is closed. It is also
|
||||
used by _SelectorMapping.
|
||||
"""
|
||||
try:
|
||||
return _fileobj_to_fd(fileobj)
|
||||
except ValueError:
|
||||
# Do an exhaustive search.
|
||||
for key in self._fd_to_key.values():
|
||||
if key.fileobj is fileobj:
|
||||
return key.fd
|
||||
# Raise ValueError after all.
|
||||
raise
|
||||
|
||||
def register(self, fileobj, events, data=None):
|
||||
if (not events) or (events & ~(EVENT_READ | EVENT_WRITE |
|
||||
EVENT_ERROR)):
|
||||
raise ValueError("Invalid events: {!r}".format(events))
|
||||
|
||||
key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data)
|
||||
|
||||
if key.fd in self._fd_to_key:
|
||||
raise KeyError("{!r} (FD {}) is already registered"
|
||||
.format(fileobj, key.fd))
|
||||
|
||||
self._fd_to_key[key.fd] = key
|
||||
return key
|
||||
|
||||
def unregister(self, fileobj):
|
||||
try:
|
||||
key = self._fd_to_key.pop(self._fileobj_lookup(fileobj))
|
||||
except KeyError:
|
||||
raise KeyError("{!r} is not registered".format(fileobj))
|
||||
return key
|
||||
|
||||
def modify(self, fileobj, events, data=None):
|
||||
# TODO: Subclasses can probably optimize this even further.
|
||||
try:
|
||||
key = self._fd_to_key[self._fileobj_lookup(fileobj)]
|
||||
except KeyError:
|
||||
raise KeyError("{!r} is not registered".format(fileobj))
|
||||
if events != key.events:
|
||||
self.unregister(fileobj)
|
||||
key = self.register(fileobj, events, data)
|
||||
elif data != key.data:
|
||||
# Use a shortcut to update the data.
|
||||
key = key._replace(data=data)
|
||||
self._fd_to_key[key.fd] = key
|
||||
return key
|
||||
|
||||
def close(self):
|
||||
self._fd_to_key.clear()
|
||||
self._map = None
|
||||
|
||||
def get_map(self):
|
||||
return self._map
|
||||
|
||||
def _key_from_fd(self, fd):
|
||||
"""Return the key associated to a given file descriptor.
|
||||
|
||||
Parameters:
|
||||
fd -- file descriptor
|
||||
|
||||
Returns:
|
||||
corresponding key, or None if not found
|
||||
"""
|
||||
try:
|
||||
return self._fd_to_key[fd]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
|
||||
class SelectSelector(_BaseSelectorImpl):
|
||||
"""Select-based selector."""
|
||||
|
||||
def __init__(self):
|
||||
super(self.__class__, self).__init__()
|
||||
self._readers = set()
|
||||
self._writers = set()
|
||||
self._errors = set()
|
||||
|
||||
def register(self, fileobj, events, data=None):
|
||||
key = super(self.__class__, self).register(fileobj, events, data)
|
||||
if events & EVENT_READ:
|
||||
self._readers.add(key.fd)
|
||||
if events & EVENT_WRITE:
|
||||
self._writers.add(key.fd)
|
||||
if events & EVENT_ERROR:
|
||||
self._errors.add(key.fd)
|
||||
return key
|
||||
|
||||
def unregister(self, fileobj):
|
||||
key = super(self.__class__, self).unregister(fileobj)
|
||||
self._readers.discard(key.fd)
|
||||
self._writers.discard(key.fd)
|
||||
self._errors.discard(key.fd)
|
||||
return key
|
||||
|
||||
if sys.platform == 'win32':
|
||||
def _select(self, r, w, x, timeout=None):
|
||||
r, w, x = select.select(r, w, x, timeout)
|
||||
return r, w, x
|
||||
else:
|
||||
_select = select.select
|
||||
|
||||
def select(self, timeout=None):
|
||||
timeout = None if timeout is None else max(timeout, 0)
|
||||
ready = []
|
||||
try:
|
||||
r, w, x = self._select(self._readers, self._writers, self._errors,
|
||||
timeout)
|
||||
except OSError as e:
|
||||
if errno_from_exception(e) == errno.EAGAIN:
|
||||
return ready
|
||||
r = set(r)
|
||||
w = set(w)
|
||||
x = set(x)
|
||||
for fd in r | w | x:
|
||||
events = 0
|
||||
if fd in r:
|
||||
events |= EVENT_READ
|
||||
if fd in w:
|
||||
events |= EVENT_WRITE
|
||||
if fd in x:
|
||||
events |= EVENT_ERROR
|
||||
|
||||
key = self._key_from_fd(fd)
|
||||
if key:
|
||||
ready.append((key, events & key.events))
|
||||
return ready
|
||||
|
||||
|
||||
if hasattr(select, 'poll'):
|
||||
|
||||
class PollSelector(_BaseSelectorImpl):
|
||||
"""Poll-based selector."""
|
||||
|
||||
def __init__(self):
|
||||
super(self.__class__, self).__init__()
|
||||
self._poll = select.poll()
|
||||
|
||||
def register(self, fileobj, events, data=None):
|
||||
key = super(self.__class__, self).register(fileobj, events, data)
|
||||
poll_events = 0
|
||||
if events & EVENT_READ:
|
||||
poll_events |= select.POLLIN
|
||||
if events & EVENT_WRITE:
|
||||
poll_events |= select.POLLOUT
|
||||
if events & EVENT_ERROR:
|
||||
poll_events |= select.POLLERR | select.POLLHUP
|
||||
self._poll.register(key.fd, poll_events)
|
||||
return key
|
||||
|
||||
def unregister(self, fileobj):
|
||||
key = super(self.__class__, self).unregister(fileobj)
|
||||
self._poll.unregister(key.fd)
|
||||
return key
|
||||
|
||||
def select(self, timeout=None):
|
||||
if timeout is None:
|
||||
timeout = None
|
||||
elif timeout <= 0:
|
||||
timeout = 0
|
||||
else:
|
||||
# poll() has a resolution of 1 millisecond, round away from
|
||||
# zero to wait *at least* timeout seconds.
|
||||
timeout = math.ceil(timeout * 1e3)
|
||||
ready = []
|
||||
try:
|
||||
fd_event_list = self._poll.poll(timeout)
|
||||
except OSError as e:
|
||||
if errno_from_exception(e) == errno.EAGAIN:
|
||||
return ready
|
||||
for fd, event in fd_event_list:
|
||||
events = 0
|
||||
if event & select.POLLIN:
|
||||
events |= EVENT_READ
|
||||
if event & select.POLLOUT:
|
||||
events |= EVENT_WRITE
|
||||
if event & (select.POLLERR | select.POLLHUP):
|
||||
events |= EVENT_ERROR
|
||||
|
||||
key = self._key_from_fd(fd)
|
||||
if key:
|
||||
ready.append((key, events & key.events))
|
||||
return ready
|
||||
|
||||
|
||||
if hasattr(select, 'epoll'):
|
||||
|
||||
class EpollSelector(_BaseSelectorImpl):
|
||||
"""Epoll-based selector."""
|
||||
|
||||
def __init__(self):
|
||||
super(self.__class__, self).__init__()
|
||||
self._epoll = select.epoll()
|
||||
|
||||
def fileno(self):
|
||||
return self._epoll.fileno()
|
||||
|
||||
def register(self, fileobj, events, data=None):
|
||||
key = super(self.__class__, self).register(fileobj, events, data)
|
||||
epoll_events = 0
|
||||
if events & EVENT_READ:
|
||||
epoll_events |= select.EPOLLIN
|
||||
if events & EVENT_WRITE:
|
||||
epoll_events |= select.EPOLLOUT
|
||||
if events & EVENT_ERROR:
|
||||
epoll_events |= select.EPOLLERR | select.EPOLLHUP
|
||||
self._epoll.register(key.fd, epoll_events)
|
||||
return key
|
||||
|
||||
def unregister(self, fileobj):
|
||||
key = super(self.__class__, self).unregister(fileobj)
|
||||
try:
|
||||
self._epoll.unregister(key.fd)
|
||||
except OSError:
|
||||
# This can happen if the FD was closed since it
|
||||
# was registered.
|
||||
pass
|
||||
return key
|
||||
|
||||
def select(self, timeout=None):
|
||||
if timeout is None:
|
||||
timeout = -1
|
||||
elif timeout <= 0:
|
||||
timeout = 0
|
||||
else:
|
||||
# epoll_wait() has a resolution of 1 millisecond, round away
|
||||
# from zero to wait *at least* timeout seconds.
|
||||
timeout = math.ceil(timeout * 1e3) * 1e-3
|
||||
|
||||
# epoll_wait() expects `maxevents` to be greater than zero;
|
||||
# we want to make sure that `select()` can be called when no
|
||||
# FD is registered.
|
||||
max_ev = max(len(self._fd_to_key), 1)
|
||||
|
||||
ready = []
|
||||
try:
|
||||
fd_event_list = self._epoll.poll(timeout, max_ev)
|
||||
except OSError as e:
|
||||
if errno_from_exception(e) == errno.EAGAIN:
|
||||
return ready
|
||||
for fd, event in fd_event_list:
|
||||
events = 0
|
||||
if event & select.EPOLLIN:
|
||||
events |= EVENT_READ
|
||||
if event & select.EPOLLOUT:
|
||||
events |= EVENT_WRITE
|
||||
if event & (select.EPOLLERR | select.EPOLLHUP):
|
||||
events |= EVENT_ERROR
|
||||
|
||||
key = self._key_from_fd(fd)
|
||||
if key:
|
||||
ready.append((key, events & key.events))
|
||||
return ready
|
||||
|
||||
def close(self):
|
||||
try:
|
||||
self._epoll.close()
|
||||
finally:
|
||||
super(self.__class__, self).close()
|
||||
|
||||
|
||||
if hasattr(select, 'devpoll'):
|
||||
|
||||
class DevpollSelector(_BaseSelectorImpl):
|
||||
"""Solaris /dev/poll selector."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._devpoll = select.devpoll()
|
||||
|
||||
def fileno(self):
|
||||
return self._devpoll.fileno()
|
||||
|
||||
def register(self, fileobj, events, data=None):
|
||||
key = super().register(fileobj, events, data)
|
||||
poll_events = 0
|
||||
if events & EVENT_READ:
|
||||
poll_events |= select.POLLIN
|
||||
if events & EVENT_WRITE:
|
||||
poll_events |= select.POLLOUT
|
||||
if events & EVENT_ERROR:
|
||||
poll_events |= select.POLLERR | select.POLLHUP
|
||||
self._devpoll.register(key.fd, poll_events)
|
||||
return key
|
||||
|
||||
def unregister(self, fileobj):
|
||||
key = super().unregister(fileobj)
|
||||
self._devpoll.unregister(key.fd)
|
||||
return key
|
||||
|
||||
def select(self, timeout=None):
|
||||
if timeout is None:
|
||||
timeout = None
|
||||
elif timeout <= 0:
|
||||
timeout = 0
|
||||
else:
|
||||
# devpoll() has a resolution of 1 millisecond, round away from
|
||||
# zero to wait *at least* timeout seconds.
|
||||
timeout = math.ceil(timeout * 1e3)
|
||||
ready = []
|
||||
try:
|
||||
fd_event_list = self._devpoll.poll(timeout)
|
||||
except InterruptedError:
|
||||
return ready
|
||||
for fd, event in fd_event_list:
|
||||
events = 0
|
||||
if event & select.POLLIN:
|
||||
events |= EVENT_READ
|
||||
if event & select.POLLOUT:
|
||||
events |= EVENT_WRITE
|
||||
if event & (select.POLLERR | select.POLLHUP):
|
||||
events |= EVENT_ERROR
|
||||
|
||||
key = self._key_from_fd(fd)
|
||||
if key:
|
||||
ready.append((key, events & key.events))
|
||||
return ready
|
||||
|
||||
def close(self):
|
||||
self._devpoll.close()
|
||||
super().close()
|
||||
|
||||
|
||||
if hasattr(select, 'kqueue'):
|
||||
|
||||
class KqueueSelector(_BaseSelectorImpl):
|
||||
"""Kqueue-based selector."""
|
||||
|
||||
def __init__(self):
|
||||
super(self.__class__, self).__init__()
|
||||
self._kqueue = select.kqueue()
|
||||
|
||||
def fileno(self):
|
||||
return self._kqueue.fileno()
|
||||
|
||||
def register(self, fileobj, events, data=None):
|
||||
key = super(self.__class__, self).register(fileobj, events, data)
|
||||
if events & EVENT_READ:
|
||||
kev = select.kevent(key.fd, select.KQ_FILTER_READ,
|
||||
select.KQ_EV_ADD)
|
||||
self._kqueue.control([kev], 0, 0)
|
||||
if events & EVENT_WRITE:
|
||||
kev = select.kevent(key.fd, select.KQ_FILTER_WRITE,
|
||||
select.KQ_EV_ADD)
|
||||
self._kqueue.control([kev], 0, 0)
|
||||
return key
|
||||
|
||||
def unregister(self, fileobj):
|
||||
key = super(self.__class__, self).unregister(fileobj)
|
||||
if key.events & EVENT_READ:
|
||||
kev = select.kevent(key.fd, select.KQ_FILTER_READ,
|
||||
select.KQ_EV_DELETE)
|
||||
try:
|
||||
self._kqueue.control([kev], 0, 0)
|
||||
except OSError:
|
||||
# This can happen if the FD was closed since it
|
||||
# was registered.
|
||||
pass
|
||||
if key.events & EVENT_WRITE:
|
||||
kev = select.kevent(key.fd, select.KQ_FILTER_WRITE,
|
||||
select.KQ_EV_DELETE)
|
||||
try:
|
||||
self._kqueue.control([kev], 0, 0)
|
||||
except OSError:
|
||||
# See comment above.
|
||||
pass
|
||||
return key
|
||||
|
||||
def select(self, timeout=None):
|
||||
timeout = None if timeout is None else max(timeout, 0)
|
||||
max_ev = len(self._fd_to_key)
|
||||
ready = []
|
||||
try:
|
||||
kev_list = self._kqueue.control(None, max_ev, timeout)
|
||||
except OSError as e:
|
||||
if errno_from_exception(e) == errno.EAGAIN:
|
||||
return ready
|
||||
for kev in kev_list:
|
||||
fd = kev.ident
|
||||
flag = kev.filter
|
||||
events = 0
|
||||
if flag == select.KQ_FILTER_READ:
|
||||
events |= EVENT_READ
|
||||
if flag == select.KQ_FILTER_WRITE:
|
||||
events |= EVENT_WRITE
|
||||
|
||||
key = self._key_from_fd(fd)
|
||||
if key:
|
||||
ready.append((key, events & key.events))
|
||||
return ready
|
||||
|
||||
def close(self):
|
||||
try:
|
||||
self._kqueue.close()
|
||||
finally:
|
||||
super(self.__class__, self).close()
|
||||
|
||||
|
||||
# Choose the best implementation: roughly, epoll|kqueue > poll > select.
|
||||
# select() also can't accept a FD > FD_SETSIZE (usually around 1024)
|
||||
if 'KqueueSelector' in globals():
|
||||
DefaultSelector = KqueueSelector
|
||||
elif 'EpollSelector' in globals():
|
||||
DefaultSelector = EpollSelector
|
||||
elif 'DevpollSelector' in globals():
|
||||
DefaultSelector = DevpollSelector
|
||||
elif 'PollSelector' in globals():
|
||||
DefaultSelector = PollSelector
|
||||
else:
|
||||
DefaultSelector = SelectSelector
|
148
shadowsocks/server.py
Executable file → Normal file
148
shadowsocks/server.py
Executable file → Normal file
|
@ -1,143 +1,35 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2015 clowwindy
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, \
|
||||
with_statement
|
||||
|
||||
import sys
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import signal
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../'))
|
||||
from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, \
|
||||
asyncdns, manager
|
||||
path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, path)
|
||||
from shadowsocks.eventloop import EventLoop
|
||||
from shadowsocks.tcprelay import TcpRelay, TcpRelayServerHandler
|
||||
from shadowsocks.asyncdns import DNSResolver
|
||||
|
||||
|
||||
FORMATTER = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
LOGGING_LEVEL = logging.INFO
|
||||
logging.basicConfig(level=LOGGING_LEVEL, format=FORMATTER)
|
||||
|
||||
LISTEN_ADDR = ('0.0.0.0', 9000)
|
||||
|
||||
|
||||
def main():
|
||||
shell.check_python()
|
||||
|
||||
config = shell.get_config(False)
|
||||
|
||||
daemon.daemon_exec(config)
|
||||
|
||||
if config['port_password']:
|
||||
if config['password']:
|
||||
logging.warn('warning: port_password should not be used with '
|
||||
'server_port and password. server_port and password '
|
||||
'will be ignored')
|
||||
else:
|
||||
config['port_password'] = {}
|
||||
server_port = config['server_port']
|
||||
if type(server_port) == list:
|
||||
for a_server_port in server_port:
|
||||
config['port_password'][a_server_port] = config['password']
|
||||
else:
|
||||
config['port_password'][str(server_port)] = config['password']
|
||||
|
||||
if config.get('manager_address', 0):
|
||||
logging.info('entering manager mode')
|
||||
manager.run(config)
|
||||
return
|
||||
|
||||
tcp_servers = []
|
||||
udp_servers = []
|
||||
|
||||
if 'dns_server' in config: # allow override settings in resolv.conf
|
||||
dns_resolver = asyncdns.DNSResolver(config['dns_server'],
|
||||
config['prefer_ipv6'])
|
||||
else:
|
||||
dns_resolver = asyncdns.DNSResolver(prefer_ipv6=config['prefer_ipv6'])
|
||||
|
||||
port_password = config['port_password']
|
||||
del config['port_password']
|
||||
for port, password in port_password.items():
|
||||
a_config = config.copy()
|
||||
a_config['server_port'] = int(port)
|
||||
a_config['password'] = password
|
||||
logging.info("starting server at %s:%d" %
|
||||
(a_config['server'], int(port)))
|
||||
tcp_servers.append(tcprelay.TCPRelay(a_config, dns_resolver, False))
|
||||
udp_servers.append(udprelay.UDPRelay(a_config, dns_resolver, False))
|
||||
|
||||
def run_server():
|
||||
def child_handler(signum, _):
|
||||
logging.warn('received SIGQUIT, doing graceful shutting down..')
|
||||
list(map(lambda s: s.close(next_tick=True),
|
||||
tcp_servers + udp_servers))
|
||||
signal.signal(getattr(signal, 'SIGQUIT', signal.SIGTERM),
|
||||
child_handler)
|
||||
|
||||
def int_handler(signum, _):
|
||||
sys.exit(1)
|
||||
signal.signal(signal.SIGINT, int_handler)
|
||||
|
||||
try:
|
||||
loop = eventloop.EventLoop()
|
||||
dns_resolver.add_to_loop(loop)
|
||||
list(map(lambda s: s.add_to_loop(loop), tcp_servers + udp_servers))
|
||||
|
||||
daemon.set_user(config.get('user', None))
|
||||
loop.run()
|
||||
except Exception as e:
|
||||
shell.print_exception(e)
|
||||
sys.exit(1)
|
||||
|
||||
if int(config['workers']) > 1:
|
||||
if os.name == 'posix':
|
||||
children = []
|
||||
is_child = False
|
||||
for i in range(0, int(config['workers'])):
|
||||
r = os.fork()
|
||||
if r == 0:
|
||||
logging.info('worker started')
|
||||
is_child = True
|
||||
run_server()
|
||||
break
|
||||
else:
|
||||
children.append(r)
|
||||
if not is_child:
|
||||
def handler(signum, _):
|
||||
for pid in children:
|
||||
try:
|
||||
os.kill(pid, signum)
|
||||
os.waitpid(pid, 0)
|
||||
except OSError: # child may already exited
|
||||
pass
|
||||
sys.exit()
|
||||
signal.signal(signal.SIGTERM, handler)
|
||||
signal.signal(signal.SIGQUIT, handler)
|
||||
signal.signal(signal.SIGINT, handler)
|
||||
|
||||
# master
|
||||
for a_tcp_server in tcp_servers:
|
||||
a_tcp_server.close()
|
||||
for a_udp_server in udp_servers:
|
||||
a_udp_server.close()
|
||||
dns_resolver.close()
|
||||
|
||||
for child in children:
|
||||
os.waitpid(child, 0)
|
||||
else:
|
||||
logging.warn('worker is only available on Unix/Linux')
|
||||
run_server()
|
||||
else:
|
||||
run_server()
|
||||
|
||||
loop = EventLoop()
|
||||
dns_resolver = DNSResolver()
|
||||
relay = TcpRelay(TcpRelayServerHandler, LISTEN_ADDR,
|
||||
dns_resolver=dns_resolver)
|
||||
dns_resolver.add_to_loop(loop)
|
||||
relay.add_to_loop(loop)
|
||||
loop.run()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -23,10 +23,6 @@ import json
|
|||
import sys
|
||||
import getopt
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from shadowsocks.common import to_bytes, to_str, IPNetwork
|
||||
from shadowsocks import encrypt
|
||||
|
||||
|
@ -57,49 +53,6 @@ def print_exception(e):
|
|||
traceback.print_exc()
|
||||
|
||||
|
||||
def exception_handle(self_, err_msg=None, exit_code=None,
|
||||
destroy=False, conn_err=False):
|
||||
# self_: if function passes self as first arg
|
||||
|
||||
def process_exception(e, self=None):
|
||||
print_exception(e)
|
||||
if err_msg:
|
||||
logging.error(err_msg)
|
||||
if exit_code:
|
||||
sys.exit(1)
|
||||
|
||||
if not self_:
|
||||
return
|
||||
|
||||
if conn_err:
|
||||
addr, port = self._client_address[0], self._client_address[1]
|
||||
logging.error('%s when handling connection from %s:%d' %
|
||||
(e, addr, port))
|
||||
if self._config['verbose']:
|
||||
traceback.print_exc()
|
||||
if destroy:
|
||||
self.destroy()
|
||||
|
||||
def decorator(func):
|
||||
if self_:
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
func(self, *args, **kwargs)
|
||||
except Exception as e:
|
||||
process_exception(e, self)
|
||||
else:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
process_exception(e)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def print_shadowsocks():
|
||||
version = ''
|
||||
try:
|
||||
|
@ -125,30 +78,13 @@ def check_config(config, is_local):
|
|||
# no need to specify configuration for daemon stop
|
||||
return
|
||||
|
||||
if is_local:
|
||||
if config.get('server', None) is None:
|
||||
logging.error('server addr not specified')
|
||||
print_local_help()
|
||||
sys.exit(2)
|
||||
else:
|
||||
config['server'] = to_str(config['server'])
|
||||
else:
|
||||
config['server'] = to_str(config.get('server', '0.0.0.0'))
|
||||
try:
|
||||
config['forbidden_ip'] = \
|
||||
IPNetwork(config.get('forbidden_ip', '127.0.0.0/8,::1/128'))
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
sys.exit(2)
|
||||
|
||||
if is_local and not config.get('password', None):
|
||||
logging.error('password not specified')
|
||||
print_help(is_local)
|
||||
sys.exit(2)
|
||||
|
||||
if not is_local and not config.get('password', None) \
|
||||
and not config.get('port_password', None) \
|
||||
and not config.get('manager_address'):
|
||||
and not config.get('port_password', None):
|
||||
logging.error('password or port_password not specified')
|
||||
print_help(is_local)
|
||||
sys.exit(2)
|
||||
|
@ -194,14 +130,13 @@ def get_config(is_local):
|
|||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(levelname)-s: %(message)s')
|
||||
if is_local:
|
||||
shortopts = 'hd:s:b:p:k:l:m:c:t:vqa'
|
||||
shortopts = 'hd:s:b:p:k:l:m:c:t:vq'
|
||||
longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'user=',
|
||||
'version']
|
||||
else:
|
||||
shortopts = 'hd:s:p:k:m:c:t:vqa'
|
||||
shortopts = 'hd:s:p:k:m:c:t:vq'
|
||||
longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'workers=',
|
||||
'forbidden-ip=', 'user=', 'manager-address=', 'version',
|
||||
'prefer-ipv6']
|
||||
'forbidden-ip=', 'user=', 'manager-address=', 'version']
|
||||
try:
|
||||
config_path = find_config()
|
||||
optlist, args = getopt.getopt(sys.argv[1:], shortopts, longopts)
|
||||
|
@ -239,8 +174,6 @@ def get_config(is_local):
|
|||
v_count += 1
|
||||
# '-vv' turns on more verbose mode
|
||||
config['verbose'] = v_count
|
||||
elif key == '-a':
|
||||
config['one_time_auth'] = True
|
||||
elif key == '-t':
|
||||
config['timeout'] = int(value)
|
||||
elif key == '--fast-open':
|
||||
|
@ -271,8 +204,6 @@ def get_config(is_local):
|
|||
elif key == '-q':
|
||||
v_count -= 1
|
||||
config['verbose'] = v_count
|
||||
elif key == '--prefer-ipv6':
|
||||
config['prefer_ipv6'] = True
|
||||
except getopt.GetoptError as e:
|
||||
print(e, file=sys.stderr)
|
||||
print_help(is_local)
|
||||
|
@ -294,8 +225,21 @@ def get_config(is_local):
|
|||
config['verbose'] = config.get('verbose', False)
|
||||
config['local_address'] = to_str(config.get('local_address', '127.0.0.1'))
|
||||
config['local_port'] = config.get('local_port', 1080)
|
||||
config['one_time_auth'] = config.get('one_time_auth', False)
|
||||
config['prefer_ipv6'] = config.get('prefer_ipv6', False)
|
||||
if is_local:
|
||||
if config.get('server', None) is None:
|
||||
logging.error('server addr not specified')
|
||||
print_local_help()
|
||||
sys.exit(2)
|
||||
else:
|
||||
config['server'] = to_str(config['server'])
|
||||
else:
|
||||
config['server'] = to_str(config.get('server', '0.0.0.0'))
|
||||
try:
|
||||
config['forbidden_ip'] = \
|
||||
IPNetwork(config.get('forbidden_ip', '127.0.0.0/8,::1/128'))
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
sys.exit(2)
|
||||
config['server_port'] = config.get('server_port', 8388)
|
||||
|
||||
logging.getLogger('').handlers = []
|
||||
|
@ -342,7 +286,6 @@ Proxy options:
|
|||
-k PASSWORD password
|
||||
-m METHOD encryption method, default: aes-256-cfb
|
||||
-t TIMEOUT timeout in seconds, default: 300
|
||||
-a ONE_TIME_AUTH one time auth
|
||||
--fast-open use TCP_FASTOPEN, requires Linux 3.7+
|
||||
|
||||
General options:
|
||||
|
@ -372,12 +315,10 @@ Proxy options:
|
|||
-k PASSWORD password
|
||||
-m METHOD encryption method, default: aes-256-cfb
|
||||
-t TIMEOUT timeout in seconds, default: 300
|
||||
-a ONE_TIME_AUTH one time auth
|
||||
--fast-open use TCP_FASTOPEN, requires Linux 3.7+
|
||||
--workers WORKERS number of workers, available on Unix/Linux
|
||||
--forbidden-ip IPLIST comma seperated IP list forbidden to connect
|
||||
--manager-address ADDR optional server manager UDP address, see wiki
|
||||
--prefer-ipv6 resolve ipv6 address first
|
||||
|
||||
General options:
|
||||
-h, --help show this help message and exit
|
||||
|
|
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue