truncation#

Transformations for truncating Spark DataFrames.

Classes#

LimitRowsPerGroup

Keep at most k rows per group.

LimitKeysPerGroup

Keep at most k keys per group.

LimitRowsPerKeyPerGroup

For each group, limit k rows per key.

class LimitRowsPerGroup(input_domain, output_metric, grouping_column, threshold)#

Bases: tmlt.core.transformations.base.Transformation

Keep at most k rows per group.

See truncate_large_groups() for more information about truncation.

Example

>>> # Example input
>>> print_sdf(spark_dataframe)
    A   B
0  a1  b1
1  a2  b1
2  a3  b2
3  a3  b2
4  a3  b2
5  a4  b1
6  a4  b2
7  a4  b3
8  a4  b4
>>> truncate = LimitRowsPerGroup(
...     input_domain=SparkDataFrameDomain(
...         {
...             "A": SparkStringColumnDescriptor(),
...             "B": SparkStringColumnDescriptor(),
...         }
...     ),
...     output_metric=SymmetricDifference(),
...     grouping_column="A",
...     threshold=2,
... )
>>> # Apply transformation to data
>>> truncated_spark_dataframe = truncate(spark_dataframe)
>>> print_sdf(truncated_spark_dataframe)
    A   B
0  a1  b1
1  a2  b1
2  a3  b2
3  a3  b2
4  a4  b3
5  a4  b4
Transformation Contract:
>>> truncate.input_domain
SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False)})
>>> truncate.output_domain
SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False)})
>>> truncate.input_metric
IfGroupedBy(column='A', inner_metric=SymmetricDifference())
>>> truncate.output_metric
SymmetricDifference()
Stability Guarantee:

LimitRowsPerGroup’s stability_function() returns threshold * d_in if output_metric is SymmetricDifference() and d_in otherwise.

>>> truncate.stability_function(1)
2
>>> truncate.stability_function(2)
4
Parameters
__init__(input_domain, output_metric, grouping_column, threshold)#

Constructor.

Parameters
property grouping_column#

Returns the column defining the groups to truncate.

Return type

str

property threshold#

Returns the maximum number of rows per group after truncation.

Return type

int

stability_function(d_in)#

Returns the smallest d_out satisfied by the transformation.

See the architecture overview for more information.

Parameters

d_in (tmlt.core.utils.exact_number.ExactNumberInput) – Distance between inputs under input_metric.

Return type

tmlt.core.utils.exact_number.ExactNumber

__call__(sdf)#

Returns a truncated dataframe.

Parameters

sdf (pyspark.sql.DataFrame) –

Return type

pyspark.sql.DataFrame

property input_domain#

Return input domain for the measurement.

Return type

tmlt.core.domains.base.Domain

property input_metric#

Distance metric on input domain.

Return type

tmlt.core.metrics.Metric

property output_domain#

Return input domain for the measurement.

Return type

tmlt.core.domains.base.Domain

property output_metric#

Distance metric on input domain.

Return type

tmlt.core.metrics.Metric

stability_relation(d_in, d_out)#

Returns True only if close inputs produce close outputs.

See the privacy and stability tutorial (add link?) for more information.

Parameters
  • d_in (Any) – Distance between inputs under input_metric.

  • d_out (Any) – Distance between outputs under output_metric.

Return type

bool

__or__(other: Transformation) Transformation#
__or__(other: tmlt.core.measurements.base.Measurement) tmlt.core.measurements.base.Measurement

Return this transformation chained with another component.

class LimitKeysPerGroup(input_domain, output_metric, grouping_column, key_column, threshold)#

Bases: tmlt.core.transformations.base.Transformation

Keep at most k keys per group.

See limit_keys_per_group() for more information about truncation.

Example

>>> # Example input
>>> print_sdf(spark_dataframe)
    A   B
0  a1  b1
1  a2  b1
2  a3  b2
3  a3  b2
4  a3  b2
5  a4  b1
6  a4  b2
7  a4  b3
8  a4  b4
>>> truncate = LimitKeysPerGroup(
...     input_domain=SparkDataFrameDomain(
...         {
...             "A": SparkStringColumnDescriptor(),
...             "B": SparkStringColumnDescriptor(),
...         }
...     ),
...     output_metric=IfGroupedBy("B", SumOf(IfGroupedBy("A", SymmetricDifference()))),
...     grouping_column="A",
...     key_column="B",
...     threshold=2,
... )
>>> # Apply transformation to data
>>> truncated_spark_dataframe = truncate(spark_dataframe)
>>> print_sdf(truncated_spark_dataframe)
    A   B
0  a1  b1
1  a2  b1
2  a3  b2
3  a3  b2
4  a3  b2
5  a4  b3
6  a4  b4
Transformation Contract:
>>> truncate.input_domain
SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False)})
>>> truncate.output_domain
SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False)})
>>> truncate.input_metric
IfGroupedBy(column='A', inner_metric=SymmetricDifference())
>>> truncate.output_metric
IfGroupedBy(column='B', inner_metric=SumOf(inner_metric=IfGroupedBy(column='A', inner_metric=SymmetricDifference())))
Stability Guarantee:

LimitKeysPerGroup’s stability_function() returns d_in if output_metric is IfGroupedBy(grouping_column, SymmetricDifference()), sqrt(threshold) * d_in if output_metric is IfGroupedBy(key_column, RootSumOfSquared(IfGroupedBy(grouping_column, SymmetricDifference()))), and threshold * d_in otherwise.

>>> truncate.stability_function(1)
2
>>> truncate.stability_function(2)
4
Methods#

grouping_column()

Returns the column defining the groups to truncate.

key_column()

Returns the column defining the keys.

threshold()

Returns the maximum number of keys per group after truncation.

stability_function()

Returns the smallest d_out satisfied by the transformation.

__call__()

Returns a truncated dataframe.

input_domain()

Return input domain for the measurement.

input_metric()

Distance metric on input domain.

output_domain()

Return input domain for the measurement.

output_metric()

Distance metric on input domain.

stability_relation()

Returns True only if close inputs produce close outputs.

__or__()

Return this transformation chained with another component.

Parameters
__init__(input_domain, output_metric, grouping_column, key_column, threshold)#

Constructor.

Parameters
  • input_domain (SparkDataFrameDomainSparkDataFrameDomain) – Domain of input DataFrame.

  • output_metric (IfGroupedByIfGroupedBy) – Distance metric for output DataFrames. This should be IfGroupedBy(key_column, SumOf(IfGroupedBy(grouping_column, SymmetricDifference()))) or IfGroupedBy(key_column, RootSumOfSquared(IfGroupedBy(grouping_column, SymmetricDifference()))) or IfGroupedBy(grouping_column, SymmetricDifference()).

  • grouping_column (strstr) – Name of column defining the groups to truncate.

  • key_column (strstr) – Name of column defining the keys.

  • threshold (intint) – The maximum number of keys per group after truncation.

property grouping_column#

Returns the column defining the groups to truncate.

Return type

str

property key_column#

Returns the column defining the keys.

Return type

str

property threshold#

Returns the maximum number of keys per group after truncation.

Return type

int

stability_function(d_in)#

Returns the smallest d_out satisfied by the transformation.

See the architecture overview for more information.

Parameters

d_in (tmlt.core.utils.exact_number.ExactNumberInput) – Distance between inputs under input_metric.

Return type

tmlt.core.utils.exact_number.ExactNumber

__call__(sdf)#

Returns a truncated dataframe.

Parameters

sdf (pyspark.sql.DataFrame) –

Return type

pyspark.sql.DataFrame

property input_domain#

Return input domain for the measurement.

Return type

tmlt.core.domains.base.Domain

property input_metric#

Distance metric on input domain.

Return type

tmlt.core.metrics.Metric

property output_domain#

Return input domain for the measurement.

Return type

tmlt.core.domains.base.Domain

property output_metric#

Distance metric on input domain.

Return type

tmlt.core.metrics.Metric

stability_relation(d_in, d_out)#

Returns True only if close inputs produce close outputs.

See the privacy and stability tutorial (add link?) for more information.

Parameters
  • d_in (Any) – Distance between inputs under input_metric.

  • d_out (Any) – Distance between outputs under output_metric.

Return type

bool

__or__(other: Transformation) Transformation#
__or__(other: tmlt.core.measurements.base.Measurement) tmlt.core.measurements.base.Measurement

Return this transformation chained with another component.

class LimitRowsPerKeyPerGroup(input_domain, input_metric, grouping_column, key_column, threshold)#

Bases: tmlt.core.transformations.base.Transformation

For each group, limit k rows per key.

See truncate_large_groups() for more information about truncation.

Example

>>> # Example input
>>> print_sdf(spark_dataframe)
    A   B
0  a1  b1
1  a2  b1
2  a3  b2
3  a3  b2
4  a3  b2
5  a4  b1
6  a4  b2
7  a4  b3
8  a4  b4
>>> truncate = LimitRowsPerKeyPerGroup(
...     input_domain=SparkDataFrameDomain(
...         {
...             "A": SparkStringColumnDescriptor(),
...             "B": SparkStringColumnDescriptor(),
...         }
...     ),
...     input_metric=IfGroupedBy("B", SumOf(IfGroupedBy("A", SymmetricDifference()))),
...     grouping_column="A",
...     key_column="B",
...     threshold=2,
... )
>>> # Apply transformation to data
>>> truncated_spark_dataframe = truncate(spark_dataframe)
>>> print_sdf(truncated_spark_dataframe)
    A   B
0  a1  b1
1  a2  b1
2  a3  b2
3  a3  b2
4  a4  b1
5  a4  b2
6  a4  b3
7  a4  b4
Transformation Contract:
>>> truncate.input_domain
SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False)})
>>> truncate.output_domain
SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False)})
>>> truncate.input_metric
IfGroupedBy(column='B', inner_metric=SumOf(inner_metric=IfGroupedBy(column='A', inner_metric=SymmetricDifference())))
>>> truncate.output_metric
SymmetricDifference()
Stability Guarantee:

LimitRowsPerKeyPerGroup’s stability_function() returns d_in if input_metric is IfGroupedBy(grouping_column, SymmetricDifference()) and threshold * d_in otherwise.

>>> truncate.stability_function(1)
2
>>> truncate.stability_function(2)
4
Methods#

grouping_column()

Returns the column defining the groups to truncate.

key_column()

Returns the column defining the keys.

threshold()

Returns the maximum number of rows each unique (key, grouping column value) pair may appear in after truncation.

stability_function()

Returns the smallest d_out satisfied by the transformation.

__call__()

Returns a truncated dataframe.

input_domain()

Return input domain for the measurement.

input_metric()

Distance metric on input domain.

output_domain()

Return input domain for the measurement.

output_metric()

Distance metric on input domain.

stability_relation()

Returns True only if close inputs produce close outputs.

__or__()

Return this transformation chained with another component.

Parameters
__init__(input_domain, input_metric, grouping_column, key_column, threshold)#

Constructor.

Parameters
  • input_domain (SparkDataFrameDomainSparkDataFrameDomain) – Domain of input DataFrame.

  • input_metric (IfGroupedByIfGroupedBy) – Distance metric for input DataFrames. This should be IfGroupedBy(key_column, SumOf(IfGroupedBy(grouping_column, SymmetricDifference()))) or IfGroupedBy(key_column, RootSumOfSquared(IfGroupedBy(grouping_column, SymmetricDifference()))) or IfGroupedBy(grouping_column, SymmetricDifference()).

  • grouping_column (strstr) – Name of column defining the groups to truncate.

  • key_column (strstr) – Name of column defining the keys.

  • threshold (intint) – The maximum number of rows each unique (key, grouping column value) pair may appear in after truncation.

property grouping_column#

Returns the column defining the groups to truncate.

Return type

str

property key_column#

Returns the column defining the keys.

Return type

str

property threshold#

Returns the maximum number of rows each unique (key, grouping column value) pair may appear in after truncation.

Return type

int

stability_function(d_in)#

Returns the smallest d_out satisfied by the transformation.

See the architecture overview for more information.

Parameters

d_in (tmlt.core.utils.exact_number.ExactNumberInput) – Distance between inputs under input_metric.

Return type

tmlt.core.utils.exact_number.ExactNumber

__call__(sdf)#

Returns a truncated dataframe.

Parameters

sdf (pyspark.sql.DataFrame) –

Return type

pyspark.sql.DataFrame

property input_domain#

Return input domain for the measurement.

Return type

tmlt.core.domains.base.Domain

property input_metric#

Distance metric on input domain.

Return type

tmlt.core.metrics.Metric

property output_domain#

Return input domain for the measurement.

Return type

tmlt.core.domains.base.Domain

property output_metric#

Distance metric on input domain.

Return type

tmlt.core.metrics.Metric

stability_relation(d_in, d_out)#

Returns True only if close inputs produce close outputs.

See the privacy and stability tutorial (add link?) for more information.

Parameters
  • d_in (Any) – Distance between inputs under input_metric.

  • d_out (Any) – Distance between outputs under output_metric.

Return type

bool

__or__(other: Transformation) Transformation#
__or__(other: tmlt.core.measurements.base.Measurement) tmlt.core.measurements.base.Measurement

Return this transformation chained with another component.