update server

This commit is contained in:
clowwindy 2013-05-17 17:12:12 +08:00
parent ef8b741182
commit a8bf892154

202
server.py
View file

@ -26,23 +26,14 @@ if sys.version_info < (2, 6):
import simplejson as json import simplejson as json
else: else:
import json import json
try:
import gevent, gevent.monkey
gevent.monkey.patch_all(dns=gevent.version_info[0]>=1)
except ImportError:
gevent = None
print >>sys.stderr, 'warning: gevent not found, using threading instead'
import socket
import select
import SocketServer
import struct import struct
import string import string
import hashlib import hashlib
import os import os
import logging import logging
import getopt import getopt
import socket
def get_table(key): def get_table(key):
m = hashlib.md5() m = hashlib.md5()
@ -54,108 +45,145 @@ def get_table(key):
table.sort(lambda x, y: int(a % (ord(x) + i) - a % (ord(y) + i))) table.sort(lambda x, y: int(a % (ord(x) + i) - a % (ord(y) + i)))
return table return table
def send_all(sock, data):
bytes_sent = 0
while True:
r = sock.send(data[bytes_sent:])
if r < 0:
return r
bytes_sent += r
if bytes_sent == len(data):
return bytes_sent
class ThreadingTCPServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer): def encrypt(data):
allow_reuse_address = True return data.translate(encrypt_table)
class Socks5Server(SocketServer.StreamRequestHandler): def decrypt(data):
def handle_tcp(self, sock, remote): return data.translate(decrypt_table)
try:
fdset = [sock, remote]
while True:
r, w, e = select.select(fdset, [], [])
if sock in r:
data = sock.recv(4096)
if len(data) <= 0:
break
result = send_all(remote, self.decrypt(data))
if result < len(data):
raise Exception('failed to send all data')
if remote in r:
data = remote.recv(4096)
if len(data) <= 0:
break
result = send_all(sock, self.encrypt(data))
if result < len(data):
raise Exception('failed to send all data')
finally:
sock.close()
remote.close()
def encrypt(self, data): class RemoteHandler(object):
return data.translate(encrypt_table) def __init__(self, conn, local_handler):
self.conn = conn
self.local_handler = local_handler
conn.on('connect', self.on_connect)
conn.on('data', self.on_data)
conn.on('close', self.on_close)
conn.on('end', self.on_end)
conn.connect(local_handler.remote_addr_pair)
def decrypt(self, data): def on_connect(self, s):
return data.translate(decrypt_table) for piece in self.local_handler.cached_pieces:
self.conn.write(decrypt(piece))
# TODO write cached pieces
self.local_handler.stage = 5
def handle(self): def on_data(self, s, data):
try: data = encrypt(data)
sock = self.connection self.local_handler.conn.write(data)
addrtype = ord(self.decrypt(sock.recv(1)))
if addrtype == 1: def on_close(self, s):
addr = socket.inet_ntoa(self.decrypt(self.rfile.read(4))) # self.local_handler.conn.end()
elif addrtype == 3: pass
addr = self.decrypt(
self.rfile.read(ord(self.decrypt(sock.recv(1))))) def on_end(self, s):
elif addrtype == 4: self.local_handler.conn.end()
addr = socket.inet_ntop(socket.AF_INET6, self.decrypt(self.rfile.read(16)))
else:
# not support class LocalHandler(object):
logging.warn('addr_type not support') def on_data(self, s, data):
return if self.stage == 5:
port = struct.unpack('>H', self.decrypt(self.rfile.read(2))) data = decrypt(data)
self.remote_handler.conn.write(data)
return
if self.stage == 0:
try: try:
logging.info('connecting %s:%d' % (addr, port[0])) addrtype = ord(data[0])
remote = socket.create_connection((addr, port[0])) # TODO check cmd == 1
except socket.error, e: if addrtype == 1:
# Connection refused remote_addr = socket.inet_ntoa(data[1:5])
logging.warn(e) remote_port = data[5:7]
header_length = 7
elif addrtype == 4:
remote_addr = socket.inet_ntop(data[1:17])
remote_port = data[17:19]
header_length = 19
elif addrtype == 3:
addr_len = ord(data[1])
remote_addr = data[2:2 + addr_len]
remote_port = data[2 + addr_len:2 + addr_len + 2]
header_length = 2 + addr_len + 2
else:
# TODO check addrtype in (1, 3, 4)
raise
remote_port = struct.unpack('>H', remote_port)[0]
self.remote_addr_pair = (remote_addr, remote_port)
logging.info('connecting %s:%d' % self.remote_addr_pair)
remote_conn = ssloop.Socket()
self.remote_handler = RemoteHandler(remote_conn, self)
if len(data) > header_length:
self.cached_pieces.append(data[header_length:])
# TODO save other bytes
self.stage = 4
return return
self.handle_tcp(sock, remote) except:
except socket.error, e: import traceback
logging.warn(e) traceback.print_exc()
if self.stage == 4:
self.cached_pieces.append(data)
def on_end(self, s):
if self.remote_handler:
self.remote_handler.conn.end()
def on_close(self, s):
pass
# self.remote_handler.conn.end()
def __init__(self, conn):
self.stage = 0
self.remote = None
self.addr_len = 0
self.addr_to_send = ''
self.conn = conn
self.cached_pieces = []
conn.on('data', self.on_data)
conn.on('end', self.on_end)
conn.on('close', self.on_close)
def on_connection(s, conn):
LocalHandler(conn)
if __name__ == '__main__': if __name__ == '__main__':
os.chdir(os.path.dirname(__file__) or '.') os.chdir(os.path.dirname(__file__) or '.')
sys.path.append('./ssloop')
print 'shadowsocks v1.1' import ssloop
print 'shadowsocks v2.0'
with open('config.json', 'rb') as f: with open('config.json', 'rb') as f:
config = json.load(f) config = json.load(f)
SERVER = config['server']
PORT = config['server_port'] PORT = config['server_port']
KEY = config['password'] KEY = config['password']
optlist, args = getopt.getopt(sys.argv[1:], 'p:k:') argv = sys.argv[1:]
if '-6' in sys.argv[1:]:
argv.remove('-6')
optlist, args = getopt.getopt(argv, 'p:k:')
for key, value in optlist: for key, value in optlist:
if key == '-p': if key == '-p':
PORT = int(value) PORT = int(value)
elif key == '-k': elif key == '-k':
KEY = value KEY = value
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)-8s %(message)s', logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S', filemode='a+') datefmt='%Y-%m-%d %H:%M:%S', filemode='a+')
encrypt_table = ''.join(get_table(KEY)) encrypt_table = ''.join(get_table(KEY))
decrypt_table = string.maketrans(encrypt_table, string.maketrans('', '')) decrypt_table = string.maketrans(encrypt_table, string.maketrans('', ''))
if '-6' in sys.argv[1:]:
ThreadingTCPServer.address_family = socket.AF_INET6
try: try:
server = ThreadingTCPServer(('', PORT), Socks5Server)
logging.info("starting server at port %d ..." % PORT) logging.info("starting server at port %d ..." % PORT)
server.serve_forever() loop = ssloop.instance()
except socket.error, e: s = ssloop.Server(('0.0.0.0', PORT))
logging.error(e) s.on('connection', on_connection)
s.listen()
loop.start()
except KeyboardInterrupt:
sys.exit(0)