From dc0b63fd68852e02dd9ee9ffe0ac05fab55b7dfd Mon Sep 17 00:00:00 2001 From: Max Muoto Date: Sun, 11 Aug 2024 18:17:06 -0500 Subject: [PATCH] Accurate overloads for `ZipFile.__init__` (#12119) Co-authored-by: Shantanu <12621235+hauntsaninja@users.noreply.github.com> --- stdlib/@tests/test_cases/check_zipfile.py | 131 ++++++++++++++++++++++ stdlib/zipfile/__init__.pyi | 80 ++++++++++++- 2 files changed, 208 insertions(+), 3 deletions(-) create mode 100644 stdlib/@tests/test_cases/check_zipfile.py diff --git a/stdlib/@tests/test_cases/check_zipfile.py b/stdlib/@tests/test_cases/check_zipfile.py new file mode 100644 index 000000000..3012271cc --- /dev/null +++ b/stdlib/@tests/test_cases/check_zipfile.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import io +import pathlib +import zipfile +from typing import Literal + +### +# Tests for `zipfile.ZipFile` +### + +p = pathlib.Path("test.zip") + + +class CustomPathObj: + def __init__(self, path: str) -> None: + self.path = path + + def __fspath__(self) -> str: + return self.path + + +class NonPathObj: + def __init__(self, path: str) -> None: + self.path = path + + +class ReadableObj: + def seek(self, offset: int, whence: int = 0) -> int: + return 0 + + def read(self, n: int | None = -1) -> bytes: + return b"test" + + +class TellableObj: + def tell(self) -> int: + return 0 + + +class WriteableObj: + def close(self) -> None: + pass + + def write(self, b: bytes) -> int: + return len(b) + + def flush(self) -> None: + pass + + +class ReadTellableObj(ReadableObj): + def tell(self) -> int: + return 0 + + +class SeekTellObj: + def seek(self, offset: int, whence: int = 0) -> int: + return 0 + + def tell(self) -> int: + return 0 + + +def write_zip(mode: Literal["r", "w", "x", "a"]) -> None: + # Test any mode with `pathlib.Path` + with zipfile.ZipFile(p, mode) as z: + z.writestr("test.txt", "test") + + # Test any mode with `str` path + with zipfile.ZipFile("test.zip", mode) as z: + z.writestr("test.txt", "test") + + # Test any mode with `os.PathLike` object + with zipfile.ZipFile(CustomPathObj("test.zip"), mode) as z: + z.writestr("test.txt", "test") + + # Non-PathLike object should raise an error + with zipfile.ZipFile(NonPathObj("test.zip"), mode) as z: # type: ignore + z.writestr("test.txt", "test") + + # IO[bytes] like-obj should work for any mode. + io_obj = io.BytesIO(b"This is a test") + with zipfile.ZipFile(io_obj, mode) as z: + z.writestr("test.txt", "test") + + # Readable object should not work for any mode. + with zipfile.ZipFile(ReadableObj(), mode) as z: # type: ignore + z.writestr("test.txt", "test") + + # Readable object should work for "r" mode. + with zipfile.ZipFile(ReadableObj(), "r") as z: + z.writestr("test.txt", "test") + + # Readable/tellable object should work for "a" mode. + with zipfile.ZipFile(ReadTellableObj(), "a") as z: + z.writestr("test.txt", "test") + + # If it doesn't have 'tell' method, it should raise an error. + with zipfile.ZipFile(ReadableObj(), "a") as z: # type: ignore + z.writestr("test.txt", "test") + + # Readable object should not work for "w" mode. + with zipfile.ZipFile(ReadableObj(), "w") as z: # type: ignore + z.writestr("test.txt", "test") + + # Tellable object should not work for any mode. + with zipfile.ZipFile(TellableObj(), mode) as z: # type: ignore + z.writestr("test.txt", "test") + + # Tellable object shouldn't work for "w" mode. + # As `__del__` will call close. + with zipfile.ZipFile(TellableObj(), "w") as z: # type: ignore + z.writestr("test.txt", "test") + + # Writeable object should not work for any mode. + with zipfile.ZipFile(WriteableObj(), mode) as z: # type: ignore + z.writestr("test.txt", "test") + + # Writeable object should work for "w" mode. + with zipfile.ZipFile(WriteableObj(), "w") as z: + z.writestr("test.txt", "test") + + # Seekable and Tellable object should not work for any mode. + with zipfile.ZipFile(SeekTellObj(), mode) as z: # type: ignore + z.writestr("test.txt", "test") + + # Seekable and Tellable object shouldn't work for "w" mode. + # Cause `__del__` will call close. + with zipfile.ZipFile(SeekTellObj(), "w") as z: # type: ignore + z.writestr("test.txt", "test") diff --git a/stdlib/zipfile/__init__.pyi b/stdlib/zipfile/__init__.pyi index 57a8a6aaa..85eb2b6df 100644 --- a/stdlib/zipfile/__init__.pyi +++ b/stdlib/zipfile/__init__.pyi @@ -94,6 +94,20 @@ class ZipExtFile(io.BufferedIOBase): class _Writer(Protocol): def write(self, s: str, /) -> object: ... +class _ZipReadable(Protocol): + def seek(self, offset: int, whence: int = 0, /) -> int: ... + def read(self, n: int = -1, /) -> bytes: ... + +class _ZipTellable(Protocol): + def tell(self) -> int: ... + +class _ZipReadableTellable(_ZipReadable, _ZipTellable, Protocol): ... + +class _ZipWritable(Protocol): + def flush(self) -> None: ... + def close(self) -> None: ... + def write(self, b: bytes, /) -> int: ... + class ZipFile: filename: str | None debug: int @@ -106,24 +120,50 @@ class ZipFile: compresslevel: int | None # undocumented mode: _ZipFileMode # undocumented pwd: bytes | None # undocumented + # metadata_encoding is new in 3.11 if sys.version_info >= (3, 11): @overload def __init__( self, file: StrPath | IO[bytes], + mode: _ZipFileMode = "r", + compression: int = 0, + allowZip64: bool = True, + compresslevel: int | None = None, + *, + strict_timestamps: bool = True, + metadata_encoding: str | None = None, + ) -> None: ... + # metadata_encoding is only allowed for read mode + @overload + def __init__( + self, + file: StrPath | _ZipReadable, mode: Literal["r"] = "r", compression: int = 0, allowZip64: bool = True, compresslevel: int | None = None, *, strict_timestamps: bool = True, - metadata_encoding: str | None, + metadata_encoding: str | None = None, ) -> None: ... @overload def __init__( self, - file: StrPath | IO[bytes], - mode: _ZipFileMode = "r", + file: StrPath | _ZipWritable, + mode: Literal["w", "x"] = ..., + compression: int = 0, + allowZip64: bool = True, + compresslevel: int | None = None, + *, + strict_timestamps: bool = True, + metadata_encoding: None = None, + ) -> None: ... + @overload + def __init__( + self, + file: StrPath | _ZipReadableTellable, + mode: Literal["a"] = ..., compression: int = 0, allowZip64: bool = True, compresslevel: int | None = None, @@ -132,6 +172,7 @@ class ZipFile: metadata_encoding: None = None, ) -> None: ... else: + @overload def __init__( self, file: StrPath | IO[bytes], @@ -142,6 +183,39 @@ class ZipFile: *, strict_timestamps: bool = True, ) -> None: ... + @overload + def __init__( + self, + file: StrPath | _ZipReadable, + mode: Literal["r"] = "r", + compression: int = 0, + allowZip64: bool = True, + compresslevel: int | None = None, + *, + strict_timestamps: bool = True, + ) -> None: ... + @overload + def __init__( + self, + file: StrPath | _ZipWritable, + mode: Literal["w", "x"] = ..., + compression: int = 0, + allowZip64: bool = True, + compresslevel: int | None = None, + *, + strict_timestamps: bool = True, + ) -> None: ... + @overload + def __init__( + self, + file: StrPath | _ZipReadableTellable, + mode: Literal["a"] = ..., + compression: int = 0, + allowZip64: bool = True, + compresslevel: int | None = None, + *, + strict_timestamps: bool = True, + ) -> None: ... def __enter__(self) -> Self: ... def __exit__(