Add PEP 706 filters to tarfile (#10316)

Fixes #10315

Co-authored-by: Sebastian Rittau <srittau@rittau.biz>
This commit is contained in:
Jelle Zijlstra
2023-06-14 07:08:32 -07:00
committed by GitHub
parent 5f9d05c7f5
commit 5beddbe883
8 changed files with 116 additions and 23 deletions

View File

@@ -7,7 +7,7 @@ from collections.abc import Callable, Iterable, Iterator, Mapping
from gzip import _ReadableFileobj as _GzipReadableFileobj, _WritableFileobj as _GzipWritableFileobj
from types import TracebackType
from typing import IO, ClassVar, Protocol, overload
from typing_extensions import Literal, Self
from typing_extensions import Literal, Self, TypeAlias
__all__ = [
"TarFile",
@@ -26,6 +26,21 @@ __all__ = [
"DEFAULT_FORMAT",
"open",
]
if sys.version_info >= (3, 12):
__all__ += [
"fully_trusted_filter",
"data_filter",
"tar_filter",
"FilterError",
"AbsoluteLinkError",
"OutsideDestinationError",
"SpecialFileError",
"AbsolutePathError",
"LinkOutsideDestinationError",
]
_FilterFunction: TypeAlias = Callable[[TarInfo, str], TarInfo | None]
_TarfileFilter: TypeAlias = Literal["fully_trusted", "tar", "data"] | _FilterFunction
class _Fileobj(Protocol):
def read(self, __size: int) -> bytes: ...
@@ -125,6 +140,7 @@ class TarFile:
debug: int | None
errorlevel: int | None
offset: int # undocumented
extraction_filter: _FilterFunction | None
def __init__(
self,
name: StrOrBytesPath | None = None,
@@ -275,12 +291,32 @@ class TarFile:
def getnames(self) -> _list[str]: ...
def list(self, verbose: bool = True, *, members: _list[TarInfo] | None = None) -> None: ...
def next(self) -> TarInfo | None: ...
def extractall(
self, path: StrOrBytesPath = ".", members: Iterable[TarInfo] | None = None, *, numeric_owner: bool = False
) -> None: ...
def extract(
self, member: str | TarInfo, path: StrOrBytesPath = "", set_attrs: bool = True, *, numeric_owner: bool = False
) -> None: ...
if sys.version_info >= (3, 8):
def extractall(
self,
path: StrOrBytesPath = ".",
members: Iterable[TarInfo] | None = None,
*,
numeric_owner: bool = False,
filter: _TarfileFilter | None = ...,
) -> None: ...
def extract(
self,
member: str | TarInfo,
path: StrOrBytesPath = "",
set_attrs: bool = True,
*,
numeric_owner: bool = False,
filter: _TarfileFilter | None = ...,
) -> None: ...
else:
def extractall(
self, path: StrOrBytesPath = ".", members: Iterable[TarInfo] | None = None, *, numeric_owner: bool = False
) -> None: ...
def extract(
self, member: str | TarInfo, path: StrOrBytesPath = "", set_attrs: bool = True, *, numeric_owner: bool = False
) -> None: ...
def _extract_member(
self, tarinfo: TarInfo, targetpath: str, set_attrs: bool = True, numeric_owner: bool = False
) -> None: ... # undocumented
@@ -324,6 +360,31 @@ class StreamError(TarError): ...
class ExtractError(TarError): ...
class HeaderError(TarError): ...
if sys.version_info >= (3, 8):
class FilterError(TarError):
# This attribute is only set directly on the subclasses, but the documentation guarantees
# that it is always present on FilterError.
tarinfo: TarInfo
class AbsolutePathError(FilterError):
def __init__(self, tarinfo: TarInfo) -> None: ...
class OutsideDestinationError(FilterError):
def __init__(self, tarinfo: TarInfo, path: str) -> None: ...
class SpecialFileError(FilterError):
def __init__(self, tarinfo: TarInfo) -> None: ...
class AbsoluteLinkError(FilterError):
def __init__(self, tarinfo: TarInfo) -> None: ...
class LinkOutsideDestinationError(FilterError):
def __init__(self, tarinfo: TarInfo, path: str) -> None: ...
def fully_trusted_filter(member: TarInfo, dest_path: str) -> TarInfo: ...
def tar_filter(member: TarInfo, dest_path: str) -> TarInfo: ...
def data_filter(member: TarInfo, dest_path: str) -> TarInfo: ...
class TarInfo:
name: str
path: str
@@ -353,6 +414,21 @@ class TarInfo:
def linkpath(self) -> str: ...
@linkpath.setter
def linkpath(self, linkname: str) -> None: ...
if sys.version_info >= (3, 8):
def replace(
self,
*,
name: str = ...,
mtime: int = ...,
mode: int = ...,
linkname: str = ...,
uid: int = ...,
gid: int = ...,
uname: str = ...,
gname: str = ...,
deep: bool = True,
) -> Self: ...
def get_info(self) -> Mapping[str, str | int | bytes | Mapping[str, str]]: ...
if sys.version_info >= (3, 8):
def tobuf(self, format: int | None = 2, encoding: str | None = "utf-8", errors: str = "surrogateescape") -> bytes: ...