diff --git a/stdlib/csv.pyi b/stdlib/csv.pyi index 606694dca..db48e9e6c 100644 --- a/stdlib/csv.pyi +++ b/stdlib/csv.pyi @@ -17,10 +17,14 @@ from _csv import ( unregister_dialect as unregister_dialect, writer as writer, ) -from collections import OrderedDict -from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Sequence, Text, Type +from typing import Any, Generic, Iterable, Iterator, List, Mapping, Optional, Sequence, Text, Type, TypeVar, overload -_DictRow = Mapping[str, Any] +if sys.version_info >= (3, 8) or sys.version_info < (3, 6): + from typing import Dict as _DictReadMapping +else: + from collections import OrderedDict as _DictReadMapping + +_T = TypeVar("_T") class excel(Dialect): delimiter: str @@ -42,23 +46,28 @@ if sys.version_info >= (3,): lineterminator: str quoting: int -if sys.version_info >= (3, 8): - _DRMapping = Dict[str, str] -elif sys.version_info >= (3, 6): - _DRMapping = OrderedDict[str, str] -else: - _DRMapping = Dict[str, str] - -class DictReader(Iterator[_DRMapping]): +class DictReader(Generic[_T], Iterator[_DictReadMapping[_T, str]]): + fieldnames: Optional[Sequence[_T]] restkey: Optional[str] restval: Optional[str] reader: _reader dialect: _DialectLike line_num: int - fieldnames: Optional[Sequence[str]] + @overload def __init__( self, f: Iterable[Text], + fieldnames: Sequence[_T], + restkey: Optional[str] = ..., + restval: Optional[str] = ..., + dialect: _DialectLike = ..., + *args: Any, + **kwds: Any, + ) -> None: ... + @overload + def __init__( + self: DictReader[str], + f: Iterable[Text], fieldnames: Optional[Sequence[str]] = ..., restkey: Optional[str] = ..., restval: Optional[str] = ..., @@ -66,21 +75,21 @@ class DictReader(Iterator[_DRMapping]): *args: Any, **kwds: Any, ) -> None: ... - def __iter__(self) -> DictReader: ... + def __iter__(self) -> DictReader[_T]: ... if sys.version_info >= (3,): - def __next__(self) -> _DRMapping: ... + def __next__(self) -> _DictReadMapping[_T, str]: ... else: - def next(self) -> _DRMapping: ... + def next(self) -> _DictReadMapping[_T, str]: ... -class DictWriter(object): - fieldnames: Sequence[str] +class DictWriter(Generic[_T]): + fieldnames: Sequence[_T] restval: Optional[Any] extrasaction: str writer: _writer def __init__( self, f: Any, - fieldnames: Iterable[str], + fieldnames: Sequence[_T], restval: Optional[Any] = ..., extrasaction: str = ..., dialect: _DialectLike = ..., @@ -91,8 +100,8 @@ class DictWriter(object): def writeheader(self) -> Any: ... else: def writeheader(self) -> None: ... - def writerow(self, rowdict: _DictRow) -> Any: ... - def writerows(self, rowdicts: Iterable[_DictRow]) -> None: ... + def writerow(self, rowdict: Mapping[_T, Any]) -> Any: ... + def writerows(self, rowdicts: Iterable[Mapping[_T, Any]]) -> None: ... class Sniffer(object): preferred: List[str]