Metric#

from tmlt.tune import Metric
class tmlt.tune.Metric(name, func, description=None, grouping_columns=None, measure_column=None, empty_value=None)#

Bases: object

A generic metric defined using a function.

This function (the func argument) must have the following parameters:

  • dp_outputs: a dictionary of DataFrames containing the program’s outputs.

  • baseline_outputs: a dictionary mapping baseline names to dictionaries of output DataFrames.

It may also have the following optional parameters:

  • result_column_name: if the function returns a DataFrame, the metric results should be in a column with this name

  • unprotected_inputs: A dictionary containing the program’s unprotected inputs.

  • parameters: A dictionary containing the program’s parameters.

If the metric does not have grouping columns, the function must return a numeric value, a boolean, or a string. If the metric has grouping columns, then it must return a DataFrame. This DataFrame should contain the grouping columns, and exactly one additional column containing the metric value for each group. This column’s type should be numeric, boolean, or string.

Example

>>> dp_df = spark.createDataFrame(pd.DataFrame({"A": [5]}))
>>> dp_outputs = {"O": dp_df}
>>> baseline_df1 = spark.createDataFrame(pd.DataFrame({"A": [5]}))
>>> baseline_df2 = spark.createDataFrame(pd.DataFrame({"A": [6]}))
>>> baseline_outputs = {
...    "baseline1": {"O": baseline_df1}, "baseline2": {"O": baseline_df2}
... }
>>> def size_difference(dp_outputs, baseline_outputs):
...     baseline_count = baseline_outputs["baseline1"]["O"].count()
...     return abs(baseline_count - dp_outputs["O"].count())
>>> metric = Metric(
...     func=size_difference,
...     name="Custom Metric",
...     description="Custom Description",
... )
>>> result = metric(dp_outputs, baseline_outputs)
>>> result.value
0
property name: str#

The name of the metric.

property description: str#

The description of the metric.

property func: Callable#

The function to be applied.

property grouping_columns: List[str]#

The grouping columns.

property measure_column: str | None#

The measure column (if any).

property empty_value: Any#

The value this metric will return when inputs are empty.

required_func_parameters()#

Returns the required parameters to the metric function.

optional_func_parameters()#

Returns the optional parameters to the metric function.

get_column_name_from_baselines(baseline_outputs)#

Gets the result column name for a given set of outputs.

get_parameter_values(dp_outputs, baseline_outputs, unprotected_inputs, parameters)#

Returns values for the function’s parameters.

Return type:

Dict[str, Any]

validate_result(result, baseline_outputs)#

Check that the metric result is an allowed type.

metric_function_inputs_empty(function_params)#

Determines if the inputs to the metric function are empty.

Return type:

bool

__call__(dp_outputs, baseline_outputs, unprotected_inputs=None, parameters=None)#

Computes the given metric on the given DP and baseline outputs.

Parameters:
Return type:

MetricResult