Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions release/nightly_tests/dataset/tpch/tpch_q20.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import ray
from ray.data.aggregate import Sum
from ray.data.expressions import col
from common import parse_tpch_args, load_table, to_f64, run_tpch_benchmark


def main(args):
def benchmark_fn():
from datetime import datetime

# Q20: Potential Part Promotion Query
# Identify suppliers in a given nation having parts with excess
# inventory (available quantity exceeds 50% of the quantity shipped
# in a given year for forest-colored parts).
#
# Equivalent SQL:
# SELECT s_name, s_address
# FROM supplier, nation
# WHERE s_suppkey IN (
# SELECT ps_suppkey
# FROM partsupp
# WHERE ps_partkey IN (
# SELECT p_partkey
# FROM part
# WHERE p_name LIKE 'forest%'
# )
# AND ps_availqty > (
# SELECT 0.5 * SUM(l_quantity)
# FROM lineitem
# WHERE l_partkey = ps_partkey
# AND l_suppkey = ps_suppkey
# AND l_shipdate >= DATE '1994-01-01'
# AND l_shipdate < DATE '1995-01-01'
# )
# )
# AND s_nationkey = n_nationkey
# AND n_name = 'CANADA'
# ORDER BY s_name;
#
# Note:
# The innermost IN subquery is a simple part filter turned into a
# left_semi join. The correlated scalar subquery on lineitem is
# decorrelated by pre-aggregating SUM(l_quantity) grouped by
# (l_partkey, l_suppkey), then joining with partsupp on the same
# composite key and applying the threshold filter.

# Load tables with early projection.
supplier = load_table("supplier", args.sf).select_columns(
["s_suppkey", "s_name", "s_address", "s_nationkey"]
)
nation = load_table("nation", args.sf).select_columns(["n_nationkey", "n_name"])
part = load_table("part", args.sf).select_columns(["p_partkey", "p_name"])
partsupp = load_table("partsupp", args.sf).select_columns(
["ps_partkey", "ps_suppkey", "ps_availqty"]
)
lineitem = load_table("lineitem", args.sf).select_columns(
["l_partkey", "l_suppkey", "l_quantity", "l_shipdate"]
)

# Q20 parameters
color = "forest"
nation_name = "CANADA"
date_start = datetime(1994, 1, 1)
date_end = datetime(1995, 1, 1)

# ── Innermost subquery: forest parts ────────────────────────────
forest_parts = part.filter(
expr=col("p_name").str.starts_with(color)
).select_columns(["p_partkey"])

# ── Restrict partsupp to forest parts (IN subquery) ─────────────
ps_forest = partsupp.join(
forest_parts,
join_type="left_semi",
num_partitions=16,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoding num_partitions=16 is likely too low for Scale Factor 100 (100GB), where tables like lineitem and partsupp contain hundreds of millions of rows. This can lead to excessively large partitions (several GBs each), causing memory pressure or underutilization of the cluster. It is generally better to let Ray Data automatically determine the number of partitions or set it to a much higher value (e.g., 200+) for this scale.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like a convention across existing test.

on=("ps_partkey",),
right_on=("p_partkey",),
)

# ── Decorrelate scalar subquery on lineitem ─────────────────────
# Pre-aggregate: SUM(l_quantity) grouped by (l_partkey, l_suppkey)
# for lineitems in the target date range.
li_filtered = lineitem.filter(
expr=(col("l_shipdate") >= date_start) & (col("l_shipdate") < date_end)
)
li_agg = li_filtered.groupby(["l_partkey", "l_suppkey"]).aggregate(
Sum(on="l_quantity", alias_name="sum_qty")
)
li_agg = li_agg.with_column("sum_qty_f", to_f64(col("sum_qty"))).select_columns(
["l_partkey", "l_suppkey", "sum_qty_f"]
)

# ── Join partsupp with lineitem aggregate and apply threshold ───
ps_forest = ps_forest.with_column("ps_availqty_f", to_f64(col("ps_availqty")))
ps_li = ps_forest.join(
li_agg,
join_type="inner",
num_partitions=16,
on=("ps_partkey", "ps_suppkey"),
right_on=("l_partkey", "l_suppkey"),
)
qualified_ps = ps_li.filter(
expr=col("ps_availqty_f") > 0.5 * col("sum_qty_f")
).select_columns(["ps_suppkey"])

# ── Main pipeline: Canadian suppliers with qualifying parts ─────
nation_filtered = nation.filter(expr=col("n_name") == nation_name)
canadian_suppliers = supplier.join(
nation_filtered,
join_type="inner",
num_partitions=16,
on=("s_nationkey",),
right_on=("n_nationkey",),
).select_columns(["s_suppkey", "s_name", "s_address"])

result = canadian_suppliers.join(
qualified_ps,
join_type="left_semi",
num_partitions=16,
on=("s_suppkey",),
right_on=("ps_suppkey",),
)

_ = (
result.select_columns(["s_name", "s_address"])
.sort(key="s_name")
.materialize()
)

# Report arguments for the benchmark.
return vars(args)

run_tpch_benchmark("tpch_q20", benchmark_fn)


if __name__ == "__main__":
ray.init()
args = parse_tpch_args()
main(args)
161 changes: 161 additions & 0 deletions release/nightly_tests/dataset/tpch/tpch_q21.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import ray
from ray.data.aggregate import Count, CountDistinct
from ray.data.expressions import col
from common import parse_tpch_args, load_table, run_tpch_benchmark


def main(args):
def benchmark_fn():
# Q21: Suppliers Who Kept Orders Waiting Query
# Identify suppliers in a given nation whose shipments were received
# late, where at least one other supplier also filled the same order
# but none of those other suppliers delivered late.
#
# Equivalent SQL:
# SELECT s_name, COUNT(*) AS numwait
# FROM supplier, lineitem l1, orders, nation
# WHERE s_suppkey = l1.l_suppkey
# AND o_orderkey = l1.l_orderkey
# AND o_orderstatus = 'F'
# AND l1.l_receiptdate > l1.l_commitdate
# AND EXISTS (
# SELECT * FROM lineitem l2
# WHERE l2.l_orderkey = l1.l_orderkey
# AND l2.l_suppkey <> l1.l_suppkey
# )
# AND NOT EXISTS (
# SELECT * FROM lineitem l3
# WHERE l3.l_orderkey = l1.l_orderkey
# AND l3.l_suppkey <> l1.l_suppkey
# AND l3.l_receiptdate > l3.l_commitdate
# )
# AND s_nationkey = n_nationkey
# AND n_name = 'SAUDI ARABIA'
# GROUP BY s_name
# ORDER BY numwait DESC, s_name
# LIMIT 100;
#
# Note:
# The EXISTS and NOT EXISTS subqueries both use inequality predicates
# (l_suppkey <> l1.l_suppkey) which cannot be expressed as equi-join
# conditions. Instead we decorrelate them using pre-aggregated counts:
# - EXISTS (another supplier for the same order)
# ⟺ COUNT(DISTINCT l_suppkey) per order > 1
# - NOT EXISTS (no other LATE supplier for the same order)
# ⟺ COUNT(DISTINCT l_suppkey) among late lineitems per order == 1
# (since l1 itself is the only late supplier)

# Load tables with early projection.
supplier = load_table("supplier", args.sf).select_columns(
["s_suppkey", "s_name", "s_nationkey"]
)
lineitem = load_table("lineitem", args.sf).select_columns(
["l_orderkey", "l_suppkey", "l_receiptdate", "l_commitdate"]
)
orders = load_table("orders", args.sf).select_columns(
["o_orderkey", "o_orderstatus"]
)
nation = load_table("nation", args.sf).select_columns(["n_nationkey", "n_name"])

# Q21 parameters
nation_name = "SAUDI ARABIA"

# ── Pre-aggregate: distinct suppliers per order (EXISTS) ────────
# If an order has > 1 distinct supplier, there exists "another"
# supplier for any given supplier on that order.
# Filter early to reduce the right-side dataset size before join.
suppliers_per_order = (
lineitem.select_columns(["l_orderkey", "l_suppkey"])
.groupby("l_orderkey")
.aggregate(CountDistinct(on="l_suppkey", alias_name="num_suppliers"))
.filter(expr=col("num_suppliers") > 1)
)

# ── Pre-aggregate: distinct LATE suppliers per order (NOT EXISTS) ─
# Late lineitem: l_receiptdate > l_commitdate.
# Materialize to avoid recomputing the filter in both the
# late_suppliers_per_order branch and the main pipeline
# (Ray Data has no CSE).
late_lineitem = (
lineitem.filter(expr=col("l_receiptdate") > col("l_commitdate"))
.select_columns(["l_orderkey", "l_suppkey"])
.materialize()
)

late_suppliers_per_order = (
late_lineitem.groupby("l_orderkey")
.aggregate(CountDistinct(on="l_suppkey", alias_name="num_late_suppliers"))
.filter(expr=col("num_late_suppliers") == 1)
)

# ── Build main pipeline ─────────────────────────────────────────
# Saudi suppliers
saudi_nation = nation.filter(expr=col("n_name") == nation_name)
saudi_suppliers = supplier.join(
saudi_nation,
join_type="inner",
num_partitions=16,
on=("s_nationkey",),
right_on=("n_nationkey",),
).select_columns(["s_suppkey", "s_name"])

# Failed orders
failed_orders = orders.filter(expr=col("o_orderstatus") == "F").select_columns(
["o_orderkey"]
)

# Late lineitem joined with failed orders (l1 base rows)
ds = late_lineitem.join(
failed_orders,
join_type="left_semi",
num_partitions=16,
on=("l_orderkey",),
right_on=("o_orderkey",),
)

# Join with Saudi suppliers
ds = ds.join(
saudi_suppliers,
join_type="inner",
num_partitions=16,
on=("l_suppkey",),
right_on=("s_suppkey",),
)

# EXISTS: another supplier exists for this order (num_suppliers > 1)
# Filter already pushed down to suppliers_per_order.
ds = ds.join(
suppliers_per_order,
join_type="inner",
num_partitions=16,
on=("l_orderkey",),
)

# NOT EXISTS: no other late supplier (num_late_suppliers == 1)
# Filter already pushed down to late_suppliers_per_order.
ds = ds.join(
late_suppliers_per_order,
join_type="inner",
num_partitions=16,
on=("l_orderkey",),
)

# Group by supplier name, count, sort, and limit.
_ = (
ds.groupby("s_name")
.aggregate(Count(alias_name="numwait"))
.sort(key=["numwait", "s_name"], descending=[True, False])
.limit(100)
.materialize()
)

# Report arguments for the benchmark.
return vars(args)

run_tpch_benchmark("tpch_q21", benchmark_fn)


if __name__ == "__main__":
ray.init()
args = parse_tpch_args()
main(args)
96 changes: 96 additions & 0 deletions release/nightly_tests/dataset/tpch/tpch_q22.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import ray
from ray.data.aggregate import Count, Sum
from ray.data.expressions import col
from common import parse_tpch_args, load_table, to_f64, run_tpch_benchmark


def main(args):
def benchmark_fn():
# Q22: Global Sales Opportunity Query
# Identify geographic areas where there are customers who may be
# likely to make a purchase (above-average balance, no existing orders).
#
# Equivalent SQL:
# SELECT cntrycode, COUNT(*) AS numcust,
# SUM(c_acctbal) AS totacctbal
# FROM (
# SELECT SUBSTRING(c_phone FROM 1 FOR 2) AS cntrycode,
# c_acctbal
# FROM customer
# WHERE SUBSTRING(c_phone FROM 1 FOR 2)
# IN ('13','31','23','29','30','18','17')
# AND c_acctbal > (
# SELECT AVG(c_acctbal)
# FROM customer
# WHERE c_acctbal > 0.00
# AND SUBSTRING(c_phone FROM 1 FOR 2)
# IN ('13','31','23','29','30','18','17')
# )
# AND NOT EXISTS (
# SELECT * FROM orders WHERE o_custkey = c_custkey
# )
# ) AS custsale
# GROUP BY cntrycode
# ORDER BY cntrycode;
#
# Note:
# The scalar AVG subquery is computed first as a plain float via
# Dataset.mean(). The NOT EXISTS is implemented as a left_anti join.

# Load tables with early projection.
customer = load_table("customer", args.sf).select_columns(
["c_custkey", "c_phone", "c_acctbal"]
)
orders = load_table("orders", args.sf).select_columns(["o_custkey"])

# Q22 parameters
codes_regex = "^(13|31|23|29|30|18|17)$"

# Derive country code and cast acctbal to float64.
customer = customer.with_column("cntrycode", col("c_phone").str.slice(0, 2))
customer = customer.with_column("c_acctbal_f", to_f64(col("c_acctbal")))

# Filter to target country codes.
customer_filtered = customer.filter(
expr=col("cntrycode").str.match_regex(codes_regex)
)

# Scalar AVG subquery: average balance among positive-balance
# customers in the target country codes.
avg_acctbal = customer_filtered.filter(expr=col("c_acctbal_f") > 0.0).mean(
"c_acctbal_f"
)

# Keep customers whose balance exceeds the average.
custsale = customer_filtered.filter(expr=col("c_acctbal_f") > avg_acctbal)

# NOT EXISTS: exclude customers who have placed orders.
custsale = custsale.join(
orders,
join_type="left_anti",
num_partitions=16,
Comment thread
ryankert01 marked this conversation as resolved.
on=("c_custkey",),
right_on=("o_custkey",),
)

# Group by country code, aggregate count and total balance.
_ = (
custsale.groupby("cntrycode")
.aggregate(
Count(alias_name="numcust"),
Sum(on="c_acctbal_f", alias_name="totacctbal"),
)
.sort(key="cntrycode")
.materialize()
)

# Report arguments for the benchmark.
return vars(args)

run_tpch_benchmark("tpch_q22", benchmark_fn)


if __name__ == "__main__":
ray.init()
args = parse_tpch_args()
main(args)
Loading
Loading