Source code for libcst.metadata.type_inference_provider

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import json
import subprocess
from pathlib import Path
from typing import Dict, List, Mapping, Optional, Sequence, Tuple

from mypy_extensions import TypedDict

import libcst as cst
from libcst._position import CodePosition, CodeRange
from libcst.metadata.base_provider import BatchableMetadataProvider
from libcst.metadata.position_provider import PositionProvider


class Position(TypedDict):
    line: int
    column: int


class Location(TypedDict):
    path: str
    start: Position
    stop: Position


class InferredType(TypedDict):
    location: Location
    annotation: str


class PyreData(TypedDict, total=False):
    types: Sequence[InferredType]


[docs]class TypeInferenceProvider(BatchableMetadataProvider[str]): """ Access inferred type annotation through `Pyre Query API <https://pyre-check.org/docs/querying-pyre.html>`_. It requires `setup watchman <https://pyre-check.org/docs/watchman-integration.html>`_ and start pyre server by running ``pyre`` command. The inferred type is a string of `type annotation <https://docs.python.org/3/library/typing.html>`_. E.g. ``typing.List[libcst._nodes.expression.Name]`` is the inferred type of name ``n`` in expression ``n = [cst.Name("")]``. All name references use the fully qualified name regardless how the names are imported. (e.g. ``import libcst; libcst.Name`` and ``import libcst as cst; cst.Name`` refer to the same name.) Pyre infers the type of :class:`~libcst.Name`, :class:`~libcst.Attribute` and :class:`~libcst.Call` nodes. The inter process communication to Pyre server is managed by :class:`~libcst.metadata.FullRepoManager`. """ METADATA_DEPENDENCIES = (PositionProvider,)
[docs] @staticmethod # pyre-fixme[40]: Static method `gen_cache` cannot override a non-static method # defined in `cst.metadata.base_provider.BaseMetadataProvider`. def gen_cache( root_path: Path, paths: List[str], timeout: Optional[int] ) -> Mapping[str, object]: params = ",".join(f"path='{root_path / path}'" for path in paths) cmd_args = ["pyre", "--noninteractive", "query", f"types({params})"] try: stdout, stderr, return_code = run_command(cmd_args, timeout=timeout) except subprocess.TimeoutExpired as exc: raise exc if return_code != 0: raise Exception(f"stderr:\n {stderr}\nstdout:\n {stdout}") try: resp = json.loads(stdout)["response"] except Exception as e: raise Exception(f"{e}\n\nstderr:\n {stderr}\nstdout:\n {stdout}") return {path: _process_pyre_data(data) for path, data in zip(paths, resp)}
def __init__(self, cache: PyreData) -> None: super().__init__(cache) lookup: Dict[CodeRange, str] = {} cache_types = cache.get("types", []) for item in cache_types: location = item["location"] start = location["start"] end = location["stop"] lookup[ CodeRange( start=CodePosition(start["line"], start["column"]), end=CodePosition(end["line"], end["column"]), ) ] = item["annotation"] self.lookup: Dict[CodeRange, str] = lookup def _parse_metadata(self, node: cst.CSTNode) -> None: range = self.get_metadata(PositionProvider, node) if range in self.lookup: self.set_metadata(node, self.lookup.pop(range)) def visit_Name(self, node: cst.Name) -> Optional[bool]: self._parse_metadata(node) def visit_Attribute(self, node: cst.Attribute) -> Optional[bool]: self._parse_metadata(node) def visit_Call(self, node: cst.Call) -> Optional[bool]: self._parse_metadata(node)
def run_command( cmd_args: List[str], timeout: Optional[int] = None ) -> Tuple[str, str, int]: process = subprocess.run(cmd_args, capture_output=True, timeout=timeout) return process.stdout.decode(), process.stderr.decode(), process.returncode class RawPyreData(TypedDict): path: str types: Sequence[InferredType] def _process_pyre_data(data: RawPyreData) -> PyreData: return {"types": sorted(data["types"], key=_sort_by_position)} def _sort_by_position(data: InferredType) -> Tuple[int, int, int, int]: start = data["location"]["start"] stop = data["location"]["stop"] return start["line"], start["column"], stop["line"], stop["column"]