Source code for hstreamdb.aio.producer

import abc
import asyncio
from typing import (
    Optional,
    Any,
    Callable,
    Awaitable,
    Iterator,
    Type,
    List,
    Dict,
    Sized,
    Tuple,
    Union,
)
import logging

import HStream.Server.HStreamApi_pb2 as ApiPb

from hstreamdb.types import RecordId
from hstreamdb.utils import encode_payload

__all__ = [
    "PayloadsFull",
    "PayloadTooBig",
    "AppendPayload",
    "PayloadGroup",
    "BufferedProducer",
]

logger = logging.getLogger(__name__)


class Timer:
    def __init__(self, delay, coro):
        self._delay = delay
        self._coro = coro
        self._continue: asyncio.Event = asyncio.Event()
        self._enable: bool = True
        self._task: asyncio.Task = asyncio.create_task(self._loop())

    def start(self):
        self._continue.set()

    def stop(self):
        self._continue.clear()
        self._task.cancel()

    def exit(self):
        self._enable = False

    async def _loop(self):
        while True:
            if not self._enable:
                break
            try:
                await self._continue.wait()
                await asyncio.sleep(self._delay)
                await asyncio.shield(self._coro())
            except asyncio.CancelledError:
                # do nothing
                logger.debug("Timer: receive CancelledError")


class PayloadsFull(Exception):
    pass


class PayloadTooBig(Exception):
    pass


class Payload(abc.ABC, Sized):
    @abc.abstractmethod
    def __len__(self):
        ...


class PayloadGroup:
    def __init__(self, queue, key, maxsize=0, maxtime=0):
        self._payloads: List[Payload] = []
        self._size: int = 0
        self._flushing_payloads: List[Payload] = []
        self._flushing_size: int = 0
        self._flush_done: asyncio.Event = asyncio.Event()

        self._key: Any = key
        self._notify_queue: asyncio.Queue = queue
        self._lock: asyncio.Lock = asyncio.Lock()
        self._maxsize: int = maxsize
        self._timer: Optional[Timer] = (
            Timer(maxtime, self.flush) if maxtime > 0 else None
        )

    async def append(self, payload: Payload):
        if not self._payloads and self._timer:
            # The first payload comes, set timer
            self._timer.start()

        payload_size = len(payload)

        if self._maxsize > 0 and payload_size > self._maxsize:
            raise PayloadTooBig

        # reach maxsize
        if self._upper(payload_size):
            await self.flush()

        await self._append_nowait(payload)

    async def flush(self):
        if self._flushing_payloads:
            # block until last flushing done
            logger.debug("waiting last flush done...")
            await self._flush_done.wait()

        if self._timer:
            self._timer.stop()

        # no current flushing, trigger it
        self._flush_nowait()

    def pop(self):
        return self._flushing_payloads, self._flushing_size

    async def post_flush(self):
        self._flushing_payloads = []
        self._flushing_size = 0
        self._flush_done.set()

    def exit(self):
        if self._timer:
            self._timer.exit()

    def _flush_nowait(self):
        self._flushing_payloads = self._payloads
        self._flushing_size = self._size
        self._payloads = []
        self._size = 0
        self._flush_done.clear()
        self._notify_queue.put_nowait(self._key)

    async def _append_nowait(self, payload: Payload):
        """Put a payload into the payloads without blocking.

        If no free bytes is immediately available, raise PayloadsFull.
        """
        payload_size = len(payload)
        if self._upper(payload_size):
            raise PayloadsFull

        # FIXME: does this lock really needed?
        async with self._lock:
            self._payloads.append(payload)
            self._size += payload_size

    def _upper(self, size):
        """Return True if there are not exceed maxsize bytes.

        Note: if the Payloads was initialized with maxsize=0 (the default),
        then _upper() is never True.
        """
        if self._maxsize <= 0:
            return False
        else:
            return (self._size + size) > self._maxsize


# -----------------------------------------------------------------------------


class AppendPayload(Payload):
    _payload_bin: bytes
    _payload_type: ApiPb.HStreamRecordHeader.Flag

    def __init__(
        self,
        payload: Union[bytes, str, Dict[Any, Any]],
        key: Optional[str] = None,
    ):
        self.payload: Union[bytes, str, Dict[Any, Any]] = payload
        self._payload_bin, self._payload_type = encode_payload(self.payload)
        self.key: Optional[str] = key

    def __len__(self):
        return len(self._payload_bin)


[docs]class BufferedProducer: StreamKeyId = int GroupKeyTy = Tuple[str, StreamKeyId] # (stream_name, shard_id)
[docs] class AppendCallback(abc.ABC):
[docs] @abc.abstractmethod def on_success( self, stream_name: str, payloads: List[AppendPayload], stream_keyid: int, ): ...
[docs] @abc.abstractmethod def on_fail( self, stream_name: str, payloads: List[AppendPayload], stream_keyid: int, e: Exception, ): ...
def __init__( self, flush_coro: Callable[ [str, List[AppendPayload], int, Optional[str], int], Awaitable[Iterator[RecordId]], ], find_stream_key_id_coro: Callable[ [str, Optional[str]], Awaitable[StreamKeyId] ], append_callback: Optional[Type[AppendCallback]] = None, size_trigger=0, time_trigger=0, workers=1, retry_count=0, retry_max_delay=60, # seconds compresstype=None, compresslevel=9, ): if workers < 1: raise ValueError("workers must be no less than 1") self._group: Dict[BufferedProducer.GroupKeyTy, PayloadGroup] = {} self._size_trigger = size_trigger self._time_trigger = time_trigger self._retry_count = retry_count self._retry_max_delay = retry_max_delay self._flush_coro = flush_coro self._compresstype = compresstype self._compresslevel = compresslevel self._find_stream_key_id_coro = find_stream_key_id_coro self._append_callback = append_callback self._queues = [asyncio.Queue() for _ in range(workers)] self._workers = [ asyncio.create_task(self._loop_queue(self._queues[i])) for i in range(workers) ]
[docs] async def append( self, stream_name: str, payload: Union[bytes, str, Dict[Any, Any]], key: Optional[str] = None, ): group_key = await self._fetch_group_key(stream_name, key) bpayload = AppendPayload(payload, key=key) payloads: PayloadGroup if group_key not in self._group: payloads = PayloadGroup( self._find_queue(group_key), group_key, maxsize=self._size_trigger, maxtime=self._time_trigger, ) self._group[group_key] = payloads else: payloads = self._group[group_key] await payloads.append(bpayload)
[docs] async def flush(self, stream_name: str, shard_id: int): group_key = self._cons_group_key(stream_name, shard_id) await self._flush(group_key)
[docs] async def flush_by_key(self, stream_name: str, key: Optional[str] = None): group_key = await self._fetch_group_key(stream_name, key) await self._flush(group_key)
[docs] async def flushall(self): for _, payloads in self._group.items(): await payloads.flush()
[docs] async def close(self): for _, pg in self._group.items(): pg.exit() await self.flushall() for q in self._queues: await q.put(None)
[docs] async def wait(self): try: await asyncio.gather( *[pg._timer._task for _, pg in self._group.items() if pg._timer] ) except asyncio.CancelledError: pass await asyncio.gather(*self._workers)
[docs] async def wait_and_close(self): await self.close() await self.wait()
# ------------------------------------------------------------------------- async def _flush(self, group_key): payloads = self._group.get(group_key) if not payloads: raise ValueError("No such payloads!") await payloads.flush() async def _loop_queue(self, queue): while True: group_key = await queue.get() if group_key is None: break payload_group = self._group[group_key] stream_name, stream_keyid = self._uncons_group_key(group_key) await self._flusing_worker(stream_name, payload_group, stream_keyid) queue.task_done() async def _flusing_worker(self, stream_name, payload_group, stream_keyid): payloads, _size = payload_group.pop() logger.debug( f"Flushing stream <{stream_name},{stream_keyid}> " f"with {len(payloads)} batches..." ) retries = 0 while True: try: await self._flush_coro( stream_name, payloads, stream_keyid, compresstype=self._compresstype, compresslevel=self._compresslevel, ) await payload_group.post_flush() except Exception as e: # TODO: should be a specific append exception if self._retry_count < 0 or retries < self._retry_count: logger.debug( f"Retrying {retries} with max deley {self._retry_max_delay}s..." ) await asyncio.sleep( min(2**retries, self._retry_max_delay) if self._retry_max_delay >= 0 else 2**retries ) retries += 1 continue else: if self._append_callback: return self._append_callback.on_fail( stream_name, payloads, stream_keyid, e ) else: raise e break if self._append_callback: return self._append_callback.on_success( stream_name, payloads, stream_keyid ) def _find_queue(self, group_key): return self._queues[abs(hash(group_key)) % len(self._queues)] async def _fetch_group_key(self, name, key): shard_id = await self._find_stream_key_id_coro(name, key) return self._cons_group_key(name, shard_id) @staticmethod def _cons_group_key(name, shard_id): return (name, shard_id) @staticmethod def _uncons_group_key(group_key): return (group_key[0], group_key[1])