Distributed Locks for Django Celery Tasks

Implementing fast and efficient locks that work with distributed async tasks

Distributed Locks for Django Celery Tasks

If you are putting asynchronous(background) tasks in a Django application, Celery is the go-to option without second thoughts. I tried to outline some considerations around production-ready celery tasks in one of the posts. In this post, we will learn how to safely handle shared resources when many celery tasks are changing them. This is going to be a slightly long post but I assure you of the following learnings -

  • Need for locks when running celery tasks
  • Implementing locks that work in a distributed environment like Celery
  • Key considerations when using such locks in production setup We will consider a close-to-real-world use case and walk through the example code snippets along the way.

Before we dive into celery and locks, let's quickly understand shared resources. If you are a backend engineer, you are already aware of it. Still, in order to have a better context, I will briefly put it here.

Shared resource

A shared resource is a piece of memory or storage that is shared by multiple threads or processes. Essentially, multiple threads can be trying to change/mutate this resource at the same time. For the purpose of this post let's consider a real-world use case to better relate the concepts.

Use case at hand

Let's say that we want to log page visits on an e-commerce store to a database and show the number of visits for the day as a real-time counter. Let's say that we have a celery task to record the page visit events into the database. In situations when you are handling millions of visitors per day, this task is not trivial due to the following reasons -

  • Aggregating the database logs for the day is expensive if we want the count to be realtime
  • Since our intention is to keep the counter ticking in real-time, it will put an unnecessary load on the database server

So, a better and more performant approach would be to maintain a counter in a global cache and have it incremented by the celery tasks. Now, let's get our hands dirty and put together a Django model and celery task for this use case.

# these are all the imports needed across all the code snippets in this post
# we will omit the imports in the rest of the snippets for brevity.
import threading
import uuid

from celery import shared_task
from contextlib import contextmanager
from django.core.cache import cache  
from django.db import models  

class PageVisit(models.Model):  
    id = models.UUIDField(primary_key=True, default=uuid.uuid4)  
    timestamp = models.DateTimeField(auto_now_add=True)  
    data = models.JSONField()

@shared_task()  
def page_visit_logger(data: dict):  
    PageVisit.objects.create(data=data)  
    counter_key = 'page-visit-counter'  
    counter = cache.get(counter_key)  
    if counter is None:  
        counter = 1  
    else:  
        counter += 1
    # resetting the counter at midnight is left out for the simplicity  
    cache.set(counter_key, counter)

Essentially the task is adding a database entry and incrementing the counter stored in the Django cache. The counter is a shared resource in this example since multiple celery tasks will be trying to alter its value.

Need for a lock

All looks good with the above celery task when you are testing locally by visiting a sample page but the situation is different when you deploy this code in production! What if more than one celery task is executing at the same time and trying to read and increment the counter? This is the fundamental flaw we have when tasks are running in parallel. If you are aware of multi-threading in python, you can protect the counter by creating a lock from the threading package as below

counter_lock = threading.Lock()  

@shared_task()  
def page_visit_logger(data: dict):  
    PageVisit.objects.create(data=data)  
    counter_key = 'page-visit-counter'  
    with counter_lock:  
        counter = cache.get(counter_key)  
        if counter is None:  
            counter = 1  
        else:  
            counter += 1  
        cache.set(counter_key, counter)

Above, we are taking a thread lock before reading the counter from the cache and writing the incremented value. This is great and the solution works as expected even when multiple tasks are running parallelly. But, there is a catch. The solution works as long as your celery worker is running on only one machine! The threading lock holds good on a single process after all and even the multiprocessing locks are limited to a single machine!

Typically, in production, you will have a celery cluster consisting of multiple worker nodes. In this setup, we will have the same concurrency problem. In order to solve this problem, we need a lock that is sitting outside the celery cluster and should provide atomic operations and it should be performant (fast acquire and release time). A distributed caching library like memcached is a perfect solution for this. Let's see how this can help us with the necessary locking.

Distributed lock with memcached

We can realize a lock that can be used in a distributed setup by setting a key-value pair in memcached. Here is how you can do it.

Nothing complicated. We are just adding a key to the cache to acquire the lock (the value can be anything, we are just setting a boolean here) and deleting the key from the cache in order to release the lock. Trying to add a key will fail and return False if the key is already present in the cache and thus we fail to acquire the lock. We are also setting a timeout for the cache key as we don't expect any task to hold the lock for more than 30 seconds in this case. We are operating on the counter stored in a cache after all!

Equipped with this, let's try to modify the celery task and make it work in a distributed setup.

@shared_task()  
def page_visit_logger(data: dict):  
    PageVisit.objects.create(data=data)  
    counter_key = 'page-visit-counter'  
    acquired = acquire_counter_lock()  

    # just block the task if until we acquire the lock  
    while not acquired:  
        time.sleep(0.01)  
        acquired = acquire_counter_lock()  

    counter = cache.get(counter_key)  
    if counter is None:  
        counter = 1  
    else:  
        counter += 1  
    cache.set(counter_key, counter)
    # we are done updating the shared counter, release the lock
    release_counter_lock()

Note that a task may fail to acquire the lock and end up waiting until it gets the lock.

This is cool stuff. We have a lock that can work for distributed celery tasks irrespective of number of worker nodes that we are going to throw as we scale our application. However, there are still some drawbacks to this implementation -

  • The tasks should remember to release the lock. And if the task fails due to some exception while holding the lock, the lock is still acquired until the timeout occurs. This will keep other tasks waiting.
  • The blocking code is part of the task
  • This lock will only work for this specific task

Let's address these issues by making the following changes to our implementation

  • Make it usable with python with context statement so that the lock is released automatically even if the task that acquired the lock fails due to some exception
  • Move the blocking logic inside the lock and provide this control to the task
  • Make the lock more generic so that it can be used for multiple types of tasks

@contextmanager  
def distributed_lock(key: str, timeout: int = 30, max_blocking: int = 0):  
    timeout_at = time.monotonic() + timeout  
    acquired = cache.add(key, True, timeout)  

    # if the caller wants to block, sleep wait and keep checking  
    if not acquired and max_blocking:  
        block_until = time.monotonic() + max_blocking  
        while time.monotonic() < block_until:  
            time.sleep(0.01)  
            acquired = cache.add(key, True, timeout)  
            if acquired:  
                break  
    try:  
        yield acquired  
    finally:  
        # if we are able to acquire the lock and the caller exits the context  
        # before the cache timeout, just delete it        
        if time.monotonic() < timeout_at and acquired:  
            cache.delete(key)

Now the above lock can be used by any celery task using the with statement and specifying a variable timeout duration and optional blocking period.

Note that, throughout this post, we have assumed that the default cache in the Django app is memcached. If the cache is in-memory, then it is not going to come for any use in the distributed scenario. Now that we have a functional lock, it's time to conclude. Some key considerations to keep in mind before we do so...

  • Failing to release the lock - If the task fails to release the lock, other tasks that are waiting for it will be blocked until the cache times out. The context manager will avoid this scenario.
  • Execution time of protected code - If the code inside the lock context (protected code) runs longer than the cache timeout, other tasks may acquire the lock and result in unintended behaviors. In this use case, we may get the wrong visitor count value. Hence, make sure that the protected code is not exceeding the timeout. On the other hand, if the timeout is long, other tasks will be blocked waiting and thus consume CPU time for no real value.
  • Task retry upon failure to acquire the lock - If a task fails to acquire the lock even after the blocking period, it is better to raise an exception and retry the task after some time has passed. I have outlined task retrying in Using Celery in Django Production Setup

That's it on cache-based locks. Please subscribe and follow if you find my posts useful. And, if you are enthusiastic to solve the toughest challenges in the e-commerce space and help sellers to run a profitable business at Konigle , join us.

Happy coding!