Source code for libcst.helpers._template

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#

from typing import Dict, Mapping, Optional, Set, Union

import libcst as cst
from libcst.helpers.common import ensure_type

TEMPLATE_PREFIX: str = "__LIBCST_MANGLED_NAME_"
TEMPLATE_SUFFIX: str = "_EMAN_DELGNAM_TSCBIL__"


ValidReplacementType = Union[
    cst.BaseExpression,
    cst.Annotation,
    cst.AssignTarget,
    cst.Param,
    cst.Parameters,
    cst.Arg,
    cst.BaseStatement,
    cst.BaseSmallStatement,
    cst.BaseSuite,
    cst.BaseSlice,
    cst.SubscriptElement,
    cst.Decorator,
]


def mangled_name(var: str) -> str:
    return f"{TEMPLATE_PREFIX}{var}{TEMPLATE_SUFFIX}"


def unmangled_name(var: str) -> Optional[str]:
    if TEMPLATE_PREFIX in var and TEMPLATE_SUFFIX in var:
        prefix, name_and_suffix = var.split(TEMPLATE_PREFIX, 1)
        name, suffix = name_and_suffix.split(TEMPLATE_SUFFIX, 1)
        if not prefix and not suffix:
            return name
    # This is not a valid mangled name
    return None


def mangle_template(template: str, template_vars: Set[str]) -> str:
    if TEMPLATE_PREFIX in template or TEMPLATE_SUFFIX in template:
        raise Exception("Cannot parse a template containing reserved strings")

    for var in template_vars:
        original = f"{{{var}}}"
        if original not in template:
            raise Exception(
                f'Template string is missing a reference to "{var}" referred to in kwargs'
            )
        template = template.replace(original, mangled_name(var))
    return template


class TemplateTransformer(cst.CSTTransformer):
    def __init__(
        self, template_replacements: Mapping[str, ValidReplacementType]
    ) -> None:
        self.simple_replacements: Dict[str, cst.BaseExpression] = {
            name: value
            for name, value in template_replacements.items()
            if isinstance(value, cst.BaseExpression)
        }
        self.annotation_replacements: Dict[str, cst.Annotation] = {
            name: value
            for name, value in template_replacements.items()
            if isinstance(value, cst.Annotation)
        }
        self.assignment_replacements: Dict[str, cst.AssignTarget] = {
            name: value
            for name, value in template_replacements.items()
            if isinstance(value, cst.AssignTarget)
        }
        self.param_replacements: Dict[str, cst.Param] = {
            name: value
            for name, value in template_replacements.items()
            if isinstance(value, cst.Param)
        }
        self.parameters_replacements: Dict[str, cst.Parameters] = {
            name: value
            for name, value in template_replacements.items()
            if isinstance(value, cst.Parameters)
        }
        self.arg_replacements: Dict[str, cst.Arg] = {
            name: value
            for name, value in template_replacements.items()
            if isinstance(value, cst.Arg)
        }
        self.small_statement_replacements: Dict[str, cst.BaseSmallStatement] = {
            name: value
            for name, value in template_replacements.items()
            if isinstance(value, cst.BaseSmallStatement)
        }
        self.statement_replacements: Dict[str, cst.BaseStatement] = {
            name: value
            for name, value in template_replacements.items()
            if isinstance(value, cst.BaseStatement)
        }
        self.suite_replacements: Dict[str, cst.BaseSuite] = {
            name: value
            for name, value in template_replacements.items()
            if isinstance(value, cst.BaseSuite)
        }
        self.subscript_element_replacements: Dict[str, cst.SubscriptElement] = {
            name: value
            for name, value in template_replacements.items()
            if isinstance(value, cst.SubscriptElement)
        }
        self.subscript_index_replacements: Dict[str, cst.BaseSlice] = {
            name: value
            for name, value in template_replacements.items()
            if isinstance(value, cst.BaseSlice)
        }
        self.decorator_replacements: Dict[str, cst.Decorator] = {
            name: value
            for name, value in template_replacements.items()
            if isinstance(value, cst.Decorator)
        }

        # Figure out if there are any variables that we can't support
        # inserting into templates.
        supported_vars = {
            *[name for name in self.simple_replacements],
            *[name for name in self.annotation_replacements],
            *[name for name in self.assignment_replacements],
            *[name for name in self.param_replacements],
            *[name for name in self.parameters_replacements],
            *[name for name in self.arg_replacements],
            *[name for name in self.small_statement_replacements],
            *[name for name in self.statement_replacements],
            *[name for name in self.suite_replacements],
            *[name for name in self.subscript_element_replacements],
            *[name for name in self.subscript_index_replacements],
            *[name for name in self.decorator_replacements],
        }
        unsupported_vars = {
            name for name in template_replacements if name not in supported_vars
        }
        if unsupported_vars:
            raise Exception(
                f'Template replacement for "{next(iter(unsupported_vars))}" is unsupported'
            )

    def leave_Name(
        self, original_node: cst.Name, updated_node: cst.Name
    ) -> cst.BaseExpression:
        var_name = unmangled_name(updated_node.value)
        if var_name is None or var_name not in self.simple_replacements:
            # This is not a valid name, don't modify it
            return updated_node
        return self.simple_replacements[var_name].deep_clone()

    def leave_Annotation(
        self,
        original_node: cst.Annotation,
        updated_node: cst.Annotation,
    ) -> cst.Annotation:
        # We can't use matchers here due to circular imports
        annotation = updated_node.annotation
        if isinstance(annotation, cst.Name):
            var_name = unmangled_name(annotation.value)
            if var_name in self.annotation_replacements:
                return self.annotation_replacements[var_name].deep_clone()
        return updated_node

    def leave_AssignTarget(
        self,
        original_node: cst.AssignTarget,
        updated_node: cst.AssignTarget,
    ) -> cst.AssignTarget:
        # We can't use matchers here due to circular imports
        target = updated_node.target
        if isinstance(target, cst.Name):
            var_name = unmangled_name(target.value)
            if var_name in self.assignment_replacements:
                return self.assignment_replacements[var_name].deep_clone()
        return updated_node

    def leave_Param(
        self,
        original_node: cst.Param,
        updated_node: cst.Param,
    ) -> cst.Param:
        var_name = unmangled_name(updated_node.name.value)
        if var_name in self.param_replacements:
            return self.param_replacements[var_name].deep_clone()
        return updated_node

    def leave_Parameters(
        self,
        original_node: cst.Parameters,
        updated_node: cst.Parameters,
    ) -> cst.Parameters:
        # A very special case for when we use a template variable for all
        # function parameters.
        if (
            len(updated_node.params) == 1
            and updated_node.star_arg == cst.MaybeSentinel.DEFAULT
            and len(updated_node.kwonly_params) == 0
            and updated_node.star_kwarg is None
            and len(updated_node.posonly_params) == 0
            and updated_node.posonly_ind == cst.MaybeSentinel.DEFAULT
        ):
            # This parameters node has only one argument, which is possibly
            # a replacement.
            var_name = unmangled_name(updated_node.params[0].name.value)
            if var_name in self.parameters_replacements:
                return self.parameters_replacements[var_name].deep_clone()
        return updated_node

    def leave_Arg(self, original_node: cst.Arg, updated_node: cst.Arg) -> cst.Arg:
        # We can't use matchers here due to circular imports
        arg = updated_node.value
        if isinstance(arg, cst.Name):
            var_name = unmangled_name(arg.value)
            if var_name in self.arg_replacements:
                return self.arg_replacements[var_name].deep_clone()
        return updated_node

    def leave_SimpleStatementLine(
        self,
        original_node: cst.SimpleStatementLine,
        updated_node: cst.SimpleStatementLine,
    ) -> cst.BaseStatement:
        # We can't use matchers here due to circular imports. We take advantage of
        # the fact that a name on a single line will be parsed as an Expr node
        # contained in a SimpleStatementLine, so we check for these and see if they
        # should be expanded template-wise to a statement of some type.
        if len(updated_node.body) == 1:
            body_node = updated_node.body[0]
            if isinstance(body_node, cst.Expr):
                name_node = body_node.value
                if isinstance(name_node, cst.Name):
                    var_name = unmangled_name(name_node.value)
                    if var_name in self.statement_replacements:
                        return self.statement_replacements[var_name].deep_clone()
        return updated_node

    def leave_Expr(
        self,
        original_node: cst.Expr,
        updated_node: cst.Expr,
    ) -> cst.BaseSmallStatement:
        # We can't use matchers here due to circular imports. We do a similar trick
        # to the above stanza handling SimpleStatementLine to support templates
        # which are trying to substitute a BaseSmallStatement.
        name_node = updated_node.value
        if isinstance(name_node, cst.Name):
            var_name = unmangled_name(name_node.value)
            if var_name in self.small_statement_replacements:
                return self.small_statement_replacements[var_name].deep_clone()
        return updated_node

    def leave_SimpleStatementSuite(
        self,
        original_node: cst.SimpleStatementSuite,
        updated_node: cst.SimpleStatementSuite,
    ) -> cst.BaseSuite:
        # We can't use matchers here due to circular imports. We take advantage of
        # the fact that a name in a simple suite will be parsed as an Expr node
        # contained in a SimpleStatementSuite, so we check for these and see if they
        # should be expanded template-wise to a base suite of some type.
        if len(updated_node.body) == 1:
            body_node = updated_node.body[0]
            if isinstance(body_node, cst.Expr):
                name_node = body_node.value
                if isinstance(name_node, cst.Name):
                    var_name = unmangled_name(name_node.value)
                    if var_name in self.suite_replacements:
                        return self.suite_replacements[var_name].deep_clone()
        return updated_node

    def leave_IndentedBlock(
        self,
        original_node: cst.IndentedBlock,
        updated_node: cst.IndentedBlock,
    ) -> cst.BaseSuite:
        # We can't use matchers here due to circular imports. We take advantage of
        # the fact that a name in an indented block will be parsed as an Expr node
        # contained in a SimpleStatementLine, so we check for these and see if they
        # should be expanded template-wise to a base suite of some type.
        if len(updated_node.body) == 1:
            statement_node = updated_node.body[0]
            if (
                isinstance(statement_node, cst.SimpleStatementLine)
                and len(statement_node.body) == 1
            ):
                body_node = statement_node.body[0]
                if isinstance(body_node, cst.Expr):
                    name_node = body_node.value
                    if isinstance(name_node, cst.Name):
                        var_name = unmangled_name(name_node.value)
                        if var_name in self.suite_replacements:
                            return self.suite_replacements[var_name].deep_clone()
        return updated_node

    def leave_Index(
        self,
        original_node: cst.Index,
        updated_node: cst.Index,
    ) -> cst.BaseSlice:
        # We can't use matchers here due to circular imports
        expr = updated_node.value
        if isinstance(expr, cst.Name):
            var_name = unmangled_name(expr.value)
            if var_name in self.subscript_index_replacements:
                return self.subscript_index_replacements[var_name].deep_clone()
        return updated_node

    def leave_SubscriptElement(
        self,
        original_node: cst.SubscriptElement,
        updated_node: cst.SubscriptElement,
    ) -> cst.SubscriptElement:
        # We can't use matchers here due to circular imports. We use the trick
        # similar to above stanzas where a template replacement variable will
        # always show up as a certain type (in this case an Index inside of a
        # SubscriptElement) in order to successfully replace subscript elements
        # in templates.
        index = updated_node.slice
        if isinstance(index, cst.Index):
            expr = index.value
            if isinstance(expr, cst.Name):
                var_name = unmangled_name(expr.value)
                if var_name in self.subscript_element_replacements:
                    return self.subscript_element_replacements[var_name].deep_clone()
        return updated_node

    def leave_Decorator(
        self, original_node: cst.Decorator, updated_node: cst.Decorator
    ) -> cst.Decorator:
        # We can't use matchers here due to circular imports
        decorator = updated_node.decorator
        if isinstance(decorator, cst.Name):
            var_name = unmangled_name(decorator.value)
            if var_name in self.decorator_replacements:
                return self.decorator_replacements[var_name].deep_clone()
        return updated_node


class TemplateChecker(cst.CSTVisitor):
    def __init__(self, template_vars: Set[str]) -> None:
        self.template_vars = template_vars

    def visit_Name(self, node: cst.Name) -> None:
        for var in self.template_vars:
            if node.value == mangled_name(var):
                raise Exception(f'Template variable "{var}" was not replaced properly')


def unmangle_nodes(
    tree: cst.CSTNode,
    template_replacements: Mapping[str, ValidReplacementType],
) -> cst.CSTNode:
    unmangler = TemplateTransformer(template_replacements)
    return ensure_type(tree.visit(unmangler), cst.CSTNode)


_DEFAULT_PARTIAL_PARSER_CONFIG: cst.PartialParserConfig = cst.PartialParserConfig()


[docs] def parse_template_module( template: str, config: cst.PartialParserConfig = _DEFAULT_PARTIAL_PARSER_CONFIG, **template_replacements: ValidReplacementType, ) -> cst.Module: """ Accepts an entire python module template, including all leading and trailing whitespace. Any :class:`~libcst.CSTNode` provided as a keyword argument to this function will be inserted into the template at the appropriate location similar to an f-string expansion. For example:: module = parse_template_module("from {mod} import Foo\\n", mod=Name("bar")) The above code will parse to a module containing a single :class:`~libcst.FromImport` statement, referencing module ``bar`` and importing object ``Foo`` from it. Remember that if you are parsing a template as part of a substitution inside a transform, its considered :ref:`best practice <libcst-config_best_practice>` to pass in a ``config`` from the current module under transformation. Note that unlike :func:`~libcst.parse_module`, this function does not support bytes as an input. This is due to the fact that it is processed as a template before parsing as a module. """ source = mangle_template(template, {name for name in template_replacements}) module = cst.parse_module(source, config) new_module = ensure_type(unmangle_nodes(module, template_replacements), cst.Module) new_module.visit(TemplateChecker({name for name in template_replacements})) return new_module
[docs] def parse_template_statement( template: str, config: cst.PartialParserConfig = _DEFAULT_PARTIAL_PARSER_CONFIG, **template_replacements: ValidReplacementType, ) -> Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]: """ Accepts a statement template followed by a trailing newline. If a trailing newline is not provided, one will be added. Any :class:`~libcst.CSTNode` provided as a keyword argument to this function will be inserted into the template at the appropriate location similar to an f-string expansion. For example:: statement = parse_template_statement("assert x > 0, {msg}", msg=SimpleString('"Uh oh!"')) The above code will parse to an assert statement checking that some variable ``x`` is greater than zero, or providing the assert message ``"Uh oh!"``. Remember that if you are parsing a template as part of a substitution inside a transform, its considered :ref:`best practice <libcst-config_best_practice>` to pass in a ``config`` from the current module under transformation. """ source = mangle_template(template, {name for name in template_replacements}) statement = cst.parse_statement(source, config) new_statement = unmangle_nodes(statement, template_replacements) if not isinstance( new_statement, (cst.SimpleStatementLine, cst.BaseCompoundStatement) ): raise Exception( f"Expected a statement but got a {new_statement.__class__.__name__}!" ) new_statement.visit(TemplateChecker({name for name in template_replacements})) return new_statement
[docs] def parse_template_expression( template: str, config: cst.PartialParserConfig = _DEFAULT_PARTIAL_PARSER_CONFIG, **template_replacements: ValidReplacementType, ) -> cst.BaseExpression: """ Accepts an expression template on a single line. Leading and trailing whitespace is not valid (there’s nowhere to store it on the expression node). Any :class:`~libcst.CSTNode` provided as a keyword argument to this function will be inserted into the template at the appropriate location similar to an f-string expansion. For example:: expression = parse_template_expression("x + {foo}", foo=Name("y"))) The above code will parse to a :class:`~libcst.BinaryOperation` expression adding two names (``x`` and ``y``) together. Remember that if you are parsing a template as part of a substitution inside a transform, its considered :ref:`best practice <libcst-config_best_practice>` to pass in a ``config`` from the current module under transformation. """ source = mangle_template(template, {name for name in template_replacements}) expression = cst.parse_expression(source, config) new_expression = ensure_type( unmangle_nodes(expression, template_replacements), cst.BaseExpression ) new_expression.visit(TemplateChecker({name for name in template_replacements})) return new_expression