Pylint integration updates (#643)

* updated files with docstrings to pass pylint

Signed-off-by: Mark Cohen <markcoh@amazon.com>

* updated samples to prepare for enabling missing-docstring linter; will continue to work on this before committing setup.cfg

Signed-off-by: Mark Cohen <markcoh@amazon.com>

* removed missing-function-docstring from setup.cfg so the linter doesn't fail while work on docstrings continues

Signed-off-by: Mark Cohen <markcoh@amazon.com>

* corrected unnecessary return docstring values

Signed-off-by: Mark Cohen <markcoh@amazon.com>

* fixing failure in 'black' on reformatting

Signed-off-by: Mark Cohen <markcoh@amazon.com>

* updated utils to pass missing-function-docstring tests

Signed-off-by: Mark Cohen <markcoh@amazon.com>

* updated functions with missing docstrings or pylint ignore instructions; added a utility to automatically add these ignore instructions to most functions that should be self-describing; rolled back some automatically generated code mistakenly changed

Signed-off-by: Mark Cohen <markcoh@amazon.com>

* * ignoring opensearchpy for pylint and then added it back to noxfile.py
* fixed some lints; created a feature flag for newer dynamic pylint so now lints can be fixed first in legacy code and then enabled by multiple people
* extracted a method for per-folder linting
* updated noxfile.lint_per_folder with type hints
* enabled unspecified-encoding in pylint
* added disable missing-function-docstring pragma to test_clients.py in test_async and test_server
* added more encodings to pass unspecified-encoding pylint tests
* updated changelog
Signed-off-by: Mark Cohen <markcoh@amazon.com>

* updated CHANGELOG.md entry
removed the feature flag for pylint lint_per_folder
fixed failures from mypy and pylint
removed pylint MESSAGE CONTROL config from setup.cfg after relocating to lint_per_folder method
Signed-off-by: Mark Cohen <markcoh@amazon.com>

* removed pylint ignore missing-function-docstring

Signed-off-by: Mark Cohen <markcoh@amazon.com>

* added pylint.extensions.docparams plugin

updated some docstrings to correct parameters

removed pylint from setup.cfg

Signed-off-by: Mark Cohen <markcoh@amazon.com>

* added four lints for opensearchpy/

Signed-off-by: Mark Cohen <markcoh@amazon.com>

* adding await back to client.info() call

Signed-off-by: Mark Cohen <markcoh@amazon.com>

* updated TODOs as requested

renamed test_opensearchpy.test_async.test_server.test_helpers.conftest.setup_ubq_tests to setup_update_by_query_tests

added
OpenSearch-main/rest-api-spec/src/main/resources/rest-api-spec/test/indices/stats/50_noop_update[0]
to skip tests list

run_tests.py catches a CalledProcessError when the git repo already exists and the command to add the origin fails in fetch_opensearch_repo()

Signed-off-by: Mark Cohen <markcoh@amazon.com>

---------

Signed-off-by: Mark Cohen <markcoh@amazon.com>
This commit is contained in:
Mark Cohen
2024-01-19 13:36:05 -05:00
committed by GitHub
parent 2ab3a40307
commit 0ddbf8cafa
47 changed files with 563 additions and 218 deletions
+1
View File
@@ -3,6 +3,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
## [Unreleased] ## [Unreleased]
### Added ### Added
- Added pylint `unspecified-encoding` and `missing-function-docstring` and ignored opensearchpy for lints (([#643](https://github.com/opensearch-project/opensearch-py/pull/643)))
- Added pylint `line-too-long` and `invalid-name` ([#590](https://github.com/opensearch-project/opensearch-py/pull/590)) - Added pylint `line-too-long` and `invalid-name` ([#590](https://github.com/opensearch-project/opensearch-py/pull/590))
- Added pylint `pointless-statement` ([#611](https://github.com/opensearch-project/opensearch-py/pull/611)) - Added pylint `pointless-statement` ([#611](https://github.com/opensearch-project/opensearch-py/pull/611))
- Added a log collection guide ([#579](https://github.com/opensearch-project/opensearch-py/pull/579)) - Added a log collection guide ([#579](https://github.com/opensearch-project/opensearch-py/pull/579))
+3 -2
View File
@@ -36,8 +36,9 @@ async def index_records(client: Any, index_name: str, item_count: int) -> None:
async def test_async(client_count: int = 1, item_count: int = 1) -> None: async def test_async(client_count: int = 1, item_count: int = 1) -> None:
""" """
asynchronously index with item_count records and run client_count clients. This function can be used to asynchronously index with item_count records and run client_count
test balancing the number of items indexed with the number of documents. clients. This function can be used to test balancing the number of
items indexed with the number of documents.
""" """
host = "localhost" host = "localhost"
port = 9200 port = 9200
+6 -7
View File
@@ -22,13 +22,12 @@ from opensearchpy import OpenSearch
def get_info(client: Any, request_count: int) -> float: def get_info(client: Any, request_count: int) -> float:
"""get info from client""" """get info from client"""
tt: float = 0 total_time: float = 0
for n in range(request_count): for request in range(request_count):
start = time.time() * 1000 start = time.time() * 1000
client.info() client.info()
total_time = time.time() * 1000 - start total_time += time.time() * 1000 - start
tt += total_time return total_time
return tt
def test(thread_count: int = 1, request_count: int = 1, client_count: int = 1) -> None: def test(thread_count: int = 1, request_count: int = 1, client_count: int = 1) -> None:
@@ -71,8 +70,8 @@ def test(thread_count: int = 1, request_count: int = 1, client_count: int = 1) -
thread.start() thread.start()
latency = 0 latency = 0
for t in threads: for thread in threads:
latency += t.join() latency += thread.join()
print(f"latency={latency}") print(f"latency={latency}")
+15 -15
View File
@@ -23,29 +23,29 @@ from opensearchpy import OpenSearch, Urllib3HttpConnection
def index_records(client: Any, index_name: str, item_count: int) -> Any: def index_records(client: Any, index_name: str, item_count: int) -> Any:
"""bulk index item_count records into index_name""" """bulk index item_count records into index_name"""
tt = 0 total_time = 0
for n in range(10): for iteration in range(10):
data: Any = [] data: Any = []
for i in range(item_count): for item in range(item_count):
data.append( data.append(
json.dumps({"index": {"_index": index_name, "_id": str(uuid.uuid4())}}) json.dumps({"index": {"_index": index_name, "_id": str(uuid.uuid4())}})
) )
data.append(json.dumps({"value": i})) data.append(json.dumps({"value": item}))
data = "\n".join(data) data = "\n".join(data)
start = time.time() * 1000 start = time.time() * 1000
rc = client.bulk(data) response = client.bulk(data)
if rc["errors"]: if response["errors"]:
raise Exception(rc["errors"]) raise Exception(response["errors"])
server_time = rc["took"] server_time = response["took"]
total_time = time.time() * 1000 - start this_time = time.time() * 1000 - start
if total_time < server_time: if this_time < server_time:
raise Exception(f"total={total_time} < server={server_time}") raise Exception(f"total={this_time} < server={server_time}")
tt += total_time - server_time total_time += this_time - server_time
return tt return total_time
def test(thread_count: int = 1, item_count: int = 1, client_count: int = 1) -> None: def test(thread_count: int = 1, item_count: int = 1, client_count: int = 1) -> None:
@@ -105,8 +105,8 @@ def test(thread_count: int = 1, item_count: int = 1, client_count: int = 1) -> N
thread.start() thread.start()
latency = 0 latency = 0
for t in threads: for thread in threads:
latency += t.join() latency += thread.join()
clients[0].indices.refresh(index=index_name) clients[0].indices.refresh(index=index_name)
count = clients[0].count(index=index_name) count = clients[0].count(index=index_name)
+80 -2
View File
@@ -25,7 +25,7 @@
# under the License. # under the License.
from typing import Any from typing import Any, List
import nox import nox
@@ -43,6 +43,10 @@ SOURCE_FILES = (
@nox.session(python=["3.6", "3.7", "3.8", "3.9", "3.10", "3.11"]) # type: ignore @nox.session(python=["3.6", "3.7", "3.8", "3.9", "3.10", "3.11"]) # type: ignore
def test(session: Any) -> None: def test(session: Any) -> None:
"""
runs all tests with a fresh python environment using "python setup.py test"
:param session: current nox session
"""
session.install(".") session.install(".")
# ensure client can be imported without aiohttp # ensure client can be imported without aiohttp
session.run("python", "-c", "import opensearchpy\nprint(opensearchpy.OpenSearch())") session.run("python", "-c", "import opensearchpy\nprint(opensearchpy.OpenSearch())")
@@ -59,6 +63,10 @@ def test(session: Any) -> None:
@nox.session(python=["3.7"]) # type: ignore @nox.session(python=["3.7"]) # type: ignore
def format(session: Any) -> None: def format(session: Any) -> None:
"""
runs black and isort to format the files accordingly
:param session: current nox session
"""
session.install(".") session.install(".")
session.install("black", "isort") session.install("black", "isort")
@@ -71,6 +79,10 @@ def format(session: Any) -> None:
@nox.session(python=["3.7"]) # type: ignore @nox.session(python=["3.7"]) # type: ignore
def lint(session: Any) -> None: def lint(session: Any) -> None:
"""
runs isort, black, flake8, pylint, and mypy to check the files according to each utility's function
:param session: current nox session
"""
session.install( session.install(
"flake8", "flake8",
"black", "black",
@@ -89,7 +101,9 @@ def lint(session: Any) -> None:
session.run("isort", "--check", *SOURCE_FILES) session.run("isort", "--check", *SOURCE_FILES)
session.run("black", "--check", *SOURCE_FILES) session.run("black", "--check", *SOURCE_FILES)
session.run("flake8", *SOURCE_FILES) session.run("flake8", *SOURCE_FILES)
session.run("pylint", *SOURCE_FILES)
lint_per_folder(session)
session.run("python", "utils/license_headers.py", "check", *SOURCE_FILES) session.run("python", "utils/license_headers.py", "check", *SOURCE_FILES)
# Workaround to make '-r' to still work despite uninstalling aiohttp below. # Workaround to make '-r' to still work despite uninstalling aiohttp below.
@@ -108,8 +122,68 @@ def lint(session: Any) -> None:
session.run("mypy", "--strict", "test_opensearchpy/test_types/sync_types.py") session.run("mypy", "--strict", "test_opensearchpy/test_types/sync_types.py")
def lint_per_folder(session: Any) -> None:
"""
allows configuration of pylint rules per folder and runs a pylint command for each folder
:param session: the current nox session
"""
# any paths that should not be run through pylint
exclude_path_from_linting: List[str] = []
# all paths not referenced in override_enable will run these lints
default_enable = [
"line-too-long",
"invalid-name",
"pointless-statement",
"unspecified-encoding",
"missing-function-docstring",
"missing-param-doc",
"differing-param-doc",
]
override_enable = {
"test_opensearchpy/": [
"line-too-long",
# "invalid-name", lots of short functions with one or two character names
"pointless-statement",
"unspecified-encoding",
"missing-param-doc",
"differing-param-doc",
# "missing-function-docstring", test names usually are, self describing
],
"opensearchpy/": [
"line-too-long",
"invalid-name",
"pointless-statement",
"unspecified-encoding",
],
}
for source_file in SOURCE_FILES:
if source_file in exclude_path_from_linting:
continue
args = [
"--disable=all",
"--max-line-length=240",
"--good-names-rgxs=^[_a-z][_a-z0-9]?$",
"--load-plugins",
"pylint.extensions.docparams",
]
if source_file in override_enable:
args.append(f"--enable={','.join(override_enable[source_file])}")
else:
args.append(f"--enable={','.join(default_enable)}")
args.append(source_file)
session.run("pylint", *args)
@nox.session() # type: ignore @nox.session() # type: ignore
def docs(session: Any) -> None: def docs(session: Any) -> None:
"""
builds the html documentation for the client
:param session: current nox session
"""
session.install(".") session.install(".")
session.install(".[docs]") session.install(".[docs]")
with session.chdir("docs"): with session.chdir("docs"):
@@ -118,6 +192,10 @@ def docs(session: Any) -> None:
@nox.session() # type: ignore @nox.session() # type: ignore
def generate(session: Any) -> None: def generate(session: Any) -> None:
"""
generates the base API code
:param session: current nox session
"""
session.install("-rdev-requirements.txt") session.install("-rdev-requirements.txt")
session.run("python", "utils/generate_api.py") session.run("python", "utils/generate_api.py")
format(session) format(session)
+6
View File
@@ -35,6 +35,12 @@ map = map # pylint: disable=invalid-name
def to_str(x: Union[str, bytes], encoding: str = "ascii") -> str: def to_str(x: Union[str, bytes], encoding: str = "ascii") -> str:
"""
returns x as a string encoded in "encoding" if it is not already a string
:param x: the value to convert to a str
:param encoding: the encoding to convert to - see https://docs.python.org/3/library/codecs.html#standard-encodings
:return: an encoded str
"""
if not isinstance(x, str): if not isinstance(x, str):
return x.decode(encoding) return x.decode(encoding)
return x return x
@@ -19,8 +19,8 @@ from opensearchpy import OpenSearch
def main() -> None: def main() -> None:
""" """
demonstrates various functions to operate on the index (e.g. clear different levels of cache, refreshing the demonstrates various functions to operate on the index
index) (e.g. clear different levels of cache, refreshing the index)
""" """
# Set up # Set up
client = OpenSearch( client = OpenSearch(
+6 -3
View File
@@ -21,10 +21,13 @@ from opensearchpy import OpenSearch, RequestsAWSV4SignerAuth, RequestsHttpConnec
def main() -> None: def main() -> None:
""" """
connects to a cluster specified in environment variables, creates an index, inserts documents, connects to a cluster specified in environment variables,
creates an index, inserts documents,
searches the index, deletes the document, deletes the index. searches the index, deletes the document, deletes the index.
the environment variables are "ENDPOINT" for the cluster endpoint, AWS_REGION for the region in which the cluster the environment variables are "ENDPOINT" for the cluster
is hosted, and SERVICE to indicate if this is an ES 7.10.2 compatible cluster endpoint, AWS_REGION for the region in which the cluster
is hosted, and SERVICE to indicate if this is an ES 7.10.2
compatible cluster
""" """
# verbose logging # verbose logging
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
+4 -3
View File
@@ -21,9 +21,10 @@ from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnecti
def main() -> None: def main() -> None:
""" """
1. connects to an OpenSearch cluster on AWS defined by environment variables (i.e. ENDPOINT - cluster endpoint like 1. connects to an OpenSearch cluster on AWS defined by environment variables
my-test-domain.us-east-1.es.amazonaws.com; AWS_REGION like us-east-1, us-west-2; and SERVICE like es which (i.e. ENDPOINT - cluster endpoint like my-test-domain.us-east-1.es.
differentiates beteween serverless and the managed service. amazonaws.com; AWS_REGION like us-east-1, us-west-2; and SERVICE like es which
differentiates between serverless and the managed service.
2. creates an index called "movies" and adds a single document 2. creates an index called "movies" and adds a single document
3. queries for that document 3. queries for that document
4. deletes the document 4. deletes the document
+1 -1
View File
@@ -53,7 +53,7 @@ def main() -> None:
data.append({"index": {"_index": index_name, "_id": i}}) data.append({"index": {"_index": index_name, "_id": i}})
data.append({"value": i}) data.append({"value": i})
rc = client.bulk(data) rc = client.bulk(data) # pylint: disable=invalid-name
if rc["errors"]: if rc["errors"]:
print("There were errors:") print("There were errors:")
for item in rc["items"]: for item in rc["items"]:
+3 -3
View File
@@ -18,8 +18,8 @@ from opensearchpy import OpenSearch, helpers
def main() -> None: def main() -> None:
""" """
demonstrates how to bulk load data using opensearchpy.helpers including examples of serial, parallel, and streaming demonstrates how to bulk load data using opensearchpy.helpers
bulk load including examples of serial, parallel, and streaming bulk load
""" """
# connect to an instance of OpenSearch # connect to an instance of OpenSearch
@@ -56,7 +56,7 @@ def main() -> None:
data.append({"_index": index_name, "_id": i, "value": i}) data.append({"_index": index_name, "_id": i, "value": i})
# serialized bulk raising an exception on error # serialized bulk raising an exception on error
rc = helpers.bulk(client, data) rc = helpers.bulk(client, data) # pylint: disable=invalid-name
print(f"Bulk-inserted {rc[0]} items (bulk).") print(f"Bulk-inserted {rc[0]} items (bulk).")
# parallel bulk with explicit error checking # parallel bulk with explicit error checking
+1 -1
View File
@@ -55,7 +55,7 @@ def main() -> None:
data += json.dumps({"index": {"_index": index_name, "_id": i}}) + "\n" data += json.dumps({"index": {"_index": index_name, "_id": i}}) + "\n"
data += json.dumps({"value": i}) + "\n" data += json.dumps({"value": i}) + "\n"
rc = client.bulk(data) rc = client.bulk(data) # pylint: disable=invalid-name
if rc["errors"]: if rc["errors"]:
print("There were errors:") print("There were errors:")
for item in rc["items"]: for item in rc["items"]:
@@ -18,7 +18,8 @@ from opensearchpy import OpenSearch
def main() -> None: def main() -> None:
""" """
provides samples for different ways to handle documents including indexing, searching, updating, and deleting provides samples for different ways to handle documents
including indexing, searching, updating, and deleting
""" """
# Connect to OpenSearch # Connect to OpenSearch
client = OpenSearch( client = OpenSearch(
+10 -7
View File
@@ -17,8 +17,9 @@ from opensearchpy import OpenSearch
def main() -> None: def main() -> None:
""" """
an example showing how to create an synchronous connection to OpenSearch, create an index, index a document an example showing how to create an synchronous connection to
and search to return the document OpenSearch, create an index, index a document and search to
return the document
""" """
host = "localhost" host = "localhost"
port = 9200 port = 9200
@@ -49,19 +50,21 @@ def main() -> None:
document = {"title": "Moneyball", "director": "Bennett Miller", "year": "2011"} document = {"title": "Moneyball", "director": "Bennett Miller", "year": "2011"}
id = "1" doc_id = "1"
response = client.index(index=index_name, body=document, id=id, refresh=True) response = client.index(index=index_name, body=document, id=doc_id, refresh=True)
print(response) print(response)
# search for a document # search for a document
q = "miller" user_query = "miller"
query = { query = {
"size": 5, "size": 5,
"query": {"multi_match": {"query": q, "fields": ["title^2", "director"]}}, "query": {
"multi_match": {"query": user_query, "fields": ["title^2", "director"]}
},
} }
response = client.search(body=query, index=index_name) response = client.search(body=query, index=index_name)
@@ -70,7 +73,7 @@ def main() -> None:
# delete the document # delete the document
response = client.delete(index=index_name, id=id) response = client.delete(index=index_name, id=doc_id)
print(response) print(response)
+7 -4
View File
@@ -17,8 +17,9 @@ from opensearchpy import AsyncOpenSearch
async def main() -> None: async def main() -> None:
""" """
an example showing how to create an asynchronous connection to OpenSearch, create an index, index a document an example showing how to create an asynchronous connection
and search to return the document to OpenSearch, create an index, index a document and
search to return the document
""" """
# connect to OpenSearch # connect to OpenSearch
host = "localhost" host = "localhost"
@@ -68,11 +69,13 @@ async def main() -> None:
await client.indices.refresh(index=index_name) await client.indices.refresh(index=index_name)
# search for a document # search for a document
q = "miller" user_query = "miller"
query = { query = {
"size": 5, "size": 5,
"query": {"multi_match": {"query": q, "fields": ["title^2", "director"]}}, "query": {
"multi_match": {"query": user_query, "fields": ["title^2", "director"]}
},
} }
results = await client.search(body=query, index=index_name) results = await client.search(body=query, index=index_name)
@@ -13,6 +13,7 @@ from opensearchpy import OpenSearch
def main() -> None: def main() -> None:
""" """
# pylint: disable=line-too-long
1. connects to an OpenSearch instance running on localhost 1. connects to an OpenSearch instance running on localhost
2. Create an index template named `books` with default settings and mappings for indices of 2. Create an index template named `books` with default settings and mappings for indices of
the `books-*` pattern. You can create an index template to define default settings and mappings for indices the `books-*` pattern. You can create an index template to define default settings and mappings for indices
+7 -5
View File
@@ -46,24 +46,26 @@ def main() -> None:
document = {"title": "Moneyball", "director": "Bennett Miller", "year": "2011"} document = {"title": "Moneyball", "director": "Bennett Miller", "year": "2011"}
id = "1" doc_id = "1"
print(client.http.put(f"/{index_name}/_doc/{id}?refresh=true", body=document)) print(client.http.put(f"/{index_name}/_doc/{doc_id}?refresh=true", body=document))
# search for a document # search for a document
q = "miller" user_query = "miller"
query = { query = {
"size": 5, "size": 5,
"query": {"multi_match": {"query": q, "fields": ["title^2", "director"]}}, "query": {
"multi_match": {"query": user_query, "fields": ["title^2", "director"]}
},
} }
print(client.http.post(f"/{index_name}/_search", body=query)) print(client.http.post(f"/{index_name}/_search", body=query))
# delete the document # delete the document
print(client.http.delete(f"/{index_name}/_doc/{id}")) print(client.http.delete(f"/{index_name}/_doc/{doc_id}"))
# delete the index # delete the index
+9 -6
View File
@@ -17,7 +17,8 @@ from opensearchpy import AsyncOpenSearch
async def main() -> None: async def main() -> None:
""" """
this sample uses asyncio and AsyncOpenSearch to asynchronously connect to local OpenSearch cluster, create an index, this sample uses asyncio and AsyncOpenSearch to asynchronously
connect to local OpenSearch cluster, create an index,
index data, search the index, delete the document, delete the index index data, search the index, delete the document, delete the index
""" """
# connect to OpenSearch # connect to OpenSearch
@@ -51,28 +52,30 @@ async def main() -> None:
document = {"title": "Moneyball", "director": "Bennett Miller", "year": "2011"} document = {"title": "Moneyball", "director": "Bennett Miller", "year": "2011"}
id = "1" doc_id = "1"
print( print(
await client.http.put( await client.http.put(
f"/{index_name}/_doc/{id}?refresh=true", body=document f"/{index_name}/_doc/{doc_id}?refresh=true", body=document
) )
) )
# search for a document # search for a document
q = "miller" user_query = "miller"
query = { query = {
"size": 5, "size": 5,
"query": {"multi_match": {"query": q, "fields": ["title^2", "director"]}}, "query": {
"multi_match": {"query": user_query, "fields": ["title^2", "director"]}
},
} }
print(await client.http.post(f"/{index_name}/_search", body=query)) print(await client.http.post(f"/{index_name}/_search", body=query))
# delete the document # delete the document
print(await client.http.delete(f"/{index_name}/_doc/{id}")) print(await client.http.delete(f"/{index_name}/_doc/{doc_id}"))
# delete the index # delete the index
+6 -5
View File
@@ -24,8 +24,9 @@ urllib3.disable_warnings()
def main() -> None: def main() -> None:
""" """
sample for custom logging; this shows how to create a console handler, connect to OpenSearch, define a custom sample for custom logging; this shows how to create a
logger and log to an OpenSearch index console handler, connect to OpenSearch, define a custom
logger, and log to an OpenSearch index
""" """
print("Collecting logs.") print("Collecting logs.")
@@ -90,9 +91,9 @@ def main() -> None:
index=self._build_index_name(), index=self._build_index_name(),
body=document, body=document,
) )
except Exception as e: except Exception as ex:
print(f"Failed to send log to OpenSearch: {e}") print(f"Failed to send log to OpenSearch: {ex}")
logging.warning(f"Failed to send log to OpenSearch: {e}") logging.warning(f"Failed to send log to OpenSearch: {ex}")
raise raise
print("Creating an instance of OpenSearchHandler and adding it to the logger...") print("Creating an instance of OpenSearchHandler and adding it to the logger...")
-6
View File
@@ -22,10 +22,4 @@ target-version = 'py33'
[mypy] [mypy]
ignore_missing_imports=True ignore_missing_imports=True
[pylint]
max-line-length = 240
good-names-rgxs = ^[_a-z][_a-z0-9]?$ # allow for 1-character variable names
[pylint.MESSAGE CONTROL]
disable = all
enable = line-too-long, invalid-name, pointless-statement
+4 -2
View File
@@ -34,7 +34,9 @@ PACKAGE_NAME = "opensearch-py"
PACKAGE_VERSION = "" PACKAGE_VERSION = ""
BASE_DIR = abspath(dirname(__file__)) BASE_DIR = abspath(dirname(__file__))
with open(join(BASE_DIR, PACKAGE_NAME.replace("-", ""), "_version.py")) as f: with open(
join(BASE_DIR, PACKAGE_NAME.replace("-", ""), "_version.py"), encoding="utf-8"
) as f:
data = f.read() data = f.read()
m = re.search(r"^__versionstr__: str\s+=\s+[\"\']([^\"\']+)[\"\']", data, re.M) m = re.search(r"^__versionstr__: str\s+=\s+[\"\']([^\"\']+)[\"\']", data, re.M)
if m: if m:
@@ -42,7 +44,7 @@ with open(join(BASE_DIR, PACKAGE_NAME.replace("-", ""), "_version.py")) as f:
else: else:
raise Exception(f"Invalid version: {data}") raise Exception(f"Invalid version: {data}")
with open(join(BASE_DIR, "README.md")) as f: with open(join(BASE_DIR, "README.md"), encoding="utf-8") as f:
long_description = f.read().strip() long_description = f.read().strip()
MODULE_DIR = PACKAGE_NAME.replace("-", "") MODULE_DIR = PACKAGE_NAME.replace("-", "")
+31 -6
View File
@@ -37,10 +37,16 @@ import subprocess
import sys import sys
from os import environ from os import environ
from os.path import abspath, dirname, exists, join, pardir from os.path import abspath, dirname, exists, join, pardir
from subprocess import CalledProcessError
from typing import Any from typing import Any
def fetch_opensearch_repo() -> None: def fetch_opensearch_repo() -> None:
"""
runs a git fetch origin on configured opensearch core repo
:return: None if environmental variables TEST_OPENSEARCH_YAML_DIR
is set or TEST_OPENSEARCH_NOFETCH is set to False; else returns nothing
"""
# user is manually setting YAML dir, don't tamper with it # user is manually setting YAML dir, don't tamper with it
if "TEST_OPENSEARCH_YAML_DIR" in environ: if "TEST_OPENSEARCH_YAML_DIR" in environ:
return return
@@ -77,12 +83,20 @@ def fetch_opensearch_repo() -> None:
# make a new blank repository in the test directory # make a new blank repository in the test directory
subprocess.check_call("cd %s && git init" % repo_path, shell=True) subprocess.check_call("cd %s && git init" % repo_path, shell=True)
# add a remote try:
subprocess.check_call( # add a remote
"cd %s && git remote add origin https://github.com/opensearch-project/opensearch.git" subprocess.check_call(
% repo_path, "cd %s && git remote add origin https://github.com/opensearch-project/opensearch.git"
shell=True, % repo_path,
) shell=True,
)
except CalledProcessError as e:
# if the run is interrupted from a previous run, it doesn't clean up, and the git add origin command
# errors out; this allows the test to continue
remote_origin_already_exists = 3
print(e)
if e.returncode != remote_origin_already_exists:
sys.exit(1)
# fetch the sha commit, version from info() # fetch the sha commit, version from info()
print("Fetching opensearch repo...") print("Fetching opensearch repo...")
@@ -90,6 +104,17 @@ def fetch_opensearch_repo() -> None:
def run_all(argv: Any = None) -> None: def run_all(argv: Any = None) -> None:
"""
run all the tests given arguments and environment variables
- sets defaults if argv is None, running "pytest --cov=opensearchpy
--junitxml=<path to opensearch-py-junit.xml>
--log-level=DEBUG --cache-clear -vv --cov-report=<path to output code coverage"
* GITHUB_ACTION: fetches yaml tests if this is not in environment variables
* TEST_PATTERN: specify a test to run
* TEST_TYPE: "server" runs on TLS connection; None is unencrypted
* OPENSEARCH_VERSION: "SNAPSHOT" does not do anything with plugins
:param argv: if this is None, then the default arguments
"""
sys.exitfunc = lambda: sys.stderr.write("Shutting down....\n") # type: ignore sys.exitfunc = lambda: sys.stderr.write("Shutting down....\n") # type: ignore
# fetch yaml tests anywhere that's not GitHub Actions # fetch yaml tests anywhere that's not GitHub Actions
if "GITHUB_ACTION" not in environ: if "GITHUB_ACTION" not in environ:
@@ -88,7 +88,8 @@ class TestAIOHttpConnection:
# it means SSLContext is not available for that version of python # it means SSLContext is not available for that version of python
# and we should skip this test. # and we should skip this test.
pytest.skip( pytest.skip(
"Test test_ssl_context is skipped cause SSLContext is not available for this version of Python" "Test test_ssl_context is skipped cause SSLContext is "
"not available for this version of Python"
) )
con = AIOHttpConnection(use_ssl=True, ssl_context=context) con = AIOHttpConnection(use_ssl=True, ssl_context=context)
@@ -202,8 +203,8 @@ class TestAIOHttpConnection:
con = AIOHttpConnection(use_ssl=True, verify_certs=False) con = AIOHttpConnection(use_ssl=True, verify_certs=False)
assert 1 == len(w) assert 1 == len(w)
assert ( assert (
"Connecting to https://localhost:9200 using SSL with verify_certs=False is insecure." "Connecting to https://localhost:9200 using SSL with "
== str(w[0].message) "verify_certs=False is insecure." == str(w[0].message)
) )
assert con.use_ssl assert con.use_ssl
@@ -379,13 +380,17 @@ class TestConnectionHttpServer:
@classmethod @classmethod
def setup_class(cls) -> None: def setup_class(cls) -> None:
# Start server """
Start server
"""
cls.server = TestHTTPServer(port=8081) cls.server = TestHTTPServer(port=8081)
cls.server.start() cls.server.start()
@classmethod @classmethod
def teardown_class(cls) -> None: def teardown_class(cls) -> None:
# Stop server """
stop server
"""
cls.server.stop() cls.server.stop()
async def httpserver(self, conn: Any, **kwargs: Any) -> Any: async def httpserver(self, conn: Any, **kwargs: Any) -> Any:
@@ -22,6 +22,10 @@ pytestmark: MarkDecorator = pytest.mark.asyncio
@fixture # type: ignore @fixture # type: ignore
async def mock_client(dummy_response: Any) -> Any: async def mock_client(dummy_response: Any) -> Any:
"""
yields a mock client with the dummy_response param
:param dummy_response: any kind of response for test
"""
client = Mock() client = Mock()
client.search.return_value = dummy_response client.search.return_value = dummy_response
await add_connection("mock", client) await add_connection("mock", client)
@@ -26,5 +26,6 @@ class TestPluginsClient:
client.plugins.__init__(client) # type: ignore client.plugins.__init__(client) # type: ignore
assert ( assert (
str(w[0].message) str(w[0].message)
== "Cannot load `alerting` directly to AsyncOpenSearch as it already exists. Use `AsyncOpenSearch.plugin.alerting` instead." == "Cannot load `alerting` directly to AsyncOpenSearch as it already exists. Use "
"`AsyncOpenSearch.plugin.alerting` instead."
) )
@@ -34,13 +34,19 @@ from ...utils import wipe_cluster
class AsyncOpenSearchTestCase(IsolatedAsyncioTestCase): # type: ignore class AsyncOpenSearchTestCase(IsolatedAsyncioTestCase): # type: ignore
async def asyncSetUp(self) -> None: # pylint: disable=invalid-name async def asyncSetUp(
self,
) -> None:
# pylint: disable=invalid-name,missing-function-docstring
self.client = await get_test_client( self.client = await get_test_client(
verify_certs=False, http_auth=("admin", "admin") verify_certs=False, http_auth=("admin", "admin")
) )
await add_connection("default", self.client) await add_connection("default", self.client)
async def asyncTearDown(self) -> None: # pylint: disable=invalid-name async def asyncTearDown(
self,
) -> None:
# pylint: disable=invalid-name,missing-function-docstring
wipe_cluster(self.client) wipe_cluster(self.client)
if self.client: if self.client:
await self.client.close() await self.client.close()
@@ -60,7 +60,9 @@ class TestYarlMissing:
async def test_aiohttp_connection_works_without_yarl( async def test_aiohttp_connection_works_without_yarl(
self, async_client: Any, monkeypatch: Any self, async_client: Any, monkeypatch: Any
) -> None: ) -> None:
# This is a defensive test case for if aiohttp suddenly stops using yarl. """
This is a defensive test case for if aiohttp suddenly stops using yarl.
"""
from opensearchpy._async import http_aiohttp from opensearchpy._async import http_aiohttp
monkeypatch.setattr(http_aiohttp, "yarl", False) monkeypatch.setattr(http_aiohttp, "yarl", False)
@@ -43,12 +43,24 @@ async def client() -> Any:
@fixture(scope="function") # type: ignore @fixture(scope="function") # type: ignore
async def opensearch_version(client: Any) -> Any: async def opensearch_version(client: Any) -> Any:
"""
yields the version of the OpenSearch cluster
:param client: client connection to OpenSearch
:return: yields major version number
"""
info = await client.info() info = await client.info()
print(info) print(info)
yield tuple( yield (int(x) async for x in match_version(info))
int(x)
for x in re.match(r"^([0-9.]+)", info["version"]["number"]).group(1).split(".") # type: ignore
) async def match_version(info: Any) -> Any:
"""
matches the full semver server version with the given info
:param info: response from the OpenSearch cluster
"""
match = re.match(r"^([0-9.]+)", info["version"]["number"])
assert match is not None
yield match.group(1).split(".")
@fixture # type: ignore @fixture # type: ignore
@@ -60,7 +72,9 @@ async def write_client(client: Any) -> Any:
@fixture # type: ignore @fixture # type: ignore
async def data_client(client: Any) -> Any: async def data_client(client: Any) -> Any:
# create mappings """
create mappings
"""
await create_git_index(client, "git") await create_git_index(client, "git")
await create_flat_git_index(client, "flat-git") await create_flat_git_index(client, "flat-git")
# load data # load data
@@ -73,6 +87,11 @@ async def data_client(client: Any) -> Any:
@fixture # type: ignore @fixture # type: ignore
async def pull_request(write_client: Any) -> Any: async def pull_request(write_client: Any) -> Any:
"""
create dummy pull request instance
:param write_client: #todo not used
:return: instance of PullRequest
"""
await PullRequest.init() await PullRequest.init()
pr = PullRequest( pr = PullRequest(
_id=42, _id=42,
@@ -96,7 +115,12 @@ async def pull_request(write_client: Any) -> Any:
@fixture # type: ignore @fixture # type: ignore
async def setup_ubq_tests(client: Any) -> str: async def setup_update_by_query_tests(client: Any) -> str:
"""
sets up update by query tests
:param client:
:return: an index name
"""
index = "test-git" index = "test-git"
await create_git_index(client, index) await create_git_index(client, index)
await async_bulk(client, TEST_GIT_DATA, raise_on_error=True, refresh=True) await async_bulk(client, TEST_GIT_DATA, raise_on_error=True, refresh=True)
@@ -60,6 +60,9 @@ class FailingBulkClient(object):
self._fail_with = fail_with self._fail_with = fail_with
async def bulk(self, *args: Any, **kwargs: Any) -> Any: async def bulk(self, *args: Any, **kwargs: Any) -> Any:
"""
increments number of times called and, when it equals fail_at, raises self.fail_with
"""
self._called += 1 self._called += 1
if self._called in self._fail_at: if self._called in self._fail_at:
raise self._fail_with raise self._fail_with
@@ -56,6 +56,10 @@ class MetricSearch(AsyncFacetedSearch):
@pytest.fixture(scope="function") # type: ignore @pytest.fixture(scope="function") # type: ignore
def commit_search_cls(opensearch_version: Any) -> Any: def commit_search_cls(opensearch_version: Any) -> Any:
"""
:param opensearch_version the semver version of OpenSearch
:return: an AsyncFacetedSearch for git commits
"""
interval_kwargs = {"fixed_interval": "1d"} interval_kwargs = {"fixed_interval": "1d"}
class CommitSearch(AsyncFacetedSearch): class CommitSearch(AsyncFacetedSearch):
@@ -102,6 +106,10 @@ def repo_search_cls(opensearch_version: Any) -> Any:
@pytest.fixture(scope="function") # type: ignore @pytest.fixture(scope="function") # type: ignore
def pr_search_cls(opensearch_version: Any) -> Any: def pr_search_cls(opensearch_version: Any) -> Any:
"""
:param opensearch_version: not used here... #TODO remove this parameter?
:return: an AsyncFacetedSearch for pull requests
"""
interval_type = "calendar_interval" interval_type = "calendar_interval"
class PRSearch(AsyncFacetedSearch): class PRSearch(AsyncFacetedSearch):
@@ -19,9 +19,9 @@ pytestmark: MarkDecorator = pytest.mark.asyncio
async def test_update_by_query_no_script( async def test_update_by_query_no_script(
write_client: Any, setup_ubq_tests: Any write_client: Any, setup_update_by_query_tests: Any
) -> None: ) -> None:
index = setup_ubq_tests index = setup_update_by_query_tests
ubq = ( ubq = (
AsyncUpdateByQuery(using=write_client) AsyncUpdateByQuery(using=write_client)
@@ -40,9 +40,9 @@ async def test_update_by_query_no_script(
async def test_update_by_query_with_script( async def test_update_by_query_with_script(
write_client: Any, setup_ubq_tests: Any write_client: Any, setup_update_by_query_tests: Any
) -> None: ) -> None:
index = setup_ubq_tests index = setup_update_by_query_tests
ubq = ( ubq = (
AsyncUpdateByQuery(using=write_client) AsyncUpdateByQuery(using=write_client)
@@ -59,9 +59,9 @@ async def test_update_by_query_with_script(
async def test_delete_by_query_with_script( async def test_delete_by_query_with_script(
write_client: Any, setup_ubq_tests: Any write_client: Any, setup_update_by_query_tests: Any
) -> None: ) -> None:
index = setup_ubq_tests index = setup_update_by_query_tests
ubq = ( ubq = (
AsyncUpdateByQuery(using=write_client) AsyncUpdateByQuery(using=write_client)
@@ -54,6 +54,10 @@ OPENSEARCH_VERSION = None
async def await_if_coro(x: Any) -> Any: async def await_if_coro(x: Any) -> Any:
"""
awaits if x is a coroutine
:return: x
"""
if inspect.iscoroutine(x): if inspect.iscoroutine(x):
return await x return await x
return x return x
@@ -40,13 +40,15 @@ class TestSecurityPlugin(IsolatedAsyncioTestCase): # type: ignore
USER_NAME = "test-user" USER_NAME = "test-user"
USER_CONTENT = {"password": "opensearchpy@123", "opendistro_security_roles": []} USER_CONTENT = {"password": "opensearchpy@123", "opendistro_security_roles": []}
async def asyncSetUp(self) -> None: # pylint: disable=invalid-name async def asyncSetUp(self) -> None:
# pylint: disable=invalid-name, missing-function-docstring
self.client = await get_test_client( self.client = await get_test_client(
verify_certs=False, http_auth=("admin", "admin") verify_certs=False, http_auth=("admin", "admin")
) )
await add_connection("default", self.client) await add_connection("default", self.client)
async def asyncTearDown(self) -> None: # pylint: disable=invalid-name async def asyncTearDown(self) -> None:
# pylint disable=invalid-name
if self.client: if self.client:
await self.client.close() await self.client.close()
@@ -449,8 +449,10 @@ class TestTransport:
assert event_loop.time() - 1 < t.last_sniff < event_loop.time() + 0.01 assert event_loop.time() - 1 < t.last_sniff < event_loop.time() + 0.01
async def test_sniff_7x_publish_host(self) -> None: async def test_sniff_7x_publish_host(self) -> None:
# Test the response shaped when a 7.x node has publish_host set """
# and the returend data is shaped in the fqdn/ip:port format. Test the response shaped when a 7.x node has publish_host set
and the returned data is shaped in the fqdn/ip:port format.
"""
t: Any = AsyncTransport( t: Any = AsyncTransport(
[{"data": CLUSTER_NODES_7X_PUBLISH_HOST}], [{"data": CLUSTER_NODES_7X_PUBLISH_HOST}],
connection_class=DummyConnection, connection_class=DummyConnection,
@@ -20,5 +20,6 @@ class TestPluginsClient(TestCase):
client.plugins.__init__(client) # type: ignore client.plugins.__init__(client) # type: ignore
self.assertEqual( self.assertEqual(
str(w.warnings[0].message), str(w.warnings[0].message),
"Cannot load `alerting` directly to OpenSearch as it already exists. Use `OpenSearch.plugin.alerting` instead.", "Cannot load `alerting` directly to OpenSearch as "
"it already exists. Use `OpenSearch.plugin.alerting` instead.",
) )
@@ -139,7 +139,8 @@ class TestRequestsHttpConnection(TestCase):
) )
self.assertEqual(1, len(w)) self.assertEqual(1, len(w))
self.assertEqual( self.assertEqual(
"Connecting to https://localhost:9200 using SSL with verify_certs=False is insecure.", "Connecting to https://localhost:9200 using SSL with "
"verify_certs=False is insecure.",
str(w[0].message), str(w[0].message),
) )
@@ -286,8 +287,9 @@ class TestRequestsHttpConnection(TestCase):
# trace request # trace request
self.assertEqual(1, tracer.info.call_count) self.assertEqual(1, tracer.info.call_count)
trace_curl_cmd = "curl -H 'Content-Type: application/json' -XGET 'http://localhost:9200/?pretty&param=42' -d '{\n \"question\": \"what\\u0027s that?\"\n}'" # pylint: disable=line-too-long
self.assertEqual( self.assertEqual(
"""curl -H 'Content-Type: application/json' -XGET 'http://localhost:9200/?pretty&param=42' -d '{\n "question": "what\\u0027s that?"\n}'""", trace_curl_cmd,
tracer.info.call_args[0][0] % tracer.info.call_args[0][1:], tracer.info.call_args[0][0] % tracer.info.call_args[0][1:],
) )
# trace response # trace response
@@ -415,9 +417,13 @@ class TestRequestsHttpConnection(TestCase):
self.assertEqual('{"answer": 42}'.encode("utf-8"), request.body) self.assertEqual('{"answer": 42}'.encode("utf-8"), request.body)
# trace request # trace request
trace_curl_cmd = (
"curl -H 'Content-Type: application/json' -XGET 'http://localhost:9200/_search?pretty' "
"-d '{\n \"answer\": 42\n}'"
)
self.assertEqual(1, tracer.info.call_count) self.assertEqual(1, tracer.info.call_count)
self.assertEqual( self.assertEqual(
"curl -H 'Content-Type: application/json' -XGET 'http://localhost:9200/_search?pretty' -d '{\n \"answer\": 42\n}'", trace_curl_cmd,
tracer.info.call_args[0][0] % tracer.info.call_args[0][1:], tracer.info.call_args[0][0] % tracer.info.call_args[0][1:],
) )
@@ -514,7 +520,7 @@ class TestRequestsConnectionRedirect(TestCase):
@classmethod @classmethod
def setup_class(cls) -> None: def setup_class(cls) -> None:
# Start servers """Start servers"""
cls.server1 = TestHTTPServer(port=8080) cls.server1 = TestHTTPServer(port=8080)
cls.server1.start() cls.server1.start()
cls.server2 = TestHTTPServer(port=8090) cls.server2 = TestHTTPServer(port=8090)
@@ -522,7 +528,7 @@ class TestRequestsConnectionRedirect(TestCase):
@classmethod @classmethod
def teardown_class(cls) -> None: def teardown_class(cls) -> None:
# Stop servers """Stop servers"""
cls.server2.stop() cls.server2.stop()
cls.server1.stop() cls.server1.stop()
@@ -73,7 +73,8 @@ class TestUrllib3HttpConnection(TestCase):
# it means SSLContext is not available for that version of python # it means SSLContext is not available for that version of python
# and we should skip this test. # and we should skip this test.
raise SkipTest( raise SkipTest(
"Test test_ssl_context is skipped cause SSLContext is not available for this version of python" "Test test_ssl_context is skipped cause SSLContext"
" is not available for this version of python"
) )
con = Urllib3HttpConnection(use_ssl=True, ssl_context=context) con = Urllib3HttpConnection(use_ssl=True, ssl_context=context)
@@ -272,7 +273,8 @@ class TestUrllib3HttpConnection(TestCase):
con = Urllib3HttpConnection(use_ssl=True, verify_certs=False) con = Urllib3HttpConnection(use_ssl=True, verify_certs=False)
self.assertEqual(1, len(w)) self.assertEqual(1, len(w))
self.assertEqual( self.assertEqual(
"Connecting to https://localhost:9200 using SSL with verify_certs=False is insecure.", "Connecting to https://localhost:9200 using SSL with "
"verify_certs=False is insecure.",
str(w[0].message), str(w[0].message),
) )
+1 -1
View File
@@ -112,7 +112,7 @@ class TestConnectionPool(TestCase):
# Nothing should be marked dead # Nothing should be marked dead
self.assertEqual(0, len(pool.dead_count)) self.assertEqual(0, len(pool.dead_count))
def test_connection_is_forcibly_resurrected_when_no_live_ones_are_availible( def test_connection_is_forcibly_resurrected_when_no_live_ones_are_available(
self, self,
) -> None: ) -> None:
pool = ConnectionPool([(x, {}) for x in range(2)]) pool = ConnectionPool([(x, {}) for x in range(2)])
@@ -136,7 +136,12 @@ class TestParallelBulk(TestCase):
class TestChunkActions(TestCase): class TestChunkActions(TestCase):
def setup_method(self, _: Any) -> None: def setup_method(self, _: Any) -> None:
self.actions: Any = [({"index": {}}, {"some": u"datá", "i": i}) for i in range(100)] # fmt: skip """
creates some documents for testing
"""
self.actions: Any = [
({"index": {}}, {"some": "datá", "i": i}) for i in range(100)
]
def test_expand_action(self) -> None: def test_expand_action(self) -> None:
self.assertEqual(helpers.expand_action({}), ({"index": {}}, {})) self.assertEqual(helpers.expand_action({}), ({"index": {}}, {}))
@@ -280,7 +285,9 @@ class TestScanFunction(TestCase):
def test_scan_with_missing_hits_key( def test_scan_with_missing_hits_key(
self, mock_search: Mock, mock_scroll: Mock, mock_clear_scroll: Mock self, mock_search: Mock, mock_scroll: Mock, mock_clear_scroll: Mock
) -> None: ) -> None:
# Simulate a response where the 'hits' key is missing """
Simulate a response where the 'hits' key is missing
"""
mock_search.return_value = {"_scroll_id": "dummy_scroll_id", "_shards": {}} mock_search.return_value = {"_scroll_id": "dummy_scroll_id", "_shards": {}}
mock_scroll.side_effect = [{"_scroll_id": "dummy_scroll_id", "_shards": {}}] mock_scroll.side_effect = [{"_scroll_id": "dummy_scroll_id", "_shards": {}}]
@@ -38,6 +38,11 @@ from opensearchpy.helpers.response.aggs import AggResponse, Bucket, BucketData
@fixture # type: ignore @fixture # type: ignore
def agg_response(aggs_search: Any, aggs_data: Any) -> Any: def agg_response(aggs_search: Any, aggs_data: Any) -> Any:
"""
:param aggs_search: aggregation search
:param aggs_data: data to aggregate
:return: the aggregated data
"""
return response.Response(aggs_search, aggs_data) return response.Response(aggs_search, aggs_data)
+9
View File
@@ -17,6 +17,9 @@ class TestHTTPRequestHandler(BaseHTTPRequestHandler):
__test__ = False __test__ = False
def do_GET(self) -> None: # pylint: disable=invalid-name def do_GET(self) -> None: # pylint: disable=invalid-name
"""
writes a response out to a file given mocked parameters on this object
"""
headers = self.headers headers = self.headers
if self.path == "/redirect": if self.path == "/redirect":
@@ -49,6 +52,9 @@ class TestHTTPServer(HTTPServer):
self._server_thread = None self._server_thread = None
def start(self) -> None: def start(self) -> None:
"""
start the test HTTP server
"""
if self._server_thread is not None: if self._server_thread is not None:
return return
@@ -56,6 +62,9 @@ class TestHTTPServer(HTTPServer):
self._server_thread.start() self._server_thread.start()
def stop(self) -> None: def stop(self) -> None:
"""
stop the test HTTP server
"""
if self._server_thread is None: if self._server_thread is None:
return return
self.socket.close() self.socket.close()
@@ -53,12 +53,23 @@ def client() -> Any:
@fixture(scope="session") # type: ignore @fixture(scope="session") # type: ignore
def opensearch_version(client: Any) -> Any: def opensearch_version(client: Any) -> Any:
info = client.info() """
yields a major version from the client
:param client: client to connect to opensearch
"""
info: Any = client.info()
print(info) print(info)
yield tuple( yield (int(x) for x in match_version(info))
int(x)
for x in re.match(r"^([0-9.]+)", info["version"]["number"]).group(1).split(".") # type: ignore
) def match_version(info: Any) -> Any:
"""
matches the major version from the given client info
:param info: part of the response from OpenSearch
"""
match = re.match(r"^([0-9.]+)", info["version"]["number"])
assert match is not None
yield match.group(1).split(".")
@fixture # type: ignore @fixture # type: ignore
@@ -107,6 +118,7 @@ def pull_request(write_client: Any) -> Any:
@fixture # type: ignore @fixture # type: ignore
def setup_ubq_tests(client: Any) -> str: def setup_ubq_tests(client: Any) -> str:
# todo what's a ubq test?
index = "test-git" index = "test-git"
create_git_index(client, index) create_git_index(client, index)
bulk(client, TEST_GIT_DATA, raise_on_error=True, refresh=True) bulk(client, TEST_GIT_DATA, raise_on_error=True, refresh=True)
@@ -509,7 +509,9 @@ class TestScan(OpenSearchTestCase):
} }
client_mock.clear_scroll.return_value = {} client_mock.clear_scroll.return_value = {}
data = list(helpers.scan(self.client, index="test_index", **{key: val})) # type: ignore data = list(
helpers.scan(self.client, index="test_index", **{key: val}) # type: ignore
)
self.assertEqual(data, [{"search_data": 1}]) self.assertEqual(data, [{"search_data": 1}])
@@ -89,6 +89,7 @@ SKIP_TESTS = {
"OpenSearch-main/rest-api-spec/src/main/resources/rest-api-spec/test/search/aggregation/20_terms[4]", "OpenSearch-main/rest-api-spec/src/main/resources/rest-api-spec/test/search/aggregation/20_terms[4]",
"OpenSearch-main/rest-api-spec/src/main/resources/rest-api-spec/test/tasks/list/10_basic[0]", "OpenSearch-main/rest-api-spec/src/main/resources/rest-api-spec/test/tasks/list/10_basic[0]",
"OpenSearch-main/rest-api-spec/src/main/resources/rest-api-spec/test/index/90_unsigned_long[1]", "OpenSearch-main/rest-api-spec/src/main/resources/rest-api-spec/test/index/90_unsigned_long[1]",
"OpenSearch-main/rest-api-spec/src/main/resources/rest-api-spec/test/indices/stats/50_noop_update[0]",
"search/aggregation/250_moving_fn[1]", "search/aggregation/250_moving_fn[1]",
# body: null # body: null
"indices/simulate_index_template/10_basic[2]", "indices/simulate_index_template/10_basic[2]",
@@ -157,7 +158,7 @@ class YamlRunner:
self._teardown_code = test_spec.pop("teardown", None) self._teardown_code = test_spec.pop("teardown", None)
def setup(self) -> Any: def setup(self) -> Any:
# Pull skips from individual tests to not do unnecessary setup. """Pull skips from individual tests to not do unnecessary setup."""
skip_code: Any = [] skip_code: Any = []
for action in self._run_code: for action in self._run_code:
assert len(action) == 1 assert len(action) == 1
@@ -472,7 +473,7 @@ client = get_client()
def load_rest_api_tests() -> None: def load_rest_api_tests() -> None:
# Try loading the REST API test specs from OpenSearch core. """Try loading the REST API test specs from OpenSearch core."""
try: try:
# Construct the HTTP and OpenSearch client # Construct the HTTP and OpenSearch client
http = urllib3.PoolManager(retries=10) http = urllib3.PoolManager(retries=10)
+36 -15
View File
@@ -45,6 +45,9 @@ TMP_DIR = None
@contextlib.contextmanager # type: ignore @contextlib.contextmanager # type: ignore
def set_tmp_dir() -> None: def set_tmp_dir() -> None:
"""
makes and yields a temporary directory for any working files needed for a process during a build
"""
global TMP_DIR global TMP_DIR
TMP_DIR = tempfile.mkdtemp() TMP_DIR = tempfile.mkdtemp()
yield TMP_DIR yield TMP_DIR
@@ -53,6 +56,13 @@ def set_tmp_dir() -> None:
def run(*argv: Any, expect_exit_code: int = 0) -> None: def run(*argv: Any, expect_exit_code: int = 0) -> None:
"""
runs a command within this script
:param argv: command to run e.g. "git" "checkout" "--" "setup.py" "opensearchpy/"
:param expect_exit_code: code to compare with actual exit code from command.
will exit the process if they do not
match the proper exit code
"""
global TMP_DIR global TMP_DIR
if TMP_DIR is None: if TMP_DIR is None:
os.chdir(BASE_DIR) os.chdir(BASE_DIR)
@@ -71,6 +81,10 @@ def run(*argv: Any, expect_exit_code: int = 0) -> None:
def test_dist(dist: Any) -> None: def test_dist(dist: Any) -> None:
"""
validate that the distribution created works
:param dist: base directory of the distribution
"""
with set_tmp_dir() as tmp_dir: # type: ignore with set_tmp_dir() as tmp_dir: # type: ignore
dist_name = re.match( # type: ignore dist_name = re.match( # type: ignore
r"^(opensearchpy\d*)-", r"^(opensearchpy\d*)-",
@@ -181,17 +195,24 @@ def test_dist(dist: Any) -> None:
def main() -> None: def main() -> None:
"""
creates a distribution given of the OpenSearch python client
Notes: does not run on MacOS; this script is generally driven by a GitHub Action located in
.github/workflows/unified-release.yml
"""
run("git", "checkout", "--", "setup.py", "opensearchpy/") run("git", "checkout", "--", "setup.py", "opensearchpy/")
run("rm", "-rf", "build/", "dist/*", "*.egg-info", ".eggs") run("rm", "-rf", "build/", "dist/*", "*.egg-info", ".eggs")
run("python", "setup.py", "sdist", "bdist_wheel") run("python", "setup.py", "sdist", "bdist_wheel")
# Grab the major version to be used as a suffix. # Grab the major version to be used as a suffix.
version_path = os.path.join(BASE_DIR, "opensearchpy/_version.py") version_path = os.path.join(BASE_DIR, "opensearchpy/_version.py")
with open(version_path) as f: with open(version_path, encoding="utf-8") as file:
data = f.read() data = file.read()
m = re.search(r"^__versionstr__: str\s+=\s+[\"\']([^\"\']+)[\"\']", data, re.M) version_match = re.search(
if m: r"^__versionstr__: str\s+=\s+[\"\']([^\"\']+)[\"\']", data, re.M
version = m.group(1) )
if version_match:
version = version_match.group(1)
else: else:
raise Exception(f"Invalid version: {data}") raise Exception(f"Invalid version: {data}")
@@ -254,25 +275,25 @@ def main() -> None:
# Ensure that the version within 'opensearchpy/_version.py' is correct. # Ensure that the version within 'opensearchpy/_version.py' is correct.
version_path = os.path.join(BASE_DIR, f"opensearchpy{suffix}/_version.py") version_path = os.path.join(BASE_DIR, f"opensearchpy{suffix}/_version.py")
with open(version_path) as f: with open(version_path, encoding="utf-8") as file:
version_data = f.read() version_data = file.read()
version_data = re.sub( version_data = re.sub(
r"__versionstr__: str = \"[^\"]+\"", r"__versionstr__: str = \"[^\"]+\"",
'__versionstr__: str = "%s"' % version, '__versionstr__: str = "%s"' % version,
version_data, version_data,
) )
with open(version_path, "w") as f: with open(version_path, "w", encoding="utf-8") as file:
f.truncate() file.truncate()
f.write(version_data) file.write(version_data)
# Rewrite setup.py with the new name. # Rewrite setup.py with the new name.
setup_py_path = os.path.join(BASE_DIR, "setup.py") setup_py_path = os.path.join(BASE_DIR, "setup.py")
with open(setup_py_path) as f: with open(setup_py_path, encoding="utf-8") as file:
setup_py = f.read() setup_py = file.read()
with open(setup_py_path, "w") as f: with open(setup_py_path, "w", encoding="utf-8") as file:
f.truncate() file.truncate()
assert 'PACKAGE_NAME = "opensearch-py"' in setup_py assert 'PACKAGE_NAME = "opensearch-py"' in setup_py
f.write( file.write(
setup_py.replace( setup_py.replace(
'PACKAGE_NAME = "opensearch-py"', 'PACKAGE_NAME = "opensearch-py"',
'PACKAGE_NAME = "opensearch-py%s"' % suffix, 'PACKAGE_NAME = "opensearch-py%s"' % suffix,
+150 -71
View File
@@ -80,6 +80,10 @@ jinja_env = Environment(
def blacken(filename: Any) -> None: def blacken(filename: Any) -> None:
"""
runs 'black' https://pypi.org/project/black/ on the given file
:param filename: file to reformant
"""
runner = CliRunner() runner = CliRunner()
result = runner.invoke(black.main, [str(filename)]) result = runner.invoke(black.main, [str(filename)])
assert result.exit_code == 0, result.output assert result.exit_code == 0, result.output
@@ -87,6 +91,11 @@ def blacken(filename: Any) -> None:
@lru_cache() @lru_cache()
def is_valid_url(url: str) -> bool: def is_valid_url(url: str) -> bool:
"""
makes a call to the url
:param url: url to check
:return: True if status code is between HTTP 200 inclusive and 400 exclusive; False otherwise
"""
return 200 <= http.request("HEAD", url).status < 400 return 200 <= http.request("HEAD", url).status < 400
@@ -97,17 +106,24 @@ class Module:
self.parse_orig() self.parse_orig()
def add(self, api: Any) -> None: def add(self, api: Any) -> None:
"""
add an API to the list of modules
:param api: an API object
"""
self._apis.append(api) self._apis.append(api)
def parse_orig(self) -> None: def parse_orig(self) -> None:
"""
reads the written module and updates with important code specific to this client
"""
self.orders = [] self.orders = []
self.header = "from typing import Any, Collection, Optional, Tuple, Union\n\n" self.header = "from typing import Any, Collection, Optional, Tuple, Union\n\n"
namespace_new = "".join(word.capitalize() for word in self.namespace.split("_")) namespace_new = "".join(word.capitalize() for word in self.namespace.split("_"))
self.header += "class " + namespace_new + "Client(NamespacedClient):" self.header += "class " + namespace_new + "Client(NamespacedClient):"
if os.path.exists(self.filepath): if os.path.exists(self.filepath):
with open(self.filepath) as f: with open(self.filepath, encoding="utf-8") as file:
content = f.read() content = file.read()
header_lines = [] header_lines = []
for line in content.split("\n"): for line in content.split("\n"):
header_lines.append(line) header_lines.append(line)
@@ -122,7 +138,7 @@ class Module:
if "security.py" in str(self.filepath): if "security.py" in str(self.filepath):
# TODO: FIXME, import code # TODO: FIXME, import code
header_lines.append( header_lines.append(
" from ._patch import health_check, update_audit_config # type: ignore" " from ._patch import health_check, update_audit_config # type: ignore" # pylint: disable=line-too-long
) )
break break
self.header = "\n".join(header_lines) self.header = "\n".join(header_lines)
@@ -137,14 +153,21 @@ class Module:
return len(self.orders) return len(self.orders)
def sort(self) -> None: def sort(self) -> None:
"""
sorts the list of APIs by the Module._position key
"""
self._apis.sort(key=self._position) self._apis.sort(key=self._position)
def dump(self) -> None: def dump(self) -> None:
"""
writes the module out to disk
"""
self.sort() self.sort()
# This code snippet adds headers to each generated module indicating that the code is generated. # This code snippet adds headers to each generated module indicating
# The separator is the last line in the "THIS CODE IS AUTOMATICALLY GENERATED" header. # that the code is generated.The separator is the last line in the
header_separator = "# -----------------------------------------------------------------------------------------+" # "THIS CODE IS AUTOMATICALLY GENERATED" header.
header_separator = "# -----------------------------------------------------------------------------------------+" # pylint: disable=line-too-long
license_header_end_1 = "# GitHub history for details." license_header_end_1 = "# GitHub history for details."
license_header_end_2 = "# under the License." license_header_end_2 = "# under the License."
@@ -153,8 +176,8 @@ class Module:
# Identifying the insertion point for the "THIS CODE IS AUTOMATICALLY GENERATED" header. # Identifying the insertion point for the "THIS CODE IS AUTOMATICALLY GENERATED" header.
if os.path.exists(self.filepath): if os.path.exists(self.filepath):
with open(self.filepath, "r") as f: with open(self.filepath, "r", encoding="utf-8") as file:
content = f.read() content = file.read()
if header_separator in content: if header_separator in content:
update_header = False update_header = False
header_end_position = ( header_end_position = (
@@ -180,17 +203,18 @@ class Module:
generated_file_header_path = os.path.join( generated_file_header_path = os.path.join(
current_script_folder, "generated_file_headers.txt" current_script_folder, "generated_file_headers.txt"
) )
with open(generated_file_header_path, "r") as header_file: with open(generated_file_header_path, "r", encoding="utf-8") as header_file:
header_content = header_file.read() header_content = header_file.read()
# Imports are temporarily removed from the header and are regenerated later to ensure imports are updated after code generation. # Imports are temporarily removed from the header and are regenerated
# later to ensure imports are updated after code generation.
self.header = "\n".join( self.header = "\n".join(
line for line in self.header.split("\n") if "from .utils import" not in line line for line in self.header.split("\n") if "from .utils import" not in line
) )
with open(self.filepath, "w") as f: with open(self.filepath, "w", encoding="utf-8") as file:
if update_header is True: if update_header is True:
f.write( file.write(
self.header[:license_position] self.header[:license_position]
+ "\n" + "\n"
+ header_content + header_content
@@ -199,20 +223,20 @@ class Module:
+ self.header[license_position:] + self.header[license_position:]
) )
else: else:
f.write( file.write(
self.header[:header_position] self.header[:header_position]
+ "\n" + "\n"
+ "#replace_token#\n" + "#replace_token#\n"
+ self.header[header_position:] + self.header[header_position:]
) )
for api in self._apis: for api in self._apis:
f.write(api.to_python()) file.write(api.to_python())
# Generating imports for each module # Generating imports for each module
utils_imports = "" utils_imports = ""
file_content = "" file_content = ""
with open(self.filepath, "r") as f: with open(self.filepath, "r", encoding="utf-8") as file:
content = f.read() content = file.read()
keywords = [ keywords = [
"SKIP_IN_PATH", "SKIP_IN_PATH",
"_normalize_hosts", "_normalize_hosts",
@@ -232,11 +256,14 @@ class Module:
utils_imports = result utils_imports = result
file_content = content.replace("#replace_token#", utils_imports) file_content = content.replace("#replace_token#", utils_imports)
with open(self.filepath, "w") as f: with open(self.filepath, "w", encoding="utf-8") as file:
f.write(file_content) file.write(file_content)
@property @property
def filepath(self) -> Any: def filepath(self) -> Any:
"""
:return: absolute path to the module
"""
return CODE_ROOT / f"opensearchpy/_async/client/{self.namespace}.py" return CODE_ROOT / f"opensearchpy/_async/client/{self.namespace}.py"
@@ -287,16 +314,20 @@ class API:
@property @property
def all_parts(self) -> Dict[str, str]: def all_parts(self) -> Dict[str, str]:
"""
updates the url parts from the specification
:return: dict of updated parts
"""
parts = {} parts = {}
for url in self._def["url"]["paths"]: for url in self._def["url"]["paths"]:
parts.update(url.get("parts", {})) parts.update(url.get("parts", {}))
for p in parts: for part in parts:
if "required" not in parts[p]: if "required" not in parts[part]:
parts[p]["required"] = all( parts[part]["required"] = all(
p in url.get("parts", {}) for url in self._def["url"]["paths"] part in url.get("parts", {}) for url in self._def["url"]["paths"]
) )
parts[p]["type"] = "Any" parts[part]["type"] = "Any"
# This piece of logic corresponds to calling # This piece of logic corresponds to calling
# client.tasks.get() w/o a task_id which was erroneously # client.tasks.get() w/o a task_id which was erroneously
@@ -322,6 +353,9 @@ class API:
@property @property
def params(self) -> Any: def params(self) -> Any:
"""
:return: itertools.chain of required parts of the API
"""
parts = self.all_parts parts = self.all_parts
params = self._def.get("params", {}) params = self._def.get("params", {})
return chain( return chain(
@@ -337,24 +371,30 @@ class API:
@property @property
def body(self) -> Any: def body(self) -> Any:
b = self._def.get("body", {}) """
if b: :return: body of the API spec
b.setdefault("required", False) """
return b body_api_spec = self._def.get("body", {})
if body_api_spec:
body_api_spec.setdefault("required", False)
return body_api_spec
@property @property
def query_params(self) -> Any: def query_params(self) -> Any:
"""
:return: any query string parameters from the specification
"""
return ( return (
k key
for k in sorted(self._def.get("params", {}).keys()) for key in sorted(self._def.get("params", {}).keys())
if k not in self.all_parts if key not in self.all_parts
) )
@property @property
def all_func_params(self) -> Any: def all_func_params(self) -> Any:
"""Parameters that will be in the '@query_params' decorator list """
Parameters that will be in the '@query_params' decorator list
and parameters that will be in the function signature. and parameters that will be in the function signature.
This doesn't include
""" """
params = list(self._def.get("params", {}).keys()) params = list(self._def.get("params", {}).keys())
for url in self._def["url"]["paths"]: for url in self._def["url"]["paths"]:
@@ -365,6 +405,9 @@ class API:
@property @property
def path(self) -> Any: def path(self) -> Any:
"""
:return: the first lexically ordered path in url.paths
"""
return max( return max(
(path for path in self._def["url"]["paths"]), (path for path in self._def["url"]["paths"]),
key=lambda p: len(re.findall(r"\{([^}]+)\}", p["path"])), key=lambda p: len(re.findall(r"\{([^}]+)\}", p["path"])),
@@ -372,8 +415,12 @@ class API:
@property @property
def method(self) -> Any: def method(self) -> Any:
# To adhere to the HTTP RFC we shouldn't send """
# bodies in GET requests. To adhere to the HTTP RFC we shouldn't send
bodies in GET requests.
:return: an updated HTTP method to use to communicate with the OpenSearch API
"""
default_method = self.path["methods"][0] default_method = self.path["methods"][0]
if self.name == "refresh" or self.name == "flush": if self.name == "refresh" or self.name == "flush":
return "POST" return "POST"
@@ -385,6 +432,9 @@ class API:
@property @property
def url_parts(self) -> Any: def url_parts(self) -> Any:
"""
:return tuple of boolean (if the path is dynamic), list of url parts
"""
path = self.path["path"] path = self.path["path"]
dynamic = "{" in path dynamic = "{" in path
@@ -406,6 +456,9 @@ class API:
@property @property
def required_parts(self) -> Any: def required_parts(self) -> Any:
"""
:return: list of parts of the url that are required plus the body
"""
parts = self.all_parts parts = self.all_parts
required = [p for p in parts if parts[p]["required"]] # type: ignore required = [p for p in parts if parts[p]["required"]] # type: ignore
if self.body.get("required"): if self.body.get("required"):
@@ -413,12 +466,15 @@ class API:
return required return required
def to_python(self) -> Any: def to_python(self) -> Any:
"""
:return: rendered Jinja template
"""
try: try:
t = jinja_env.get_template(f"overrides/{self.namespace}/{self.name}") template = jinja_env.get_template(f"overrides/{self.namespace}/{self.name}")
except TemplateNotFound: except TemplateNotFound:
t = jinja_env.get_template("base") template = jinja_env.get_template("base")
return t.render( return template.render(
api=self, api=self,
substitutions={v: k for k, v in SUBSTITUTIONS.items()}, substitutions={v: k for k, v in SUBSTITUTIONS.items()},
global_query_params=GLOBAL_QUERY_PARAMS, global_query_params=GLOBAL_QUERY_PARAMS,
@@ -426,52 +482,61 @@ class API:
def read_modules() -> Any: def read_modules() -> Any:
"""
checks the opensearch-api spec at
https://raw.githubusercontent.com/opensearch-project/opensearch-api-specification/main/OpenSearch.openapi.json
and parses it into one or more API modules
:return: a dict of API objects
"""
modules = {} modules = {}
# Load the OpenAPI specification file # Load the OpenAPI specification file
response = requests.get( response = requests.get(
"https://raw.githubusercontent.com/opensearch-project/opensearch-api-specification/main/OpenSearch.openapi.json" "https://raw.githubusercontent.com/opensearch-project/opensearch-api-"
"specification/main/OpenSearch.openapi.json"
) )
data = response.json() data = response.json()
list_of_dicts = [] list_of_dicts = []
for path in data["paths"]: for path in data["paths"]:
for x in data["paths"][path]: for param in data["paths"][path]: # pylint: disable=invalid-name
if data["paths"][path][x]["x-operation-group"] == "nodes.hot_threads": if data["paths"][path][param]["x-operation-group"] == "nodes.hot_threads":
if "deprecated" in data["paths"][path][x]: if "deprecated" in data["paths"][path][param]:
continue continue
data["paths"][path][x].update({"path": path, "method": x}) data["paths"][path][param].update({"path": path, "method": param})
list_of_dicts.append(data["paths"][path][x]) list_of_dicts.append(data["paths"][path][param])
# Update parameters in each endpoint # Update parameters in each endpoint
for p in list_of_dicts: for param_dict in list_of_dicts:
if "parameters" in p: if "parameters" in param_dict:
params = [] params = []
parts = [] parts = []
# Iterate over the list of parameters and update them # Iterate over the list of parameters and update them
for x in p["parameters"]: for param in param_dict["parameters"]:
if "schema" in x and "$ref" in x["schema"]: if "schema" in param and "$ref" in param["schema"]:
schema_path_ref = x["schema"]["$ref"].split("/")[-1] schema_path_ref = param["schema"]["$ref"].split("/")[-1]
x["schema"] = data["components"]["schemas"][schema_path_ref] param["schema"] = data["components"]["schemas"][schema_path_ref]
params.append(x) params.append(param)
else: else:
params.append(x) params.append(param)
# Iterate over the list of updated parameters to separate "parts" from "params" # Iterate over the list of updated parameters to separate "parts" from "params"
k = params.copy() params_copy = params.copy()
for q in k: for param in params_copy:
if q["in"] == "path": if param["in"] == "path":
parts.append(q) parts.append(param)
params.remove(q) params.remove(param)
# Convert "params" and "parts" into the structure required for generator. # Convert "params" and "parts" into the structure required for generator.
params_new = {} params_new = {}
parts_new = {} parts_new = {}
for m in params: for m in params: # pylint: disable=invalid-name
a = dict(type=m["schema"]["type"], description=m["description"]) a = dict( # pylint: disable=invalid-name
type=m["schema"]["type"], description=m["description"]
) # pylint: disable=invalid-name
if "default" in m["schema"]: if "default" in m["schema"]:
a.update({"default": m["schema"]["default"]}) a.update({"default": m["schema"]["default"]})
@@ -488,22 +553,25 @@ def read_modules() -> Any:
params_new.update({m["name"]: a}) params_new.update({m["name"]: a})
# Removing the deprecated "type" # Removing the deprecated "type"
if p["x-operation-group"] != "nodes.hot_threads" and "type" in params_new: if (
param_dict["x-operation-group"] != "nodes.hot_threads"
and "type" in params_new
):
params_new.pop("type") params_new.pop("type")
if ( if (
p["x-operation-group"] == "cluster.health" param_dict["x-operation-group"] == "cluster.health"
and "ensure_node_commissioned" in params_new and "ensure_node_commissioned" in params_new
): ):
params_new.pop("ensure_node_commissioned") params_new.pop("ensure_node_commissioned")
if bool(params_new): if bool(params_new):
p.update({"params": params_new}) param_dict.update({"params": params_new})
p.pop("parameters") param_dict.pop("parameters")
for n in parts: for n in parts: # pylint: disable=invalid-name
b = dict(type=n["schema"]["type"]) b = dict(type=n["schema"]["type"]) # pylint: disable=invalid-name
if "description" in n: if "description" in n:
b.update({"description": n["description"]}) b.update({"description": n["description"]})
@@ -526,7 +594,7 @@ def read_modules() -> Any:
parts_new.update({n["name"]: b}) parts_new.update({n["name"]: b})
if bool(parts_new): if bool(parts_new):
p.update({"parts": parts_new}) param_dict.update({"parts": parts_new})
# Sort the input list by the value of the "x-operation-group" key # Sort the input list by the value of the "x-operation-group" key
list_of_dicts = sorted(list_of_dicts, key=itemgetter("x-operation-group")) list_of_dicts = sorted(list_of_dicts, key=itemgetter("x-operation-group"))
@@ -550,7 +618,7 @@ def read_modules() -> Any:
# Extract the HTTP methods from the data in the current subgroup # Extract the HTTP methods from the data in the current subgroup
methods = [] methods = []
parts_final = {} parts_final = {}
for z in value2: for z in value2: # pylint: disable=invalid-name
methods.append(z["method"].upper()) methods.append(z["method"].upper())
# Update 'api' dictionary # Update 'api' dictionary
@@ -570,9 +638,9 @@ def read_modules() -> Any:
body = {"required": False} body = {"required": False}
if "required" in z["requestBody"]: if "required" in z["requestBody"]:
body.update({"required": z["requestBody"]["required"]}) body.update({"required": z["requestBody"]["required"]})
q = z["requestBody"]["content"]["application/json"]["schema"][ q = z["requestBody"]["content"][ # pylint: disable=invalid-name
"$ref" "application/json"
].split("/")[-1] ]["schema"]["$ref"].split("/")[-1]
if "description" in data["components"]["schemas"][q]: if "description" in data["components"]["schemas"][q]:
body.update( body.update(
{ {
@@ -644,17 +712,28 @@ def read_modules() -> Any:
def apply_patch(namespace: str, name: str, api: Any) -> Any: def apply_patch(namespace: str, name: str, api: Any) -> Any:
"""
applies patches as specified in {name}.json
:param namespace: directory containing overrides
:param name: file to be prepended to ".json" containing override instructions
:param api: specific api to override
:return: modified api
"""
override_file_path = ( override_file_path = (
CODE_ROOT / "utils/templates/overrides" / namespace / f"{name}.json" CODE_ROOT / "utils/templates/overrides" / namespace / f"{name}.json"
) )
if os.path.exists(override_file_path): if os.path.exists(override_file_path):
with open(override_file_path) as f: with open(override_file_path, encoding="utf-8") as file:
override_json = json.load(f) override_json = json.load(file)
api = deepmerge.always_merger.merge(api, override_json) api = deepmerge.always_merger.merge(api, override_json)
return api return api
def dump_modules(modules: Any) -> None: def dump_modules(modules: Any) -> None:
"""
writes out modules to disk
:param modules: a list of python modules
"""
for mod in modules.values(): for mod in modules.values():
mod.dump() mod.dump()
+22 -7
View File
@@ -47,11 +47,16 @@ def find_files_to_fix(sources: List[str]) -> Iterator[str]:
def does_file_need_fix(filepath: str) -> bool: def does_file_need_fix(filepath: str) -> bool:
"""
checks if the correct license header exists at the top of the file
:param filepath: an absolute or relative filepath to a file to check
:return: True if the file needs a header, False if it does not
"""
if not re.search(r"\.py$", filepath): if not re.search(r"\.py$", filepath):
return False return False
existing_header = "" existing_header = ""
with open(filepath, mode="r") as f: with open(filepath, mode="r", encoding="utf-8") as file:
for line in f: for line in file:
line = line.strip() line = line.strip()
if len(line) == 0 or line in LINES_TO_KEEP: if len(line) == 0 or line in LINES_TO_KEEP:
pass pass
@@ -64,20 +69,30 @@ def does_file_need_fix(filepath: str) -> bool:
def add_header_to_file(filepath: str) -> None: def add_header_to_file(filepath: str) -> None:
with open(filepath, mode="r") as f: """
lines = list(f) writes the license header to the beginning of a file
:param filepath: relative or absolute filepath to update
"""
with open(filepath, mode="r", encoding="utf-8") as file:
lines = list(file)
i = 0 i = 0
for i, line in enumerate(lines): for i, line in enumerate(lines):
if len(line) > 0 and line not in LINES_TO_KEEP: if len(line) > 0 and line not in LINES_TO_KEEP:
break break
lines = lines[:i] + [LICENSE_HEADER] + lines[i:] lines = lines[:i] + [LICENSE_HEADER] + lines[i:]
with open(filepath, mode="w") as f: with open(filepath, mode="w", encoding="utf-8") as file:
f.truncate() file.truncate()
f.write("".join(lines)) file.write("".join(lines))
print(f"Fixed {os.path.relpath(filepath, os.getcwd())}") print(f"Fixed {os.path.relpath(filepath, os.getcwd())}")
def main() -> None: def main() -> None:
"""
arguments:
fix: find all files without license headers and insert headers at the top of the file
check: prints a list of files without license headers
list of one or more directories: search in these directories
"""
mode = sys.argv[1] mode = sys.argv[1]
assert mode in ("fix", "check") assert mode in ("fix", "check")
sources = [os.path.abspath(x) for x in sys.argv[2:]] sources = [os.path.abspath(x) for x in sys.argv[2:]]