This commit is contained in:
Shengdun Hua 2017-02-08 13:07:16 +00:00 committed by GitHub
commit 0fdd2fec83
12 changed files with 1183 additions and 1385 deletions

View file

@ -29,7 +29,7 @@ from shadowsocks import common, lru_cache, eventloop, shell
CACHE_SWEEP_INTERVAL = 30
VALID_HOSTNAME = re.compile(br"(?!-)[A-Z\d\-_]{1,63}(?<!-)$", re.IGNORECASE)
VALID_HOSTNAME = re.compile(br"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE)
common.patch_socket()
@ -242,13 +242,13 @@ class DNSResponse(object):
return '%s: %s' % (self.hostname, str(self.answers))
STATUS_FIRST = 0
STATUS_SECOND = 1
STATUS_IPV4 = 0
STATUS_IPV6 = 1
class DNSResolver(object):
def __init__(self, server_list=None, prefer_ipv6=False):
def __init__(self):
self._loop = None
self._hosts = {}
self._hostname_status = {}
@ -256,15 +256,8 @@ class DNSResolver(object):
self._cb_to_hostname = {}
self._cache = lru_cache.LRUCache(timeout=300)
self._sock = None
if server_list is None:
self._servers = None
self._parse_resolv()
else:
self._servers = server_list
if prefer_ipv6:
self._QTYPES = [QTYPE_AAAA, QTYPE_A]
else:
self._QTYPES = [QTYPE_A, QTYPE_AAAA]
self._servers = None
self._parse_resolv()
self._parse_hosts()
# TODO monitor hosts change and reload hosts
# TODO parse /etc/gai.conf and follow its rules
@ -276,18 +269,15 @@ class DNSResolver(object):
content = f.readlines()
for line in content:
line = line.strip()
if not (line and line.startswith(b'nameserver')):
continue
parts = line.split()
if len(parts) < 2:
continue
server = parts[1]
if common.is_ip(server) == socket.AF_INET:
if type(server) != str:
server = server.decode('utf8')
self._servers.append(server)
if line:
if line.startswith(b'nameserver'):
parts = line.split()
if len(parts) >= 2:
server = parts[1]
if common.is_ip(server) == socket.AF_INET:
if type(server) != str:
server = server.decode('utf8')
self._servers.append(server)
except IOError:
pass
if not self._servers:
@ -302,17 +292,13 @@ class DNSResolver(object):
for line in f.readlines():
line = line.strip()
parts = line.split()
if len(parts) < 2:
continue
ip = parts[0]
if not common.is_ip(ip):
continue
for i in range(1, len(parts)):
hostname = parts[i]
if hostname:
self._hosts[hostname] = ip
if len(parts) >= 2:
ip = parts[0]
if common.is_ip(ip):
for i in range(1, len(parts)):
hostname = parts[i]
if hostname:
self._hosts[hostname] = ip
except IOError:
self._hosts['localhost'] = '127.0.0.1'
@ -352,22 +338,21 @@ class DNSResolver(object):
answer[2] == QCLASS_IN:
ip = answer[0]
break
if not ip and self._hostname_status.get(hostname, STATUS_SECOND) \
== STATUS_FIRST:
self._hostname_status[hostname] = STATUS_SECOND
self._send_req(hostname, self._QTYPES[1])
if not ip and self._hostname_status.get(hostname, STATUS_IPV6) \
== STATUS_IPV4:
self._hostname_status[hostname] = STATUS_IPV6
self._send_req(hostname, QTYPE_AAAA)
else:
if ip:
self._cache[hostname] = ip
self._call_callback(hostname, ip)
elif self._hostname_status.get(hostname, None) \
== STATUS_SECOND:
elif self._hostname_status.get(hostname, None) == STATUS_IPV6:
for question in response.questions:
if question[1] == self._QTYPES[1]:
if question[1] == QTYPE_AAAA:
self._call_callback(hostname, None)
break
def handle_event(self, sock, fd, event):
def handle_event(self, sock, event):
if sock != self._sock:
return
if event & eventloop.POLL_ERR:
@ -429,14 +414,14 @@ class DNSResolver(object):
return
arr = self._hostname_to_cb.get(hostname, None)
if not arr:
self._hostname_status[hostname] = STATUS_FIRST
self._send_req(hostname, self._QTYPES[0])
self._hostname_status[hostname] = STATUS_IPV4
self._send_req(hostname, QTYPE_A)
self._hostname_to_cb[hostname] = [callback]
self._cb_to_hostname[callback] = hostname
else:
arr.append(callback)
# TODO send again only if waited too long
self._send_req(hostname, self._QTYPES[0])
self._send_req(hostname, QTYPE_A)
def close(self):
if self._sock:

33
shadowsocks/client.py Normal file
View file

@ -0,0 +1,33 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function, \
with_statement
import logging
import os
import sys
path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, path)
from shadowsocks.eventloop import EventLoop
from shadowsocks.tcprelay import TcpRelay, TcpRelayClientHanler
FORMATTER = '%(asctime)s - %(levelname)s - %(message)s'
LOGGING_LEVEL = logging.INFO
logging.basicConfig(level=LOGGING_LEVEL, format=FORMATTER)
LISTEN_ADDR = ('127.0.0.1', 7070)
REMOTE_ADDR = ('104.129.176.124', 9000)
def main():
loop = EventLoop()
relay = TcpRelay(TcpRelayClientHanler, LISTEN_ADDR, REMOTE_ADDR)
relay.add_to_loop(loop)
loop.run()
if __name__ == '__main__':
main()

View file

@ -21,25 +21,6 @@ from __future__ import absolute_import, division, print_function, \
import socket
import struct
import logging
import hashlib
import hmac
ONETIMEAUTH_BYTES = 10
ONETIMEAUTH_CHUNK_BYTES = 12
ONETIMEAUTH_CHUNK_DATA_LEN = 2
def sha1_hmac(secret, data):
return hmac.new(secret, data, hashlib.sha1).digest()
def onetimeauth_verify(_hash, data, key):
return _hash == sha1_hmac(key, data)[:ONETIMEAUTH_BYTES]
def onetimeauth_gen(data, key):
return sha1_hmac(key, data)[:ONETIMEAUTH_BYTES]
def compat_ord(s):
@ -137,11 +118,9 @@ def patch_socket():
patch_socket()
ADDRTYPE_IPV4 = 0x01
ADDRTYPE_IPV6 = 0x04
ADDRTYPE_HOST = 0x03
ADDRTYPE_AUTH = 0x10
ADDRTYPE_MASK = 0xF
ADDRTYPE_IPV4 = 1
ADDRTYPE_IPV6 = 4
ADDRTYPE_HOST = 3
def pack_addr(address):
@ -165,17 +144,17 @@ def parse_header(data):
dest_addr = None
dest_port = None
header_length = 0
if addrtype & ADDRTYPE_MASK == ADDRTYPE_IPV4:
if addrtype == ADDRTYPE_IPV4:
if len(data) >= 7:
dest_addr = socket.inet_ntoa(data[1:5])
dest_port = struct.unpack('>H', data[5:7])[0]
header_length = 7
else:
logging.warn('header is too short')
elif addrtype & ADDRTYPE_MASK == ADDRTYPE_HOST:
elif addrtype == ADDRTYPE_HOST:
if len(data) > 2:
addrlen = ord(data[1])
if len(data) >= 4 + addrlen:
if len(data) >= 2 + addrlen:
dest_addr = data[2:2 + addrlen]
dest_port = struct.unpack('>H', data[2 + addrlen:4 +
addrlen])[0]
@ -184,7 +163,7 @@ def parse_header(data):
logging.warn('header is too short')
else:
logging.warn('header is too short')
elif addrtype & ADDRTYPE_MASK == ADDRTYPE_IPV6:
elif addrtype == ADDRTYPE_IPV6:
if len(data) >= 19:
dest_addr = socket.inet_ntop(socket.AF_INET6, data[1:17])
dest_port = struct.unpack('>H', data[17:19])[0]

View file

@ -32,7 +32,7 @@ buf_size = 2048
def load_openssl():
global loaded, libcrypto, buf, ctx_cleanup
global loaded, libcrypto, buf
libcrypto = util.find_library(('crypto', 'eay32'),
'EVP_get_cipherbyname',
@ -49,12 +49,7 @@ def load_openssl():
libcrypto.EVP_CipherUpdate.argtypes = (c_void_p, c_void_p, c_void_p,
c_char_p, c_int)
try:
libcrypto.EVP_CIPHER_CTX_cleanup.argtypes = (c_void_p,)
ctx_cleanup = libcrypto.EVP_CIPHER_CTX_cleanup
except AttributeError:
libcrypto.EVP_CIPHER_CTX_reset.argtypes = (c_void_p,)
ctx_cleanup = libcrypto.EVP_CIPHER_CTX_reset
libcrypto.EVP_CIPHER_CTX_cleanup.argtypes = (c_void_p,)
libcrypto.EVP_CIPHER_CTX_free.argtypes = (c_void_p,)
if hasattr(libcrypto, 'OpenSSL_add_all_ciphers'):
libcrypto.OpenSSL_add_all_ciphers()
@ -113,7 +108,7 @@ class OpenSSLCrypto(object):
def clean(self):
if self._ctx:
ctx_cleanup(self._ctx)
libcrypto.EVP_CIPHER_CTX_cleanup(self._ctx)
libcrypto.EVP_CIPHER_CTX_free(self._ctx)

View file

@ -17,7 +17,7 @@
from __future__ import absolute_import, division, print_function, \
with_statement
from ctypes import c_char_p, c_int, c_ulonglong, byref, c_ulong, \
from ctypes import c_char_p, c_int, c_ulonglong, byref, \
create_string_buffer, c_void_p
from shadowsocks.crypto import util
@ -29,7 +29,7 @@ loaded = False
buf_size = 2048
# for salsa20 and chacha20 and chacha20-ietf
# for salsa20 and chacha20
BLOCK_SIZE = 64
@ -51,13 +51,6 @@ def load_libsodium():
c_ulonglong,
c_char_p, c_ulonglong,
c_char_p)
libsodium.crypto_stream_chacha20_ietf_xor_ic.restype = c_int
libsodium.crypto_stream_chacha20_ietf_xor_ic.argtypes = (c_void_p,
c_char_p,
c_ulonglong,
c_char_p,
c_ulong,
c_char_p)
buf = create_string_buffer(buf_size)
loaded = True
@ -75,8 +68,6 @@ class SodiumCrypto(object):
self.cipher = libsodium.crypto_stream_salsa20_xor_ic
elif cipher_name == 'chacha20':
self.cipher = libsodium.crypto_stream_chacha20_xor_ic
elif cipher_name == 'chacha20-ietf':
self.cipher = libsodium.crypto_stream_chacha20_ietf_xor_ic
else:
raise Exception('Unknown cipher')
# byte counter, not block counter
@ -106,7 +97,6 @@ class SodiumCrypto(object):
ciphers = {
'salsa20': (32, 8, SodiumCrypto),
'chacha20': (32, 8, SodiumCrypto),
'chacha20-ietf': (32, 12, SodiumCrypto),
}
@ -125,15 +115,6 @@ def test_chacha20():
util.run_cipher(cipher, decipher)
def test_chacha20_ietf():
cipher = SodiumCrypto('chacha20-ietf', b'k' * 32, b'i' * 16, 1)
decipher = SodiumCrypto('chacha20-ietf', b'k' * 32, b'i' * 16, 0)
util.run_cipher(cipher, decipher)
if __name__ == '__main__':
test_chacha20()
test_salsa20()
test_chacha20_ietf()

View file

@ -69,18 +69,17 @@ def EVP_BytesToKey(password, key_len, iv_len):
class Encryptor(object):
def __init__(self, password, method):
self.password = password
self.key = None
def __init__(self, key, method):
self.key = key
self.method = method
self.iv = None
self.iv_sent = False
self.cipher_iv = b''
self.decipher = None
self.decipher_iv = None
method = method.lower()
self._method_info = self.get_method_info(method)
if self._method_info:
self.cipher = self.get_cipher(password, method, 1,
self.cipher = self.get_cipher(key, method, 1,
random_string(self._method_info[1]))
else:
logging.error('method %s not supported' % method)
@ -102,7 +101,7 @@ class Encryptor(object):
else:
# key_length == 0 indicates we should use the key directly
key, iv = password, b''
self.key = key
iv = iv[:m[1]]
if op == 1:
# this iv is for cipher not decipher
@ -124,8 +123,7 @@ class Encryptor(object):
if self.decipher is None:
decipher_iv_len = self._method_info[1]
decipher_iv = buf[:decipher_iv_len]
self.decipher_iv = decipher_iv
self.decipher = self.get_cipher(self.password, self.method, 0,
self.decipher = self.get_cipher(self.key, self.method, 0,
iv=decipher_iv)
buf = buf[decipher_iv_len:]
if len(buf) == 0:
@ -133,47 +131,10 @@ class Encryptor(object):
return self.decipher.update(buf)
def gen_key_iv(password, method):
method = method.lower()
(key_len, iv_len, m) = method_supported[method]
key = None
if key_len > 0:
key, _ = EVP_BytesToKey(password, key_len, iv_len)
else:
key = password
iv = random_string(iv_len)
return key, iv, m
def encrypt_all_m(key, iv, m, method, data):
result = []
result.append(iv)
cipher = m(method, key, iv, 1)
result.append(cipher.update(data))
return b''.join(result)
def dencrypt_all(password, method, data):
result = []
method = method.lower()
(key_len, iv_len, m) = method_supported[method]
key = None
if key_len > 0:
key, _ = EVP_BytesToKey(password, key_len, iv_len)
else:
key = password
iv = data[:iv_len]
data = data[iv_len:]
cipher = m(method, key, iv, 0)
result.append(cipher.update(data))
return b''.join(result), key, iv
def encrypt_all(password, method, op, data):
result = []
method = method.lower()
(key_len, iv_len, m) = method_supported[method]
key = None
if key_len > 0:
key, _ = EVP_BytesToKey(password, key_len, iv_len)
else:
@ -221,18 +182,6 @@ def test_encrypt_all():
assert plain == plain2
def test_encrypt_all_m():
from os import urandom
plain = urandom(10240)
for method in CIPHERS_TO_TEST:
logging.warn(method)
key, iv, m = gen_key_iv(b'key', method)
cipher = encrypt_all_m(key, iv, m, method, plain)
plain2, key, iv = dencrypt_all(b'key', method, cipher)
assert plain == plain2
if __name__ == '__main__':
test_encrypt_all()
test_encryptor()
test_encrypt_all_m()

View file

@ -1,181 +1,53 @@
#!/usr/bin/python
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2013-2015 clowwindy
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
# from ssloop
# https://github.com/clowwindy/ssloop
from __future__ import absolute_import, division, print_function, \
with_statement
import os
import logging
import time
import socket
import select
import traceback
import errno
import logging
from collections import defaultdict
from shadowsocks import shell
from shadowsocks import selectors
from shadowsocks.selectors import (EVENT_READ, EVENT_WRITE, EVENT_ERROR,
errno_from_exception)
__all__ = ['EventLoop', 'POLL_NULL', 'POLL_IN', 'POLL_OUT', 'POLL_ERR',
'POLL_HUP', 'POLL_NVAL', 'EVENT_NAMES']
POLL_IN = EVENT_READ
POLL_OUT = EVENT_WRITE
POLL_ERR = EVENT_ERROR
POLL_NULL = 0x00
POLL_IN = 0x01
POLL_OUT = 0x04
POLL_ERR = 0x08
POLL_HUP = 0x10
POLL_NVAL = 0x20
EVENT_NAMES = {
POLL_NULL: 'POLL_NULL',
POLL_IN: 'POLL_IN',
POLL_OUT: 'POLL_OUT',
POLL_ERR: 'POLL_ERR',
POLL_HUP: 'POLL_HUP',
POLL_NVAL: 'POLL_NVAL',
}
# we check timeouts every TIMEOUT_PRECISION seconds
TIMEOUT_PRECISION = 10
class KqueueLoop(object):
MAX_EVENTS = 1024
class EventLoop:
def __init__(self):
self._kqueue = select.kqueue()
self._fds = {}
def _control(self, fd, mode, flags):
events = []
if mode & POLL_IN:
events.append(select.kevent(fd, select.KQ_FILTER_READ, flags))
if mode & POLL_OUT:
events.append(select.kevent(fd, select.KQ_FILTER_WRITE, flags))
for e in events:
self._kqueue.control([e], 0)
def poll(self, timeout):
if timeout < 0:
timeout = None # kqueue behaviour
events = self._kqueue.control(None, KqueueLoop.MAX_EVENTS, timeout)
results = defaultdict(lambda: POLL_NULL)
for e in events:
fd = e.ident
if e.filter == select.KQ_FILTER_READ:
results[fd] |= POLL_IN
elif e.filter == select.KQ_FILTER_WRITE:
results[fd] |= POLL_OUT
return results.items()
def register(self, fd, mode):
self._fds[fd] = mode
self._control(fd, mode, select.KQ_EV_ADD)
def unregister(self, fd):
self._control(fd, self._fds[fd], select.KQ_EV_DELETE)
del self._fds[fd]
def modify(self, fd, mode):
self.unregister(fd)
self.register(fd, mode)
def close(self):
self._kqueue.close()
class SelectLoop(object):
def __init__(self):
self._r_list = set()
self._w_list = set()
self._x_list = set()
def poll(self, timeout):
r, w, x = select.select(self._r_list, self._w_list, self._x_list,
timeout)
results = defaultdict(lambda: POLL_NULL)
for p in [(r, POLL_IN), (w, POLL_OUT), (x, POLL_ERR)]:
for fd in p[0]:
results[fd] |= p[1]
return results.items()
def register(self, fd, mode):
if mode & POLL_IN:
self._r_list.add(fd)
if mode & POLL_OUT:
self._w_list.add(fd)
if mode & POLL_ERR:
self._x_list.add(fd)
def unregister(self, fd):
if fd in self._r_list:
self._r_list.remove(fd)
if fd in self._w_list:
self._w_list.remove(fd)
if fd in self._x_list:
self._x_list.remove(fd)
def modify(self, fd, mode):
self.unregister(fd)
self.register(fd, mode)
def close(self):
pass
class EventLoop(object):
def __init__(self):
if hasattr(select, 'epoll'):
self._impl = select.epoll()
model = 'epoll'
elif hasattr(select, 'kqueue'):
self._impl = KqueueLoop()
model = 'kqueue'
elif hasattr(select, 'select'):
self._impl = SelectLoop()
model = 'select'
else:
raise Exception('can not find any available functions in select '
'package')
self._fdmap = {} # (f, handler)
self._selector = selectors.DefaultSelector()
self._stopping = False
self._last_time = time.time()
self._periodic_callbacks = []
self._stopping = False
logging.debug('using event model: %s', model)
def poll(self, timeout=None):
events = self._impl.poll(timeout)
return [(self._fdmap[fd][0], fd, event) for fd, event in events]
return self._selector.select(timeout)
def add(self, f, mode, handler):
fd = f.fileno()
self._fdmap[fd] = (f, handler)
self._impl.register(fd, mode)
def add(self, sock, events, data):
events |= selectors.EVENT_ERROR
return self._selector.register(sock, events, data)
def remove(self, f):
fd = f.fileno()
del self._fdmap[fd]
self._impl.unregister(fd)
def remove(self, sock):
try:
return self._selector.unregister(sock)
except KeyError:
pass
def modify(self, sock, events, data):
events |= selectors.EVENT_ERROR
try:
key = self._selector.modify(sock, events, data)
except KeyError:
key = self.add(sock, events, data)
return key
def add_periodic(self, callback):
self._periodic_callbacks.append(callback)
@ -183,69 +55,57 @@ class EventLoop(object):
def remove_periodic(self, callback):
self._periodic_callbacks.remove(callback)
def modify(self, f, mode):
fd = f.fileno()
self._impl.modify(fd, mode)
def stop(self):
self._stopping = True
def fd_count(self):
return len(self._selector.get_map())
def run(self):
events = []
logging.debug('Starting event loop')
while not self._stopping:
asap = False
try:
events = self.poll(TIMEOUT_PRECISION)
events = self.poll(timeout=TIMEOUT_PRECISION)
except (OSError, IOError) as e:
if errno_from_exception(e) in (errno.EPIPE, errno.EINTR):
# EPIPE: Happens when the client closes the connection
# EINTR: Happens when received a signal
# handles them as soon as possible
asap = True
logging.debug('poll:%s', e)
logging.debug('poll: %s', e)
else:
logging.error('poll:%s', e)
logging.error('poll: %s', e)
traceback.print_exc()
continue
for sock, fd, event in events:
handler = self._fdmap.get(fd, None)
if handler is not None:
handler = handler[1]
try:
handler.handle_event(sock, fd, event)
except (OSError, IOError) as e:
shell.print_exception(e)
for key, event in events:
if type(key.data) == tuple:
handler = key.data[0]
args = key.data[1:]
else:
handler = key.data
args = ()
sock = key.fileobj
if hasattr(handler, 'handle_event'):
handler = handler.handle_event
try:
handler(sock, event, *args)
except Exception as e:
logging.debug(e)
traceback.print_exc()
raise
now = time.time()
if asap or now - self._last_time >= TIMEOUT_PRECISION:
for callback in self._periodic_callbacks:
callback()
self._last_time = now
def __del__(self):
self._impl.close()
logging.debug('Got {} fds registered'.format(self.fd_count()))
logging.debug('Stopping event loop')
self._selector.close()
# from tornado
def errno_from_exception(e):
"""Provides the errno from an Exception object.
There are cases that the errno attribute was not set so we pull
the errno out of the args but if someone instatiates an Exception
without any args you will get a tuple error. So this function
abstracts all that behavior to give you a safe way to get the
errno.
"""
if hasattr(e, 'errno'):
return e.errno
elif e.args:
return e.args[0]
else:
return None
# from tornado
def get_sock_error(sock):
error_number = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
return socket.error(error_number, os.strerror(error_number))
def stop(self):
self._stopping = True

View file

@ -87,8 +87,8 @@ class LRUCache(collections.MutableMapping):
if value not in self._closed_values:
self.close_callback(value)
self._closed_values.add(value)
self._last_visits.popleft()
for key in self._time_to_keys[least]:
self._last_visits.popleft()
if key in self._store:
if now - self._keys_to_last_time[key] > self.timeout:
del self._store[key]

652
shadowsocks/selectors.py Normal file
View file

@ -0,0 +1,652 @@
"""Selectors module.
This module allows high-level and efficient I/O multiplexing, built upon the
`select` module primitives.
"""
from abc import ABCMeta, abstractmethod
from collections import namedtuple, Mapping
import errno
import math
import os
import select
import socket
import sys
# generic events, that must be mapped to implementation-specific ones
EVENT_READ = 0x001
EVENT_WRITE = 0x004
EVENT_ERROR = 0x018
def errno_from_exception(e):
"""Provides the errno from an Exception object.
There are cases that the errno attribute was not set so we pull
the errno out of the args but if someone instatiates an Exception
without any args you will get a tuple error. So this function
abstracts all that behavior to give you a safe way to get the
errno.
"""
if hasattr(e, 'errno'):
return e.errno
elif e.args:
return e.args[0]
else:
return None
def get_sock_error(sock):
if not sock:
return
error_number = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
return socket.error(error_number, os.strerror(error_number))
def _fileobj_to_fd(fileobj):
"""Return a file descriptor from a file object.
Parameters:
fileobj -- file object or file descriptor
Returns:
corresponding file descriptor
Raises:
ValueError if the object is invalid
"""
if isinstance(fileobj, int):
fd = fileobj
else:
try:
fd = int(fileobj.fileno())
except (AttributeError, TypeError, ValueError):
raise ValueError("Invalid file object: "
"{!r}".format(fileobj))
if fd < 0:
raise ValueError("Invalid file descriptor: {}".format(fd))
return fd
SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data'])
"""Object used to associate a file object to its backing file descriptor,
selected event mask and attached data."""
class _SelectorMapping(Mapping):
"""Mapping of file objects to selector keys."""
def __init__(self, selector):
self._selector = selector
def __len__(self):
return len(self._selector._fd_to_key)
def __getitem__(self, fileobj):
try:
fd = self._selector._fileobj_lookup(fileobj)
return self._selector._fd_to_key[fd]
except KeyError:
raise KeyError("{!r} is not registered".format(fileobj))
def __iter__(self):
return iter(self._selector._fd_to_key)
class BaseSelector(object):
__metaclass__ = ABCMeta
"""Selector abstract base class.
A selector supports registering file objects to be monitored for specific
I/O events.
A file object is a file descriptor or any object with a `fileno()` method.
An arbitrary object can be attached to the file object, which can be used
for example to store context information, a callback, etc.
A selector can use various implementations (select(), poll(), epoll()...)
depending on the platform. The default `Selector` class uses the most
efficient implementation on the current platform.
"""
@abstractmethod
def register(self, fileobj, events, data=None):
"""Register a file object.
Parameters:
fileobj -- file object or file descriptor
events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE)
data -- attached data
Returns:
SelectorKey instance
Raises:
ValueError if events is invalid
KeyError if fileobj is already registered
OSError if fileobj is closed or otherwise is unacceptable to
the underlying system call (if a system call is made)
Note:
OSError may or may not be raised
"""
raise NotImplementedError
@abstractmethod
def unregister(self, fileobj):
"""Unregister a file object.
Parameters:
fileobj -- file object or file descriptor
Returns:
SelectorKey instance
Raises:
KeyError if fileobj is not registered
Note:
If fileobj is registered but has since been closed this does
*not* raise OSError (even if the wrapped syscall does)
"""
raise NotImplementedError
def modify(self, fileobj, events, data=None):
"""Change a registered file object monitored events or attached data.
Parameters:
fileobj -- file object or file descriptor
events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE)
data -- attached data
Returns:
SelectorKey instance
Raises:
Anything that unregister() or register() raises
"""
self.unregister(fileobj)
return self.register(fileobj, events, data)
@abstractmethod
def select(self, timeout=None):
"""Perform the actual selection, until some monitored file objects are
ready or a timeout expires.
Parameters:
timeout -- if timeout > 0, this specifies the maximum wait time, in
seconds
if timeout <= 0, the select() call won't block, and will
report the currently ready file objects
if timeout is None, select() will block until a monitored
file object becomes ready
Returns:
list of (key, events) for ready file objects
`events` is a bitwise mask of EVENT_READ|EVENT_WRITE
"""
raise NotImplementedError
def close(self):
"""Close the selector.
This must be called to make sure that any underlying resource is freed.
"""
pass
def get_key(self, fileobj):
"""Return the key associated to a registered file object.
Returns:
SelectorKey for this file object
"""
mapping = self.get_map()
try:
if mapping is None:
raise KeyError
return mapping[fileobj]
except KeyError:
raise KeyError("{!r} is not registered".format(fileobj))
@abstractmethod
def get_map(self):
"""Return a mapping of file objects to selector keys."""
raise NotImplementedError
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
class _BaseSelectorImpl(BaseSelector):
"""Base selector implementation."""
def __init__(self):
# this maps file descriptors to keys
self._fd_to_key = {}
# read-only mapping returned by get_map()
self._map = _SelectorMapping(self)
def _fileobj_lookup(self, fileobj):
"""Return a file descriptor from a file object.
This wraps _fileobj_to_fd() to do an exhaustive search in case
the object is invalid but we still have it in our map. This
is used by unregister() so we can unregister an object that
was previously registered even if it is closed. It is also
used by _SelectorMapping.
"""
try:
return _fileobj_to_fd(fileobj)
except ValueError:
# Do an exhaustive search.
for key in self._fd_to_key.values():
if key.fileobj is fileobj:
return key.fd
# Raise ValueError after all.
raise
def register(self, fileobj, events, data=None):
if (not events) or (events & ~(EVENT_READ | EVENT_WRITE |
EVENT_ERROR)):
raise ValueError("Invalid events: {!r}".format(events))
key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data)
if key.fd in self._fd_to_key:
raise KeyError("{!r} (FD {}) is already registered"
.format(fileobj, key.fd))
self._fd_to_key[key.fd] = key
return key
def unregister(self, fileobj):
try:
key = self._fd_to_key.pop(self._fileobj_lookup(fileobj))
except KeyError:
raise KeyError("{!r} is not registered".format(fileobj))
return key
def modify(self, fileobj, events, data=None):
# TODO: Subclasses can probably optimize this even further.
try:
key = self._fd_to_key[self._fileobj_lookup(fileobj)]
except KeyError:
raise KeyError("{!r} is not registered".format(fileobj))
if events != key.events:
self.unregister(fileobj)
key = self.register(fileobj, events, data)
elif data != key.data:
# Use a shortcut to update the data.
key = key._replace(data=data)
self._fd_to_key[key.fd] = key
return key
def close(self):
self._fd_to_key.clear()
self._map = None
def get_map(self):
return self._map
def _key_from_fd(self, fd):
"""Return the key associated to a given file descriptor.
Parameters:
fd -- file descriptor
Returns:
corresponding key, or None if not found
"""
try:
return self._fd_to_key[fd]
except KeyError:
return None
class SelectSelector(_BaseSelectorImpl):
"""Select-based selector."""
def __init__(self):
super(self.__class__, self).__init__()
self._readers = set()
self._writers = set()
self._errors = set()
def register(self, fileobj, events, data=None):
key = super(self.__class__, self).register(fileobj, events, data)
if events & EVENT_READ:
self._readers.add(key.fd)
if events & EVENT_WRITE:
self._writers.add(key.fd)
if events & EVENT_ERROR:
self._errors.add(key.fd)
return key
def unregister(self, fileobj):
key = super(self.__class__, self).unregister(fileobj)
self._readers.discard(key.fd)
self._writers.discard(key.fd)
self._errors.discard(key.fd)
return key
if sys.platform == 'win32':
def _select(self, r, w, x, timeout=None):
r, w, x = select.select(r, w, x, timeout)
return r, w, x
else:
_select = select.select
def select(self, timeout=None):
timeout = None if timeout is None else max(timeout, 0)
ready = []
try:
r, w, x = self._select(self._readers, self._writers, self._errors,
timeout)
except OSError as e:
if errno_from_exception(e) == errno.EAGAIN:
return ready
r = set(r)
w = set(w)
x = set(x)
for fd in r | w | x:
events = 0
if fd in r:
events |= EVENT_READ
if fd in w:
events |= EVENT_WRITE
if fd in x:
events |= EVENT_ERROR
key = self._key_from_fd(fd)
if key:
ready.append((key, events & key.events))
return ready
if hasattr(select, 'poll'):
class PollSelector(_BaseSelectorImpl):
"""Poll-based selector."""
def __init__(self):
super(self.__class__, self).__init__()
self._poll = select.poll()
def register(self, fileobj, events, data=None):
key = super(self.__class__, self).register(fileobj, events, data)
poll_events = 0
if events & EVENT_READ:
poll_events |= select.POLLIN
if events & EVENT_WRITE:
poll_events |= select.POLLOUT
if events & EVENT_ERROR:
poll_events |= select.POLLERR | select.POLLHUP
self._poll.register(key.fd, poll_events)
return key
def unregister(self, fileobj):
key = super(self.__class__, self).unregister(fileobj)
self._poll.unregister(key.fd)
return key
def select(self, timeout=None):
if timeout is None:
timeout = None
elif timeout <= 0:
timeout = 0
else:
# poll() has a resolution of 1 millisecond, round away from
# zero to wait *at least* timeout seconds.
timeout = math.ceil(timeout * 1e3)
ready = []
try:
fd_event_list = self._poll.poll(timeout)
except OSError as e:
if errno_from_exception(e) == errno.EAGAIN:
return ready
for fd, event in fd_event_list:
events = 0
if event & select.POLLIN:
events |= EVENT_READ
if event & select.POLLOUT:
events |= EVENT_WRITE
if event & (select.POLLERR | select.POLLHUP):
events |= EVENT_ERROR
key = self._key_from_fd(fd)
if key:
ready.append((key, events & key.events))
return ready
if hasattr(select, 'epoll'):
class EpollSelector(_BaseSelectorImpl):
"""Epoll-based selector."""
def __init__(self):
super(self.__class__, self).__init__()
self._epoll = select.epoll()
def fileno(self):
return self._epoll.fileno()
def register(self, fileobj, events, data=None):
key = super(self.__class__, self).register(fileobj, events, data)
epoll_events = 0
if events & EVENT_READ:
epoll_events |= select.EPOLLIN
if events & EVENT_WRITE:
epoll_events |= select.EPOLLOUT
if events & EVENT_ERROR:
epoll_events |= select.EPOLLERR | select.EPOLLHUP
self._epoll.register(key.fd, epoll_events)
return key
def unregister(self, fileobj):
key = super(self.__class__, self).unregister(fileobj)
try:
self._epoll.unregister(key.fd)
except OSError:
# This can happen if the FD was closed since it
# was registered.
pass
return key
def select(self, timeout=None):
if timeout is None:
timeout = -1
elif timeout <= 0:
timeout = 0
else:
# epoll_wait() has a resolution of 1 millisecond, round away
# from zero to wait *at least* timeout seconds.
timeout = math.ceil(timeout * 1e3) * 1e-3
# epoll_wait() expects `maxevents` to be greater than zero;
# we want to make sure that `select()` can be called when no
# FD is registered.
max_ev = max(len(self._fd_to_key), 1)
ready = []
try:
fd_event_list = self._epoll.poll(timeout, max_ev)
except OSError as e:
if errno_from_exception(e) == errno.EAGAIN:
return ready
for fd, event in fd_event_list:
events = 0
if event & select.EPOLLIN:
events |= EVENT_READ
if event & select.EPOLLOUT:
events |= EVENT_WRITE
if event & (select.EPOLLERR | select.EPOLLHUP):
events |= EVENT_ERROR
key = self._key_from_fd(fd)
if key:
ready.append((key, events & key.events))
return ready
def close(self):
try:
self._epoll.close()
finally:
super(self.__class__, self).close()
if hasattr(select, 'devpoll'):
class DevpollSelector(_BaseSelectorImpl):
"""Solaris /dev/poll selector."""
def __init__(self):
super().__init__()
self._devpoll = select.devpoll()
def fileno(self):
return self._devpoll.fileno()
def register(self, fileobj, events, data=None):
key = super().register(fileobj, events, data)
poll_events = 0
if events & EVENT_READ:
poll_events |= select.POLLIN
if events & EVENT_WRITE:
poll_events |= select.POLLOUT
if events & EVENT_ERROR:
poll_events |= select.POLLERR | select.POLLHUP
self._devpoll.register(key.fd, poll_events)
return key
def unregister(self, fileobj):
key = super().unregister(fileobj)
self._devpoll.unregister(key.fd)
return key
def select(self, timeout=None):
if timeout is None:
timeout = None
elif timeout <= 0:
timeout = 0
else:
# devpoll() has a resolution of 1 millisecond, round away from
# zero to wait *at least* timeout seconds.
timeout = math.ceil(timeout * 1e3)
ready = []
try:
fd_event_list = self._devpoll.poll(timeout)
except InterruptedError:
return ready
for fd, event in fd_event_list:
events = 0
if event & select.POLLIN:
events |= EVENT_READ
if event & select.POLLOUT:
events |= EVENT_WRITE
if event & (select.POLLERR | select.POLLHUP):
events |= EVENT_ERROR
key = self._key_from_fd(fd)
if key:
ready.append((key, events & key.events))
return ready
def close(self):
self._devpoll.close()
super().close()
if hasattr(select, 'kqueue'):
class KqueueSelector(_BaseSelectorImpl):
"""Kqueue-based selector."""
def __init__(self):
super(self.__class__, self).__init__()
self._kqueue = select.kqueue()
def fileno(self):
return self._kqueue.fileno()
def register(self, fileobj, events, data=None):
key = super(self.__class__, self).register(fileobj, events, data)
if events & EVENT_READ:
kev = select.kevent(key.fd, select.KQ_FILTER_READ,
select.KQ_EV_ADD)
self._kqueue.control([kev], 0, 0)
if events & EVENT_WRITE:
kev = select.kevent(key.fd, select.KQ_FILTER_WRITE,
select.KQ_EV_ADD)
self._kqueue.control([kev], 0, 0)
return key
def unregister(self, fileobj):
key = super(self.__class__, self).unregister(fileobj)
if key.events & EVENT_READ:
kev = select.kevent(key.fd, select.KQ_FILTER_READ,
select.KQ_EV_DELETE)
try:
self._kqueue.control([kev], 0, 0)
except OSError:
# This can happen if the FD was closed since it
# was registered.
pass
if key.events & EVENT_WRITE:
kev = select.kevent(key.fd, select.KQ_FILTER_WRITE,
select.KQ_EV_DELETE)
try:
self._kqueue.control([kev], 0, 0)
except OSError:
# See comment above.
pass
return key
def select(self, timeout=None):
timeout = None if timeout is None else max(timeout, 0)
max_ev = len(self._fd_to_key)
ready = []
try:
kev_list = self._kqueue.control(None, max_ev, timeout)
except OSError as e:
if errno_from_exception(e) == errno.EAGAIN:
return ready
for kev in kev_list:
fd = kev.ident
flag = kev.filter
events = 0
if flag == select.KQ_FILTER_READ:
events |= EVENT_READ
if flag == select.KQ_FILTER_WRITE:
events |= EVENT_WRITE
key = self._key_from_fd(fd)
if key:
ready.append((key, events & key.events))
return ready
def close(self):
try:
self._kqueue.close()
finally:
super(self.__class__, self).close()
# Choose the best implementation: roughly, epoll|kqueue > poll > select.
# select() also can't accept a FD > FD_SETSIZE (usually around 1024)
if 'KqueueSelector' in globals():
DefaultSelector = KqueueSelector
elif 'EpollSelector' in globals():
DefaultSelector = EpollSelector
elif 'DevpollSelector' in globals():
DefaultSelector = DevpollSelector
elif 'PollSelector' in globals():
DefaultSelector = PollSelector
else:
DefaultSelector = SelectSelector

148
shadowsocks/server.py Executable file → Normal file
View file

@ -1,143 +1,35 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2015 clowwindy
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, \
with_statement
import sys
import os
import sys
import logging
import signal
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../'))
from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, \
asyncdns, manager
path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, path)
from shadowsocks.eventloop import EventLoop
from shadowsocks.tcprelay import TcpRelay, TcpRelayServerHandler
from shadowsocks.asyncdns import DNSResolver
FORMATTER = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
LOGGING_LEVEL = logging.INFO
logging.basicConfig(level=LOGGING_LEVEL, format=FORMATTER)
LISTEN_ADDR = ('0.0.0.0', 9000)
def main():
shell.check_python()
config = shell.get_config(False)
daemon.daemon_exec(config)
if config['port_password']:
if config['password']:
logging.warn('warning: port_password should not be used with '
'server_port and password. server_port and password '
'will be ignored')
else:
config['port_password'] = {}
server_port = config['server_port']
if type(server_port) == list:
for a_server_port in server_port:
config['port_password'][a_server_port] = config['password']
else:
config['port_password'][str(server_port)] = config['password']
if config.get('manager_address', 0):
logging.info('entering manager mode')
manager.run(config)
return
tcp_servers = []
udp_servers = []
if 'dns_server' in config: # allow override settings in resolv.conf
dns_resolver = asyncdns.DNSResolver(config['dns_server'],
config['prefer_ipv6'])
else:
dns_resolver = asyncdns.DNSResolver(prefer_ipv6=config['prefer_ipv6'])
port_password = config['port_password']
del config['port_password']
for port, password in port_password.items():
a_config = config.copy()
a_config['server_port'] = int(port)
a_config['password'] = password
logging.info("starting server at %s:%d" %
(a_config['server'], int(port)))
tcp_servers.append(tcprelay.TCPRelay(a_config, dns_resolver, False))
udp_servers.append(udprelay.UDPRelay(a_config, dns_resolver, False))
def run_server():
def child_handler(signum, _):
logging.warn('received SIGQUIT, doing graceful shutting down..')
list(map(lambda s: s.close(next_tick=True),
tcp_servers + udp_servers))
signal.signal(getattr(signal, 'SIGQUIT', signal.SIGTERM),
child_handler)
def int_handler(signum, _):
sys.exit(1)
signal.signal(signal.SIGINT, int_handler)
try:
loop = eventloop.EventLoop()
dns_resolver.add_to_loop(loop)
list(map(lambda s: s.add_to_loop(loop), tcp_servers + udp_servers))
daemon.set_user(config.get('user', None))
loop.run()
except Exception as e:
shell.print_exception(e)
sys.exit(1)
if int(config['workers']) > 1:
if os.name == 'posix':
children = []
is_child = False
for i in range(0, int(config['workers'])):
r = os.fork()
if r == 0:
logging.info('worker started')
is_child = True
run_server()
break
else:
children.append(r)
if not is_child:
def handler(signum, _):
for pid in children:
try:
os.kill(pid, signum)
os.waitpid(pid, 0)
except OSError: # child may already exited
pass
sys.exit()
signal.signal(signal.SIGTERM, handler)
signal.signal(signal.SIGQUIT, handler)
signal.signal(signal.SIGINT, handler)
# master
for a_tcp_server in tcp_servers:
a_tcp_server.close()
for a_udp_server in udp_servers:
a_udp_server.close()
dns_resolver.close()
for child in children:
os.waitpid(child, 0)
else:
logging.warn('worker is only available on Unix/Linux')
run_server()
else:
run_server()
loop = EventLoop()
dns_resolver = DNSResolver()
relay = TcpRelay(TcpRelayServerHandler, LISTEN_ADDR,
dns_resolver=dns_resolver)
dns_resolver.add_to_loop(loop)
relay.add_to_loop(loop)
loop.run()
if __name__ == '__main__':
main()

View file

@ -23,10 +23,6 @@ import json
import sys
import getopt
import logging
import traceback
from functools import wraps
from shadowsocks.common import to_bytes, to_str, IPNetwork
from shadowsocks import encrypt
@ -57,49 +53,6 @@ def print_exception(e):
traceback.print_exc()
def exception_handle(self_, err_msg=None, exit_code=None,
destroy=False, conn_err=False):
# self_: if function passes self as first arg
def process_exception(e, self=None):
print_exception(e)
if err_msg:
logging.error(err_msg)
if exit_code:
sys.exit(1)
if not self_:
return
if conn_err:
addr, port = self._client_address[0], self._client_address[1]
logging.error('%s when handling connection from %s:%d' %
(e, addr, port))
if self._config['verbose']:
traceback.print_exc()
if destroy:
self.destroy()
def decorator(func):
if self_:
@wraps(func)
def wrapper(self, *args, **kwargs):
try:
func(self, *args, **kwargs)
except Exception as e:
process_exception(e, self)
else:
@wraps(func)
def wrapper(*args, **kwargs):
try:
func(*args, **kwargs)
except Exception as e:
process_exception(e)
return wrapper
return decorator
def print_shadowsocks():
version = ''
try:
@ -125,30 +78,13 @@ def check_config(config, is_local):
# no need to specify configuration for daemon stop
return
if is_local:
if config.get('server', None) is None:
logging.error('server addr not specified')
print_local_help()
sys.exit(2)
else:
config['server'] = to_str(config['server'])
else:
config['server'] = to_str(config.get('server', '0.0.0.0'))
try:
config['forbidden_ip'] = \
IPNetwork(config.get('forbidden_ip', '127.0.0.0/8,::1/128'))
except Exception as e:
logging.error(e)
sys.exit(2)
if is_local and not config.get('password', None):
logging.error('password not specified')
print_help(is_local)
sys.exit(2)
if not is_local and not config.get('password', None) \
and not config.get('port_password', None) \
and not config.get('manager_address'):
and not config.get('port_password', None):
logging.error('password or port_password not specified')
print_help(is_local)
sys.exit(2)
@ -194,14 +130,13 @@ def get_config(is_local):
logging.basicConfig(level=logging.INFO,
format='%(levelname)-s: %(message)s')
if is_local:
shortopts = 'hd:s:b:p:k:l:m:c:t:vqa'
shortopts = 'hd:s:b:p:k:l:m:c:t:vq'
longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'user=',
'version']
else:
shortopts = 'hd:s:p:k:m:c:t:vqa'
shortopts = 'hd:s:p:k:m:c:t:vq'
longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'workers=',
'forbidden-ip=', 'user=', 'manager-address=', 'version',
'prefer-ipv6']
'forbidden-ip=', 'user=', 'manager-address=', 'version']
try:
config_path = find_config()
optlist, args = getopt.getopt(sys.argv[1:], shortopts, longopts)
@ -239,8 +174,6 @@ def get_config(is_local):
v_count += 1
# '-vv' turns on more verbose mode
config['verbose'] = v_count
elif key == '-a':
config['one_time_auth'] = True
elif key == '-t':
config['timeout'] = int(value)
elif key == '--fast-open':
@ -271,8 +204,6 @@ def get_config(is_local):
elif key == '-q':
v_count -= 1
config['verbose'] = v_count
elif key == '--prefer-ipv6':
config['prefer_ipv6'] = True
except getopt.GetoptError as e:
print(e, file=sys.stderr)
print_help(is_local)
@ -294,8 +225,21 @@ def get_config(is_local):
config['verbose'] = config.get('verbose', False)
config['local_address'] = to_str(config.get('local_address', '127.0.0.1'))
config['local_port'] = config.get('local_port', 1080)
config['one_time_auth'] = config.get('one_time_auth', False)
config['prefer_ipv6'] = config.get('prefer_ipv6', False)
if is_local:
if config.get('server', None) is None:
logging.error('server addr not specified')
print_local_help()
sys.exit(2)
else:
config['server'] = to_str(config['server'])
else:
config['server'] = to_str(config.get('server', '0.0.0.0'))
try:
config['forbidden_ip'] = \
IPNetwork(config.get('forbidden_ip', '127.0.0.0/8,::1/128'))
except Exception as e:
logging.error(e)
sys.exit(2)
config['server_port'] = config.get('server_port', 8388)
logging.getLogger('').handlers = []
@ -342,7 +286,6 @@ Proxy options:
-k PASSWORD password
-m METHOD encryption method, default: aes-256-cfb
-t TIMEOUT timeout in seconds, default: 300
-a ONE_TIME_AUTH one time auth
--fast-open use TCP_FASTOPEN, requires Linux 3.7+
General options:
@ -372,12 +315,10 @@ Proxy options:
-k PASSWORD password
-m METHOD encryption method, default: aes-256-cfb
-t TIMEOUT timeout in seconds, default: 300
-a ONE_TIME_AUTH one time auth
--fast-open use TCP_FASTOPEN, requires Linux 3.7+
--workers WORKERS number of workers, available on Unix/Linux
--forbidden-ip IPLIST comma seperated IP list forbidden to connect
--manager-address ADDR optional server manager UDP address, see wiki
--prefer-ipv6 resolve ipv6 address first
General options:
-h, --help show this help message and exit

File diff suppressed because it is too large Load diff