Skip to content

celery

Updated celery configuration.

handle_worker_shutdown(sender=None, **kwargs)

Update the database for a worker entry when a worker shuts down.

Parameters:

Name Type Description Default
sender str

The hostname of the worker that was just started

None
Source code in merlin/celery.py
@worker_shutdown.connect
def handle_worker_shutdown(sender: str = None, **kwargs):
    """
    Update the database for a worker entry when a worker shuts down.

    Args:
        sender (str): The hostname of the worker that was just started
    """
    if sender is not None:
        LOG.debug(f"Worker {sender} is shutting down.")
        merlin_db = MerlinDatabase()
        physical_worker = merlin_db.get("physical_worker", str(sender))
        if physical_worker:
            physical_worker.set_worker_status(WorkerStatus.STOPPED)
            physical_worker.set_pid(None)  # Clear the pid
        else:
            LOG.warning(f"Worker {sender} not found in the database.")
    else:
        LOG.warning("On worker shutdown no sender was provided from Celery.")

handle_worker_startup(sender=None, **kwargs)

Store information about each physical worker instance in the database.

When workers first start up, the celeryd_init signal is the first signal that they receive. This specific function will create a PhysicalWorkerModel and store it in the database. It does this through the use of the MerlinDatabase class.

Parameters:

Name Type Description Default
sender str

The hostname of the worker that was just started

None
Source code in merlin/celery.py
@celeryd_init.connect
def handle_worker_startup(sender: str = None, **kwargs):
    """
    Store information about each physical worker instance in the database.

    When workers first start up, the `celeryd_init` signal is the first signal
    that they receive. This specific function will create a
    [`PhysicalWorkerModel`][db_scripts.data_models.PhysicalWorkerModel] and
    store it in the database. It does this through the use of the
    [`MerlinDatabase`][db_scripts.merlin_db.MerlinDatabase] class.

    Args:
        sender (str): The hostname of the worker that was just started
    """
    if sender is not None:
        LOG.debug(f"Worker {sender} has started.")
        options = kwargs.get("options", None)
        if options is not None:
            try:
                # Sender name is of the form celery@worker_name.%hostname
                worker_name, host = sender.split("@")[1].split(".%")
                merlin_db = MerlinDatabase()
                logical_worker = merlin_db.get("logical_worker", worker_name=worker_name, queues=options.get("queues"))
                physical_worker = merlin_db.create(
                    "physical_worker",
                    name=str(sender),
                    host=host,
                    worker_status=WorkerStatus.RUNNING.value,
                    logical_worker_id=logical_worker.get_id(),
                    pid=os.getpid(),
                )
                logical_worker.add_physical_worker(physical_worker.get_id())
            # Without this exception catcher, celery does not output any errors that happen here
            except Exception as exc:
                LOG.error(f"An error occurred when processing handle_worker_startup: {exc}")
        else:
            LOG.warning("On worker connect could not retrieve worker options from Celery.")
    else:
        LOG.warning("On worker connect no sender was provided from Celery.")

patch_celery()

Patch redis backend so that errors in chords don't break workflows. Celery has error callbacks but they do not work properly on chords that are nested within chains.

Credit to this function goes to the following post.

Source code in merlin/celery.py
def patch_celery():
    """
    Patch redis backend so that errors in chords don't break workflows.
    Celery has error callbacks but they do not work properly on chords that
    are nested within chains.

    Credit to this function goes to
    [the following post](https://danidee10.github.io/2019/07/09/celery-chords.html).
    """

    def _unpack_chord_result(
        self,
        tup,
        decode,
        EXCEPTION_STATES=states.EXCEPTION_STATES,
        PROPAGATE_STATES=states.PROPAGATE_STATES,
    ):
        _, tid, state, retval = decode(tup)

        if state in EXCEPTION_STATES:
            retval = self.exception_to_python(retval)
        if state in PROPAGATE_STATES:
            # retval is an Exception
            retval = f"{retval.__class__.__name__}: {str(retval)}"

        return retval

    celery.backends.redis.RedisBackend._unpack_chord_result = _unpack_chord_result

    return celery

route_for_task(name, args, kwargs, options, task=None, **kw)

Custom task router for Celery queues.

This function routes tasks to specific queues based on the task name. If the task name contains a colon, it splits the name to determine the queue.

Parameters:

Name Type Description Default
name str

The name of the task being routed.

required
args List[Any]

The positional arguments passed to the task.

required
kwargs Dict[Any, Any]

The keyword arguments passed to the task.

required
options Dict[Any, Any]

Additional options for the task.

required
task Task

The task instance (default is None).

None
**kw Dict[Any, Any]

Additional keyword arguments for THIS function (not the task).

{}

Returns:

Type Description
Dict[Any, Any]

A dictionary specifying the queue to route the task to. If the task name contains a colon, it returns a dictionary with the key "queue" set to the queue name. Otherwise, it returns an empty dictionary.

Example

Using a colon in the name will return the string before the colon as the queue:

>>> route_for_task("my_queue:my_task")
{"queue": "my_queue"}
Source code in merlin/celery.py
def route_for_task(
    name: str,
    args: List[Any],
    kwargs: Dict[Any, Any],
    options: Dict[Any, Any],
    task: celery.Task = None,
    **kw: Dict[Any, Any],
) -> Dict[Any, Any]:  # pylint: disable=W0613,R1710
    """
    Custom task router for Celery queues.

    This function routes tasks to specific queues based on the task name.
    If the task name contains a colon, it splits the name to determine the queue.

    Args:
        name: The name of the task being routed.
        args: The positional arguments passed to the task.
        kwargs: The keyword arguments passed to the task.
        options: Additional options for the task.
        task: The task instance (default is None).
        **kw: Additional keyword arguments for THIS function (not the task).

    Returns:
        A dictionary specifying the queue to route the task to.
            If the task name contains a colon, it returns a dictionary with
            the key "queue" set to the queue name. Otherwise, it returns
            an empty dictionary.

    Example:
        Using a colon in the name will return the string before the colon as the queue:

        ```python
        >>> route_for_task("my_queue:my_task")
        {"queue": "my_queue"}
        ```
    """
    if ":" in name:
        queue, _ = name.split(":")
        return {"queue": queue}

setup(**kwargs)

Set affinity for the worker on startup (works on toss3 nodes).

Parameters:

Name Type Description Default
**kwargs Dict[Any, Any]

Keyword arguments.

{}
Source code in merlin/celery.py
@worker_process_init.connect()
def setup(**kwargs: Dict[Any, Any]):  # pylint: disable=W0613
    """
    Set affinity for the worker on startup (works on toss3 nodes).

    Args:
        **kwargs: Keyword arguments.
    """
    if "CELERY_AFFINITY" in os.environ and int(os.environ["CELERY_AFFINITY"]) > 1:
        # Number of cpus between workers.
        cpu_skip: int = int(os.environ["CELERY_AFFINITY"])
        npu: int = psutil.cpu_count()
        process: psutil.Process = psutil.Process()
        # pylint is upset that typing accesses a protected class, ignoring W0212
        # pylint is upset that billiard doesn't have a current_process() method - it does
        current: billiard.process._MainProcess = billiard.current_process()  # pylint: disable=W0212, E1101
        prefork_id: int = current._identity[0] - 1  # pylint: disable=W0212  # range 0:nworkers-1
        cpu_slot: int = (prefork_id * cpu_skip) % npu
        process.cpu_affinity(list(range(cpu_slot, cpu_slot + cpu_skip)))