This commit is contained in:
JinXing 2014-04-15 23:33:12 +08:00
parent 64317a70de
commit 636b4fdeda
3 changed files with 85 additions and 22 deletions

View file

@ -1,4 +1,4 @@
1. fd 不在 fd_map 中却注册到了 epoll
1. fd 不在 fd_map 中却注册到了 epoll (可能跟 unrigxxxx 的时效有关,明天来再看下 epoll 的示例代码)
2. epoll 出现 MY_POLLEV_ERR 事件
3. 大部分 SSL 连接时浏览器会报: ERR_SSL_PROTOCOL_ERROR,
但是当尝试把 read() 及 write() 的数据打印出来时又可以正常连接(应该是跟打印输出会消耗一定的时间有关)

View file

@ -13,6 +13,7 @@ import logging
import socket
import errno
import binascii
import time
try:
from cStringIO import StringIO
@ -85,16 +86,19 @@ class IOLoop(object):
logging.warn('fd %d not in fd_map', fd)
self._poller.unregister(fd)
continue
# logging.info('fd %d, events %d', fd, events)
handler = self._fd_map[fd]
if events & MY_POLLEV_ERR:
# logging.debug("fd[%s] events MY_POLLEV_ERR | MY_POLLEV_HUP", fd)
handler.handle_error(events)
handler.handle_error(fd, events)
elif events & MY_POLLEV_IN or events & MY_POLLEV_PRI:
# logging.debug("fd[%s] events MY_POLLEV_IN | MY_POLLEV_PRI", fd)
handler.handle_read()
elif events & MY_POLLEV_OUT:
# logging.debug("fd[%s] events MY_POLLEV_OUT", fd)
handler.handle_write()
else:
logging.error("unknow events %d", events)
#@staticmethod
#def _set_nonblocking(fd):
@ -161,8 +165,8 @@ class BaseHandler(object):
def handle_write(self):
raise
def handle_error(self):
raise
def handle_error(self, fd, events):
logging.warn("socket error, fd: %d, events: %d", fd, events)
class IOHandler(BaseHandler):
@ -201,9 +205,12 @@ class IOHandler(BaseHandler):
"""fd 可写事件出现"""
self._ios.real_write()
def handle_error(self, events):
logging.error("handle_error fd(%s), events: %s", self._fd, binascii.b2a_hex(events))
self._ios.close()
def handle_error(self, fd, events):
logging.error("handle_error fd(%s), events: %r", fd, events)
try:
self._ios.close()
except Exception, e:
loggin.error("handle_error() close() exception: %s", e)
class SimpleCopyFileHandler(IOHandler):

View file

@ -120,6 +120,7 @@ class BaseTunnelHandler(ioloop.IOHandler):
ioloop.IOHandler.__init__(self, *args, **kwargs)
self.encryptor = encrypt.Encryptor(G_CONFIG["password"], G_CONFIG["method"])
self._remote_ios = None
self._rs_connecting = False
def encrypt(self, data):
return self.encryptor.encrypt(data)
@ -154,23 +155,31 @@ class BaseTunnelHandler(ioloop.IOHandler):
def connect_to_remote(self):
raise
def set_remote_ts(self, sock):
raise
def handle_read(self):
"""fd 可读事件出现"""
# logging.info("%r, remote_ios: %r, _rs_connecting: %r", self, self._remote_ios, self._rs_connecting)
if not self._remote_ios:
self._remote_ios = self.connect_to_remote()
if not self._remote_ios:
self.close_tunnel()
return
if not self._rs_connecting:
self.connect_to_remote()
return
logging.info("handle_read(), local:%d, remote:%d, Handler:%r",
logging.debug("handle_read(), local:%d, remote:%d, Handler:%r",
self._ios.fileno(), self._remote_ios.fileno(), self)
try:
_s = time.time()
s = self.do_stream_read()
# logging.debug('do_stream_read() cast time %f', time.time()-_s)
if len(s) == 0:
logging.debug('iostream[%s].read() return len(s) == 0, close it', self._fd)
self.close_tunnel()
return self.write_to_remote(s)
_s = time.time()
self.write_to_remote(s)
# logging.debug('write_to_remote() cast time %f', time.time()-_s)
return
except socket.error, _e:
if _e.errno in (errno.EWOULDBLOCK, errno.EAGAIN):
@ -184,6 +193,10 @@ class LeftTunnelHandler(BaseTunnelHandler):
super(self.__class__, self).__init__(*args, **kwargs)
self._remote_ios = None
def set_remote_ios(self, sock_stream):
self._remote_ios = sock_stream
logging.info("self._remote_ios: %r", self._remote_ios)
def connect_to_remote(self):
rfile = self._ios
iv_len = self.encryptor.iv_len()
@ -205,19 +218,25 @@ class LeftTunnelHandler(BaseTunnelHandler):
logging.info('connecting to remote %s:%d', addr, port)
_start_time = time.time()
remote_socket = socket.socket()
remote_socket.connect((addr, port))
remote_socket.setblocking(0)
logging.info('cost time: %d', time.time()-_start_time)
try:
remote_socket.connect((addr, port))
except socket.error, _e:
if _e.errno != errno.EINPROGRESS:
raise _e
logging.info('socket.connect() cost time: %f', time.time()-_start_time)
except socket.error, e:
# Connection refused
logging.warn(e)
return None
remote_ts = TunnelStream(remote_socket)
handler = RightTunnelHandler( self._ios, self._ioloop, remote_ts)
handler = ShadowConnectHandler(self._ioloop, self, remote_ts)
self._ioloop.add_handler(remote_ts.fileno(), handler, m_read=True, m_write=True)
logging.info('New tunnel %d <=> %d' % (self._ios.fileno(), remote_ts.fileno()))
return remote_ts
self._rs_connecting = True
return None
def do_stream_read(self, size=4096):
"""从客户端读"""
@ -244,9 +263,33 @@ class RightTunnelHandler(BaseTunnelHandler):
# logging.debug('send to left: %s', list(data))
self._remote_ios.write(data)
class MyAcceptHandler(ioloop.BaseHandler):
def __init__(self, ioloop, srv_socket):
self._ioloop = ioloop
class ShadowConnectHandler(ioloop.BaseHandler):
def __init__(self, _ioloop, left_handler, right_ts):
self._ioloop = _ioloop
self._left_handler = left_handler
self._left_ts = self._left_handler._ios
self._right_ts = right_ts
def handle_write(self):
self.handle_connect_res()
def handle_read(self):
self.handle_connect_res()
def handle_connect_res(self):
self._left_handler.set_remote_ios(self._right_ts)
print self._left_handler._remote_ios
handler = RightTunnelHandler( self._left_ts, self._ioloop, self._right_ts)
self._ioloop.modify_handler( self._right_ts.fileno(), handler, m_read=True, m_write=True)
logging.info('New tunnel (%d,%d) <=> (%d,%d)' % (
self._left_handler._ios.fileno(), self._left_handler._remote_ios.fileno(),
handler._ios.fileno(), handler._remote_ios.fileno(),
))
class ShadowAcceptHandler(ioloop.BaseHandler):
def __init__(self, _ioloop, srv_socket):
self._ioloop = _ioloop
self._srv_socket = srv_socket
def handle_read(self):
@ -302,9 +345,22 @@ def main():
sock.bind((SERVER, PORT))
logging.info("listing on %s", str(sock.getsockname()))
sock.listen(1024)
io.add_handler(sock.fileno(), MyAcceptHandler(io, sock), m_read=True)
io.add_handler(sock.fileno(), ShadowAcceptHandler(io, sock), m_read=True)
next_tick = time.time() + 10
count = 0
while True:
count += 1
if time.time() >= next_tick:
logging.info("loop count %d", count)
next_tick = time.time() + 10
pass
_s = time.time()
io.wait_events(0.1)
use_time = time.time() - _s
if use_time > 0.2:
logging.error("events process cost time: %f", _e-_s)
elif use_time < 0.1:
time.sleep(0.1-use_time)
global G_CONFIG
if __name__ == '__main__':