# ------------------------------------------------------------------------------ # # Project: vsq # Authors: Fabian Schindler # # ------------------------------------------------------------------------------ # Copyright (C) 2021 EOX IT Services GmbH # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in # all copies of this Software or works derived from this Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. # ------------------------------------------------------------------------------ import sys import enum import json from dataclasses import dataclass from datetime import datetime from typing import Optional from traceback import format_tb import logging from importlib import import_module from time import time from redis import Redis, TimeoutError import click from .common import ( TaskStatus, MessageType, now, QueueClosedError, TaskFailedException, TaskTimeoutException, TaskCommon, ) from .logging import setup_logging logger = logging.getLogger(__name__) class ResponseChannel: def send_response(self, response: bytes): raise NotImplementedError() def wait_for_response(self, timeout: float = None): raise NotImplementedError() class ListResponseChannel(ResponseChannel): def __init__(self, redis: Redis, channel_name: str, expires: int = 120): self.redis = redis self.channel_name = channel_name self.expires = expires def send_response(self, response: bytes): self.redis.lpush(self.channel_name, response) if self.expires is not None: self.redis.expire( self.channel_name, self.expires ) def wait_for_response(self, timeout: float = None): _, response = self.redis.brpop(self.channel_name, timeout) return response class PubSubResponseChannel(ResponseChannel): def __init__(self, redis: Redis, channel_name: str): self.redis = redis self.channel_name = channel_name self.pubsub = redis.pubsub() self.pubsub.subscribe(channel_name) def send_response(self, response: bytes): self.redis.publish(self.channel_name, response) def wait_for_response(self, timeout: float = None): while timeout is None or timeout > 0.0: current = time() message = self.pubsub.get_message(timeout=timeout) if message['type'] == 'message': return message['data'] if timeout is not None: timeout -= time() - current @dataclass class Task(TaskCommon): response_channel: Optional[ResponseChannel] = None # Consumer API def __enter__(self): self.status = TaskStatus.PROCESSING self.started = now() return self def __exit__(self, etype, value, traceback): if etype is None: self.status = TaskStatus.DONE self.finished = now() else: self.status = TaskStatus.FAILED self.error = { 'type': str(etype), 'value': str(value), 'traceback': format_tb(traceback), } self.response_channel.send_response(self.encode()) return True def encode(self): def default(o): if isinstance(o, datetime): return o.isoformat() elif isinstance(o, Exception): return str(o) elif isinstance(o, TaskStatus): return o.value dct = { 'status': self.status, 'id': self.id, 'message': self.message, 'result': self.result, 'error': self.error, 'created': self.created, 'started': self.started, 'finished': self.finished, } return json.dumps(dct, default=default) @classmethod def decode(cls, raw_value) -> 'Task': values = json.loads(raw_value) started = ( datetime.fromisoformat(values['started']) if values['started'] else None ) finished = ( datetime.fromisoformat(values['finished']) if values['finished'] else None ) return cls( status=TaskStatus(values['status']), id=values['id'], message=values['message'], result=values['result'], error=values['error'], created=datetime.fromisoformat(values['created']), started=started, finished=finished, ) # Producer API def get(self, timeout: float = None) -> MessageType: """ Wait for the task result and return its result. Raise a ``TaskFailedException`` if the task failed. Optionally a timeout can be specified to abort when a certain time has passed. This raises a ``TaskTimeoutException`` """ raw = self.response_channel.wait_for_response(timeout) received = self.decode(raw) if received.status == TaskStatus.FAILED: raise TaskFailedException(received.error) return received.result class MessageScheme(enum.Enum): LPUSH_RPOP = 'LPUSH_RPOP' LPUSH_LPOP = 'LPUSH_LPOP' RPUSH_LPOP = 'RPUSH_LPOP' RPUSH_RPOP = 'RPUSH_RPOP' SADD_SPOP = 'SADD_SPOP' class ResponseScheme(enum.Enum): LPUSH_RPOP = 'LPUSH_RPOP' PUBSUB = 'PUBSUB' class Queue: def __init__(self, queue_name: str, redis: Redis, message_scheme: MessageScheme = MessageScheme.LPUSH_RPOP, response_scheme: ResponseScheme = ResponseScheme.LPUSH_RPOP, response_channel_template: str = 'response_{id}', response_channel_expires: Optional[int] = 120): self.queue_name = queue_name self.redis = redis self.message_scheme = message_scheme self.response_scheme = response_scheme self.response_channel_template = response_channel_template self.response_channel_expires = response_channel_expires def _get_response_channel(self, task): channel_name = self.response_channel_template.format( id=task.id, ) if self.response_scheme == ResponseScheme.LPUSH_RPOP: return ListResponseChannel(self.redis, channel_name) elif self.response_scheme == ResponseScheme.PUBSUB: return PubSubResponseChannel(self.redis, channel_name) def put(self, message: MessageType, msg_id: str = None) -> Task: if msg_id is not None: task = Task(id=msg_id, message=message) else: task = Task(message=message) task.response_channel = self._get_response_channel(task) encoded = task.encode() if self.message_scheme in (MessageScheme.LPUSH_RPOP, MessageScheme.LPUSH_LPOP): self.redis.lpush(self.queue_name, encoded) elif self.message_scheme in (MessageScheme.RPUSH_RPOP, MessageScheme.RPUSH_LPOP): self.redis.rpush(self.queue_name, encoded) elif self.message_scheme == MessageScheme.SADD_SPOP: self.redis.sadd(self.queue_name, encoded) return task def get_task(self, timeout: float = None) -> Task: try: if self.message_scheme in (MessageScheme.RPUSH_LPOP, MessageScheme.LPUSH_LPOP): result = self.redis.blpop(self.queue_name, timeout) elif self.message_scheme in (MessageScheme.RPUSH_RPOP, MessageScheme.LPUSH_RPOP): result = self.redis.brpop(self.queue_name, timeout) elif self.message_scheme == MessageScheme.SADD_SPOP: result = self.redis.spop(self.queue_name, timeout) if result is None: raise TaskTimeoutException() _, raw_value = result task = Task.decode(raw_value) task.response_channel = self._get_response_channel(task) return task except TimeoutError: raise TaskTimeoutException() def __iter__(self): while True: yield self.get_task() @click.group() @click.option('--host', type=str) @click.option('--port', show_default=True, default=6379, type=int) @click.option('--debug', type=bool) @click.pass_context def cli(ctx, host, port, debug): ctx.ensure_object(dict) ctx.obj['redis'] = Redis(host=host, port=port) ctx.obj['debug'] = debug setup_logging(debug) return 0 @cli.command() @click.argument('name', type=str) @click.argument('handler', type=str) @click.pass_context def daemon(ctx, name, handler): """Start a task daemon listening on the specified queue""" handler_mod, _, handler_name = handler.rpartition('.') handler_func = getattr(import_module(handler_mod), handler_name) queue = Queue(name, ctx.obj['redis']) logger.debug(f"waiting for tasks on queue '{name}'...") for task in queue: with task: task.result = handler_func(task.message) @cli.command() @click.argument('name', type=str) @click.argument('value', type=str) @click.option('-j', '--json', 'as_json', is_flag=True, type=bool) @click.option('-w', '--wait', is_flag=True, type=bool) @click.option('-t', '--timeout', type=int, default=None) @click.pass_context def message(ctx, name, value, as_json=False, wait=False, timeout=None): """Send a message to the specified queue""" message = value if as_json: message = json.loads(message) queue = Queue(name, ctx.obj['redis']) task = queue.put(message) if wait: try: result = task.get(timeout) print(result) return 0 except TaskFailedException as e: logger.exception(e) return 1 if __name__ == "__main__": sys.exit(cli()) # pragma: no cover