Remove redundant mock backport dependency and upgrade syntax for Python 3.8+ (#785)

* Upgrade syntax with pyupgrade --py38-plus

Signed-off-by: Hugo van Kemenade <1324225+hugovk@users.noreply.github.com>

* Convert to f-strings with flynt

Signed-off-by: Hugo van Kemenade <1324225+hugovk@users.noreply.github.com>

* Format with Black

Signed-off-by: Hugo van Kemenade <1324225+hugovk@users.noreply.github.com>

* Remove redundant mock backport dependency

Signed-off-by: Hugo van Kemenade <1324225+hugovk@users.noreply.github.com>

* isort imports

Signed-off-by: Hugo van Kemenade <1324225+hugovk@users.noreply.github.com>

* Add changelog entry

Signed-off-by: Hugo van Kemenade <1324225+hugovk@users.noreply.github.com>

---------

Signed-off-by: Hugo van Kemenade <1324225+hugovk@users.noreply.github.com>
This commit is contained in:
Hugo van Kemenade
2024-07-20 23:19:20 +03:00
committed by GitHub
parent de96d28e45
commit 6e3f1a1194
95 changed files with 229 additions and 300 deletions
+1
View File
@@ -8,6 +8,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
### Deprecated
### Removed
- Removed redundant dependency on six ([#781](https://github.com/opensearch-project/opensearch-py/pull/781))
- Removed redundant dependency on mock and upgrade Python syntax ([#785](https://github.com/opensearch-project/opensearch-py/pull/785))
### Fixed
- Fixed Search helper to ensure proper retention of the _collapse attribute in chained operations. ([#771](https://github.com/opensearch-project/opensearch-py/pull/771))
### Updated APIs
-1
View File
@@ -2,7 +2,6 @@ requests>=2, <3
pytest
pytest-cov
coverage
mock
sphinx<7.4
sphinx_rtd_theme
jinja2
-1
View File
@@ -93,7 +93,6 @@ def lint(session: Any) -> None:
"types-simplejson",
"types-python-dateutil",
"types-PyYAML",
"types-mock",
"types-pytz",
)
+2 -3
View File
@@ -26,7 +26,6 @@
# flake8: noqa
from __future__ import absolute_import
import logging
import re
@@ -34,9 +33,9 @@ import warnings
from ._version import __versionstr__
_major, _minor, _patch = [
_major, _minor, _patch = (
int(x) for x in re.search(r"^(\d+)\.(\d+)\.(\d+)", __versionstr__).groups() # type: ignore
]
)
VERSION = __version__ = (_major, _minor, _patch)
+3 -5
View File
@@ -34,8 +34,6 @@
# -----------------------------------------------------------------------------------------+
from __future__ import unicode_literals
import logging
from typing import Any, Type
@@ -197,7 +195,7 @@ class AsyncOpenSearch(Client):
self,
hosts: Any = None,
transport_class: Type[AsyncTransport] = AsyncTransport,
**kwargs: Any
**kwargs: Any,
) -> None:
"""
:arg hosts: list of nodes, or a single node, we should connect to.
@@ -240,10 +238,10 @@ class AsyncOpenSearch(Client):
# truncate to 5 if there are too many
if len(cons) > 5:
cons = cons[:5] + ["..."]
return "<{cls}({cons})>".format(cls=self.__class__.__name__, cons=cons)
return f"<{self.__class__.__name__}({cons})>"
except Exception:
# probably operating on custom transport and connection_pool, ignore
return super(AsyncOpenSearch, self).__repr__()
return super().__repr__()
async def __aenter__(self) -> Any:
if hasattr(self.transport, "_async_call"):
+1 -1
View File
@@ -13,7 +13,7 @@ from opensearchpy.client.utils import _normalize_hosts
from opensearchpy.transport import Transport
class Client(object):
class Client:
"""
A generic async OpenSearch client.
"""
+1 -1
View File
@@ -15,7 +15,7 @@ from .utils import NamespacedClient
class HttpClient(NamespacedClient):
def __init__(self, client: Client) -> None:
super(HttpClient, self).__init__(client)
super().__init__(client)
async def get(
self,
+1 -1
View File
@@ -26,7 +26,7 @@ class PluginsClient(NamespacedClient):
index_management: Any
def __init__(self, client: Client) -> None:
super(PluginsClient, self).__init__(client)
super().__init__(client)
self.ml = MlClient(client)
self.transforms = TransformsClient(client)
self.rollups = RollupsClient(client)
+2 -2
View File
@@ -117,7 +117,7 @@ class AsyncDocument(ObjectBase, metaclass=AsyncIndexMeta):
return "{}({})".format(
self.__class__.__name__,
", ".join(
"{}={!r}".format(key, getattr(self.meta, key))
f"{key}={getattr(self.meta, key)!r}"
for key in ("index", "id")
if key in self.meta
),
@@ -249,7 +249,7 @@ class AsyncDocument(ObjectBase, metaclass=AsyncIndexMeta):
raise RequestError(400, message, error_docs)
if missing_docs:
missing_ids = [doc["_id"] for doc in missing_docs]
message = "Documents %s not found." % ", ".join(missing_ids)
message = f"Documents {', '.join(missing_ids)} not found."
raise NotFoundError(404, message, {"docs": missing_docs})
return objs
+2 -2
View File
@@ -18,7 +18,7 @@ from opensearchpy.helpers import analysis
from opensearchpy.helpers.utils import merge
class AsyncIndexTemplate(object):
class AsyncIndexTemplate:
def __init__(
self,
name: Any,
@@ -57,7 +57,7 @@ class AsyncIndexTemplate(object):
)
class AsyncIndex(object):
class AsyncIndex:
def __init__(self, name: Any, using: Any = "default") -> None:
"""
:arg name: name of the index
@@ -31,7 +31,7 @@ class AsyncUpdateByQuery(Request):
overridden by methods (`using`, `index` and `doc_type` respectively).
"""
super(AsyncUpdateByQuery, self).__init__(**kwargs)
super().__init__(**kwargs)
self._response_class = UpdateByQueryResponse
self._script: Any = {}
self._query_proxy = QueryProxy(self, "query")
@@ -70,7 +70,7 @@ class AsyncUpdateByQuery(Request):
of all the underlying objects. Used internally by most state modifying
APIs.
"""
ubq = super(AsyncUpdateByQuery, self)._clone()
ubq = super()._clone()
ubq._response_class = self._response_class
ubq._script = self._script.copy()
+3 -3
View File
@@ -93,7 +93,7 @@ class AIOHttpConnection(AsyncConnection):
opaque_id: Optional[str] = None,
loop: Any = None,
trust_env: Optional[bool] = False,
**kwargs: Any
**kwargs: Any,
) -> None:
"""
Default connection class for ``AsyncOpenSearch`` using the `aiohttp` library and the http protocol.
@@ -140,7 +140,7 @@ class AIOHttpConnection(AsyncConnection):
headers=headers,
http_compress=http_compress,
opaque_id=opaque_id,
**kwargs
**kwargs,
)
if http_auth is not None:
@@ -276,7 +276,7 @@ class AIOHttpConnection(AsyncConnection):
else:
url = self.url_prefix + url
if query_string:
url = "%s?%s" % (url, query_string)
url = f"{url}?{query_string}"
url = self.host + url
timeout = aiohttp.ClientTimeout(
+1 -1
View File
@@ -121,7 +121,7 @@ class AsyncTransport(Transport):
self._async_init_called = False
self._sniff_on_start_event: Optional[asyncio.Event] = None
super(AsyncTransport, self).__init__(
super().__init__(
hosts=[],
connection_class=connection_class,
connection_pool_class=connection_pool_class,
+3 -5
View File
@@ -34,8 +34,6 @@
# -----------------------------------------------------------------------------------------+
from __future__ import unicode_literals
import logging
from typing import Any, Type
@@ -197,7 +195,7 @@ class OpenSearch(Client):
self,
hosts: Any = None,
transport_class: Type[Transport] = Transport,
**kwargs: Any
**kwargs: Any,
) -> None:
"""
:arg hosts: list of nodes, or a single node, we should connect to.
@@ -240,10 +238,10 @@ class OpenSearch(Client):
# truncate to 5 if there are too many
if len(cons) > 5:
cons = cons[:5] + ["..."]
return "<{cls}({cons})>".format(cls=self.__class__.__name__, cons=cons)
return f"<{self.__class__.__name__}({cons})>"
except Exception:
# probably operating on custom transport and connection_pool, ignore
return super(OpenSearch, self).__repr__()
return super().__repr__()
def __enter__(self) -> Any:
if hasattr(self.transport, "_async_call"):
+1 -1
View File
@@ -13,7 +13,7 @@ from opensearchpy.client.utils import _normalize_hosts
from opensearchpy.transport import Transport
class Client(object):
class Client:
"""
A generic async OpenSearch client.
"""
+1 -1
View File
@@ -15,7 +15,7 @@ from .utils import NamespacedClient
class HttpClient(NamespacedClient):
def __init__(self, client: Client) -> None:
super(HttpClient, self).__init__(client)
super().__init__(client)
def get(
self,
+1 -1
View File
@@ -26,7 +26,7 @@ class PluginsClient(NamespacedClient):
index_management: Any
def __init__(self, client: Client) -> None:
super(PluginsClient, self).__init__(client)
super().__init__(client)
self.ml = MlClient(client)
self.transforms = TransformsClient(client)
self.rollups = RollupsClient(client)
+5 -9
View File
@@ -25,8 +25,6 @@
# under the License.
from __future__ import unicode_literals
import base64
import weakref
from datetime import date, datetime
@@ -59,7 +57,7 @@ def _normalize_hosts(hosts: Any) -> Any:
for host in hosts:
if isinstance(host, string_types):
if "://" not in host:
host = "//%s" % host # type: ignore
host = f"//{host}" # type: ignore
parsed_url = urlparse(host)
h = {"host": parsed_url.hostname}
@@ -72,7 +70,7 @@ def _normalize_hosts(hosts: Any) -> Any:
h["use_ssl"] = True
if parsed_url.username or parsed_url.password:
h["http_auth"] = "%s:%s" % (
h["http_auth"] = "{}:{}".format(
unquote(parsed_url.username),
unquote(parsed_url.password),
)
@@ -160,11 +158,9 @@ def query_params(*opensearch_query_params: Any) -> Callable: # type: ignore
"Only one of 'http_auth' and 'api_key' may be passed at a time"
)
elif http_auth is not None:
headers["authorization"] = "Basic %s" % (
_base64_auth_header(http_auth),
)
headers["authorization"] = f"Basic {_base64_auth_header(http_auth)}"
elif api_key is not None:
headers["authorization"] = "ApiKey %s" % (_base64_auth_header(api_key),)
headers["authorization"] = f"ApiKey {_base64_auth_header(api_key)}"
# don't escape ignore, request_timeout, or timeout
for p in ("ignore", "request_timeout", "timeout"):
@@ -209,7 +205,7 @@ def _base64_auth_header(auth_value: Any) -> str:
return to_str(auth_value)
class NamespacedClient(object):
class NamespacedClient:
def __init__(self, client: Any) -> None:
self.client = client
+2 -2
View File
@@ -68,7 +68,7 @@ class AsyncConnections:
errors += 1
if errors == 2:
raise KeyError("There is no connection with alias %r." % alias)
raise KeyError(f"There is no connection with alias {alias!r}.")
async def create_connection(self, alias: str = "default", **kwargs: Any) -> Any:
"""
@@ -104,7 +104,7 @@ class AsyncConnections:
return await self.create_connection(alias, **self._kwargs[alias])
except KeyError:
# no connection and no kwargs to set one up
raise KeyError("There is no connection with alias %r." % alias)
raise KeyError(f"There is no connection with alias {alias!r}.")
async_connections = AsyncConnections()
+9 -9
View File
@@ -53,7 +53,7 @@ if not TRACER_ALREADY_CONFIGURED:
_WARNING_RE = re.compile(r"\"([^\"]*)\"")
class Connection(object):
class Connection:
"""
Class responsible for maintaining a connection to an OpenSearch node. It
holds persistent connection pool to it and its main interface
@@ -81,7 +81,7 @@ class Connection(object):
headers: Optional[Dict[str, str]] = None,
http_compress: Optional[bool] = None,
opaque_id: Optional[str] = None,
**kwargs: Any
**kwargs: Any,
) -> None:
if port is None:
port = 9200
@@ -119,27 +119,27 @@ class Connection(object):
self.hostname = host
self.port = port
if ":" in host: # IPv6
self.host = "%s://[%s]" % (scheme, host)
self.host = f"{scheme}://[{host}]"
else:
self.host = "%s://%s" % (scheme, host)
self.host = f"{scheme}://{host}"
if self.port is not None:
self.host += ":%s" % self.port
self.host += f":{self.port}"
if url_prefix:
url_prefix = "/" + url_prefix.strip("/")
self.url_prefix = url_prefix
self.timeout = timeout
def __repr__(self) -> str:
return "<%s: %s>" % (self.__class__.__name__, self.host)
return f"<{self.__class__.__name__}: {self.host}>"
def __eq__(self, other: object) -> bool:
if not isinstance(other, Connection):
raise TypeError("Unsupported equality check for %s and %s" % (self, other))
raise TypeError(f"Unsupported equality check for {self} and {other}")
return self.__hash__() == other.__hash__()
def __lt__(self, other: object) -> bool:
if not isinstance(other, Connection):
raise TypeError("Unsupported lt check for %s and %s" % (self, other))
raise TypeError(f"Unsupported lt check for {self} and {other}")
return self.__hash__() < other.__hash__()
def __hash__(self) -> int:
@@ -317,7 +317,7 @@ class Connection(object):
)
def _get_default_user_agent(self) -> str:
return "opensearch-py/%s (Python %s)" % (__versionstr__, python_version())
return f"opensearch-py/{__versionstr__} (Python {python_version()})"
@staticmethod
def default_ca_certs() -> Union[str, None]:
+2 -2
View File
@@ -82,7 +82,7 @@ class Connections:
errors += 1
if errors == 2:
raise KeyError("There is no connection with alias %r." % alias)
raise KeyError(f"There is no connection with alias {alias!r}.")
def create_connection(self, alias: str = "default", **kwargs: Any) -> Any:
"""
@@ -118,7 +118,7 @@ class Connections:
return self.create_connection(alias, **self._kwargs[alias])
except KeyError:
# no connection and no kwargs to set one up
raise KeyError("There is no connection with alias %r." % alias)
raise KeyError(f"There is no connection with alias {alias!r}.")
connections = Connections()
+3 -3
View File
@@ -52,7 +52,7 @@ class AsyncHttpConnection(AIOHttpConnection):
http_compress: Optional[bool] = None,
opaque_id: Optional[str] = None,
loop: Any = None,
**kwargs: Any
**kwargs: Any,
) -> None:
self.headers = {}
@@ -63,7 +63,7 @@ class AsyncHttpConnection(AIOHttpConnection):
headers=headers,
http_compress=http_compress,
opaque_id=opaque_id,
**kwargs
**kwargs,
)
if http_auth is not None:
@@ -186,7 +186,7 @@ class AsyncHttpConnection(AIOHttpConnection):
# then we pass a string into ClientSession.request() instead.
url = self.url_prefix + url
if query_string:
url = "%s?%s" % (url, query_string)
url = f"{url}?{query_string}"
url = self.host + url
timeout = aiohttp.ClientTimeout(
+5 -8
View File
@@ -92,7 +92,7 @@ class RequestsHttpConnection(Connection):
opaque_id: Any = None,
pool_maxsize: Any = None,
metrics: Metrics = MetricsNone(),
**kwargs: Any
**kwargs: Any,
) -> None:
self.metrics = metrics
if not REQUESTS_AVAILABLE:
@@ -111,14 +111,14 @@ class RequestsHttpConnection(Connection):
self.session.mount("http://", pool_adapter)
self.session.mount("https://", pool_adapter)
super(RequestsHttpConnection, self).__init__(
super().__init__(
host=host,
port=port,
use_ssl=use_ssl,
headers=headers,
http_compress=http_compress,
opaque_id=opaque_id,
**kwargs
**kwargs,
)
if not self.http_compress:
@@ -132,10 +132,7 @@ class RequestsHttpConnection(Connection):
http_auth = tuple(http_auth.split(":", 1)) # type: ignore
self.session.auth = http_auth
self.base_url = "%s%s" % (
self.host,
self.url_prefix,
)
self.base_url = f"{self.host}{self.url_prefix}"
self.session.verify = verify_certs
if not client_key:
self.session.cert = client_cert
@@ -176,7 +173,7 @@ class RequestsHttpConnection(Connection):
url = self.base_url + url
headers = headers or {}
if params:
url = "%s?%s" % (url, urlencode(params or {}))
url = f"{url}?{urlencode(params or {})}"
orig_body = body
if self.http_compress and body:
+4 -4
View File
@@ -121,20 +121,20 @@ class Urllib3HttpConnection(Connection):
http_compress: Any = None,
opaque_id: Any = None,
metrics: Metrics = MetricsNone(),
**kwargs: Any
**kwargs: Any,
) -> None:
self.metrics = metrics
# Initialize headers before calling super().__init__().
self.headers = urllib3.make_headers(keep_alive=True)
super(Urllib3HttpConnection, self).__init__(
super().__init__(
host=host,
port=port,
use_ssl=use_ssl,
headers=headers,
http_compress=http_compress,
opaque_id=opaque_id,
**kwargs
**kwargs,
)
self.http_auth = http_auth
@@ -245,7 +245,7 @@ class Urllib3HttpConnection(Connection):
url = self.url_prefix + url
if params:
url = "%s?%s" % (url, urlencode(params))
url = f"{url}?{urlencode(params)}"
full_url = self.host + url
+1 -1
View File
@@ -47,7 +47,7 @@ class PoolingConnection(Connection):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self._free_connections = queue.Queue()
super(PoolingConnection, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
def _make_connection(self) -> Connection:
raise NotImplementedError
+5 -5
View File
@@ -38,7 +38,7 @@ from .exceptions import ImproperlyConfigured
logger: logging.Logger = logging.getLogger("opensearch")
class ConnectionSelector(object):
class ConnectionSelector:
"""
Simple class used to select a connection from a list of currently live
connection instances. In init time it is passed a dictionary containing all
@@ -87,7 +87,7 @@ class RoundRobinSelector(ConnectionSelector):
"""
def __init__(self, opts: Sequence[Tuple[Connection, Any]]) -> None:
super(RoundRobinSelector, self).__init__(opts)
super().__init__(opts)
self.data = threading.local()
def select(self, connections: Sequence[Connection]) -> Any:
@@ -96,7 +96,7 @@ class RoundRobinSelector(ConnectionSelector):
return connections[self.data.rr]
class ConnectionPool(object):
class ConnectionPool:
"""
Container holding the :class:`~opensearchpy.Connection` instances,
managing the selection process (via a
@@ -135,7 +135,7 @@ class ConnectionPool(object):
timeout_cutoff: int = 5,
selector_class: Type[ConnectionSelector] = RoundRobinSelector,
randomize_hosts: bool = True,
**kwargs: Any
**kwargs: Any,
) -> None:
"""
:arg connections: list of tuples containing the
@@ -290,7 +290,7 @@ class ConnectionPool(object):
conn.close()
def __repr__(self) -> str:
return "<%s: %r>" % (type(self).__name__, self.connections)
return f"<{type(self).__name__}: {self.connections!r}>"
class DummyConnectionPool(ConnectionPool):
+3 -3
View File
@@ -120,7 +120,7 @@ class TransportError(OpenSearchException):
except LookupError:
pass
msg = ", ".join(filter(None, [str(self.status_code), repr(self.error), cause]))
return "%s(%s)" % (self.__class__.__name__, msg)
return f"{self.__class__.__name__}({msg})"
class ConnectionError(TransportError):
@@ -131,7 +131,7 @@ class ConnectionError(TransportError):
"""
def __str__(self) -> str:
return "ConnectionError(%s) caused by: %s(%s)" % (
return "ConnectionError({}) caused by: {}({})".format(
self.error,
self.info.__class__.__name__,
self.info,
@@ -146,7 +146,7 @@ class ConnectionTimeout(ConnectionError):
"""A network timeout. Doesn't cause a node retry by default."""
def __str__(self) -> str:
return "ConnectionTimeout caused by - %s(%s)" % (
return "ConnectionTimeout caused by - {}({})".format(
self.info.__class__.__name__,
self.info,
)
+14 -17
View File
@@ -198,7 +198,7 @@ def _process_bulk_chunk_success(
yield ok, {op_type: item}
if errors:
raise BulkIndexError("%i document(s) failed to index." % len(errors), errors)
raise BulkIndexError(f"{len(errors)} document(s) failed to index.", errors)
def _process_bulk_chunk_error(
@@ -228,7 +228,7 @@ def _process_bulk_chunk_error(
# emulate standard behavior for failed actions
if raise_on_error and error.status_code not in ignore_status:
raise BulkIndexError(
"%i document(s) failed to index." % len(exc_errors), exc_errors
f"{len(exc_errors)} document(s) failed to index.", exc_errors
)
else:
for err in exc_errors:
@@ -243,7 +243,7 @@ def _process_bulk_chunk(
raise_on_error: bool = True,
ignore_status: Any = (),
*args: Any,
**kwargs: Any
**kwargs: Any,
) -> Any:
"""
Send a bulk request to opensearch and process the output.
@@ -269,8 +269,7 @@ def _process_bulk_chunk(
ignore_status=ignore_status,
raise_on_error=raise_on_error,
)
for item in gen:
yield item
yield from gen
def streaming_bulk(
@@ -287,7 +286,7 @@ def streaming_bulk(
yield_ok: bool = True,
ignore_status: Any = (),
*args: Any,
**kwargs: Any
**kwargs: Any,
) -> Any:
"""
Streaming bulk consumes actions from the iterable passed in and yields
@@ -344,7 +343,7 @@ def streaming_bulk(
raise_on_error,
ignore_status,
*args,
**kwargs
**kwargs,
),
):
if not ok:
@@ -384,7 +383,7 @@ def bulk(
stats_only: bool = False,
ignore_status: Any = (),
*args: Any,
**kwargs: Any
**kwargs: Any,
) -> Any:
"""
Helper for the :meth:`~opensearchpy.OpenSearch.bulk` api that provides
@@ -445,7 +444,7 @@ def parallel_bulk(
raise_on_error: bool = True,
ignore_status: Any = (),
*args: Any,
**kwargs: Any
**kwargs: Any,
) -> Any:
"""
Parallel version of the bulk helper run in multiple threads at once.
@@ -474,7 +473,7 @@ def parallel_bulk(
class BlockingPool(ThreadPool):
def _setup_queues(self) -> None:
super(BlockingPool, self)._setup_queues() # type: ignore
super()._setup_queues() # type: ignore
# The queue must be at least the size of the number of threads to
# prevent hanging when inserting sentinel values during teardown.
self._inqueue: Any = Queue(max(queue_size, thread_count))
@@ -493,15 +492,14 @@ def parallel_bulk(
raise_on_error,
ignore_status,
*args,
**kwargs
**kwargs,
)
),
_chunk_actions(
actions, chunk_size, max_chunk_bytes, client.transport.serializer
),
):
for item in result:
yield item
yield from result
finally:
pool.close()
@@ -518,7 +516,7 @@ def scan(
request_timeout: Optional[float] = None,
clear_scroll: Optional[bool] = True,
scroll_kwargs: Any = None,
**kwargs: Any
**kwargs: Any,
) -> Any:
"""
Simple abstraction on top of the
@@ -587,8 +585,7 @@ def scan(
try:
while scroll_id and resp.get("hits", {}).get("hits"):
for hit in resp.get("hits", {}).get("hits", []):
yield hit
yield from resp.get("hits", {}).get("hits", [])
_shards = resp.get("_shards")
@@ -687,5 +684,5 @@ def reindex(
target_client,
_change_doc_index(docs, target_index),
chunk_size=chunk_size,
**kwargs
**kwargs,
)
+5 -5
View File
@@ -89,7 +89,7 @@ class Agg(DslBase):
return False
def to_dict(self) -> Any:
d = super(Agg, self).to_dict()
d = super().to_dict()
if "meta" in d[self.name]:
d["meta"] = d[self.name].pop("meta")
return d
@@ -98,7 +98,7 @@ class Agg(DslBase):
return AggResponse(self, search, data)
class AggBase(object):
class AggBase:
_param_defs = {
"aggs": {"type": "agg", "hash": True},
}
@@ -151,7 +151,7 @@ class AggBase(object):
class Bucket(AggBase, Agg):
def __init__(self, **params: Any) -> None:
super(Bucket, self).__init__(**params)
super().__init__(**params)
# remember self for chaining
self._base = self
@@ -172,10 +172,10 @@ class Filter(Bucket):
def __init__(self, filter: Any = None, **params: Any) -> None:
if filter is not None:
params["filter"] = filter
super(Filter, self).__init__(**params)
super().__init__(**params)
def to_dict(self) -> Any:
d = super(Filter, self).to_dict()
d = super().to_dict()
d[self.name].update(d[self.name].pop("filter", {}))
return d
+1 -1
View File
@@ -38,7 +38,7 @@ class AnalysisBase:
) -> Any:
if isinstance(name_or_instance, cls):
if type or kwargs:
raise ValueError("%s() cannot accept parameters." % cls.__name__)
raise ValueError(f"{cls.__name__}() cannot accept parameters.")
return name_or_instance
if not (type or kwargs):
+2 -2
View File
@@ -188,7 +188,7 @@ class Document(ObjectBase, metaclass=IndexMeta):
return "{}({})".format(
self.__class__.__name__,
", ".join(
"{}={!r}".format(key, getattr(self.meta, key))
f"{key}={getattr(self.meta, key)!r}"
for key in ("index", "id")
if key in self.meta
),
@@ -310,7 +310,7 @@ class Document(ObjectBase, metaclass=IndexMeta):
raise RequestError(400, message, error_docs)
if missing_docs:
missing_ids = [doc["_id"] for doc in missing_docs]
message = "Documents %s not found." % ", ".join(missing_ids)
message = f"Documents {', '.join(missing_ids)} not found."
raise NotFoundError(404, message, {"docs": missing_docs})
return objs
+1 -1
View File
@@ -41,5 +41,5 @@ class ScanError(OpenSearchException):
scroll_id: str
def __init__(self, scroll_id: str, *args: Any, **kwargs: Any) -> None:
super(ScanError, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
self.scroll_id = scroll_id
+3 -3
View File
@@ -157,7 +157,7 @@ class Object(Field):
doc_class: Any = None,
dynamic: Any = None,
properties: Any = None,
**kwargs: Any
**kwargs: Any,
) -> None:
"""
:arg document.InnerDoc doc_class: base doc class that handles mapping.
@@ -281,7 +281,7 @@ class Date(Field):
data = parser.parse(data)
except Exception as e:
raise ValidationException(
"Could not parse date from the value (%r)" % data, e
f"Could not parse date from the value ({data!r})", e
)
if isinstance(data, datetime):
@@ -294,7 +294,7 @@ class Date(Field):
# Divide by a float to preserve milliseconds on the datetime.
return datetime.utcfromtimestamp(data / 1000.0)
raise ValidationException("Could not parse date from the value (%r)" % data)
raise ValidationException(f"Could not parse date from the value ({data!r})")
class Text(Field):
+3 -3
View File
@@ -48,7 +48,7 @@ def SF(name_or_sf: Any, **params: Any) -> Any: # pylint: disable=invalid-name
elif len(sf) == 1:
name, params = sf.popitem()
else:
raise ValueError("SF() got an unexpected fields in the dictionary: %r" % sf)
raise ValueError(f"SF() got an unexpected fields in the dictionary: {sf!r}")
# boost factor special case, see https://github.com/elastic/elasticsearch/issues/6343
if not isinstance(params, collections_abc.Mapping):
@@ -81,7 +81,7 @@ class ScoreFunction(DslBase):
name: Optional[str] = None
def to_dict(self) -> Any:
d = super(ScoreFunction, self).to_dict()
d = super().to_dict()
# filter and query dicts should be at the same level as us
for k in self._param_defs:
if k in d[self.name]:
@@ -97,7 +97,7 @@ class BoostFactor(ScoreFunction):
name = "boost_factor"
def to_dict(self) -> Any:
d = super(BoostFactor, self).to_dict()
d = super().to_dict()
if "value" in d[self.name]:
d[self.name] = d[self.name].pop("value")
else:
+2 -2
View File
@@ -37,7 +37,7 @@ from .update_by_query import UpdateByQuery
from .utils import merge
class IndexTemplate(object):
class IndexTemplate:
def __init__(
self,
name: Any,
@@ -76,7 +76,7 @@ class IndexTemplate(object):
)
class Index(object):
class Index:
def __init__(self, name: Any, using: Any = "default") -> None:
"""
:arg name: name of the index
+1 -1
View File
@@ -261,7 +261,7 @@ class FunctionScore(Query):
for name in ScoreFunction._classes:
if name in kwargs:
fns.append({name: kwargs.pop(name)})
super(FunctionScore, self).__init__(**kwargs)
super().__init__(**kwargs)
# compound queries
+5 -5
View File
@@ -34,7 +34,7 @@ class Response(AttrDict):
def __init__(self, search: Any, response: Any, doc_class: Any = None) -> None:
super(AttrDict, self).__setattr__("_search", search)
super(AttrDict, self).__setattr__("_doc_class", doc_class)
super(Response, self).__init__(response)
super().__init__(response)
def __iter__(self) -> Any:
return iter(self.hits)
@@ -43,7 +43,7 @@ class Response(AttrDict):
if isinstance(key, (slice, int)):
# for slicing etc
return self.hits[key]
return super(Response, self).__getitem__(key)
return super().__getitem__(key)
def __nonzero__(self) -> Any:
return bool(self.hits)
@@ -103,14 +103,14 @@ class Response(AttrDict):
class AggResponse(AttrDict):
def __init__(self, aggs: Any, search: Any, data: Any) -> None:
super(AttrDict, self).__setattr__("_meta", {"search": search, "aggs": aggs})
super(AggResponse, self).__init__(data)
super().__init__(data)
def __getitem__(self, attr_name: Any) -> Any:
if attr_name in self._meta["aggs"]:
# don't do self._meta['aggs'][attr_name] to avoid copying
agg = self._meta["aggs"].aggs[attr_name]
return agg.result(self._meta["search"], self._d_[attr_name])
return super(AggResponse, self).__getitem__(attr_name)
return super().__getitem__(attr_name)
def __iter__(self) -> Any:
for name in self._meta["aggs"]:
@@ -121,7 +121,7 @@ class UpdateByQueryResponse(AttrDict):
def __init__(self, search: Any, response: Any, doc_class: Any = None) -> None:
super(AttrDict, self).__setattr__("_search", search)
super(AttrDict, self).__setattr__("_doc_class", doc_class)
super(UpdateByQueryResponse, self).__init__(response)
super().__init__(response)
def success(self) -> bool:
return not self.timed_out and not self.failures
+4 -4
View File
@@ -32,14 +32,14 @@ from . import AggResponse, Response
class Bucket(AggResponse):
def __init__(self, aggs: Any, search: Any, data: Any, field: Any = None) -> None:
super(Bucket, self).__init__(aggs, search, data)
super().__init__(aggs, search, data)
class FieldBucket(Bucket):
def __init__(self, aggs: Any, search: Any, data: Any, field: Any = None) -> None:
if field:
data["key"] = field.deserialize(data["key"])
super(FieldBucket, self).__init__(aggs, search, data, field)
super().__init__(aggs, search, data, field)
class BucketData(AggResponse):
@@ -62,7 +62,7 @@ class BucketData(AggResponse):
def __getitem__(self, key: Any) -> Any:
if isinstance(key, (int, slice)):
return self.buckets[key]
return super(BucketData, self).__getitem__(key)
return super().__getitem__(key)
@property
def buckets(self) -> Any:
@@ -88,7 +88,7 @@ class TopHitsData(Response):
super(AttrDict, self).__setattr__(
"meta", AttrDict({"agg": agg, "search": search})
)
super(TopHitsData, self).__init__(search, data)
super().__init__(search, data)
__all__ = ["AggResponse"]
+5 -5
View File
@@ -37,28 +37,28 @@ class Hit(AttrDict):
if "fields" in document:
data.update(document["fields"])
super(Hit, self).__init__(data)
super().__init__(data)
# assign meta as attribute and not as key in self._d_
super(AttrDict, self).__setattr__("meta", HitMeta(document))
def __getstate__(self) -> Any:
# add self.meta since it is not in self.__dict__
return super(Hit, self).__getstate__() + (self.meta,)
return super().__getstate__() + (self.meta,)
def __setstate__(self, state: Any) -> None:
super(AttrDict, self).__setattr__("meta", state[-1])
super(Hit, self).__setstate__(state[:-1])
super().__setstate__(state[:-1])
def __dir__(self) -> Any:
# be sure to expose meta in dir(self)
return super(Hit, self).__dir__() + ["meta"]
return super().__dir__() + ["meta"]
def __repr__(self) -> str:
return "<Hit({}): {}>".format(
"/".join(
getattr(self.meta, key) for key in ("index", "id") if key in self.meta
),
super(Hit, self).__repr__(),
super().__repr__(),
)
+1 -1
View File
@@ -96,7 +96,7 @@ class ProxyDescriptor:
"""
def __init__(self, name: str) -> None:
self._attr_name = "_%s_proxy" % name
self._attr_name = f"_{name}_proxy"
def __get__(self, instance: Any, owner: Any) -> Any:
return getattr(instance, self._attr_name)
+3 -3
View File
@@ -116,9 +116,9 @@ class RequestsAWSV4SignerAuth(requests.auth.AuthBase):
)
# fetch the host information from headers
headers = dict(
(key.lower(), value) for key, value in prepared_request.headers.items()
)
headers = {
key.lower(): value for key, value in prepared_request.headers.items()
}
location = headers.get("host") or url.netloc
# construct the url and return
+2 -2
View File
@@ -49,7 +49,7 @@ class UpdateByQuery(Request):
overridden by methods (`using`, `index` and `doc_type` respectively).
"""
super(UpdateByQuery, self).__init__(**kwargs)
super().__init__(**kwargs)
self._response_class = UpdateByQueryResponse
self._script: Any = {}
self._query_proxy = QueryProxy(self, "query")
@@ -88,7 +88,7 @@ class UpdateByQuery(Request):
of all the underlying objects. Used internally by most state modifying
APIs.
"""
ubq = super(UpdateByQuery, self)._clone()
ubq = super()._clone()
ubq._response_class = self._response_class
ubq._script = self._script.copy()
+11 -23
View File
@@ -163,9 +163,7 @@ class AttrDict:
return self.__getitem__(attr_name)
except KeyError:
raise AttributeError(
"{!r} object has no attribute {!r}".format(
self.__class__.__name__, attr_name
)
f"{self.__class__.__name__!r} object has no attribute {attr_name!r}"
)
def get(self, key: Any, default: Any = None) -> Any:
@@ -181,9 +179,7 @@ class AttrDict:
del self._d_[attr_name]
except KeyError:
raise AttributeError(
"{!r} object has no attribute {!r}".format(
self.__class__.__name__, attr_name
)
f"{self.__class__.__name__!r} object has no attribute {attr_name!r}"
)
def __getitem__(self, key: Any) -> Any:
@@ -245,7 +241,7 @@ class DslMeta(type):
try:
return cls._types[name]
except KeyError:
raise UnknownDslObject("DSL type %s does not exist." % name)
raise UnknownDslObject(f"DSL type {name} does not exist.")
class DslBase(metaclass=DslMeta):
@@ -275,7 +271,7 @@ class DslBase(metaclass=DslMeta):
if default is not None:
return cls._classes[default]
raise UnknownDslObject(
"DSL class `{}` does not exist in {}.".format(name, cls._type_name)
f"DSL class `{name}` does not exist in {cls._type_name}."
)
def __init__(self, _expand__to_dot: Any = EXPAND__TO_DOT, **params: Any) -> None:
@@ -288,14 +284,14 @@ class DslBase(metaclass=DslMeta):
def _repr_params(self) -> str:
"""Produce a repr of all our parameters to be used in __repr__."""
return ", ".join(
"{}={!r}".format(n.replace(".", "__"), v)
f"{n.replace('.', '__')}={v!r}"
for (n, v) in sorted(self._params.items())
# make sure we don't include empty typed params
if "type" not in self._param_defs.get(n, {}) or v
)
def __repr__(self) -> str:
return "{}({})".format(self.__class__.__name__, self._repr_params())
return f"{self.__class__.__name__}({self._repr_params()})"
def __eq__(self, other: Any) -> bool:
return isinstance(other, self.__class__) and other.to_dict() == self.to_dict()
@@ -341,9 +337,7 @@ class DslBase(metaclass=DslMeta):
def __getattr__(self, name: str) -> Any:
if name.startswith("_"):
raise AttributeError(
"{!r} object has no attribute {!r}".format(
self.__class__.__name__, name
)
f"{self.__class__.__name__!r} object has no attribute {name!r}"
)
value = None
@@ -360,9 +354,7 @@ class DslBase(metaclass=DslMeta):
value = self._params.setdefault(name, {})
if value is None:
raise AttributeError(
"{!r} object has no attribute {!r}".format(
self.__class__.__name__, name
)
f"{self.__class__.__name__!r} object has no attribute {name!r}"
)
# wrap nested dicts in AttrDict for convenient access
@@ -567,16 +559,12 @@ def merge(data: Any, new_data: Any, raise_on_conflict: bool = False) -> None:
and isinstance(new_data, (AttrDict, collections_abc.Mapping))
):
raise ValueError(
"You can only merge two dicts! Got {!r} and {!r} instead.".format(
data, new_data
)
f"You can only merge two dicts! Got {data!r} and {new_data!r} instead."
)
if not isinstance(new_data, Dict):
raise ValueError(
"You can only merge two dicts! Got {!r} and {!r} instead.".format(
data, new_data
)
f"You can only merge two dicts! Got {data!r} and {new_data!r} instead."
)
for key, value in new_data.items():
@@ -587,7 +575,7 @@ def merge(data: Any, new_data: Any, raise_on_conflict: bool = False) -> None:
):
merge(data[key], value, raise_on_conflict)
elif key in data and data[key] != value and raise_on_conflict:
raise ValueError("Incompatible data for key %r, cannot be merged." % key)
raise ValueError(f"Incompatible data for key {key!r}, cannot be merged.")
else:
data[key] = value # type: ignore
+1 -1
View File
@@ -47,7 +47,7 @@ class Range(AttrDict):
for k in data:
if k not in self.OPS:
raise ValueError("Range received an unknown operator %r" % k)
raise ValueError(f"Range received an unknown operator {k!r}")
if "gt" in data and "gte" in data:
raise ValueError("You cannot specify both gt and gte for Range.")
+7 -7
View File
@@ -45,7 +45,7 @@ FLOAT_TYPES = (Decimal,)
TIME_TYPES = (date, datetime)
class Serializer(object):
class Serializer:
mimetype: str = ""
def loads(self, s: str) -> Any:
@@ -65,7 +65,7 @@ class TextSerializer(Serializer):
if isinstance(data, string_types):
return data
raise SerializationError("Cannot serialize %r into text." % data)
raise SerializationError(f"Cannot serialize {data!r} into text.")
class JSONSerializer(Serializer):
@@ -140,7 +140,7 @@ class JSONSerializer(Serializer):
except ImportError:
pass
raise TypeError("Unable to serialize %r (type: %s)" % (data, type(data)))
raise TypeError(f"Unable to serialize {data!r} (type: {type(data)})")
def loads(self, s: str) -> Any:
try:
@@ -167,7 +167,7 @@ DEFAULT_SERIALIZERS: Dict[str, Serializer] = {
}
class Deserializer(object):
class Deserializer:
def __init__(
self,
serializers: Dict[str, Serializer],
@@ -177,7 +177,7 @@ class Deserializer(object):
self.default = serializers[default_mimetype]
except KeyError:
raise ImproperlyConfigured(
"Cannot find default serializer (%s)" % default_mimetype
f"Cannot find default serializer ({default_mimetype})"
)
self.serializers = serializers
@@ -196,7 +196,7 @@ class Deserializer(object):
deserializer = self.serializers[mimetype]
except KeyError:
raise SerializationError(
"Unknown mimetype, unable to deserialize: %s" % mimetype
f"Unknown mimetype, unable to deserialize: {mimetype}"
)
return deserializer.loads(s)
@@ -208,7 +208,7 @@ class AttrJSONSerializer(JSONSerializer):
return data._l_
if hasattr(data, "to_dict"):
return data.to_dict()
return super(AttrJSONSerializer, self).default(data)
return super().default(data)
serializer = AttrJSONSerializer()
+1 -1
View File
@@ -64,7 +64,7 @@ def get_host_info(
return host
class Transport(object):
class Transport:
"""
Encapsulation of transport-related to logic. Handles instantiation of the
individual connections as well as creating a connection pool to hold them.
-1
View File
@@ -64,7 +64,6 @@ install_requires = [
tests_require = [
"requests>=2.0.0, <3.0.0",
"coverage<8.0.0",
"mock",
"pyyaml",
"pytest>=3.0.0",
"pytest-cov",
+8 -10
View File
@@ -31,8 +31,6 @@
# under the License.
from __future__ import print_function
import subprocess
import sys
from os import environ
@@ -78,10 +76,10 @@ def fetch_opensearch_repo() -> None:
# no test directory
if not exists(repo_path):
subprocess.check_call("mkdir %s" % repo_path, shell=True)
subprocess.check_call(f"mkdir {repo_path}", shell=True)
# make a new blank repository in the test directory
subprocess.check_call("cd %s && git init" % repo_path, shell=True)
subprocess.check_call(f"cd {repo_path} && git init", shell=True)
try:
# add a remote
@@ -104,7 +102,7 @@ def fetch_opensearch_repo() -> None:
# fetch the sha commit, version from info()
print("Fetching opensearch repo...")
subprocess.check_call("cd %s && git fetch origin %s" % (repo_path, sha), shell=True)
subprocess.check_call(f"cd {repo_path} && git fetch origin {sha}", shell=True)
def run_all(argv: Any = None) -> None:
@@ -136,18 +134,18 @@ def run_all(argv: Any = None) -> None:
argv = [
"pytest",
"--cov=opensearchpy",
"--junitxml=%s" % junit_xml,
f"--junitxml={junit_xml}",
"--log-level=DEBUG",
"--cache-clear",
"-vv",
"--cov-report=xml:%s" % codecov_xml,
f"--cov-report=xml:{codecov_xml}",
]
if (
"OPENSEARCHPY_GEN_HTML_COV" in environ
and environ.get("OPENSEARCHPY_GEN_HTML_COV") == "true"
):
codecov_html = join(abspath(dirname(dirname(__file__))), "junit", "html")
argv.append("--cov-report=html:%s" % codecov_html)
argv.append(f"--cov-report=html:{codecov_html}")
secured = False
if environ.get("OPENSEARCH_URL", "").startswith("https://"):
@@ -156,7 +154,7 @@ def run_all(argv: Any = None) -> None:
# check TEST_PATTERN env var for specific test to run
test_pattern = environ.get("TEST_PATTERN")
if test_pattern:
argv.append("-k %s" % test_pattern)
argv.append(f"-k {test_pattern}")
else:
ignores = [
"test_opensearchpy/test_server/",
@@ -196,7 +194,7 @@ def run_all(argv: Any = None) -> None:
)
if ignores:
argv.extend(["--ignore=%s" % ignore for ignore in ignores])
argv.extend([f"--ignore={ignore}" for ignore in ignores])
# Not in CI, run all tests specified.
else:
@@ -32,11 +32,11 @@ import ssl
import warnings
from platform import python_version
from typing import Any
from unittest.mock import MagicMock, patch
import aiohttp
import pytest
from _pytest.mark.structures import MarkDecorator
from mock import MagicMock, patch
from multidict import CIMultiDict
from pytest import raises
@@ -154,7 +154,7 @@ class TestAIOHttpConnection:
async def test_default_user_agent(self) -> None:
con = AIOHttpConnection()
assert con._get_default_user_agent() == "opensearch-py/%s (Python %s)" % (
assert con._get_default_user_agent() == "opensearch-py/{} (Python {})".format(
__versionstr__,
python_version(),
)
@@ -342,7 +342,7 @@ class TestAIOHttpConnection:
buf = b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"
con = await self._get_mock_connection(response_body=buf)
_, _, data = await con.perform_request("GET", "/")
assert u"你好\uda6a" == data # fmt: skip
assert "你好\uda6a" == data # fmt: skip
@pytest.mark.parametrize("exception_cls", reraise_exceptions) # type: ignore
async def test_recursion_error_reraised(self, exception_cls: Any) -> None:
@@ -9,10 +9,10 @@
from typing import Any
from unittest.mock import Mock
import pytest
from _pytest.mark.structures import MarkDecorator
from mock import Mock
from pytest import fixture
from opensearchpy.connection.async_connections import add_connection, async_connections
@@ -7,7 +7,6 @@
# Modifications Copyright OpenSearch Contributors. See
# GitHub history for details.
from __future__ import unicode_literals
import codecs
import ipaddress
@@ -73,7 +73,7 @@ async def test_cloned_index_has_analysis_attribute() -> None:
client = object()
i = AsyncIndex("my-index", using=client)
random_analyzer_name = "".join((choice(string.ascii_letters) for _ in range(100)))
random_analyzer_name = "".join(choice(string.ascii_letters) for _ in range(100))
random_analyzer = analyzer(
random_analyzer_name, tokenizer="standard", filter="standard"
)
@@ -117,7 +117,7 @@ async def test_registered_doc_type_included_in_search() -> None:
async def test_aliases_add_to_object() -> None:
random_alias = "".join((choice(string.ascii_letters) for _ in range(100)))
random_alias = "".join(choice(string.ascii_letters) for _ in range(100))
alias_dict: Any = {random_alias: {}}
index = AsyncIndex("i", using="alias")
@@ -127,7 +127,7 @@ async def test_aliases_add_to_object() -> None:
async def test_aliases_returned_from_to_dict() -> None:
random_alias = "".join((choice(string.ascii_letters) for _ in range(100)))
random_alias = "".join(choice(string.ascii_letters) for _ in range(100))
alias_dict: Any = {random_alias: {}}
index = AsyncIndex("i", using="alias")
@@ -137,7 +137,7 @@ async def test_aliases_returned_from_to_dict() -> None:
async def test_analyzers_added_to_object() -> None:
random_analyzer_name = "".join((choice(string.ascii_letters) for _ in range(100)))
random_analyzer_name = "".join(choice(string.ascii_letters) for _ in range(100))
random_analyzer = analyzer(
random_analyzer_name, tokenizer="standard", filter="standard"
)
@@ -153,7 +153,7 @@ async def test_analyzers_added_to_object() -> None:
async def test_analyzers_returned_from_to_dict() -> None:
random_analyzer_name = "".join((choice(string.ascii_letters) for _ in range(100)))
random_analyzer_name = "".join(choice(string.ascii_letters) for _ in range(100))
random_analyzer = analyzer(
random_analyzer_name, tokenizer="standard", filter="standard"
)
@@ -26,8 +26,8 @@
from typing import Any
from unittest import mock
import mock
import pytest
from multidict import CIMultiDict
@@ -25,8 +25,6 @@
# under the License.
from __future__ import unicode_literals
from typing import Any
import pytest
@@ -27,9 +27,9 @@
import asyncio
from typing import Any, List
from unittest.mock import MagicMock, patch
import pytest
from mock import MagicMock, patch
from opensearchpy import TransportError
from opensearchpy._async.helpers import actions
@@ -40,13 +40,13 @@ pytestmark = pytest.mark.asyncio
class AsyncMock(MagicMock):
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
return super(AsyncMock, self).__call__(*args, **kwargs)
return super().__call__(*args, **kwargs)
def __await__(self) -> Any:
return self().__await__()
class FailingBulkClient(object):
class FailingBulkClient:
def __init__(
self,
client: Any,
@@ -69,7 +69,7 @@ class FailingBulkClient(object):
return await self.client.bulk(*args, **kwargs)
class TestStreamingBulk(object):
class TestStreamingBulk:
async def test_actions_remain_unchanged(self, async_client: Any) -> None:
actions1 = [{"_id": 1}, {"_id": 2}]
async for ok, _ in actions.async_streaming_bulk(
@@ -281,7 +281,7 @@ class TestStreamingBulk(object):
assert 4 == failing_client._called
class TestBulk(object):
class TestBulk:
async def test_bulk_works_with_single_item(self, async_client: Any) -> None:
docs = [{"answer": 42, "_id": 1}]
success, failed = await actions.async_bulk(
@@ -453,7 +453,7 @@ async def scan_teardown(async_client: Any) -> Any:
await async_client.clear_scroll(scroll_id="_all")
class TestScan(object):
class TestScan:
async def test_order_can_be_preserved(
self, async_client: Any, scan_teardown: Any
) -> None:
@@ -492,8 +492,8 @@ class TestScan(object):
]
assert 100 == len(docs)
assert set(map(str, range(100))) == set(d["_id"] for d in docs)
assert set(range(100)) == set(d["_source"]["answer"] for d in docs)
assert set(map(str, range(100))) == {d["_id"] for d in docs}
assert set(range(100)) == {d["_source"]["answer"] for d in docs}
async def test_scroll_error(self, async_client: Any, scan_teardown: Any) -> None:
bulk: Any = []
@@ -824,7 +824,7 @@ async def reindex_setup(async_client: Any) -> Any:
yield
class TestReindex(object):
class TestReindex:
async def test_reindex_passes_kwargs_to_scan_and_bulk(
self, async_client: Any, reindex_setup: Any
) -> None:
@@ -7,7 +7,6 @@
# Modifications Copyright OpenSearch Contributors. See
# GitHub history for details.
from __future__ import unicode_literals
from typing import Any, Dict
@@ -64,7 +64,7 @@ class Repository(AsyncDocument):
@classmethod
def search(cls, using: Any = None, index: Optional[str] = None) -> Any:
return super(Repository, cls).search().filter("term", commit_repo="repo")
return super().search().filter("term", commit_repo="repo")
class Index:
name = "git"
@@ -98,7 +98,7 @@ def repo_search_cls(opensearch_version: Any) -> Any:
}
def search(self) -> Any:
s = super(RepoSearch, self).search()
s = super().search()
return s.filter("term", commit_repo="repo")
return RepoSearch
@@ -7,7 +7,6 @@
# Modifications Copyright OpenSearch Contributors. See
# GitHub history for details.
from __future__ import unicode_literals
from typing import Any
@@ -31,7 +30,7 @@ class Repository(AsyncDocument):
@classmethod
def search(cls, using: Any = None, index: Any = None) -> Any:
return super(Repository, cls).search().filter("term", commit_repo="repo")
return super().search().filter("term", commit_repo="repo")
class Index:
name = "git"
@@ -8,8 +8,6 @@
# GitHub history for details.
from __future__ import unicode_literals
import unittest
import pytest
@@ -8,8 +8,6 @@
# GitHub history for details.
from __future__ import unicode_literals
import pytest
from _pytest.mark.structures import MarkDecorator
@@ -121,7 +121,7 @@ class AsyncYamlRunner(YamlRunner):
if hasattr(self, "run_" + action_type):
await await_if_coro(getattr(self, "run_" + action_type)(action))
else:
raise RuntimeError("Invalid action type %r" % (action_type,))
raise RuntimeError(f"Invalid action type {action_type!r}")
async def run_do(self, action: Any) -> Any:
api = self.client
@@ -170,7 +170,7 @@ class AsyncYamlRunner(YamlRunner):
else:
if catch:
raise AssertionError(
"Failed to catch %r in %r." % (catch, self.last_response)
f"Failed to catch {catch!r} in {self.last_response!r}."
)
# Filter out warnings raised by other components.
@@ -197,7 +197,7 @@ class AsyncYamlRunner(YamlRunner):
for feature in features:
if feature in IMPLEMENTED_FEATURES:
continue
pytest.skip("feature '%s' is not supported" % feature)
pytest.skip(f"feature '{feature}' is not supported")
if "version" in skip:
version, reason = skip["version"], skip["reason"]
@@ -8,8 +8,6 @@
# GitHub history for details.
from __future__ import unicode_literals
import os
from unittest import IsolatedAsyncioTestCase
+1 -1
View File
@@ -8,10 +8,10 @@
# GitHub history for details.
import uuid
from unittest.mock import Mock
import pytest
from _pytest.mark.structures import MarkDecorator
from mock import Mock
pytestmark: MarkDecorator = pytest.mark.asyncio
@@ -25,15 +25,13 @@
# under the License.
from __future__ import unicode_literals
import asyncio
import json
from typing import Any
from unittest.mock import patch
import pytest
from _pytest.mark.structures import MarkDecorator
from mock import patch
from opensearchpy import AIOHttpConnection, AsyncTransport
from opensearchpy.connection import Connection
@@ -51,7 +49,7 @@ class DummyConnection(Connection):
self.delay = kwargs.pop("delay", 0)
self.calls: Any = []
self.closed = False
super(DummyConnection, self).__init__(**kwargs)
super().__init__(**kwargs)
async def perform_request(self, *args: Any, **kwargs: Any) -> Any:
if self.closed:
@@ -253,7 +251,7 @@ class TestTransport:
assert dt is t.connection_pool.dead_timeout
async def test_custom_connection_class(self) -> None:
class MyConnection(object):
class MyConnection:
def __init__(self, **kwargs: Any) -> None:
self.kwargs = kwargs
+2 -2
View File
@@ -32,7 +32,7 @@ from unittest import SkipTest, TestCase
from opensearchpy import OpenSearch
class DummyTransport(object):
class DummyTransport:
def __init__(
self, hosts: Sequence[str], responses: Any = None, **kwargs: Any
) -> None:
@@ -59,7 +59,7 @@ class DummyTransport(object):
class OpenSearchTestCase(TestCase):
def setUp(self) -> None:
super(OpenSearchTestCase, self).setUp()
super().setUp()
self.client: Any = OpenSearch(transport_class=DummyTransport) # type: ignore
def assert_call_count_equals(self, count: int) -> None:
@@ -25,8 +25,6 @@
# under the License.
from __future__ import unicode_literals
import warnings
from opensearchpy.client import OpenSearch
@@ -25,8 +25,6 @@
# under the License.
from __future__ import unicode_literals
from typing import Any
from opensearchpy.client.utils import _bulk_body, _escape, _make_path, query_params
@@ -30,9 +30,9 @@ import re
import uuid
import warnings
from typing import Any
from unittest.mock import MagicMock, Mock, patch
import pytest
from mock import MagicMock, Mock, patch
from requests.auth import AuthBase
from opensearchpy.connection import Connection, RequestsHttpConnection
@@ -258,7 +258,7 @@ class TestRequestsHttpConnection(TestCase):
"GET",
"/",
{"param": 42},
"{}".encode("utf-8"),
b"{}",
)
# trace request
@@ -282,7 +282,7 @@ class TestRequestsHttpConnection(TestCase):
"GET",
"/",
{"param": 42},
"""{"question": "what's that?"}""".encode("utf-8"),
b"""{"question": "what's that?"}""",
)
# trace request
@@ -397,7 +397,7 @@ class TestRequestsHttpConnection(TestCase):
self.assertEqual("http://localhost:9200/", request.url)
self.assertEqual("GET", request.method)
self.assertEqual('{"answer": 42}'.encode("utf-8"), request.body)
self.assertEqual(b'{"answer": 42}', request.body)
def test_http_auth_attached(self) -> None:
con = self._get_mock_connection({"http_auth": "username:secret"})
@@ -414,7 +414,7 @@ class TestRequestsHttpConnection(TestCase):
self.assertEqual("http://localhost:9200/some-prefix/_search", request.url)
self.assertEqual("GET", request.method)
self.assertEqual('{"answer": 42}'.encode("utf-8"), request.body)
self.assertEqual(b'{"answer": 42}', request.body)
# trace request
trace_curl_cmd = (
@@ -431,7 +431,7 @@ class TestRequestsHttpConnection(TestCase):
buf = b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"
con = self._get_mock_connection(response_body=buf)
_, _, data = con.perform_request("GET", "/")
self.assertEqual(u"你好\uda6a", data) # fmt: skip
self.assertEqual("你好\uda6a", data) # fmt: skip
def test_recursion_error_reraised(self) -> None:
conn = RequestsHttpConnection()
@@ -32,10 +32,10 @@ from gzip import GzipFile
from io import BytesIO
from platform import python_version
from typing import Any
from unittest.mock import MagicMock, Mock, patch
import pytest
import urllib3
from mock import MagicMock, Mock, patch
from urllib3._collections import HTTPHeaderDict
from opensearchpy import __versionstr__
@@ -130,7 +130,7 @@ class TestUrllib3HttpConnection(TestCase):
con = Urllib3HttpConnection()
self.assertEqual(
con._get_default_user_agent(),
"opensearch-py/%s (Python %s)" % (__versionstr__, python_version()),
f"opensearch-py/{__versionstr__} (Python {python_version()})",
)
def test_timeout_set(self) -> None:
@@ -385,7 +385,7 @@ class TestUrllib3HttpConnection(TestCase):
buf = b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"
con = self._get_mock_connection(response_body=buf)
_, _, data = con.perform_request("GET", "/")
self.assertEqual(u"你好\uda6a", data) # fmt: skip
self.assertEqual("你好\uda6a", data) # fmt: skip
def test_recursion_error_reraised(self) -> None:
conn = Urllib3HttpConnection()
+1 -3
View File
@@ -68,9 +68,7 @@ class TestConnectionPool(TestCase):
def test_selectors_have_access_to_connection_opts(self) -> None:
class MySelector(RoundRobinSelector):
def select(self, connections: Any) -> Any:
return self.connection_opts[
super(MySelector, self).select(connections)
]["actual"]
return self.connection_opts[super().select(connections)]["actual"]
pool = ConnectionPool(
[(x, {"actual": x * x}) for x in range(100)],
+1 -1
View File
@@ -26,8 +26,8 @@
from typing import Any
from unittest.mock import Mock
from mock import Mock
from pytest import fixture
from opensearchpy.connection.connections import add_connection, connections
@@ -28,9 +28,9 @@
import threading
import time
from typing import Any
from unittest import mock
from unittest.mock import Mock
import mock
import pytest
from opensearchpy import OpenSearch, helpers
@@ -131,7 +131,7 @@ class TestParallelBulk(TestCase):
results = list(
helpers.parallel_bulk(OpenSearch(), actions, thread_count=10, chunk_size=2)
)
self.assertTrue(len(set([r[1] for r in results])) > 1)
self.assertTrue(len({r[1] for r in results}) > 1)
class TestChunkActions(TestCase):
@@ -266,7 +266,7 @@ class TestChunkActions(TestCase):
)
self.assertEqual(25, len(chunks))
for _, chunk_actions in chunks:
chunk = u"".join(chunk_actions) # fmt: skip
chunk = "".join(chunk_actions) # fmt: skip
chunk = chunk if isinstance(chunk, str) else chunk.encode("utf-8")
self.assertLessEqual(len(chunk), max_byte_size)
@@ -24,7 +24,6 @@
# specific language governing permissions and limitations
# under the License.
from __future__ import unicode_literals
import codecs
import ipaddress
+5 -5
View File
@@ -84,7 +84,7 @@ def test_cloned_index_has_analysis_attribute() -> None:
client = object()
i: Any = Index("my-index", using=client)
random_analyzer_name = "".join((choice(string.ascii_letters) for _ in range(100)))
random_analyzer_name = "".join(choice(string.ascii_letters) for _ in range(100))
random_analyzer = analyzer(
random_analyzer_name, tokenizer="standard", filter="standard"
)
@@ -128,7 +128,7 @@ def test_registered_doc_type_included_in_search() -> None:
def test_aliases_add_to_object() -> None:
random_alias = "".join((choice(string.ascii_letters) for _ in range(100)))
random_alias = "".join(choice(string.ascii_letters) for _ in range(100))
alias_dict: Any = {random_alias: {}}
index: Any = Index("i", using="alias")
@@ -138,7 +138,7 @@ def test_aliases_add_to_object() -> None:
def test_aliases_returned_from_to_dict() -> None:
random_alias = "".join((choice(string.ascii_letters) for _ in range(100)))
random_alias = "".join(choice(string.ascii_letters) for _ in range(100))
alias_dict: Any = {random_alias: {}}
index: Any = Index("i", using="alias")
@@ -148,7 +148,7 @@ def test_aliases_returned_from_to_dict() -> None:
def test_analyzers_added_to_object() -> None:
random_analyzer_name = "".join((choice(string.ascii_letters) for _ in range(100)))
random_analyzer_name = "".join(choice(string.ascii_letters) for _ in range(100))
random_analyzer = analyzer(
random_analyzer_name, tokenizer="standard", filter="standard"
)
@@ -164,7 +164,7 @@ def test_analyzers_added_to_object() -> None:
def test_analyzers_returned_from_to_dict() -> None:
random_analyzer_name = "".join((choice(string.ascii_letters) for _ in range(100)))
random_analyzer_name = "".join(choice(string.ascii_letters) for _ in range(100))
random_analyzer = analyzer(
random_analyzer_name, tokenizer="standard", filter="standard"
)
@@ -95,7 +95,7 @@ def test_interactive_helpers(dummy_response: Any) -> None:
)
assert res
assert "<Response: %s>" % rhits == repr(res)
assert f"<Response: {rhits}>" == repr(res)
assert rhits == repr(hits)
assert {"meta", "city", "name"} == set(dir(h))
assert "<Hit(test-index/opensearch): %r>" % dummy_response["hits"]["hits"][0][
+1 -1
View File
@@ -99,7 +99,7 @@ def test_serializer_deals_with_attr_versions() -> None:
def test_serializer_deals_with_objects_with_to_dict() -> None:
class MyClass(object):
class MyClass:
def to_dict(self) -> int:
return 42
@@ -66,7 +66,7 @@ class AutoNowDate(Date):
def clean(self, data: Any) -> Any:
if data is None:
data = datetime.now()
return super(AutoNowDate, self).clean(data)
return super().clean(data)
class Log(Document):
+1 -1
View File
@@ -74,7 +74,7 @@ def sync_client_factory() -> Any:
except ConnectionError:
time.sleep(0.1)
else:
pytest.skip("OpenSearch wasn't running at %r" % (OPENSEARCH_URL,))
pytest.skip(f"OpenSearch wasn't running at {OPENSEARCH_URL!r}")
wipe_cluster(client)
yield client
@@ -25,8 +25,6 @@
# under the License.
from __future__ import unicode_literals
from . import OpenSearchTestCase
@@ -26,8 +26,7 @@
from typing import Any
from mock import patch
from unittest.mock import patch
from opensearchpy import TransportError, helpers
from opensearchpy.helpers import ScanError
@@ -36,7 +35,7 @@ from ...test_cases import SkipTest
from .. import OpenSearchTestCase
class FailingBulkClient(object):
class FailingBulkClient:
def __init__(
self,
client: Any,
@@ -383,7 +382,7 @@ class TestScan(OpenSearchTestCase):
def teardown_method(self, m: Any) -> None:
self.client.transport.perform_request("DELETE", "/_search/scroll/_all")
super(TestScan, self).teardown_method(m)
super().teardown_method(m)
def test_order_can_be_preserved(self) -> None:
bulk: Any = []
@@ -415,8 +414,8 @@ class TestScan(OpenSearchTestCase):
docs = list(helpers.scan(self.client, index="test_index", size=2))
self.assertEqual(100, len(docs))
self.assertEqual(set(map(str, range(100))), set(d["_id"] for d in docs))
self.assertEqual(set(range(100)), set(d["_source"]["answer"] for d in docs))
self.assertEqual(set(map(str, range(100))), {d["_id"] for d in docs})
self.assertEqual(set(range(100)), {d["_source"]["answer"] for d in docs})
def test_scroll_error(self) -> None:
bulk: Any = []
@@ -24,7 +24,6 @@
# specific language governing permissions and limitations
# under the License.
from __future__ import unicode_literals
from typing import Any, Dict
@@ -79,7 +79,7 @@ class Repository(Document):
@classmethod
def search(cls, using: Any = None, index: Any = None) -> Any:
return super(Repository, cls).search().filter("term", commit_repo="repo")
return super().search().filter("term", commit_repo="repo")
class Index:
name = "git"
@@ -106,7 +106,7 @@ def repo_search_cls(opensearch_version: Any) -> Any:
}
def search(self) -> Any:
s = super(RepoSearch, self).search()
s = super().search()
return s.filter("term", commit_repo="repo")
return RepoSearch
@@ -24,7 +24,6 @@
# specific language governing permissions and limitations
# under the License.
from __future__ import unicode_literals
from typing import Any
@@ -52,7 +51,7 @@ class Repository(Document):
@classmethod
def search(cls, using: Any = None, index: Any = None) -> Any:
return super(Repository, cls).search().filter("term", commit_repo="repo")
return super().search().filter("term", commit_repo="repo")
class Index:
name = "git"
@@ -7,7 +7,6 @@
# Modifications Copyright OpenSearch Contributors. See
# GitHub history for details.
from __future__ import unicode_literals
import time
@@ -8,8 +8,6 @@
# GitHub history for details.
from __future__ import unicode_literals
import unittest
from opensearchpy.helpers.test import OPENSEARCH_VERSION
@@ -8,8 +8,6 @@
# GitHub history for details.
from __future__ import unicode_literals
from opensearchpy.exceptions import NotFoundError
from .. import OpenSearchTestCase
@@ -8,8 +8,6 @@
# GitHub history for details.
from __future__ import unicode_literals
import unittest
from typing import Any, Dict
@@ -169,7 +169,7 @@ class YamlRunner:
if hasattr(self, "run_" + action_type):
getattr(self, "run_" + action_type)(action)
else:
raise RuntimeError("Invalid action type %r" % (action_type,))
raise RuntimeError(f"Invalid action type {action_type!r}")
def run_do(self, action: Any) -> Any:
api = self.client
@@ -218,7 +218,7 @@ class YamlRunner:
else:
if catch:
raise AssertionError(
"Failed to catch %r in %r." % (catch, self.last_response)
f"Failed to catch {catch!r} in {self.last_response!r}."
)
# Filter out warnings raised by other components.
@@ -248,7 +248,7 @@ class YamlRunner:
elif catch[0] == "/" and catch[-1] == "/":
assert (
re.search(catch[1:-1], exception.error + " " + repr(exception.info)),
"%s not in %r" % (catch, exception.info),
f"{catch} not in {exception.info!r}",
) is not None
self.last_response = exception.info
@@ -262,7 +262,7 @@ class YamlRunner:
for feature in features:
if feature in IMPLEMENTED_FEATURES:
continue
pytest.skip("feature '%s' is not supported" % feature)
pytest.skip(f"feature '{feature}' is not supported")
if "version" in skip:
version, reason = skip["version"], skip["reason"]
@@ -328,10 +328,7 @@ class YamlRunner:
and expected.strip().endswith("/")
):
expected = re.compile(expected.strip()[1:-1], re.VERBOSE | re.MULTILINE)
assert expected.search(value), "%r does not match %r" % (
value,
expected,
)
assert expected.search(value), f"{value!r} does not match {expected!r}"
else:
self._assert_match_equals(value, expected)
@@ -341,7 +338,7 @@ class YamlRunner:
expected = self._resolve(expected) # dict[str, str]
if expected not in value:
raise AssertionError("%s is not contained by %s" % (expected, value))
raise AssertionError(f"{expected} is not contained by {value}")
def run_transform_and_set(self, action: Any) -> None:
for key, value in action.items():
@@ -371,7 +368,7 @@ class YamlRunner:
break
if isinstance(value, dict):
value = dict((k, self._resolve(v)) for (k, v) in value.items())
value = {k: self._resolve(v) for (k, v) in value.items()}
elif isinstance(value, list):
value = list(map(self._resolve, value))
return value
@@ -412,7 +409,7 @@ class YamlRunner:
if isinstance(b, string_types) and isinstance(a, float) and "e" in repr(a):
a = repr(a).replace("e+", "E")
assert a == b, "%r does not match %r" % (a, b)
assert a == b, f"{a!r} does not match {b!r}"
@pytest.fixture(scope="function") # type: ignore
@@ -473,7 +470,7 @@ def load_rest_api_tests() -> None:
for prefix in ("rest-api-spec/", "test/", "oss/"):
if pytest_test_name.startswith(prefix):
pytest_test_name = pytest_test_name[len(prefix) :]
pytest_param_id = "%s[%d]" % (pytest_test_name, test_number)
pytest_param_id = f"{pytest_test_name}[{test_number}]"
pytest_param = {
"setup": setup_steps,
@@ -487,7 +484,7 @@ def load_rest_api_tests() -> None:
YAML_TEST_SPECS.append(pytest.param(pytest_param, id=pytest_param_id))
except Exception as e:
warnings.warn("Could not load REST API tests: %s" % (str(e),))
warnings.warn(f"Could not load REST API tests: {str(e)}")
load_rest_api_tests()
@@ -8,8 +8,6 @@
# GitHub history for details.
from __future__ import unicode_literals
import os
from unittest import TestCase
+2 -5
View File
@@ -25,13 +25,10 @@
# under the License.
from __future__ import unicode_literals
import json
import time
from typing import Any
from mock import patch
from unittest.mock import patch
from opensearchpy.connection import Connection
from opensearchpy.connection_pool import DummyConnectionPool
@@ -47,7 +44,7 @@ class DummyConnection(Connection):
self.status, self.data = kwargs.pop("status", 200), kwargs.pop("data", "{}")
self.headers = kwargs.pop("headers", {})
self.calls: Any = []
super(DummyConnection, self).__init__(**kwargs)
super().__init__(**kwargs)
def perform_request(self, *args: Any, **kwargs: Any) -> Any:
self.calls.append((args, kwargs))
+5 -5
View File
@@ -69,7 +69,7 @@ def run(*argv: Any, expect_exit_code: int = 0) -> None:
else:
os.chdir(TMP_DIR)
cmd = " ".join(shlex.quote(x) for x in argv)
cmd = shlex.join(argv)
print("$ " + cmd)
exit_code = os.system(cmd)
if exit_code != expect_exit_code:
@@ -270,7 +270,7 @@ def main() -> None:
# Rename the module to fit the suffix.
shutil.move(
os.path.join(BASE_DIR, "opensearchpy"),
os.path.join(BASE_DIR, "opensearchpy%s" % suffix),
os.path.join(BASE_DIR, f"opensearchpy{suffix}"),
)
# Ensure that the version within 'opensearchpy/_version.py' is correct.
@@ -279,7 +279,7 @@ def main() -> None:
version_data = file.read()
version_data = re.sub(
r"__versionstr__: str = \"[^\"]+\"",
'__versionstr__: str = "%s"' % version,
f'__versionstr__: str = "{version}"',
version_data,
)
with open(version_path, "w", encoding="utf-8") as file:
@@ -296,7 +296,7 @@ def main() -> None:
file.write(
setup_py.replace(
'PACKAGE_NAME = "opensearch-py"',
'PACKAGE_NAME = "opensearch-py%s"' % suffix,
f'PACKAGE_NAME = "opensearch-py{suffix}"',
)
)
@@ -306,7 +306,7 @@ def main() -> None:
# Clean up everything.
run("git", "checkout", "--", "setup.py", "opensearchpy/")
if suffix:
run("rm", "-rf", "opensearchpy%s/" % suffix)
run("rm", "-rf", f"opensearchpy{suffix}/")
# Test everything that got created
dists = os.listdir(os.path.join(BASE_DIR, "dist"))
+4 -4
View File
@@ -95,7 +95,7 @@ def blacken(filename: Any) -> None:
assert result.exit_code == 0, result.output
@lru_cache()
@lru_cache
def is_valid_url(url: str) -> bool:
"""
makes a call to the url
@@ -222,7 +222,7 @@ class Module:
# Identifying the insertion point for the "THIS CODE IS AUTOMATICALLY GENERATED" header.
if os.path.exists(self.filepath):
with open(self.filepath, "r", encoding="utf-8") as file:
with open(self.filepath, encoding="utf-8") as file:
content = file.read()
if header_separator in content:
update_header = False
@@ -249,7 +249,7 @@ class Module:
generated_file_header_path = os.path.join(
current_script_folder, "generated_file_headers.txt"
)
with open(generated_file_header_path, "r", encoding="utf-8") as header_file:
with open(generated_file_header_path, encoding="utf-8") as header_file:
header_content = header_file.read()
# Imports are temporarily removed from the header and are regenerated
@@ -287,7 +287,7 @@ class Module:
# Generating imports for each module
utils_imports = ""
file_content = ""
with open(self.filepath, "r", encoding="utf-8") as file:
with open(self.filepath, encoding="utf-8") as file:
content = file.read()
keywords = [
"SKIP_IN_PATH",
+2 -2
View File
@@ -55,7 +55,7 @@ def does_file_need_fix(filepath: str) -> bool:
if not re.search(r"\.py$", filepath):
return False
existing_header = ""
with open(filepath, mode="r", encoding="utf-8") as file:
with open(filepath, encoding="utf-8") as file:
for line in file:
line = line.strip()
if len(line) == 0 or line in LINES_TO_KEEP:
@@ -73,7 +73,7 @@ def add_header_to_file(filepath: str) -> None:
writes the license header to the beginning of a file
:param filepath: relative or absolute filepath to update
"""
with open(filepath, mode="r", encoding="utf-8") as file:
with open(filepath, encoding="utf-8") as file:
lines = list(file)
i = 0
for i, line in enumerate(lines):