Small refactor of AWS Signer classes for both sync and async clients (#866)

* made custom headers be available to async aws signer

Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com>

* updated changelog

Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com>

* added tests for using host header for AWS request signature on both sync and async clients

Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com>

* added documentation guide about aws auth when accessing via tunnel

Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com>

* small refactor of AWS Signer classes on sync and async clients; improved testing on them as well

Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com>

* changelog

Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com>

* fixed test

Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com>

* lint fix

Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com>

---------

Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com>
This commit is contained in:
Bruno Murino
2024-12-03 13:13:50 +00:00
committed by GitHub
parent 87aebcd653
commit 7815c6abe8
5 changed files with 81 additions and 124 deletions
+1
View File
@@ -7,6 +7,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Added sync and async sample that uses `search_after` parameter ([859](https://github.com/opensearch-project/opensearch-py/pull/859))
### Updated APIs
### Changed
- Small refactor of AWS Signer classes for both sync and async clients ([866](https://github.com/opensearch-project/opensearch-py/pull/866))
### Deprecated
### Removed
### Fixed
+9 -68
View File
@@ -8,7 +8,8 @@
# GitHub history for details.
from typing import Any, Dict, Optional, Union
from urllib.parse import parse_qs, urlencode, urlparse
from opensearchpy.helpers.signer import AWSV4Signer
class AWSV4SignerAsyncAuth:
@@ -17,33 +18,21 @@ class AWSV4SignerAsyncAuth:
"""
def __init__(self, credentials: Any, region: str, service: str = "es") -> None:
if not credentials:
raise ValueError("Credentials cannot be empty")
self.credentials = credentials
if not region:
raise ValueError("Region cannot be empty")
self.region = region
if not service:
raise ValueError("Service name cannot be empty")
self.service = service
self.signer = AWSV4Signer(credentials, region, service)
def __call__(
self,
method: str,
url: str,
query_string: Optional[str] = None,
body: Optional[Union[str, bytes]] = None,
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
return self._sign_request(method, url, query_string, body, headers)
return self._sign_request(method=method, url=url, body=body, headers=headers)
def _sign_request(
self,
method: str,
url: str,
query_string: Optional[str],
body: Optional[Union[str, bytes]],
headers: Optional[Dict[str, str]],
) -> Dict[str, str]:
@@ -53,58 +42,10 @@ class AWSV4SignerAsyncAuth:
:return: signed headers
"""
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
signature_host = self._fetch_url(url, headers or dict())
# create an AWS request object and sign it using SigV4Auth
aws_request = AWSRequest(
updated_headers = self.signer.sign(
method=method,
url=signature_host,
data=body,
url=url,
body=body,
headers=headers,
)
# credentials objects expose access_key, secret_key and token attributes
# via @property annotations that call _refresh() on every access,
# creating a race condition if the credentials expire before secret_key
# is called but after access_key- the end result is the access_key doesn't
# correspond to the secret_key used to sign the request. To avoid this,
# get_frozen_credentials() which returns non-refreshing credentials is
# called if it exists.
credentials = (
self.credentials.get_frozen_credentials()
if hasattr(self.credentials, "get_frozen_credentials")
and callable(self.credentials.get_frozen_credentials)
else self.credentials
)
sig_v4_auth = SigV4Auth(credentials, self.service, self.region)
sig_v4_auth.add_auth(aws_request)
aws_request.headers["X-Amz-Content-SHA256"] = sig_v4_auth.payload(aws_request)
# copy the headers from AWS request object into the prepared_request
return dict(aws_request.headers.items())
def _fetch_url(self, url: str, headers: Optional[Dict[str, str]]) -> str:
"""
This is a util method that helps in reconstructing the request url.
:param prepared_request: unsigned request
:return: reconstructed url
"""
parsed_url = urlparse(url)
path = parsed_url.path or "/"
# fetch the query string if present in the request
querystring = ""
if parsed_url.query:
querystring = "?" + urlencode(
parse_qs(parsed_url.query, keep_blank_values=True), doseq=True
)
# fetch the host information from headers
headers = {key.lower(): value for key, value in (headers or dict()).items()}
location = headers.get("host") or parsed_url.netloc
# construct the url and return
return parsed_url.scheme + "://" + location + path + querystring
return updated_headers
+42 -36
View File
@@ -7,7 +7,7 @@
# Modifications Copyright OpenSearch Contributors. See
# GitHub history for details.
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Optional
from urllib.parse import parse_qs, urlencode, urlparse
import requests
@@ -31,7 +31,9 @@ class AWSV4Signer:
raise ValueError("Service name cannot be empty")
self.service = service
def sign(self, method: str, url: str, body: Any) -> Dict[str, str]:
def sign(
self, method: str, url: str, body: Any, headers: Optional[Dict[str, str]] = None
) -> Dict[str, str]:
"""
This method signs the request and returns headers.
:param method: HTTP method
@@ -43,8 +45,10 @@ class AWSV4Signer:
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
signature_host = self._fetch_url(url, headers or dict())
# create an AWS request object and sign it using SigV4Auth
aws_request = AWSRequest(method=method.upper(), url=url, data=body)
aws_request = AWSRequest(method=method.upper(), url=signature_host, data=body)
# credentials objects expose access_key, secret_key and token attributes
# via @property annotations that call _refresh() on every access,
@@ -69,6 +73,30 @@ class AWSV4Signer:
return headers
@staticmethod
def _fetch_url(url: str, headers: Optional[Dict[str, str]]) -> str:
"""
This is a util method that helps in reconstructing the request url.
:param prepared_request: unsigned request
:return: reconstructed url
"""
parsed_url = urlparse(url)
path = parsed_url.path or "/"
# fetch the query string if present in the request
querystring = ""
if parsed_url.query:
querystring = "?" + urlencode(
parse_qs(parsed_url.query, keep_blank_values=True), doseq=True
)
# fetch the host information from headers
headers = {key.lower(): value for key, value in (headers or dict()).items()}
location = headers.get("host") or parsed_url.netloc
# construct the url and return
return parsed_url.scheme + "://" + location + path + querystring
class RequestsAWSV4SignerAuth(requests.auth.AuthBase):
"""
@@ -89,41 +117,17 @@ class RequestsAWSV4SignerAuth(requests.auth.AuthBase):
:return: signed request
"""
prepared_request.headers.update(
self.signer.sign(
prepared_request.method,
self._fetch_url(prepared_request),
prepared_request.body,
)
updated_headers = self.signer.sign(
method=prepared_request.method,
url=prepared_request.url,
body=prepared_request.body,
headers=prepared_request.headers,
)
prepared_request.headers.update(updated_headers)
return prepared_request
def _fetch_url(self, prepared_request: requests.PreparedRequest) -> str:
"""
This is a util method that helps in reconstructing the request url.
:param prepared_request: unsigned request
:return: reconstructed url
"""
url = urlparse(prepared_request.url)
path = url.path or "/"
# fetch the query string if present in the request
querystring = ""
if url.query:
querystring = "?" + urlencode(
parse_qs(url.query, keep_blank_values=True), doseq=True # type: ignore
)
# fetch the host information from headers
headers = {
key.lower(): value for key, value in prepared_request.headers.items()
}
location = headers.get("host") or url.netloc
# construct the url and return
return url.scheme + "://" + location + path + querystring # type: ignore
# Deprecated: use RequestsAWSV4SignerAuth
class AWSV4SignerAuth(RequestsAWSV4SignerAuth):
@@ -135,5 +139,7 @@ class Urllib3AWSV4SignerAuth(Callable): # type: ignore
self.signer = AWSV4Signer(credentials, region, service)
self.service = service # tools like LangChain rely on this, see https://github.com/opensearch-project/opensearch-py/issues/600
def __call__(self, method: str, url: str, body: Any) -> Dict[str, str]:
return self.signer.sign(method, url, body)
def __call__(
self, method: str, url: str, body: Any, headers: Optional[Dict[str, str]] = None
) -> Dict[str, str]:
return self.signer.sign(method, url, body, headers)
+11 -9
View File
@@ -9,7 +9,7 @@
import uuid
from typing import Any, Collection, Dict, Mapping, Optional, Tuple, Union
from unittest.mock import Mock
from unittest.mock import Mock, patch
import pytest
from _pytest.mark.structures import MarkDecorator
@@ -81,15 +81,18 @@ class TestAsyncSigner:
region = "us-west-2"
service = "aoss"
from botocore.awsrequest import AWSRequest
from opensearchpy.helpers.asyncsigner import AWSV4SignerAsyncAuth
auth = AWSV4SignerAsyncAuth(self.mock_session(), region, service)
signature_host = auth._fetch_url(
"http://localhost/?foo=bar", headers={"host": "otherhost"}
)
assert signature_host == "http://otherhost/?foo=bar"
with patch(
"botocore.awsrequest.AWSRequest", side_effect=AWSRequest
) as mock_aws_request:
auth = AWSV4SignerAsyncAuth(self.mock_session(), region, service)
auth("GET", "http://localhost/?foo=bar", headers={"host": "otherhost:443"})
mock_aws_request.assert_called_with(
method="GET", url="http://otherhost:443/?foo=bar", data=None
)
class TestAsyncSignerWithFrozenCredentials(TestAsyncSigner):
@@ -155,7 +158,6 @@ class TestAsyncSignerWithSpecialCharacters:
self,
method: str,
url: str,
query_string: Optional[str] = None,
body: Optional[Union[str, bytes]] = None,
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
@@ -457,22 +457,27 @@ class TestRequestsHttpConnection(TestCase):
return dummy_session
def test_aws_signer_fetch_url_with_querystring(self) -> None:
def test_aws_signer_url_with_querystring_and_custom_header(self) -> None:
region = "us-west-2"
import requests
from botocore.awsrequest import AWSRequest
from opensearchpy.helpers.signer import RequestsAWSV4SignerAuth
auth = RequestsAWSV4SignerAuth(self.mock_session(), region)
with patch(
"botocore.awsrequest.AWSRequest", side_effect=AWSRequest
) as mock_aws_request:
prepared_request = requests.Request(
"GET", "http://localhost/?foo=bar", headers={"host": "otherhost:443"}
).prepare()
auth = RequestsAWSV4SignerAuth(self.mock_session(), region)
prepared_request = requests.Request(
"GET", "http://localhost/?foo=bar", headers={"host": "otherhost:443"}
).prepare()
auth(prepared_request)
signature_host = auth._fetch_url(prepared_request)
assert signature_host == "http://otherhost:443/?foo=bar"
mock_aws_request.assert_called_with(
method="GET", url="http://otherhost:443/?foo=bar", data=None
)
def test_aws_signer_as_http_auth(self) -> None:
region = "us-west-2"
@@ -525,9 +530,11 @@ class TestRequestsHttpConnection(TestCase):
).prepare()
auth(prepared_request)
self.assertEqual(mock_sign.call_count, 1)
self.assertEqual(
mock_sign.call_args[0],
("GET", "http://localhost/?key1=value1&key2=value2", None),
mock_sign.assert_called_with(
method="GET",
url="http://localhost/?key1=value1&key2=value2",
body=None,
headers={},
)
def test_aws_signer_consitent_url(self) -> None: