[7.x] Add AIOHttpConnection
This commit is contained in:
committed by
Seth Michael Larson
parent
521d4db641
commit
02e3323650
@@ -14,3 +14,7 @@ pandas
|
||||
pyyaml<5.3
|
||||
|
||||
black; python_version>="3.6"
|
||||
|
||||
# Requirements for testing [async] extra
|
||||
aiohttp; python_version>="3.6"
|
||||
pytest-asyncio; python_version>="3.6"
|
||||
|
||||
@@ -9,6 +9,7 @@ VERSION = (7, 9, 0)
|
||||
__version__ = VERSION
|
||||
__versionstr__ = "7.9.0a1"
|
||||
|
||||
import sys
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
@@ -64,3 +65,14 @@ __all__ = [
|
||||
"AuthorizationException",
|
||||
"ElasticsearchDeprecationWarning",
|
||||
]
|
||||
|
||||
try:
|
||||
# Asyncio only supported on Python 3.6+
|
||||
if sys.version_info < (3, 6):
|
||||
raise ImportError
|
||||
|
||||
from ._async.http_aiohttp import AIOHttpConnection
|
||||
|
||||
__all__ += ["AIOHttpConnection"]
|
||||
except (ImportError, SyntaxError):
|
||||
pass
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
# Licensed to Elasticsearch B.V under one or more agreements.
|
||||
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
|
||||
# See the LICENSE file in the project root for more information
|
||||
@@ -0,0 +1,23 @@
|
||||
# Licensed to Elasticsearch B.V under one or more agreements.
|
||||
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
|
||||
# See the LICENSE file in the project root for more information
|
||||
|
||||
import asyncio
|
||||
|
||||
# Hack supporting Python 3.6 asyncio which didn't have 'get_running_loop()'.
|
||||
# Essentially we want to get away from having users pass in a loop to us.
|
||||
# Instead we should call 'get_running_loop()' whenever we need
|
||||
# the currently running loop.
|
||||
# See: https://aiopg.readthedocs.io/en/stable/run_loop.html#implementation
|
||||
try:
|
||||
from asyncio import get_running_loop
|
||||
except ImportError:
|
||||
|
||||
def get_running_loop():
|
||||
loop = asyncio.get_event_loop()
|
||||
if not loop.is_running():
|
||||
raise RuntimeError("no running event loop")
|
||||
return loop
|
||||
|
||||
|
||||
__all__ = ["get_running_loop"]
|
||||
@@ -0,0 +1,311 @@
|
||||
# Licensed to Elasticsearch B.V under one or more agreements.
|
||||
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
|
||||
# See the LICENSE file in the project root for more information
|
||||
|
||||
import asyncio
|
||||
import ssl
|
||||
import os
|
||||
import urllib3
|
||||
import warnings
|
||||
|
||||
import aiohttp
|
||||
import yarl
|
||||
from aiohttp.client_exceptions import ServerFingerprintMismatch, ServerTimeoutError
|
||||
|
||||
from .compat import get_running_loop
|
||||
from ..connection import Connection
|
||||
from ..compat import urlencode
|
||||
from ..exceptions import (
|
||||
ConnectionError,
|
||||
ConnectionTimeout,
|
||||
ImproperlyConfigured,
|
||||
SSLError,
|
||||
)
|
||||
|
||||
|
||||
# sentinel value for `verify_certs`.
|
||||
# This is used to detect if a user is passing in a value
|
||||
# for SSL kwargs if also using an SSLContext.
|
||||
VERIFY_CERTS_DEFAULT = object()
|
||||
SSL_SHOW_WARN_DEFAULT = object()
|
||||
|
||||
CA_CERTS = None
|
||||
|
||||
try:
|
||||
import certifi
|
||||
|
||||
CA_CERTS = certifi.where()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class AIOHttpConnection(Connection):
|
||||
def __init__(
|
||||
self,
|
||||
host="localhost",
|
||||
port=None,
|
||||
http_auth=None,
|
||||
use_ssl=False,
|
||||
verify_certs=VERIFY_CERTS_DEFAULT,
|
||||
ssl_show_warn=SSL_SHOW_WARN_DEFAULT,
|
||||
ca_certs=None,
|
||||
client_cert=None,
|
||||
client_key=None,
|
||||
ssl_version=None,
|
||||
ssl_assert_fingerprint=None,
|
||||
maxsize=10,
|
||||
headers=None,
|
||||
ssl_context=None,
|
||||
http_compress=None,
|
||||
cloud_id=None,
|
||||
api_key=None,
|
||||
opaque_id=None,
|
||||
loop=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Default connection class for ``AsyncElasticsearch`` using the `aiohttp` library and the http protocol.
|
||||
|
||||
:arg host: hostname of the node (default: localhost)
|
||||
:arg port: port to use (integer, default: 9200)
|
||||
:arg timeout: default timeout in seconds (float, default: 10)
|
||||
:arg http_auth: optional http auth information as either ':' separated
|
||||
string or a tuple
|
||||
:arg use_ssl: use ssl for the connection if `True`
|
||||
:arg verify_certs: whether to verify SSL certificates
|
||||
:arg ssl_show_warn: show warning when verify certs is disabled
|
||||
:arg ca_certs: optional path to CA bundle.
|
||||
See https://urllib3.readthedocs.io/en/latest/security.html#using-certifi-with-urllib3
|
||||
for instructions how to get default set
|
||||
:arg client_cert: path to the file containing the private key and the
|
||||
certificate, or cert only if using client_key
|
||||
:arg client_key: path to the file containing the private key if using
|
||||
separate cert and key files (client_cert will contain only the cert)
|
||||
:arg ssl_version: version of the SSL protocol to use. Choices are:
|
||||
SSLv23 (default) SSLv2 SSLv3 TLSv1 (see ``PROTOCOL_*`` constants in the
|
||||
``ssl`` module for exact options for your environment).
|
||||
:arg ssl_assert_hostname: use hostname verification if not `False`
|
||||
:arg ssl_assert_fingerprint: verify the supplied certificate fingerprint if not `None`
|
||||
:arg maxsize: the number of connections which will be kept open to this
|
||||
host. See https://urllib3.readthedocs.io/en/1.4/pools.html#api for more
|
||||
information.
|
||||
:arg headers: any custom http headers to be add to requests
|
||||
:arg http_compress: Use gzip compression
|
||||
:arg cloud_id: The Cloud ID from ElasticCloud. Convenient way to connect to cloud instances.
|
||||
Other host connection params will be ignored.
|
||||
:arg api_key: optional API Key authentication as either base64 encoded string or a tuple.
|
||||
:arg opaque_id: Send this value in the 'X-Opaque-Id' HTTP header
|
||||
For tracing all requests made by this transport.
|
||||
:arg loop: asyncio Event Loop to use with aiohttp. This is set by default to the currently running loop.
|
||||
"""
|
||||
|
||||
self.headers = {}
|
||||
|
||||
super().__init__(
|
||||
host=host,
|
||||
port=port,
|
||||
use_ssl=use_ssl,
|
||||
headers=headers,
|
||||
http_compress=http_compress,
|
||||
cloud_id=cloud_id,
|
||||
api_key=api_key,
|
||||
opaque_id=opaque_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if http_auth is not None:
|
||||
if isinstance(http_auth, (tuple, list)):
|
||||
http_auth = ":".join(http_auth)
|
||||
self.headers.update(urllib3.make_headers(basic_auth=http_auth))
|
||||
|
||||
# if providing an SSL context, raise error if any other SSL related flag is used
|
||||
if ssl_context and (
|
||||
(verify_certs is not VERIFY_CERTS_DEFAULT)
|
||||
or (ssl_show_warn is not SSL_SHOW_WARN_DEFAULT)
|
||||
or ca_certs
|
||||
or client_cert
|
||||
or client_key
|
||||
or ssl_version
|
||||
):
|
||||
warnings.warn(
|
||||
"When using `ssl_context`, all other SSL related kwargs are ignored"
|
||||
)
|
||||
|
||||
self.ssl_assert_fingerprint = ssl_assert_fingerprint
|
||||
if self.use_ssl and ssl_context is None:
|
||||
ssl_context = ssl.SSLContext(ssl_version or ssl.PROTOCOL_TLS)
|
||||
|
||||
# Convert all sentinel values to their actual default
|
||||
# values if not using an SSLContext.
|
||||
if verify_certs is VERIFY_CERTS_DEFAULT:
|
||||
verify_certs = True
|
||||
if ssl_show_warn is SSL_SHOW_WARN_DEFAULT:
|
||||
ssl_show_warn = True
|
||||
|
||||
if verify_certs:
|
||||
ssl_context.verify_mode = ssl.CERT_REQUIRED
|
||||
ssl_context.check_hostname = True
|
||||
else:
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
ssl_context.check_hostname = False
|
||||
|
||||
ca_certs = CA_CERTS if ca_certs is None else ca_certs
|
||||
if verify_certs:
|
||||
if not ca_certs:
|
||||
raise ImproperlyConfigured(
|
||||
"Root certificates are missing for certificate "
|
||||
"validation. Either pass them in using the ca_certs parameter or "
|
||||
"install certifi to use it automatically."
|
||||
)
|
||||
else:
|
||||
if ssl_show_warn:
|
||||
warnings.warn(
|
||||
"Connecting to %s using SSL with verify_certs=False is insecure."
|
||||
% self.host
|
||||
)
|
||||
|
||||
if os.path.isfile(ca_certs):
|
||||
ssl_context.load_verify_locations(cafile=ca_certs)
|
||||
elif os.path.isdir(ca_certs):
|
||||
ssl_context.load_verify_locations(capath=ca_certs)
|
||||
else:
|
||||
raise ImproperlyConfigured("ca_certs parameter is not a path")
|
||||
|
||||
self.headers.setdefault("connection", "keep-alive")
|
||||
self.loop = loop
|
||||
self.session = None
|
||||
|
||||
# Parameters for creating an aiohttp.ClientSession later.
|
||||
self._limit = maxsize
|
||||
self._http_auth = http_auth
|
||||
self._ssl_context = ssl_context
|
||||
|
||||
async def perform_request(
|
||||
self, method, url, params=None, body=None, timeout=None, ignore=(), headers=None
|
||||
):
|
||||
if self.session is None:
|
||||
await self._create_aiohttp_session()
|
||||
|
||||
orig_body = body
|
||||
url_path = url
|
||||
if params:
|
||||
query_string = urlencode(params)
|
||||
else:
|
||||
query_string = ""
|
||||
|
||||
# There is a bug in aiohttp that disables the re-use
|
||||
# of the connection in the pool when method=HEAD.
|
||||
# See: aio-libs/aiohttp#1769
|
||||
is_head = False
|
||||
if method == "HEAD":
|
||||
method = "GET"
|
||||
is_head = True
|
||||
|
||||
# Provide correct URL object to avoid string parsing in low-level code
|
||||
url = yarl.URL.build(
|
||||
scheme=self.scheme,
|
||||
host=self.hostname,
|
||||
port=self.port,
|
||||
path=url,
|
||||
query_string=query_string,
|
||||
encoded=True,
|
||||
)
|
||||
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=timeout if timeout is not None else self.timeout
|
||||
)
|
||||
|
||||
req_headers = self.headers.copy()
|
||||
if headers:
|
||||
req_headers.update(headers)
|
||||
|
||||
if self.http_compress and body:
|
||||
body = self._gzip_compress(body)
|
||||
req_headers["content-encoding"] = "gzip"
|
||||
|
||||
start = self.loop.time()
|
||||
try:
|
||||
async with self.session.request(
|
||||
method,
|
||||
url,
|
||||
data=body,
|
||||
headers=req_headers,
|
||||
timeout=timeout,
|
||||
fingerprint=self.ssl_assert_fingerprint,
|
||||
) as response:
|
||||
if is_head: # We actually called 'GET' so throw away the data.
|
||||
await response.release()
|
||||
raw_data = ""
|
||||
else:
|
||||
raw_data = await response.text()
|
||||
duration = self.loop.time() - start
|
||||
|
||||
# We want to reraise a cancellation.
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
self.log_request_fail(
|
||||
method, url, url_path, orig_body, self.loop.time() - start, exception=e
|
||||
)
|
||||
if isinstance(e, ServerFingerprintMismatch):
|
||||
raise SSLError("N/A", str(e), e)
|
||||
if isinstance(e, (asyncio.TimeoutError, ServerTimeoutError)):
|
||||
raise ConnectionTimeout("TIMEOUT", str(e), e)
|
||||
raise ConnectionError("N/A", str(e), e)
|
||||
|
||||
# raise warnings if any from the 'Warnings' header.
|
||||
warning_headers = response.headers.getall("warning", ())
|
||||
self._raise_warnings(warning_headers)
|
||||
|
||||
# raise errors based on http status codes, let the client handle those if needed
|
||||
if not (200 <= response.status < 300) and response.status not in ignore:
|
||||
self.log_request_fail(
|
||||
method,
|
||||
url,
|
||||
url_path,
|
||||
orig_body,
|
||||
duration,
|
||||
status_code=response.status,
|
||||
response=raw_data,
|
||||
)
|
||||
self._raise_error(response.status, raw_data)
|
||||
|
||||
self.log_request_success(
|
||||
method, url, url_path, orig_body, response.status, raw_data, duration
|
||||
)
|
||||
|
||||
return response.status, response.headers, raw_data
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Explicitly closes connection
|
||||
"""
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
|
||||
async def _create_aiohttp_session(self):
|
||||
"""Creates an aiohttp.ClientSession(). This is delayed until
|
||||
the first call to perform_request() so that AsyncTransport has
|
||||
a chance to set AIOHttpConnection.loop
|
||||
"""
|
||||
if self.loop is None:
|
||||
self.loop = get_running_loop()
|
||||
self.session = aiohttp.ClientSession(
|
||||
headers=self.headers,
|
||||
auto_decompress=True,
|
||||
loop=self.loop,
|
||||
cookie_jar=aiohttp.DummyCookieJar(),
|
||||
response_class=ESClientResponse,
|
||||
connector=aiohttp.TCPConnector(
|
||||
limit=self._limit, use_dns_cache=True, ssl=self._ssl_context,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ESClientResponse(aiohttp.ClientResponse):
|
||||
async def text(self, encoding=None, errors="strict"):
|
||||
if self._body is None:
|
||||
await self.read()
|
||||
|
||||
return self._body.decode("utf-8", "surrogatepass")
|
||||
@@ -65,6 +65,7 @@ class Connection(object):
|
||||
http_compress=None,
|
||||
cloud_id=None,
|
||||
api_key=None,
|
||||
opaque_id=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
@@ -101,6 +102,8 @@ class Connection(object):
|
||||
headers = headers or {}
|
||||
for key in headers:
|
||||
self.headers[key.lower()] = headers[key]
|
||||
if opaque_id:
|
||||
self.headers["x-opaque-id"] = opaque_id
|
||||
|
||||
self.headers.setdefault("content-type", "application/json")
|
||||
self.headers.setdefault("user-agent", self._get_default_user_agent())
|
||||
@@ -118,6 +121,7 @@ class Connection(object):
|
||||
self.use_ssl = use_ssl
|
||||
self.http_compress = http_compress or False
|
||||
|
||||
self.scheme = scheme
|
||||
self.hostname = host
|
||||
self.port = port
|
||||
self.host = "%s://%s" % (scheme, host)
|
||||
@@ -170,9 +174,7 @@ class Connection(object):
|
||||
warning_messages.append(header)
|
||||
|
||||
for message in warning_messages:
|
||||
warnings.warn(
|
||||
message, category=ElasticsearchDeprecationWarning, stacklevel=6
|
||||
)
|
||||
warnings.warn(message, category=ElasticsearchDeprecationWarning)
|
||||
|
||||
def _pretty_json(self, data):
|
||||
# pretty JSON in tracer curl logs
|
||||
|
||||
@@ -25,6 +25,7 @@ tests_require = [
|
||||
"pytest",
|
||||
"pytest-cov",
|
||||
]
|
||||
async_require = ["aiohttp>=3,<4", "yarl"]
|
||||
|
||||
docs_require = ["sphinx<1.7", "sphinx_rtd_theme"]
|
||||
generate_require = ["black", "jinja2"]
|
||||
@@ -67,5 +68,6 @@ setup(
|
||||
"develop": tests_require + docs_require + generate_require,
|
||||
"docs": docs_require,
|
||||
"requests": ["requests>=2.4.0, <3.0.0"],
|
||||
"async": async_require,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -78,9 +78,13 @@ def run_all(argv=None):
|
||||
"--log-level=DEBUG",
|
||||
"--cache-clear",
|
||||
"-vv",
|
||||
join(abspath(dirname(__file__)), "test_exceptions.py"),
|
||||
]
|
||||
|
||||
if sys.version_info < (3, 6):
|
||||
argv.append("--ignore=test_elasticsearch/test_async/")
|
||||
|
||||
argv.append(abspath(dirname(__file__)),)
|
||||
|
||||
exit_code = 0
|
||||
try:
|
||||
subprocess.check_call(argv, stdout=sys.stdout, stderr=sys.stderr)
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
# Licensed to Elasticsearch B.V under one or more agreements.
|
||||
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
|
||||
# See the LICENSE file in the project root for more information
|
||||
@@ -0,0 +1,310 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Licensed to Elasticsearch B.V under one or more agreements.
|
||||
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
|
||||
# See the LICENSE file in the project root for more information
|
||||
|
||||
import ssl
|
||||
import gzip
|
||||
import io
|
||||
from mock import patch
|
||||
import warnings
|
||||
from platform import python_version
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
from elasticsearch import AIOHttpConnection
|
||||
from elasticsearch import __versionstr__
|
||||
from ..test_cases import TestCase, SkipTest
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def gzip_decompress(data):
|
||||
buf = gzip.GzipFile(fileobj=io.BytesIO(data), mode="rb")
|
||||
return buf.read()
|
||||
|
||||
|
||||
class TestAIOHttpConnection(TestCase):
|
||||
async def _get_mock_connection(self, connection_params={}, response_body=b"{}"):
|
||||
con = AIOHttpConnection(**connection_params)
|
||||
await con._create_aiohttp_session()
|
||||
|
||||
def _dummy_request(*args, **kwargs):
|
||||
class DummyResponse:
|
||||
async def __aenter__(self, *_, **__):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *_, **__):
|
||||
pass
|
||||
|
||||
async def text(self):
|
||||
return response_body.decode("utf-8", "surrogatepass")
|
||||
|
||||
dummy_response = DummyResponse()
|
||||
dummy_response.headers = {}
|
||||
dummy_response.status = 200
|
||||
_dummy_request.call_args = (args, kwargs)
|
||||
return dummy_response
|
||||
|
||||
con.session.request = _dummy_request
|
||||
return con
|
||||
|
||||
async def test_ssl_context(self):
|
||||
try:
|
||||
context = ssl.create_default_context()
|
||||
except AttributeError:
|
||||
# if create_default_context raises an AttributeError Exception
|
||||
# it means SSLContext is not available for that version of python
|
||||
# and we should skip this test.
|
||||
raise SkipTest(
|
||||
"Test test_ssl_context is skipped cause SSLContext is not available for this version of ptyhon"
|
||||
)
|
||||
|
||||
con = AIOHttpConnection(use_ssl=True, ssl_context=context)
|
||||
await con._create_aiohttp_session()
|
||||
self.assertTrue(con.use_ssl)
|
||||
self.assertEqual(con.session.connector._ssl, context)
|
||||
|
||||
def test_opaque_id(self):
|
||||
con = AIOHttpConnection(opaque_id="app-1")
|
||||
self.assertEqual(con.headers["x-opaque-id"], "app-1")
|
||||
|
||||
def test_http_cloud_id(self):
|
||||
con = AIOHttpConnection(
|
||||
cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng=="
|
||||
)
|
||||
self.assertTrue(con.use_ssl)
|
||||
self.assertEqual(
|
||||
con.host, "https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io"
|
||||
)
|
||||
self.assertEqual(con.port, None)
|
||||
self.assertEqual(
|
||||
con.hostname, "4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io"
|
||||
)
|
||||
self.assertTrue(con.http_compress)
|
||||
|
||||
con = AIOHttpConnection(
|
||||
cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==",
|
||||
port=9243,
|
||||
)
|
||||
self.assertEqual(
|
||||
con.host,
|
||||
"https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io:9243",
|
||||
)
|
||||
self.assertEqual(con.port, 9243)
|
||||
self.assertEqual(
|
||||
con.hostname, "4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io"
|
||||
)
|
||||
|
||||
def test_api_key_auth(self):
|
||||
# test with tuple
|
||||
con = AIOHttpConnection(
|
||||
cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==",
|
||||
api_key=("elastic", "changeme1"),
|
||||
)
|
||||
self.assertEqual(
|
||||
con.headers["authorization"], "ApiKey ZWxhc3RpYzpjaGFuZ2VtZTE="
|
||||
)
|
||||
self.assertEqual(
|
||||
con.host, "https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io"
|
||||
)
|
||||
|
||||
# test with base64 encoded string
|
||||
con = AIOHttpConnection(
|
||||
cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==",
|
||||
api_key="ZWxhc3RpYzpjaGFuZ2VtZTI=",
|
||||
)
|
||||
self.assertEqual(
|
||||
con.headers["authorization"], "ApiKey ZWxhc3RpYzpjaGFuZ2VtZTI="
|
||||
)
|
||||
self.assertEqual(
|
||||
con.host, "https://4fa8821e75634032bed1cf22110e2f97.us-east-1.aws.found.io"
|
||||
)
|
||||
|
||||
async def test_no_http_compression(self):
|
||||
con = await self._get_mock_connection()
|
||||
self.assertFalse(con.http_compress)
|
||||
self.assertNotIn("accept-encoding", con.headers)
|
||||
|
||||
await con.perform_request("GET", "/")
|
||||
|
||||
_, kwargs = con.session.request.call_args
|
||||
|
||||
self.assertFalse(kwargs["data"])
|
||||
self.assertNotIn("accept-encoding", kwargs["headers"])
|
||||
self.assertNotIn("content-encoding", kwargs["headers"])
|
||||
|
||||
async def test_http_compression(self):
|
||||
con = await self._get_mock_connection({"http_compress": True})
|
||||
self.assertTrue(con.http_compress)
|
||||
self.assertEqual(con.headers["accept-encoding"], "gzip,deflate")
|
||||
|
||||
# 'content-encoding' shouldn't be set at a connection level.
|
||||
# Should be applied only if the request is sent with a body.
|
||||
self.assertNotIn("content-encoding", con.headers)
|
||||
|
||||
await con.perform_request("GET", "/", body=b"{}")
|
||||
|
||||
_, kwargs = con.session.request.call_args
|
||||
|
||||
self.assertEqual(gzip_decompress(kwargs["data"]), b"{}")
|
||||
self.assertEqual(kwargs["headers"]["accept-encoding"], "gzip,deflate")
|
||||
self.assertEqual(kwargs["headers"]["content-encoding"], "gzip")
|
||||
|
||||
await con.perform_request("GET", "/")
|
||||
|
||||
_, kwargs = con.session.request.call_args
|
||||
|
||||
self.assertFalse(kwargs["data"])
|
||||
self.assertEqual(kwargs["headers"]["accept-encoding"], "gzip,deflate")
|
||||
self.assertNotIn("content-encoding", kwargs["headers"])
|
||||
|
||||
def test_cloud_id_http_compress_override(self):
|
||||
# 'http_compress' will be 'True' by default for connections with
|
||||
# 'cloud_id' set but should prioritize user-defined values.
|
||||
con = AIOHttpConnection(
|
||||
cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==",
|
||||
)
|
||||
self.assertEqual(con.http_compress, True)
|
||||
|
||||
con = AIOHttpConnection(
|
||||
cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==",
|
||||
http_compress=False,
|
||||
)
|
||||
self.assertEqual(con.http_compress, False)
|
||||
|
||||
con = AIOHttpConnection(
|
||||
cloud_id="cluster:dXMtZWFzdC0xLmF3cy5mb3VuZC5pbyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5NyQ0ZmE4ODIxZTc1NjM0MDMyYmVkMWNmMjIxMTBlMmY5Ng==",
|
||||
http_compress=True,
|
||||
)
|
||||
self.assertEqual(con.http_compress, True)
|
||||
|
||||
def test_default_user_agent(self):
|
||||
con = AIOHttpConnection()
|
||||
self.assertEqual(
|
||||
con._get_default_user_agent(),
|
||||
"elasticsearch-py/%s (Python %s)" % (__versionstr__, python_version()),
|
||||
)
|
||||
|
||||
def test_timeout_set(self):
|
||||
con = AIOHttpConnection(timeout=42)
|
||||
self.assertEqual(42, con.timeout)
|
||||
|
||||
def test_keep_alive_is_on_by_default(self):
|
||||
con = AIOHttpConnection()
|
||||
self.assertEqual(
|
||||
{
|
||||
"connection": "keep-alive",
|
||||
"content-type": "application/json",
|
||||
"user-agent": con._get_default_user_agent(),
|
||||
},
|
||||
con.headers,
|
||||
)
|
||||
|
||||
def test_http_auth(self):
|
||||
con = AIOHttpConnection(http_auth="username:secret")
|
||||
self.assertEqual(
|
||||
{
|
||||
"authorization": "Basic dXNlcm5hbWU6c2VjcmV0",
|
||||
"connection": "keep-alive",
|
||||
"content-type": "application/json",
|
||||
"user-agent": con._get_default_user_agent(),
|
||||
},
|
||||
con.headers,
|
||||
)
|
||||
|
||||
def test_http_auth_tuple(self):
|
||||
con = AIOHttpConnection(http_auth=("username", "secret"))
|
||||
self.assertEqual(
|
||||
{
|
||||
"authorization": "Basic dXNlcm5hbWU6c2VjcmV0",
|
||||
"content-type": "application/json",
|
||||
"connection": "keep-alive",
|
||||
"user-agent": con._get_default_user_agent(),
|
||||
},
|
||||
con.headers,
|
||||
)
|
||||
|
||||
def test_http_auth_list(self):
|
||||
con = AIOHttpConnection(http_auth=["username", "secret"])
|
||||
self.assertEqual(
|
||||
{
|
||||
"authorization": "Basic dXNlcm5hbWU6c2VjcmV0",
|
||||
"content-type": "application/json",
|
||||
"connection": "keep-alive",
|
||||
"user-agent": con._get_default_user_agent(),
|
||||
},
|
||||
con.headers,
|
||||
)
|
||||
|
||||
def test_uses_https_if_verify_certs_is_off(self):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
con = AIOHttpConnection(use_ssl=True, verify_certs=False)
|
||||
self.assertEqual(1, len(w))
|
||||
self.assertEqual(
|
||||
"Connecting to https://localhost:9200 using SSL with verify_certs=False is insecure.",
|
||||
str(w[0].message),
|
||||
)
|
||||
|
||||
self.assertTrue(con.use_ssl)
|
||||
self.assertEqual(con.scheme, "https")
|
||||
self.assertEqual(con.host, "https://localhost:9200")
|
||||
|
||||
async def test_nowarn_when_test_uses_https_if_verify_certs_is_off(self):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
con = AIOHttpConnection(
|
||||
use_ssl=True, verify_certs=False, ssl_show_warn=False
|
||||
)
|
||||
await con._create_aiohttp_session()
|
||||
self.assertEqual(0, len(w))
|
||||
|
||||
self.assertIsInstance(con.session, aiohttp.ClientSession)
|
||||
|
||||
def test_doesnt_use_https_if_not_specified(self):
|
||||
con = AIOHttpConnection()
|
||||
self.assertFalse(con.use_ssl)
|
||||
|
||||
def test_no_warning_when_using_ssl_context(self):
|
||||
ctx = ssl.create_default_context()
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
AIOHttpConnection(ssl_context=ctx)
|
||||
self.assertEqual(0, len(w), str([x.message for x in w]))
|
||||
|
||||
def test_warns_if_using_non_default_ssl_kwargs_with_ssl_context(self):
|
||||
for kwargs in (
|
||||
{"ssl_show_warn": False},
|
||||
{"ssl_show_warn": True},
|
||||
{"verify_certs": True},
|
||||
{"verify_certs": False},
|
||||
{"ca_certs": "/path/to/certs"},
|
||||
{"ssl_show_warn": True, "ca_certs": "/path/to/certs"},
|
||||
):
|
||||
kwargs["ssl_context"] = ssl.create_default_context()
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
|
||||
AIOHttpConnection(**kwargs)
|
||||
|
||||
self.assertEqual(1, len(w))
|
||||
self.assertEqual(
|
||||
"When using `ssl_context`, all other SSL related kwargs are ignored",
|
||||
str(w[0].message),
|
||||
)
|
||||
|
||||
@patch("elasticsearch.connection.base.logger")
|
||||
async def test_uncompressed_body_logged(self, logger):
|
||||
con = await self._get_mock_connection(connection_params={"http_compress": True})
|
||||
await con.perform_request("GET", "/", body=b'{"example": "body"}')
|
||||
|
||||
self.assertEqual(2, logger.debug.call_count)
|
||||
req, resp = logger.debug.call_args_list
|
||||
|
||||
self.assertEqual('> {"example": "body"}', req[0][0] % req[0][1:])
|
||||
self.assertEqual("< {}", resp[0][0] % resp[0][1:])
|
||||
|
||||
async def test_surrogatepass_into_bytes(self):
|
||||
buf = b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"
|
||||
con = await self._get_mock_connection(response_body=buf)
|
||||
status, headers, data = await con.perform_request("GET", "/")
|
||||
self.assertEqual(u"你好\uda6a", data)
|
||||
Reference in New Issue
Block a user