From 710c93291e23494732294b0fb37b3b8518ed4fc8 Mon Sep 17 00:00:00 2001 From: Sheng Date: Wed, 8 Feb 2017 20:59:33 +0800 Subject: [PATCH] just for learn --- shadowsocks/asyncdns.py | 79 +-- shadowsocks/client.py | 33 + shadowsocks/common.py | 35 +- shadowsocks/crypto/openssl.py | 11 +- shadowsocks/crypto/sodium.py | 23 +- shadowsocks/encrypt.py | 63 +- shadowsocks/eventloop.py | 258 ++------ shadowsocks/lru_cache.py | 2 +- shadowsocks/selectors.py | 652 ++++++++++++++++++ shadowsocks/server.py | 148 +---- shadowsocks/shell.py | 97 +-- shadowsocks/tcprelay.py | 1167 ++++++++++----------------------- 12 files changed, 1183 insertions(+), 1385 deletions(-) create mode 100644 shadowsocks/client.py create mode 100644 shadowsocks/selectors.py mode change 100755 => 100644 shadowsocks/server.py diff --git a/shadowsocks/asyncdns.py b/shadowsocks/asyncdns.py index fa5be41..0461f61 100644 --- a/shadowsocks/asyncdns.py +++ b/shadowsocks/asyncdns.py @@ -29,7 +29,7 @@ from shadowsocks import common, lru_cache, eventloop, shell CACHE_SWEEP_INTERVAL = 30 -VALID_HOSTNAME = re.compile(br"(?!-)[A-Z\d\-_]{1,63}(?= 2: + server = parts[1] + if common.is_ip(server) == socket.AF_INET: + if type(server) != str: + server = server.decode('utf8') + self._servers.append(server) except IOError: pass if not self._servers: @@ -302,17 +292,13 @@ class DNSResolver(object): for line in f.readlines(): line = line.strip() parts = line.split() - if len(parts) < 2: - continue - - ip = parts[0] - if not common.is_ip(ip): - continue - - for i in range(1, len(parts)): - hostname = parts[i] - if hostname: - self._hosts[hostname] = ip + if len(parts) >= 2: + ip = parts[0] + if common.is_ip(ip): + for i in range(1, len(parts)): + hostname = parts[i] + if hostname: + self._hosts[hostname] = ip except IOError: self._hosts['localhost'] = '127.0.0.1' @@ -352,22 +338,21 @@ class DNSResolver(object): answer[2] == QCLASS_IN: ip = answer[0] break - if not ip and self._hostname_status.get(hostname, STATUS_SECOND) \ - == STATUS_FIRST: - self._hostname_status[hostname] = STATUS_SECOND - self._send_req(hostname, self._QTYPES[1]) + if not ip and self._hostname_status.get(hostname, STATUS_IPV6) \ + == STATUS_IPV4: + self._hostname_status[hostname] = STATUS_IPV6 + self._send_req(hostname, QTYPE_AAAA) else: if ip: self._cache[hostname] = ip self._call_callback(hostname, ip) - elif self._hostname_status.get(hostname, None) \ - == STATUS_SECOND: + elif self._hostname_status.get(hostname, None) == STATUS_IPV6: for question in response.questions: - if question[1] == self._QTYPES[1]: + if question[1] == QTYPE_AAAA: self._call_callback(hostname, None) break - def handle_event(self, sock, fd, event): + def handle_event(self, sock, event): if sock != self._sock: return if event & eventloop.POLL_ERR: @@ -429,14 +414,14 @@ class DNSResolver(object): return arr = self._hostname_to_cb.get(hostname, None) if not arr: - self._hostname_status[hostname] = STATUS_FIRST - self._send_req(hostname, self._QTYPES[0]) + self._hostname_status[hostname] = STATUS_IPV4 + self._send_req(hostname, QTYPE_A) self._hostname_to_cb[hostname] = [callback] self._cb_to_hostname[callback] = hostname else: arr.append(callback) # TODO send again only if waited too long - self._send_req(hostname, self._QTYPES[0]) + self._send_req(hostname, QTYPE_A) def close(self): if self._sock: diff --git a/shadowsocks/client.py b/shadowsocks/client.py new file mode 100644 index 0000000..99067c1 --- /dev/null +++ b/shadowsocks/client.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import logging +import os +import sys + +path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, path) +from shadowsocks.eventloop import EventLoop +from shadowsocks.tcprelay import TcpRelay, TcpRelayClientHanler + + +FORMATTER = '%(asctime)s - %(levelname)s - %(message)s' +LOGGING_LEVEL = logging.INFO +logging.basicConfig(level=LOGGING_LEVEL, format=FORMATTER) + +LISTEN_ADDR = ('127.0.0.1', 7070) +REMOTE_ADDR = ('104.129.176.124', 9000) + + +def main(): + loop = EventLoop() + relay = TcpRelay(TcpRelayClientHanler, LISTEN_ADDR, REMOTE_ADDR) + relay.add_to_loop(loop) + loop.run() + + +if __name__ == '__main__': + main() diff --git a/shadowsocks/common.py b/shadowsocks/common.py index ee14995..db4beea 100644 --- a/shadowsocks/common.py +++ b/shadowsocks/common.py @@ -21,25 +21,6 @@ from __future__ import absolute_import, division, print_function, \ import socket import struct import logging -import hashlib -import hmac - - -ONETIMEAUTH_BYTES = 10 -ONETIMEAUTH_CHUNK_BYTES = 12 -ONETIMEAUTH_CHUNK_DATA_LEN = 2 - - -def sha1_hmac(secret, data): - return hmac.new(secret, data, hashlib.sha1).digest() - - -def onetimeauth_verify(_hash, data, key): - return _hash == sha1_hmac(key, data)[:ONETIMEAUTH_BYTES] - - -def onetimeauth_gen(data, key): - return sha1_hmac(key, data)[:ONETIMEAUTH_BYTES] def compat_ord(s): @@ -137,11 +118,9 @@ def patch_socket(): patch_socket() -ADDRTYPE_IPV4 = 0x01 -ADDRTYPE_IPV6 = 0x04 -ADDRTYPE_HOST = 0x03 -ADDRTYPE_AUTH = 0x10 -ADDRTYPE_MASK = 0xF +ADDRTYPE_IPV4 = 1 +ADDRTYPE_IPV6 = 4 +ADDRTYPE_HOST = 3 def pack_addr(address): @@ -165,17 +144,17 @@ def parse_header(data): dest_addr = None dest_port = None header_length = 0 - if addrtype & ADDRTYPE_MASK == ADDRTYPE_IPV4: + if addrtype == ADDRTYPE_IPV4: if len(data) >= 7: dest_addr = socket.inet_ntoa(data[1:5]) dest_port = struct.unpack('>H', data[5:7])[0] header_length = 7 else: logging.warn('header is too short') - elif addrtype & ADDRTYPE_MASK == ADDRTYPE_HOST: + elif addrtype == ADDRTYPE_HOST: if len(data) > 2: addrlen = ord(data[1]) - if len(data) >= 4 + addrlen: + if len(data) >= 2 + addrlen: dest_addr = data[2:2 + addrlen] dest_port = struct.unpack('>H', data[2 + addrlen:4 + addrlen])[0] @@ -184,7 +163,7 @@ def parse_header(data): logging.warn('header is too short') else: logging.warn('header is too short') - elif addrtype & ADDRTYPE_MASK == ADDRTYPE_IPV6: + elif addrtype == ADDRTYPE_IPV6: if len(data) >= 19: dest_addr = socket.inet_ntop(socket.AF_INET6, data[1:17]) dest_port = struct.unpack('>H', data[17:19])[0] diff --git a/shadowsocks/crypto/openssl.py b/shadowsocks/crypto/openssl.py index da7f177..3775b6c 100644 --- a/shadowsocks/crypto/openssl.py +++ b/shadowsocks/crypto/openssl.py @@ -32,7 +32,7 @@ buf_size = 2048 def load_openssl(): - global loaded, libcrypto, buf, ctx_cleanup + global loaded, libcrypto, buf libcrypto = util.find_library(('crypto', 'eay32'), 'EVP_get_cipherbyname', @@ -49,12 +49,7 @@ def load_openssl(): libcrypto.EVP_CipherUpdate.argtypes = (c_void_p, c_void_p, c_void_p, c_char_p, c_int) - try: - libcrypto.EVP_CIPHER_CTX_cleanup.argtypes = (c_void_p,) - ctx_cleanup = libcrypto.EVP_CIPHER_CTX_cleanup - except AttributeError: - libcrypto.EVP_CIPHER_CTX_reset.argtypes = (c_void_p,) - ctx_cleanup = libcrypto.EVP_CIPHER_CTX_reset + libcrypto.EVP_CIPHER_CTX_cleanup.argtypes = (c_void_p,) libcrypto.EVP_CIPHER_CTX_free.argtypes = (c_void_p,) if hasattr(libcrypto, 'OpenSSL_add_all_ciphers'): libcrypto.OpenSSL_add_all_ciphers() @@ -113,7 +108,7 @@ class OpenSSLCrypto(object): def clean(self): if self._ctx: - ctx_cleanup(self._ctx) + libcrypto.EVP_CIPHER_CTX_cleanup(self._ctx) libcrypto.EVP_CIPHER_CTX_free(self._ctx) diff --git a/shadowsocks/crypto/sodium.py b/shadowsocks/crypto/sodium.py index b744e2c..ae86fef 100644 --- a/shadowsocks/crypto/sodium.py +++ b/shadowsocks/crypto/sodium.py @@ -17,7 +17,7 @@ from __future__ import absolute_import, division, print_function, \ with_statement -from ctypes import c_char_p, c_int, c_ulonglong, byref, c_ulong, \ +from ctypes import c_char_p, c_int, c_ulonglong, byref, \ create_string_buffer, c_void_p from shadowsocks.crypto import util @@ -29,7 +29,7 @@ loaded = False buf_size = 2048 -# for salsa20 and chacha20 and chacha20-ietf +# for salsa20 and chacha20 BLOCK_SIZE = 64 @@ -51,13 +51,6 @@ def load_libsodium(): c_ulonglong, c_char_p, c_ulonglong, c_char_p) - libsodium.crypto_stream_chacha20_ietf_xor_ic.restype = c_int - libsodium.crypto_stream_chacha20_ietf_xor_ic.argtypes = (c_void_p, - c_char_p, - c_ulonglong, - c_char_p, - c_ulong, - c_char_p) buf = create_string_buffer(buf_size) loaded = True @@ -75,8 +68,6 @@ class SodiumCrypto(object): self.cipher = libsodium.crypto_stream_salsa20_xor_ic elif cipher_name == 'chacha20': self.cipher = libsodium.crypto_stream_chacha20_xor_ic - elif cipher_name == 'chacha20-ietf': - self.cipher = libsodium.crypto_stream_chacha20_ietf_xor_ic else: raise Exception('Unknown cipher') # byte counter, not block counter @@ -106,7 +97,6 @@ class SodiumCrypto(object): ciphers = { 'salsa20': (32, 8, SodiumCrypto), 'chacha20': (32, 8, SodiumCrypto), - 'chacha20-ietf': (32, 12, SodiumCrypto), } @@ -125,15 +115,6 @@ def test_chacha20(): util.run_cipher(cipher, decipher) -def test_chacha20_ietf(): - - cipher = SodiumCrypto('chacha20-ietf', b'k' * 32, b'i' * 16, 1) - decipher = SodiumCrypto('chacha20-ietf', b'k' * 32, b'i' * 16, 0) - - util.run_cipher(cipher, decipher) - - if __name__ == '__main__': test_chacha20() test_salsa20() - test_chacha20_ietf() diff --git a/shadowsocks/encrypt.py b/shadowsocks/encrypt.py index ece72ec..4e87f41 100644 --- a/shadowsocks/encrypt.py +++ b/shadowsocks/encrypt.py @@ -69,18 +69,17 @@ def EVP_BytesToKey(password, key_len, iv_len): class Encryptor(object): - def __init__(self, password, method): - self.password = password - self.key = None + def __init__(self, key, method): + self.key = key self.method = method + self.iv = None self.iv_sent = False self.cipher_iv = b'' self.decipher = None - self.decipher_iv = None method = method.lower() self._method_info = self.get_method_info(method) if self._method_info: - self.cipher = self.get_cipher(password, method, 1, + self.cipher = self.get_cipher(key, method, 1, random_string(self._method_info[1])) else: logging.error('method %s not supported' % method) @@ -102,7 +101,7 @@ class Encryptor(object): else: # key_length == 0 indicates we should use the key directly key, iv = password, b'' - self.key = key + iv = iv[:m[1]] if op == 1: # this iv is for cipher not decipher @@ -124,8 +123,7 @@ class Encryptor(object): if self.decipher is None: decipher_iv_len = self._method_info[1] decipher_iv = buf[:decipher_iv_len] - self.decipher_iv = decipher_iv - self.decipher = self.get_cipher(self.password, self.method, 0, + self.decipher = self.get_cipher(self.key, self.method, 0, iv=decipher_iv) buf = buf[decipher_iv_len:] if len(buf) == 0: @@ -133,47 +131,10 @@ class Encryptor(object): return self.decipher.update(buf) -def gen_key_iv(password, method): - method = method.lower() - (key_len, iv_len, m) = method_supported[method] - key = None - if key_len > 0: - key, _ = EVP_BytesToKey(password, key_len, iv_len) - else: - key = password - iv = random_string(iv_len) - return key, iv, m - - -def encrypt_all_m(key, iv, m, method, data): - result = [] - result.append(iv) - cipher = m(method, key, iv, 1) - result.append(cipher.update(data)) - return b''.join(result) - - -def dencrypt_all(password, method, data): - result = [] - method = method.lower() - (key_len, iv_len, m) = method_supported[method] - key = None - if key_len > 0: - key, _ = EVP_BytesToKey(password, key_len, iv_len) - else: - key = password - iv = data[:iv_len] - data = data[iv_len:] - cipher = m(method, key, iv, 0) - result.append(cipher.update(data)) - return b''.join(result), key, iv - - def encrypt_all(password, method, op, data): result = [] method = method.lower() (key_len, iv_len, m) = method_supported[method] - key = None if key_len > 0: key, _ = EVP_BytesToKey(password, key_len, iv_len) else: @@ -221,18 +182,6 @@ def test_encrypt_all(): assert plain == plain2 -def test_encrypt_all_m(): - from os import urandom - plain = urandom(10240) - for method in CIPHERS_TO_TEST: - logging.warn(method) - key, iv, m = gen_key_iv(b'key', method) - cipher = encrypt_all_m(key, iv, m, method, plain) - plain2, key, iv = dencrypt_all(b'key', method, cipher) - assert plain == plain2 - - if __name__ == '__main__': test_encrypt_all() test_encryptor() - test_encrypt_all_m() diff --git a/shadowsocks/eventloop.py b/shadowsocks/eventloop.py index ce5da37..36bca4e 100644 --- a/shadowsocks/eventloop.py +++ b/shadowsocks/eventloop.py @@ -1,181 +1,53 @@ -#!/usr/bin/python +#!/usr/bin/env python # -*- coding: utf-8 -*- -# -# Copyright 2013-2015 clowwindy -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -# from ssloop -# https://github.com/clowwindy/ssloop from __future__ import absolute_import, division, print_function, \ with_statement -import os +import logging import time -import socket -import select import traceback import errno -import logging -from collections import defaultdict - -from shadowsocks import shell +from shadowsocks import selectors +from shadowsocks.selectors import (EVENT_READ, EVENT_WRITE, EVENT_ERROR, + errno_from_exception) -__all__ = ['EventLoop', 'POLL_NULL', 'POLL_IN', 'POLL_OUT', 'POLL_ERR', - 'POLL_HUP', 'POLL_NVAL', 'EVENT_NAMES'] +POLL_IN = EVENT_READ +POLL_OUT = EVENT_WRITE +POLL_ERR = EVENT_ERROR -POLL_NULL = 0x00 -POLL_IN = 0x01 -POLL_OUT = 0x04 -POLL_ERR = 0x08 -POLL_HUP = 0x10 -POLL_NVAL = 0x20 - - -EVENT_NAMES = { - POLL_NULL: 'POLL_NULL', - POLL_IN: 'POLL_IN', - POLL_OUT: 'POLL_OUT', - POLL_ERR: 'POLL_ERR', - POLL_HUP: 'POLL_HUP', - POLL_NVAL: 'POLL_NVAL', -} - -# we check timeouts every TIMEOUT_PRECISION seconds TIMEOUT_PRECISION = 10 -class KqueueLoop(object): - - MAX_EVENTS = 1024 +class EventLoop: def __init__(self): - self._kqueue = select.kqueue() - self._fds = {} - - def _control(self, fd, mode, flags): - events = [] - if mode & POLL_IN: - events.append(select.kevent(fd, select.KQ_FILTER_READ, flags)) - if mode & POLL_OUT: - events.append(select.kevent(fd, select.KQ_FILTER_WRITE, flags)) - for e in events: - self._kqueue.control([e], 0) - - def poll(self, timeout): - if timeout < 0: - timeout = None # kqueue behaviour - events = self._kqueue.control(None, KqueueLoop.MAX_EVENTS, timeout) - results = defaultdict(lambda: POLL_NULL) - for e in events: - fd = e.ident - if e.filter == select.KQ_FILTER_READ: - results[fd] |= POLL_IN - elif e.filter == select.KQ_FILTER_WRITE: - results[fd] |= POLL_OUT - return results.items() - - def register(self, fd, mode): - self._fds[fd] = mode - self._control(fd, mode, select.KQ_EV_ADD) - - def unregister(self, fd): - self._control(fd, self._fds[fd], select.KQ_EV_DELETE) - del self._fds[fd] - - def modify(self, fd, mode): - self.unregister(fd) - self.register(fd, mode) - - def close(self): - self._kqueue.close() - - -class SelectLoop(object): - - def __init__(self): - self._r_list = set() - self._w_list = set() - self._x_list = set() - - def poll(self, timeout): - r, w, x = select.select(self._r_list, self._w_list, self._x_list, - timeout) - results = defaultdict(lambda: POLL_NULL) - for p in [(r, POLL_IN), (w, POLL_OUT), (x, POLL_ERR)]: - for fd in p[0]: - results[fd] |= p[1] - return results.items() - - def register(self, fd, mode): - if mode & POLL_IN: - self._r_list.add(fd) - if mode & POLL_OUT: - self._w_list.add(fd) - if mode & POLL_ERR: - self._x_list.add(fd) - - def unregister(self, fd): - if fd in self._r_list: - self._r_list.remove(fd) - if fd in self._w_list: - self._w_list.remove(fd) - if fd in self._x_list: - self._x_list.remove(fd) - - def modify(self, fd, mode): - self.unregister(fd) - self.register(fd, mode) - - def close(self): - pass - - -class EventLoop(object): - def __init__(self): - if hasattr(select, 'epoll'): - self._impl = select.epoll() - model = 'epoll' - elif hasattr(select, 'kqueue'): - self._impl = KqueueLoop() - model = 'kqueue' - elif hasattr(select, 'select'): - self._impl = SelectLoop() - model = 'select' - else: - raise Exception('can not find any available functions in select ' - 'package') - self._fdmap = {} # (f, handler) + self._selector = selectors.DefaultSelector() + self._stopping = False self._last_time = time.time() self._periodic_callbacks = [] - self._stopping = False - logging.debug('using event model: %s', model) def poll(self, timeout=None): - events = self._impl.poll(timeout) - return [(self._fdmap[fd][0], fd, event) for fd, event in events] + return self._selector.select(timeout) - def add(self, f, mode, handler): - fd = f.fileno() - self._fdmap[fd] = (f, handler) - self._impl.register(fd, mode) + def add(self, sock, events, data): + events |= selectors.EVENT_ERROR + return self._selector.register(sock, events, data) - def remove(self, f): - fd = f.fileno() - del self._fdmap[fd] - self._impl.unregister(fd) + def remove(self, sock): + try: + return self._selector.unregister(sock) + except KeyError: + pass + + def modify(self, sock, events, data): + events |= selectors.EVENT_ERROR + try: + key = self._selector.modify(sock, events, data) + except KeyError: + key = self.add(sock, events, data) + return key def add_periodic(self, callback): self._periodic_callbacks.append(callback) @@ -183,69 +55,57 @@ class EventLoop(object): def remove_periodic(self, callback): self._periodic_callbacks.remove(callback) - def modify(self, f, mode): - fd = f.fileno() - self._impl.modify(fd, mode) - - def stop(self): - self._stopping = True + def fd_count(self): + return len(self._selector.get_map()) def run(self): - events = [] + logging.debug('Starting event loop') + while not self._stopping: asap = False try: - events = self.poll(TIMEOUT_PRECISION) + events = self.poll(timeout=TIMEOUT_PRECISION) except (OSError, IOError) as e: if errno_from_exception(e) in (errno.EPIPE, errno.EINTR): # EPIPE: Happens when the client closes the connection # EINTR: Happens when received a signal # handles them as soon as possible asap = True - logging.debug('poll:%s', e) + logging.debug('poll: %s', e) else: - logging.error('poll:%s', e) + logging.error('poll: %s', e) traceback.print_exc() continue - for sock, fd, event in events: - handler = self._fdmap.get(fd, None) - if handler is not None: - handler = handler[1] - try: - handler.handle_event(sock, fd, event) - except (OSError, IOError) as e: - shell.print_exception(e) + for key, event in events: + if type(key.data) == tuple: + handler = key.data[0] + args = key.data[1:] + else: + handler = key.data + args = () + + sock = key.fileobj + if hasattr(handler, 'handle_event'): + handler = handler.handle_event + + try: + handler(sock, event, *args) + except Exception as e: + logging.debug(e) + traceback.print_exc() + raise + now = time.time() if asap or now - self._last_time >= TIMEOUT_PRECISION: for callback in self._periodic_callbacks: callback() self._last_time = now - def __del__(self): - self._impl.close() + logging.debug('Got {} fds registered'.format(self.fd_count())) + logging.debug('Stopping event loop') + self._selector.close() -# from tornado -def errno_from_exception(e): - """Provides the errno from an Exception object. - - There are cases that the errno attribute was not set so we pull - the errno out of the args but if someone instatiates an Exception - without any args you will get a tuple error. So this function - abstracts all that behavior to give you a safe way to get the - errno. - """ - - if hasattr(e, 'errno'): - return e.errno - elif e.args: - return e.args[0] - else: - return None - - -# from tornado -def get_sock_error(sock): - error_number = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) - return socket.error(error_number, os.strerror(error_number)) + def stop(self): + self._stopping = True diff --git a/shadowsocks/lru_cache.py b/shadowsocks/lru_cache.py index ff4fc7d..401f19b 100644 --- a/shadowsocks/lru_cache.py +++ b/shadowsocks/lru_cache.py @@ -87,8 +87,8 @@ class LRUCache(collections.MutableMapping): if value not in self._closed_values: self.close_callback(value) self._closed_values.add(value) - self._last_visits.popleft() for key in self._time_to_keys[least]: + self._last_visits.popleft() if key in self._store: if now - self._keys_to_last_time[key] > self.timeout: del self._store[key] diff --git a/shadowsocks/selectors.py b/shadowsocks/selectors.py new file mode 100644 index 0000000..d95251e --- /dev/null +++ b/shadowsocks/selectors.py @@ -0,0 +1,652 @@ +"""Selectors module. + +This module allows high-level and efficient I/O multiplexing, built upon the +`select` module primitives. +""" + + +from abc import ABCMeta, abstractmethod +from collections import namedtuple, Mapping +import errno +import math +import os +import select +import socket +import sys + + +# generic events, that must be mapped to implementation-specific ones +EVENT_READ = 0x001 +EVENT_WRITE = 0x004 +EVENT_ERROR = 0x018 + + +def errno_from_exception(e): + """Provides the errno from an Exception object. + + There are cases that the errno attribute was not set so we pull + the errno out of the args but if someone instatiates an Exception + without any args you will get a tuple error. So this function + abstracts all that behavior to give you a safe way to get the + errno. + """ + + if hasattr(e, 'errno'): + return e.errno + elif e.args: + return e.args[0] + else: + return None + + +def get_sock_error(sock): + if not sock: + return + error_number = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + return socket.error(error_number, os.strerror(error_number)) + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + corresponding file descriptor + + Raises: + ValueError if the object is invalid + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (AttributeError, TypeError, ValueError): + raise ValueError("Invalid file object: " + "{!r}".format(fileobj)) + if fd < 0: + raise ValueError("Invalid file descriptor: {}".format(fd)) + return fd + + +SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data']) +"""Object used to associate a file object to its backing file descriptor, +selected event mask and attached data.""" + + +class _SelectorMapping(Mapping): + """Mapping of file objects to selector keys.""" + + def __init__(self, selector): + self._selector = selector + + def __len__(self): + return len(self._selector._fd_to_key) + + def __getitem__(self, fileobj): + try: + fd = self._selector._fileobj_lookup(fileobj) + return self._selector._fd_to_key[fd] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) + + def __iter__(self): + return iter(self._selector._fd_to_key) + + +class BaseSelector(object): + __metaclass__ = ABCMeta + """Selector abstract base class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + efficient implementation on the current platform. + """ + + @abstractmethod + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + + Raises: + ValueError if events is invalid + KeyError if fileobj is already registered + OSError if fileobj is closed or otherwise is unacceptable to + the underlying system call (if a system call is made) + + Note: + OSError may or may not be raised + """ + raise NotImplementedError + + @abstractmethod + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + SelectorKey instance + + Raises: + KeyError if fileobj is not registered + + Note: + If fileobj is registered but has since been closed this does + *not* raise OSError (even if the wrapped syscall does) + """ + raise NotImplementedError + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + + Raises: + Anything that unregister() or register() raises + """ + self.unregister(fileobj) + return self.register(fileobj, events, data) + + @abstractmethod + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout <= 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (key, events) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE + """ + raise NotImplementedError + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + pass + + def get_key(self, fileobj): + """Return the key associated to a registered file object. + + Returns: + SelectorKey for this file object + """ + mapping = self.get_map() + try: + if mapping is None: + raise KeyError + return mapping[fileobj] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) + + @abstractmethod + def get_map(self): + """Return a mapping of file objects to selector keys.""" + raise NotImplementedError + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + +class _BaseSelectorImpl(BaseSelector): + """Base selector implementation.""" + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # read-only mapping returned by get_map() + self._map = _SelectorMapping(self) + + def _fileobj_lookup(self, fileobj): + """Return a file descriptor from a file object. + + This wraps _fileobj_to_fd() to do an exhaustive search in case + the object is invalid but we still have it in our map. This + is used by unregister() so we can unregister an object that + was previously registered even if it is closed. It is also + used by _SelectorMapping. + """ + try: + return _fileobj_to_fd(fileobj) + except ValueError: + # Do an exhaustive search. + for key in self._fd_to_key.values(): + if key.fileobj is fileobj: + return key.fd + # Raise ValueError after all. + raise + + def register(self, fileobj, events, data=None): + if (not events) or (events & ~(EVENT_READ | EVENT_WRITE | + EVENT_ERROR)): + raise ValueError("Invalid events: {!r}".format(events)) + + key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data) + + if key.fd in self._fd_to_key: + raise KeyError("{!r} (FD {}) is already registered" + .format(fileobj, key.fd)) + + self._fd_to_key[key.fd] = key + return key + + def unregister(self, fileobj): + try: + key = self._fd_to_key.pop(self._fileobj_lookup(fileobj)) + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) + return key + + def modify(self, fileobj, events, data=None): + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fd_to_key[self._fileobj_lookup(fileobj)] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) + if events != key.events: + self.unregister(fileobj) + key = self.register(fileobj, events, data) + elif data != key.data: + # Use a shortcut to update the data. + key = key._replace(data=data) + self._fd_to_key[key.fd] = key + return key + + def close(self): + self._fd_to_key.clear() + self._map = None + + def get_map(self): + return self._map + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key, or None if not found + """ + try: + return self._fd_to_key[fd] + except KeyError: + return None + + +class SelectSelector(_BaseSelectorImpl): + """Select-based selector.""" + + def __init__(self): + super(self.__class__, self).__init__() + self._readers = set() + self._writers = set() + self._errors = set() + + def register(self, fileobj, events, data=None): + key = super(self.__class__, self).register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & EVENT_WRITE: + self._writers.add(key.fd) + if events & EVENT_ERROR: + self._errors.add(key.fd) + return key + + def unregister(self, fileobj): + key = super(self.__class__, self).unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + self._errors.discard(key.fd) + return key + + if sys.platform == 'win32': + def _select(self, r, w, x, timeout=None): + r, w, x = select.select(r, w, x, timeout) + return r, w, x + else: + _select = select.select + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + ready = [] + try: + r, w, x = self._select(self._readers, self._writers, self._errors, + timeout) + except OSError as e: + if errno_from_exception(e) == errno.EAGAIN: + return ready + r = set(r) + w = set(w) + x = set(x) + for fd in r | w | x: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE + if fd in x: + events |= EVENT_ERROR + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'poll'): + + class PollSelector(_BaseSelectorImpl): + """Poll-based selector.""" + + def __init__(self): + super(self.__class__, self).__init__() + self._poll = select.poll() + + def register(self, fileobj, events, data=None): + key = super(self.__class__, self).register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= select.POLLIN + if events & EVENT_WRITE: + poll_events |= select.POLLOUT + if events & EVENT_ERROR: + poll_events |= select.POLLERR | select.POLLHUP + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super(self.__class__, self).unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + if timeout is None: + timeout = None + elif timeout <= 0: + timeout = 0 + else: + # poll() has a resolution of 1 millisecond, round away from + # zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except OSError as e: + if errno_from_exception(e) == errno.EAGAIN: + return ready + for fd, event in fd_event_list: + events = 0 + if event & select.POLLIN: + events |= EVENT_READ + if event & select.POLLOUT: + events |= EVENT_WRITE + if event & (select.POLLERR | select.POLLHUP): + events |= EVENT_ERROR + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'epoll'): + + class EpollSelector(_BaseSelectorImpl): + """Epoll-based selector.""" + + def __init__(self): + super(self.__class__, self).__init__() + self._epoll = select.epoll() + + def fileno(self): + return self._epoll.fileno() + + def register(self, fileobj, events, data=None): + key = super(self.__class__, self).register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= select.EPOLLIN + if events & EVENT_WRITE: + epoll_events |= select.EPOLLOUT + if events & EVENT_ERROR: + epoll_events |= select.EPOLLERR | select.EPOLLHUP + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super(self.__class__, self).unregister(fileobj) + try: + self._epoll.unregister(key.fd) + except OSError: + # This can happen if the FD was closed since it + # was registered. + pass + return key + + def select(self, timeout=None): + if timeout is None: + timeout = -1 + elif timeout <= 0: + timeout = 0 + else: + # epoll_wait() has a resolution of 1 millisecond, round away + # from zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) * 1e-3 + + # epoll_wait() expects `maxevents` to be greater than zero; + # we want to make sure that `select()` can be called when no + # FD is registered. + max_ev = max(len(self._fd_to_key), 1) + + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except OSError as e: + if errno_from_exception(e) == errno.EAGAIN: + return ready + for fd, event in fd_event_list: + events = 0 + if event & select.EPOLLIN: + events |= EVENT_READ + if event & select.EPOLLOUT: + events |= EVENT_WRITE + if event & (select.EPOLLERR | select.EPOLLHUP): + events |= EVENT_ERROR + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + try: + self._epoll.close() + finally: + super(self.__class__, self).close() + + +if hasattr(select, 'devpoll'): + + class DevpollSelector(_BaseSelectorImpl): + """Solaris /dev/poll selector.""" + + def __init__(self): + super().__init__() + self._devpoll = select.devpoll() + + def fileno(self): + return self._devpoll.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= select.POLLIN + if events & EVENT_WRITE: + poll_events |= select.POLLOUT + if events & EVENT_ERROR: + poll_events |= select.POLLERR | select.POLLHUP + self._devpoll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._devpoll.unregister(key.fd) + return key + + def select(self, timeout=None): + if timeout is None: + timeout = None + elif timeout <= 0: + timeout = 0 + else: + # devpoll() has a resolution of 1 millisecond, round away from + # zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) + ready = [] + try: + fd_event_list = self._devpoll.poll(timeout) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & select.POLLIN: + events |= EVENT_READ + if event & select.POLLOUT: + events |= EVENT_WRITE + if event & (select.POLLERR | select.POLLHUP): + events |= EVENT_ERROR + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + self._devpoll.close() + super().close() + + +if hasattr(select, 'kqueue'): + + class KqueueSelector(_BaseSelectorImpl): + """Kqueue-based selector.""" + + def __init__(self): + super(self.__class__, self).__init__() + self._kqueue = select.kqueue() + + def fileno(self): + return self._kqueue.fileno() + + def register(self, fileobj, events, data=None): + key = super(self.__class__, self).register(fileobj, events, data) + if events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def unregister(self, fileobj): + key = super(self.__class__, self).unregister(fileobj) + if key.events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_DELETE) + try: + self._kqueue.control([kev], 0, 0) + except OSError: + # This can happen if the FD was closed since it + # was registered. + pass + if key.events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_DELETE) + try: + self._kqueue.control([kev], 0, 0) + except OSError: + # See comment above. + pass + return key + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + max_ev = len(self._fd_to_key) + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except OSError as e: + if errno_from_exception(e) == errno.EAGAIN: + return ready + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == select.KQ_FILTER_READ: + events |= EVENT_READ + if flag == select.KQ_FILTER_WRITE: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + try: + self._kqueue.close() + finally: + super(self.__class__, self).close() + + +# Choose the best implementation: roughly, epoll|kqueue > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + DefaultSelector = KqueueSelector +elif 'EpollSelector' in globals(): + DefaultSelector = EpollSelector +elif 'DevpollSelector' in globals(): + DefaultSelector = DevpollSelector +elif 'PollSelector' in globals(): + DefaultSelector = PollSelector +else: + DefaultSelector = SelectSelector diff --git a/shadowsocks/server.py b/shadowsocks/server.py old mode 100755 new mode 100644 index 4dc5621..1d99c8d --- a/shadowsocks/server.py +++ b/shadowsocks/server.py @@ -1,143 +1,35 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# -# Copyright 2015 clowwindy -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. from __future__ import absolute_import, division, print_function, \ with_statement -import sys import os +import sys import logging -import signal -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../')) -from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, \ - asyncdns, manager +path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, path) +from shadowsocks.eventloop import EventLoop +from shadowsocks.tcprelay import TcpRelay, TcpRelayServerHandler +from shadowsocks.asyncdns import DNSResolver + + +FORMATTER = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' +LOGGING_LEVEL = logging.INFO +logging.basicConfig(level=LOGGING_LEVEL, format=FORMATTER) + +LISTEN_ADDR = ('0.0.0.0', 9000) def main(): - shell.check_python() - - config = shell.get_config(False) - - daemon.daemon_exec(config) - - if config['port_password']: - if config['password']: - logging.warn('warning: port_password should not be used with ' - 'server_port and password. server_port and password ' - 'will be ignored') - else: - config['port_password'] = {} - server_port = config['server_port'] - if type(server_port) == list: - for a_server_port in server_port: - config['port_password'][a_server_port] = config['password'] - else: - config['port_password'][str(server_port)] = config['password'] - - if config.get('manager_address', 0): - logging.info('entering manager mode') - manager.run(config) - return - - tcp_servers = [] - udp_servers = [] - - if 'dns_server' in config: # allow override settings in resolv.conf - dns_resolver = asyncdns.DNSResolver(config['dns_server'], - config['prefer_ipv6']) - else: - dns_resolver = asyncdns.DNSResolver(prefer_ipv6=config['prefer_ipv6']) - - port_password = config['port_password'] - del config['port_password'] - for port, password in port_password.items(): - a_config = config.copy() - a_config['server_port'] = int(port) - a_config['password'] = password - logging.info("starting server at %s:%d" % - (a_config['server'], int(port))) - tcp_servers.append(tcprelay.TCPRelay(a_config, dns_resolver, False)) - udp_servers.append(udprelay.UDPRelay(a_config, dns_resolver, False)) - - def run_server(): - def child_handler(signum, _): - logging.warn('received SIGQUIT, doing graceful shutting down..') - list(map(lambda s: s.close(next_tick=True), - tcp_servers + udp_servers)) - signal.signal(getattr(signal, 'SIGQUIT', signal.SIGTERM), - child_handler) - - def int_handler(signum, _): - sys.exit(1) - signal.signal(signal.SIGINT, int_handler) - - try: - loop = eventloop.EventLoop() - dns_resolver.add_to_loop(loop) - list(map(lambda s: s.add_to_loop(loop), tcp_servers + udp_servers)) - - daemon.set_user(config.get('user', None)) - loop.run() - except Exception as e: - shell.print_exception(e) - sys.exit(1) - - if int(config['workers']) > 1: - if os.name == 'posix': - children = [] - is_child = False - for i in range(0, int(config['workers'])): - r = os.fork() - if r == 0: - logging.info('worker started') - is_child = True - run_server() - break - else: - children.append(r) - if not is_child: - def handler(signum, _): - for pid in children: - try: - os.kill(pid, signum) - os.waitpid(pid, 0) - except OSError: # child may already exited - pass - sys.exit() - signal.signal(signal.SIGTERM, handler) - signal.signal(signal.SIGQUIT, handler) - signal.signal(signal.SIGINT, handler) - - # master - for a_tcp_server in tcp_servers: - a_tcp_server.close() - for a_udp_server in udp_servers: - a_udp_server.close() - dns_resolver.close() - - for child in children: - os.waitpid(child, 0) - else: - logging.warn('worker is only available on Unix/Linux') - run_server() - else: - run_server() - + loop = EventLoop() + dns_resolver = DNSResolver() + relay = TcpRelay(TcpRelayServerHandler, LISTEN_ADDR, + dns_resolver=dns_resolver) + dns_resolver.add_to_loop(loop) + relay.add_to_loop(loop) + loop.run() if __name__ == '__main__': main() diff --git a/shadowsocks/shell.py b/shadowsocks/shell.py index 3c6676f..c91fc22 100644 --- a/shadowsocks/shell.py +++ b/shadowsocks/shell.py @@ -23,10 +23,6 @@ import json import sys import getopt import logging -import traceback - -from functools import wraps - from shadowsocks.common import to_bytes, to_str, IPNetwork from shadowsocks import encrypt @@ -57,49 +53,6 @@ def print_exception(e): traceback.print_exc() -def exception_handle(self_, err_msg=None, exit_code=None, - destroy=False, conn_err=False): - # self_: if function passes self as first arg - - def process_exception(e, self=None): - print_exception(e) - if err_msg: - logging.error(err_msg) - if exit_code: - sys.exit(1) - - if not self_: - return - - if conn_err: - addr, port = self._client_address[0], self._client_address[1] - logging.error('%s when handling connection from %s:%d' % - (e, addr, port)) - if self._config['verbose']: - traceback.print_exc() - if destroy: - self.destroy() - - def decorator(func): - if self_: - @wraps(func) - def wrapper(self, *args, **kwargs): - try: - func(self, *args, **kwargs) - except Exception as e: - process_exception(e, self) - else: - @wraps(func) - def wrapper(*args, **kwargs): - try: - func(*args, **kwargs) - except Exception as e: - process_exception(e) - - return wrapper - return decorator - - def print_shadowsocks(): version = '' try: @@ -125,30 +78,13 @@ def check_config(config, is_local): # no need to specify configuration for daemon stop return - if is_local: - if config.get('server', None) is None: - logging.error('server addr not specified') - print_local_help() - sys.exit(2) - else: - config['server'] = to_str(config['server']) - else: - config['server'] = to_str(config.get('server', '0.0.0.0')) - try: - config['forbidden_ip'] = \ - IPNetwork(config.get('forbidden_ip', '127.0.0.0/8,::1/128')) - except Exception as e: - logging.error(e) - sys.exit(2) - if is_local and not config.get('password', None): logging.error('password not specified') print_help(is_local) sys.exit(2) if not is_local and not config.get('password', None) \ - and not config.get('port_password', None) \ - and not config.get('manager_address'): + and not config.get('port_password', None): logging.error('password or port_password not specified') print_help(is_local) sys.exit(2) @@ -194,14 +130,13 @@ def get_config(is_local): logging.basicConfig(level=logging.INFO, format='%(levelname)-s: %(message)s') if is_local: - shortopts = 'hd:s:b:p:k:l:m:c:t:vqa' + shortopts = 'hd:s:b:p:k:l:m:c:t:vq' longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'user=', 'version'] else: - shortopts = 'hd:s:p:k:m:c:t:vqa' + shortopts = 'hd:s:p:k:m:c:t:vq' longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'workers=', - 'forbidden-ip=', 'user=', 'manager-address=', 'version', - 'prefer-ipv6'] + 'forbidden-ip=', 'user=', 'manager-address=', 'version'] try: config_path = find_config() optlist, args = getopt.getopt(sys.argv[1:], shortopts, longopts) @@ -239,8 +174,6 @@ def get_config(is_local): v_count += 1 # '-vv' turns on more verbose mode config['verbose'] = v_count - elif key == '-a': - config['one_time_auth'] = True elif key == '-t': config['timeout'] = int(value) elif key == '--fast-open': @@ -271,8 +204,6 @@ def get_config(is_local): elif key == '-q': v_count -= 1 config['verbose'] = v_count - elif key == '--prefer-ipv6': - config['prefer_ipv6'] = True except getopt.GetoptError as e: print(e, file=sys.stderr) print_help(is_local) @@ -294,8 +225,21 @@ def get_config(is_local): config['verbose'] = config.get('verbose', False) config['local_address'] = to_str(config.get('local_address', '127.0.0.1')) config['local_port'] = config.get('local_port', 1080) - config['one_time_auth'] = config.get('one_time_auth', False) - config['prefer_ipv6'] = config.get('prefer_ipv6', False) + if is_local: + if config.get('server', None) is None: + logging.error('server addr not specified') + print_local_help() + sys.exit(2) + else: + config['server'] = to_str(config['server']) + else: + config['server'] = to_str(config.get('server', '0.0.0.0')) + try: + config['forbidden_ip'] = \ + IPNetwork(config.get('forbidden_ip', '127.0.0.0/8,::1/128')) + except Exception as e: + logging.error(e) + sys.exit(2) config['server_port'] = config.get('server_port', 8388) logging.getLogger('').handlers = [] @@ -342,7 +286,6 @@ Proxy options: -k PASSWORD password -m METHOD encryption method, default: aes-256-cfb -t TIMEOUT timeout in seconds, default: 300 - -a ONE_TIME_AUTH one time auth --fast-open use TCP_FASTOPEN, requires Linux 3.7+ General options: @@ -372,12 +315,10 @@ Proxy options: -k PASSWORD password -m METHOD encryption method, default: aes-256-cfb -t TIMEOUT timeout in seconds, default: 300 - -a ONE_TIME_AUTH one time auth --fast-open use TCP_FASTOPEN, requires Linux 3.7+ --workers WORKERS number of workers, available on Unix/Linux --forbidden-ip IPLIST comma seperated IP list forbidden to connect --manager-address ADDR optional server manager UDP address, see wiki - --prefer-ipv6 resolve ipv6 address first General options: -h, --help show this help message and exit diff --git a/shadowsocks/tcprelay.py b/shadowsocks/tcprelay.py index 207407a..25623ec 100644 --- a/shadowsocks/tcprelay.py +++ b/shadowsocks/tcprelay.py @@ -1,856 +1,387 @@ -#!/usr/bin/python +#!/usr/bin/env python # -*- coding: utf-8 -*- -# -# Copyright 2015 clowwindy -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. from __future__ import absolute_import, division, print_function, \ with_statement -import time -import socket import errno -import struct import logging -import traceback -import random +import socket -from shadowsocks import encrypt, eventloop, shell, common -from shadowsocks.common import parse_header, onetimeauth_verify, \ - onetimeauth_gen, ONETIMEAUTH_BYTES, ONETIMEAUTH_CHUNK_BYTES, \ - ONETIMEAUTH_CHUNK_DATA_LEN, ADDRTYPE_AUTH +from shadowsocks.selectors import (EVENT_READ, EVENT_WRITE, EVENT_ERROR, + errno_from_exception, get_sock_error) +from shadowsocks.common import parse_header, to_str +from shadowsocks import encrypt -# we clear at most TIMEOUTS_CLEAN_SIZE timeouts each time -TIMEOUTS_CLEAN_SIZE = 512 - -MSG_FASTOPEN = 0x20000000 - -# SOCKS METHOD definition -METHOD_NOAUTH = 0 - -# SOCKS command definition -CMD_CONNECT = 1 -CMD_BIND = 2 -CMD_UDP_ASSOCIATE = 3 - -# for each opening port, we have a TCP Relay - -# for each connection, we have a TCP Relay Handler to handle the connection - -# for each handler, we have 2 sockets: -# local: connected to the client -# remote: connected to remote server - -# for each handler, it could be at one of several stages: - -# as sslocal: -# stage 0 auth METHOD received from local, reply with selection message -# stage 1 addr received from local, query DNS for remote -# stage 2 UDP assoc -# stage 3 DNS resolved, connect to remote -# stage 4 still connecting, more data from local received -# stage 5 remote connected, piping local and remote - -# as ssserver: -# stage 0 just jump to stage 1 -# stage 1 addr received from local, query DNS for remote -# stage 3 DNS resolved, connect to remote -# stage 4 still connecting, more data from local received -# stage 5 remote connected, piping local and remote - -STAGE_INIT = 0 -STAGE_ADDR = 1 -STAGE_UDP_ASSOC = 2 -STAGE_DNS = 3 -STAGE_CONNECTING = 4 -STAGE_STREAM = 5 -STAGE_DESTROYED = -1 - -# for each handler, we have 2 stream directions: -# upstream: from client to server direction -# read local and write to remote -# downstream: from server to client direction -# read remote and write to local - -STREAM_UP = 0 -STREAM_DOWN = 1 - -# for each stream, it's waiting for reading, or writing, or both -WAIT_STATUS_INIT = 0 -WAIT_STATUS_READING = 1 -WAIT_STATUS_WRITING = 2 -WAIT_STATUS_READWRITING = WAIT_STATUS_READING | WAIT_STATUS_WRITING BUF_SIZE = 32 * 1024 +CMD_CONNECT = 1 -# helper exceptions for TCPRelayHandler +def create_sock(ip, port): + addrs = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM, + socket.SOL_TCP) + if len(addrs) == 0: + raise Exception("Getaddrinfo failed for %s:%d" % (ip, port)) -class BadSocksHeader(Exception): - pass + af, socktype, proto, canonname, sa = addrs[0] + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) + return sock -class NoAcceptableMethods(Exception): - pass +class TcpRelayHanler(object): - -class TCPRelayHandler(object): - def __init__(self, server, fd_to_handlers, loop, local_sock, config, - dns_resolver, is_local): - self._server = server - self._fd_to_handlers = fd_to_handlers - self._loop = loop + def __init__(self, local_sock, local_addr, remote_addr=None, + dns_resolver=None): + self._loop = None self._local_sock = local_sock - self._remote_sock = None - self._config = config + self._local_addr = local_addr + self._remote_addr = remote_addr + # self._crypt = None + self._crypt = encrypt.Encryptor(b'PassThrouthGFW', 'aes-256-cfb') self._dns_resolver = dns_resolver - - # TCP Relay works as either sslocal or ssserver - # if is_local, this is sslocal - self._is_local = is_local - self._stage = STAGE_INIT - self._encryptor = encrypt.Encryptor(config['password'], - config['method']) - self._ota_enable = config.get('one_time_auth', False) - self._ota_enable_session = self._ota_enable - self._ota_buff_head = b'' - self._ota_buff_data = b'' - self._ota_len = 0 - self._ota_chunk_idx = 0 - self._fastopen_connected = False + self._remote_sock = None + self._local_sock_mode = 0 + self._remote_sock_mode = 0 self._data_to_write_to_local = [] self._data_to_write_to_remote = [] - self._upstream_status = WAIT_STATUS_READING - self._downstream_status = WAIT_STATUS_INIT - self._client_address = local_sock.getpeername()[:2] - self._remote_address = None - self._forbidden_iplist = config.get('forbidden_ip') - if is_local: - self._chosen_server = self._get_a_server() - fd_to_handlers[local_sock.fileno()] = self - local_sock.setblocking(False) - local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) - loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR, - self._server) - self.last_activity = 0 - self._update_activity() + self._id = id(self) - def __hash__(self): - # default __hash__ is id / 16 - # we want to eliminate collisions - return id(self) - - @property - def remote_address(self): - return self._remote_address - - def _get_a_server(self): - server = self._config['server'] - server_port = self._config['server_port'] - if type(server_port) == list: - server_port = random.choice(server_port) - if type(server) == list: - server = random.choice(server) - logging.debug('chosen server: %s:%d', server, server_port) - return server, server_port - - def _update_activity(self, data_len=0): - # tell the TCP Relay we have activities recently - # else it will think we are inactive and timed out - self._server.update_activity(self, data_len) - - def _update_stream(self, stream, status): - # update a stream to a new waiting status - - # check if status is changed - # only update if dirty - dirty = False - if stream == STREAM_DOWN: - if self._downstream_status != status: - self._downstream_status = status - dirty = True - elif stream == STREAM_UP: - if self._upstream_status != status: - self._upstream_status = status - dirty = True - if not dirty: - return - - if self._local_sock: - event = eventloop.POLL_ERR - if self._downstream_status & WAIT_STATUS_WRITING: - event |= eventloop.POLL_OUT - if self._upstream_status & WAIT_STATUS_READING: - event |= eventloop.POLL_IN - self._loop.modify(self._local_sock, event) - if self._remote_sock: - event = eventloop.POLL_ERR - if self._downstream_status & WAIT_STATUS_READING: - event |= eventloop.POLL_IN - if self._upstream_status & WAIT_STATUS_WRITING: - event |= eventloop.POLL_OUT - self._loop.modify(self._remote_sock, event) - - def _write_to_sock(self, data, sock): - # write data to sock - # if only some of the data are written, put remaining in the buffer - # and update the stream to wait for writing - if not data or not sock: - return False - uncomplete = False - try: - l = len(data) - s = sock.send(data) - if s < l: - data = data[s:] - uncomplete = True - except (OSError, IOError) as e: - error_no = eventloop.errno_from_exception(e) - if error_no in (errno.EAGAIN, errno.EINPROGRESS, - errno.EWOULDBLOCK): - uncomplete = True - else: - shell.print_exception(e) - self.destroy() - return False - if uncomplete: - if sock == self._local_sock: - self._data_to_write_to_local.append(data) - self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING) - elif sock == self._remote_sock: - self._data_to_write_to_remote.append(data) - self._update_stream(STREAM_UP, WAIT_STATUS_WRITING) - else: - logging.error('write_all_to_sock:unknown socket') + def handle_event(self, sock, event, call, *args): + if event & EVENT_ERROR: + logging.error(get_sock_error(sock)) + self.close() else: - if sock == self._local_sock: - self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) - elif sock == self._remote_sock: - self._update_stream(STREAM_UP, WAIT_STATUS_READING) - else: - logging.error('write_all_to_sock:unknown socket') - return True - - def _handle_stage_connecting(self, data): - if not self._is_local: - if self._ota_enable_session: - self._ota_chunk_data(data, - self._data_to_write_to_remote.append) - else: - self._data_to_write_to_remote.append(data) - return - - if self._ota_enable_session: - data = self._ota_chunk_data_gen(data) - data = self._encryptor.encrypt(data) - self._data_to_write_to_remote.append(data) - - if self._config['fast_open'] and not self._fastopen_connected: - # for sslocal and fastopen, we basically wait for data and use - # sendto to connect try: - # only connect once - self._fastopen_connected = True - remote_sock = \ - self._create_remote_socket(self._chosen_server[0], - self._chosen_server[1]) - self._loop.add(remote_sock, eventloop.POLL_ERR, self._server) - data = b''.join(self._data_to_write_to_remote) - l = len(data) - s = remote_sock.sendto(data, MSG_FASTOPEN, - self._chosen_server) - if s < l: - data = data[s:] - self._data_to_write_to_remote = [data] - else: - self._data_to_write_to_remote = [] - self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) - except (OSError, IOError) as e: - if eventloop.errno_from_exception(e) == errno.EINPROGRESS: - # in this case data is not sent at all - self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) - elif eventloop.errno_from_exception(e) == errno.ENOTCONN: - logging.error('fast open not supported on this OS') - self._config['fast_open'] = False - self.destroy() - else: - shell.print_exception(e) - if self._config['verbose']: - traceback.print_exc() - self.destroy() + call(sock, event, *args) + except Exception as e: + logging.error(e) + self.close() - @shell.exception_handle(self_=True, destroy=True, conn_err=True) - def _handle_stage_addr(self, data): - if self._is_local: - cmd = common.ord(data[1]) - if cmd == CMD_UDP_ASSOCIATE: - logging.debug('UDP associate') - if self._local_sock.family == socket.AF_INET6: - header = b'\x05\x00\x00\x04' - else: - header = b'\x05\x00\x00\x01' - addr, port = self._local_sock.getsockname()[:2] - addr_to_send = socket.inet_pton(self._local_sock.family, - addr) - port_to_send = struct.pack('>H', port) - self._write_to_sock(header + addr_to_send + port_to_send, - self._local_sock) - self._stage = STAGE_UDP_ASSOC - # just wait for the client to disconnect - return - elif cmd == CMD_CONNECT: - # just trim VER CMD RSV - data = data[3:] - else: - logging.error('unknown command %d', cmd) - self.destroy() - return - header_result = parse_header(data) - if header_result is None: - raise Exception('can not parse header') - addrtype, remote_addr, remote_port, header_length = header_result - logging.info('connecting %s:%d from %s:%d' % - (common.to_str(remote_addr), remote_port, - self._client_address[0], self._client_address[1])) - if self._is_local is False: - # spec https://shadowsocks.org/en/spec/one-time-auth.html - self._ota_enable_session = addrtype & ADDRTYPE_AUTH - if self._ota_enable and not self._ota_enable_session: - logging.warn('client one time auth is required') - return - if self._ota_enable_session: - if len(data) < header_length + ONETIMEAUTH_BYTES: - logging.warn('one time auth header is too short') - return None - offset = header_length + ONETIMEAUTH_BYTES - _hash = data[header_length: offset] - _data = data[:header_length] - key = self._encryptor.decipher_iv + self._encryptor.key - if onetimeauth_verify(_hash, _data, key) is False: - logging.warn('one time auth fail') - self.destroy() - return - header_length += ONETIMEAUTH_BYTES - self._remote_address = (common.to_str(remote_addr), remote_port) - # pause reading - self._update_stream(STREAM_UP, WAIT_STATUS_WRITING) - self._stage = STAGE_DNS - if self._is_local: - # forward address to remote - self._write_to_sock((b'\x05\x00\x00\x01' - b'\x00\x00\x00\x00\x10\x10'), - self._local_sock) - # spec https://shadowsocks.org/en/spec/one-time-auth.html - # ATYP & 0x10 == 0x10, then OTA is enabled. - if self._ota_enable_session: - data = common.chr(addrtype | ADDRTYPE_AUTH) + data[1:] - key = self._encryptor.cipher_iv + self._encryptor.key - _header = data[:header_length] - sha110 = onetimeauth_gen(data, key) - data = _header + sha110 + data[header_length:] - data_to_send = self._encryptor.encrypt(data) - self._data_to_write_to_remote.append(data_to_send) - # notice here may go into _handle_dns_resolved directly - self._dns_resolver.resolve(self._chosen_server[0], - self._handle_dns_resolved) - else: - if self._ota_enable_session: - data = data[header_length:] - self._ota_chunk_data(data, - self._data_to_write_to_remote.append) - elif len(data) > header_length: - self._data_to_write_to_remote.append(data[header_length:]) - # notice here may go into _handle_dns_resolved directly - self._dns_resolver.resolve(remote_addr, - self._handle_dns_resolved) - - def _create_remote_socket(self, ip, port): - addrs = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM, - socket.SOL_TCP) - if len(addrs) == 0: - raise Exception("getaddrinfo failed for %s:%d" % (ip, port)) - af, socktype, proto, canonname, sa = addrs[0] - if self._forbidden_iplist: - if common.to_str(sa[0]) in self._forbidden_iplist: - raise Exception('IP %s is in forbidden list, reject' % - common.to_str(sa[0])) - remote_sock = socket.socket(af, socktype, proto) - self._remote_sock = remote_sock - self._fd_to_handlers[remote_sock.fileno()] = self - remote_sock.setblocking(False) - remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) - return remote_sock - - @shell.exception_handle(self_=True) - def _handle_dns_resolved(self, result, error): - if error: - addr, port = self._client_address[0], self._client_address[1] - logging.error('%s when handling connection from %s:%d' % - (error, addr, port)) - self.destroy() - return - if not (result and result[1]): - self.destroy() - return - - ip = result[1] - self._stage = STAGE_CONNECTING - remote_addr = ip - if self._is_local: - remote_port = self._chosen_server[1] - else: - remote_port = self._remote_address[1] - - if self._is_local and self._config['fast_open']: - # for fastopen: - # wait for more data arrive and send them in one SYN - self._stage = STAGE_CONNECTING - # we don't have to wait for remote since it's not - # created - self._update_stream(STREAM_UP, WAIT_STATUS_READING) - # TODO when there is already data in this packet - else: - # else do connect - remote_sock = self._create_remote_socket(remote_addr, - remote_port) - try: - remote_sock.connect((remote_addr, remote_port)) - except (OSError, IOError) as e: - if eventloop.errno_from_exception(e) == \ - errno.EINPROGRESS: - pass - self._loop.add(remote_sock, - eventloop.POLL_ERR | eventloop.POLL_OUT, - self._server) - self._stage = STAGE_CONNECTING - self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) - self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) - - def _write_to_sock_remote(self, data): - self._write_to_sock(data, self._remote_sock) - - def _ota_chunk_data(self, data, data_cb): - # spec https://shadowsocks.org/en/spec/one-time-auth.html - unchunk_data = b'' - while len(data) > 0: - if self._ota_len == 0: - # get DATA.LEN + HMAC-SHA1 - length = ONETIMEAUTH_CHUNK_BYTES - len(self._ota_buff_head) - self._ota_buff_head += data[:length] - data = data[length:] - if len(self._ota_buff_head) < ONETIMEAUTH_CHUNK_BYTES: - # wait more data - return - data_len = self._ota_buff_head[:ONETIMEAUTH_CHUNK_DATA_LEN] - self._ota_len = struct.unpack('>H', data_len)[0] - length = min(self._ota_len - len(self._ota_buff_data), len(data)) - self._ota_buff_data += data[:length] - data = data[length:] - if len(self._ota_buff_data) == self._ota_len: - # get a chunk data - _hash = self._ota_buff_head[ONETIMEAUTH_CHUNK_DATA_LEN:] - _data = self._ota_buff_data - index = struct.pack('>I', self._ota_chunk_idx) - key = self._encryptor.decipher_iv + index - if onetimeauth_verify(_hash, _data, key) is False: - logging.warn('one time auth fail, drop chunk !') - else: - unchunk_data += _data - self._ota_chunk_idx += 1 - self._ota_buff_head = b'' - self._ota_buff_data = b'' - self._ota_len = 0 - data_cb(unchunk_data) - return - - def _ota_chunk_data_gen(self, data): - data_len = struct.pack(">H", len(data)) - index = struct.pack('>I', self._ota_chunk_idx) - key = self._encryptor.cipher_iv + index - sha110 = onetimeauth_gen(data, key) - self._ota_chunk_idx += 1 - return data_len + sha110 + data - - def _handle_stage_stream(self, data): - if self._is_local: - if self._ota_enable_session: - data = self._ota_chunk_data_gen(data) - data = self._encryptor.encrypt(data) - self._write_to_sock(data, self._remote_sock) - else: - if self._ota_enable_session: - self._ota_chunk_data(data, self._write_to_sock_remote) - else: - self._write_to_sock(data, self._remote_sock) - return - - def _check_auth_method(self, data): - # VER, NMETHODS, and at least 1 METHODS - if len(data) < 3: - logging.warning('method selection header too short') - raise BadSocksHeader - socks_version = common.ord(data[0]) - nmethods = common.ord(data[1]) - if socks_version != 5: - logging.warning('unsupported SOCKS protocol version ' + - str(socks_version)) - raise BadSocksHeader - if nmethods < 1 or len(data) != nmethods + 2: - logging.warning('NMETHODS and number of METHODS mismatch') - raise BadSocksHeader - noauth_exist = False - for method in data[2:]: - if common.ord(method) == METHOD_NOAUTH: - noauth_exist = True - break - if not noauth_exist: - logging.warning('none of SOCKS METHOD\'s ' - 'requested by client is supported') - raise NoAcceptableMethods - - def _handle_stage_init(self, data): - try: - self._check_auth_method(data) - except BadSocksHeader: - self.destroy() - return - except NoAcceptableMethods: - self._write_to_sock(b'\x05\xff', self._local_sock) - self.destroy() - return - - self._write_to_sock(b'\x05\00', self._local_sock) - self._stage = STAGE_ADDR - - def _on_local_read(self): - # handle all local read events and dispatch them to methods for - # each stage - if not self._local_sock: - return - is_local = self._is_local - data = None - try: - data = self._local_sock.recv(BUF_SIZE) - except (OSError, IOError) as e: - if eventloop.errno_from_exception(e) in \ - (errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK): - return - if not data: - self.destroy() - return - self._update_activity(len(data)) - if not is_local: - data = self._encryptor.decrypt(data) - if not data: - return - if self._stage == STAGE_STREAM: - self._handle_stage_stream(data) - return - elif is_local and self._stage == STAGE_INIT: - self._handle_stage_init(data) - elif self._stage == STAGE_CONNECTING: - self._handle_stage_connecting(data) - elif (is_local and self._stage == STAGE_ADDR) or \ - (not is_local and self._stage == STAGE_INIT): - self._handle_stage_addr(data) - - def _on_remote_read(self): - # handle all remote read events - data = None - try: - data = self._remote_sock.recv(BUF_SIZE) - - except (OSError, IOError) as e: - if eventloop.errno_from_exception(e) in \ - (errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK): - return - if not data: - self.destroy() - return - self._update_activity(len(data)) - if self._is_local: - data = self._encryptor.decrypt(data) - else: - data = self._encryptor.encrypt(data) - try: - self._write_to_sock(data, self._local_sock) - except Exception as e: - shell.print_exception(e) - if self._config['verbose']: - traceback.print_exc() - # TODO use logging when debug completed - self.destroy() - - def _on_local_write(self): - # handle local writable event - if self._data_to_write_to_local: - data = b''.join(self._data_to_write_to_local) - self._data_to_write_to_local = [] - self._write_to_sock(data, self._local_sock) - else: - self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) - - def _on_remote_write(self): - # handle remote writable event - self._stage = STAGE_STREAM - if self._data_to_write_to_remote: - data = b''.join(self._data_to_write_to_remote) - self._data_to_write_to_remote = [] - self._write_to_sock(data, self._remote_sock) - else: - self._update_stream(STREAM_UP, WAIT_STATUS_READING) - - def _on_local_error(self): - logging.debug('got local error') - if self._local_sock: - logging.error(eventloop.get_sock_error(self._local_sock)) - self.destroy() - - def _on_remote_error(self): - logging.debug('got remote error') - if self._remote_sock: - logging.error(eventloop.get_sock_error(self._remote_sock)) - self.destroy() - - def handle_event(self, sock, event): - # handle all events in this handler and dispatch them to methods - if self._stage == STAGE_DESTROYED: - logging.debug('ignore handle_event: destroyed') - return - # order is important - if sock == self._remote_sock: - if event & eventloop.POLL_ERR: - self._on_remote_error() - if self._stage == STAGE_DESTROYED: - return - if event & (eventloop.POLL_IN | eventloop.POLL_HUP): - self._on_remote_read() - if self._stage == STAGE_DESTROYED: - return - if event & eventloop.POLL_OUT: - self._on_remote_write() - elif sock == self._local_sock: - if event & eventloop.POLL_ERR: - self._on_local_error() - if self._stage == STAGE_DESTROYED: - return - if event & (eventloop.POLL_IN | eventloop.POLL_HUP): - self._on_local_read() - if self._stage == STAGE_DESTROYED: - return - if event & eventloop.POLL_OUT: - self._on_local_write() - else: - logging.warn('unknown socket') - - def destroy(self): - # destroy the handler and release any resources - # promises: - # 1. destroy won't make another destroy() call inside - # 2. destroy releases resources so it prevents future call to destroy - # 3. destroy won't raise any exceptions - # if any of the promises are broken, it indicates a bug has been - # introduced! mostly likely memory leaks, etc - if self._stage == STAGE_DESTROYED: - # this couldn't happen - logging.debug('already destroyed') - return - self._stage = STAGE_DESTROYED - if self._remote_address: - logging.debug('destroy: %s:%d' % - self._remote_address) - else: - logging.debug('destroy') - if self._remote_sock: - logging.debug('destroying remote') - self._loop.remove(self._remote_sock) - del self._fd_to_handlers[self._remote_sock.fileno()] - self._remote_sock.close() - self._remote_sock = None - if self._local_sock: - logging.debug('destroying local') - self._loop.remove(self._local_sock) - del self._fd_to_handlers[self._local_sock.fileno()] - self._local_sock.close() - self._local_sock = None - self._dns_resolver.remove_callback(self._handle_dns_resolved) - self._server.remove_handler(self) - - -class TCPRelay(object): - def __init__(self, config, dns_resolver, is_local, stat_callback=None): - self._config = config - self._is_local = is_local - self._dns_resolver = dns_resolver - self._closed = False - self._eventloop = None - self._fd_to_handlers = {} - - self._timeout = config['timeout'] - self._timeouts = [] # a list for all the handlers - # we trim the timeouts once a while - self._timeout_offset = 0 # last checked position for timeout - self._handler_to_timeouts = {} # key: handler value: index in timeouts - - if is_local: - listen_addr = config['local_address'] - listen_port = config['local_port'] - else: - listen_addr = config['server'] - listen_port = config['server_port'] - self._listen_port = listen_port - - addrs = socket.getaddrinfo(listen_addr, listen_port, 0, - socket.SOCK_STREAM, socket.SOL_TCP) - if len(addrs) == 0: - raise Exception("can't get addrinfo for %s:%d" % - (listen_addr, listen_port)) - af, socktype, proto, canonname, sa = addrs[0] - server_socket = socket.socket(af, socktype, proto) - server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_socket.bind(sa) - server_socket.setblocking(False) - if config['fast_open']: - try: - server_socket.setsockopt(socket.SOL_TCP, 23, 5) - except socket.error: - logging.error('warning: fast open is not available') - self._config['fast_open'] = False - server_socket.listen(1024) - self._server_socket = server_socket - self._stat_callback = stat_callback + def __del__(self): + logging.debug('Deleting {}'.format(self._id)) def add_to_loop(self, loop): - if self._eventloop: - raise Exception('already add to loop') - if self._closed: - raise Exception('already closed') - self._eventloop = loop - self._eventloop.add(self._server_socket, - eventloop.POLL_IN | eventloop.POLL_ERR, self) - self._eventloop.add_periodic(self.handle_periodic) + if self._loop: + raise Exception('Already added to loop') + self._loop = loop + loop.add(self._local_sock, EVENT_READ, (self, self.start)) - def remove_handler(self, handler): - index = self._handler_to_timeouts.get(hash(handler), -1) - if index >= 0: - # delete is O(n), so we just set it to None - self._timeouts[index] = None - del self._handler_to_timeouts[hash(handler)] + def modify_local_sock_mode(self, event): + if self._local_sock_mode != event: + self._local_sock_mode = self.modify_sock_mode(self._local_sock, + event) - def update_activity(self, handler, data_len): - if data_len and self._stat_callback: - self._stat_callback(self._listen_port, data_len) + def modify_remote_sock_mode(self, event): + if self._remote_sock_mode != event: + self._remote_sock_mode = self.modify_sock_mode(self._remote_sock, + event) - # set handler to active - now = int(time.time()) - if now - handler.last_activity < eventloop.TIMEOUT_PRECISION: - # thus we can lower timeout modification frequency - return - handler.last_activity = now - index = self._handler_to_timeouts.get(hash(handler), -1) - if index >= 0: - # delete is O(n), so we just set it to None - self._timeouts[index] = None - length = len(self._timeouts) - self._timeouts.append(handler) - self._handler_to_timeouts[hash(handler)] = length + def modify_sock_mode(self, sock, event): + key = self._loop.modify(sock, event, (self, self.stream)) + return key.events - def _sweep_timeout(self): - # tornado's timeout memory management is more flexible than we need - # we just need a sorted last_activity queue and it's faster than heapq - # in fact we can do O(1) insertion/remove so we invent our own - if self._timeouts: - logging.log(shell.VERBOSE_LEVEL, 'sweeping timeouts') - now = time.time() - length = len(self._timeouts) - pos = self._timeout_offset - while pos < length: - handler = self._timeouts[pos] - if handler: - if now - handler.last_activity < self._timeout: - break - else: - if handler.remote_address: - logging.warn('timed out: %s:%d' % - handler.remote_address) - else: - logging.warn('timed out') - handler.destroy() - self._timeouts[pos] = None # free memory - pos += 1 - else: - pos += 1 - if pos > TIMEOUTS_CLEAN_SIZE and pos > length >> 1: - # clean up the timeout queue when it gets larger than half - # of the queue - self._timeouts = self._timeouts[pos:] - for key in self._handler_to_timeouts: - self._handler_to_timeouts[key] -= pos - pos = 0 - self._timeout_offset = pos + def close_sock(self, sock): + self._loop.remove(sock) + sock.close() - def handle_event(self, sock, fd, event): - # handle events and dispatch to handlers - if sock: - logging.log(shell.VERBOSE_LEVEL, 'fd %d %s', fd, - eventloop.EVENT_NAMES.get(event, event)) - if sock == self._server_socket: - if event & eventloop.POLL_ERR: - # TODO - raise Exception('server_socket error') + def close(self): + if self._local_sock: + self.close_sock(self._local_sock) + self._local_sock = None + if self._remote_sock: + self.close_sock(self._remote_sock) + self._remote_sock = None + + def sock_connect(self, sock, addr): + while True: try: - logging.debug('accept') - conn = self._server_socket.accept() - TCPRelayHandler(self, self._fd_to_handlers, - self._eventloop, conn[0], self._config, - self._dns_resolver, self._is_local) + sock.connect(addr) except (OSError, IOError) as e: - error_no = eventloop.errno_from_exception(e) - if error_no in (errno.EAGAIN, errno.EINPROGRESS, - errno.EWOULDBLOCK): - return + err = errno_from_exception(e) + if err == errno.EINTR: + pass + elif err == errno.EINPROGRESS: + break else: - shell.print_exception(e) - if self._config['verbose']: - traceback.print_exc() - else: - if sock: - handler = self._fd_to_handlers.get(fd, None) - if handler: - handler.handle_event(sock, event) + raise else: - logging.warn('poll removed fd') + break - def handle_periodic(self): - if self._closed: - if self._server_socket: - self._eventloop.remove(self._server_socket) - self._server_socket.close() - self._server_socket = None - logging.info('closed TCP port %d', self._listen_port) - if not self._fd_to_handlers: - logging.info('stopping') - self._eventloop.stop() - self._sweep_timeout() + def sock_recv(self, sock, size=BUF_SIZE): + try: + data = sock.recv(size) + if not data: + self.close() + except (OSError, IOError) as e: + if errno_from_exception(e) in (errno.EAGAIN, errno.EWOULDBLOCK, + errno.EINTR): + return + else: + raise + return data - def close(self, next_tick=False): - logging.debug('TCP close') - self._closed = True - if not next_tick: - if self._eventloop: - self._eventloop.remove_periodic(self.handle_periodic) - self._eventloop.remove(self._server_socket) - self._server_socket.close() - for handler in list(self._fd_to_handlers.values()): - handler.destroy() + def sock_send(self, sock, data): + try: + s = sock.send(data) + data = data[s:] + except (OSError, IOError) as e: + if errno_from_exception(e) in (errno.EAGAIN, errno.EWOULDBLOCK, + errno.EINPROGRESS, errno.EINTR): + pass + else: + raise + + return data + + def on_local_read(self, size=BUF_SIZE): + logging.debug('on_local_read') + if not self._local_sock: + return + + data = self.sock_recv(self._local_sock, size) + if not data: + return + + logging.debug('Received {} bytes from {}:{}'.format(len(data), + *self._local_addr)) + if self._crypt: + if self._is_client: + data = self._crypt.encrypt(data) + else: + data = self._crypt.decrypt(data) + + if data: + self._data_to_write_to_remote.append(data) + self.on_remote_write() + + def on_remote_read(self, size=BUF_SIZE): + logging.debug('on_remote_read') + if not self._remote_sock: + return + + data = self.sock_recv(self._remote_sock, size) + if not data: + return + + logging.debug('Received {} bytes from {}:{}'.format( + len(data), *self._remote_addr)) + + if self._crypt: + if self._is_client: + data = self._crypt.decrypt(data) + else: + data = self._crypt.encrypt(data) + + if data: + self._data_to_write_to_local.append(data) + self.on_local_write() + + def on_local_write(self): + logging.debug('on_local_write') + if not self._local_sock: + return + + if not self._data_to_write_to_local: + self.modify_local_sock_mode(EVENT_READ) + return + + data = b''.join(self._data_to_write_to_local) + self._data_to_write_to_local = [] + + data = self.sock_send(self._local_sock, data) + + if data: + self._data_to_write_to_local.append(data) + self.modify_local_sock_mode(EVENT_WRITE) + else: + self.modify_local_sock_mode(EVENT_READ) + + def on_remote_write(self): + logging.debug('on_remote_write') + if not self._remote_sock: + return + + if not self._data_to_write_to_remote: + self.modify_remote_sock_mode(EVENT_READ) + return + + data = b''.join(self._data_to_write_to_remote) + self._data_to_write_to_remote = [] + + data = self.sock_send(self._remote_sock, data) + + if data: + self._data_to_write_to_remote.append(data) + self.modify_remote_sock_mode(EVENT_WRITE) + else: + self.modify_remote_sock_mode(EVENT_READ) + + def stream(self, sock, event): + logging.debug('stream') + + if sock == self._local_sock: + if event & EVENT_READ: + self.on_local_read() + if event & EVENT_WRITE: + self.on_local_write() + elif sock == self._remote_sock: + if event & EVENT_READ: + self.on_remote_read() + if event & EVENT_WRITE: + self.on_remote_write() + else: + logging.warn('Unknow sock {}'.format(sock)) + + +class TcpRelayClientHanler(TcpRelayHanler): + + _is_client = True + + def start(self, sock, event): + data = self.sock_recv(sock) + if not data: + return + reply = b'\x05\x00' + self.send_reply(sock, None, reply) + + def send_reply(self, sock, event, data): + data = self.sock_send(sock, data) + if data: + self._loop.modify(sock, EVENT_WRITE, (self, self.send_reply, data)) + else: + self._loop.modify(sock, EVENT_READ, (self, self.handle_addr)) + + def handle_addr(self, sock, event): + data = self.sock_recv(sock) + if not data: + return + # self._loop.remove(sock) + + if ord(data[1:2]) != CMD_CONNECT: + raise Exception('Command not suppored') + + result = parse_header(data[3:]) + if not result: + raise Exception('Header cannot be parsed') + + self._remote_sock = create_sock(*self._remote_addr) + self.sock_connect(self._remote_sock, self._remote_addr) + + dest_addr = (to_str(result[1]), result[2]) + logging.info('Connecting to {}:{}'.format(*dest_addr)) + data = '{}:{}\n'.format(*dest_addr).encode('utf-8') + if self._crypt: + data = self._crypt.encrypt(data) + self._data_to_write_to_remote.append(data) + + bind_addr = b'\x05\x00\x00\x01\x00\x00\x00\x00\x00\x00' + self.send_bind_addr(sock, None, bind_addr) + + def send_bind_addr(self, sock, event, data): + data = self.sock_send(sock, data) + if data: + self._loop.modify(sock, EVENT_WRITE, (self, self.send_bind_addr, + data)) + else: + self.modify_local_sock_mode(EVENT_READ) + + +class TcpRelayServerHandler(TcpRelayHanler): + + _is_client = False + + def start(self, sock, event, data=None): + data = self.sock_recv(sock) + if not data: + return + self._loop.remove(sock) + + if self._crypt: + data = self._crypt.decrypt(data) + remote, data = data.split(b'\n', 1) + host, port = remote.split(b':') + self._remote = (host, int(port)) + self._data_to_write_to_remote.append(data) + self._dns_resolver.resolve(host, self.dns_resolved) + + def dns_resolved(self, result, error): + try: + ip = result[1] + except (TypeError, IndexError): + ip = None + + if not ip: + raise Exception('Hostname {} cannot resolved'.format( + self._remote[0])) + + self._remote_addr = (ip, self._remote[1]) + self._remote_sock = create_sock(*self._remote_addr) + logging.info('Connecting to {}'.format(self._remote[0])) + self.sock_connect(self._remote_sock, self._remote_addr) + self.modify_remote_sock_mode(EVENT_WRITE) + self.modify_local_sock_mode(EVENT_READ) + + +class TcpRelay(object): + + def __init__(self, handler_type, listen_addr, remote_addr=None, + dns_resolver=None): + self._loop = None + self._handler_type = handler_type + self._listen_addr = listen_addr + self._remote_addr = remote_addr + self._dns_resolver = dns_resolver + self._create_listen_sock() + + def _create_listen_sock(self): + sock = create_sock(*self._listen_addr) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(self._listen_addr) + sock.listen(1024) + self._listen_sock = sock + logging.info('Listening on {}:{}'.format(*self._listen_addr)) + + def add_to_loop(self, loop): + if self._loop: + raise Exception('Already added to loop') + self._loop = loop + loop.add(self._listen_sock, EVENT_READ, (self, self.accept)) + + def _accept(self, listen_sock): + try: + sock, addr = listen_sock.accept() + except (OSError, IOError) as e: + if errno_from_exception(e) in ( + errno.EAGAIN, errno.EWOULDBLOCK, errno.EINPROGRESS, + errno.EINTR, errno.ECONNABORTED + ): + pass + else: + raise + logging.info('Connected from {}:{}'.format(*addr)) + sock.setblocking(False) + sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) + return (sock, addr) + + def accept(self, listen_sock, event): + sock, addr = self._accept(listen_sock) + handler = self._handler_type(sock, addr, self._remote_addr, + self._dns_resolver) + handler.add_to_loop(self._loop) + + def close(self): + self._loop.remove(self._listen_sock) + + def handle_event(self, sock, event, call, *args): + if event & EVENT_ERROR: + logging.error(get_sock_error(sock)) + self.close() + else: + try: + call(sock, event, *args) + except Exception as e: + logging.error(e) + self.close()