Merge cee1b0be33
into 938bba32a4
This commit is contained in:
commit
5a89d7429c
19 changed files with 3444 additions and 1 deletions
61
.gitignore
vendored
Normal file
61
.gitignore
vendored
Normal file
|
@ -0,0 +1,61 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
env/
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# database file
|
||||
*.sqlite
|
||||
*.sqlite3
|
||||
*.db
|
||||
|
||||
# temporary file
|
||||
*.swp
|
|
@ -1 +1 @@
|
|||
Removed according to regulations.
|
||||
Just for learn.
|
||||
|
|
18
__init__.py
Normal file
18
__init__.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
#!/usr/bin/python
|
||||
#
|
||||
# Copyright 2012-2015 clowwindy
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, \
|
||||
with_statement
|
481
asyncdns.py
Normal file
481
asyncdns.py
Normal file
|
@ -0,0 +1,481 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2014-2015 clowwindy
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, \
|
||||
with_statement
|
||||
|
||||
import os
|
||||
import socket
|
||||
import struct
|
||||
import re
|
||||
import logging
|
||||
|
||||
from shadowsocks import common, lru_cache, eventloop, shell
|
||||
|
||||
|
||||
CACHE_SWEEP_INTERVAL = 30
|
||||
|
||||
VALID_HOSTNAME = re.compile(br"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE)
|
||||
|
||||
common.patch_socket()
|
||||
|
||||
# rfc1035
|
||||
# format
|
||||
# +---------------------+
|
||||
# | Header |
|
||||
# +---------------------+
|
||||
# | Question | the question for the name server
|
||||
# +---------------------+
|
||||
# | Answer | RRs answering the question
|
||||
# +---------------------+
|
||||
# | Authority | RRs pointing toward an authority
|
||||
# +---------------------+
|
||||
# | Additional | RRs holding additional information
|
||||
# +---------------------+
|
||||
#
|
||||
# header
|
||||
# 1 1 1 1 1 1
|
||||
# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
||||
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
||||
# | ID |
|
||||
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
||||
# |QR| Opcode |AA|TC|RD|RA| Z | RCODE |
|
||||
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
||||
# | QDCOUNT |
|
||||
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
||||
# | ANCOUNT |
|
||||
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
||||
# | NSCOUNT |
|
||||
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
||||
# | ARCOUNT |
|
||||
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
||||
|
||||
QTYPE_ANY = 255
|
||||
QTYPE_A = 1
|
||||
QTYPE_AAAA = 28
|
||||
QTYPE_CNAME = 5
|
||||
QTYPE_NS = 2
|
||||
QCLASS_IN = 1
|
||||
|
||||
|
||||
def build_address(address):
|
||||
address = address.strip(b'.')
|
||||
labels = address.split(b'.')
|
||||
results = []
|
||||
for label in labels:
|
||||
l = len(label)
|
||||
if l > 63:
|
||||
return None
|
||||
results.append(common.chr(l))
|
||||
results.append(label)
|
||||
results.append(b'\0')
|
||||
return b''.join(results)
|
||||
|
||||
|
||||
def build_request(address, qtype):
|
||||
request_id = os.urandom(2)
|
||||
header = struct.pack('!BBHHHH', 1, 0, 1, 0, 0, 0)
|
||||
addr = build_address(address)
|
||||
qtype_qclass = struct.pack('!HH', qtype, QCLASS_IN)
|
||||
return request_id + header + addr + qtype_qclass
|
||||
|
||||
|
||||
def parse_ip(addrtype, data, length, offset):
|
||||
if addrtype == QTYPE_A:
|
||||
return socket.inet_ntop(socket.AF_INET, data[offset:offset + length])
|
||||
elif addrtype == QTYPE_AAAA:
|
||||
return socket.inet_ntop(socket.AF_INET6, data[offset:offset + length])
|
||||
elif addrtype in [QTYPE_CNAME, QTYPE_NS]:
|
||||
return parse_name(data, offset)[1]
|
||||
else:
|
||||
return data[offset:offset + length]
|
||||
|
||||
|
||||
def parse_name(data, offset):
|
||||
p = offset
|
||||
labels = []
|
||||
l = common.ord(data[p])
|
||||
while l > 0:
|
||||
if (l & (128 + 64)) == (128 + 64):
|
||||
# pointer
|
||||
pointer = struct.unpack('!H', data[p:p + 2])[0]
|
||||
pointer &= 0x3FFF
|
||||
r = parse_name(data, pointer)
|
||||
labels.append(r[1])
|
||||
p += 2
|
||||
# pointer is the end
|
||||
return p - offset, b'.'.join(labels)
|
||||
else:
|
||||
labels.append(data[p + 1:p + 1 + l])
|
||||
p += 1 + l
|
||||
l = common.ord(data[p])
|
||||
return p - offset + 1, b'.'.join(labels)
|
||||
|
||||
|
||||
# rfc1035
|
||||
# record
|
||||
# 1 1 1 1 1 1
|
||||
# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
||||
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
||||
# | |
|
||||
# / /
|
||||
# / NAME /
|
||||
# | |
|
||||
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
||||
# | TYPE |
|
||||
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
||||
# | CLASS |
|
||||
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
||||
# | TTL |
|
||||
# | |
|
||||
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
||||
# | RDLENGTH |
|
||||
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--|
|
||||
# / RDATA /
|
||||
# / /
|
||||
# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
||||
def parse_record(data, offset, question=False):
|
||||
nlen, name = parse_name(data, offset)
|
||||
if not question:
|
||||
record_type, record_class, record_ttl, record_rdlength = struct.unpack(
|
||||
'!HHiH', data[offset + nlen:offset + nlen + 10]
|
||||
)
|
||||
ip = parse_ip(record_type, data, record_rdlength, offset + nlen + 10)
|
||||
return nlen + 10 + record_rdlength, \
|
||||
(name, ip, record_type, record_class, record_ttl)
|
||||
else:
|
||||
record_type, record_class = struct.unpack(
|
||||
'!HH', data[offset + nlen:offset + nlen + 4]
|
||||
)
|
||||
return nlen + 4, (name, None, record_type, record_class, None, None)
|
||||
|
||||
|
||||
def parse_header(data):
|
||||
if len(data) >= 12:
|
||||
header = struct.unpack('!HBBHHHH', data[:12])
|
||||
res_id = header[0]
|
||||
res_qr = header[1] & 128
|
||||
res_tc = header[1] & 2
|
||||
res_ra = header[2] & 128
|
||||
res_rcode = header[2] & 15
|
||||
# assert res_tc == 0
|
||||
# assert res_rcode in [0, 3]
|
||||
res_qdcount = header[3]
|
||||
res_ancount = header[4]
|
||||
res_nscount = header[5]
|
||||
res_arcount = header[6]
|
||||
return (res_id, res_qr, res_tc, res_ra, res_rcode, res_qdcount,
|
||||
res_ancount, res_nscount, res_arcount)
|
||||
return None
|
||||
|
||||
|
||||
def parse_response(data):
|
||||
try:
|
||||
if len(data) >= 12:
|
||||
header = parse_header(data)
|
||||
if not header:
|
||||
return None
|
||||
res_id, res_qr, res_tc, res_ra, res_rcode, res_qdcount, \
|
||||
res_ancount, res_nscount, res_arcount = header
|
||||
|
||||
qds = []
|
||||
ans = []
|
||||
offset = 12
|
||||
for i in range(0, res_qdcount):
|
||||
l, r = parse_record(data, offset, True)
|
||||
offset += l
|
||||
if r:
|
||||
qds.append(r)
|
||||
for i in range(0, res_ancount):
|
||||
l, r = parse_record(data, offset)
|
||||
offset += l
|
||||
if r:
|
||||
ans.append(r)
|
||||
for i in range(0, res_nscount):
|
||||
l, r = parse_record(data, offset)
|
||||
offset += l
|
||||
for i in range(0, res_arcount):
|
||||
l, r = parse_record(data, offset)
|
||||
offset += l
|
||||
response = DNSResponse()
|
||||
if qds:
|
||||
response.hostname = qds[0][0]
|
||||
for an in qds:
|
||||
response.questions.append((an[1], an[2], an[3]))
|
||||
for an in ans:
|
||||
response.answers.append((an[1], an[2], an[3]))
|
||||
return response
|
||||
except Exception as e:
|
||||
shell.print_exception(e)
|
||||
return None
|
||||
|
||||
|
||||
def is_valid_hostname(hostname):
|
||||
if len(hostname) > 255:
|
||||
return False
|
||||
if hostname[-1] == b'.':
|
||||
hostname = hostname[:-1]
|
||||
return all(VALID_HOSTNAME.match(x) for x in hostname.split(b'.'))
|
||||
|
||||
|
||||
class DNSResponse(object):
|
||||
def __init__(self):
|
||||
self.hostname = None
|
||||
self.questions = [] # each: (addr, type, class)
|
||||
self.answers = [] # each: (addr, type, class)
|
||||
|
||||
def __str__(self):
|
||||
return '%s: %s' % (self.hostname, str(self.answers))
|
||||
|
||||
|
||||
STATUS_IPV4 = 0
|
||||
STATUS_IPV6 = 1
|
||||
|
||||
|
||||
class DNSResolver(object):
|
||||
|
||||
def __init__(self):
|
||||
self._loop = None
|
||||
self._hosts = {}
|
||||
self._hostname_status = {}
|
||||
self._hostname_to_cb = {}
|
||||
self._cb_to_hostname = {}
|
||||
self._cache = lru_cache.LRUCache(timeout=300)
|
||||
self._sock = None
|
||||
self._servers = None
|
||||
self._parse_resolv()
|
||||
self._parse_hosts()
|
||||
# TODO monitor hosts change and reload hosts
|
||||
# TODO parse /etc/gai.conf and follow its rules
|
||||
|
||||
def _parse_resolv(self):
|
||||
self._servers = []
|
||||
try:
|
||||
with open('/etc/resolv.conf', 'rb') as f:
|
||||
content = f.readlines()
|
||||
for line in content:
|
||||
line = line.strip()
|
||||
if line:
|
||||
if line.startswith(b'nameserver'):
|
||||
parts = line.split()
|
||||
if len(parts) >= 2:
|
||||
server = parts[1]
|
||||
if common.is_ip(server) == socket.AF_INET:
|
||||
if type(server) != str:
|
||||
server = server.decode('utf8')
|
||||
self._servers.append(server)
|
||||
except IOError:
|
||||
pass
|
||||
if not self._servers:
|
||||
self._servers = ['8.8.4.4', '8.8.8.8']
|
||||
|
||||
def _parse_hosts(self):
|
||||
etc_path = '/etc/hosts'
|
||||
if 'WINDIR' in os.environ:
|
||||
etc_path = os.environ['WINDIR'] + '/system32/drivers/etc/hosts'
|
||||
try:
|
||||
with open(etc_path, 'rb') as f:
|
||||
for line in f.readlines():
|
||||
line = line.strip()
|
||||
parts = line.split()
|
||||
if len(parts) >= 2:
|
||||
ip = parts[0]
|
||||
if common.is_ip(ip):
|
||||
for i in range(1, len(parts)):
|
||||
hostname = parts[i]
|
||||
if hostname:
|
||||
self._hosts[hostname] = ip
|
||||
except IOError:
|
||||
self._hosts['localhost'] = '127.0.0.1'
|
||||
|
||||
def add_to_loop(self, loop):
|
||||
if self._loop:
|
||||
raise Exception('already add to loop')
|
||||
self._loop = loop
|
||||
# TODO when dns server is IPv6
|
||||
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM,
|
||||
socket.SOL_UDP)
|
||||
self._sock.setblocking(False)
|
||||
loop.add(self._sock, eventloop.POLL_IN, self)
|
||||
loop.add_periodic(self.handle_periodic)
|
||||
|
||||
def _call_callback(self, hostname, ip, error=None):
|
||||
callbacks = self._hostname_to_cb.get(hostname, [])
|
||||
for callback in callbacks:
|
||||
if callback in self._cb_to_hostname:
|
||||
del self._cb_to_hostname[callback]
|
||||
if ip or error:
|
||||
callback((hostname, ip), error)
|
||||
else:
|
||||
callback((hostname, None),
|
||||
Exception('unknown hostname %s' % hostname))
|
||||
if hostname in self._hostname_to_cb:
|
||||
del self._hostname_to_cb[hostname]
|
||||
if hostname in self._hostname_status:
|
||||
del self._hostname_status[hostname]
|
||||
|
||||
def _handle_data(self, data):
|
||||
response = parse_response(data)
|
||||
if response and response.hostname:
|
||||
hostname = response.hostname
|
||||
ip = None
|
||||
for answer in response.answers:
|
||||
if answer[1] in (QTYPE_A, QTYPE_AAAA) and \
|
||||
answer[2] == QCLASS_IN:
|
||||
ip = answer[0]
|
||||
break
|
||||
if not ip and self._hostname_status.get(hostname, STATUS_IPV6) \
|
||||
== STATUS_IPV4:
|
||||
self._hostname_status[hostname] = STATUS_IPV6
|
||||
self._send_req(hostname, QTYPE_AAAA)
|
||||
else:
|
||||
if ip:
|
||||
self._cache[hostname] = ip
|
||||
self._call_callback(hostname, ip)
|
||||
elif self._hostname_status.get(hostname, None) == STATUS_IPV6:
|
||||
for question in response.questions:
|
||||
if question[1] == QTYPE_AAAA:
|
||||
self._call_callback(hostname, None)
|
||||
break
|
||||
|
||||
def handle_event(self, sock, event):
|
||||
if sock != self._sock:
|
||||
return
|
||||
if event & eventloop.POLL_ERR:
|
||||
logging.error('dns socket err')
|
||||
self._loop.remove(self._sock)
|
||||
self._sock.close()
|
||||
# TODO when dns server is IPv6
|
||||
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM,
|
||||
socket.SOL_UDP)
|
||||
self._sock.setblocking(False)
|
||||
self._loop.add(self._sock, eventloop.POLL_IN, self)
|
||||
else:
|
||||
data, addr = sock.recvfrom(1024)
|
||||
if addr[0] not in self._servers:
|
||||
logging.warn('received a packet other than our dns')
|
||||
return
|
||||
self._handle_data(data)
|
||||
|
||||
def handle_periodic(self):
|
||||
self._cache.sweep()
|
||||
|
||||
def remove_callback(self, callback):
|
||||
hostname = self._cb_to_hostname.get(callback)
|
||||
if hostname:
|
||||
del self._cb_to_hostname[callback]
|
||||
arr = self._hostname_to_cb.get(hostname, None)
|
||||
if arr:
|
||||
arr.remove(callback)
|
||||
if not arr:
|
||||
del self._hostname_to_cb[hostname]
|
||||
if hostname in self._hostname_status:
|
||||
del self._hostname_status[hostname]
|
||||
|
||||
def _send_req(self, hostname, qtype):
|
||||
req = build_request(hostname, qtype)
|
||||
for server in self._servers:
|
||||
logging.debug('resolving %s with type %d using server %s',
|
||||
hostname, qtype, server)
|
||||
self._sock.sendto(req, (server, 53))
|
||||
|
||||
def resolve(self, hostname, callback):
|
||||
if type(hostname) != bytes:
|
||||
hostname = hostname.encode('utf8')
|
||||
if not hostname:
|
||||
callback(None, Exception('empty hostname'))
|
||||
elif common.is_ip(hostname):
|
||||
callback((hostname, hostname), None)
|
||||
elif hostname in self._hosts:
|
||||
logging.debug('hit hosts: %s', hostname)
|
||||
ip = self._hosts[hostname]
|
||||
callback((hostname, ip), None)
|
||||
elif hostname in self._cache:
|
||||
logging.debug('hit cache: %s', hostname)
|
||||
ip = self._cache[hostname]
|
||||
callback((hostname, ip), None)
|
||||
else:
|
||||
if not is_valid_hostname(hostname):
|
||||
callback(None, Exception('invalid hostname: %s' % hostname))
|
||||
return
|
||||
arr = self._hostname_to_cb.get(hostname, None)
|
||||
if not arr:
|
||||
self._hostname_status[hostname] = STATUS_IPV4
|
||||
self._send_req(hostname, QTYPE_A)
|
||||
self._hostname_to_cb[hostname] = [callback]
|
||||
self._cb_to_hostname[callback] = hostname
|
||||
else:
|
||||
arr.append(callback)
|
||||
# TODO send again only if waited too long
|
||||
self._send_req(hostname, QTYPE_A)
|
||||
|
||||
def close(self):
|
||||
if self._sock:
|
||||
if self._loop:
|
||||
self._loop.remove_periodic(self.handle_periodic)
|
||||
self._loop.remove(self._sock)
|
||||
self._sock.close()
|
||||
self._sock = None
|
||||
|
||||
|
||||
def test():
|
||||
dns_resolver = DNSResolver()
|
||||
loop = eventloop.EventLoop()
|
||||
dns_resolver.add_to_loop(loop)
|
||||
|
||||
global counter
|
||||
counter = 0
|
||||
|
||||
def make_callback():
|
||||
global counter
|
||||
|
||||
def callback(result, error):
|
||||
global counter
|
||||
# TODO: what can we assert?
|
||||
print(result, error)
|
||||
counter += 1
|
||||
if counter == 9:
|
||||
dns_resolver.close()
|
||||
loop.stop()
|
||||
a_callback = callback
|
||||
return a_callback
|
||||
|
||||
assert(make_callback() != make_callback())
|
||||
|
||||
dns_resolver.resolve(b'google.com', make_callback())
|
||||
dns_resolver.resolve('google.com', make_callback())
|
||||
dns_resolver.resolve('example.com', make_callback())
|
||||
dns_resolver.resolve('ipv6.google.com', make_callback())
|
||||
dns_resolver.resolve('www.facebook.com', make_callback())
|
||||
dns_resolver.resolve('ns2.google.com', make_callback())
|
||||
dns_resolver.resolve('invalid.@!#$%^&$@.hostname', make_callback())
|
||||
dns_resolver.resolve('toooooooooooooooooooooooooooooooooooooooooooooooooo'
|
||||
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
|
||||
'long.hostname', make_callback())
|
||||
dns_resolver.resolve('toooooooooooooooooooooooooooooooooooooooooooooooooo'
|
||||
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
|
||||
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
|
||||
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
|
||||
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
|
||||
'ooooooooooooooooooooooooooooooooooooooooooooooooooo'
|
||||
'long.hostname', make_callback())
|
||||
|
||||
loop.run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
33
client.py
Normal file
33
client.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import absolute_import, division, print_function, \
|
||||
with_statement
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, path)
|
||||
from shadowsocks.eventloop import EventLoop
|
||||
from shadowsocks.tcprelay import TcpRelay, TcpRelayClientHanler
|
||||
|
||||
|
||||
FORMATTER = '%(asctime)s - %(levelname)s - %(message)s'
|
||||
LOGGING_LEVEL = logging.INFO
|
||||
logging.basicConfig(level=LOGGING_LEVEL, format=FORMATTER)
|
||||
|
||||
LISTEN_ADDR = ('127.0.0.1', 1080)
|
||||
REMOTE_ADDR = ('127.0.0.1', 9000)
|
||||
|
||||
|
||||
def main():
|
||||
loop = EventLoop()
|
||||
relay = TcpRelay(TcpRelayClientHanler, LISTEN_ADDR, REMOTE_ADDR)
|
||||
relay.add_to_loop(loop)
|
||||
loop.run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
281
common.py
Normal file
281
common.py
Normal file
|
@ -0,0 +1,281 @@
|
|||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2013-2015 clowwindy
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, \
|
||||
with_statement
|
||||
|
||||
import socket
|
||||
import struct
|
||||
import logging
|
||||
|
||||
|
||||
def compat_ord(s):
|
||||
if type(s) == int:
|
||||
return s
|
||||
return _ord(s)
|
||||
|
||||
|
||||
def compat_chr(d):
|
||||
if bytes == str:
|
||||
return _chr(d)
|
||||
return bytes([d])
|
||||
|
||||
|
||||
_ord = ord
|
||||
_chr = chr
|
||||
ord = compat_ord
|
||||
chr = compat_chr
|
||||
|
||||
|
||||
def to_bytes(s):
|
||||
if bytes != str:
|
||||
if type(s) == str:
|
||||
return s.encode('utf-8')
|
||||
return s
|
||||
|
||||
|
||||
def to_str(s):
|
||||
if bytes != str:
|
||||
if type(s) == bytes:
|
||||
return s.decode('utf-8')
|
||||
return s
|
||||
|
||||
|
||||
def inet_ntop(family, ipstr):
|
||||
if family == socket.AF_INET:
|
||||
return to_bytes(socket.inet_ntoa(ipstr))
|
||||
elif family == socket.AF_INET6:
|
||||
import re
|
||||
v6addr = ':'.join(('%02X%02X' % (ord(i), ord(j))).lstrip('0')
|
||||
for i, j in zip(ipstr[::2], ipstr[1::2]))
|
||||
v6addr = re.sub('::+', '::', v6addr, count=1)
|
||||
return to_bytes(v6addr)
|
||||
|
||||
|
||||
def inet_pton(family, addr):
|
||||
addr = to_str(addr)
|
||||
if family == socket.AF_INET:
|
||||
return socket.inet_aton(addr)
|
||||
elif family == socket.AF_INET6:
|
||||
if '.' in addr: # a v4 addr
|
||||
v4addr = addr[addr.rindex(':') + 1:]
|
||||
v4addr = socket.inet_aton(v4addr)
|
||||
v4addr = map(lambda x: ('%02X' % ord(x)), v4addr)
|
||||
v4addr.insert(2, ':')
|
||||
newaddr = addr[:addr.rindex(':') + 1] + ''.join(v4addr)
|
||||
return inet_pton(family, newaddr)
|
||||
dbyts = [0] * 8 # 8 groups
|
||||
grps = addr.split(':')
|
||||
for i, v in enumerate(grps):
|
||||
if v:
|
||||
dbyts[i] = int(v, 16)
|
||||
else:
|
||||
for j, w in enumerate(grps[::-1]):
|
||||
if w:
|
||||
dbyts[7 - j] = int(w, 16)
|
||||
else:
|
||||
break
|
||||
break
|
||||
return b''.join((chr(i // 256) + chr(i % 256)) for i in dbyts)
|
||||
else:
|
||||
raise RuntimeError("What family?")
|
||||
|
||||
|
||||
def is_ip(address):
|
||||
for family in (socket.AF_INET, socket.AF_INET6):
|
||||
try:
|
||||
if type(address) != str:
|
||||
address = address.decode('utf8')
|
||||
inet_pton(family, address)
|
||||
return family
|
||||
except (TypeError, ValueError, OSError, IOError):
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def patch_socket():
|
||||
if not hasattr(socket, 'inet_pton'):
|
||||
socket.inet_pton = inet_pton
|
||||
|
||||
if not hasattr(socket, 'inet_ntop'):
|
||||
socket.inet_ntop = inet_ntop
|
||||
|
||||
|
||||
patch_socket()
|
||||
|
||||
|
||||
ADDRTYPE_IPV4 = 1
|
||||
ADDRTYPE_IPV6 = 4
|
||||
ADDRTYPE_HOST = 3
|
||||
|
||||
|
||||
def pack_addr(address):
|
||||
address_str = to_str(address)
|
||||
for family in (socket.AF_INET, socket.AF_INET6):
|
||||
try:
|
||||
r = socket.inet_pton(family, address_str)
|
||||
if family == socket.AF_INET6:
|
||||
return b'\x04' + r
|
||||
else:
|
||||
return b'\x01' + r
|
||||
except (TypeError, ValueError, OSError, IOError):
|
||||
pass
|
||||
if len(address) > 255:
|
||||
address = address[:255] # TODO
|
||||
return b'\x03' + chr(len(address)) + address
|
||||
|
||||
|
||||
def parse_header(data):
|
||||
addrtype = ord(data[0])
|
||||
dest_addr = None
|
||||
dest_port = None
|
||||
header_length = 0
|
||||
if addrtype == ADDRTYPE_IPV4:
|
||||
if len(data) >= 7:
|
||||
dest_addr = socket.inet_ntoa(data[1:5])
|
||||
dest_port = struct.unpack('>H', data[5:7])[0]
|
||||
header_length = 7
|
||||
else:
|
||||
logging.warn('header is too short')
|
||||
elif addrtype == ADDRTYPE_HOST:
|
||||
if len(data) > 2:
|
||||
addrlen = ord(data[1])
|
||||
if len(data) >= 2 + addrlen:
|
||||
dest_addr = data[2:2 + addrlen]
|
||||
dest_port = struct.unpack('>H', data[2 + addrlen:4 +
|
||||
addrlen])[0]
|
||||
header_length = 4 + addrlen
|
||||
else:
|
||||
logging.warn('header is too short')
|
||||
else:
|
||||
logging.warn('header is too short')
|
||||
elif addrtype == ADDRTYPE_IPV6:
|
||||
if len(data) >= 19:
|
||||
dest_addr = socket.inet_ntop(socket.AF_INET6, data[1:17])
|
||||
dest_port = struct.unpack('>H', data[17:19])[0]
|
||||
header_length = 19
|
||||
else:
|
||||
logging.warn('header is too short')
|
||||
else:
|
||||
logging.warn('unsupported addrtype %d, maybe wrong password or '
|
||||
'encryption method' % addrtype)
|
||||
if dest_addr is None:
|
||||
return None
|
||||
return addrtype, to_bytes(dest_addr), dest_port, header_length
|
||||
|
||||
|
||||
class IPNetwork(object):
|
||||
ADDRLENGTH = {socket.AF_INET: 32, socket.AF_INET6: 128, False: 0}
|
||||
|
||||
def __init__(self, addrs):
|
||||
self._network_list_v4 = []
|
||||
self._network_list_v6 = []
|
||||
if type(addrs) == str:
|
||||
addrs = addrs.split(',')
|
||||
list(map(self.add_network, addrs))
|
||||
|
||||
def add_network(self, addr):
|
||||
if addr is "":
|
||||
return
|
||||
block = addr.split('/')
|
||||
addr_family = is_ip(block[0])
|
||||
addr_len = IPNetwork.ADDRLENGTH[addr_family]
|
||||
if addr_family is socket.AF_INET:
|
||||
ip, = struct.unpack("!I", socket.inet_aton(block[0]))
|
||||
elif addr_family is socket.AF_INET6:
|
||||
hi, lo = struct.unpack("!QQ", inet_pton(addr_family, block[0]))
|
||||
ip = (hi << 64) | lo
|
||||
else:
|
||||
raise Exception("Not a valid CIDR notation: %s" % addr)
|
||||
if len(block) is 1:
|
||||
prefix_size = 0
|
||||
while (ip & 1) == 0 and ip is not 0:
|
||||
ip >>= 1
|
||||
prefix_size += 1
|
||||
logging.warn("You did't specify CIDR routing prefix size for %s, "
|
||||
"implicit treated as %s/%d" % (addr, addr, addr_len))
|
||||
elif block[1].isdigit() and int(block[1]) <= addr_len:
|
||||
prefix_size = addr_len - int(block[1])
|
||||
ip >>= prefix_size
|
||||
else:
|
||||
raise Exception("Not a valid CIDR notation: %s" % addr)
|
||||
if addr_family is socket.AF_INET:
|
||||
self._network_list_v4.append((ip, prefix_size))
|
||||
else:
|
||||
self._network_list_v6.append((ip, prefix_size))
|
||||
|
||||
def __contains__(self, addr):
|
||||
addr_family = is_ip(addr)
|
||||
if addr_family is socket.AF_INET:
|
||||
ip, = struct.unpack("!I", socket.inet_aton(addr))
|
||||
return any(map(lambda n_ps: n_ps[0] == ip >> n_ps[1],
|
||||
self._network_list_v4))
|
||||
elif addr_family is socket.AF_INET6:
|
||||
hi, lo = struct.unpack("!QQ", inet_pton(addr_family, addr))
|
||||
ip = (hi << 64) | lo
|
||||
return any(map(lambda n_ps: n_ps[0] == ip >> n_ps[1],
|
||||
self._network_list_v6))
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def test_inet_conv():
|
||||
ipv4 = b'8.8.4.4'
|
||||
b = inet_pton(socket.AF_INET, ipv4)
|
||||
assert inet_ntop(socket.AF_INET, b) == ipv4
|
||||
ipv6 = b'2404:6800:4005:805::1011'
|
||||
b = inet_pton(socket.AF_INET6, ipv6)
|
||||
assert inet_ntop(socket.AF_INET6, b) == ipv6
|
||||
|
||||
|
||||
def test_parse_header():
|
||||
assert parse_header(b'\x03\x0ewww.google.com\x00\x50') == \
|
||||
(3, b'www.google.com', 80, 18)
|
||||
assert parse_header(b'\x01\x08\x08\x08\x08\x00\x35') == \
|
||||
(1, b'8.8.8.8', 53, 7)
|
||||
assert parse_header((b'\x04$\x04h\x00@\x05\x08\x05\x00\x00\x00\x00\x00'
|
||||
b'\x00\x10\x11\x00\x50')) == \
|
||||
(4, b'2404:6800:4005:805::1011', 80, 19)
|
||||
|
||||
|
||||
def test_pack_header():
|
||||
assert pack_addr(b'8.8.8.8') == b'\x01\x08\x08\x08\x08'
|
||||
assert pack_addr(b'2404:6800:4005:805::1011') == \
|
||||
b'\x04$\x04h\x00@\x05\x08\x05\x00\x00\x00\x00\x00\x00\x10\x11'
|
||||
assert pack_addr(b'www.google.com') == b'\x03\x0ewww.google.com'
|
||||
|
||||
|
||||
def test_ip_network():
|
||||
ip_network = IPNetwork('127.0.0.0/24,::ff:1/112,::1,192.168.1.1,192.0.2.0')
|
||||
assert '127.0.0.1' in ip_network
|
||||
assert '127.0.1.1' not in ip_network
|
||||
assert ':ff:ffff' in ip_network
|
||||
assert '::ffff:1' not in ip_network
|
||||
assert '::1' in ip_network
|
||||
assert '::2' not in ip_network
|
||||
assert '192.168.1.1' in ip_network
|
||||
assert '192.168.1.2' not in ip_network
|
||||
assert '192.0.2.1' in ip_network
|
||||
assert '192.0.3.1' in ip_network # 192.0.2.0 is treated as 192.0.2.0/23
|
||||
assert 'www.google.com' not in ip_network
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_inet_conv()
|
||||
test_parse_header()
|
||||
test_pack_header()
|
||||
test_ip_network()
|
18
crypto/__init__.py
Normal file
18
crypto/__init__.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright 2015 clowwindy
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, \
|
||||
with_statement
|
181
crypto/openssl.py
Normal file
181
crypto/openssl.py
Normal file
|
@ -0,0 +1,181 @@
|
|||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright 2015 clowwindy
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, \
|
||||
with_statement
|
||||
|
||||
from ctypes import c_char_p, c_int, c_long, byref,\
|
||||
create_string_buffer, c_void_p
|
||||
|
||||
from shadowsocks import common
|
||||
from shadowsocks.crypto import util
|
||||
|
||||
__all__ = ['ciphers']
|
||||
|
||||
libcrypto = None
|
||||
loaded = False
|
||||
|
||||
buf_size = 2048
|
||||
|
||||
|
||||
def load_openssl():
|
||||
global loaded, libcrypto, buf
|
||||
|
||||
libcrypto = util.find_library(('crypto', 'eay32'),
|
||||
'EVP_get_cipherbyname',
|
||||
'libcrypto')
|
||||
if libcrypto is None:
|
||||
raise Exception('libcrypto(OpenSSL) not found')
|
||||
|
||||
libcrypto.EVP_get_cipherbyname.restype = c_void_p
|
||||
libcrypto.EVP_CIPHER_CTX_new.restype = c_void_p
|
||||
|
||||
libcrypto.EVP_CipherInit_ex.argtypes = (c_void_p, c_void_p, c_char_p,
|
||||
c_char_p, c_char_p, c_int)
|
||||
|
||||
libcrypto.EVP_CipherUpdate.argtypes = (c_void_p, c_void_p, c_void_p,
|
||||
c_char_p, c_int)
|
||||
|
||||
libcrypto.EVP_CIPHER_CTX_cleanup.argtypes = (c_void_p,)
|
||||
libcrypto.EVP_CIPHER_CTX_free.argtypes = (c_void_p,)
|
||||
if hasattr(libcrypto, 'OpenSSL_add_all_ciphers'):
|
||||
libcrypto.OpenSSL_add_all_ciphers()
|
||||
|
||||
buf = create_string_buffer(buf_size)
|
||||
loaded = True
|
||||
|
||||
|
||||
def load_cipher(cipher_name):
|
||||
func_name = 'EVP_' + cipher_name.replace('-', '_')
|
||||
if bytes != str:
|
||||
func_name = str(func_name, 'utf-8')
|
||||
cipher = getattr(libcrypto, func_name, None)
|
||||
if cipher:
|
||||
cipher.restype = c_void_p
|
||||
return cipher()
|
||||
return None
|
||||
|
||||
|
||||
class OpenSSLCrypto(object):
|
||||
def __init__(self, cipher_name, key, iv, op):
|
||||
self._ctx = None
|
||||
if not loaded:
|
||||
load_openssl()
|
||||
cipher_name = common.to_bytes(cipher_name)
|
||||
cipher = libcrypto.EVP_get_cipherbyname(cipher_name)
|
||||
if not cipher:
|
||||
cipher = load_cipher(cipher_name)
|
||||
if not cipher:
|
||||
raise Exception('cipher %s not found in libcrypto' % cipher_name)
|
||||
key_ptr = c_char_p(key)
|
||||
iv_ptr = c_char_p(iv)
|
||||
self._ctx = libcrypto.EVP_CIPHER_CTX_new()
|
||||
if not self._ctx:
|
||||
raise Exception('can not create cipher context')
|
||||
r = libcrypto.EVP_CipherInit_ex(self._ctx, cipher, None,
|
||||
key_ptr, iv_ptr, c_int(op))
|
||||
if not r:
|
||||
self.clean()
|
||||
raise Exception('can not initialize cipher context')
|
||||
|
||||
def update(self, data):
|
||||
global buf_size, buf
|
||||
cipher_out_len = c_long(0)
|
||||
l = len(data)
|
||||
if buf_size < l:
|
||||
buf_size = l * 2
|
||||
buf = create_string_buffer(buf_size)
|
||||
libcrypto.EVP_CipherUpdate(self._ctx, byref(buf),
|
||||
byref(cipher_out_len), c_char_p(data), l)
|
||||
# buf is copied to a str object when we access buf.raw
|
||||
return buf.raw[:cipher_out_len.value]
|
||||
|
||||
def __del__(self):
|
||||
self.clean()
|
||||
|
||||
def clean(self):
|
||||
if self._ctx:
|
||||
libcrypto.EVP_CIPHER_CTX_cleanup(self._ctx)
|
||||
libcrypto.EVP_CIPHER_CTX_free(self._ctx)
|
||||
|
||||
|
||||
ciphers = {
|
||||
'aes-128-cfb': (16, 16, OpenSSLCrypto),
|
||||
'aes-192-cfb': (24, 16, OpenSSLCrypto),
|
||||
'aes-256-cfb': (32, 16, OpenSSLCrypto),
|
||||
'aes-128-ofb': (16, 16, OpenSSLCrypto),
|
||||
'aes-192-ofb': (24, 16, OpenSSLCrypto),
|
||||
'aes-256-ofb': (32, 16, OpenSSLCrypto),
|
||||
'aes-128-ctr': (16, 16, OpenSSLCrypto),
|
||||
'aes-192-ctr': (24, 16, OpenSSLCrypto),
|
||||
'aes-256-ctr': (32, 16, OpenSSLCrypto),
|
||||
'aes-128-cfb8': (16, 16, OpenSSLCrypto),
|
||||
'aes-192-cfb8': (24, 16, OpenSSLCrypto),
|
||||
'aes-256-cfb8': (32, 16, OpenSSLCrypto),
|
||||
'aes-128-cfb1': (16, 16, OpenSSLCrypto),
|
||||
'aes-192-cfb1': (24, 16, OpenSSLCrypto),
|
||||
'aes-256-cfb1': (32, 16, OpenSSLCrypto),
|
||||
'bf-cfb': (16, 8, OpenSSLCrypto),
|
||||
'camellia-128-cfb': (16, 16, OpenSSLCrypto),
|
||||
'camellia-192-cfb': (24, 16, OpenSSLCrypto),
|
||||
'camellia-256-cfb': (32, 16, OpenSSLCrypto),
|
||||
'cast5-cfb': (16, 8, OpenSSLCrypto),
|
||||
'des-cfb': (8, 8, OpenSSLCrypto),
|
||||
'idea-cfb': (16, 8, OpenSSLCrypto),
|
||||
'rc2-cfb': (16, 8, OpenSSLCrypto),
|
||||
'rc4': (16, 0, OpenSSLCrypto),
|
||||
'seed-cfb': (16, 16, OpenSSLCrypto),
|
||||
}
|
||||
|
||||
|
||||
def run_method(method):
|
||||
|
||||
cipher = OpenSSLCrypto(method, b'k' * 32, b'i' * 16, 1)
|
||||
decipher = OpenSSLCrypto(method, b'k' * 32, b'i' * 16, 0)
|
||||
|
||||
util.run_cipher(cipher, decipher)
|
||||
|
||||
|
||||
def test_aes_128_cfb():
|
||||
run_method('aes-128-cfb')
|
||||
|
||||
|
||||
def test_aes_256_cfb():
|
||||
run_method('aes-256-cfb')
|
||||
|
||||
|
||||
def test_aes_128_cfb8():
|
||||
run_method('aes-128-cfb8')
|
||||
|
||||
|
||||
def test_aes_256_ofb():
|
||||
run_method('aes-256-ofb')
|
||||
|
||||
|
||||
def test_aes_256_ctr():
|
||||
run_method('aes-256-ctr')
|
||||
|
||||
|
||||
def test_bf_cfb():
|
||||
run_method('bf-cfb')
|
||||
|
||||
|
||||
def test_rc4():
|
||||
run_method('rc4')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_aes_128_cfb()
|
51
crypto/rc4_md5.py
Normal file
51
crypto/rc4_md5.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright 2015 clowwindy
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, \
|
||||
with_statement
|
||||
|
||||
import hashlib
|
||||
|
||||
from shadowsocks.crypto import openssl
|
||||
|
||||
__all__ = ['ciphers']
|
||||
|
||||
|
||||
def create_cipher(alg, key, iv, op, key_as_bytes=0, d=None, salt=None,
|
||||
i=1, padding=1):
|
||||
md5 = hashlib.md5()
|
||||
md5.update(key)
|
||||
md5.update(iv)
|
||||
rc4_key = md5.digest()
|
||||
return openssl.OpenSSLCrypto(b'rc4', rc4_key, b'', op)
|
||||
|
||||
|
||||
ciphers = {
|
||||
'rc4-md5': (16, 16, create_cipher),
|
||||
}
|
||||
|
||||
|
||||
def test():
|
||||
from shadowsocks.crypto import util
|
||||
|
||||
cipher = create_cipher('rc4-md5', b'k' * 32, b'i' * 16, 1)
|
||||
decipher = create_cipher('rc4-md5', b'k' * 32, b'i' * 16, 0)
|
||||
|
||||
util.run_cipher(cipher, decipher)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
120
crypto/sodium.py
Normal file
120
crypto/sodium.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright 2015 clowwindy
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, \
|
||||
with_statement
|
||||
|
||||
from ctypes import c_char_p, c_int, c_ulonglong, byref, \
|
||||
create_string_buffer, c_void_p
|
||||
|
||||
from shadowsocks.crypto import util
|
||||
|
||||
__all__ = ['ciphers']
|
||||
|
||||
libsodium = None
|
||||
loaded = False
|
||||
|
||||
buf_size = 2048
|
||||
|
||||
# for salsa20 and chacha20
|
||||
BLOCK_SIZE = 64
|
||||
|
||||
|
||||
def load_libsodium():
|
||||
global loaded, libsodium, buf
|
||||
|
||||
libsodium = util.find_library('sodium', 'crypto_stream_salsa20_xor_ic',
|
||||
'libsodium')
|
||||
if libsodium is None:
|
||||
raise Exception('libsodium not found')
|
||||
|
||||
libsodium.crypto_stream_salsa20_xor_ic.restype = c_int
|
||||
libsodium.crypto_stream_salsa20_xor_ic.argtypes = (c_void_p, c_char_p,
|
||||
c_ulonglong,
|
||||
c_char_p, c_ulonglong,
|
||||
c_char_p)
|
||||
libsodium.crypto_stream_chacha20_xor_ic.restype = c_int
|
||||
libsodium.crypto_stream_chacha20_xor_ic.argtypes = (c_void_p, c_char_p,
|
||||
c_ulonglong,
|
||||
c_char_p, c_ulonglong,
|
||||
c_char_p)
|
||||
|
||||
buf = create_string_buffer(buf_size)
|
||||
loaded = True
|
||||
|
||||
|
||||
class SodiumCrypto(object):
|
||||
def __init__(self, cipher_name, key, iv, op):
|
||||
if not loaded:
|
||||
load_libsodium()
|
||||
self.key = key
|
||||
self.iv = iv
|
||||
self.key_ptr = c_char_p(key)
|
||||
self.iv_ptr = c_char_p(iv)
|
||||
if cipher_name == 'salsa20':
|
||||
self.cipher = libsodium.crypto_stream_salsa20_xor_ic
|
||||
elif cipher_name == 'chacha20':
|
||||
self.cipher = libsodium.crypto_stream_chacha20_xor_ic
|
||||
else:
|
||||
raise Exception('Unknown cipher')
|
||||
# byte counter, not block counter
|
||||
self.counter = 0
|
||||
|
||||
def update(self, data):
|
||||
global buf_size, buf
|
||||
l = len(data)
|
||||
|
||||
# we can only prepend some padding to make the encryption align to
|
||||
# blocks
|
||||
padding = self.counter % BLOCK_SIZE
|
||||
if buf_size < padding + l:
|
||||
buf_size = (padding + l) * 2
|
||||
buf = create_string_buffer(buf_size)
|
||||
|
||||
if padding:
|
||||
data = (b'\0' * padding) + data
|
||||
self.cipher(byref(buf), c_char_p(data), padding + l,
|
||||
self.iv_ptr, int(self.counter / BLOCK_SIZE), self.key_ptr)
|
||||
self.counter += l
|
||||
# buf is copied to a str object when we access buf.raw
|
||||
# strip off the padding
|
||||
return buf.raw[padding:padding + l]
|
||||
|
||||
|
||||
ciphers = {
|
||||
'salsa20': (32, 8, SodiumCrypto),
|
||||
'chacha20': (32, 8, SodiumCrypto),
|
||||
}
|
||||
|
||||
|
||||
def test_salsa20():
|
||||
cipher = SodiumCrypto('salsa20', b'k' * 32, b'i' * 16, 1)
|
||||
decipher = SodiumCrypto('salsa20', b'k' * 32, b'i' * 16, 0)
|
||||
|
||||
util.run_cipher(cipher, decipher)
|
||||
|
||||
|
||||
def test_chacha20():
|
||||
|
||||
cipher = SodiumCrypto('chacha20', b'k' * 32, b'i' * 16, 1)
|
||||
decipher = SodiumCrypto('chacha20', b'k' * 32, b'i' * 16, 0)
|
||||
|
||||
util.run_cipher(cipher, decipher)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_chacha20()
|
||||
test_salsa20()
|
174
crypto/table.py
Normal file
174
crypto/table.py
Normal file
|
@ -0,0 +1,174 @@
|
|||
# !/usr/bin/env python
|
||||
#
|
||||
# Copyright 2015 clowwindy
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, \
|
||||
with_statement
|
||||
|
||||
import string
|
||||
import struct
|
||||
import hashlib
|
||||
|
||||
|
||||
__all__ = ['ciphers']
|
||||
|
||||
cached_tables = {}
|
||||
|
||||
if hasattr(string, 'maketrans'):
|
||||
maketrans = string.maketrans
|
||||
translate = string.translate
|
||||
else:
|
||||
maketrans = bytes.maketrans
|
||||
translate = bytes.translate
|
||||
|
||||
|
||||
def get_table(key):
|
||||
m = hashlib.md5()
|
||||
m.update(key)
|
||||
s = m.digest()
|
||||
a, b = struct.unpack('<QQ', s)
|
||||
table = maketrans(b'', b'')
|
||||
table = [table[i: i + 1] for i in range(len(table))]
|
||||
for i in range(1, 1024):
|
||||
table.sort(key=lambda x: int(a % (ord(x) + i)))
|
||||
return table
|
||||
|
||||
|
||||
def init_table(key):
|
||||
if key not in cached_tables:
|
||||
encrypt_table = b''.join(get_table(key))
|
||||
decrypt_table = maketrans(encrypt_table, maketrans(b'', b''))
|
||||
cached_tables[key] = [encrypt_table, decrypt_table]
|
||||
return cached_tables[key]
|
||||
|
||||
|
||||
class TableCipher(object):
|
||||
def __init__(self, cipher_name, key, iv, op):
|
||||
self._encrypt_table, self._decrypt_table = init_table(key)
|
||||
self._op = op
|
||||
|
||||
def update(self, data):
|
||||
if self._op:
|
||||
return translate(data, self._encrypt_table)
|
||||
else:
|
||||
return translate(data, self._decrypt_table)
|
||||
|
||||
|
||||
ciphers = {
|
||||
'table': (0, 0, TableCipher)
|
||||
}
|
||||
|
||||
|
||||
def test_table_result():
|
||||
from shadowsocks.common import ord
|
||||
target1 = [
|
||||
[60, 53, 84, 138, 217, 94, 88, 23, 39, 242, 219, 35, 12, 157, 165, 181,
|
||||
255, 143, 83, 247, 162, 16, 31, 209, 190, 171, 115, 65, 38, 41, 21,
|
||||
245, 236, 46, 121, 62, 166, 233, 44, 154, 153, 145, 230, 49, 128, 216,
|
||||
173, 29, 241, 119, 64, 229, 194, 103, 131, 110, 26, 197, 218, 59, 204,
|
||||
56, 27, 34, 141, 221, 149, 239, 192, 195, 24, 155, 170, 183, 11, 254,
|
||||
213, 37, 137, 226, 75, 203, 55, 19, 72, 248, 22, 129, 33, 175, 178,
|
||||
10, 198, 71, 77, 36, 113, 167, 48, 2, 117, 140, 142, 66, 199, 232,
|
||||
243, 32, 123, 54, 51, 82, 57, 177, 87, 251, 150, 196, 133, 5, 253,
|
||||
130, 8, 184, 14, 152, 231, 3, 186, 159, 76, 89, 228, 205, 156, 96,
|
||||
163, 146, 18, 91, 132, 85, 80, 109, 172, 176, 105, 13, 50, 235, 127,
|
||||
0, 189, 95, 98, 136, 250, 200, 108, 179, 211, 214, 106, 168, 78, 79,
|
||||
74, 210, 30, 73, 201, 151, 208, 114, 101, 174, 92, 52, 120, 240, 15,
|
||||
169, 220, 182, 81, 224, 43, 185, 40, 99, 180, 17, 212, 158, 42, 90, 9,
|
||||
191, 45, 6, 25, 4, 222, 67, 126, 1, 116, 124, 206, 69, 61, 7, 68, 97,
|
||||
202, 63, 244, 20, 28, 58, 93, 134, 104, 144, 227, 147, 102, 118, 135,
|
||||
148, 47, 238, 86, 112, 122, 70, 107, 215, 100, 139, 223, 225, 164,
|
||||
237, 111, 125, 207, 160, 187, 246, 234, 161, 188, 193, 249, 252],
|
||||
[151, 205, 99, 127, 201, 119, 199, 211, 122, 196, 91, 74, 12, 147, 124,
|
||||
180, 21, 191, 138, 83, 217, 30, 86, 7, 70, 200, 56, 62, 218, 47, 168,
|
||||
22, 107, 88, 63, 11, 95, 77, 28, 8, 188, 29, 194, 186, 38, 198, 33,
|
||||
230, 98, 43, 148, 110, 177, 1, 109, 82, 61, 112, 219, 59, 0, 210, 35,
|
||||
215, 50, 27, 103, 203, 212, 209, 235, 93, 84, 169, 166, 80, 130, 94,
|
||||
164, 165, 142, 184, 111, 18, 2, 141, 232, 114, 6, 131, 195, 139, 176,
|
||||
220, 5, 153, 135, 213, 154, 189, 238, 174, 226, 53, 222, 146, 162,
|
||||
236, 158, 143, 55, 244, 233, 96, 173, 26, 206, 100, 227, 49, 178, 34,
|
||||
234, 108, 207, 245, 204, 150, 44, 87, 121, 54, 140, 118, 221, 228,
|
||||
155, 78, 3, 239, 101, 64, 102, 17, 223, 41, 137, 225, 229, 66, 116,
|
||||
171, 125, 40, 39, 71, 134, 13, 193, 129, 247, 251, 20, 136, 242, 14,
|
||||
36, 97, 163, 181, 72, 25, 144, 46, 175, 89, 145, 113, 90, 159, 190,
|
||||
15, 183, 73, 123, 187, 128, 248, 252, 152, 24, 197, 68, 253, 52, 69,
|
||||
117, 57, 92, 104, 157, 170, 214, 81, 60, 133, 208, 246, 172, 23, 167,
|
||||
160, 192, 76, 161, 237, 45, 4, 58, 10, 182, 65, 202, 240, 185, 241,
|
||||
79, 224, 132, 51, 42, 126, 105, 37, 250, 149, 32, 243, 231, 67, 179,
|
||||
48, 9, 106, 216, 31, 249, 19, 85, 254, 156, 115, 255, 120, 75, 16]]
|
||||
|
||||
target2 = [
|
||||
[124, 30, 170, 247, 27, 127, 224, 59, 13, 22, 196, 76, 72, 154, 32,
|
||||
209, 4, 2, 131, 62, 101, 51, 230, 9, 166, 11, 99, 80, 208, 112, 36,
|
||||
248, 81, 102, 130, 88, 218, 38, 168, 15, 241, 228, 167, 117, 158, 41,
|
||||
10, 180, 194, 50, 204, 243, 246, 251, 29, 198, 219, 210, 195, 21, 54,
|
||||
91, 203, 221, 70, 57, 183, 17, 147, 49, 133, 65, 77, 55, 202, 122,
|
||||
162, 169, 188, 200, 190, 125, 63, 244, 96, 31, 107, 106, 74, 143, 116,
|
||||
148, 78, 46, 1, 137, 150, 110, 181, 56, 95, 139, 58, 3, 231, 66, 165,
|
||||
142, 242, 43, 192, 157, 89, 175, 109, 220, 128, 0, 178, 42, 255, 20,
|
||||
214, 185, 83, 160, 253, 7, 23, 92, 111, 153, 26, 226, 33, 176, 144,
|
||||
18, 216, 212, 28, 151, 71, 206, 222, 182, 8, 174, 205, 201, 152, 240,
|
||||
155, 108, 223, 104, 239, 98, 164, 211, 184, 34, 193, 14, 114, 187, 40,
|
||||
254, 12, 67, 93, 217, 6, 94, 16, 19, 82, 86, 245, 24, 197, 134, 132,
|
||||
138, 229, 121, 5, 235, 238, 85, 47, 103, 113, 179, 69, 250, 45, 135,
|
||||
156, 25, 61, 75, 44, 146, 189, 84, 207, 172, 119, 53, 123, 186, 120,
|
||||
171, 68, 227, 145, 136, 100, 90, 48, 79, 159, 149, 39, 213, 236, 126,
|
||||
52, 60, 225, 199, 105, 73, 233, 252, 118, 215, 35, 115, 64, 37, 97,
|
||||
129, 161, 177, 87, 237, 141, 173, 191, 163, 140, 234, 232, 249],
|
||||
[117, 94, 17, 103, 16, 186, 172, 127, 146, 23, 46, 25, 168, 8, 163, 39,
|
||||
174, 67, 137, 175, 121, 59, 9, 128, 179, 199, 132, 4, 140, 54, 1, 85,
|
||||
14, 134, 161, 238, 30, 241, 37, 224, 166, 45, 119, 109, 202, 196, 93,
|
||||
190, 220, 69, 49, 21, 228, 209, 60, 73, 99, 65, 102, 7, 229, 200, 19,
|
||||
82, 240, 71, 105, 169, 214, 194, 64, 142, 12, 233, 88, 201, 11, 72,
|
||||
92, 221, 27, 32, 176, 124, 205, 189, 177, 246, 35, 112, 219, 61, 129,
|
||||
170, 173, 100, 84, 242, 157, 26, 218, 20, 33, 191, 155, 232, 87, 86,
|
||||
153, 114, 97, 130, 29, 192, 164, 239, 90, 43, 236, 208, 212, 185, 75,
|
||||
210, 0, 81, 227, 5, 116, 243, 34, 18, 182, 70, 181, 197, 217, 95, 183,
|
||||
101, 252, 248, 107, 89, 136, 216, 203, 68, 91, 223, 96, 141, 150, 131,
|
||||
13, 152, 198, 111, 44, 222, 125, 244, 76, 251, 158, 106, 24, 42, 38,
|
||||
77, 2, 213, 207, 249, 147, 113, 135, 245, 118, 193, 47, 98, 145, 66,
|
||||
160, 123, 211, 165, 78, 204, 80, 250, 110, 162, 48, 58, 10, 180, 55,
|
||||
231, 79, 149, 74, 62, 50, 148, 143, 206, 28, 15, 57, 159, 139, 225,
|
||||
122, 237, 138, 171, 36, 56, 115, 63, 144, 154, 6, 230, 133, 215, 41,
|
||||
184, 22, 104, 254, 234, 253, 187, 226, 247, 188, 156, 151, 40, 108,
|
||||
51, 83, 178, 52, 3, 31, 255, 195, 53, 235, 126, 167, 120]]
|
||||
|
||||
encrypt_table = b''.join(get_table(b'foobar!'))
|
||||
decrypt_table = maketrans(encrypt_table, maketrans(b'', b''))
|
||||
|
||||
for i in range(0, 256):
|
||||
assert (target1[0][i] == ord(encrypt_table[i]))
|
||||
assert (target1[1][i] == ord(decrypt_table[i]))
|
||||
|
||||
encrypt_table = b''.join(get_table(b'barfoo!'))
|
||||
decrypt_table = maketrans(encrypt_table, maketrans(b'', b''))
|
||||
|
||||
for i in range(0, 256):
|
||||
assert (target2[0][i] == ord(encrypt_table[i]))
|
||||
assert (target2[1][i] == ord(decrypt_table[i]))
|
||||
|
||||
|
||||
def test_encryption():
|
||||
from shadowsocks.crypto import util
|
||||
|
||||
cipher = TableCipher('table', b'test', b'', 1)
|
||||
decipher = TableCipher('table', b'test', b'', 0)
|
||||
|
||||
util.run_cipher(cipher, decipher)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_table_result()
|
||||
test_encryption()
|
138
crypto/util.py
Normal file
138
crypto/util.py
Normal file
|
@ -0,0 +1,138 @@
|
|||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright 2015 clowwindy
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, \
|
||||
with_statement
|
||||
|
||||
import os
|
||||
import logging
|
||||
|
||||
|
||||
def find_library_nt(name):
|
||||
# modified from ctypes.util
|
||||
# ctypes.util.find_library just returns first result he found
|
||||
# but we want to try them all
|
||||
# because on Windows, users may have both 32bit and 64bit version installed
|
||||
results = []
|
||||
for directory in os.environ['PATH'].split(os.pathsep):
|
||||
fname = os.path.join(directory, name)
|
||||
if os.path.isfile(fname):
|
||||
results.append(fname)
|
||||
if fname.lower().endswith(".dll"):
|
||||
continue
|
||||
fname = fname + ".dll"
|
||||
if os.path.isfile(fname):
|
||||
results.append(fname)
|
||||
return results
|
||||
|
||||
|
||||
def find_library(possible_lib_names, search_symbol, library_name):
|
||||
import ctypes.util
|
||||
from ctypes import CDLL
|
||||
|
||||
paths = []
|
||||
|
||||
if type(possible_lib_names) not in (list, tuple):
|
||||
possible_lib_names = [possible_lib_names]
|
||||
|
||||
lib_names = []
|
||||
for lib_name in possible_lib_names:
|
||||
lib_names.append(lib_name)
|
||||
lib_names.append('lib' + lib_name)
|
||||
|
||||
for name in lib_names:
|
||||
if os.name == "nt":
|
||||
paths.extend(find_library_nt(name))
|
||||
else:
|
||||
path = ctypes.util.find_library(name)
|
||||
if path:
|
||||
paths.append(path)
|
||||
|
||||
if not paths:
|
||||
# We may get here when find_library fails because, for example,
|
||||
# the user does not have sufficient privileges to access those
|
||||
# tools underlying find_library on linux.
|
||||
import glob
|
||||
|
||||
for name in lib_names:
|
||||
patterns = [
|
||||
'/usr/local/lib*/lib%s.*' % name,
|
||||
'/usr/lib*/lib%s.*' % name,
|
||||
'lib%s.*' % name,
|
||||
'%s.dll' % name]
|
||||
|
||||
for pat in patterns:
|
||||
files = glob.glob(pat)
|
||||
if files:
|
||||
paths.extend(files)
|
||||
for path in paths:
|
||||
try:
|
||||
lib = CDLL(path)
|
||||
if hasattr(lib, search_symbol):
|
||||
logging.info('loading %s from %s', library_name, path)
|
||||
return lib
|
||||
else:
|
||||
logging.warn('can\'t find symbol %s in %s', search_symbol,
|
||||
path)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def run_cipher(cipher, decipher):
|
||||
from os import urandom
|
||||
import random
|
||||
import time
|
||||
|
||||
BLOCK_SIZE = 16384
|
||||
rounds = 1 * 1024
|
||||
plain = urandom(BLOCK_SIZE * rounds)
|
||||
|
||||
results = []
|
||||
pos = 0
|
||||
print('test start')
|
||||
start = time.time()
|
||||
while pos < len(plain):
|
||||
l = random.randint(100, 32768)
|
||||
c = cipher.update(plain[pos:pos + l])
|
||||
results.append(c)
|
||||
pos += l
|
||||
pos = 0
|
||||
c = b''.join(results)
|
||||
results = []
|
||||
while pos < len(plain):
|
||||
l = random.randint(100, 32768)
|
||||
results.append(decipher.update(c[pos:pos + l]))
|
||||
pos += l
|
||||
end = time.time()
|
||||
print('speed: %d bytes/s' % (BLOCK_SIZE * rounds / (end - start)))
|
||||
assert b''.join(results) == plain
|
||||
|
||||
|
||||
def test_find_library():
|
||||
assert find_library('c', 'strcpy', 'libc') is not None
|
||||
assert find_library(['c'], 'strcpy', 'libc') is not None
|
||||
assert find_library(('c',), 'strcpy', 'libc') is not None
|
||||
assert find_library(('crypto', 'eay32'), 'EVP_CipherUpdate',
|
||||
'libcrypto') is not None
|
||||
assert find_library('notexist', 'strcpy', 'libnotexist') is None
|
||||
assert find_library('c', 'symbol_not_exist', 'c') is None
|
||||
assert find_library(('notexist', 'c', 'crypto', 'eay32'),
|
||||
'EVP_CipherUpdate', 'libc') is not None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_find_library()
|
187
encrypt.py
Normal file
187
encrypt.py
Normal file
|
@ -0,0 +1,187 @@
|
|||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright 2012-2015 clowwindy
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, \
|
||||
with_statement
|
||||
|
||||
import os
|
||||
import sys
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from shadowsocks import common
|
||||
from shadowsocks.crypto import rc4_md5, openssl, sodium, table
|
||||
|
||||
|
||||
method_supported = {}
|
||||
method_supported.update(rc4_md5.ciphers)
|
||||
method_supported.update(openssl.ciphers)
|
||||
method_supported.update(sodium.ciphers)
|
||||
method_supported.update(table.ciphers)
|
||||
|
||||
|
||||
def random_string(length):
|
||||
return os.urandom(length)
|
||||
|
||||
|
||||
cached_keys = {}
|
||||
|
||||
|
||||
def try_cipher(key, method=None):
|
||||
Encryptor(key, method)
|
||||
|
||||
|
||||
def EVP_BytesToKey(password, key_len, iv_len):
|
||||
# equivalent to OpenSSL's EVP_BytesToKey() with count 1
|
||||
# so that we make the same key and iv as nodejs version
|
||||
cached_key = '%s-%d-%d' % (password, key_len, iv_len)
|
||||
r = cached_keys.get(cached_key, None)
|
||||
if r:
|
||||
return r
|
||||
m = []
|
||||
i = 0
|
||||
while len(b''.join(m)) < (key_len + iv_len):
|
||||
md5 = hashlib.md5()
|
||||
data = password
|
||||
if i > 0:
|
||||
data = m[i - 1] + password
|
||||
md5.update(data)
|
||||
m.append(md5.digest())
|
||||
i += 1
|
||||
ms = b''.join(m)
|
||||
key = ms[:key_len]
|
||||
iv = ms[key_len:key_len + iv_len]
|
||||
cached_keys[cached_key] = (key, iv)
|
||||
return key, iv
|
||||
|
||||
|
||||
class Encryptor(object):
|
||||
def __init__(self, key, method):
|
||||
self.key = key
|
||||
self.method = method
|
||||
self.iv = None
|
||||
self.iv_sent = False
|
||||
self.cipher_iv = b''
|
||||
self.decipher = None
|
||||
method = method.lower()
|
||||
self._method_info = self.get_method_info(method)
|
||||
if self._method_info:
|
||||
self.cipher = self.get_cipher(key, method, 1,
|
||||
random_string(self._method_info[1]))
|
||||
else:
|
||||
logging.error('method %s not supported' % method)
|
||||
sys.exit(1)
|
||||
|
||||
def get_method_info(self, method):
|
||||
method = method.lower()
|
||||
m = method_supported.get(method)
|
||||
return m
|
||||
|
||||
def iv_len(self):
|
||||
return len(self.cipher_iv)
|
||||
|
||||
def get_cipher(self, password, method, op, iv):
|
||||
password = common.to_bytes(password)
|
||||
m = self._method_info
|
||||
if m[0] > 0:
|
||||
key, iv_ = EVP_BytesToKey(password, m[0], m[1])
|
||||
else:
|
||||
# key_length == 0 indicates we should use the key directly
|
||||
key, iv = password, b''
|
||||
|
||||
iv = iv[:m[1]]
|
||||
if op == 1:
|
||||
# this iv is for cipher not decipher
|
||||
self.cipher_iv = iv[:m[1]]
|
||||
return m[2](method, key, iv, op)
|
||||
|
||||
def encrypt(self, buf):
|
||||
if len(buf) == 0:
|
||||
return buf
|
||||
if self.iv_sent:
|
||||
return self.cipher.update(buf)
|
||||
else:
|
||||
self.iv_sent = True
|
||||
return self.cipher_iv + self.cipher.update(buf)
|
||||
|
||||
def decrypt(self, buf):
|
||||
if len(buf) == 0:
|
||||
return buf
|
||||
if self.decipher is None:
|
||||
decipher_iv_len = self._method_info[1]
|
||||
decipher_iv = buf[:decipher_iv_len]
|
||||
self.decipher = self.get_cipher(self.key, self.method, 0,
|
||||
iv=decipher_iv)
|
||||
buf = buf[decipher_iv_len:]
|
||||
if len(buf) == 0:
|
||||
return buf
|
||||
return self.decipher.update(buf)
|
||||
|
||||
|
||||
def encrypt_all(password, method, op, data):
|
||||
result = []
|
||||
method = method.lower()
|
||||
(key_len, iv_len, m) = method_supported[method]
|
||||
if key_len > 0:
|
||||
key, _ = EVP_BytesToKey(password, key_len, iv_len)
|
||||
else:
|
||||
key = password
|
||||
if op:
|
||||
iv = random_string(iv_len)
|
||||
result.append(iv)
|
||||
else:
|
||||
iv = data[:iv_len]
|
||||
data = data[iv_len:]
|
||||
cipher = m(method, key, iv, op)
|
||||
result.append(cipher.update(data))
|
||||
return b''.join(result)
|
||||
|
||||
|
||||
CIPHERS_TO_TEST = [
|
||||
'aes-128-cfb',
|
||||
'aes-256-cfb',
|
||||
'rc4-md5',
|
||||
'salsa20',
|
||||
'chacha20',
|
||||
'table',
|
||||
]
|
||||
|
||||
|
||||
def test_encryptor():
|
||||
from os import urandom
|
||||
plain = urandom(10240)
|
||||
for method in CIPHERS_TO_TEST:
|
||||
logging.warn(method)
|
||||
encryptor = Encryptor(b'key', method)
|
||||
decryptor = Encryptor(b'key', method)
|
||||
cipher = encryptor.encrypt(plain)
|
||||
plain2 = decryptor.decrypt(cipher)
|
||||
assert plain == plain2
|
||||
|
||||
|
||||
def test_encrypt_all():
|
||||
from os import urandom
|
||||
plain = urandom(10240)
|
||||
for method in CIPHERS_TO_TEST:
|
||||
logging.warn(method)
|
||||
cipher = encrypt_all(b'key', method, 1, plain)
|
||||
plain2 = encrypt_all(b'key', method, 0, cipher)
|
||||
assert plain == plain2
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_encrypt_all()
|
||||
test_encryptor()
|
111
eventloop.py
Normal file
111
eventloop.py
Normal file
|
@ -0,0 +1,111 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import absolute_import, division, print_function, \
|
||||
with_statement
|
||||
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
import errno
|
||||
from shadowsocks import selectors
|
||||
from shadowsocks.selectors import (EVENT_READ, EVENT_WRITE, EVENT_ERROR,
|
||||
errno_from_exception)
|
||||
|
||||
|
||||
POLL_IN = EVENT_READ
|
||||
POLL_OUT = EVENT_WRITE
|
||||
POLL_ERR = EVENT_ERROR
|
||||
|
||||
TIMEOUT_PRECISION = 10
|
||||
|
||||
|
||||
class EventLoop:
|
||||
|
||||
def __init__(self):
|
||||
self._selector = selectors.DefaultSelector()
|
||||
self._stopping = False
|
||||
self._last_time = time.time()
|
||||
self._periodic_callbacks = []
|
||||
|
||||
def poll(self, timeout=None):
|
||||
return self._selector.select(timeout)
|
||||
|
||||
def add(self, sock, events, data):
|
||||
events |= selectors.EVENT_ERROR
|
||||
return self._selector.register(sock, events, data)
|
||||
|
||||
def remove(self, sock):
|
||||
try:
|
||||
return self._selector.unregister(sock)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def modify(self, sock, events, data):
|
||||
events |= selectors.EVENT_ERROR
|
||||
try:
|
||||
key = self._selector.modify(sock, events, data)
|
||||
except KeyError:
|
||||
key = self.add(sock, events, data)
|
||||
return key
|
||||
|
||||
def add_periodic(self, callback):
|
||||
self._periodic_callbacks.append(callback)
|
||||
|
||||
def remove_periodic(self, callback):
|
||||
self._periodic_callbacks.remove(callback)
|
||||
|
||||
def fd_count(self):
|
||||
return len(self._selector.get_map())
|
||||
|
||||
def run(self):
|
||||
logging.debug('Starting event loop')
|
||||
|
||||
while not self._stopping:
|
||||
asap = False
|
||||
try:
|
||||
events = self.poll(timeout=TIMEOUT_PRECISION)
|
||||
except (OSError, IOError) as e:
|
||||
if errno_from_exception(e) in (errno.EPIPE, errno.EINTR):
|
||||
# EPIPE: Happens when the client closes the connection
|
||||
# EINTR: Happens when received a signal
|
||||
# handles them as soon as possible
|
||||
asap = True
|
||||
logging.debug('poll: %s', e)
|
||||
else:
|
||||
logging.error('poll: %s', e)
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
for key, event in events:
|
||||
if type(key.data) == tuple:
|
||||
handler = key.data[0]
|
||||
args = key.data[1:]
|
||||
else:
|
||||
handler = key.data
|
||||
args = ()
|
||||
|
||||
sock = key.fileobj
|
||||
if hasattr(handler, 'handle_event'):
|
||||
handler = handler.handle_event
|
||||
|
||||
try:
|
||||
handler(sock, event, *args)
|
||||
except Exception as e:
|
||||
logging.debug(e)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
now = time.time()
|
||||
if asap or now - self._last_time >= TIMEOUT_PRECISION:
|
||||
for callback in self._periodic_callbacks:
|
||||
callback()
|
||||
self._last_time = now
|
||||
|
||||
logging.debug('Got {} fds registered'.format(self.fd_count()))
|
||||
|
||||
logging.debug('Stopping event loop')
|
||||
self._selector.close()
|
||||
|
||||
def stop(self):
|
||||
self._stopping = True
|
150
lru_cache.py
Normal file
150
lru_cache.py
Normal file
|
@ -0,0 +1,150 @@
|
|||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2015 clowwindy
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, \
|
||||
with_statement
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import time
|
||||
|
||||
|
||||
# this LRUCache is optimized for concurrency, not QPS
|
||||
# n: concurrency, keys stored in the cache
|
||||
# m: visits not timed out, proportional to QPS * timeout
|
||||
# get & set is O(1), not O(n). thus we can support very large n
|
||||
# TODO: if timeout or QPS is too large, then this cache is not very efficient,
|
||||
# as sweep() causes long pause
|
||||
|
||||
|
||||
class LRUCache(collections.MutableMapping):
|
||||
"""This class is not thread safe"""
|
||||
|
||||
def __init__(self, timeout=60, close_callback=None, *args, **kwargs):
|
||||
self.timeout = timeout
|
||||
self.close_callback = close_callback
|
||||
self._store = {}
|
||||
self._time_to_keys = collections.defaultdict(list)
|
||||
self._keys_to_last_time = {}
|
||||
self._last_visits = collections.deque()
|
||||
self._closed_values = set()
|
||||
self.update(dict(*args, **kwargs)) # use the free update to set keys
|
||||
|
||||
def __getitem__(self, key):
|
||||
# O(1)
|
||||
t = time.time()
|
||||
self._keys_to_last_time[key] = t
|
||||
self._time_to_keys[t].append(key)
|
||||
self._last_visits.append(t)
|
||||
return self._store[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# O(1)
|
||||
t = time.time()
|
||||
self._keys_to_last_time[key] = t
|
||||
self._store[key] = value
|
||||
self._time_to_keys[t].append(key)
|
||||
self._last_visits.append(t)
|
||||
|
||||
def __delitem__(self, key):
|
||||
# O(1)
|
||||
del self._store[key]
|
||||
del self._keys_to_last_time[key]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._store)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._store)
|
||||
|
||||
def sweep(self):
|
||||
# O(m)
|
||||
now = time.time()
|
||||
c = 0
|
||||
while len(self._last_visits) > 0:
|
||||
least = self._last_visits[0]
|
||||
if now - least <= self.timeout:
|
||||
break
|
||||
if self.close_callback is not None:
|
||||
for key in self._time_to_keys[least]:
|
||||
if key in self._store:
|
||||
if now - self._keys_to_last_time[key] > self.timeout:
|
||||
value = self._store[key]
|
||||
if value not in self._closed_values:
|
||||
self.close_callback(value)
|
||||
self._closed_values.add(value)
|
||||
for key in self._time_to_keys[least]:
|
||||
self._last_visits.popleft()
|
||||
if key in self._store:
|
||||
if now - self._keys_to_last_time[key] > self.timeout:
|
||||
del self._store[key]
|
||||
del self._keys_to_last_time[key]
|
||||
c += 1
|
||||
del self._time_to_keys[least]
|
||||
if c:
|
||||
self._closed_values.clear()
|
||||
logging.debug('%d keys swept' % c)
|
||||
|
||||
|
||||
def test():
|
||||
c = LRUCache(timeout=0.3)
|
||||
|
||||
c['a'] = 1
|
||||
assert c['a'] == 1
|
||||
|
||||
time.sleep(0.5)
|
||||
c.sweep()
|
||||
assert 'a' not in c
|
||||
|
||||
c['a'] = 2
|
||||
c['b'] = 3
|
||||
time.sleep(0.2)
|
||||
c.sweep()
|
||||
assert c['a'] == 2
|
||||
assert c['b'] == 3
|
||||
|
||||
time.sleep(0.2)
|
||||
c.sweep()
|
||||
c['b']
|
||||
time.sleep(0.2)
|
||||
c.sweep()
|
||||
assert 'a' not in c
|
||||
assert c['b'] == 3
|
||||
|
||||
time.sleep(0.5)
|
||||
c.sweep()
|
||||
assert 'a' not in c
|
||||
assert 'b' not in c
|
||||
|
||||
global close_cb_called
|
||||
close_cb_called = False
|
||||
|
||||
def close_cb(t):
|
||||
global close_cb_called
|
||||
assert not close_cb_called
|
||||
close_cb_called = True
|
||||
|
||||
c = LRUCache(timeout=0.1, close_callback=close_cb)
|
||||
c['s'] = 1
|
||||
c['s']
|
||||
time.sleep(0.1)
|
||||
c['s']
|
||||
time.sleep(0.3)
|
||||
c.sweep()
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
652
selectors.py
Normal file
652
selectors.py
Normal file
|
@ -0,0 +1,652 @@
|
|||
"""Selectors module.
|
||||
|
||||
This module allows high-level and efficient I/O multiplexing, built upon the
|
||||
`select` module primitives.
|
||||
"""
|
||||
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections import namedtuple, Mapping
|
||||
import errno
|
||||
import math
|
||||
import os
|
||||
import select
|
||||
import socket
|
||||
import sys
|
||||
|
||||
|
||||
# generic events, that must be mapped to implementation-specific ones
|
||||
EVENT_READ = 0x001
|
||||
EVENT_WRITE = 0x004
|
||||
EVENT_ERROR = 0x018
|
||||
|
||||
|
||||
def errno_from_exception(e):
|
||||
"""Provides the errno from an Exception object.
|
||||
|
||||
There are cases that the errno attribute was not set so we pull
|
||||
the errno out of the args but if someone instatiates an Exception
|
||||
without any args you will get a tuple error. So this function
|
||||
abstracts all that behavior to give you a safe way to get the
|
||||
errno.
|
||||
"""
|
||||
|
||||
if hasattr(e, 'errno'):
|
||||
return e.errno
|
||||
elif e.args:
|
||||
return e.args[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_sock_error(sock):
|
||||
if not sock:
|
||||
return
|
||||
error_number = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
|
||||
return socket.error(error_number, os.strerror(error_number))
|
||||
|
||||
|
||||
def _fileobj_to_fd(fileobj):
|
||||
"""Return a file descriptor from a file object.
|
||||
|
||||
Parameters:
|
||||
fileobj -- file object or file descriptor
|
||||
|
||||
Returns:
|
||||
corresponding file descriptor
|
||||
|
||||
Raises:
|
||||
ValueError if the object is invalid
|
||||
"""
|
||||
if isinstance(fileobj, int):
|
||||
fd = fileobj
|
||||
else:
|
||||
try:
|
||||
fd = int(fileobj.fileno())
|
||||
except (AttributeError, TypeError, ValueError):
|
||||
raise ValueError("Invalid file object: "
|
||||
"{!r}".format(fileobj))
|
||||
if fd < 0:
|
||||
raise ValueError("Invalid file descriptor: {}".format(fd))
|
||||
return fd
|
||||
|
||||
|
||||
SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data'])
|
||||
"""Object used to associate a file object to its backing file descriptor,
|
||||
selected event mask and attached data."""
|
||||
|
||||
|
||||
class _SelectorMapping(Mapping):
|
||||
"""Mapping of file objects to selector keys."""
|
||||
|
||||
def __init__(self, selector):
|
||||
self._selector = selector
|
||||
|
||||
def __len__(self):
|
||||
return len(self._selector._fd_to_key)
|
||||
|
||||
def __getitem__(self, fileobj):
|
||||
try:
|
||||
fd = self._selector._fileobj_lookup(fileobj)
|
||||
return self._selector._fd_to_key[fd]
|
||||
except KeyError:
|
||||
raise KeyError("{!r} is not registered".format(fileobj))
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._selector._fd_to_key)
|
||||
|
||||
|
||||
class BaseSelector(object):
|
||||
__metaclass__ = ABCMeta
|
||||
"""Selector abstract base class.
|
||||
|
||||
A selector supports registering file objects to be monitored for specific
|
||||
I/O events.
|
||||
|
||||
A file object is a file descriptor or any object with a `fileno()` method.
|
||||
An arbitrary object can be attached to the file object, which can be used
|
||||
for example to store context information, a callback, etc.
|
||||
|
||||
A selector can use various implementations (select(), poll(), epoll()...)
|
||||
depending on the platform. The default `Selector` class uses the most
|
||||
efficient implementation on the current platform.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def register(self, fileobj, events, data=None):
|
||||
"""Register a file object.
|
||||
|
||||
Parameters:
|
||||
fileobj -- file object or file descriptor
|
||||
events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE)
|
||||
data -- attached data
|
||||
|
||||
Returns:
|
||||
SelectorKey instance
|
||||
|
||||
Raises:
|
||||
ValueError if events is invalid
|
||||
KeyError if fileobj is already registered
|
||||
OSError if fileobj is closed or otherwise is unacceptable to
|
||||
the underlying system call (if a system call is made)
|
||||
|
||||
Note:
|
||||
OSError may or may not be raised
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def unregister(self, fileobj):
|
||||
"""Unregister a file object.
|
||||
|
||||
Parameters:
|
||||
fileobj -- file object or file descriptor
|
||||
|
||||
Returns:
|
||||
SelectorKey instance
|
||||
|
||||
Raises:
|
||||
KeyError if fileobj is not registered
|
||||
|
||||
Note:
|
||||
If fileobj is registered but has since been closed this does
|
||||
*not* raise OSError (even if the wrapped syscall does)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def modify(self, fileobj, events, data=None):
|
||||
"""Change a registered file object monitored events or attached data.
|
||||
|
||||
Parameters:
|
||||
fileobj -- file object or file descriptor
|
||||
events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE)
|
||||
data -- attached data
|
||||
|
||||
Returns:
|
||||
SelectorKey instance
|
||||
|
||||
Raises:
|
||||
Anything that unregister() or register() raises
|
||||
"""
|
||||
self.unregister(fileobj)
|
||||
return self.register(fileobj, events, data)
|
||||
|
||||
@abstractmethod
|
||||
def select(self, timeout=None):
|
||||
"""Perform the actual selection, until some monitored file objects are
|
||||
ready or a timeout expires.
|
||||
|
||||
Parameters:
|
||||
timeout -- if timeout > 0, this specifies the maximum wait time, in
|
||||
seconds
|
||||
if timeout <= 0, the select() call won't block, and will
|
||||
report the currently ready file objects
|
||||
if timeout is None, select() will block until a monitored
|
||||
file object becomes ready
|
||||
|
||||
Returns:
|
||||
list of (key, events) for ready file objects
|
||||
`events` is a bitwise mask of EVENT_READ|EVENT_WRITE
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self):
|
||||
"""Close the selector.
|
||||
|
||||
This must be called to make sure that any underlying resource is freed.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_key(self, fileobj):
|
||||
"""Return the key associated to a registered file object.
|
||||
|
||||
Returns:
|
||||
SelectorKey for this file object
|
||||
"""
|
||||
mapping = self.get_map()
|
||||
try:
|
||||
if mapping is None:
|
||||
raise KeyError
|
||||
return mapping[fileobj]
|
||||
except KeyError:
|
||||
raise KeyError("{!r} is not registered".format(fileobj))
|
||||
|
||||
@abstractmethod
|
||||
def get_map(self):
|
||||
"""Return a mapping of file objects to selector keys."""
|
||||
raise NotImplementedError
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.close()
|
||||
|
||||
|
||||
class _BaseSelectorImpl(BaseSelector):
|
||||
"""Base selector implementation."""
|
||||
|
||||
def __init__(self):
|
||||
# this maps file descriptors to keys
|
||||
self._fd_to_key = {}
|
||||
# read-only mapping returned by get_map()
|
||||
self._map = _SelectorMapping(self)
|
||||
|
||||
def _fileobj_lookup(self, fileobj):
|
||||
"""Return a file descriptor from a file object.
|
||||
|
||||
This wraps _fileobj_to_fd() to do an exhaustive search in case
|
||||
the object is invalid but we still have it in our map. This
|
||||
is used by unregister() so we can unregister an object that
|
||||
was previously registered even if it is closed. It is also
|
||||
used by _SelectorMapping.
|
||||
"""
|
||||
try:
|
||||
return _fileobj_to_fd(fileobj)
|
||||
except ValueError:
|
||||
# Do an exhaustive search.
|
||||
for key in self._fd_to_key.values():
|
||||
if key.fileobj is fileobj:
|
||||
return key.fd
|
||||
# Raise ValueError after all.
|
||||
raise
|
||||
|
||||
def register(self, fileobj, events, data=None):
|
||||
if (not events) or (events & ~(EVENT_READ | EVENT_WRITE |
|
||||
EVENT_ERROR)):
|
||||
raise ValueError("Invalid events: {!r}".format(events))
|
||||
|
||||
key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data)
|
||||
|
||||
if key.fd in self._fd_to_key:
|
||||
raise KeyError("{!r} (FD {}) is already registered"
|
||||
.format(fileobj, key.fd))
|
||||
|
||||
self._fd_to_key[key.fd] = key
|
||||
return key
|
||||
|
||||
def unregister(self, fileobj):
|
||||
try:
|
||||
key = self._fd_to_key.pop(self._fileobj_lookup(fileobj))
|
||||
except KeyError:
|
||||
raise KeyError("{!r} is not registered".format(fileobj))
|
||||
return key
|
||||
|
||||
def modify(self, fileobj, events, data=None):
|
||||
# TODO: Subclasses can probably optimize this even further.
|
||||
try:
|
||||
key = self._fd_to_key[self._fileobj_lookup(fileobj)]
|
||||
except KeyError:
|
||||
raise KeyError("{!r} is not registered".format(fileobj))
|
||||
if events != key.events:
|
||||
self.unregister(fileobj)
|
||||
key = self.register(fileobj, events, data)
|
||||
elif data != key.data:
|
||||
# Use a shortcut to update the data.
|
||||
key = key._replace(data=data)
|
||||
self._fd_to_key[key.fd] = key
|
||||
return key
|
||||
|
||||
def close(self):
|
||||
self._fd_to_key.clear()
|
||||
self._map = None
|
||||
|
||||
def get_map(self):
|
||||
return self._map
|
||||
|
||||
def _key_from_fd(self, fd):
|
||||
"""Return the key associated to a given file descriptor.
|
||||
|
||||
Parameters:
|
||||
fd -- file descriptor
|
||||
|
||||
Returns:
|
||||
corresponding key, or None if not found
|
||||
"""
|
||||
try:
|
||||
return self._fd_to_key[fd]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
|
||||
class SelectSelector(_BaseSelectorImpl):
|
||||
"""Select-based selector."""
|
||||
|
||||
def __init__(self):
|
||||
super(self.__class__, self).__init__()
|
||||
self._readers = set()
|
||||
self._writers = set()
|
||||
self._errors = set()
|
||||
|
||||
def register(self, fileobj, events, data=None):
|
||||
key = super(self.__class__, self).register(fileobj, events, data)
|
||||
if events & EVENT_READ:
|
||||
self._readers.add(key.fd)
|
||||
if events & EVENT_WRITE:
|
||||
self._writers.add(key.fd)
|
||||
if events & EVENT_ERROR:
|
||||
self._errors.add(key.fd)
|
||||
return key
|
||||
|
||||
def unregister(self, fileobj):
|
||||
key = super(self.__class__, self).unregister(fileobj)
|
||||
self._readers.discard(key.fd)
|
||||
self._writers.discard(key.fd)
|
||||
self._errors.discard(key.fd)
|
||||
return key
|
||||
|
||||
if sys.platform == 'win32':
|
||||
def _select(self, r, w, x, timeout=None):
|
||||
r, w, x = select.select(r, w, x, timeout)
|
||||
return r, w, x
|
||||
else:
|
||||
_select = select.select
|
||||
|
||||
def select(self, timeout=None):
|
||||
timeout = None if timeout is None else max(timeout, 0)
|
||||
ready = []
|
||||
try:
|
||||
r, w, x = self._select(self._readers, self._writers, self._errors,
|
||||
timeout)
|
||||
except OSError as e:
|
||||
if errno_from_exception(e) == errno.EAGAIN:
|
||||
return ready
|
||||
r = set(r)
|
||||
w = set(w)
|
||||
x = set(x)
|
||||
for fd in r | w | x:
|
||||
events = 0
|
||||
if fd in r:
|
||||
events |= EVENT_READ
|
||||
if fd in w:
|
||||
events |= EVENT_WRITE
|
||||
if fd in x:
|
||||
events |= EVENT_ERROR
|
||||
|
||||
key = self._key_from_fd(fd)
|
||||
if key:
|
||||
ready.append((key, events & key.events))
|
||||
return ready
|
||||
|
||||
|
||||
if hasattr(select, 'poll'):
|
||||
|
||||
class PollSelector(_BaseSelectorImpl):
|
||||
"""Poll-based selector."""
|
||||
|
||||
def __init__(self):
|
||||
super(self.__class__, self).__init__()
|
||||
self._poll = select.poll()
|
||||
|
||||
def register(self, fileobj, events, data=None):
|
||||
key = super(self.__class__, self).register(fileobj, events, data)
|
||||
poll_events = 0
|
||||
if events & EVENT_READ:
|
||||
poll_events |= select.POLLIN
|
||||
if events & EVENT_WRITE:
|
||||
poll_events |= select.POLLOUT
|
||||
if events & EVENT_ERROR:
|
||||
poll_events |= select.POLLERR | select.POLLHUP
|
||||
self._poll.register(key.fd, poll_events)
|
||||
return key
|
||||
|
||||
def unregister(self, fileobj):
|
||||
key = super(self.__class__, self).unregister(fileobj)
|
||||
self._poll.unregister(key.fd)
|
||||
return key
|
||||
|
||||
def select(self, timeout=None):
|
||||
if timeout is None:
|
||||
timeout = None
|
||||
elif timeout <= 0:
|
||||
timeout = 0
|
||||
else:
|
||||
# poll() has a resolution of 1 millisecond, round away from
|
||||
# zero to wait *at least* timeout seconds.
|
||||
timeout = math.ceil(timeout * 1e3)
|
||||
ready = []
|
||||
try:
|
||||
fd_event_list = self._poll.poll(timeout)
|
||||
except OSError as e:
|
||||
if errno_from_exception(e) == errno.EAGAIN:
|
||||
return ready
|
||||
for fd, event in fd_event_list:
|
||||
events = 0
|
||||
if event & select.POLLIN:
|
||||
events |= EVENT_READ
|
||||
if event & select.POLLOUT:
|
||||
events |= EVENT_WRITE
|
||||
if event & (select.POLLERR | select.POLLHUP):
|
||||
events |= EVENT_ERROR
|
||||
|
||||
key = self._key_from_fd(fd)
|
||||
if key:
|
||||
ready.append((key, events & key.events))
|
||||
return ready
|
||||
|
||||
|
||||
if hasattr(select, 'epoll'):
|
||||
|
||||
class EpollSelector(_BaseSelectorImpl):
|
||||
"""Epoll-based selector."""
|
||||
|
||||
def __init__(self):
|
||||
super(self.__class__, self).__init__()
|
||||
self._epoll = select.epoll()
|
||||
|
||||
def fileno(self):
|
||||
return self._epoll.fileno()
|
||||
|
||||
def register(self, fileobj, events, data=None):
|
||||
key = super(self.__class__, self).register(fileobj, events, data)
|
||||
epoll_events = 0
|
||||
if events & EVENT_READ:
|
||||
epoll_events |= select.EPOLLIN
|
||||
if events & EVENT_WRITE:
|
||||
epoll_events |= select.EPOLLOUT
|
||||
if events & EVENT_ERROR:
|
||||
epoll_events |= select.EPOLLERR | select.EPOLLHUP
|
||||
self._epoll.register(key.fd, epoll_events)
|
||||
return key
|
||||
|
||||
def unregister(self, fileobj):
|
||||
key = super(self.__class__, self).unregister(fileobj)
|
||||
try:
|
||||
self._epoll.unregister(key.fd)
|
||||
except OSError:
|
||||
# This can happen if the FD was closed since it
|
||||
# was registered.
|
||||
pass
|
||||
return key
|
||||
|
||||
def select(self, timeout=None):
|
||||
if timeout is None:
|
||||
timeout = -1
|
||||
elif timeout <= 0:
|
||||
timeout = 0
|
||||
else:
|
||||
# epoll_wait() has a resolution of 1 millisecond, round away
|
||||
# from zero to wait *at least* timeout seconds.
|
||||
timeout = math.ceil(timeout * 1e3) * 1e-3
|
||||
|
||||
# epoll_wait() expects `maxevents` to be greater than zero;
|
||||
# we want to make sure that `select()` can be called when no
|
||||
# FD is registered.
|
||||
max_ev = max(len(self._fd_to_key), 1)
|
||||
|
||||
ready = []
|
||||
try:
|
||||
fd_event_list = self._epoll.poll(timeout, max_ev)
|
||||
except OSError as e:
|
||||
if errno_from_exception(e) == errno.EAGAIN:
|
||||
return ready
|
||||
for fd, event in fd_event_list:
|
||||
events = 0
|
||||
if event & select.EPOLLIN:
|
||||
events |= EVENT_READ
|
||||
if event & select.EPOLLOUT:
|
||||
events |= EVENT_WRITE
|
||||
if event & (select.EPOLLERR | select.EPOLLHUP):
|
||||
events |= EVENT_ERROR
|
||||
|
||||
key = self._key_from_fd(fd)
|
||||
if key:
|
||||
ready.append((key, events & key.events))
|
||||
return ready
|
||||
|
||||
def close(self):
|
||||
try:
|
||||
self._epoll.close()
|
||||
finally:
|
||||
super(self.__class__, self).close()
|
||||
|
||||
|
||||
if hasattr(select, 'devpoll'):
|
||||
|
||||
class DevpollSelector(_BaseSelectorImpl):
|
||||
"""Solaris /dev/poll selector."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._devpoll = select.devpoll()
|
||||
|
||||
def fileno(self):
|
||||
return self._devpoll.fileno()
|
||||
|
||||
def register(self, fileobj, events, data=None):
|
||||
key = super().register(fileobj, events, data)
|
||||
poll_events = 0
|
||||
if events & EVENT_READ:
|
||||
poll_events |= select.POLLIN
|
||||
if events & EVENT_WRITE:
|
||||
poll_events |= select.POLLOUT
|
||||
if events & EVENT_ERROR:
|
||||
poll_events |= select.POLLERR | select.POLLHUP
|
||||
self._devpoll.register(key.fd, poll_events)
|
||||
return key
|
||||
|
||||
def unregister(self, fileobj):
|
||||
key = super().unregister(fileobj)
|
||||
self._devpoll.unregister(key.fd)
|
||||
return key
|
||||
|
||||
def select(self, timeout=None):
|
||||
if timeout is None:
|
||||
timeout = None
|
||||
elif timeout <= 0:
|
||||
timeout = 0
|
||||
else:
|
||||
# devpoll() has a resolution of 1 millisecond, round away from
|
||||
# zero to wait *at least* timeout seconds.
|
||||
timeout = math.ceil(timeout * 1e3)
|
||||
ready = []
|
||||
try:
|
||||
fd_event_list = self._devpoll.poll(timeout)
|
||||
except InterruptedError:
|
||||
return ready
|
||||
for fd, event in fd_event_list:
|
||||
events = 0
|
||||
if event & select.POLLIN:
|
||||
events |= EVENT_READ
|
||||
if event & select.POLLOUT:
|
||||
events |= EVENT_WRITE
|
||||
if event & (select.POLLERR | select.POLLHUP):
|
||||
events |= EVENT_ERROR
|
||||
|
||||
key = self._key_from_fd(fd)
|
||||
if key:
|
||||
ready.append((key, events & key.events))
|
||||
return ready
|
||||
|
||||
def close(self):
|
||||
self._devpoll.close()
|
||||
super().close()
|
||||
|
||||
|
||||
if hasattr(select, 'kqueue'):
|
||||
|
||||
class KqueueSelector(_BaseSelectorImpl):
|
||||
"""Kqueue-based selector."""
|
||||
|
||||
def __init__(self):
|
||||
super(self.__class__, self).__init__()
|
||||
self._kqueue = select.kqueue()
|
||||
|
||||
def fileno(self):
|
||||
return self._kqueue.fileno()
|
||||
|
||||
def register(self, fileobj, events, data=None):
|
||||
key = super(self.__class__, self).register(fileobj, events, data)
|
||||
if events & EVENT_READ:
|
||||
kev = select.kevent(key.fd, select.KQ_FILTER_READ,
|
||||
select.KQ_EV_ADD)
|
||||
self._kqueue.control([kev], 0, 0)
|
||||
if events & EVENT_WRITE:
|
||||
kev = select.kevent(key.fd, select.KQ_FILTER_WRITE,
|
||||
select.KQ_EV_ADD)
|
||||
self._kqueue.control([kev], 0, 0)
|
||||
return key
|
||||
|
||||
def unregister(self, fileobj):
|
||||
key = super(self.__class__, self).unregister(fileobj)
|
||||
if key.events & EVENT_READ:
|
||||
kev = select.kevent(key.fd, select.KQ_FILTER_READ,
|
||||
select.KQ_EV_DELETE)
|
||||
try:
|
||||
self._kqueue.control([kev], 0, 0)
|
||||
except OSError:
|
||||
# This can happen if the FD was closed since it
|
||||
# was registered.
|
||||
pass
|
||||
if key.events & EVENT_WRITE:
|
||||
kev = select.kevent(key.fd, select.KQ_FILTER_WRITE,
|
||||
select.KQ_EV_DELETE)
|
||||
try:
|
||||
self._kqueue.control([kev], 0, 0)
|
||||
except OSError:
|
||||
# See comment above.
|
||||
pass
|
||||
return key
|
||||
|
||||
def select(self, timeout=None):
|
||||
timeout = None if timeout is None else max(timeout, 0)
|
||||
max_ev = len(self._fd_to_key)
|
||||
ready = []
|
||||
try:
|
||||
kev_list = self._kqueue.control(None, max_ev, timeout)
|
||||
except OSError as e:
|
||||
if errno_from_exception(e) == errno.EAGAIN:
|
||||
return ready
|
||||
for kev in kev_list:
|
||||
fd = kev.ident
|
||||
flag = kev.filter
|
||||
events = 0
|
||||
if flag == select.KQ_FILTER_READ:
|
||||
events |= EVENT_READ
|
||||
if flag == select.KQ_FILTER_WRITE:
|
||||
events |= EVENT_WRITE
|
||||
|
||||
key = self._key_from_fd(fd)
|
||||
if key:
|
||||
ready.append((key, events & key.events))
|
||||
return ready
|
||||
|
||||
def close(self):
|
||||
try:
|
||||
self._kqueue.close()
|
||||
finally:
|
||||
super(self.__class__, self).close()
|
||||
|
||||
|
||||
# Choose the best implementation: roughly, epoll|kqueue > poll > select.
|
||||
# select() also can't accept a FD > FD_SETSIZE (usually around 1024)
|
||||
if 'KqueueSelector' in globals():
|
||||
DefaultSelector = KqueueSelector
|
||||
elif 'EpollSelector' in globals():
|
||||
DefaultSelector = EpollSelector
|
||||
elif 'DevpollSelector' in globals():
|
||||
DefaultSelector = DevpollSelector
|
||||
elif 'PollSelector' in globals():
|
||||
DefaultSelector = PollSelector
|
||||
else:
|
||||
DefaultSelector = SelectSelector
|
35
server.py
Normal file
35
server.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import absolute_import, division, print_function, \
|
||||
with_statement
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
|
||||
path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, path)
|
||||
from shadowsocks.eventloop import EventLoop
|
||||
from shadowsocks.tcprelay import TcpRelay, TcpRelayServerHandler
|
||||
from shadowsocks.asyncdns import DNSResolver
|
||||
|
||||
|
||||
FORMATTER = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
LOGGING_LEVEL = logging.INFO
|
||||
logging.basicConfig(level=LOGGING_LEVEL, format=FORMATTER)
|
||||
|
||||
LISTEN_ADDR = ('0.0.0.0', 9000)
|
||||
|
||||
|
||||
def main():
|
||||
loop = EventLoop()
|
||||
dns_resolver = DNSResolver()
|
||||
relay = TcpRelay(TcpRelayServerHandler, LISTEN_ADDR,
|
||||
dns_resolver=dns_resolver)
|
||||
dns_resolver.add_to_loop(loop)
|
||||
relay.add_to_loop(loop)
|
||||
loop.run()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
365
shell.py
Normal file
365
shell.py
Normal file
|
@ -0,0 +1,365 @@
|
|||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2015 clowwindy
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, \
|
||||
with_statement
|
||||
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
import getopt
|
||||
import logging
|
||||
from shadowsocks.common import to_bytes, to_str, IPNetwork
|
||||
from shadowsocks import encrypt
|
||||
|
||||
|
||||
VERBOSE_LEVEL = 5
|
||||
|
||||
verbose = 0
|
||||
|
||||
|
||||
def check_python():
|
||||
info = sys.version_info
|
||||
if info[0] == 2 and not info[1] >= 6:
|
||||
print('Python 2.6+ required')
|
||||
sys.exit(1)
|
||||
elif info[0] == 3 and not info[1] >= 3:
|
||||
print('Python 3.3+ required')
|
||||
sys.exit(1)
|
||||
elif info[0] not in [2, 3]:
|
||||
print('Python version not supported')
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def print_exception(e):
|
||||
global verbose
|
||||
logging.error(e)
|
||||
if verbose > 0:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def print_shadowsocks():
|
||||
version = ''
|
||||
try:
|
||||
import pkg_resources
|
||||
version = pkg_resources.get_distribution('shadowsocks').version
|
||||
except Exception:
|
||||
pass
|
||||
print('Shadowsocks %s' % version)
|
||||
|
||||
|
||||
def find_config():
|
||||
config_path = 'config.json'
|
||||
if os.path.exists(config_path):
|
||||
return config_path
|
||||
config_path = os.path.join(os.path.dirname(__file__), '../', 'config.json')
|
||||
if os.path.exists(config_path):
|
||||
return config_path
|
||||
return None
|
||||
|
||||
|
||||
def check_config(config, is_local):
|
||||
if config.get('daemon', None) == 'stop':
|
||||
# no need to specify configuration for daemon stop
|
||||
return
|
||||
|
||||
if is_local and not config.get('password', None):
|
||||
logging.error('password not specified')
|
||||
print_help(is_local)
|
||||
sys.exit(2)
|
||||
|
||||
if not is_local and not config.get('password', None) \
|
||||
and not config.get('port_password', None):
|
||||
logging.error('password or port_password not specified')
|
||||
print_help(is_local)
|
||||
sys.exit(2)
|
||||
|
||||
if 'local_port' in config:
|
||||
config['local_port'] = int(config['local_port'])
|
||||
|
||||
if 'server_port' in config and type(config['server_port']) != list:
|
||||
config['server_port'] = int(config['server_port'])
|
||||
|
||||
if config.get('local_address', '') in [b'0.0.0.0']:
|
||||
logging.warn('warning: local set to listen on 0.0.0.0, it\'s not safe')
|
||||
if config.get('server', '') in ['127.0.0.1', 'localhost']:
|
||||
logging.warn('warning: server set to listen on %s:%s, are you sure?' %
|
||||
(to_str(config['server']), config['server_port']))
|
||||
if (config.get('method', '') or '').lower() == 'table':
|
||||
logging.warn('warning: table is not safe; please use a safer cipher, '
|
||||
'like AES-256-CFB')
|
||||
if (config.get('method', '') or '').lower() == 'rc4':
|
||||
logging.warn('warning: RC4 is not safe; please use a safer cipher, '
|
||||
'like AES-256-CFB')
|
||||
if config.get('timeout', 300) < 100:
|
||||
logging.warn('warning: your timeout %d seems too short' %
|
||||
int(config.get('timeout')))
|
||||
if config.get('timeout', 300) > 600:
|
||||
logging.warn('warning: your timeout %d seems too long' %
|
||||
int(config.get('timeout')))
|
||||
if config.get('password') in [b'mypassword']:
|
||||
logging.error('DON\'T USE DEFAULT PASSWORD! Please change it in your '
|
||||
'config.json!')
|
||||
sys.exit(1)
|
||||
if config.get('user', None) is not None:
|
||||
if os.name != 'posix':
|
||||
logging.error('user can be used only on Unix')
|
||||
sys.exit(1)
|
||||
|
||||
encrypt.try_cipher(config['password'], config['method'])
|
||||
|
||||
|
||||
def get_config(is_local):
|
||||
global verbose
|
||||
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(levelname)-s: %(message)s')
|
||||
if is_local:
|
||||
shortopts = 'hd:s:b:p:k:l:m:c:t:vq'
|
||||
longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'user=',
|
||||
'version']
|
||||
else:
|
||||
shortopts = 'hd:s:p:k:m:c:t:vq'
|
||||
longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'workers=',
|
||||
'forbidden-ip=', 'user=', 'manager-address=', 'version']
|
||||
try:
|
||||
config_path = find_config()
|
||||
optlist, args = getopt.getopt(sys.argv[1:], shortopts, longopts)
|
||||
for key, value in optlist:
|
||||
if key == '-c':
|
||||
config_path = value
|
||||
|
||||
if config_path:
|
||||
logging.info('loading config from %s' % config_path)
|
||||
with open(config_path, 'rb') as f:
|
||||
try:
|
||||
config = parse_json_in_str(f.read().decode('utf8'))
|
||||
except ValueError as e:
|
||||
logging.error('found an error in config.json: %s',
|
||||
e.message)
|
||||
sys.exit(1)
|
||||
else:
|
||||
config = {}
|
||||
|
||||
v_count = 0
|
||||
for key, value in optlist:
|
||||
if key == '-p':
|
||||
config['server_port'] = int(value)
|
||||
elif key == '-k':
|
||||
config['password'] = to_bytes(value)
|
||||
elif key == '-l':
|
||||
config['local_port'] = int(value)
|
||||
elif key == '-s':
|
||||
config['server'] = to_str(value)
|
||||
elif key == '-m':
|
||||
config['method'] = to_str(value)
|
||||
elif key == '-b':
|
||||
config['local_address'] = to_str(value)
|
||||
elif key == '-v':
|
||||
v_count += 1
|
||||
# '-vv' turns on more verbose mode
|
||||
config['verbose'] = v_count
|
||||
elif key == '-t':
|
||||
config['timeout'] = int(value)
|
||||
elif key == '--fast-open':
|
||||
config['fast_open'] = True
|
||||
elif key == '--workers':
|
||||
config['workers'] = int(value)
|
||||
elif key == '--manager-address':
|
||||
config['manager_address'] = value
|
||||
elif key == '--user':
|
||||
config['user'] = to_str(value)
|
||||
elif key == '--forbidden-ip':
|
||||
config['forbidden_ip'] = to_str(value).split(',')
|
||||
elif key in ('-h', '--help'):
|
||||
if is_local:
|
||||
print_local_help()
|
||||
else:
|
||||
print_server_help()
|
||||
sys.exit(0)
|
||||
elif key == '--version':
|
||||
print_shadowsocks()
|
||||
sys.exit(0)
|
||||
elif key == '-d':
|
||||
config['daemon'] = to_str(value)
|
||||
elif key == '--pid-file':
|
||||
config['pid-file'] = to_str(value)
|
||||
elif key == '--log-file':
|
||||
config['log-file'] = to_str(value)
|
||||
elif key == '-q':
|
||||
v_count -= 1
|
||||
config['verbose'] = v_count
|
||||
except getopt.GetoptError as e:
|
||||
print(e, file=sys.stderr)
|
||||
print_help(is_local)
|
||||
sys.exit(2)
|
||||
|
||||
if not config:
|
||||
logging.error('config not specified')
|
||||
print_help(is_local)
|
||||
sys.exit(2)
|
||||
|
||||
config['password'] = to_bytes(config.get('password', b''))
|
||||
config['method'] = to_str(config.get('method', 'aes-256-cfb'))
|
||||
config['port_password'] = config.get('port_password', None)
|
||||
config['timeout'] = int(config.get('timeout', 300))
|
||||
config['fast_open'] = config.get('fast_open', False)
|
||||
config['workers'] = config.get('workers', 1)
|
||||
config['pid-file'] = config.get('pid-file', '/var/run/shadowsocks.pid')
|
||||
config['log-file'] = config.get('log-file', '/var/log/shadowsocks.log')
|
||||
config['verbose'] = config.get('verbose', False)
|
||||
config['local_address'] = to_str(config.get('local_address', '127.0.0.1'))
|
||||
config['local_port'] = config.get('local_port', 1080)
|
||||
if is_local:
|
||||
if config.get('server', None) is None:
|
||||
logging.error('server addr not specified')
|
||||
print_local_help()
|
||||
sys.exit(2)
|
||||
else:
|
||||
config['server'] = to_str(config['server'])
|
||||
else:
|
||||
config['server'] = to_str(config.get('server', '0.0.0.0'))
|
||||
try:
|
||||
config['forbidden_ip'] = \
|
||||
IPNetwork(config.get('forbidden_ip', '127.0.0.0/8,::1/128'))
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
sys.exit(2)
|
||||
config['server_port'] = config.get('server_port', 8388)
|
||||
|
||||
logging.getLogger('').handlers = []
|
||||
logging.addLevelName(VERBOSE_LEVEL, 'VERBOSE')
|
||||
if config['verbose'] >= 2:
|
||||
level = VERBOSE_LEVEL
|
||||
elif config['verbose'] == 1:
|
||||
level = logging.DEBUG
|
||||
elif config['verbose'] == -1:
|
||||
level = logging.WARN
|
||||
elif config['verbose'] <= -2:
|
||||
level = logging.ERROR
|
||||
else:
|
||||
level = logging.INFO
|
||||
verbose = config['verbose']
|
||||
logging.basicConfig(level=level,
|
||||
format='%(asctime)s %(levelname)-8s %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S')
|
||||
|
||||
check_config(config, is_local)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def print_help(is_local):
|
||||
if is_local:
|
||||
print_local_help()
|
||||
else:
|
||||
print_server_help()
|
||||
|
||||
|
||||
def print_local_help():
|
||||
print('''usage: sslocal [OPTION]...
|
||||
A fast tunnel proxy that helps you bypass firewalls.
|
||||
|
||||
You can supply configurations via either config file or command line arguments.
|
||||
|
||||
Proxy options:
|
||||
-c CONFIG path to config file
|
||||
-s SERVER_ADDR server address
|
||||
-p SERVER_PORT server port, default: 8388
|
||||
-b LOCAL_ADDR local binding address, default: 127.0.0.1
|
||||
-l LOCAL_PORT local port, default: 1080
|
||||
-k PASSWORD password
|
||||
-m METHOD encryption method, default: aes-256-cfb
|
||||
-t TIMEOUT timeout in seconds, default: 300
|
||||
--fast-open use TCP_FASTOPEN, requires Linux 3.7+
|
||||
|
||||
General options:
|
||||
-h, --help show this help message and exit
|
||||
-d start/stop/restart daemon mode
|
||||
--pid-file PID_FILE pid file for daemon mode
|
||||
--log-file LOG_FILE log file for daemon mode
|
||||
--user USER username to run as
|
||||
-v, -vv verbose mode
|
||||
-q, -qq quiet mode, only show warnings/errors
|
||||
--version show version information
|
||||
|
||||
Online help: <https://github.com/shadowsocks/shadowsocks>
|
||||
''')
|
||||
|
||||
|
||||
def print_server_help():
|
||||
print('''usage: ssserver [OPTION]...
|
||||
A fast tunnel proxy that helps you bypass firewalls.
|
||||
|
||||
You can supply configurations via either config file or command line arguments.
|
||||
|
||||
Proxy options:
|
||||
-c CONFIG path to config file
|
||||
-s SERVER_ADDR server address, default: 0.0.0.0
|
||||
-p SERVER_PORT server port, default: 8388
|
||||
-k PASSWORD password
|
||||
-m METHOD encryption method, default: aes-256-cfb
|
||||
-t TIMEOUT timeout in seconds, default: 300
|
||||
--fast-open use TCP_FASTOPEN, requires Linux 3.7+
|
||||
--workers WORKERS number of workers, available on Unix/Linux
|
||||
--forbidden-ip IPLIST comma seperated IP list forbidden to connect
|
||||
--manager-address ADDR optional server manager UDP address, see wiki
|
||||
|
||||
General options:
|
||||
-h, --help show this help message and exit
|
||||
-d start/stop/restart daemon mode
|
||||
--pid-file PID_FILE pid file for daemon mode
|
||||
--log-file LOG_FILE log file for daemon mode
|
||||
--user USER username to run as
|
||||
-v, -vv verbose mode
|
||||
-q, -qq quiet mode, only show warnings/errors
|
||||
--version show version information
|
||||
|
||||
Online help: <https://github.com/shadowsocks/shadowsocks>
|
||||
''')
|
||||
|
||||
|
||||
def _decode_list(data):
|
||||
rv = []
|
||||
for item in data:
|
||||
if hasattr(item, 'encode'):
|
||||
item = item.encode('utf-8')
|
||||
elif isinstance(item, list):
|
||||
item = _decode_list(item)
|
||||
elif isinstance(item, dict):
|
||||
item = _decode_dict(item)
|
||||
rv.append(item)
|
||||
return rv
|
||||
|
||||
|
||||
def _decode_dict(data):
|
||||
rv = {}
|
||||
for key, value in data.items():
|
||||
if hasattr(value, 'encode'):
|
||||
value = value.encode('utf-8')
|
||||
elif isinstance(value, list):
|
||||
value = _decode_list(value)
|
||||
elif isinstance(value, dict):
|
||||
value = _decode_dict(value)
|
||||
rv[key] = value
|
||||
return rv
|
||||
|
||||
|
||||
def parse_json_in_str(data):
|
||||
# parse json and convert everything from unicode to str
|
||||
return json.loads(data, object_hook=_decode_dict)
|
387
tcprelay.py
Normal file
387
tcprelay.py
Normal file
|
@ -0,0 +1,387 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import absolute_import, division, print_function, \
|
||||
with_statement
|
||||
|
||||
import errno
|
||||
import logging
|
||||
import socket
|
||||
|
||||
from shadowsocks.selectors import (EVENT_READ, EVENT_WRITE, EVENT_ERROR,
|
||||
errno_from_exception, get_sock_error)
|
||||
from shadowsocks.common import parse_header, to_str
|
||||
from shadowsocks import encrypt
|
||||
|
||||
|
||||
BUF_SIZE = 32 * 1024
|
||||
CMD_CONNECT = 1
|
||||
|
||||
|
||||
def create_sock(ip, port):
|
||||
addrs = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM,
|
||||
socket.SOL_TCP)
|
||||
if len(addrs) == 0:
|
||||
raise Exception("Getaddrinfo failed for %s:%d" % (ip, port))
|
||||
|
||||
af, socktype, proto, canonname, sa = addrs[0]
|
||||
sock = socket.socket(af, socktype, proto)
|
||||
sock.setblocking(False)
|
||||
sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
|
||||
return sock
|
||||
|
||||
|
||||
class TcpRelayHanler(object):
|
||||
|
||||
def __init__(self, local_sock, local_addr, remote_addr=None,
|
||||
dns_resolver=None):
|
||||
self._loop = None
|
||||
self._local_sock = local_sock
|
||||
self._local_addr = local_addr
|
||||
self._remote_addr = remote_addr
|
||||
# self._crypt = None
|
||||
self._crypt = encrypt.Encryptor(b'PassThrouthGFW', 'aes-256-cfb')
|
||||
self._dns_resolver = dns_resolver
|
||||
self._remote_sock = None
|
||||
self._local_sock_mode = 0
|
||||
self._remote_sock_mode = 0
|
||||
self._data_to_write_to_local = []
|
||||
self._data_to_write_to_remote = []
|
||||
self._id = id(self)
|
||||
|
||||
def handle_event(self, sock, event, call, *args):
|
||||
if event & EVENT_ERROR:
|
||||
logging.error(get_sock_error(sock))
|
||||
self.close()
|
||||
else:
|
||||
try:
|
||||
call(sock, event, *args)
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
self.close()
|
||||
|
||||
def __del__(self):
|
||||
logging.debug('Deleting {}'.format(self._id))
|
||||
|
||||
def add_to_loop(self, loop):
|
||||
if self._loop:
|
||||
raise Exception('Already added to loop')
|
||||
self._loop = loop
|
||||
loop.add(self._local_sock, EVENT_READ, (self, self.start))
|
||||
|
||||
def modify_local_sock_mode(self, event):
|
||||
if self._local_sock_mode != event:
|
||||
self._local_sock_mode = self.modify_sock_mode(self._local_sock,
|
||||
event)
|
||||
|
||||
def modify_remote_sock_mode(self, event):
|
||||
if self._remote_sock_mode != event:
|
||||
self._remote_sock_mode = self.modify_sock_mode(self._remote_sock,
|
||||
event)
|
||||
|
||||
def modify_sock_mode(self, sock, event):
|
||||
key = self._loop.modify(sock, event, (self, self.stream))
|
||||
return key.events
|
||||
|
||||
def close_sock(self, sock):
|
||||
self._loop.remove(sock)
|
||||
sock.close()
|
||||
|
||||
def close(self):
|
||||
if self._local_sock:
|
||||
self.close_sock(self._local_sock)
|
||||
self._local_sock = None
|
||||
if self._remote_sock:
|
||||
self.close_sock(self._remote_sock)
|
||||
self._remote_sock = None
|
||||
|
||||
def sock_connect(self, sock, addr):
|
||||
while True:
|
||||
try:
|
||||
sock.connect(addr)
|
||||
except (OSError, IOError) as e:
|
||||
err = errno_from_exception(e)
|
||||
if err == errno.EINTR:
|
||||
pass
|
||||
elif err == errno.EINPROGRESS:
|
||||
break
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
break
|
||||
|
||||
def sock_recv(self, sock, size=BUF_SIZE):
|
||||
try:
|
||||
data = sock.recv(size)
|
||||
if not data:
|
||||
self.close()
|
||||
except (OSError, IOError) as e:
|
||||
if errno_from_exception(e) in (errno.EAGAIN, errno.EWOULDBLOCK,
|
||||
errno.EINTR):
|
||||
return
|
||||
else:
|
||||
raise
|
||||
return data
|
||||
|
||||
def sock_send(self, sock, data):
|
||||
try:
|
||||
s = sock.send(data)
|
||||
data = data[s:]
|
||||
except (OSError, IOError) as e:
|
||||
if errno_from_exception(e) in (errno.EAGAIN, errno.EWOULDBLOCK,
|
||||
errno.EINPROGRESS, errno.EINTR):
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
return data
|
||||
|
||||
def on_local_read(self, size=BUF_SIZE):
|
||||
logging.debug('on_local_read')
|
||||
if not self._local_sock:
|
||||
return
|
||||
|
||||
data = self.sock_recv(self._local_sock, size)
|
||||
if not data:
|
||||
return
|
||||
|
||||
logging.debug('Received {} bytes from {}:{}'.format(len(data),
|
||||
*self._local_addr))
|
||||
if self._crypt:
|
||||
if self._is_client:
|
||||
data = self._crypt.encrypt(data)
|
||||
else:
|
||||
data = self._crypt.decrypt(data)
|
||||
|
||||
if data:
|
||||
self._data_to_write_to_remote.append(data)
|
||||
self.on_remote_write()
|
||||
|
||||
def on_remote_read(self, size=BUF_SIZE):
|
||||
logging.debug('on_remote_read')
|
||||
if not self._remote_sock:
|
||||
return
|
||||
|
||||
data = self.sock_recv(self._remote_sock, size)
|
||||
if not data:
|
||||
return
|
||||
|
||||
logging.debug('Received {} bytes from {}:{}'.format(
|
||||
len(data), *self._remote_addr))
|
||||
|
||||
if self._crypt:
|
||||
if self._is_client:
|
||||
data = self._crypt.decrypt(data)
|
||||
else:
|
||||
data = self._crypt.encrypt(data)
|
||||
|
||||
if data:
|
||||
self._data_to_write_to_local.append(data)
|
||||
self.on_local_write()
|
||||
|
||||
def on_local_write(self):
|
||||
logging.debug('on_local_write')
|
||||
if not self._local_sock:
|
||||
return
|
||||
|
||||
if not self._data_to_write_to_local:
|
||||
self.modify_local_sock_mode(EVENT_READ)
|
||||
return
|
||||
|
||||
data = b''.join(self._data_to_write_to_local)
|
||||
self._data_to_write_to_local = []
|
||||
|
||||
data = self.sock_send(self._local_sock, data)
|
||||
|
||||
if data:
|
||||
self._data_to_write_to_local.append(data)
|
||||
self.modify_local_sock_mode(EVENT_WRITE)
|
||||
else:
|
||||
self.modify_local_sock_mode(EVENT_READ)
|
||||
|
||||
def on_remote_write(self):
|
||||
logging.debug('on_remote_write')
|
||||
if not self._remote_sock:
|
||||
return
|
||||
|
||||
if not self._data_to_write_to_remote:
|
||||
self.modify_remote_sock_mode(EVENT_READ)
|
||||
return
|
||||
|
||||
data = b''.join(self._data_to_write_to_remote)
|
||||
self._data_to_write_to_remote = []
|
||||
|
||||
data = self.sock_send(self._remote_sock, data)
|
||||
|
||||
if data:
|
||||
self._data_to_write_to_remote.append(data)
|
||||
self.modify_remote_sock_mode(EVENT_WRITE)
|
||||
else:
|
||||
self.modify_remote_sock_mode(EVENT_READ)
|
||||
|
||||
def stream(self, sock, event):
|
||||
logging.debug('stream')
|
||||
|
||||
if sock == self._local_sock:
|
||||
if event & EVENT_READ:
|
||||
self.on_local_read()
|
||||
if event & EVENT_WRITE:
|
||||
self.on_local_write()
|
||||
elif sock == self._remote_sock:
|
||||
if event & EVENT_READ:
|
||||
self.on_remote_read()
|
||||
if event & EVENT_WRITE:
|
||||
self.on_remote_write()
|
||||
else:
|
||||
logging.warn('Unknow sock {}'.format(sock))
|
||||
|
||||
|
||||
class TcpRelayClientHanler(TcpRelayHanler):
|
||||
|
||||
_is_client = True
|
||||
|
||||
def start(self, sock, event):
|
||||
data = self.sock_recv(sock)
|
||||
if not data:
|
||||
return
|
||||
reply = b'\x05\x00'
|
||||
self.send_reply(sock, None, reply)
|
||||
|
||||
def send_reply(self, sock, event, data):
|
||||
data = self.sock_send(sock, data)
|
||||
if data:
|
||||
self._loop.modify(sock, EVENT_WRITE, (self, self.send_reply, data))
|
||||
else:
|
||||
self._loop.modify(sock, EVENT_READ, (self, self.handle_addr))
|
||||
|
||||
def handle_addr(self, sock, event):
|
||||
data = self.sock_recv(sock)
|
||||
if not data:
|
||||
return
|
||||
# self._loop.remove(sock)
|
||||
|
||||
if ord(data[1:2]) != CMD_CONNECT:
|
||||
raise Exception('Command not suppored')
|
||||
|
||||
result = parse_header(data[3:])
|
||||
if not result:
|
||||
raise Exception('Header cannot be parsed')
|
||||
|
||||
self._remote_sock = create_sock(*self._remote_addr)
|
||||
self.sock_connect(self._remote_sock, self._remote_addr)
|
||||
|
||||
dest_addr = (to_str(result[1]), result[2])
|
||||
logging.info('Connecting to {}:{}'.format(*dest_addr))
|
||||
data = '{}:{}\n'.format(*dest_addr).encode('utf-8')
|
||||
if self._crypt:
|
||||
data = self._crypt.encrypt(data)
|
||||
self._data_to_write_to_remote.append(data)
|
||||
|
||||
bind_addr = b'\x05\x00\x00\x01\x00\x00\x00\x00\x00\x00'
|
||||
self.send_bind_addr(sock, None, bind_addr)
|
||||
|
||||
def send_bind_addr(self, sock, event, data):
|
||||
data = self.sock_send(sock, data)
|
||||
if data:
|
||||
self._loop.modify(sock, EVENT_WRITE, (self, self.send_bind_addr,
|
||||
data))
|
||||
else:
|
||||
self.modify_local_sock_mode(EVENT_READ)
|
||||
|
||||
|
||||
class TcpRelayServerHandler(TcpRelayHanler):
|
||||
|
||||
_is_client = False
|
||||
|
||||
def start(self, sock, event, data=None):
|
||||
data = self.sock_recv(sock)
|
||||
if not data:
|
||||
return
|
||||
self._loop.remove(sock)
|
||||
|
||||
if self._crypt:
|
||||
data = self._crypt.decrypt(data)
|
||||
remote, data = data.split(b'\n', 1)
|
||||
host, port = remote.split(b':')
|
||||
self._remote = (host, int(port))
|
||||
self._data_to_write_to_remote.append(data)
|
||||
self._dns_resolver.resolve(host, self.dns_resolved)
|
||||
|
||||
def dns_resolved(self, result, error):
|
||||
try:
|
||||
ip = result[1]
|
||||
except (TypeError, IndexError):
|
||||
ip = None
|
||||
|
||||
if not ip:
|
||||
raise Exception('Hostname {} cannot resolved'.format(
|
||||
self._remote[0]))
|
||||
|
||||
self._remote_addr = (ip, self._remote[1])
|
||||
self._remote_sock = create_sock(*self._remote_addr)
|
||||
logging.info('Connecting to {}'.format(self._remote[0]))
|
||||
self.sock_connect(self._remote_sock, self._remote_addr)
|
||||
self.modify_remote_sock_mode(EVENT_WRITE)
|
||||
self.modify_local_sock_mode(EVENT_READ)
|
||||
|
||||
|
||||
class TcpRelay(object):
|
||||
|
||||
def __init__(self, handler_type, listen_addr, remote_addr=None,
|
||||
dns_resolver=None):
|
||||
self._loop = None
|
||||
self._handler_type = handler_type
|
||||
self._listen_addr = listen_addr
|
||||
self._remote_addr = remote_addr
|
||||
self._dns_resolver = dns_resolver
|
||||
self._create_listen_sock()
|
||||
|
||||
def _create_listen_sock(self):
|
||||
sock = create_sock(*self._listen_addr)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind(self._listen_addr)
|
||||
sock.listen(1024)
|
||||
self._listen_sock = sock
|
||||
logging.info('Listening on {}:{}'.format(*self._listen_addr))
|
||||
|
||||
def add_to_loop(self, loop):
|
||||
if self._loop:
|
||||
raise Exception('Already added to loop')
|
||||
self._loop = loop
|
||||
loop.add(self._listen_sock, EVENT_READ, (self, self.accept))
|
||||
|
||||
def _accept(self, listen_sock):
|
||||
try:
|
||||
sock, addr = listen_sock.accept()
|
||||
except (OSError, IOError) as e:
|
||||
if errno_from_exception(e) in (
|
||||
errno.EAGAIN, errno.EWOULDBLOCK, errno.EINPROGRESS,
|
||||
errno.EINTR, errno.ECONNABORTED
|
||||
):
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
logging.info('Connected from {}:{}'.format(*addr))
|
||||
sock.setblocking(False)
|
||||
sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
|
||||
return (sock, addr)
|
||||
|
||||
def accept(self, listen_sock, event):
|
||||
sock, addr = self._accept(listen_sock)
|
||||
handler = self._handler_type(sock, addr, self._remote_addr,
|
||||
self._dns_resolver)
|
||||
handler.add_to_loop(self._loop)
|
||||
|
||||
def close(self):
|
||||
self._loop.remove(self._listen_sock)
|
||||
|
||||
def handle_event(self, sock, event, call, *args):
|
||||
if event & EVENT_ERROR:
|
||||
logging.error(get_sock_error(sock))
|
||||
self.close()
|
||||
else:
|
||||
try:
|
||||
call(sock, event, *args)
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
self.close()
|
Loading…
Add table
Add a link
Reference in a new issue