Skip to content

middleware

Custom ASGI app middleware.

These middleware are based on Starlette's BaseHTTPMiddleware. See the specific Starlette documentation page for more information on it's middleware implementation.

OPTIMADE_MIDDLEWARE: Tuple[starlette.middleware.base.BaseHTTPMiddleware]

A tuple of all the middleware classes that implement certain required features of the OPTIMADE specification, e.g. warnings and URL versioning.

Note

The order in which middleware is added to an application matters.

As discussed in the docstring of AddWarnings, this middleware is the final entry to this list so that it is the first to be applied by the server. Any other middleware should therefore be added before iterating through this variable. This is the opposite way around to the example in the Starlette documentation which initialises the application with a pre-built middleware list in the reverse order to OPTIMADE_MIDDLEWARE.

To use this variable in FastAPI app code after initialisation:

from fastapi import FastAPI
app = FastAPI()
for middleware in OPTIMADE_MIDDLEWARE:
    app.add_middleware(middleware)

Alternatively, to use this variable on initialisation:

from fastapi import FastAPI
from starlette.middleware import Middleware
app = FastAPI(
    ...,
    middleware=[Middleware(m) for m in reversed(OPTIMADE_MIDDLEWARE)]
)

AddWarnings (BaseHTTPMiddleware)

Add OptimadeWarnings to the response.

All sub-classes of OptimadeWarning will also be added to the response's meta.warnings list.

By overriding the warnings.showwarning() function with the showwarning method, all usages of warnings.warn() will result in the regular printing of the warning message to stderr, but also its addition to an in-memory list of warnings. This middleware will, after the URL request has been handled, add the list of accumulated warnings to the JSON response under the meta.warnings field.

To make sure the last part happens correctly and a Starlette StreamingResponse is returned, as is expected from a BaseHTTPMiddleware sub-class, one is instantiated with the updated Content-Length header, as well as making sure the response's body content is actually streamable, by breaking it down into chunks of the original response's chunk size.

Important

It is recommended to add this middleware as the last one to your application.

This is to ensure it is invoked first, updating warnings.showwarning() and catching all warnings that should be added to the response.

This can be achieved by applying AddWarnings after all other middleware with the .add_middleware() method, or by initialising the app with a middleware list in which AddWarnings appears first. More information can be found in the docstring of OPTIMADE_MIDDLEWARE.

Attributes:

Name Type Description
_warnings List[Warnings]

List of Warnings added through usages of warnings.warn() via showwarning.

chunk_it_up(content, chunk_size) staticmethod

Return generator for string in chunks of size chunk_size.

Parameters:

Name Type Description Default
content str

String-content to separate into chunks.

required
chunk_size int

The size of the chunks, i.e. the length of the string-chunks.

required

Returns:

Type Description
Generator

A Python generator to be converted later to an asyncio generator.

Source code in optimade/server/middleware.py
@staticmethod
def chunk_it_up(content: str, chunk_size: int) -> Generator:
    """Return generator for string in chunks of size `chunk_size`.

    Parameters:
        content: String-content to separate into chunks.
        chunk_size: The size of the chunks, i.e. the length of the string-chunks.

    Returns:
        A Python generator to be converted later to an `asyncio` generator.

    """
    if chunk_size <= 0:
        chunk_size = 1
    return (content[i : chunk_size + i] for i in range(0, len(content), chunk_size))

showwarning(self, message, category, filename, lineno, file=None, line=None)

Hook to write a warning to a file using the built-in warnings lib.

In the documentation for the built-in warnings library, there are a few recommended ways of customizing the printing of warning messages.

This method can override the warnings.showwarning function, which is called as part of the warnings library's workflow to print warning messages, e.g., when using warnings.warn(). Originally, it prints warning messages to stderr. This method will also print warning messages to stderr by calling warnings._showwarning_orig() or warnings._showwarnmsg_impl(). The first function will be called if the issued warning is not recognized as an OptimadeWarning. This is equivalent to "standard behaviour". The second function will be called after an OptimadeWarning has been handled.

An OptimadeWarning will be translated into an OPTIMADE Warnings JSON object in accordance with the specification. This process is similar to the Exception handlers.

Parameters:

Name Type Description Default
message Warning

The Warning object to show and possibly handle.

required
category Type[Warning]

Warning type being warned about. This amounts to type(message).

required
filename str

Name of the file, where the warning was issued.

required
lineno int

Line number in the file, where the warning was issued.

required
file Optional[IO]

A file-like object to which the warning should be written.

None
line Optional[str]

Source content of the line that issued the warning.

None
Source code in optimade/server/middleware.py
def showwarning(
    self,
    message: Warning,
    category: Type[Warning],
    filename: str,
    lineno: int,
    file: Optional[IO] = None,
    line: Optional[str] = None,
) -> None:
    """
    Hook to write a warning to a file using the built-in `warnings` lib.

    In [the documentation](https://docs.python.org/3/library/warnings.html)
    for the built-in `warnings` library, there are a few recommended ways of
    customizing the printing of warning messages.

    This method can override the `warnings.showwarning` function,
    which is called as part of the `warnings` library's workflow to print
    warning messages, e.g., when using `warnings.warn()`.
    Originally, it prints warning messages to `stderr`.
    This method will also print warning messages to `stderr` by calling
    `warnings._showwarning_orig()` or `warnings._showwarnmsg_impl()`.
    The first function will be called if the issued warning is not recognized
    as an [`OptimadeWarning`][optimade.server.warnings.OptimadeWarning].
    This is equivalent to "standard behaviour".
    The second function will be called _after_ an
    [`OptimadeWarning`][optimade.server.warnings.OptimadeWarning] has been handled.

    An [`OptimadeWarning`][optimade.server.warnings.OptimadeWarning] will be
    translated into an OPTIMADE Warnings JSON object in accordance with
    [the specification](https://github.com/Materials-Consortia/OPTIMADE/blob/v1.0.0/optimade.rst#json-response-schema-common-fields).
    This process is similar to the [Exception handlers][optimade.server.exception_handlers].

    Parameters:
        message: The `Warning` object to show and possibly handle.
        category: `Warning` type being warned about. This amounts to `type(message)`.
        filename: Name of the file, where the warning was issued.
        lineno: Line number in the file, where the warning was issued.
        file: A file-like object to which the warning should be written.
        line: Source content of the line that issued the warning.

    """
    assert isinstance(
        message, Warning
    ), "'message' is expected to be a Warning or subclass thereof."

    if not isinstance(message, OptimadeWarning):
        # If the Warning is not an OptimadeWarning or subclass thereof,
        # use the regular 'showwarning' function.
        warnings._showwarning_orig(message, category, filename, lineno, file, line)
        return

    # Format warning
    try:
        title = str(message.title)
    except AttributeError:
        title = str(message.__class__.__name__)

    try:
        detail = str(message.detail)
    except AttributeError:
        detail = str(message)

    if CONFIG.debug:
        if line is None:
            # All this is taken directly from the warnings library.
            # See 'warnings._formatwarnmsg_impl()' for the original code.
            try:
                import linecache

                line = linecache.getline(filename, lineno)
            except Exception:
                # When a warning is logged during Python shutdown, linecache
                # and the import machinery don't work anymore
                line = None
                linecache = None
        meta = {
            "filename": filename,
            "lineno": lineno,
        }
        if line:
            meta["line"] = line.strip()

    if CONFIG.debug:
        new_warning = Warnings(title=title, detail=detail, meta=meta)
    else:
        new_warning = Warnings(title=title, detail=detail)

    # Add new warning to self._warnings
    self._warnings.append(new_warning.dict(exclude_unset=True))

    # Show warning message as normal in sys.stderr
    warnings._showwarnmsg_impl(
        warnings.WarningMessage(message, category, filename, lineno, file, line)
    )

CheckWronglyVersionedBaseUrls (BaseHTTPMiddleware)

If a non-supported versioned base URL is supplied return 553 Version Not Supported.

check_url(url) staticmethod

Check URL path for versioned part.

Parameters:

Name Type Description Default
url URL

A complete urllib-parsed raw URL.

required

Exceptions:

Type Description
VersionNotSupported

If the URL represents an OPTIMADE versioned base URL and the version part is not supported by the implementation.

Source code in optimade/server/middleware.py
@staticmethod
def check_url(url: StarletteURL):
    """Check URL path for versioned part.

    Parameters:
        url: A complete urllib-parsed raw URL.

    Raises:
        VersionNotSupported: If the URL represents an OPTIMADE versioned base URL
            and the version part is not supported by the implementation.

    """
    base_url = get_base_url(url)
    optimade_path = f"{url.scheme}://{url.netloc}{url.path}"[len(base_url) :]
    match = re.match(r"^(?P<version>/v[0-9]+(\.[0-9]+){0,2}).*", optimade_path)
    if match is not None:
        if match.group("version") not in BASE_URL_PREFIXES.values():
            raise VersionNotSupported(
                detail=(
                    f"The parsed versioned base URL {match.group('version')!r} from "
                    f"{url} is not supported by this implementation. "
                    f"Supported versioned base URLs are: {', '.join(BASE_URL_PREFIXES.values())}"
                )
            )

EnsureQueryParamIntegrity (BaseHTTPMiddleware)

Ensure all query parameters are followed by an equal sign (=).

check_url(url_query) staticmethod

Check parsed URL query part for parameters not followed by =.

URL query parameters are considered to be split by ampersand (&) and semi-colon (;).

Parameters:

Name Type Description Default
url_query str

The raw urllib-parsed query part.

required

Exceptions:

Type Description
BadRequest

If a query parameter does not come with a value.

Returns:

Type Description
set

The set of individual query parameters and their values.

This is mainly for testing and not actually neeeded by the middleware, since if the URL exhibits an invalid query part a 400 Bad Request response will be returned.

Source code in optimade/server/middleware.py
@staticmethod
def check_url(url_query: str) -> set:
    """Check parsed URL query part for parameters not followed by `=`.

    URL query parameters are considered to be split by ampersand (`&`)
    and semi-colon (`;`).

    Parameters:
        url_query: The raw urllib-parsed query part.

    Raises:
        BadRequest: If a query parameter does not come with a value.

    Returns:
        The set of individual query parameters and their values.

        This is mainly for testing and not actually neeeded by the middleware,
        since if the URL exhibits an invalid query part a `400 Bad Request`
        response will be returned.

    """
    queries_amp = set(url_query.split("&"))
    queries = set()
    for query in queries_amp:
        queries.update(set(query.split(";")))
    for query in queries:
        if "=" not in query and query != "":
            raise BadRequest(
                detail="A query parameter without an equal sign (=) is not supported by this server"
            )
    return queries  # Useful for testing

HandleApiHint (BaseHTTPMiddleware)

Handle api_hint query parameter.

handle_api_hint(api_hint) staticmethod

Handle api_hint parameter value.

There are several scenarios that can play out, when handling the api_hint query parameter:

If several api_hint query parameters have been used, or a "standard" JSON list (,-separated value) has been supplied, a warning will be added to the response and the api_hint query parameter will not be applied.

If the passed value does not comply with the rules set out in the specification, a warning will be added to the response and the api_hint query parameter will not be applied.

If the value is part of the implementation's accepted versioned base URLs, it will be returned as is.

If the value represents a major version that is newer than what is supported by the implementation, a 553 Version Not Supported response will be returned, as is stated by the specification.

On the other hand, if the value represents a major version equal to or lower than the implementation's supported major version, then the implementation's supported major version will be returned and tried for the request.

Parameters:

Name Type Description Default
api_hint List[str]

The urllib-parsed query parameter value for api_hint.

required

Exceptions:

Type Description
VersionNotSupported

If the requested major version is newer than the supported major version of the implementation.

Returns:

Type Description
Union[NoneType, str]

Either a valid api_hint value or None.

Source code in optimade/server/middleware.py
@staticmethod
def handle_api_hint(api_hint: List[str]) -> Union[None, str]:
    """Handle `api_hint` parameter value.

    There are several scenarios that can play out, when handling the `api_hint`
    query parameter:

    If several `api_hint` query parameters have been used, or a "standard" JSON
    list (`,`-separated value) has been supplied, a warning will be added to the
    response and the `api_hint` query parameter will not be applied.

    If the passed value does not comply with the rules set out in
    [the specification](https://github.com/Materials-Consortia/OPTIMADE/blob/v1.0.0/optimade.rst#version-negotiation),
    a warning will be added to the response and the `api_hint` query parameter
    will not be applied.

    If the value is part of the implementation's accepted versioned base URLs,
    it will be returned as is.

    If the value represents a major version that is newer than what is supported
    by the implementation, a `553 Version Not Supported` response will be returned,
    as is stated by [the specification](https://github.com/Materials-Consortia/OPTIMADE/blob/v1.0.0/optimade.rst#version-negotiation).

    On the other hand, if the value represents a major version equal to or lower
    than the implementation's supported major version, then the implementation's
    supported major version will be returned and tried for the request.

    Parameters:
        api_hint: The urllib-parsed query parameter value for `api_hint`.

    Raises:
        VersionNotSupported: If the requested major version is newer than the
            supported major version of the implementation.

    Returns:
        Either a valid `api_hint` value or `None`.

    """
    # Try to split by `,` if value is provided once, but in JSON-type "list" format
    _api_hint = []
    for value in api_hint:
        values = value.split(",")
        _api_hint.extend(values)
    api_hint = _api_hint

    if len(api_hint) > 1:
        warnings.warn(
            TooManyValues(
                detail="`api_hint` should only be supplied once, with a single value."
            )
        )
        return None

    api_hint = f"/{api_hint[0]}"
    if re.match(r"^/v[0-9]+(\.[0-9]+)?$", api_hint) is None:
        warnings.warn(
            FieldValueNotRecognized(
                detail=f"{api_hint[1:]!r} is not recognized as a valid `api_hint` value."
            )
        )
        return None

    if api_hint in BASE_URL_PREFIXES.values():
        return api_hint

    major_api_hint = int(re.findall(r"/v([0-9]+)", api_hint)[0])
    major_implementation = int(BASE_URL_PREFIXES["major"][len("/v") :])

    if major_api_hint > major_implementation:
        # Let's not try to handle a request for a newer major version
        raise VersionNotSupported(
            detail=(
                f"The provided `api_hint` ({api_hint[1:]!r}) is not supported by this implementation. "
                f"Supported versions include: {', '.join(BASE_URL_PREFIXES.values())}"
            )
        )
    if major_api_hint <= major_implementation:
        # If less than:
        # Use the current implementation in hope that it can still handle older requests
        #
        # If equal:
        # Go to /v<MAJOR>, since this should point to the latest available
        return BASE_URL_PREFIXES["major"]

is_versioned_base_url(url) staticmethod

Determine whether a request is for a versioned base URL.

First, simply check whether a /vMAJOR(.MINOR.PATCH) part exists in the URL. If not, return False, else, remove unversioned base URL from the URL and check again. Return bool of final result.

Parameters:

Name Type Description Default
url str

The full URL to check.

required

Returns:

Type Description
bool

Whether or not the full URL represents an OPTIMADE versioned base URL.

Source code in optimade/server/middleware.py
@staticmethod
def is_versioned_base_url(url: str) -> bool:
    """Determine whether a request is for a versioned base URL.

    First, simply check whether a `/vMAJOR(.MINOR.PATCH)` part exists in the URL.
    If not, return `False`, else, remove unversioned base URL from the URL and check again.
    Return `bool` of final result.

    Parameters:
        url: The full URL to check.

    Returns:
        Whether or not the full URL represents an OPTIMADE versioned base URL.

    """
    if not re.findall(r"(/v[0-9]+(\.[0-9]+){0,2})", url):
        return False

    base_url = get_base_url(url)
    return bool(re.findall(r"(/v[0-9]+(\.[0-9]+){0,2})", url[len(base_url) :]))