diff --git a/src/dingdinghelper/websocket/__init__.py b/src/dingdinghelper/websocket/__init__.py index b90e65ada118589793a5229a44f7c4c3fda2be25..605f76cd1241100458acb1faa36ad0b17cf992b1 100644 --- a/src/dingdinghelper/websocket/__init__.py +++ b/src/dingdinghelper/websocket/__init__.py @@ -26,4 +26,4 @@ from ._exceptions import * from ._logging import * from ._socket import * -__version__ = "0.47.0" +__version__ = "0.57.0" diff --git a/src/dingdinghelper/websocket/__pycache__/__init__.cpython-37.pyc b/src/dingdinghelper/websocket/__pycache__/__init__.cpython-37.pyc index 09658ba4cf9cb1c77b84628043de348b65634d0c..54e815ede1d9b88b590f895e06b611e7d11d8592 100644 Binary files a/src/dingdinghelper/websocket/__pycache__/__init__.cpython-37.pyc and b/src/dingdinghelper/websocket/__pycache__/__init__.cpython-37.pyc differ diff --git a/src/dingdinghelper/websocket/__pycache__/_abnf.cpython-37.pyc b/src/dingdinghelper/websocket/__pycache__/_abnf.cpython-37.pyc index 675b8727a796bb7eaf6fc00822e6ff714f10f725..343052960a63bf5d4bf4ffdf889ddbb9a5cfa0a3 100644 Binary files a/src/dingdinghelper/websocket/__pycache__/_abnf.cpython-37.pyc and b/src/dingdinghelper/websocket/__pycache__/_abnf.cpython-37.pyc differ diff --git a/src/dingdinghelper/websocket/__pycache__/_app.cpython-37.pyc b/src/dingdinghelper/websocket/__pycache__/_app.cpython-37.pyc index d12b5abacccd66fdc1f968b6c4e2cd849e31ecc5..94ad6cb7c06d327c75bb358571b4daee648ccf47 100644 Binary files a/src/dingdinghelper/websocket/__pycache__/_app.cpython-37.pyc and b/src/dingdinghelper/websocket/__pycache__/_app.cpython-37.pyc differ diff --git a/src/dingdinghelper/websocket/__pycache__/_cookiejar.cpython-37.pyc b/src/dingdinghelper/websocket/__pycache__/_cookiejar.cpython-37.pyc index 772669e0d913b269638b961557d22dae5cbcae56..cee60b02e431d8d9b60989d9342bfa5def6cef2d 100644 Binary files a/src/dingdinghelper/websocket/__pycache__/_cookiejar.cpython-37.pyc and b/src/dingdinghelper/websocket/__pycache__/_cookiejar.cpython-37.pyc differ diff --git a/src/dingdinghelper/websocket/__pycache__/_core.cpython-37.pyc b/src/dingdinghelper/websocket/__pycache__/_core.cpython-37.pyc index 5049dbd56a0da68f4a155a203d5c444f272b6953..3031e57fe329724d493e86d73f7cee7e13ab5cb9 100644 Binary files a/src/dingdinghelper/websocket/__pycache__/_core.cpython-37.pyc and b/src/dingdinghelper/websocket/__pycache__/_core.cpython-37.pyc differ diff --git a/src/dingdinghelper/websocket/__pycache__/_exceptions.cpython-37.pyc b/src/dingdinghelper/websocket/__pycache__/_exceptions.cpython-37.pyc index c31366500f4b2f2dfc4d3a9f43e6410f75a134ca..2dd26d8356d5c9622c383ef75b6c820e0bf4ddda 100644 Binary files a/src/dingdinghelper/websocket/__pycache__/_exceptions.cpython-37.pyc and b/src/dingdinghelper/websocket/__pycache__/_exceptions.cpython-37.pyc differ diff --git a/src/dingdinghelper/websocket/__pycache__/_handshake.cpython-37.pyc b/src/dingdinghelper/websocket/__pycache__/_handshake.cpython-37.pyc index 19b5a826035aa5afeda5277f31920f2f015444f3..75227ff9665754be287333d41c4927cf4d735793 100644 Binary files a/src/dingdinghelper/websocket/__pycache__/_handshake.cpython-37.pyc and b/src/dingdinghelper/websocket/__pycache__/_handshake.cpython-37.pyc differ diff --git a/src/dingdinghelper/websocket/__pycache__/_http.cpython-37.pyc b/src/dingdinghelper/websocket/__pycache__/_http.cpython-37.pyc index d2893092c41f5d78ff3f1c1f2af34b528ce59ecb..4b04641c16d6fa26a98a06dada0c2d334a563ece 100644 Binary files a/src/dingdinghelper/websocket/__pycache__/_http.cpython-37.pyc and b/src/dingdinghelper/websocket/__pycache__/_http.cpython-37.pyc differ diff --git a/src/dingdinghelper/websocket/__pycache__/_logging.cpython-37.pyc b/src/dingdinghelper/websocket/__pycache__/_logging.cpython-37.pyc index beaa83383e1d36cf64ac22a73ec44a1f57df5519..44ea69422ffb5aaf47c3f3baf8c1672ee5e8723b 100644 Binary files a/src/dingdinghelper/websocket/__pycache__/_logging.cpython-37.pyc and b/src/dingdinghelper/websocket/__pycache__/_logging.cpython-37.pyc differ diff --git a/src/dingdinghelper/websocket/__pycache__/_socket.cpython-37.pyc b/src/dingdinghelper/websocket/__pycache__/_socket.cpython-37.pyc index 29203cd57214893d772c18eebb97ccfabea64fe2..e1ff736e01ce6f4c57d76be5afec01a045f28477 100644 Binary files a/src/dingdinghelper/websocket/__pycache__/_socket.cpython-37.pyc and b/src/dingdinghelper/websocket/__pycache__/_socket.cpython-37.pyc differ diff --git a/src/dingdinghelper/websocket/__pycache__/_ssl_compat.cpython-37.pyc b/src/dingdinghelper/websocket/__pycache__/_ssl_compat.cpython-37.pyc index 938bda89b223cb998f22f4473b3bb9a3f720cfc8..200ca13b64c1433e236c1edd2287299fa46e8032 100644 Binary files a/src/dingdinghelper/websocket/__pycache__/_ssl_compat.cpython-37.pyc and b/src/dingdinghelper/websocket/__pycache__/_ssl_compat.cpython-37.pyc differ diff --git a/src/dingdinghelper/websocket/__pycache__/_url.cpython-37.pyc b/src/dingdinghelper/websocket/__pycache__/_url.cpython-37.pyc index bb42b0985bbe563ce02880796f3c10ecd3eea57b..e787773be933c2cfa03a2609b829eb03879dcabe 100644 Binary files a/src/dingdinghelper/websocket/__pycache__/_url.cpython-37.pyc and b/src/dingdinghelper/websocket/__pycache__/_url.cpython-37.pyc differ diff --git a/src/dingdinghelper/websocket/__pycache__/_utils.cpython-37.pyc b/src/dingdinghelper/websocket/__pycache__/_utils.cpython-37.pyc index 316b720dd937908f83eddd103b9c5f1e9f4efb98..d0540857569754af24a5135b3647219ae1c8a5f5 100644 Binary files a/src/dingdinghelper/websocket/__pycache__/_utils.cpython-37.pyc and b/src/dingdinghelper/websocket/__pycache__/_utils.cpython-37.pyc differ diff --git a/src/dingdinghelper/websocket/_app.py b/src/dingdinghelper/websocket/_app.py index 74e90ae02c1f18c64f840f335f533e6fd1703db4..e4e9f99c7d5919c8c5a00354815b19e3fa3f129d 100644 --- a/src/dingdinghelper/websocket/_app.py +++ b/src/dingdinghelper/websocket/_app.py @@ -23,6 +23,7 @@ Copyright (C) 2010 Hiroki Ohtani(liris) """ WebSocketApp provides higher level APIs. """ +import inspect import select import sys import threading @@ -41,26 +42,30 @@ __all__ = ["WebSocketApp"] class Dispatcher: def __init__(self, app, ping_timeout): - self.app = app + self.app = app self.ping_timeout = ping_timeout - def read(self, sock, callback): - while self.app.sock.connected: + def read(self, sock, read_callback, check_callback): + while self.app.keep_running: r, w, e = select.select( - (self.app.sock.sock, ), (), (), self.ping_timeout) # Use a 10 second timeout to avoid to wait forever on close + (self.app.sock.sock, ), (), (), self.ping_timeout) if r: - callback() + if not read_callback(): + break + check_callback() -class SSLDispacther: +class SSLDispatcher: def __init__(self, app, ping_timeout): - self.app = app + self.app = app self.ping_timeout = ping_timeout - def read(self, sock, callback): - while self.app.sock.connected: + def read(self, sock, read_callback, check_callback): + while self.app.keep_running: r = self.select() if r: - callback() + if not read_callback(): + break + check_callback() def select(self): sock = self.app.sock.sock @@ -70,6 +75,7 @@ class SSLDispacther: r, w, e = select.select((sock, ), (), (), self.ping_timeout) return r + class WebSocketApp(object): """ Higher level of APIs are provided. @@ -113,7 +119,7 @@ class WebSocketApp(object): The 2nd argument is utf-8 string which we get from the server. The 3rd argument is data type. ABNF.OPCODE_TEXT or ABNF.OPCODE_BINARY will be came. The 4th argument is continue flag. if 0, the data continue - keep_running: this parameter is obosleted and ignored it. + keep_running: this parameter is obsolete and ignored. get_mask_key: a callable to produce new mask keys, see the WebSocket.set_mask_key's docstring for more information subprotocols: array of available sub protocols. default is None. @@ -121,6 +127,7 @@ class WebSocketApp(object): self.url = url self.header = header if header is not None else [] self.cookie = cookie + self.on_open = on_open self.on_message = on_message self.on_data = on_data @@ -155,6 +162,7 @@ class WebSocketApp(object): self.keep_running = False if self.sock: self.sock.close(**kwargs) + self.sock = None def _send_ping(self, interval, event): while not event.wait(interval): @@ -171,7 +179,8 @@ class WebSocketApp(object): http_proxy_host=None, http_proxy_port=None, http_no_proxy=None, http_proxy_auth=None, skip_utf8_validation=False, - host=None, origin=None, dispatcher=None): + host=None, origin=None, dispatcher=None, + suppress_origin=False, proxy_type=None): """ run event loop for WebSocket framework. This loop is infinite loop and is alive during websocket is available. @@ -189,32 +198,42 @@ class WebSocketApp(object): skip_utf8_validation: skip utf8 validation. host: update host header. origin: update origin header. + dispatcher: customize reading data from socket. + suppress_origin: suppress outputting origin header. + + Returns + ------- + False if caught KeyboardInterrupt + True if other exception was raised during a loop """ - if not ping_timeout or ping_timeout <= 0: + if ping_timeout is not None and ping_timeout <= 0: ping_timeout = None if ping_timeout and ping_interval and ping_interval <= ping_timeout: raise WebSocketException("Ensure ping_interval > ping_timeout") - if sockopt is None: + if not sockopt: sockopt = [] - if sslopt is None: + if not sslopt: sslopt = {} if self.sock: raise WebSocketException("socket is already opened") thread = None - close_frame = None self.keep_running = True self.last_ping_tm = 0 self.last_pong_tm = 0 - def teardown(): - if not self.keep_running: - return + def teardown(close_frame=None): + """ + Tears down the connection. + If close_frame is set, we will invoke the on_close handler with the + statusCode and reason from there. + """ if thread and thread.isAlive(): event.set() thread.join() self.keep_running = False - self.sock.close() + if self.sock: + self.sock.close() close_args = self._get_close_args( close_frame.data if close_frame else None) self._callback(self.on_close, *close_args) @@ -223,15 +242,17 @@ class WebSocketApp(object): try: self.sock = WebSocket( self.get_mask_key, sockopt=sockopt, sslopt=sslopt, - fire_cont_frame=self.on_cont_message and True or False, - skip_utf8_validation=skip_utf8_validation) + fire_cont_frame=self.on_cont_message is not None, + skip_utf8_validation=skip_utf8_validation, + enable_multithread=True if ping_interval else False) self.sock.settimeout(getdefaulttimeout()) self.sock.connect( self.url, header=self.header, cookie=self.cookie, http_proxy_host=http_proxy_host, http_proxy_port=http_proxy_port, http_no_proxy=http_no_proxy, http_proxy_auth=http_proxy_auth, subprotocols=self.subprotocols, - host=host, origin=origin) + host=host, origin=origin, suppress_origin=suppress_origin, + proxy_type=proxy_type) if not dispatcher: dispatcher = self.create_dispatcher(ping_timeout) @@ -250,8 +271,7 @@ class WebSocketApp(object): op_code, frame = self.sock.recv_data_frame(True) if op_code == ABNF.OPCODE_CLOSE: - close_frame = frame - return teardown() + return teardown(frame) elif op_code == ABNF.OPCODE_PING: self._callback(self.on_ping, frame.data) elif op_code == ABNF.OPCODE_PONG: @@ -269,31 +289,39 @@ class WebSocketApp(object): self._callback(self.on_data, data, frame.opcode, True) self._callback(self.on_message, data) - if ping_timeout and self.last_ping_tm \ - and time.time() - self.last_ping_tm > ping_timeout \ - and self.last_ping_tm - self.last_pong_tm > ping_timeout: - raise WebSocketTimeoutException("ping/pong timed out") return True - dispatcher.read(self.sock.sock, read) + def check(): + if (ping_timeout): + has_timeout_expired = time.time() - self.last_ping_tm > ping_timeout + has_pong_not_arrived_after_last_ping = self.last_pong_tm - self.last_ping_tm < 0 + has_pong_arrived_too_late = self.last_pong_tm - self.last_ping_tm > ping_timeout + + if (self.last_ping_tm + and has_timeout_expired + and (has_pong_not_arrived_after_last_ping or has_pong_arrived_too_late)): + raise WebSocketTimeoutException("ping/pong timed out") + return True + + dispatcher.read(self.sock.sock, read, check) except (Exception, KeyboardInterrupt, SystemExit) as e: self._callback(self.on_error, e) if isinstance(e, SystemExit): # propagate SystemExit further raise teardown() + return not isinstance(e, KeyboardInterrupt) def create_dispatcher(self, ping_timeout): timeout = ping_timeout or 10 if self.sock.is_ssl(): - return SSLDispacther(self, timeout) + return SSLDispatcher(self, timeout) return Dispatcher(self, timeout) def _get_close_args(self, data): """ this functions extracts the code, reason from the close body if they exists, and if the self.on_close except three arguments """ - import inspect # if the on_close callback is "old", just return empty list if sys.version_info < (3, 0): if not self.on_close or len(inspect.getargspec(self.on_close).args) != 3: @@ -312,7 +340,11 @@ class WebSocketApp(object): def _callback(self, callback, *args): if callback: try: - callback(self, *args) + if inspect.ismethod(callback): + callback(*args) + else: + callback(self, *args) + except Exception as e: _logging.error("error from callback {}: {}".format(callback, e)) if _logging.isEnabledForDebug(): diff --git a/src/dingdinghelper/websocket/_core.py b/src/dingdinghelper/websocket/_core.py index 2d009621fbcf5fe8aad17627f98ae8ddd973d443..418aafc42a76b2b7e32c1616887c4ec4f9408daa 100644 --- a/src/dingdinghelper/websocket/_core.py +++ b/src/dingdinghelper/websocket/_core.py @@ -24,6 +24,7 @@ from __future__ import print_function import socket import struct import threading +import time import six @@ -201,6 +202,7 @@ class WebSocket(object): options: "header" -> custom http header list or dict. "cookie" -> cookie value. "origin" -> custom origin url. + "suppress_origin" -> suppress outputting origin header. "host" -> custom host header string. "http_proxy_host" - http proxy host name. "http_proxy_port" - http proxy port. If not set, set to 80. @@ -208,16 +210,27 @@ class WebSocket(object): "http_proxy_auth" - http proxy auth information. tuple of username and password. default is None + "redirect_limit" -> number of redirects to follow. "subprotocols" - array of available sub protocols. default is None. "socket" - pre-initialized stream socket. """ + # FIXME: "subprotocols" are getting lost, not passed down + # FIXME: "header", "cookie", "origin" and "host" too + self.sock_opt.timeout = options.get('timeout', self.sock_opt.timeout) self.sock, addrs = connect(url, self.sock_opt, proxy_info(**options), options.pop('socket', None)) try: self.handshake_response = handshake(self.sock, *addrs, **options) + for attempt in range(options.pop('redirect_limit', 3)): + if self.handshake_response.status in SUPPORTED_REDIRECT_STATUSES: + url = self.handshake_response.headers['location'] + self.sock.close() + self.sock, addrs = connect(url, self.sock_opt, proxy_info(**options), + options.pop('socket', None)) + self.handshake_response = handshake(self.sock, *addrs, **options) self.connected = True except: if self.sock: @@ -258,7 +271,8 @@ class WebSocket(object): frame.get_mask_key = self.get_mask_key data = frame.format() length = len(data) - trace("send: " + repr(data)) + if (isEnabledForTrace()): + trace("send: " + repr(data)) with self.lock: while data: @@ -397,20 +411,25 @@ class WebSocket(object): reason, ABNF.OPCODE_CLOSE) sock_timeout = self.sock.gettimeout() self.sock.settimeout(timeout) - try: - frame = self.recv_frame() - if isEnabledForError(): - recv_status = struct.unpack("!H", frame.data[0:2])[0] - if recv_status != STATUS_NORMAL: - error("close status: " + repr(recv_status)) - except: - pass + start_time = time.time() + while timeout is None or time.time() - start_time < timeout: + try: + frame = self.recv_frame() + if frame.opcode != ABNF.OPCODE_CLOSE: + continue + if isEnabledForError(): + recv_status = struct.unpack("!H", frame.data[0:2])[0] + if recv_status != STATUS_NORMAL: + error("close status: " + repr(recv_status)) + break + except: + break self.sock.settimeout(sock_timeout) self.sock.shutdown(socket.SHUT_RDWR) except: pass - self.shutdown() + self.shutdown() def abort(self): """ @@ -466,6 +485,7 @@ def create_connection(url, timeout=None, class_=WebSocket, **options): options: "header" -> custom http header list or dict. "cookie" -> cookie value. "origin" -> custom origin url. + "suppress_origin" -> suppress outputting origin header. "host" -> custom host header string. "http_proxy_host" - http proxy host name. "http_proxy_port" - http proxy port. If not set, set to 80. @@ -474,6 +494,7 @@ def create_connection(url, timeout=None, class_=WebSocket, **options): tuple of username and password. default is None "enable_multithread" -> enable lock for multithread. + "redirect_limit" -> number of redirects to follow. "sockopt" -> socket options "sslopt" -> ssl option "subprotocols" - array of available sub protocols. diff --git a/src/dingdinghelper/websocket/_exceptions.py b/src/dingdinghelper/websocket/_exceptions.py index 24c85e0ee50cc0bd5f35e3aff3c830087dd61a52..207079026bad8f9b7c1c676a203aac7b94a81a78 100644 --- a/src/dingdinghelper/websocket/_exceptions.py +++ b/src/dingdinghelper/websocket/_exceptions.py @@ -74,11 +74,12 @@ class WebSocketBadStatusException(WebSocketException): WebSocketBadStatusException will be raised when we get bad handshake status code. """ - def __init__(self, message, status_code, status_message=None): - msg = message % (status_code, status_message) if status_message is not None \ - else message % status_code + def __init__(self, message, status_code, status_message=None, resp_headers=None): + msg = message % (status_code, status_message) super(WebSocketBadStatusException, self).__init__(msg) self.status_code = status_code + self.resp_headers = resp_headers + class WebSocketAddressException(WebSocketException): """ diff --git a/src/dingdinghelper/websocket/_handshake.py b/src/dingdinghelper/websocket/_handshake.py index 3fd5c9eed0fdcedc549366ab051d056358f5e3dd..7476a072c634f4c33421b996156ddc0288d90e7d 100644 --- a/src/dingdinghelper/websocket/_handshake.py +++ b/src/dingdinghelper/websocket/_handshake.py @@ -31,12 +31,20 @@ from ._http import * from ._logging import * from ._socket import * -if six.PY3: +if hasattr(six, 'PY3') and six.PY3: from base64 import encodebytes as base64encode else: from base64 import encodestring as base64encode -__all__ = ["handshake_response", "handshake"] +if hasattr(six, 'PY3') and six.PY3: + if hasattr(six, 'PY34') and six.PY34: + from http import client as HTTPStatus + else: + from http import HTTPStatus +else: + import httplib as HTTPStatus + +__all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"] if hasattr(hmac, "compare_digest"): compare_digest = hmac.compare_digest @@ -47,6 +55,9 @@ else: # websocket supported version. VERSION = 13 +SUPPORTED_REDIRECT_STATUSES = (HTTPStatus.MOVED_PERMANENTLY, HTTPStatus.FOUND, HTTPStatus.SEE_OTHER,) +SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,) + CookieJar = SimpleCookieJar() @@ -67,12 +78,15 @@ def handshake(sock, hostname, port, resource, **options): dump("request header", header_str) status, resp = _get_resp_headers(sock) + if status in SUPPORTED_REDIRECT_STATUSES: + return handshake_response(status, resp, None) success, subproto = _validate(resp, key, options.get("subprotocols")) if not success: raise WebSocketException("Invalid WebSocket Header") return handshake_response(status, resp, subproto) + def _pack_hostname(hostname): # IPv6 address if ':' in hostname: @@ -83,27 +97,39 @@ def _pack_hostname(hostname): def _get_handshake_headers(resource, host, port, options): headers = [ "GET %s HTTP/1.1" % resource, - "Upgrade: websocket", - "Connection: Upgrade" + "Upgrade: websocket" ] if port == 80 or port == 443: hostport = _pack_hostname(host) else: hostport = "%s:%d" % (_pack_hostname(host), port) - if "host" in options and options["host"] is not None: headers.append("Host: %s" % options["host"]) else: headers.append("Host: %s" % hostport) - if "origin" in options and options["origin"] is not None: - headers.append("Origin: %s" % options["origin"]) - else: - headers.append("Origin: http://%s" % hostport) + if "suppress_origin" not in options or not options["suppress_origin"]: + if "origin" in options and options["origin"] is not None: + headers.append("Origin: %s" % options["origin"]) + else: + headers.append("Origin: http://%s" % hostport) key = _create_sec_websocket_key() - headers.append("Sec-WebSocket-Key: %s" % key) - headers.append("Sec-WebSocket-Version: %s" % VERSION) + + # Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified + if not 'header' in options or 'Sec-WebSocket-Key' not in options['header']: + key = _create_sec_websocket_key() + headers.append("Sec-WebSocket-Key: %s" % key) + else: + key = options['header']['Sec-WebSocket-Key'] + + if not 'header' in options or 'Sec-WebSocket-Version' not in options['header']: + headers.append("Sec-WebSocket-Version: %s" % VERSION) + + if not 'connection' in options or options['connection'] is None: + headers.append('Connection: upgrade') + else: + headers.append(options['connection']) subprotocols = options.get("subprotocols") if subprotocols: @@ -112,7 +138,11 @@ def _get_handshake_headers(resource, host, port, options): if "header" in options: header = options["header"] if isinstance(header, dict): - header = map(": ".join, header.items()) + header = [ + ": ".join([k, v]) + for k, v in header.items() + if v is not None + ] headers.extend(header) server_cookie = CookieJar.get(host) @@ -129,12 +159,13 @@ def _get_handshake_headers(resource, host, port, options): return headers, key -def _get_resp_headers(sock, success_status=101): +def _get_resp_headers(sock, success_statuses=SUCCESS_STATUSES): status, resp_headers, status_message = read_headers(sock) - if status != success_status: - raise WebSocketBadStatusException("Handshake status %d %s", status, status_message) + if status not in success_statuses: + raise WebSocketBadStatusException("Handshake status %d %s", status, status_message, resp_headers) return status, resp_headers + _HEADERS_TO_CHECK = { "upgrade": "websocket", "connection": "upgrade", diff --git a/src/dingdinghelper/websocket/_http.py b/src/dingdinghelper/websocket/_http.py index e341dfd313a681cbd002488124a26fcca5cf0332..a8777de6096595b242cbe8f6aeaf748be914bc77 100644 --- a/src/dingdinghelper/websocket/_http.py +++ b/src/dingdinghelper/websocket/_http.py @@ -39,21 +39,72 @@ else: __all__ = ["proxy_info", "connect", "read_headers"] +try: + import socks + ProxyConnectionError = socks.ProxyConnectionError + HAS_PYSOCKS = True +except: + class ProxyConnectionError(BaseException): + pass + HAS_PYSOCKS = False class proxy_info(object): def __init__(self, **options): + self.type = options.get("proxy_type") or "http" + if not(self.type in ['http', 'socks4', 'socks5', 'socks5h']): + raise ValueError("proxy_type must be 'http', 'socks4', 'socks5' or 'socks5h'") self.host = options.get("http_proxy_host", None) if self.host: self.port = options.get("http_proxy_port", 0) self.auth = options.get("http_proxy_auth", None) + self.no_proxy = options.get("http_no_proxy", None) else: self.port = 0 self.auth = None - self.no_proxy = options.get("http_no_proxy", None) + self.no_proxy = None + + +def _open_proxied_socket(url, options, proxy): + hostname, port, resource, is_secure = parse_url(url) + + if not HAS_PYSOCKS: + raise WebSocketException("PySocks module not found.") + + ptype = socks.SOCKS5 + rdns = False + if proxy.type == "socks4": + ptype = socks.SOCKS4 + if proxy.type == "http": + ptype = socks.HTTP + if proxy.type[-1] == "h": + rdns = True + + sock = socks.create_connection( + (hostname, port), + proxy_type = ptype, + proxy_addr = proxy.host, + proxy_port = proxy.port, + proxy_rdns = rdns, + proxy_username = proxy.auth[0] if proxy.auth else None, + proxy_password = proxy.auth[1] if proxy.auth else None, + timeout = options.timeout, + socket_options = DEFAULT_SOCKET_OPTION + options.sockopt + ) + + if is_secure: + if HAVE_SSL: + sock = _ssl_socket(sock, options.sslopt, hostname) + else: + raise WebSocketException("SSL not available.") + + return sock, (hostname, port, resource) def connect(url, options, proxy, socket): + if proxy.host and not socket and not (proxy.type == 'http'): + return _open_proxied_socket(url, options, proxy) + hostname, port, resource, is_secure = parse_url(url) if socket: @@ -88,13 +139,20 @@ def _get_addrinfo_list(hostname, port, is_secure, proxy): phost, pport, pauth = get_proxy_info( hostname, is_secure, proxy.host, proxy.port, proxy.auth, proxy.no_proxy) try: + # when running on windows 10, getaddrinfo without socktype returns a socktype 0. + # This generates an error exception: `_on_error: exception Socket type must be stream or datagram, not 0` + # or `OSError: [Errno 22] Invalid argument` when creating socket. Force the socket type to SOCK_STREAM. if not phost: addrinfo_list = socket.getaddrinfo( - hostname, port, 0, 0, socket.SOL_TCP) + hostname, port, 0, socket.SOCK_STREAM, socket.SOL_TCP) return addrinfo_list, False, None else: pport = pport and pport or 80 - addrinfo_list = socket.getaddrinfo(phost, pport, 0, 0, socket.SOL_TCP) + # when running on windows 10, the getaddrinfo used above + # returns a socktype 0. This generates an error exception: + # _on_error: exception Socket type must be stream or datagram, not 0 + # Force the socket type to SOCK_STREAM + addrinfo_list = socket.getaddrinfo(phost, pport, 0, socket.SOCK_STREAM, socket.SOL_TCP) return addrinfo_list, True, pauth except socket.gaierror as e: raise WebSocketAddressException(e) @@ -106,29 +164,41 @@ def _open_socket(addrinfo_list, sockopt, timeout): family, socktype, proto = addrinfo[:3] sock = socket.socket(family, socktype, proto) sock.settimeout(timeout) - # for opts in DEFAULT_SOCKET_OPTION: - # sock.setsockopt(*opts) - # for opts in sockopt: - # sock.setsockopt(*opts) + for opts in DEFAULT_SOCKET_OPTION: + sock.setsockopt(*opts) + for opts in sockopt: + sock.setsockopt(*opts) address = addrinfo[4] - try: - sock.connect(address) - except socket.error as error: - error.remote_ip = str(address[0]) + err = None + while not err: try: - eConnRefused = (errno.ECONNREFUSED, errno.WSAECONNREFUSED) - except: - eConnRefused = (errno.ECONNREFUSED, ) - if error.errno in eConnRefused: - err = error + sock.connect(address) + except ProxyConnectionError as error: + err = WebSocketProxyException(str(error)) + err.remote_ip = str(address[0]) continue + except socket.error as error: + error.remote_ip = str(address[0]) + try: + eConnRefused = (errno.ECONNREFUSED, errno.WSAECONNREFUSED) + except: + eConnRefused = (errno.ECONNREFUSED, ) + if error.errno == errno.EINTR: + continue + elif error.errno in eConnRefused: + err = error + continue + else: + raise error else: - raise + break else: - break + continue + break else: - raise err + if err: + raise err return sock @@ -141,7 +211,12 @@ def _wrap_sni_socket(sock, sslopt, hostname, check_hostname): context = ssl.SSLContext(sslopt.get('ssl_version', ssl.PROTOCOL_SSLv23)) if sslopt.get('cert_reqs', ssl.CERT_NONE) != ssl.CERT_NONE: - context.load_verify_locations(cafile=sslopt.get('ca_certs', None), capath=sslopt.get('ca_cert_path', None)) + cafile = sslopt.get('ca_certs', None) + capath = sslopt.get('ca_cert_path', None) + if cafile or capath: + context.load_verify_locations(cafile=cafile, capath=capath) + elif hasattr(context, 'load_default_certs'): + context.load_default_certs(ssl.Purpose.SERVER_AUTH) if sslopt.get('certfile', None): context.load_cert_chain( sslopt['certfile'], @@ -173,15 +248,13 @@ def _ssl_socket(sock, user_sslopt, hostname): sslopt = dict(cert_reqs=ssl.CERT_REQUIRED) sslopt.update(user_sslopt) - if os.environ.get('WEBSOCKET_CLIENT_CA_BUNDLE'): - certPath = os.environ.get('WEBSOCKET_CLIENT_CA_BUNDLE') - else: - certPath = os.path.join( - os.path.dirname(__file__), "cacert.pem") - if os.path.isfile(certPath) and user_sslopt.get('ca_certs', None) is None \ + certPath = os.environ.get('WEBSOCKET_CLIENT_CA_BUNDLE') + if certPath and os.path.isfile(certPath) \ + and user_sslopt.get('ca_certs', None) is None \ and user_sslopt.get('ca_cert', None) is None: sslopt['ca_certs'] = certPath - elif os.path.isdir(certPath) and user_sslopt.get('ca_cert_path', None) is None: + elif certPath and os.path.isdir(certPath) \ + and user_sslopt.get('ca_cert_path', None) is None: sslopt['ca_cert_path'] = certPath check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop( @@ -207,7 +280,7 @@ def _tunnel(sock, host, port, auth): auth_str = auth[0] if auth[1]: auth_str += ":" + auth[1] - encoded_str = base64encode(auth_str.encode()).strip().decode() + encoded_str = base64encode(auth_str.encode()).strip().decode().replace('\n', '') connect_header += "Proxy-Authorization: Basic %s\r\n" % encoded_str connect_header += "\r\n" dump("request header", connect_header) @@ -242,7 +315,8 @@ def read_headers(sock): status_info = line.split(" ", 2) status = int(status_info[1]) - status_message = status_info[2] + if len(status_info) > 2: + status_message = status_info[2] else: kv = line.split(":", 1) if len(kv) == 2: diff --git a/src/dingdinghelper/websocket/_logging.py b/src/dingdinghelper/websocket/_logging.py index d406db6a99c8ea7a2186303e9c68e29e672f3c2e..c94777899091d3fcb197faf594f280077cb4ae24 100644 --- a/src/dingdinghelper/websocket/_logging.py +++ b/src/dingdinghelper/websocket/_logging.py @@ -22,13 +22,22 @@ Copyright (C) 2010 Hiroki Ohtani(liris) import logging _logger = logging.getLogger('websocket') +try: + from logging import NullHandler +except ImportError: + class NullHandler(logging.Handler): + def emit(self, record): + pass + +_logger.addHandler(NullHandler()) + _traceEnabled = False __all__ = ["enableTrace", "dump", "error", "warning", "debug", "trace", - "isEnabledForError", "isEnabledForDebug"] + "isEnabledForError", "isEnabledForDebug", "isEnabledForTrace"] -def enableTrace(traceable): +def enableTrace(traceable, handler = logging.StreamHandler()): """ turn on/off the traceability. @@ -37,11 +46,9 @@ def enableTrace(traceable): global _traceEnabled _traceEnabled = traceable if traceable: - if not _logger.handlers: - _logger.addHandler(logging.StreamHandler()) + _logger.addHandler(handler) _logger.setLevel(logging.DEBUG) - def dump(title, message): if _traceEnabled: _logger.debug("--- " + title + " ---") @@ -72,3 +79,6 @@ def isEnabledForError(): def isEnabledForDebug(): return _logger.isEnabledFor(logging.DEBUG) + +def isEnabledForTrace(): + return _traceEnabled diff --git a/src/dingdinghelper/websocket/_socket.py b/src/dingdinghelper/websocket/_socket.py index c84fcf90a0c31e25ed1986450df54f893d975dfe..7be39138c628dc63b56349c25d12c9e7de9d2bad 100644 --- a/src/dingdinghelper/websocket/_socket.py +++ b/src/dingdinghelper/websocket/_socket.py @@ -19,6 +19,8 @@ Copyright (C) 2010 Hiroki Ohtani(liris) Boston, MA 02110-1335 USA """ +import errno +import select import socket import six @@ -77,8 +79,27 @@ def recv(sock, bufsize): if not sock: raise WebSocketConnectionClosedException("socket is already closed.") + def _recv(): + try: + return sock.recv(bufsize) + except SSLWantReadError: + pass + except socket.error as exc: + error_code = extract_error_code(exc) + if error_code is None: + raise + if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK: + raise + + r, w, e = select.select((sock, ), (), (), sock.gettimeout()) + if r: + return sock.recv(bufsize) + try: - bytes_ = sock.recv(bufsize) + if sock.gettimeout() == 0: + bytes_ = sock.recv(bufsize) + else: + bytes_ = _recv() except socket.timeout as e: message = extract_err_message(e) raise WebSocketTimeoutException(message) @@ -113,8 +134,27 @@ def send(sock, data): if not sock: raise WebSocketConnectionClosedException("socket is already closed.") + def _send(): + try: + return sock.send(data) + except SSLWantWriteError: + pass + except socket.error as exc: + error_code = extract_error_code(exc) + if error_code is None: + raise + if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK: + raise + + r, w, e = select.select((), (sock, ), (), sock.gettimeout()) + if w: + return sock.send(data) + try: - return sock.send(data) + if sock.gettimeout() == 0: + return sock.send(data) + else: + return _send() except socket.timeout as e: message = extract_err_message(e) raise WebSocketTimeoutException(message) diff --git a/src/dingdinghelper/websocket/_ssl_compat.py b/src/dingdinghelper/websocket/_ssl_compat.py index 0304816286aba4db42f6c2163cda257acf49b54b..96cd173e6dbf5a811a41f51f77aed79ad1c09da7 100644 --- a/src/dingdinghelper/websocket/_ssl_compat.py +++ b/src/dingdinghelper/websocket/_ssl_compat.py @@ -19,11 +19,13 @@ Copyright (C) 2010 Hiroki Ohtani(liris) Boston, MA 02110-1335 USA """ -__all__ = ["HAVE_SSL", "ssl", "SSLError"] +__all__ = ["HAVE_SSL", "ssl", "SSLError", "SSLWantReadError", "SSLWantWriteError"] try: import ssl from ssl import SSLError + from ssl import SSLWantReadError + from ssl import SSLWantWriteError if hasattr(ssl, 'SSLContext') and hasattr(ssl.SSLContext, 'check_hostname'): HAVE_CONTEXT_CHECK_HOSTNAME = True else: @@ -41,4 +43,12 @@ except ImportError: class SSLError(Exception): pass + class SSLWantReadError(Exception): + pass + + class SSLWantWriteError(Exception): + pass + + ssl = lambda: None + HAVE_SSL = False diff --git a/src/dingdinghelper/websocket/_url.py b/src/dingdinghelper/websocket/_url.py index f7bdf346708c30514f185fda917dcbc3072a672c..a394fc34963d95b2fef1a3f492572485838c372b 100644 --- a/src/dingdinghelper/websocket/_url.py +++ b/src/dingdinghelper/websocket/_url.py @@ -103,7 +103,8 @@ def _is_address_in_network(ip, net): def _is_no_proxy_host(hostname, no_proxy): if not no_proxy: v = os.environ.get("no_proxy", "").replace(" ", "") - no_proxy = v.split(",") + if v: + no_proxy = v.split(",") if not no_proxy: no_proxy = DEFAULT_NO_PROXY_HOST @@ -117,7 +118,7 @@ def _is_no_proxy_host(hostname, no_proxy): def get_proxy_info( hostname, is_secure, proxy_host=None, proxy_port=0, proxy_auth=None, - no_proxy=None): + no_proxy=None, proxy_type='http'): """ try to retrieve proxy host and port from environment if not provided in options. @@ -137,6 +138,9 @@ def get_proxy_info( "http_proxy_auth" - http proxy auth information. tuple of username and password. default is None + "proxy_type" - if set to "socks5" PySocks wrapper + will be used in place of a http proxy. + default is "http" """ if _is_no_proxy_host(hostname, no_proxy): return None, 0, None diff --git a/src/dingdinghelper/websocket/_utils.py b/src/dingdinghelper/websocket/_utils.py index 399fb89d9eece087ce2a929f4b0d7f537d8f3259..32ee12ee206e47e4c9c9b9302b44def5e7dee780 100644 --- a/src/dingdinghelper/websocket/_utils.py +++ b/src/dingdinghelper/websocket/_utils.py @@ -21,7 +21,7 @@ Copyright (C) 2010 Hiroki Ohtani(liris) """ import six -__all__ = ["NoLock", "validate_utf8", "extract_err_message"] +__all__ = ["NoLock", "validate_utf8", "extract_err_message", "extract_error_code"] class NoLock(object): @@ -32,6 +32,7 @@ class NoLock(object): def __exit__(self, exc_type, exc_value, traceback): pass + try: # If wsaccel is available we use compiled routines to validate UTF-8 # strings. @@ -103,3 +104,8 @@ def extract_err_message(exception): return exception.args[0] else: return None + + +def extract_error_code(exception): + if exception.args and len(exception.args) > 1: + return exception.args[0] if isinstance(exception.args[0], int) else None diff --git a/src/dingdinghelper/websocket/tests/__init__.py b/src/dingdinghelper/websocket/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/dingdinghelper/websocket/tests/data/header01.txt b/src/dingdinghelper/websocket/tests/data/header01.txt new file mode 100644 index 0000000000000000000000000000000000000000..d44d24c205b97b501a309825b7dd13cc97e371a3 --- /dev/null +++ b/src/dingdinghelper/websocket/tests/data/header01.txt @@ -0,0 +1,6 @@ +HTTP/1.1 101 WebSocket Protocol Handshake +Connection: Upgrade +Upgrade: WebSocket +Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0= +some_header: something + diff --git a/src/dingdinghelper/websocket/tests/data/header02.txt b/src/dingdinghelper/websocket/tests/data/header02.txt new file mode 100644 index 0000000000000000000000000000000000000000..f481de928a8eeaffe1eec8f484cb66d9351cd869 --- /dev/null +++ b/src/dingdinghelper/websocket/tests/data/header02.txt @@ -0,0 +1,6 @@ +HTTP/1.1 101 WebSocket Protocol Handshake +Connection: Upgrade +Upgrade WebSocket +Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0= +some_header: something + diff --git a/src/dingdinghelper/websocket/tests/test_cookiejar.py b/src/dingdinghelper/websocket/tests/test_cookiejar.py new file mode 100644 index 0000000000000000000000000000000000000000..c40a00bd2cdd4e55f774e54e2f0691e98fafd671 --- /dev/null +++ b/src/dingdinghelper/websocket/tests/test_cookiejar.py @@ -0,0 +1,98 @@ +import unittest + +from websocket._cookiejar import SimpleCookieJar + +try: + import Cookie +except: + import http.cookies as Cookie + + +class CookieJarTest(unittest.TestCase): + def testAdd(self): + cookie_jar = SimpleCookieJar() + cookie_jar.add("") + self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar") + + cookie_jar = SimpleCookieJar() + cookie_jar.add("a=b") + self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar") + + cookie_jar = SimpleCookieJar() + cookie_jar.add("a=b; domain=.abc") + self.assertTrue(".abc" in cookie_jar.jar) + + cookie_jar = SimpleCookieJar() + cookie_jar.add("a=b; domain=abc") + self.assertTrue(".abc" in cookie_jar.jar) + self.assertTrue("abc" not in cookie_jar.jar) + + cookie_jar = SimpleCookieJar() + cookie_jar.add("a=b; c=d; domain=abc") + self.assertEquals(cookie_jar.get("abc"), "a=b; c=d") + + cookie_jar = SimpleCookieJar() + cookie_jar.add("a=b; c=d; domain=abc") + cookie_jar.add("e=f; domain=abc") + self.assertEquals(cookie_jar.get("abc"), "a=b; c=d; e=f") + + cookie_jar = SimpleCookieJar() + cookie_jar.add("a=b; c=d; domain=abc") + cookie_jar.add("e=f; domain=.abc") + self.assertEquals(cookie_jar.get("abc"), "a=b; c=d; e=f") + + cookie_jar = SimpleCookieJar() + cookie_jar.add("a=b; c=d; domain=abc") + cookie_jar.add("e=f; domain=xyz") + self.assertEquals(cookie_jar.get("abc"), "a=b; c=d") + self.assertEquals(cookie_jar.get("xyz"), "e=f") + self.assertEquals(cookie_jar.get("something"), "") + + def testSet(self): + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b") + self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar") + + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b; domain=.abc") + self.assertTrue(".abc" in cookie_jar.jar) + + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b; domain=abc") + self.assertTrue(".abc" in cookie_jar.jar) + self.assertTrue("abc" not in cookie_jar.jar) + + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b; c=d; domain=abc") + self.assertEquals(cookie_jar.get("abc"), "a=b; c=d") + + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b; c=d; domain=abc") + cookie_jar.set("e=f; domain=abc") + self.assertEquals(cookie_jar.get("abc"), "e=f") + + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b; c=d; domain=abc") + cookie_jar.set("e=f; domain=.abc") + self.assertEquals(cookie_jar.get("abc"), "e=f") + + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b; c=d; domain=abc") + cookie_jar.set("e=f; domain=xyz") + self.assertEquals(cookie_jar.get("abc"), "a=b; c=d") + self.assertEquals(cookie_jar.get("xyz"), "e=f") + self.assertEquals(cookie_jar.get("something"), "") + + def testGet(self): + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b; c=d; domain=abc.com") + self.assertEquals(cookie_jar.get("abc.com"), "a=b; c=d") + self.assertEquals(cookie_jar.get("x.abc.com"), "a=b; c=d") + self.assertEquals(cookie_jar.get("abc.com.es"), "") + self.assertEquals(cookie_jar.get("xabc.com"), "") + + cookie_jar.set("a=b; c=d; domain=.abc.com") + self.assertEquals(cookie_jar.get("abc.com"), "a=b; c=d") + self.assertEquals(cookie_jar.get("x.abc.com"), "a=b; c=d") + self.assertEquals(cookie_jar.get("abc.com.es"), "") + self.assertEquals(cookie_jar.get("xabc.com"), "") diff --git a/src/dingdinghelper/websocket/tests/test_websocket.py b/src/dingdinghelper/websocket/tests/test_websocket.py new file mode 100644 index 0000000000000000000000000000000000000000..8b131bb6775513ef6cdf86c078c4c457d8a8b816 --- /dev/null +++ b/src/dingdinghelper/websocket/tests/test_websocket.py @@ -0,0 +1,665 @@ +# -*- coding: utf-8 -*- +# + +import sys +sys.path[0:0] = [""] + +import os +import os.path +import socket + +import six + +# websocket-client +import websocket as ws +from websocket._handshake import _create_sec_websocket_key, \ + _validate as _validate_header +from websocket._http import read_headers +from websocket._url import get_proxy_info, parse_url +from websocket._utils import validate_utf8 + +if six.PY3: + from base64 import decodebytes as base64decode +else: + from base64 import decodestring as base64decode + +if sys.version_info[0] == 2 and sys.version_info[1] < 7: + import unittest2 as unittest +else: + import unittest + +try: + from ssl import SSLError +except ImportError: + # dummy class of SSLError for ssl none-support environment. + class SSLError(Exception): + pass + +# Skip test to access the internet. +TEST_WITH_INTERNET = os.environ.get('TEST_WITH_INTERNET', '0') == '1' + +# Skip Secure WebSocket test. +TEST_SECURE_WS = True +TRACEABLE = True + + +def create_mask_key(_): + return "abcd" + + +class SockMock(object): + def __init__(self): + self.data = [] + self.sent = [] + + def add_packet(self, data): + self.data.append(data) + + def gettimeout(self): + return None + + def recv(self, bufsize): + if self.data: + e = self.data.pop(0) + if isinstance(e, Exception): + raise e + if len(e) > bufsize: + self.data.insert(0, e[bufsize:]) + return e[:bufsize] + + def send(self, data): + self.sent.append(data) + return len(data) + + def close(self): + pass + + +class HeaderSockMock(SockMock): + + def __init__(self, fname): + SockMock.__init__(self) + path = os.path.join(os.path.dirname(__file__), fname) + with open(path, "rb") as f: + self.add_packet(f.read()) + + +class WebSocketTest(unittest.TestCase): + def setUp(self): + ws.enableTrace(TRACEABLE) + + def tearDown(self): + pass + + def testDefaultTimeout(self): + self.assertEqual(ws.getdefaulttimeout(), None) + ws.setdefaulttimeout(10) + self.assertEqual(ws.getdefaulttimeout(), 10) + ws.setdefaulttimeout(None) + + def testParseUrl(self): + p = parse_url("ws://www.example.com/r") + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 80) + self.assertEqual(p[2], "/r") + self.assertEqual(p[3], False) + + p = parse_url("ws://www.example.com/r/") + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 80) + self.assertEqual(p[2], "/r/") + self.assertEqual(p[3], False) + + p = parse_url("ws://www.example.com/") + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 80) + self.assertEqual(p[2], "/") + self.assertEqual(p[3], False) + + p = parse_url("ws://www.example.com") + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 80) + self.assertEqual(p[2], "/") + self.assertEqual(p[3], False) + + p = parse_url("ws://www.example.com:8080/r") + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 8080) + self.assertEqual(p[2], "/r") + self.assertEqual(p[3], False) + + p = parse_url("ws://www.example.com:8080/") + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 8080) + self.assertEqual(p[2], "/") + self.assertEqual(p[3], False) + + p = parse_url("ws://www.example.com:8080") + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 8080) + self.assertEqual(p[2], "/") + self.assertEqual(p[3], False) + + p = parse_url("wss://www.example.com:8080/r") + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 8080) + self.assertEqual(p[2], "/r") + self.assertEqual(p[3], True) + + p = parse_url("wss://www.example.com:8080/r?key=value") + self.assertEqual(p[0], "www.example.com") + self.assertEqual(p[1], 8080) + self.assertEqual(p[2], "/r?key=value") + self.assertEqual(p[3], True) + + self.assertRaises(ValueError, parse_url, "http://www.example.com/r") + + if sys.version_info[0] == 2 and sys.version_info[1] < 7: + return + + p = parse_url("ws://[2a03:4000:123:83::3]/r") + self.assertEqual(p[0], "2a03:4000:123:83::3") + self.assertEqual(p[1], 80) + self.assertEqual(p[2], "/r") + self.assertEqual(p[3], False) + + p = parse_url("ws://[2a03:4000:123:83::3]:8080/r") + self.assertEqual(p[0], "2a03:4000:123:83::3") + self.assertEqual(p[1], 8080) + self.assertEqual(p[2], "/r") + self.assertEqual(p[3], False) + + p = parse_url("wss://[2a03:4000:123:83::3]/r") + self.assertEqual(p[0], "2a03:4000:123:83::3") + self.assertEqual(p[1], 443) + self.assertEqual(p[2], "/r") + self.assertEqual(p[3], True) + + p = parse_url("wss://[2a03:4000:123:83::3]:8080/r") + self.assertEqual(p[0], "2a03:4000:123:83::3") + self.assertEqual(p[1], 8080) + self.assertEqual(p[2], "/r") + self.assertEqual(p[3], True) + + def testWSKey(self): + key = _create_sec_websocket_key() + self.assertTrue(key != 24) + self.assertTrue(six.u("Â¥n") not in key) + + def testWsUtils(self): + key = "c6b8hTg4EeGb2gQMztV1/g==" + required_header = { + "upgrade": "websocket", + "connection": "upgrade", + "sec-websocket-accept": "Kxep+hNu9n51529fGidYu7a3wO0=", + } + self.assertEqual(_validate_header(required_header, key, None), (True, None)) + + header = required_header.copy() + header["upgrade"] = "http" + self.assertEqual(_validate_header(header, key, None), (False, None)) + del header["upgrade"] + self.assertEqual(_validate_header(header, key, None), (False, None)) + + header = required_header.copy() + header["connection"] = "something" + self.assertEqual(_validate_header(header, key, None), (False, None)) + del header["connection"] + self.assertEqual(_validate_header(header, key, None), (False, None)) + + header = required_header.copy() + header["sec-websocket-accept"] = "something" + self.assertEqual(_validate_header(header, key, None), (False, None)) + del header["sec-websocket-accept"] + self.assertEqual(_validate_header(header, key, None), (False, None)) + + header = required_header.copy() + header["sec-websocket-protocol"] = "sub1" + self.assertEqual(_validate_header(header, key, ["sub1", "sub2"]), (True, "sub1")) + self.assertEqual(_validate_header(header, key, ["sub2", "sub3"]), (False, None)) + + header = required_header.copy() + header["sec-websocket-protocol"] = "sUb1" + self.assertEqual(_validate_header(header, key, ["Sub1", "suB2"]), (True, "sub1")) + + + def testReadHeader(self): + status, header, status_message = read_headers(HeaderSockMock("data/header01.txt")) + self.assertEqual(status, 101) + self.assertEqual(header["connection"], "Upgrade") + + HeaderSockMock("data/header02.txt") + self.assertRaises(ws.WebSocketException, read_headers, HeaderSockMock("data/header02.txt")) + + def testSend(self): + # TODO: add longer frame data + sock = ws.WebSocket() + sock.set_mask_key(create_mask_key) + s = sock.sock = HeaderSockMock("data/header01.txt") + sock.send("Hello") + self.assertEqual(s.sent[0], six.b("\x81\x85abcd)\x07\x0f\x08\x0e")) + + sock.send("ã“ã‚“ã«ã¡ã¯") + self.assertEqual(s.sent[1], six.b("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")) + + sock.send(u"ã“ã‚“ã«ã¡ã¯") + self.assertEqual(s.sent[1], six.b("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")) + + sock.send("x" * 127) + + def testRecv(self): + # TODO: add longer frame data + sock = ws.WebSocket() + s = sock.sock = SockMock() + something = six.b("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc") + s.add_packet(something) + data = sock.recv() + self.assertEqual(data, "ã“ã‚“ã«ã¡ã¯") + + s.add_packet(six.b("\x81\x85abcd)\x07\x0f\x08\x0e")) + data = sock.recv() + self.assertEqual(data, "Hello") + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def testIter(self): + count = 2 + for _ in ws.create_connection('ws://stream.meetup.com/2/rsvps'): + count -= 1 + if count == 0: + break + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def testNext(self): + sock = ws.create_connection('ws://stream.meetup.com/2/rsvps') + self.assertEqual(str, type(next(sock))) + + def testInternalRecvStrict(self): + sock = ws.WebSocket() + s = sock.sock = SockMock() + s.add_packet(six.b("foo")) + s.add_packet(socket.timeout()) + s.add_packet(six.b("bar")) + # s.add_packet(SSLError("The read operation timed out")) + s.add_packet(six.b("baz")) + with self.assertRaises(ws.WebSocketTimeoutException): + sock.frame_buffer.recv_strict(9) + # if six.PY2: + # with self.assertRaises(ws.WebSocketTimeoutException): + # data = sock._recv_strict(9) + # else: + # with self.assertRaises(SSLError): + # data = sock._recv_strict(9) + data = sock.frame_buffer.recv_strict(9) + self.assertEqual(data, six.b("foobarbaz")) + with self.assertRaises(ws.WebSocketConnectionClosedException): + sock.frame_buffer.recv_strict(1) + + def testRecvTimeout(self): + sock = ws.WebSocket() + s = sock.sock = SockMock() + s.add_packet(six.b("\x81")) + s.add_packet(socket.timeout()) + s.add_packet(six.b("\x8dabcd\x29\x07\x0f\x08\x0e")) + s.add_packet(socket.timeout()) + s.add_packet(six.b("\x4e\x43\x33\x0e\x10\x0f\x00\x40")) + with self.assertRaises(ws.WebSocketTimeoutException): + sock.recv() + with self.assertRaises(ws.WebSocketTimeoutException): + sock.recv() + data = sock.recv() + self.assertEqual(data, "Hello, World!") + with self.assertRaises(ws.WebSocketConnectionClosedException): + sock.recv() + + def testRecvWithSimpleFragmentation(self): + sock = ws.WebSocket() + s = sock.sock = SockMock() + # OPCODE=TEXT, FIN=0, MSG="Brevity is " + s.add_packet(six.b("\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")) + # OPCODE=CONT, FIN=1, MSG="the soul of wit" + s.add_packet(six.b("\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")) + data = sock.recv() + self.assertEqual(data, "Brevity is the soul of wit") + with self.assertRaises(ws.WebSocketConnectionClosedException): + sock.recv() + + def testRecvWithFireEventOfFragmentation(self): + sock = ws.WebSocket(fire_cont_frame=True) + s = sock.sock = SockMock() + # OPCODE=TEXT, FIN=0, MSG="Brevity is " + s.add_packet(six.b("\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")) + # OPCODE=CONT, FIN=0, MSG="Brevity is " + s.add_packet(six.b("\x00\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")) + # OPCODE=CONT, FIN=1, MSG="the soul of wit" + s.add_packet(six.b("\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")) + + _, data = sock.recv_data() + self.assertEqual(data, six.b("Brevity is ")) + _, data = sock.recv_data() + self.assertEqual(data, six.b("Brevity is ")) + _, data = sock.recv_data() + self.assertEqual(data, six.b("the soul of wit")) + + # OPCODE=CONT, FIN=0, MSG="Brevity is " + s.add_packet(six.b("\x80\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")) + + with self.assertRaises(ws.WebSocketException): + sock.recv_data() + + with self.assertRaises(ws.WebSocketConnectionClosedException): + sock.recv() + + def testClose(self): + sock = ws.WebSocket() + sock.sock = SockMock() + sock.connected = True + sock.close() + self.assertEqual(sock.connected, False) + + sock = ws.WebSocket() + s = sock.sock = SockMock() + sock.connected = True + s.add_packet(six.b('\x88\x80\x17\x98p\x84')) + sock.recv() + self.assertEqual(sock.connected, False) + + def testRecvContFragmentation(self): + sock = ws.WebSocket() + s = sock.sock = SockMock() + # OPCODE=CONT, FIN=1, MSG="the soul of wit" + s.add_packet(six.b("\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")) + self.assertRaises(ws.WebSocketException, sock.recv) + + def testRecvWithProlongedFragmentation(self): + sock = ws.WebSocket() + s = sock.sock = SockMock() + # OPCODE=TEXT, FIN=0, MSG="Once more unto the breach, " + s.add_packet(six.b("\x01\x9babcd.\x0c\x00\x01A\x0f\x0c\x16\x04B\x16\n\x15" + "\rC\x10\t\x07C\x06\x13\x07\x02\x07\tNC")) + # OPCODE=CONT, FIN=0, MSG="dear friends, " + s.add_packet(six.b("\x00\x8eabcd\x05\x07\x02\x16A\x04\x11\r\x04\x0c\x07" + "\x17MB")) + # OPCODE=CONT, FIN=1, MSG="once more" + s.add_packet(six.b("\x80\x89abcd\x0e\x0c\x00\x01A\x0f\x0c\x16\x04")) + data = sock.recv() + self.assertEqual( + data, + "Once more unto the breach, dear friends, once more") + with self.assertRaises(ws.WebSocketConnectionClosedException): + sock.recv() + + def testRecvWithFragmentationAndControlFrame(self): + sock = ws.WebSocket() + sock.set_mask_key(create_mask_key) + s = sock.sock = SockMock() + # OPCODE=TEXT, FIN=0, MSG="Too much " + s.add_packet(six.b("\x01\x89abcd5\r\x0cD\x0c\x17\x00\x0cA")) + # OPCODE=PING, FIN=1, MSG="Please PONG this" + s.add_packet(six.b("\x89\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17")) + # OPCODE=CONT, FIN=1, MSG="of a good thing" + s.add_packet(six.b("\x80\x8fabcd\x0e\x04C\x05A\x05\x0c\x0b\x05B\x17\x0c" + "\x08\x0c\x04")) + data = sock.recv() + self.assertEqual(data, "Too much of a good thing") + with self.assertRaises(ws.WebSocketConnectionClosedException): + sock.recv() + self.assertEqual( + s.sent[0], + six.b("\x8a\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17")) + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def testWebSocket(self): + s = ws.create_connection("ws://echo.websocket.org/") + self.assertNotEqual(s, None) + s.send("Hello, World") + result = s.recv() + self.assertEqual(result, "Hello, World") + + s.send(u"ã“ã«ã‚ƒã«ã‚ƒã¡ã¯ã€ä¸–界") + result = s.recv() + self.assertEqual(result, "ã“ã«ã‚ƒã«ã‚ƒã¡ã¯ã€ä¸–界") + s.close() + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def testPingPong(self): + s = ws.create_connection("ws://echo.websocket.org/") + self.assertNotEqual(s, None) + s.ping("Hello") + s.pong("Hi") + s.close() + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + @unittest.skipUnless(TEST_SECURE_WS, "wss://echo.websocket.org doesn't work well.") + def testSecureWebSocket(self): + if 1: + import ssl + s = ws.create_connection("wss://echo.websocket.org/") + self.assertNotEqual(s, None) + self.assertTrue(isinstance(s.sock, ssl.SSLSocket)) + s.send("Hello, World") + result = s.recv() + self.assertEqual(result, "Hello, World") + s.send(u"ã“ã«ã‚ƒã«ã‚ƒã¡ã¯ã€ä¸–界") + result = s.recv() + self.assertEqual(result, "ã“ã«ã‚ƒã«ã‚ƒã¡ã¯ã€ä¸–界") + s.close() + #except: + # pass + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def testWebSocketWihtCustomHeader(self): + s = ws.create_connection("ws://echo.websocket.org/", + headers={"User-Agent": "PythonWebsocketClient"}) + self.assertNotEqual(s, None) + s.send("Hello, World") + result = s.recv() + self.assertEqual(result, "Hello, World") + s.close() + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def testAfterClose(self): + s = ws.create_connection("ws://echo.websocket.org/") + self.assertNotEqual(s, None) + s.close() + self.assertRaises(ws.WebSocketConnectionClosedException, s.send, "Hello") + self.assertRaises(ws.WebSocketConnectionClosedException, s.recv) + + def testNonce(self): + """ WebSocket key should be a random 16-byte nonce. + """ + key = _create_sec_websocket_key() + nonce = base64decode(key.encode("utf-8")) + self.assertEqual(16, len(nonce)) + + +class WebSocketAppTest(unittest.TestCase): + + class NotSetYet(object): + """ A marker class for signalling that a value hasn't been set yet. + """ + + def setUp(self): + ws.enableTrace(TRACEABLE) + + WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet() + WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet() + WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet() + + def tearDown(self): + WebSocketAppTest.keep_running_open = WebSocketAppTest.NotSetYet() + WebSocketAppTest.keep_running_close = WebSocketAppTest.NotSetYet() + WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet() + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def testKeepRunning(self): + """ A WebSocketApp should keep running as long as its self.keep_running + is not False (in the boolean context). + """ + + def on_open(self, *args, **kwargs): + """ Set the keep_running flag for later inspection and immediately + close the connection. + """ + WebSocketAppTest.keep_running_open = self.keep_running + + self.close() + + def on_close(self, *args, **kwargs): + """ Set the keep_running flag for the test to use. + """ + WebSocketAppTest.keep_running_close = self.keep_running + + app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, on_close=on_close) + app.run_forever() + + # if numpy is installed, this assertion fail + # self.assertFalse(isinstance(WebSocketAppTest.keep_running_open, + # WebSocketAppTest.NotSetYet)) + + # self.assertFalse(isinstance(WebSocketAppTest.keep_running_close, + # WebSocketAppTest.NotSetYet)) + + # self.assertEqual(True, WebSocketAppTest.keep_running_open) + # self.assertEqual(False, WebSocketAppTest.keep_running_close) + + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def testSockMaskKey(self): + """ A WebSocketApp should forward the received mask_key function down + to the actual socket. + """ + + def my_mask_key_func(): + pass + + def on_open(self, *args, **kwargs): + """ Set the value so the test can use it later on and immediately + close the connection. + """ + WebSocketAppTest.get_mask_key_id = id(self.get_mask_key) + self.close() + + app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, get_mask_key=my_mask_key_func) + app.run_forever() + + # if numpu is installed, this assertion fail + # Note: We can't use 'is' for comparing the functions directly, need to use 'id'. + # self.assertEqual(WebSocketAppTest.get_mask_key_id, id(my_mask_key_func)) + + +class SockOptTest(unittest.TestCase): + @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") + def testSockOpt(self): + sockopt = ((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),) + s = ws.create_connection("ws://echo.websocket.org", sockopt=sockopt) + self.assertNotEqual(s.sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY), 0) + s.close() + + +class UtilsTest(unittest.TestCase): + def testUtf8Validator(self): + state = validate_utf8(six.b('\xf0\x90\x80\x80')) + self.assertEqual(state, True) + state = validate_utf8(six.b('\xce\xba\xe1\xbd\xb9\xcf\x83\xce\xbc\xce\xb5\xed\xa0\x80edited')) + self.assertEqual(state, False) + state = validate_utf8(six.b('')) + self.assertEqual(state, True) + + +class ProxyInfoTest(unittest.TestCase): + def setUp(self): + self.http_proxy = os.environ.get("http_proxy", None) + self.https_proxy = os.environ.get("https_proxy", None) + if "http_proxy" in os.environ: + del os.environ["http_proxy"] + if "https_proxy" in os.environ: + del os.environ["https_proxy"] + + def tearDown(self): + if self.http_proxy: + os.environ["http_proxy"] = self.http_proxy + elif "http_proxy" in os.environ: + del os.environ["http_proxy"] + + if self.https_proxy: + os.environ["https_proxy"] = self.https_proxy + elif "https_proxy" in os.environ: + del os.environ["https_proxy"] + + def testProxyFromArgs(self): + self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost"), ("localhost", 0, None)) + self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost", proxy_port=3128), ("localhost", 3128, None)) + self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost"), ("localhost", 0, None)) + self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128), ("localhost", 3128, None)) + + self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost", proxy_auth=("a", "b")), + ("localhost", 0, ("a", "b"))) + self.assertEqual(get_proxy_info("echo.websocket.org", False, proxy_host="localhost", proxy_port=3128, proxy_auth=("a", "b")), + ("localhost", 3128, ("a", "b"))) + self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_auth=("a", "b")), + ("localhost", 0, ("a", "b"))) + self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128, proxy_auth=("a", "b")), + ("localhost", 3128, ("a", "b"))) + + self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128, no_proxy=["example.com"], proxy_auth=("a", "b")), + ("localhost", 3128, ("a", "b"))) + self.assertEqual(get_proxy_info("echo.websocket.org", True, proxy_host="localhost", proxy_port=3128, no_proxy=["echo.websocket.org"], proxy_auth=("a", "b")), + (None, 0, None)) + + def testProxyFromEnv(self): + os.environ["http_proxy"] = "http://localhost/" + self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", None, None)) + os.environ["http_proxy"] = "http://localhost:3128/" + self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", 3128, None)) + + os.environ["http_proxy"] = "http://localhost/" + os.environ["https_proxy"] = "http://localhost2/" + self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", None, None)) + os.environ["http_proxy"] = "http://localhost:3128/" + os.environ["https_proxy"] = "http://localhost2:3128/" + self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", 3128, None)) + + os.environ["http_proxy"] = "http://localhost/" + os.environ["https_proxy"] = "http://localhost2/" + self.assertEqual(get_proxy_info("echo.websocket.org", True), ("localhost2", None, None)) + os.environ["http_proxy"] = "http://localhost:3128/" + os.environ["https_proxy"] = "http://localhost2:3128/" + self.assertEqual(get_proxy_info("echo.websocket.org", True), ("localhost2", 3128, None)) + + + os.environ["http_proxy"] = "http://a:b@localhost/" + self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", None, ("a", "b"))) + os.environ["http_proxy"] = "http://a:b@localhost:3128/" + self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", 3128, ("a", "b"))) + + os.environ["http_proxy"] = "http://a:b@localhost/" + os.environ["https_proxy"] = "http://a:b@localhost2/" + self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", None, ("a", "b"))) + os.environ["http_proxy"] = "http://a:b@localhost:3128/" + os.environ["https_proxy"] = "http://a:b@localhost2:3128/" + self.assertEqual(get_proxy_info("echo.websocket.org", False), ("localhost", 3128, ("a", "b"))) + + os.environ["http_proxy"] = "http://a:b@localhost/" + os.environ["https_proxy"] = "http://a:b@localhost2/" + self.assertEqual(get_proxy_info("echo.websocket.org", True), ("localhost2", None, ("a", "b"))) + os.environ["http_proxy"] = "http://a:b@localhost:3128/" + os.environ["https_proxy"] = "http://a:b@localhost2:3128/" + self.assertEqual(get_proxy_info("echo.websocket.org", True), ("localhost2", 3128, ("a", "b"))) + + os.environ["http_proxy"] = "http://a:b@localhost/" + os.environ["https_proxy"] = "http://a:b@localhost2/" + os.environ["no_proxy"] = "example1.com,example2.com" + self.assertEqual(get_proxy_info("example.1.com", True), ("localhost2", None, ("a", "b"))) + os.environ["http_proxy"] = "http://a:b@localhost:3128/" + os.environ["https_proxy"] = "http://a:b@localhost2:3128/" + os.environ["no_proxy"] = "example1.com,example2.com, echo.websocket.org" + self.assertEqual(get_proxy_info("echo.websocket.org", True), (None, 0, None)) + + os.environ["http_proxy"] = "http://a:b@localhost:3128/" + os.environ["https_proxy"] = "http://a:b@localhost2:3128/" + os.environ["no_proxy"] = "127.0.0.0/8, 192.168.0.0/16" + self.assertEqual(get_proxy_info("127.0.0.1", False), (None, 0, None)) + self.assertEqual(get_proxy_info("192.168.1.1", False), (None, 0, None)) + + +if __name__ == "__main__": + unittest.main()