Pylint integration updates (#643)

* updated files with docstrings to pass pylint

Signed-off-by: Mark Cohen <markcoh@amazon.com>

* updated samples to prepare for enabling missing-docstring linter; will continue to work on this before committing setup.cfg

Signed-off-by: Mark Cohen <markcoh@amazon.com>

* removed missing-function-docstring from setup.cfg so the linter doesn't fail while work on docstrings continues

Signed-off-by: Mark Cohen <markcoh@amazon.com>

* corrected unnecessary return docstring values

Signed-off-by: Mark Cohen <markcoh@amazon.com>

* fixing failure in 'black' on reformatting

Signed-off-by: Mark Cohen <markcoh@amazon.com>

* updated utils to pass missing-function-docstring tests

Signed-off-by: Mark Cohen <markcoh@amazon.com>

* updated functions with missing docstrings or pylint ignore instructions; added a utility to automatically add these ignore instructions to most functions that should be self-describing; rolled back some automatically generated code mistakenly changed

Signed-off-by: Mark Cohen <markcoh@amazon.com>

* * ignoring opensearchpy for pylint and then added it back to noxfile.py
* fixed some lints; created a feature flag for newer dynamic pylint so now lints can be fixed first in legacy code and then enabled by multiple people
* extracted a method for per-folder linting
* updated noxfile.lint_per_folder with type hints
* enabled unspecified-encoding in pylint
* added disable missing-function-docstring pragma to test_clients.py in test_async and test_server
* added more encodings to pass unspecified-encoding pylint tests
* updated changelog
Signed-off-by: Mark Cohen <markcoh@amazon.com>

* updated CHANGELOG.md entry
removed the feature flag for pylint lint_per_folder
fixed failures from mypy and pylint
removed pylint MESSAGE CONTROL config from setup.cfg after relocating to lint_per_folder method
Signed-off-by: Mark Cohen <markcoh@amazon.com>

* removed pylint ignore missing-function-docstring

Signed-off-by: Mark Cohen <markcoh@amazon.com>

* added pylint.extensions.docparams plugin

updated some docstrings to correct parameters

removed pylint from setup.cfg

Signed-off-by: Mark Cohen <markcoh@amazon.com>

* added four lints for opensearchpy/

Signed-off-by: Mark Cohen <markcoh@amazon.com>

* adding await back to client.info() call

Signed-off-by: Mark Cohen <markcoh@amazon.com>

* updated TODOs as requested

renamed test_opensearchpy.test_async.test_server.test_helpers.conftest.setup_ubq_tests to setup_update_by_query_tests

added
OpenSearch-main/rest-api-spec/src/main/resources/rest-api-spec/test/indices/stats/50_noop_update[0]
to skip tests list

run_tests.py catches a CalledProcessError when the git repo already exists and the command to add the origin fails in fetch_opensearch_repo()

Signed-off-by: Mark Cohen <markcoh@amazon.com>

---------

Signed-off-by: Mark Cohen <markcoh@amazon.com>
This commit is contained in:
Mark Cohen
2024-01-19 13:36:05 -05:00
committed by GitHub
parent 2ab3a40307
commit 0ddbf8cafa
47 changed files with 563 additions and 218 deletions
+1
View File
@@ -3,6 +3,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
## [Unreleased]
### Added
- Added pylint `unspecified-encoding` and `missing-function-docstring` and ignored opensearchpy for lints (([#643](https://github.com/opensearch-project/opensearch-py/pull/643)))
- Added pylint `line-too-long` and `invalid-name` ([#590](https://github.com/opensearch-project/opensearch-py/pull/590))
- Added pylint `pointless-statement` ([#611](https://github.com/opensearch-project/opensearch-py/pull/611))
- Added a log collection guide ([#579](https://github.com/opensearch-project/opensearch-py/pull/579))
+3 -2
View File
@@ -36,8 +36,9 @@ async def index_records(client: Any, index_name: str, item_count: int) -> None:
async def test_async(client_count: int = 1, item_count: int = 1) -> None:
"""
asynchronously index with item_count records and run client_count clients. This function can be used to
test balancing the number of items indexed with the number of documents.
asynchronously index with item_count records and run client_count
clients. This function can be used to test balancing the number of
items indexed with the number of documents.
"""
host = "localhost"
port = 9200
+6 -7
View File
@@ -22,13 +22,12 @@ from opensearchpy import OpenSearch
def get_info(client: Any, request_count: int) -> float:
"""get info from client"""
tt: float = 0
for n in range(request_count):
total_time: float = 0
for request in range(request_count):
start = time.time() * 1000
client.info()
total_time = time.time() * 1000 - start
tt += total_time
return tt
total_time += time.time() * 1000 - start
return total_time
def test(thread_count: int = 1, request_count: int = 1, client_count: int = 1) -> None:
@@ -71,8 +70,8 @@ def test(thread_count: int = 1, request_count: int = 1, client_count: int = 1) -
thread.start()
latency = 0
for t in threads:
latency += t.join()
for thread in threads:
latency += thread.join()
print(f"latency={latency}")
+15 -15
View File
@@ -23,29 +23,29 @@ from opensearchpy import OpenSearch, Urllib3HttpConnection
def index_records(client: Any, index_name: str, item_count: int) -> Any:
"""bulk index item_count records into index_name"""
tt = 0
for n in range(10):
total_time = 0
for iteration in range(10):
data: Any = []
for i in range(item_count):
for item in range(item_count):
data.append(
json.dumps({"index": {"_index": index_name, "_id": str(uuid.uuid4())}})
)
data.append(json.dumps({"value": i}))
data.append(json.dumps({"value": item}))
data = "\n".join(data)
start = time.time() * 1000
rc = client.bulk(data)
if rc["errors"]:
raise Exception(rc["errors"])
response = client.bulk(data)
if response["errors"]:
raise Exception(response["errors"])
server_time = rc["took"]
total_time = time.time() * 1000 - start
server_time = response["took"]
this_time = time.time() * 1000 - start
if total_time < server_time:
raise Exception(f"total={total_time} < server={server_time}")
if this_time < server_time:
raise Exception(f"total={this_time} < server={server_time}")
tt += total_time - server_time
return tt
total_time += this_time - server_time
return total_time
def test(thread_count: int = 1, item_count: int = 1, client_count: int = 1) -> None:
@@ -105,8 +105,8 @@ def test(thread_count: int = 1, item_count: int = 1, client_count: int = 1) -> N
thread.start()
latency = 0
for t in threads:
latency += t.join()
for thread in threads:
latency += thread.join()
clients[0].indices.refresh(index=index_name)
count = clients[0].count(index=index_name)
+80 -2
View File
@@ -25,7 +25,7 @@
# under the License.
from typing import Any
from typing import Any, List
import nox
@@ -43,6 +43,10 @@ SOURCE_FILES = (
@nox.session(python=["3.6", "3.7", "3.8", "3.9", "3.10", "3.11"]) # type: ignore
def test(session: Any) -> None:
"""
runs all tests with a fresh python environment using "python setup.py test"
:param session: current nox session
"""
session.install(".")
# ensure client can be imported without aiohttp
session.run("python", "-c", "import opensearchpy\nprint(opensearchpy.OpenSearch())")
@@ -59,6 +63,10 @@ def test(session: Any) -> None:
@nox.session(python=["3.7"]) # type: ignore
def format(session: Any) -> None:
"""
runs black and isort to format the files accordingly
:param session: current nox session
"""
session.install(".")
session.install("black", "isort")
@@ -71,6 +79,10 @@ def format(session: Any) -> None:
@nox.session(python=["3.7"]) # type: ignore
def lint(session: Any) -> None:
"""
runs isort, black, flake8, pylint, and mypy to check the files according to each utility's function
:param session: current nox session
"""
session.install(
"flake8",
"black",
@@ -89,7 +101,9 @@ def lint(session: Any) -> None:
session.run("isort", "--check", *SOURCE_FILES)
session.run("black", "--check", *SOURCE_FILES)
session.run("flake8", *SOURCE_FILES)
session.run("pylint", *SOURCE_FILES)
lint_per_folder(session)
session.run("python", "utils/license_headers.py", "check", *SOURCE_FILES)
# Workaround to make '-r' to still work despite uninstalling aiohttp below.
@@ -108,8 +122,68 @@ def lint(session: Any) -> None:
session.run("mypy", "--strict", "test_opensearchpy/test_types/sync_types.py")
def lint_per_folder(session: Any) -> None:
"""
allows configuration of pylint rules per folder and runs a pylint command for each folder
:param session: the current nox session
"""
# any paths that should not be run through pylint
exclude_path_from_linting: List[str] = []
# all paths not referenced in override_enable will run these lints
default_enable = [
"line-too-long",
"invalid-name",
"pointless-statement",
"unspecified-encoding",
"missing-function-docstring",
"missing-param-doc",
"differing-param-doc",
]
override_enable = {
"test_opensearchpy/": [
"line-too-long",
# "invalid-name", lots of short functions with one or two character names
"pointless-statement",
"unspecified-encoding",
"missing-param-doc",
"differing-param-doc",
# "missing-function-docstring", test names usually are, self describing
],
"opensearchpy/": [
"line-too-long",
"invalid-name",
"pointless-statement",
"unspecified-encoding",
],
}
for source_file in SOURCE_FILES:
if source_file in exclude_path_from_linting:
continue
args = [
"--disable=all",
"--max-line-length=240",
"--good-names-rgxs=^[_a-z][_a-z0-9]?$",
"--load-plugins",
"pylint.extensions.docparams",
]
if source_file in override_enable:
args.append(f"--enable={','.join(override_enable[source_file])}")
else:
args.append(f"--enable={','.join(default_enable)}")
args.append(source_file)
session.run("pylint", *args)
@nox.session() # type: ignore
def docs(session: Any) -> None:
"""
builds the html documentation for the client
:param session: current nox session
"""
session.install(".")
session.install(".[docs]")
with session.chdir("docs"):
@@ -118,6 +192,10 @@ def docs(session: Any) -> None:
@nox.session() # type: ignore
def generate(session: Any) -> None:
"""
generates the base API code
:param session: current nox session
"""
session.install("-rdev-requirements.txt")
session.run("python", "utils/generate_api.py")
format(session)
+6
View File
@@ -35,6 +35,12 @@ map = map # pylint: disable=invalid-name
def to_str(x: Union[str, bytes], encoding: str = "ascii") -> str:
"""
returns x as a string encoded in "encoding" if it is not already a string
:param x: the value to convert to a str
:param encoding: the encoding to convert to - see https://docs.python.org/3/library/codecs.html#standard-encodings
:return: an encoded str
"""
if not isinstance(x, str):
return x.decode(encoding)
return x
@@ -19,8 +19,8 @@ from opensearchpy import OpenSearch
def main() -> None:
"""
demonstrates various functions to operate on the index (e.g. clear different levels of cache, refreshing the
index)
demonstrates various functions to operate on the index
(e.g. clear different levels of cache, refreshing the index)
"""
# Set up
client = OpenSearch(
+6 -3
View File
@@ -21,10 +21,13 @@ from opensearchpy import OpenSearch, RequestsAWSV4SignerAuth, RequestsHttpConnec
def main() -> None:
"""
connects to a cluster specified in environment variables, creates an index, inserts documents,
connects to a cluster specified in environment variables,
creates an index, inserts documents,
searches the index, deletes the document, deletes the index.
the environment variables are "ENDPOINT" for the cluster endpoint, AWS_REGION for the region in which the cluster
is hosted, and SERVICE to indicate if this is an ES 7.10.2 compatible cluster
the environment variables are "ENDPOINT" for the cluster
endpoint, AWS_REGION for the region in which the cluster
is hosted, and SERVICE to indicate if this is an ES 7.10.2
compatible cluster
"""
# verbose logging
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
+4 -3
View File
@@ -21,9 +21,10 @@ from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnecti
def main() -> None:
"""
1. connects to an OpenSearch cluster on AWS defined by environment variables (i.e. ENDPOINT - cluster endpoint like
my-test-domain.us-east-1.es.amazonaws.com; AWS_REGION like us-east-1, us-west-2; and SERVICE like es which
differentiates beteween serverless and the managed service.
1. connects to an OpenSearch cluster on AWS defined by environment variables
(i.e. ENDPOINT - cluster endpoint like my-test-domain.us-east-1.es.
amazonaws.com; AWS_REGION like us-east-1, us-west-2; and SERVICE like es which
differentiates between serverless and the managed service.
2. creates an index called "movies" and adds a single document
3. queries for that document
4. deletes the document
+1 -1
View File
@@ -53,7 +53,7 @@ def main() -> None:
data.append({"index": {"_index": index_name, "_id": i}})
data.append({"value": i})
rc = client.bulk(data)
rc = client.bulk(data) # pylint: disable=invalid-name
if rc["errors"]:
print("There were errors:")
for item in rc["items"]:
+3 -3
View File
@@ -18,8 +18,8 @@ from opensearchpy import OpenSearch, helpers
def main() -> None:
"""
demonstrates how to bulk load data using opensearchpy.helpers including examples of serial, parallel, and streaming
bulk load
demonstrates how to bulk load data using opensearchpy.helpers
including examples of serial, parallel, and streaming bulk load
"""
# connect to an instance of OpenSearch
@@ -56,7 +56,7 @@ def main() -> None:
data.append({"_index": index_name, "_id": i, "value": i})
# serialized bulk raising an exception on error
rc = helpers.bulk(client, data)
rc = helpers.bulk(client, data) # pylint: disable=invalid-name
print(f"Bulk-inserted {rc[0]} items (bulk).")
# parallel bulk with explicit error checking
+1 -1
View File
@@ -55,7 +55,7 @@ def main() -> None:
data += json.dumps({"index": {"_index": index_name, "_id": i}}) + "\n"
data += json.dumps({"value": i}) + "\n"
rc = client.bulk(data)
rc = client.bulk(data) # pylint: disable=invalid-name
if rc["errors"]:
print("There were errors:")
for item in rc["items"]:
@@ -18,7 +18,8 @@ from opensearchpy import OpenSearch
def main() -> None:
"""
provides samples for different ways to handle documents including indexing, searching, updating, and deleting
provides samples for different ways to handle documents
including indexing, searching, updating, and deleting
"""
# Connect to OpenSearch
client = OpenSearch(
+10 -7
View File
@@ -17,8 +17,9 @@ from opensearchpy import OpenSearch
def main() -> None:
"""
an example showing how to create an synchronous connection to OpenSearch, create an index, index a document
and search to return the document
an example showing how to create an synchronous connection to
OpenSearch, create an index, index a document and search to
return the document
"""
host = "localhost"
port = 9200
@@ -49,19 +50,21 @@ def main() -> None:
document = {"title": "Moneyball", "director": "Bennett Miller", "year": "2011"}
id = "1"
doc_id = "1"
response = client.index(index=index_name, body=document, id=id, refresh=True)
response = client.index(index=index_name, body=document, id=doc_id, refresh=True)
print(response)
# search for a document
q = "miller"
user_query = "miller"
query = {
"size": 5,
"query": {"multi_match": {"query": q, "fields": ["title^2", "director"]}},
"query": {
"multi_match": {"query": user_query, "fields": ["title^2", "director"]}
},
}
response = client.search(body=query, index=index_name)
@@ -70,7 +73,7 @@ def main() -> None:
# delete the document
response = client.delete(index=index_name, id=id)
response = client.delete(index=index_name, id=doc_id)
print(response)
+7 -4
View File
@@ -17,8 +17,9 @@ from opensearchpy import AsyncOpenSearch
async def main() -> None:
"""
an example showing how to create an asynchronous connection to OpenSearch, create an index, index a document
and search to return the document
an example showing how to create an asynchronous connection
to OpenSearch, create an index, index a document and
search to return the document
"""
# connect to OpenSearch
host = "localhost"
@@ -68,11 +69,13 @@ async def main() -> None:
await client.indices.refresh(index=index_name)
# search for a document
q = "miller"
user_query = "miller"
query = {
"size": 5,
"query": {"multi_match": {"query": q, "fields": ["title^2", "director"]}},
"query": {
"multi_match": {"query": user_query, "fields": ["title^2", "director"]}
},
}
results = await client.search(body=query, index=index_name)
@@ -13,6 +13,7 @@ from opensearchpy import OpenSearch
def main() -> None:
"""
# pylint: disable=line-too-long
1. connects to an OpenSearch instance running on localhost
2. Create an index template named `books` with default settings and mappings for indices of
the `books-*` pattern. You can create an index template to define default settings and mappings for indices
+7 -5
View File
@@ -46,24 +46,26 @@ def main() -> None:
document = {"title": "Moneyball", "director": "Bennett Miller", "year": "2011"}
id = "1"
doc_id = "1"
print(client.http.put(f"/{index_name}/_doc/{id}?refresh=true", body=document))
print(client.http.put(f"/{index_name}/_doc/{doc_id}?refresh=true", body=document))
# search for a document
q = "miller"
user_query = "miller"
query = {
"size": 5,
"query": {"multi_match": {"query": q, "fields": ["title^2", "director"]}},
"query": {
"multi_match": {"query": user_query, "fields": ["title^2", "director"]}
},
}
print(client.http.post(f"/{index_name}/_search", body=query))
# delete the document
print(client.http.delete(f"/{index_name}/_doc/{id}"))
print(client.http.delete(f"/{index_name}/_doc/{doc_id}"))
# delete the index
+9 -6
View File
@@ -17,7 +17,8 @@ from opensearchpy import AsyncOpenSearch
async def main() -> None:
"""
this sample uses asyncio and AsyncOpenSearch to asynchronously connect to local OpenSearch cluster, create an index,
this sample uses asyncio and AsyncOpenSearch to asynchronously
connect to local OpenSearch cluster, create an index,
index data, search the index, delete the document, delete the index
"""
# connect to OpenSearch
@@ -51,28 +52,30 @@ async def main() -> None:
document = {"title": "Moneyball", "director": "Bennett Miller", "year": "2011"}
id = "1"
doc_id = "1"
print(
await client.http.put(
f"/{index_name}/_doc/{id}?refresh=true", body=document
f"/{index_name}/_doc/{doc_id}?refresh=true", body=document
)
)
# search for a document
q = "miller"
user_query = "miller"
query = {
"size": 5,
"query": {"multi_match": {"query": q, "fields": ["title^2", "director"]}},
"query": {
"multi_match": {"query": user_query, "fields": ["title^2", "director"]}
},
}
print(await client.http.post(f"/{index_name}/_search", body=query))
# delete the document
print(await client.http.delete(f"/{index_name}/_doc/{id}"))
print(await client.http.delete(f"/{index_name}/_doc/{doc_id}"))
# delete the index
+6 -5
View File
@@ -24,8 +24,9 @@ urllib3.disable_warnings()
def main() -> None:
"""
sample for custom logging; this shows how to create a console handler, connect to OpenSearch, define a custom
logger and log to an OpenSearch index
sample for custom logging; this shows how to create a
console handler, connect to OpenSearch, define a custom
logger, and log to an OpenSearch index
"""
print("Collecting logs.")
@@ -90,9 +91,9 @@ def main() -> None:
index=self._build_index_name(),
body=document,
)
except Exception as e:
print(f"Failed to send log to OpenSearch: {e}")
logging.warning(f"Failed to send log to OpenSearch: {e}")
except Exception as ex:
print(f"Failed to send log to OpenSearch: {ex}")
logging.warning(f"Failed to send log to OpenSearch: {ex}")
raise
print("Creating an instance of OpenSearchHandler and adding it to the logger...")
-6
View File
@@ -22,10 +22,4 @@ target-version = 'py33'
[mypy]
ignore_missing_imports=True
[pylint]
max-line-length = 240
good-names-rgxs = ^[_a-z][_a-z0-9]?$ # allow for 1-character variable names
[pylint.MESSAGE CONTROL]
disable = all
enable = line-too-long, invalid-name, pointless-statement
+4 -2
View File
@@ -34,7 +34,9 @@ PACKAGE_NAME = "opensearch-py"
PACKAGE_VERSION = ""
BASE_DIR = abspath(dirname(__file__))
with open(join(BASE_DIR, PACKAGE_NAME.replace("-", ""), "_version.py")) as f:
with open(
join(BASE_DIR, PACKAGE_NAME.replace("-", ""), "_version.py"), encoding="utf-8"
) as f:
data = f.read()
m = re.search(r"^__versionstr__: str\s+=\s+[\"\']([^\"\']+)[\"\']", data, re.M)
if m:
@@ -42,7 +44,7 @@ with open(join(BASE_DIR, PACKAGE_NAME.replace("-", ""), "_version.py")) as f:
else:
raise Exception(f"Invalid version: {data}")
with open(join(BASE_DIR, "README.md")) as f:
with open(join(BASE_DIR, "README.md"), encoding="utf-8") as f:
long_description = f.read().strip()
MODULE_DIR = PACKAGE_NAME.replace("-", "")
+25
View File
@@ -37,10 +37,16 @@ import subprocess
import sys
from os import environ
from os.path import abspath, dirname, exists, join, pardir
from subprocess import CalledProcessError
from typing import Any
def fetch_opensearch_repo() -> None:
"""
runs a git fetch origin on configured opensearch core repo
:return: None if environmental variables TEST_OPENSEARCH_YAML_DIR
is set or TEST_OPENSEARCH_NOFETCH is set to False; else returns nothing
"""
# user is manually setting YAML dir, don't tamper with it
if "TEST_OPENSEARCH_YAML_DIR" in environ:
return
@@ -77,12 +83,20 @@ def fetch_opensearch_repo() -> None:
# make a new blank repository in the test directory
subprocess.check_call("cd %s && git init" % repo_path, shell=True)
try:
# add a remote
subprocess.check_call(
"cd %s && git remote add origin https://github.com/opensearch-project/opensearch.git"
% repo_path,
shell=True,
)
except CalledProcessError as e:
# if the run is interrupted from a previous run, it doesn't clean up, and the git add origin command
# errors out; this allows the test to continue
remote_origin_already_exists = 3
print(e)
if e.returncode != remote_origin_already_exists:
sys.exit(1)
# fetch the sha commit, version from info()
print("Fetching opensearch repo...")
@@ -90,6 +104,17 @@ def fetch_opensearch_repo() -> None:
def run_all(argv: Any = None) -> None:
"""
run all the tests given arguments and environment variables
- sets defaults if argv is None, running "pytest --cov=opensearchpy
--junitxml=<path to opensearch-py-junit.xml>
--log-level=DEBUG --cache-clear -vv --cov-report=<path to output code coverage"
* GITHUB_ACTION: fetches yaml tests if this is not in environment variables
* TEST_PATTERN: specify a test to run
* TEST_TYPE: "server" runs on TLS connection; None is unencrypted
* OPENSEARCH_VERSION: "SNAPSHOT" does not do anything with plugins
:param argv: if this is None, then the default arguments
"""
sys.exitfunc = lambda: sys.stderr.write("Shutting down....\n") # type: ignore
# fetch yaml tests anywhere that's not GitHub Actions
if "GITHUB_ACTION" not in environ:
@@ -88,7 +88,8 @@ class TestAIOHttpConnection:
# it means SSLContext is not available for that version of python
# and we should skip this test.
pytest.skip(
"Test test_ssl_context is skipped cause SSLContext is not available for this version of Python"
"Test test_ssl_context is skipped cause SSLContext is "
"not available for this version of Python"
)
con = AIOHttpConnection(use_ssl=True, ssl_context=context)
@@ -202,8 +203,8 @@ class TestAIOHttpConnection:
con = AIOHttpConnection(use_ssl=True, verify_certs=False)
assert 1 == len(w)
assert (
"Connecting to https://localhost:9200 using SSL with verify_certs=False is insecure."
== str(w[0].message)
"Connecting to https://localhost:9200 using SSL with "
"verify_certs=False is insecure." == str(w[0].message)
)
assert con.use_ssl
@@ -379,13 +380,17 @@ class TestConnectionHttpServer:
@classmethod
def setup_class(cls) -> None:
# Start server
"""
Start server
"""
cls.server = TestHTTPServer(port=8081)
cls.server.start()
@classmethod
def teardown_class(cls) -> None:
# Stop server
"""
stop server
"""
cls.server.stop()
async def httpserver(self, conn: Any, **kwargs: Any) -> Any:
@@ -22,6 +22,10 @@ pytestmark: MarkDecorator = pytest.mark.asyncio
@fixture # type: ignore
async def mock_client(dummy_response: Any) -> Any:
"""
yields a mock client with the dummy_response param
:param dummy_response: any kind of response for test
"""
client = Mock()
client.search.return_value = dummy_response
await add_connection("mock", client)
@@ -26,5 +26,6 @@ class TestPluginsClient:
client.plugins.__init__(client) # type: ignore
assert (
str(w[0].message)
== "Cannot load `alerting` directly to AsyncOpenSearch as it already exists. Use `AsyncOpenSearch.plugin.alerting` instead."
== "Cannot load `alerting` directly to AsyncOpenSearch as it already exists. Use "
"`AsyncOpenSearch.plugin.alerting` instead."
)
@@ -34,13 +34,19 @@ from ...utils import wipe_cluster
class AsyncOpenSearchTestCase(IsolatedAsyncioTestCase): # type: ignore
async def asyncSetUp(self) -> None: # pylint: disable=invalid-name
async def asyncSetUp(
self,
) -> None:
# pylint: disable=invalid-name,missing-function-docstring
self.client = await get_test_client(
verify_certs=False, http_auth=("admin", "admin")
)
await add_connection("default", self.client)
async def asyncTearDown(self) -> None: # pylint: disable=invalid-name
async def asyncTearDown(
self,
) -> None:
# pylint: disable=invalid-name,missing-function-docstring
wipe_cluster(self.client)
if self.client:
await self.client.close()
@@ -60,7 +60,9 @@ class TestYarlMissing:
async def test_aiohttp_connection_works_without_yarl(
self, async_client: Any, monkeypatch: Any
) -> None:
# This is a defensive test case for if aiohttp suddenly stops using yarl.
"""
This is a defensive test case for if aiohttp suddenly stops using yarl.
"""
from opensearchpy._async import http_aiohttp
monkeypatch.setattr(http_aiohttp, "yarl", False)
@@ -43,12 +43,24 @@ async def client() -> Any:
@fixture(scope="function") # type: ignore
async def opensearch_version(client: Any) -> Any:
"""
yields the version of the OpenSearch cluster
:param client: client connection to OpenSearch
:return: yields major version number
"""
info = await client.info()
print(info)
yield tuple(
int(x)
for x in re.match(r"^([0-9.]+)", info["version"]["number"]).group(1).split(".") # type: ignore
)
yield (int(x) async for x in match_version(info))
async def match_version(info: Any) -> Any:
"""
matches the full semver server version with the given info
:param info: response from the OpenSearch cluster
"""
match = re.match(r"^([0-9.]+)", info["version"]["number"])
assert match is not None
yield match.group(1).split(".")
@fixture # type: ignore
@@ -60,7 +72,9 @@ async def write_client(client: Any) -> Any:
@fixture # type: ignore
async def data_client(client: Any) -> Any:
# create mappings
"""
create mappings
"""
await create_git_index(client, "git")
await create_flat_git_index(client, "flat-git")
# load data
@@ -73,6 +87,11 @@ async def data_client(client: Any) -> Any:
@fixture # type: ignore
async def pull_request(write_client: Any) -> Any:
"""
create dummy pull request instance
:param write_client: #todo not used
:return: instance of PullRequest
"""
await PullRequest.init()
pr = PullRequest(
_id=42,
@@ -96,7 +115,12 @@ async def pull_request(write_client: Any) -> Any:
@fixture # type: ignore
async def setup_ubq_tests(client: Any) -> str:
async def setup_update_by_query_tests(client: Any) -> str:
"""
sets up update by query tests
:param client:
:return: an index name
"""
index = "test-git"
await create_git_index(client, index)
await async_bulk(client, TEST_GIT_DATA, raise_on_error=True, refresh=True)
@@ -60,6 +60,9 @@ class FailingBulkClient(object):
self._fail_with = fail_with
async def bulk(self, *args: Any, **kwargs: Any) -> Any:
"""
increments number of times called and, when it equals fail_at, raises self.fail_with
"""
self._called += 1
if self._called in self._fail_at:
raise self._fail_with
@@ -56,6 +56,10 @@ class MetricSearch(AsyncFacetedSearch):
@pytest.fixture(scope="function") # type: ignore
def commit_search_cls(opensearch_version: Any) -> Any:
"""
:param opensearch_version the semver version of OpenSearch
:return: an AsyncFacetedSearch for git commits
"""
interval_kwargs = {"fixed_interval": "1d"}
class CommitSearch(AsyncFacetedSearch):
@@ -102,6 +106,10 @@ def repo_search_cls(opensearch_version: Any) -> Any:
@pytest.fixture(scope="function") # type: ignore
def pr_search_cls(opensearch_version: Any) -> Any:
"""
:param opensearch_version: not used here... #TODO remove this parameter?
:return: an AsyncFacetedSearch for pull requests
"""
interval_type = "calendar_interval"
class PRSearch(AsyncFacetedSearch):
@@ -19,9 +19,9 @@ pytestmark: MarkDecorator = pytest.mark.asyncio
async def test_update_by_query_no_script(
write_client: Any, setup_ubq_tests: Any
write_client: Any, setup_update_by_query_tests: Any
) -> None:
index = setup_ubq_tests
index = setup_update_by_query_tests
ubq = (
AsyncUpdateByQuery(using=write_client)
@@ -40,9 +40,9 @@ async def test_update_by_query_no_script(
async def test_update_by_query_with_script(
write_client: Any, setup_ubq_tests: Any
write_client: Any, setup_update_by_query_tests: Any
) -> None:
index = setup_ubq_tests
index = setup_update_by_query_tests
ubq = (
AsyncUpdateByQuery(using=write_client)
@@ -59,9 +59,9 @@ async def test_update_by_query_with_script(
async def test_delete_by_query_with_script(
write_client: Any, setup_ubq_tests: Any
write_client: Any, setup_update_by_query_tests: Any
) -> None:
index = setup_ubq_tests
index = setup_update_by_query_tests
ubq = (
AsyncUpdateByQuery(using=write_client)
@@ -54,6 +54,10 @@ OPENSEARCH_VERSION = None
async def await_if_coro(x: Any) -> Any:
"""
awaits if x is a coroutine
:return: x
"""
if inspect.iscoroutine(x):
return await x
return x
@@ -40,13 +40,15 @@ class TestSecurityPlugin(IsolatedAsyncioTestCase): # type: ignore
USER_NAME = "test-user"
USER_CONTENT = {"password": "opensearchpy@123", "opendistro_security_roles": []}
async def asyncSetUp(self) -> None: # pylint: disable=invalid-name
async def asyncSetUp(self) -> None:
# pylint: disable=invalid-name, missing-function-docstring
self.client = await get_test_client(
verify_certs=False, http_auth=("admin", "admin")
)
await add_connection("default", self.client)
async def asyncTearDown(self) -> None: # pylint: disable=invalid-name
async def asyncTearDown(self) -> None:
# pylint disable=invalid-name
if self.client:
await self.client.close()
@@ -449,8 +449,10 @@ class TestTransport:
assert event_loop.time() - 1 < t.last_sniff < event_loop.time() + 0.01
async def test_sniff_7x_publish_host(self) -> None:
# Test the response shaped when a 7.x node has publish_host set
# and the returend data is shaped in the fqdn/ip:port format.
"""
Test the response shaped when a 7.x node has publish_host set
and the returned data is shaped in the fqdn/ip:port format.
"""
t: Any = AsyncTransport(
[{"data": CLUSTER_NODES_7X_PUBLISH_HOST}],
connection_class=DummyConnection,
@@ -20,5 +20,6 @@ class TestPluginsClient(TestCase):
client.plugins.__init__(client) # type: ignore
self.assertEqual(
str(w.warnings[0].message),
"Cannot load `alerting` directly to OpenSearch as it already exists. Use `OpenSearch.plugin.alerting` instead.",
"Cannot load `alerting` directly to OpenSearch as "
"it already exists. Use `OpenSearch.plugin.alerting` instead.",
)
@@ -139,7 +139,8 @@ class TestRequestsHttpConnection(TestCase):
)
self.assertEqual(1, len(w))
self.assertEqual(
"Connecting to https://localhost:9200 using SSL with verify_certs=False is insecure.",
"Connecting to https://localhost:9200 using SSL with "
"verify_certs=False is insecure.",
str(w[0].message),
)
@@ -286,8 +287,9 @@ class TestRequestsHttpConnection(TestCase):
# trace request
self.assertEqual(1, tracer.info.call_count)
trace_curl_cmd = "curl -H 'Content-Type: application/json' -XGET 'http://localhost:9200/?pretty&param=42' -d '{\n \"question\": \"what\\u0027s that?\"\n}'" # pylint: disable=line-too-long
self.assertEqual(
"""curl -H 'Content-Type: application/json' -XGET 'http://localhost:9200/?pretty&param=42' -d '{\n "question": "what\\u0027s that?"\n}'""",
trace_curl_cmd,
tracer.info.call_args[0][0] % tracer.info.call_args[0][1:],
)
# trace response
@@ -415,9 +417,13 @@ class TestRequestsHttpConnection(TestCase):
self.assertEqual('{"answer": 42}'.encode("utf-8"), request.body)
# trace request
trace_curl_cmd = (
"curl -H 'Content-Type: application/json' -XGET 'http://localhost:9200/_search?pretty' "
"-d '{\n \"answer\": 42\n}'"
)
self.assertEqual(1, tracer.info.call_count)
self.assertEqual(
"curl -H 'Content-Type: application/json' -XGET 'http://localhost:9200/_search?pretty' -d '{\n \"answer\": 42\n}'",
trace_curl_cmd,
tracer.info.call_args[0][0] % tracer.info.call_args[0][1:],
)
@@ -514,7 +520,7 @@ class TestRequestsConnectionRedirect(TestCase):
@classmethod
def setup_class(cls) -> None:
# Start servers
"""Start servers"""
cls.server1 = TestHTTPServer(port=8080)
cls.server1.start()
cls.server2 = TestHTTPServer(port=8090)
@@ -522,7 +528,7 @@ class TestRequestsConnectionRedirect(TestCase):
@classmethod
def teardown_class(cls) -> None:
# Stop servers
"""Stop servers"""
cls.server2.stop()
cls.server1.stop()
@@ -73,7 +73,8 @@ class TestUrllib3HttpConnection(TestCase):
# it means SSLContext is not available for that version of python
# and we should skip this test.
raise SkipTest(
"Test test_ssl_context is skipped cause SSLContext is not available for this version of python"
"Test test_ssl_context is skipped cause SSLContext"
" is not available for this version of python"
)
con = Urllib3HttpConnection(use_ssl=True, ssl_context=context)
@@ -272,7 +273,8 @@ class TestUrllib3HttpConnection(TestCase):
con = Urllib3HttpConnection(use_ssl=True, verify_certs=False)
self.assertEqual(1, len(w))
self.assertEqual(
"Connecting to https://localhost:9200 using SSL with verify_certs=False is insecure.",
"Connecting to https://localhost:9200 using SSL with "
"verify_certs=False is insecure.",
str(w[0].message),
)
+1 -1
View File
@@ -112,7 +112,7 @@ class TestConnectionPool(TestCase):
# Nothing should be marked dead
self.assertEqual(0, len(pool.dead_count))
def test_connection_is_forcibly_resurrected_when_no_live_ones_are_availible(
def test_connection_is_forcibly_resurrected_when_no_live_ones_are_available(
self,
) -> None:
pool = ConnectionPool([(x, {}) for x in range(2)])
@@ -136,7 +136,12 @@ class TestParallelBulk(TestCase):
class TestChunkActions(TestCase):
def setup_method(self, _: Any) -> None:
self.actions: Any = [({"index": {}}, {"some": u"datá", "i": i}) for i in range(100)] # fmt: skip
"""
creates some documents for testing
"""
self.actions: Any = [
({"index": {}}, {"some": "datá", "i": i}) for i in range(100)
]
def test_expand_action(self) -> None:
self.assertEqual(helpers.expand_action({}), ({"index": {}}, {}))
@@ -280,7 +285,9 @@ class TestScanFunction(TestCase):
def test_scan_with_missing_hits_key(
self, mock_search: Mock, mock_scroll: Mock, mock_clear_scroll: Mock
) -> None:
# Simulate a response where the 'hits' key is missing
"""
Simulate a response where the 'hits' key is missing
"""
mock_search.return_value = {"_scroll_id": "dummy_scroll_id", "_shards": {}}
mock_scroll.side_effect = [{"_scroll_id": "dummy_scroll_id", "_shards": {}}]
@@ -38,6 +38,11 @@ from opensearchpy.helpers.response.aggs import AggResponse, Bucket, BucketData
@fixture # type: ignore
def agg_response(aggs_search: Any, aggs_data: Any) -> Any:
"""
:param aggs_search: aggregation search
:param aggs_data: data to aggregate
:return: the aggregated data
"""
return response.Response(aggs_search, aggs_data)
+9
View File
@@ -17,6 +17,9 @@ class TestHTTPRequestHandler(BaseHTTPRequestHandler):
__test__ = False
def do_GET(self) -> None: # pylint: disable=invalid-name
"""
writes a response out to a file given mocked parameters on this object
"""
headers = self.headers
if self.path == "/redirect":
@@ -49,6 +52,9 @@ class TestHTTPServer(HTTPServer):
self._server_thread = None
def start(self) -> None:
"""
start the test HTTP server
"""
if self._server_thread is not None:
return
@@ -56,6 +62,9 @@ class TestHTTPServer(HTTPServer):
self._server_thread.start()
def stop(self) -> None:
"""
stop the test HTTP server
"""
if self._server_thread is None:
return
self.socket.close()
@@ -53,12 +53,23 @@ def client() -> Any:
@fixture(scope="session") # type: ignore
def opensearch_version(client: Any) -> Any:
info = client.info()
"""
yields a major version from the client
:param client: client to connect to opensearch
"""
info: Any = client.info()
print(info)
yield tuple(
int(x)
for x in re.match(r"^([0-9.]+)", info["version"]["number"]).group(1).split(".") # type: ignore
)
yield (int(x) for x in match_version(info))
def match_version(info: Any) -> Any:
"""
matches the major version from the given client info
:param info: part of the response from OpenSearch
"""
match = re.match(r"^([0-9.]+)", info["version"]["number"])
assert match is not None
yield match.group(1).split(".")
@fixture # type: ignore
@@ -107,6 +118,7 @@ def pull_request(write_client: Any) -> Any:
@fixture # type: ignore
def setup_ubq_tests(client: Any) -> str:
# todo what's a ubq test?
index = "test-git"
create_git_index(client, index)
bulk(client, TEST_GIT_DATA, raise_on_error=True, refresh=True)
@@ -509,7 +509,9 @@ class TestScan(OpenSearchTestCase):
}
client_mock.clear_scroll.return_value = {}
data = list(helpers.scan(self.client, index="test_index", **{key: val})) # type: ignore
data = list(
helpers.scan(self.client, index="test_index", **{key: val}) # type: ignore
)
self.assertEqual(data, [{"search_data": 1}])
@@ -89,6 +89,7 @@ SKIP_TESTS = {
"OpenSearch-main/rest-api-spec/src/main/resources/rest-api-spec/test/search/aggregation/20_terms[4]",
"OpenSearch-main/rest-api-spec/src/main/resources/rest-api-spec/test/tasks/list/10_basic[0]",
"OpenSearch-main/rest-api-spec/src/main/resources/rest-api-spec/test/index/90_unsigned_long[1]",
"OpenSearch-main/rest-api-spec/src/main/resources/rest-api-spec/test/indices/stats/50_noop_update[0]",
"search/aggregation/250_moving_fn[1]",
# body: null
"indices/simulate_index_template/10_basic[2]",
@@ -157,7 +158,7 @@ class YamlRunner:
self._teardown_code = test_spec.pop("teardown", None)
def setup(self) -> Any:
# Pull skips from individual tests to not do unnecessary setup.
"""Pull skips from individual tests to not do unnecessary setup."""
skip_code: Any = []
for action in self._run_code:
assert len(action) == 1
@@ -472,7 +473,7 @@ client = get_client()
def load_rest_api_tests() -> None:
# Try loading the REST API test specs from OpenSearch core.
"""Try loading the REST API test specs from OpenSearch core."""
try:
# Construct the HTTP and OpenSearch client
http = urllib3.PoolManager(retries=10)
+36 -15
View File
@@ -45,6 +45,9 @@ TMP_DIR = None
@contextlib.contextmanager # type: ignore
def set_tmp_dir() -> None:
"""
makes and yields a temporary directory for any working files needed for a process during a build
"""
global TMP_DIR
TMP_DIR = tempfile.mkdtemp()
yield TMP_DIR
@@ -53,6 +56,13 @@ def set_tmp_dir() -> None:
def run(*argv: Any, expect_exit_code: int = 0) -> None:
"""
runs a command within this script
:param argv: command to run e.g. "git" "checkout" "--" "setup.py" "opensearchpy/"
:param expect_exit_code: code to compare with actual exit code from command.
will exit the process if they do not
match the proper exit code
"""
global TMP_DIR
if TMP_DIR is None:
os.chdir(BASE_DIR)
@@ -71,6 +81,10 @@ def run(*argv: Any, expect_exit_code: int = 0) -> None:
def test_dist(dist: Any) -> None:
"""
validate that the distribution created works
:param dist: base directory of the distribution
"""
with set_tmp_dir() as tmp_dir: # type: ignore
dist_name = re.match( # type: ignore
r"^(opensearchpy\d*)-",
@@ -181,17 +195,24 @@ def test_dist(dist: Any) -> None:
def main() -> None:
"""
creates a distribution given of the OpenSearch python client
Notes: does not run on MacOS; this script is generally driven by a GitHub Action located in
.github/workflows/unified-release.yml
"""
run("git", "checkout", "--", "setup.py", "opensearchpy/")
run("rm", "-rf", "build/", "dist/*", "*.egg-info", ".eggs")
run("python", "setup.py", "sdist", "bdist_wheel")
# Grab the major version to be used as a suffix.
version_path = os.path.join(BASE_DIR, "opensearchpy/_version.py")
with open(version_path) as f:
data = f.read()
m = re.search(r"^__versionstr__: str\s+=\s+[\"\']([^\"\']+)[\"\']", data, re.M)
if m:
version = m.group(1)
with open(version_path, encoding="utf-8") as file:
data = file.read()
version_match = re.search(
r"^__versionstr__: str\s+=\s+[\"\']([^\"\']+)[\"\']", data, re.M
)
if version_match:
version = version_match.group(1)
else:
raise Exception(f"Invalid version: {data}")
@@ -254,25 +275,25 @@ def main() -> None:
# Ensure that the version within 'opensearchpy/_version.py' is correct.
version_path = os.path.join(BASE_DIR, f"opensearchpy{suffix}/_version.py")
with open(version_path) as f:
version_data = f.read()
with open(version_path, encoding="utf-8") as file:
version_data = file.read()
version_data = re.sub(
r"__versionstr__: str = \"[^\"]+\"",
'__versionstr__: str = "%s"' % version,
version_data,
)
with open(version_path, "w") as f:
f.truncate()
f.write(version_data)
with open(version_path, "w", encoding="utf-8") as file:
file.truncate()
file.write(version_data)
# Rewrite setup.py with the new name.
setup_py_path = os.path.join(BASE_DIR, "setup.py")
with open(setup_py_path) as f:
setup_py = f.read()
with open(setup_py_path, "w") as f:
f.truncate()
with open(setup_py_path, encoding="utf-8") as file:
setup_py = file.read()
with open(setup_py_path, "w", encoding="utf-8") as file:
file.truncate()
assert 'PACKAGE_NAME = "opensearch-py"' in setup_py
f.write(
file.write(
setup_py.replace(
'PACKAGE_NAME = "opensearch-py"',
'PACKAGE_NAME = "opensearch-py%s"' % suffix,
+150 -71
View File
@@ -80,6 +80,10 @@ jinja_env = Environment(
def blacken(filename: Any) -> None:
"""
runs 'black' https://pypi.org/project/black/ on the given file
:param filename: file to reformant
"""
runner = CliRunner()
result = runner.invoke(black.main, [str(filename)])
assert result.exit_code == 0, result.output
@@ -87,6 +91,11 @@ def blacken(filename: Any) -> None:
@lru_cache()
def is_valid_url(url: str) -> bool:
"""
makes a call to the url
:param url: url to check
:return: True if status code is between HTTP 200 inclusive and 400 exclusive; False otherwise
"""
return 200 <= http.request("HEAD", url).status < 400
@@ -97,17 +106,24 @@ class Module:
self.parse_orig()
def add(self, api: Any) -> None:
"""
add an API to the list of modules
:param api: an API object
"""
self._apis.append(api)
def parse_orig(self) -> None:
"""
reads the written module and updates with important code specific to this client
"""
self.orders = []
self.header = "from typing import Any, Collection, Optional, Tuple, Union\n\n"
namespace_new = "".join(word.capitalize() for word in self.namespace.split("_"))
self.header += "class " + namespace_new + "Client(NamespacedClient):"
if os.path.exists(self.filepath):
with open(self.filepath) as f:
content = f.read()
with open(self.filepath, encoding="utf-8") as file:
content = file.read()
header_lines = []
for line in content.split("\n"):
header_lines.append(line)
@@ -122,7 +138,7 @@ class Module:
if "security.py" in str(self.filepath):
# TODO: FIXME, import code
header_lines.append(
" from ._patch import health_check, update_audit_config # type: ignore"
" from ._patch import health_check, update_audit_config # type: ignore" # pylint: disable=line-too-long
)
break
self.header = "\n".join(header_lines)
@@ -137,14 +153,21 @@ class Module:
return len(self.orders)
def sort(self) -> None:
"""
sorts the list of APIs by the Module._position key
"""
self._apis.sort(key=self._position)
def dump(self) -> None:
"""
writes the module out to disk
"""
self.sort()
# This code snippet adds headers to each generated module indicating that the code is generated.
# The separator is the last line in the "THIS CODE IS AUTOMATICALLY GENERATED" header.
header_separator = "# -----------------------------------------------------------------------------------------+"
# This code snippet adds headers to each generated module indicating
# that the code is generated.The separator is the last line in the
# "THIS CODE IS AUTOMATICALLY GENERATED" header.
header_separator = "# -----------------------------------------------------------------------------------------+" # pylint: disable=line-too-long
license_header_end_1 = "# GitHub history for details."
license_header_end_2 = "# under the License."
@@ -153,8 +176,8 @@ class Module:
# Identifying the insertion point for the "THIS CODE IS AUTOMATICALLY GENERATED" header.
if os.path.exists(self.filepath):
with open(self.filepath, "r") as f:
content = f.read()
with open(self.filepath, "r", encoding="utf-8") as file:
content = file.read()
if header_separator in content:
update_header = False
header_end_position = (
@@ -180,17 +203,18 @@ class Module:
generated_file_header_path = os.path.join(
current_script_folder, "generated_file_headers.txt"
)
with open(generated_file_header_path, "r") as header_file:
with open(generated_file_header_path, "r", encoding="utf-8") as header_file:
header_content = header_file.read()
# Imports are temporarily removed from the header and are regenerated later to ensure imports are updated after code generation.
# Imports are temporarily removed from the header and are regenerated
# later to ensure imports are updated after code generation.
self.header = "\n".join(
line for line in self.header.split("\n") if "from .utils import" not in line
)
with open(self.filepath, "w") as f:
with open(self.filepath, "w", encoding="utf-8") as file:
if update_header is True:
f.write(
file.write(
self.header[:license_position]
+ "\n"
+ header_content
@@ -199,20 +223,20 @@ class Module:
+ self.header[license_position:]
)
else:
f.write(
file.write(
self.header[:header_position]
+ "\n"
+ "#replace_token#\n"
+ self.header[header_position:]
)
for api in self._apis:
f.write(api.to_python())
file.write(api.to_python())
# Generating imports for each module
utils_imports = ""
file_content = ""
with open(self.filepath, "r") as f:
content = f.read()
with open(self.filepath, "r", encoding="utf-8") as file:
content = file.read()
keywords = [
"SKIP_IN_PATH",
"_normalize_hosts",
@@ -232,11 +256,14 @@ class Module:
utils_imports = result
file_content = content.replace("#replace_token#", utils_imports)
with open(self.filepath, "w") as f:
f.write(file_content)
with open(self.filepath, "w", encoding="utf-8") as file:
file.write(file_content)
@property
def filepath(self) -> Any:
"""
:return: absolute path to the module
"""
return CODE_ROOT / f"opensearchpy/_async/client/{self.namespace}.py"
@@ -287,16 +314,20 @@ class API:
@property
def all_parts(self) -> Dict[str, str]:
"""
updates the url parts from the specification
:return: dict of updated parts
"""
parts = {}
for url in self._def["url"]["paths"]:
parts.update(url.get("parts", {}))
for p in parts:
if "required" not in parts[p]:
parts[p]["required"] = all(
p in url.get("parts", {}) for url in self._def["url"]["paths"]
for part in parts:
if "required" not in parts[part]:
parts[part]["required"] = all(
part in url.get("parts", {}) for url in self._def["url"]["paths"]
)
parts[p]["type"] = "Any"
parts[part]["type"] = "Any"
# This piece of logic corresponds to calling
# client.tasks.get() w/o a task_id which was erroneously
@@ -322,6 +353,9 @@ class API:
@property
def params(self) -> Any:
"""
:return: itertools.chain of required parts of the API
"""
parts = self.all_parts
params = self._def.get("params", {})
return chain(
@@ -337,24 +371,30 @@ class API:
@property
def body(self) -> Any:
b = self._def.get("body", {})
if b:
b.setdefault("required", False)
return b
"""
:return: body of the API spec
"""
body_api_spec = self._def.get("body", {})
if body_api_spec:
body_api_spec.setdefault("required", False)
return body_api_spec
@property
def query_params(self) -> Any:
"""
:return: any query string parameters from the specification
"""
return (
k
for k in sorted(self._def.get("params", {}).keys())
if k not in self.all_parts
key
for key in sorted(self._def.get("params", {}).keys())
if key not in self.all_parts
)
@property
def all_func_params(self) -> Any:
"""Parameters that will be in the '@query_params' decorator list
"""
Parameters that will be in the '@query_params' decorator list
and parameters that will be in the function signature.
This doesn't include
"""
params = list(self._def.get("params", {}).keys())
for url in self._def["url"]["paths"]:
@@ -365,6 +405,9 @@ class API:
@property
def path(self) -> Any:
"""
:return: the first lexically ordered path in url.paths
"""
return max(
(path for path in self._def["url"]["paths"]),
key=lambda p: len(re.findall(r"\{([^}]+)\}", p["path"])),
@@ -372,8 +415,12 @@ class API:
@property
def method(self) -> Any:
# To adhere to the HTTP RFC we shouldn't send
# bodies in GET requests.
"""
To adhere to the HTTP RFC we shouldn't send
bodies in GET requests.
:return: an updated HTTP method to use to communicate with the OpenSearch API
"""
default_method = self.path["methods"][0]
if self.name == "refresh" or self.name == "flush":
return "POST"
@@ -385,6 +432,9 @@ class API:
@property
def url_parts(self) -> Any:
"""
:return tuple of boolean (if the path is dynamic), list of url parts
"""
path = self.path["path"]
dynamic = "{" in path
@@ -406,6 +456,9 @@ class API:
@property
def required_parts(self) -> Any:
"""
:return: list of parts of the url that are required plus the body
"""
parts = self.all_parts
required = [p for p in parts if parts[p]["required"]] # type: ignore
if self.body.get("required"):
@@ -413,12 +466,15 @@ class API:
return required
def to_python(self) -> Any:
"""
:return: rendered Jinja template
"""
try:
t = jinja_env.get_template(f"overrides/{self.namespace}/{self.name}")
template = jinja_env.get_template(f"overrides/{self.namespace}/{self.name}")
except TemplateNotFound:
t = jinja_env.get_template("base")
template = jinja_env.get_template("base")
return t.render(
return template.render(
api=self,
substitutions={v: k for k, v in SUBSTITUTIONS.items()},
global_query_params=GLOBAL_QUERY_PARAMS,
@@ -426,52 +482,61 @@ class API:
def read_modules() -> Any:
"""
checks the opensearch-api spec at
https://raw.githubusercontent.com/opensearch-project/opensearch-api-specification/main/OpenSearch.openapi.json
and parses it into one or more API modules
:return: a dict of API objects
"""
modules = {}
# Load the OpenAPI specification file
response = requests.get(
"https://raw.githubusercontent.com/opensearch-project/opensearch-api-specification/main/OpenSearch.openapi.json"
"https://raw.githubusercontent.com/opensearch-project/opensearch-api-"
"specification/main/OpenSearch.openapi.json"
)
data = response.json()
list_of_dicts = []
for path in data["paths"]:
for x in data["paths"][path]:
if data["paths"][path][x]["x-operation-group"] == "nodes.hot_threads":
if "deprecated" in data["paths"][path][x]:
for param in data["paths"][path]: # pylint: disable=invalid-name
if data["paths"][path][param]["x-operation-group"] == "nodes.hot_threads":
if "deprecated" in data["paths"][path][param]:
continue
data["paths"][path][x].update({"path": path, "method": x})
list_of_dicts.append(data["paths"][path][x])
data["paths"][path][param].update({"path": path, "method": param})
list_of_dicts.append(data["paths"][path][param])
# Update parameters in each endpoint
for p in list_of_dicts:
if "parameters" in p:
for param_dict in list_of_dicts:
if "parameters" in param_dict:
params = []
parts = []
# Iterate over the list of parameters and update them
for x in p["parameters"]:
if "schema" in x and "$ref" in x["schema"]:
schema_path_ref = x["schema"]["$ref"].split("/")[-1]
x["schema"] = data["components"]["schemas"][schema_path_ref]
params.append(x)
for param in param_dict["parameters"]:
if "schema" in param and "$ref" in param["schema"]:
schema_path_ref = param["schema"]["$ref"].split("/")[-1]
param["schema"] = data["components"]["schemas"][schema_path_ref]
params.append(param)
else:
params.append(x)
params.append(param)
# Iterate over the list of updated parameters to separate "parts" from "params"
k = params.copy()
for q in k:
if q["in"] == "path":
parts.append(q)
params.remove(q)
params_copy = params.copy()
for param in params_copy:
if param["in"] == "path":
parts.append(param)
params.remove(param)
# Convert "params" and "parts" into the structure required for generator.
params_new = {}
parts_new = {}
for m in params:
a = dict(type=m["schema"]["type"], description=m["description"])
for m in params: # pylint: disable=invalid-name
a = dict( # pylint: disable=invalid-name
type=m["schema"]["type"], description=m["description"]
) # pylint: disable=invalid-name
if "default" in m["schema"]:
a.update({"default": m["schema"]["default"]})
@@ -488,22 +553,25 @@ def read_modules() -> Any:
params_new.update({m["name"]: a})
# Removing the deprecated "type"
if p["x-operation-group"] != "nodes.hot_threads" and "type" in params_new:
if (
param_dict["x-operation-group"] != "nodes.hot_threads"
and "type" in params_new
):
params_new.pop("type")
if (
p["x-operation-group"] == "cluster.health"
param_dict["x-operation-group"] == "cluster.health"
and "ensure_node_commissioned" in params_new
):
params_new.pop("ensure_node_commissioned")
if bool(params_new):
p.update({"params": params_new})
param_dict.update({"params": params_new})
p.pop("parameters")
param_dict.pop("parameters")
for n in parts:
b = dict(type=n["schema"]["type"])
for n in parts: # pylint: disable=invalid-name
b = dict(type=n["schema"]["type"]) # pylint: disable=invalid-name
if "description" in n:
b.update({"description": n["description"]})
@@ -526,7 +594,7 @@ def read_modules() -> Any:
parts_new.update({n["name"]: b})
if bool(parts_new):
p.update({"parts": parts_new})
param_dict.update({"parts": parts_new})
# Sort the input list by the value of the "x-operation-group" key
list_of_dicts = sorted(list_of_dicts, key=itemgetter("x-operation-group"))
@@ -550,7 +618,7 @@ def read_modules() -> Any:
# Extract the HTTP methods from the data in the current subgroup
methods = []
parts_final = {}
for z in value2:
for z in value2: # pylint: disable=invalid-name
methods.append(z["method"].upper())
# Update 'api' dictionary
@@ -570,9 +638,9 @@ def read_modules() -> Any:
body = {"required": False}
if "required" in z["requestBody"]:
body.update({"required": z["requestBody"]["required"]})
q = z["requestBody"]["content"]["application/json"]["schema"][
"$ref"
].split("/")[-1]
q = z["requestBody"]["content"][ # pylint: disable=invalid-name
"application/json"
]["schema"]["$ref"].split("/")[-1]
if "description" in data["components"]["schemas"][q]:
body.update(
{
@@ -644,17 +712,28 @@ def read_modules() -> Any:
def apply_patch(namespace: str, name: str, api: Any) -> Any:
"""
applies patches as specified in {name}.json
:param namespace: directory containing overrides
:param name: file to be prepended to ".json" containing override instructions
:param api: specific api to override
:return: modified api
"""
override_file_path = (
CODE_ROOT / "utils/templates/overrides" / namespace / f"{name}.json"
)
if os.path.exists(override_file_path):
with open(override_file_path) as f:
override_json = json.load(f)
with open(override_file_path, encoding="utf-8") as file:
override_json = json.load(file)
api = deepmerge.always_merger.merge(api, override_json)
return api
def dump_modules(modules: Any) -> None:
"""
writes out modules to disk
:param modules: a list of python modules
"""
for mod in modules.values():
mod.dump()
+22 -7
View File
@@ -47,11 +47,16 @@ def find_files_to_fix(sources: List[str]) -> Iterator[str]:
def does_file_need_fix(filepath: str) -> bool:
"""
checks if the correct license header exists at the top of the file
:param filepath: an absolute or relative filepath to a file to check
:return: True if the file needs a header, False if it does not
"""
if not re.search(r"\.py$", filepath):
return False
existing_header = ""
with open(filepath, mode="r") as f:
for line in f:
with open(filepath, mode="r", encoding="utf-8") as file:
for line in file:
line = line.strip()
if len(line) == 0 or line in LINES_TO_KEEP:
pass
@@ -64,20 +69,30 @@ def does_file_need_fix(filepath: str) -> bool:
def add_header_to_file(filepath: str) -> None:
with open(filepath, mode="r") as f:
lines = list(f)
"""
writes the license header to the beginning of a file
:param filepath: relative or absolute filepath to update
"""
with open(filepath, mode="r", encoding="utf-8") as file:
lines = list(file)
i = 0
for i, line in enumerate(lines):
if len(line) > 0 and line not in LINES_TO_KEEP:
break
lines = lines[:i] + [LICENSE_HEADER] + lines[i:]
with open(filepath, mode="w") as f:
f.truncate()
f.write("".join(lines))
with open(filepath, mode="w", encoding="utf-8") as file:
file.truncate()
file.write("".join(lines))
print(f"Fixed {os.path.relpath(filepath, os.getcwd())}")
def main() -> None:
"""
arguments:
fix: find all files without license headers and insert headers at the top of the file
check: prints a list of files without license headers
list of one or more directories: search in these directories
"""
mode = sys.argv[1]
assert mode in ("fix", "check")
sources = [os.path.abspath(x) for x in sys.argv[2:]]