zmc
2023-08-08 e792e9a60d958b93aef96050644f369feb25d61b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#-----------------------------------------------------------------------------
# Copyright (c) 2005-2023, PyInstaller Development Team.
#
# Distributed under the terms of the GNU General Public License (version 2
# or later) with exception for distributing the bootloader.
#
# The full license is in the file COPYING.txt, distributed with this software.
#
# SPDX-License-Identifier: (GPL-2.0-or-later WITH Bootloader-exception)
#-----------------------------------------------------------------------------
 
import re
 
from PyInstaller import isolated
from PyInstaller.lib.modulegraph.modulegraph import SourceModule
from PyInstaller.lib.modulegraph.util import guess_encoding
from PyInstaller.utils.hooks import is_module_satisfies, logger
 
# 'sqlalchemy.testing' causes bundling a lot of unnecessary modules.
excludedimports = ['sqlalchemy.testing']
 
# Include most common database bindings some database bindings are detected and include some are not. We should
# explicitly include database backends.
hiddenimports = ['pysqlite2', 'MySQLdb', 'psycopg2', 'sqlalchemy.ext.baked']
 
if is_module_satisfies('sqlalchemy >= 1.4'):
    hiddenimports.append("sqlalchemy.sql.default_comparator")
 
 
@isolated.decorate
def _get_dialect_modules(module_name):
    import importlib
    module = importlib.import_module(module_name)
    return [f"{module_name}.{submodule_name}" for submodule_name in module.__all__]
 
 
# In SQLAlchemy >= 0.6, the "sqlalchemy.dialects" package provides dialects.
# In SQLAlchemy <= 0.5, the "sqlalchemy.databases" package provides dialects.
if is_module_satisfies('sqlalchemy >= 0.6'):
    hiddenimports += _get_dialect_modules("sqlalchemy.dialects")
else:
    hiddenimports += _get_dialect_modules("sqlalchemy.databases")
 
 
def hook(hook_api):
    """
    SQLAlchemy 0.9 introduced the decorator 'util.dependencies'.  This decorator does imports. E.g.:
 
            @util.dependencies("sqlalchemy.sql.schema")
 
    This hook scans for included SQLAlchemy modules and then scans those modules for any util.dependencies and marks
    those modules as hidden imports.
    """
 
    if not is_module_satisfies('sqlalchemy >= 0.9'):
        return
 
    # this parser is very simplistic but seems to catch all cases as of V1.1
    depend_regex = re.compile(r'@util.dependencies\([\'"](.*?)[\'"]\)')
 
    hidden_imports_set = set()
    known_imports = set()
    for node in hook_api.module_graph.iter_graph(start=hook_api.module):
        if isinstance(node, SourceModule) and node.identifier.startswith('sqlalchemy.'):
            known_imports.add(node.identifier)
            # Determine the encoding of the source file.
            with open(node.filename, 'rb') as f:
                encoding = guess_encoding(f)
            # Use that to open the file.
            with open(node.filename, 'r', encoding=encoding) as f:
                for match in depend_regex.findall(f.read()):
                    hidden_imports_set.add(match)
 
    hidden_imports_set -= known_imports
    if len(hidden_imports_set):
        logger.info("  Found %d sqlalchemy hidden imports", len(hidden_imports_set))
        hook_api.add_imports(*list(hidden_imports_set))