metrics#

Module containing metrics used for constructing measurements and transformations.

Classes#

Metric

Base class for input/output metrics.

NullMetric

Metric for use when distance is undefined.

ExactNumberMetric

A metric whose distances are exact numbers.

AbsoluteDifference

The absolute value of the difference of two values.

SymmetricDifference

The number of elements that are in only one of two sets.

HammingDistance

The number of elements that are different between two sets of the same size.

AggregationMetric

Distances resulting from aggregating distances of its components.

SumOf

Distances resulting from summing distances of its components.

RootSumOfSquared

The square root of the sum of the squares of component distances.

OnColumn

The value of a metric applied to a single column treated as a vector.

OnColumns

A tuple containing the values of multiple OnColumn metrics.

IfGroupedBy

Distance between two DataFrames that shall be grouped by a given attribute.

DictMetric

Distance between two dictionaries with identical sets of keys.

AddRemoveKeys

The number of keys that dictionaries of dataframe differ by.

class Metric#

Bases: abc.ABC

Base class for input/output metrics.

abstract validate(self, value)#

Raises an error if value not a valid distance.

Parameters

value (Any) – A distance between two datasets under this metric.

abstract compare(self, value1, value2)#

Returns True if value1 is less than or equal to value2.

Parameters
  • value1 (Any) –

  • value2 (Any) –

Return type

bool

abstract supports_domain(self, domain)#

Return True if the metric is implemented for the passed domain.

Parameters

domain (tmlt.core.domains.base.Domain) –

Return type

bool

abstract distance(self, value1, value2, domain)#

Returns the metric distance between two elements of a supported domain.

Parameters
Return type

Any

__eq__(self, other)#

Return True if both metrics are equal.

Parameters

other (Any) –

Return type

bool

class NullMetric#

Bases: Metric

Metric for use when distance is undefined.

abstract validate(self, value)#

Raises an error if value not a valid distance.

This method is not implemented.

Parameters

value (Any) – A distance between two datasets under this metric.

abstract compare(self, value1, value2)#

Returns True if value1 is less than or equal to value2.

This method is not implemented.

Parameters
  • value1 (Any) – A distance between two datasets under this metric.

  • value2 (Any) – A distance between two datasets under this metric.

Return type

bool

supports_domain(self, domain)#

Return True if the metric is implemented for the passed domain.

Parameters

domain (tmlt.core.domains.base.Domain) – The domain to check against.

Return type

bool

abstract distance(self, value1, value2, domain)#

Return the metric distance between two elements of a supported domain.

Parameters
  • value1 (Any) – An element of the domain.

  • value2 (Any) – An element of the domain.

  • domain (tmlt.core.domains.base.Domain) – A domain compatible with the metric.

Return type

Any

__eq__(self, other)#

Return True if both metrics are equal.

Parameters

other (Any) –

Return type

bool

class ExactNumberMetric#

Bases: Metric

A metric whose distances are exact numbers.

abstract distance(self, value1, value2, domain)#

Return the metric distance between two elements of a supported domain.

Parameters
  • value1 (Any) – An element of the domain.

  • value2 (Any) – An element of the domain.

  • domain (tmlt.core.domains.base.Domain) – A domain compatible with the metric.

Return type

tmlt.core.utils.exact_number.ExactNumber

abstract validate(self, value)#

Raises an error if value not a valid distance.

Parameters

value (Any) – A distance between two datasets under this metric.

abstract compare(self, value1, value2)#

Returns True if value1 is less than or equal to value2.

Parameters
  • value1 (Any) –

  • value2 (Any) –

Return type

bool

abstract supports_domain(self, domain)#

Return True if the metric is implemented for the passed domain.

Parameters

domain (tmlt.core.domains.base.Domain) –

Return type

bool

__eq__(self, other)#

Return True if both metrics are equal.

Parameters

other (Any) –

Return type

bool

class AbsoluteDifference#

Bases: ExactNumberMetric

The absolute value of the difference of two values.

Example

>>> AbsoluteDifference().distance(
...     np.int64(20), np.int64(82), NumpyIntegerDomain()
... )
62
>>> # 1.2 is first converted to rational 5404319552844595/4503599627370496
>>> AbsoluteDifference().distance(
...     np.float64(1.2), np.float64(1.0), NumpyFloatDomain()
... )
900719925474099/4503599627370496
validate(self, value)#

Raises an error if value not a valid distance.

  • value must be a nonnegative real or infinite

Parameters

value (tmlt.core.utils.exact_number.ExactNumberInput) – A distance between two datasets under this metric.

compare(self, value1, value2)#

Returns True if value1 is less than or equal to value2.

Parameters
  • value1 (tmlt.core.utils.exact_number.ExactNumberInput) –

  • value2 (tmlt.core.utils.exact_number.ExactNumberInput) –

Return type

bool

supports_domain(self, domain)#

Return True if the metric is implemented for the passed domain.

Parameters

domain (tmlt.core.domains.base.Domain) – The domain to check against.

Return type

bool

distance(self, value1, value2, domain)#

Return the metric distance between two elements of a supported domain.

Parameters
  • value1 (Any) – An element of the domain.

  • value2 (Any) – An element of the domain.

  • domain (tmlt.core.domains.base.Domain) – A domain compatible with the metric.

Return type

tmlt.core.utils.exact_number.ExactNumber

__eq__(self, other)#

Return True if both metrics are equal.

Parameters

other (Any) –

Return type

bool

class SymmetricDifference#

Bases: ExactNumberMetric

The number of elements that are in only one of two sets.

This metric is compatible with spark dataframes, pandas dataframes, and pandas series. It ignores ordering and, in the case of pandas, indices. That is, it treats each collection as a multiset of items. For non-grouped data, it treats each record as an item. For grouped data there are a few cases:

  • The distance between two groups with the same multi-set of records is 0

  • The distance between two groups where exactly one is empty is 1

  • The distance between two groups with different records (where neither is empty) is 2

Examples

>>> import pandas as pd
>>> from pyspark.sql import SparkSession
>>> from tmlt.core.domains.spark_domains import (
...     SparkColumnsDescriptor,
...     SparkIntegerColumnDescriptor,
... )
>>> spark = SparkSession.builder.getOrCreate()
>>> domain = SparkDataFrameDomain(
...     {
...         "A": SparkIntegerColumnDescriptor(),
...         "B": SparkIntegerColumnDescriptor(),
...     }
... )
>>> df1 = spark.createDataFrame(
...     pd.DataFrame({"A": [1, 1, 1, 2, 3], "B": [2, 2, 2, 4, 3]})
... )
>>> df2 = spark.createDataFrame(pd.DataFrame({"A": [1, 2, 1], "B": [2, 4, 1]}))
>>> SymmetricDifference().distance(df1, df2, domain)
4
>>> group_keys = spark.createDataFrame(pd.DataFrame({"B": [1, 2, 4]}))
>>> domain = SparkGroupedDataFrameDomain(
...     {
...         "A": SparkIntegerColumnDescriptor(),
...         "B": SparkIntegerColumnDescriptor(),
...     },
...     group_keys,
... )
>>> grouped_df1 = GroupedDataFrame(df1, group_keys)
>>> grouped_df2 = GroupedDataFrame(df2, group_keys)
>>> SymmetricDifference().distance(grouped_df1, grouped_df2, domain)
3
validate(self, value)#

Raises an error if value not a valid distance.

  • value must be a nonnegative integer or infinity

Parameters

value (tmlt.core.utils.exact_number.ExactNumberInput) – A distance between two datasets under this metric.

compare(self, value1, value2)#

Returns True if value1 is less than or equal to value2.

Parameters
  • value1 (tmlt.core.utils.exact_number.ExactNumberInput) –

  • value2 (tmlt.core.utils.exact_number.ExactNumberInput) –

Return type

bool

supports_domain(self, domain)#

Return True if the metric is implemented for the passed domain.

Parameters

domain (tmlt.core.domains.base.Domain) – The domain to check against.

Return type

bool

distance(self, value1, value2, domain)#

Return the metric distance between two elements of a supported domain.

Parameters
  • value1 (Any) – An element of the domain.

  • value2 (Any) – An element of the domain.

  • domain (tmlt.core.domains.base.Domain) – A domain compatible with the metric.

Return type

tmlt.core.utils.exact_number.ExactNumber

__eq__(self, other)#

Return True if both metrics are equal.

Parameters

other (Any) –

Return type

bool

class HammingDistance#

Bases: ExactNumberMetric

The number of elements that are different between two sets of the same size.

This metric is compatible with spark dataframes, pandas dataframes, and pandas series. It ignores ordering and, in the case of pandas, indices. That is, it treats each collection as a multiset of records.

If the sets are not the same size, the distance is infinity.

Example

>>> import pandas as pd
>>> from pyspark.sql import SparkSession
>>> from tmlt.core.domains.spark_domains import SparkColumnsDescriptor
>>> from tmlt.core.domains.spark_domains import (
...     SparkIntegerColumnDescriptor,
... )
>>> spark = SparkSession.builder.getOrCreate()
>>> domain = SparkDataFrameDomain(
...     {
...         "A": SparkIntegerColumnDescriptor(),
...         "B": SparkIntegerColumnDescriptor(),
...     }
... )
>>> df1 = spark.createDataFrame(
...     pd.DataFrame({"A": [1, 1, 1, 3], "B": [2, 2, 2, 4]})
... )
>>> df2 = spark.createDataFrame(pd.DataFrame({"A": [1, 2], "B": [2, 4]}))
>>> HammingDistance().distance(df1, df2, domain)
oo
>>> df3 = spark.createDataFrame(
...     pd.DataFrame({"A": [1, 2, 3, 1], "B": [2, 4, 4, 2]})
... )
>>> HammingDistance().distance(df1, df3, domain)
1
validate(self, value)#

Raises an error if value not a valid distance.

  • value must be a nonnegative and integer or infinity

Parameters

value (tmlt.core.utils.exact_number.ExactNumberInput) – A distance between two datasets under this metric.

compare(self, value1, value2)#

Returns True if value1 is less than or equal to value2.

Parameters
  • value1 (tmlt.core.utils.exact_number.ExactNumberInput) –

  • value2 (tmlt.core.utils.exact_number.ExactNumberInput) –

Return type

bool

supports_domain(self, domain)#

Return True if the metric is implemented for the passed domain.

Parameters

domain (tmlt.core.domains.base.Domain) – The domain to check against.

Return type

bool

distance(self, value1, value2, domain)#

Return the metric distance between two elements of a supported domain.

Parameters
  • value1 (Any) – An element of the domain.

  • value2 (Any) – An element of the domain.

  • domain (tmlt.core.domains.base.Domain) – A domain compatible with the metric.

Return type

tmlt.core.utils.exact_number.ExactNumber

__eq__(self, other)#

Return True if both metrics are equal.

Parameters

other (Any) –

Return type

bool

class AggregationMetric(inner_metric)#

Bases: ExactNumberMetric

Distances resulting from aggregating distances of its components.

Components may be elements of a series, groups of a grouped dataframe, or elements of a list. This metric is parameterized by an inner_metric that is used to compute the distances of the components. See SumOf or :class`RootSumOfSquared` for example usage.

Parameters

inner_metric (Union[AbsoluteDifference, SymmetricDifference, HammingDistance, IfGroupedBy]) –

__init__(inner_metric)#

Constructor.

Parameters

inner_metric (AbsoluteDifference | SymmetricDifference | HammingDistance | IfGroupedByUnion[AbsoluteDifference, SymmetricDifference, HammingDistance, IfGroupedBy]) – Metric to be applied to the components.

property inner_metric(self)#

Returns metric to be used for summing.

Return type

Union[AbsoluteDifference, SymmetricDifference, HammingDistance, IfGroupedBy]

compare(self, value1, value2)#

Returns True if value1 is less than or equal to value2.

Parameters
  • value1 (tmlt.core.utils.exact_number.ExactNumberInput) – A distance between two datasets under this metric.

  • value2 (tmlt.core.utils.exact_number.ExactNumberInput) – A distance between two datasets under this metric.

Return type

bool

supports_domain(self, domain)#

Return True if the metric is implemented for the passed domain.

Parameters

domain (tmlt.core.domains.base.Domain) – The domain to check against.

Return type

bool

distance(self, value1, value2, domain)#

Return the metric distance between two elements of a supported domain.

Parameters
  • value1 (Any) – An element of the domain.

  • value2 (Any) – An element of the domain.

  • domain (tmlt.core.domains.base.Domain) – A domain compatible with the metric.

Return type

tmlt.core.utils.exact_number.ExactNumber

abstract validate(self, value)#

Raises an error if value not a valid distance.

Parameters

value (Any) – A distance between two datasets under this metric.

__eq__(self, other)#

Return True if both metrics are equal.

Parameters

other (Any) –

Return type

bool

class SumOf(inner_metric)#

Bases: AggregationMetric

Distances resulting from summing distances of its components.

These components may be elements of a series, groups of a grouped dataframe, or elements of a list. This metric is parameterized by an inner_metric that is used to compute the distances of the components.

Example

>>> import pandas as pd
>>> from pyspark.sql import SparkSession
>>> from tmlt.core.domains.spark_domains import SparkColumnsDescriptor
>>> from tmlt.core.domains.spark_domains import (
...     SparkIntegerColumnDescriptor,
... )
>>> spark = SparkSession.builder.getOrCreate()
>>> # Symmetric difference on SparkGroupedDataFrame
>>> group_keys = spark.createDataFrame(pd.DataFrame({"A": [1, 2]}))
>>> domain = SparkGroupedDataFrameDomain(
...     {
...         "A": SparkIntegerColumnDescriptor(),
...         "B": SparkIntegerColumnDescriptor(),
...     },
...     group_keys,
... )
>>> df1 = GroupedDataFrame(
...     spark.createDataFrame(
...         pd.DataFrame({"A": [1, 1, 2, 3], "B": [1, 1, 2, 4]})
...     ),
...     group_keys,
... )
>>> df2 = GroupedDataFrame(
...     spark.createDataFrame(
...         pd.DataFrame({"A": [1, 2, 2, 3], "B": [1, 3, 4, 5]})
...     ),
...     group_keys,
... )
>>> SumOf(SymmetricDifference()).distance(df1, df2, domain)
4
>>> # Using HammingDistance gives a distance of infinity since the groups
>>> # are different sizes, despite the fact that the two dataframes are the
>>> # same size.
>>> SumOf(HammingDistance()).distance(df1, df2, domain)
oo
>>> # Absolute difference on pandas series first converts the floats to
>>> # rationals, then exactly computes the distance.
>>> domain = PandasSeriesDomain(NumpyFloatDomain())
>>> series1 = pd.Series([1.2, 0.8])
>>> series2 = pd.Series([0.3, 1.4])
>>> SumOf(AbsoluteDifference()).distance(series1, series2, domain)
27021597764222973/18014398509481984
Parameters

inner_metric (Union[AbsoluteDifference, SymmetricDifference, HammingDistance, IfGroupedBy]) –

__init__(inner_metric)#

Constructor.

Parameters

inner_metric (AbsoluteDifference | SymmetricDifference | HammingDistance | IfGroupedByUnion[AbsoluteDifference, SymmetricDifference, HammingDistance, IfGroupedBy]) – Metric to be applied to the components.

validate(self, value)#

Raises an error if value not a valid distance.

Parameters

value (tmlt.core.utils.exact_number.ExactNumberInput) – A distance between two datasets under this metric.

property inner_metric(self)#

Returns metric to be used for summing.

Return type

Union[AbsoluteDifference, SymmetricDifference, HammingDistance, IfGroupedBy]

compare(self, value1, value2)#

Returns True if value1 is less than or equal to value2.

Parameters
  • value1 (tmlt.core.utils.exact_number.ExactNumberInput) – A distance between two datasets under this metric.

  • value2 (tmlt.core.utils.exact_number.ExactNumberInput) – A distance between two datasets under this metric.

Return type

bool

supports_domain(self, domain)#

Return True if the metric is implemented for the passed domain.

Parameters

domain (tmlt.core.domains.base.Domain) – The domain to check against.

Return type

bool

distance(self, value1, value2, domain)#

Return the metric distance between two elements of a supported domain.

Parameters
  • value1 (Any) – An element of the domain.

  • value2 (Any) – An element of the domain.

  • domain (tmlt.core.domains.base.Domain) – A domain compatible with the metric.

Return type

tmlt.core.utils.exact_number.ExactNumber

__eq__(self, other)#

Return True if both metrics are equal.

Parameters

other (Any) –

Return type

bool

class RootSumOfSquared(inner_metric)#

Bases: AggregationMetric

The square root of the sum of the squares of component distances.

These components may be elements of a series, groups of a grouped dataframe, or elements of a list. This metric is parameterized by an inner_metric that is used to compute the distances of the components.

Example

>>> import pandas as pd
>>> from pyspark.sql import SparkSession
>>> from tmlt.core.domains.spark_domains import SparkColumnsDescriptor
>>> from tmlt.core.domains.spark_domains import (
...     SparkIntegerColumnDescriptor,
... )
>>> spark = SparkSession.builder.getOrCreate()
>>> # Symmetric difference on SparkGroupedDataFrame
>>> group_keys = spark.createDataFrame(pd.DataFrame({"A": [1, 2]}))
>>> domain = SparkGroupedDataFrameDomain(
...     {
...         "A": SparkIntegerColumnDescriptor(),
...         "B": SparkIntegerColumnDescriptor(),
...     },
...     group_keys,
... )
>>> df1 = GroupedDataFrame(
...     spark.createDataFrame(
...         pd.DataFrame({"A": [1, 1, 2, 3], "B": [1, 1, 2, 4]})
...     ),
...     group_keys,
... )
>>> df2 = GroupedDataFrame(
...     spark.createDataFrame(
...         pd.DataFrame({"A": [1, 2, 2, 3], "B": [1, 3, 4, 5]})
...     ),
...     group_keys,
... )
>>> RootSumOfSquared(SymmetricDifference()).distance(df1, df2, domain)
sqrt(10)
>>> # Using HammingDistance gives a distance of infinity since the groups
>>> # are different sizes, despite the fact that the two dataframes are the
>>> # same size.
>>> RootSumOfSquared(HammingDistance()).distance(df1, df2, domain)
oo
Parameters

inner_metric (Union[AbsoluteDifference, SymmetricDifference, HammingDistance, IfGroupedBy]) –

__init__(inner_metric)#

Constructor.

Parameters

inner_metric (AbsoluteDifference | SymmetricDifference | HammingDistance | IfGroupedByUnion[AbsoluteDifference, SymmetricDifference, HammingDistance, IfGroupedBy]) – Metric to be applied to the components.

validate(self, value)#

Raises an error if value not a valid distance.

  • value must be a nonnegative real or infinity

Parameters

value (tmlt.core.utils.exact_number.ExactNumberInput) – A distance between two datasets under this metric.

property inner_metric(self)#

Returns metric to be used for summing.

Return type

Union[AbsoluteDifference, SymmetricDifference, HammingDistance, IfGroupedBy]

compare(self, value1, value2)#

Returns True if value1 is less than or equal to value2.

Parameters
  • value1 (tmlt.core.utils.exact_number.ExactNumberInput) – A distance between two datasets under this metric.

  • value2 (tmlt.core.utils.exact_number.ExactNumberInput) – A distance between two datasets under this metric.

Return type

bool

supports_domain(self, domain)#

Return True if the metric is implemented for the passed domain.

Parameters

domain (tmlt.core.domains.base.Domain) – The domain to check against.

Return type

bool

distance(self, value1, value2, domain)#

Return the metric distance between two elements of a supported domain.

Parameters
  • value1 (Any) – An element of the domain.

  • value2 (Any) – An element of the domain.

  • domain (tmlt.core.domains.base.Domain) – A domain compatible with the metric.

Return type

tmlt.core.utils.exact_number.ExactNumber

__eq__(self, other)#

Return True if both metrics are equal.

Parameters

other (Any) –

Return type

bool

class OnColumn(column, metric)#

Bases: ExactNumberMetric

The value of a metric applied to a single column treated as a vector.

Example

>>> import pandas as pd
>>> from pyspark.sql import SparkSession
>>> from tmlt.core.domains.spark_domains import (
...     SparkIntegerColumnDescriptor,
... )
>>> spark = SparkSession.builder.getOrCreate()
>>> domain = SparkDataFrameDomain(
...     {
...         "A": SparkIntegerColumnDescriptor(),
...         "B": SparkIntegerColumnDescriptor(),
...     }
... )
>>> value1 = spark.createDataFrame(
...     pd.DataFrame({"A": [1, 23], "B": [3, 1]})
... )
>>> value2 = spark.createDataFrame(
...     pd.DataFrame({"A": [2, 20], "B": [1, 8]})
... )
>>> OnColumn("A", SumOf(AbsoluteDifference())).distance(value1, value2, domain)
4
>>> OnColumn("B", RootSumOfSquared(AbsoluteDifference())).distance(
...     value1, value2, domain
... )
sqrt(53)
Parameters
__init__(column, metric)#

Constructor.

Parameters
property column(self)#

Return the column to apply the metric to.

Return type

str

property metric(self)#

Return the metric to apply.

Return type

Union[SumOf, RootSumOfSquared]

validate(self, value)#

Raises an error if value not a valid distance.

  • value must be a a valid distance for metric

Parameters

value (tmlt.core.utils.exact_number.ExactNumberInput) – A distance between two datasets under this metric.

compare(self, value1, value2)#

Returns True if value1 is less than or equal to value2.

Parameters
  • value1 (tmlt.core.utils.exact_number.ExactNumberInput) –

  • value2 (tmlt.core.utils.exact_number.ExactNumberInput) –

Return type

bool

supports_domain(self, domain)#

Return True if the metric is implemented for the passed domain.

Parameters

domain (tmlt.core.domains.base.Domain) – The domain to check against.

Return type

bool

distance(self, value1, value2, domain)#

Return the metric distance between two elements of a supported domain.

Parameters
  • value1 (Any) – An element of the domain.

  • value2 (Any) – An element of the domain.

  • domain (tmlt.core.domains.base.Domain) – A domain compatible with the metric.

Return type

tmlt.core.utils.exact_number.ExactNumber

__eq__(self, other)#

Return True if both metrics are equal.

Parameters

other (Any) –

Return type

bool

class OnColumns(on_columns)#

Bases: Metric

A tuple containing the values of multiple OnColumn metrics.

Example

>>> import pandas as pd
>>> from pyspark.sql import SparkSession
>>> from tmlt.core.domains.spark_domains import (
...     SparkIntegerColumnDescriptor,
... )
>>> spark = SparkSession.builder.getOrCreate()
>>> domain = SparkDataFrameDomain(
...     {
...         "A": SparkIntegerColumnDescriptor(),
...         "B": SparkIntegerColumnDescriptor(),
...     }
... )
>>> metric = OnColumns(
...     [
...         OnColumn("A", SumOf(AbsoluteDifference())),
...         OnColumn("B", RootSumOfSquared(AbsoluteDifference())),
...     ]
... )
>>> value1 = spark.createDataFrame(
...     pd.DataFrame({"A": [1, 23], "B": [3, 1]})
... )
>>> value2 = spark.createDataFrame(
...     pd.DataFrame({"A": [2, 20], "B": [1, 8]})
... )
>>> metric.distance(value1, value2, domain)
(4, sqrt(53))
Parameters

on_columns (List[OnColumn]) –

__init__(on_columns)#

Constructor.

Parameters

on_columns (List[OnColumn]List[OnColumn]) – The OnColumn metrics to apply.

property on_columns(self)#

Return the OnColumn metrics to apply.

Return type

List[OnColumn]

validate(self, value)#

Raises an error if value not a valid distance.

  • value must be a tuple with one value for each metric in on_columns

  • each value must be a valid distance for the corresponding metric

Parameters

value (Tuple[tmlt.core.utils.exact_number.ExactNumberInput, ...]) – A distance between two datasets under this metric.

compare(self, value1, value2)#

Returns True if value1 is less than or equal to value2.

Parameters
  • value1 (Tuple[tmlt.core.utils.exact_number.ExactNumberInput, ...]) – A distance between two datasets under this metric.

  • value2 (Tuple[tmlt.core.utils.exact_number.ExactNumberInput, ...]) – A distance between two datasets under this metric.

Return type

bool

supports_domain(self, domain)#

Return True if the metric is implemented for the passed domain.

Parameters

domain (tmlt.core.domains.base.Domain) – The domain to check against.

Return type

bool

distance(self, value1, value2, domain)#

Return the metric distance between two elements of a supported domain.

Parameters
  • value1 (Any) – An element of the domain.

  • value2 (Any) – An element of the domain.

  • domain (tmlt.core.domains.base.Domain) – A domain compatible with the metric.

Return type

Tuple[tmlt.core.utils.exact_number.ExactNumber, …]

__eq__(self, other)#

Return True if both metrics are equal.

Parameters

other (Any) –

Return type

bool

class IfGroupedBy(column, inner_metric)#

Bases: ExactNumberMetric

Distance between two DataFrames that shall be grouped by a given attribute.

This metric is an upper bound on the distance for any fixed set of grouping keys. This assumes that the distance between two empty groups is zero, and the inner metric must satisfy this property.

The grouping column cannot contain floating point values.

Examples

>>> import pandas as pd
>>> from pyspark.sql import SparkSession
>>> from tmlt.core.domains.spark_domains import (
...     SparkIntegerColumnDescriptor,
... )
>>> spark = SparkSession.builder.getOrCreate()
>>> domain = SparkDataFrameDomain(
...     {
...         "A": SparkIntegerColumnDescriptor(),
...         "B": SparkIntegerColumnDescriptor(),
...         "C": SparkIntegerColumnDescriptor(),
...     },
... )
>>> metric = IfGroupedBy("C", RootSumOfSquared(SymmetricDifference()))
>>> value1 = spark.createDataFrame(
...     pd.DataFrame({"A": [1, 1, 3], "B": [2, 1, 4], "C": [1, 1, 2]}),
... )
>>> value2 = spark.createDataFrame(
...     pd.DataFrame({"A": [2, 1], "B": [1, 1], "C": [1, 1]})
... )
>>> metric.distance(value1, value2, domain)
sqrt(5)
>>> metric = IfGroupedBy("C", SymmetricDifference())
>>> value1 = spark.createDataFrame(
...     pd.DataFrame({"A": [1, 1, 3], "B": [2, 1, 4], "C": [1, 1, 2]}),
... )
>>> value2 = spark.createDataFrame(
...     pd.DataFrame({"A": [1, 1], "B": [2, 1], "C": [1, 1]})
... )
>>> metric.distance(value1, value2, domain)
1
Parameters
__init__(column, inner_metric)#

Constructor.

Parameters
property column(self)#

Column that DataFrame shall be grouped by.

Return type

str

property inner_metric(self)#

Metric to be applied for corresponding groups.

Return type

Union[SumOf, RootSumOfSquared, SymmetricDifference]

validate(self, value)#

Raises an error if value not a valid distance.

Parameters

value (tmlt.core.utils.exact_number.ExactNumberInput) – A distance between two datasets under this metric.

compare(self, value1, value2)#

Returns True if value1 is less than or equal to value2.

Parameters
  • value1 (tmlt.core.utils.exact_number.ExactNumberInput) – A distance between two datasets under this metric.

  • value2 (tmlt.core.utils.exact_number.ExactNumberInput) – A distance between two datasets under this metric.

Return type

bool

supports_domain(self, domain)#

Return True if the metric is implemented for the passed domain.

Parameters

domain (tmlt.core.domains.base.Domain) – The domain to check against.

Return type

bool

distance(self, value1, value2, domain)#

Return the metric distance between two elements of a supported domain.

Parameters
  • value1 (Any) – An element of the domain.

  • value2 (Any) – An element of the domain.

  • domain (tmlt.core.domains.base.Domain) – A domain compatible with the metric.

Return type

tmlt.core.utils.exact_number.ExactNumber

__eq__(self, other)#

Return True if both metrics are equal.

Parameters

other (Any) –

Return type

bool

class DictMetric(key_to_metric)#

Bases: Metric

Distance between two dictionaries with identical sets of keys.

Example

>>> import pandas as pd
>>> from pyspark.sql import SparkSession
>>> from tmlt.core.domains.spark_domains import (
...     SparkIntegerColumnDescriptor,
... )
>>> spark = SparkSession.builder.getOrCreate()
>>> metric = DictMetric(
...     {"x": AbsoluteDifference(), "y": SymmetricDifference()}
... )
>>> domain = DictDomain(
...     {
...         "x": NumpyIntegerDomain(),
...         "y": SparkDataFrameDomain(
...             {
...                 "A": SparkIntegerColumnDescriptor(),
...                 "B": SparkIntegerColumnDescriptor(),
...             }
...         ),
...     }
... )
>>> df1 = spark.createDataFrame(
...     pd.DataFrame({"A": [1, 1, 3], "B": [2, 1, 4]})
... )
>>> df2 = spark.createDataFrame(pd.DataFrame({"A": [2, 1], "B": [1, 1]}))
>>> value1 = {"x": np.int64(1), "y": df1}
>>> value2 = {"x": np.int64(10), "y": df2}
>>> metric.distance(value1, value2, domain)
{'x': 9, 'y': 3}
Parameters

key_to_metric (Dict[Any, Metric]) –

__init__(key_to_metric)#

Constructor.

Parameters

key_to_metric ({Any: Metric}Dict[Any, Metric]) – Mapping from dictionary key to metric.

property key_to_metric(self)#

Returns mapping from keys to metrics.

Return type

Dict[Any, Metric]

validate(self, value)#

Raises an error if value not a valid distance.

  • value must be a dictionary with the same keys as key_to_metric

  • each value in the dictionary must be a valid distance under the corresponding metric

Parameters

value (Dict[Any, Any]) – A distance between two datasets under this metric.

compare(self, value1, value2)#

Returns True if value1 is less than or equal to value2.

Parameters
  • value1 (Dict[Any, Any]) – A distance between two datasets under this metric.

  • value2 (Dict[Any, Any]) – A distance between two datasets under this metric.

Return type

bool

supports_domain(self, domain)#

Return True if the metric is implemented for the passed domain.

Parameters

domain (tmlt.core.domains.base.Domain) – The domain to check against.

Return type

bool

distance(self, value1, value2, domain)#

Return the metric distance between two elements of a supported domain.

Parameters
  • value1 (Any) – An element of the domain.

  • value2 (Any) – An element of the domain.

  • domain (tmlt.core.domains.base.Domain) – A domain compatible with the metric.

Return type

Dict[Any, Any]

__getitem__(self, key)#

Returns metric associated with given key.

Parameters

key (Any) –

Return type

Metric

__len__(self)#

Returns number of keys in the metric.

Return type

int

__eq__(self, other)#

Return True if both metrics are equal.

Parameters

other (Any) –

Return type

bool

class AddRemoveKeys(column)#

Bases: Metric

The number of keys that dictionaries of dataframe differ by.

This metric can be thought of as a extension of IfGroupedBy with inner metric SymmetricDifference, except it is applied to a dictionary of dataframes, instead of a single dataframe.

Both IfGroupedBy(X, SymmetricDifference()) and AddRemoveKeys(X) can be described in the following way:

Sum over each key that appears in column X in either neighbor

  • 0 if both neighbors “match” for X = key

  • 1 if only one neighbor has records for X = key

  • 2 if both neighbor have records for X = key, but they don’t “match”

The key column cannot containg floating point values, and all dataframes must have the same type for the key column.

Examples

>>> import pandas as pd
>>> from pyspark.sql import SparkSession
>>> from tmlt.core.domains.spark_domains import (
...     SparkIntegerColumnDescriptor,
...     SparkStringColumnDescriptor,
... )
>>> spark = SparkSession.builder.getOrCreate()
>>> domain = DictDomain(
...     {
...         1: SparkDataFrameDomain(
...             {
...                 "A": SparkIntegerColumnDescriptor(),
...                 "B": SparkIntegerColumnDescriptor(),
...             },
...         ),
...         2: SparkDataFrameDomain(
...             {
...                 "A": SparkIntegerColumnDescriptor(),
...                 "C": SparkStringColumnDescriptor(),
...             },
...         ),
...     }
... )
>>> metric = AddRemoveKeys("A")
>>> # A=1 matches, A=2 is only in value1, A=3 is only in value2, A=4 differs
>>> value1 = {
...     1: spark.createDataFrame(
...             pd.DataFrame(
...             {
...                 "A": [1, 1, 2],
...                 "B": [1, 1, 1],
...             }
...         )
...     ),
...     2: spark.createDataFrame(
...         pd.DataFrame(
...             {
...                 "A": [1, 4],
...                 "C": ["1", "1"],
...             }
...         )
...     )
... }
>>> value2 = {
...     1: spark.createDataFrame(
...             pd.DataFrame(
...             {
...                 "A": [1, 1, 3],
...                 "B": [1, 1, 1],
...             }
...         )
...     ),
...     2: spark.createDataFrame(
...         pd.DataFrame(
...             {
...                 "A": [1, 4],
...                 "C": ["1", "2"],
...             }
...         )
...     )
... }
>>> metric.distance(value1, value2, domain)
4
Parameters

column (str) –

__init__(column)#

Constructor.

Parameters

column (strstr) – The column defining the keys.

property column(self)#

Returns the key column.

Return type

str

validate(self, value)#

Raises an error if value not a valid distance.

  • value must be a nonnegative real or infinite

Parameters

value (tmlt.core.utils.exact_number.ExactNumberInput) – A distance between two datasets under this metric.

compare(self, value1, value2)#

Returns True if value1 is less than or equal to value2.

Parameters
  • value1 (tmlt.core.utils.exact_number.ExactNumberInput) –

  • value2 (tmlt.core.utils.exact_number.ExactNumberInput) –

Return type

bool

supports_domain(self, domain)#

Return True if the metric is implemented for the passed domain.

Parameters

domain (tmlt.core.domains.base.Domain) – The domain to check against.

Return type

bool

distance(self, value1, value2, domain)#

Return the metric distance between two elements of a supported domain.

Parameters
  • value1 (Any) – An element of the domain.

  • value2 (Any) – An element of the domain.

  • domain (tmlt.core.domains.base.Domain) – A domain compatible with the metric.

Return type

tmlt.core.utils.exact_number.ExactNumber

__eq__(self, other)#

Return True if both metrics are equal.

Parameters

other (Any) –

Return type

bool