zmc
2023-10-12 ed135d79df12a2466b52dae1a82326941211dcc9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""Common utilities for Numba operations with groupby ops"""
from __future__ import annotations
 
import functools
import inspect
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
)
 
import numpy as np
 
from pandas._typing import Scalar
from pandas.compat._optional import import_optional_dependency
 
from pandas.core.util.numba_ import (
    NumbaUtilError,
    jit_user_function,
)
 
 
def validate_udf(func: Callable) -> None:
    """
    Validate user defined function for ops when using Numba with groupby ops.
 
    The first signature arguments should include:
 
    def f(values, index, ...):
        ...
 
    Parameters
    ----------
    func : function, default False
        user defined function
 
    Returns
    -------
    None
 
    Raises
    ------
    NumbaUtilError
    """
    if not callable(func):
        raise NotImplementedError(
            "Numba engine can only be used with a single function."
        )
    udf_signature = list(inspect.signature(func).parameters.keys())
    expected_args = ["values", "index"]
    min_number_args = len(expected_args)
    if (
        len(udf_signature) < min_number_args
        or udf_signature[:min_number_args] != expected_args
    ):
        raise NumbaUtilError(
            f"The first {min_number_args} arguments to {func.__name__} must be "
            f"{expected_args}"
        )
 
 
@functools.lru_cache(maxsize=None)
def generate_numba_agg_func(
    func: Callable[..., Scalar],
    nopython: bool,
    nogil: bool,
    parallel: bool,
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]:
    """
    Generate a numba jitted agg function specified by values from engine_kwargs.
 
    1. jit the user's function
    2. Return a groupby agg function with the jitted function inline
 
    Configurations specified in engine_kwargs apply to both the user's
    function _AND_ the groupby evaluation loop.
 
    Parameters
    ----------
    func : function
        function to be applied to each group and will be JITed
    nopython : bool
        nopython to be passed into numba.jit
    nogil : bool
        nogil to be passed into numba.jit
    parallel : bool
        parallel to be passed into numba.jit
 
    Returns
    -------
    Numba function
    """
    numba_func = jit_user_function(func, nopython, nogil, parallel)
    if TYPE_CHECKING:
        import numba
    else:
        numba = import_optional_dependency("numba")
 
    @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
    def group_agg(
        values: np.ndarray,
        index: np.ndarray,
        begin: np.ndarray,
        end: np.ndarray,
        num_columns: int,
        *args: Any,
    ) -> np.ndarray:
        assert len(begin) == len(end)
        num_groups = len(begin)
 
        result = np.empty((num_groups, num_columns))
        for i in numba.prange(num_groups):
            group_index = index[begin[i] : end[i]]
            for j in numba.prange(num_columns):
                group = values[begin[i] : end[i], j]
                result[i, j] = numba_func(group, group_index, *args)
        return result
 
    return group_agg
 
 
@functools.lru_cache(maxsize=None)
def generate_numba_transform_func(
    func: Callable[..., np.ndarray],
    nopython: bool,
    nogil: bool,
    parallel: bool,
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]:
    """
    Generate a numba jitted transform function specified by values from engine_kwargs.
 
    1. jit the user's function
    2. Return a groupby transform function with the jitted function inline
 
    Configurations specified in engine_kwargs apply to both the user's
    function _AND_ the groupby evaluation loop.
 
    Parameters
    ----------
    func : function
        function to be applied to each window and will be JITed
    nopython : bool
        nopython to be passed into numba.jit
    nogil : bool
        nogil to be passed into numba.jit
    parallel : bool
        parallel to be passed into numba.jit
 
    Returns
    -------
    Numba function
    """
    numba_func = jit_user_function(func, nopython, nogil, parallel)
    if TYPE_CHECKING:
        import numba
    else:
        numba = import_optional_dependency("numba")
 
    @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
    def group_transform(
        values: np.ndarray,
        index: np.ndarray,
        begin: np.ndarray,
        end: np.ndarray,
        num_columns: int,
        *args: Any,
    ) -> np.ndarray:
        assert len(begin) == len(end)
        num_groups = len(begin)
 
        result = np.empty((len(values), num_columns))
        for i in numba.prange(num_groups):
            group_index = index[begin[i] : end[i]]
            for j in numba.prange(num_columns):
                group = values[begin[i] : end[i], j]
                result[begin[i] : end[i], j] = numba_func(group, group_index, *args)
        return result
 
    return group_transform