Functional Dependency Injection in Python

The Problem

In software engineering, dependency injection is a programming technique in which an object or function receives other objects or functions that it requires, as opposed to creating them internally. Dependency injection aims to separate the concerns of constructing objects and using them, leading to loosely coupled programs.

Source: https://en.wikipedia.org/wiki/Dependency_injection

If you've found your way here, you're probably already familiar with the problem of dependency injection.

I'm going to show you a simple way to tackle it using closures.

But remember...

There are no solutions. There are only trade-offs.

- Thomas Sowell

Closures

Let's say we're working with an external library called blockchain.

blockchain has two classes: an API Client and an Account manager:

from collections.abc import Callable
from dataclasses import dataclass, field
from hashlib import sha256
from secrets import token_bytes
from typing import ParamSpec, TypeVar


@dataclass
class Client:
    """Blockchain API client."""

    url: str
    registered_accounts: set[str] = field(default_factory=set)

    def register_account(self, address: str) -> str:
        """Register an account with the blockchain.

        Args:
            address (str): The address of the account to register.

        Returns:
            str: A string confirming the account registration.
        """
        self.registered_accounts |= {address}
        return f"Registered account {address}"


@dataclass
class Account:
    """Blockchain account."""

    private_key: bytes
    address: str

    @classmethod
    def create(cls) -> "Account":
        """Create a new account."""
        private_key = token_bytes()
        address = sha256(private_key).hexdigest()
        return cls(private_key=private_key, address=address)

We want to make it easy for users to register an account on the blockchain, and check that it has been registered successfully.

Let's keep it simple and define some functions:

def register_account(client: Client, account: Account) -> str:
    """Register an account with the blockchain.

    Args:
        client (Client): The client to use.
        account (Account): The account to register.

    Returns:
        str: A string confirming the account registration.
    """
    return client.register_account(account.address)


def is_account_registered(client: Client, account: Account) -> bool:
    """Check if an account is registered with the blockchain.

    Args:
        client (Client): The client to use.
        account (Account): The account to check.

    Returns:
        bool: True if the account is registered, otherwise False.
    """
    return account.address in client.registered_accounts

Notice the pattern: each function requires a Client and an Account to be passed in.

The benefit of this design is that the functions are pure - they only depend on their arguments.

But we've added the slight inconvenience of having to repeatedly pass the same arguments:

>>> account = Account.create()
... client = Client("https://api.blockchain.com")

... print(register_account(client, account))
... print(is_account_registered(client, account))
Registered account b49b0bfbed97ae5d0bbe8c473c79c0644b08c3abac3ce1d8109b989be46f3846
True

We can eliminate the repetition with a closure:

T = TypeVar("T")

def provide_context(
    client: Client, account: Account
) -> Callable[[Callable[[Client, Account], T]], T]:
    """A closure that provides client and account arguments to a function.

    Args:
        client (Client): The client to use.
        account (Account): The account to use.

    Returns:
        Callable[[Callable[[Client, Account], T]], T]: A function that takes a client and account and returns a value.
    """

    def wrapped(fn: Callable[[Client, Account], T]) -> T:
        """Calls the function with the provided client and account.

        Args:
            fn (Callable[[Client, Account], T]): The function to call.

        Returns:
            T: The result of calling the function.
        """
        return fn(client, account)

    return wrapped

The provide_context() function takes two arguments: a Client and an Account.

It returns another function, which accepts a callable requiring a Client and an Account.

Let's see it in action:

>>> account = Account.create()
... client = Client("https://api.blockchain.com")

... # Define our context once
... with_context = provide_context(client, account)
... # Use the context to call our functions
... print(with_context(register_account))
... print(with_context(is_account_registered))
Registered account e4735977494a46593a5cd2b9bd233d3c629a630ebb792a52b0a6718366fc7405
True

Now we've abstracted away the passing of the Client and Account objects to our functions, without storing them in a class.

But what if we need to pass additional arguments to a function, like the one below?:

def check_and_greet_user(client: Client, account: Account, user: str) -> str:
    """Check if the user's account is registered and greet them.

    Args:
        client (Client): The client to use.
        account (Account): The account to check.
        user (str): The user to greet.

    Returns:
        str: A greeting for the user.
    """
    if is_account_registered(client, account):
        return f"Hello, {user}!"
    else:
        return f"Hello, {user}! Please register your account."

Let's adapt the provide_context() function so it accepts any number of arguments as context, and also allows us to pass additional arguments to another function:

T = TypeVar("T")
P = ParamSpec("P")

def provide_context(*args: P.args, **kwargs: P.kwargs) -> Callable[[Callable[P, T]], T]:
    """A closure that provides context arguments to a function.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.

    Returns:
        Callable[[Callable[P, T]], T]: The wrapped function.
    """

    def wrapped(fn: Callable[P, T], *fn_args: P.args, **fn_kwargs: P.kwargs) -> T:
        """Calls the function with the context arguments and any additional arguments passed in.

        Args:
            fn (Callable[P, T]): The function to call.
            *fn_args: Variable length argument list.
            **fn_kwargs: Arbitrary keyword arguments.

        Returns:
            T: The result of calling the function.
        """
        return fn(*args, *fn_args, **kwargs, **fn_kwargs)

    return wrapped

Which gives us:

>>> account = Account.create()
... client = Client("https://api.blockchain.com")

... # Define our context once
... with_context = provide_context(client, account)
... # Use the context to call our functions
... print(with_context(register_account))
... print(with_context(is_account_registered))
... # Pass additional arguments to the function
... print(with_context(check_and_greet_user, user="Alice"))
Registered account 1a5e4c7b6d8f911f18b122b80b9ecba0558cc83da77d6c2a44e5b1cf59e097e6
True
Hello, Alice!

Summary

Closures provide a simple way to pass the same arguments to different functions.

We can use a generic context_provider() function to avoid repeating code, without introducing any new classes.