Small refactor of AWS Signer classes for both sync and async clients (#866)
* made custom headers be available to async aws signer Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com> * updated changelog Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com> * added tests for using host header for AWS request signature on both sync and async clients Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com> * added documentation guide about aws auth when accessing via tunnel Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com> * small refactor of AWS Signer classes on sync and async clients; improved testing on them as well Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com> * changelog Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com> * fixed test Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com> * lint fix Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com> --------- Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com>
This commit is contained in:
@@ -7,6 +7,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
|
||||
- Added sync and async sample that uses `search_after` parameter ([859](https://github.com/opensearch-project/opensearch-py/pull/859))
|
||||
### Updated APIs
|
||||
### Changed
|
||||
- Small refactor of AWS Signer classes for both sync and async clients ([866](https://github.com/opensearch-project/opensearch-py/pull/866))
|
||||
### Deprecated
|
||||
### Removed
|
||||
### Fixed
|
||||
|
||||
@@ -8,7 +8,8 @@
|
||||
# GitHub history for details.
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from urllib.parse import parse_qs, urlencode, urlparse
|
||||
|
||||
from opensearchpy.helpers.signer import AWSV4Signer
|
||||
|
||||
|
||||
class AWSV4SignerAsyncAuth:
|
||||
@@ -17,33 +18,21 @@ class AWSV4SignerAsyncAuth:
|
||||
"""
|
||||
|
||||
def __init__(self, credentials: Any, region: str, service: str = "es") -> None:
|
||||
if not credentials:
|
||||
raise ValueError("Credentials cannot be empty")
|
||||
self.credentials = credentials
|
||||
|
||||
if not region:
|
||||
raise ValueError("Region cannot be empty")
|
||||
self.region = region
|
||||
|
||||
if not service:
|
||||
raise ValueError("Service name cannot be empty")
|
||||
self.service = service
|
||||
self.signer = AWSV4Signer(credentials, region, service)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
query_string: Optional[str] = None,
|
||||
body: Optional[Union[str, bytes]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, str]:
|
||||
return self._sign_request(method, url, query_string, body, headers)
|
||||
return self._sign_request(method=method, url=url, body=body, headers=headers)
|
||||
|
||||
def _sign_request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
query_string: Optional[str],
|
||||
body: Optional[Union[str, bytes]],
|
||||
headers: Optional[Dict[str, str]],
|
||||
) -> Dict[str, str]:
|
||||
@@ -53,58 +42,10 @@ class AWSV4SignerAsyncAuth:
|
||||
:return: signed headers
|
||||
"""
|
||||
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
|
||||
signature_host = self._fetch_url(url, headers or dict())
|
||||
|
||||
# create an AWS request object and sign it using SigV4Auth
|
||||
aws_request = AWSRequest(
|
||||
updated_headers = self.signer.sign(
|
||||
method=method,
|
||||
url=signature_host,
|
||||
data=body,
|
||||
url=url,
|
||||
body=body,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# credentials objects expose access_key, secret_key and token attributes
|
||||
# via @property annotations that call _refresh() on every access,
|
||||
# creating a race condition if the credentials expire before secret_key
|
||||
# is called but after access_key- the end result is the access_key doesn't
|
||||
# correspond to the secret_key used to sign the request. To avoid this,
|
||||
# get_frozen_credentials() which returns non-refreshing credentials is
|
||||
# called if it exists.
|
||||
credentials = (
|
||||
self.credentials.get_frozen_credentials()
|
||||
if hasattr(self.credentials, "get_frozen_credentials")
|
||||
and callable(self.credentials.get_frozen_credentials)
|
||||
else self.credentials
|
||||
)
|
||||
|
||||
sig_v4_auth = SigV4Auth(credentials, self.service, self.region)
|
||||
sig_v4_auth.add_auth(aws_request)
|
||||
aws_request.headers["X-Amz-Content-SHA256"] = sig_v4_auth.payload(aws_request)
|
||||
|
||||
# copy the headers from AWS request object into the prepared_request
|
||||
return dict(aws_request.headers.items())
|
||||
|
||||
def _fetch_url(self, url: str, headers: Optional[Dict[str, str]]) -> str:
|
||||
"""
|
||||
This is a util method that helps in reconstructing the request url.
|
||||
:param prepared_request: unsigned request
|
||||
:return: reconstructed url
|
||||
"""
|
||||
parsed_url = urlparse(url)
|
||||
path = parsed_url.path or "/"
|
||||
|
||||
# fetch the query string if present in the request
|
||||
querystring = ""
|
||||
if parsed_url.query:
|
||||
querystring = "?" + urlencode(
|
||||
parse_qs(parsed_url.query, keep_blank_values=True), doseq=True
|
||||
)
|
||||
|
||||
# fetch the host information from headers
|
||||
headers = {key.lower(): value for key, value in (headers or dict()).items()}
|
||||
location = headers.get("host") or parsed_url.netloc
|
||||
|
||||
# construct the url and return
|
||||
return parsed_url.scheme + "://" + location + path + querystring
|
||||
return updated_headers
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
# Modifications Copyright OpenSearch Contributors. See
|
||||
# GitHub history for details.
|
||||
|
||||
from typing import Any, Callable, Dict
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from urllib.parse import parse_qs, urlencode, urlparse
|
||||
|
||||
import requests
|
||||
@@ -31,7 +31,9 @@ class AWSV4Signer:
|
||||
raise ValueError("Service name cannot be empty")
|
||||
self.service = service
|
||||
|
||||
def sign(self, method: str, url: str, body: Any) -> Dict[str, str]:
|
||||
def sign(
|
||||
self, method: str, url: str, body: Any, headers: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
This method signs the request and returns headers.
|
||||
:param method: HTTP method
|
||||
@@ -43,8 +45,10 @@ class AWSV4Signer:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
|
||||
signature_host = self._fetch_url(url, headers or dict())
|
||||
|
||||
# create an AWS request object and sign it using SigV4Auth
|
||||
aws_request = AWSRequest(method=method.upper(), url=url, data=body)
|
||||
aws_request = AWSRequest(method=method.upper(), url=signature_host, data=body)
|
||||
|
||||
# credentials objects expose access_key, secret_key and token attributes
|
||||
# via @property annotations that call _refresh() on every access,
|
||||
@@ -69,6 +73,30 @@ class AWSV4Signer:
|
||||
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def _fetch_url(url: str, headers: Optional[Dict[str, str]]) -> str:
|
||||
"""
|
||||
This is a util method that helps in reconstructing the request url.
|
||||
:param prepared_request: unsigned request
|
||||
:return: reconstructed url
|
||||
"""
|
||||
parsed_url = urlparse(url)
|
||||
path = parsed_url.path or "/"
|
||||
|
||||
# fetch the query string if present in the request
|
||||
querystring = ""
|
||||
if parsed_url.query:
|
||||
querystring = "?" + urlencode(
|
||||
parse_qs(parsed_url.query, keep_blank_values=True), doseq=True
|
||||
)
|
||||
|
||||
# fetch the host information from headers
|
||||
headers = {key.lower(): value for key, value in (headers or dict()).items()}
|
||||
location = headers.get("host") or parsed_url.netloc
|
||||
|
||||
# construct the url and return
|
||||
return parsed_url.scheme + "://" + location + path + querystring
|
||||
|
||||
|
||||
class RequestsAWSV4SignerAuth(requests.auth.AuthBase):
|
||||
"""
|
||||
@@ -89,41 +117,17 @@ class RequestsAWSV4SignerAuth(requests.auth.AuthBase):
|
||||
:return: signed request
|
||||
"""
|
||||
|
||||
prepared_request.headers.update(
|
||||
self.signer.sign(
|
||||
prepared_request.method,
|
||||
self._fetch_url(prepared_request),
|
||||
prepared_request.body,
|
||||
)
|
||||
updated_headers = self.signer.sign(
|
||||
method=prepared_request.method,
|
||||
url=prepared_request.url,
|
||||
body=prepared_request.body,
|
||||
headers=prepared_request.headers,
|
||||
)
|
||||
|
||||
prepared_request.headers.update(updated_headers)
|
||||
|
||||
return prepared_request
|
||||
|
||||
def _fetch_url(self, prepared_request: requests.PreparedRequest) -> str:
|
||||
"""
|
||||
This is a util method that helps in reconstructing the request url.
|
||||
:param prepared_request: unsigned request
|
||||
:return: reconstructed url
|
||||
"""
|
||||
url = urlparse(prepared_request.url)
|
||||
path = url.path or "/"
|
||||
|
||||
# fetch the query string if present in the request
|
||||
querystring = ""
|
||||
if url.query:
|
||||
querystring = "?" + urlencode(
|
||||
parse_qs(url.query, keep_blank_values=True), doseq=True # type: ignore
|
||||
)
|
||||
|
||||
# fetch the host information from headers
|
||||
headers = {
|
||||
key.lower(): value for key, value in prepared_request.headers.items()
|
||||
}
|
||||
location = headers.get("host") or url.netloc
|
||||
|
||||
# construct the url and return
|
||||
return url.scheme + "://" + location + path + querystring # type: ignore
|
||||
|
||||
|
||||
# Deprecated: use RequestsAWSV4SignerAuth
|
||||
class AWSV4SignerAuth(RequestsAWSV4SignerAuth):
|
||||
@@ -135,5 +139,7 @@ class Urllib3AWSV4SignerAuth(Callable): # type: ignore
|
||||
self.signer = AWSV4Signer(credentials, region, service)
|
||||
self.service = service # tools like LangChain rely on this, see https://github.com/opensearch-project/opensearch-py/issues/600
|
||||
|
||||
def __call__(self, method: str, url: str, body: Any) -> Dict[str, str]:
|
||||
return self.signer.sign(method, url, body)
|
||||
def __call__(
|
||||
self, method: str, url: str, body: Any, headers: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, str]:
|
||||
return self.signer.sign(method, url, body, headers)
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
|
||||
import uuid
|
||||
from typing import Any, Collection, Dict, Mapping, Optional, Tuple, Union
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from _pytest.mark.structures import MarkDecorator
|
||||
@@ -81,15 +81,18 @@ class TestAsyncSigner:
|
||||
region = "us-west-2"
|
||||
service = "aoss"
|
||||
|
||||
from botocore.awsrequest import AWSRequest
|
||||
|
||||
from opensearchpy.helpers.asyncsigner import AWSV4SignerAsyncAuth
|
||||
|
||||
auth = AWSV4SignerAsyncAuth(self.mock_session(), region, service)
|
||||
|
||||
signature_host = auth._fetch_url(
|
||||
"http://localhost/?foo=bar", headers={"host": "otherhost"}
|
||||
)
|
||||
|
||||
assert signature_host == "http://otherhost/?foo=bar"
|
||||
with patch(
|
||||
"botocore.awsrequest.AWSRequest", side_effect=AWSRequest
|
||||
) as mock_aws_request:
|
||||
auth = AWSV4SignerAsyncAuth(self.mock_session(), region, service)
|
||||
auth("GET", "http://localhost/?foo=bar", headers={"host": "otherhost:443"})
|
||||
mock_aws_request.assert_called_with(
|
||||
method="GET", url="http://otherhost:443/?foo=bar", data=None
|
||||
)
|
||||
|
||||
|
||||
class TestAsyncSignerWithFrozenCredentials(TestAsyncSigner):
|
||||
@@ -155,7 +158,6 @@ class TestAsyncSignerWithSpecialCharacters:
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
query_string: Optional[str] = None,
|
||||
body: Optional[Union[str, bytes]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, str]:
|
||||
|
||||
@@ -457,22 +457,27 @@ class TestRequestsHttpConnection(TestCase):
|
||||
|
||||
return dummy_session
|
||||
|
||||
def test_aws_signer_fetch_url_with_querystring(self) -> None:
|
||||
def test_aws_signer_url_with_querystring_and_custom_header(self) -> None:
|
||||
region = "us-west-2"
|
||||
|
||||
import requests
|
||||
from botocore.awsrequest import AWSRequest
|
||||
|
||||
from opensearchpy.helpers.signer import RequestsAWSV4SignerAuth
|
||||
|
||||
auth = RequestsAWSV4SignerAuth(self.mock_session(), region)
|
||||
with patch(
|
||||
"botocore.awsrequest.AWSRequest", side_effect=AWSRequest
|
||||
) as mock_aws_request:
|
||||
|
||||
prepared_request = requests.Request(
|
||||
"GET", "http://localhost/?foo=bar", headers={"host": "otherhost:443"}
|
||||
).prepare()
|
||||
auth = RequestsAWSV4SignerAuth(self.mock_session(), region)
|
||||
prepared_request = requests.Request(
|
||||
"GET", "http://localhost/?foo=bar", headers={"host": "otherhost:443"}
|
||||
).prepare()
|
||||
auth(prepared_request)
|
||||
|
||||
signature_host = auth._fetch_url(prepared_request)
|
||||
|
||||
assert signature_host == "http://otherhost:443/?foo=bar"
|
||||
mock_aws_request.assert_called_with(
|
||||
method="GET", url="http://otherhost:443/?foo=bar", data=None
|
||||
)
|
||||
|
||||
def test_aws_signer_as_http_auth(self) -> None:
|
||||
region = "us-west-2"
|
||||
@@ -525,9 +530,11 @@ class TestRequestsHttpConnection(TestCase):
|
||||
).prepare()
|
||||
auth(prepared_request)
|
||||
self.assertEqual(mock_sign.call_count, 1)
|
||||
self.assertEqual(
|
||||
mock_sign.call_args[0],
|
||||
("GET", "http://localhost/?key1=value1&key2=value2", None),
|
||||
mock_sign.assert_called_with(
|
||||
method="GET",
|
||||
url="http://localhost/?key1=value1&key2=value2",
|
||||
body=None,
|
||||
headers={},
|
||||
)
|
||||
|
||||
def test_aws_signer_consitent_url(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user