Julian Mehne

Use this jupyter notebook setup function.

python, jupyter

Another building block in efficiently working with notebooks: use a setup function. I assume you use a git repository, and you polish your code from the notebook to move it into regular python modules. Following common python project structure convetions, we may be using something like this:

my_project/some/module/important.py
notebooks/some_analysis/my_notebook.ipynb
notebooks/some_analysis/src/some_nb_specific_module.py

When starting the notebook, the current working directory is the directory the notebook is stored in, so that we can import anything from the subdirectories. But I also want to be able to import anything from the project root directory and more importantly, I want to make sure that any paths I reference are always relative to the project root. Otherwise it gets annoying, if we I want to reference sql/some_query.sql from any notebook (sure, we could use some kind of loader class, but let's keep it simple).

The job of the setup function is thus to:

Long story short, I just put this in my first notebook cell:

import os
import logging
import pathlib
import subprocess as sp
import sys
from dataclasses import dataclass


def get_git_root() -> str:
    r = sp.run(["git", "rev-parse", "--show-toplevel"], check=True, capture_output=True)
    return r.stdout.decode("utf-8").strip()


@dataclass(kw_only=True)
class NBConfig:
    """Let's make it easy to extend the return object and use a data class."""
    project_root_path: pathlib.Path
    notebook_file_dir_path: pathlib.Path
    logger: logging.Logger


def setup_notebook(log_level=logging.INFO) -> NBConfig:
    nb_setup_env_name = "NB_SETUP_DONE"
    nb_file_dir_path_env_name = "NB_FILE_DIR_PATH"

    logger = logging.getLogger(__name__)
    project_root_path = get_git_root()

    if os.getenv(nb_setup_env_name):
        logger.info("Notebook setup was already run.")
        notebook_file_dir_path = os.getenv(nb_file_dir_path_env_name)
        if notebook_file_dir_path is None:
            raise ValueError("`notebook_file_dir_path` must not be None here.")
        return NBConfig(
            project_root_path=project_root_path,
            notebook_file_dir_path=notebook_file_dir_path,
            logger=logger,
        )

    logging.basicConfig(level=log_level)
    logger.info("Running notebook setup.")

    # Auto-reload modules, if they get changed (only works for direct imports, not
    # indirect imports!).
    %load_ext autoreload
    %autoreload 2

    notebook_file_dir_path = str(pathlib.Path("").absolute())
    os.environ[nb_file_dir_path_env_name] = notebook_file_dir_path
    sys.path.insert(0, notebook_file_dir_path)
    os.chdir(project_root_path)

    logger.info(f"Using system path: {sys.path}")
    logger.info(f"Switched current working directory to {os.getcwd()}.")

    os.environ[nb_setup_env_name] = "1"
    return NBConfig(
        project_root_path=project_root_path,
        notebook_file_dir_path=notebook_file_dir_path,
        logger=logger,
    )

r = setup_notebook()
logger = r.logger
project_root_path = r.project_root_path
notebook_file_dir_path = r.notebook_file_dir_path

Happy notebooking!