single_output_metric#
from tmlt.tune import single_output_metric
- tmlt.tune.single_output_metric(name, description=None, baseline=None, output=None, grouping_columns=None, measure_column=None, empty_value=None)#
Decorator to define a custom
SingleOutputMetric
.If you can use the
joined_output_metric()
decorator instead, it will likely be easier to use.The decorated function must have the following parameters:
dp_output
: the chosen DP output DataFrame.baseline_outputs
: the chosen baseline output DataFrame.
It may also have the following optional parameters:
result_column_name
: if the function returns a DataFrame, the metric results should be in a column with this nameunprotected_inputs
: A dictionary containing the program’s unprotected inputs.parameters
: A dictionary containing the program’s parameters.
If the metric does not have grouping columns, the function must return a numeric value, a boolean, or a string. If the metric has grouping columns, then it must return a DataFrame. This DataFrame should contain the grouping columns, and exactly one additional column containing the metric value for each group. This column’s type should be numeric, boolean, or string.
To use the built-in metrics in addition to this custom metric, you can separately specify
metrics
class variable.- Parameters:
name (
str
) – A name for the metric.baseline (
Optional
[str
]) – The name of the baseline program used for the error report. If None, the tuner must have a single baseline (which will be used).output (
Optional
[str
]) – The name of the program output to be used for the metric. If None, the program must have only one output (which will be used).grouping_columns (
Optional
[List
[str
]]) – If specified, the metric should group the outputs by the given columns, and calculate the metric for each group.measure_column (
Optional
[str
]) – If specified, the column in the outputs to measure.empty_value (
Optional
[Any
]) – If all dp and baseline outputs are empty, the metric will return this value.
>>> from tmlt.analytics import Session >>> from tmlt.tune import MedianAbsoluteError >>> from pyspark.sql import DataFrame >>> from typing import Dict
>>> class Program(SessionProgram): ... class ProtectedInputs: ... protected_df: DataFrame ... class UnprotectedInputs: ... unprotected_df: DataFrame ... class Outputs: ... output_df: DataFrame ... def session_interaction(self, session: Session): ... return {"output_df": dp_output} >>> class Tuner(SessionProgramTuner, program=Program): ... @single_output_metric(name="custom_metric") ... @staticmethod ... def custom_metric( ... dp_output: DataFrame, ... baseline_output: DataFrame ... ): ... # If the program has unprotected inputs and/or parameters, the custom ... # metric method can take them as an argument. ... ... ... metrics = [ ... MedianAbsoluteError( ... output="output_df", ... join_columns=["join_column"], ... measure_column="Y" ... ), ... ] # You can mix custom and built-in metrics.