@@ -61,6 +61,7 @@ from .exceptions import (
|
||||
SSLError,
|
||||
TransportError,
|
||||
)
|
||||
from .helpers import AWSV4SignerAuth
|
||||
from .serializer import JSONSerializer
|
||||
from .transport import Transport
|
||||
|
||||
@@ -92,6 +93,7 @@ __all__ = [
|
||||
"AuthorizationException",
|
||||
"OpenSearchWarning",
|
||||
"OpenSearchDeprecationWarning",
|
||||
"AWSV4SignerAuth",
|
||||
]
|
||||
|
||||
try:
|
||||
|
||||
@@ -57,6 +57,7 @@ try:
|
||||
from ._async.client import AsyncOpenSearch as AsyncOpenSearch
|
||||
from ._async.http_aiohttp import AIOHttpConnection as AIOHttpConnection
|
||||
from ._async.transport import AsyncTransport as AsyncTransport
|
||||
from .helpers import AWSV4SignerAuth as AWSV4SignerAuth
|
||||
except (ImportError, SyntaxError):
|
||||
pass
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
|
||||
from .connection import Connection
|
||||
|
||||
try:
|
||||
from Queue import PriorityQueue # type: ignore
|
||||
from Queue import PriorityQueue
|
||||
except ImportError:
|
||||
from queue import PriorityQueue
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ from .actions import (
|
||||
streaming_bulk,
|
||||
)
|
||||
from .errors import BulkIndexError, ScanError
|
||||
from .signer import AWSV4SignerAuth
|
||||
|
||||
__all__ = [
|
||||
"BulkIndexError",
|
||||
@@ -49,6 +50,7 @@ __all__ = [
|
||||
"reindex",
|
||||
"_chunk_actions",
|
||||
"_process_bulk_chunk",
|
||||
"AWSV4SignerAuth",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -46,5 +46,6 @@ try:
|
||||
from .._async.helpers import async_reindex as async_reindex
|
||||
from .._async.helpers import async_scan as async_scan
|
||||
from .._async.helpers import async_streaming_bulk as async_streaming_bulk
|
||||
from .signer import AWSV4SignerAuth as AWSV4SignerAuth
|
||||
except (ImportError, SyntaxError):
|
||||
pass
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# The OpenSearch Contributors require contributions made to
|
||||
# this file be licensed under the Apache-2.0 license or a
|
||||
# compatible open source license.
|
||||
#
|
||||
# Modifications Copyright OpenSearch Contributors. See
|
||||
# GitHub history for details.
|
||||
|
||||
import sys
|
||||
|
||||
import requests
|
||||
|
||||
OPENSEARCH_SERVICE = "es"
|
||||
|
||||
PY3 = sys.version_info[0] == 3
|
||||
|
||||
if PY3:
|
||||
from urllib.parse import parse_qs, urlencode, urlparse
|
||||
|
||||
|
||||
def fetch_url(prepared_request): # type: ignore
|
||||
"""
|
||||
This is a util method that helps in reconstructing the request url.
|
||||
:param prepared_request: unsigned request
|
||||
:return: reconstructed url
|
||||
"""
|
||||
url = urlparse(prepared_request.url)
|
||||
path = url.path or "/"
|
||||
|
||||
# fetch the query string if present in the request
|
||||
querystring = ""
|
||||
if url.query:
|
||||
querystring = "?" + urlencode(
|
||||
parse_qs(url.query, keep_blank_values=True), doseq=True
|
||||
)
|
||||
|
||||
# fetch the host information from headers
|
||||
headers = dict(
|
||||
(key.lower(), value) for key, value in prepared_request.headers.items()
|
||||
)
|
||||
location = headers.get("host") or url.netloc
|
||||
|
||||
# construct the url and return
|
||||
return url.scheme + "://" + location + path + querystring
|
||||
|
||||
|
||||
class AWSV4SignerAuth(requests.auth.AuthBase):
|
||||
"""
|
||||
AWS V4 Request Signer for Requests.
|
||||
"""
|
||||
|
||||
def __init__(self, credentials, region): # type: ignore
|
||||
if not credentials:
|
||||
raise ValueError("Credentials cannot be empty")
|
||||
self.credentials = credentials
|
||||
|
||||
if not region:
|
||||
raise ValueError("Region cannot be empty")
|
||||
self.region = region
|
||||
|
||||
def __call__(self, request): # type: ignore
|
||||
return self._sign_request(request) # type: ignore
|
||||
|
||||
def _sign_request(self, prepared_request): # type: ignore
|
||||
"""
|
||||
This method helps in signing the request by injecting the required headers.
|
||||
:param prepared_request: unsigned request
|
||||
:return: signed request
|
||||
"""
|
||||
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
|
||||
url = fetch_url(prepared_request) # type: ignore
|
||||
|
||||
# create an AWS request object and sign it using SigV4Auth
|
||||
aws_request = AWSRequest(
|
||||
method=prepared_request.method.upper(),
|
||||
url=url,
|
||||
data=prepared_request.body,
|
||||
)
|
||||
sig_v4_auth = SigV4Auth(self.credentials, OPENSEARCH_SERVICE, self.region)
|
||||
sig_v4_auth.add_auth(aws_request)
|
||||
|
||||
# copy the headers from AWS request object into the prepared_request
|
||||
prepared_request.headers.update(dict(aws_request.headers.items()))
|
||||
|
||||
return prepared_request
|
||||
@@ -12,3 +12,6 @@ junit_family=legacy
|
||||
|
||||
[tool:isort]
|
||||
profile=black
|
||||
|
||||
[mypy]
|
||||
ignore_missing_imports=True
|
||||
|
||||
@@ -59,6 +59,7 @@ tests_require = [
|
||||
"pyyaml",
|
||||
"pytest",
|
||||
"pytest-cov",
|
||||
"botocore;python_version>='3.6'",
|
||||
]
|
||||
async_require = ["aiohttp>=3,<4"]
|
||||
|
||||
|
||||
@@ -31,7 +31,9 @@ import json
|
||||
import os
|
||||
import re
|
||||
import ssl
|
||||
import sys
|
||||
import unittest
|
||||
import uuid
|
||||
import warnings
|
||||
from platform import python_version
|
||||
|
||||
@@ -284,6 +286,61 @@ class TestUrllib3Connection(TestCase):
|
||||
con.headers,
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 6), reason="AWSV4SignerAuth requires python3.6+"
|
||||
)
|
||||
def test_aws_signer_as_http_auth(self):
|
||||
region = "us-west-2"
|
||||
|
||||
import requests
|
||||
|
||||
from opensearchpy.helpers.signer import AWSV4SignerAuth
|
||||
|
||||
auth = AWSV4SignerAuth(self.mock_session(), region)
|
||||
con = RequestsHttpConnection(http_auth=auth)
|
||||
prepared_request = requests.Request("GET", "http://localhost").prepare()
|
||||
auth(prepared_request)
|
||||
self.assertEqual(auth, con.session.auth)
|
||||
self.assertIn("Authorization", prepared_request.headers)
|
||||
self.assertIn("X-Amz-Date", prepared_request.headers)
|
||||
self.assertIn("X-Amz-Security-Token", prepared_request.headers)
|
||||
|
||||
def test_aws_signer_when_region_is_null(self):
|
||||
session = self.mock_session()
|
||||
|
||||
from opensearchpy.helpers.signer import AWSV4SignerAuth
|
||||
|
||||
with pytest.raises(ValueError) as e:
|
||||
AWSV4SignerAuth(session, None)
|
||||
assert str(e.value) == "Region cannot be empty"
|
||||
|
||||
with pytest.raises(ValueError) as e:
|
||||
AWSV4SignerAuth(session, "")
|
||||
assert str(e.value) == "Region cannot be empty"
|
||||
|
||||
def test_aws_signer_when_credentials_is_null(self):
|
||||
region = "us-west-1"
|
||||
|
||||
from opensearchpy.helpers.signer import AWSV4SignerAuth
|
||||
|
||||
with pytest.raises(ValueError) as e:
|
||||
AWSV4SignerAuth(None, region)
|
||||
assert str(e.value) == "Credentials cannot be empty"
|
||||
|
||||
with pytest.raises(ValueError) as e:
|
||||
AWSV4SignerAuth("", region)
|
||||
assert str(e.value) == "Credentials cannot be empty"
|
||||
|
||||
def mock_session(self):
|
||||
access_key = uuid.uuid4().hex
|
||||
secret_key = uuid.uuid4().hex
|
||||
token = uuid.uuid4().hex
|
||||
dummy_session = Mock()
|
||||
dummy_session.access_key = access_key
|
||||
dummy_session.secret_key = secret_key
|
||||
dummy_session.token = token
|
||||
return dummy_session
|
||||
|
||||
def test_uses_https_if_verify_certs_is_off(self):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
con = Urllib3HttpConnection(use_ssl=True, verify_certs=False)
|
||||
|
||||
Reference in New Issue
Block a user