Commit 54dae43b authored by l2m2's avatar l2m2

upload.

parent 3ca68ace
......@@ -26,4 +26,4 @@ from ._exceptions import *
from ._logging import *
from ._socket import *
__version__ = "0.47.0"
__version__ = "0.57.0"
......@@ -23,6 +23,7 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
"""
WebSocketApp provides higher level APIs.
"""
import inspect
import select
import sys
import threading
......@@ -44,23 +45,27 @@ class Dispatcher:
self.app = app
self.ping_timeout = ping_timeout
def read(self, sock, callback):
while self.app.sock.connected:
def read(self, sock, read_callback, check_callback):
while self.app.keep_running:
r, w, e = select.select(
(self.app.sock.sock, ), (), (), self.ping_timeout) # Use a 10 second timeout to avoid to wait forever on close
(self.app.sock.sock, ), (), (), self.ping_timeout)
if r:
callback()
if not read_callback():
break
check_callback()
class SSLDispacther:
class SSLDispatcher:
def __init__(self, app, ping_timeout):
self.app = app
self.ping_timeout = ping_timeout
def read(self, sock, callback):
while self.app.sock.connected:
def read(self, sock, read_callback, check_callback):
while self.app.keep_running:
r = self.select()
if r:
callback()
if not read_callback():
break
check_callback()
def select(self):
sock = self.app.sock.sock
......@@ -70,6 +75,7 @@ class SSLDispacther:
r, w, e = select.select((sock, ), (), (), self.ping_timeout)
return r
class WebSocketApp(object):
"""
Higher level of APIs are provided.
......@@ -113,7 +119,7 @@ class WebSocketApp(object):
The 2nd argument is utf-8 string which we get from the server.
The 3rd argument is data type. ABNF.OPCODE_TEXT or ABNF.OPCODE_BINARY will be came.
The 4th argument is continue flag. if 0, the data continue
keep_running: this parameter is obosleted and ignored it.
keep_running: this parameter is obsolete and ignored.
get_mask_key: a callable to produce new mask keys,
see the WebSocket.set_mask_key's docstring for more information
subprotocols: array of available sub protocols. default is None.
......@@ -121,6 +127,7 @@ class WebSocketApp(object):
self.url = url
self.header = header if header is not None else []
self.cookie = cookie
self.on_open = on_open
self.on_message = on_message
self.on_data = on_data
......@@ -155,6 +162,7 @@ class WebSocketApp(object):
self.keep_running = False
if self.sock:
self.sock.close(**kwargs)
self.sock = None
def _send_ping(self, interval, event):
while not event.wait(interval):
......@@ -171,7 +179,8 @@ class WebSocketApp(object):
http_proxy_host=None, http_proxy_port=None,
http_no_proxy=None, http_proxy_auth=None,
skip_utf8_validation=False,
host=None, origin=None, dispatcher=None):
host=None, origin=None, dispatcher=None,
suppress_origin=False, proxy_type=None):
"""
run event loop for WebSocket framework.
This loop is infinite loop and is alive during websocket is available.
......@@ -189,31 +198,41 @@ class WebSocketApp(object):
skip_utf8_validation: skip utf8 validation.
host: update host header.
origin: update origin header.
dispatcher: customize reading data from socket.
suppress_origin: suppress outputting origin header.
Returns
-------
False if caught KeyboardInterrupt
True if other exception was raised during a loop
"""
if not ping_timeout or ping_timeout <= 0:
if ping_timeout is not None and ping_timeout <= 0:
ping_timeout = None
if ping_timeout and ping_interval and ping_interval <= ping_timeout:
raise WebSocketException("Ensure ping_interval > ping_timeout")
if sockopt is None:
if not sockopt:
sockopt = []
if sslopt is None:
if not sslopt:
sslopt = {}
if self.sock:
raise WebSocketException("socket is already opened")
thread = None
close_frame = None
self.keep_running = True
self.last_ping_tm = 0
self.last_pong_tm = 0
def teardown():
if not self.keep_running:
return
def teardown(close_frame=None):
"""
Tears down the connection.
If close_frame is set, we will invoke the on_close handler with the
statusCode and reason from there.
"""
if thread and thread.isAlive():
event.set()
thread.join()
self.keep_running = False
if self.sock:
self.sock.close()
close_args = self._get_close_args(
close_frame.data if close_frame else None)
......@@ -223,15 +242,17 @@ class WebSocketApp(object):
try:
self.sock = WebSocket(
self.get_mask_key, sockopt=sockopt, sslopt=sslopt,
fire_cont_frame=self.on_cont_message and True or False,
skip_utf8_validation=skip_utf8_validation)
fire_cont_frame=self.on_cont_message is not None,
skip_utf8_validation=skip_utf8_validation,
enable_multithread=True if ping_interval else False)
self.sock.settimeout(getdefaulttimeout())
self.sock.connect(
self.url, header=self.header, cookie=self.cookie,
http_proxy_host=http_proxy_host,
http_proxy_port=http_proxy_port, http_no_proxy=http_no_proxy,
http_proxy_auth=http_proxy_auth, subprotocols=self.subprotocols,
host=host, origin=origin)
host=host, origin=origin, suppress_origin=suppress_origin,
proxy_type=proxy_type)
if not dispatcher:
dispatcher = self.create_dispatcher(ping_timeout)
......@@ -250,8 +271,7 @@ class WebSocketApp(object):
op_code, frame = self.sock.recv_data_frame(True)
if op_code == ABNF.OPCODE_CLOSE:
close_frame = frame
return teardown()
return teardown(frame)
elif op_code == ABNF.OPCODE_PING:
self._callback(self.on_ping, frame.data)
elif op_code == ABNF.OPCODE_PONG:
......@@ -269,31 +289,39 @@ class WebSocketApp(object):
self._callback(self.on_data, data, frame.opcode, True)
self._callback(self.on_message, data)
if ping_timeout and self.last_ping_tm \
and time.time() - self.last_ping_tm > ping_timeout \
and self.last_ping_tm - self.last_pong_tm > ping_timeout:
return True
def check():
if (ping_timeout):
has_timeout_expired = time.time() - self.last_ping_tm > ping_timeout
has_pong_not_arrived_after_last_ping = self.last_pong_tm - self.last_ping_tm < 0
has_pong_arrived_too_late = self.last_pong_tm - self.last_ping_tm > ping_timeout
if (self.last_ping_tm
and has_timeout_expired
and (has_pong_not_arrived_after_last_ping or has_pong_arrived_too_late)):
raise WebSocketTimeoutException("ping/pong timed out")
return True
dispatcher.read(self.sock.sock, read)
dispatcher.read(self.sock.sock, read, check)
except (Exception, KeyboardInterrupt, SystemExit) as e:
self._callback(self.on_error, e)
if isinstance(e, SystemExit):
# propagate SystemExit further
raise
teardown()
return not isinstance(e, KeyboardInterrupt)
def create_dispatcher(self, ping_timeout):
timeout = ping_timeout or 10
if self.sock.is_ssl():
return SSLDispacther(self, timeout)
return SSLDispatcher(self, timeout)
return Dispatcher(self, timeout)
def _get_close_args(self, data):
""" this functions extracts the code, reason from the close body
if they exists, and if the self.on_close except three arguments """
import inspect
# if the on_close callback is "old", just return empty list
if sys.version_info < (3, 0):
if not self.on_close or len(inspect.getargspec(self.on_close).args) != 3:
......@@ -312,7 +340,11 @@ class WebSocketApp(object):
def _callback(self, callback, *args):
if callback:
try:
if inspect.ismethod(callback):
callback(*args)
else:
callback(self, *args)
except Exception as e:
_logging.error("error from callback {}: {}".format(callback, e))
if _logging.isEnabledForDebug():
......
......@@ -24,6 +24,7 @@ from __future__ import print_function
import socket
import struct
import threading
import time
import six
......@@ -201,6 +202,7 @@ class WebSocket(object):
options: "header" -> custom http header list or dict.
"cookie" -> cookie value.
"origin" -> custom origin url.
"suppress_origin" -> suppress outputting origin header.
"host" -> custom host header string.
"http_proxy_host" - http proxy host name.
"http_proxy_port" - http proxy port. If not set, set to 80.
......@@ -208,16 +210,27 @@ class WebSocket(object):
"http_proxy_auth" - http proxy auth information.
tuple of username and password.
default is None
"redirect_limit" -> number of redirects to follow.
"subprotocols" - array of available sub protocols.
default is None.
"socket" - pre-initialized stream socket.
"""
# FIXME: "subprotocols" are getting lost, not passed down
# FIXME: "header", "cookie", "origin" and "host" too
self.sock_opt.timeout = options.get('timeout', self.sock_opt.timeout)
self.sock, addrs = connect(url, self.sock_opt, proxy_info(**options),
options.pop('socket', None))
try:
self.handshake_response = handshake(self.sock, *addrs, **options)
for attempt in range(options.pop('redirect_limit', 3)):
if self.handshake_response.status in SUPPORTED_REDIRECT_STATUSES:
url = self.handshake_response.headers['location']
self.sock.close()
self.sock, addrs = connect(url, self.sock_opt, proxy_info(**options),
options.pop('socket', None))
self.handshake_response = handshake(self.sock, *addrs, **options)
self.connected = True
except:
if self.sock:
......@@ -258,6 +271,7 @@ class WebSocket(object):
frame.get_mask_key = self.get_mask_key
data = frame.format()
length = len(data)
if (isEnabledForTrace()):
trace("send: " + repr(data))
with self.lock:
......@@ -397,14 +411,19 @@ class WebSocket(object):
reason, ABNF.OPCODE_CLOSE)
sock_timeout = self.sock.gettimeout()
self.sock.settimeout(timeout)
start_time = time.time()
while timeout is None or time.time() - start_time < timeout:
try:
frame = self.recv_frame()
if frame.opcode != ABNF.OPCODE_CLOSE:
continue
if isEnabledForError():
recv_status = struct.unpack("!H", frame.data[0:2])[0]
if recv_status != STATUS_NORMAL:
error("close status: " + repr(recv_status))
break
except:
pass
break
self.sock.settimeout(sock_timeout)
self.sock.shutdown(socket.SHUT_RDWR)
except:
......@@ -466,6 +485,7 @@ def create_connection(url, timeout=None, class_=WebSocket, **options):
options: "header" -> custom http header list or dict.
"cookie" -> cookie value.
"origin" -> custom origin url.
"suppress_origin" -> suppress outputting origin header.
"host" -> custom host header string.
"http_proxy_host" - http proxy host name.
"http_proxy_port" - http proxy port. If not set, set to 80.
......@@ -474,6 +494,7 @@ def create_connection(url, timeout=None, class_=WebSocket, **options):
tuple of username and password.
default is None
"enable_multithread" -> enable lock for multithread.
"redirect_limit" -> number of redirects to follow.
"sockopt" -> socket options
"sslopt" -> ssl option
"subprotocols" - array of available sub protocols.
......
......@@ -74,11 +74,12 @@ class WebSocketBadStatusException(WebSocketException):
WebSocketBadStatusException will be raised when we get bad handshake status code.
"""
def __init__(self, message, status_code, status_message=None):
msg = message % (status_code, status_message) if status_message is not None \
else message % status_code
def __init__(self, message, status_code, status_message=None, resp_headers=None):
msg = message % (status_code, status_message)
super(WebSocketBadStatusException, self).__init__(msg)
self.status_code = status_code
self.resp_headers = resp_headers
class WebSocketAddressException(WebSocketException):
"""
......
......@@ -31,12 +31,20 @@ from ._http import *
from ._logging import *
from ._socket import *
if six.PY3:
if hasattr(six, 'PY3') and six.PY3:
from base64 import encodebytes as base64encode
else:
from base64 import encodestring as base64encode
__all__ = ["handshake_response", "handshake"]
if hasattr(six, 'PY3') and six.PY3:
if hasattr(six, 'PY34') and six.PY34:
from http import client as HTTPStatus
else:
from http import HTTPStatus
else:
import httplib as HTTPStatus
__all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"]
if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest
......@@ -47,6 +55,9 @@ else:
# websocket supported version.
VERSION = 13
SUPPORTED_REDIRECT_STATUSES = (HTTPStatus.MOVED_PERMANENTLY, HTTPStatus.FOUND, HTTPStatus.SEE_OTHER,)
SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,)
CookieJar = SimpleCookieJar()
......@@ -67,12 +78,15 @@ def handshake(sock, hostname, port, resource, **options):
dump("request header", header_str)
status, resp = _get_resp_headers(sock)
if status in SUPPORTED_REDIRECT_STATUSES:
return handshake_response(status, resp, None)
success, subproto = _validate(resp, key, options.get("subprotocols"))
if not success:
raise WebSocketException("Invalid WebSocket Header")
return handshake_response(status, resp, subproto)
def _pack_hostname(hostname):
# IPv6 address
if ':' in hostname:
......@@ -83,28 +97,40 @@ def _pack_hostname(hostname):
def _get_handshake_headers(resource, host, port, options):
headers = [
"GET %s HTTP/1.1" % resource,
"Upgrade: websocket",
"Connection: Upgrade"
"Upgrade: websocket"
]
if port == 80 or port == 443:
hostport = _pack_hostname(host)
else:
hostport = "%s:%d" % (_pack_hostname(host), port)
if "host" in options and options["host"] is not None:
headers.append("Host: %s" % options["host"])
else:
headers.append("Host: %s" % hostport)
if "suppress_origin" not in options or not options["suppress_origin"]:
if "origin" in options and options["origin"] is not None:
headers.append("Origin: %s" % options["origin"])
else:
headers.append("Origin: http://%s" % hostport)
key = _create_sec_websocket_key()
# Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified
if not 'header' in options or 'Sec-WebSocket-Key' not in options['header']:
key = _create_sec_websocket_key()
headers.append("Sec-WebSocket-Key: %s" % key)
else:
key = options['header']['Sec-WebSocket-Key']
if not 'header' in options or 'Sec-WebSocket-Version' not in options['header']:
headers.append("Sec-WebSocket-Version: %s" % VERSION)
if not 'connection' in options or options['connection'] is None:
headers.append('Connection: upgrade')
else:
headers.append(options['connection'])
subprotocols = options.get("subprotocols")
if subprotocols:
headers.append("Sec-WebSocket-Protocol: %s" % ",".join(subprotocols))
......@@ -112,7 +138,11 @@ def _get_handshake_headers(resource, host, port, options):
if "header" in options:
header = options["header"]
if isinstance(header, dict):
header = map(": ".join, header.items())
header = [
": ".join([k, v])
for k, v in header.items()
if v is not None
]
headers.extend(header)
server_cookie = CookieJar.get(host)
......@@ -129,12 +159,13 @@ def _get_handshake_headers(resource, host, port, options):
return headers, key
def _get_resp_headers(sock, success_status=101):
def _get_resp_headers(sock, success_statuses=SUCCESS_STATUSES):
status, resp_headers, status_message = read_headers(sock)
if status != success_status:
raise WebSocketBadStatusException("Handshake status %d %s", status, status_message)
if status not in success_statuses:
raise WebSocketBadStatusException("Handshake status %d %s", status, status_message, resp_headers)
return status, resp_headers
_HEADERS_TO_CHECK = {
"upgrade": "websocket",
"connection": "upgrade",
......
......@@ -39,21 +39,72 @@ else:
__all__ = ["proxy_info", "connect", "read_headers"]
try:
import socks
ProxyConnectionError = socks.ProxyConnectionError
HAS_PYSOCKS = True
except:
class ProxyConnectionError(BaseException):
pass
HAS_PYSOCKS = False
class proxy_info(object):
def __init__(self, **options):
self.type = options.get("proxy_type") or "http"
if not(self.type in ['http', 'socks4', 'socks5', 'socks5h']):
raise ValueError("proxy_type must be 'http', 'socks4', 'socks5' or 'socks5h'")
self.host = options.get("http_proxy_host", None)
if self.host:
self.port = options.get("http_proxy_port", 0)
self.auth = options.get("http_proxy_auth", None)
self.no_proxy = options.get("http_no_proxy", None)
else:
self.port = 0
self.auth = None
self.no_proxy = options.get("http_no_proxy", None)
self.no_proxy = None
def _open_proxied_socket(url, options, proxy):
hostname, port, resource, is_secure = parse_url(url)
if not HAS_PYSOCKS:
raise WebSocketException("PySocks module not found.")
ptype = socks.SOCKS5
rdns = False
if proxy.type == "socks4":
ptype = socks.SOCKS4
if proxy.type == "http":
ptype = socks.HTTP
if proxy.type[-1] == "h":
rdns = True
sock = socks.create_connection(
(hostname, port),
proxy_type = ptype,
proxy_addr = proxy.host,
proxy_port = proxy.port,
proxy_rdns = rdns,
proxy_username = proxy.auth[0] if proxy.auth else None,
proxy_password = proxy.auth[1] if proxy.auth else None,
timeout = options.timeout,
socket_options = DEFAULT_SOCKET_OPTION + options.sockopt
)
if is_secure:
if HAVE_SSL:
sock = _ssl_socket(sock, options.sslopt, hostname)
else:
raise WebSocketException("SSL not available.")
return sock, (hostname, port, resource)
def connect(url, options, proxy, socket):
if proxy.host and not socket and not (proxy.type == 'http'):
return _open_proxied_socket(url, options, proxy)
hostname, port, resource, is_secure = parse_url(url)
if socket:
......@@ -88,13 +139,20 @@ def _get_addrinfo_list(hostname, port, is_secure, proxy):
phost, pport, pauth = get_proxy_info(
hostname, is_secure, proxy.host, proxy.port, proxy.auth, proxy.no_proxy)
try:
# when running on windows 10, getaddrinfo without socktype returns a socktype 0.
# This generates an error exception: `_on_error: exception Socket type must be stream or datagram, not 0`
# or `OSError: [Errno 22] Invalid argument` when creating socket. Force the socket type to SOCK_STREAM.
if not phost:
addrinfo_list = socket.getaddrinfo(
hostname, port, 0, 0, socket.SOL_TCP)
hostname, port, 0, socket.SOCK_STREAM, socket.SOL_TCP)
return addrinfo_list, False, None
else:
pport = pport and pport or 80
addrinfo_list = socket.getaddrinfo(phost, pport, 0, 0, socket.SOL_TCP)
# when running on windows 10, the getaddrinfo used above
# returns a socktype 0. This generates an error exception:
# _on_error: exception Socket type must be stream or datagram, not 0
# Force the socket type to SOCK_STREAM
addrinfo_list = socket.getaddrinfo(phost, pport, 0, socket.SOCK_STREAM, socket.SOL_TCP)
return addrinfo_list, True, pauth
except socket.gaierror as e:
raise WebSocketAddressException(e)
......@@ -106,28 +164,40 @@ def _open_socket(addrinfo_list, sockopt, timeout):
family, socktype, proto = addrinfo[:3]
sock = socket.socket(family, socktype, proto)
sock.settimeout(timeout)
# for opts in DEFAULT_SOCKET_OPTION:
# sock.setsockopt(*opts)
# for opts in sockopt:
# sock.setsockopt(*opts)
for opts in DEFAULT_SOCKET_OPTION:
sock.setsockopt(*opts)
for opts in sockopt:
sock.setsockopt(*opts)
address = addrinfo[4]
err = None
while not err:
try:
sock.connect(address)
except ProxyConnectionError as error:
err = WebSocketProxyException(str(error))
err.remote_ip = str(address[0])
continue
except socket.error as error:
error.remote_ip = str(address[0])
try:
eConnRefused = (errno.ECONNREFUSED, errno.WSAECONNREFUSED)
except:
eConnRefused = (errno.ECONNREFUSED, )
if error.errno in eConnRefused:
if error.errno == errno.EINTR:
continue
elif error.errno in eConnRefused:
err = error
continue
else:
raise
raise error
else:
break
else:
continue
break
else:
if err:
raise err
return sock
......@@ -141,7 +211,12 @@ def _wrap_sni_socket(sock, sslopt, hostname, check_hostname):
context = ssl.SSLContext(sslopt.get('ssl_version', ssl.PROTOCOL_SSLv23))
if sslopt.get('cert_reqs', ssl.CERT_NONE) != ssl.CERT_NONE:
context.load_verify_locations(cafile=sslopt.get('ca_certs', None), capath=sslopt.get('ca_cert_path', None))
cafile = sslopt.get('ca_certs', None)
capath = sslopt.get('ca_cert_path', None)
if cafile or capath:
context.load_verify_locations(cafile=cafile, capath=capath)
elif hasattr(context, 'load_default_certs'):
context.load_default_certs(ssl.Purpose.SERVER_AUTH)
if sslopt.get('certfile', None):
context.load_cert_chain(
sslopt['certfile'],
......@@ -173,15 +248,13 @@ def _ssl_socket(sock, user_sslopt, hostname):
sslopt = dict(cert_reqs=ssl.CERT_REQUIRED)
sslopt.update(user_sslopt)
if os.environ.get('WEBSOCKET_CLIENT_CA_BUNDLE'):
certPath = os.environ.get('WEBSOCKET_CLIENT_CA_BUNDLE')
else:
certPath = os.path.join(
os.path.dirname(__file__), "cacert.pem")
if os.path.isfile(certPath) and user_sslopt.get('ca_certs', None) is None \
if certPath and os.path.isfile(certPath) \
and user_sslopt.get('ca_certs', None) is None \
and user_sslopt.get('ca_cert', None) is None:
sslopt['ca_certs'] = certPath
elif os.path.isdir(certPath) and user_sslopt.get('ca_cert_path', None) is None:
elif certPath and os.path.isdir(certPath) \
and user_sslopt.get('ca_cert_path', None) is None:
sslopt['ca_cert_path'] = certPath
check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop(
......@@ -207,7 +280,7 @@ def _tunnel(sock, host, port, auth):
auth_str = auth[0]
if auth[1]:
auth_str += ":" + auth[1]
encoded_str = base64encode(auth_str.encode()).strip().decode()
encoded_str = base64encode(auth_str.encode()).strip().decode().replace('\n', '')
connect_header += "Proxy-Authorization: Basic %s\r\n" % encoded_str
connect_header += "\r\n"
dump("request header", connect_header)
......@@ -242,6 +315,7 @@ def read_headers(sock):
status_info = line.split(" ", 2)
status = int(status_info[1])
if len(status_info) > 2:
status_message = status_info[2]
else:
kv = line.split(":", 1)
......
......@@ -22,13 +22,22 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
import logging
_logger = logging.getLogger('websocket')
try:
from logging import NullHandler
except ImportError:
class NullHandler(logging.Handler):
def emit(self, record):
pass
_logger.addHandler(NullHandler())
_traceEnabled = False
__all__ = ["enableTrace", "dump", "error", "warning", "debug", "trace",
"isEnabledForError", "isEnabledForDebug"]
"isEnabledForError", "isEnabledForDebug", "isEnabledForTrace"]
def enableTrace(traceable):
def enableTrace(traceable, handler = logging.StreamHandler()):
"""
turn on/off the traceability.
......@@ -37,11 +46,9 @@ def enableTrace(traceable):
global _traceEnabled
_traceEnabled = traceable
if traceable:
if not _logger.handlers:
_logger.addHandler(logging.StreamHandler())
_logger.addHandler(handler)
_logger.setLevel(logging.DEBUG)
def dump(title, message):
if _traceEnabled:
_logger.debug("--- " + title + " ---")
......@@ -72,3 +79,6 @@ def isEnabledForError():
def isEnabledForDebug():
return _logger.isEnabledFor(logging.DEBUG)
def isEnabledForTrace():
return _traceEnabled
......@@ -19,6 +19,8 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA
"""
import errno
import select
import socket
import six
......@@ -77,8 +79,27 @@ def recv(sock, bufsize):
if not sock:
raise WebSocketConnectionClosedException("socket is already closed.")
def _recv():
try:
return sock.recv(bufsize)
except SSLWantReadError:
pass
except socket.error as exc:
error_code = extract_error_code(exc)
if error_code is None:
raise
if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK:
raise
r, w, e = select.select((sock, ), (), (), sock.gettimeout())
if r:
return sock.recv(bufsize)
try:
if sock.gettimeout() == 0:
bytes_ = sock.recv(bufsize)
else:
bytes_ = _recv()
except socket.timeout as e:
message = extract_err_message(e)
raise WebSocketTimeoutException(message)
......@@ -113,8 +134,27 @@ def send(sock, data):
if not sock:
raise WebSocketConnectionClosedException("socket is already closed.")
def _send():
try:
return sock.send(data)
except SSLWantWriteError:
pass
except socket.error as exc:
error_code = extract_error_code(exc)
if error_code is None:
raise
if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK:
raise
r, w, e = select.select((), (sock, ), (), sock.gettimeout())
if w:
return sock.send(data)
try:
if sock.gettimeout() == 0:
return sock.send(data)
else:
return _send()
except socket.timeout as e:
message = extract_err_message(e)
raise WebSocketTimeoutException(message)
......
......@@ -19,11 +19,13 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA
"""
__all__ = ["HAVE_SSL", "ssl", "SSLError"]
__all__ = ["HAVE_SSL", "ssl", "SSLError", "SSLWantReadError", "SSLWantWriteError"]
try:
import ssl
from ssl import SSLError
from ssl import SSLWantReadError
from ssl import SSLWantWriteError
if hasattr(ssl, 'SSLContext') and hasattr(ssl.SSLContext, 'check_hostname'):
HAVE_CONTEXT_CHECK_HOSTNAME = True
else:
......@@ -41,4 +43,12 @@ except ImportError:
class SSLError(Exception):
pass
class SSLWantReadError(Exception):
pass
class SSLWantWriteError(Exception):
pass
ssl = lambda: None
HAVE_SSL = False
......@@ -103,6 +103,7 @@ def _is_address_in_network(ip, net):
def _is_no_proxy_host(hostname, no_proxy):
if not no_proxy:
v = os.environ.get("no_proxy", "").replace(" ", "")
if v:
no_proxy = v.split(",")
if not no_proxy:
no_proxy = DEFAULT_NO_PROXY_HOST
......@@ -117,7 +118,7 @@ def _is_no_proxy_host(hostname, no_proxy):
def get_proxy_info(
hostname, is_secure, proxy_host=None, proxy_port=0, proxy_auth=None,
no_proxy=None):
no_proxy=None, proxy_type='http'):
"""
try to retrieve proxy host and port from environment
if not provided in options.
......@@ -137,6 +138,9 @@ def get_proxy_info(
"http_proxy_auth" - http proxy auth information.
tuple of username and password.
default is None
"proxy_type" - if set to "socks5" PySocks wrapper
will be used in place of a http proxy.
default is "http"
"""
if _is_no_proxy_host(hostname, no_proxy):
return None, 0, None
......
......@@ -21,7 +21,7 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
"""
import six
__all__ = ["NoLock", "validate_utf8", "extract_err_message"]
__all__ = ["NoLock", "validate_utf8", "extract_err_message", "extract_error_code"]
class NoLock(object):
......@@ -32,6 +32,7 @@ class NoLock(object):
def __exit__(self, exc_type, exc_value, traceback):
pass
try:
# If wsaccel is available we use compiled routines to validate UTF-8
# strings.
......@@ -103,3 +104,8 @@ def extract_err_message(exception):
return exception.args[0]
else:
return None
def extract_error_code(exception):
if exception.args and len(exception.args) > 1:
return exception.args[0] if isinstance(exception.args[0], int) else None
HTTP/1.1 101 WebSocket Protocol Handshake
Connection: Upgrade
Upgrade: WebSocket
Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0=
some_header: something
HTTP/1.1 101 WebSocket Protocol Handshake
Connection: Upgrade
Upgrade WebSocket
Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0=
some_header: something
import unittest
from websocket._cookiejar import SimpleCookieJar
try:
import Cookie
except:
import http.cookies as Cookie
class CookieJarTest(unittest.TestCase):
def testAdd(self):
cookie_jar = SimpleCookieJar()
cookie_jar.add("")
self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b")
self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; domain=.abc")
self.assertTrue(".abc" in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; domain=abc")
self.assertTrue(".abc" in cookie_jar.jar)
self.assertTrue("abc" not in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
cookie_jar.add("e=f; domain=abc")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d; e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
cookie_jar.add("e=f; domain=.abc")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d; e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
cookie_jar.add("e=f; domain=xyz")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d")
self.assertEquals(cookie_jar.get("xyz"), "e=f")
self.assertEquals(cookie_jar.get("something"), "")
def testSet(self):
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b")
self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; domain=.abc")
self.assertTrue(".abc" in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; domain=abc")
self.assertTrue(".abc" in cookie_jar.jar)
self.assertTrue("abc" not in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
cookie_jar.set("e=f; domain=abc")
self.assertEquals(cookie_jar.get("abc"), "e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
cookie_jar.set("e=f; domain=.abc")
self.assertEquals(cookie_jar.get("abc"), "e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
cookie_jar.set("e=f; domain=xyz")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d")
self.assertEquals(cookie_jar.get("xyz"), "e=f")
self.assertEquals(cookie_jar.get("something"), "")
def testGet(self):
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc.com")
self.assertEquals(cookie_jar.get("abc.com"), "a=b; c=d")
self.assertEquals(cookie_jar.get("x.abc.com"), "a=b; c=d")
self.assertEquals(cookie_jar.get("abc.com.es"), "")
self.assertEquals(cookie_jar.get("xabc.com"), "")
cookie_jar.set("a=b; c=d; domain=.abc.com")
self.assertEquals(cookie_jar.get("abc.com"), "a=b; c=d")
self.assertEquals(cookie_jar.get("x.abc.com"), "a=b; c=d")
self.assertEquals(cookie_jar.get("abc.com.es"), "")
self.assertEquals(cookie_jar.get("xabc.com"), "")
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment