adds IAM auth support

Signed-off-by: Shivam Dhar <dhshivam@amazon.com>
This commit is contained in:
Shivam Dhar
2022-02-24 12:47:57 -08:00
committed by Shivam Dhar
parent 65c12d7224
commit 6ab90be906
9 changed files with 157 additions and 1 deletions
+2
View File
@@ -61,6 +61,7 @@ from .exceptions import (
SSLError,
TransportError,
)
from .helpers import AWSV4SignerAuth
from .serializer import JSONSerializer
from .transport import Transport
@@ -92,6 +93,7 @@ __all__ = [
"AuthorizationException",
"OpenSearchWarning",
"OpenSearchDeprecationWarning",
"AWSV4SignerAuth",
]
try:
+1
View File
@@ -57,6 +57,7 @@ try:
from ._async.client import AsyncOpenSearch as AsyncOpenSearch
from ._async.http_aiohttp import AIOHttpConnection as AIOHttpConnection
from ._async.transport import AsyncTransport as AsyncTransport
from .helpers import AWSV4SignerAuth as AWSV4SignerAuth
except (ImportError, SyntaxError):
pass
+1 -1
View File
@@ -30,7 +30,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
from .connection import Connection
try:
from Queue import PriorityQueue # type: ignore
from Queue import PriorityQueue
except ImportError:
from queue import PriorityQueue
+2
View File
@@ -37,6 +37,7 @@ from .actions import (
streaming_bulk,
)
from .errors import BulkIndexError, ScanError
from .signer import AWSV4SignerAuth
__all__ = [
"BulkIndexError",
@@ -49,6 +50,7 @@ __all__ = [
"reindex",
"_chunk_actions",
"_process_bulk_chunk",
"AWSV4SignerAuth",
]
+1
View File
@@ -46,5 +46,6 @@ try:
from .._async.helpers import async_reindex as async_reindex
from .._async.helpers import async_scan as async_scan
from .._async.helpers import async_streaming_bulk as async_streaming_bulk
from .signer import AWSV4SignerAuth as AWSV4SignerAuth
except (ImportError, SyntaxError):
pass
+89
View File
@@ -0,0 +1,89 @@
# SPDX-License-Identifier: Apache-2.0
#
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
#
# Modifications Copyright OpenSearch Contributors. See
# GitHub history for details.
import sys
import requests
OPENSEARCH_SERVICE = "es"
PY3 = sys.version_info[0] == 3
if PY3:
from urllib.parse import parse_qs, urlencode, urlparse
def fetch_url(prepared_request): # type: ignore
"""
This is a util method that helps in reconstructing the request url.
:param prepared_request: unsigned request
:return: reconstructed url
"""
url = urlparse(prepared_request.url)
path = url.path or "/"
# fetch the query string if present in the request
querystring = ""
if url.query:
querystring = "?" + urlencode(
parse_qs(url.query, keep_blank_values=True), doseq=True
)
# fetch the host information from headers
headers = dict(
(key.lower(), value) for key, value in prepared_request.headers.items()
)
location = headers.get("host") or url.netloc
# construct the url and return
return url.scheme + "://" + location + path + querystring
class AWSV4SignerAuth(requests.auth.AuthBase):
"""
AWS V4 Request Signer for Requests.
"""
def __init__(self, credentials, region): # type: ignore
if not credentials:
raise ValueError("Credentials cannot be empty")
self.credentials = credentials
if not region:
raise ValueError("Region cannot be empty")
self.region = region
def __call__(self, request): # type: ignore
return self._sign_request(request) # type: ignore
def _sign_request(self, prepared_request): # type: ignore
"""
This method helps in signing the request by injecting the required headers.
:param prepared_request: unsigned request
:return: signed request
"""
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
url = fetch_url(prepared_request) # type: ignore
# create an AWS request object and sign it using SigV4Auth
aws_request = AWSRequest(
method=prepared_request.method.upper(),
url=url,
data=prepared_request.body,
)
sig_v4_auth = SigV4Auth(self.credentials, OPENSEARCH_SERVICE, self.region)
sig_v4_auth.add_auth(aws_request)
# copy the headers from AWS request object into the prepared_request
prepared_request.headers.update(dict(aws_request.headers.items()))
return prepared_request
+3
View File
@@ -12,3 +12,6 @@ junit_family=legacy
[tool:isort]
profile=black
[mypy]
ignore_missing_imports=True
+1
View File
@@ -59,6 +59,7 @@ tests_require = [
"pyyaml",
"pytest",
"pytest-cov",
"botocore;python_version>='3.6'",
]
async_require = ["aiohttp>=3,<4"]
+57
View File
@@ -31,7 +31,9 @@ import json
import os
import re
import ssl
import sys
import unittest
import uuid
import warnings
from platform import python_version
@@ -284,6 +286,61 @@ class TestUrllib3Connection(TestCase):
con.headers,
)
@pytest.mark.skipif(
sys.version_info < (3, 6), reason="AWSV4SignerAuth requires python3.6+"
)
def test_aws_signer_as_http_auth(self):
region = "us-west-2"
import requests
from opensearchpy.helpers.signer import AWSV4SignerAuth
auth = AWSV4SignerAuth(self.mock_session(), region)
con = RequestsHttpConnection(http_auth=auth)
prepared_request = requests.Request("GET", "http://localhost").prepare()
auth(prepared_request)
self.assertEqual(auth, con.session.auth)
self.assertIn("Authorization", prepared_request.headers)
self.assertIn("X-Amz-Date", prepared_request.headers)
self.assertIn("X-Amz-Security-Token", prepared_request.headers)
def test_aws_signer_when_region_is_null(self):
session = self.mock_session()
from opensearchpy.helpers.signer import AWSV4SignerAuth
with pytest.raises(ValueError) as e:
AWSV4SignerAuth(session, None)
assert str(e.value) == "Region cannot be empty"
with pytest.raises(ValueError) as e:
AWSV4SignerAuth(session, "")
assert str(e.value) == "Region cannot be empty"
def test_aws_signer_when_credentials_is_null(self):
region = "us-west-1"
from opensearchpy.helpers.signer import AWSV4SignerAuth
with pytest.raises(ValueError) as e:
AWSV4SignerAuth(None, region)
assert str(e.value) == "Credentials cannot be empty"
with pytest.raises(ValueError) as e:
AWSV4SignerAuth("", region)
assert str(e.value) == "Credentials cannot be empty"
def mock_session(self):
access_key = uuid.uuid4().hex
secret_key = uuid.uuid4().hex
token = uuid.uuid4().hex
dummy_session = Mock()
dummy_session.access_key = access_key
dummy_session.secret_key = secret_key
dummy_session.token = token
return dummy_session
def test_uses_https_if_verify_certs_is_off(self):
with warnings.catch_warnings(record=True) as w:
con = Urllib3HttpConnection(use_ssl=True, verify_certs=False)