# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import inspect
from typing import (
Callable,
cast,
Iterable,
List,
Mapping,
MutableMapping,
Optional,
TYPE_CHECKING,
)
from libcst._metadata_dependent import MetadataDependent
from libcst._typed_visitor import CSTTypedVisitorFunctions
from libcst._visitors import CSTNodeT, CSTVisitor
if TYPE_CHECKING:
from libcst._nodes.base import CSTNode # noqa: F401
VisitorMethod = Callable[["CSTNode"], None]
_VisitorMethodCollection = Mapping[str, List[VisitorMethod]]
[docs]class BatchableCSTVisitor(CSTTypedVisitorFunctions, MetadataDependent):
"""
The low-level base visitor class for traversing a CST as part of a batched
set of traversals. This should be used in conjunction with the
:func:`~libcst.visit_batched` function or the
:func:`~libcst.MetadataWrapper.visit_batched` method from
:class:`~libcst.MetadataWrapper` to visit a tree.
Instances of this class cannot modify the tree.
"""
[docs] def get_visitors(self) -> Mapping[str, VisitorMethod]:
"""
Returns a mapping of all the ``visit_<Type[CSTNode]>``,
``visit_<Type[CSTNode]>_<attribute>``, ``leave_<Type[CSTNode]>`` and
`leave_<Type[CSTNode]>_<attribute>`` methods defined by this visitor,
excluding all empty stubs.
"""
methods = inspect.getmembers(
self,
lambda m: (
inspect.ismethod(m)
and (m.__name__.startswith("visit_") or m.__name__.startswith("leave_"))
and not getattr(m, "_is_no_op", False)
),
)
# TODO: verify all visitor methods reference valid node classes.
# for name, __ in methods:
# ...
return dict(methods)
[docs]def visit_batched(
node: CSTNodeT,
batchable_visitors: Iterable[BatchableCSTVisitor],
before_visit: Optional[VisitorMethod] = None,
after_leave: Optional[VisitorMethod] = None,
) -> CSTNodeT:
"""
Do a batched traversal over ``node`` with all ``visitors``.
``before_visit`` and ``after_leave`` are provided as optional hooks to
execute before the ``visit_<Type[CSTNode]>`` and after the
``leave_<Type[CSTNode]>`` methods from each visitor in ``visitor`` are
executed by the batched visitor.
This function does not handle metadata dependency resolution for ``visitors``.
See :func:`~libcst.MetadataWrapper.visit_batched` from
:class:`~libcst.MetadataWrapper` for batched traversal with metadata dependency
resolution.
"""
visitor_methods = _get_visitor_methods(batchable_visitors)
batched_visitor = _BatchedCSTVisitor(
visitor_methods, before_visit=before_visit, after_leave=after_leave
)
return cast(CSTNodeT, node.visit(batched_visitor))
def _get_visitor_methods(
batchable_visitors: Iterable[BatchableCSTVisitor],
) -> _VisitorMethodCollection:
"""
Gather all ``visit_<Type[CSTNode]>``, ``visit_<Type[CSTNode]>_<attribute>``,
``leave_<Type[CSTNode]>`` amd `leave_<Type[CSTNode]>_<attribute>`` methods
from ``batchabled_visitors``.
"""
visitor_methods: MutableMapping[str, List[VisitorMethod]] = {}
for bv in batchable_visitors:
for name, fn in bv.get_visitors().items():
visitor_methods.setdefault(name, []).append(fn)
return visitor_methods
class _BatchedCSTVisitor(CSTVisitor):
"""
Internal visitor class to perform batched traversal over a tree.
"""
visitor_methods: _VisitorMethodCollection
before_visit: Optional[VisitorMethod]
after_leave: Optional[VisitorMethod]
def __init__(
self,
visitor_methods: _VisitorMethodCollection,
*,
before_visit: Optional[VisitorMethod] = None,
after_leave: Optional[VisitorMethod] = None,
) -> None:
super().__init__()
self.visitor_methods = visitor_methods
self.before_visit = before_visit
self.after_leave = after_leave
def on_visit(self, node: "CSTNode") -> bool:
"""
Call appropriate visit methods on node before visiting children.
"""
before_visit = self.before_visit
if before_visit is not None:
before_visit(node)
type_name = type(node).__name__
for v in self.visitor_methods.get(f"visit_{type_name}", []):
v(node)
return True
def on_leave(self, original_node: "CSTNode") -> None:
"""
Call appropriate leave methods on node after visiting children.
"""
type_name = type(original_node).__name__
for v in self.visitor_methods.get(f"leave_{type_name}", []):
v(original_node)
after_leave = self.after_leave
if after_leave is not None:
after_leave(original_node)
def on_visit_attribute(self, node: "CSTNode", attribute: str) -> None:
"""
Call appropriate visit attribute methods on node before visiting
attribute's children.
"""
type_name = type(node).__name__
for v in self.visitor_methods.get(f"visit_{type_name}_{attribute}", []):
v(node)
def on_leave_attribute(self, original_node: "CSTNode", attribute: str) -> None:
"""
Call appropriate leave attribute methods on node after visiting
attribute's children.
"""
type_name = type(original_node).__name__
for v in self.visitor_methods.get(f"leave_{type_name}_{attribute}", []):
v(original_node)