truncation#
Transformations for truncating Spark DataFrames.
Classes#
Keep at most k rows per group. |
|
Keep at most k keys per group. |
- class LimitRowsPerGroup(input_domain, grouping_column, threshold)#
Bases:
tmlt.core.transformations.base.Transformation
Keep at most k rows per group.
See
truncate_large_groups()
for more information about truncation.Example
>>> # Example input >>> print_sdf(spark_dataframe) A B 0 a1 b1 1 a2 b1 2 a3 b2 3 a3 b2 4 a3 b2 5 a4 b1 6 a4 b2 7 a4 b3 8 a4 b4 >>> truncate = LimitRowsPerGroup( ... input_domain=SparkDataFrameDomain( ... { ... "A": SparkStringColumnDescriptor(), ... "B": SparkStringColumnDescriptor(), ... } ... ), ... grouping_column="A", ... threshold=2, ... ) >>> # Apply transformation to data >>> truncated_spark_dataframe = truncate(spark_dataframe) >>> print_sdf(truncated_spark_dataframe) A B 0 a1 b1 1 a2 b1 2 a3 b2 3 a3 b2 4 a4 b2 5 a4 b3
- Transformation Contract:
Input domain -
SparkDataFrameDomain
Output domain -
SparkDataFrameDomain
(matches input domain)Input metric -
IfGroupedBy
on the grouping column, with inner metricSymmetricDifference
Output metric -
SymmetricDifference
>>> truncate.input_domain SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False)}) >>> truncate.output_domain SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False)}) >>> truncate.input_metric IfGroupedBy(column='A', inner_metric=SymmetricDifference()) >>> truncate.output_metric SymmetricDifference()
- Stability Guarantee:
LimitRowsPerGroup
’sstability_function()
returns threshold * d_in.>>> truncate.stability_function(1) 2 >>> truncate.stability_function(2) 4
- Parameters
input_domain (tmlt.core.domains.spark_domains.SparkDataFrameDomain) –
grouping_column (str) –
threshold (int) –
- __init__(input_domain, grouping_column, threshold)#
Constructor.
- Parameters
input_domain (
SparkDataFrameDomain
SparkDataFrameDomain
) – Domain of input DataFrame.grouping_column (
str
str
) – Name of column defining the groups to truncate.threshold (
int
int
) – The maximum number of rows per group after truncation.
- property threshold(self)#
Returns the maximum number of rows per group after truncation.
- Return type
- stability_function(self, d_in)#
Returns the smallest d_out satisfied by the transformation.
See the privacy and stability tutorial for more information. # TODO(#1320)
- Parameters
d_in (tmlt.core.utils.exact_number.ExactNumberInput) – Distance between inputs under input_metric.
- Return type
- __call__(self, sdf)#
Returns a truncated dataframe.
- Parameters
sdf (pyspark.sql.DataFrame) –
- Return type
- property input_domain(self)#
Return input domain for the measurement.
- Return type
- property input_metric(self)#
Distance metric on input domain.
- Return type
- property output_domain(self)#
Return input domain for the measurement.
- Return type
- property output_metric(self)#
Distance metric on input domain.
- Return type
- stability_relation(self, d_in, d_out)#
Returns True only if close inputs produce close outputs.
See the privacy and stability tutorial (add link?) for more information.
- Parameters
d_in (Any) – Distance between inputs under input_metric.
d_out (Any) – Distance between outputs under output_metric.
- Return type
- __or__(self, other: Transformation) Transformation #
- __or__(self, other: tmlt.core.measurements.base.Measurement) tmlt.core.measurements.base.Measurement
Return this transformation chained with another component.
- class LimitKeysPerGroup(input_domain, grouping_column, key_column, threshold, use_l2)#
Bases:
tmlt.core.transformations.base.Transformation
Keep at most k keys per group.
See
limit_keys_per_group()
for more information about truncation.Example
>>> # Example input >>> print_sdf(spark_dataframe) A B 0 a1 b1 1 a2 b1 2 a3 b2 3 a3 b2 4 a3 b2 5 a4 b1 6 a4 b2 7 a4 b3 8 a4 b4 >>> truncate = LimitKeysPerGroup( ... input_domain=SparkDataFrameDomain( ... { ... "A": SparkStringColumnDescriptor(), ... "B": SparkStringColumnDescriptor(), ... } ... ), ... grouping_column="A", ... key_column="B", ... threshold=2, ... use_l2=False, ... ) >>> # Apply transformation to data >>> truncated_spark_dataframe = truncate(spark_dataframe) >>> print_sdf(truncated_spark_dataframe) A B 0 a1 b1 1 a2 b1 2 a3 b2 3 a3 b2 4 a3 b2 5 a4 b2 6 a4 b3
- Transformation Contract:
Input domain -
SparkDataFrameDomain
Output domain -
SparkDataFrameDomain
(matches input domain)Input metric -
IfGroupedBy
on the grouping column, with inner metricSymmetricDifference
Output metric -
IfGroupedBy
on the key column, with inner metric as aSumOf
(use_l2 is False) orRootSumOfSquared
(use_l2 is True) over aIfGroupedBy
on the grouping column, with inner metricSymmetricDifference
>>> truncate.input_domain SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False)}) >>> truncate.output_domain SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False)}) >>> truncate.input_metric IfGroupedBy(column='A', inner_metric=SymmetricDifference()) >>> truncate.output_metric IfGroupedBy(column='B', inner_metric=SumOf(inner_metric=IfGroupedBy(column='A', inner_metric=SymmetricDifference())))
- Stability Guarantee:
LimitKeysPerGroup
’sstability_function()
returns threshold * d_in if use_l2 is False and sqrt(threshold) * d_in otherwise.>>> truncate.stability_function(1) 2 >>> truncate.stability_function(2) 4
# Returns the column defining the groups to truncate.
Returns the column defining the keys.
Returns the maximum number of keys per group after truncation.
Returns whether the output metric will use
RootSumOfSquared
.Returns the smallest d_out satisfied by the transformation.
Returns a truncated dataframe.
Return input domain for the measurement.
Distance metric on input domain.
Return input domain for the measurement.
Distance metric on input domain.
Returns True only if close inputs produce close outputs.
Return this transformation chained with another component.
- Parameters
input_domain (tmlt.core.domains.spark_domains.SparkDataFrameDomain) –
grouping_column (str) –
key_column (str) –
threshold (int) –
use_l2 (bool) –
- __init__(input_domain, grouping_column, key_column, threshold, use_l2)#
Constructor.
- Parameters
input_domain (
SparkDataFrameDomain
SparkDataFrameDomain
) – Domain of input DataFrame.grouping_column (
str
str
) – Name of column defining the groups to truncate.threshold (
int
int
) – The maximum number of rows per group after truncation.use_l2 (
bool
bool
) – If True, useRootSumOfSquared
as the inner metric of the outputIfGroupedBy
metric of this transformation instead ofSumOf
.
- property threshold(self)#
Returns the maximum number of keys per group after truncation.
- Return type
- property use_l2(self)#
Returns whether the output metric will use
RootSumOfSquared
.- Return type
- stability_function(self, d_in)#
Returns the smallest d_out satisfied by the transformation.
See the privacy and stability tutorial for more information. # TODO(#1320)
- Parameters
d_in (tmlt.core.utils.exact_number.ExactNumberInput) – Distance between inputs under input_metric.
- Return type
- __call__(self, sdf)#
Returns a truncated dataframe.
- Parameters
sdf (pyspark.sql.DataFrame) –
- Return type
- property input_domain(self)#
Return input domain for the measurement.
- Return type
- property input_metric(self)#
Distance metric on input domain.
- Return type
- property output_domain(self)#
Return input domain for the measurement.
- Return type
- property output_metric(self)#
Distance metric on input domain.
- Return type
- stability_relation(self, d_in, d_out)#
Returns True only if close inputs produce close outputs.
See the privacy and stability tutorial (add link?) for more information.
- Parameters
d_in (Any) – Distance between inputs under input_metric.
d_out (Any) – Distance between outputs under output_metric.
- Return type
- __or__(self, other: Transformation) Transformation #
- __or__(self, other: tmlt.core.measurements.base.Measurement) tmlt.core.measurements.base.Measurement
Return this transformation chained with another component.