baseline#

from tmlt.tune import baseline
tmlt.tune.baseline(name)#

Decorator to define a custom baseline in a SessionProgramTuner.

To use the “default” baseline in addition to this custom baseline, you need to separately specify “default”: NoPrivacySession.Options() in baseline_options class variable.

Parameters:

name (str) – A name for the custom baseline.

>>> from tmlt.analytics import Session
>>> class Program(SessionProgram):
...     class ProtectedInputs:
...         protected_df: DataFrame
...     class UnprotectedInputs:
...         unprotected_df: DataFrame
...     class Outputs:
...         output_df: DataFrame
...     def session_interaction(self, session: Session):
...         ...
>>> class Tuner(SessionProgramTuner, program=Program):
...     @baseline("custom_baseline")
...     @staticmethod
...     def custom_baseline(
...         protected_inputs: Dict[str, DataFrame],
...     ) -> Dict[str, DataFrame]:
...         ...
...     @baseline("another_custom_baseline")
...     @staticmethod
...     def another_custom_baseline(
...         protected_inputs: Dict[str, DataFrame],
...         unprotected_inputs: Dict[str, DataFrame],
...     ) -> Dict[str, DataFrame]:
...         # If the program has unprotected inputs or parameters, the custom
...         # baseline method can take them as an argument.
...         ...
...     baseline_options = {
...         "default": NoPrivacySession.Options()
...     }  # This is required to keep the default baseline