Added sniffing logic to Transport
This commit is contained in:
@@ -1,14 +1,30 @@
|
||||
import re
|
||||
|
||||
from .connection import RequestsHttpConnection
|
||||
from .connection_pool import ConnectionPool
|
||||
from .serializer import JSONSerializer
|
||||
from .exceptions import TransportError
|
||||
|
||||
# get ip/port from "inet[wind/127.0.0.1:9200]"
|
||||
ADDRESS_RE = re.compile(r'/(?P<host>[^:]*):(?P<port>[0-9]+)\]')
|
||||
|
||||
def construct_hosts_list(nodes, transport):
|
||||
hosts = []
|
||||
address = '%s_address' % transport
|
||||
for n in nodes.values():
|
||||
match = ADDRESS_RE.search(n.get(address, ''))
|
||||
if match:
|
||||
hosts.append(match.groupdict())
|
||||
return hosts
|
||||
|
||||
class Transport(object):
|
||||
def __init__(self, hosts, connection_class=RequestsHttpConnection,
|
||||
connection_pool_class=ConnectionPool, serializer=JSONSerializer(),
|
||||
connection_pool_class=ConnectionPool, nodes_to_host_callback=construct_hosts_list,
|
||||
sniff_on_start=False, sniff_after_requests=None,
|
||||
sniff_on_connection_fail=False, serializer=JSONSerializer(),
|
||||
max_retries=3, **kwargs):
|
||||
|
||||
self.max_retries = 3
|
||||
self.max_retries = max_retries
|
||||
|
||||
# data serializer
|
||||
self.serializer = serializer
|
||||
@@ -24,6 +40,19 @@ class Transport(object):
|
||||
# ...and instantiate them
|
||||
self.set_connections(hosts)
|
||||
|
||||
# sniffing data
|
||||
self.req_counter = 0
|
||||
self.sniffs_due_to_failure = 0
|
||||
self.sniff_after_requests_original = sniff_after_requests
|
||||
self.sniff_after_requests = sniff_after_requests
|
||||
self.sniff_on_connection_fail = sniff_on_connection_fail
|
||||
|
||||
# callback to construct hosts dicts from /_cluster/nodes data
|
||||
self.nodes_to_host_callback = nodes_to_host_callback
|
||||
|
||||
if sniff_on_start:
|
||||
self.sniff_hosts()
|
||||
|
||||
def add_connection(self, host):
|
||||
self.hosts.append(host)
|
||||
self.set_connections(self.hosts)
|
||||
@@ -39,16 +68,47 @@ class Transport(object):
|
||||
# pass the hosts dicts to the connection pool to optionally extract parameters from
|
||||
self.connection_pool = self.connection_pool_class(zip(connections, hosts), **self.kwargs)
|
||||
|
||||
def perform_request(self, method, url, params=None, body=None):
|
||||
def get_connection(self, sniffing=False):
|
||||
if not sniffing and self.sniff_after_requests:
|
||||
if self.req_counter >= self.sniff_after_requests:
|
||||
self.sniff_hosts()
|
||||
self.req_counter += 1
|
||||
return self.connection_pool.get_connection()
|
||||
|
||||
def sniff_hosts(self, failure=False):
|
||||
# set the counter to 0 first so that perform_request doesn't trigger an
|
||||
# infinite loop
|
||||
self.req_counter = 0
|
||||
_, node_info = self.perform_request('GET', '/_cluster/nodes', sniffing=True)
|
||||
hosts = self.nodes_to_host_callback(node_info['nodes'], self.connection_class.transport_schema)
|
||||
self.set_connections(hosts)
|
||||
|
||||
# when sniffing due to failure, shorten the period between sniffs progressively
|
||||
if failure:
|
||||
self.sniffs_due_to_failure += 1
|
||||
if self.sniff_after_requests:
|
||||
self.sniff_after_requests = 1 + self.sniff_after_requests_original // 2**self.sniffs_due_to_failure
|
||||
else:
|
||||
self.sniffs_due_to_failure = 0
|
||||
self.sniff_after_requests = self.sniff_after_requests_original
|
||||
|
||||
def mark_dead(self, connection, sniffing=False):
|
||||
if not sniffing and self.sniff_on_connection_fail:
|
||||
self.sniff_hosts(True)
|
||||
else:
|
||||
self.connection_pool.mark_dead(connection)
|
||||
|
||||
def perform_request(self, method, url, params=None, body=None, sniffing=False):
|
||||
for attempt in range(self.max_retries):
|
||||
connection = self.connection_pool.get_connection()
|
||||
connection = self.get_connection(sniffing)
|
||||
print connection, attempt, method, url
|
||||
|
||||
if body:
|
||||
body = self.serializer.dumps(body)
|
||||
try:
|
||||
status, raw_data = connection.perform_request(method, url, params, body)
|
||||
except TransportError:
|
||||
self.connection_pool.mark_dead(connection)
|
||||
self.mark_dead(connection, sniffing)
|
||||
|
||||
# raise exception on last retry
|
||||
if attempt + 1 == self.max_retries:
|
||||
|
||||
@@ -17,6 +17,20 @@ class DummyConnection(Connection):
|
||||
raise self.exception
|
||||
return self.status, self.data
|
||||
|
||||
CLUSTER_NODES = '''{
|
||||
"ok" : true,
|
||||
"cluster_name" : "super_cluster",
|
||||
"nodes" : {
|
||||
"wE_6OGBNSjGksbONNncIbg" : {
|
||||
"name" : "Nightwind",
|
||||
"transport_address" : "inet[/127.0.0.1:9300]",
|
||||
"hostname" : "wind",
|
||||
"version" : "0.20.4",
|
||||
"http_address" : "inet[/1.1.1.1:123]"
|
||||
}
|
||||
}
|
||||
}'''
|
||||
|
||||
class TestTransport(TestCase):
|
||||
def test_kwargs_passed_on_to_connections(self):
|
||||
t = Transport([{'host': 'google.com'}], port=123)
|
||||
@@ -47,7 +61,7 @@ class TestTransport(TestCase):
|
||||
t = Transport([{'exception': TransportError('abandon ship')}], connection_class=DummyConnection)
|
||||
|
||||
self.assertRaises(TransportError, t.perform_request, 'GET', '/')
|
||||
self.assertEquals(3, len(t.connection_pool.get_connection().calls))
|
||||
self.assertEquals(3, len(t.get_connection().calls))
|
||||
|
||||
def test_failed_connection_will_be_marked_as_dead(self):
|
||||
t = Transport([{'exception': TransportError('abandon ship')}], connection_class=DummyConnection)
|
||||
@@ -55,3 +69,39 @@ class TestTransport(TestCase):
|
||||
self.assertRaises(TransportError, t.perform_request, 'GET', '/')
|
||||
self.assertEquals(0, len(t.connection_pool.connections))
|
||||
|
||||
def test_sniff_on_start_fetches_and_uses_nodes_list(self):
|
||||
t = Transport([{'data': CLUSTER_NODES}], connection_class=DummyConnection, sniff_on_start=True)
|
||||
self.assertEquals(1, len(t.connection_pool.connections))
|
||||
self.assertEquals('http://1.1.1.1:123', t.get_connection().host)
|
||||
|
||||
def test_sniff_on_fail_triggers_sniffing_on_fail(self):
|
||||
t = Transport([{'exception': TransportError('abandon ship')}, {"data": CLUSTER_NODES}],
|
||||
connection_class=DummyConnection, sniff_on_connection_fail=True, max_retries=1, randomize_hosts=False)
|
||||
|
||||
self.assertRaises(TransportError, t.perform_request, 'GET', '/')
|
||||
self.assertEquals(1, len(t.connection_pool.connections))
|
||||
self.assertEquals('http://1.1.1.1:123', t.get_connection().host)
|
||||
|
||||
def test_sniff_after_n_requests(self):
|
||||
t = Transport([{"data": CLUSTER_NODES}],
|
||||
connection_class=DummyConnection, sniff_after_requests=5)
|
||||
|
||||
for _ in range(4):
|
||||
t.perform_request('GET', '/')
|
||||
self.assertEquals(1, len(t.connection_pool.connections))
|
||||
self.assertIsInstance(t.get_connection(), DummyConnection)
|
||||
|
||||
t.perform_request('GET', '/')
|
||||
self.assertEquals(1, len(t.connection_pool.connections))
|
||||
self.assertEquals('http://1.1.1.1:123', t.get_connection().host)
|
||||
|
||||
def test_sniff_on_failure_shortens_sniff_after_n_requests(self):
|
||||
t = Transport([{'exception': TransportError('abandon ship')}, {"data": CLUSTER_NODES}],
|
||||
connection_class=DummyConnection, sniff_on_connection_fail=True, max_retries=1,
|
||||
randomize_hosts=False, sniff_after_requests=4)
|
||||
|
||||
self.assertRaises(TransportError, t.perform_request, 'GET', '/')
|
||||
self.assertEquals(1, len(t.connection_pool.connections))
|
||||
self.assertEquals('http://1.1.1.1:123', t.get_connection().host)
|
||||
self.assertEquals(3, t.sniff_after_requests)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user