truncation#

Transformations for truncating Spark DataFrames.

Classes#

LimitRowsPerGroup

Keep at most k rows per group.

LimitKeysPerGroup

Keep at most k keys per group.

class LimitRowsPerGroup(input_domain, grouping_column, threshold)#

Bases: tmlt.core.transformations.base.Transformation

Keep at most k rows per group.

See truncate_large_groups() for more information about truncation.

Example

>>> # Example input
>>> print_sdf(spark_dataframe)
    A   B
0  a1  b1
1  a2  b1
2  a3  b2
3  a3  b2
4  a3  b2
5  a4  b1
6  a4  b2
7  a4  b3
8  a4  b4
>>> truncate = LimitRowsPerGroup(
...     input_domain=SparkDataFrameDomain(
...         {
...             "A": SparkStringColumnDescriptor(),
...             "B": SparkStringColumnDescriptor(),
...         }
...     ),
...     grouping_column="A",
...     threshold=2,
... )
>>> # Apply transformation to data
>>> truncated_spark_dataframe = truncate(spark_dataframe)
>>> print_sdf(truncated_spark_dataframe)
    A   B
0  a1  b1
1  a2  b1
2  a3  b2
3  a3  b2
4  a4  b2
5  a4  b3
Transformation Contract:
>>> truncate.input_domain
SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False)})
>>> truncate.output_domain
SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False)})
>>> truncate.input_metric
IfGroupedBy(column='A', inner_metric=SymmetricDifference())
>>> truncate.output_metric
SymmetricDifference()
Stability Guarantee:

LimitRowsPerGroup’s stability_function() returns threshold * d_in.

>>> truncate.stability_function(1)
2
>>> truncate.stability_function(2)
4
Parameters
__init__(input_domain, grouping_column, threshold)#

Constructor.

Parameters
property grouping_column(self)#

Returns the column defining the groups to truncate.

Return type

str

property threshold(self)#

Returns the maximum number of rows per group after truncation.

Return type

int

stability_function(self, d_in)#

Returns the smallest d_out satisfied by the transformation.

See the privacy and stability tutorial for more information. # TODO(#1320)

Parameters

d_in (tmlt.core.utils.exact_number.ExactNumberInput) – Distance between inputs under input_metric.

Return type

tmlt.core.utils.exact_number.ExactNumber

__call__(self, sdf)#

Returns a truncated dataframe.

Parameters

sdf (pyspark.sql.DataFrame) –

Return type

pyspark.sql.DataFrame

property input_domain(self)#

Return input domain for the measurement.

Return type

tmlt.core.domains.base.Domain

property input_metric(self)#

Distance metric on input domain.

Return type

tmlt.core.metrics.Metric

property output_domain(self)#

Return input domain for the measurement.

Return type

tmlt.core.domains.base.Domain

property output_metric(self)#

Distance metric on input domain.

Return type

tmlt.core.metrics.Metric

stability_relation(self, d_in, d_out)#

Returns True only if close inputs produce close outputs.

See the privacy and stability tutorial (add link?) for more information.

Parameters
  • d_in (Any) – Distance between inputs under input_metric.

  • d_out (Any) – Distance between outputs under output_metric.

Return type

bool

__or__(self, other: Transformation) Transformation#
__or__(self, other: tmlt.core.measurements.base.Measurement) tmlt.core.measurements.base.Measurement

Return this transformation chained with another component.

class LimitKeysPerGroup(input_domain, grouping_column, key_column, threshold, use_l2)#

Bases: tmlt.core.transformations.base.Transformation

Keep at most k keys per group.

See limit_keys_per_group() for more information about truncation.

Example

>>> # Example input
>>> print_sdf(spark_dataframe)
    A   B
0  a1  b1
1  a2  b1
2  a3  b2
3  a3  b2
4  a3  b2
5  a4  b1
6  a4  b2
7  a4  b3
8  a4  b4
>>> truncate = LimitKeysPerGroup(
...     input_domain=SparkDataFrameDomain(
...         {
...             "A": SparkStringColumnDescriptor(),
...             "B": SparkStringColumnDescriptor(),
...         }
...     ),
...     grouping_column="A",
...     key_column="B",
...     threshold=2,
...     use_l2=False,
... )
>>> # Apply transformation to data
>>> truncated_spark_dataframe = truncate(spark_dataframe)
>>> print_sdf(truncated_spark_dataframe)
    A   B
0  a1  b1
1  a2  b1
2  a3  b2
3  a3  b2
4  a3  b2
5  a4  b2
6  a4  b3
Transformation Contract:
>>> truncate.input_domain
SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False)})
>>> truncate.output_domain
SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False)})
>>> truncate.input_metric
IfGroupedBy(column='A', inner_metric=SymmetricDifference())
>>> truncate.output_metric
IfGroupedBy(column='B', inner_metric=SumOf(inner_metric=IfGroupedBy(column='A', inner_metric=SymmetricDifference())))
Stability Guarantee:

LimitKeysPerGroup’s stability_function() returns threshold * d_in if use_l2 is False and sqrt(threshold) * d_in otherwise.

>>> truncate.stability_function(1)
2
>>> truncate.stability_function(2)
4
Methods#

grouping_column()

Returns the column defining the groups to truncate.

key_column()

Returns the column defining the keys.

threshold()

Returns the maximum number of keys per group after truncation.

use_l2()

Returns whether the output metric will use RootSumOfSquared.

stability_function()

Returns the smallest d_out satisfied by the transformation.

__call__()

Returns a truncated dataframe.

input_domain()

Return input domain for the measurement.

input_metric()

Distance metric on input domain.

output_domain()

Return input domain for the measurement.

output_metric()

Distance metric on input domain.

stability_relation()

Returns True only if close inputs produce close outputs.

__or__()

Return this transformation chained with another component.

Parameters
__init__(input_domain, grouping_column, key_column, threshold, use_l2)#

Constructor.

Parameters
property grouping_column(self)#

Returns the column defining the groups to truncate.

Return type

str

property key_column(self)#

Returns the column defining the keys.

Return type

str

property threshold(self)#

Returns the maximum number of keys per group after truncation.

Return type

int

property use_l2(self)#

Returns whether the output metric will use RootSumOfSquared.

Return type

bool

stability_function(self, d_in)#

Returns the smallest d_out satisfied by the transformation.

See the privacy and stability tutorial for more information. # TODO(#1320)

Parameters

d_in (tmlt.core.utils.exact_number.ExactNumberInput) – Distance between inputs under input_metric.

Return type

tmlt.core.utils.exact_number.ExactNumber

__call__(self, sdf)#

Returns a truncated dataframe.

Parameters

sdf (pyspark.sql.DataFrame) –

Return type

pyspark.sql.DataFrame

property input_domain(self)#

Return input domain for the measurement.

Return type

tmlt.core.domains.base.Domain

property input_metric(self)#

Distance metric on input domain.

Return type

tmlt.core.metrics.Metric

property output_domain(self)#

Return input domain for the measurement.

Return type

tmlt.core.domains.base.Domain

property output_metric(self)#

Distance metric on input domain.

Return type

tmlt.core.metrics.Metric

stability_relation(self, d_in, d_out)#

Returns True only if close inputs produce close outputs.

See the privacy and stability tutorial (add link?) for more information.

Parameters
  • d_in (Any) – Distance between inputs under input_metric.

  • d_out (Any) – Distance between outputs under output_metric.

Return type

bool

__or__(self, other: Transformation) Transformation#
__or__(self, other: tmlt.core.measurements.base.Measurement) tmlt.core.measurements.base.Measurement

Return this transformation chained with another component.