_output_schema_visitor#

Defines a visitor for determining the output schemas of query expressions.

Classes#

OutputSchemaVisitor

A visitor to get the output schema of a query expression.

class OutputSchemaVisitor(catalog)#

Bases: tmlt.analytics.query_expr.QueryExprVisitor

A visitor to get the output schema of a query expression.

Methods#

visit_private_source()

Return the resulting schema from evaluating a PrivateSource.

visit_rename()

Returns the resulting schema from evaluating a Rename.

visit_filter()

Returns the resulting schema from evaluating a Filter.

visit_select()

Returns the resulting schema from evaluating a Select.

visit_map()

Returns the resulting schema from evaluating a Map.

visit_flat_map()

Returns the resulting schema from evaluating a FlatMap.

visit_join_private()

Returns the resulting schema from evaluating a JoinPrivate.

visit_join_public()

Returns the resulting schema from evaluating a JoinPublic.

visit_replace_null_and_nan()

Returns the resulting schema from evaluating a ReplaceNullAndNan.

visit_replace_infinity()

Returns the resulting schema from evaluating a ReplaceInfinity.

visit_drop_null_and_nan()

Returns the resulting schema from evaluating a DropNullAndNan.

visit_drop_infinity()

Returns the resulting schema from evaluating a DropInfinity.

visit_enforce_constraint()

Returns the resulting schema from evaluating an EnforceConstraint.

visit_get_groups()

Returns the resulting schema from GetGroups.

visit_groupby_count()

Returns the resulting schema from evaluating a GroupByCount.

visit_groupby_count_distinct()

Returns the resulting schema from evaluating a GroupByCountDistinct.

visit_groupby_quantile()

Returns the resulting schema from evaluating a GroupByQuantile.

visit_groupby_bounded_sum()

Returns the resulting schema from evaluating a GroupByBoundedSum.

visit_groupby_bounded_average()

Returns the resulting schema from evaluating a GroupByBoundedAverage.

visit_groupby_bounded_variance()

Returns the resulting schema from evaluating a GroupByBoundedVariance.

visit_groupby_bounded_stdev()

Returns the resulting schema from evaluating a GroupByBoundedSTDEV.

Parameters

catalog (tmlt.analytics._catalog.Catalog) –

__init__(catalog)#

Visitor constructor.

Parameters

catalog (CatalogCatalog) – The catalog defining schemas and relations between tables.

visit_private_source(expr)#

Return the resulting schema from evaluating a PrivateSource.

Example

>>> catalog = Catalog()
>>> catalog.add_private_table(
...     source_id="private",
...     col_types={"A": ColumnType.VARCHAR, "B": ColumnType.INTEGER},
... )
>>> output_schema_visitor = OutputSchemaVisitor(catalog)
>>> query = PrivateSource("private")
>>> query.accept(output_schema_visitor).column_types
{'A': 'VARCHAR', 'B': 'INTEGER'}
Parameters

expr (tmlt.analytics.query_expr.PrivateSource) –

Return type

tmlt.analytics._schema.Schema

visit_rename(expr)#

Returns the resulting schema from evaluating a Rename.

Example

>>> catalog = Catalog()
>>> catalog.add_private_table(
...     source_id="private",
...     col_types={"A": ColumnType.VARCHAR, "B": ColumnType.INTEGER},
... )
>>> output_schema_visitor = OutputSchemaVisitor(catalog)
>>> query = Rename(PrivateSource("private"), {"B": "C"})
>>> query.accept(output_schema_visitor).column_types
{'A': 'VARCHAR', 'C': 'INTEGER'}
Parameters

expr (tmlt.analytics.query_expr.Rename) –

Return type

tmlt.analytics._schema.Schema

visit_filter(expr)#

Returns the resulting schema from evaluating a Filter.

Example

>>> catalog = Catalog()
>>> catalog.add_private_table(
...     source_id="private",
...     col_types={"A": ColumnType.VARCHAR, "B": ColumnType.INTEGER},
... )
>>> output_schema_visitor = OutputSchemaVisitor(catalog)
>>> query = Filter(PrivateSource("private"), 'B > 10')
>>> query.accept(output_schema_visitor).column_types
{'A': 'VARCHAR', 'B': 'INTEGER'}
Parameters

expr (tmlt.analytics.query_expr.Filter) –

Return type

tmlt.analytics._schema.Schema

visit_select(expr)#

Returns the resulting schema from evaluating a Select.

Example

>>> catalog = Catalog()
>>> catalog.add_private_table(
...     source_id="private",
...     col_types={"A": ColumnType.VARCHAR, "B": ColumnType.INTEGER},
... )
>>> output_schema_visitor = OutputSchemaVisitor(catalog)
>>> query = Select(PrivateSource("private"), ["A"])
>>> query.accept(output_schema_visitor).column_types
{'A': 'VARCHAR'}
Parameters

expr (tmlt.analytics.query_expr.Select) –

Return type

tmlt.analytics._schema.Schema

visit_map(expr)#

Returns the resulting schema from evaluating a Map.

Example

>>> catalog = Catalog()
>>> catalog.add_private_table(
...     source_id="private",
...     col_types={
...         "A": ColumnType.VARCHAR,
...         "B": ColumnType.INTEGER,
...     },
... )
>>> output_schema_visitor = OutputSchemaVisitor(catalog)
>>> query1 = Map( # Augment = False example
...     child=PrivateSource("private"),
...     f=lambda row: {"C": row["B"] + 1, "D": "A"},
...     schema_new_columns=Schema(
...         {"C": ColumnType.INTEGER, "D": ColumnType.VARCHAR}
...     ),
...     augment=False,
... )
>>> query1.accept(output_schema_visitor).column_types
{'C': 'INTEGER', 'D': 'VARCHAR'}
>>> query2 = Map( # Augment = True example
...     child=PrivateSource("private"),
...     f=lambda row: {"C": row["B"] + 1, "D": "A"},
...     schema_new_columns=Schema(
...         {"C": ColumnType.INTEGER, "D": ColumnType.VARCHAR}
...     ),
...     augment=True,
... )
>>> query2.accept(output_schema_visitor).column_types
{'A': 'VARCHAR', 'B': 'INTEGER', 'C': 'INTEGER', 'D': 'VARCHAR'}
Parameters

expr (tmlt.analytics.query_expr.Map) –

Return type

tmlt.analytics._schema.Schema

visit_flat_map(expr)#

Returns the resulting schema from evaluating a FlatMap.

Example

>>> catalog = Catalog()
>>> catalog.add_private_table(
...     source_id="private",
...     col_types={"A": ColumnType.VARCHAR, "B": ColumnType.INTEGER},
... )
>>> output_schema_visitor = OutputSchemaVisitor(catalog)
>>> query1 = FlatMap( # Augment = False example
...     child=PrivateSource("private"),
...     f=lambda row: [{"C": row["B"]}, {"C": row["B"] + 1}],
...     schema_new_columns=Schema({"C": ColumnType.INTEGER}),
...     augment=False,
...     max_rows=2,
... )
>>> query1.accept(output_schema_visitor).column_types
{'C': 'INTEGER'}
>>> query2 = FlatMap( # Augment = True example
...     child=PrivateSource("private"),
...     f=lambda row: [{"C": row["B"]}, {"C": row["B"] + 1}],
...     schema_new_columns=Schema({"C": ColumnType.INTEGER}),
...     augment=True,
...     max_rows=2,
... )
>>> query2.accept(output_schema_visitor).column_types
{'A': 'VARCHAR', 'B': 'INTEGER', 'C': 'INTEGER'}
>>> query3 = FlatMap( # Grouping example
...     child=PrivateSource("private"),
...     f=lambda row: [{"C": row["B"]}, {"C": row["B"] + 1}],
...     schema_new_columns=Schema(
...         {"C": ColumnType.INTEGER}, grouping_column="C",
...     ),
...     augment=True,
...     max_rows=2,
... )
>>> query3.accept(output_schema_visitor).column_types
{'A': 'VARCHAR', 'B': 'INTEGER', 'C': 'INTEGER'}
Parameters

expr (tmlt.analytics.query_expr.FlatMap) –

Return type

tmlt.analytics._schema.Schema

visit_join_private(expr)#

Returns the resulting schema from evaluating a JoinPrivate.

The ordering of output columns are:

  1. The join columns

  2. Columns that are only in the left table

  3. Columns that are only in the right table

  4. Columns that are in both tables, but not included in the join columns. These columns are included with _left and _right suffixes.

Example

>>> catalog = Catalog()
>>> catalog.add_private_table(
...     source_id="left_source",
...     col_types={
...         "left_only": ColumnType.DECIMAL,
...         "common1": ColumnType.INTEGER,
...         "common2": ColumnType.VARCHAR,
...         "common3": ColumnType.INTEGER
...     },
... )
>>> catalog.add_private_table(
...     source_id="right_source",
...     col_types={
...         "common1": ColumnType.INTEGER,
...         "common2": ColumnType.VARCHAR,
...         "common3": ColumnType.INTEGER,
...         "right_only": ColumnType.VARCHAR,
...    },
... )
>>> output_schema_visitor = OutputSchemaVisitor(catalog)
>>> # join_columns default behavior is ["common1", "common2", "common3"]
>>> query1 = JoinPrivate(
...     child=PrivateSource("left_source"),
...     right_operand_expr=PrivateSource("right_source"),
...     truncation_strategy_left=TruncationStrategy.DropExcess(1),
...     truncation_strategy_right=TruncationStrategy.DropExcess(1),
... )
>>> query1.accept(output_schema_visitor).column_types
{'common1': 'INTEGER', 'common2': 'VARCHAR', 'common3': 'INTEGER', 'left_only': 'DECIMAL', 'right_only': 'VARCHAR'}
>>> query2 = JoinPrivate(
...     child=PrivateSource("left_source"),
...     right_operand_expr=PrivateSource("right_source"),
...     truncation_strategy_left=TruncationStrategy.DropExcess(1),
...     truncation_strategy_right=TruncationStrategy.DropExcess(1),
...     join_columns=["common3"],
... )
>>> query2.accept(output_schema_visitor).column_types
{'common3': 'INTEGER', 'left_only': 'DECIMAL', 'common1_left': 'INTEGER', 'common2_left': 'VARCHAR', 'common1_right': 'INTEGER', 'common2_right': 'VARCHAR', 'right_only': 'VARCHAR'}
Parameters

expr (tmlt.analytics.query_expr.JoinPrivate) –

Return type

tmlt.analytics._schema.Schema

visit_join_public(expr)#

Returns the resulting schema from evaluating a JoinPublic.

Has analogous behavior to OutputSchemaVisitor.visit_join_private(), where the private table is the left table.

Example

>>> catalog = Catalog()
>>> catalog.add_private_table(
...     source_id="private",
...     col_types={"A": ColumnType.VARCHAR, "B": ColumnType.INTEGER},
... )
>>> output_schema_visitor = OutputSchemaVisitor(catalog)
>>> catalog.add_public_table(
...     "public", {"B": ColumnType.INTEGER, "C": ColumnType.DECIMAL}
... )
>>> query = JoinPublic(
...    child=PrivateSource("private"), public_table="public"
... )
>>> query.accept(output_schema_visitor).column_types
{'B': 'INTEGER', 'A': 'VARCHAR', 'C': 'DECIMAL'}
Parameters

expr (tmlt.analytics.query_expr.JoinPublic) –

Return type

tmlt.analytics._schema.Schema

visit_replace_null_and_nan(expr)#

Returns the resulting schema from evaluating a ReplaceNullAndNan.

Parameters

expr (tmlt.analytics.query_expr.ReplaceNullAndNan) –

Return type

tmlt.analytics._schema.Schema

visit_replace_infinity(expr)#

Returns the resulting schema from evaluating a ReplaceInfinity.

Parameters

expr (tmlt.analytics.query_expr.ReplaceInfinity) –

Return type

tmlt.analytics._schema.Schema

visit_drop_null_and_nan(expr)#

Returns the resulting schema from evaluating a DropNullAndNan.

Parameters

expr (tmlt.analytics.query_expr.DropNullAndNan) –

Return type

tmlt.analytics._schema.Schema

visit_drop_infinity(expr)#

Returns the resulting schema from evaluating a DropInfinity.

Parameters

expr (tmlt.analytics.query_expr.DropInfinity) –

Return type

tmlt.analytics._schema.Schema

visit_enforce_constraint(expr)#

Returns the resulting schema from evaluating an EnforceConstraint.

Parameters

expr (tmlt.analytics.query_expr.EnforceConstraint) –

Return type

tmlt.analytics._schema.Schema

visit_get_groups(expr)#

Returns the resulting schema from GetGroups.

Parameters

expr (tmlt.analytics.query_expr.GetGroups) –

Return type

tmlt.analytics._schema.Schema

visit_groupby_count(expr)#

Returns the resulting schema from evaluating a GroupByCount.

Example

>>> catalog = Catalog()
>>> catalog.add_private_table(
...     source_id="private",
...     col_types={"A": ColumnType.VARCHAR, "B": ColumnType.INTEGER},
... )
>>> output_schema_visitor = OutputSchemaVisitor(catalog)
>>> query = GroupByCount(
...     child=PrivateSource("private"),
...     groupby_keys=KeySet.from_dict({"A": ["a1", "a2", "a3"]}),
...     output_column="count",
... )
>>> query.accept(output_schema_visitor).column_types
{'A': 'VARCHAR', 'count': 'INTEGER'}
Parameters

expr (tmlt.analytics.query_expr.GroupByCount) –

Return type

tmlt.analytics._schema.Schema

visit_groupby_count_distinct(expr)#

Returns the resulting schema from evaluating a GroupByCountDistinct.

Example

>>> catalog = Catalog()
>>> catalog.add_private_table(
...     source_id="private",
...     col_types={"A": ColumnType.VARCHAR, "B": ColumnType.INTEGER},
... )
>>> output_schema_visitor = OutputSchemaVisitor(catalog)
>>> query = GroupByCountDistinct(
...     child=PrivateSource("private"),
...     groupby_keys=KeySet.from_dict({"A": ["a1", "a2", "a3"]}),
...     output_column="count_distinct",
... )
>>> query.accept(output_schema_visitor).column_types
{'A': 'VARCHAR', 'count_distinct': 'INTEGER'}
Parameters

expr (tmlt.analytics.query_expr.GroupByCountDistinct) –

Return type

tmlt.analytics._schema.Schema

visit_groupby_quantile(expr)#

Returns the resulting schema from evaluating a GroupByQuantile.

Example

>>> catalog = Catalog()
>>> catalog.add_private_table(
...     source_id="private",
...     col_types={"A": ColumnType.VARCHAR, "B": ColumnType.INTEGER},
... )
>>> output_schema_visitor = OutputSchemaVisitor(catalog)
>>> query = GroupByQuantile(
...     child=PrivateSource("private"),
...     groupby_keys=KeySet.from_dict({"A": ["a1", "a2", "a3"]}),
...     measure_column="B",
...     quantile=0.5,
...     low=0,
...     high=10,
...     output_column="quantile",
... )
>>> query.accept(output_schema_visitor).column_types
{'A': 'VARCHAR', 'quantile': 'DECIMAL'}
Parameters

expr (tmlt.analytics.query_expr.GroupByQuantile) –

Return type

tmlt.analytics._schema.Schema

visit_groupby_bounded_sum(expr)#

Returns the resulting schema from evaluating a GroupByBoundedSum.

Example

>>> catalog = Catalog()
>>> catalog.add_private_table(
...     source_id="private",
...     col_types={"A": ColumnType.VARCHAR, "B": ColumnType.INTEGER},
... )
>>> output_schema_visitor = OutputSchemaVisitor(catalog)
>>> query = GroupByBoundedSum(
...     child=PrivateSource("private"),
...     groupby_keys=KeySet.from_dict({"A": ["a1", "a2", "a3"]}),
...     measure_column="B",
...     low=0,
...     high=10,
...     output_column="sum",
... )
>>> query.accept(output_schema_visitor).column_types
{'A': 'VARCHAR', 'sum': 'INTEGER'}
Parameters

expr (tmlt.analytics.query_expr.GroupByBoundedSum) –

Return type

tmlt.analytics._schema.Schema

visit_groupby_bounded_average(expr)#

Returns the resulting schema from evaluating a GroupByBoundedAverage.

Example

>>> catalog = Catalog()
>>> catalog.add_private_table(
...     source_id="private",
...     col_types={"A": ColumnType.VARCHAR, "B": ColumnType.INTEGER},
... )
>>> output_schema_visitor = OutputSchemaVisitor(catalog)
>>> query = GroupByBoundedAverage(
...     child=PrivateSource("private"),
...     groupby_keys=KeySet.from_dict({"A": ["a1", "a2", "a3"]}),
...     measure_column="B",
...     low=0,
...     high=10,
...     output_column="average",
... )
>>> query.accept(output_schema_visitor).column_types
{'A': 'VARCHAR', 'average': 'DECIMAL'}
Parameters

expr (tmlt.analytics.query_expr.GroupByBoundedAverage) –

Return type

tmlt.analytics._schema.Schema

visit_groupby_bounded_variance(expr)#

Returns the resulting schema from evaluating a GroupByBoundedVariance.

Example

>>> catalog = Catalog()
>>> catalog.add_private_table(
...     source_id="private",
...     col_types={"A": ColumnType.VARCHAR, "B": ColumnType.INTEGER},
... )
>>> output_schema_visitor = OutputSchemaVisitor(catalog)
>>> query = GroupByBoundedAverage(
...     child=PrivateSource("private"),
...     groupby_keys=KeySet.from_dict({"A": ["a1", "a2", "a3"]}),
...     measure_column="B",
...     low=0,
...     high=10,
...     output_column="variance",
... )
>>> query.accept(output_schema_visitor).column_types
{'A': 'VARCHAR', 'variance': 'DECIMAL'}
Parameters

expr (tmlt.analytics.query_expr.GroupByBoundedVariance) –

Return type

tmlt.analytics._schema.Schema

visit_groupby_bounded_stdev(expr)#

Returns the resulting schema from evaluating a GroupByBoundedSTDEV.

Example

>>> catalog = Catalog()
>>> catalog.add_private_table(
...     source_id="private",
...     col_types={"A": ColumnType.VARCHAR, "B": ColumnType.INTEGER},
... )
>>> output_schema_visitor = OutputSchemaVisitor(catalog)
>>> query = GroupByBoundedSTDEV(
...     child=PrivateSource("private"),
...     groupby_keys=KeySet.from_dict({"A": ["a1", "a2", "a3"]}),
...     measure_column="B",
...     low=0,
...     high=10,
...     output_column="stdev",
... )
>>> query.accept(output_schema_visitor).column_types
{'A': 'VARCHAR', 'stdev': 'DECIMAL'}
Parameters

expr (tmlt.analytics.query_expr.GroupByBoundedSTDEV) –

Return type

tmlt.analytics._schema.Schema