truncation#
Transformations for truncating Spark DataFrames.
Classes#
Keep at most k rows per group. |
|
Keep at most k keys per group. |
|
For each group, limit k rows per key. |
- class LimitRowsPerGroup(input_domain, output_metric, grouping_column, threshold)#
Bases:
tmlt.core.transformations.base.Transformation
Keep at most k rows per group.
See
truncate_large_groups()
for more information about truncation.Example
>>> # Example input >>> print_sdf(spark_dataframe) A B 0 a1 b1 1 a2 b1 2 a3 b2 3 a3 b2 4 a3 b2 5 a4 b1 6 a4 b2 7 a4 b3 8 a4 b4 >>> truncate = LimitRowsPerGroup( ... input_domain=SparkDataFrameDomain( ... { ... "A": SparkStringColumnDescriptor(), ... "B": SparkStringColumnDescriptor(), ... } ... ), ... output_metric=SymmetricDifference(), ... grouping_column="A", ... threshold=2, ... ) >>> # Apply transformation to data >>> truncated_spark_dataframe = truncate(spark_dataframe) >>> print_sdf(truncated_spark_dataframe) A B 0 a1 b1 1 a2 b1 2 a3 b2 3 a3 b2 4 a4 b3 5 a4 b4
- Transformation Contract:
Input domain -
SparkDataFrameDomain
Output domain -
SparkDataFrameDomain
(matches input domain)Input metric -
IfGroupedBy
on the grouping column, with inner metricSymmetricDifference
Output metric -
SymmetricDifference
orIfGroupedBy
on the grouping column, with inner metricSymmetricDifference
>>> truncate.input_domain SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False)}) >>> truncate.output_domain SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False)}) >>> truncate.input_metric IfGroupedBy(column='A', inner_metric=SymmetricDifference()) >>> truncate.output_metric SymmetricDifference()
- Stability Guarantee:
LimitRowsPerGroup
’sstability_function()
returns threshold * d_in if output_metric is SymmetricDifference() and d_in otherwise.>>> truncate.stability_function(1) 2 >>> truncate.stability_function(2) 4
- Parameters
input_domain (tmlt.core.domains.spark_domains.SparkDataFrameDomain) –
output_metric (Union[tmlt.core.metrics.SymmetricDifference, tmlt.core.metrics.IfGroupedBy]) –
grouping_column (str) –
threshold (int) –
- __init__(input_domain, output_metric, grouping_column, threshold)#
Constructor.
- Parameters
input_domain (
SparkDataFrameDomain
SparkDataFrameDomain
) – Domain of input DataFrame.output_metric (
SymmetricDifference
|IfGroupedBy
Union
[SymmetricDifference
,IfGroupedBy
]) – Distance metric for output DataFrames. This should be SymmetricDifference() or IfGroupedBy(grouping_column, SymmetricDifference()).grouping_column (
str
str
) – Name of column defining the groups to truncate.threshold (
int
int
) – The maximum number of rows per group after truncation.
- stability_function(d_in)#
Returns the smallest d_out satisfied by the transformation.
See the architecture overview for more information.
- Parameters
d_in (tmlt.core.utils.exact_number.ExactNumberInput) – Distance between inputs under input_metric.
- Return type
- __call__(sdf)#
Returns a truncated dataframe.
- Parameters
sdf (pyspark.sql.DataFrame) –
- Return type
- property input_domain#
Return input domain for the measurement.
- Return type
- property input_metric#
Distance metric on input domain.
- Return type
- property output_domain#
Return input domain for the measurement.
- Return type
- property output_metric#
Distance metric on input domain.
- Return type
- stability_relation(d_in, d_out)#
Returns True only if close inputs produce close outputs.
See the privacy and stability tutorial (add link?) for more information.
- Parameters
d_in (Any) – Distance between inputs under input_metric.
d_out (Any) – Distance between outputs under output_metric.
- Return type
- __or__(other: Transformation) Transformation #
- __or__(other: tmlt.core.measurements.base.Measurement) tmlt.core.measurements.base.Measurement
Return this transformation chained with another component.
- class LimitKeysPerGroup(input_domain, output_metric, grouping_column, key_column, threshold)#
Bases:
tmlt.core.transformations.base.Transformation
Keep at most k keys per group.
See
limit_keys_per_group()
for more information about truncation.Example
>>> # Example input >>> print_sdf(spark_dataframe) A B 0 a1 b1 1 a2 b1 2 a3 b2 3 a3 b2 4 a3 b2 5 a4 b1 6 a4 b2 7 a4 b3 8 a4 b4 >>> truncate = LimitKeysPerGroup( ... input_domain=SparkDataFrameDomain( ... { ... "A": SparkStringColumnDescriptor(), ... "B": SparkStringColumnDescriptor(), ... } ... ), ... output_metric=IfGroupedBy("B", SumOf(IfGroupedBy("A", SymmetricDifference()))), ... grouping_column="A", ... key_column="B", ... threshold=2, ... ) >>> # Apply transformation to data >>> truncated_spark_dataframe = truncate(spark_dataframe) >>> print_sdf(truncated_spark_dataframe) A B 0 a1 b1 1 a2 b1 2 a3 b2 3 a3 b2 4 a3 b2 5 a4 b3 6 a4 b4
- Transformation Contract:
Input domain -
SparkDataFrameDomain
Output domain -
SparkDataFrameDomain
(matches input domain)Input metric -
IfGroupedBy
on the grouping column, with inner metricSymmetricDifference
Output metric -
IfGroupedBy
on the grouping column, with inner metricSymmetricDifference
orIfGroupedBy
on the key column, with inner metric as aSumOf
orRootSumOfSquared
over aIfGroupedBy
on the grouping column, with inner metricSymmetricDifference
>>> truncate.input_domain SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False)}) >>> truncate.output_domain SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False)}) >>> truncate.input_metric IfGroupedBy(column='A', inner_metric=SymmetricDifference()) >>> truncate.output_metric IfGroupedBy(column='B', inner_metric=SumOf(inner_metric=IfGroupedBy(column='A', inner_metric=SymmetricDifference())))
- Stability Guarantee:
LimitKeysPerGroup
’sstability_function()
returns d_in if output_metric is IfGroupedBy(grouping_column, SymmetricDifference()), sqrt(threshold) * d_in if output_metric is IfGroupedBy(key_column, RootSumOfSquared(IfGroupedBy(grouping_column, SymmetricDifference()))), and threshold * d_in otherwise.>>> truncate.stability_function(1) 2 >>> truncate.stability_function(2) 4
# Returns the column defining the groups to truncate.
Returns the column defining the keys.
Returns the maximum number of keys per group after truncation.
Returns the smallest d_out satisfied by the transformation.
Returns a truncated dataframe.
Return input domain for the measurement.
Distance metric on input domain.
Return input domain for the measurement.
Distance metric on input domain.
Returns True only if close inputs produce close outputs.
Return this transformation chained with another component.
- Parameters
input_domain (tmlt.core.domains.spark_domains.SparkDataFrameDomain) –
output_metric (tmlt.core.metrics.IfGroupedBy) –
grouping_column (str) –
key_column (str) –
threshold (int) –
- __init__(input_domain, output_metric, grouping_column, key_column, threshold)#
Constructor.
- Parameters
input_domain (
SparkDataFrameDomain
SparkDataFrameDomain
) – Domain of input DataFrame.output_metric (
IfGroupedBy
IfGroupedBy
) – Distance metric for output DataFrames. This should be IfGroupedBy(key_column, SumOf(IfGroupedBy(grouping_column, SymmetricDifference()))) or IfGroupedBy(key_column, RootSumOfSquared(IfGroupedBy(grouping_column, SymmetricDifference()))) or IfGroupedBy(grouping_column, SymmetricDifference()).grouping_column (
str
str
) – Name of column defining the groups to truncate.threshold (
int
int
) – The maximum number of keys per group after truncation.
- stability_function(d_in)#
Returns the smallest d_out satisfied by the transformation.
See the architecture overview for more information.
- Parameters
d_in (tmlt.core.utils.exact_number.ExactNumberInput) – Distance between inputs under input_metric.
- Return type
- __call__(sdf)#
Returns a truncated dataframe.
- Parameters
sdf (pyspark.sql.DataFrame) –
- Return type
- property input_domain#
Return input domain for the measurement.
- Return type
- property input_metric#
Distance metric on input domain.
- Return type
- property output_domain#
Return input domain for the measurement.
- Return type
- property output_metric#
Distance metric on input domain.
- Return type
- stability_relation(d_in, d_out)#
Returns True only if close inputs produce close outputs.
See the privacy and stability tutorial (add link?) for more information.
- Parameters
d_in (Any) – Distance between inputs under input_metric.
d_out (Any) – Distance between outputs under output_metric.
- Return type
- __or__(other: Transformation) Transformation #
- __or__(other: tmlt.core.measurements.base.Measurement) tmlt.core.measurements.base.Measurement
Return this transformation chained with another component.
- class LimitRowsPerKeyPerGroup(input_domain, input_metric, grouping_column, key_column, threshold)#
Bases:
tmlt.core.transformations.base.Transformation
For each group, limit k rows per key.
See
truncate_large_groups()
for more information about truncation.Example
>>> # Example input >>> print_sdf(spark_dataframe) A B 0 a1 b1 1 a2 b1 2 a3 b2 3 a3 b2 4 a3 b2 5 a4 b1 6 a4 b2 7 a4 b3 8 a4 b4 >>> truncate = LimitRowsPerKeyPerGroup( ... input_domain=SparkDataFrameDomain( ... { ... "A": SparkStringColumnDescriptor(), ... "B": SparkStringColumnDescriptor(), ... } ... ), ... input_metric=IfGroupedBy("B", SumOf(IfGroupedBy("A", SymmetricDifference()))), ... grouping_column="A", ... key_column="B", ... threshold=2, ... ) >>> # Apply transformation to data >>> truncated_spark_dataframe = truncate(spark_dataframe) >>> print_sdf(truncated_spark_dataframe) A B 0 a1 b1 1 a2 b1 2 a3 b2 3 a3 b2 4 a4 b1 5 a4 b2 6 a4 b3 7 a4 b4
- Transformation Contract:
Input domain -
SparkDataFrameDomain
Output domain -
SparkDataFrameDomain
(matches input domain)Input metric -
IfGroupedBy
on the grouping column, with inner metricSymmetricDifference
orIfGroupedBy
on the key column, with inner metric as aSumOf
orRootSumOfSquared
over aIfGroupedBy
on the grouping column, with inner metricSymmetricDifference
Output metric -
SymmetricDifference
orIfGroupedBy
on the key column, with inner metric as aRootSumOfSquared
, with inner metricSymmetricDifference
orIfGroupedBy
on the grouping column, with inner metricSymmetricDifference
>>> truncate.input_domain SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False)}) >>> truncate.output_domain SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False)}) >>> truncate.input_metric IfGroupedBy(column='B', inner_metric=SumOf(inner_metric=IfGroupedBy(column='A', inner_metric=SymmetricDifference()))) >>> truncate.output_metric SymmetricDifference()
- Stability Guarantee:
LimitRowsPerKeyPerGroup
’sstability_function()
returns d_in if input_metric is IfGroupedBy(grouping_column, SymmetricDifference()) and threshold * d_in otherwise.>>> truncate.stability_function(1) 2 >>> truncate.stability_function(2) 4
# Returns the column defining the groups to truncate.
Returns the column defining the keys.
Returns the maximum number of rows each unique (key, grouping column value) pair may appear in after truncation.
Returns the smallest d_out satisfied by the transformation.
Returns a truncated dataframe.
Return input domain for the measurement.
Distance metric on input domain.
Return input domain for the measurement.
Distance metric on input domain.
Returns True only if close inputs produce close outputs.
Return this transformation chained with another component.
- Parameters
input_domain (tmlt.core.domains.spark_domains.SparkDataFrameDomain) –
input_metric (tmlt.core.metrics.IfGroupedBy) –
grouping_column (str) –
key_column (str) –
threshold (int) –
- __init__(input_domain, input_metric, grouping_column, key_column, threshold)#
Constructor.
- Parameters
input_domain (
SparkDataFrameDomain
SparkDataFrameDomain
) – Domain of input DataFrame.input_metric (
IfGroupedBy
IfGroupedBy
) – Distance metric for input DataFrames. This should be IfGroupedBy(key_column, SumOf(IfGroupedBy(grouping_column, SymmetricDifference()))) or IfGroupedBy(key_column, RootSumOfSquared(IfGroupedBy(grouping_column, SymmetricDifference()))) or IfGroupedBy(grouping_column, SymmetricDifference()).grouping_column (
str
str
) – Name of column defining the groups to truncate.threshold (
int
int
) – The maximum number of rows each unique (key, grouping column value) pair may appear in after truncation.
- property threshold#
Returns the maximum number of rows each unique (key, grouping column value) pair may appear in after truncation.
- Return type
- stability_function(d_in)#
Returns the smallest d_out satisfied by the transformation.
See the architecture overview for more information.
- Parameters
d_in (tmlt.core.utils.exact_number.ExactNumberInput) – Distance between inputs under input_metric.
- Return type
- __call__(sdf)#
Returns a truncated dataframe.
- Parameters
sdf (pyspark.sql.DataFrame) –
- Return type
- property input_domain#
Return input domain for the measurement.
- Return type
- property input_metric#
Distance metric on input domain.
- Return type
- property output_domain#
Return input domain for the measurement.
- Return type
- property output_metric#
Distance metric on input domain.
- Return type
- stability_relation(d_in, d_out)#
Returns True only if close inputs produce close outputs.
See the privacy and stability tutorial (add link?) for more information.
- Parameters
d_in (Any) – Distance between inputs under input_metric.
d_out (Any) – Distance between outputs under output_metric.
- Return type
- __or__(other: Transformation) Transformation #
- __or__(other: tmlt.core.measurements.base.Measurement) tmlt.core.measurements.base.Measurement
Return this transformation chained with another component.