From 9150db52a6280c7a40f1a8029f5b390314ba0dda Mon Sep 17 00:00:00 2001 From: Nikita Sobolev Date: Fri, 11 Feb 2022 14:30:00 +0300 Subject: [PATCH] Improve `pstats` typing (#7174) * Add __all__ * Use Literal for Stats.sort_stats() argument type. * Add FunctionProfile, StatsProfile, and Stats.get_stats_profile() (Python 3.9+) --- stdlib/pstats.pyi | 36 ++++++++++++++++++++++++++--- tests/stubtest_allowlists/py310.txt | 3 --- tests/stubtest_allowlists/py39.txt | 3 --- 3 files changed, 33 insertions(+), 9 deletions(-) diff --git a/stdlib/pstats.pyi b/stdlib/pstats.pyi index 7b44e34d7..c4fe28477 100644 --- a/stdlib/pstats.pyi +++ b/stdlib/pstats.pyi @@ -3,6 +3,14 @@ from _typeshed import Self, StrOrBytesPath from cProfile import Profile as _cProfile from profile import Profile from typing import IO, Any, Iterable, Union, overload +from typing_extensions import Literal + +if sys.version_info >= (3, 9): + __all__ = ["Stats", "SortKey", "FunctionProfile", "StatsProfile"] +elif sys.version_info >= (3, 7): + __all__ = ["Stats", "SortKey"] +else: + __all__ = ["Stats"] _Selector = Union[str, float, int] @@ -20,8 +28,27 @@ if sys.version_info >= (3, 7): STDNAME: str TIME: str +if sys.version_info >= (3, 9): + from dataclasses import dataclass + + @dataclass(unsafe_hash=True) + class FunctionProfile: + ncalls: int + tottime: float + percall_tottime: float + cumtime: float + percall_cumtime: float + file_name: str + line_number: int + @dataclass(unsafe_hash=True) + class StatsProfile: + total_tt: float + func_profiles: dict[str, FunctionProfile] + +_SortArgDict = dict[str, tuple[tuple[tuple[int, int], ...], str]] + class Stats: - sort_arg_dict_default: dict[str, tuple[Any, str]] + sort_arg_dict_default: _SortArgDict def __init__( self: Self, __arg: None | str | Profile | _cProfile = ..., @@ -33,15 +60,18 @@ class Stats: def get_top_level_stats(self) -> None: ... def add(self: Self, *arg_list: None | str | Profile | _cProfile | Self) -> Self: ... def dump_stats(self, filename: StrOrBytesPath) -> None: ... - def get_sort_arg_defs(self) -> dict[str, tuple[tuple[tuple[int, int], ...], str]]: ... + def get_sort_arg_defs(self) -> _SortArgDict: ... @overload - def sort_stats(self: Self, field: int) -> Self: ... + def sort_stats(self: Self, field: Literal[-1, 0, 1, 2]) -> Self: ... @overload def sort_stats(self: Self, *field: str) -> Self: ... def reverse_order(self: Self) -> Self: ... def strip_dirs(self: Self) -> Self: ... def calc_callees(self) -> None: ... def eval_print_amount(self, sel: _Selector, list: list[str], msg: str) -> tuple[list[str], str]: ... + if sys.version_info >= (3, 9): + def get_stats_profile(self) -> StatsProfile: ... + def get_print_list(self, sel_list: Iterable[_Selector]) -> tuple[int, list[str]]: ... def print_stats(self: Self, *amount: _Selector) -> Self: ... def print_callees(self: Self, *amount: _Selector) -> Self: ... diff --git a/tests/stubtest_allowlists/py310.txt b/tests/stubtest_allowlists/py310.txt index e9205c10b..d9ba1aeb1 100644 --- a/tests/stubtest_allowlists/py310.txt +++ b/tests/stubtest_allowlists/py310.txt @@ -176,9 +176,6 @@ multiprocessing.managers.SharedMemoryServer.public multiprocessing.managers.SharedMemoryServer.release_segment multiprocessing.managers.SharedMemoryServer.shutdown multiprocessing.managers.SharedMemoryServer.track_segment -pstats.FunctionProfile -pstats.Stats.get_stats_profile -pstats.StatsProfile pyexpat.XMLParserType.SkippedEntityHandler pyexpat.XMLParserType.intern sched.Event.sequence diff --git a/tests/stubtest_allowlists/py39.txt b/tests/stubtest_allowlists/py39.txt index f39513daf..92d763db8 100644 --- a/tests/stubtest_allowlists/py39.txt +++ b/tests/stubtest_allowlists/py39.txt @@ -150,9 +150,6 @@ multiprocessing.managers.SharedMemoryServer.public multiprocessing.managers.SharedMemoryServer.release_segment multiprocessing.managers.SharedMemoryServer.shutdown multiprocessing.managers.SharedMemoryServer.track_segment -pstats.FunctionProfile -pstats.Stats.get_stats_profile -pstats.StatsProfile pyexpat.XMLParserType.SkippedEntityHandler pyexpat.XMLParserType.intern stringprep.unicodedata # re-exported from unicodedata