Source code for kazoo.testing.harness

"""Kazoo testing harnesses"""
import atexit
import logging
import os
import uuid
import threading
import unittest

from kazoo.client import KazooClient
from kazoo.exceptions import NotEmptyError
from kazoo.protocol.states import (
from kazoo.testing.common import ZookeeperCluster
from kazoo.protocol.connection import _SESSION_EXPIRED

log = logging.getLogger(__name__)


def get_global_cluster():
    global CLUSTER
    if CLUSTER is None:
        ZK_HOME = os.environ.get("ZOOKEEPER_PATH")
        ZK_CLASSPATH = os.environ.get("ZOOKEEPER_CLASSPATH")
        ZK_PORT_OFFSET = int(os.environ.get("ZOOKEEPER_PORT_OFFSET", 20000))

        assert ZK_HOME or ZK_CLASSPATH, (
            "either ZOOKEEPER_PATH or ZOOKEEPER_CLASSPATH environment variable "
            "must be defined.\n"
            "For deb package installations this is /usr/share/java")

        CLUSTER = ZookeeperCluster(
        atexit.register(lambda cluster: cluster.terminate(), CLUSTER)
    return CLUSTER

[docs]class KazooTestHarness(unittest.TestCase): """Harness for testing code that uses Kazoo This object can be used directly or as a mixin. It supports starting and stopping a complete ZooKeeper cluster locally and provides an API for simulating errors and expiring sessions. Example:: class MyTestCase(KazooTestHarness): def setUp(self): self.setup_zookeeper() # additional test setup def tearDown(self): self.teardown_zookeeper() def test_something(self): something_that_needs_a_kazoo_client(self.client) def test_something_else(self): something_that_needs_zk_servers(self.servers) """ def __init__(self, *args, **kw): super(KazooTestHarness, self).__init__(*args, **kw) self.client = None self._clients = [] @property def cluster(self): return get_global_cluster() @property def servers(self): return ",".join([s.address for s in self.cluster]) def _get_nonchroot_client(self): return KazooClient(self.servers) def _get_client(self, **kwargs): c = KazooClient(self.hosts, **kwargs) try: self._clients.append(c) except AttributeError: self._client = [c] return c def expire_session(self, client_id=None): """Force ZK to expire a client session :param client_id: id of client to expire. If unspecified, the id of self.client will be used. """ client_id = client_id or self.client.client_id lost = threading.Event() safe = threading.Event() def watch_loss(state): if state == KazooState.LOST: lost.set() if lost.is_set() and state == KazooState.CONNECTED: safe.set() return True self.client.add_listener(watch_loss) self.client._call(_SESSION_EXPIRED, None) lost.wait(5) if not lost.isSet(): raise Exception("Failed to get notified of session loss") # Wait for the reconnect now safe.wait(15) if not safe.isSet(): raise Exception("Failed to see client reconnect") self.client.retry(self.client.get_async, '/') def setup_zookeeper(self, **client_options): """Create a ZK cluster and chrooted :class:`KazooClient` The cluster will only be created on the first invocation and won't be fully torn down until exit. """ if not self.cluster[0].running: self.cluster.start() namespace = "/kazootests" + uuid.uuid4().hex self.hosts = self.servers + namespace if 'timeout' not in client_options: client_options['timeout'] = 0.8 self.client = self._get_client(**client_options) self.client.start() self.client.ensure_path("/") def teardown_zookeeper(self): """Clean up any ZNodes created during the test """ if not self.cluster[0].running: self.cluster.start() tries = 0 if self.client and self.client.connected: while tries < 3: try: self.client.retry(self.client.delete, '/', recursive=True) break except NotEmptyError: pass tries += 1 self.client.stop() self.client.close() del self.client else: client = self._get_client() client.start() client.retry(client.delete, '/', recursive=True) client.stop() client.close() del client for client in self._clients: client.stop() del client self._clients = None
[docs]class KazooTestCase(KazooTestHarness): def setUp(self): self.setup_zookeeper() def tearDown(self): self.teardown_zookeeper()