# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import subprocess
from pathlib import Path
from typing import Dict, List, Mapping, Optional, Sequence, Tuple
from mypy_extensions import TypedDict
import libcst as cst
from libcst._position import CodePosition, CodeRange
from libcst.metadata.base_provider import BatchableMetadataProvider
from libcst.metadata.position_provider import PositionProvider
class Position(TypedDict):
line: int
column: int
class Location(TypedDict):
path: str
start: Position
stop: Position
class InferredType(TypedDict):
location: Location
annotation: str
class PyreData(TypedDict, total=False):
types: Sequence[InferredType]
[docs]class TypeInferenceProvider(BatchableMetadataProvider[str]):
"""
Access inferred type annotation through `Pyre Query API <https://pyre-check.org/docs/querying-pyre.html>`_.
It requires `setup watchman <https://pyre-check.org/docs/watchman-integration.html>`_
and start pyre server by running ``pyre`` command.
The inferred type is a string of `type annotation <https://docs.python.org/3/library/typing.html>`_.
E.g. ``typing.List[libcst._nodes.expression.Name]``
is the inferred type of name ``n`` in expression ``n = [cst.Name("")]``.
All name references use the fully qualified name regardless how the names are imported.
(e.g. ``import libcst; libcst.Name`` and ``import libcst as cst; cst.Name`` refer to the same name.)
Pyre infers the type of :class:`~libcst.Name`, :class:`~libcst.Attribute` and :class:`~libcst.Call` nodes.
The inter process communication to Pyre server is managed by :class:`~libcst.metadata.FullRepoManager`.
"""
METADATA_DEPENDENCIES = (PositionProvider,)
[docs] @staticmethod
# pyre-fixme[40]: Static method `gen_cache` cannot override a non-static method
# defined in `cst.metadata.base_provider.BaseMetadataProvider`.
def gen_cache(
root_path: Path, paths: List[str], timeout: Optional[int]
) -> Mapping[str, object]:
params = ",".join(f"path='{root_path / path}'" for path in paths)
cmd_args = ["pyre", "--noninteractive", "query", f"types({params})"]
try:
stdout, stderr, return_code = run_command(cmd_args, timeout=timeout)
except subprocess.TimeoutExpired as exc:
raise exc
if return_code != 0:
raise Exception(f"stderr:\n {stderr}\nstdout:\n {stdout}")
try:
resp = json.loads(stdout)["response"]
except Exception as e:
raise Exception(f"{e}\n\nstderr:\n {stderr}\nstdout:\n {stdout}")
return {path: _process_pyre_data(data) for path, data in zip(paths, resp)}
def __init__(self, cache: PyreData) -> None:
super().__init__(cache)
lookup: Dict[CodeRange, str] = {}
cache_types = cache.get("types", [])
for item in cache_types:
location = item["location"]
start = location["start"]
end = location["stop"]
lookup[
CodeRange(
start=CodePosition(start["line"], start["column"]),
end=CodePosition(end["line"], end["column"]),
)
] = item["annotation"]
self.lookup: Dict[CodeRange, str] = lookup
def _parse_metadata(self, node: cst.CSTNode) -> None:
range = self.get_metadata(PositionProvider, node)
if range in self.lookup:
self.set_metadata(node, self.lookup.pop(range))
def visit_Name(self, node: cst.Name) -> Optional[bool]:
self._parse_metadata(node)
def visit_Attribute(self, node: cst.Attribute) -> Optional[bool]:
self._parse_metadata(node)
def visit_Call(self, node: cst.Call) -> Optional[bool]:
self._parse_metadata(node)
def run_command(
cmd_args: List[str], timeout: Optional[int] = None
) -> Tuple[str, str, int]:
process = subprocess.run(cmd_args, capture_output=True, timeout=timeout)
return process.stdout.decode(), process.stderr.decode(), process.returncode
class RawPyreData(TypedDict):
path: str
types: Sequence[InferredType]
def _process_pyre_data(data: RawPyreData) -> PyreData:
return {"types": sorted(data["types"], key=_sort_by_position)}
def _sort_by_position(data: InferredType) -> Tuple[int, int, int, int]:
start = data["location"]["start"]
stop = data["location"]["stop"]
return start["line"], start["column"], stop["line"], stop["column"]