Coverage for sqlalchemy_csv_writer/writer.py: 100%

62 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-19 16:16 +0000

1"""Provide SQLAlchemy Csv Writer.""" 

2import csv 

3import typing 

4from pathlib import Path 

5from typing import Union 

6 

7from sqlalchemy import inspect 

8from sqlalchemy.orm import DeclarativeBase 

9 

10 

11class SQLAlchemyCsvWriter: 

12 """Write SQL Alchemy results to a csv file.""" 

13 

14 def __init__( 

15 self, 

16 csvfile: Union[Path, str, typing.IO], 

17 header: Union[list[str], bool] = True, 

18 prefix_model_names: bool = False, 

19 field_formats: Union[dict[str, str], None] = None, 

20 dialect="excel", 

21 **fmtparams, 

22 ): 

23 """Create a SQLAlchemyCsvWriter instance. 

24 

25 The instance's methods can be used to write rows to the specified csv file. 

26 

27 Args: 

28 csvfile: Path or File-like object to write the resulting csv data to 

29 header: True to automatically generate header, False to disable header or list of strings for custom header 

30 prefix_model_names: Whether to prefix the model names in the header 

31 field_formats: Dictionary containing the column name as keys and column format as a values (using % style format syntax) 

32 dialect: csv dialect to use 

33 **fmtparams: extra formatting parameters to pass to csv.writer instance 

34 """ 

35 if isinstance(csvfile, str): 

36 csvfile = Path(csvfile) 

37 if isinstance(csvfile, Path): 

38 csvfile.parent.mkdir(exist_ok=True, parents=True) 

39 csvfile = open(csvfile, "w", encoding="utf-8") # noqa: SIM115 

40 

41 self.csvfile = csvfile 

42 self.writer = csv.writer(csvfile, dialect=dialect, **fmtparams) 

43 self.header = header 

44 self.prefix_model_names = prefix_model_names 

45 self.field_formats = field_formats if field_formats else {} 

46 self.header_row_written = False 

47 

48 def __del__(self): 

49 """Close open resources.""" 

50 if hasattr(self.csvfile, "close"): 

51 self.csvfile.close() 

52 

53 def __enter__(self): 

54 """SQLAlchemyCsvWriter may be used as a context manager.""" 

55 return self 

56 

57 def __exit__(self, type, value, traceback): 

58 """Exit context manager.""" 

59 self.__del__() 

60 

61 async def write_rows_stream(self, results: list): 

62 """Write query results to csv. 

63 

64 Write query results retrieved with SQLAlchemy's .stream or .stream_scalars. 

65 

66 Args: 

67 results: query results retrieved with SQLAlchemy's .stream or .stream_scalars 

68 """ 

69 async for result in results: 

70 self._process_result(result) 

71 

72 def write_rows(self, results: list): 

73 """Write query results to csv. 

74 

75 Write query results retrieved with SQLAlchemy's .execute or .scalars. to csv 

76 

77 Args: 

78 results: query results retrieved with SQLAlchemy's .execute or .scalars 

79 """ 

80 for result in results: 

81 self._process_result(result) 

82 

83 def _process_result(self, result): 

84 result = self._extract_columns(result) 

85 

86 # write header 

87 if self.header and not self.header_row_written: 

88 if self.header is True: 

89 self.writer.writerow([r[0] for r in result]) 

90 else: 

91 if len(self.header) == len(result): 

92 self.writer.writerow(self.header) 

93 else: 

94 raise ValueError("Length of header and content does not match.") 

95 self.header_row_written = True 

96 

97 # write data 

98 values = [] 

99 for key, value in result: 

100 if str(key) in self.field_formats: 

101 value = self.field_formats[str(key)] % value 

102 values.append(value) 

103 self.writer.writerow(values) 

104 

105 def _extract_columns(self, result): 

106 columns = [] 

107 if hasattr(result, "_mapping"): # is not a scalar 

108 for element_key, element_value in result._mapping.items(): 

109 if isinstance(element_value, DeclarativeBase): # is an orm model 

110 columns.extend(self._extract_model(element_value)) 

111 else: # is a column 

112 columns.append((element_key, element_value)) 

113 elif isinstance(result, DeclarativeBase): # is a scalar orm model 

114 columns.extend(self._extract_model(result)) 

115 

116 return columns 

117 

118 def _extract_model(self, obj): 

119 insp = inspect(obj) 

120 if self.prefix_model_names: 

121 return [(getattr(insp.class_, attr.key), attr.value) for attr in insp.attrs] 

122 else: 

123 return [(attr.key, attr.value) for attr in insp.attrs]