Skip to content

Arrayutils

Utility functions for processing arrays.

PeriodicProcessSample

Bases: NamedTuple

A container for holding the output from process.PeriodicProcess().

Attributes:

Name Type Description
value ArrayLike

The sampled quantity.

pad_edges_to_match

pad_edges_to_match(
    x: ArrayLike,
    y: ArrayLike,
    axis: int = 0,
    pad_direction: str = "end",
    fix_y: bool = False,
) -> tuple[ArrayLike, ArrayLike]

Pad the shorter array at the start or end using the edge values to match the length of the longer array.

Parameters:

Name Type Description Default
x ArrayLike

First array.

required
y ArrayLike

Second array.

required
axis int

Axis along which to add padding, by default 0

0
pad_direction str

Direction to pad the shorter array, either "start" or "end", by default "end".

'end'
fix_y bool

If True, raise an error when y is shorter than x, by default False.

False

Returns:

Type Description
tuple[ArrayLike, ArrayLike]

Tuple of the two arrays with the same length.

Source code in pyrenew/arrayutils.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def pad_edges_to_match(
    x: ArrayLike,
    y: ArrayLike,
    axis: int = 0,
    pad_direction: str = "end",
    fix_y: bool = False,
) -> tuple[ArrayLike, ArrayLike]:
    """
    Pad the shorter array at the start or end using the
    edge values to match the length of the longer array.

    Parameters
    ----------
    x
        First array.
    y
        Second array.
    axis
        Axis along which to add padding, by default 0
    pad_direction
        Direction to pad the shorter array, either "start" or "end", by default "end".
    fix_y
        If True, raise an error when `y` is shorter than `x`, by default False.

    Returns
    -------
    tuple[ArrayLike, ArrayLike]
        Tuple of the two arrays with the same length.
    """
    x = jnp.atleast_1d(x)
    y = jnp.atleast_1d(y)
    x_len = x.shape[axis]
    y_len = y.shape[axis]
    pad_size = abs(x_len - y_len)
    pad_width = [(0, 0)] * x.ndim

    if pad_direction not in ["start", "end"]:
        raise ValueError(
            f"pad_direction must be either 'start' or 'end'. Got {pad_direction}."
        )

    pad_width[axis] = {"start": (pad_size, 0), "end": (0, pad_size)}.get(
        pad_direction, None
    )

    if x_len > y_len:
        if fix_y:
            raise ValueError(
                f"Cannot fix y when x is longer than y. x_len: {x_len}, y_len: {y_len}."
            )
        y = jnp.pad(y, pad_width, mode="edge")

    elif y_len > x_len:
        x = jnp.pad(x, pad_width, mode="edge")

    return x, y

repeat_until_n

repeat_until_n(
    data: ArrayLike, period_size: int, n_timepoints: int, offset: int = 0
)

Repeat each entry in data a given number of times (period_size) until an array of length n_timepoints has been produced.

Notes

Using the offset parameter, the function will offset the data after the repeat operation. So, if the offset is 2, the repeat operation will repeat the data until n_timepoints + 2 and then offset the data by 2, returning an array of size n_timepoints. This is a way to start part-way into a period. For example, the following code will each array element four times until 10 timepoints and then offset the data by 2:

.. code-block:: python data = jnp.array([1, 2, 3]) repeat_until_n(data, 4, 10, 2) # Array([1, 1, 2, 2, 2, 2, 3, 3, 3, 3], dtype=int32)

Parameters:

Name Type Description Default
data ArrayLike

Data to broadcast.

required
period_size int

Size of the period for the repeat broadcast.

required
n_timepoints int

Duration of the sequence.

required
offset int

Relative point at which data starts, must be between 0 and period_size - 1. By default 0, i.e., no offset.

0

Returns:

Type Description
ArrayLike

Repeated data.

Source code in pyrenew/arrayutils.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def repeat_until_n(
    data: ArrayLike,
    period_size: int,
    n_timepoints: int,
    offset: int = 0,
):
    """
    Repeat each entry in `data` a given number of times (`period_size`)
    until an array of length `n_timepoints` has been produced.

    Notes
    -----
    Using the `offset` parameter, the function will offset the data after
    the repeat operation. So, if the offset is 2, the repeat operation
    will repeat the data until `n_timepoints + 2` and then offset the
    data by 2, returning an array of size `n_timepoints`. This is a way to start
    part-way into a period. For example, the following code will each array
    element four times until 10 timepoints and then offset the data by 2:

    .. code-block:: python
      data = jnp.array([1, 2, 3])
      repeat_until_n(data, 4, 10, 2)
      # Array([1, 1, 2, 2, 2, 2, 3, 3, 3, 3], dtype=int32)

    Parameters
    ----------
    data
        Data to broadcast.
    period_size
        Size of the period for the repeat broadcast.
    n_timepoints
        Duration of the sequence.
    offset
        Relative point at which data starts, must be between 0 and
        period_size - 1. By default 0, i.e., no offset.

    Returns
    -------
    ArrayLike
        Repeated data.
    """

    # Data starts should be a positive integer
    assert isinstance(offset, int), (
        f"offset should be an integer. It is {type(offset)}."
    )

    assert 0 <= offset, f"offset should be a positive integer. It is {offset}."

    # Period size should be a positive integer
    assert isinstance(period_size, int), (
        f"period_size should be an integer. It is {type(period_size)}."
    )

    assert period_size > 0, (
        f"period_size should be a positive integer. It is {period_size}."
    )

    assert offset <= period_size - 1, (
        "offset should be less than or equal to period_size - 1."
        f"It is {offset}. It should be less than or equal "
        f"to {period_size - 1}."
    )

    if (data.size * period_size) < (n_timepoints + offset):
        raise ValueError(
            "The data is too short to broadcast to the given number "
            f"of timepoints + offset ({n_timepoints + offset}). The "
            "repeated data would have a size of data.size * "
            f"period_size = {data.size} * {period_size} = "
            f"{data.size * period_size}."
        )

    return jnp.repeat(
        a=data,
        repeats=period_size,
        total_repeat_length=n_timepoints + offset,
    )[offset : (offset + n_timepoints)]

tile_until_n

tile_until_n(data: ArrayLike, n_timepoints: int, offset: int = 0) -> ArrayLike

Tile the data until it reaches n_timepoints.

Parameters:

Name Type Description Default
data ArrayLike

Data to broadcast.

required
n_timepoints int

Duration of the sequence.

required
offset int

Relative point at which data starts, must be a non-negative integer. Default is zero, i.e., no offset.

0
Notes

Using the offset parameter, the function will start the broadcast from the offset-th element of the data. If the data is shorter than n_timepoints, the function will tile the data until it reaches n_timepoints.

Returns:

Type Description
ArrayLike

Tiled data.

Source code in pyrenew/arrayutils.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def tile_until_n(
    data: ArrayLike,
    n_timepoints: int,
    offset: int = 0,
) -> ArrayLike:
    """
    Tile the data until it reaches `n_timepoints`.

    Parameters
    ----------
    data
        Data to broadcast.
    n_timepoints
        Duration of the sequence.
    offset
        Relative point at which data starts, must be a non-negative integer.
        Default is zero, i.e., no offset.

    Notes
    -----
    Using the `offset` parameter, the function will start the broadcast
    from the `offset`-th element of the data. If the data is shorter than
    `n_timepoints`, the function will tile the data until it
    reaches `n_timepoints`.

    Returns
    -------
    ArrayLike
        Tiled data.
    """

    # Data starts should be a positive integer
    assert isinstance(offset, int), (
        f"offset should be an integer. It is {type(offset)}."
    )

    assert 0 <= offset, f"offset should be a positive integer. It is {offset}."

    return jnp.tile(data, (n_timepoints // data.size) + 1)[
        offset : (offset + n_timepoints)
    ]