Skip to content

API reference

gather_draws

gather_draws(
    data: InferenceData,
    group: str = "posterior",
    combined: bool = True,
    var_names: Iterable[str] | None = None,
    filter_vars: str | None = None,
    num_samples: int | None = None,
    rng: bool | int | Generator | None = None,
    value_name: str | None = None,
    variable_name: str | None = None,
) -> DataFrame

Convert an ArviZ InferenceData object to a polars DataFrame of tidy (gathered) draws, using the syntax of arviz.extract.

Parameters:

Name Type Description Default
data InferenceData

Data to convert.

required
group str

group parameter passed to arviz.extract.

'posterior'
combined bool

combined parameter passed to arviz.extract.

True
var_names Iterable[str] | None

var_names parameter passed to arviz.extract.

None
filter_vars str | None

filter_vars parameter passed to arviz.extract.

None
num_samples int | None

num_samples parameter passed to arviz.extract.

None
rng bool | int | Generator | None

rng parameter passed to arviz.extract.

None
value_name str | None

Name for the value column in the output DataFrame. if None (default), use "value".

None
variable_name str | None

Name for the variable column in the output DataFrame. if None (default), use "variable".

None

Returns:

Type Description
DataFrame

The DataFrame of tidy (gathered) draws, including standard columns to identify a unique sample (typically "chain" and "draw"), a column of variable names, a column of associated variable values, plus (as needed) columns that index array-valued variables.

Source code in polarbayes/gather.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def gather_draws(
    data: az.InferenceData,
    group: str = "posterior",
    combined: bool = True,
    var_names: Iterable[str] | None = None,
    filter_vars: str | None = None,
    num_samples: int | None = None,
    rng: bool | int | np.random.Generator | None = None,
    value_name: str | None = None,
    variable_name: str | None = None,
) -> pl.DataFrame:
    """
    Convert an ArviZ InferenceData object to a polars
    DataFrame of tidy (gathered) draws, using the syntax of
    [`arviz.extract`][].

    Parameters
    ----------
    data
        Data to convert.

    group
        `group` parameter passed to [`arviz.extract`][].

    combined
        `combined` parameter passed to [`arviz.extract`][].

    var_names
        `var_names` parameter passed to [`arviz.extract`][].

    filter_vars
        `filter_vars` parameter passed to [`arviz.extract`][].

    num_samples
        `num_samples` parameter passed to [`arviz.extract`][].

    rng
        `rng` parameter passed to [`arviz.extract`][].

    value_name
        Name for the value column in the output DataFrame. if `None` (default),
        use `"value"`.

    variable_name
        Name for the variable column in the output DataFrame. if `None` (default),
        use `"variable"`.

    Returns
    -------
    pl.DataFrame
        The DataFrame of tidy (gathered) draws, including
        standard columns to identify a unique sample
        (typically `"chain"` and `"draw"`), a column of variable
        names, a column of associated variable values,
        plus (as needed) columns that index array-valued variables.
    """
    if variable_name is None:
        variable_name = VARIABLE_NAME
    if value_name is None:
        value_name = VALUE_NAME
    # need to extract all variables jointly to ensure same
    # draws for each
    extracted = az.extract(
        data,
        group=group,
        combined=combined,
        var_names=var_names,
        filter_vars=filter_vars,
        num_samples=num_samples,
        keep_dataset=True,
        rng=rng,
    )
    var_names = extracted.data_vars.keys()
    result = pl.concat(
        [
            gather_variables(
                *spread_draws_and_get_index_cols(
                    extracted,
                    group=group,
                    var_names=var,
                    combined=False,
                    filter_vars=None,
                    num_samples=None,
                    rng=False,
                    enforce_drop_chain_draw=combined,
                ),
                variable_name=variable_name,
                value_name=value_name,
            )
            for var in var_names
        ],
        how="diagonal_relaxed",
    )
    # Need to order output columns here as well as
    # in gather_variables() calls in case later gather_variables()
    # calls add new index columns that were not present due to earlier
    # calls, in which case those index columns will be out of order.
    index_cols_ordered = order_index_column_names(
        [x for x in result.columns if x not in [variable_name, value_name]]
    )

    return result.select(index_cols_ordered + [variable_name, value_name])

gather_variables

gather_variables(
    data: LazyFrame | DataFrame,
    index: ColumnNameOrSelector | Sequence[ColumnNameOrSelector] | None = None,
    value_name: str | None = None,
    variable_name: str | None = None,
)

Gather variable columns into key-value pairs. Light wrapper of pl.DataFrame.unpivot designed for use with spread_draws output.

Parameters:

Name Type Description Default
data LazyFrame | DataFrame

Input DataFrame to (un)pivot from wide to long format.

required
index ColumnNameOrSelector | Sequence[ColumnNameOrSelector] | None

Polars expression selecting mandatory or optional columns to index the gather. Passed as the index argument to pl.DataFrame.unpivot. If None (default), use the columns ["chain", "draw"] if they are present. Those are the MCMC index columns created when spread_draws is called on a standard az.InferenceData object.

None
value_name str | None

Name for the value column in the output DataFrame. If None (default), use "value".

None
variable_name str | None

Name for the variable column in the output DataFrame. If None (default), use "variable".

None

Returns:

Type Description
LazyFrame | DataFrame

Unpivoted (pivoted longer) tidy data frame with index columns plus variable name and value columns.

Raises:

Type Description
ValueError

If value_name or variable_name conflicts with requested index columns.

Source code in polarbayes/gather.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def gather_variables(
    data: pl.LazyFrame | pl.DataFrame,
    index: ColumnNameOrSelector | Sequence[ColumnNameOrSelector] | None = None,
    value_name: str | None = None,
    variable_name: str | None = None,
):
    """
    Gather variable columns into key-value pairs.
    Light wrapper of [`pl.DataFrame.unpivot`][polars.DataFrame.unpivot]
    designed for use with
    [`spread_draws`][polarbayes.spread.spread_draws] output.

    Parameters
    ----------
    data
        Input DataFrame to (un)pivot from wide to long format.
    index
        Polars expression selecting mandatory or optional columns to
        index the gather. Passed as the `index` argument to
        [`pl.DataFrame.unpivot`][polars.DataFrame.unpivot].
        If `None` (default), use the columns
        `["chain", "draw"]` if they are present. Those are the MCMC
        index columns created when
        [`spread_draws`][polarbayes.spread.spread_draws] is called on
        a standard [`az.InferenceData`][arviz.InferenceData] object.

    value_name
        Name for the value column in the output DataFrame.
        If `None` (default), use `"value"`.

    variable_name
        Name for the variable column in the output DataFrame.
        If `None` (default), use `"variable"`.

    Returns
    -------
    pl.LazyFrame | pl.DataFrame
        Unpivoted (pivoted longer) tidy data frame with index columns plus
        variable name and value columns.

    Raises
    ------
    ValueError
        If `value_name` or `variable_name` conflicts with requested
        index columns.
    """
    if variable_name is None:
        variable_name = VARIABLE_NAME
    if value_name is None:
        value_name = VALUE_NAME
    if index is None:
        index = cs.by_name(CHAIN_NAME, DRAW_NAME, require_all=False)

    index_names = order_index_column_names(
        data.select(index).collect_schema().names()
    )

    # more informative error message than `unpivot()` gives on its own
    [
        _assert_not_in_index_columns(k, v, index_names)
        for k, v in dict(
            value_name=value_name, variable_name=variable_name
        ).items()
    ]

    return data.unpivot(
        index=index, variable_name=variable_name, value_name=value_name
    ).select(index_names + [variable_name, value_name])  # order output columns

spread_draws

spread_draws(
    data: InferenceData,
    group: str = "posterior",
    combined: bool = True,
    var_names: Iterable[str] | None = None,
    filter_vars: str | None = None,
    num_samples: int | None = None,
    rng: bool | int | Generator | None = None,
) -> DataFrame

Convert an ArviZ InferenceData object to a polars DataFrame of tidy (spread) draws, using the syntax of arviz.extract.

Parameters:

Name Type Description Default
data InferenceData

Data to convert.

required
group str

group parameter passed to arviz.extract.

'posterior'
combined bool

combined parameter passed to arviz.extract.

True
var_names Iterable[str] | None

var_names parameter passed to arviz.extract.

None
filter_vars str | None

var_names parameter passed to arviz.extract.

None
num_samples int | None

num_samples parameter passed to arviz.extract.

None
rng bool | int | Generator | None

rng parameter passed to arviz.extract.

None

Returns:

Type Description
DataFrame

The DataFrame of tidy draws. Consists of columns named for variables and index columns. Columns named for variables contain the sampled values of those variables. Index columns include standard columns to identify a unique sample (typically "chain" and "draw") plus (as needed) columns that index array-valued variables.

Source code in polarbayes/spread.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def spread_draws(
    data: az.InferenceData,
    group: str = "posterior",
    combined: bool = True,
    var_names: Iterable[str] | None = None,
    filter_vars: str | None = None,
    num_samples: int | None = None,
    rng: bool | int | np.random.Generator | None = None,
) -> pl.DataFrame:
    """
    Convert an ArviZ InferenceData object to a polars
    DataFrame of tidy (spread) draws, using the syntax of
    [`arviz.extract`][].

    Parameters
    ----------
    data
        Data to convert.

    group
        `group` parameter passed to [`arviz.extract`][].

    combined
        `combined` parameter passed to [`arviz.extract`][].

    var_names
        `var_names` parameter passed to [`arviz.extract`][].

    filter_vars
        `var_names` parameter passed to [`arviz.extract`][].

    num_samples
        `num_samples` parameter passed to [`arviz.extract`][].

    rng
        `rng` parameter passed to [`arviz.extract`][].

    Returns
    -------
    pl.DataFrame
        The DataFrame of tidy draws. Consists of columns named for
        variables and index columns. Columns named for variables
        contain the sampled values of those variables. Index columns
        include standard columns to identify a unique
        sample (typically `"chain"` and `"draw"`) plus (as needed)
        columns that index array-valued variables.
    """
    result, _ = spread_draws_and_get_index_cols(
        data,
        group=group,
        combined=combined,
        var_names=var_names,
        filter_vars=filter_vars,
        num_samples=num_samples,
        rng=rng,
    )
    return result

spread_draws_and_get_index_cols

spread_draws_and_get_index_cols(
    data: InferenceData,
    group: str = "posterior",
    combined: bool = True,
    var_names: Iterable[str] | None = None,
    filter_vars: str | None = None,
    num_samples: int | None = None,
    rng: bool | int | Generator | None = None,
    enforce_drop_chain_draw: bool = False,
) -> tuple[DataFrame, tuple]

Convert an ArviZ InferenceData object to a polars DataFrame of tidy (spread) draws, using the syntax of arviz.extract. Return that DataFrame alongside a tuple giving the names of the DataFrame's index columns.

Parameters:

Name Type Description Default
data InferenceData

Data to convert.

required
group str

group parameter passed to arviz.extract.

'posterior'
combined bool

combined parameter passed to arviz.extract.

True
var_names Iterable[str] | None

var_names parameter passed to arviz.extract.

None
filter_vars str | None

filter_vars parameter passed to arviz.extract.

None
num_samples int | None

num_samples parameter passed to arviz.extract.

None
rng bool | int | Generator | None

rng parameter passed to arviz.extract.

None

Returns:

Type Description
tuple[DataFrame, tuple]

Two-entry whose first entry is the DataFrame, and whose second entry is a tuple giving the names of that DataFrame's index columns. The DataFrame consists of columns named for variables and index columns. Columns named for variables contain the sampled values of those variables. Index columns include standard columns to identify a unique sample (typically "chain" and "draw") plus (as needed) columns that index array-valued variables.

Source code in polarbayes/spread.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def spread_draws_and_get_index_cols(
    data: az.InferenceData,
    group: str = "posterior",
    combined: bool = True,
    var_names: Iterable[str] | None = None,
    filter_vars: str | None = None,
    num_samples: int | None = None,
    rng: bool | int | np.random.Generator | None = None,
    enforce_drop_chain_draw: bool = False,
) -> tuple[pl.DataFrame, tuple]:
    """
    Convert an ArviZ InferenceData object to a polars
    DataFrame of tidy (spread) draws, using the syntax of
    arviz.extract. Return that DataFrame alongside a tuple
    giving the names of the DataFrame's index columns.

    Parameters
    ----------
    data
        Data to convert.

    group
        `group` parameter passed to [`arviz.extract`][].

    combined
        `combined` parameter passed to [`arviz.extract`][].

    var_names
        `var_names` parameter passed to [`arviz.extract`][].

    filter_vars
        `filter_vars` parameter passed to [`arviz.extract`][].

    num_samples
        `num_samples` parameter passed to [`arviz.extract`][].

    rng
        `rng` parameter passed to [`arviz.extract`][].

    Returns
    -------
    tuple[pl.DataFrame, tuple]
        Two-entry whose first entry is the DataFrame, and whose
        second entry is a tuple giving the names of that DataFrame's
        index columns. The DataFrame consists of columns named for
        variables and index columns. Columns named for variables
        contain the sampled values of those variables. Index columns
        include standard columns to identify a unique
        sample (typically `"chain"` and `"draw"`) plus (as needed)
        columns that index array-valued variables.
    """

    df = spread_draws_to_pandas_(
        data,
        group=group,
        combined=combined,
        var_names=var_names,
        filter_vars=filter_vars,
        num_samples=num_samples,
        rng=rng,
    )
    if enforce_drop_chain_draw:
        df = df.drop([CHAIN_NAME, DRAW_NAME], axis=1)
        # this is handled automatically when `combined=True`
        # by spread_draws_to_pandas_,
        # but not when combined=False but the `data` input
        # is an already-combined output of [`arviz.extract`][].
    df, index_cols = pl.DataFrame(df.reset_index()), df.index.names
    index_cols_ordered = order_index_column_names(index_cols)

    return (
        df.select(
            cs.by_name(index_cols_ordered, require_all=True),
            cs.exclude(index_cols_ordered),
        ),
        index_cols_ordered,
    )