join#
Transformations for joining Spark DataFrames.
See the architecture overview for more information on transformations.
Classes#
Join a Spark DataFrame with a public Pandas DataFrame. |
|
Enumerating truncation strategies for PrivateJoin. |
|
Join two private SparkDataFrames. |
|
Join two private SparkDataFrames including a key column. |
- class PublicJoin(input_domain, metric, public_df, public_df_domain=None, join_cols=None, join_on_nulls=False, how='inner')#
Bases:
tmlt.core.transformations.base.Transformation
Join a Spark DataFrame with a public Pandas DataFrame.
Performs an inner join. By default, this mimics the behavior of a PySpark join, but it can also be set to consider null values equal to each other (unlike PySpark).
Examples
Natural join:
>>> # Example input >>> print_sdf(spark_dataframe) A B 0 a1 b1 1 a2 b1 2 a3 b2 3 a3 b2 >>> # Create example public dataframe >>> public_dataframe = spark.createDataFrame( ... pd.DataFrame( ... { ... "B": ["b1", "b2", "b2"], ... "C": ["c1", "c2", "c3"], ... } ... ) ... ) >>> # Create the transformation >>> natural_join = PublicJoin( ... input_domain=SparkDataFrameDomain( ... { ... "A": SparkStringColumnDescriptor(), ... "B": SparkStringColumnDescriptor(), ... } ... ), ... public_df=public_dataframe, ... metric=SymmetricDifference(), ... ) >>> # Apply transformation to data >>> joined_spark_dataframe = natural_join(spark_dataframe) >>> print_sdf(joined_spark_dataframe) B A C 0 b1 a1 c1 1 b1 a2 c1 2 b2 a3 c2 3 b2 a3 c2 4 b2 a3 c3 5 b2 a3 c3
Join with some common columns excluded from join:
>>> # Example input >>> print_sdf(spark_dataframe) A B 0 a1 b1 1 a2 b1 2 a3 b2 3 a3 b2 >>> # Create example public dataframe >>> public_dataframe = spark.createDataFrame( ... pd.DataFrame( ... { ... "A": ["a1", "a1", "a2"], ... "B": ["b1", "b1", "b2"], ... } ... ) ... ) >>> # Create the transformation >>> public_join = PublicJoin( ... input_domain=SparkDataFrameDomain( ... { ... "A": SparkStringColumnDescriptor(), ... "B": SparkStringColumnDescriptor(), ... } ... ), ... public_df=public_dataframe, ... metric=SymmetricDifference(), ... join_cols=["A"], ... ) >>> # Apply transformation to data >>> joined_spark_dataframe = public_join(spark_dataframe) >>> print_sdf(joined_spark_dataframe) A B_left B_right 0 a1 b1 b1 1 a1 b1 b1 2 a2 b1 b2
Join on nulls
>>> # Example input >>> print_sdf(spark_dataframe_with_null) A B 0 a1 b1 1 a2 b1 2 a3 b2 3 None b2 >>> # Create example public dataframe >>> public_dataframe = spark.createDataFrame( ... pd.DataFrame( ... { ... "A": ["a1", "a2", None], ... "C": ["c1", "c2", "c3"], ... } ... ) ... ) >>> # Create the transformation >>> join_transformation = PublicJoin( ... input_domain=SparkDataFrameDomain( ... { ... "A": SparkStringColumnDescriptor(), ... "B": SparkStringColumnDescriptor(), ... } ... ), ... public_df=public_dataframe, ... metric=SymmetricDifference(), ... join_on_nulls=True, ... ) >>> # Apply transformation to data >>> joined_spark_dataframe = join_transformation(spark_dataframe_with_null) >>> print_sdf(joined_spark_dataframe) A B C 0 a1 b1 c1 1 a2 b1 c2 2 None b2 c3
- Transformation Contract:
Input domain -
SparkDataFrameDomain
Output domain -
SparkDataFrameDomain
Input metric -
SymmetricDifference
orIfGroupedBy
Output metric -
SymmetricDifference
orIfGroupedBy
(matches input metric)
>>> public_join.input_domain SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False)}) >>> public_join.output_domain SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B_left': SparkStringColumnDescriptor(allow_null=False), 'B_right': SparkStringColumnDescriptor(allow_null=True)}) >>> public_join.input_metric SymmetricDifference() >>> public_join.output_metric SymmetricDifference()
- Stability Guarantee:
For
SymmetricDifference()
IfGroupedBy(column, SumOf(SymmetricDifference()))
IfGroupedBy(column, RootSumOfSquared(SymmetricDifference()))
PublicJoin
’sstability_function()
returns thed_in
times the maximum count of any combination of values in the join columns ofpublic_df
.>>> # Both example transformations had a stability of 2 >>> natural_join.join_cols ['B'] >>> natural_join.public_df.toPandas() B C 0 b1 c1 1 b2 c2 2 b2 c3 >>> # Notice that 'b2' occurs twice >>> natural_join.stability_function(1) 2 >>> natural_join.stability_function(2) 4
For
IfGroupedBy(column, SymmetricDifference())
PublicJoin
’sstability_function()
returnsd_in
>>> PublicJoin( ... input_domain=SparkDataFrameDomain( ... { ... "A": SparkStringColumnDescriptor(), ... "B": SparkStringColumnDescriptor(), ... } ... ), ... public_df=public_dataframe, ... metric=IfGroupedBy("A", SymmetricDifference()), ... ).stability_function(2) 2
- Parameters:
input_domain (tmlt.core.domains.spark_domains.SparkDataFrameDomain)
metric (Union[tmlt.core.metrics.SymmetricDifference, tmlt.core.metrics.IfGroupedBy])
public_df (pyspark.sql.DataFrame)
public_df_domain (Optional[tmlt.core.domains.spark_domains.SparkDataFrameDomain])
join_cols (Optional[List[str]])
join_on_nulls (bool)
how (str)
- property public_df: pyspark.sql.DataFrame#
Returns Pandas DataFrame being joined with.
- Return type:
- property stability: int#
Returns stability of public join.
The stability is the maximum count of any combination of values in the join columns.
- Return type:
- property input_domain: tmlt.core.domains.base.Domain#
Return input domain for the measurement.
- Return type:
- property input_metric: tmlt.core.metrics.Metric#
Distance metric on input domain.
- Return type:
- property output_domain: tmlt.core.domains.base.Domain#
Return input domain for the measurement.
- Return type:
- property output_metric: tmlt.core.metrics.Metric#
Distance metric on input domain.
- Return type:
- __init__(input_domain, metric, public_df, public_df_domain=None, join_cols=None, join_on_nulls=False, how='inner')#
Constructor.
- Parameters:
input_domain (
SparkDataFrameDomain
) – Domain of the input Spark DataFrames.metric (
Union
[SymmetricDifference
,IfGroupedBy
]) – Metric for input/output Spark DataFrames.public_df (
DataFrame
) – A Spark DataFrame to join with.public_df_domain (
Optional
[SparkDataFrameDomain
]) – Domain of public DataFrame to join with. If this domain indicates that a float column does not allow nans (or infs), all rows inpublic_df
containing a nan (or an inf) in that column will be dropped. If None, domain is inferred from the schema ofpublic_df
and any float column will be marked as allowing inf and nan values.join_cols (
Optional
[List
[str
]]) – Names of columns to join on. If None, a natural join is performed.join_on_nulls (
bool
) – If True, null values on corresponding join columns of the public and private DataFrames will be considered to be equal.how (
str
) – Type of join to perform. Defaults to “inner”. Note that only “inner” and “left” joins are supported.
- stability_function(d_in)#
Returns the smallest d_out satisfied by the transformation.
See the privacy and stability tutorial (add link?) for more information.
- Parameters:
d_in (tmlt.core.utils.exact_number.ExactNumberInput) – Distance between inputs under input_metric.
- Return type:
- __call__(sdf)#
Perform public join.
- Parameters:
sdf (pyspark.sql.DataFrame) – Private DataFrame to join public DataFrame with.
- Return type:
- stability_relation(d_in, d_out)#
Returns True only if close inputs produce close outputs.
See the privacy and stability tutorial (add link?) for more information.
- Parameters:
d_in (Any) – Distance between inputs under input_metric.
d_out (Any) – Distance between outputs under output_metric.
- Return type:
- __or__(other: Transformation) Transformation #
- __or__(other: tmlt.core.measurements.base.Measurement) tmlt.core.measurements.base.Measurement
Return this transformation chained with another component.
- class TruncationStrategy#
Bases:
enum.Enum
Enumerating truncation strategies for PrivateJoin.
See
stability_function()
for the stability of each strategy.- TRUNCATE = 1#
- DROP = 2#
Use
drop_large_groups()
.
- NO_TRUNCATION = 3#
No truncation, results in infinite stability.
- class PrivateJoin(input_domain, left_key, right_key, left_truncation_strategy, right_truncation_strategy, left_truncation_threshold, right_truncation_threshold, join_cols=None, join_on_nulls=False)#
Bases:
tmlt.core.transformations.base.Transformation
Join two private SparkDataFrames.
Performs an inner join. By default, this mimics the behavior of a PySpark join, but it can also be set to consider null values equal to each other (unlike PySpark).
Example
>>> # Example input >>> print_sdf(left_spark_dataframe) A B X 0 a1 b1 2 1 a1 b1 3 2 a1 b1 5 3 a1 b2 -1 4 a1 b2 4 5 a2 b1 -5 >>> print_sdf(right_spark_dataframe) B C 0 b1 c1 1 b2 c2 2 b2 c3 >>> # Create transformation >>> left_domain = SparkDataFrameDomain( ... { ... "A": SparkStringColumnDescriptor(), ... "B": SparkStringColumnDescriptor(), ... "X": SparkIntegerColumnDescriptor(), ... }, ... ) >>> assert left_spark_dataframe in left_domain >>> right_domain = SparkDataFrameDomain( ... { ... "B": SparkStringColumnDescriptor(), ... "C": SparkStringColumnDescriptor(), ... }, ... ) >>> assert right_spark_dataframe in right_domain >>> private_join = PrivateJoin( ... input_domain=DictDomain( ... { ... "left": left_domain, ... "right": right_domain, ... } ... ), ... left_key="left", ... right_key="right", ... left_truncation_strategy=TruncationStrategy.TRUNCATE, ... left_truncation_threshold=2, ... right_truncation_strategy=TruncationStrategy.TRUNCATE, ... right_truncation_threshold=2, ... ) >>> input_dictionary = { ... "left": left_spark_dataframe, ... "right": right_spark_dataframe ... } >>> # Apply transformation to data >>> joined_dataframe = private_join(input_dictionary) >>> print_sdf(joined_dataframe) B A X C 0 b1 a1 5 c1 1 b1 a2 -5 c1 2 b2 a1 -1 c2 3 b2 a1 -1 c3 4 b2 a1 4 c2 5 b2 a1 4 c3
- Transformation Contract:
Input domain -
DictDomain
containing two SparkDataFrame domains.Output domain -
SparkDataFrameDomain
Input metric -
DictMetric
withSymmetricDifference
for each input.Output metric -
SymmetricDifference
>>> private_join.input_domain DictDomain(key_to_domain={'left': SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False), 'X': SparkIntegerColumnDescriptor(allow_null=False, size=64)}), 'right': SparkDataFrameDomain(schema={'B': SparkStringColumnDescriptor(allow_null=False), 'C': SparkStringColumnDescriptor(allow_null=False)})}) >>> private_join.output_domain SparkDataFrameDomain(schema={'B': SparkStringColumnDescriptor(allow_null=False), 'A': SparkStringColumnDescriptor(allow_null=False), 'X': SparkIntegerColumnDescriptor(allow_null=False, size=64), 'C': SparkStringColumnDescriptor(allow_null=False)}) >>> private_join.input_metric DictMetric(key_to_metric={'left': SymmetricDifference(), 'right': SymmetricDifference()}) >>> private_join.output_metric SymmetricDifference()
- Stability Guarantee:
Let \(T_l\) and \(T_r\) be the left and right truncation strategies with stabilities \(s_l\) and \(s_r\) and thresholds \(\tau_l\) and \(\tau_r\).
PublicJoin
’sstability_function()
returns\[\tau_l \cdot s_r \cdot (df_{r1} \Delta df_{r2}) + \tau_r \cdot s_l \cdot (df_{l1} \Delta df_{l2})\]where:
\(df_{r1} \Delta df_{r2}\) is
d_in[self.right]
\(df_{l1} \Delta df_{l2}\) is
d_in[self.left]
TruncationStrategy.DROP has a stability equal to the truncation threshold (This is because adding a row can cause a number of rows equal to the truncation threshold to be dropped).
TruncationStrategy.TRUNCATE has a stability of 2 (This is because adding a new row can not only add a new row to the output, it also can displace another row)
TruncationStrategy.NO_TRUNCATION has infinite stablity.
>>> # TRUNCATE has a stability of 2 >>> s_r = s_l = private_join.truncation_strategy_stability( ... TruncationStrategy.TRUNCATE, 1 ... ) >>> tau_r = tau_l = 2 >>> tau_l * s_r * 1 + tau_r * s_l * 1 8 >>> private_join.stability_function({"left": 1, "right": 1}) 8
- Parameters:
input_domain (tmlt.core.domains.collections.DictDomain)
left_key (Any)
right_key (Any)
left_truncation_strategy (TruncationStrategy)
right_truncation_strategy (TruncationStrategy)
join_cols (Optional[List[str]])
join_on_nulls (bool)
- property left_key: Any#
Returns key to left DataFrame.
- Return type:
Any
- property right_key: Any#
Returns key to right DataFrame.
- Return type:
Any
- property left_truncation_strategy: TruncationStrategy#
Returns TruncationStrategy for truncating the left DataFrame.
- Return type:
- property right_truncation_strategy: TruncationStrategy#
Returns TruncationStrategy for truncating the right DataFrame.
- Return type:
- property left_truncation_threshold: int | float#
Returns the threshold for truncating the left DataFrame.
- property right_truncation_threshold: int | float#
Returns the threshold for truncating the right DataFrame.
- property input_domain: tmlt.core.domains.base.Domain#
Return input domain for the measurement.
- Return type:
- property input_metric: tmlt.core.metrics.Metric#
Distance metric on input domain.
- Return type:
- property output_domain: tmlt.core.domains.base.Domain#
Return input domain for the measurement.
- Return type:
- property output_metric: tmlt.core.metrics.Metric#
Distance metric on input domain.
- Return type:
- __init__(input_domain, left_key, right_key, left_truncation_strategy, right_truncation_strategy, left_truncation_threshold, right_truncation_threshold, join_cols=None, join_on_nulls=False)#
Constructor.
The following conditions are checked:
input_domain
is a DictDomain with 2SparkDataFrameDomain
s.left
andright
are the two keys in the input domain.join_cols
is not empty, when provided or computed (if None).Columns in
join_cols
are common to both tables.Columns in
join_cols
have matching column types in both tables.
- Parameters:
input_domain (
DictDomain
) – Domain of input dictionaries (with exactly two keys).left_key (
Any
) – Key for the left DataFrame.right_key (
Any
) – Key for the right DataFrame.left_truncation_strategy (
TruncationStrategy
) –TruncationStrategy
to use for truncating the left DataFrame.right_truncation_strategy (
TruncationStrategy
) –TruncationStrategy
to use for truncating the right DataFrame.left_truncation_threshold (
Union
[int
,float
]) – The maximum number of rows to allow for each combination of values ofjoin_cols
in the left DataFrame.right_truncation_threshold (
Union
[int
,float
]) – The maximum number of rows to allow for each combination of values ofjoin_cols
in the right DataFrame.join_cols (
Optional
[List
[str
]]) – Columns to perform join on. If None, a natural join is computed.join_on_nulls (
bool
) – If True, null values on corresponding join columns of both dataframes will be considered to be equal.
- static truncation_strategy_stability(truncation_strategy, threshold)#
Returns the stability for the given truncation strategy.
- Parameters:
truncation_strategy (TruncationStrategy)
- Return type:
- stability_function(d_in)#
Returns the smallest d_out satisfied by the transformation.
See the architecture overview for more information.
- Parameters:
d_in (Dict[Any, tmlt.core.utils.exact_number.ExactNumberInput]) – Distance between inputs under input_metric.
- Return type:
- __call__(dfs)#
Perform join.
- Parameters:
dfs (Dict[Any, pyspark.sql.DataFrame])
- Return type:
- stability_relation(d_in, d_out)#
Returns True only if close inputs produce close outputs.
See the privacy and stability tutorial (add link?) for more information.
- Parameters:
d_in (Any) – Distance between inputs under input_metric.
d_out (Any) – Distance between outputs under output_metric.
- Return type:
- __or__(other: Transformation) Transformation #
- __or__(other: tmlt.core.measurements.base.Measurement) tmlt.core.measurements.base.Measurement
Return this transformation chained with another component.
- class PrivateJoinOnKey(input_domain, input_metric, left_key, right_key, new_key, join_cols=None, join_on_nulls=False)#
Bases:
tmlt.core.transformations.base.Transformation
Join two private SparkDataFrames including a key column.
Example
>>> # Example input >>> print_sdf(left_spark_dataframe) A B X 0 a1 b1 2 1 a1 b1 3 2 a1 b1 5 3 a1 b2 -1 4 a1 b2 4 5 a2 b1 -5 >>> print_sdf(right_spark_dataframe) B C 0 b1 c1 1 b2 c2 2 b2 c3 >>> print_sdf(ignored_dataframe) B D 0 b1 d1 1 b2 d1 2 b2 d2 >>> # Create transformation >>> left_domain = SparkDataFrameDomain( ... { ... "A": SparkStringColumnDescriptor(), ... "B": SparkStringColumnDescriptor(), ... "X": SparkIntegerColumnDescriptor(), ... }, ... ) >>> assert left_spark_dataframe in left_domain >>> right_domain = SparkDataFrameDomain( ... { ... "B": SparkStringColumnDescriptor(), ... "C": SparkStringColumnDescriptor(), ... }, ... ) >>> assert right_spark_dataframe in right_domain >>> ignored_domain = SparkDataFrameDomain( ... { ... "B": SparkStringColumnDescriptor(), ... "D": SparkStringColumnDescriptor(), ... }, ... ) >>> assert ignored_dataframe in ignored_domain >>> private_join = PrivateJoinOnKey( ... input_domain=DictDomain( ... { ... "left": left_domain, ... "right": right_domain, ... "ignored": ignored_domain, ... } ... ), ... input_metric=AddRemoveKeys( ... { ... "left": "B", ... "right": "B", ... "ignored": "B", ... } ... ), ... left_key="left", ... right_key="right", ... new_key="joined", ... ) >>> input_dictionary = { ... "left": left_spark_dataframe, ... "right": right_spark_dataframe, ... "ignored": ignored_dataframe, ... } >>> # Apply transformation to data >>> output_dictionary = private_join(input_dictionary) >>> assert left_spark_dataframe is output_dictionary["left"] >>> assert right_spark_dataframe is output_dictionary["right"] >>> assert ignored_dataframe is output_dictionary["ignored"] >>> joined_dataframe = output_dictionary["joined"] >>> print_sdf(joined_dataframe) B A X C 0 b1 a1 2 c1 1 b1 a1 3 c1 2 b1 a1 5 c1 3 b1 a2 -5 c1 4 b2 a1 -1 c2 5 b2 a1 -1 c3 6 b2 a1 4 c2 7 b2 a1 4 c3
- Transformation Contract:
Input domain -
DictDomain
containing two or more SparkDataFrame domains.Output domain - The same as the input
DictDomain
with the addition of a newSparkDataFrameDomain
for the joined table.Input metric -
AddRemoveKeys
Output metric -
AddRemoveKeys
>>> private_join.input_domain DictDomain(key_to_domain={'left': SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False), 'X': SparkIntegerColumnDescriptor(allow_null=False, size=64)}), 'right': SparkDataFrameDomain(schema={'B': SparkStringColumnDescriptor(allow_null=False), 'C': SparkStringColumnDescriptor(allow_null=False)}), 'ignored': SparkDataFrameDomain(schema={'B': SparkStringColumnDescriptor(allow_null=False), 'D': SparkStringColumnDescriptor(allow_null=False)})}) >>> private_join.output_domain DictDomain(key_to_domain={'left': SparkDataFrameDomain(schema={'A': SparkStringColumnDescriptor(allow_null=False), 'B': SparkStringColumnDescriptor(allow_null=False), 'X': SparkIntegerColumnDescriptor(allow_null=False, size=64)}), 'right': SparkDataFrameDomain(schema={'B': SparkStringColumnDescriptor(allow_null=False), 'C': SparkStringColumnDescriptor(allow_null=False)}), 'ignored': SparkDataFrameDomain(schema={'B': SparkStringColumnDescriptor(allow_null=False), 'D': SparkStringColumnDescriptor(allow_null=False)}), 'joined': SparkDataFrameDomain(schema={'B': SparkStringColumnDescriptor(allow_null=False), 'A': SparkStringColumnDescriptor(allow_null=False), 'X': SparkIntegerColumnDescriptor(allow_null=False, size=64), 'C': SparkStringColumnDescriptor(allow_null=False)})}) >>> private_join.input_metric AddRemoveKeys(df_to_key_column={'left': 'B', 'right': 'B', 'ignored': 'B'}) >>> private_join.output_metric AddRemoveKeys(df_to_key_column={'left': 'B', 'right': 'B', 'ignored': 'B', 'joined': 'B'})
- Stability Guarantee:
PrivateJoinOnKey
’sstability_function()
returnsd_in
>>> private_join.stability_function(1) 1 >>> private_join.stability_function(2) 2
- Parameters:
input_domain (tmlt.core.domains.collections.DictDomain)
input_metric (tmlt.core.metrics.AddRemoveKeys)
left_key (Any)
right_key (Any)
new_key (Any)
join_cols (Optional[List[str]])
join_on_nulls (bool)
- property left_key: Any#
Returns key to left DataFrame.
- Return type:
Any
- property right_key: Any#
Returns key to right DataFrame.
- Return type:
Any
- property new_key: Any#
Returns key to output DataFrame.
- Return type:
Any
- property input_domain: tmlt.core.domains.base.Domain#
Return input domain for the measurement.
- Return type:
- property input_metric: tmlt.core.metrics.Metric#
Distance metric on input domain.
- Return type:
- property output_domain: tmlt.core.domains.base.Domain#
Return input domain for the measurement.
- Return type:
- property output_metric: tmlt.core.metrics.Metric#
Distance metric on input domain.
- Return type:
- __init__(input_domain, input_metric, left_key, right_key, new_key, join_cols=None, join_on_nulls=False)#
Constructor.
- Parameters:
input_domain (
DictDomain
) – Domain of the input dictionaries. Must containleft_key
andright_key
, but may also contain other keys.input_metric (
AddRemoveKeys
) – AddRemoveKeys metric for the input dictionaries. The left and right dataframes must use the same key column.left_key (
Any
) – Key for the left DataFrame.right_key (
Any
) – Key for the right DataFrame.new_key (
Any
) – Key for the output DataFrame.join_cols (
Optional
[List
[str
]]) – Columns to perform join on. If None, or empty, natural join is computed.join_on_nulls (
bool
) – If True, null values on corresponding join columns of both dataframes will be considered to be equal.
- stability_function(d_in)#
Returns the smallest d_out satisfied by the transformation.
See the architecture overview for more information on transformations.
- Parameters:
d_in (tmlt.core.utils.exact_number.ExactNumberInput) – Distance between inputs under input_metric.
- Return type:
- __call__(dfs)#
Perform join.
- Parameters:
dfs (Dict[Any, pyspark.sql.DataFrame])
- Return type:
Dict[Any, pyspark.sql.DataFrame]
- stability_relation(d_in, d_out)#
Returns True only if close inputs produce close outputs.
See the privacy and stability tutorial (add link?) for more information.
- Parameters:
d_in (Any) – Distance between inputs under input_metric.
d_out (Any) – Distance between outputs under output_metric.
- Return type:
- __or__(other: Transformation) Transformation #
- __or__(other: tmlt.core.measurements.base.Measurement) tmlt.core.measurements.base.Measurement
Return this transformation chained with another component.