SpuriousCount#
from tmlt.tune import SpuriousCount
- class tmlt.tune.SpuriousCount(join_columns, *, name=None, description=None, baseline=None, output=None, grouping_columns=None)#
Bases:
JoinedOutputMetric
Computes the number of values in the DP output but not in the baseline output.
This metric returns the number of values of
join_columns
that appear in the DP output but not in the baseline output (such values are called spurious).If
grouping_columns
is defined, then the DP output and the baseline output are both grouped by these columns, the spurious count is calculated separately for each group, and the metric returns a DataFrame. Otherwise, the metric returns a single number.In each group (or globally if
grouping_column
isNone
), each combination of values ofjoin_columns
must appear in at most one row of the DP output and the baseline output. Otherwise, the metric returns an error.Example
>>> dp_df = spark.createDataFrame( ... pd.DataFrame( ... { ... "A": ["a1", "a2", "a3", "c"], ... "X": [50, 110, 100, 50] ... } ... ) ... ) >>> dp_outputs = {"O": dp_df} >>> baseline_df = spark.createDataFrame( ... pd.DataFrame( ... { ... "A": ["a1", "a2", "a3", "b"], ... "X": [100, 100, 100, 50] ... } ... ) ... ) >>> baseline_outputs = {"default": {"O": baseline_df}}
>>> metric = SpuriousCount( ... join_columns=["A"] ... ) >>> metric.join_columns ['A'] >>> metric(dp_outputs, baseline_outputs).value 1
- count_spurious_rows(joined_output, result_column_name)#
Computes spurious count given DP and baseline outputs.