update server

This commit is contained in:
clowwindy 2013-05-17 17:12:12 +08:00
parent ef8b741182
commit a8bf892154

198
server.py
View file

@ -26,23 +26,14 @@ if sys.version_info < (2, 6):
import simplejson as json import simplejson as json
else: else:
import json import json
try:
import gevent, gevent.monkey
gevent.monkey.patch_all(dns=gevent.version_info[0]>=1)
except ImportError:
gevent = None
print >>sys.stderr, 'warning: gevent not found, using threading instead'
import socket
import select
import SocketServer
import struct import struct
import string import string
import hashlib import hashlib
import os import os
import logging import logging
import getopt import getopt
import socket
def get_table(key): def get_table(key):
m = hashlib.md5() m = hashlib.md5()
@ -54,108 +45,145 @@ def get_table(key):
table.sort(lambda x, y: int(a % (ord(x) + i) - a % (ord(y) + i))) table.sort(lambda x, y: int(a % (ord(x) + i) - a % (ord(y) + i)))
return table return table
def send_all(sock, data):
bytes_sent = 0
while True:
r = sock.send(data[bytes_sent:])
if r < 0:
return r
bytes_sent += r
if bytes_sent == len(data):
return bytes_sent
class ThreadingTCPServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer): def encrypt(data):
allow_reuse_address = True
class Socks5Server(SocketServer.StreamRequestHandler):
def handle_tcp(self, sock, remote):
try:
fdset = [sock, remote]
while True:
r, w, e = select.select(fdset, [], [])
if sock in r:
data = sock.recv(4096)
if len(data) <= 0:
break
result = send_all(remote, self.decrypt(data))
if result < len(data):
raise Exception('failed to send all data')
if remote in r:
data = remote.recv(4096)
if len(data) <= 0:
break
result = send_all(sock, self.encrypt(data))
if result < len(data):
raise Exception('failed to send all data')
finally:
sock.close()
remote.close()
def encrypt(self, data):
return data.translate(encrypt_table) return data.translate(encrypt_table)
def decrypt(self, data):
def decrypt(data):
return data.translate(decrypt_table) return data.translate(decrypt_table)
def handle(self):
class RemoteHandler(object):
def __init__(self, conn, local_handler):
self.conn = conn
self.local_handler = local_handler
conn.on('connect', self.on_connect)
conn.on('data', self.on_data)
conn.on('close', self.on_close)
conn.on('end', self.on_end)
conn.connect(local_handler.remote_addr_pair)
def on_connect(self, s):
for piece in self.local_handler.cached_pieces:
self.conn.write(decrypt(piece))
# TODO write cached pieces
self.local_handler.stage = 5
def on_data(self, s, data):
data = encrypt(data)
self.local_handler.conn.write(data)
def on_close(self, s):
# self.local_handler.conn.end()
pass
def on_end(self, s):
self.local_handler.conn.end()
class LocalHandler(object):
def on_data(self, s, data):
if self.stage == 5:
data = decrypt(data)
self.remote_handler.conn.write(data)
return
if self.stage == 0:
try: try:
sock = self.connection addrtype = ord(data[0])
addrtype = ord(self.decrypt(sock.recv(1))) # TODO check cmd == 1
if addrtype == 1: if addrtype == 1:
addr = socket.inet_ntoa(self.decrypt(self.rfile.read(4))) remote_addr = socket.inet_ntoa(data[1:5])
elif addrtype == 3: remote_port = data[5:7]
addr = self.decrypt( header_length = 7
self.rfile.read(ord(self.decrypt(sock.recv(1)))))
elif addrtype == 4: elif addrtype == 4:
addr = socket.inet_ntop(socket.AF_INET6, self.decrypt(self.rfile.read(16))) remote_addr = socket.inet_ntop(data[1:17])
remote_port = data[17:19]
header_length = 19
elif addrtype == 3:
addr_len = ord(data[1])
remote_addr = data[2:2 + addr_len]
remote_port = data[2 + addr_len:2 + addr_len + 2]
header_length = 2 + addr_len + 2
else: else:
# not support # TODO check addrtype in (1, 3, 4)
logging.warn('addr_type not support') raise
remote_port = struct.unpack('>H', remote_port)[0]
self.remote_addr_pair = (remote_addr, remote_port)
logging.info('connecting %s:%d' % self.remote_addr_pair)
remote_conn = ssloop.Socket()
self.remote_handler = RemoteHandler(remote_conn, self)
if len(data) > header_length:
self.cached_pieces.append(data[header_length:])
# TODO save other bytes
self.stage = 4
return return
port = struct.unpack('>H', self.decrypt(self.rfile.read(2))) except:
try: import traceback
logging.info('connecting %s:%d' % (addr, port[0])) traceback.print_exc()
remote = socket.create_connection((addr, port[0]))
except socket.error, e: if self.stage == 4:
# Connection refused self.cached_pieces.append(data)
logging.warn(e)
return def on_end(self, s):
self.handle_tcp(sock, remote) if self.remote_handler:
except socket.error, e: self.remote_handler.conn.end()
logging.warn(e)
def on_close(self, s):
pass
# self.remote_handler.conn.end()
def __init__(self, conn):
self.stage = 0
self.remote = None
self.addr_len = 0
self.addr_to_send = ''
self.conn = conn
self.cached_pieces = []
conn.on('data', self.on_data)
conn.on('end', self.on_end)
conn.on('close', self.on_close)
def on_connection(s, conn):
LocalHandler(conn)
if __name__ == '__main__': if __name__ == '__main__':
os.chdir(os.path.dirname(__file__) or '.') os.chdir(os.path.dirname(__file__) or '.')
sys.path.append('./ssloop')
print 'shadowsocks v1.1' import ssloop
print 'shadowsocks v2.0'
with open('config.json', 'rb') as f: with open('config.json', 'rb') as f:
config = json.load(f) config = json.load(f)
SERVER = config['server']
PORT = config['server_port'] PORT = config['server_port']
KEY = config['password'] KEY = config['password']
optlist, args = getopt.getopt(sys.argv[1:], 'p:k:') argv = sys.argv[1:]
if '-6' in sys.argv[1:]:
argv.remove('-6')
optlist, args = getopt.getopt(argv, 'p:k:')
for key, value in optlist: for key, value in optlist:
if key == '-p': if key == '-p':
PORT = int(value) PORT = int(value)
elif key == '-k': elif key == '-k':
KEY = value KEY = value
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)-8s %(message)s', logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S', filemode='a+') datefmt='%Y-%m-%d %H:%M:%S', filemode='a+')
encrypt_table = ''.join(get_table(KEY)) encrypt_table = ''.join(get_table(KEY))
decrypt_table = string.maketrans(encrypt_table, string.maketrans('', '')) decrypt_table = string.maketrans(encrypt_table, string.maketrans('', ''))
if '-6' in sys.argv[1:]:
ThreadingTCPServer.address_family = socket.AF_INET6
try: try:
server = ThreadingTCPServer(('', PORT), Socks5Server)
logging.info("starting server at port %d ..." % PORT) logging.info("starting server at port %d ..." % PORT)
server.serve_forever() loop = ssloop.instance()
except socket.error, e: s = ssloop.Server(('0.0.0.0', PORT))
logging.error(e) s.on('connection', on_connection)
s.listen()
loop.start()
except KeyboardInterrupt:
sys.exit(0)