Sampling
Sample rows from a DataFrame or Arrow Table using various sampling methods.
Sampling
Processing
This function samples a subset of rows from an input data structure (Pandas DataFrame, Polars DataFrame, or Arrow Table). It supports various sampling methods, including fixed row count sampling (first N, last N, random N) and ratio-based sampling (random ratio, stratified ratio), while allowing for column selection and optional sorting prior to sampling.
Inputs
- data
- The input DataFrame or Arrow Table containing the data to be sampled.
- num rows (optional)
- The exact number of rows desired for fixed-size sampling methods. If provided, this overrides the 'Number of Rows' option.
- columns (optional)
- A list of column names to include in the output dataset. If provided, this overrides the 'Columns to Select' option.
Inputs Types
| Input | Types |
|---|---|
data |
DataFrame, ArrowTable |
num rows |
Int |
columns |
List |
You can check the list of supported types here: Available Type Hints.
Outputs
- result
- The resulting data structure containing the sampled subset of rows, formatted according to the selected output format.
Outputs Types
| Output | Types |
|---|---|
result |
DataFrame, ArrowTable |
You can check the list of supported types here: Available Type Hints.
Options
The Sampling brick contains some changeable options:
- Apply Sampling
- If enabled, sampling rules are applied. If disabled, the entire input dataset is returned (subject only to column selection).
- Sampling Method
- Specifies the strategy for selecting rows. Choices include
first_records,last_records,random_fixed(fixed count),random_ratio(percentage),stratified_fixed, andstratified_ratio. - Number of Rows
- The target number of rows to retrieve for fixed-size sampling methods. This value is ignored if a ratio method is selected or if
num rowsis provided as an input. - Sampling Ratio (%)
- The approximate percentage of rows to retrieve for ratio-based sampling methods.
- Stratify Column
- The name of the column used to ensure balanced representation across groups when using stratified sampling methods.
- Sort Columns
- Defines columns and directions (ASC/DESC) used to sort the data before sampling is applied (especially relevant for
first_recordsandlast_records). - Columns to Select
- A list of column names to be included in the output. If left empty, all columns are selected.
- Output Format
- Specifies the desired format for the returned data structure:
pandas(DataFrame),polars(DataFrame), orarrow(Arrow Table). - Verbose
- Enables detailed logging of the sampling operations and resulting query execution.
import logging
import duckdb
import pandas as pd
import polars as pl
import pyarrow as pa
from coded_flows.types import Union, DataFrame, ArrowTable, Int, List
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def _coalesce(*values):
return next((v for v in values if v is not None))
def _sanitize_identifier(identifier):
"""
Sanitize SQL identifier by escaping special characters.
Handles double quotes and other problematic characters.
"""
return identifier.replace('"', '""')
def sampling(
data: Union[DataFrame, ArrowTable],
num_rows: Int = None,
columns: List = None,
options=None,
) -> Union[DataFrame, ArrowTable]:
brick_display_name = "Sample Rows"
options = options or {}
verbose = options.get("verbose", True)
apply_limit = options.get("apply_limit", True)
sampling_method = options.get("sampling_method", "first_records")
num_rows = _coalesce(num_rows, options.get("num_rows", 10000))
ratio = options.get("ratio", 10)
stratify_column = options.get("stratify_column", "")
sort_columns = options.get("sort_columns", [])
columns = _coalesce(columns, options.get("columns", []))
output_format = options.get("output_format", "pandas")
result = None
if not isinstance(columns, list) and (
not all((isinstance(c, str) for c in columns))
):
verbose and logger.error(
f"[{brick_display_name}] Invalid columns format! Expected a list."
)
raise ValueError("Columns must be provided as a list!")
try:
verbose and logger.info(f"[{brick_display_name}] Starting sampling operation.")
data_type = None
if isinstance(data, pd.DataFrame):
data_type = "pandas"
elif isinstance(data, pl.DataFrame):
data_type = "polars"
elif isinstance(data, (pa.Table, pa.lib.Table)):
data_type = "arrow"
if data_type is None:
verbose and logger.error(
f"[{brick_display_name}] Input data must be a pandas DataFrame, Polars DataFrame, or Arrow Table."
)
raise ValueError(
"Input data must be a pandas DataFrame, Polars DataFrame, or Arrow Table"
)
verbose and logger.info(
f"[{brick_display_name}] Detected input format: {data_type}."
)
conn = duckdb.connect(":memory:")
conn.register("input_table", data)
column_info = conn.execute("DESCRIBE input_table").fetchall()
all_columns = [col[0] for col in column_info]
total_rows = conn.execute("SELECT COUNT(*) FROM input_table").fetchone()[0]
verbose and logger.info(
f"[{brick_display_name}] Total rows in input data: {total_rows}."
)
if columns and len(columns) > 0:
missing_columns = [col for col in columns if col not in all_columns]
if missing_columns:
verbose and logger.error(
f"[{brick_display_name}] Columns not found in data: {missing_columns}"
)
conn.close()
raise ValueError(f"Columns not found in data: {missing_columns}")
selected_columns = columns
verbose and logger.info(
f"[{brick_display_name}] Selecting specific columns: {columns}."
)
else:
selected_columns = all_columns
verbose and logger.info(
f"[{brick_display_name}] Selecting all columns ({len(all_columns)} columns)."
)
sanitized_columns = [
f'"{_sanitize_identifier(col)}"' for col in selected_columns
]
select_clause = ", ".join(sanitized_columns)
if not apply_limit:
verbose and logger.info(
f"[{brick_display_name}] No sampling applied. Returning all {total_rows} rows."
)
query = f"SELECT {select_clause} FROM input_table"
else:
verbose and logger.info(
f"[{brick_display_name}] Applying sampling method: {sampling_method}."
)
if sampling_method == "first_records":
actual_num_rows = min(num_rows, total_rows)
verbose and logger.info(
f"[{brick_display_name}] Taking first {actual_num_rows} records."
)
query = (
f"SELECT {select_clause} FROM input_table LIMIT {actual_num_rows}"
)
elif sampling_method == "last_records":
actual_num_rows = min(num_rows, total_rows)
verbose and logger.info(
f"[{brick_display_name}] Taking last {actual_num_rows} records."
)
offset = max(0, total_rows - actual_num_rows)
query = f"SELECT {select_clause} FROM input_table OFFSET {offset}"
elif sampling_method == "random_fixed":
actual_num_rows = min(num_rows, total_rows)
verbose and logger.info(
f"[{brick_display_name}] Random sampling {actual_num_rows} records."
)
query = f"SELECT {select_clause} FROM input_table ORDER BY RANDOM() LIMIT {actual_num_rows}"
elif sampling_method == "random_ratio":
ratio_decimal = ratio / 100.0
verbose and logger.info(
f"[{brick_display_name}] Random sampling approximately {ratio}% of records."
)
query = f"SELECT {select_clause} FROM input_table WHERE RANDOM() <= {ratio_decimal}"
elif sampling_method == "stratified_fixed":
if not stratify_column:
verbose and logger.error(
f"[{brick_display_name}] Stratified sampling requires a stratify column."
)
conn.close()
raise ValueError("Stratified sampling requires a stratify column")
if stratify_column not in all_columns:
verbose and logger.error(
f"[{brick_display_name}] Stratify column '{stratify_column}' not found in data."
)
conn.close()
raise ValueError(
f"Stratify column '{stratify_column}' not found in data"
)
sanitized_stratify = _sanitize_identifier(stratify_column)
verbose and logger.info(
f"[{brick_display_name}] Stratified sampling {num_rows} records by column '{stratify_column}'."
)
qualified_select_parts = []
for col in selected_columns:
sanitized_col = _sanitize_identifier(col)
qualified_select_parts.append(f'r."{sanitized_col}"')
qualified_select_clause = ", ".join(qualified_select_parts)
query = f'\n WITH stratum_counts AS (\n SELECT "{sanitized_stratify}", COUNT(*) as cnt\n FROM input_table\n GROUP BY "{sanitized_stratify}"\n ),\n stratum_weights AS (\n SELECT "{sanitized_stratify}", \n CAST(cnt AS DOUBLE) / SUM(cnt) OVER () as weight,\n cnt\n FROM stratum_counts\n ),\n stratum_targets AS (\n SELECT "{sanitized_stratify}",\n -- FIX: Removed GREATEST(1, ...) to strictly respect num_rows.\n -- This prioritizes the total sample size over guaranteeing\n -- representation for every single stratum.\n -- Small strata may now correctly receive 0 samples.\n CAST(ROUND(weight * {num_rows}) AS INTEGER) as target_cnt,\n cnt\n FROM stratum_weights\n ),\n ranked_data AS (\n SELECT *, \n ROW_NUMBER() OVER (PARTITION BY "{sanitized_stratify}" ORDER BY RANDOM()) as rn\n FROM input_table\n )\n SELECT {qualified_select_clause}\n FROM ranked_data r\n JOIN stratum_targets t ON r."{sanitized_stratify}" = t."{sanitized_stratify}"\n WHERE r.rn <= t.target_cnt\n '
elif sampling_method == "stratified_ratio":
if not stratify_column:
verbose and logger.error(
f"[{brick_display_name}] Stratified sampling requires a stratify column."
)
conn.close()
raise ValueError("Stratified sampling requires a stratify column")
if stratify_column not in all_columns:
verbose and logger.error(
f"[{brick_display_name}] Stratify column '{stratify_column}' not found in data."
)
conn.close()
raise ValueError(
f"Stratify column '{stratify_column}' not found in data"
)
sanitized_stratify = _sanitize_identifier(stratify_column)
ratio_decimal = ratio / 100.0
verbose and logger.info(
f"[{brick_display_name}] Stratified sampling approximately {ratio}% of records by column '{stratify_column}'."
)
query = f'\n WITH ranked_data AS (\n SELECT *, \n ROW_NUMBER() OVER (PARTITION BY "{sanitized_stratify}" ORDER BY RANDOM()) as rn,\n COUNT(*) OVER (PARTITION BY "{sanitized_stratify}") as stratum_cnt\n FROM input_table\n )\n SELECT {select_clause}\n FROM ranked_data\n -- FIX: Removed GREATEST(1, ...) to strictly respect the ratio.\n -- This prioritizes the proportional ratio over guaranteeing\n -- representation for every single stratum.\n -- Small strata may now correctly receive 0 samples.\n WHERE rn <= CAST(ROUND(stratum_cnt * {ratio_decimal}) AS INTEGER)\n '
else:
verbose and logger.error(
f"[{brick_display_name}] Unknown sampling method: {sampling_method}"
)
conn.close()
raise ValueError(f"Unknown sampling method: {sampling_method}")
if (
sort_columns
and len(sort_columns) > 0
and (
sampling_method
not in ["random_fixed", "stratified_fixed", "stratified_ratio"]
)
):
sort_col_names = [item["key"] for item in sort_columns]
missing_sort_columns = [
col for col in sort_col_names if col not in all_columns
]
if missing_sort_columns:
verbose and logger.error(
f"[{brick_display_name}] Sort columns not found in data: {missing_sort_columns}"
)
conn.close()
raise ValueError(
f"Sort columns not found in data: {missing_sort_columns}"
)
order_by_parts = []
for item in sort_columns:
col_name = item["key"]
order_dir = item["value"]
sanitized_col = _sanitize_identifier(col_name)
order_by_parts.append(f'"{sanitized_col}" {order_dir}')
order_by_clause = ", ".join(order_by_parts)
sorting_columns_display = [
f"{item['key']} {item['value']}" for item in sort_columns
]
verbose and logger.info(
f"[{brick_display_name}] Sorting by columns: {', '.join(sorting_columns_display)}."
)
query = f"SELECT * FROM ({query}) AS sorted_data ORDER BY {order_by_clause}"
verbose and logger.info(
f"[{brick_display_name}] Executing query to sample rows."
)
if output_format == "pandas":
result = conn.execute(query).df()
verbose and logger.info(
f"[{brick_display_name}] Converted result to pandas DataFrame."
)
elif output_format == "polars":
result = conn.execute(query).pl()
verbose and logger.info(
f"[{brick_display_name}] Converted result to Polars DataFrame."
)
elif output_format == "arrow":
result = conn.execute(query).fetch_arrow_table()
verbose and logger.info(
f"[{brick_display_name}] Converted result to Arrow Table."
)
else:
verbose and logger.error(
f"[{brick_display_name}] Unsupported output format: {output_format}"
)
conn.close()
raise ValueError(f"Unsupported output format: {output_format}")
conn.close()
result_rows = len(result)
verbose and logger.info(
f"[{brick_display_name}] Sampling operation completed successfully. Returned {result_rows} rows with {len(selected_columns)} columns."
)
except Exception as e:
verbose and logger.error(
f"[{brick_display_name}] Error during sampling operation."
)
raise
return result
Brick Info
- pandas
- polars[pyarrow]
- duckdb
- pyarrow