diff --git a/shadowsocks/local.py b/shadowsocks/local.py index 2b7d7f4..7f8fd40 100755 --- a/shadowsocks/local.py +++ b/shadowsocks/local.py @@ -50,6 +50,9 @@ import utils import udprelay +MSG_FASTOPEN = 0x20000000 + + def send_all(sock, data): bytes_sent = 0 while True: @@ -84,18 +87,30 @@ class Socks5Server(SocketServer.StreamRequestHandler): aPort = int(r.group(2)) return (aServer, aPort) - def handle_tcp(self, sock, remote): + def handle_tcp(self, sock, remote, pending_data=None, server=None, port=None): + connected = False try: - fdset = [sock, remote] + if FAST_OPEN: + fdset = [sock] + else: + fdset = [sock, remote] while True: r, w, e = select.select(fdset, [], []) if sock in r: - data = self.encrypt(sock.recv(4096)) - if len(data) <= 0: - break - result = send_all(remote, data) - if result < len(data): - raise Exception('failed to send all data') + if not connected and FAST_OPEN: + data = sock.recv(4096) + data = self.encrypt(pending_data + data) + remote.sendto(data, MSG_FASTOPEN, (server, port)) + connected = True + fdset = [sock, remote] + logging.info('fast open %s:%d' % (server, port)) + else: + data = self.encrypt(sock.recv(4096)) + if len(data) <= 0: + break + result = send_all(remote, data) + if result < len(data): + raise Exception('failed to send all data') if remote in r: data = self.decrypt(remote.recv(4096)) @@ -202,24 +217,37 @@ class Socks5Server(SocketServer.StreamRequestHandler): self.wfile.write(reply) # reply immediately aServer, aPort = self.getServer() - MSG_FASTOPEN = 0x20000000 - remote = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - # remote = socket.create_connection((aServer, aPort)) - remote.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - data = self.encrypt(addr_to_send) - remote.sendto(data, MSG_FASTOPEN, (aServer, aPort)) - # self.send_encrypt(remote, addr_to_send) - logging.info('connecting %s:%d' % (addr, port[0])) - except socket.error, e: - logging.warn(e) - return - self.handle_tcp(sock, remote) - except socket.error, e: - logging.warn(e) + addrs = socket.getaddrinfo(aServer, aPort) + if addrs: + af, socktype, proto, canonname, sa = addrs[0] + if FAST_OPEN: + remote = socket.socket(af, socktype, proto) + # remote.setsockopt(socket.IPPROTO_TCP, + # socket.TCP_NODELAY, 1) + self.handle_tcp(sock, remote, addr_to_send, aServer, + aPort) + else: + remote = socket.create_connection((aServer, aPort)) + remote.setsockopt(socket.IPPROTO_TCP, + socket.TCP_NODELAY, 1) + self.send_encrypt(remote, addr_to_send) + logging.info('connecting %s:%d' % (addr, port[0])) + self.handle_tcp(sock, remote) + finally: + pass + # except socket.error, e: + # raise e + # logging.warn(e) + # return + finally: + pass + # except socket.error, e: + # raise e + # logging.warn(e) def main(): - global SERVER, REMOTE_PORT, KEY, METHOD + global SERVER, REMOTE_PORT, KEY, METHOD, FAST_OPEN logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)-8s %(message)s', @@ -289,6 +317,7 @@ def main(): METHOD = config.get('method', None) LOCAL = config.get('local_address', '127.0.0.1') TIMEOUT = config.get('timeout', 600) + FAST_OPEN = config.get('fast_open', False) if not KEY and not config_path: sys.exit('config not specified, please read '