more work

This commit is contained in:
clowwindy 2014-05-30 23:55:33 +08:00
parent 2cdddd4515
commit 0d6c39900b
3 changed files with 241 additions and 79 deletions

64
shadowsocks/common.py Normal file
View file

@ -0,0 +1,64 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# Copyright (c) 2014 clowwindy
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import socket
import struct
import logging
def parse_header(data):
addrtype = ord(data[0])
dest_addr = None
dest_port = None
header_length = 0
if addrtype == 1:
if len(data) >= 7:
dest_addr = socket.inet_ntoa(data[1:5])
dest_port = struct.unpack('>H', data[5:7])[0]
header_length = 7
else:
logging.warn('header is too short')
elif addrtype == 3:
if len(data) > 2:
addrlen = ord(data[1])
if len(data) >= 2 + addrlen:
dest_addr = data[2:2 + addrlen]
dest_port = struct.unpack('>H', data[2 + addrlen:4 +
addrlen])[0]
header_length = 4 + addrlen
else:
logging.warn('header is too short')
else:
logging.warn('header is too short')
elif addrtype == 4:
if len(data) >= 19:
dest_addr = socket.inet_ntop(socket.AF_INET6, data[1:17])
dest_port = struct.unpack('>H', data[17:19])[0]
header_length = 19
else:
logging.warn('header is too short')
else:
logging.warn('unsupported addrtype %d, maybe wrong password' % addrtype)
if dest_addr is None:
return None
return (addrtype, dest_addr, dest_port, header_length)

View file

@ -22,41 +22,170 @@
# SOFTWARE. # SOFTWARE.
import time import time
import threading
import socket import socket
import logging import logging
import struct
import encrypt import encrypt
import eventloop
import errno import errno
import eventloop
from common import parse_header
# local:
# stage 0 init
# stage 1 hello received, hello sent
# stage 4 addr received, reply sent
# stage 5 remote connected
# remote:
# stage 0 init
# stage 4 addr received, reply sent
# stage 5 remote connected
BUF_SIZE = 8 * 1024
class TCPRelayHandler(object): class TCPRelayHandler(object):
def __init__(self, fd_to_handlers, loop, conn, config, is_local): def __init__(self, fd_to_handlers, loop, local_sock, config, is_local):
self._fd_to_handlers = fd_to_handlers self._fd_to_handlers = fd_to_handlers
self._loop = loop self._loop = loop
self._local_conn = conn self._local_sock = local_sock
self._remote_conn = None self._remote_sock = None
self._remains_data_for_local = None
self._remains_data_for_remote = None
self._config = config self._config = config
self._is_local = is_local self._is_local = is_local
self._stage = 0 self._stage = 0
fd_to_handlers[conn.fileno()] = self self._encryptor = encrypt.Encryptor(config['password'],
conn.setblocking(False) config['method'])
loop.add(conn, eventloop.POLL_IN) self._data_to_write_to_local = []
self._data_to_write_to_remote = []
fd_to_handlers[local_sock.fileno()] = self
local_sock.setblocking(False)
loop.add(local_sock, eventloop.POLL_IN)
def resume_reading(self, sock):
pass
def pause_reading(self, sock):
pass
def resume_writing(self, sock):
pass
def pause_writing(self, sock):
pass
def write_all_to_sock(self, data, sock):
# write to sock
# put remaining bytes into buffer
# return true if all written
# return false if some bytes left in buffer
# raise if encounter error
return True
def on_local_read(self): def on_local_read(self):
pass if not self._local_sock:
return
is_local = self._is_local
data = self._local_sock.recv(BUF_SIZE)
if not is_local:
data = self._encryptor.decrypt(data)
if self._stage == 5:
if self._is_local:
data = self._encryptor.encrypt(data)
if not self.write_all_to_sock(data, self._remote_sock):
self.pause_reading(self._local_sock)
return
if is_local and self._stage == 0:
# TODO check auth method
self.write_all_to_sock('\x05\00', self._local_sock)
self._stage = 1
return
if self._stage == 4:
self._data_to_write_to_remote.append(data)
if (is_local and self._stage == 0) or \
(not is_local and self._stage == 1):
try:
if is_local:
cmd = ord(data[1])
# TODO check cmd == 1
assert cmd == 1
# just trim VER CMD RSV
data = data[3:]
header_result = parse_header(data)
if header_result is None:
raise Exception('can not parse header')
addrtype, remote_addr, remote_port, header_length =\
header_result
logging.info('connecting %s:%d' % (remote_addr, remote_port))
if is_local:
# forward address to remote
self._data_to_write_to_remote.append(data[:header_length])
self.write_all_to_sock('\x05\x00\x00\x01' +
'\x00\x00\x00\x00\x10\x10',
self._local_sock)
else:
remote_addr = self._config['server']
remote_port = self._config['server_port']
# TODO async DNS
addrs = socket.getaddrinfo(remote_addr, remote_port, 0,
socket.SOCK_STREAM, socket.SOL_TCP)
if len(addrs) == 0:
raise Exception("can't get addrinfo for %s:%d" %
(remote_addr, remote_port))
af, socktype, proto, canonname, sa = addrs[0]
self._remote_sock = socket.socket(af, socktype, proto)
self._remote_sock.setblocking(False)
# TODO support TCP fast open
self._remote_sock.connect(sa)
self._loop.add(self._remote_sock, eventloop.POLL_OUT)
if len(data) > header_length:
self._data_to_write_to_remote.append(data[header_length:])
self._stage = 4
self.pause_reading(self._local_sock)
return
except Exception:
import traceback
traceback.print_exc()
# TODO use logging when debug completed
self.destroy()
if self._stage == 4:
self._data_to_write_to_remote.append(data)
def on_remote_read(self): def on_remote_read(self):
pass data = self._remote_sock.recv(BUF_SIZE)
if self._is_local:
data = self._encryptor.decrypt(data)
try:
if not self.write_all_to_sock(data, self._local_sock):
self.pause_reading(self._remote_sock)
self.resume_writing(self._local_sock)
except Exception:
import traceback
traceback.print_exc()
# TODO use logging when debug completed
self.destroy()
def on_local_write(self): def on_local_write(self):
pass if self._data_to_write_to_local:
written = self.write_all_to_sock(
''.join(self._data_to_write_to_local), self._local_sock)
if written:
self.pause_writing(self._local_sock)
else:
self.pause_writing(self._local_sock)
def on_remote_write(self): def on_remote_write(self):
pass if self._data_to_write_to_remote:
written = self.write_all_to_sock(
''.join(self._data_to_write_to_remote), self._remote_sock)
if written:
self.pause_writing(self._remote_sock)
else:
self.pause_writing(self._remote_sock)
def on_local_error(self): def on_local_error(self):
self.destroy() self.destroy()
@ -66,33 +195,34 @@ class TCPRelayHandler(object):
def handle_event(self, sock, event): def handle_event(self, sock, event):
# order is important # order is important
if sock == self._local_conn: if sock == self._remote_sock:
if event & eventloop.POLL_IN:
self.on_local_read()
if event & eventloop.POLL_OUT:
self.on_local_write()
if event & eventloop.POLL_ERR:
self.on_local_error()
elif sock == self._remote_conn:
if event & eventloop.POLL_IN: if event & eventloop.POLL_IN:
self.on_remote_read() self.on_remote_read()
if event & eventloop.POLL_OUT: if event & eventloop.POLL_OUT:
self.on_remote_write() self.on_remote_write()
if event & eventloop.POLL_ERR: if event & eventloop.POLL_ERR:
self.on_remote_error() self.on_remote_error()
elif sock == self._local_sock:
if event & eventloop.POLL_IN:
self.on_local_read()
if event & eventloop.POLL_OUT:
self.on_local_write()
if event & eventloop.POLL_ERR:
self.on_local_error()
else: else:
logging.warn('unknown socket') logging.warn('unknown socket')
def destroy(self): def destroy(self):
if self._local_conn: if self._remote_sock:
self._local_conn.close() self._remote_sock.close()
eventloop.remove(self._local_conn) self._loop.remove(self._remote_sock)
# TODO maybe better to delete the key del self._fd_to_handlers[self._remote_sock.fileno()]
self._fd_to_handlers[self._local_conn.fileno()] = None self._remote_sock = None
if self._remote_conn: if self._local_sock:
self._remote_conn.close() self._local_sock.close()
eventloop.remove(self._remote_conn) self._loop.remove(self._local_sock)
self._fd_to_handlers[self._local_conn.fileno()] = None del self._fd_to_handlers[self._local_sock.fileno()]
self._local_sock = None
class TCPRelay(object): class TCPRelay(object):
@ -102,15 +232,23 @@ class TCPRelay(object):
self._closed = False self._closed = False
self._fd_to_handlers = {} self._fd_to_handlers = {}
addrs = socket.getaddrinfo(self._listen_addr, self._listen_port, 0, if is_local:
listen_addr = config['local_address']
listen_port = config['local_port']
else:
listen_addr = config['server']
listen_port = config['server_port']
addrs = socket.getaddrinfo(listen_addr, listen_port, 0,
socket.SOCK_STREAM, socket.SOL_TCP) socket.SOCK_STREAM, socket.SOL_TCP)
if len(addrs) == 0: if len(addrs) == 0:
raise Exception("can't get addrinfo for %s:%d" % raise Exception("can't get addrinfo for %s:%d" %
(self._listen_addr, self._listen_port)) (listen_addr, listen_port))
af, socktype, proto, canonname, sa = addrs[0] af, socktype, proto, canonname, sa = addrs[0]
server_socket = socket.socket(af, socktype, proto) server_socket = socket.socket(af, socktype, proto)
server_socket.bind((self._listen_addr, self._listen_port)) server_socket.bind(sa)
server_socket.setblocking(False) server_socket.setblocking(False)
server_socket.listen(1024)
self._server_socket = server_socket self._server_socket = server_socket
def _run(self): def _run(self):
@ -132,8 +270,8 @@ class TCPRelay(object):
if sock == self._server_socket: if sock == self._server_socket:
try: try:
conn = self._server_socket.accept() conn = self._server_socket.accept()
TCPRelayHandler(loop, conn, remote_addr, remote_port, TCPRelayHandler(self._eventloop, conn, self._config,
password, method, timeout, is_local) self._is_local)
except (OSError, IOError) as e: except (OSError, IOError) as e:
error_no = eventloop.errno_from_exception(e) error_no = eventloop.errno_from_exception(e)
if error_no in [errno.EAGAIN, errno.EINPROGRESS]: if error_no in [errno.EAGAIN, errno.EINPROGRESS]:
@ -149,6 +287,3 @@ class TCPRelay(object):
if now - last_time > 5: if now - last_time > 5:
# TODO sweep timeouts # TODO sweep timeouts
last_time = now last_time = now

View file

@ -71,53 +71,16 @@ import threading
import socket import socket
import logging import logging
import struct import struct
import errno
import encrypt import encrypt
import eventloop import eventloop
import lru_cache import lru_cache
import errno from common import parse_header
BUF_SIZE = 65536 BUF_SIZE = 65536
def parse_header(data):
addrtype = ord(data[0])
dest_addr = None
dest_port = None
header_length = 0
if addrtype == 1:
if len(data) >= 7:
dest_addr = socket.inet_ntoa(data[1:5])
dest_port = struct.unpack('>H', data[5:7])[0]
header_length = 7
else:
logging.warn('[udp] header is too short')
elif addrtype == 3:
if len(data) > 2:
addrlen = ord(data[1])
if len(data) >= 2 + addrlen:
dest_addr = data[2:2 + addrlen]
dest_port = struct.unpack('>H', data[2 + addrlen:4 +
addrlen])[0]
header_length = 4 + addrlen
else:
logging.warn('[udp] header is too short')
else:
logging.warn('[udp] header is too short')
elif addrtype == 4:
if len(data) >= 19:
dest_addr = socket.inet_ntop(socket.AF_INET6, data[1:17])
dest_port = struct.unpack('>H', data[17:19])[0]
header_length = 19
else:
logging.warn('[udp] header is too short')
else:
logging.warn('unsupported addrtype %d' % addrtype)
if dest_addr is None:
return None
return (addrtype, dest_addr, dest_port, header_length)
def client_key(a, b, c, d): def client_key(a, b, c, d):
return '%s:%s:%s:%s' % (a, b, c, d) return '%s:%s:%s:%s' % (a, b, c, d)