Source code for hstreamdb.aio.client

import functools
from typing import Optional, Any, Iterable, Type, Iterator, Dict, List, Tuple
import types
import grpc
import logging
from contextlib import asynccontextmanager
import itertools
from urllib.parse import urlparse

import HStream.Server.HStreamApi_pb2 as ApiPb
import HStream.Server.HStreamApi_pb2_grpc as ApiGrpc
from hstreamdb.aio.producer import BufferedProducer, AppendPayload
from hstreamdb.aio.consumer import Consumer
from hstreamdb.types import (
    RecordId,
    Record,
    record_id_from,
    Stream,
    stream_type_from,
    Subscription,
    subscription_type_from,
    Shard,
    shard_type_from,
    SpecialOffset,
    ShardOffset,
)
from hstreamdb.utils import (
    find_shard_id,
    decode_records,
    encode_records,
    encode_records_from_append_payload,
)

__all__ = ["secure_client", "insecure_client", "HStreamDBClient"]

logger = logging.getLogger(__name__)


def dec_api(f):
    @functools.wraps(f)
    async def wrapper(client, *args, **kargs):
        try:
            return await f(client, *args, **kargs)
        except grpc.aio.AioRpcError as e:
            # The service is currently unavailable, so we choose another
            if e.code() == grpc.StatusCode.UNAVAILABLE:
                await client._switch_channel()
                return await f(client, *args, **kargs)
            else:
                raise e

    return wrapper


[docs]class HStreamDBClient: _TargetTy = str _stub: ApiGrpc.HStreamApiStub _channels: Dict[_TargetTy, Optional[grpc.aio.Channel]] _current_target: _TargetTy # {(stream_name, shard_id)} _append_channels: Dict[Tuple[str, int], _TargetTy] _subscription_channels: Dict[str, _TargetTy] _reader_channels: Dict[str, _TargetTy] _shards_info: Dict[str, List[Shard]] _cons_target = staticmethod(lambda host, port: f"{host}:{port}") def __init__( self, host: str = "127.0.0.1", port: int = 6570, credentials=None ): self._current_target = self._cons_target(host, port) self._channels = {} self._append_channels = {} self._subscription_channels = {} self._reader_channels = {} self._shards_info = {} if credentials: _channel = grpc.aio.secure_channel( self._current_target, credentials ) else: _channel = grpc.aio.insecure_channel(self._current_target) self._channels[self._current_target] = _channel self._stub = ApiGrpc.HStreamApiStub(_channel)
[docs] async def init_cluster_info(self): cluster_info = await self._stub.DescribeCluster(None) # TODO: check protocolVersion, serverVersion for node in cluster_info.serverNodes: target = self._cons_target(node.host, node.port) if target not in self._channels: self._channels[target] = None
# -------------------------------------------------------------------------
[docs] @dec_api async def create_stream( self, name, replication_factor=1, backlog=0, shard_count=1 ): """ Args: name: stream name replication_factor: how stream can be replicated across nodes in the cluster backlog: how long streams of HStreamDB retain records after being appended, in senconds. """ await self._stub.CreateStream( ApiPb.Stream( streamName=name, replicationFactor=replication_factor, backlogDuration=backlog, shardCount=shard_count, ) )
[docs] @dec_api async def delete_stream(self, name, ignore_non_exist=False, force=False): await self._stub.DeleteStream( ApiPb.DeleteStreamRequest( streamName=name, ignoreNonExist=ignore_non_exist, force=force ) )
[docs] @dec_api async def list_streams(self) -> Iterator[Stream]: """List all streams""" r = await self._stub.ListStreams(ApiPb.ListStreamsRequest()) return (stream_type_from(s) for s in r.streams)
[docs] async def append( self, name: str, payloads: Iterable[Any], key: Optional[str] = None, compresstype=None, compresslevel=9, ) -> Iterator[RecordId]: """Append payloads to a stream. Args: name: stream name payloads: a list of string, bytes or dict(json). key: Optional stream key. Returns: Appended RecordIds generator """ shard_id, channel = await self._lookup_append(name, key, None) stub = ApiGrpc.HStreamApiStub(channel) r = await stub.Append( ApiPb.AppendRequest( streamName=name, shardId=shard_id, records=encode_records( payloads, key=key, compresstype=compresstype, compresslevel=compresslevel, ), ) ) return (record_id_from(x) for x in r.recordIds)
[docs] def new_producer( self, append_callback: Optional[Type[BufferedProducer.AppendCallback]] = None, size_trigger=0, # NOTE: this is the size of uncompressed records time_trigger=0, workers=1, retry_count=0, retry_max_delay=60, compresstype=None, compresslevel=9, ): return BufferedProducer( self._append_with_shard, self._find_shard_id, append_callback=append_callback, size_trigger=size_trigger, time_trigger=time_trigger, workers=workers, retry_count=retry_count, retry_max_delay=retry_max_delay, )
[docs] @dec_api async def list_shards(self, stream_name) -> List[Shard]: # FIXME: what if shards_info can be changed? shards = self._shards_info.get(stream_name) if not shards: r = await self._stub.ListShards( ApiPb.ListShardsRequest(streamName=stream_name) ) shards = [shard_type_from(s) for s in r.shards] self._shards_info[stream_name] = shards return shards
[docs] @dec_api async def create_subscription( self, subscription_id: str, stream_name: str, ack_timeout: int = 600, # 10min max_unacks: int = 10000, offset: SpecialOffset = SpecialOffset.LATEST, ): await self._stub.CreateSubscription( ApiPb.Subscription( subscriptionId=subscription_id, streamName=stream_name, ackTimeoutSeconds=ack_timeout, maxUnackedRecords=max_unacks, offset=offset, ) )
[docs] @dec_api async def list_subscriptions(self) -> Iterator[Subscription]: r = await self._stub.ListSubscriptions(None) return (subscription_type_from(s) for s in r.subscription)
[docs] @dec_api async def does_subscription_exist(self, subscription_id: str): r = await self._stub.CheckSubscriptionExist( ApiPb.CheckSubscriptionExistRequest(subscriptionId=subscription_id) ) return r.exists
[docs] async def delete_subscription(self, subscription_id: str, force=False): channel = await self._lookup_subscription(subscription_id) stub = ApiGrpc.HStreamApiStub(channel) await stub.DeleteSubscription( ApiPb.DeleteSubscriptionRequest( subscriptionId=subscription_id, force=force ) )
[docs] def new_consumer(self, name: str, subscription_id: str, processing_func): async def find_stub(): channel = await self._lookup_subscription(subscription_id) return ApiGrpc.HStreamApiStub(channel) return Consumer( name, subscription_id, find_stub, processing_func, )
[docs] @asynccontextmanager async def with_reader( self, stream_name: str, reader_id: str, shard_offset: ShardOffset, timeout: int, shard_id: Optional[int] = None, stream_key: Optional[str] = None, ): await self.create_reader( stream_name, reader_id, shard_offset, timeout, shard_id=shard_id, stream_key=stream_key, ) try: obj = types.SimpleNamespace() obj.read = lambda max_records: self.read_reader( reader_id, max_records ) yield obj finally: await self.delete_reader(reader_id)
[docs] @dec_api async def create_reader( self, stream_name: str, reader_id: str, shard_offset: ShardOffset, timeout: int, shard_id: Optional[int] = None, stream_key: Optional[str] = None, ): """Create a reader. If the 'shard_id' is None, then use the shard which the optional 'stream_key' corresponds. """ if shard_id is None: shard_id = await self._find_shard_id(stream_name, key=stream_key) return await self._stub.CreateShardReader( ApiPb.CreateShardReaderRequest( streamName=stream_name, shardId=shard_id, shardOffset=shard_offset, readerId=reader_id, timeout=timeout, ) )
[docs] async def read_reader( self, reader_id: str, max_records: str ) -> Iterator[Record]: stub = await self._lookup_reader_stub(reader_id) resp = await stub.ReadShard( ApiPb.ReadShardRequest(readerId=reader_id, maxRecords=max_records) ) return itertools.chain.from_iterable( decode_records(r) for r in resp.receivedRecords )
[docs] async def delete_reader(self, reader_id: str) -> None: stub = await self._lookup_reader_stub(reader_id) await stub.DeleteShardReader( ApiPb.DeleteShardReaderRequest(readerId=reader_id) ) return None
# ------------------------------------------------------------------------- async def _find_shard_id(self, stream_name, key=None) -> int: shards = await self.list_shards(stream_name) return find_shard_id(shards, key=key) async def _append_with_shard( self, name: str, payloads: List[AppendPayload], shard_id: int, compresstype=None, compresslevel=9, ) -> Iterator[RecordId]: if not payloads: logger.warning("Empty payloads, ignored.") return shard_id, channel = await self._lookup_append(name, None, shard_id) stub = ApiGrpc.HStreamApiStub(channel) r = await stub.Append( ApiPb.AppendRequest( streamName=name, shardId=shard_id, records=encode_records_from_append_payload( payloads, compresstype=compresstype, compresslevel=compresslevel, ), ) ) return (record_id_from(x) for x in r.recordIds) async def _lookup_append(self, name, key, shard_id): if shard_id is not None: keyid = shard_id # NOTE: do not use this 'key', the 'key' param has no means. del key else: keyid = await self._find_shard_id(name, key=key) target = self._append_channels.get((name, keyid)) if not target: node = await self._lookup_append_api(keyid) target = self._cons_target(node.host, node.port) self._append_channels[(name, keyid)] = target if not shard_id: logger.debug(f"Find target for stream <{name},{key}>: {target}") else: logger.debug( f"Find target for stream <{name}> with shard id <{shard_id}>: {target}" ) return keyid, self._get_channel(target) async def _lookup_append_stub(self, name, key, shard_id): keyid, channel = self._lookup_append(name, key, shard_id) return keyid, ApiGrpc.HStreamApiStub(channel) async def _lookup_subscription(self, subscription_id: str): target = self._subscription_channels.get(subscription_id) if not target: node = await self._lookup_subscription_api(subscription_id) target = self._cons_target(node.host, node.port) self._subscription_channels[subscription_id] = target logger.debug( f"Find target for subscription <{subscription_id}>: {target}" ) return self._get_channel(target) async def _lookup_subscription_stub(self, subscription_id: str): channel = self._lookup_subscription(subscription_id) return ApiGrpc.HStreamApiStub(channel) async def _lookup_reader(self, reader_id: str): target = self._reader_channels.get(reader_id) if not target: node = await self._lookup_reader_api(reader_id) target = self._cons_target(node.host, node.port) self._reader_channels[reader_id] = target logger.debug(f"Find target for reader <{reader_id}>: {target}") return self._get_channel(target) async def _lookup_reader_stub(self, reader_id: str): channel = await self._lookup_reader(reader_id) return ApiGrpc.HStreamApiStub(channel) @dec_api async def _lookup_append_api(self, shard_id): r = await self._stub.LookupShard( ApiPb.LookupShardRequest(shardId=shard_id) ) # there is no reason that returned value does not equal to requested. assert r.shardId == shard_id return r.serverNode @dec_api async def _lookup_subscription_api(self, subscription_id: str): r = await self._stub.LookupSubscription( ApiPb.LookupSubscriptionRequest(subscriptionId=subscription_id) ) assert r.subscriptionId == subscription_id return r.serverNode @dec_api async def _lookup_reader_api(self, reader_id: str): r = await self._stub.LookupShardReader( ApiPb.LookupShardReaderRequest(readerId=reader_id) ) assert r.readerId == reader_id return r.serverNode # ------------------------------------------------------------------------- async def _switch_channel(self): while True: logger.warning( f"Target {self._current_target} unavailable, switching to another..." ) # remove unavailable target self._channels.pop(self._current_target) if not self._channels: raise RuntimeError("No unavailable targets!") # Now, self._channels should not be empty. self._current_target = list(self._channels.keys())[0] channel = self._get_channel(self._current_target) self._stub = ApiGrpc.HStreamApiStub(channel) try: return await self.init_cluster_info() except grpc.aio.AioRpcError as e: # The service is currently unavailable, so we choose another logger.warning( f"Fetch cluster info from {self._current_target} failed! \n {e}" ) continue def _get_channel(self, target): channel = self._channels.get(target) if channel: return channel else: # new channel channel = grpc.aio.insecure_channel(target) self._channels[target] = channel return channel # ------------------------------------------------------------------------- async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): for target, channel in self._channels.items(): if channel: await channel.close(grace=None)
[docs]async def insecure_client(host="127.0.0.1", port=6570, url=None): """Creates an insecure client to a cluster. Args: host: hostname to connect to HStreamDB, defaults to '127.0.0.1' port: port to connect to HStreanDB, defaults to 6570 url: alternative service url to connect to HStreamDB, it should be in 'hstream://your-host' format. Note that if you provide this url then the 'host' and 'port' args will be ignored. Returns: A :class:`HStreamDBClient` """ if url: o = urlparse(url) if o.scheme != "hstream": raise KeyError(f"Invalid service url scheme {o.scheme}") host = o.hostname # FIXME: should the default port be the same as secure_client? port = o.port or 6570 client = HStreamDBClient(host=host, port=port) await client.init_cluster_info() return client
[docs]async def secure_client( host="127.0.0.1", port=6570, url=None, is_creds_file=False, root_certificates=None, private_key=None, certificate_chain=None, ): """Creates a secure client to a cluster. Args: host: hostname to connect to HStreamDB, defaults to '127.0.0.1' port: port to connect to HStreanDB, defaults to 6570 url: alternative service url to connect to HStreamDB, it should be in 'hstreams://your-host' format. Note that if you provide this url then the 'host' and 'port' args will be ignored. is_creds_file: whether the credentials is a filepath or the contents. root_certificates: The PEM-encoded root certificates as a byte string, or None to retrieve them from a default location chosen by gRPC runtime. Note: if 'is_creds_file' is True this is the filepath instead of the contents. private_key: The PEM-encoded private key as a byte string, or None if no private key should be used. Note: if 'is_creds_file' is True this is the filepath instead of the contents. certificate_chain: The PEM-encoded certificate chain as a byte string to use or None if no certificate chain should be used. Note: if 'is_creds_file' is True this is the filepath instead of the contents. Returns: A :class:`HStreamDBClient` """ if is_creds_file: with open(root_certificates, "rb") as f_cert, open( private_key, "rb" ) as f_key, open(certificate_chain, "rb") as f_chain: root_certificates = f_cert.read() private_key = f_key.read() certificate_chain = f_chain.read() creds = grpc.ssl_channel_credentials( root_certificates=root_certificates, private_key=private_key, certificate_chain=certificate_chain, ) if url: o = urlparse(url) if o.scheme != "hstreams": raise KeyError(f"Invalid service url scheme {o.scheme}") host = o.hostname port = o.port or 6570 client = HStreamDBClient(host=host, port=port, credentials=creds) await client.init_cluster_info() return client