Remove redundant mock backport dependency and upgrade syntax for Python 3.8+ (#785)

* Upgrade syntax with pyupgrade --py38-plus

Signed-off-by: Hugo van Kemenade <1324225+hugovk@users.noreply.github.com>

* Convert to f-strings with flynt

Signed-off-by: Hugo van Kemenade <1324225+hugovk@users.noreply.github.com>

* Format with Black

Signed-off-by: Hugo van Kemenade <1324225+hugovk@users.noreply.github.com>

* Remove redundant mock backport dependency

Signed-off-by: Hugo van Kemenade <1324225+hugovk@users.noreply.github.com>

* isort imports

Signed-off-by: Hugo van Kemenade <1324225+hugovk@users.noreply.github.com>

* Add changelog entry

Signed-off-by: Hugo van Kemenade <1324225+hugovk@users.noreply.github.com>

---------

Signed-off-by: Hugo van Kemenade <1324225+hugovk@users.noreply.github.com>
This commit is contained in:
Hugo van Kemenade
2024-07-20 23:19:20 +03:00
committed by GitHub
parent de96d28e45
commit 6e3f1a1194
95 changed files with 229 additions and 300 deletions
+1
View File
@@ -8,6 +8,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
### Deprecated ### Deprecated
### Removed ### Removed
- Removed redundant dependency on six ([#781](https://github.com/opensearch-project/opensearch-py/pull/781)) - Removed redundant dependency on six ([#781](https://github.com/opensearch-project/opensearch-py/pull/781))
- Removed redundant dependency on mock and upgrade Python syntax ([#785](https://github.com/opensearch-project/opensearch-py/pull/785))
### Fixed ### Fixed
- Fixed Search helper to ensure proper retention of the _collapse attribute in chained operations. ([#771](https://github.com/opensearch-project/opensearch-py/pull/771)) - Fixed Search helper to ensure proper retention of the _collapse attribute in chained operations. ([#771](https://github.com/opensearch-project/opensearch-py/pull/771))
### Updated APIs ### Updated APIs
-1
View File
@@ -2,7 +2,6 @@ requests>=2, <3
pytest pytest
pytest-cov pytest-cov
coverage coverage
mock
sphinx<7.4 sphinx<7.4
sphinx_rtd_theme sphinx_rtd_theme
jinja2 jinja2
-1
View File
@@ -93,7 +93,6 @@ def lint(session: Any) -> None:
"types-simplejson", "types-simplejson",
"types-python-dateutil", "types-python-dateutil",
"types-PyYAML", "types-PyYAML",
"types-mock",
"types-pytz", "types-pytz",
) )
+2 -3
View File
@@ -26,7 +26,6 @@
# flake8: noqa # flake8: noqa
from __future__ import absolute_import
import logging import logging
import re import re
@@ -34,9 +33,9 @@ import warnings
from ._version import __versionstr__ from ._version import __versionstr__
_major, _minor, _patch = [ _major, _minor, _patch = (
int(x) for x in re.search(r"^(\d+)\.(\d+)\.(\d+)", __versionstr__).groups() # type: ignore int(x) for x in re.search(r"^(\d+)\.(\d+)\.(\d+)", __versionstr__).groups() # type: ignore
] )
VERSION = __version__ = (_major, _minor, _patch) VERSION = __version__ = (_major, _minor, _patch)
+3 -5
View File
@@ -34,8 +34,6 @@
# -----------------------------------------------------------------------------------------+ # -----------------------------------------------------------------------------------------+
from __future__ import unicode_literals
import logging import logging
from typing import Any, Type from typing import Any, Type
@@ -197,7 +195,7 @@ class AsyncOpenSearch(Client):
self, self,
hosts: Any = None, hosts: Any = None,
transport_class: Type[AsyncTransport] = AsyncTransport, transport_class: Type[AsyncTransport] = AsyncTransport,
**kwargs: Any **kwargs: Any,
) -> None: ) -> None:
""" """
:arg hosts: list of nodes, or a single node, we should connect to. :arg hosts: list of nodes, or a single node, we should connect to.
@@ -240,10 +238,10 @@ class AsyncOpenSearch(Client):
# truncate to 5 if there are too many # truncate to 5 if there are too many
if len(cons) > 5: if len(cons) > 5:
cons = cons[:5] + ["..."] cons = cons[:5] + ["..."]
return "<{cls}({cons})>".format(cls=self.__class__.__name__, cons=cons) return f"<{self.__class__.__name__}({cons})>"
except Exception: except Exception:
# probably operating on custom transport and connection_pool, ignore # probably operating on custom transport and connection_pool, ignore
return super(AsyncOpenSearch, self).__repr__() return super().__repr__()
async def __aenter__(self) -> Any: async def __aenter__(self) -> Any:
if hasattr(self.transport, "_async_call"): if hasattr(self.transport, "_async_call"):
+1 -1
View File
@@ -13,7 +13,7 @@ from opensearchpy.client.utils import _normalize_hosts
from opensearchpy.transport import Transport from opensearchpy.transport import Transport
class Client(object): class Client:
""" """
A generic async OpenSearch client. A generic async OpenSearch client.
""" """
+1 -1
View File
@@ -15,7 +15,7 @@ from .utils import NamespacedClient
class HttpClient(NamespacedClient): class HttpClient(NamespacedClient):
def __init__(self, client: Client) -> None: def __init__(self, client: Client) -> None:
super(HttpClient, self).__init__(client) super().__init__(client)
async def get( async def get(
self, self,
+1 -1
View File
@@ -26,7 +26,7 @@ class PluginsClient(NamespacedClient):
index_management: Any index_management: Any
def __init__(self, client: Client) -> None: def __init__(self, client: Client) -> None:
super(PluginsClient, self).__init__(client) super().__init__(client)
self.ml = MlClient(client) self.ml = MlClient(client)
self.transforms = TransformsClient(client) self.transforms = TransformsClient(client)
self.rollups = RollupsClient(client) self.rollups = RollupsClient(client)
+2 -2
View File
@@ -117,7 +117,7 @@ class AsyncDocument(ObjectBase, metaclass=AsyncIndexMeta):
return "{}({})".format( return "{}({})".format(
self.__class__.__name__, self.__class__.__name__,
", ".join( ", ".join(
"{}={!r}".format(key, getattr(self.meta, key)) f"{key}={getattr(self.meta, key)!r}"
for key in ("index", "id") for key in ("index", "id")
if key in self.meta if key in self.meta
), ),
@@ -249,7 +249,7 @@ class AsyncDocument(ObjectBase, metaclass=AsyncIndexMeta):
raise RequestError(400, message, error_docs) raise RequestError(400, message, error_docs)
if missing_docs: if missing_docs:
missing_ids = [doc["_id"] for doc in missing_docs] missing_ids = [doc["_id"] for doc in missing_docs]
message = "Documents %s not found." % ", ".join(missing_ids) message = f"Documents {', '.join(missing_ids)} not found."
raise NotFoundError(404, message, {"docs": missing_docs}) raise NotFoundError(404, message, {"docs": missing_docs})
return objs return objs
+2 -2
View File
@@ -18,7 +18,7 @@ from opensearchpy.helpers import analysis
from opensearchpy.helpers.utils import merge from opensearchpy.helpers.utils import merge
class AsyncIndexTemplate(object): class AsyncIndexTemplate:
def __init__( def __init__(
self, self,
name: Any, name: Any,
@@ -57,7 +57,7 @@ class AsyncIndexTemplate(object):
) )
class AsyncIndex(object): class AsyncIndex:
def __init__(self, name: Any, using: Any = "default") -> None: def __init__(self, name: Any, using: Any = "default") -> None:
""" """
:arg name: name of the index :arg name: name of the index
@@ -31,7 +31,7 @@ class AsyncUpdateByQuery(Request):
overridden by methods (`using`, `index` and `doc_type` respectively). overridden by methods (`using`, `index` and `doc_type` respectively).
""" """
super(AsyncUpdateByQuery, self).__init__(**kwargs) super().__init__(**kwargs)
self._response_class = UpdateByQueryResponse self._response_class = UpdateByQueryResponse
self._script: Any = {} self._script: Any = {}
self._query_proxy = QueryProxy(self, "query") self._query_proxy = QueryProxy(self, "query")
@@ -70,7 +70,7 @@ class AsyncUpdateByQuery(Request):
of all the underlying objects. Used internally by most state modifying of all the underlying objects. Used internally by most state modifying
APIs. APIs.
""" """
ubq = super(AsyncUpdateByQuery, self)._clone() ubq = super()._clone()
ubq._response_class = self._response_class ubq._response_class = self._response_class
ubq._script = self._script.copy() ubq._script = self._script.copy()
+3 -3
View File
@@ -93,7 +93,7 @@ class AIOHttpConnection(AsyncConnection):
opaque_id: Optional[str] = None, opaque_id: Optional[str] = None,
loop: Any = None, loop: Any = None,
trust_env: Optional[bool] = False, trust_env: Optional[bool] = False,
**kwargs: Any **kwargs: Any,
) -> None: ) -> None:
""" """
Default connection class for ``AsyncOpenSearch`` using the `aiohttp` library and the http protocol. Default connection class for ``AsyncOpenSearch`` using the `aiohttp` library and the http protocol.
@@ -140,7 +140,7 @@ class AIOHttpConnection(AsyncConnection):
headers=headers, headers=headers,
http_compress=http_compress, http_compress=http_compress,
opaque_id=opaque_id, opaque_id=opaque_id,
**kwargs **kwargs,
) )
if http_auth is not None: if http_auth is not None:
@@ -276,7 +276,7 @@ class AIOHttpConnection(AsyncConnection):
else: else:
url = self.url_prefix + url url = self.url_prefix + url
if query_string: if query_string:
url = "%s?%s" % (url, query_string) url = f"{url}?{query_string}"
url = self.host + url url = self.host + url
timeout = aiohttp.ClientTimeout( timeout = aiohttp.ClientTimeout(
+1 -1
View File
@@ -121,7 +121,7 @@ class AsyncTransport(Transport):
self._async_init_called = False self._async_init_called = False
self._sniff_on_start_event: Optional[asyncio.Event] = None self._sniff_on_start_event: Optional[asyncio.Event] = None
super(AsyncTransport, self).__init__( super().__init__(
hosts=[], hosts=[],
connection_class=connection_class, connection_class=connection_class,
connection_pool_class=connection_pool_class, connection_pool_class=connection_pool_class,
+3 -5
View File
@@ -34,8 +34,6 @@
# -----------------------------------------------------------------------------------------+ # -----------------------------------------------------------------------------------------+
from __future__ import unicode_literals
import logging import logging
from typing import Any, Type from typing import Any, Type
@@ -197,7 +195,7 @@ class OpenSearch(Client):
self, self,
hosts: Any = None, hosts: Any = None,
transport_class: Type[Transport] = Transport, transport_class: Type[Transport] = Transport,
**kwargs: Any **kwargs: Any,
) -> None: ) -> None:
""" """
:arg hosts: list of nodes, or a single node, we should connect to. :arg hosts: list of nodes, or a single node, we should connect to.
@@ -240,10 +238,10 @@ class OpenSearch(Client):
# truncate to 5 if there are too many # truncate to 5 if there are too many
if len(cons) > 5: if len(cons) > 5:
cons = cons[:5] + ["..."] cons = cons[:5] + ["..."]
return "<{cls}({cons})>".format(cls=self.__class__.__name__, cons=cons) return f"<{self.__class__.__name__}({cons})>"
except Exception: except Exception:
# probably operating on custom transport and connection_pool, ignore # probably operating on custom transport and connection_pool, ignore
return super(OpenSearch, self).__repr__() return super().__repr__()
def __enter__(self) -> Any: def __enter__(self) -> Any:
if hasattr(self.transport, "_async_call"): if hasattr(self.transport, "_async_call"):
+1 -1
View File
@@ -13,7 +13,7 @@ from opensearchpy.client.utils import _normalize_hosts
from opensearchpy.transport import Transport from opensearchpy.transport import Transport
class Client(object): class Client:
""" """
A generic async OpenSearch client. A generic async OpenSearch client.
""" """
+1 -1
View File
@@ -15,7 +15,7 @@ from .utils import NamespacedClient
class HttpClient(NamespacedClient): class HttpClient(NamespacedClient):
def __init__(self, client: Client) -> None: def __init__(self, client: Client) -> None:
super(HttpClient, self).__init__(client) super().__init__(client)
def get( def get(
self, self,
+1 -1
View File
@@ -26,7 +26,7 @@ class PluginsClient(NamespacedClient):
index_management: Any index_management: Any
def __init__(self, client: Client) -> None: def __init__(self, client: Client) -> None:
super(PluginsClient, self).__init__(client) super().__init__(client)
self.ml = MlClient(client) self.ml = MlClient(client)
self.transforms = TransformsClient(client) self.transforms = TransformsClient(client)
self.rollups = RollupsClient(client) self.rollups = RollupsClient(client)
+5 -9
View File
@@ -25,8 +25,6 @@
# under the License. # under the License.
from __future__ import unicode_literals
import base64 import base64
import weakref import weakref
from datetime import date, datetime from datetime import date, datetime
@@ -59,7 +57,7 @@ def _normalize_hosts(hosts: Any) -> Any:
for host in hosts: for host in hosts:
if isinstance(host, string_types): if isinstance(host, string_types):
if "://" not in host: if "://" not in host:
host = "//%s" % host # type: ignore host = f"//{host}" # type: ignore
parsed_url = urlparse(host) parsed_url = urlparse(host)
h = {"host": parsed_url.hostname} h = {"host": parsed_url.hostname}
@@ -72,7 +70,7 @@ def _normalize_hosts(hosts: Any) -> Any:
h["use_ssl"] = True h["use_ssl"] = True
if parsed_url.username or parsed_url.password: if parsed_url.username or parsed_url.password:
h["http_auth"] = "%s:%s" % ( h["http_auth"] = "{}:{}".format(
unquote(parsed_url.username), unquote(parsed_url.username),
unquote(parsed_url.password), unquote(parsed_url.password),
) )
@@ -160,11 +158,9 @@ def query_params(*opensearch_query_params: Any) -> Callable: # type: ignore
"Only one of 'http_auth' and 'api_key' may be passed at a time" "Only one of 'http_auth' and 'api_key' may be passed at a time"
) )
elif http_auth is not None: elif http_auth is not None:
headers["authorization"] = "Basic %s" % ( headers["authorization"] = f"Basic {_base64_auth_header(http_auth)}"
_base64_auth_header(http_auth),
)
elif api_key is not None: elif api_key is not None:
headers["authorization"] = "ApiKey %s" % (_base64_auth_header(api_key),) headers["authorization"] = f"ApiKey {_base64_auth_header(api_key)}"
# don't escape ignore, request_timeout, or timeout # don't escape ignore, request_timeout, or timeout
for p in ("ignore", "request_timeout", "timeout"): for p in ("ignore", "request_timeout", "timeout"):
@@ -209,7 +205,7 @@ def _base64_auth_header(auth_value: Any) -> str:
return to_str(auth_value) return to_str(auth_value)
class NamespacedClient(object): class NamespacedClient:
def __init__(self, client: Any) -> None: def __init__(self, client: Any) -> None:
self.client = client self.client = client
+2 -2
View File
@@ -68,7 +68,7 @@ class AsyncConnections:
errors += 1 errors += 1
if errors == 2: if errors == 2:
raise KeyError("There is no connection with alias %r." % alias) raise KeyError(f"There is no connection with alias {alias!r}.")
async def create_connection(self, alias: str = "default", **kwargs: Any) -> Any: async def create_connection(self, alias: str = "default", **kwargs: Any) -> Any:
""" """
@@ -104,7 +104,7 @@ class AsyncConnections:
return await self.create_connection(alias, **self._kwargs[alias]) return await self.create_connection(alias, **self._kwargs[alias])
except KeyError: except KeyError:
# no connection and no kwargs to set one up # no connection and no kwargs to set one up
raise KeyError("There is no connection with alias %r." % alias) raise KeyError(f"There is no connection with alias {alias!r}.")
async_connections = AsyncConnections() async_connections = AsyncConnections()
+9 -9
View File
@@ -53,7 +53,7 @@ if not TRACER_ALREADY_CONFIGURED:
_WARNING_RE = re.compile(r"\"([^\"]*)\"") _WARNING_RE = re.compile(r"\"([^\"]*)\"")
class Connection(object): class Connection:
""" """
Class responsible for maintaining a connection to an OpenSearch node. It Class responsible for maintaining a connection to an OpenSearch node. It
holds persistent connection pool to it and its main interface holds persistent connection pool to it and its main interface
@@ -81,7 +81,7 @@ class Connection(object):
headers: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None,
http_compress: Optional[bool] = None, http_compress: Optional[bool] = None,
opaque_id: Optional[str] = None, opaque_id: Optional[str] = None,
**kwargs: Any **kwargs: Any,
) -> None: ) -> None:
if port is None: if port is None:
port = 9200 port = 9200
@@ -119,27 +119,27 @@ class Connection(object):
self.hostname = host self.hostname = host
self.port = port self.port = port
if ":" in host: # IPv6 if ":" in host: # IPv6
self.host = "%s://[%s]" % (scheme, host) self.host = f"{scheme}://[{host}]"
else: else:
self.host = "%s://%s" % (scheme, host) self.host = f"{scheme}://{host}"
if self.port is not None: if self.port is not None:
self.host += ":%s" % self.port self.host += f":{self.port}"
if url_prefix: if url_prefix:
url_prefix = "/" + url_prefix.strip("/") url_prefix = "/" + url_prefix.strip("/")
self.url_prefix = url_prefix self.url_prefix = url_prefix
self.timeout = timeout self.timeout = timeout
def __repr__(self) -> str: def __repr__(self) -> str:
return "<%s: %s>" % (self.__class__.__name__, self.host) return f"<{self.__class__.__name__}: {self.host}>"
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, Connection): if not isinstance(other, Connection):
raise TypeError("Unsupported equality check for %s and %s" % (self, other)) raise TypeError(f"Unsupported equality check for {self} and {other}")
return self.__hash__() == other.__hash__() return self.__hash__() == other.__hash__()
def __lt__(self, other: object) -> bool: def __lt__(self, other: object) -> bool:
if not isinstance(other, Connection): if not isinstance(other, Connection):
raise TypeError("Unsupported lt check for %s and %s" % (self, other)) raise TypeError(f"Unsupported lt check for {self} and {other}")
return self.__hash__() < other.__hash__() return self.__hash__() < other.__hash__()
def __hash__(self) -> int: def __hash__(self) -> int:
@@ -317,7 +317,7 @@ class Connection(object):
) )
def _get_default_user_agent(self) -> str: def _get_default_user_agent(self) -> str:
return "opensearch-py/%s (Python %s)" % (__versionstr__, python_version()) return f"opensearch-py/{__versionstr__} (Python {python_version()})"
@staticmethod @staticmethod
def default_ca_certs() -> Union[str, None]: def default_ca_certs() -> Union[str, None]:
+2 -2
View File
@@ -82,7 +82,7 @@ class Connections:
errors += 1 errors += 1
if errors == 2: if errors == 2:
raise KeyError("There is no connection with alias %r." % alias) raise KeyError(f"There is no connection with alias {alias!r}.")
def create_connection(self, alias: str = "default", **kwargs: Any) -> Any: def create_connection(self, alias: str = "default", **kwargs: Any) -> Any:
""" """
@@ -118,7 +118,7 @@ class Connections:
return self.create_connection(alias, **self._kwargs[alias]) return self.create_connection(alias, **self._kwargs[alias])
except KeyError: except KeyError:
# no connection and no kwargs to set one up # no connection and no kwargs to set one up
raise KeyError("There is no connection with alias %r." % alias) raise KeyError(f"There is no connection with alias {alias!r}.")
connections = Connections() connections = Connections()
+3 -3
View File
@@ -52,7 +52,7 @@ class AsyncHttpConnection(AIOHttpConnection):
http_compress: Optional[bool] = None, http_compress: Optional[bool] = None,
opaque_id: Optional[str] = None, opaque_id: Optional[str] = None,
loop: Any = None, loop: Any = None,
**kwargs: Any **kwargs: Any,
) -> None: ) -> None:
self.headers = {} self.headers = {}
@@ -63,7 +63,7 @@ class AsyncHttpConnection(AIOHttpConnection):
headers=headers, headers=headers,
http_compress=http_compress, http_compress=http_compress,
opaque_id=opaque_id, opaque_id=opaque_id,
**kwargs **kwargs,
) )
if http_auth is not None: if http_auth is not None:
@@ -186,7 +186,7 @@ class AsyncHttpConnection(AIOHttpConnection):
# then we pass a string into ClientSession.request() instead. # then we pass a string into ClientSession.request() instead.
url = self.url_prefix + url url = self.url_prefix + url
if query_string: if query_string:
url = "%s?%s" % (url, query_string) url = f"{url}?{query_string}"
url = self.host + url url = self.host + url
timeout = aiohttp.ClientTimeout( timeout = aiohttp.ClientTimeout(
+5 -8
View File
@@ -92,7 +92,7 @@ class RequestsHttpConnection(Connection):
opaque_id: Any = None, opaque_id: Any = None,
pool_maxsize: Any = None, pool_maxsize: Any = None,
metrics: Metrics = MetricsNone(), metrics: Metrics = MetricsNone(),
**kwargs: Any **kwargs: Any,
) -> None: ) -> None:
self.metrics = metrics self.metrics = metrics
if not REQUESTS_AVAILABLE: if not REQUESTS_AVAILABLE:
@@ -111,14 +111,14 @@ class RequestsHttpConnection(Connection):
self.session.mount("http://", pool_adapter) self.session.mount("http://", pool_adapter)
self.session.mount("https://", pool_adapter) self.session.mount("https://", pool_adapter)
super(RequestsHttpConnection, self).__init__( super().__init__(
host=host, host=host,
port=port, port=port,
use_ssl=use_ssl, use_ssl=use_ssl,
headers=headers, headers=headers,
http_compress=http_compress, http_compress=http_compress,
opaque_id=opaque_id, opaque_id=opaque_id,
**kwargs **kwargs,
) )
if not self.http_compress: if not self.http_compress:
@@ -132,10 +132,7 @@ class RequestsHttpConnection(Connection):
http_auth = tuple(http_auth.split(":", 1)) # type: ignore http_auth = tuple(http_auth.split(":", 1)) # type: ignore
self.session.auth = http_auth self.session.auth = http_auth
self.base_url = "%s%s" % ( self.base_url = f"{self.host}{self.url_prefix}"
self.host,
self.url_prefix,
)
self.session.verify = verify_certs self.session.verify = verify_certs
if not client_key: if not client_key:
self.session.cert = client_cert self.session.cert = client_cert
@@ -176,7 +173,7 @@ class RequestsHttpConnection(Connection):
url = self.base_url + url url = self.base_url + url
headers = headers or {} headers = headers or {}
if params: if params:
url = "%s?%s" % (url, urlencode(params or {})) url = f"{url}?{urlencode(params or {})}"
orig_body = body orig_body = body
if self.http_compress and body: if self.http_compress and body:
+4 -4
View File
@@ -121,20 +121,20 @@ class Urllib3HttpConnection(Connection):
http_compress: Any = None, http_compress: Any = None,
opaque_id: Any = None, opaque_id: Any = None,
metrics: Metrics = MetricsNone(), metrics: Metrics = MetricsNone(),
**kwargs: Any **kwargs: Any,
) -> None: ) -> None:
self.metrics = metrics self.metrics = metrics
# Initialize headers before calling super().__init__(). # Initialize headers before calling super().__init__().
self.headers = urllib3.make_headers(keep_alive=True) self.headers = urllib3.make_headers(keep_alive=True)
super(Urllib3HttpConnection, self).__init__( super().__init__(
host=host, host=host,
port=port, port=port,
use_ssl=use_ssl, use_ssl=use_ssl,
headers=headers, headers=headers,
http_compress=http_compress, http_compress=http_compress,
opaque_id=opaque_id, opaque_id=opaque_id,
**kwargs **kwargs,
) )
self.http_auth = http_auth self.http_auth = http_auth
@@ -245,7 +245,7 @@ class Urllib3HttpConnection(Connection):
url = self.url_prefix + url url = self.url_prefix + url
if params: if params:
url = "%s?%s" % (url, urlencode(params)) url = f"{url}?{urlencode(params)}"
full_url = self.host + url full_url = self.host + url
+1 -1
View File
@@ -47,7 +47,7 @@ class PoolingConnection(Connection):
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
self._free_connections = queue.Queue() self._free_connections = queue.Queue()
super(PoolingConnection, self).__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def _make_connection(self) -> Connection: def _make_connection(self) -> Connection:
raise NotImplementedError raise NotImplementedError
+5 -5
View File
@@ -38,7 +38,7 @@ from .exceptions import ImproperlyConfigured
logger: logging.Logger = logging.getLogger("opensearch") logger: logging.Logger = logging.getLogger("opensearch")
class ConnectionSelector(object): class ConnectionSelector:
""" """
Simple class used to select a connection from a list of currently live Simple class used to select a connection from a list of currently live
connection instances. In init time it is passed a dictionary containing all connection instances. In init time it is passed a dictionary containing all
@@ -87,7 +87,7 @@ class RoundRobinSelector(ConnectionSelector):
""" """
def __init__(self, opts: Sequence[Tuple[Connection, Any]]) -> None: def __init__(self, opts: Sequence[Tuple[Connection, Any]]) -> None:
super(RoundRobinSelector, self).__init__(opts) super().__init__(opts)
self.data = threading.local() self.data = threading.local()
def select(self, connections: Sequence[Connection]) -> Any: def select(self, connections: Sequence[Connection]) -> Any:
@@ -96,7 +96,7 @@ class RoundRobinSelector(ConnectionSelector):
return connections[self.data.rr] return connections[self.data.rr]
class ConnectionPool(object): class ConnectionPool:
""" """
Container holding the :class:`~opensearchpy.Connection` instances, Container holding the :class:`~opensearchpy.Connection` instances,
managing the selection process (via a managing the selection process (via a
@@ -135,7 +135,7 @@ class ConnectionPool(object):
timeout_cutoff: int = 5, timeout_cutoff: int = 5,
selector_class: Type[ConnectionSelector] = RoundRobinSelector, selector_class: Type[ConnectionSelector] = RoundRobinSelector,
randomize_hosts: bool = True, randomize_hosts: bool = True,
**kwargs: Any **kwargs: Any,
) -> None: ) -> None:
""" """
:arg connections: list of tuples containing the :arg connections: list of tuples containing the
@@ -290,7 +290,7 @@ class ConnectionPool(object):
conn.close() conn.close()
def __repr__(self) -> str: def __repr__(self) -> str:
return "<%s: %r>" % (type(self).__name__, self.connections) return f"<{type(self).__name__}: {self.connections!r}>"
class DummyConnectionPool(ConnectionPool): class DummyConnectionPool(ConnectionPool):
+3 -3
View File
@@ -120,7 +120,7 @@ class TransportError(OpenSearchException):
except LookupError: except LookupError:
pass pass
msg = ", ".join(filter(None, [str(self.status_code), repr(self.error), cause])) msg = ", ".join(filter(None, [str(self.status_code), repr(self.error), cause]))
return "%s(%s)" % (self.__class__.__name__, msg) return f"{self.__class__.__name__}({msg})"
class ConnectionError(TransportError): class ConnectionError(TransportError):
@@ -131,7 +131,7 @@ class ConnectionError(TransportError):
""" """
def __str__(self) -> str: def __str__(self) -> str:
return "ConnectionError(%s) caused by: %s(%s)" % ( return "ConnectionError({}) caused by: {}({})".format(
self.error, self.error,
self.info.__class__.__name__, self.info.__class__.__name__,
self.info, self.info,
@@ -146,7 +146,7 @@ class ConnectionTimeout(ConnectionError):
"""A network timeout. Doesn't cause a node retry by default.""" """A network timeout. Doesn't cause a node retry by default."""
def __str__(self) -> str: def __str__(self) -> str:
return "ConnectionTimeout caused by - %s(%s)" % ( return "ConnectionTimeout caused by - {}({})".format(
self.info.__class__.__name__, self.info.__class__.__name__,
self.info, self.info,
) )
+14 -17
View File
@@ -198,7 +198,7 @@ def _process_bulk_chunk_success(
yield ok, {op_type: item} yield ok, {op_type: item}
if errors: if errors:
raise BulkIndexError("%i document(s) failed to index." % len(errors), errors) raise BulkIndexError(f"{len(errors)} document(s) failed to index.", errors)
def _process_bulk_chunk_error( def _process_bulk_chunk_error(
@@ -228,7 +228,7 @@ def _process_bulk_chunk_error(
# emulate standard behavior for failed actions # emulate standard behavior for failed actions
if raise_on_error and error.status_code not in ignore_status: if raise_on_error and error.status_code not in ignore_status:
raise BulkIndexError( raise BulkIndexError(
"%i document(s) failed to index." % len(exc_errors), exc_errors f"{len(exc_errors)} document(s) failed to index.", exc_errors
) )
else: else:
for err in exc_errors: for err in exc_errors:
@@ -243,7 +243,7 @@ def _process_bulk_chunk(
raise_on_error: bool = True, raise_on_error: bool = True,
ignore_status: Any = (), ignore_status: Any = (),
*args: Any, *args: Any,
**kwargs: Any **kwargs: Any,
) -> Any: ) -> Any:
""" """
Send a bulk request to opensearch and process the output. Send a bulk request to opensearch and process the output.
@@ -269,8 +269,7 @@ def _process_bulk_chunk(
ignore_status=ignore_status, ignore_status=ignore_status,
raise_on_error=raise_on_error, raise_on_error=raise_on_error,
) )
for item in gen: yield from gen
yield item
def streaming_bulk( def streaming_bulk(
@@ -287,7 +286,7 @@ def streaming_bulk(
yield_ok: bool = True, yield_ok: bool = True,
ignore_status: Any = (), ignore_status: Any = (),
*args: Any, *args: Any,
**kwargs: Any **kwargs: Any,
) -> Any: ) -> Any:
""" """
Streaming bulk consumes actions from the iterable passed in and yields Streaming bulk consumes actions from the iterable passed in and yields
@@ -344,7 +343,7 @@ def streaming_bulk(
raise_on_error, raise_on_error,
ignore_status, ignore_status,
*args, *args,
**kwargs **kwargs,
), ),
): ):
if not ok: if not ok:
@@ -384,7 +383,7 @@ def bulk(
stats_only: bool = False, stats_only: bool = False,
ignore_status: Any = (), ignore_status: Any = (),
*args: Any, *args: Any,
**kwargs: Any **kwargs: Any,
) -> Any: ) -> Any:
""" """
Helper for the :meth:`~opensearchpy.OpenSearch.bulk` api that provides Helper for the :meth:`~opensearchpy.OpenSearch.bulk` api that provides
@@ -445,7 +444,7 @@ def parallel_bulk(
raise_on_error: bool = True, raise_on_error: bool = True,
ignore_status: Any = (), ignore_status: Any = (),
*args: Any, *args: Any,
**kwargs: Any **kwargs: Any,
) -> Any: ) -> Any:
""" """
Parallel version of the bulk helper run in multiple threads at once. Parallel version of the bulk helper run in multiple threads at once.
@@ -474,7 +473,7 @@ def parallel_bulk(
class BlockingPool(ThreadPool): class BlockingPool(ThreadPool):
def _setup_queues(self) -> None: def _setup_queues(self) -> None:
super(BlockingPool, self)._setup_queues() # type: ignore super()._setup_queues() # type: ignore
# The queue must be at least the size of the number of threads to # The queue must be at least the size of the number of threads to
# prevent hanging when inserting sentinel values during teardown. # prevent hanging when inserting sentinel values during teardown.
self._inqueue: Any = Queue(max(queue_size, thread_count)) self._inqueue: Any = Queue(max(queue_size, thread_count))
@@ -493,15 +492,14 @@ def parallel_bulk(
raise_on_error, raise_on_error,
ignore_status, ignore_status,
*args, *args,
**kwargs **kwargs,
) )
), ),
_chunk_actions( _chunk_actions(
actions, chunk_size, max_chunk_bytes, client.transport.serializer actions, chunk_size, max_chunk_bytes, client.transport.serializer
), ),
): ):
for item in result: yield from result
yield item
finally: finally:
pool.close() pool.close()
@@ -518,7 +516,7 @@ def scan(
request_timeout: Optional[float] = None, request_timeout: Optional[float] = None,
clear_scroll: Optional[bool] = True, clear_scroll: Optional[bool] = True,
scroll_kwargs: Any = None, scroll_kwargs: Any = None,
**kwargs: Any **kwargs: Any,
) -> Any: ) -> Any:
""" """
Simple abstraction on top of the Simple abstraction on top of the
@@ -587,8 +585,7 @@ def scan(
try: try:
while scroll_id and resp.get("hits", {}).get("hits"): while scroll_id and resp.get("hits", {}).get("hits"):
for hit in resp.get("hits", {}).get("hits", []): yield from resp.get("hits", {}).get("hits", [])
yield hit
_shards = resp.get("_shards") _shards = resp.get("_shards")
@@ -687,5 +684,5 @@ def reindex(
target_client, target_client,
_change_doc_index(docs, target_index), _change_doc_index(docs, target_index),
chunk_size=chunk_size, chunk_size=chunk_size,
**kwargs **kwargs,
) )
+5 -5
View File
@@ -89,7 +89,7 @@ class Agg(DslBase):
return False return False
def to_dict(self) -> Any: def to_dict(self) -> Any:
d = super(Agg, self).to_dict() d = super().to_dict()
if "meta" in d[self.name]: if "meta" in d[self.name]:
d["meta"] = d[self.name].pop("meta") d["meta"] = d[self.name].pop("meta")
return d return d
@@ -98,7 +98,7 @@ class Agg(DslBase):
return AggResponse(self, search, data) return AggResponse(self, search, data)
class AggBase(object): class AggBase:
_param_defs = { _param_defs = {
"aggs": {"type": "agg", "hash": True}, "aggs": {"type": "agg", "hash": True},
} }
@@ -151,7 +151,7 @@ class AggBase(object):
class Bucket(AggBase, Agg): class Bucket(AggBase, Agg):
def __init__(self, **params: Any) -> None: def __init__(self, **params: Any) -> None:
super(Bucket, self).__init__(**params) super().__init__(**params)
# remember self for chaining # remember self for chaining
self._base = self self._base = self
@@ -172,10 +172,10 @@ class Filter(Bucket):
def __init__(self, filter: Any = None, **params: Any) -> None: def __init__(self, filter: Any = None, **params: Any) -> None:
if filter is not None: if filter is not None:
params["filter"] = filter params["filter"] = filter
super(Filter, self).__init__(**params) super().__init__(**params)
def to_dict(self) -> Any: def to_dict(self) -> Any:
d = super(Filter, self).to_dict() d = super().to_dict()
d[self.name].update(d[self.name].pop("filter", {})) d[self.name].update(d[self.name].pop("filter", {}))
return d return d
+1 -1
View File
@@ -38,7 +38,7 @@ class AnalysisBase:
) -> Any: ) -> Any:
if isinstance(name_or_instance, cls): if isinstance(name_or_instance, cls):
if type or kwargs: if type or kwargs:
raise ValueError("%s() cannot accept parameters." % cls.__name__) raise ValueError(f"{cls.__name__}() cannot accept parameters.")
return name_or_instance return name_or_instance
if not (type or kwargs): if not (type or kwargs):
+2 -2
View File
@@ -188,7 +188,7 @@ class Document(ObjectBase, metaclass=IndexMeta):
return "{}({})".format( return "{}({})".format(
self.__class__.__name__, self.__class__.__name__,
", ".join( ", ".join(
"{}={!r}".format(key, getattr(self.meta, key)) f"{key}={getattr(self.meta, key)!r}"
for key in ("index", "id") for key in ("index", "id")
if key in self.meta if key in self.meta
), ),
@@ -310,7 +310,7 @@ class Document(ObjectBase, metaclass=IndexMeta):
raise RequestError(400, message, error_docs) raise RequestError(400, message, error_docs)
if missing_docs: if missing_docs:
missing_ids = [doc["_id"] for doc in missing_docs] missing_ids = [doc["_id"] for doc in missing_docs]
message = "Documents %s not found." % ", ".join(missing_ids) message = f"Documents {', '.join(missing_ids)} not found."
raise NotFoundError(404, message, {"docs": missing_docs}) raise NotFoundError(404, message, {"docs": missing_docs})
return objs return objs
+1 -1
View File
@@ -41,5 +41,5 @@ class ScanError(OpenSearchException):
scroll_id: str scroll_id: str
def __init__(self, scroll_id: str, *args: Any, **kwargs: Any) -> None: def __init__(self, scroll_id: str, *args: Any, **kwargs: Any) -> None:
super(ScanError, self).__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.scroll_id = scroll_id self.scroll_id = scroll_id
+3 -3
View File
@@ -157,7 +157,7 @@ class Object(Field):
doc_class: Any = None, doc_class: Any = None,
dynamic: Any = None, dynamic: Any = None,
properties: Any = None, properties: Any = None,
**kwargs: Any **kwargs: Any,
) -> None: ) -> None:
""" """
:arg document.InnerDoc doc_class: base doc class that handles mapping. :arg document.InnerDoc doc_class: base doc class that handles mapping.
@@ -281,7 +281,7 @@ class Date(Field):
data = parser.parse(data) data = parser.parse(data)
except Exception as e: except Exception as e:
raise ValidationException( raise ValidationException(
"Could not parse date from the value (%r)" % data, e f"Could not parse date from the value ({data!r})", e
) )
if isinstance(data, datetime): if isinstance(data, datetime):
@@ -294,7 +294,7 @@ class Date(Field):
# Divide by a float to preserve milliseconds on the datetime. # Divide by a float to preserve milliseconds on the datetime.
return datetime.utcfromtimestamp(data / 1000.0) return datetime.utcfromtimestamp(data / 1000.0)
raise ValidationException("Could not parse date from the value (%r)" % data) raise ValidationException(f"Could not parse date from the value ({data!r})")
class Text(Field): class Text(Field):
+3 -3
View File
@@ -48,7 +48,7 @@ def SF(name_or_sf: Any, **params: Any) -> Any: # pylint: disable=invalid-name
elif len(sf) == 1: elif len(sf) == 1:
name, params = sf.popitem() name, params = sf.popitem()
else: else:
raise ValueError("SF() got an unexpected fields in the dictionary: %r" % sf) raise ValueError(f"SF() got an unexpected fields in the dictionary: {sf!r}")
# boost factor special case, see https://github.com/elastic/elasticsearch/issues/6343 # boost factor special case, see https://github.com/elastic/elasticsearch/issues/6343
if not isinstance(params, collections_abc.Mapping): if not isinstance(params, collections_abc.Mapping):
@@ -81,7 +81,7 @@ class ScoreFunction(DslBase):
name: Optional[str] = None name: Optional[str] = None
def to_dict(self) -> Any: def to_dict(self) -> Any:
d = super(ScoreFunction, self).to_dict() d = super().to_dict()
# filter and query dicts should be at the same level as us # filter and query dicts should be at the same level as us
for k in self._param_defs: for k in self._param_defs:
if k in d[self.name]: if k in d[self.name]:
@@ -97,7 +97,7 @@ class BoostFactor(ScoreFunction):
name = "boost_factor" name = "boost_factor"
def to_dict(self) -> Any: def to_dict(self) -> Any:
d = super(BoostFactor, self).to_dict() d = super().to_dict()
if "value" in d[self.name]: if "value" in d[self.name]:
d[self.name] = d[self.name].pop("value") d[self.name] = d[self.name].pop("value")
else: else:
+2 -2
View File
@@ -37,7 +37,7 @@ from .update_by_query import UpdateByQuery
from .utils import merge from .utils import merge
class IndexTemplate(object): class IndexTemplate:
def __init__( def __init__(
self, self,
name: Any, name: Any,
@@ -76,7 +76,7 @@ class IndexTemplate(object):
) )
class Index(object): class Index:
def __init__(self, name: Any, using: Any = "default") -> None: def __init__(self, name: Any, using: Any = "default") -> None:
""" """
:arg name: name of the index :arg name: name of the index
+1 -1
View File
@@ -261,7 +261,7 @@ class FunctionScore(Query):
for name in ScoreFunction._classes: for name in ScoreFunction._classes:
if name in kwargs: if name in kwargs:
fns.append({name: kwargs.pop(name)}) fns.append({name: kwargs.pop(name)})
super(FunctionScore, self).__init__(**kwargs) super().__init__(**kwargs)
# compound queries # compound queries
+5 -5
View File
@@ -34,7 +34,7 @@ class Response(AttrDict):
def __init__(self, search: Any, response: Any, doc_class: Any = None) -> None: def __init__(self, search: Any, response: Any, doc_class: Any = None) -> None:
super(AttrDict, self).__setattr__("_search", search) super(AttrDict, self).__setattr__("_search", search)
super(AttrDict, self).__setattr__("_doc_class", doc_class) super(AttrDict, self).__setattr__("_doc_class", doc_class)
super(Response, self).__init__(response) super().__init__(response)
def __iter__(self) -> Any: def __iter__(self) -> Any:
return iter(self.hits) return iter(self.hits)
@@ -43,7 +43,7 @@ class Response(AttrDict):
if isinstance(key, (slice, int)): if isinstance(key, (slice, int)):
# for slicing etc # for slicing etc
return self.hits[key] return self.hits[key]
return super(Response, self).__getitem__(key) return super().__getitem__(key)
def __nonzero__(self) -> Any: def __nonzero__(self) -> Any:
return bool(self.hits) return bool(self.hits)
@@ -103,14 +103,14 @@ class Response(AttrDict):
class AggResponse(AttrDict): class AggResponse(AttrDict):
def __init__(self, aggs: Any, search: Any, data: Any) -> None: def __init__(self, aggs: Any, search: Any, data: Any) -> None:
super(AttrDict, self).__setattr__("_meta", {"search": search, "aggs": aggs}) super(AttrDict, self).__setattr__("_meta", {"search": search, "aggs": aggs})
super(AggResponse, self).__init__(data) super().__init__(data)
def __getitem__(self, attr_name: Any) -> Any: def __getitem__(self, attr_name: Any) -> Any:
if attr_name in self._meta["aggs"]: if attr_name in self._meta["aggs"]:
# don't do self._meta['aggs'][attr_name] to avoid copying # don't do self._meta['aggs'][attr_name] to avoid copying
agg = self._meta["aggs"].aggs[attr_name] agg = self._meta["aggs"].aggs[attr_name]
return agg.result(self._meta["search"], self._d_[attr_name]) return agg.result(self._meta["search"], self._d_[attr_name])
return super(AggResponse, self).__getitem__(attr_name) return super().__getitem__(attr_name)
def __iter__(self) -> Any: def __iter__(self) -> Any:
for name in self._meta["aggs"]: for name in self._meta["aggs"]:
@@ -121,7 +121,7 @@ class UpdateByQueryResponse(AttrDict):
def __init__(self, search: Any, response: Any, doc_class: Any = None) -> None: def __init__(self, search: Any, response: Any, doc_class: Any = None) -> None:
super(AttrDict, self).__setattr__("_search", search) super(AttrDict, self).__setattr__("_search", search)
super(AttrDict, self).__setattr__("_doc_class", doc_class) super(AttrDict, self).__setattr__("_doc_class", doc_class)
super(UpdateByQueryResponse, self).__init__(response) super().__init__(response)
def success(self) -> bool: def success(self) -> bool:
return not self.timed_out and not self.failures return not self.timed_out and not self.failures
+4 -4
View File
@@ -32,14 +32,14 @@ from . import AggResponse, Response
class Bucket(AggResponse): class Bucket(AggResponse):
def __init__(self, aggs: Any, search: Any, data: Any, field: Any = None) -> None: def __init__(self, aggs: Any, search: Any, data: Any, field: Any = None) -> None:
super(Bucket, self).__init__(aggs, search, data) super().__init__(aggs, search, data)
class FieldBucket(Bucket): class FieldBucket(Bucket):
def __init__(self, aggs: Any, search: Any, data: Any, field: Any = None) -> None: def __init__(self, aggs: Any, search: Any, data: Any, field: Any = None) -> None:
if field: if field:
data["key"] = field.deserialize(data["key"]) data["key"] = field.deserialize(data["key"])
super(FieldBucket, self).__init__(aggs, search, data, field) super().__init__(aggs, search, data, field)
class BucketData(AggResponse): class BucketData(AggResponse):
@@ -62,7 +62,7 @@ class BucketData(AggResponse):
def __getitem__(self, key: Any) -> Any: def __getitem__(self, key: Any) -> Any:
if isinstance(key, (int, slice)): if isinstance(key, (int, slice)):
return self.buckets[key] return self.buckets[key]
return super(BucketData, self).__getitem__(key) return super().__getitem__(key)
@property @property
def buckets(self) -> Any: def buckets(self) -> Any:
@@ -88,7 +88,7 @@ class TopHitsData(Response):
super(AttrDict, self).__setattr__( super(AttrDict, self).__setattr__(
"meta", AttrDict({"agg": agg, "search": search}) "meta", AttrDict({"agg": agg, "search": search})
) )
super(TopHitsData, self).__init__(search, data) super().__init__(search, data)
__all__ = ["AggResponse"] __all__ = ["AggResponse"]
+5 -5
View File
@@ -37,28 +37,28 @@ class Hit(AttrDict):
if "fields" in document: if "fields" in document:
data.update(document["fields"]) data.update(document["fields"])
super(Hit, self).__init__(data) super().__init__(data)
# assign meta as attribute and not as key in self._d_ # assign meta as attribute and not as key in self._d_
super(AttrDict, self).__setattr__("meta", HitMeta(document)) super(AttrDict, self).__setattr__("meta", HitMeta(document))
def __getstate__(self) -> Any: def __getstate__(self) -> Any:
# add self.meta since it is not in self.__dict__ # add self.meta since it is not in self.__dict__
return super(Hit, self).__getstate__() + (self.meta,) return super().__getstate__() + (self.meta,)
def __setstate__(self, state: Any) -> None: def __setstate__(self, state: Any) -> None:
super(AttrDict, self).__setattr__("meta", state[-1]) super(AttrDict, self).__setattr__("meta", state[-1])
super(Hit, self).__setstate__(state[:-1]) super().__setstate__(state[:-1])
def __dir__(self) -> Any: def __dir__(self) -> Any:
# be sure to expose meta in dir(self) # be sure to expose meta in dir(self)
return super(Hit, self).__dir__() + ["meta"] return super().__dir__() + ["meta"]
def __repr__(self) -> str: def __repr__(self) -> str:
return "<Hit({}): {}>".format( return "<Hit({}): {}>".format(
"/".join( "/".join(
getattr(self.meta, key) for key in ("index", "id") if key in self.meta getattr(self.meta, key) for key in ("index", "id") if key in self.meta
), ),
super(Hit, self).__repr__(), super().__repr__(),
) )
+1 -1
View File
@@ -96,7 +96,7 @@ class ProxyDescriptor:
""" """
def __init__(self, name: str) -> None: def __init__(self, name: str) -> None:
self._attr_name = "_%s_proxy" % name self._attr_name = f"_{name}_proxy"
def __get__(self, instance: Any, owner: Any) -> Any: def __get__(self, instance: Any, owner: Any) -> Any:
return getattr(instance, self._attr_name) return getattr(instance, self._attr_name)
+3 -3
View File
@@ -116,9 +116,9 @@ class RequestsAWSV4SignerAuth(requests.auth.AuthBase):
) )
# fetch the host information from headers # fetch the host information from headers
headers = dict( headers = {
(key.lower(), value) for key, value in prepared_request.headers.items() key.lower(): value for key, value in prepared_request.headers.items()
) }
location = headers.get("host") or url.netloc location = headers.get("host") or url.netloc
# construct the url and return # construct the url and return
+2 -2
View File
@@ -49,7 +49,7 @@ class UpdateByQuery(Request):
overridden by methods (`using`, `index` and `doc_type` respectively). overridden by methods (`using`, `index` and `doc_type` respectively).
""" """
super(UpdateByQuery, self).__init__(**kwargs) super().__init__(**kwargs)
self._response_class = UpdateByQueryResponse self._response_class = UpdateByQueryResponse
self._script: Any = {} self._script: Any = {}
self._query_proxy = QueryProxy(self, "query") self._query_proxy = QueryProxy(self, "query")
@@ -88,7 +88,7 @@ class UpdateByQuery(Request):
of all the underlying objects. Used internally by most state modifying of all the underlying objects. Used internally by most state modifying
APIs. APIs.
""" """
ubq = super(UpdateByQuery, self)._clone() ubq = super()._clone()
ubq._response_class = self._response_class ubq._response_class = self._response_class
ubq._script = self._script.copy() ubq._script = self._script.copy()
+11 -23
View File
@@ -163,9 +163,7 @@ class AttrDict:
return self.__getitem__(attr_name) return self.__getitem__(attr_name)
except KeyError: except KeyError:
raise AttributeError( raise AttributeError(
"{!r} object has no attribute {!r}".format( f"{self.__class__.__name__!r} object has no attribute {attr_name!r}"
self.__class__.__name__, attr_name
)
) )
def get(self, key: Any, default: Any = None) -> Any: def get(self, key: Any, default: Any = None) -> Any:
@@ -181,9 +179,7 @@ class AttrDict:
del self._d_[attr_name] del self._d_[attr_name]
except KeyError: except KeyError:
raise AttributeError( raise AttributeError(
"{!r} object has no attribute {!r}".format( f"{self.__class__.__name__!r} object has no attribute {attr_name!r}"
self.__class__.__name__, attr_name
)
) )
def __getitem__(self, key: Any) -> Any: def __getitem__(self, key: Any) -> Any:
@@ -245,7 +241,7 @@ class DslMeta(type):
try: try:
return cls._types[name] return cls._types[name]
except KeyError: except KeyError:
raise UnknownDslObject("DSL type %s does not exist." % name) raise UnknownDslObject(f"DSL type {name} does not exist.")
class DslBase(metaclass=DslMeta): class DslBase(metaclass=DslMeta):
@@ -275,7 +271,7 @@ class DslBase(metaclass=DslMeta):
if default is not None: if default is not None:
return cls._classes[default] return cls._classes[default]
raise UnknownDslObject( raise UnknownDslObject(
"DSL class `{}` does not exist in {}.".format(name, cls._type_name) f"DSL class `{name}` does not exist in {cls._type_name}."
) )
def __init__(self, _expand__to_dot: Any = EXPAND__TO_DOT, **params: Any) -> None: def __init__(self, _expand__to_dot: Any = EXPAND__TO_DOT, **params: Any) -> None:
@@ -288,14 +284,14 @@ class DslBase(metaclass=DslMeta):
def _repr_params(self) -> str: def _repr_params(self) -> str:
"""Produce a repr of all our parameters to be used in __repr__.""" """Produce a repr of all our parameters to be used in __repr__."""
return ", ".join( return ", ".join(
"{}={!r}".format(n.replace(".", "__"), v) f"{n.replace('.', '__')}={v!r}"
for (n, v) in sorted(self._params.items()) for (n, v) in sorted(self._params.items())
# make sure we don't include empty typed params # make sure we don't include empty typed params
if "type" not in self._param_defs.get(n, {}) or v if "type" not in self._param_defs.get(n, {}) or v
) )
def __repr__(self) -> str: def __repr__(self) -> str:
return "{}({})".format(self.__class__.__name__, self._repr_params()) return f"{self.__class__.__name__}({self._repr_params()})"
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
return isinstance(other, self.__class__) and other.to_dict() == self.to_dict() return isinstance(other, self.__class__) and other.to_dict() == self.to_dict()
@@ -341,9 +337,7 @@ class DslBase(metaclass=DslMeta):
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
if name.startswith("_"): if name.startswith("_"):
raise AttributeError( raise AttributeError(
"{!r} object has no attribute {!r}".format( f"{self.__class__.__name__!r} object has no attribute {name!r}"
self.__class__.__name__, name
)
) )
value = None value = None
@@ -360,9 +354,7 @@ class DslBase(metaclass=DslMeta):
value = self._params.setdefault(name, {}) value = self._params.setdefault(name, {})
if value is None: if value is None:
raise AttributeError( raise AttributeError(
"{!r} object has no attribute {!r}".format( f"{self.__class__.__name__!r} object has no attribute {name!r}"
self.__class__.__name__, name
)
) )
# wrap nested dicts in AttrDict for convenient access # wrap nested dicts in AttrDict for convenient access
@@ -567,16 +559,12 @@ def merge(data: Any, new_data: Any, raise_on_conflict: bool = False) -> None:
and isinstance(new_data, (AttrDict, collections_abc.Mapping)) and isinstance(new_data, (AttrDict, collections_abc.Mapping))
): ):
raise ValueError( raise ValueError(
"You can only merge two dicts! Got {!r} and {!r} instead.".format( f"You can only merge two dicts! Got {data!r} and {new_data!r} instead."
data, new_data
)
) )
if not isinstance(new_data, Dict): if not isinstance(new_data, Dict):
raise ValueError( raise ValueError(
"You can only merge two dicts! Got {!r} and {!r} instead.".format( f"You can only merge two dicts! Got {data!r} and {new_data!r} instead."
data, new_data
)
) )
for key, value in new_data.items(): for key, value in new_data.items():
@@ -587,7 +575,7 @@ def merge(data: Any, new_data: Any, raise_on_conflict: bool = False) -> None:
): ):
merge(data[key], value, raise_on_conflict) merge(data[key], value, raise_on_conflict)
elif key in data and data[key] != value and raise_on_conflict: elif key in data and data[key] != value and raise_on_conflict:
raise ValueError("Incompatible data for key %r, cannot be merged." % key) raise ValueError(f"Incompatible data for key {key!r}, cannot be merged.")
else: else:
data[key] = value # type: ignore data[key] = value # type: ignore
+1 -1
View File
@@ -47,7 +47,7 @@ class Range(AttrDict):
for k in data: for k in data:
if k not in self.OPS: if k not in self.OPS:
raise ValueError("Range received an unknown operator %r" % k) raise ValueError(f"Range received an unknown operator {k!r}")
if "gt" in data and "gte" in data: if "gt" in data and "gte" in data:
raise ValueError("You cannot specify both gt and gte for Range.") raise ValueError("You cannot specify both gt and gte for Range.")
+7 -7
View File
@@ -45,7 +45,7 @@ FLOAT_TYPES = (Decimal,)
TIME_TYPES = (date, datetime) TIME_TYPES = (date, datetime)
class Serializer(object): class Serializer:
mimetype: str = "" mimetype: str = ""
def loads(self, s: str) -> Any: def loads(self, s: str) -> Any:
@@ -65,7 +65,7 @@ class TextSerializer(Serializer):
if isinstance(data, string_types): if isinstance(data, string_types):
return data return data
raise SerializationError("Cannot serialize %r into text." % data) raise SerializationError(f"Cannot serialize {data!r} into text.")
class JSONSerializer(Serializer): class JSONSerializer(Serializer):
@@ -140,7 +140,7 @@ class JSONSerializer(Serializer):
except ImportError: except ImportError:
pass pass
raise TypeError("Unable to serialize %r (type: %s)" % (data, type(data))) raise TypeError(f"Unable to serialize {data!r} (type: {type(data)})")
def loads(self, s: str) -> Any: def loads(self, s: str) -> Any:
try: try:
@@ -167,7 +167,7 @@ DEFAULT_SERIALIZERS: Dict[str, Serializer] = {
} }
class Deserializer(object): class Deserializer:
def __init__( def __init__(
self, self,
serializers: Dict[str, Serializer], serializers: Dict[str, Serializer],
@@ -177,7 +177,7 @@ class Deserializer(object):
self.default = serializers[default_mimetype] self.default = serializers[default_mimetype]
except KeyError: except KeyError:
raise ImproperlyConfigured( raise ImproperlyConfigured(
"Cannot find default serializer (%s)" % default_mimetype f"Cannot find default serializer ({default_mimetype})"
) )
self.serializers = serializers self.serializers = serializers
@@ -196,7 +196,7 @@ class Deserializer(object):
deserializer = self.serializers[mimetype] deserializer = self.serializers[mimetype]
except KeyError: except KeyError:
raise SerializationError( raise SerializationError(
"Unknown mimetype, unable to deserialize: %s" % mimetype f"Unknown mimetype, unable to deserialize: {mimetype}"
) )
return deserializer.loads(s) return deserializer.loads(s)
@@ -208,7 +208,7 @@ class AttrJSONSerializer(JSONSerializer):
return data._l_ return data._l_
if hasattr(data, "to_dict"): if hasattr(data, "to_dict"):
return data.to_dict() return data.to_dict()
return super(AttrJSONSerializer, self).default(data) return super().default(data)
serializer = AttrJSONSerializer() serializer = AttrJSONSerializer()
+1 -1
View File
@@ -64,7 +64,7 @@ def get_host_info(
return host return host
class Transport(object): class Transport:
""" """
Encapsulation of transport-related to logic. Handles instantiation of the Encapsulation of transport-related to logic. Handles instantiation of the
individual connections as well as creating a connection pool to hold them. individual connections as well as creating a connection pool to hold them.
-1
View File
@@ -64,7 +64,6 @@ install_requires = [
tests_require = [ tests_require = [
"requests>=2.0.0, <3.0.0", "requests>=2.0.0, <3.0.0",
"coverage<8.0.0", "coverage<8.0.0",
"mock",
"pyyaml", "pyyaml",
"pytest>=3.0.0", "pytest>=3.0.0",
"pytest-cov", "pytest-cov",
+8 -10
View File
@@ -31,8 +31,6 @@
# under the License. # under the License.
from __future__ import print_function
import subprocess import subprocess
import sys import sys
from os import environ from os import environ
@@ -78,10 +76,10 @@ def fetch_opensearch_repo() -> None:
# no test directory # no test directory
if not exists(repo_path): if not exists(repo_path):
subprocess.check_call("mkdir %s" % repo_path, shell=True) subprocess.check_call(f"mkdir {repo_path}", shell=True)
# make a new blank repository in the test directory # make a new blank repository in the test directory
subprocess.check_call("cd %s && git init" % repo_path, shell=True) subprocess.check_call(f"cd {repo_path} && git init", shell=True)
try: try:
# add a remote # add a remote
@@ -104,7 +102,7 @@ def fetch_opensearch_repo() -> None:
# fetch the sha commit, version from info() # fetch the sha commit, version from info()
print("Fetching opensearch repo...") print("Fetching opensearch repo...")
subprocess.check_call("cd %s && git fetch origin %s" % (repo_path, sha), shell=True) subprocess.check_call(f"cd {repo_path} && git fetch origin {sha}", shell=True)
def run_all(argv: Any = None) -> None: def run_all(argv: Any = None) -> None:
@@ -136,18 +134,18 @@ def run_all(argv: Any = None) -> None:
argv = [ argv = [
"pytest", "pytest",
"--cov=opensearchpy", "--cov=opensearchpy",
"--junitxml=%s" % junit_xml, f"--junitxml={junit_xml}",
"--log-level=DEBUG", "--log-level=DEBUG",
"--cache-clear", "--cache-clear",
"-vv", "-vv",
"--cov-report=xml:%s" % codecov_xml, f"--cov-report=xml:{codecov_xml}",
] ]
if ( if (
"OPENSEARCHPY_GEN_HTML_COV" in environ "OPENSEARCHPY_GEN_HTML_COV" in environ
and environ.get("OPENSEARCHPY_GEN_HTML_COV") == "true" and environ.get("OPENSEARCHPY_GEN_HTML_COV") == "true"
): ):
codecov_html = join(abspath(dirname(dirname(__file__))), "junit", "html") codecov_html = join(abspath(dirname(dirname(__file__))), "junit", "html")
argv.append("--cov-report=html:%s" % codecov_html) argv.append(f"--cov-report=html:{codecov_html}")
secured = False secured = False
if environ.get("OPENSEARCH_URL", "").startswith("https://"): if environ.get("OPENSEARCH_URL", "").startswith("https://"):
@@ -156,7 +154,7 @@ def run_all(argv: Any = None) -> None:
# check TEST_PATTERN env var for specific test to run # check TEST_PATTERN env var for specific test to run
test_pattern = environ.get("TEST_PATTERN") test_pattern = environ.get("TEST_PATTERN")
if test_pattern: if test_pattern:
argv.append("-k %s" % test_pattern) argv.append(f"-k {test_pattern}")
else: else:
ignores = [ ignores = [
"test_opensearchpy/test_server/", "test_opensearchpy/test_server/",
@@ -196,7 +194,7 @@ def run_all(argv: Any = None) -> None:
) )
if ignores: if ignores:
argv.extend(["--ignore=%s" % ignore for ignore in ignores]) argv.extend([f"--ignore={ignore}" for ignore in ignores])
# Not in CI, run all tests specified. # Not in CI, run all tests specified.
else: else:
@@ -32,11 +32,11 @@ import ssl
import warnings import warnings
from platform import python_version from platform import python_version
from typing import Any from typing import Any
from unittest.mock import MagicMock, patch
import aiohttp import aiohttp
import pytest import pytest
from _pytest.mark.structures import MarkDecorator from _pytest.mark.structures import MarkDecorator
from mock import MagicMock, patch
from multidict import CIMultiDict from multidict import CIMultiDict
from pytest import raises from pytest import raises
@@ -154,7 +154,7 @@ class TestAIOHttpConnection:
async def test_default_user_agent(self) -> None: async def test_default_user_agent(self) -> None:
con = AIOHttpConnection() con = AIOHttpConnection()
assert con._get_default_user_agent() == "opensearch-py/%s (Python %s)" % ( assert con._get_default_user_agent() == "opensearch-py/{} (Python {})".format(
__versionstr__, __versionstr__,
python_version(), python_version(),
) )
@@ -342,7 +342,7 @@ class TestAIOHttpConnection:
buf = b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa" buf = b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"
con = await self._get_mock_connection(response_body=buf) con = await self._get_mock_connection(response_body=buf)
_, _, data = await con.perform_request("GET", "/") _, _, data = await con.perform_request("GET", "/")
assert u"你好\uda6a" == data # fmt: skip assert "你好\uda6a" == data # fmt: skip
@pytest.mark.parametrize("exception_cls", reraise_exceptions) # type: ignore @pytest.mark.parametrize("exception_cls", reraise_exceptions) # type: ignore
async def test_recursion_error_reraised(self, exception_cls: Any) -> None: async def test_recursion_error_reraised(self, exception_cls: Any) -> None:
@@ -9,10 +9,10 @@
from typing import Any from typing import Any
from unittest.mock import Mock
import pytest import pytest
from _pytest.mark.structures import MarkDecorator from _pytest.mark.structures import MarkDecorator
from mock import Mock
from pytest import fixture from pytest import fixture
from opensearchpy.connection.async_connections import add_connection, async_connections from opensearchpy.connection.async_connections import add_connection, async_connections
@@ -7,7 +7,6 @@
# Modifications Copyright OpenSearch Contributors. See # Modifications Copyright OpenSearch Contributors. See
# GitHub history for details. # GitHub history for details.
from __future__ import unicode_literals
import codecs import codecs
import ipaddress import ipaddress
@@ -73,7 +73,7 @@ async def test_cloned_index_has_analysis_attribute() -> None:
client = object() client = object()
i = AsyncIndex("my-index", using=client) i = AsyncIndex("my-index", using=client)
random_analyzer_name = "".join((choice(string.ascii_letters) for _ in range(100))) random_analyzer_name = "".join(choice(string.ascii_letters) for _ in range(100))
random_analyzer = analyzer( random_analyzer = analyzer(
random_analyzer_name, tokenizer="standard", filter="standard" random_analyzer_name, tokenizer="standard", filter="standard"
) )
@@ -117,7 +117,7 @@ async def test_registered_doc_type_included_in_search() -> None:
async def test_aliases_add_to_object() -> None: async def test_aliases_add_to_object() -> None:
random_alias = "".join((choice(string.ascii_letters) for _ in range(100))) random_alias = "".join(choice(string.ascii_letters) for _ in range(100))
alias_dict: Any = {random_alias: {}} alias_dict: Any = {random_alias: {}}
index = AsyncIndex("i", using="alias") index = AsyncIndex("i", using="alias")
@@ -127,7 +127,7 @@ async def test_aliases_add_to_object() -> None:
async def test_aliases_returned_from_to_dict() -> None: async def test_aliases_returned_from_to_dict() -> None:
random_alias = "".join((choice(string.ascii_letters) for _ in range(100))) random_alias = "".join(choice(string.ascii_letters) for _ in range(100))
alias_dict: Any = {random_alias: {}} alias_dict: Any = {random_alias: {}}
index = AsyncIndex("i", using="alias") index = AsyncIndex("i", using="alias")
@@ -137,7 +137,7 @@ async def test_aliases_returned_from_to_dict() -> None:
async def test_analyzers_added_to_object() -> None: async def test_analyzers_added_to_object() -> None:
random_analyzer_name = "".join((choice(string.ascii_letters) for _ in range(100))) random_analyzer_name = "".join(choice(string.ascii_letters) for _ in range(100))
random_analyzer = analyzer( random_analyzer = analyzer(
random_analyzer_name, tokenizer="standard", filter="standard" random_analyzer_name, tokenizer="standard", filter="standard"
) )
@@ -153,7 +153,7 @@ async def test_analyzers_added_to_object() -> None:
async def test_analyzers_returned_from_to_dict() -> None: async def test_analyzers_returned_from_to_dict() -> None:
random_analyzer_name = "".join((choice(string.ascii_letters) for _ in range(100))) random_analyzer_name = "".join(choice(string.ascii_letters) for _ in range(100))
random_analyzer = analyzer( random_analyzer = analyzer(
random_analyzer_name, tokenizer="standard", filter="standard" random_analyzer_name, tokenizer="standard", filter="standard"
) )
@@ -26,8 +26,8 @@
from typing import Any from typing import Any
from unittest import mock
import mock
import pytest import pytest
from multidict import CIMultiDict from multidict import CIMultiDict
@@ -25,8 +25,6 @@
# under the License. # under the License.
from __future__ import unicode_literals
from typing import Any from typing import Any
import pytest import pytest
@@ -27,9 +27,9 @@
import asyncio import asyncio
from typing import Any, List from typing import Any, List
from unittest.mock import MagicMock, patch
import pytest import pytest
from mock import MagicMock, patch
from opensearchpy import TransportError from opensearchpy import TransportError
from opensearchpy._async.helpers import actions from opensearchpy._async.helpers import actions
@@ -40,13 +40,13 @@ pytestmark = pytest.mark.asyncio
class AsyncMock(MagicMock): class AsyncMock(MagicMock):
async def __call__(self, *args: Any, **kwargs: Any) -> Any: async def __call__(self, *args: Any, **kwargs: Any) -> Any:
return super(AsyncMock, self).__call__(*args, **kwargs) return super().__call__(*args, **kwargs)
def __await__(self) -> Any: def __await__(self) -> Any:
return self().__await__() return self().__await__()
class FailingBulkClient(object): class FailingBulkClient:
def __init__( def __init__(
self, self,
client: Any, client: Any,
@@ -69,7 +69,7 @@ class FailingBulkClient(object):
return await self.client.bulk(*args, **kwargs) return await self.client.bulk(*args, **kwargs)
class TestStreamingBulk(object): class TestStreamingBulk:
async def test_actions_remain_unchanged(self, async_client: Any) -> None: async def test_actions_remain_unchanged(self, async_client: Any) -> None:
actions1 = [{"_id": 1}, {"_id": 2}] actions1 = [{"_id": 1}, {"_id": 2}]
async for ok, _ in actions.async_streaming_bulk( async for ok, _ in actions.async_streaming_bulk(
@@ -281,7 +281,7 @@ class TestStreamingBulk(object):
assert 4 == failing_client._called assert 4 == failing_client._called
class TestBulk(object): class TestBulk:
async def test_bulk_works_with_single_item(self, async_client: Any) -> None: async def test_bulk_works_with_single_item(self, async_client: Any) -> None:
docs = [{"answer": 42, "_id": 1}] docs = [{"answer": 42, "_id": 1}]
success, failed = await actions.async_bulk( success, failed = await actions.async_bulk(
@@ -453,7 +453,7 @@ async def scan_teardown(async_client: Any) -> Any:
await async_client.clear_scroll(scroll_id="_all") await async_client.clear_scroll(scroll_id="_all")
class TestScan(object): class TestScan:
async def test_order_can_be_preserved( async def test_order_can_be_preserved(
self, async_client: Any, scan_teardown: Any self, async_client: Any, scan_teardown: Any
) -> None: ) -> None:
@@ -492,8 +492,8 @@ class TestScan(object):
] ]
assert 100 == len(docs) assert 100 == len(docs)
assert set(map(str, range(100))) == set(d["_id"] for d in docs) assert set(map(str, range(100))) == {d["_id"] for d in docs}
assert set(range(100)) == set(d["_source"]["answer"] for d in docs) assert set(range(100)) == {d["_source"]["answer"] for d in docs}
async def test_scroll_error(self, async_client: Any, scan_teardown: Any) -> None: async def test_scroll_error(self, async_client: Any, scan_teardown: Any) -> None:
bulk: Any = [] bulk: Any = []
@@ -824,7 +824,7 @@ async def reindex_setup(async_client: Any) -> Any:
yield yield
class TestReindex(object): class TestReindex:
async def test_reindex_passes_kwargs_to_scan_and_bulk( async def test_reindex_passes_kwargs_to_scan_and_bulk(
self, async_client: Any, reindex_setup: Any self, async_client: Any, reindex_setup: Any
) -> None: ) -> None:
@@ -7,7 +7,6 @@
# Modifications Copyright OpenSearch Contributors. See # Modifications Copyright OpenSearch Contributors. See
# GitHub history for details. # GitHub history for details.
from __future__ import unicode_literals
from typing import Any, Dict from typing import Any, Dict
@@ -64,7 +64,7 @@ class Repository(AsyncDocument):
@classmethod @classmethod
def search(cls, using: Any = None, index: Optional[str] = None) -> Any: def search(cls, using: Any = None, index: Optional[str] = None) -> Any:
return super(Repository, cls).search().filter("term", commit_repo="repo") return super().search().filter("term", commit_repo="repo")
class Index: class Index:
name = "git" name = "git"
@@ -98,7 +98,7 @@ def repo_search_cls(opensearch_version: Any) -> Any:
} }
def search(self) -> Any: def search(self) -> Any:
s = super(RepoSearch, self).search() s = super().search()
return s.filter("term", commit_repo="repo") return s.filter("term", commit_repo="repo")
return RepoSearch return RepoSearch
@@ -7,7 +7,6 @@
# Modifications Copyright OpenSearch Contributors. See # Modifications Copyright OpenSearch Contributors. See
# GitHub history for details. # GitHub history for details.
from __future__ import unicode_literals
from typing import Any from typing import Any
@@ -31,7 +30,7 @@ class Repository(AsyncDocument):
@classmethod @classmethod
def search(cls, using: Any = None, index: Any = None) -> Any: def search(cls, using: Any = None, index: Any = None) -> Any:
return super(Repository, cls).search().filter("term", commit_repo="repo") return super().search().filter("term", commit_repo="repo")
class Index: class Index:
name = "git" name = "git"
@@ -8,8 +8,6 @@
# GitHub history for details. # GitHub history for details.
from __future__ import unicode_literals
import unittest import unittest
import pytest import pytest
@@ -8,8 +8,6 @@
# GitHub history for details. # GitHub history for details.
from __future__ import unicode_literals
import pytest import pytest
from _pytest.mark.structures import MarkDecorator from _pytest.mark.structures import MarkDecorator
@@ -121,7 +121,7 @@ class AsyncYamlRunner(YamlRunner):
if hasattr(self, "run_" + action_type): if hasattr(self, "run_" + action_type):
await await_if_coro(getattr(self, "run_" + action_type)(action)) await await_if_coro(getattr(self, "run_" + action_type)(action))
else: else:
raise RuntimeError("Invalid action type %r" % (action_type,)) raise RuntimeError(f"Invalid action type {action_type!r}")
async def run_do(self, action: Any) -> Any: async def run_do(self, action: Any) -> Any:
api = self.client api = self.client
@@ -170,7 +170,7 @@ class AsyncYamlRunner(YamlRunner):
else: else:
if catch: if catch:
raise AssertionError( raise AssertionError(
"Failed to catch %r in %r." % (catch, self.last_response) f"Failed to catch {catch!r} in {self.last_response!r}."
) )
# Filter out warnings raised by other components. # Filter out warnings raised by other components.
@@ -197,7 +197,7 @@ class AsyncYamlRunner(YamlRunner):
for feature in features: for feature in features:
if feature in IMPLEMENTED_FEATURES: if feature in IMPLEMENTED_FEATURES:
continue continue
pytest.skip("feature '%s' is not supported" % feature) pytest.skip(f"feature '{feature}' is not supported")
if "version" in skip: if "version" in skip:
version, reason = skip["version"], skip["reason"] version, reason = skip["version"], skip["reason"]
@@ -8,8 +8,6 @@
# GitHub history for details. # GitHub history for details.
from __future__ import unicode_literals
import os import os
from unittest import IsolatedAsyncioTestCase from unittest import IsolatedAsyncioTestCase
+1 -1
View File
@@ -8,10 +8,10 @@
# GitHub history for details. # GitHub history for details.
import uuid import uuid
from unittest.mock import Mock
import pytest import pytest
from _pytest.mark.structures import MarkDecorator from _pytest.mark.structures import MarkDecorator
from mock import Mock
pytestmark: MarkDecorator = pytest.mark.asyncio pytestmark: MarkDecorator = pytest.mark.asyncio
@@ -25,15 +25,13 @@
# under the License. # under the License.
from __future__ import unicode_literals
import asyncio import asyncio
import json import json
from typing import Any from typing import Any
from unittest.mock import patch
import pytest import pytest
from _pytest.mark.structures import MarkDecorator from _pytest.mark.structures import MarkDecorator
from mock import patch
from opensearchpy import AIOHttpConnection, AsyncTransport from opensearchpy import AIOHttpConnection, AsyncTransport
from opensearchpy.connection import Connection from opensearchpy.connection import Connection
@@ -51,7 +49,7 @@ class DummyConnection(Connection):
self.delay = kwargs.pop("delay", 0) self.delay = kwargs.pop("delay", 0)
self.calls: Any = [] self.calls: Any = []
self.closed = False self.closed = False
super(DummyConnection, self).__init__(**kwargs) super().__init__(**kwargs)
async def perform_request(self, *args: Any, **kwargs: Any) -> Any: async def perform_request(self, *args: Any, **kwargs: Any) -> Any:
if self.closed: if self.closed:
@@ -253,7 +251,7 @@ class TestTransport:
assert dt is t.connection_pool.dead_timeout assert dt is t.connection_pool.dead_timeout
async def test_custom_connection_class(self) -> None: async def test_custom_connection_class(self) -> None:
class MyConnection(object): class MyConnection:
def __init__(self, **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:
self.kwargs = kwargs self.kwargs = kwargs
+2 -2
View File
@@ -32,7 +32,7 @@ from unittest import SkipTest, TestCase
from opensearchpy import OpenSearch from opensearchpy import OpenSearch
class DummyTransport(object): class DummyTransport:
def __init__( def __init__(
self, hosts: Sequence[str], responses: Any = None, **kwargs: Any self, hosts: Sequence[str], responses: Any = None, **kwargs: Any
) -> None: ) -> None:
@@ -59,7 +59,7 @@ class DummyTransport(object):
class OpenSearchTestCase(TestCase): class OpenSearchTestCase(TestCase):
def setUp(self) -> None: def setUp(self) -> None:
super(OpenSearchTestCase, self).setUp() super().setUp()
self.client: Any = OpenSearch(transport_class=DummyTransport) # type: ignore self.client: Any = OpenSearch(transport_class=DummyTransport) # type: ignore
def assert_call_count_equals(self, count: int) -> None: def assert_call_count_equals(self, count: int) -> None:
@@ -25,8 +25,6 @@
# under the License. # under the License.
from __future__ import unicode_literals
import warnings import warnings
from opensearchpy.client import OpenSearch from opensearchpy.client import OpenSearch
@@ -25,8 +25,6 @@
# under the License. # under the License.
from __future__ import unicode_literals
from typing import Any from typing import Any
from opensearchpy.client.utils import _bulk_body, _escape, _make_path, query_params from opensearchpy.client.utils import _bulk_body, _escape, _make_path, query_params
@@ -30,9 +30,9 @@ import re
import uuid import uuid
import warnings import warnings
from typing import Any from typing import Any
from unittest.mock import MagicMock, Mock, patch
import pytest import pytest
from mock import MagicMock, Mock, patch
from requests.auth import AuthBase from requests.auth import AuthBase
from opensearchpy.connection import Connection, RequestsHttpConnection from opensearchpy.connection import Connection, RequestsHttpConnection
@@ -258,7 +258,7 @@ class TestRequestsHttpConnection(TestCase):
"GET", "GET",
"/", "/",
{"param": 42}, {"param": 42},
"{}".encode("utf-8"), b"{}",
) )
# trace request # trace request
@@ -282,7 +282,7 @@ class TestRequestsHttpConnection(TestCase):
"GET", "GET",
"/", "/",
{"param": 42}, {"param": 42},
"""{"question": "what's that?"}""".encode("utf-8"), b"""{"question": "what's that?"}""",
) )
# trace request # trace request
@@ -397,7 +397,7 @@ class TestRequestsHttpConnection(TestCase):
self.assertEqual("http://localhost:9200/", request.url) self.assertEqual("http://localhost:9200/", request.url)
self.assertEqual("GET", request.method) self.assertEqual("GET", request.method)
self.assertEqual('{"answer": 42}'.encode("utf-8"), request.body) self.assertEqual(b'{"answer": 42}', request.body)
def test_http_auth_attached(self) -> None: def test_http_auth_attached(self) -> None:
con = self._get_mock_connection({"http_auth": "username:secret"}) con = self._get_mock_connection({"http_auth": "username:secret"})
@@ -414,7 +414,7 @@ class TestRequestsHttpConnection(TestCase):
self.assertEqual("http://localhost:9200/some-prefix/_search", request.url) self.assertEqual("http://localhost:9200/some-prefix/_search", request.url)
self.assertEqual("GET", request.method) self.assertEqual("GET", request.method)
self.assertEqual('{"answer": 42}'.encode("utf-8"), request.body) self.assertEqual(b'{"answer": 42}', request.body)
# trace request # trace request
trace_curl_cmd = ( trace_curl_cmd = (
@@ -431,7 +431,7 @@ class TestRequestsHttpConnection(TestCase):
buf = b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa" buf = b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"
con = self._get_mock_connection(response_body=buf) con = self._get_mock_connection(response_body=buf)
_, _, data = con.perform_request("GET", "/") _, _, data = con.perform_request("GET", "/")
self.assertEqual(u"你好\uda6a", data) # fmt: skip self.assertEqual("你好\uda6a", data) # fmt: skip
def test_recursion_error_reraised(self) -> None: def test_recursion_error_reraised(self) -> None:
conn = RequestsHttpConnection() conn = RequestsHttpConnection()
@@ -32,10 +32,10 @@ from gzip import GzipFile
from io import BytesIO from io import BytesIO
from platform import python_version from platform import python_version
from typing import Any from typing import Any
from unittest.mock import MagicMock, Mock, patch
import pytest import pytest
import urllib3 import urllib3
from mock import MagicMock, Mock, patch
from urllib3._collections import HTTPHeaderDict from urllib3._collections import HTTPHeaderDict
from opensearchpy import __versionstr__ from opensearchpy import __versionstr__
@@ -130,7 +130,7 @@ class TestUrllib3HttpConnection(TestCase):
con = Urllib3HttpConnection() con = Urllib3HttpConnection()
self.assertEqual( self.assertEqual(
con._get_default_user_agent(), con._get_default_user_agent(),
"opensearch-py/%s (Python %s)" % (__versionstr__, python_version()), f"opensearch-py/{__versionstr__} (Python {python_version()})",
) )
def test_timeout_set(self) -> None: def test_timeout_set(self) -> None:
@@ -385,7 +385,7 @@ class TestUrllib3HttpConnection(TestCase):
buf = b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa" buf = b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"
con = self._get_mock_connection(response_body=buf) con = self._get_mock_connection(response_body=buf)
_, _, data = con.perform_request("GET", "/") _, _, data = con.perform_request("GET", "/")
self.assertEqual(u"你好\uda6a", data) # fmt: skip self.assertEqual("你好\uda6a", data) # fmt: skip
def test_recursion_error_reraised(self) -> None: def test_recursion_error_reraised(self) -> None:
conn = Urllib3HttpConnection() conn = Urllib3HttpConnection()
+1 -3
View File
@@ -68,9 +68,7 @@ class TestConnectionPool(TestCase):
def test_selectors_have_access_to_connection_opts(self) -> None: def test_selectors_have_access_to_connection_opts(self) -> None:
class MySelector(RoundRobinSelector): class MySelector(RoundRobinSelector):
def select(self, connections: Any) -> Any: def select(self, connections: Any) -> Any:
return self.connection_opts[ return self.connection_opts[super().select(connections)]["actual"]
super(MySelector, self).select(connections)
]["actual"]
pool = ConnectionPool( pool = ConnectionPool(
[(x, {"actual": x * x}) for x in range(100)], [(x, {"actual": x * x}) for x in range(100)],
+1 -1
View File
@@ -26,8 +26,8 @@
from typing import Any from typing import Any
from unittest.mock import Mock
from mock import Mock
from pytest import fixture from pytest import fixture
from opensearchpy.connection.connections import add_connection, connections from opensearchpy.connection.connections import add_connection, connections
@@ -28,9 +28,9 @@
import threading import threading
import time import time
from typing import Any from typing import Any
from unittest import mock
from unittest.mock import Mock from unittest.mock import Mock
import mock
import pytest import pytest
from opensearchpy import OpenSearch, helpers from opensearchpy import OpenSearch, helpers
@@ -131,7 +131,7 @@ class TestParallelBulk(TestCase):
results = list( results = list(
helpers.parallel_bulk(OpenSearch(), actions, thread_count=10, chunk_size=2) helpers.parallel_bulk(OpenSearch(), actions, thread_count=10, chunk_size=2)
) )
self.assertTrue(len(set([r[1] for r in results])) > 1) self.assertTrue(len({r[1] for r in results}) > 1)
class TestChunkActions(TestCase): class TestChunkActions(TestCase):
@@ -266,7 +266,7 @@ class TestChunkActions(TestCase):
) )
self.assertEqual(25, len(chunks)) self.assertEqual(25, len(chunks))
for _, chunk_actions in chunks: for _, chunk_actions in chunks:
chunk = u"".join(chunk_actions) # fmt: skip chunk = "".join(chunk_actions) # fmt: skip
chunk = chunk if isinstance(chunk, str) else chunk.encode("utf-8") chunk = chunk if isinstance(chunk, str) else chunk.encode("utf-8")
self.assertLessEqual(len(chunk), max_byte_size) self.assertLessEqual(len(chunk), max_byte_size)
@@ -24,7 +24,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from __future__ import unicode_literals
import codecs import codecs
import ipaddress import ipaddress
+5 -5
View File
@@ -84,7 +84,7 @@ def test_cloned_index_has_analysis_attribute() -> None:
client = object() client = object()
i: Any = Index("my-index", using=client) i: Any = Index("my-index", using=client)
random_analyzer_name = "".join((choice(string.ascii_letters) for _ in range(100))) random_analyzer_name = "".join(choice(string.ascii_letters) for _ in range(100))
random_analyzer = analyzer( random_analyzer = analyzer(
random_analyzer_name, tokenizer="standard", filter="standard" random_analyzer_name, tokenizer="standard", filter="standard"
) )
@@ -128,7 +128,7 @@ def test_registered_doc_type_included_in_search() -> None:
def test_aliases_add_to_object() -> None: def test_aliases_add_to_object() -> None:
random_alias = "".join((choice(string.ascii_letters) for _ in range(100))) random_alias = "".join(choice(string.ascii_letters) for _ in range(100))
alias_dict: Any = {random_alias: {}} alias_dict: Any = {random_alias: {}}
index: Any = Index("i", using="alias") index: Any = Index("i", using="alias")
@@ -138,7 +138,7 @@ def test_aliases_add_to_object() -> None:
def test_aliases_returned_from_to_dict() -> None: def test_aliases_returned_from_to_dict() -> None:
random_alias = "".join((choice(string.ascii_letters) for _ in range(100))) random_alias = "".join(choice(string.ascii_letters) for _ in range(100))
alias_dict: Any = {random_alias: {}} alias_dict: Any = {random_alias: {}}
index: Any = Index("i", using="alias") index: Any = Index("i", using="alias")
@@ -148,7 +148,7 @@ def test_aliases_returned_from_to_dict() -> None:
def test_analyzers_added_to_object() -> None: def test_analyzers_added_to_object() -> None:
random_analyzer_name = "".join((choice(string.ascii_letters) for _ in range(100))) random_analyzer_name = "".join(choice(string.ascii_letters) for _ in range(100))
random_analyzer = analyzer( random_analyzer = analyzer(
random_analyzer_name, tokenizer="standard", filter="standard" random_analyzer_name, tokenizer="standard", filter="standard"
) )
@@ -164,7 +164,7 @@ def test_analyzers_added_to_object() -> None:
def test_analyzers_returned_from_to_dict() -> None: def test_analyzers_returned_from_to_dict() -> None:
random_analyzer_name = "".join((choice(string.ascii_letters) for _ in range(100))) random_analyzer_name = "".join(choice(string.ascii_letters) for _ in range(100))
random_analyzer = analyzer( random_analyzer = analyzer(
random_analyzer_name, tokenizer="standard", filter="standard" random_analyzer_name, tokenizer="standard", filter="standard"
) )
@@ -95,7 +95,7 @@ def test_interactive_helpers(dummy_response: Any) -> None:
) )
assert res assert res
assert "<Response: %s>" % rhits == repr(res) assert f"<Response: {rhits}>" == repr(res)
assert rhits == repr(hits) assert rhits == repr(hits)
assert {"meta", "city", "name"} == set(dir(h)) assert {"meta", "city", "name"} == set(dir(h))
assert "<Hit(test-index/opensearch): %r>" % dummy_response["hits"]["hits"][0][ assert "<Hit(test-index/opensearch): %r>" % dummy_response["hits"]["hits"][0][
+1 -1
View File
@@ -99,7 +99,7 @@ def test_serializer_deals_with_attr_versions() -> None:
def test_serializer_deals_with_objects_with_to_dict() -> None: def test_serializer_deals_with_objects_with_to_dict() -> None:
class MyClass(object): class MyClass:
def to_dict(self) -> int: def to_dict(self) -> int:
return 42 return 42
@@ -66,7 +66,7 @@ class AutoNowDate(Date):
def clean(self, data: Any) -> Any: def clean(self, data: Any) -> Any:
if data is None: if data is None:
data = datetime.now() data = datetime.now()
return super(AutoNowDate, self).clean(data) return super().clean(data)
class Log(Document): class Log(Document):
+1 -1
View File
@@ -74,7 +74,7 @@ def sync_client_factory() -> Any:
except ConnectionError: except ConnectionError:
time.sleep(0.1) time.sleep(0.1)
else: else:
pytest.skip("OpenSearch wasn't running at %r" % (OPENSEARCH_URL,)) pytest.skip(f"OpenSearch wasn't running at {OPENSEARCH_URL!r}")
wipe_cluster(client) wipe_cluster(client)
yield client yield client
@@ -25,8 +25,6 @@
# under the License. # under the License.
from __future__ import unicode_literals
from . import OpenSearchTestCase from . import OpenSearchTestCase
@@ -26,8 +26,7 @@
from typing import Any from typing import Any
from unittest.mock import patch
from mock import patch
from opensearchpy import TransportError, helpers from opensearchpy import TransportError, helpers
from opensearchpy.helpers import ScanError from opensearchpy.helpers import ScanError
@@ -36,7 +35,7 @@ from ...test_cases import SkipTest
from .. import OpenSearchTestCase from .. import OpenSearchTestCase
class FailingBulkClient(object): class FailingBulkClient:
def __init__( def __init__(
self, self,
client: Any, client: Any,
@@ -383,7 +382,7 @@ class TestScan(OpenSearchTestCase):
def teardown_method(self, m: Any) -> None: def teardown_method(self, m: Any) -> None:
self.client.transport.perform_request("DELETE", "/_search/scroll/_all") self.client.transport.perform_request("DELETE", "/_search/scroll/_all")
super(TestScan, self).teardown_method(m) super().teardown_method(m)
def test_order_can_be_preserved(self) -> None: def test_order_can_be_preserved(self) -> None:
bulk: Any = [] bulk: Any = []
@@ -415,8 +414,8 @@ class TestScan(OpenSearchTestCase):
docs = list(helpers.scan(self.client, index="test_index", size=2)) docs = list(helpers.scan(self.client, index="test_index", size=2))
self.assertEqual(100, len(docs)) self.assertEqual(100, len(docs))
self.assertEqual(set(map(str, range(100))), set(d["_id"] for d in docs)) self.assertEqual(set(map(str, range(100))), {d["_id"] for d in docs})
self.assertEqual(set(range(100)), set(d["_source"]["answer"] for d in docs)) self.assertEqual(set(range(100)), {d["_source"]["answer"] for d in docs})
def test_scroll_error(self) -> None: def test_scroll_error(self) -> None:
bulk: Any = [] bulk: Any = []
@@ -24,7 +24,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from __future__ import unicode_literals
from typing import Any, Dict from typing import Any, Dict
@@ -79,7 +79,7 @@ class Repository(Document):
@classmethod @classmethod
def search(cls, using: Any = None, index: Any = None) -> Any: def search(cls, using: Any = None, index: Any = None) -> Any:
return super(Repository, cls).search().filter("term", commit_repo="repo") return super().search().filter("term", commit_repo="repo")
class Index: class Index:
name = "git" name = "git"
@@ -106,7 +106,7 @@ def repo_search_cls(opensearch_version: Any) -> Any:
} }
def search(self) -> Any: def search(self) -> Any:
s = super(RepoSearch, self).search() s = super().search()
return s.filter("term", commit_repo="repo") return s.filter("term", commit_repo="repo")
return RepoSearch return RepoSearch
@@ -24,7 +24,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from __future__ import unicode_literals
from typing import Any from typing import Any
@@ -52,7 +51,7 @@ class Repository(Document):
@classmethod @classmethod
def search(cls, using: Any = None, index: Any = None) -> Any: def search(cls, using: Any = None, index: Any = None) -> Any:
return super(Repository, cls).search().filter("term", commit_repo="repo") return super().search().filter("term", commit_repo="repo")
class Index: class Index:
name = "git" name = "git"
@@ -7,7 +7,6 @@
# Modifications Copyright OpenSearch Contributors. See # Modifications Copyright OpenSearch Contributors. See
# GitHub history for details. # GitHub history for details.
from __future__ import unicode_literals
import time import time
@@ -8,8 +8,6 @@
# GitHub history for details. # GitHub history for details.
from __future__ import unicode_literals
import unittest import unittest
from opensearchpy.helpers.test import OPENSEARCH_VERSION from opensearchpy.helpers.test import OPENSEARCH_VERSION
@@ -8,8 +8,6 @@
# GitHub history for details. # GitHub history for details.
from __future__ import unicode_literals
from opensearchpy.exceptions import NotFoundError from opensearchpy.exceptions import NotFoundError
from .. import OpenSearchTestCase from .. import OpenSearchTestCase
@@ -8,8 +8,6 @@
# GitHub history for details. # GitHub history for details.
from __future__ import unicode_literals
import unittest import unittest
from typing import Any, Dict from typing import Any, Dict
@@ -169,7 +169,7 @@ class YamlRunner:
if hasattr(self, "run_" + action_type): if hasattr(self, "run_" + action_type):
getattr(self, "run_" + action_type)(action) getattr(self, "run_" + action_type)(action)
else: else:
raise RuntimeError("Invalid action type %r" % (action_type,)) raise RuntimeError(f"Invalid action type {action_type!r}")
def run_do(self, action: Any) -> Any: def run_do(self, action: Any) -> Any:
api = self.client api = self.client
@@ -218,7 +218,7 @@ class YamlRunner:
else: else:
if catch: if catch:
raise AssertionError( raise AssertionError(
"Failed to catch %r in %r." % (catch, self.last_response) f"Failed to catch {catch!r} in {self.last_response!r}."
) )
# Filter out warnings raised by other components. # Filter out warnings raised by other components.
@@ -248,7 +248,7 @@ class YamlRunner:
elif catch[0] == "/" and catch[-1] == "/": elif catch[0] == "/" and catch[-1] == "/":
assert ( assert (
re.search(catch[1:-1], exception.error + " " + repr(exception.info)), re.search(catch[1:-1], exception.error + " " + repr(exception.info)),
"%s not in %r" % (catch, exception.info), f"{catch} not in {exception.info!r}",
) is not None ) is not None
self.last_response = exception.info self.last_response = exception.info
@@ -262,7 +262,7 @@ class YamlRunner:
for feature in features: for feature in features:
if feature in IMPLEMENTED_FEATURES: if feature in IMPLEMENTED_FEATURES:
continue continue
pytest.skip("feature '%s' is not supported" % feature) pytest.skip(f"feature '{feature}' is not supported")
if "version" in skip: if "version" in skip:
version, reason = skip["version"], skip["reason"] version, reason = skip["version"], skip["reason"]
@@ -328,10 +328,7 @@ class YamlRunner:
and expected.strip().endswith("/") and expected.strip().endswith("/")
): ):
expected = re.compile(expected.strip()[1:-1], re.VERBOSE | re.MULTILINE) expected = re.compile(expected.strip()[1:-1], re.VERBOSE | re.MULTILINE)
assert expected.search(value), "%r does not match %r" % ( assert expected.search(value), f"{value!r} does not match {expected!r}"
value,
expected,
)
else: else:
self._assert_match_equals(value, expected) self._assert_match_equals(value, expected)
@@ -341,7 +338,7 @@ class YamlRunner:
expected = self._resolve(expected) # dict[str, str] expected = self._resolve(expected) # dict[str, str]
if expected not in value: if expected not in value:
raise AssertionError("%s is not contained by %s" % (expected, value)) raise AssertionError(f"{expected} is not contained by {value}")
def run_transform_and_set(self, action: Any) -> None: def run_transform_and_set(self, action: Any) -> None:
for key, value in action.items(): for key, value in action.items():
@@ -371,7 +368,7 @@ class YamlRunner:
break break
if isinstance(value, dict): if isinstance(value, dict):
value = dict((k, self._resolve(v)) for (k, v) in value.items()) value = {k: self._resolve(v) for (k, v) in value.items()}
elif isinstance(value, list): elif isinstance(value, list):
value = list(map(self._resolve, value)) value = list(map(self._resolve, value))
return value return value
@@ -412,7 +409,7 @@ class YamlRunner:
if isinstance(b, string_types) and isinstance(a, float) and "e" in repr(a): if isinstance(b, string_types) and isinstance(a, float) and "e" in repr(a):
a = repr(a).replace("e+", "E") a = repr(a).replace("e+", "E")
assert a == b, "%r does not match %r" % (a, b) assert a == b, f"{a!r} does not match {b!r}"
@pytest.fixture(scope="function") # type: ignore @pytest.fixture(scope="function") # type: ignore
@@ -473,7 +470,7 @@ def load_rest_api_tests() -> None:
for prefix in ("rest-api-spec/", "test/", "oss/"): for prefix in ("rest-api-spec/", "test/", "oss/"):
if pytest_test_name.startswith(prefix): if pytest_test_name.startswith(prefix):
pytest_test_name = pytest_test_name[len(prefix) :] pytest_test_name = pytest_test_name[len(prefix) :]
pytest_param_id = "%s[%d]" % (pytest_test_name, test_number) pytest_param_id = f"{pytest_test_name}[{test_number}]"
pytest_param = { pytest_param = {
"setup": setup_steps, "setup": setup_steps,
@@ -487,7 +484,7 @@ def load_rest_api_tests() -> None:
YAML_TEST_SPECS.append(pytest.param(pytest_param, id=pytest_param_id)) YAML_TEST_SPECS.append(pytest.param(pytest_param, id=pytest_param_id))
except Exception as e: except Exception as e:
warnings.warn("Could not load REST API tests: %s" % (str(e),)) warnings.warn(f"Could not load REST API tests: {str(e)}")
load_rest_api_tests() load_rest_api_tests()
@@ -8,8 +8,6 @@
# GitHub history for details. # GitHub history for details.
from __future__ import unicode_literals
import os import os
from unittest import TestCase from unittest import TestCase
+2 -5
View File
@@ -25,13 +25,10 @@
# under the License. # under the License.
from __future__ import unicode_literals
import json import json
import time import time
from typing import Any from typing import Any
from unittest.mock import patch
from mock import patch
from opensearchpy.connection import Connection from opensearchpy.connection import Connection
from opensearchpy.connection_pool import DummyConnectionPool from opensearchpy.connection_pool import DummyConnectionPool
@@ -47,7 +44,7 @@ class DummyConnection(Connection):
self.status, self.data = kwargs.pop("status", 200), kwargs.pop("data", "{}") self.status, self.data = kwargs.pop("status", 200), kwargs.pop("data", "{}")
self.headers = kwargs.pop("headers", {}) self.headers = kwargs.pop("headers", {})
self.calls: Any = [] self.calls: Any = []
super(DummyConnection, self).__init__(**kwargs) super().__init__(**kwargs)
def perform_request(self, *args: Any, **kwargs: Any) -> Any: def perform_request(self, *args: Any, **kwargs: Any) -> Any:
self.calls.append((args, kwargs)) self.calls.append((args, kwargs))
+5 -5
View File
@@ -69,7 +69,7 @@ def run(*argv: Any, expect_exit_code: int = 0) -> None:
else: else:
os.chdir(TMP_DIR) os.chdir(TMP_DIR)
cmd = " ".join(shlex.quote(x) for x in argv) cmd = shlex.join(argv)
print("$ " + cmd) print("$ " + cmd)
exit_code = os.system(cmd) exit_code = os.system(cmd)
if exit_code != expect_exit_code: if exit_code != expect_exit_code:
@@ -270,7 +270,7 @@ def main() -> None:
# Rename the module to fit the suffix. # Rename the module to fit the suffix.
shutil.move( shutil.move(
os.path.join(BASE_DIR, "opensearchpy"), os.path.join(BASE_DIR, "opensearchpy"),
os.path.join(BASE_DIR, "opensearchpy%s" % suffix), os.path.join(BASE_DIR, f"opensearchpy{suffix}"),
) )
# Ensure that the version within 'opensearchpy/_version.py' is correct. # Ensure that the version within 'opensearchpy/_version.py' is correct.
@@ -279,7 +279,7 @@ def main() -> None:
version_data = file.read() version_data = file.read()
version_data = re.sub( version_data = re.sub(
r"__versionstr__: str = \"[^\"]+\"", r"__versionstr__: str = \"[^\"]+\"",
'__versionstr__: str = "%s"' % version, f'__versionstr__: str = "{version}"',
version_data, version_data,
) )
with open(version_path, "w", encoding="utf-8") as file: with open(version_path, "w", encoding="utf-8") as file:
@@ -296,7 +296,7 @@ def main() -> None:
file.write( file.write(
setup_py.replace( setup_py.replace(
'PACKAGE_NAME = "opensearch-py"', 'PACKAGE_NAME = "opensearch-py"',
'PACKAGE_NAME = "opensearch-py%s"' % suffix, f'PACKAGE_NAME = "opensearch-py{suffix}"',
) )
) )
@@ -306,7 +306,7 @@ def main() -> None:
# Clean up everything. # Clean up everything.
run("git", "checkout", "--", "setup.py", "opensearchpy/") run("git", "checkout", "--", "setup.py", "opensearchpy/")
if suffix: if suffix:
run("rm", "-rf", "opensearchpy%s/" % suffix) run("rm", "-rf", f"opensearchpy{suffix}/")
# Test everything that got created # Test everything that got created
dists = os.listdir(os.path.join(BASE_DIR, "dist")) dists = os.listdir(os.path.join(BASE_DIR, "dist"))
+4 -4
View File
@@ -95,7 +95,7 @@ def blacken(filename: Any) -> None:
assert result.exit_code == 0, result.output assert result.exit_code == 0, result.output
@lru_cache() @lru_cache
def is_valid_url(url: str) -> bool: def is_valid_url(url: str) -> bool:
""" """
makes a call to the url makes a call to the url
@@ -222,7 +222,7 @@ class Module:
# Identifying the insertion point for the "THIS CODE IS AUTOMATICALLY GENERATED" header. # Identifying the insertion point for the "THIS CODE IS AUTOMATICALLY GENERATED" header.
if os.path.exists(self.filepath): if os.path.exists(self.filepath):
with open(self.filepath, "r", encoding="utf-8") as file: with open(self.filepath, encoding="utf-8") as file:
content = file.read() content = file.read()
if header_separator in content: if header_separator in content:
update_header = False update_header = False
@@ -249,7 +249,7 @@ class Module:
generated_file_header_path = os.path.join( generated_file_header_path = os.path.join(
current_script_folder, "generated_file_headers.txt" current_script_folder, "generated_file_headers.txt"
) )
with open(generated_file_header_path, "r", encoding="utf-8") as header_file: with open(generated_file_header_path, encoding="utf-8") as header_file:
header_content = header_file.read() header_content = header_file.read()
# Imports are temporarily removed from the header and are regenerated # Imports are temporarily removed from the header and are regenerated
@@ -287,7 +287,7 @@ class Module:
# Generating imports for each module # Generating imports for each module
utils_imports = "" utils_imports = ""
file_content = "" file_content = ""
with open(self.filepath, "r", encoding="utf-8") as file: with open(self.filepath, encoding="utf-8") as file:
content = file.read() content = file.read()
keywords = [ keywords = [
"SKIP_IN_PATH", "SKIP_IN_PATH",
+2 -2
View File
@@ -55,7 +55,7 @@ def does_file_need_fix(filepath: str) -> bool:
if not re.search(r"\.py$", filepath): if not re.search(r"\.py$", filepath):
return False return False
existing_header = "" existing_header = ""
with open(filepath, mode="r", encoding="utf-8") as file: with open(filepath, encoding="utf-8") as file:
for line in file: for line in file:
line = line.strip() line = line.strip()
if len(line) == 0 or line in LINES_TO_KEEP: if len(line) == 0 or line in LINES_TO_KEEP:
@@ -73,7 +73,7 @@ def add_header_to_file(filepath: str) -> None:
writes the license header to the beginning of a file writes the license header to the beginning of a file
:param filepath: relative or absolute filepath to update :param filepath: relative or absolute filepath to update
""" """
with open(filepath, mode="r", encoding="utf-8") as file: with open(filepath, encoding="utf-8") as file:
lines = list(file) lines = list(file)
i = 0 i = 0
for i, line in enumerate(lines): for i, line in enumerate(lines):