This commit is contained in:
Shengdun Hua 2017-02-08 05:56:55 +00:00 committed by GitHub
commit 5a89d7429c
19 changed files with 3444 additions and 1 deletions

61
.gitignore vendored Normal file
View file

@ -0,0 +1,61 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.cache
nosetests.xml
coverage.xml
# Translations
*.mo
*.pot
# Django stuff:
*.log
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# database file
*.sqlite
*.sqlite3
*.db
# temporary file
*.swp

View file

@ -1 +1 @@
Removed according to regulations.
Just for learn.

18
__init__.py Normal file
View file

@ -0,0 +1,18 @@
#!/usr/bin/python
#
# Copyright 2012-2015 clowwindy
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, \
with_statement

481
asyncdns.py Normal file
View file

@ -0,0 +1,481 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2014-2015 clowwindy
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, \
with_statement
import os
import socket
import struct
import re
import logging
from shadowsocks import common, lru_cache, eventloop, shell
CACHE_SWEEP_INTERVAL = 30
VALID_HOSTNAME = re.compile(br"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE)
common.patch_socket()
# rfc1035
# format
# +---------------------+
# | Header |
# +---------------------+
# | Question | the question for the name server
# +---------------------+
# | Answer | RRs answering the question
# +---------------------+
# | Authority | RRs pointing toward an authority
# +---------------------+
# | Additional | RRs holding additional information
# +---------------------+
#
# header
# 1 1 1 1 1 1
# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | ID |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# |QR| Opcode |AA|TC|RD|RA| Z | RCODE |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | QDCOUNT |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | ANCOUNT |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | NSCOUNT |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | ARCOUNT |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
QTYPE_ANY = 255
QTYPE_A = 1
QTYPE_AAAA = 28
QTYPE_CNAME = 5
QTYPE_NS = 2
QCLASS_IN = 1
def build_address(address):
address = address.strip(b'.')
labels = address.split(b'.')
results = []
for label in labels:
l = len(label)
if l > 63:
return None
results.append(common.chr(l))
results.append(label)
results.append(b'\0')
return b''.join(results)
def build_request(address, qtype):
request_id = os.urandom(2)
header = struct.pack('!BBHHHH', 1, 0, 1, 0, 0, 0)
addr = build_address(address)
qtype_qclass = struct.pack('!HH', qtype, QCLASS_IN)
return request_id + header + addr + qtype_qclass
def parse_ip(addrtype, data, length, offset):
if addrtype == QTYPE_A:
return socket.inet_ntop(socket.AF_INET, data[offset:offset + length])
elif addrtype == QTYPE_AAAA:
return socket.inet_ntop(socket.AF_INET6, data[offset:offset + length])
elif addrtype in [QTYPE_CNAME, QTYPE_NS]:
return parse_name(data, offset)[1]
else:
return data[offset:offset + length]
def parse_name(data, offset):
p = offset
labels = []
l = common.ord(data[p])
while l > 0:
if (l & (128 + 64)) == (128 + 64):
# pointer
pointer = struct.unpack('!H', data[p:p + 2])[0]
pointer &= 0x3FFF
r = parse_name(data, pointer)
labels.append(r[1])
p += 2
# pointer is the end
return p - offset, b'.'.join(labels)
else:
labels.append(data[p + 1:p + 1 + l])
p += 1 + l
l = common.ord(data[p])
return p - offset + 1, b'.'.join(labels)
# rfc1035
# record
# 1 1 1 1 1 1
# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | |
# / /
# / NAME /
# | |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | TYPE |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | CLASS |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | TTL |
# | |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
# | RDLENGTH |
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--|
# / RDATA /
# / /
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
def parse_record(data, offset, question=False):
nlen, name = parse_name(data, offset)
if not question:
record_type, record_class, record_ttl, record_rdlength = struct.unpack(
'!HHiH', data[offset + nlen:offset + nlen + 10]
)
ip = parse_ip(record_type, data, record_rdlength, offset + nlen + 10)
return nlen + 10 + record_rdlength, \
(name, ip, record_type, record_class, record_ttl)
else:
record_type, record_class = struct.unpack(
'!HH', data[offset + nlen:offset + nlen + 4]
)
return nlen + 4, (name, None, record_type, record_class, None, None)
def parse_header(data):
if len(data) >= 12:
header = struct.unpack('!HBBHHHH', data[:12])
res_id = header[0]
res_qr = header[1] & 128
res_tc = header[1] & 2
res_ra = header[2] & 128
res_rcode = header[2] & 15
# assert res_tc == 0
# assert res_rcode in [0, 3]
res_qdcount = header[3]
res_ancount = header[4]
res_nscount = header[5]
res_arcount = header[6]
return (res_id, res_qr, res_tc, res_ra, res_rcode, res_qdcount,
res_ancount, res_nscount, res_arcount)
return None
def parse_response(data):
try:
if len(data) >= 12:
header = parse_header(data)
if not header:
return None
res_id, res_qr, res_tc, res_ra, res_rcode, res_qdcount, \
res_ancount, res_nscount, res_arcount = header
qds = []
ans = []
offset = 12
for i in range(0, res_qdcount):
l, r = parse_record(data, offset, True)
offset += l
if r:
qds.append(r)
for i in range(0, res_ancount):
l, r = parse_record(data, offset)
offset += l
if r:
ans.append(r)
for i in range(0, res_nscount):
l, r = parse_record(data, offset)
offset += l
for i in range(0, res_arcount):
l, r = parse_record(data, offset)
offset += l
response = DNSResponse()
if qds:
response.hostname = qds[0][0]
for an in qds:
response.questions.append((an[1], an[2], an[3]))
for an in ans:
response.answers.append((an[1], an[2], an[3]))
return response
except Exception as e:
shell.print_exception(e)
return None
def is_valid_hostname(hostname):
if len(hostname) > 255:
return False
if hostname[-1] == b'.':
hostname = hostname[:-1]
return all(VALID_HOSTNAME.match(x) for x in hostname.split(b'.'))
class DNSResponse(object):
def __init__(self):
self.hostname = None
self.questions = [] # each: (addr, type, class)
self.answers = [] # each: (addr, type, class)
def __str__(self):
return '%s: %s' % (self.hostname, str(self.answers))
STATUS_IPV4 = 0
STATUS_IPV6 = 1
class DNSResolver(object):
def __init__(self):
self._loop = None
self._hosts = {}
self._hostname_status = {}
self._hostname_to_cb = {}
self._cb_to_hostname = {}
self._cache = lru_cache.LRUCache(timeout=300)
self._sock = None
self._servers = None
self._parse_resolv()
self._parse_hosts()
# TODO monitor hosts change and reload hosts
# TODO parse /etc/gai.conf and follow its rules
def _parse_resolv(self):
self._servers = []
try:
with open('/etc/resolv.conf', 'rb') as f:
content = f.readlines()
for line in content:
line = line.strip()
if line:
if line.startswith(b'nameserver'):
parts = line.split()
if len(parts) >= 2:
server = parts[1]
if common.is_ip(server) == socket.AF_INET:
if type(server) != str:
server = server.decode('utf8')
self._servers.append(server)
except IOError:
pass
if not self._servers:
self._servers = ['8.8.4.4', '8.8.8.8']
def _parse_hosts(self):
etc_path = '/etc/hosts'
if 'WINDIR' in os.environ:
etc_path = os.environ['WINDIR'] + '/system32/drivers/etc/hosts'
try:
with open(etc_path, 'rb') as f:
for line in f.readlines():
line = line.strip()
parts = line.split()
if len(parts) >= 2:
ip = parts[0]
if common.is_ip(ip):
for i in range(1, len(parts)):
hostname = parts[i]
if hostname:
self._hosts[hostname] = ip
except IOError:
self._hosts['localhost'] = '127.0.0.1'
def add_to_loop(self, loop):
if self._loop:
raise Exception('already add to loop')
self._loop = loop
# TODO when dns server is IPv6
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM,
socket.SOL_UDP)
self._sock.setblocking(False)
loop.add(self._sock, eventloop.POLL_IN, self)
loop.add_periodic(self.handle_periodic)
def _call_callback(self, hostname, ip, error=None):
callbacks = self._hostname_to_cb.get(hostname, [])
for callback in callbacks:
if callback in self._cb_to_hostname:
del self._cb_to_hostname[callback]
if ip or error:
callback((hostname, ip), error)
else:
callback((hostname, None),
Exception('unknown hostname %s' % hostname))
if hostname in self._hostname_to_cb:
del self._hostname_to_cb[hostname]
if hostname in self._hostname_status:
del self._hostname_status[hostname]
def _handle_data(self, data):
response = parse_response(data)
if response and response.hostname:
hostname = response.hostname
ip = None
for answer in response.answers:
if answer[1] in (QTYPE_A, QTYPE_AAAA) and \
answer[2] == QCLASS_IN:
ip = answer[0]
break
if not ip and self._hostname_status.get(hostname, STATUS_IPV6) \
== STATUS_IPV4:
self._hostname_status[hostname] = STATUS_IPV6
self._send_req(hostname, QTYPE_AAAA)
else:
if ip:
self._cache[hostname] = ip
self._call_callback(hostname, ip)
elif self._hostname_status.get(hostname, None) == STATUS_IPV6:
for question in response.questions:
if question[1] == QTYPE_AAAA:
self._call_callback(hostname, None)
break
def handle_event(self, sock, event):
if sock != self._sock:
return
if event & eventloop.POLL_ERR:
logging.error('dns socket err')
self._loop.remove(self._sock)
self._sock.close()
# TODO when dns server is IPv6
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM,
socket.SOL_UDP)
self._sock.setblocking(False)
self._loop.add(self._sock, eventloop.POLL_IN, self)
else:
data, addr = sock.recvfrom(1024)
if addr[0] not in self._servers:
logging.warn('received a packet other than our dns')
return
self._handle_data(data)
def handle_periodic(self):
self._cache.sweep()
def remove_callback(self, callback):
hostname = self._cb_to_hostname.get(callback)
if hostname:
del self._cb_to_hostname[callback]
arr = self._hostname_to_cb.get(hostname, None)
if arr:
arr.remove(callback)
if not arr:
del self._hostname_to_cb[hostname]
if hostname in self._hostname_status:
del self._hostname_status[hostname]
def _send_req(self, hostname, qtype):
req = build_request(hostname, qtype)
for server in self._servers:
logging.debug('resolving %s with type %d using server %s',
hostname, qtype, server)
self._sock.sendto(req, (server, 53))
def resolve(self, hostname, callback):
if type(hostname) != bytes:
hostname = hostname.encode('utf8')
if not hostname:
callback(None, Exception('empty hostname'))
elif common.is_ip(hostname):
callback((hostname, hostname), None)
elif hostname in self._hosts:
logging.debug('hit hosts: %s', hostname)
ip = self._hosts[hostname]
callback((hostname, ip), None)
elif hostname in self._cache:
logging.debug('hit cache: %s', hostname)
ip = self._cache[hostname]
callback((hostname, ip), None)
else:
if not is_valid_hostname(hostname):
callback(None, Exception('invalid hostname: %s' % hostname))
return
arr = self._hostname_to_cb.get(hostname, None)
if not arr:
self._hostname_status[hostname] = STATUS_IPV4
self._send_req(hostname, QTYPE_A)
self._hostname_to_cb[hostname] = [callback]
self._cb_to_hostname[callback] = hostname
else:
arr.append(callback)
# TODO send again only if waited too long
self._send_req(hostname, QTYPE_A)
def close(self):
if self._sock:
if self._loop:
self._loop.remove_periodic(self.handle_periodic)
self._loop.remove(self._sock)
self._sock.close()
self._sock = None
def test():
dns_resolver = DNSResolver()
loop = eventloop.EventLoop()
dns_resolver.add_to_loop(loop)
global counter
counter = 0
def make_callback():
global counter
def callback(result, error):
global counter
# TODO: what can we assert?
print(result, error)
counter += 1
if counter == 9:
dns_resolver.close()
loop.stop()
a_callback = callback
return a_callback
assert(make_callback() != make_callback())
dns_resolver.resolve(b'google.com', make_callback())
dns_resolver.resolve('google.com', make_callback())
dns_resolver.resolve('example.com', make_callback())
dns_resolver.resolve('ipv6.google.com', make_callback())
dns_resolver.resolve('www.facebook.com', make_callback())
dns_resolver.resolve('ns2.google.com', make_callback())
dns_resolver.resolve('invalid.@!#$%^&$@.hostname', make_callback())
dns_resolver.resolve('toooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'long.hostname', make_callback())
dns_resolver.resolve('toooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
'long.hostname', make_callback())
loop.run()
if __name__ == '__main__':
test()

33
client.py Normal file
View file

@ -0,0 +1,33 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function, \
with_statement
import logging
import os
import sys
path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, path)
from shadowsocks.eventloop import EventLoop
from shadowsocks.tcprelay import TcpRelay, TcpRelayClientHanler
FORMATTER = '%(asctime)s - %(levelname)s - %(message)s'
LOGGING_LEVEL = logging.INFO
logging.basicConfig(level=LOGGING_LEVEL, format=FORMATTER)
LISTEN_ADDR = ('127.0.0.1', 1080)
REMOTE_ADDR = ('127.0.0.1', 9000)
def main():
loop = EventLoop()
relay = TcpRelay(TcpRelayClientHanler, LISTEN_ADDR, REMOTE_ADDR)
relay.add_to_loop(loop)
loop.run()
if __name__ == '__main__':
main()

281
common.py Normal file
View file

@ -0,0 +1,281 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright 2013-2015 clowwindy
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, \
with_statement
import socket
import struct
import logging
def compat_ord(s):
if type(s) == int:
return s
return _ord(s)
def compat_chr(d):
if bytes == str:
return _chr(d)
return bytes([d])
_ord = ord
_chr = chr
ord = compat_ord
chr = compat_chr
def to_bytes(s):
if bytes != str:
if type(s) == str:
return s.encode('utf-8')
return s
def to_str(s):
if bytes != str:
if type(s) == bytes:
return s.decode('utf-8')
return s
def inet_ntop(family, ipstr):
if family == socket.AF_INET:
return to_bytes(socket.inet_ntoa(ipstr))
elif family == socket.AF_INET6:
import re
v6addr = ':'.join(('%02X%02X' % (ord(i), ord(j))).lstrip('0')
for i, j in zip(ipstr[::2], ipstr[1::2]))
v6addr = re.sub('::+', '::', v6addr, count=1)
return to_bytes(v6addr)
def inet_pton(family, addr):
addr = to_str(addr)
if family == socket.AF_INET:
return socket.inet_aton(addr)
elif family == socket.AF_INET6:
if '.' in addr: # a v4 addr
v4addr = addr[addr.rindex(':') + 1:]
v4addr = socket.inet_aton(v4addr)
v4addr = map(lambda x: ('%02X' % ord(x)), v4addr)
v4addr.insert(2, ':')
newaddr = addr[:addr.rindex(':') + 1] + ''.join(v4addr)
return inet_pton(family, newaddr)
dbyts = [0] * 8 # 8 groups
grps = addr.split(':')
for i, v in enumerate(grps):
if v:
dbyts[i] = int(v, 16)
else:
for j, w in enumerate(grps[::-1]):
if w:
dbyts[7 - j] = int(w, 16)
else:
break
break
return b''.join((chr(i // 256) + chr(i % 256)) for i in dbyts)
else:
raise RuntimeError("What family?")
def is_ip(address):
for family in (socket.AF_INET, socket.AF_INET6):
try:
if type(address) != str:
address = address.decode('utf8')
inet_pton(family, address)
return family
except (TypeError, ValueError, OSError, IOError):
pass
return False
def patch_socket():
if not hasattr(socket, 'inet_pton'):
socket.inet_pton = inet_pton
if not hasattr(socket, 'inet_ntop'):
socket.inet_ntop = inet_ntop
patch_socket()
ADDRTYPE_IPV4 = 1
ADDRTYPE_IPV6 = 4
ADDRTYPE_HOST = 3
def pack_addr(address):
address_str = to_str(address)
for family in (socket.AF_INET, socket.AF_INET6):
try:
r = socket.inet_pton(family, address_str)
if family == socket.AF_INET6:
return b'\x04' + r
else:
return b'\x01' + r
except (TypeError, ValueError, OSError, IOError):
pass
if len(address) > 255:
address = address[:255] # TODO
return b'\x03' + chr(len(address)) + address
def parse_header(data):
addrtype = ord(data[0])
dest_addr = None
dest_port = None
header_length = 0
if addrtype == ADDRTYPE_IPV4:
if len(data) >= 7:
dest_addr = socket.inet_ntoa(data[1:5])
dest_port = struct.unpack('>H', data[5:7])[0]
header_length = 7
else:
logging.warn('header is too short')
elif addrtype == ADDRTYPE_HOST:
if len(data) > 2:
addrlen = ord(data[1])
if len(data) >= 2 + addrlen:
dest_addr = data[2:2 + addrlen]
dest_port = struct.unpack('>H', data[2 + addrlen:4 +
addrlen])[0]
header_length = 4 + addrlen
else:
logging.warn('header is too short')
else:
logging.warn('header is too short')
elif addrtype == ADDRTYPE_IPV6:
if len(data) >= 19:
dest_addr = socket.inet_ntop(socket.AF_INET6, data[1:17])
dest_port = struct.unpack('>H', data[17:19])[0]
header_length = 19
else:
logging.warn('header is too short')
else:
logging.warn('unsupported addrtype %d, maybe wrong password or '
'encryption method' % addrtype)
if dest_addr is None:
return None
return addrtype, to_bytes(dest_addr), dest_port, header_length
class IPNetwork(object):
ADDRLENGTH = {socket.AF_INET: 32, socket.AF_INET6: 128, False: 0}
def __init__(self, addrs):
self._network_list_v4 = []
self._network_list_v6 = []
if type(addrs) == str:
addrs = addrs.split(',')
list(map(self.add_network, addrs))
def add_network(self, addr):
if addr is "":
return
block = addr.split('/')
addr_family = is_ip(block[0])
addr_len = IPNetwork.ADDRLENGTH[addr_family]
if addr_family is socket.AF_INET:
ip, = struct.unpack("!I", socket.inet_aton(block[0]))
elif addr_family is socket.AF_INET6:
hi, lo = struct.unpack("!QQ", inet_pton(addr_family, block[0]))
ip = (hi << 64) | lo
else:
raise Exception("Not a valid CIDR notation: %s" % addr)
if len(block) is 1:
prefix_size = 0
while (ip & 1) == 0 and ip is not 0:
ip >>= 1
prefix_size += 1
logging.warn("You did't specify CIDR routing prefix size for %s, "
"implicit treated as %s/%d" % (addr, addr, addr_len))
elif block[1].isdigit() and int(block[1]) <= addr_len:
prefix_size = addr_len - int(block[1])
ip >>= prefix_size
else:
raise Exception("Not a valid CIDR notation: %s" % addr)
if addr_family is socket.AF_INET:
self._network_list_v4.append((ip, prefix_size))
else:
self._network_list_v6.append((ip, prefix_size))
def __contains__(self, addr):
addr_family = is_ip(addr)
if addr_family is socket.AF_INET:
ip, = struct.unpack("!I", socket.inet_aton(addr))
return any(map(lambda n_ps: n_ps[0] == ip >> n_ps[1],
self._network_list_v4))
elif addr_family is socket.AF_INET6:
hi, lo = struct.unpack("!QQ", inet_pton(addr_family, addr))
ip = (hi << 64) | lo
return any(map(lambda n_ps: n_ps[0] == ip >> n_ps[1],
self._network_list_v6))
else:
return False
def test_inet_conv():
ipv4 = b'8.8.4.4'
b = inet_pton(socket.AF_INET, ipv4)
assert inet_ntop(socket.AF_INET, b) == ipv4
ipv6 = b'2404:6800:4005:805::1011'
b = inet_pton(socket.AF_INET6, ipv6)
assert inet_ntop(socket.AF_INET6, b) == ipv6
def test_parse_header():
assert parse_header(b'\x03\x0ewww.google.com\x00\x50') == \
(3, b'www.google.com', 80, 18)
assert parse_header(b'\x01\x08\x08\x08\x08\x00\x35') == \
(1, b'8.8.8.8', 53, 7)
assert parse_header((b'\x04$\x04h\x00@\x05\x08\x05\x00\x00\x00\x00\x00'
b'\x00\x10\x11\x00\x50')) == \
(4, b'2404:6800:4005:805::1011', 80, 19)
def test_pack_header():
assert pack_addr(b'8.8.8.8') == b'\x01\x08\x08\x08\x08'
assert pack_addr(b'2404:6800:4005:805::1011') == \
b'\x04$\x04h\x00@\x05\x08\x05\x00\x00\x00\x00\x00\x00\x10\x11'
assert pack_addr(b'www.google.com') == b'\x03\x0ewww.google.com'
def test_ip_network():
ip_network = IPNetwork('127.0.0.0/24,::ff:1/112,::1,192.168.1.1,192.0.2.0')
assert '127.0.0.1' in ip_network
assert '127.0.1.1' not in ip_network
assert ':ff:ffff' in ip_network
assert '::ffff:1' not in ip_network
assert '::1' in ip_network
assert '::2' not in ip_network
assert '192.168.1.1' in ip_network
assert '192.168.1.2' not in ip_network
assert '192.0.2.1' in ip_network
assert '192.0.3.1' in ip_network # 192.0.2.0 is treated as 192.0.2.0/23
assert 'www.google.com' not in ip_network
if __name__ == '__main__':
test_inet_conv()
test_parse_header()
test_pack_header()
test_ip_network()

18
crypto/__init__.py Normal file
View file

@ -0,0 +1,18 @@
#!/usr/bin/env python
#
# Copyright 2015 clowwindy
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, \
with_statement

181
crypto/openssl.py Normal file
View file

@ -0,0 +1,181 @@
#!/usr/bin/env python
#
# Copyright 2015 clowwindy
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, \
with_statement
from ctypes import c_char_p, c_int, c_long, byref,\
create_string_buffer, c_void_p
from shadowsocks import common
from shadowsocks.crypto import util
__all__ = ['ciphers']
libcrypto = None
loaded = False
buf_size = 2048
def load_openssl():
global loaded, libcrypto, buf
libcrypto = util.find_library(('crypto', 'eay32'),
'EVP_get_cipherbyname',
'libcrypto')
if libcrypto is None:
raise Exception('libcrypto(OpenSSL) not found')
libcrypto.EVP_get_cipherbyname.restype = c_void_p
libcrypto.EVP_CIPHER_CTX_new.restype = c_void_p
libcrypto.EVP_CipherInit_ex.argtypes = (c_void_p, c_void_p, c_char_p,
c_char_p, c_char_p, c_int)
libcrypto.EVP_CipherUpdate.argtypes = (c_void_p, c_void_p, c_void_p,
c_char_p, c_int)
libcrypto.EVP_CIPHER_CTX_cleanup.argtypes = (c_void_p,)
libcrypto.EVP_CIPHER_CTX_free.argtypes = (c_void_p,)
if hasattr(libcrypto, 'OpenSSL_add_all_ciphers'):
libcrypto.OpenSSL_add_all_ciphers()
buf = create_string_buffer(buf_size)
loaded = True
def load_cipher(cipher_name):
func_name = 'EVP_' + cipher_name.replace('-', '_')
if bytes != str:
func_name = str(func_name, 'utf-8')
cipher = getattr(libcrypto, func_name, None)
if cipher:
cipher.restype = c_void_p
return cipher()
return None
class OpenSSLCrypto(object):
def __init__(self, cipher_name, key, iv, op):
self._ctx = None
if not loaded:
load_openssl()
cipher_name = common.to_bytes(cipher_name)
cipher = libcrypto.EVP_get_cipherbyname(cipher_name)
if not cipher:
cipher = load_cipher(cipher_name)
if not cipher:
raise Exception('cipher %s not found in libcrypto' % cipher_name)
key_ptr = c_char_p(key)
iv_ptr = c_char_p(iv)
self._ctx = libcrypto.EVP_CIPHER_CTX_new()
if not self._ctx:
raise Exception('can not create cipher context')
r = libcrypto.EVP_CipherInit_ex(self._ctx, cipher, None,
key_ptr, iv_ptr, c_int(op))
if not r:
self.clean()
raise Exception('can not initialize cipher context')
def update(self, data):
global buf_size, buf
cipher_out_len = c_long(0)
l = len(data)
if buf_size < l:
buf_size = l * 2
buf = create_string_buffer(buf_size)
libcrypto.EVP_CipherUpdate(self._ctx, byref(buf),
byref(cipher_out_len), c_char_p(data), l)
# buf is copied to a str object when we access buf.raw
return buf.raw[:cipher_out_len.value]
def __del__(self):
self.clean()
def clean(self):
if self._ctx:
libcrypto.EVP_CIPHER_CTX_cleanup(self._ctx)
libcrypto.EVP_CIPHER_CTX_free(self._ctx)
ciphers = {
'aes-128-cfb': (16, 16, OpenSSLCrypto),
'aes-192-cfb': (24, 16, OpenSSLCrypto),
'aes-256-cfb': (32, 16, OpenSSLCrypto),
'aes-128-ofb': (16, 16, OpenSSLCrypto),
'aes-192-ofb': (24, 16, OpenSSLCrypto),
'aes-256-ofb': (32, 16, OpenSSLCrypto),
'aes-128-ctr': (16, 16, OpenSSLCrypto),
'aes-192-ctr': (24, 16, OpenSSLCrypto),
'aes-256-ctr': (32, 16, OpenSSLCrypto),
'aes-128-cfb8': (16, 16, OpenSSLCrypto),
'aes-192-cfb8': (24, 16, OpenSSLCrypto),
'aes-256-cfb8': (32, 16, OpenSSLCrypto),
'aes-128-cfb1': (16, 16, OpenSSLCrypto),
'aes-192-cfb1': (24, 16, OpenSSLCrypto),
'aes-256-cfb1': (32, 16, OpenSSLCrypto),
'bf-cfb': (16, 8, OpenSSLCrypto),
'camellia-128-cfb': (16, 16, OpenSSLCrypto),
'camellia-192-cfb': (24, 16, OpenSSLCrypto),
'camellia-256-cfb': (32, 16, OpenSSLCrypto),
'cast5-cfb': (16, 8, OpenSSLCrypto),
'des-cfb': (8, 8, OpenSSLCrypto),
'idea-cfb': (16, 8, OpenSSLCrypto),
'rc2-cfb': (16, 8, OpenSSLCrypto),
'rc4': (16, 0, OpenSSLCrypto),
'seed-cfb': (16, 16, OpenSSLCrypto),
}
def run_method(method):
cipher = OpenSSLCrypto(method, b'k' * 32, b'i' * 16, 1)
decipher = OpenSSLCrypto(method, b'k' * 32, b'i' * 16, 0)
util.run_cipher(cipher, decipher)
def test_aes_128_cfb():
run_method('aes-128-cfb')
def test_aes_256_cfb():
run_method('aes-256-cfb')
def test_aes_128_cfb8():
run_method('aes-128-cfb8')
def test_aes_256_ofb():
run_method('aes-256-ofb')
def test_aes_256_ctr():
run_method('aes-256-ctr')
def test_bf_cfb():
run_method('bf-cfb')
def test_rc4():
run_method('rc4')
if __name__ == '__main__':
test_aes_128_cfb()

51
crypto/rc4_md5.py Normal file
View file

@ -0,0 +1,51 @@
#!/usr/bin/env python
#
# Copyright 2015 clowwindy
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, \
with_statement
import hashlib
from shadowsocks.crypto import openssl
__all__ = ['ciphers']
def create_cipher(alg, key, iv, op, key_as_bytes=0, d=None, salt=None,
i=1, padding=1):
md5 = hashlib.md5()
md5.update(key)
md5.update(iv)
rc4_key = md5.digest()
return openssl.OpenSSLCrypto(b'rc4', rc4_key, b'', op)
ciphers = {
'rc4-md5': (16, 16, create_cipher),
}
def test():
from shadowsocks.crypto import util
cipher = create_cipher('rc4-md5', b'k' * 32, b'i' * 16, 1)
decipher = create_cipher('rc4-md5', b'k' * 32, b'i' * 16, 0)
util.run_cipher(cipher, decipher)
if __name__ == '__main__':
test()

120
crypto/sodium.py Normal file
View file

@ -0,0 +1,120 @@
#!/usr/bin/env python
#
# Copyright 2015 clowwindy
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, \
with_statement
from ctypes import c_char_p, c_int, c_ulonglong, byref, \
create_string_buffer, c_void_p
from shadowsocks.crypto import util
__all__ = ['ciphers']
libsodium = None
loaded = False
buf_size = 2048
# for salsa20 and chacha20
BLOCK_SIZE = 64
def load_libsodium():
global loaded, libsodium, buf
libsodium = util.find_library('sodium', 'crypto_stream_salsa20_xor_ic',
'libsodium')
if libsodium is None:
raise Exception('libsodium not found')
libsodium.crypto_stream_salsa20_xor_ic.restype = c_int
libsodium.crypto_stream_salsa20_xor_ic.argtypes = (c_void_p, c_char_p,
c_ulonglong,
c_char_p, c_ulonglong,
c_char_p)
libsodium.crypto_stream_chacha20_xor_ic.restype = c_int
libsodium.crypto_stream_chacha20_xor_ic.argtypes = (c_void_p, c_char_p,
c_ulonglong,
c_char_p, c_ulonglong,
c_char_p)
buf = create_string_buffer(buf_size)
loaded = True
class SodiumCrypto(object):
def __init__(self, cipher_name, key, iv, op):
if not loaded:
load_libsodium()
self.key = key
self.iv = iv
self.key_ptr = c_char_p(key)
self.iv_ptr = c_char_p(iv)
if cipher_name == 'salsa20':
self.cipher = libsodium.crypto_stream_salsa20_xor_ic
elif cipher_name == 'chacha20':
self.cipher = libsodium.crypto_stream_chacha20_xor_ic
else:
raise Exception('Unknown cipher')
# byte counter, not block counter
self.counter = 0
def update(self, data):
global buf_size, buf
l = len(data)
# we can only prepend some padding to make the encryption align to
# blocks
padding = self.counter % BLOCK_SIZE
if buf_size < padding + l:
buf_size = (padding + l) * 2
buf = create_string_buffer(buf_size)
if padding:
data = (b'\0' * padding) + data
self.cipher(byref(buf), c_char_p(data), padding + l,
self.iv_ptr, int(self.counter / BLOCK_SIZE), self.key_ptr)
self.counter += l
# buf is copied to a str object when we access buf.raw
# strip off the padding
return buf.raw[padding:padding + l]
ciphers = {
'salsa20': (32, 8, SodiumCrypto),
'chacha20': (32, 8, SodiumCrypto),
}
def test_salsa20():
cipher = SodiumCrypto('salsa20', b'k' * 32, b'i' * 16, 1)
decipher = SodiumCrypto('salsa20', b'k' * 32, b'i' * 16, 0)
util.run_cipher(cipher, decipher)
def test_chacha20():
cipher = SodiumCrypto('chacha20', b'k' * 32, b'i' * 16, 1)
decipher = SodiumCrypto('chacha20', b'k' * 32, b'i' * 16, 0)
util.run_cipher(cipher, decipher)
if __name__ == '__main__':
test_chacha20()
test_salsa20()

174
crypto/table.py Normal file
View file

@ -0,0 +1,174 @@
# !/usr/bin/env python
#
# Copyright 2015 clowwindy
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, \
with_statement
import string
import struct
import hashlib
__all__ = ['ciphers']
cached_tables = {}
if hasattr(string, 'maketrans'):
maketrans = string.maketrans
translate = string.translate
else:
maketrans = bytes.maketrans
translate = bytes.translate
def get_table(key):
m = hashlib.md5()
m.update(key)
s = m.digest()
a, b = struct.unpack('<QQ', s)
table = maketrans(b'', b'')
table = [table[i: i + 1] for i in range(len(table))]
for i in range(1, 1024):
table.sort(key=lambda x: int(a % (ord(x) + i)))
return table
def init_table(key):
if key not in cached_tables:
encrypt_table = b''.join(get_table(key))
decrypt_table = maketrans(encrypt_table, maketrans(b'', b''))
cached_tables[key] = [encrypt_table, decrypt_table]
return cached_tables[key]
class TableCipher(object):
def __init__(self, cipher_name, key, iv, op):
self._encrypt_table, self._decrypt_table = init_table(key)
self._op = op
def update(self, data):
if self._op:
return translate(data, self._encrypt_table)
else:
return translate(data, self._decrypt_table)
ciphers = {
'table': (0, 0, TableCipher)
}
def test_table_result():
from shadowsocks.common import ord
target1 = [
[60, 53, 84, 138, 217, 94, 88, 23, 39, 242, 219, 35, 12, 157, 165, 181,
255, 143, 83, 247, 162, 16, 31, 209, 190, 171, 115, 65, 38, 41, 21,
245, 236, 46, 121, 62, 166, 233, 44, 154, 153, 145, 230, 49, 128, 216,
173, 29, 241, 119, 64, 229, 194, 103, 131, 110, 26, 197, 218, 59, 204,
56, 27, 34, 141, 221, 149, 239, 192, 195, 24, 155, 170, 183, 11, 254,
213, 37, 137, 226, 75, 203, 55, 19, 72, 248, 22, 129, 33, 175, 178,
10, 198, 71, 77, 36, 113, 167, 48, 2, 117, 140, 142, 66, 199, 232,
243, 32, 123, 54, 51, 82, 57, 177, 87, 251, 150, 196, 133, 5, 253,
130, 8, 184, 14, 152, 231, 3, 186, 159, 76, 89, 228, 205, 156, 96,
163, 146, 18, 91, 132, 85, 80, 109, 172, 176, 105, 13, 50, 235, 127,
0, 189, 95, 98, 136, 250, 200, 108, 179, 211, 214, 106, 168, 78, 79,
74, 210, 30, 73, 201, 151, 208, 114, 101, 174, 92, 52, 120, 240, 15,
169, 220, 182, 81, 224, 43, 185, 40, 99, 180, 17, 212, 158, 42, 90, 9,
191, 45, 6, 25, 4, 222, 67, 126, 1, 116, 124, 206, 69, 61, 7, 68, 97,
202, 63, 244, 20, 28, 58, 93, 134, 104, 144, 227, 147, 102, 118, 135,
148, 47, 238, 86, 112, 122, 70, 107, 215, 100, 139, 223, 225, 164,
237, 111, 125, 207, 160, 187, 246, 234, 161, 188, 193, 249, 252],
[151, 205, 99, 127, 201, 119, 199, 211, 122, 196, 91, 74, 12, 147, 124,
180, 21, 191, 138, 83, 217, 30, 86, 7, 70, 200, 56, 62, 218, 47, 168,
22, 107, 88, 63, 11, 95, 77, 28, 8, 188, 29, 194, 186, 38, 198, 33,
230, 98, 43, 148, 110, 177, 1, 109, 82, 61, 112, 219, 59, 0, 210, 35,
215, 50, 27, 103, 203, 212, 209, 235, 93, 84, 169, 166, 80, 130, 94,
164, 165, 142, 184, 111, 18, 2, 141, 232, 114, 6, 131, 195, 139, 176,
220, 5, 153, 135, 213, 154, 189, 238, 174, 226, 53, 222, 146, 162,
236, 158, 143, 55, 244, 233, 96, 173, 26, 206, 100, 227, 49, 178, 34,
234, 108, 207, 245, 204, 150, 44, 87, 121, 54, 140, 118, 221, 228,
155, 78, 3, 239, 101, 64, 102, 17, 223, 41, 137, 225, 229, 66, 116,
171, 125, 40, 39, 71, 134, 13, 193, 129, 247, 251, 20, 136, 242, 14,
36, 97, 163, 181, 72, 25, 144, 46, 175, 89, 145, 113, 90, 159, 190,
15, 183, 73, 123, 187, 128, 248, 252, 152, 24, 197, 68, 253, 52, 69,
117, 57, 92, 104, 157, 170, 214, 81, 60, 133, 208, 246, 172, 23, 167,
160, 192, 76, 161, 237, 45, 4, 58, 10, 182, 65, 202, 240, 185, 241,
79, 224, 132, 51, 42, 126, 105, 37, 250, 149, 32, 243, 231, 67, 179,
48, 9, 106, 216, 31, 249, 19, 85, 254, 156, 115, 255, 120, 75, 16]]
target2 = [
[124, 30, 170, 247, 27, 127, 224, 59, 13, 22, 196, 76, 72, 154, 32,
209, 4, 2, 131, 62, 101, 51, 230, 9, 166, 11, 99, 80, 208, 112, 36,
248, 81, 102, 130, 88, 218, 38, 168, 15, 241, 228, 167, 117, 158, 41,
10, 180, 194, 50, 204, 243, 246, 251, 29, 198, 219, 210, 195, 21, 54,
91, 203, 221, 70, 57, 183, 17, 147, 49, 133, 65, 77, 55, 202, 122,
162, 169, 188, 200, 190, 125, 63, 244, 96, 31, 107, 106, 74, 143, 116,
148, 78, 46, 1, 137, 150, 110, 181, 56, 95, 139, 58, 3, 231, 66, 165,
142, 242, 43, 192, 157, 89, 175, 109, 220, 128, 0, 178, 42, 255, 20,
214, 185, 83, 160, 253, 7, 23, 92, 111, 153, 26, 226, 33, 176, 144,
18, 216, 212, 28, 151, 71, 206, 222, 182, 8, 174, 205, 201, 152, 240,
155, 108, 223, 104, 239, 98, 164, 211, 184, 34, 193, 14, 114, 187, 40,
254, 12, 67, 93, 217, 6, 94, 16, 19, 82, 86, 245, 24, 197, 134, 132,
138, 229, 121, 5, 235, 238, 85, 47, 103, 113, 179, 69, 250, 45, 135,
156, 25, 61, 75, 44, 146, 189, 84, 207, 172, 119, 53, 123, 186, 120,
171, 68, 227, 145, 136, 100, 90, 48, 79, 159, 149, 39, 213, 236, 126,
52, 60, 225, 199, 105, 73, 233, 252, 118, 215, 35, 115, 64, 37, 97,
129, 161, 177, 87, 237, 141, 173, 191, 163, 140, 234, 232, 249],
[117, 94, 17, 103, 16, 186, 172, 127, 146, 23, 46, 25, 168, 8, 163, 39,
174, 67, 137, 175, 121, 59, 9, 128, 179, 199, 132, 4, 140, 54, 1, 85,
14, 134, 161, 238, 30, 241, 37, 224, 166, 45, 119, 109, 202, 196, 93,
190, 220, 69, 49, 21, 228, 209, 60, 73, 99, 65, 102, 7, 229, 200, 19,
82, 240, 71, 105, 169, 214, 194, 64, 142, 12, 233, 88, 201, 11, 72,
92, 221, 27, 32, 176, 124, 205, 189, 177, 246, 35, 112, 219, 61, 129,
170, 173, 100, 84, 242, 157, 26, 218, 20, 33, 191, 155, 232, 87, 86,
153, 114, 97, 130, 29, 192, 164, 239, 90, 43, 236, 208, 212, 185, 75,
210, 0, 81, 227, 5, 116, 243, 34, 18, 182, 70, 181, 197, 217, 95, 183,
101, 252, 248, 107, 89, 136, 216, 203, 68, 91, 223, 96, 141, 150, 131,
13, 152, 198, 111, 44, 222, 125, 244, 76, 251, 158, 106, 24, 42, 38,
77, 2, 213, 207, 249, 147, 113, 135, 245, 118, 193, 47, 98, 145, 66,
160, 123, 211, 165, 78, 204, 80, 250, 110, 162, 48, 58, 10, 180, 55,
231, 79, 149, 74, 62, 50, 148, 143, 206, 28, 15, 57, 159, 139, 225,
122, 237, 138, 171, 36, 56, 115, 63, 144, 154, 6, 230, 133, 215, 41,
184, 22, 104, 254, 234, 253, 187, 226, 247, 188, 156, 151, 40, 108,
51, 83, 178, 52, 3, 31, 255, 195, 53, 235, 126, 167, 120]]
encrypt_table = b''.join(get_table(b'foobar!'))
decrypt_table = maketrans(encrypt_table, maketrans(b'', b''))
for i in range(0, 256):
assert (target1[0][i] == ord(encrypt_table[i]))
assert (target1[1][i] == ord(decrypt_table[i]))
encrypt_table = b''.join(get_table(b'barfoo!'))
decrypt_table = maketrans(encrypt_table, maketrans(b'', b''))
for i in range(0, 256):
assert (target2[0][i] == ord(encrypt_table[i]))
assert (target2[1][i] == ord(decrypt_table[i]))
def test_encryption():
from shadowsocks.crypto import util
cipher = TableCipher('table', b'test', b'', 1)
decipher = TableCipher('table', b'test', b'', 0)
util.run_cipher(cipher, decipher)
if __name__ == '__main__':
test_table_result()
test_encryption()

138
crypto/util.py Normal file
View file

@ -0,0 +1,138 @@
#!/usr/bin/env python
#
# Copyright 2015 clowwindy
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, \
with_statement
import os
import logging
def find_library_nt(name):
# modified from ctypes.util
# ctypes.util.find_library just returns first result he found
# but we want to try them all
# because on Windows, users may have both 32bit and 64bit version installed
results = []
for directory in os.environ['PATH'].split(os.pathsep):
fname = os.path.join(directory, name)
if os.path.isfile(fname):
results.append(fname)
if fname.lower().endswith(".dll"):
continue
fname = fname + ".dll"
if os.path.isfile(fname):
results.append(fname)
return results
def find_library(possible_lib_names, search_symbol, library_name):
import ctypes.util
from ctypes import CDLL
paths = []
if type(possible_lib_names) not in (list, tuple):
possible_lib_names = [possible_lib_names]
lib_names = []
for lib_name in possible_lib_names:
lib_names.append(lib_name)
lib_names.append('lib' + lib_name)
for name in lib_names:
if os.name == "nt":
paths.extend(find_library_nt(name))
else:
path = ctypes.util.find_library(name)
if path:
paths.append(path)
if not paths:
# We may get here when find_library fails because, for example,
# the user does not have sufficient privileges to access those
# tools underlying find_library on linux.
import glob
for name in lib_names:
patterns = [
'/usr/local/lib*/lib%s.*' % name,
'/usr/lib*/lib%s.*' % name,
'lib%s.*' % name,
'%s.dll' % name]
for pat in patterns:
files = glob.glob(pat)
if files:
paths.extend(files)
for path in paths:
try:
lib = CDLL(path)
if hasattr(lib, search_symbol):
logging.info('loading %s from %s', library_name, path)
return lib
else:
logging.warn('can\'t find symbol %s in %s', search_symbol,
path)
except Exception:
pass
return None
def run_cipher(cipher, decipher):
from os import urandom
import random
import time
BLOCK_SIZE = 16384
rounds = 1 * 1024
plain = urandom(BLOCK_SIZE * rounds)
results = []
pos = 0
print('test start')
start = time.time()
while pos < len(plain):
l = random.randint(100, 32768)
c = cipher.update(plain[pos:pos + l])
results.append(c)
pos += l
pos = 0
c = b''.join(results)
results = []
while pos < len(plain):
l = random.randint(100, 32768)
results.append(decipher.update(c[pos:pos + l]))
pos += l
end = time.time()
print('speed: %d bytes/s' % (BLOCK_SIZE * rounds / (end - start)))
assert b''.join(results) == plain
def test_find_library():
assert find_library('c', 'strcpy', 'libc') is not None
assert find_library(['c'], 'strcpy', 'libc') is not None
assert find_library(('c',), 'strcpy', 'libc') is not None
assert find_library(('crypto', 'eay32'), 'EVP_CipherUpdate',
'libcrypto') is not None
assert find_library('notexist', 'strcpy', 'libnotexist') is None
assert find_library('c', 'symbol_not_exist', 'c') is None
assert find_library(('notexist', 'c', 'crypto', 'eay32'),
'EVP_CipherUpdate', 'libc') is not None
if __name__ == '__main__':
test_find_library()

187
encrypt.py Normal file
View file

@ -0,0 +1,187 @@
#!/usr/bin/env python
#
# Copyright 2012-2015 clowwindy
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, \
with_statement
import os
import sys
import hashlib
import logging
from shadowsocks import common
from shadowsocks.crypto import rc4_md5, openssl, sodium, table
method_supported = {}
method_supported.update(rc4_md5.ciphers)
method_supported.update(openssl.ciphers)
method_supported.update(sodium.ciphers)
method_supported.update(table.ciphers)
def random_string(length):
return os.urandom(length)
cached_keys = {}
def try_cipher(key, method=None):
Encryptor(key, method)
def EVP_BytesToKey(password, key_len, iv_len):
# equivalent to OpenSSL's EVP_BytesToKey() with count 1
# so that we make the same key and iv as nodejs version
cached_key = '%s-%d-%d' % (password, key_len, iv_len)
r = cached_keys.get(cached_key, None)
if r:
return r
m = []
i = 0
while len(b''.join(m)) < (key_len + iv_len):
md5 = hashlib.md5()
data = password
if i > 0:
data = m[i - 1] + password
md5.update(data)
m.append(md5.digest())
i += 1
ms = b''.join(m)
key = ms[:key_len]
iv = ms[key_len:key_len + iv_len]
cached_keys[cached_key] = (key, iv)
return key, iv
class Encryptor(object):
def __init__(self, key, method):
self.key = key
self.method = method
self.iv = None
self.iv_sent = False
self.cipher_iv = b''
self.decipher = None
method = method.lower()
self._method_info = self.get_method_info(method)
if self._method_info:
self.cipher = self.get_cipher(key, method, 1,
random_string(self._method_info[1]))
else:
logging.error('method %s not supported' % method)
sys.exit(1)
def get_method_info(self, method):
method = method.lower()
m = method_supported.get(method)
return m
def iv_len(self):
return len(self.cipher_iv)
def get_cipher(self, password, method, op, iv):
password = common.to_bytes(password)
m = self._method_info
if m[0] > 0:
key, iv_ = EVP_BytesToKey(password, m[0], m[1])
else:
# key_length == 0 indicates we should use the key directly
key, iv = password, b''
iv = iv[:m[1]]
if op == 1:
# this iv is for cipher not decipher
self.cipher_iv = iv[:m[1]]
return m[2](method, key, iv, op)
def encrypt(self, buf):
if len(buf) == 0:
return buf
if self.iv_sent:
return self.cipher.update(buf)
else:
self.iv_sent = True
return self.cipher_iv + self.cipher.update(buf)
def decrypt(self, buf):
if len(buf) == 0:
return buf
if self.decipher is None:
decipher_iv_len = self._method_info[1]
decipher_iv = buf[:decipher_iv_len]
self.decipher = self.get_cipher(self.key, self.method, 0,
iv=decipher_iv)
buf = buf[decipher_iv_len:]
if len(buf) == 0:
return buf
return self.decipher.update(buf)
def encrypt_all(password, method, op, data):
result = []
method = method.lower()
(key_len, iv_len, m) = method_supported[method]
if key_len > 0:
key, _ = EVP_BytesToKey(password, key_len, iv_len)
else:
key = password
if op:
iv = random_string(iv_len)
result.append(iv)
else:
iv = data[:iv_len]
data = data[iv_len:]
cipher = m(method, key, iv, op)
result.append(cipher.update(data))
return b''.join(result)
CIPHERS_TO_TEST = [
'aes-128-cfb',
'aes-256-cfb',
'rc4-md5',
'salsa20',
'chacha20',
'table',
]
def test_encryptor():
from os import urandom
plain = urandom(10240)
for method in CIPHERS_TO_TEST:
logging.warn(method)
encryptor = Encryptor(b'key', method)
decryptor = Encryptor(b'key', method)
cipher = encryptor.encrypt(plain)
plain2 = decryptor.decrypt(cipher)
assert plain == plain2
def test_encrypt_all():
from os import urandom
plain = urandom(10240)
for method in CIPHERS_TO_TEST:
logging.warn(method)
cipher = encrypt_all(b'key', method, 1, plain)
plain2 = encrypt_all(b'key', method, 0, cipher)
assert plain == plain2
if __name__ == '__main__':
test_encrypt_all()
test_encryptor()

111
eventloop.py Normal file
View file

@ -0,0 +1,111 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function, \
with_statement
import logging
import time
import traceback
import errno
from shadowsocks import selectors
from shadowsocks.selectors import (EVENT_READ, EVENT_WRITE, EVENT_ERROR,
errno_from_exception)
POLL_IN = EVENT_READ
POLL_OUT = EVENT_WRITE
POLL_ERR = EVENT_ERROR
TIMEOUT_PRECISION = 10
class EventLoop:
def __init__(self):
self._selector = selectors.DefaultSelector()
self._stopping = False
self._last_time = time.time()
self._periodic_callbacks = []
def poll(self, timeout=None):
return self._selector.select(timeout)
def add(self, sock, events, data):
events |= selectors.EVENT_ERROR
return self._selector.register(sock, events, data)
def remove(self, sock):
try:
return self._selector.unregister(sock)
except KeyError:
pass
def modify(self, sock, events, data):
events |= selectors.EVENT_ERROR
try:
key = self._selector.modify(sock, events, data)
except KeyError:
key = self.add(sock, events, data)
return key
def add_periodic(self, callback):
self._periodic_callbacks.append(callback)
def remove_periodic(self, callback):
self._periodic_callbacks.remove(callback)
def fd_count(self):
return len(self._selector.get_map())
def run(self):
logging.debug('Starting event loop')
while not self._stopping:
asap = False
try:
events = self.poll(timeout=TIMEOUT_PRECISION)
except (OSError, IOError) as e:
if errno_from_exception(e) in (errno.EPIPE, errno.EINTR):
# EPIPE: Happens when the client closes the connection
# EINTR: Happens when received a signal
# handles them as soon as possible
asap = True
logging.debug('poll: %s', e)
else:
logging.error('poll: %s', e)
traceback.print_exc()
continue
for key, event in events:
if type(key.data) == tuple:
handler = key.data[0]
args = key.data[1:]
else:
handler = key.data
args = ()
sock = key.fileobj
if hasattr(handler, 'handle_event'):
handler = handler.handle_event
try:
handler(sock, event, *args)
except Exception as e:
logging.debug(e)
traceback.print_exc()
raise
now = time.time()
if asap or now - self._last_time >= TIMEOUT_PRECISION:
for callback in self._periodic_callbacks:
callback()
self._last_time = now
logging.debug('Got {} fds registered'.format(self.fd_count()))
logging.debug('Stopping event loop')
self._selector.close()
def stop(self):
self._stopping = True

150
lru_cache.py Normal file
View file

@ -0,0 +1,150 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright 2015 clowwindy
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, \
with_statement
import collections
import logging
import time
# this LRUCache is optimized for concurrency, not QPS
# n: concurrency, keys stored in the cache
# m: visits not timed out, proportional to QPS * timeout
# get & set is O(1), not O(n). thus we can support very large n
# TODO: if timeout or QPS is too large, then this cache is not very efficient,
# as sweep() causes long pause
class LRUCache(collections.MutableMapping):
"""This class is not thread safe"""
def __init__(self, timeout=60, close_callback=None, *args, **kwargs):
self.timeout = timeout
self.close_callback = close_callback
self._store = {}
self._time_to_keys = collections.defaultdict(list)
self._keys_to_last_time = {}
self._last_visits = collections.deque()
self._closed_values = set()
self.update(dict(*args, **kwargs)) # use the free update to set keys
def __getitem__(self, key):
# O(1)
t = time.time()
self._keys_to_last_time[key] = t
self._time_to_keys[t].append(key)
self._last_visits.append(t)
return self._store[key]
def __setitem__(self, key, value):
# O(1)
t = time.time()
self._keys_to_last_time[key] = t
self._store[key] = value
self._time_to_keys[t].append(key)
self._last_visits.append(t)
def __delitem__(self, key):
# O(1)
del self._store[key]
del self._keys_to_last_time[key]
def __iter__(self):
return iter(self._store)
def __len__(self):
return len(self._store)
def sweep(self):
# O(m)
now = time.time()
c = 0
while len(self._last_visits) > 0:
least = self._last_visits[0]
if now - least <= self.timeout:
break
if self.close_callback is not None:
for key in self._time_to_keys[least]:
if key in self._store:
if now - self._keys_to_last_time[key] > self.timeout:
value = self._store[key]
if value not in self._closed_values:
self.close_callback(value)
self._closed_values.add(value)
for key in self._time_to_keys[least]:
self._last_visits.popleft()
if key in self._store:
if now - self._keys_to_last_time[key] > self.timeout:
del self._store[key]
del self._keys_to_last_time[key]
c += 1
del self._time_to_keys[least]
if c:
self._closed_values.clear()
logging.debug('%d keys swept' % c)
def test():
c = LRUCache(timeout=0.3)
c['a'] = 1
assert c['a'] == 1
time.sleep(0.5)
c.sweep()
assert 'a' not in c
c['a'] = 2
c['b'] = 3
time.sleep(0.2)
c.sweep()
assert c['a'] == 2
assert c['b'] == 3
time.sleep(0.2)
c.sweep()
c['b']
time.sleep(0.2)
c.sweep()
assert 'a' not in c
assert c['b'] == 3
time.sleep(0.5)
c.sweep()
assert 'a' not in c
assert 'b' not in c
global close_cb_called
close_cb_called = False
def close_cb(t):
global close_cb_called
assert not close_cb_called
close_cb_called = True
c = LRUCache(timeout=0.1, close_callback=close_cb)
c['s'] = 1
c['s']
time.sleep(0.1)
c['s']
time.sleep(0.3)
c.sweep()
if __name__ == '__main__':
test()

652
selectors.py Normal file
View file

@ -0,0 +1,652 @@
"""Selectors module.
This module allows high-level and efficient I/O multiplexing, built upon the
`select` module primitives.
"""
from abc import ABCMeta, abstractmethod
from collections import namedtuple, Mapping
import errno
import math
import os
import select
import socket
import sys
# generic events, that must be mapped to implementation-specific ones
EVENT_READ = 0x001
EVENT_WRITE = 0x004
EVENT_ERROR = 0x018
def errno_from_exception(e):
"""Provides the errno from an Exception object.
There are cases that the errno attribute was not set so we pull
the errno out of the args but if someone instatiates an Exception
without any args you will get a tuple error. So this function
abstracts all that behavior to give you a safe way to get the
errno.
"""
if hasattr(e, 'errno'):
return e.errno
elif e.args:
return e.args[0]
else:
return None
def get_sock_error(sock):
if not sock:
return
error_number = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
return socket.error(error_number, os.strerror(error_number))
def _fileobj_to_fd(fileobj):
"""Return a file descriptor from a file object.
Parameters:
fileobj -- file object or file descriptor
Returns:
corresponding file descriptor
Raises:
ValueError if the object is invalid
"""
if isinstance(fileobj, int):
fd = fileobj
else:
try:
fd = int(fileobj.fileno())
except (AttributeError, TypeError, ValueError):
raise ValueError("Invalid file object: "
"{!r}".format(fileobj))
if fd < 0:
raise ValueError("Invalid file descriptor: {}".format(fd))
return fd
SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data'])
"""Object used to associate a file object to its backing file descriptor,
selected event mask and attached data."""
class _SelectorMapping(Mapping):
"""Mapping of file objects to selector keys."""
def __init__(self, selector):
self._selector = selector
def __len__(self):
return len(self._selector._fd_to_key)
def __getitem__(self, fileobj):
try:
fd = self._selector._fileobj_lookup(fileobj)
return self._selector._fd_to_key[fd]
except KeyError:
raise KeyError("{!r} is not registered".format(fileobj))
def __iter__(self):
return iter(self._selector._fd_to_key)
class BaseSelector(object):
__metaclass__ = ABCMeta
"""Selector abstract base class.
A selector supports registering file objects to be monitored for specific
I/O events.
A file object is a file descriptor or any object with a `fileno()` method.
An arbitrary object can be attached to the file object, which can be used
for example to store context information, a callback, etc.
A selector can use various implementations (select(), poll(), epoll()...)
depending on the platform. The default `Selector` class uses the most
efficient implementation on the current platform.
"""
@abstractmethod
def register(self, fileobj, events, data=None):
"""Register a file object.
Parameters:
fileobj -- file object or file descriptor
events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE)
data -- attached data
Returns:
SelectorKey instance
Raises:
ValueError if events is invalid
KeyError if fileobj is already registered
OSError if fileobj is closed or otherwise is unacceptable to
the underlying system call (if a system call is made)
Note:
OSError may or may not be raised
"""
raise NotImplementedError
@abstractmethod
def unregister(self, fileobj):
"""Unregister a file object.
Parameters:
fileobj -- file object or file descriptor
Returns:
SelectorKey instance
Raises:
KeyError if fileobj is not registered
Note:
If fileobj is registered but has since been closed this does
*not* raise OSError (even if the wrapped syscall does)
"""
raise NotImplementedError
def modify(self, fileobj, events, data=None):
"""Change a registered file object monitored events or attached data.
Parameters:
fileobj -- file object or file descriptor
events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE)
data -- attached data
Returns:
SelectorKey instance
Raises:
Anything that unregister() or register() raises
"""
self.unregister(fileobj)
return self.register(fileobj, events, data)
@abstractmethod
def select(self, timeout=None):
"""Perform the actual selection, until some monitored file objects are
ready or a timeout expires.
Parameters:
timeout -- if timeout > 0, this specifies the maximum wait time, in
seconds
if timeout <= 0, the select() call won't block, and will
report the currently ready file objects
if timeout is None, select() will block until a monitored
file object becomes ready
Returns:
list of (key, events) for ready file objects
`events` is a bitwise mask of EVENT_READ|EVENT_WRITE
"""
raise NotImplementedError
def close(self):
"""Close the selector.
This must be called to make sure that any underlying resource is freed.
"""
pass
def get_key(self, fileobj):
"""Return the key associated to a registered file object.
Returns:
SelectorKey for this file object
"""
mapping = self.get_map()
try:
if mapping is None:
raise KeyError
return mapping[fileobj]
except KeyError:
raise KeyError("{!r} is not registered".format(fileobj))
@abstractmethod
def get_map(self):
"""Return a mapping of file objects to selector keys."""
raise NotImplementedError
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
class _BaseSelectorImpl(BaseSelector):
"""Base selector implementation."""
def __init__(self):
# this maps file descriptors to keys
self._fd_to_key = {}
# read-only mapping returned by get_map()
self._map = _SelectorMapping(self)
def _fileobj_lookup(self, fileobj):
"""Return a file descriptor from a file object.
This wraps _fileobj_to_fd() to do an exhaustive search in case
the object is invalid but we still have it in our map. This
is used by unregister() so we can unregister an object that
was previously registered even if it is closed. It is also
used by _SelectorMapping.
"""
try:
return _fileobj_to_fd(fileobj)
except ValueError:
# Do an exhaustive search.
for key in self._fd_to_key.values():
if key.fileobj is fileobj:
return key.fd
# Raise ValueError after all.
raise
def register(self, fileobj, events, data=None):
if (not events) or (events & ~(EVENT_READ | EVENT_WRITE |
EVENT_ERROR)):
raise ValueError("Invalid events: {!r}".format(events))
key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data)
if key.fd in self._fd_to_key:
raise KeyError("{!r} (FD {}) is already registered"
.format(fileobj, key.fd))
self._fd_to_key[key.fd] = key
return key
def unregister(self, fileobj):
try:
key = self._fd_to_key.pop(self._fileobj_lookup(fileobj))
except KeyError:
raise KeyError("{!r} is not registered".format(fileobj))
return key
def modify(self, fileobj, events, data=None):
# TODO: Subclasses can probably optimize this even further.
try:
key = self._fd_to_key[self._fileobj_lookup(fileobj)]
except KeyError:
raise KeyError("{!r} is not registered".format(fileobj))
if events != key.events:
self.unregister(fileobj)
key = self.register(fileobj, events, data)
elif data != key.data:
# Use a shortcut to update the data.
key = key._replace(data=data)
self._fd_to_key[key.fd] = key
return key
def close(self):
self._fd_to_key.clear()
self._map = None
def get_map(self):
return self._map
def _key_from_fd(self, fd):
"""Return the key associated to a given file descriptor.
Parameters:
fd -- file descriptor
Returns:
corresponding key, or None if not found
"""
try:
return self._fd_to_key[fd]
except KeyError:
return None
class SelectSelector(_BaseSelectorImpl):
"""Select-based selector."""
def __init__(self):
super(self.__class__, self).__init__()
self._readers = set()
self._writers = set()
self._errors = set()
def register(self, fileobj, events, data=None):
key = super(self.__class__, self).register(fileobj, events, data)
if events & EVENT_READ:
self._readers.add(key.fd)
if events & EVENT_WRITE:
self._writers.add(key.fd)
if events & EVENT_ERROR:
self._errors.add(key.fd)
return key
def unregister(self, fileobj):
key = super(self.__class__, self).unregister(fileobj)
self._readers.discard(key.fd)
self._writers.discard(key.fd)
self._errors.discard(key.fd)
return key
if sys.platform == 'win32':
def _select(self, r, w, x, timeout=None):
r, w, x = select.select(r, w, x, timeout)
return r, w, x
else:
_select = select.select
def select(self, timeout=None):
timeout = None if timeout is None else max(timeout, 0)
ready = []
try:
r, w, x = self._select(self._readers, self._writers, self._errors,
timeout)
except OSError as e:
if errno_from_exception(e) == errno.EAGAIN:
return ready
r = set(r)
w = set(w)
x = set(x)
for fd in r | w | x:
events = 0
if fd in r:
events |= EVENT_READ
if fd in w:
events |= EVENT_WRITE
if fd in x:
events |= EVENT_ERROR
key = self._key_from_fd(fd)
if key:
ready.append((key, events & key.events))
return ready
if hasattr(select, 'poll'):
class PollSelector(_BaseSelectorImpl):
"""Poll-based selector."""
def __init__(self):
super(self.__class__, self).__init__()
self._poll = select.poll()
def register(self, fileobj, events, data=None):
key = super(self.__class__, self).register(fileobj, events, data)
poll_events = 0
if events & EVENT_READ:
poll_events |= select.POLLIN
if events & EVENT_WRITE:
poll_events |= select.POLLOUT
if events & EVENT_ERROR:
poll_events |= select.POLLERR | select.POLLHUP
self._poll.register(key.fd, poll_events)
return key
def unregister(self, fileobj):
key = super(self.__class__, self).unregister(fileobj)
self._poll.unregister(key.fd)
return key
def select(self, timeout=None):
if timeout is None:
timeout = None
elif timeout <= 0:
timeout = 0
else:
# poll() has a resolution of 1 millisecond, round away from
# zero to wait *at least* timeout seconds.
timeout = math.ceil(timeout * 1e3)
ready = []
try:
fd_event_list = self._poll.poll(timeout)
except OSError as e:
if errno_from_exception(e) == errno.EAGAIN:
return ready
for fd, event in fd_event_list:
events = 0
if event & select.POLLIN:
events |= EVENT_READ
if event & select.POLLOUT:
events |= EVENT_WRITE
if event & (select.POLLERR | select.POLLHUP):
events |= EVENT_ERROR
key = self._key_from_fd(fd)
if key:
ready.append((key, events & key.events))
return ready
if hasattr(select, 'epoll'):
class EpollSelector(_BaseSelectorImpl):
"""Epoll-based selector."""
def __init__(self):
super(self.__class__, self).__init__()
self._epoll = select.epoll()
def fileno(self):
return self._epoll.fileno()
def register(self, fileobj, events, data=None):
key = super(self.__class__, self).register(fileobj, events, data)
epoll_events = 0
if events & EVENT_READ:
epoll_events |= select.EPOLLIN
if events & EVENT_WRITE:
epoll_events |= select.EPOLLOUT
if events & EVENT_ERROR:
epoll_events |= select.EPOLLERR | select.EPOLLHUP
self._epoll.register(key.fd, epoll_events)
return key
def unregister(self, fileobj):
key = super(self.__class__, self).unregister(fileobj)
try:
self._epoll.unregister(key.fd)
except OSError:
# This can happen if the FD was closed since it
# was registered.
pass
return key
def select(self, timeout=None):
if timeout is None:
timeout = -1
elif timeout <= 0:
timeout = 0
else:
# epoll_wait() has a resolution of 1 millisecond, round away
# from zero to wait *at least* timeout seconds.
timeout = math.ceil(timeout * 1e3) * 1e-3
# epoll_wait() expects `maxevents` to be greater than zero;
# we want to make sure that `select()` can be called when no
# FD is registered.
max_ev = max(len(self._fd_to_key), 1)
ready = []
try:
fd_event_list = self._epoll.poll(timeout, max_ev)
except OSError as e:
if errno_from_exception(e) == errno.EAGAIN:
return ready
for fd, event in fd_event_list:
events = 0
if event & select.EPOLLIN:
events |= EVENT_READ
if event & select.EPOLLOUT:
events |= EVENT_WRITE
if event & (select.EPOLLERR | select.EPOLLHUP):
events |= EVENT_ERROR
key = self._key_from_fd(fd)
if key:
ready.append((key, events & key.events))
return ready
def close(self):
try:
self._epoll.close()
finally:
super(self.__class__, self).close()
if hasattr(select, 'devpoll'):
class DevpollSelector(_BaseSelectorImpl):
"""Solaris /dev/poll selector."""
def __init__(self):
super().__init__()
self._devpoll = select.devpoll()
def fileno(self):
return self._devpoll.fileno()
def register(self, fileobj, events, data=None):
key = super().register(fileobj, events, data)
poll_events = 0
if events & EVENT_READ:
poll_events |= select.POLLIN
if events & EVENT_WRITE:
poll_events |= select.POLLOUT
if events & EVENT_ERROR:
poll_events |= select.POLLERR | select.POLLHUP
self._devpoll.register(key.fd, poll_events)
return key
def unregister(self, fileobj):
key = super().unregister(fileobj)
self._devpoll.unregister(key.fd)
return key
def select(self, timeout=None):
if timeout is None:
timeout = None
elif timeout <= 0:
timeout = 0
else:
# devpoll() has a resolution of 1 millisecond, round away from
# zero to wait *at least* timeout seconds.
timeout = math.ceil(timeout * 1e3)
ready = []
try:
fd_event_list = self._devpoll.poll(timeout)
except InterruptedError:
return ready
for fd, event in fd_event_list:
events = 0
if event & select.POLLIN:
events |= EVENT_READ
if event & select.POLLOUT:
events |= EVENT_WRITE
if event & (select.POLLERR | select.POLLHUP):
events |= EVENT_ERROR
key = self._key_from_fd(fd)
if key:
ready.append((key, events & key.events))
return ready
def close(self):
self._devpoll.close()
super().close()
if hasattr(select, 'kqueue'):
class KqueueSelector(_BaseSelectorImpl):
"""Kqueue-based selector."""
def __init__(self):
super(self.__class__, self).__init__()
self._kqueue = select.kqueue()
def fileno(self):
return self._kqueue.fileno()
def register(self, fileobj, events, data=None):
key = super(self.__class__, self).register(fileobj, events, data)
if events & EVENT_READ:
kev = select.kevent(key.fd, select.KQ_FILTER_READ,
select.KQ_EV_ADD)
self._kqueue.control([kev], 0, 0)
if events & EVENT_WRITE:
kev = select.kevent(key.fd, select.KQ_FILTER_WRITE,
select.KQ_EV_ADD)
self._kqueue.control([kev], 0, 0)
return key
def unregister(self, fileobj):
key = super(self.__class__, self).unregister(fileobj)
if key.events & EVENT_READ:
kev = select.kevent(key.fd, select.KQ_FILTER_READ,
select.KQ_EV_DELETE)
try:
self._kqueue.control([kev], 0, 0)
except OSError:
# This can happen if the FD was closed since it
# was registered.
pass
if key.events & EVENT_WRITE:
kev = select.kevent(key.fd, select.KQ_FILTER_WRITE,
select.KQ_EV_DELETE)
try:
self._kqueue.control([kev], 0, 0)
except OSError:
# See comment above.
pass
return key
def select(self, timeout=None):
timeout = None if timeout is None else max(timeout, 0)
max_ev = len(self._fd_to_key)
ready = []
try:
kev_list = self._kqueue.control(None, max_ev, timeout)
except OSError as e:
if errno_from_exception(e) == errno.EAGAIN:
return ready
for kev in kev_list:
fd = kev.ident
flag = kev.filter
events = 0
if flag == select.KQ_FILTER_READ:
events |= EVENT_READ
if flag == select.KQ_FILTER_WRITE:
events |= EVENT_WRITE
key = self._key_from_fd(fd)
if key:
ready.append((key, events & key.events))
return ready
def close(self):
try:
self._kqueue.close()
finally:
super(self.__class__, self).close()
# Choose the best implementation: roughly, epoll|kqueue > poll > select.
# select() also can't accept a FD > FD_SETSIZE (usually around 1024)
if 'KqueueSelector' in globals():
DefaultSelector = KqueueSelector
elif 'EpollSelector' in globals():
DefaultSelector = EpollSelector
elif 'DevpollSelector' in globals():
DefaultSelector = DevpollSelector
elif 'PollSelector' in globals():
DefaultSelector = PollSelector
else:
DefaultSelector = SelectSelector

35
server.py Normal file
View file

@ -0,0 +1,35 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function, \
with_statement
import os
import sys
import logging
path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, path)
from shadowsocks.eventloop import EventLoop
from shadowsocks.tcprelay import TcpRelay, TcpRelayServerHandler
from shadowsocks.asyncdns import DNSResolver
FORMATTER = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
LOGGING_LEVEL = logging.INFO
logging.basicConfig(level=LOGGING_LEVEL, format=FORMATTER)
LISTEN_ADDR = ('0.0.0.0', 9000)
def main():
loop = EventLoop()
dns_resolver = DNSResolver()
relay = TcpRelay(TcpRelayServerHandler, LISTEN_ADDR,
dns_resolver=dns_resolver)
dns_resolver.add_to_loop(loop)
relay.add_to_loop(loop)
loop.run()
if __name__ == '__main__':
main()

365
shell.py Normal file
View file

@ -0,0 +1,365 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright 2015 clowwindy
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, \
with_statement
import os
import json
import sys
import getopt
import logging
from shadowsocks.common import to_bytes, to_str, IPNetwork
from shadowsocks import encrypt
VERBOSE_LEVEL = 5
verbose = 0
def check_python():
info = sys.version_info
if info[0] == 2 and not info[1] >= 6:
print('Python 2.6+ required')
sys.exit(1)
elif info[0] == 3 and not info[1] >= 3:
print('Python 3.3+ required')
sys.exit(1)
elif info[0] not in [2, 3]:
print('Python version not supported')
sys.exit(1)
def print_exception(e):
global verbose
logging.error(e)
if verbose > 0:
import traceback
traceback.print_exc()
def print_shadowsocks():
version = ''
try:
import pkg_resources
version = pkg_resources.get_distribution('shadowsocks').version
except Exception:
pass
print('Shadowsocks %s' % version)
def find_config():
config_path = 'config.json'
if os.path.exists(config_path):
return config_path
config_path = os.path.join(os.path.dirname(__file__), '../', 'config.json')
if os.path.exists(config_path):
return config_path
return None
def check_config(config, is_local):
if config.get('daemon', None) == 'stop':
# no need to specify configuration for daemon stop
return
if is_local and not config.get('password', None):
logging.error('password not specified')
print_help(is_local)
sys.exit(2)
if not is_local and not config.get('password', None) \
and not config.get('port_password', None):
logging.error('password or port_password not specified')
print_help(is_local)
sys.exit(2)
if 'local_port' in config:
config['local_port'] = int(config['local_port'])
if 'server_port' in config and type(config['server_port']) != list:
config['server_port'] = int(config['server_port'])
if config.get('local_address', '') in [b'0.0.0.0']:
logging.warn('warning: local set to listen on 0.0.0.0, it\'s not safe')
if config.get('server', '') in ['127.0.0.1', 'localhost']:
logging.warn('warning: server set to listen on %s:%s, are you sure?' %
(to_str(config['server']), config['server_port']))
if (config.get('method', '') or '').lower() == 'table':
logging.warn('warning: table is not safe; please use a safer cipher, '
'like AES-256-CFB')
if (config.get('method', '') or '').lower() == 'rc4':
logging.warn('warning: RC4 is not safe; please use a safer cipher, '
'like AES-256-CFB')
if config.get('timeout', 300) < 100:
logging.warn('warning: your timeout %d seems too short' %
int(config.get('timeout')))
if config.get('timeout', 300) > 600:
logging.warn('warning: your timeout %d seems too long' %
int(config.get('timeout')))
if config.get('password') in [b'mypassword']:
logging.error('DON\'T USE DEFAULT PASSWORD! Please change it in your '
'config.json!')
sys.exit(1)
if config.get('user', None) is not None:
if os.name != 'posix':
logging.error('user can be used only on Unix')
sys.exit(1)
encrypt.try_cipher(config['password'], config['method'])
def get_config(is_local):
global verbose
logging.basicConfig(level=logging.INFO,
format='%(levelname)-s: %(message)s')
if is_local:
shortopts = 'hd:s:b:p:k:l:m:c:t:vq'
longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'user=',
'version']
else:
shortopts = 'hd:s:p:k:m:c:t:vq'
longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'workers=',
'forbidden-ip=', 'user=', 'manager-address=', 'version']
try:
config_path = find_config()
optlist, args = getopt.getopt(sys.argv[1:], shortopts, longopts)
for key, value in optlist:
if key == '-c':
config_path = value
if config_path:
logging.info('loading config from %s' % config_path)
with open(config_path, 'rb') as f:
try:
config = parse_json_in_str(f.read().decode('utf8'))
except ValueError as e:
logging.error('found an error in config.json: %s',
e.message)
sys.exit(1)
else:
config = {}
v_count = 0
for key, value in optlist:
if key == '-p':
config['server_port'] = int(value)
elif key == '-k':
config['password'] = to_bytes(value)
elif key == '-l':
config['local_port'] = int(value)
elif key == '-s':
config['server'] = to_str(value)
elif key == '-m':
config['method'] = to_str(value)
elif key == '-b':
config['local_address'] = to_str(value)
elif key == '-v':
v_count += 1
# '-vv' turns on more verbose mode
config['verbose'] = v_count
elif key == '-t':
config['timeout'] = int(value)
elif key == '--fast-open':
config['fast_open'] = True
elif key == '--workers':
config['workers'] = int(value)
elif key == '--manager-address':
config['manager_address'] = value
elif key == '--user':
config['user'] = to_str(value)
elif key == '--forbidden-ip':
config['forbidden_ip'] = to_str(value).split(',')
elif key in ('-h', '--help'):
if is_local:
print_local_help()
else:
print_server_help()
sys.exit(0)
elif key == '--version':
print_shadowsocks()
sys.exit(0)
elif key == '-d':
config['daemon'] = to_str(value)
elif key == '--pid-file':
config['pid-file'] = to_str(value)
elif key == '--log-file':
config['log-file'] = to_str(value)
elif key == '-q':
v_count -= 1
config['verbose'] = v_count
except getopt.GetoptError as e:
print(e, file=sys.stderr)
print_help(is_local)
sys.exit(2)
if not config:
logging.error('config not specified')
print_help(is_local)
sys.exit(2)
config['password'] = to_bytes(config.get('password', b''))
config['method'] = to_str(config.get('method', 'aes-256-cfb'))
config['port_password'] = config.get('port_password', None)
config['timeout'] = int(config.get('timeout', 300))
config['fast_open'] = config.get('fast_open', False)
config['workers'] = config.get('workers', 1)
config['pid-file'] = config.get('pid-file', '/var/run/shadowsocks.pid')
config['log-file'] = config.get('log-file', '/var/log/shadowsocks.log')
config['verbose'] = config.get('verbose', False)
config['local_address'] = to_str(config.get('local_address', '127.0.0.1'))
config['local_port'] = config.get('local_port', 1080)
if is_local:
if config.get('server', None) is None:
logging.error('server addr not specified')
print_local_help()
sys.exit(2)
else:
config['server'] = to_str(config['server'])
else:
config['server'] = to_str(config.get('server', '0.0.0.0'))
try:
config['forbidden_ip'] = \
IPNetwork(config.get('forbidden_ip', '127.0.0.0/8,::1/128'))
except Exception as e:
logging.error(e)
sys.exit(2)
config['server_port'] = config.get('server_port', 8388)
logging.getLogger('').handlers = []
logging.addLevelName(VERBOSE_LEVEL, 'VERBOSE')
if config['verbose'] >= 2:
level = VERBOSE_LEVEL
elif config['verbose'] == 1:
level = logging.DEBUG
elif config['verbose'] == -1:
level = logging.WARN
elif config['verbose'] <= -2:
level = logging.ERROR
else:
level = logging.INFO
verbose = config['verbose']
logging.basicConfig(level=level,
format='%(asctime)s %(levelname)-8s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
check_config(config, is_local)
return config
def print_help(is_local):
if is_local:
print_local_help()
else:
print_server_help()
def print_local_help():
print('''usage: sslocal [OPTION]...
A fast tunnel proxy that helps you bypass firewalls.
You can supply configurations via either config file or command line arguments.
Proxy options:
-c CONFIG path to config file
-s SERVER_ADDR server address
-p SERVER_PORT server port, default: 8388
-b LOCAL_ADDR local binding address, default: 127.0.0.1
-l LOCAL_PORT local port, default: 1080
-k PASSWORD password
-m METHOD encryption method, default: aes-256-cfb
-t TIMEOUT timeout in seconds, default: 300
--fast-open use TCP_FASTOPEN, requires Linux 3.7+
General options:
-h, --help show this help message and exit
-d start/stop/restart daemon mode
--pid-file PID_FILE pid file for daemon mode
--log-file LOG_FILE log file for daemon mode
--user USER username to run as
-v, -vv verbose mode
-q, -qq quiet mode, only show warnings/errors
--version show version information
Online help: <https://github.com/shadowsocks/shadowsocks>
''')
def print_server_help():
print('''usage: ssserver [OPTION]...
A fast tunnel proxy that helps you bypass firewalls.
You can supply configurations via either config file or command line arguments.
Proxy options:
-c CONFIG path to config file
-s SERVER_ADDR server address, default: 0.0.0.0
-p SERVER_PORT server port, default: 8388
-k PASSWORD password
-m METHOD encryption method, default: aes-256-cfb
-t TIMEOUT timeout in seconds, default: 300
--fast-open use TCP_FASTOPEN, requires Linux 3.7+
--workers WORKERS number of workers, available on Unix/Linux
--forbidden-ip IPLIST comma seperated IP list forbidden to connect
--manager-address ADDR optional server manager UDP address, see wiki
General options:
-h, --help show this help message and exit
-d start/stop/restart daemon mode
--pid-file PID_FILE pid file for daemon mode
--log-file LOG_FILE log file for daemon mode
--user USER username to run as
-v, -vv verbose mode
-q, -qq quiet mode, only show warnings/errors
--version show version information
Online help: <https://github.com/shadowsocks/shadowsocks>
''')
def _decode_list(data):
rv = []
for item in data:
if hasattr(item, 'encode'):
item = item.encode('utf-8')
elif isinstance(item, list):
item = _decode_list(item)
elif isinstance(item, dict):
item = _decode_dict(item)
rv.append(item)
return rv
def _decode_dict(data):
rv = {}
for key, value in data.items():
if hasattr(value, 'encode'):
value = value.encode('utf-8')
elif isinstance(value, list):
value = _decode_list(value)
elif isinstance(value, dict):
value = _decode_dict(value)
rv[key] = value
return rv
def parse_json_in_str(data):
# parse json and convert everything from unicode to str
return json.loads(data, object_hook=_decode_dict)

387
tcprelay.py Normal file
View file

@ -0,0 +1,387 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function, \
with_statement
import errno
import logging
import socket
from shadowsocks.selectors import (EVENT_READ, EVENT_WRITE, EVENT_ERROR,
errno_from_exception, get_sock_error)
from shadowsocks.common import parse_header, to_str
from shadowsocks import encrypt
BUF_SIZE = 32 * 1024
CMD_CONNECT = 1
def create_sock(ip, port):
addrs = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM,
socket.SOL_TCP)
if len(addrs) == 0:
raise Exception("Getaddrinfo failed for %s:%d" % (ip, port))
af, socktype, proto, canonname, sa = addrs[0]
sock = socket.socket(af, socktype, proto)
sock.setblocking(False)
sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
return sock
class TcpRelayHanler(object):
def __init__(self, local_sock, local_addr, remote_addr=None,
dns_resolver=None):
self._loop = None
self._local_sock = local_sock
self._local_addr = local_addr
self._remote_addr = remote_addr
# self._crypt = None
self._crypt = encrypt.Encryptor(b'PassThrouthGFW', 'aes-256-cfb')
self._dns_resolver = dns_resolver
self._remote_sock = None
self._local_sock_mode = 0
self._remote_sock_mode = 0
self._data_to_write_to_local = []
self._data_to_write_to_remote = []
self._id = id(self)
def handle_event(self, sock, event, call, *args):
if event & EVENT_ERROR:
logging.error(get_sock_error(sock))
self.close()
else:
try:
call(sock, event, *args)
except Exception as e:
logging.error(e)
self.close()
def __del__(self):
logging.debug('Deleting {}'.format(self._id))
def add_to_loop(self, loop):
if self._loop:
raise Exception('Already added to loop')
self._loop = loop
loop.add(self._local_sock, EVENT_READ, (self, self.start))
def modify_local_sock_mode(self, event):
if self._local_sock_mode != event:
self._local_sock_mode = self.modify_sock_mode(self._local_sock,
event)
def modify_remote_sock_mode(self, event):
if self._remote_sock_mode != event:
self._remote_sock_mode = self.modify_sock_mode(self._remote_sock,
event)
def modify_sock_mode(self, sock, event):
key = self._loop.modify(sock, event, (self, self.stream))
return key.events
def close_sock(self, sock):
self._loop.remove(sock)
sock.close()
def close(self):
if self._local_sock:
self.close_sock(self._local_sock)
self._local_sock = None
if self._remote_sock:
self.close_sock(self._remote_sock)
self._remote_sock = None
def sock_connect(self, sock, addr):
while True:
try:
sock.connect(addr)
except (OSError, IOError) as e:
err = errno_from_exception(e)
if err == errno.EINTR:
pass
elif err == errno.EINPROGRESS:
break
else:
raise
else:
break
def sock_recv(self, sock, size=BUF_SIZE):
try:
data = sock.recv(size)
if not data:
self.close()
except (OSError, IOError) as e:
if errno_from_exception(e) in (errno.EAGAIN, errno.EWOULDBLOCK,
errno.EINTR):
return
else:
raise
return data
def sock_send(self, sock, data):
try:
s = sock.send(data)
data = data[s:]
except (OSError, IOError) as e:
if errno_from_exception(e) in (errno.EAGAIN, errno.EWOULDBLOCK,
errno.EINPROGRESS, errno.EINTR):
pass
else:
raise
return data
def on_local_read(self, size=BUF_SIZE):
logging.debug('on_local_read')
if not self._local_sock:
return
data = self.sock_recv(self._local_sock, size)
if not data:
return
logging.debug('Received {} bytes from {}:{}'.format(len(data),
*self._local_addr))
if self._crypt:
if self._is_client:
data = self._crypt.encrypt(data)
else:
data = self._crypt.decrypt(data)
if data:
self._data_to_write_to_remote.append(data)
self.on_remote_write()
def on_remote_read(self, size=BUF_SIZE):
logging.debug('on_remote_read')
if not self._remote_sock:
return
data = self.sock_recv(self._remote_sock, size)
if not data:
return
logging.debug('Received {} bytes from {}:{}'.format(
len(data), *self._remote_addr))
if self._crypt:
if self._is_client:
data = self._crypt.decrypt(data)
else:
data = self._crypt.encrypt(data)
if data:
self._data_to_write_to_local.append(data)
self.on_local_write()
def on_local_write(self):
logging.debug('on_local_write')
if not self._local_sock:
return
if not self._data_to_write_to_local:
self.modify_local_sock_mode(EVENT_READ)
return
data = b''.join(self._data_to_write_to_local)
self._data_to_write_to_local = []
data = self.sock_send(self._local_sock, data)
if data:
self._data_to_write_to_local.append(data)
self.modify_local_sock_mode(EVENT_WRITE)
else:
self.modify_local_sock_mode(EVENT_READ)
def on_remote_write(self):
logging.debug('on_remote_write')
if not self._remote_sock:
return
if not self._data_to_write_to_remote:
self.modify_remote_sock_mode(EVENT_READ)
return
data = b''.join(self._data_to_write_to_remote)
self._data_to_write_to_remote = []
data = self.sock_send(self._remote_sock, data)
if data:
self._data_to_write_to_remote.append(data)
self.modify_remote_sock_mode(EVENT_WRITE)
else:
self.modify_remote_sock_mode(EVENT_READ)
def stream(self, sock, event):
logging.debug('stream')
if sock == self._local_sock:
if event & EVENT_READ:
self.on_local_read()
if event & EVENT_WRITE:
self.on_local_write()
elif sock == self._remote_sock:
if event & EVENT_READ:
self.on_remote_read()
if event & EVENT_WRITE:
self.on_remote_write()
else:
logging.warn('Unknow sock {}'.format(sock))
class TcpRelayClientHanler(TcpRelayHanler):
_is_client = True
def start(self, sock, event):
data = self.sock_recv(sock)
if not data:
return
reply = b'\x05\x00'
self.send_reply(sock, None, reply)
def send_reply(self, sock, event, data):
data = self.sock_send(sock, data)
if data:
self._loop.modify(sock, EVENT_WRITE, (self, self.send_reply, data))
else:
self._loop.modify(sock, EVENT_READ, (self, self.handle_addr))
def handle_addr(self, sock, event):
data = self.sock_recv(sock)
if not data:
return
# self._loop.remove(sock)
if ord(data[1:2]) != CMD_CONNECT:
raise Exception('Command not suppored')
result = parse_header(data[3:])
if not result:
raise Exception('Header cannot be parsed')
self._remote_sock = create_sock(*self._remote_addr)
self.sock_connect(self._remote_sock, self._remote_addr)
dest_addr = (to_str(result[1]), result[2])
logging.info('Connecting to {}:{}'.format(*dest_addr))
data = '{}:{}\n'.format(*dest_addr).encode('utf-8')
if self._crypt:
data = self._crypt.encrypt(data)
self._data_to_write_to_remote.append(data)
bind_addr = b'\x05\x00\x00\x01\x00\x00\x00\x00\x00\x00'
self.send_bind_addr(sock, None, bind_addr)
def send_bind_addr(self, sock, event, data):
data = self.sock_send(sock, data)
if data:
self._loop.modify(sock, EVENT_WRITE, (self, self.send_bind_addr,
data))
else:
self.modify_local_sock_mode(EVENT_READ)
class TcpRelayServerHandler(TcpRelayHanler):
_is_client = False
def start(self, sock, event, data=None):
data = self.sock_recv(sock)
if not data:
return
self._loop.remove(sock)
if self._crypt:
data = self._crypt.decrypt(data)
remote, data = data.split(b'\n', 1)
host, port = remote.split(b':')
self._remote = (host, int(port))
self._data_to_write_to_remote.append(data)
self._dns_resolver.resolve(host, self.dns_resolved)
def dns_resolved(self, result, error):
try:
ip = result[1]
except (TypeError, IndexError):
ip = None
if not ip:
raise Exception('Hostname {} cannot resolved'.format(
self._remote[0]))
self._remote_addr = (ip, self._remote[1])
self._remote_sock = create_sock(*self._remote_addr)
logging.info('Connecting to {}'.format(self._remote[0]))
self.sock_connect(self._remote_sock, self._remote_addr)
self.modify_remote_sock_mode(EVENT_WRITE)
self.modify_local_sock_mode(EVENT_READ)
class TcpRelay(object):
def __init__(self, handler_type, listen_addr, remote_addr=None,
dns_resolver=None):
self._loop = None
self._handler_type = handler_type
self._listen_addr = listen_addr
self._remote_addr = remote_addr
self._dns_resolver = dns_resolver
self._create_listen_sock()
def _create_listen_sock(self):
sock = create_sock(*self._listen_addr)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(self._listen_addr)
sock.listen(1024)
self._listen_sock = sock
logging.info('Listening on {}:{}'.format(*self._listen_addr))
def add_to_loop(self, loop):
if self._loop:
raise Exception('Already added to loop')
self._loop = loop
loop.add(self._listen_sock, EVENT_READ, (self, self.accept))
def _accept(self, listen_sock):
try:
sock, addr = listen_sock.accept()
except (OSError, IOError) as e:
if errno_from_exception(e) in (
errno.EAGAIN, errno.EWOULDBLOCK, errno.EINPROGRESS,
errno.EINTR, errno.ECONNABORTED
):
pass
else:
raise
logging.info('Connected from {}:{}'.format(*addr))
sock.setblocking(False)
sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
return (sock, addr)
def accept(self, listen_sock, event):
sock, addr = self._accept(listen_sock)
handler = self._handler_type(sock, addr, self._remote_addr,
self._dns_resolver)
handler.add_to_loop(self._loop)
def close(self):
self._loop.remove(self._listen_sock)
def handle_event(self, sock, event, call, *args):
if event & EVENT_ERROR:
logging.error(get_sock_error(sock))
self.close()
else:
try:
call(sock, event, *args)
except Exception as e:
logging.error(e)
self.close()