Source code for libcst.helpers.node_fields

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import dataclasses
from typing import TYPE_CHECKING

from libcst import IndentedBlock, Module
from libcst._nodes.deep_equals import deep_equals

if TYPE_CHECKING:
    from typing import Sequence

    from libcst import CSTNode


[docs] def get_node_fields(node: CSTNode) -> Sequence[dataclasses.Field[CSTNode]]: """ Returns the sequence of a given CST-node's fields. """ return dataclasses.fields(node)
[docs] def is_whitespace_node_field(node: CSTNode, field: dataclasses.Field[CSTNode]) -> bool: """ Returns True if a given CST-node's field is a whitespace-related field (whitespace, indent, header, footer, etc.). """ if "whitespace" in field.name: return True if "leading_lines" in field.name: return True if "lines_after_decorators" in field.name: return True if isinstance(node, (IndentedBlock, Module)) and field.name in [ "header", "footer", ]: return True if isinstance(node, IndentedBlock) and field.name == "indent": return True return False
[docs] def is_syntax_node_field(node: CSTNode, field: dataclasses.Field[CSTNode]) -> bool: """ Returns True if a given CST-node's field is a syntax-related field (colon, semicolon, dot, encoding, etc.). """ if isinstance(node, Module) and field.name in [ "encoding", "default_indent", "default_newline", "has_trailing_newline", ]: return True type_str = repr(field.type) if ( "Sentinel" in type_str and field.name not in ["star_arg", "star", "posonly_ind"] and "whitespace" not in field.name ): # This is a value that can optionally be specified, so its # definitely syntax. return True for name in ["Semicolon", "Colon", "Comma", "Dot", "AssignEqual"]: # These are all nodes that exist for separation syntax if name in type_str: return True return False
[docs] def get_field_default_value(field: dataclasses.Field[CSTNode]) -> object: """ Returns the default value of a CST-node's field. """ if field.default_factory is not dataclasses.MISSING: # pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE, # dataclasses._DefaultFactory[object]]` is not a function. return field.default_factory() return field.default
[docs] def is_default_node_field(node: CSTNode, field: dataclasses.Field[CSTNode]) -> bool: """ Returns True if a given CST-node's field has its default value. """ return deep_equals(getattr(node, field.name), get_field_default_value(field))
[docs] def filter_node_fields( node: CSTNode, *, show_defaults: bool, show_syntax: bool, show_whitespace: bool, ) -> Sequence[dataclasses.Field[CSTNode]]: """ Returns a filtered sequence of a CST-node's fields. Setting ``show_whitespace`` to ``False`` will filter whitespace fields. Setting ``show_defaults`` to ``False`` will filter fields if their value is equal to the default value ; while respecting the value of ``show_whitespace``. Setting ``show_syntax`` to ``False`` will filter syntax fields ; while respecting the value of ``show_whitespace`` & ``show_defaults``. """ fields: Sequence[dataclasses.Field[CSTNode]] = dataclasses.fields(node) # Hide all fields prefixed with "_" fields = [f for f in fields if f.name[0] != "_"] # Filter whitespace nodes if needed if not show_whitespace: fields = [f for f in fields if not is_whitespace_node_field(node, f)] # Filter values which aren't changed from their defaults if not show_defaults: fields = [f for f in fields if not is_default_node_field(node, f)] # Filter out values which aren't interesting if needed if not show_syntax: fields = [f for f in fields if not is_syntax_node_field(node, f)] return fields