diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e182e7b --- /dev/null +++ b/.gitignore @@ -0,0 +1,61 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.cache +nosetests.xml +coverage.xml + +# Translations +*.mo +*.pot + +# Django stuff: +*.log + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# database file +*.sqlite +*.sqlite3 +*.db + +# temporary file +*.swp diff --git a/README.md b/README.md index 15db4b5..32561fd 100644 --- a/README.md +++ b/README.md @@ -1 +1 @@ -Removed according to regulations. +Just for learn. diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..dc3abd4 --- /dev/null +++ b/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/python +# +# Copyright 2012-2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement diff --git a/asyncdns.py b/asyncdns.py new file mode 100644 index 0000000..0461f61 --- /dev/null +++ b/asyncdns.py @@ -0,0 +1,481 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2014-2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import os +import socket +import struct +import re +import logging + +from shadowsocks import common, lru_cache, eventloop, shell + + +CACHE_SWEEP_INTERVAL = 30 + +VALID_HOSTNAME = re.compile(br"(?!-)[A-Z\d-]{1,63}(? 63: + return None + results.append(common.chr(l)) + results.append(label) + results.append(b'\0') + return b''.join(results) + + +def build_request(address, qtype): + request_id = os.urandom(2) + header = struct.pack('!BBHHHH', 1, 0, 1, 0, 0, 0) + addr = build_address(address) + qtype_qclass = struct.pack('!HH', qtype, QCLASS_IN) + return request_id + header + addr + qtype_qclass + + +def parse_ip(addrtype, data, length, offset): + if addrtype == QTYPE_A: + return socket.inet_ntop(socket.AF_INET, data[offset:offset + length]) + elif addrtype == QTYPE_AAAA: + return socket.inet_ntop(socket.AF_INET6, data[offset:offset + length]) + elif addrtype in [QTYPE_CNAME, QTYPE_NS]: + return parse_name(data, offset)[1] + else: + return data[offset:offset + length] + + +def parse_name(data, offset): + p = offset + labels = [] + l = common.ord(data[p]) + while l > 0: + if (l & (128 + 64)) == (128 + 64): + # pointer + pointer = struct.unpack('!H', data[p:p + 2])[0] + pointer &= 0x3FFF + r = parse_name(data, pointer) + labels.append(r[1]) + p += 2 + # pointer is the end + return p - offset, b'.'.join(labels) + else: + labels.append(data[p + 1:p + 1 + l]) + p += 1 + l + l = common.ord(data[p]) + return p - offset + 1, b'.'.join(labels) + + +# rfc1035 +# record +# 1 1 1 1 1 1 +# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +# | | +# / / +# / NAME / +# | | +# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +# | TYPE | +# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +# | CLASS | +# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +# | TTL | +# | | +# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +# | RDLENGTH | +# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--| +# / RDATA / +# / / +# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +def parse_record(data, offset, question=False): + nlen, name = parse_name(data, offset) + if not question: + record_type, record_class, record_ttl, record_rdlength = struct.unpack( + '!HHiH', data[offset + nlen:offset + nlen + 10] + ) + ip = parse_ip(record_type, data, record_rdlength, offset + nlen + 10) + return nlen + 10 + record_rdlength, \ + (name, ip, record_type, record_class, record_ttl) + else: + record_type, record_class = struct.unpack( + '!HH', data[offset + nlen:offset + nlen + 4] + ) + return nlen + 4, (name, None, record_type, record_class, None, None) + + +def parse_header(data): + if len(data) >= 12: + header = struct.unpack('!HBBHHHH', data[:12]) + res_id = header[0] + res_qr = header[1] & 128 + res_tc = header[1] & 2 + res_ra = header[2] & 128 + res_rcode = header[2] & 15 + # assert res_tc == 0 + # assert res_rcode in [0, 3] + res_qdcount = header[3] + res_ancount = header[4] + res_nscount = header[5] + res_arcount = header[6] + return (res_id, res_qr, res_tc, res_ra, res_rcode, res_qdcount, + res_ancount, res_nscount, res_arcount) + return None + + +def parse_response(data): + try: + if len(data) >= 12: + header = parse_header(data) + if not header: + return None + res_id, res_qr, res_tc, res_ra, res_rcode, res_qdcount, \ + res_ancount, res_nscount, res_arcount = header + + qds = [] + ans = [] + offset = 12 + for i in range(0, res_qdcount): + l, r = parse_record(data, offset, True) + offset += l + if r: + qds.append(r) + for i in range(0, res_ancount): + l, r = parse_record(data, offset) + offset += l + if r: + ans.append(r) + for i in range(0, res_nscount): + l, r = parse_record(data, offset) + offset += l + for i in range(0, res_arcount): + l, r = parse_record(data, offset) + offset += l + response = DNSResponse() + if qds: + response.hostname = qds[0][0] + for an in qds: + response.questions.append((an[1], an[2], an[3])) + for an in ans: + response.answers.append((an[1], an[2], an[3])) + return response + except Exception as e: + shell.print_exception(e) + return None + + +def is_valid_hostname(hostname): + if len(hostname) > 255: + return False + if hostname[-1] == b'.': + hostname = hostname[:-1] + return all(VALID_HOSTNAME.match(x) for x in hostname.split(b'.')) + + +class DNSResponse(object): + def __init__(self): + self.hostname = None + self.questions = [] # each: (addr, type, class) + self.answers = [] # each: (addr, type, class) + + def __str__(self): + return '%s: %s' % (self.hostname, str(self.answers)) + + +STATUS_IPV4 = 0 +STATUS_IPV6 = 1 + + +class DNSResolver(object): + + def __init__(self): + self._loop = None + self._hosts = {} + self._hostname_status = {} + self._hostname_to_cb = {} + self._cb_to_hostname = {} + self._cache = lru_cache.LRUCache(timeout=300) + self._sock = None + self._servers = None + self._parse_resolv() + self._parse_hosts() + # TODO monitor hosts change and reload hosts + # TODO parse /etc/gai.conf and follow its rules + + def _parse_resolv(self): + self._servers = [] + try: + with open('/etc/resolv.conf', 'rb') as f: + content = f.readlines() + for line in content: + line = line.strip() + if line: + if line.startswith(b'nameserver'): + parts = line.split() + if len(parts) >= 2: + server = parts[1] + if common.is_ip(server) == socket.AF_INET: + if type(server) != str: + server = server.decode('utf8') + self._servers.append(server) + except IOError: + pass + if not self._servers: + self._servers = ['8.8.4.4', '8.8.8.8'] + + def _parse_hosts(self): + etc_path = '/etc/hosts' + if 'WINDIR' in os.environ: + etc_path = os.environ['WINDIR'] + '/system32/drivers/etc/hosts' + try: + with open(etc_path, 'rb') as f: + for line in f.readlines(): + line = line.strip() + parts = line.split() + if len(parts) >= 2: + ip = parts[0] + if common.is_ip(ip): + for i in range(1, len(parts)): + hostname = parts[i] + if hostname: + self._hosts[hostname] = ip + except IOError: + self._hosts['localhost'] = '127.0.0.1' + + def add_to_loop(self, loop): + if self._loop: + raise Exception('already add to loop') + self._loop = loop + # TODO when dns server is IPv6 + self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, + socket.SOL_UDP) + self._sock.setblocking(False) + loop.add(self._sock, eventloop.POLL_IN, self) + loop.add_periodic(self.handle_periodic) + + def _call_callback(self, hostname, ip, error=None): + callbacks = self._hostname_to_cb.get(hostname, []) + for callback in callbacks: + if callback in self._cb_to_hostname: + del self._cb_to_hostname[callback] + if ip or error: + callback((hostname, ip), error) + else: + callback((hostname, None), + Exception('unknown hostname %s' % hostname)) + if hostname in self._hostname_to_cb: + del self._hostname_to_cb[hostname] + if hostname in self._hostname_status: + del self._hostname_status[hostname] + + def _handle_data(self, data): + response = parse_response(data) + if response and response.hostname: + hostname = response.hostname + ip = None + for answer in response.answers: + if answer[1] in (QTYPE_A, QTYPE_AAAA) and \ + answer[2] == QCLASS_IN: + ip = answer[0] + break + if not ip and self._hostname_status.get(hostname, STATUS_IPV6) \ + == STATUS_IPV4: + self._hostname_status[hostname] = STATUS_IPV6 + self._send_req(hostname, QTYPE_AAAA) + else: + if ip: + self._cache[hostname] = ip + self._call_callback(hostname, ip) + elif self._hostname_status.get(hostname, None) == STATUS_IPV6: + for question in response.questions: + if question[1] == QTYPE_AAAA: + self._call_callback(hostname, None) + break + + def handle_event(self, sock, event): + if sock != self._sock: + return + if event & eventloop.POLL_ERR: + logging.error('dns socket err') + self._loop.remove(self._sock) + self._sock.close() + # TODO when dns server is IPv6 + self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, + socket.SOL_UDP) + self._sock.setblocking(False) + self._loop.add(self._sock, eventloop.POLL_IN, self) + else: + data, addr = sock.recvfrom(1024) + if addr[0] not in self._servers: + logging.warn('received a packet other than our dns') + return + self._handle_data(data) + + def handle_periodic(self): + self._cache.sweep() + + def remove_callback(self, callback): + hostname = self._cb_to_hostname.get(callback) + if hostname: + del self._cb_to_hostname[callback] + arr = self._hostname_to_cb.get(hostname, None) + if arr: + arr.remove(callback) + if not arr: + del self._hostname_to_cb[hostname] + if hostname in self._hostname_status: + del self._hostname_status[hostname] + + def _send_req(self, hostname, qtype): + req = build_request(hostname, qtype) + for server in self._servers: + logging.debug('resolving %s with type %d using server %s', + hostname, qtype, server) + self._sock.sendto(req, (server, 53)) + + def resolve(self, hostname, callback): + if type(hostname) != bytes: + hostname = hostname.encode('utf8') + if not hostname: + callback(None, Exception('empty hostname')) + elif common.is_ip(hostname): + callback((hostname, hostname), None) + elif hostname in self._hosts: + logging.debug('hit hosts: %s', hostname) + ip = self._hosts[hostname] + callback((hostname, ip), None) + elif hostname in self._cache: + logging.debug('hit cache: %s', hostname) + ip = self._cache[hostname] + callback((hostname, ip), None) + else: + if not is_valid_hostname(hostname): + callback(None, Exception('invalid hostname: %s' % hostname)) + return + arr = self._hostname_to_cb.get(hostname, None) + if not arr: + self._hostname_status[hostname] = STATUS_IPV4 + self._send_req(hostname, QTYPE_A) + self._hostname_to_cb[hostname] = [callback] + self._cb_to_hostname[callback] = hostname + else: + arr.append(callback) + # TODO send again only if waited too long + self._send_req(hostname, QTYPE_A) + + def close(self): + if self._sock: + if self._loop: + self._loop.remove_periodic(self.handle_periodic) + self._loop.remove(self._sock) + self._sock.close() + self._sock = None + + +def test(): + dns_resolver = DNSResolver() + loop = eventloop.EventLoop() + dns_resolver.add_to_loop(loop) + + global counter + counter = 0 + + def make_callback(): + global counter + + def callback(result, error): + global counter + # TODO: what can we assert? + print(result, error) + counter += 1 + if counter == 9: + dns_resolver.close() + loop.stop() + a_callback = callback + return a_callback + + assert(make_callback() != make_callback()) + + dns_resolver.resolve(b'google.com', make_callback()) + dns_resolver.resolve('google.com', make_callback()) + dns_resolver.resolve('example.com', make_callback()) + dns_resolver.resolve('ipv6.google.com', make_callback()) + dns_resolver.resolve('www.facebook.com', make_callback()) + dns_resolver.resolve('ns2.google.com', make_callback()) + dns_resolver.resolve('invalid.@!#$%^&$@.hostname', make_callback()) + dns_resolver.resolve('toooooooooooooooooooooooooooooooooooooooooooooooooo' + 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' + 'long.hostname', make_callback()) + dns_resolver.resolve('toooooooooooooooooooooooooooooooooooooooooooooooooo' + 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' + 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' + 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' + 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' + 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' + 'long.hostname', make_callback()) + + loop.run() + + +if __name__ == '__main__': + test() diff --git a/client.py b/client.py new file mode 100644 index 0000000..b196fee --- /dev/null +++ b/client.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import logging +import os +import sys + +path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, path) +from shadowsocks.eventloop import EventLoop +from shadowsocks.tcprelay import TcpRelay, TcpRelayClientHanler + + +FORMATTER = '%(asctime)s - %(levelname)s - %(message)s' +LOGGING_LEVEL = logging.INFO +logging.basicConfig(level=LOGGING_LEVEL, format=FORMATTER) + +LISTEN_ADDR = ('127.0.0.1', 1080) +REMOTE_ADDR = ('127.0.0.1', 9000) + + +def main(): + loop = EventLoop() + relay = TcpRelay(TcpRelayClientHanler, LISTEN_ADDR, REMOTE_ADDR) + relay.add_to_loop(loop) + loop.run() + + +if __name__ == '__main__': + main() diff --git a/common.py b/common.py new file mode 100644 index 0000000..db4beea --- /dev/null +++ b/common.py @@ -0,0 +1,281 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# +# Copyright 2013-2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import socket +import struct +import logging + + +def compat_ord(s): + if type(s) == int: + return s + return _ord(s) + + +def compat_chr(d): + if bytes == str: + return _chr(d) + return bytes([d]) + + +_ord = ord +_chr = chr +ord = compat_ord +chr = compat_chr + + +def to_bytes(s): + if bytes != str: + if type(s) == str: + return s.encode('utf-8') + return s + + +def to_str(s): + if bytes != str: + if type(s) == bytes: + return s.decode('utf-8') + return s + + +def inet_ntop(family, ipstr): + if family == socket.AF_INET: + return to_bytes(socket.inet_ntoa(ipstr)) + elif family == socket.AF_INET6: + import re + v6addr = ':'.join(('%02X%02X' % (ord(i), ord(j))).lstrip('0') + for i, j in zip(ipstr[::2], ipstr[1::2])) + v6addr = re.sub('::+', '::', v6addr, count=1) + return to_bytes(v6addr) + + +def inet_pton(family, addr): + addr = to_str(addr) + if family == socket.AF_INET: + return socket.inet_aton(addr) + elif family == socket.AF_INET6: + if '.' in addr: # a v4 addr + v4addr = addr[addr.rindex(':') + 1:] + v4addr = socket.inet_aton(v4addr) + v4addr = map(lambda x: ('%02X' % ord(x)), v4addr) + v4addr.insert(2, ':') + newaddr = addr[:addr.rindex(':') + 1] + ''.join(v4addr) + return inet_pton(family, newaddr) + dbyts = [0] * 8 # 8 groups + grps = addr.split(':') + for i, v in enumerate(grps): + if v: + dbyts[i] = int(v, 16) + else: + for j, w in enumerate(grps[::-1]): + if w: + dbyts[7 - j] = int(w, 16) + else: + break + break + return b''.join((chr(i // 256) + chr(i % 256)) for i in dbyts) + else: + raise RuntimeError("What family?") + + +def is_ip(address): + for family in (socket.AF_INET, socket.AF_INET6): + try: + if type(address) != str: + address = address.decode('utf8') + inet_pton(family, address) + return family + except (TypeError, ValueError, OSError, IOError): + pass + return False + + +def patch_socket(): + if not hasattr(socket, 'inet_pton'): + socket.inet_pton = inet_pton + + if not hasattr(socket, 'inet_ntop'): + socket.inet_ntop = inet_ntop + + +patch_socket() + + +ADDRTYPE_IPV4 = 1 +ADDRTYPE_IPV6 = 4 +ADDRTYPE_HOST = 3 + + +def pack_addr(address): + address_str = to_str(address) + for family in (socket.AF_INET, socket.AF_INET6): + try: + r = socket.inet_pton(family, address_str) + if family == socket.AF_INET6: + return b'\x04' + r + else: + return b'\x01' + r + except (TypeError, ValueError, OSError, IOError): + pass + if len(address) > 255: + address = address[:255] # TODO + return b'\x03' + chr(len(address)) + address + + +def parse_header(data): + addrtype = ord(data[0]) + dest_addr = None + dest_port = None + header_length = 0 + if addrtype == ADDRTYPE_IPV4: + if len(data) >= 7: + dest_addr = socket.inet_ntoa(data[1:5]) + dest_port = struct.unpack('>H', data[5:7])[0] + header_length = 7 + else: + logging.warn('header is too short') + elif addrtype == ADDRTYPE_HOST: + if len(data) > 2: + addrlen = ord(data[1]) + if len(data) >= 2 + addrlen: + dest_addr = data[2:2 + addrlen] + dest_port = struct.unpack('>H', data[2 + addrlen:4 + + addrlen])[0] + header_length = 4 + addrlen + else: + logging.warn('header is too short') + else: + logging.warn('header is too short') + elif addrtype == ADDRTYPE_IPV6: + if len(data) >= 19: + dest_addr = socket.inet_ntop(socket.AF_INET6, data[1:17]) + dest_port = struct.unpack('>H', data[17:19])[0] + header_length = 19 + else: + logging.warn('header is too short') + else: + logging.warn('unsupported addrtype %d, maybe wrong password or ' + 'encryption method' % addrtype) + if dest_addr is None: + return None + return addrtype, to_bytes(dest_addr), dest_port, header_length + + +class IPNetwork(object): + ADDRLENGTH = {socket.AF_INET: 32, socket.AF_INET6: 128, False: 0} + + def __init__(self, addrs): + self._network_list_v4 = [] + self._network_list_v6 = [] + if type(addrs) == str: + addrs = addrs.split(',') + list(map(self.add_network, addrs)) + + def add_network(self, addr): + if addr is "": + return + block = addr.split('/') + addr_family = is_ip(block[0]) + addr_len = IPNetwork.ADDRLENGTH[addr_family] + if addr_family is socket.AF_INET: + ip, = struct.unpack("!I", socket.inet_aton(block[0])) + elif addr_family is socket.AF_INET6: + hi, lo = struct.unpack("!QQ", inet_pton(addr_family, block[0])) + ip = (hi << 64) | lo + else: + raise Exception("Not a valid CIDR notation: %s" % addr) + if len(block) is 1: + prefix_size = 0 + while (ip & 1) == 0 and ip is not 0: + ip >>= 1 + prefix_size += 1 + logging.warn("You did't specify CIDR routing prefix size for %s, " + "implicit treated as %s/%d" % (addr, addr, addr_len)) + elif block[1].isdigit() and int(block[1]) <= addr_len: + prefix_size = addr_len - int(block[1]) + ip >>= prefix_size + else: + raise Exception("Not a valid CIDR notation: %s" % addr) + if addr_family is socket.AF_INET: + self._network_list_v4.append((ip, prefix_size)) + else: + self._network_list_v6.append((ip, prefix_size)) + + def __contains__(self, addr): + addr_family = is_ip(addr) + if addr_family is socket.AF_INET: + ip, = struct.unpack("!I", socket.inet_aton(addr)) + return any(map(lambda n_ps: n_ps[0] == ip >> n_ps[1], + self._network_list_v4)) + elif addr_family is socket.AF_INET6: + hi, lo = struct.unpack("!QQ", inet_pton(addr_family, addr)) + ip = (hi << 64) | lo + return any(map(lambda n_ps: n_ps[0] == ip >> n_ps[1], + self._network_list_v6)) + else: + return False + + +def test_inet_conv(): + ipv4 = b'8.8.4.4' + b = inet_pton(socket.AF_INET, ipv4) + assert inet_ntop(socket.AF_INET, b) == ipv4 + ipv6 = b'2404:6800:4005:805::1011' + b = inet_pton(socket.AF_INET6, ipv6) + assert inet_ntop(socket.AF_INET6, b) == ipv6 + + +def test_parse_header(): + assert parse_header(b'\x03\x0ewww.google.com\x00\x50') == \ + (3, b'www.google.com', 80, 18) + assert parse_header(b'\x01\x08\x08\x08\x08\x00\x35') == \ + (1, b'8.8.8.8', 53, 7) + assert parse_header((b'\x04$\x04h\x00@\x05\x08\x05\x00\x00\x00\x00\x00' + b'\x00\x10\x11\x00\x50')) == \ + (4, b'2404:6800:4005:805::1011', 80, 19) + + +def test_pack_header(): + assert pack_addr(b'8.8.8.8') == b'\x01\x08\x08\x08\x08' + assert pack_addr(b'2404:6800:4005:805::1011') == \ + b'\x04$\x04h\x00@\x05\x08\x05\x00\x00\x00\x00\x00\x00\x10\x11' + assert pack_addr(b'www.google.com') == b'\x03\x0ewww.google.com' + + +def test_ip_network(): + ip_network = IPNetwork('127.0.0.0/24,::ff:1/112,::1,192.168.1.1,192.0.2.0') + assert '127.0.0.1' in ip_network + assert '127.0.1.1' not in ip_network + assert ':ff:ffff' in ip_network + assert '::ffff:1' not in ip_network + assert '::1' in ip_network + assert '::2' not in ip_network + assert '192.168.1.1' in ip_network + assert '192.168.1.2' not in ip_network + assert '192.0.2.1' in ip_network + assert '192.0.3.1' in ip_network # 192.0.2.0 is treated as 192.0.2.0/23 + assert 'www.google.com' not in ip_network + + +if __name__ == '__main__': + test_inet_conv() + test_parse_header() + test_pack_header() + test_ip_network() diff --git a/crypto/__init__.py b/crypto/__init__.py new file mode 100644 index 0000000..401c7b7 --- /dev/null +++ b/crypto/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# +# Copyright 2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement diff --git a/crypto/openssl.py b/crypto/openssl.py new file mode 100644 index 0000000..3775b6c --- /dev/null +++ b/crypto/openssl.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python +# +# Copyright 2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +from ctypes import c_char_p, c_int, c_long, byref,\ + create_string_buffer, c_void_p + +from shadowsocks import common +from shadowsocks.crypto import util + +__all__ = ['ciphers'] + +libcrypto = None +loaded = False + +buf_size = 2048 + + +def load_openssl(): + global loaded, libcrypto, buf + + libcrypto = util.find_library(('crypto', 'eay32'), + 'EVP_get_cipherbyname', + 'libcrypto') + if libcrypto is None: + raise Exception('libcrypto(OpenSSL) not found') + + libcrypto.EVP_get_cipherbyname.restype = c_void_p + libcrypto.EVP_CIPHER_CTX_new.restype = c_void_p + + libcrypto.EVP_CipherInit_ex.argtypes = (c_void_p, c_void_p, c_char_p, + c_char_p, c_char_p, c_int) + + libcrypto.EVP_CipherUpdate.argtypes = (c_void_p, c_void_p, c_void_p, + c_char_p, c_int) + + libcrypto.EVP_CIPHER_CTX_cleanup.argtypes = (c_void_p,) + libcrypto.EVP_CIPHER_CTX_free.argtypes = (c_void_p,) + if hasattr(libcrypto, 'OpenSSL_add_all_ciphers'): + libcrypto.OpenSSL_add_all_ciphers() + + buf = create_string_buffer(buf_size) + loaded = True + + +def load_cipher(cipher_name): + func_name = 'EVP_' + cipher_name.replace('-', '_') + if bytes != str: + func_name = str(func_name, 'utf-8') + cipher = getattr(libcrypto, func_name, None) + if cipher: + cipher.restype = c_void_p + return cipher() + return None + + +class OpenSSLCrypto(object): + def __init__(self, cipher_name, key, iv, op): + self._ctx = None + if not loaded: + load_openssl() + cipher_name = common.to_bytes(cipher_name) + cipher = libcrypto.EVP_get_cipherbyname(cipher_name) + if not cipher: + cipher = load_cipher(cipher_name) + if not cipher: + raise Exception('cipher %s not found in libcrypto' % cipher_name) + key_ptr = c_char_p(key) + iv_ptr = c_char_p(iv) + self._ctx = libcrypto.EVP_CIPHER_CTX_new() + if not self._ctx: + raise Exception('can not create cipher context') + r = libcrypto.EVP_CipherInit_ex(self._ctx, cipher, None, + key_ptr, iv_ptr, c_int(op)) + if not r: + self.clean() + raise Exception('can not initialize cipher context') + + def update(self, data): + global buf_size, buf + cipher_out_len = c_long(0) + l = len(data) + if buf_size < l: + buf_size = l * 2 + buf = create_string_buffer(buf_size) + libcrypto.EVP_CipherUpdate(self._ctx, byref(buf), + byref(cipher_out_len), c_char_p(data), l) + # buf is copied to a str object when we access buf.raw + return buf.raw[:cipher_out_len.value] + + def __del__(self): + self.clean() + + def clean(self): + if self._ctx: + libcrypto.EVP_CIPHER_CTX_cleanup(self._ctx) + libcrypto.EVP_CIPHER_CTX_free(self._ctx) + + +ciphers = { + 'aes-128-cfb': (16, 16, OpenSSLCrypto), + 'aes-192-cfb': (24, 16, OpenSSLCrypto), + 'aes-256-cfb': (32, 16, OpenSSLCrypto), + 'aes-128-ofb': (16, 16, OpenSSLCrypto), + 'aes-192-ofb': (24, 16, OpenSSLCrypto), + 'aes-256-ofb': (32, 16, OpenSSLCrypto), + 'aes-128-ctr': (16, 16, OpenSSLCrypto), + 'aes-192-ctr': (24, 16, OpenSSLCrypto), + 'aes-256-ctr': (32, 16, OpenSSLCrypto), + 'aes-128-cfb8': (16, 16, OpenSSLCrypto), + 'aes-192-cfb8': (24, 16, OpenSSLCrypto), + 'aes-256-cfb8': (32, 16, OpenSSLCrypto), + 'aes-128-cfb1': (16, 16, OpenSSLCrypto), + 'aes-192-cfb1': (24, 16, OpenSSLCrypto), + 'aes-256-cfb1': (32, 16, OpenSSLCrypto), + 'bf-cfb': (16, 8, OpenSSLCrypto), + 'camellia-128-cfb': (16, 16, OpenSSLCrypto), + 'camellia-192-cfb': (24, 16, OpenSSLCrypto), + 'camellia-256-cfb': (32, 16, OpenSSLCrypto), + 'cast5-cfb': (16, 8, OpenSSLCrypto), + 'des-cfb': (8, 8, OpenSSLCrypto), + 'idea-cfb': (16, 8, OpenSSLCrypto), + 'rc2-cfb': (16, 8, OpenSSLCrypto), + 'rc4': (16, 0, OpenSSLCrypto), + 'seed-cfb': (16, 16, OpenSSLCrypto), +} + + +def run_method(method): + + cipher = OpenSSLCrypto(method, b'k' * 32, b'i' * 16, 1) + decipher = OpenSSLCrypto(method, b'k' * 32, b'i' * 16, 0) + + util.run_cipher(cipher, decipher) + + +def test_aes_128_cfb(): + run_method('aes-128-cfb') + + +def test_aes_256_cfb(): + run_method('aes-256-cfb') + + +def test_aes_128_cfb8(): + run_method('aes-128-cfb8') + + +def test_aes_256_ofb(): + run_method('aes-256-ofb') + + +def test_aes_256_ctr(): + run_method('aes-256-ctr') + + +def test_bf_cfb(): + run_method('bf-cfb') + + +def test_rc4(): + run_method('rc4') + + +if __name__ == '__main__': + test_aes_128_cfb() diff --git a/crypto/rc4_md5.py b/crypto/rc4_md5.py new file mode 100644 index 0000000..1f07a82 --- /dev/null +++ b/crypto/rc4_md5.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python +# +# Copyright 2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import hashlib + +from shadowsocks.crypto import openssl + +__all__ = ['ciphers'] + + +def create_cipher(alg, key, iv, op, key_as_bytes=0, d=None, salt=None, + i=1, padding=1): + md5 = hashlib.md5() + md5.update(key) + md5.update(iv) + rc4_key = md5.digest() + return openssl.OpenSSLCrypto(b'rc4', rc4_key, b'', op) + + +ciphers = { + 'rc4-md5': (16, 16, create_cipher), +} + + +def test(): + from shadowsocks.crypto import util + + cipher = create_cipher('rc4-md5', b'k' * 32, b'i' * 16, 1) + decipher = create_cipher('rc4-md5', b'k' * 32, b'i' * 16, 0) + + util.run_cipher(cipher, decipher) + + +if __name__ == '__main__': + test() diff --git a/crypto/sodium.py b/crypto/sodium.py new file mode 100644 index 0000000..ae86fef --- /dev/null +++ b/crypto/sodium.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python +# +# Copyright 2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +from ctypes import c_char_p, c_int, c_ulonglong, byref, \ + create_string_buffer, c_void_p + +from shadowsocks.crypto import util + +__all__ = ['ciphers'] + +libsodium = None +loaded = False + +buf_size = 2048 + +# for salsa20 and chacha20 +BLOCK_SIZE = 64 + + +def load_libsodium(): + global loaded, libsodium, buf + + libsodium = util.find_library('sodium', 'crypto_stream_salsa20_xor_ic', + 'libsodium') + if libsodium is None: + raise Exception('libsodium not found') + + libsodium.crypto_stream_salsa20_xor_ic.restype = c_int + libsodium.crypto_stream_salsa20_xor_ic.argtypes = (c_void_p, c_char_p, + c_ulonglong, + c_char_p, c_ulonglong, + c_char_p) + libsodium.crypto_stream_chacha20_xor_ic.restype = c_int + libsodium.crypto_stream_chacha20_xor_ic.argtypes = (c_void_p, c_char_p, + c_ulonglong, + c_char_p, c_ulonglong, + c_char_p) + + buf = create_string_buffer(buf_size) + loaded = True + + +class SodiumCrypto(object): + def __init__(self, cipher_name, key, iv, op): + if not loaded: + load_libsodium() + self.key = key + self.iv = iv + self.key_ptr = c_char_p(key) + self.iv_ptr = c_char_p(iv) + if cipher_name == 'salsa20': + self.cipher = libsodium.crypto_stream_salsa20_xor_ic + elif cipher_name == 'chacha20': + self.cipher = libsodium.crypto_stream_chacha20_xor_ic + else: + raise Exception('Unknown cipher') + # byte counter, not block counter + self.counter = 0 + + def update(self, data): + global buf_size, buf + l = len(data) + + # we can only prepend some padding to make the encryption align to + # blocks + padding = self.counter % BLOCK_SIZE + if buf_size < padding + l: + buf_size = (padding + l) * 2 + buf = create_string_buffer(buf_size) + + if padding: + data = (b'\0' * padding) + data + self.cipher(byref(buf), c_char_p(data), padding + l, + self.iv_ptr, int(self.counter / BLOCK_SIZE), self.key_ptr) + self.counter += l + # buf is copied to a str object when we access buf.raw + # strip off the padding + return buf.raw[padding:padding + l] + + +ciphers = { + 'salsa20': (32, 8, SodiumCrypto), + 'chacha20': (32, 8, SodiumCrypto), +} + + +def test_salsa20(): + cipher = SodiumCrypto('salsa20', b'k' * 32, b'i' * 16, 1) + decipher = SodiumCrypto('salsa20', b'k' * 32, b'i' * 16, 0) + + util.run_cipher(cipher, decipher) + + +def test_chacha20(): + + cipher = SodiumCrypto('chacha20', b'k' * 32, b'i' * 16, 1) + decipher = SodiumCrypto('chacha20', b'k' * 32, b'i' * 16, 0) + + util.run_cipher(cipher, decipher) + + +if __name__ == '__main__': + test_chacha20() + test_salsa20() diff --git a/crypto/table.py b/crypto/table.py new file mode 100644 index 0000000..bc693f5 --- /dev/null +++ b/crypto/table.py @@ -0,0 +1,174 @@ +# !/usr/bin/env python +# +# Copyright 2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import string +import struct +import hashlib + + +__all__ = ['ciphers'] + +cached_tables = {} + +if hasattr(string, 'maketrans'): + maketrans = string.maketrans + translate = string.translate +else: + maketrans = bytes.maketrans + translate = bytes.translate + + +def get_table(key): + m = hashlib.md5() + m.update(key) + s = m.digest() + a, b = struct.unpack(' 0: + data = m[i - 1] + password + md5.update(data) + m.append(md5.digest()) + i += 1 + ms = b''.join(m) + key = ms[:key_len] + iv = ms[key_len:key_len + iv_len] + cached_keys[cached_key] = (key, iv) + return key, iv + + +class Encryptor(object): + def __init__(self, key, method): + self.key = key + self.method = method + self.iv = None + self.iv_sent = False + self.cipher_iv = b'' + self.decipher = None + method = method.lower() + self._method_info = self.get_method_info(method) + if self._method_info: + self.cipher = self.get_cipher(key, method, 1, + random_string(self._method_info[1])) + else: + logging.error('method %s not supported' % method) + sys.exit(1) + + def get_method_info(self, method): + method = method.lower() + m = method_supported.get(method) + return m + + def iv_len(self): + return len(self.cipher_iv) + + def get_cipher(self, password, method, op, iv): + password = common.to_bytes(password) + m = self._method_info + if m[0] > 0: + key, iv_ = EVP_BytesToKey(password, m[0], m[1]) + else: + # key_length == 0 indicates we should use the key directly + key, iv = password, b'' + + iv = iv[:m[1]] + if op == 1: + # this iv is for cipher not decipher + self.cipher_iv = iv[:m[1]] + return m[2](method, key, iv, op) + + def encrypt(self, buf): + if len(buf) == 0: + return buf + if self.iv_sent: + return self.cipher.update(buf) + else: + self.iv_sent = True + return self.cipher_iv + self.cipher.update(buf) + + def decrypt(self, buf): + if len(buf) == 0: + return buf + if self.decipher is None: + decipher_iv_len = self._method_info[1] + decipher_iv = buf[:decipher_iv_len] + self.decipher = self.get_cipher(self.key, self.method, 0, + iv=decipher_iv) + buf = buf[decipher_iv_len:] + if len(buf) == 0: + return buf + return self.decipher.update(buf) + + +def encrypt_all(password, method, op, data): + result = [] + method = method.lower() + (key_len, iv_len, m) = method_supported[method] + if key_len > 0: + key, _ = EVP_BytesToKey(password, key_len, iv_len) + else: + key = password + if op: + iv = random_string(iv_len) + result.append(iv) + else: + iv = data[:iv_len] + data = data[iv_len:] + cipher = m(method, key, iv, op) + result.append(cipher.update(data)) + return b''.join(result) + + +CIPHERS_TO_TEST = [ + 'aes-128-cfb', + 'aes-256-cfb', + 'rc4-md5', + 'salsa20', + 'chacha20', + 'table', +] + + +def test_encryptor(): + from os import urandom + plain = urandom(10240) + for method in CIPHERS_TO_TEST: + logging.warn(method) + encryptor = Encryptor(b'key', method) + decryptor = Encryptor(b'key', method) + cipher = encryptor.encrypt(plain) + plain2 = decryptor.decrypt(cipher) + assert plain == plain2 + + +def test_encrypt_all(): + from os import urandom + plain = urandom(10240) + for method in CIPHERS_TO_TEST: + logging.warn(method) + cipher = encrypt_all(b'key', method, 1, plain) + plain2 = encrypt_all(b'key', method, 0, cipher) + assert plain == plain2 + + +if __name__ == '__main__': + test_encrypt_all() + test_encryptor() diff --git a/eventloop.py b/eventloop.py new file mode 100644 index 0000000..36bca4e --- /dev/null +++ b/eventloop.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import logging +import time +import traceback +import errno +from shadowsocks import selectors +from shadowsocks.selectors import (EVENT_READ, EVENT_WRITE, EVENT_ERROR, + errno_from_exception) + + +POLL_IN = EVENT_READ +POLL_OUT = EVENT_WRITE +POLL_ERR = EVENT_ERROR + +TIMEOUT_PRECISION = 10 + + +class EventLoop: + + def __init__(self): + self._selector = selectors.DefaultSelector() + self._stopping = False + self._last_time = time.time() + self._periodic_callbacks = [] + + def poll(self, timeout=None): + return self._selector.select(timeout) + + def add(self, sock, events, data): + events |= selectors.EVENT_ERROR + return self._selector.register(sock, events, data) + + def remove(self, sock): + try: + return self._selector.unregister(sock) + except KeyError: + pass + + def modify(self, sock, events, data): + events |= selectors.EVENT_ERROR + try: + key = self._selector.modify(sock, events, data) + except KeyError: + key = self.add(sock, events, data) + return key + + def add_periodic(self, callback): + self._periodic_callbacks.append(callback) + + def remove_periodic(self, callback): + self._periodic_callbacks.remove(callback) + + def fd_count(self): + return len(self._selector.get_map()) + + def run(self): + logging.debug('Starting event loop') + + while not self._stopping: + asap = False + try: + events = self.poll(timeout=TIMEOUT_PRECISION) + except (OSError, IOError) as e: + if errno_from_exception(e) in (errno.EPIPE, errno.EINTR): + # EPIPE: Happens when the client closes the connection + # EINTR: Happens when received a signal + # handles them as soon as possible + asap = True + logging.debug('poll: %s', e) + else: + logging.error('poll: %s', e) + traceback.print_exc() + continue + + for key, event in events: + if type(key.data) == tuple: + handler = key.data[0] + args = key.data[1:] + else: + handler = key.data + args = () + + sock = key.fileobj + if hasattr(handler, 'handle_event'): + handler = handler.handle_event + + try: + handler(sock, event, *args) + except Exception as e: + logging.debug(e) + traceback.print_exc() + raise + + now = time.time() + if asap or now - self._last_time >= TIMEOUT_PRECISION: + for callback in self._periodic_callbacks: + callback() + self._last_time = now + + logging.debug('Got {} fds registered'.format(self.fd_count())) + + logging.debug('Stopping event loop') + self._selector.close() + + def stop(self): + self._stopping = True diff --git a/lru_cache.py b/lru_cache.py new file mode 100644 index 0000000..401f19b --- /dev/null +++ b/lru_cache.py @@ -0,0 +1,150 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# +# Copyright 2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import collections +import logging +import time + + +# this LRUCache is optimized for concurrency, not QPS +# n: concurrency, keys stored in the cache +# m: visits not timed out, proportional to QPS * timeout +# get & set is O(1), not O(n). thus we can support very large n +# TODO: if timeout or QPS is too large, then this cache is not very efficient, +# as sweep() causes long pause + + +class LRUCache(collections.MutableMapping): + """This class is not thread safe""" + + def __init__(self, timeout=60, close_callback=None, *args, **kwargs): + self.timeout = timeout + self.close_callback = close_callback + self._store = {} + self._time_to_keys = collections.defaultdict(list) + self._keys_to_last_time = {} + self._last_visits = collections.deque() + self._closed_values = set() + self.update(dict(*args, **kwargs)) # use the free update to set keys + + def __getitem__(self, key): + # O(1) + t = time.time() + self._keys_to_last_time[key] = t + self._time_to_keys[t].append(key) + self._last_visits.append(t) + return self._store[key] + + def __setitem__(self, key, value): + # O(1) + t = time.time() + self._keys_to_last_time[key] = t + self._store[key] = value + self._time_to_keys[t].append(key) + self._last_visits.append(t) + + def __delitem__(self, key): + # O(1) + del self._store[key] + del self._keys_to_last_time[key] + + def __iter__(self): + return iter(self._store) + + def __len__(self): + return len(self._store) + + def sweep(self): + # O(m) + now = time.time() + c = 0 + while len(self._last_visits) > 0: + least = self._last_visits[0] + if now - least <= self.timeout: + break + if self.close_callback is not None: + for key in self._time_to_keys[least]: + if key in self._store: + if now - self._keys_to_last_time[key] > self.timeout: + value = self._store[key] + if value not in self._closed_values: + self.close_callback(value) + self._closed_values.add(value) + for key in self._time_to_keys[least]: + self._last_visits.popleft() + if key in self._store: + if now - self._keys_to_last_time[key] > self.timeout: + del self._store[key] + del self._keys_to_last_time[key] + c += 1 + del self._time_to_keys[least] + if c: + self._closed_values.clear() + logging.debug('%d keys swept' % c) + + +def test(): + c = LRUCache(timeout=0.3) + + c['a'] = 1 + assert c['a'] == 1 + + time.sleep(0.5) + c.sweep() + assert 'a' not in c + + c['a'] = 2 + c['b'] = 3 + time.sleep(0.2) + c.sweep() + assert c['a'] == 2 + assert c['b'] == 3 + + time.sleep(0.2) + c.sweep() + c['b'] + time.sleep(0.2) + c.sweep() + assert 'a' not in c + assert c['b'] == 3 + + time.sleep(0.5) + c.sweep() + assert 'a' not in c + assert 'b' not in c + + global close_cb_called + close_cb_called = False + + def close_cb(t): + global close_cb_called + assert not close_cb_called + close_cb_called = True + + c = LRUCache(timeout=0.1, close_callback=close_cb) + c['s'] = 1 + c['s'] + time.sleep(0.1) + c['s'] + time.sleep(0.3) + c.sweep() + +if __name__ == '__main__': + test() diff --git a/selectors.py b/selectors.py new file mode 100644 index 0000000..d95251e --- /dev/null +++ b/selectors.py @@ -0,0 +1,652 @@ +"""Selectors module. + +This module allows high-level and efficient I/O multiplexing, built upon the +`select` module primitives. +""" + + +from abc import ABCMeta, abstractmethod +from collections import namedtuple, Mapping +import errno +import math +import os +import select +import socket +import sys + + +# generic events, that must be mapped to implementation-specific ones +EVENT_READ = 0x001 +EVENT_WRITE = 0x004 +EVENT_ERROR = 0x018 + + +def errno_from_exception(e): + """Provides the errno from an Exception object. + + There are cases that the errno attribute was not set so we pull + the errno out of the args but if someone instatiates an Exception + without any args you will get a tuple error. So this function + abstracts all that behavior to give you a safe way to get the + errno. + """ + + if hasattr(e, 'errno'): + return e.errno + elif e.args: + return e.args[0] + else: + return None + + +def get_sock_error(sock): + if not sock: + return + error_number = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + return socket.error(error_number, os.strerror(error_number)) + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + corresponding file descriptor + + Raises: + ValueError if the object is invalid + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (AttributeError, TypeError, ValueError): + raise ValueError("Invalid file object: " + "{!r}".format(fileobj)) + if fd < 0: + raise ValueError("Invalid file descriptor: {}".format(fd)) + return fd + + +SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data']) +"""Object used to associate a file object to its backing file descriptor, +selected event mask and attached data.""" + + +class _SelectorMapping(Mapping): + """Mapping of file objects to selector keys.""" + + def __init__(self, selector): + self._selector = selector + + def __len__(self): + return len(self._selector._fd_to_key) + + def __getitem__(self, fileobj): + try: + fd = self._selector._fileobj_lookup(fileobj) + return self._selector._fd_to_key[fd] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) + + def __iter__(self): + return iter(self._selector._fd_to_key) + + +class BaseSelector(object): + __metaclass__ = ABCMeta + """Selector abstract base class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + efficient implementation on the current platform. + """ + + @abstractmethod + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + + Raises: + ValueError if events is invalid + KeyError if fileobj is already registered + OSError if fileobj is closed or otherwise is unacceptable to + the underlying system call (if a system call is made) + + Note: + OSError may or may not be raised + """ + raise NotImplementedError + + @abstractmethod + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + SelectorKey instance + + Raises: + KeyError if fileobj is not registered + + Note: + If fileobj is registered but has since been closed this does + *not* raise OSError (even if the wrapped syscall does) + """ + raise NotImplementedError + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + + Raises: + Anything that unregister() or register() raises + """ + self.unregister(fileobj) + return self.register(fileobj, events, data) + + @abstractmethod + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout <= 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (key, events) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE + """ + raise NotImplementedError + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + pass + + def get_key(self, fileobj): + """Return the key associated to a registered file object. + + Returns: + SelectorKey for this file object + """ + mapping = self.get_map() + try: + if mapping is None: + raise KeyError + return mapping[fileobj] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) + + @abstractmethod + def get_map(self): + """Return a mapping of file objects to selector keys.""" + raise NotImplementedError + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + +class _BaseSelectorImpl(BaseSelector): + """Base selector implementation.""" + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # read-only mapping returned by get_map() + self._map = _SelectorMapping(self) + + def _fileobj_lookup(self, fileobj): + """Return a file descriptor from a file object. + + This wraps _fileobj_to_fd() to do an exhaustive search in case + the object is invalid but we still have it in our map. This + is used by unregister() so we can unregister an object that + was previously registered even if it is closed. It is also + used by _SelectorMapping. + """ + try: + return _fileobj_to_fd(fileobj) + except ValueError: + # Do an exhaustive search. + for key in self._fd_to_key.values(): + if key.fileobj is fileobj: + return key.fd + # Raise ValueError after all. + raise + + def register(self, fileobj, events, data=None): + if (not events) or (events & ~(EVENT_READ | EVENT_WRITE | + EVENT_ERROR)): + raise ValueError("Invalid events: {!r}".format(events)) + + key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data) + + if key.fd in self._fd_to_key: + raise KeyError("{!r} (FD {}) is already registered" + .format(fileobj, key.fd)) + + self._fd_to_key[key.fd] = key + return key + + def unregister(self, fileobj): + try: + key = self._fd_to_key.pop(self._fileobj_lookup(fileobj)) + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) + return key + + def modify(self, fileobj, events, data=None): + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fd_to_key[self._fileobj_lookup(fileobj)] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) + if events != key.events: + self.unregister(fileobj) + key = self.register(fileobj, events, data) + elif data != key.data: + # Use a shortcut to update the data. + key = key._replace(data=data) + self._fd_to_key[key.fd] = key + return key + + def close(self): + self._fd_to_key.clear() + self._map = None + + def get_map(self): + return self._map + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key, or None if not found + """ + try: + return self._fd_to_key[fd] + except KeyError: + return None + + +class SelectSelector(_BaseSelectorImpl): + """Select-based selector.""" + + def __init__(self): + super(self.__class__, self).__init__() + self._readers = set() + self._writers = set() + self._errors = set() + + def register(self, fileobj, events, data=None): + key = super(self.__class__, self).register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & EVENT_WRITE: + self._writers.add(key.fd) + if events & EVENT_ERROR: + self._errors.add(key.fd) + return key + + def unregister(self, fileobj): + key = super(self.__class__, self).unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + self._errors.discard(key.fd) + return key + + if sys.platform == 'win32': + def _select(self, r, w, x, timeout=None): + r, w, x = select.select(r, w, x, timeout) + return r, w, x + else: + _select = select.select + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + ready = [] + try: + r, w, x = self._select(self._readers, self._writers, self._errors, + timeout) + except OSError as e: + if errno_from_exception(e) == errno.EAGAIN: + return ready + r = set(r) + w = set(w) + x = set(x) + for fd in r | w | x: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE + if fd in x: + events |= EVENT_ERROR + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'poll'): + + class PollSelector(_BaseSelectorImpl): + """Poll-based selector.""" + + def __init__(self): + super(self.__class__, self).__init__() + self._poll = select.poll() + + def register(self, fileobj, events, data=None): + key = super(self.__class__, self).register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= select.POLLIN + if events & EVENT_WRITE: + poll_events |= select.POLLOUT + if events & EVENT_ERROR: + poll_events |= select.POLLERR | select.POLLHUP + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super(self.__class__, self).unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + if timeout is None: + timeout = None + elif timeout <= 0: + timeout = 0 + else: + # poll() has a resolution of 1 millisecond, round away from + # zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except OSError as e: + if errno_from_exception(e) == errno.EAGAIN: + return ready + for fd, event in fd_event_list: + events = 0 + if event & select.POLLIN: + events |= EVENT_READ + if event & select.POLLOUT: + events |= EVENT_WRITE + if event & (select.POLLERR | select.POLLHUP): + events |= EVENT_ERROR + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'epoll'): + + class EpollSelector(_BaseSelectorImpl): + """Epoll-based selector.""" + + def __init__(self): + super(self.__class__, self).__init__() + self._epoll = select.epoll() + + def fileno(self): + return self._epoll.fileno() + + def register(self, fileobj, events, data=None): + key = super(self.__class__, self).register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= select.EPOLLIN + if events & EVENT_WRITE: + epoll_events |= select.EPOLLOUT + if events & EVENT_ERROR: + epoll_events |= select.EPOLLERR | select.EPOLLHUP + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super(self.__class__, self).unregister(fileobj) + try: + self._epoll.unregister(key.fd) + except OSError: + # This can happen if the FD was closed since it + # was registered. + pass + return key + + def select(self, timeout=None): + if timeout is None: + timeout = -1 + elif timeout <= 0: + timeout = 0 + else: + # epoll_wait() has a resolution of 1 millisecond, round away + # from zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) * 1e-3 + + # epoll_wait() expects `maxevents` to be greater than zero; + # we want to make sure that `select()` can be called when no + # FD is registered. + max_ev = max(len(self._fd_to_key), 1) + + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except OSError as e: + if errno_from_exception(e) == errno.EAGAIN: + return ready + for fd, event in fd_event_list: + events = 0 + if event & select.EPOLLIN: + events |= EVENT_READ + if event & select.EPOLLOUT: + events |= EVENT_WRITE + if event & (select.EPOLLERR | select.EPOLLHUP): + events |= EVENT_ERROR + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + try: + self._epoll.close() + finally: + super(self.__class__, self).close() + + +if hasattr(select, 'devpoll'): + + class DevpollSelector(_BaseSelectorImpl): + """Solaris /dev/poll selector.""" + + def __init__(self): + super().__init__() + self._devpoll = select.devpoll() + + def fileno(self): + return self._devpoll.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= select.POLLIN + if events & EVENT_WRITE: + poll_events |= select.POLLOUT + if events & EVENT_ERROR: + poll_events |= select.POLLERR | select.POLLHUP + self._devpoll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._devpoll.unregister(key.fd) + return key + + def select(self, timeout=None): + if timeout is None: + timeout = None + elif timeout <= 0: + timeout = 0 + else: + # devpoll() has a resolution of 1 millisecond, round away from + # zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) + ready = [] + try: + fd_event_list = self._devpoll.poll(timeout) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & select.POLLIN: + events |= EVENT_READ + if event & select.POLLOUT: + events |= EVENT_WRITE + if event & (select.POLLERR | select.POLLHUP): + events |= EVENT_ERROR + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + self._devpoll.close() + super().close() + + +if hasattr(select, 'kqueue'): + + class KqueueSelector(_BaseSelectorImpl): + """Kqueue-based selector.""" + + def __init__(self): + super(self.__class__, self).__init__() + self._kqueue = select.kqueue() + + def fileno(self): + return self._kqueue.fileno() + + def register(self, fileobj, events, data=None): + key = super(self.__class__, self).register(fileobj, events, data) + if events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def unregister(self, fileobj): + key = super(self.__class__, self).unregister(fileobj) + if key.events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_DELETE) + try: + self._kqueue.control([kev], 0, 0) + except OSError: + # This can happen if the FD was closed since it + # was registered. + pass + if key.events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_DELETE) + try: + self._kqueue.control([kev], 0, 0) + except OSError: + # See comment above. + pass + return key + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + max_ev = len(self._fd_to_key) + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except OSError as e: + if errno_from_exception(e) == errno.EAGAIN: + return ready + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == select.KQ_FILTER_READ: + events |= EVENT_READ + if flag == select.KQ_FILTER_WRITE: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + try: + self._kqueue.close() + finally: + super(self.__class__, self).close() + + +# Choose the best implementation: roughly, epoll|kqueue > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + DefaultSelector = KqueueSelector +elif 'EpollSelector' in globals(): + DefaultSelector = EpollSelector +elif 'DevpollSelector' in globals(): + DefaultSelector = DevpollSelector +elif 'PollSelector' in globals(): + DefaultSelector = PollSelector +else: + DefaultSelector = SelectSelector diff --git a/server.py b/server.py new file mode 100644 index 0000000..1d99c8d --- /dev/null +++ b/server.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import os +import sys +import logging + +path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, path) +from shadowsocks.eventloop import EventLoop +from shadowsocks.tcprelay import TcpRelay, TcpRelayServerHandler +from shadowsocks.asyncdns import DNSResolver + + +FORMATTER = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' +LOGGING_LEVEL = logging.INFO +logging.basicConfig(level=LOGGING_LEVEL, format=FORMATTER) + +LISTEN_ADDR = ('0.0.0.0', 9000) + + +def main(): + loop = EventLoop() + dns_resolver = DNSResolver() + relay = TcpRelay(TcpRelayServerHandler, LISTEN_ADDR, + dns_resolver=dns_resolver) + dns_resolver.add_to_loop(loop) + relay.add_to_loop(loop) + loop.run() + +if __name__ == '__main__': + main() diff --git a/shell.py b/shell.py new file mode 100644 index 0000000..c91fc22 --- /dev/null +++ b/shell.py @@ -0,0 +1,365 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# +# Copyright 2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import os +import json +import sys +import getopt +import logging +from shadowsocks.common import to_bytes, to_str, IPNetwork +from shadowsocks import encrypt + + +VERBOSE_LEVEL = 5 + +verbose = 0 + + +def check_python(): + info = sys.version_info + if info[0] == 2 and not info[1] >= 6: + print('Python 2.6+ required') + sys.exit(1) + elif info[0] == 3 and not info[1] >= 3: + print('Python 3.3+ required') + sys.exit(1) + elif info[0] not in [2, 3]: + print('Python version not supported') + sys.exit(1) + + +def print_exception(e): + global verbose + logging.error(e) + if verbose > 0: + import traceback + traceback.print_exc() + + +def print_shadowsocks(): + version = '' + try: + import pkg_resources + version = pkg_resources.get_distribution('shadowsocks').version + except Exception: + pass + print('Shadowsocks %s' % version) + + +def find_config(): + config_path = 'config.json' + if os.path.exists(config_path): + return config_path + config_path = os.path.join(os.path.dirname(__file__), '../', 'config.json') + if os.path.exists(config_path): + return config_path + return None + + +def check_config(config, is_local): + if config.get('daemon', None) == 'stop': + # no need to specify configuration for daemon stop + return + + if is_local and not config.get('password', None): + logging.error('password not specified') + print_help(is_local) + sys.exit(2) + + if not is_local and not config.get('password', None) \ + and not config.get('port_password', None): + logging.error('password or port_password not specified') + print_help(is_local) + sys.exit(2) + + if 'local_port' in config: + config['local_port'] = int(config['local_port']) + + if 'server_port' in config and type(config['server_port']) != list: + config['server_port'] = int(config['server_port']) + + if config.get('local_address', '') in [b'0.0.0.0']: + logging.warn('warning: local set to listen on 0.0.0.0, it\'s not safe') + if config.get('server', '') in ['127.0.0.1', 'localhost']: + logging.warn('warning: server set to listen on %s:%s, are you sure?' % + (to_str(config['server']), config['server_port'])) + if (config.get('method', '') or '').lower() == 'table': + logging.warn('warning: table is not safe; please use a safer cipher, ' + 'like AES-256-CFB') + if (config.get('method', '') or '').lower() == 'rc4': + logging.warn('warning: RC4 is not safe; please use a safer cipher, ' + 'like AES-256-CFB') + if config.get('timeout', 300) < 100: + logging.warn('warning: your timeout %d seems too short' % + int(config.get('timeout'))) + if config.get('timeout', 300) > 600: + logging.warn('warning: your timeout %d seems too long' % + int(config.get('timeout'))) + if config.get('password') in [b'mypassword']: + logging.error('DON\'T USE DEFAULT PASSWORD! Please change it in your ' + 'config.json!') + sys.exit(1) + if config.get('user', None) is not None: + if os.name != 'posix': + logging.error('user can be used only on Unix') + sys.exit(1) + + encrypt.try_cipher(config['password'], config['method']) + + +def get_config(is_local): + global verbose + + logging.basicConfig(level=logging.INFO, + format='%(levelname)-s: %(message)s') + if is_local: + shortopts = 'hd:s:b:p:k:l:m:c:t:vq' + longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'user=', + 'version'] + else: + shortopts = 'hd:s:p:k:m:c:t:vq' + longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'workers=', + 'forbidden-ip=', 'user=', 'manager-address=', 'version'] + try: + config_path = find_config() + optlist, args = getopt.getopt(sys.argv[1:], shortopts, longopts) + for key, value in optlist: + if key == '-c': + config_path = value + + if config_path: + logging.info('loading config from %s' % config_path) + with open(config_path, 'rb') as f: + try: + config = parse_json_in_str(f.read().decode('utf8')) + except ValueError as e: + logging.error('found an error in config.json: %s', + e.message) + sys.exit(1) + else: + config = {} + + v_count = 0 + for key, value in optlist: + if key == '-p': + config['server_port'] = int(value) + elif key == '-k': + config['password'] = to_bytes(value) + elif key == '-l': + config['local_port'] = int(value) + elif key == '-s': + config['server'] = to_str(value) + elif key == '-m': + config['method'] = to_str(value) + elif key == '-b': + config['local_address'] = to_str(value) + elif key == '-v': + v_count += 1 + # '-vv' turns on more verbose mode + config['verbose'] = v_count + elif key == '-t': + config['timeout'] = int(value) + elif key == '--fast-open': + config['fast_open'] = True + elif key == '--workers': + config['workers'] = int(value) + elif key == '--manager-address': + config['manager_address'] = value + elif key == '--user': + config['user'] = to_str(value) + elif key == '--forbidden-ip': + config['forbidden_ip'] = to_str(value).split(',') + elif key in ('-h', '--help'): + if is_local: + print_local_help() + else: + print_server_help() + sys.exit(0) + elif key == '--version': + print_shadowsocks() + sys.exit(0) + elif key == '-d': + config['daemon'] = to_str(value) + elif key == '--pid-file': + config['pid-file'] = to_str(value) + elif key == '--log-file': + config['log-file'] = to_str(value) + elif key == '-q': + v_count -= 1 + config['verbose'] = v_count + except getopt.GetoptError as e: + print(e, file=sys.stderr) + print_help(is_local) + sys.exit(2) + + if not config: + logging.error('config not specified') + print_help(is_local) + sys.exit(2) + + config['password'] = to_bytes(config.get('password', b'')) + config['method'] = to_str(config.get('method', 'aes-256-cfb')) + config['port_password'] = config.get('port_password', None) + config['timeout'] = int(config.get('timeout', 300)) + config['fast_open'] = config.get('fast_open', False) + config['workers'] = config.get('workers', 1) + config['pid-file'] = config.get('pid-file', '/var/run/shadowsocks.pid') + config['log-file'] = config.get('log-file', '/var/log/shadowsocks.log') + config['verbose'] = config.get('verbose', False) + config['local_address'] = to_str(config.get('local_address', '127.0.0.1')) + config['local_port'] = config.get('local_port', 1080) + if is_local: + if config.get('server', None) is None: + logging.error('server addr not specified') + print_local_help() + sys.exit(2) + else: + config['server'] = to_str(config['server']) + else: + config['server'] = to_str(config.get('server', '0.0.0.0')) + try: + config['forbidden_ip'] = \ + IPNetwork(config.get('forbidden_ip', '127.0.0.0/8,::1/128')) + except Exception as e: + logging.error(e) + sys.exit(2) + config['server_port'] = config.get('server_port', 8388) + + logging.getLogger('').handlers = [] + logging.addLevelName(VERBOSE_LEVEL, 'VERBOSE') + if config['verbose'] >= 2: + level = VERBOSE_LEVEL + elif config['verbose'] == 1: + level = logging.DEBUG + elif config['verbose'] == -1: + level = logging.WARN + elif config['verbose'] <= -2: + level = logging.ERROR + else: + level = logging.INFO + verbose = config['verbose'] + logging.basicConfig(level=level, + format='%(asctime)s %(levelname)-8s %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') + + check_config(config, is_local) + + return config + + +def print_help(is_local): + if is_local: + print_local_help() + else: + print_server_help() + + +def print_local_help(): + print('''usage: sslocal [OPTION]... +A fast tunnel proxy that helps you bypass firewalls. + +You can supply configurations via either config file or command line arguments. + +Proxy options: + -c CONFIG path to config file + -s SERVER_ADDR server address + -p SERVER_PORT server port, default: 8388 + -b LOCAL_ADDR local binding address, default: 127.0.0.1 + -l LOCAL_PORT local port, default: 1080 + -k PASSWORD password + -m METHOD encryption method, default: aes-256-cfb + -t TIMEOUT timeout in seconds, default: 300 + --fast-open use TCP_FASTOPEN, requires Linux 3.7+ + +General options: + -h, --help show this help message and exit + -d start/stop/restart daemon mode + --pid-file PID_FILE pid file for daemon mode + --log-file LOG_FILE log file for daemon mode + --user USER username to run as + -v, -vv verbose mode + -q, -qq quiet mode, only show warnings/errors + --version show version information + +Online help: +''') + + +def print_server_help(): + print('''usage: ssserver [OPTION]... +A fast tunnel proxy that helps you bypass firewalls. + +You can supply configurations via either config file or command line arguments. + +Proxy options: + -c CONFIG path to config file + -s SERVER_ADDR server address, default: 0.0.0.0 + -p SERVER_PORT server port, default: 8388 + -k PASSWORD password + -m METHOD encryption method, default: aes-256-cfb + -t TIMEOUT timeout in seconds, default: 300 + --fast-open use TCP_FASTOPEN, requires Linux 3.7+ + --workers WORKERS number of workers, available on Unix/Linux + --forbidden-ip IPLIST comma seperated IP list forbidden to connect + --manager-address ADDR optional server manager UDP address, see wiki + +General options: + -h, --help show this help message and exit + -d start/stop/restart daemon mode + --pid-file PID_FILE pid file for daemon mode + --log-file LOG_FILE log file for daemon mode + --user USER username to run as + -v, -vv verbose mode + -q, -qq quiet mode, only show warnings/errors + --version show version information + +Online help: +''') + + +def _decode_list(data): + rv = [] + for item in data: + if hasattr(item, 'encode'): + item = item.encode('utf-8') + elif isinstance(item, list): + item = _decode_list(item) + elif isinstance(item, dict): + item = _decode_dict(item) + rv.append(item) + return rv + + +def _decode_dict(data): + rv = {} + for key, value in data.items(): + if hasattr(value, 'encode'): + value = value.encode('utf-8') + elif isinstance(value, list): + value = _decode_list(value) + elif isinstance(value, dict): + value = _decode_dict(value) + rv[key] = value + return rv + + +def parse_json_in_str(data): + # parse json and convert everything from unicode to str + return json.loads(data, object_hook=_decode_dict) diff --git a/tcprelay.py b/tcprelay.py new file mode 100644 index 0000000..25623ec --- /dev/null +++ b/tcprelay.py @@ -0,0 +1,387 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import errno +import logging +import socket + +from shadowsocks.selectors import (EVENT_READ, EVENT_WRITE, EVENT_ERROR, + errno_from_exception, get_sock_error) +from shadowsocks.common import parse_header, to_str +from shadowsocks import encrypt + + +BUF_SIZE = 32 * 1024 +CMD_CONNECT = 1 + + +def create_sock(ip, port): + addrs = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM, + socket.SOL_TCP) + if len(addrs) == 0: + raise Exception("Getaddrinfo failed for %s:%d" % (ip, port)) + + af, socktype, proto, canonname, sa = addrs[0] + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) + return sock + + +class TcpRelayHanler(object): + + def __init__(self, local_sock, local_addr, remote_addr=None, + dns_resolver=None): + self._loop = None + self._local_sock = local_sock + self._local_addr = local_addr + self._remote_addr = remote_addr + # self._crypt = None + self._crypt = encrypt.Encryptor(b'PassThrouthGFW', 'aes-256-cfb') + self._dns_resolver = dns_resolver + self._remote_sock = None + self._local_sock_mode = 0 + self._remote_sock_mode = 0 + self._data_to_write_to_local = [] + self._data_to_write_to_remote = [] + self._id = id(self) + + def handle_event(self, sock, event, call, *args): + if event & EVENT_ERROR: + logging.error(get_sock_error(sock)) + self.close() + else: + try: + call(sock, event, *args) + except Exception as e: + logging.error(e) + self.close() + + def __del__(self): + logging.debug('Deleting {}'.format(self._id)) + + def add_to_loop(self, loop): + if self._loop: + raise Exception('Already added to loop') + self._loop = loop + loop.add(self._local_sock, EVENT_READ, (self, self.start)) + + def modify_local_sock_mode(self, event): + if self._local_sock_mode != event: + self._local_sock_mode = self.modify_sock_mode(self._local_sock, + event) + + def modify_remote_sock_mode(self, event): + if self._remote_sock_mode != event: + self._remote_sock_mode = self.modify_sock_mode(self._remote_sock, + event) + + def modify_sock_mode(self, sock, event): + key = self._loop.modify(sock, event, (self, self.stream)) + return key.events + + def close_sock(self, sock): + self._loop.remove(sock) + sock.close() + + def close(self): + if self._local_sock: + self.close_sock(self._local_sock) + self._local_sock = None + if self._remote_sock: + self.close_sock(self._remote_sock) + self._remote_sock = None + + def sock_connect(self, sock, addr): + while True: + try: + sock.connect(addr) + except (OSError, IOError) as e: + err = errno_from_exception(e) + if err == errno.EINTR: + pass + elif err == errno.EINPROGRESS: + break + else: + raise + else: + break + + def sock_recv(self, sock, size=BUF_SIZE): + try: + data = sock.recv(size) + if not data: + self.close() + except (OSError, IOError) as e: + if errno_from_exception(e) in (errno.EAGAIN, errno.EWOULDBLOCK, + errno.EINTR): + return + else: + raise + return data + + def sock_send(self, sock, data): + try: + s = sock.send(data) + data = data[s:] + except (OSError, IOError) as e: + if errno_from_exception(e) in (errno.EAGAIN, errno.EWOULDBLOCK, + errno.EINPROGRESS, errno.EINTR): + pass + else: + raise + + return data + + def on_local_read(self, size=BUF_SIZE): + logging.debug('on_local_read') + if not self._local_sock: + return + + data = self.sock_recv(self._local_sock, size) + if not data: + return + + logging.debug('Received {} bytes from {}:{}'.format(len(data), + *self._local_addr)) + if self._crypt: + if self._is_client: + data = self._crypt.encrypt(data) + else: + data = self._crypt.decrypt(data) + + if data: + self._data_to_write_to_remote.append(data) + self.on_remote_write() + + def on_remote_read(self, size=BUF_SIZE): + logging.debug('on_remote_read') + if not self._remote_sock: + return + + data = self.sock_recv(self._remote_sock, size) + if not data: + return + + logging.debug('Received {} bytes from {}:{}'.format( + len(data), *self._remote_addr)) + + if self._crypt: + if self._is_client: + data = self._crypt.decrypt(data) + else: + data = self._crypt.encrypt(data) + + if data: + self._data_to_write_to_local.append(data) + self.on_local_write() + + def on_local_write(self): + logging.debug('on_local_write') + if not self._local_sock: + return + + if not self._data_to_write_to_local: + self.modify_local_sock_mode(EVENT_READ) + return + + data = b''.join(self._data_to_write_to_local) + self._data_to_write_to_local = [] + + data = self.sock_send(self._local_sock, data) + + if data: + self._data_to_write_to_local.append(data) + self.modify_local_sock_mode(EVENT_WRITE) + else: + self.modify_local_sock_mode(EVENT_READ) + + def on_remote_write(self): + logging.debug('on_remote_write') + if not self._remote_sock: + return + + if not self._data_to_write_to_remote: + self.modify_remote_sock_mode(EVENT_READ) + return + + data = b''.join(self._data_to_write_to_remote) + self._data_to_write_to_remote = [] + + data = self.sock_send(self._remote_sock, data) + + if data: + self._data_to_write_to_remote.append(data) + self.modify_remote_sock_mode(EVENT_WRITE) + else: + self.modify_remote_sock_mode(EVENT_READ) + + def stream(self, sock, event): + logging.debug('stream') + + if sock == self._local_sock: + if event & EVENT_READ: + self.on_local_read() + if event & EVENT_WRITE: + self.on_local_write() + elif sock == self._remote_sock: + if event & EVENT_READ: + self.on_remote_read() + if event & EVENT_WRITE: + self.on_remote_write() + else: + logging.warn('Unknow sock {}'.format(sock)) + + +class TcpRelayClientHanler(TcpRelayHanler): + + _is_client = True + + def start(self, sock, event): + data = self.sock_recv(sock) + if not data: + return + reply = b'\x05\x00' + self.send_reply(sock, None, reply) + + def send_reply(self, sock, event, data): + data = self.sock_send(sock, data) + if data: + self._loop.modify(sock, EVENT_WRITE, (self, self.send_reply, data)) + else: + self._loop.modify(sock, EVENT_READ, (self, self.handle_addr)) + + def handle_addr(self, sock, event): + data = self.sock_recv(sock) + if not data: + return + # self._loop.remove(sock) + + if ord(data[1:2]) != CMD_CONNECT: + raise Exception('Command not suppored') + + result = parse_header(data[3:]) + if not result: + raise Exception('Header cannot be parsed') + + self._remote_sock = create_sock(*self._remote_addr) + self.sock_connect(self._remote_sock, self._remote_addr) + + dest_addr = (to_str(result[1]), result[2]) + logging.info('Connecting to {}:{}'.format(*dest_addr)) + data = '{}:{}\n'.format(*dest_addr).encode('utf-8') + if self._crypt: + data = self._crypt.encrypt(data) + self._data_to_write_to_remote.append(data) + + bind_addr = b'\x05\x00\x00\x01\x00\x00\x00\x00\x00\x00' + self.send_bind_addr(sock, None, bind_addr) + + def send_bind_addr(self, sock, event, data): + data = self.sock_send(sock, data) + if data: + self._loop.modify(sock, EVENT_WRITE, (self, self.send_bind_addr, + data)) + else: + self.modify_local_sock_mode(EVENT_READ) + + +class TcpRelayServerHandler(TcpRelayHanler): + + _is_client = False + + def start(self, sock, event, data=None): + data = self.sock_recv(sock) + if not data: + return + self._loop.remove(sock) + + if self._crypt: + data = self._crypt.decrypt(data) + remote, data = data.split(b'\n', 1) + host, port = remote.split(b':') + self._remote = (host, int(port)) + self._data_to_write_to_remote.append(data) + self._dns_resolver.resolve(host, self.dns_resolved) + + def dns_resolved(self, result, error): + try: + ip = result[1] + except (TypeError, IndexError): + ip = None + + if not ip: + raise Exception('Hostname {} cannot resolved'.format( + self._remote[0])) + + self._remote_addr = (ip, self._remote[1]) + self._remote_sock = create_sock(*self._remote_addr) + logging.info('Connecting to {}'.format(self._remote[0])) + self.sock_connect(self._remote_sock, self._remote_addr) + self.modify_remote_sock_mode(EVENT_WRITE) + self.modify_local_sock_mode(EVENT_READ) + + +class TcpRelay(object): + + def __init__(self, handler_type, listen_addr, remote_addr=None, + dns_resolver=None): + self._loop = None + self._handler_type = handler_type + self._listen_addr = listen_addr + self._remote_addr = remote_addr + self._dns_resolver = dns_resolver + self._create_listen_sock() + + def _create_listen_sock(self): + sock = create_sock(*self._listen_addr) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(self._listen_addr) + sock.listen(1024) + self._listen_sock = sock + logging.info('Listening on {}:{}'.format(*self._listen_addr)) + + def add_to_loop(self, loop): + if self._loop: + raise Exception('Already added to loop') + self._loop = loop + loop.add(self._listen_sock, EVENT_READ, (self, self.accept)) + + def _accept(self, listen_sock): + try: + sock, addr = listen_sock.accept() + except (OSError, IOError) as e: + if errno_from_exception(e) in ( + errno.EAGAIN, errno.EWOULDBLOCK, errno.EINPROGRESS, + errno.EINTR, errno.ECONNABORTED + ): + pass + else: + raise + logging.info('Connected from {}:{}'.format(*addr)) + sock.setblocking(False) + sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) + return (sock, addr) + + def accept(self, listen_sock, event): + sock, addr = self._accept(listen_sock) + handler = self._handler_type(sock, addr, self._remote_addr, + self._dns_resolver) + handler.add_to_loop(self._loop) + + def close(self): + self._loop.remove(self._listen_sock) + + def handle_event(self, sock, event, call, *args): + if event & EVENT_ERROR: + logging.error(get_sock_error(sock)) + self.close() + else: + try: + call(sock, event, *args) + except Exception as e: + logging.error(e) + self.close()