diff --git a/server.py b/server.py index c94dac6..5adb7b0 100755 --- a/server.py +++ b/server.py @@ -26,23 +26,14 @@ if sys.version_info < (2, 6): import simplejson as json else: import json - -try: - import gevent, gevent.monkey - gevent.monkey.patch_all(dns=gevent.version_info[0]>=1) -except ImportError: - gevent = None - print >>sys.stderr, 'warning: gevent not found, using threading instead' - -import socket -import select -import SocketServer import struct import string import hashlib import os import logging import getopt +import socket + def get_table(key): m = hashlib.md5() @@ -54,108 +45,145 @@ def get_table(key): table.sort(lambda x, y: int(a % (ord(x) + i) - a % (ord(y) + i))) return table -def send_all(sock, data): - bytes_sent = 0 - while True: - r = sock.send(data[bytes_sent:]) - if r < 0: - return r - bytes_sent += r - if bytes_sent == len(data): - return bytes_sent -class ThreadingTCPServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer): - allow_reuse_address = True +def encrypt(data): + return data.translate(encrypt_table) -class Socks5Server(SocketServer.StreamRequestHandler): - def handle_tcp(self, sock, remote): - try: - fdset = [sock, remote] - while True: - r, w, e = select.select(fdset, [], []) - if sock in r: - data = sock.recv(4096) - if len(data) <= 0: - break - result = send_all(remote, self.decrypt(data)) - if result < len(data): - raise Exception('failed to send all data') - if remote in r: - data = remote.recv(4096) - if len(data) <= 0: - break - result = send_all(sock, self.encrypt(data)) - if result < len(data): - raise Exception('failed to send all data') +def decrypt(data): + return data.translate(decrypt_table) - finally: - sock.close() - remote.close() - def encrypt(self, data): - return data.translate(encrypt_table) +class RemoteHandler(object): + def __init__(self, conn, local_handler): + self.conn = conn + self.local_handler = local_handler + conn.on('connect', self.on_connect) + conn.on('data', self.on_data) + conn.on('close', self.on_close) + conn.on('end', self.on_end) + conn.connect(local_handler.remote_addr_pair) - def decrypt(self, data): - return data.translate(decrypt_table) + def on_connect(self, s): + for piece in self.local_handler.cached_pieces: + self.conn.write(decrypt(piece)) + # TODO write cached pieces + self.local_handler.stage = 5 - def handle(self): - try: - sock = self.connection - addrtype = ord(self.decrypt(sock.recv(1))) - if addrtype == 1: - addr = socket.inet_ntoa(self.decrypt(self.rfile.read(4))) - elif addrtype == 3: - addr = self.decrypt( - self.rfile.read(ord(self.decrypt(sock.recv(1))))) - elif addrtype == 4: - addr = socket.inet_ntop(socket.AF_INET6, self.decrypt(self.rfile.read(16))) - else: - # not support - logging.warn('addr_type not support') - return - port = struct.unpack('>H', self.decrypt(self.rfile.read(2))) + def on_data(self, s, data): + data = encrypt(data) + self.local_handler.conn.write(data) + + def on_close(self, s): + # self.local_handler.conn.end() + pass + + def on_end(self, s): + self.local_handler.conn.end() + + +class LocalHandler(object): + def on_data(self, s, data): + if self.stage == 5: + data = decrypt(data) + self.remote_handler.conn.write(data) + return + if self.stage == 0: try: - logging.info('connecting %s:%d' % (addr, port[0])) - remote = socket.create_connection((addr, port[0])) - except socket.error, e: - # Connection refused - logging.warn(e) + addrtype = ord(data[0]) + # TODO check cmd == 1 + if addrtype == 1: + remote_addr = socket.inet_ntoa(data[1:5]) + remote_port = data[5:7] + header_length = 7 + elif addrtype == 4: + remote_addr = socket.inet_ntop(data[1:17]) + remote_port = data[17:19] + header_length = 19 + elif addrtype == 3: + addr_len = ord(data[1]) + remote_addr = data[2:2 + addr_len] + remote_port = data[2 + addr_len:2 + addr_len + 2] + header_length = 2 + addr_len + 2 + else: + # TODO check addrtype in (1, 3, 4) + raise + remote_port = struct.unpack('>H', remote_port)[0] + self.remote_addr_pair = (remote_addr, remote_port) + logging.info('connecting %s:%d' % self.remote_addr_pair) + remote_conn = ssloop.Socket() + self.remote_handler = RemoteHandler(remote_conn, self) + + if len(data) > header_length: + self.cached_pieces.append(data[header_length:]) + + # TODO save other bytes + self.stage = 4 return - self.handle_tcp(sock, remote) - except socket.error, e: - logging.warn(e) + except: + import traceback + traceback.print_exc() + + if self.stage == 4: + self.cached_pieces.append(data) + + def on_end(self, s): + if self.remote_handler: + self.remote_handler.conn.end() + + def on_close(self, s): + pass + # self.remote_handler.conn.end() + + def __init__(self, conn): + self.stage = 0 + self.remote = None + self.addr_len = 0 + self.addr_to_send = '' + self.conn = conn + self.cached_pieces = [] + + conn.on('data', self.on_data) + conn.on('end', self.on_end) + conn.on('close', self.on_close) + + +def on_connection(s, conn): + LocalHandler(conn) if __name__ == '__main__': os.chdir(os.path.dirname(__file__) or '.') - - print 'shadowsocks v1.1' + sys.path.append('./ssloop') + import ssloop + print 'shadowsocks v2.0' with open('config.json', 'rb') as f: config = json.load(f) - - SERVER = config['server'] PORT = config['server_port'] KEY = config['password'] - optlist, args = getopt.getopt(sys.argv[1:], 'p:k:') + argv = sys.argv[1:] + if '-6' in sys.argv[1:]: + argv.remove('-6') + + optlist, args = getopt.getopt(argv, 'p:k:') for key, value in optlist: if key == '-p': PORT = int(value) elif key == '-k': KEY = value - logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)-8s %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', filemode='a+') + logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', filemode='a+') encrypt_table = ''.join(get_table(KEY)) decrypt_table = string.maketrans(encrypt_table, string.maketrans('', '')) - if '-6' in sys.argv[1:]: - ThreadingTCPServer.address_family = socket.AF_INET6 try: - server = ThreadingTCPServer(('', PORT), Socks5Server) logging.info("starting server at port %d ..." % PORT) - server.serve_forever() - except socket.error, e: - logging.error(e) - + loop = ssloop.instance() + s = ssloop.Server(('0.0.0.0', PORT)) + s.on('connection', on_connection) + s.listen() + loop.start() + except KeyboardInterrupt: + sys.exit(0)