Sampling

Sample rows from a DataFrame or Arrow Table using various sampling methods.

Sampling

Processing

This function samples a subset of rows from an input data structure (Pandas DataFrame, Polars DataFrame, or Arrow Table). It supports various sampling methods, including fixed row count sampling (first N, last N, random N) and ratio-based sampling (random ratio, stratified ratio), while allowing for column selection and optional sorting prior to sampling.

Inputs

data
The input DataFrame or Arrow Table containing the data to be sampled.
num rows (optional)
The exact number of rows desired for fixed-size sampling methods. If provided, this overrides the 'Number of Rows' option.
columns (optional)
A list of column names to include in the output dataset. If provided, this overrides the 'Columns to Select' option.

Inputs Types

Input Types
data DataFrame, ArrowTable
num rows Int
columns List

You can check the list of supported types here: Available Type Hints.

Outputs

result
The resulting data structure containing the sampled subset of rows, formatted according to the selected output format.

Outputs Types

Output Types
result DataFrame, ArrowTable

You can check the list of supported types here: Available Type Hints.

Options

The Sampling brick contains some changeable options:

Apply Sampling
If enabled, sampling rules are applied. If disabled, the entire input dataset is returned (subject only to column selection).
Sampling Method
Specifies the strategy for selecting rows. Choices include first_records, last_records, random_fixed (fixed count), random_ratio (percentage), stratified_fixed, and stratified_ratio.
Number of Rows
The target number of rows to retrieve for fixed-size sampling methods. This value is ignored if a ratio method is selected or if num rows is provided as an input.
Sampling Ratio (%)
The approximate percentage of rows to retrieve for ratio-based sampling methods.
Stratify Column
The name of the column used to ensure balanced representation across groups when using stratified sampling methods.
Sort Columns
Defines columns and directions (ASC/DESC) used to sort the data before sampling is applied (especially relevant for first_records and last_records).
Columns to Select
A list of column names to be included in the output. If left empty, all columns are selected.
Output Format
Specifies the desired format for the returned data structure: pandas (DataFrame), polars (DataFrame), or arrow (Arrow Table).
Verbose
Enables detailed logging of the sampling operations and resulting query execution.
import logging
import duckdb
import pandas as pd
import polars as pl
import pyarrow as pa
from coded_flows.types import Union, DataFrame, ArrowTable, Int, List

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def _coalesce(*values):
    return next((v for v in values if v is not None))


def _sanitize_identifier(identifier):
    """
    Sanitize SQL identifier by escaping special characters.
    Handles double quotes and other problematic characters.
    """
    return identifier.replace('"', '""')


def sampling(
    data: Union[DataFrame, ArrowTable],
    num_rows: Int = None,
    columns: List = None,
    options=None,
) -> Union[DataFrame, ArrowTable]:
    brick_display_name = "Sample Rows"
    options = options or {}
    verbose = options.get("verbose", True)
    apply_limit = options.get("apply_limit", True)
    sampling_method = options.get("sampling_method", "first_records")
    num_rows = _coalesce(num_rows, options.get("num_rows", 10000))
    ratio = options.get("ratio", 10)
    stratify_column = options.get("stratify_column", "")
    sort_columns = options.get("sort_columns", [])
    columns = _coalesce(columns, options.get("columns", []))
    output_format = options.get("output_format", "pandas")
    result = None
    if not isinstance(columns, list) and (
        not all((isinstance(c, str) for c in columns))
    ):
        verbose and logger.error(
            f"[{brick_display_name}] Invalid columns format! Expected a list."
        )
        raise ValueError("Columns must be provided as a list!")
    try:
        verbose and logger.info(f"[{brick_display_name}] Starting sampling operation.")
        data_type = None
        if isinstance(data, pd.DataFrame):
            data_type = "pandas"
        elif isinstance(data, pl.DataFrame):
            data_type = "polars"
        elif isinstance(data, (pa.Table, pa.lib.Table)):
            data_type = "arrow"
        if data_type is None:
            verbose and logger.error(
                f"[{brick_display_name}] Input data must be a pandas DataFrame, Polars DataFrame, or Arrow Table."
            )
            raise ValueError(
                "Input data must be a pandas DataFrame, Polars DataFrame, or Arrow Table"
            )
        verbose and logger.info(
            f"[{brick_display_name}] Detected input format: {data_type}."
        )
        conn = duckdb.connect(":memory:")
        conn.register("input_table", data)
        column_info = conn.execute("DESCRIBE input_table").fetchall()
        all_columns = [col[0] for col in column_info]
        total_rows = conn.execute("SELECT COUNT(*) FROM input_table").fetchone()[0]
        verbose and logger.info(
            f"[{brick_display_name}] Total rows in input data: {total_rows}."
        )
        if columns and len(columns) > 0:
            missing_columns = [col for col in columns if col not in all_columns]
            if missing_columns:
                verbose and logger.error(
                    f"[{brick_display_name}] Columns not found in data: {missing_columns}"
                )
                conn.close()
                raise ValueError(f"Columns not found in data: {missing_columns}")
            selected_columns = columns
            verbose and logger.info(
                f"[{brick_display_name}] Selecting specific columns: {columns}."
            )
        else:
            selected_columns = all_columns
            verbose and logger.info(
                f"[{brick_display_name}] Selecting all columns ({len(all_columns)} columns)."
            )
        sanitized_columns = [
            f'"{_sanitize_identifier(col)}"' for col in selected_columns
        ]
        select_clause = ", ".join(sanitized_columns)
        if not apply_limit:
            verbose and logger.info(
                f"[{brick_display_name}] No sampling applied. Returning all {total_rows} rows."
            )
            query = f"SELECT {select_clause} FROM input_table"
        else:
            verbose and logger.info(
                f"[{brick_display_name}] Applying sampling method: {sampling_method}."
            )
            if sampling_method == "first_records":
                actual_num_rows = min(num_rows, total_rows)
                verbose and logger.info(
                    f"[{brick_display_name}] Taking first {actual_num_rows} records."
                )
                query = (
                    f"SELECT {select_clause} FROM input_table LIMIT {actual_num_rows}"
                )
            elif sampling_method == "last_records":
                actual_num_rows = min(num_rows, total_rows)
                verbose and logger.info(
                    f"[{brick_display_name}] Taking last {actual_num_rows} records."
                )
                offset = max(0, total_rows - actual_num_rows)
                query = f"SELECT {select_clause} FROM input_table OFFSET {offset}"
            elif sampling_method == "random_fixed":
                actual_num_rows = min(num_rows, total_rows)
                verbose and logger.info(
                    f"[{brick_display_name}] Random sampling {actual_num_rows} records."
                )
                query = f"SELECT {select_clause} FROM input_table ORDER BY RANDOM() LIMIT {actual_num_rows}"
            elif sampling_method == "random_ratio":
                ratio_decimal = ratio / 100.0
                verbose and logger.info(
                    f"[{brick_display_name}] Random sampling approximately {ratio}% of records."
                )
                query = f"SELECT {select_clause} FROM input_table WHERE RANDOM() <= {ratio_decimal}"
            elif sampling_method == "stratified_fixed":
                if not stratify_column:
                    verbose and logger.error(
                        f"[{brick_display_name}] Stratified sampling requires a stratify column."
                    )
                    conn.close()
                    raise ValueError("Stratified sampling requires a stratify column")
                if stratify_column not in all_columns:
                    verbose and logger.error(
                        f"[{brick_display_name}] Stratify column '{stratify_column}' not found in data."
                    )
                    conn.close()
                    raise ValueError(
                        f"Stratify column '{stratify_column}' not found in data"
                    )
                sanitized_stratify = _sanitize_identifier(stratify_column)
                verbose and logger.info(
                    f"[{brick_display_name}] Stratified sampling {num_rows} records by column '{stratify_column}'."
                )
                qualified_select_parts = []
                for col in selected_columns:
                    sanitized_col = _sanitize_identifier(col)
                    qualified_select_parts.append(f'r."{sanitized_col}"')
                qualified_select_clause = ", ".join(qualified_select_parts)
                query = f'\n                WITH stratum_counts AS (\n                    SELECT "{sanitized_stratify}", COUNT(*) as cnt\n                    FROM input_table\n                    GROUP BY "{sanitized_stratify}"\n                ),\n                stratum_weights AS (\n                    SELECT "{sanitized_stratify}", \n                           CAST(cnt AS DOUBLE) / SUM(cnt) OVER () as weight,\n                           cnt\n                    FROM stratum_counts\n                ),\n                stratum_targets AS (\n                    SELECT "{sanitized_stratify}",\n                           -- FIX: Removed GREATEST(1, ...) to strictly respect num_rows.\n                           -- This prioritizes the total sample size over guaranteeing\n                           -- representation for every single stratum.\n                           -- Small strata may now correctly receive 0 samples.\n                           CAST(ROUND(weight * {num_rows}) AS INTEGER) as target_cnt,\n                           cnt\n                    FROM stratum_weights\n                ),\n                ranked_data AS (\n                    SELECT *, \n                           ROW_NUMBER() OVER (PARTITION BY "{sanitized_stratify}" ORDER BY RANDOM()) as rn\n                    FROM input_table\n                )\n                SELECT {qualified_select_clause}\n                FROM ranked_data r\n                JOIN stratum_targets t ON r."{sanitized_stratify}" = t."{sanitized_stratify}"\n                WHERE r.rn <= t.target_cnt\n                '
            elif sampling_method == "stratified_ratio":
                if not stratify_column:
                    verbose and logger.error(
                        f"[{brick_display_name}] Stratified sampling requires a stratify column."
                    )
                    conn.close()
                    raise ValueError("Stratified sampling requires a stratify column")
                if stratify_column not in all_columns:
                    verbose and logger.error(
                        f"[{brick_display_name}] Stratify column '{stratify_column}' not found in data."
                    )
                    conn.close()
                    raise ValueError(
                        f"Stratify column '{stratify_column}' not found in data"
                    )
                sanitized_stratify = _sanitize_identifier(stratify_column)
                ratio_decimal = ratio / 100.0
                verbose and logger.info(
                    f"[{brick_display_name}] Stratified sampling approximately {ratio}% of records by column '{stratify_column}'."
                )
                query = f'\n                WITH ranked_data AS (\n                    SELECT *, \n                           ROW_NUMBER() OVER (PARTITION BY "{sanitized_stratify}" ORDER BY RANDOM()) as rn,\n                           COUNT(*) OVER (PARTITION BY "{sanitized_stratify}") as stratum_cnt\n                    FROM input_table\n                )\n                SELECT {select_clause}\n                FROM ranked_data\n                -- FIX: Removed GREATEST(1, ...) to strictly respect the ratio.\n                -- This prioritizes the proportional ratio over guaranteeing\n                -- representation for every single stratum.\n                -- Small strata may now correctly receive 0 samples.\n                WHERE rn <= CAST(ROUND(stratum_cnt * {ratio_decimal}) AS INTEGER)\n                '
            else:
                verbose and logger.error(
                    f"[{brick_display_name}] Unknown sampling method: {sampling_method}"
                )
                conn.close()
                raise ValueError(f"Unknown sampling method: {sampling_method}")
        if (
            sort_columns
            and len(sort_columns) > 0
            and (
                sampling_method
                not in ["random_fixed", "stratified_fixed", "stratified_ratio"]
            )
        ):
            sort_col_names = [item["key"] for item in sort_columns]
            missing_sort_columns = [
                col for col in sort_col_names if col not in all_columns
            ]
            if missing_sort_columns:
                verbose and logger.error(
                    f"[{brick_display_name}] Sort columns not found in data: {missing_sort_columns}"
                )
                conn.close()
                raise ValueError(
                    f"Sort columns not found in data: {missing_sort_columns}"
                )
            order_by_parts = []
            for item in sort_columns:
                col_name = item["key"]
                order_dir = item["value"]
                sanitized_col = _sanitize_identifier(col_name)
                order_by_parts.append(f'"{sanitized_col}" {order_dir}')
            order_by_clause = ", ".join(order_by_parts)
            sorting_columns_display = [
                f"{item['key']} {item['value']}" for item in sort_columns
            ]
            verbose and logger.info(
                f"[{brick_display_name}] Sorting by columns: {', '.join(sorting_columns_display)}."
            )
            query = f"SELECT * FROM ({query}) AS sorted_data ORDER BY {order_by_clause}"
        verbose and logger.info(
            f"[{brick_display_name}] Executing query to sample rows."
        )
        if output_format == "pandas":
            result = conn.execute(query).df()
            verbose and logger.info(
                f"[{brick_display_name}] Converted result to pandas DataFrame."
            )
        elif output_format == "polars":
            result = conn.execute(query).pl()
            verbose and logger.info(
                f"[{brick_display_name}] Converted result to Polars DataFrame."
            )
        elif output_format == "arrow":
            result = conn.execute(query).fetch_arrow_table()
            verbose and logger.info(
                f"[{brick_display_name}] Converted result to Arrow Table."
            )
        else:
            verbose and logger.error(
                f"[{brick_display_name}] Unsupported output format: {output_format}"
            )
            conn.close()
            raise ValueError(f"Unsupported output format: {output_format}")
        conn.close()
        result_rows = len(result)
        verbose and logger.info(
            f"[{brick_display_name}] Sampling operation completed successfully. Returned {result_rows} rows with {len(selected_columns)} columns."
        )
    except Exception as e:
        verbose and logger.error(
            f"[{brick_display_name}] Error during sampling operation."
        )
        raise
    return result

Brick Info

version v0.1.3
python 3.10, 3.11, 3.12, 3.13
requirements
  • pandas
  • polars[pyarrow]
  • duckdb
  • pyarrow