Commit 54dae43b authored by l2m2's avatar l2m2

upload.

parent 3ca68ace
......@@ -26,4 +26,4 @@ from ._exceptions import *
from ._logging import *
from ._socket import *
__version__ = "0.47.0"
__version__ = "0.57.0"
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
......@@ -23,6 +23,7 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
"""
WebSocketApp provides higher level APIs.
"""
import inspect
import select
import sys
import threading
......@@ -44,23 +45,27 @@ class Dispatcher:
self.app = app
self.ping_timeout = ping_timeout
def read(self, sock, callback):
while self.app.sock.connected:
def read(self, sock, read_callback, check_callback):
while self.app.keep_running:
r, w, e = select.select(
(self.app.sock.sock, ), (), (), self.ping_timeout) # Use a 10 second timeout to avoid to wait forever on close
(self.app.sock.sock, ), (), (), self.ping_timeout)
if r:
callback()
if not read_callback():
break
check_callback()
class SSLDispacther:
class SSLDispatcher:
def __init__(self, app, ping_timeout):
self.app = app
self.ping_timeout = ping_timeout
def read(self, sock, callback):
while self.app.sock.connected:
def read(self, sock, read_callback, check_callback):
while self.app.keep_running:
r = self.select()
if r:
callback()
if not read_callback():
break
check_callback()
def select(self):
sock = self.app.sock.sock
......@@ -70,6 +75,7 @@ class SSLDispacther:
r, w, e = select.select((sock, ), (), (), self.ping_timeout)
return r
class WebSocketApp(object):
"""
Higher level of APIs are provided.
......@@ -113,7 +119,7 @@ class WebSocketApp(object):
The 2nd argument is utf-8 string which we get from the server.
The 3rd argument is data type. ABNF.OPCODE_TEXT or ABNF.OPCODE_BINARY will be came.
The 4th argument is continue flag. if 0, the data continue
keep_running: this parameter is obosleted and ignored it.
keep_running: this parameter is obsolete and ignored.
get_mask_key: a callable to produce new mask keys,
see the WebSocket.set_mask_key's docstring for more information
subprotocols: array of available sub protocols. default is None.
......@@ -121,6 +127,7 @@ class WebSocketApp(object):
self.url = url
self.header = header if header is not None else []
self.cookie = cookie
self.on_open = on_open
self.on_message = on_message
self.on_data = on_data
......@@ -155,6 +162,7 @@ class WebSocketApp(object):
self.keep_running = False
if self.sock:
self.sock.close(**kwargs)
self.sock = None
def _send_ping(self, interval, event):
while not event.wait(interval):
......@@ -171,7 +179,8 @@ class WebSocketApp(object):
http_proxy_host=None, http_proxy_port=None,
http_no_proxy=None, http_proxy_auth=None,
skip_utf8_validation=False,
host=None, origin=None, dispatcher=None):
host=None, origin=None, dispatcher=None,
suppress_origin=False, proxy_type=None):
"""
run event loop for WebSocket framework.
This loop is infinite loop and is alive during websocket is available.
......@@ -189,31 +198,41 @@ class WebSocketApp(object):
skip_utf8_validation: skip utf8 validation.
host: update host header.
origin: update origin header.
dispatcher: customize reading data from socket.
suppress_origin: suppress outputting origin header.
Returns
-------
False if caught KeyboardInterrupt
True if other exception was raised during a loop
"""
if not ping_timeout or ping_timeout <= 0:
if ping_timeout is not None and ping_timeout <= 0:
ping_timeout = None
if ping_timeout and ping_interval and ping_interval <= ping_timeout:
raise WebSocketException("Ensure ping_interval > ping_timeout")
if sockopt is None:
if not sockopt:
sockopt = []
if sslopt is None:
if not sslopt:
sslopt = {}
if self.sock:
raise WebSocketException("socket is already opened")
thread = None
close_frame = None
self.keep_running = True
self.last_ping_tm = 0
self.last_pong_tm = 0
def teardown():
if not self.keep_running:
return
def teardown(close_frame=None):
"""
Tears down the connection.
If close_frame is set, we will invoke the on_close handler with the
statusCode and reason from there.
"""
if thread and thread.isAlive():
event.set()
thread.join()
self.keep_running = False
if self.sock:
self.sock.close()
close_args = self._get_close_args(
close_frame.data if close_frame else None)
......@@ -223,15 +242,17 @@ class WebSocketApp(object):
try:
self.sock = WebSocket(
self.get_mask_key, sockopt=sockopt, sslopt=sslopt,
fire_cont_frame=self.on_cont_message and True or False,
skip_utf8_validation=skip_utf8_validation)
fire_cont_frame=self.on_cont_message is not None,
skip_utf8_validation=skip_utf8_validation,
enable_multithread=True if ping_interval else False)
self.sock.settimeout(getdefaulttimeout())
self.sock.connect(
self.url, header=self.header, cookie=self.cookie,
http_proxy_host=http_proxy_host,
http_proxy_port=http_proxy_port, http_no_proxy=http_no_proxy,
http_proxy_auth=http_proxy_auth, subprotocols=self.subprotocols,
host=host, origin=origin)
host=host, origin=origin, suppress_origin=suppress_origin,
proxy_type=proxy_type)
if not dispatcher:
dispatcher = self.create_dispatcher(ping_timeout)
......@@ -250,8 +271,7 @@ class WebSocketApp(object):
op_code, frame = self.sock.recv_data_frame(True)
if op_code == ABNF.OPCODE_CLOSE:
close_frame = frame
return teardown()
return teardown(frame)
elif op_code == ABNF.OPCODE_PING:
self._callback(self.on_ping, frame.data)
elif op_code == ABNF.OPCODE_PONG:
......@@ -269,31 +289,39 @@ class WebSocketApp(object):
self._callback(self.on_data, data, frame.opcode, True)
self._callback(self.on_message, data)
if ping_timeout and self.last_ping_tm \
and time.time() - self.last_ping_tm > ping_timeout \
and self.last_ping_tm - self.last_pong_tm > ping_timeout:
return True
def check():
if (ping_timeout):
has_timeout_expired = time.time() - self.last_ping_tm > ping_timeout
has_pong_not_arrived_after_last_ping = self.last_pong_tm - self.last_ping_tm < 0
has_pong_arrived_too_late = self.last_pong_tm - self.last_ping_tm > ping_timeout
if (self.last_ping_tm
and has_timeout_expired
and (has_pong_not_arrived_after_last_ping or has_pong_arrived_too_late)):
raise WebSocketTimeoutException("ping/pong timed out")
return True
dispatcher.read(self.sock.sock, read)
dispatcher.read(self.sock.sock, read, check)
except (Exception, KeyboardInterrupt, SystemExit) as e:
self._callback(self.on_error, e)
if isinstance(e, SystemExit):
# propagate SystemExit further
raise
teardown()
return not isinstance(e, KeyboardInterrupt)
def create_dispatcher(self, ping_timeout):
timeout = ping_timeout or 10
if self.sock.is_ssl():
return SSLDispacther(self, timeout)
return SSLDispatcher(self, timeout)
return Dispatcher(self, timeout)
def _get_close_args(self, data):
""" this functions extracts the code, reason from the close body
if they exists, and if the self.on_close except three arguments """
import inspect
# if the on_close callback is "old", just return empty list
if sys.version_info < (3, 0):
if not self.on_close or len(inspect.getargspec(self.on_close).args) != 3:
......@@ -312,7 +340,11 @@ class WebSocketApp(object):
def _callback(self, callback, *args):
if callback:
try:
if inspect.ismethod(callback):
callback(*args)
else:
callback(self, *args)
except Exception as e:
_logging.error("error from callback {}: {}".format(callback, e))
if _logging.isEnabledForDebug():
......
......@@ -24,6 +24,7 @@ from __future__ import print_function
import socket
import struct
import threading
import time
import six
......@@ -201,6 +202,7 @@ class WebSocket(object):
options: "header" -> custom http header list or dict.
"cookie" -> cookie value.
"origin" -> custom origin url.
"suppress_origin" -> suppress outputting origin header.
"host" -> custom host header string.
"http_proxy_host" - http proxy host name.
"http_proxy_port" - http proxy port. If not set, set to 80.
......@@ -208,16 +210,27 @@ class WebSocket(object):
"http_proxy_auth" - http proxy auth information.
tuple of username and password.
default is None
"redirect_limit" -> number of redirects to follow.
"subprotocols" - array of available sub protocols.
default is None.
"socket" - pre-initialized stream socket.
"""
# FIXME: "subprotocols" are getting lost, not passed down
# FIXME: "header", "cookie", "origin" and "host" too
self.sock_opt.timeout = options.get('timeout', self.sock_opt.timeout)
self.sock, addrs = connect(url, self.sock_opt, proxy_info(**options),
options.pop('socket', None))
try:
self.handshake_response = handshake(self.sock, *addrs, **options)
for attempt in range(options.pop('redirect_limit', 3)):
if self.handshake_response.status in SUPPORTED_REDIRECT_STATUSES:
url = self.handshake_response.headers['location']
self.sock.close()
self.sock, addrs = connect(url, self.sock_opt, proxy_info(**options),
options.pop('socket', None))
self.handshake_response = handshake(self.sock, *addrs, **options)
self.connected = True
except:
if self.sock:
......@@ -258,6 +271,7 @@ class WebSocket(object):
frame.get_mask_key = self.get_mask_key
data = frame.format()
length = len(data)
if (isEnabledForTrace()):
trace("send: " + repr(data))
with self.lock:
......@@ -397,14 +411,19 @@ class WebSocket(object):
reason, ABNF.OPCODE_CLOSE)
sock_timeout = self.sock.gettimeout()
self.sock.settimeout(timeout)
start_time = time.time()
while timeout is None or time.time() - start_time < timeout:
try:
frame = self.recv_frame()
if frame.opcode != ABNF.OPCODE_CLOSE:
continue
if isEnabledForError():
recv_status = struct.unpack("!H", frame.data[0:2])[0]
if recv_status != STATUS_NORMAL:
error("close status: " + repr(recv_status))
break
except:
pass
break
self.sock.settimeout(sock_timeout)
self.sock.shutdown(socket.SHUT_RDWR)
except:
......@@ -466,6 +485,7 @@ def create_connection(url, timeout=None, class_=WebSocket, **options):
options: "header" -> custom http header list or dict.
"cookie" -> cookie value.
"origin" -> custom origin url.
"suppress_origin" -> suppress outputting origin header.
"host" -> custom host header string.
"http_proxy_host" - http proxy host name.
"http_proxy_port" - http proxy port. If not set, set to 80.
......@@ -474,6 +494,7 @@ def create_connection(url, timeout=None, class_=WebSocket, **options):
tuple of username and password.
default is None
"enable_multithread" -> enable lock for multithread.
"redirect_limit" -> number of redirects to follow.
"sockopt" -> socket options
"sslopt" -> ssl option
"subprotocols" - array of available sub protocols.
......
......@@ -74,11 +74,12 @@ class WebSocketBadStatusException(WebSocketException):
WebSocketBadStatusException will be raised when we get bad handshake status code.
"""
def __init__(self, message, status_code, status_message=None):
msg = message % (status_code, status_message) if status_message is not None \
else message % status_code
def __init__(self, message, status_code, status_message=None, resp_headers=None):
msg = message % (status_code, status_message)
super(WebSocketBadStatusException, self).__init__(msg)
self.status_code = status_code
self.resp_headers = resp_headers
class WebSocketAddressException(WebSocketException):
"""
......
......@@ -31,12 +31,20 @@ from ._http import *
from ._logging import *
from ._socket import *
if six.PY3:
if hasattr(six, 'PY3') and six.PY3:
from base64 import encodebytes as base64encode
else:
from base64 import encodestring as base64encode
__all__ = ["handshake_response", "handshake"]
if hasattr(six, 'PY3') and six.PY3:
if hasattr(six, 'PY34') and six.PY34:
from http import client as HTTPStatus
else:
from http import HTTPStatus
else:
import httplib as HTTPStatus
__all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"]
if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest
......@@ -47,6 +55,9 @@ else:
# websocket supported version.
VERSION = 13
SUPPORTED_REDIRECT_STATUSES = (HTTPStatus.MOVED_PERMANENTLY, HTTPStatus.FOUND, HTTPStatus.SEE_OTHER,)
SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,)
CookieJar = SimpleCookieJar()
......@@ -67,12 +78,15 @@ def handshake(sock, hostname, port, resource, **options):
dump("request header", header_str)
status, resp = _get_resp_headers(sock)
if status in SUPPORTED_REDIRECT_STATUSES:
return handshake_response(status, resp, None)
success, subproto = _validate(resp, key, options.get("subprotocols"))
if not success:
raise WebSocketException("Invalid WebSocket Header")
return handshake_response(status, resp, subproto)
def _pack_hostname(hostname):
# IPv6 address
if ':' in hostname:
......@@ -83,28 +97,40 @@ def _pack_hostname(hostname):
def _get_handshake_headers(resource, host, port, options):
headers = [
"GET %s HTTP/1.1" % resource,
"Upgrade: websocket",
"Connection: Upgrade"
"Upgrade: websocket"
]
if port == 80 or port == 443:
hostport = _pack_hostname(host)
else:
hostport = "%s:%d" % (_pack_hostname(host), port)
if "host" in options and options["host"] is not None:
headers.append("Host: %s" % options["host"])
else:
headers.append("Host: %s" % hostport)
if "suppress_origin" not in options or not options["suppress_origin"]:
if "origin" in options and options["origin"] is not None:
headers.append("Origin: %s" % options["origin"])
else:
headers.append("Origin: http://%s" % hostport)
key = _create_sec_websocket_key()
# Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified
if not 'header' in options or 'Sec-WebSocket-Key' not in options['header']:
key = _create_sec_websocket_key()
headers.append("Sec-WebSocket-Key: %s" % key)
else:
key = options['header']['Sec-WebSocket-Key']
if not 'header' in options or 'Sec-WebSocket-Version' not in options['header']:
headers.append("Sec-WebSocket-Version: %s" % VERSION)
if not 'connection' in options or options['connection'] is None:
headers.append('Connection: upgrade')
else:
headers.append(options['connection'])
subprotocols = options.get("subprotocols")
if subprotocols:
headers.append("Sec-WebSocket-Protocol: %s" % ",".join(subprotocols))
......@@ -112,7 +138,11 @@ def _get_handshake_headers(resource, host, port, options):
if "header" in options:
header = options["header"]
if isinstance(header, dict):
header = map(": ".join, header.items())
header = [
": ".join([k, v])
for k, v in header.items()
if v is not None
]
headers.extend(header)
server_cookie = CookieJar.get(host)
......@@ -129,12 +159,13 @@ def _get_handshake_headers(resource, host, port, options):
return headers, key
def _get_resp_headers(sock, success_status=101):
def _get_resp_headers(sock, success_statuses=SUCCESS_STATUSES):
status, resp_headers, status_message = read_headers(sock)
if status != success_status:
raise WebSocketBadStatusException("Handshake status %d %s", status, status_message)
if status not in success_statuses:
raise WebSocketBadStatusException("Handshake status %d %s", status, status_message, resp_headers)
return status, resp_headers
_HEADERS_TO_CHECK = {
"upgrade": "websocket",
"connection": "upgrade",
......
......@@ -39,21 +39,72 @@ else:
__all__ = ["proxy_info", "connect", "read_headers"]
try:
import socks
ProxyConnectionError = socks.ProxyConnectionError
HAS_PYSOCKS = True
except:
class ProxyConnectionError(BaseException):
pass
HAS_PYSOCKS = False
class proxy_info(object):
def __init__(self, **options):
self.type = options.get("proxy_type") or "http"
if not(self.type in ['http', 'socks4', 'socks5', 'socks5h']):
raise ValueError("proxy_type must be 'http', 'socks4', 'socks5' or 'socks5h'")
self.host = options.get("http_proxy_host", None)
if self.host:
self.port = options.get("http_proxy_port", 0)
self.auth = options.get("http_proxy_auth", None)
self.no_proxy = options.get("http_no_proxy", None)
else:
self.port = 0
self.auth = None
self.no_proxy = options.get("http_no_proxy", None)
self.no_proxy = None
def _open_proxied_socket(url, options, proxy):
hostname, port, resource, is_secure = parse_url(url)
if not HAS_PYSOCKS:
raise WebSocketException("PySocks module not found.")
ptype = socks.SOCKS5
rdns = False
if proxy.type == "socks4":
ptype = socks.SOCKS4
if proxy.type == "http":
ptype = socks.HTTP
if proxy.type[-1] == "h":
rdns = True
sock = socks.create_connection(
(hostname, port),
proxy_type = ptype,
proxy_addr = proxy.host,
proxy_port = proxy.port,
proxy_rdns = rdns,
proxy_username = proxy.auth[0] if proxy.auth else None,
proxy_password = proxy.auth[1] if proxy.auth else None,
timeout = options.timeout,
socket_options = DEFAULT_SOCKET_OPTION + options.sockopt
)
if is_secure:
if HAVE_SSL:
sock = _ssl_socket(sock, options.sslopt, hostname)
else:
raise WebSocketException("SSL not available.")
return sock, (hostname, port, resource)
def connect(url, options, proxy, socket):
if proxy.host and not socket and not (proxy.type == 'http'):
return _open_proxied_socket(url, options, proxy)
hostname, port, resource, is_secure = parse_url(url)
if socket:
......@@ -88,13 +139,20 @@ def _get_addrinfo_list(hostname, port, is_secure, proxy):
phost, pport, pauth = get_proxy_info(
hostname, is_secure, proxy.host, proxy.port, proxy.auth, proxy.no_proxy)
try:
# when running on windows 10, getaddrinfo without socktype returns a socktype 0.
# This generates an error exception: `_on_error: exception Socket type must be stream or datagram, not 0`
# or `OSError: [Errno 22] Invalid argument` when creating socket. Force the socket type to SOCK_STREAM.
if not phost:
addrinfo_list = socket.getaddrinfo(
hostname, port, 0, 0, socket.SOL_TCP)
hostname, port, 0, socket.SOCK_STREAM, socket.SOL_TCP)
return addrinfo_list, False, None
else:
pport = pport and pport or 80
addrinfo_list = socket.getaddrinfo(phost, pport, 0, 0, socket.SOL_TCP)
# when running on windows 10, the getaddrinfo used above
# returns a socktype 0. This generates an error exception:
# _on_error: exception Socket type must be stream or datagram, not 0
# Force the socket type to SOCK_STREAM
addrinfo_list = socket.getaddrinfo(phost, pport, 0, socket.SOCK_STREAM, socket.SOL_TCP)
return addrinfo_list, True, pauth
except socket.gaierror as e:
raise WebSocketAddressException(e)
......@@ -106,28 +164,40 @@ def _open_socket(addrinfo_list, sockopt, timeout):
family, socktype, proto = addrinfo[:3]
sock = socket.socket(family, socktype, proto)
sock.settimeout(timeout)
# for opts in DEFAULT_SOCKET_OPTION:
# sock.setsockopt(*opts)
# for opts in sockopt:
# sock.setsockopt(*opts)
for opts in DEFAULT_SOCKET_OPTION:
sock.setsockopt(*opts)
for opts in sockopt:
sock.setsockopt(*opts)
address = addrinfo[4]
err = None
while not err:
try:
sock.connect(address)
except ProxyConnectionError as error:
err = WebSocketProxyException(str(error))
err.remote_ip = str(address[0])
continue
except socket.error as error:
error.remote_ip = str(address[0])
try:
eConnRefused = (errno.ECONNREFUSED, errno.WSAECONNREFUSED)
except:
eConnRefused = (errno.ECONNREFUSED, )
if error.errno in eConnRefused:
if error.errno == errno.EINTR:
continue
elif error.errno in eConnRefused:
err = error
continue
else:
raise
raise error
else:
break
else:
continue
break
else:
if err:
raise err
return sock
......@@ -141,7 +211,12 @@ def _wrap_sni_socket(sock, sslopt, hostname, check_hostname):
context = ssl.SSLContext(sslopt.get('ssl_version', ssl.PROTOCOL_SSLv23))
if sslopt.get('cert_reqs', ssl.CERT_NONE) != ssl.CERT_NONE:
context.load_verify_locations(cafile=sslopt.get('ca_certs', None), capath=sslopt.get('ca_cert_path', None))
cafile = sslopt.get('ca_certs', None)
capath = sslopt.get('ca_cert_path', None)
if cafile or capath:
context.load_verify_locations(cafile=cafile, capath=capath)
elif hasattr(context, 'load_default_certs'):
context.load_default_certs(ssl.Purpose.SERVER_AUTH)
if sslopt.get('certfile', None):
context.load_cert_chain(
sslopt['certfile'],
......@@ -173,15 +248,13 @@ def _ssl_socket(sock, user_sslopt, hostname):
sslopt = dict(cert_reqs=ssl.CERT_REQUIRED)
sslopt.update(user_sslopt)
if os.environ.get('WEBSOCKET_CLIENT_CA_BUNDLE'):
certPath = os.environ.get('WEBSOCKET_CLIENT_CA_BUNDLE')
else:
certPath = os.path.join(
os.path.dirname(__file__), "cacert.pem")
if os.path.isfile(certPath) and user_sslopt.get('ca_certs', None) is None \
if certPath and os.path.isfile(certPath) \
and user_sslopt.get('ca_certs', None) is None \
and user_sslopt.get('ca_cert', None) is None:
sslopt['ca_certs'] = certPath
elif os.path.isdir(certPath) and user_sslopt.get('ca_cert_path', None) is None:
elif certPath and os.path.isdir(certPath) \
and user_sslopt.get('ca_cert_path', None) is None:
sslopt['ca_cert_path'] = certPath
check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop(
......@@ -207,7 +280,7 @@ def _tunnel(sock, host, port, auth):
auth_str = auth[0]
if auth[1]:
auth_str += ":" + auth[1]
encoded_str = base64encode(auth_str.encode()).strip().decode()
encoded_str = base64encode(auth_str.encode()).strip().decode().replace('\n', '')
connect_header += "Proxy-Authorization: Basic %s\r\n" % encoded_str
connect_header += "\r\n"
dump("request header", connect_header)
......@@ -242,6 +315,7 @@ def read_headers(sock):
status_info = line.split(" ", 2)
status = int(status_info[1])
if len(status_info) > 2:
status_message = status_info[2]
else:
kv = line.split(":", 1)
......
......@@ -22,13 +22,22 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
import logging
_logger = logging.getLogger('websocket')
try:
from logging import NullHandler
except ImportError:
class NullHandler(logging.Handler):
def emit(self, record):
pass
_logger.addHandler(NullHandler())
_traceEnabled = False
__all__ = ["enableTrace", "dump", "error", "warning", "debug", "trace",
"isEnabledForError", "isEnabledForDebug"]
"isEnabledForError", "isEnabledForDebug", "isEnabledForTrace"]
def enableTrace(traceable):
def enableTrace(traceable, handler = logging.StreamHandler()):
"""
turn on/off the traceability.
......@@ -37,11 +46,9 @@ def enableTrace(traceable):
global _traceEnabled
_traceEnabled = traceable
if traceable:
if not _logger.handlers:
_logger.addHandler(logging.StreamHandler())
_logger.addHandler(handler)
_logger.setLevel(logging.DEBUG)
def dump(title, message):
if _traceEnabled:
_logger.debug("--- " + title + " ---")
......@@ -72,3 +79,6 @@ def isEnabledForError():
def isEnabledForDebug():
return _logger.isEnabledFor(logging.DEBUG)
def isEnabledForTrace():
return _traceEnabled
......@@ -19,6 +19,8 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA
"""
import errno
import select
import socket
import six
......@@ -77,8 +79,27 @@ def recv(sock, bufsize):
if not sock:
raise WebSocketConnectionClosedException("socket is already closed.")
def _recv():
try:
return sock.recv(bufsize)
except SSLWantReadError:
pass
except socket.error as exc:
error_code = extract_error_code(exc)
if error_code is None:
raise
if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK:
raise
r, w, e = select.select((sock, ), (), (), sock.gettimeout())
if r:
return sock.recv(bufsize)
try:
if sock.gettimeout() == 0:
bytes_ = sock.recv(bufsize)
else:
bytes_ = _recv()
except socket.timeout as e:
message = extract_err_message(e)
raise WebSocketTimeoutException(message)
......@@ -113,8 +134,27 @@ def send(sock, data):
if not sock:
raise WebSocketConnectionClosedException("socket is already closed.")
def _send():
try:
return sock.send(data)
except SSLWantWriteError:
pass
except socket.error as exc:
error_code = extract_error_code(exc)
if error_code is None:
raise
if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK:
raise
r, w, e = select.select((), (sock, ), (), sock.gettimeout())
if w:
return sock.send(data)
try:
if sock.gettimeout() == 0:
return sock.send(data)
else:
return _send()
except socket.timeout as e:
message = extract_err_message(e)
raise WebSocketTimeoutException(message)
......
......@@ -19,11 +19,13 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Boston, MA 02110-1335 USA
"""
__all__ = ["HAVE_SSL", "ssl", "SSLError"]
__all__ = ["HAVE_SSL", "ssl", "SSLError", "SSLWantReadError", "SSLWantWriteError"]
try:
import ssl
from ssl import SSLError
from ssl import SSLWantReadError
from ssl import SSLWantWriteError
if hasattr(ssl, 'SSLContext') and hasattr(ssl.SSLContext, 'check_hostname'):
HAVE_CONTEXT_CHECK_HOSTNAME = True
else:
......@@ -41,4 +43,12 @@ except ImportError:
class SSLError(Exception):
pass
class SSLWantReadError(Exception):
pass
class SSLWantWriteError(Exception):
pass
ssl = lambda: None
HAVE_SSL = False
......@@ -103,6 +103,7 @@ def _is_address_in_network(ip, net):
def _is_no_proxy_host(hostname, no_proxy):
if not no_proxy:
v = os.environ.get("no_proxy", "").replace(" ", "")
if v:
no_proxy = v.split(",")
if not no_proxy:
no_proxy = DEFAULT_NO_PROXY_HOST
......@@ -117,7 +118,7 @@ def _is_no_proxy_host(hostname, no_proxy):
def get_proxy_info(
hostname, is_secure, proxy_host=None, proxy_port=0, proxy_auth=None,
no_proxy=None):
no_proxy=None, proxy_type='http'):
"""
try to retrieve proxy host and port from environment
if not provided in options.
......@@ -137,6 +138,9 @@ def get_proxy_info(
"http_proxy_auth" - http proxy auth information.
tuple of username and password.
default is None
"proxy_type" - if set to "socks5" PySocks wrapper
will be used in place of a http proxy.
default is "http"
"""
if _is_no_proxy_host(hostname, no_proxy):
return None, 0, None
......
......@@ -21,7 +21,7 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
"""
import six
__all__ = ["NoLock", "validate_utf8", "extract_err_message"]
__all__ = ["NoLock", "validate_utf8", "extract_err_message", "extract_error_code"]
class NoLock(object):
......@@ -32,6 +32,7 @@ class NoLock(object):
def __exit__(self, exc_type, exc_value, traceback):
pass
try:
# If wsaccel is available we use compiled routines to validate UTF-8
# strings.
......@@ -103,3 +104,8 @@ def extract_err_message(exception):
return exception.args[0]
else:
return None
def extract_error_code(exception):
if exception.args and len(exception.args) > 1:
return exception.args[0] if isinstance(exception.args[0], int) else None
HTTP/1.1 101 WebSocket Protocol Handshake
Connection: Upgrade
Upgrade: WebSocket
Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0=
some_header: something
HTTP/1.1 101 WebSocket Protocol Handshake
Connection: Upgrade
Upgrade WebSocket
Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0=
some_header: something
import unittest
from websocket._cookiejar import SimpleCookieJar
try:
import Cookie
except:
import http.cookies as Cookie
class CookieJarTest(unittest.TestCase):
def testAdd(self):
cookie_jar = SimpleCookieJar()
cookie_jar.add("")
self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b")
self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; domain=.abc")
self.assertTrue(".abc" in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; domain=abc")
self.assertTrue(".abc" in cookie_jar.jar)
self.assertTrue("abc" not in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
cookie_jar.add("e=f; domain=abc")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d; e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
cookie_jar.add("e=f; domain=.abc")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d; e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc")
cookie_jar.add("e=f; domain=xyz")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d")
self.assertEquals(cookie_jar.get("xyz"), "e=f")
self.assertEquals(cookie_jar.get("something"), "")
def testSet(self):
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b")
self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; domain=.abc")
self.assertTrue(".abc" in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; domain=abc")
self.assertTrue(".abc" in cookie_jar.jar)
self.assertTrue("abc" not in cookie_jar.jar)
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
cookie_jar.set("e=f; domain=abc")
self.assertEquals(cookie_jar.get("abc"), "e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
cookie_jar.set("e=f; domain=.abc")
self.assertEquals(cookie_jar.get("abc"), "e=f")
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc")
cookie_jar.set("e=f; domain=xyz")
self.assertEquals(cookie_jar.get("abc"), "a=b; c=d")
self.assertEquals(cookie_jar.get("xyz"), "e=f")
self.assertEquals(cookie_jar.get("something"), "")
def testGet(self):
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc.com")
self.assertEquals(cookie_jar.get("abc.com"), "a=b; c=d")
self.assertEquals(cookie_jar.get("x.abc.com"), "a=b; c=d")
self.assertEquals(cookie_jar.get("abc.com.es"), "")
self.assertEquals(cookie_jar.get("xabc.com"), "")
cookie_jar.set("a=b; c=d; domain=.abc.com")
self.assertEquals(cookie_jar.get("abc.com"), "a=b; c=d")
self.assertEquals(cookie_jar.get("x.abc.com"), "a=b; c=d")
self.assertEquals(cookie_jar.get("abc.com.es"), "")
self.assertEquals(cookie_jar.get("xabc.com"), "")
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment