Commit 54dae43b authored by l2m2's avatar l2m2

upload.

parent 3ca68ace
...@@ -26,4 +26,4 @@ from ._exceptions import * ...@@ -26,4 +26,4 @@ from ._exceptions import *
from ._logging import * from ._logging import *
from ._socket import * from ._socket import *
__version__ = "0.47.0" __version__ = "0.57.0"
...@@ -23,6 +23,7 @@ Copyright (C) 2010 Hiroki Ohtani(liris) ...@@ -23,6 +23,7 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
""" """
WebSocketApp provides higher level APIs. WebSocketApp provides higher level APIs.
""" """
import inspect
import select import select
import sys import sys
import threading import threading
...@@ -44,23 +45,27 @@ class Dispatcher: ...@@ -44,23 +45,27 @@ class Dispatcher:
self.app = app self.app = app
self.ping_timeout = ping_timeout self.ping_timeout = ping_timeout
def read(self, sock, callback): def read(self, sock, read_callback, check_callback):
while self.app.sock.connected: while self.app.keep_running:
r, w, e = select.select( r, w, e = select.select(
(self.app.sock.sock, ), (), (), self.ping_timeout) # Use a 10 second timeout to avoid to wait forever on close (self.app.sock.sock, ), (), (), self.ping_timeout)
if r: if r:
callback() if not read_callback():
break
check_callback()
class SSLDispacther: class SSLDispatcher:
def __init__(self, app, ping_timeout): def __init__(self, app, ping_timeout):
self.app = app self.app = app
self.ping_timeout = ping_timeout self.ping_timeout = ping_timeout
def read(self, sock, callback): def read(self, sock, read_callback, check_callback):
while self.app.sock.connected: while self.app.keep_running:
r = self.select() r = self.select()
if r: if r:
callback() if not read_callback():
break
check_callback()
def select(self): def select(self):
sock = self.app.sock.sock sock = self.app.sock.sock
...@@ -70,6 +75,7 @@ class SSLDispacther: ...@@ -70,6 +75,7 @@ class SSLDispacther:
r, w, e = select.select((sock, ), (), (), self.ping_timeout) r, w, e = select.select((sock, ), (), (), self.ping_timeout)
return r return r
class WebSocketApp(object): class WebSocketApp(object):
""" """
Higher level of APIs are provided. Higher level of APIs are provided.
...@@ -113,7 +119,7 @@ class WebSocketApp(object): ...@@ -113,7 +119,7 @@ class WebSocketApp(object):
The 2nd argument is utf-8 string which we get from the server. The 2nd argument is utf-8 string which we get from the server.
The 3rd argument is data type. ABNF.OPCODE_TEXT or ABNF.OPCODE_BINARY will be came. The 3rd argument is data type. ABNF.OPCODE_TEXT or ABNF.OPCODE_BINARY will be came.
The 4th argument is continue flag. if 0, the data continue The 4th argument is continue flag. if 0, the data continue
keep_running: this parameter is obosleted and ignored it. keep_running: this parameter is obsolete and ignored.
get_mask_key: a callable to produce new mask keys, get_mask_key: a callable to produce new mask keys,
see the WebSocket.set_mask_key's docstring for more information see the WebSocket.set_mask_key's docstring for more information
subprotocols: array of available sub protocols. default is None. subprotocols: array of available sub protocols. default is None.
...@@ -121,6 +127,7 @@ class WebSocketApp(object): ...@@ -121,6 +127,7 @@ class WebSocketApp(object):
self.url = url self.url = url
self.header = header if header is not None else [] self.header = header if header is not None else []
self.cookie = cookie self.cookie = cookie
self.on_open = on_open self.on_open = on_open
self.on_message = on_message self.on_message = on_message
self.on_data = on_data self.on_data = on_data
...@@ -155,6 +162,7 @@ class WebSocketApp(object): ...@@ -155,6 +162,7 @@ class WebSocketApp(object):
self.keep_running = False self.keep_running = False
if self.sock: if self.sock:
self.sock.close(**kwargs) self.sock.close(**kwargs)
self.sock = None
def _send_ping(self, interval, event): def _send_ping(self, interval, event):
while not event.wait(interval): while not event.wait(interval):
...@@ -171,7 +179,8 @@ class WebSocketApp(object): ...@@ -171,7 +179,8 @@ class WebSocketApp(object):
http_proxy_host=None, http_proxy_port=None, http_proxy_host=None, http_proxy_port=None,
http_no_proxy=None, http_proxy_auth=None, http_no_proxy=None, http_proxy_auth=None,
skip_utf8_validation=False, skip_utf8_validation=False,
host=None, origin=None, dispatcher=None): host=None, origin=None, dispatcher=None,
suppress_origin=False, proxy_type=None):
""" """
run event loop for WebSocket framework. run event loop for WebSocket framework.
This loop is infinite loop and is alive during websocket is available. This loop is infinite loop and is alive during websocket is available.
...@@ -189,31 +198,41 @@ class WebSocketApp(object): ...@@ -189,31 +198,41 @@ class WebSocketApp(object):
skip_utf8_validation: skip utf8 validation. skip_utf8_validation: skip utf8 validation.
host: update host header. host: update host header.
origin: update origin header. origin: update origin header.
dispatcher: customize reading data from socket.
suppress_origin: suppress outputting origin header.
Returns
-------
False if caught KeyboardInterrupt
True if other exception was raised during a loop
""" """
if not ping_timeout or ping_timeout <= 0: if ping_timeout is not None and ping_timeout <= 0:
ping_timeout = None ping_timeout = None
if ping_timeout and ping_interval and ping_interval <= ping_timeout: if ping_timeout and ping_interval and ping_interval <= ping_timeout:
raise WebSocketException("Ensure ping_interval > ping_timeout") raise WebSocketException("Ensure ping_interval > ping_timeout")
if sockopt is None: if not sockopt:
sockopt = [] sockopt = []
if sslopt is None: if not sslopt:
sslopt = {} sslopt = {}
if self.sock: if self.sock:
raise WebSocketException("socket is already opened") raise WebSocketException("socket is already opened")
thread = None thread = None
close_frame = None
self.keep_running = True self.keep_running = True
self.last_ping_tm = 0 self.last_ping_tm = 0
self.last_pong_tm = 0 self.last_pong_tm = 0
def teardown(): def teardown(close_frame=None):
if not self.keep_running: """
return Tears down the connection.
If close_frame is set, we will invoke the on_close handler with the
statusCode and reason from there.
"""
if thread and thread.isAlive(): if thread and thread.isAlive():
event.set() event.set()
thread.join() thread.join()
self.keep_running = False self.keep_running = False
if self.sock:
self.sock.close() self.sock.close()
close_args = self._get_close_args( close_args = self._get_close_args(
close_frame.data if close_frame else None) close_frame.data if close_frame else None)
...@@ -223,15 +242,17 @@ class WebSocketApp(object): ...@@ -223,15 +242,17 @@ class WebSocketApp(object):
try: try:
self.sock = WebSocket( self.sock = WebSocket(
self.get_mask_key, sockopt=sockopt, sslopt=sslopt, self.get_mask_key, sockopt=sockopt, sslopt=sslopt,
fire_cont_frame=self.on_cont_message and True or False, fire_cont_frame=self.on_cont_message is not None,
skip_utf8_validation=skip_utf8_validation) skip_utf8_validation=skip_utf8_validation,
enable_multithread=True if ping_interval else False)
self.sock.settimeout(getdefaulttimeout()) self.sock.settimeout(getdefaulttimeout())
self.sock.connect( self.sock.connect(
self.url, header=self.header, cookie=self.cookie, self.url, header=self.header, cookie=self.cookie,
http_proxy_host=http_proxy_host, http_proxy_host=http_proxy_host,
http_proxy_port=http_proxy_port, http_no_proxy=http_no_proxy, http_proxy_port=http_proxy_port, http_no_proxy=http_no_proxy,
http_proxy_auth=http_proxy_auth, subprotocols=self.subprotocols, http_proxy_auth=http_proxy_auth, subprotocols=self.subprotocols,
host=host, origin=origin) host=host, origin=origin, suppress_origin=suppress_origin,
proxy_type=proxy_type)
if not dispatcher: if not dispatcher:
dispatcher = self.create_dispatcher(ping_timeout) dispatcher = self.create_dispatcher(ping_timeout)
...@@ -250,8 +271,7 @@ class WebSocketApp(object): ...@@ -250,8 +271,7 @@ class WebSocketApp(object):
op_code, frame = self.sock.recv_data_frame(True) op_code, frame = self.sock.recv_data_frame(True)
if op_code == ABNF.OPCODE_CLOSE: if op_code == ABNF.OPCODE_CLOSE:
close_frame = frame return teardown(frame)
return teardown()
elif op_code == ABNF.OPCODE_PING: elif op_code == ABNF.OPCODE_PING:
self._callback(self.on_ping, frame.data) self._callback(self.on_ping, frame.data)
elif op_code == ABNF.OPCODE_PONG: elif op_code == ABNF.OPCODE_PONG:
...@@ -269,31 +289,39 @@ class WebSocketApp(object): ...@@ -269,31 +289,39 @@ class WebSocketApp(object):
self._callback(self.on_data, data, frame.opcode, True) self._callback(self.on_data, data, frame.opcode, True)
self._callback(self.on_message, data) self._callback(self.on_message, data)
if ping_timeout and self.last_ping_tm \ return True
and time.time() - self.last_ping_tm > ping_timeout \
and self.last_ping_tm - self.last_pong_tm > ping_timeout: def check():
if (ping_timeout):
has_timeout_expired = time.time() - self.last_ping_tm > ping_timeout
has_pong_not_arrived_after_last_ping = self.last_pong_tm - self.last_ping_tm < 0
has_pong_arrived_too_late = self.last_pong_tm - self.last_ping_tm > ping_timeout
if (self.last_ping_tm
and has_timeout_expired
and (has_pong_not_arrived_after_last_ping or has_pong_arrived_too_late)):
raise WebSocketTimeoutException("ping/pong timed out") raise WebSocketTimeoutException("ping/pong timed out")
return True return True
dispatcher.read(self.sock.sock, read) dispatcher.read(self.sock.sock, read, check)
except (Exception, KeyboardInterrupt, SystemExit) as e: except (Exception, KeyboardInterrupt, SystemExit) as e:
self._callback(self.on_error, e) self._callback(self.on_error, e)
if isinstance(e, SystemExit): if isinstance(e, SystemExit):
# propagate SystemExit further # propagate SystemExit further
raise raise
teardown() teardown()
return not isinstance(e, KeyboardInterrupt)
def create_dispatcher(self, ping_timeout): def create_dispatcher(self, ping_timeout):
timeout = ping_timeout or 10 timeout = ping_timeout or 10
if self.sock.is_ssl(): if self.sock.is_ssl():
return SSLDispacther(self, timeout) return SSLDispatcher(self, timeout)
return Dispatcher(self, timeout) return Dispatcher(self, timeout)
def _get_close_args(self, data): def _get_close_args(self, data):
""" this functions extracts the code, reason from the close body """ this functions extracts the code, reason from the close body
if they exists, and if the self.on_close except three arguments """ if they exists, and if the self.on_close except three arguments """
import inspect
# if the on_close callback is "old", just return empty list # if the on_close callback is "old", just return empty list
if sys.version_info < (3, 0): if sys.version_info < (3, 0):
if not self.on_close or len(inspect.getargspec(self.on_close).args) != 3: if not self.on_close or len(inspect.getargspec(self.on_close).args) != 3:
...@@ -312,7 +340,11 @@ class WebSocketApp(object): ...@@ -312,7 +340,11 @@ class WebSocketApp(object):
def _callback(self, callback, *args): def _callback(self, callback, *args):
if callback: if callback:
try: try:
if inspect.ismethod(callback):
callback(*args)
else:
callback(self, *args) callback(self, *args)
except Exception as e: except Exception as e:
_logging.error("error from callback {}: {}".format(callback, e)) _logging.error("error from callback {}: {}".format(callback, e))
if _logging.isEnabledForDebug(): if _logging.isEnabledForDebug():
......
...@@ -24,6 +24,7 @@ from __future__ import print_function ...@@ -24,6 +24,7 @@ from __future__ import print_function
import socket import socket
import struct import struct
import threading import threading
import time
import six import six
...@@ -201,6 +202,7 @@ class WebSocket(object): ...@@ -201,6 +202,7 @@ class WebSocket(object):
options: "header" -> custom http header list or dict. options: "header" -> custom http header list or dict.
"cookie" -> cookie value. "cookie" -> cookie value.
"origin" -> custom origin url. "origin" -> custom origin url.
"suppress_origin" -> suppress outputting origin header.
"host" -> custom host header string. "host" -> custom host header string.
"http_proxy_host" - http proxy host name. "http_proxy_host" - http proxy host name.
"http_proxy_port" - http proxy port. If not set, set to 80. "http_proxy_port" - http proxy port. If not set, set to 80.
...@@ -208,16 +210,27 @@ class WebSocket(object): ...@@ -208,16 +210,27 @@ class WebSocket(object):
"http_proxy_auth" - http proxy auth information. "http_proxy_auth" - http proxy auth information.
tuple of username and password. tuple of username and password.
default is None default is None
"redirect_limit" -> number of redirects to follow.
"subprotocols" - array of available sub protocols. "subprotocols" - array of available sub protocols.
default is None. default is None.
"socket" - pre-initialized stream socket. "socket" - pre-initialized stream socket.
""" """
# FIXME: "subprotocols" are getting lost, not passed down
# FIXME: "header", "cookie", "origin" and "host" too
self.sock_opt.timeout = options.get('timeout', self.sock_opt.timeout)
self.sock, addrs = connect(url, self.sock_opt, proxy_info(**options), self.sock, addrs = connect(url, self.sock_opt, proxy_info(**options),
options.pop('socket', None)) options.pop('socket', None))
try: try:
self.handshake_response = handshake(self.sock, *addrs, **options) self.handshake_response = handshake(self.sock, *addrs, **options)
for attempt in range(options.pop('redirect_limit', 3)):
if self.handshake_response.status in SUPPORTED_REDIRECT_STATUSES:
url = self.handshake_response.headers['location']
self.sock.close()
self.sock, addrs = connect(url, self.sock_opt, proxy_info(**options),
options.pop('socket', None))
self.handshake_response = handshake(self.sock, *addrs, **options)
self.connected = True self.connected = True
except: except:
if self.sock: if self.sock:
...@@ -258,6 +271,7 @@ class WebSocket(object): ...@@ -258,6 +271,7 @@ class WebSocket(object):
frame.get_mask_key = self.get_mask_key frame.get_mask_key = self.get_mask_key
data = frame.format() data = frame.format()
length = len(data) length = len(data)
if (isEnabledForTrace()):
trace("send: " + repr(data)) trace("send: " + repr(data))
with self.lock: with self.lock:
...@@ -397,14 +411,19 @@ class WebSocket(object): ...@@ -397,14 +411,19 @@ class WebSocket(object):
reason, ABNF.OPCODE_CLOSE) reason, ABNF.OPCODE_CLOSE)
sock_timeout = self.sock.gettimeout() sock_timeout = self.sock.gettimeout()
self.sock.settimeout(timeout) self.sock.settimeout(timeout)
start_time = time.time()
while timeout is None or time.time() - start_time < timeout:
try: try:
frame = self.recv_frame() frame = self.recv_frame()
if frame.opcode != ABNF.OPCODE_CLOSE:
continue
if isEnabledForError(): if isEnabledForError():
recv_status = struct.unpack("!H", frame.data[0:2])[0] recv_status = struct.unpack("!H", frame.data[0:2])[0]
if recv_status != STATUS_NORMAL: if recv_status != STATUS_NORMAL:
error("close status: " + repr(recv_status)) error("close status: " + repr(recv_status))
break
except: except:
pass break
self.sock.settimeout(sock_timeout) self.sock.settimeout(sock_timeout)
self.sock.shutdown(socket.SHUT_RDWR) self.sock.shutdown(socket.SHUT_RDWR)
except: except:
...@@ -466,6 +485,7 @@ def create_connection(url, timeout=None, class_=WebSocket, **options): ...@@ -466,6 +485,7 @@ def create_connection(url, timeout=None, class_=WebSocket, **options):
options: "header" -> custom http header list or dict. options: "header" -> custom http header list or dict.
"cookie" -> cookie value. "cookie" -> cookie value.
"origin" -> custom origin url. "origin" -> custom origin url.
"suppress_origin" -> suppress outputting origin header.
"host" -> custom host header string. "host" -> custom host header string.
"http_proxy_host" - http proxy host name. "http_proxy_host" - http proxy host name.
"http_proxy_port" - http proxy port. If not set, set to 80. "http_proxy_port" - http proxy port. If not set, set to 80.
...@@ -474,6 +494,7 @@ def create_connection(url, timeout=None, class_=WebSocket, **options): ...@@ -474,6 +494,7 @@ def create_connection(url, timeout=None, class_=WebSocket, **options):
tuple of username and password. tuple of username and password.
default is None default is None
"enable_multithread" -> enable lock for multithread. "enable_multithread" -> enable lock for multithread.
"redirect_limit" -> number of redirects to follow.
"sockopt" -> socket options "sockopt" -> socket options
"sslopt" -> ssl option "sslopt" -> ssl option
"subprotocols" - array of available sub protocols. "subprotocols" - array of available sub protocols.
......
...@@ -74,11 +74,12 @@ class WebSocketBadStatusException(WebSocketException): ...@@ -74,11 +74,12 @@ class WebSocketBadStatusException(WebSocketException):
WebSocketBadStatusException will be raised when we get bad handshake status code. WebSocketBadStatusException will be raised when we get bad handshake status code.
""" """
def __init__(self, message, status_code, status_message=None): def __init__(self, message, status_code, status_message=None, resp_headers=None):
msg = message % (status_code, status_message) if status_message is not None \ msg = message % (status_code, status_message)
else message % status_code
super(WebSocketBadStatusException, self).__init__(msg) super(WebSocketBadStatusException, self).__init__(msg)
self.status_code = status_code self.status_code = status_code
self.resp_headers = resp_headers
class WebSocketAddressException(WebSocketException): class WebSocketAddressException(WebSocketException):
""" """
......
...@@ -31,12 +31,20 @@ from ._http import * ...@@ -31,12 +31,20 @@ from ._http import *
from ._logging import * from ._logging import *
from ._socket import * from ._socket import *
if six.PY3: if hasattr(six, 'PY3') and six.PY3:
from base64 import encodebytes as base64encode from base64 import encodebytes as base64encode
else: else:
from base64 import encodestring as base64encode from base64 import encodestring as base64encode
__all__ = ["handshake_response", "handshake"] if hasattr(six, 'PY3') and six.PY3:
if hasattr(six, 'PY34') and six.PY34:
from http import client as HTTPStatus
else:
from http import HTTPStatus
else:
import httplib as HTTPStatus
__all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"]
if hasattr(hmac, "compare_digest"): if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest compare_digest = hmac.compare_digest
...@@ -47,6 +55,9 @@ else: ...@@ -47,6 +55,9 @@ else:
# websocket supported version. # websocket supported version.
VERSION = 13 VERSION = 13
SUPPORTED_REDIRECT_STATUSES = (HTTPStatus.MOVED_PERMANENTLY, HTTPStatus.FOUND, HTTPStatus.SEE_OTHER,)
SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,)
CookieJar = SimpleCookieJar() CookieJar = SimpleCookieJar()
...@@ -67,12 +78,15 @@ def handshake(sock, hostname, port, resource, **options): ...@@ -67,12 +78,15 @@ def handshake(sock, hostname, port, resource, **options):
dump("request header", header_str) dump("request header", header_str)
status, resp = _get_resp_headers(sock) status, resp = _get_resp_headers(sock)
if status in SUPPORTED_REDIRECT_STATUSES:
return handshake_response(status, resp, None)
success, subproto = _validate(resp, key, options.get("subprotocols")) success, subproto = _validate(resp, key, options.get("subprotocols"))
if not success: if not success:
raise WebSocketException("Invalid WebSocket Header") raise WebSocketException("Invalid WebSocket Header")
return handshake_response(status, resp, subproto) return handshake_response(status, resp, subproto)
def _pack_hostname(hostname): def _pack_hostname(hostname):
# IPv6 address # IPv6 address
if ':' in hostname: if ':' in hostname:
...@@ -83,28 +97,40 @@ def _pack_hostname(hostname): ...@@ -83,28 +97,40 @@ def _pack_hostname(hostname):
def _get_handshake_headers(resource, host, port, options): def _get_handshake_headers(resource, host, port, options):
headers = [ headers = [
"GET %s HTTP/1.1" % resource, "GET %s HTTP/1.1" % resource,
"Upgrade: websocket", "Upgrade: websocket"
"Connection: Upgrade"
] ]
if port == 80 or port == 443: if port == 80 or port == 443:
hostport = _pack_hostname(host) hostport = _pack_hostname(host)
else: else:
hostport = "%s:%d" % (_pack_hostname(host), port) hostport = "%s:%d" % (_pack_hostname(host), port)
if "host" in options and options["host"] is not None: if "host" in options and options["host"] is not None:
headers.append("Host: %s" % options["host"]) headers.append("Host: %s" % options["host"])
else: else:
headers.append("Host: %s" % hostport) headers.append("Host: %s" % hostport)
if "suppress_origin" not in options or not options["suppress_origin"]:
if "origin" in options and options["origin"] is not None: if "origin" in options and options["origin"] is not None:
headers.append("Origin: %s" % options["origin"]) headers.append("Origin: %s" % options["origin"])
else: else:
headers.append("Origin: http://%s" % hostport) headers.append("Origin: http://%s" % hostport)
key = _create_sec_websocket_key() key = _create_sec_websocket_key()
# Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified
if not 'header' in options or 'Sec-WebSocket-Key' not in options['header']:
key = _create_sec_websocket_key()
headers.append("Sec-WebSocket-Key: %s" % key) headers.append("Sec-WebSocket-Key: %s" % key)
else:
key = options['header']['Sec-WebSocket-Key']
if not 'header' in options or 'Sec-WebSocket-Version' not in options['header']:
headers.append("Sec-WebSocket-Version: %s" % VERSION) headers.append("Sec-WebSocket-Version: %s" % VERSION)
if not 'connection' in options or options['connection'] is None:
headers.append('Connection: upgrade')
else:
headers.append(options['connection'])
subprotocols = options.get("subprotocols") subprotocols = options.get("subprotocols")
if subprotocols: if subprotocols:
headers.append("Sec-WebSocket-Protocol: %s" % ",".join(subprotocols)) headers.append("Sec-WebSocket-Protocol: %s" % ",".join(subprotocols))
...@@ -112,7 +138,11 @@ def _get_handshake_headers(resource, host, port, options): ...@@ -112,7 +138,11 @@ def _get_handshake_headers(resource, host, port, options):
if "header" in options: if "header" in options:
header = options["header"] header = options["header"]
if isinstance(header, dict): if isinstance(header, dict):
header = map(": ".join, header.items()) header = [
": ".join([k, v])
for k, v in header.items()
if v is not None
]
headers.extend(header) headers.extend(header)
server_cookie = CookieJar.get(host) server_cookie = CookieJar.get(host)
...@@ -129,12 +159,13 @@ def _get_handshake_headers(resource, host, port, options): ...@@ -129,12 +159,13 @@ def _get_handshake_headers(resource, host, port, options):
return headers, key return headers, key
def _get_resp_headers(sock, success_status=101): def _get_resp_headers(sock, success_statuses=SUCCESS_STATUSES):
status, resp_headers, status_message = read_headers(sock) status, resp_headers, status_message = read_headers(sock)
if status != success_status: if status not in success_statuses:
raise WebSocketBadStatusException("Handshake status %d %s", status, status_message) raise WebSocketBadStatusException("Handshake status %d %s", status, status_message, resp_headers)
return status, resp_headers return status, resp_headers
_HEADERS_TO_CHECK = { _HEADERS_TO_CHECK = {
"upgrade": "websocket", "upgrade": "websocket",
"connection": "upgrade", "connection": "upgrade",
......
...@@ -39,21 +39,72 @@ else: ...@@ -39,21 +39,72 @@ else:
__all__ = ["proxy_info", "connect", "read_headers"] __all__ = ["proxy_info", "connect", "read_headers"]
try:
import socks
ProxyConnectionError = socks.ProxyConnectionError
HAS_PYSOCKS = True
except:
class ProxyConnectionError(BaseException):
pass
HAS_PYSOCKS = False
class proxy_info(object): class proxy_info(object):
def __init__(self, **options): def __init__(self, **options):
self.type = options.get("proxy_type") or "http"
if not(self.type in ['http', 'socks4', 'socks5', 'socks5h']):
raise ValueError("proxy_type must be 'http', 'socks4', 'socks5' or 'socks5h'")
self.host = options.get("http_proxy_host", None) self.host = options.get("http_proxy_host", None)
if self.host: if self.host:
self.port = options.get("http_proxy_port", 0) self.port = options.get("http_proxy_port", 0)
self.auth = options.get("http_proxy_auth", None) self.auth = options.get("http_proxy_auth", None)
self.no_proxy = options.get("http_no_proxy", None)
else: else:
self.port = 0 self.port = 0
self.auth = None self.auth = None
self.no_proxy = options.get("http_no_proxy", None) self.no_proxy = None
def _open_proxied_socket(url, options, proxy):
hostname, port, resource, is_secure = parse_url(url)
if not HAS_PYSOCKS:
raise WebSocketException("PySocks module not found.")
ptype = socks.SOCKS5
rdns = False
if proxy.type == "socks4":
ptype = socks.SOCKS4
if proxy.type == "http":
ptype = socks.HTTP
if proxy.type[-1] == "h":
rdns = True
sock = socks.create_connection(
(hostname, port),
proxy_type = ptype,
proxy_addr = proxy.host,
proxy_port = proxy.port,
proxy_rdns = rdns,
proxy_username = proxy.auth[0] if proxy.auth else None,
proxy_password = proxy.auth[1] if proxy.auth else None,
timeout = options.timeout,
socket_options = DEFAULT_SOCKET_OPTION + options.sockopt
)
if is_secure:
if HAVE_SSL:
sock = _ssl_socket(sock, options.sslopt, hostname)
else:
raise WebSocketException("SSL not available.")
return sock, (hostname, port, resource)
def connect(url, options, proxy, socket): def connect(url, options, proxy, socket):
if proxy.host and not socket and not (proxy.type == 'http'):
return _open_proxied_socket(url, options, proxy)
hostname, port, resource, is_secure = parse_url(url) hostname, port, resource, is_secure = parse_url(url)
if socket: if socket:
...@@ -88,13 +139,20 @@ def _get_addrinfo_list(hostname, port, is_secure, proxy): ...@@ -88,13 +139,20 @@ def _get_addrinfo_list(hostname, port, is_secure, proxy):
phost, pport, pauth = get_proxy_info( phost, pport, pauth = get_proxy_info(
hostname, is_secure, proxy.host, proxy.port, proxy.auth, proxy.no_proxy) hostname, is_secure, proxy.host, proxy.port, proxy.auth, proxy.no_proxy)
try: try:
# when running on windows 10, getaddrinfo without socktype returns a socktype 0.
# This generates an error exception: `_on_error: exception Socket type must be stream or datagram, not 0`
# or `OSError: [Errno 22] Invalid argument` when creating socket. Force the socket type to SOCK_STREAM.
if not phost: if not phost:
addrinfo_list = socket.getaddrinfo( addrinfo_list = socket.getaddrinfo(
hostname, port, 0, 0, socket.SOL_TCP) hostname, port, 0, socket.SOCK_STREAM, socket.SOL_TCP)
return addrinfo_list, False, None return addrinfo_list, False, None
else: else:
pport = pport and pport or 80 pport = pport and pport or 80
addrinfo_list = socket.getaddrinfo(phost, pport, 0, 0, socket.SOL_TCP) # when running on windows 10, the getaddrinfo used above
# returns a socktype 0. This generates an error exception:
# _on_error: exception Socket type must be stream or datagram, not 0
# Force the socket type to SOCK_STREAM
addrinfo_list = socket.getaddrinfo(phost, pport, 0, socket.SOCK_STREAM, socket.SOL_TCP)
return addrinfo_list, True, pauth return addrinfo_list, True, pauth
except socket.gaierror as e: except socket.gaierror as e:
raise WebSocketAddressException(e) raise WebSocketAddressException(e)
...@@ -106,28 +164,40 @@ def _open_socket(addrinfo_list, sockopt, timeout): ...@@ -106,28 +164,40 @@ def _open_socket(addrinfo_list, sockopt, timeout):
family, socktype, proto = addrinfo[:3] family, socktype, proto = addrinfo[:3]
sock = socket.socket(family, socktype, proto) sock = socket.socket(family, socktype, proto)
sock.settimeout(timeout) sock.settimeout(timeout)
# for opts in DEFAULT_SOCKET_OPTION: for opts in DEFAULT_SOCKET_OPTION:
# sock.setsockopt(*opts) sock.setsockopt(*opts)
# for opts in sockopt: for opts in sockopt:
# sock.setsockopt(*opts) sock.setsockopt(*opts)
address = addrinfo[4] address = addrinfo[4]
err = None
while not err:
try: try:
sock.connect(address) sock.connect(address)
except ProxyConnectionError as error:
err = WebSocketProxyException(str(error))
err.remote_ip = str(address[0])
continue
except socket.error as error: except socket.error as error:
error.remote_ip = str(address[0]) error.remote_ip = str(address[0])
try: try:
eConnRefused = (errno.ECONNREFUSED, errno.WSAECONNREFUSED) eConnRefused = (errno.ECONNREFUSED, errno.WSAECONNREFUSED)
except: except:
eConnRefused = (errno.ECONNREFUSED, ) eConnRefused = (errno.ECONNREFUSED, )
if error.errno in eConnRefused: if error.errno == errno.EINTR:
continue
elif error.errno in eConnRefused:
err = error err = error
continue continue
else: else:
raise raise error
else:
break
else: else:
continue
break break
else: else:
if err:
raise err raise err
return sock return sock
...@@ -141,7 +211,12 @@ def _wrap_sni_socket(sock, sslopt, hostname, check_hostname): ...@@ -141,7 +211,12 @@ def _wrap_sni_socket(sock, sslopt, hostname, check_hostname):
context = ssl.SSLContext(sslopt.get('ssl_version', ssl.PROTOCOL_SSLv23)) context = ssl.SSLContext(sslopt.get('ssl_version', ssl.PROTOCOL_SSLv23))
if sslopt.get('cert_reqs', ssl.CERT_NONE) != ssl.CERT_NONE: if sslopt.get('cert_reqs', ssl.CERT_NONE) != ssl.CERT_NONE:
context.load_verify_locations(cafile=sslopt.get('ca_certs', None), capath=sslopt.get('ca_cert_path', None)) cafile = sslopt.get('ca_certs', None)
capath = sslopt.get('ca_cert_path', None)
if cafile or capath:
context.load_verify_locations(cafile=cafile, capath=capath)
elif hasattr(context, 'load_default_certs'):
context.load_default_certs(ssl.Purpose.SERVER_AUTH)
if sslopt.get('certfile', None): if sslopt.get('certfile', None):
context.load_cert_chain( context.load_cert_chain(
sslopt['certfile'], sslopt['certfile'],
...@@ -173,15 +248,13 @@ def _ssl_socket(sock, user_sslopt, hostname): ...@@ -173,15 +248,13 @@ def _ssl_socket(sock, user_sslopt, hostname):
sslopt = dict(cert_reqs=ssl.CERT_REQUIRED) sslopt = dict(cert_reqs=ssl.CERT_REQUIRED)
sslopt.update(user_sslopt) sslopt.update(user_sslopt)
if os.environ.get('WEBSOCKET_CLIENT_CA_BUNDLE'):
certPath = os.environ.get('WEBSOCKET_CLIENT_CA_BUNDLE') certPath = os.environ.get('WEBSOCKET_CLIENT_CA_BUNDLE')
else: if certPath and os.path.isfile(certPath) \
certPath = os.path.join( and user_sslopt.get('ca_certs', None) is None \
os.path.dirname(__file__), "cacert.pem")
if os.path.isfile(certPath) and user_sslopt.get('ca_certs', None) is None \
and user_sslopt.get('ca_cert', None) is None: and user_sslopt.get('ca_cert', None) is None:
sslopt['ca_certs'] = certPath sslopt['ca_certs'] = certPath
elif os.path.isdir(certPath) and user_sslopt.get('ca_cert_path', None) is None: elif certPath and os.path.isdir(certPath) \
and user_sslopt.get('ca_cert_path', None) is None:
sslopt['ca_cert_path'] = certPath sslopt['ca_cert_path'] = certPath
check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop( check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop(
...@@ -207,7 +280,7 @@ def _tunnel(sock, host, port, auth): ...@@ -207,7 +280,7 @@ def _tunnel(sock, host, port, auth):
auth_str = auth[0] auth_str = auth[0]
if auth[1]: if auth[1]:
auth_str += ":" + auth[1] auth_str += ":" + auth[1]
encoded_str = base64encode(auth_str.encode()).strip().decode() encoded_str = base64encode(auth_str.encode()).strip().decode().replace('\n', '')
connect_header += "Proxy-Authorization: Basic %s\r\n" % encoded_str connect_header += "Proxy-Authorization: Basic %s\r\n" % encoded_str
connect_header += "\r\n" connect_header += "\r\n"
dump("request header", connect_header) dump("request header", connect_header)
...@@ -242,6 +315,7 @@ def read_headers(sock): ...@@ -242,6 +315,7 @@ def read_headers(sock):
status_info = line.split(" ", 2) status_info = line.split(" ", 2)
status = int(status_info[1]) status = int(status_info[1])
if len(status_info) > 2:
status_message = status_info[2] status_message = status_info[2]
else: else:
kv = line.split(":", 1) kv = line.split(":", 1)
......
...@@ -22,13 +22,22 @@ Copyright (C) 2010 Hiroki Ohtani(liris) ...@@ -22,13 +22,22 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
import logging import logging
_logger = logging.getLogger('websocket') _logger = logging.getLogger('websocket')
try:
from logging import NullHandler
except ImportError:
class NullHandler(logging.Handler):
def emit(self, record):
pass
_logger.addHandler(NullHandler())
_traceEnabled = False _traceEnabled = False
__all__ = ["enableTrace", "dump", "error", "warning", "debug", "trace", __all__ = ["enableTrace", "dump", "error", "warning", "debug", "trace",
"isEnabledForError", "isEnabledForDebug"] "isEnabledForError", "isEnabledForDebug", "isEnabledForTrace"]
def enableTrace(traceable): def enableTrace(traceable, handler = logging.StreamHandler()):
""" """
turn on/off the traceability. turn on/off the traceability.
...@@ -37,11 +46,9 @@ def enableTrace(traceable): ...@@ -37,11 +46,9 @@ def enableTrace(traceable):
global _traceEnabled global _traceEnabled
_traceEnabled = traceable _traceEnabled = traceable
if traceable: if traceable:
if not _logger.handlers: _logger.addHandler(handler)
_logger.addHandler(logging.StreamHandler())
_logger.setLevel(logging.DEBUG) _logger.setLevel(logging.DEBUG)
def dump(title, message): def dump(title, message):
if _traceEnabled: if _traceEnabled:
_logger.debug("--- " + title + " ---") _logger.debug("--- " + title + " ---")
...@@ -72,3 +79,6 @@ def isEnabledForError(): ...@@ -72,3 +79,6 @@ def isEnabledForError():
def isEnabledForDebug(): def isEnabledForDebug():
return _logger.isEnabledFor(logging.DEBUG) return _logger.isEnabledFor(logging.DEBUG)
def isEnabledForTrace():
return _traceEnabled
...@@ -19,6 +19,8 @@ Copyright (C) 2010 Hiroki Ohtani(liris) ...@@ -19,6 +19,8 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA Boston, MA 02110-1335 USA
""" """
import errno
import select
import socket import socket
import six import six
...@@ -77,8 +79,27 @@ def recv(sock, bufsize): ...@@ -77,8 +79,27 @@ def recv(sock, bufsize):
if not sock: if not sock:
raise WebSocketConnectionClosedException("socket is already closed.") raise WebSocketConnectionClosedException("socket is already closed.")
def _recv():
try: try:
return sock.recv(bufsize)
except SSLWantReadError:
pass
except socket.error as exc:
error_code = extract_error_code(exc)
if error_code is None:
raise
if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK:
raise
r, w, e = select.select((sock, ), (), (), sock.gettimeout())
if r:
return sock.recv(bufsize)
try:
if sock.gettimeout() == 0:
bytes_ = sock.recv(bufsize) bytes_ = sock.recv(bufsize)
else:
bytes_ = _recv()
except socket.timeout as e: except socket.timeout as e:
message = extract_err_message(e) message = extract_err_message(e)
raise WebSocketTimeoutException(message) raise WebSocketTimeoutException(message)
...@@ -113,8 +134,27 @@ def send(sock, data): ...@@ -113,8 +134,27 @@ def send(sock, data):
if not sock: if not sock:
raise WebSocketConnectionClosedException("socket is already closed.") raise WebSocketConnectionClosedException("socket is already closed.")
def _send():
try:
return sock.send(data)
except SSLWantWriteError:
pass
except socket.error as exc:
error_code = extract_error_code(exc)
if error_code is None:
raise
if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK:
raise
r, w, e = select.select((), (sock, ), (), sock.gettimeout())
if w:
return sock.send(data)
try: try:
if sock.gettimeout() == 0:
return sock.send(data) return sock.send(data)
else:
return _send()
except socket.timeout as e: except socket.timeout as e:
message = extract_err_message(e) message = extract_err_message(e)
raise WebSocketTimeoutException(message) raise WebSocketTimeoutException(message)
......
...@@ -19,11 +19,13 @@ Copyright (C) 2010 Hiroki Ohtani(liris) ...@@ -19,11 +19,13 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA Boston, MA 02110-1335 USA
""" """
__all__ = ["HAVE_SSL", "ssl", "SSLError"] __all__ = ["HAVE_SSL", "ssl", "SSLError", "SSLWantReadError", "SSLWantWriteError"]
try: try:
import ssl import ssl
from ssl import SSLError from ssl import SSLError
from ssl import SSLWantReadError
from ssl import SSLWantWriteError
if hasattr(ssl, 'SSLContext') and hasattr(ssl.SSLContext, 'check_hostname'): if hasattr(ssl, 'SSLContext') and hasattr(ssl.SSLContext, 'check_hostname'):
HAVE_CONTEXT_CHECK_HOSTNAME = True HAVE_CONTEXT_CHECK_HOSTNAME = True
else: else:
...@@ -41,4 +43,12 @@ except ImportError: ...@@ -41,4 +43,12 @@ except ImportError:
class SSLError(Exception): class SSLError(Exception):
pass pass
class SSLWantReadError(Exception):
pass
class SSLWantWriteError(Exception):
pass
ssl = lambda: None
HAVE_SSL = False HAVE_SSL = False
...@@ -103,6 +103,7 @@ def _is_address_in_network(ip, net): ...@@ -103,6 +103,7 @@ def _is_address_in_network(ip, net):
def _is_no_proxy_host(hostname, no_proxy): def _is_no_proxy_host(hostname, no_proxy):
if not no_proxy: if not no_proxy:
v = os.environ.get("no_proxy", "").replace(" ", "") v = os.environ.get("no_proxy", "").replace(" ", "")
if v:
no_proxy = v.split(",") no_proxy = v.split(",")
if not no_proxy: if not no_proxy:
no_proxy = DEFAULT_NO_PROXY_HOST no_proxy = DEFAULT_NO_PROXY_HOST
...@@ -117,7 +118,7 @@ def _is_no_proxy_host(hostname, no_proxy): ...@@ -117,7 +118,7 @@ def _is_no_proxy_host(hostname, no_proxy):
def get_proxy_info( def get_proxy_info(
hostname, is_secure, proxy_host=None, proxy_port=0, proxy_auth=None, hostname, is_secure, proxy_host=None, proxy_port=0, proxy_auth=None,
no_proxy=None): no_proxy=None, proxy_type='http'):
""" """
try to retrieve proxy host and port from environment try to retrieve proxy host and port from environment
if not provided in options. if not provided in options.
...@@ -137,6 +138,9 @@ def get_proxy_info( ...@@ -137,6 +138,9 @@ def get_proxy_info(
"http_proxy_auth" - http proxy auth information. "http_proxy_auth" - http proxy auth information.
tuple of username and password. tuple of username and password.
default is None default is None
"proxy_type" - if set to "socks5" PySocks wrapper
will be used in place of a http proxy.
default is "http"
""" """
if _is_no_proxy_host(hostname, no_proxy): if _is_no_proxy_host(hostname, no_proxy):
return None, 0, None return None, 0, None
......
...@@ -21,7 +21,7 @@ Copyright (C) 2010 Hiroki Ohtani(liris) ...@@ -21,7 +21,7 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
""" """
import six import six
__all__ = ["NoLock", "validate_utf8", "extract_err_message"] __all__ = ["NoLock", "validate_utf8", "extract_err_message", "extract_error_code"]
class NoLock(object): class NoLock(object):
...@@ -32,6 +32,7 @@ class NoLock(object): ...@@ -32,6 +32,7 @@ class NoLock(object):
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
pass pass
try: try:
# If wsaccel is available we use compiled routines to validate UTF-8 # If wsaccel is available we use compiled routines to validate UTF-8
# strings. # strings.
...@@ -103,3 +104,8 @@ def extract_err_message(exception): ...@@ -103,3 +104,8 @@ def extract_err_message(exception):
return exception.args[0] return exception.args[0]
else: else:
return None return None
def extract_error_code(exception):
if exception.args and len(exception.args) > 1:
return exception.args[0] if isinstance(exception.args[0], int) else None
HTTP/1.1 101 WebSocket Protocol Handshake
Connection: Upgrade
Upgrade: WebSocket
Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0=
some_header: something
HTTP/1.1 101 WebSocket Protocol Handshake
Connection: Upgrade
Upgrade WebSocket
Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0=
some_header: something
import unittest
from websocket._cookiejar import SimpleCookieJar
try:
import Cookie
except:
import http.cookies as Cookie
class CookieJarTest(unittest.TestCase):
def testAdd(self):
cookie_jar = SimpleCookieJar()
cookie_jar.add("")
self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b")
self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; domain=.abc")
self.assertTrue(".abc" in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; domain=abc")
self.assertTrue(".abc" in cookie_jar.jar)
self.assertTrue("abc" not in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
cookie_jar.add("e=f; domain=abc")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d; e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
cookie_jar.add("e=f; domain=.abc")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d; e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
cookie_jar.add("e=f; domain=xyz")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d")
self.assertEquals(cookie_jar.get("xyz"), "e=f")
self.assertEquals(cookie_jar.get("something"), "")
def testSet(self):
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b")
self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; domain=.abc")
self.assertTrue(".abc" in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; domain=abc")
self.assertTrue(".abc" in cookie_jar.jar)
self.assertTrue("abc" not in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
cookie_jar.set("e=f; domain=abc")
self.assertEquals(cookie_jar.get("abc"), "e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
cookie_jar.set("e=f; domain=.abc")
self.assertEquals(cookie_jar.get("abc"), "e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
cookie_jar.set("e=f; domain=xyz")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d")
self.assertEquals(cookie_jar.get("xyz"), "e=f")
self.assertEquals(cookie_jar.get("something"), "")
def testGet(self):
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc.com")
self.assertEquals(cookie_jar.get("abc.com"), "a=b; c=d")
self.assertEquals(cookie_jar.get("x.abc.com"), "a=b; c=d")
self.assertEquals(cookie_jar.get("abc.com.es"), "")
self.assertEquals(cookie_jar.get("xabc.com"), "")
cookie_jar.set("a=b; c=d; domain=.abc.com")
self.assertEquals(cookie_jar.get("abc.com"), "a=b; c=d")
self.assertEquals(cookie_jar.get("x.abc.com"), "a=b; c=d")
self.assertEquals(cookie_jar.get("abc.com.es"), "")
self.assertEquals(cookie_jar.get("xabc.com"), "")
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment