Files
opensearch-pyd/test_opensearchpy/test_async/test_connection.py
T
Merlin 12c379d32d Implement AsyncOpenSearch() parameter ssl_assert_hostname (#843)
* Implement AsyncOpenSearch() parameter `ssl_assert_hostname` to allow disabling SSL hostname verification

Signed-off-by: merlinz01 <158784988+merlinz01@users.noreply.github.com>

* Update PR link

Signed-off-by: merlinz01 <158784988+merlinz01@users.noreply.github.com>

* Add test

Signed-off-by: merlinz01 <158784988+merlinz01@users.noreply.github.com>

* Update docs

Signed-off-by: merlinz01 <158784988+merlinz01@users.noreply.github.com>

* Add test for default value

Signed-off-by: merlinz01 <158784988+merlinz01@users.noreply.github.com>

* Fix formatting

Signed-off-by: merlinz01 <158784988+merlinz01@users.noreply.github.com>

* Fix test failing on Python >3.12.7

Signed-off-by: merlinz01 <158784988+merlinz01@users.noreply.github.com>

* Fix formatting

Signed-off-by: merlinz01 <158784988+merlinz01@users.noreply.github.com>

---------

Signed-off-by: merlinz01 <158784988+merlinz01@users.noreply.github.com>
Signed-off-by: Daniel (dB.) Doubrovkine <dblock@amazon.com>
Co-authored-by: Daniel (dB.) Doubrovkine <dblock@amazon.com>
2024-11-16 08:29:10 -05:00

560 lines
20 KiB
Python

# SPDX-License-Identifier: Apache-2.0
#
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
#
# Modifications Copyright OpenSearch Contributors. See
# GitHub history for details.
#
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import gzip
import io
import json
import ssl
import sys
import warnings
from platform import python_version
from typing import Any
from unittest.mock import MagicMock, patch
import aiohttp
import pytest
from _pytest.mark.structures import MarkDecorator
from multidict import CIMultiDict
from pytest import raises
from opensearchpy import AIOHttpConnection, AsyncOpenSearch, __versionstr__, serializer
from opensearchpy.compat import reraise_exceptions
from opensearchpy.connection import Connection, async_connections
from opensearchpy.exceptions import ConnectionError, NotFoundError, TransportError
from test_opensearchpy.test_http_server import TestHTTPServer
pytestmark: MarkDecorator = pytest.mark.asyncio
class TestAIOHttpConnection:
async def _get_mock_connection(
self,
connection_params: Any = {},
response_code: int = 200,
response_body: bytes = b"{}",
response_headers: Any = {},
) -> Any:
con = AIOHttpConnection(**connection_params)
await con._create_aiohttp_session()
def _dummy_request(*args: Any, **kwargs: Any) -> Any:
class DummyResponse:
async def __aenter__(self, *_: Any, **__: Any) -> Any:
return self
async def __aexit__(self, *_: Any, **__: Any) -> None:
pass
async def text(self) -> Any:
return response_body.decode("utf-8", "surrogatepass")
dummy_response: Any = DummyResponse()
dummy_response.headers = CIMultiDict(**response_headers)
dummy_response.status = response_code
_dummy_request.call_args = (args, kwargs) # type: ignore
return dummy_response
con.session.request = _dummy_request
return con
async def test_ssl_context(self) -> None:
try:
context = ssl.create_default_context()
except AttributeError:
# if create_default_context raises an AttributeError Exception
# it means SSLContext is not available for that version of python
# and we should skip this test.
pytest.skip(
"Test test_ssl_context is skipped cause SSLContext is "
"not available for this version of Python"
)
con = AIOHttpConnection(use_ssl=True, ssl_context=context)
await con._create_aiohttp_session()
assert con.use_ssl
assert con.session.connector._ssl == context
async def test_ssl_assert_hostname(self) -> None:
con = AIOHttpConnection(use_ssl=True, ssl_assert_hostname=True)
await con._create_aiohttp_session()
assert con.use_ssl
assert con.session.connector._ssl.check_hostname is True
con = AIOHttpConnection(use_ssl=True, ssl_assert_hostname=False)
await con._create_aiohttp_session()
assert con.use_ssl
assert con.session.connector._ssl.check_hostname is False
async def test_opaque_id(self) -> None:
con = AIOHttpConnection(opaque_id="app-1")
assert con.headers["x-opaque-id"] == "app-1"
async def test_no_http_compression(self) -> None:
con = await self._get_mock_connection()
assert not con.http_compress
assert "accept-encoding" not in con.headers
await con.perform_request("GET", "/")
_, kwargs = con.session.request.call_args
assert not kwargs["data"]
assert "accept-encoding" not in kwargs["headers"]
assert "content-encoding" not in kwargs["headers"]
async def test_http_compression(self) -> None:
con = await self._get_mock_connection({"http_compress": True})
assert con.http_compress
assert con.headers["accept-encoding"] == "gzip,deflate"
# 'content-encoding' shouldn't be set at a connection level.
# Should be applied only if the request is sent with a body.
assert "content-encoding" not in con.headers
await con.perform_request("GET", "/", body=b"{}")
_, kwargs = con.session.request.call_args
buf = gzip.GzipFile(fileobj=io.BytesIO(kwargs["data"]), mode="rb")
assert buf.read() == b"{}"
assert kwargs["headers"]["accept-encoding"] == "gzip,deflate"
assert kwargs["headers"]["content-encoding"] == "gzip"
await con.perform_request("GET", "/")
_, kwargs = con.session.request.call_args
assert not kwargs["data"]
assert kwargs["headers"]["accept-encoding"] == "gzip,deflate"
assert "content-encoding" not in kwargs["headers"]
async def test_url_prefix(self) -> None:
con = await self._get_mock_connection(
connection_params={"url_prefix": "/_search/"}
)
assert con.url_prefix == "/_search"
await con.perform_request("GET", "/")
# Need to convert the yarl URL to a string to compare.
method, yarl_url = con.session.request.call_args[0]
assert method == "GET" and str(yarl_url) == "http://localhost:9200/_search/"
async def test_default_user_agent(self) -> None:
con = AIOHttpConnection()
assert con._get_default_user_agent() == "opensearch-py/{} (Python {})".format(
__versionstr__,
python_version(),
)
async def test_timeout_set(self) -> None:
con = AIOHttpConnection(timeout=42)
assert 42 == con.timeout
async def test_keep_alive_is_on_by_default(self) -> None:
con = AIOHttpConnection()
assert {
"connection": "keep-alive",
"content-type": "application/json",
"user-agent": con._get_default_user_agent(),
} == con.headers
async def test_http_auth(self) -> None:
con = AIOHttpConnection(http_auth="username:secret")
assert {
"authorization": "Basic dXNlcm5hbWU6c2VjcmV0",
"connection": "keep-alive",
"content-type": "application/json",
"user-agent": con._get_default_user_agent(),
} == con.headers
async def test_http_auth_tuple(self) -> None:
con = AIOHttpConnection(http_auth=("username", "secret"))
assert {
"authorization": "Basic dXNlcm5hbWU6c2VjcmV0",
"content-type": "application/json",
"connection": "keep-alive",
"user-agent": con._get_default_user_agent(),
} == con.headers
async def test_http_auth_list(self) -> None:
con = AIOHttpConnection(http_auth=["username", "secret"])
assert {
"authorization": "Basic dXNlcm5hbWU6c2VjcmV0",
"content-type": "application/json",
"connection": "keep-alive",
"user-agent": con._get_default_user_agent(),
} == con.headers
async def test_uses_https_if_verify_certs_is_off(self) -> None:
with warnings.catch_warnings(record=True) as w:
con = AIOHttpConnection(use_ssl=True, verify_certs=False)
assert 1 == len(w)
assert (
"Connecting to https://localhost:9200 using SSL with "
"verify_certs=False is insecure." == str(w[0].message)
)
assert con.use_ssl
assert con.scheme == "https"
assert con.host == "https://localhost:9200"
async def test_nowarn_when_test_uses_https_if_verify_certs_is_off(self) -> None:
with warnings.catch_warnings(record=True) as w:
con = AIOHttpConnection(
use_ssl=True, verify_certs=False, ssl_show_warn=False
)
await con._create_aiohttp_session()
if sys.hexversion < 0x30C0700:
assert w == []
else:
assert len(w) == 1
assert (
str(w[0].message) == "enable_cleanup_closed ignored because "
"https://github.com/python/cpython/pull/118960 is fixed in "
"Python version sys.version_info(major=3, minor=12, micro=7, releaselevel='final', serial=0)"
)
assert isinstance(con.session, aiohttp.ClientSession)
async def test_doesnt_use_https_if_not_specified(self) -> None:
con = AIOHttpConnection()
assert not con.use_ssl
async def test_no_warning_when_using_ssl_context(self) -> None:
ctx = ssl.create_default_context()
with warnings.catch_warnings(record=True) as w:
AIOHttpConnection(ssl_context=ctx)
assert w == [], str([x.message for x in w])
async def test_warns_if_using_non_default_ssl_kwargs_with_ssl_context(self) -> None:
kwargs: Any
for kwargs in (
{"ssl_show_warn": False},
{"ssl_show_warn": True},
{"verify_certs": True},
{"verify_certs": False},
{"ca_certs": "/path/to/certs"},
{"ssl_show_warn": True, "ca_certs": "/path/to/certs"},
):
kwargs["ssl_context"] = ssl.create_default_context()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
AIOHttpConnection(**kwargs)
assert 1 == len(w)
assert (
"When using `ssl_context`, all other SSL related kwargs are ignored"
== str(w[0].message)
)
@patch("ssl.SSLContext", return_value=MagicMock())
async def test_uses_given_ca_certs(self, ssl_context: Any, tmp_path: Any) -> None:
path = tmp_path / "ca_certs.pem"
path.touch()
ssl_context.return_value.load_verify_locations.return_value = None
AIOHttpConnection(use_ssl=True, ca_certs=str(path))
ssl_context.return_value.load_verify_locations.assert_called_once_with(
cafile=str(path)
)
@patch("ssl.SSLContext", return_value=MagicMock())
async def test_uses_default_ca_certs(self, ssl_context: Any) -> None:
ssl_context.return_value.load_verify_locations.return_value = None
AIOHttpConnection(use_ssl=True)
ssl_context.return_value.load_verify_locations.assert_called_once_with(
cafile=Connection.default_ca_certs()
)
@patch("ssl.SSLContext", return_value=MagicMock())
async def test_uses_no_ca_certs(self, ssl_context: Any) -> None:
ssl_context.return_value.load_verify_locations.return_value = None
AIOHttpConnection(use_ssl=True, verify_certs=False)
ssl_context.return_value.load_verify_locations.assert_not_called()
async def test_trust_env(self) -> None:
con: Any = AIOHttpConnection(trust_env=True)
await con._create_aiohttp_session()
assert con._trust_env is True
assert con.session.trust_env is True
async def test_trust_env_default_value_is_false(self) -> None:
con = AIOHttpConnection()
await con._create_aiohttp_session()
assert con._trust_env is False
assert con.session.trust_env is False
@patch("opensearchpy.connection.base.logger")
async def test_uncompressed_body_logged(self, logger: Any) -> None:
con = await self._get_mock_connection(connection_params={"http_compress": True})
await con.perform_request("GET", "/", body=b'{"example": "body"}')
assert 2 == logger.debug.call_count
req, resp = logger.debug.call_args_list
assert '> {"example": "body"}' == req[0][0] % req[0][1:]
assert "< {}" == resp[0][0] % resp[0][1:]
@patch("opensearchpy.connection.base.logger", return_value=MagicMock())
async def test_body_not_logged(self, logger: Any) -> None:
logger.isEnabledFor.return_value = False
con = await self._get_mock_connection()
await con.perform_request("GET", "/", body=b'{"example": "body"}')
assert logger.isEnabledFor.call_count == 1
assert logger.debug.call_count == 0
@patch("opensearchpy.connection.base.logger")
async def test_failure_body_logged(self, logger: Any) -> None:
con = await self._get_mock_connection(response_code=404)
with pytest.raises(NotFoundError) as e:
await con.perform_request("GET", "/invalid", body=b'{"example": "body"}')
assert str(e.value) == "NotFoundError(404, '{}')"
assert 2 == logger.debug.call_count
req, resp = logger.debug.call_args_list
assert '> {"example": "body"}' == req[0][0] % req[0][1:]
assert "< {}" == resp[0][0] % resp[0][1:]
@patch("opensearchpy.connection.base.logger", return_value=MagicMock())
async def test_failure_body_not_logged(self, logger: Any) -> None:
logger.isEnabledFor.return_value = False
con = await self._get_mock_connection(response_code=404)
with pytest.raises(NotFoundError) as e:
await con.perform_request("GET", "/invalid")
assert str(e.value) == "NotFoundError(404, '{}')"
assert logger.isEnabledFor.call_count == 1
assert logger.debug.call_count == 0
async def test_surrogatepass_into_bytes(self) -> None:
buf = b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"
con = await self._get_mock_connection(response_body=buf)
_, _, data = await con.perform_request("GET", "/")
assert "你好\uda6a" == data # fmt: skip
@pytest.mark.parametrize("exception_cls", reraise_exceptions) # type: ignore
async def test_recursion_error_reraised(self, exception_cls: Any) -> None:
conn = AIOHttpConnection()
def request_raise(*_: Any, **__: Any) -> Any:
raise exception_cls("Wasn't modified!")
await conn._create_aiohttp_session()
conn.session.request = request_raise
with pytest.raises(exception_cls) as e:
await conn.perform_request("GET", "/")
assert str(e.value) == "Wasn't modified!"
async def test_json_errors_are_parsed(self) -> None:
con = await self._get_mock_connection(
response_code=400,
response_body=b'{"error": {"type": "snapshot_in_progress_exception"}}',
response_headers={"Content-Type": "application/json;"},
)
try:
with pytest.raises(TransportError) as e:
await con.perform_request("POST", "/", body=b'{"some": "json"')
assert e.value.error == "snapshot_in_progress_exception"
finally:
await con.close()
class TestConnectionHttpServer:
"""Tests the HTTP connection implementations against a live server E2E"""
server: Any
@classmethod
def setup_class(cls) -> None:
"""
Start server
"""
cls.server = TestHTTPServer(port=8081)
cls.server.start()
@classmethod
def teardown_class(cls) -> None:
"""
stop server
"""
cls.server.stop()
async def httpserver(self, conn: Any, **kwargs: Any) -> Any:
status, _, data = await conn.perform_request("GET", "/", **kwargs)
data = json.loads(data)
return (status, data)
async def test_aiohttp_connection(self) -> None:
# Defaults
conn = AIOHttpConnection("localhost", port=8081, use_ssl=False)
user_agent = conn._get_default_user_agent()
status, data = await self.httpserver(conn)
assert status == 200
assert data["method"] == "GET"
assert data["headers"] == {
"Content-Type": "application/json",
"Host": "localhost:8081",
"User-Agent": user_agent,
}
# http_compress=False
conn = AIOHttpConnection(
"localhost", port=8081, use_ssl=False, http_compress=False
)
status, data = await self.httpserver(conn)
assert status == 200
assert data["method"] == "GET"
assert data["headers"] == {
"Content-Type": "application/json",
"Host": "localhost:8081",
"User-Agent": user_agent,
}
# http_compress=True
conn = AIOHttpConnection(
"localhost", port=8081, use_ssl=False, http_compress=True
)
status, data = await self.httpserver(conn)
assert status == 200
assert data["headers"] == {
"Accept-Encoding": "gzip,deflate",
"Content-Type": "application/json",
"Host": "localhost:8081",
"User-Agent": user_agent,
}
# Headers
conn = AIOHttpConnection(
"localhost",
port=8081,
use_ssl=False,
http_compress=True,
headers={"header1": "value1"},
)
status, data = await self.httpserver(
conn, headers={"header2": "value2", "header1": "override!"}
)
assert status == 200
assert data["headers"] == {
"Accept-Encoding": "gzip,deflate",
"Content-Type": "application/json",
"Host": "localhost:8081",
"Header1": "override!",
"Header2": "value2",
"User-Agent": user_agent,
}
async def test_aiohttp_connection_error(self) -> None:
conn = AIOHttpConnection("not.a.host.name")
with pytest.raises(ConnectionError):
await conn.perform_request("GET", "/")
async def test_default_connection_is_returned_by_default() -> None:
c = async_connections.AsyncConnections()
con, con2 = object(), object()
await c.add_connection("default", con)
await c.add_connection("not-default", con2)
assert await c.get_connection() is con
async def test_get_connection_created_connection_if_needed() -> None:
c = async_connections.AsyncConnections()
await c.configure(
default={"hosts": ["opensearch.com"]}, local={"hosts": ["localhost"]}
)
default = await c.get_connection()
local = await c.get_connection("local")
assert isinstance(default, AsyncOpenSearch)
assert isinstance(local, AsyncOpenSearch)
assert [{"host": "opensearch.com"}] == default.transport.hosts
assert [{"host": "localhost"}] == local.transport.hosts
async def test_configure_preserves_unchanged_connections() -> None:
c = async_connections.AsyncConnections()
await c.configure(
default={"hosts": ["opensearch.com"]}, local={"hosts": ["localhost"]}
)
default = await c.get_connection()
local = await c.get_connection("local")
await c.configure(
default={"hosts": ["not-opensearch.com"]}, local={"hosts": ["localhost"]}
)
new_default = await c.get_connection()
new_local = await c.get_connection("local")
assert new_local is local
assert new_default is not default
async def test_remove_connection_removes_both_conn_and_conf() -> None:
c = async_connections.AsyncConnections()
await c.configure(
default={"hosts": ["opensearch.com"]}, local={"hosts": ["localhost"]}
)
await c.add_connection("local2", object())
await c.remove_connection("default")
await c.get_connection("local2")
await c.remove_connection("local2")
with raises(Exception):
await c.get_connection("local2")
await c.get_connection("default")
async def test_create_connection_constructs_client() -> None:
c = async_connections.AsyncConnections()
await c.create_connection("testing", hosts=["opensearch.com"])
con = await c.get_connection("testing")
assert [{"host": "opensearch.com"}] == con.transport.hosts
async def test_create_connection_adds_our_serializer() -> None:
c = async_connections.AsyncConnections()
await c.create_connection("testing", hosts=["opensearch.com"])
result = await c.get_connection("testing")
assert result.transport.serializer is serializer.serializer