From c721a1c02f8a8278ccd7d12e619c84667111ae67 Mon Sep 17 00:00:00 2001 From: clowwindy Date: Sun, 1 Jun 2014 15:58:37 +0800 Subject: [PATCH] local works --- shadowsocks/tcprelay.py | 69 +++++++++++++++++++++++++++++++---------- shadowsocks/utils.py | 6 ++-- 2 files changed, 56 insertions(+), 19 deletions(-) diff --git a/shadowsocks/tcprelay.py b/shadowsocks/tcprelay.py index b7902d9..2a9d552 100644 --- a/shadowsocks/tcprelay.py +++ b/shadowsocks/tcprelay.py @@ -53,8 +53,9 @@ STREAM_UP = 0 STREAM_DOWN = 1 # stream status -STATUS_WAIT_READING = 0 -STATUS_WAIT_WRITING = 1 +STATUS_WAIT_INIT = 0 +STATUS_WAIT_READING = 1 +STATUS_WAIT_WRITING = 2 BUF_SIZE = 8 * 1024 @@ -72,8 +73,8 @@ class TCPRelayHandler(object): config['method']) self._data_to_write_to_local = [] self._data_to_write_to_remote = [] - self._upstream_status = STATUS_WAIT_READING - self._downstream_status = STATUS_WAIT_READING + self._upstream_status = STATUS_WAIT_INIT + self._downstream_status = STATUS_WAIT_INIT fd_to_handlers[local_sock.fileno()] = self local_sock.setblocking(False) loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR) @@ -132,12 +133,24 @@ class TCPRelayHandler(object): logging.error('write_all_to_sock:unknown socket') def on_local_read(self): + # TODO update timeout if not self._local_sock: return is_local = self._is_local - data = self._local_sock.recv(BUF_SIZE) + data = None + try: + data = self._local_sock.recv(BUF_SIZE) + except (OSError, IOError) as e: + if eventloop.errno_from_exception(e) in \ + (errno.ETIMEDOUT, errno.EAGAIN): + return + if not data: + self.destroy() + return if not is_local: data = self._encryptor.decrypt(data) + if not data: + return if self._stage == STAGE_STREAM: if self._is_local: data = self._encryptor.encrypt(data) @@ -167,13 +180,17 @@ class TCPRelayHandler(object): logging.info('connecting %s:%d' % (remote_addr, remote_port)) if is_local: # forward address to remote - self._data_to_write_to_remote.append(data[:header_length]) self.write_all_to_sock('\x05\x00\x00\x01' + '\x00\x00\x00\x00\x10\x10', self._local_sock) - else: + data_to_send = self._encryptor.encrypt(data) + self._data_to_write_to_remote.append(data_to_send) remote_addr = self._config['server'] remote_port = self._config['server_port'] + else: + if len(data) > header_length: + self._data_to_write_to_remote.append( + data[header_length:]) # TODO async DNS addrs = socket.getaddrinfo(remote_addr, remote_port, 0, @@ -183,17 +200,20 @@ class TCPRelayHandler(object): (remote_addr, remote_port)) af, socktype, proto, canonname, sa = addrs[0] self._remote_sock = socket.socket(af, socktype, proto) + self._fd_to_handlers[self._remote_sock.fileno()] = self self._remote_sock.setblocking(False) # TODO support TCP fast open - self._remote_sock.connect(sa) + try: + self._remote_sock.connect(sa) + except (OSError, IOError) as e: + if eventloop.errno_from_exception(e) == errno.EINPROGRESS: + pass self._loop.add(self._remote_sock, eventloop.POLL_ERR | eventloop.POLL_OUT) - if len(data) > header_length: - self._data_to_write_to_remote.append(data[header_length:]) - self._stage = 4 self.update_stream(STREAM_UP, STATUS_WAIT_WRITING) + self.update_stream(STREAM_DOWN, STATUS_WAIT_READING) return except Exception: import traceback @@ -205,7 +225,17 @@ class TCPRelayHandler(object): self._data_to_write_to_remote.append(data) def on_remote_read(self): - data = self._remote_sock.recv(BUF_SIZE) + # TODO update timeout + data = None + try: + data = self._remote_sock.recv(BUF_SIZE) + except (OSError, IOError) as e: + if eventloop.errno_from_exception(e) in \ + (errno.ETIMEDOUT, errno.EAGAIN): + return + if not data: + self.destroy() + return if self._is_local: data = self._encryptor.decrypt(data) try: @@ -225,12 +255,13 @@ class TCPRelayHandler(object): self.update_stream(STREAM_DOWN, STATUS_WAIT_READING) def on_remote_write(self): + self._stage = STAGE_STREAM if self._data_to_write_to_remote: data = ''.join(self._data_to_write_to_remote) self._data_to_write_to_remote = [] self.write_all_to_sock(data, self._remote_sock) else: - self.update_stream(STREAM_DOWN, STATUS_WAIT_READING) + self.update_stream(STREAM_UP, STATUS_WAIT_READING) def on_local_error(self): logging.error(eventloop.get_sock_error(self._local_sock)) @@ -261,14 +292,14 @@ class TCPRelayHandler(object): def destroy(self): if self._remote_sock: - self._remote_sock.close() self._loop.remove(self._remote_sock) del self._fd_to_handlers[self._remote_sock.fileno()] + self._remote_sock.close() self._remote_sock = None if self._local_sock: - self._local_sock.close() self._loop.remove(self._local_sock) del self._fd_to_handlers[self._local_sock.fileno()] + self._local_sock.close() self._local_sock = None @@ -317,6 +348,7 @@ class TCPRelay(object): logging.error(e) continue for sock, event in events: + logging.debug('%s %d', sock, event) if sock == self._server_socket: if event & eventloop.POLL_ERR: # TODO @@ -324,11 +356,13 @@ class TCPRelay(object): try: conn = self._server_socket.accept() TCPRelayHandler(self._fd_to_handlers, self._eventloop, - conn, self._config, self._is_local) + conn[0], self._config, self._is_local) except (OSError, IOError) as e: error_no = eventloop.errno_from_exception(e) if error_no in (errno.EAGAIN, errno.EINPROGRESS): continue + else: + logging.error(e) else: handler = self._fd_to_handlers.get(sock.fileno(), None) if handler: @@ -336,6 +370,7 @@ class TCPRelay(object): else: logging.warn('can not find handler for fd %d', sock.fileno()) + self._eventloop.remove(sock) now = time.time() if now - last_time > 5: # TODO sweep timeouts @@ -346,7 +381,7 @@ class TCPRelay(object): if self._closed: raise Exception('closed') t = threading.Thread(target=self._run) - t.setName('UDPThread') + t.setName('TCPThread') t.setDaemon(False) t.start() self._thread = t diff --git a/shadowsocks/utils.py b/shadowsocks/utils.py index 180cdcb..595068a 100644 --- a/shadowsocks/utils.py +++ b/shadowsocks/utils.py @@ -139,15 +139,17 @@ def get_config(is_local): config['verbose'] = config.get('verbose', False) config['local_address'] = config.get('local_address', '127.0.0.1') - check_config(config) - if config['verbose']: level = logging.DEBUG else: level = logging.WARNING + logging.getLogger('').handlers = [] logging.basicConfig(level=level, format='%(asctime)s %(levelname)-8s %(message)s', datefmt='%Y-%m-%d %H:%M:%S', filemode='a+') + + check_config(config) + return config