Coverage for sqlalchemy_csv_writer/writer.py: 100%
62 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-19 16:16 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-19 16:16 +0000
1"""Provide SQLAlchemy Csv Writer."""
2import csv
3import typing
4from pathlib import Path
5from typing import Union
7from sqlalchemy import inspect
8from sqlalchemy.orm import DeclarativeBase
11class SQLAlchemyCsvWriter:
12 """Write SQL Alchemy results to a csv file."""
14 def __init__(
15 self,
16 csvfile: Union[Path, str, typing.IO],
17 header: Union[list[str], bool] = True,
18 prefix_model_names: bool = False,
19 field_formats: Union[dict[str, str], None] = None,
20 dialect="excel",
21 **fmtparams,
22 ):
23 """Create a SQLAlchemyCsvWriter instance.
25 The instance's methods can be used to write rows to the specified csv file.
27 Args:
28 csvfile: Path or File-like object to write the resulting csv data to
29 header: True to automatically generate header, False to disable header or list of strings for custom header
30 prefix_model_names: Whether to prefix the model names in the header
31 field_formats: Dictionary containing the column name as keys and column format as a values (using % style format syntax)
32 dialect: csv dialect to use
33 **fmtparams: extra formatting parameters to pass to csv.writer instance
34 """
35 if isinstance(csvfile, str):
36 csvfile = Path(csvfile)
37 if isinstance(csvfile, Path):
38 csvfile.parent.mkdir(exist_ok=True, parents=True)
39 csvfile = open(csvfile, "w", encoding="utf-8") # noqa: SIM115
41 self.csvfile = csvfile
42 self.writer = csv.writer(csvfile, dialect=dialect, **fmtparams)
43 self.header = header
44 self.prefix_model_names = prefix_model_names
45 self.field_formats = field_formats if field_formats else {}
46 self.header_row_written = False
48 def __del__(self):
49 """Close open resources."""
50 if hasattr(self.csvfile, "close"):
51 self.csvfile.close()
53 def __enter__(self):
54 """SQLAlchemyCsvWriter may be used as a context manager."""
55 return self
57 def __exit__(self, type, value, traceback):
58 """Exit context manager."""
59 self.__del__()
61 async def write_rows_stream(self, results: list):
62 """Write query results to csv.
64 Write query results retrieved with SQLAlchemy's .stream or .stream_scalars.
66 Args:
67 results: query results retrieved with SQLAlchemy's .stream or .stream_scalars
68 """
69 async for result in results:
70 self._process_result(result)
72 def write_rows(self, results: list):
73 """Write query results to csv.
75 Write query results retrieved with SQLAlchemy's .execute or .scalars. to csv
77 Args:
78 results: query results retrieved with SQLAlchemy's .execute or .scalars
79 """
80 for result in results:
81 self._process_result(result)
83 def _process_result(self, result):
84 result = self._extract_columns(result)
86 # write header
87 if self.header and not self.header_row_written:
88 if self.header is True:
89 self.writer.writerow([r[0] for r in result])
90 else:
91 if len(self.header) == len(result):
92 self.writer.writerow(self.header)
93 else:
94 raise ValueError("Length of header and content does not match.")
95 self.header_row_written = True
97 # write data
98 values = []
99 for key, value in result:
100 if str(key) in self.field_formats:
101 value = self.field_formats[str(key)] % value
102 values.append(value)
103 self.writer.writerow(values)
105 def _extract_columns(self, result):
106 columns = []
107 if hasattr(result, "_mapping"): # is not a scalar
108 for element_key, element_value in result._mapping.items():
109 if isinstance(element_value, DeclarativeBase): # is an orm model
110 columns.extend(self._extract_model(element_value))
111 else: # is a column
112 columns.append((element_key, element_value))
113 elif isinstance(result, DeclarativeBase): # is a scalar orm model
114 columns.extend(self._extract_model(result))
116 return columns
118 def _extract_model(self, obj):
119 insp = inspect(obj)
120 if self.prefix_model_names:
121 return [(getattr(insp.class_, attr.key), attr.value) for attr in insp.attrs]
122 else:
123 return [(attr.key, attr.value) for attr in insp.attrs]