Top N

Filter a dataset to get the top and/or bottom N rows based on specified column(s) sorting.

Top N

Processing

Filter a dataset to get the top and/or bottom N rows based on specified column(s) sorting.

Inputs

data
The input dataset (Pandas DataFrame, Polars DataFrame, or Arrow Table) to be filtered.
sort columns (optional)
List of columns and their sort direction (ASC/DESC) used to rank the data. If not provided, the value from the options is used.
top n count (optional)
The maximum number of top rows to retrieve. If not provided, the value from the options is used.
bottom n count (optional)
The maximum number of bottom rows to retrieve. If not provided, the value from the options is used.
group by (optional)
Columns used to partition the data before applying the N filter (e.g., top N per group). If not provided, the value from the options is used.

Inputs Types

Input Types
data DataFrame, ArrowTable
sort columns List
top n count Int
bottom n count Int
group by List

You can check the list of supported types here: Available Type Hints.

Outputs

result
The filtered dataset containing the top N and/or bottom N rows, potentially grouped, in the format specified by the 'Output Format' option.

Outputs Types

Output Types
result DataFrame, ArrowTable

You can check the list of supported types here: Available Type Hints.

Options

The Top N brick contains some changeable options:

Sort Columns
Defines the column(s) and direction (ASC or DESC) used for ranking and filtering the data.
Top N Rows
The maximum number of rows to retrieve from the top (highest rank).
Bottom N Rows
The maximum number of rows to retrieve from the bottom (lowest rank).
Group By Columns
List of columns to partition the dataset by, ensuring the top/bottom N filter applies independently within each group.
Combine Top and Bottom
If both Top N and Bottom N results are requested, determines if they should be merged into a single output dataset.
Add Rank Column
If enabled, a new column containing the calculated row rank will be added to the output dataset. The rank for bottom rows is typically negative.
Rank Column Name
The name of the column used to store the rank if 'Add Rank Column' is enabled (default is rank).
Output Format
Specifies the desired format of the returned data structure (pandas, polars, or arrow).
Safe Mode
If enabled, non-existent sort or group columns are skipped rather than raising an execution error.
Verbose
Enables detailed logging of the brick's execution steps.
import logging
import duckdb
import pandas as pd
import polars as pl
import pyarrow as pa
from coded_flows.types import Union, List, DataFrame, ArrowTable, Int, Bool

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def _coalesce(*values):
    return next((v for v in values if v is not None), None)


def _sanitize_identifier(identifier):
    """
    Sanitize SQL identifier by escaping special characters.
    Handles double quotes and other problematic characters.
    """
    return identifier.replace('"', '""')


def top_n(
    data: Union[DataFrame, ArrowTable],
    sort_columns: List = None,
    top_n_count: Int = None,
    bottom_n_count: Int = None,
    group_by: List = None,
    options=None,
) -> Union[DataFrame, ArrowTable]:
    brick_display_name = "Top N"
    options = options or {}
    verbose = options.get("verbose", True)
    sort_columns = _coalesce(sort_columns, options.get("sort_columns", []))
    top_n_count = _coalesce(top_n_count, options.get("top_n", 10))
    bottom_n_count = _coalesce(bottom_n_count, options.get("bottom_n", 0))
    group_by = _coalesce(group_by, options.get("group_by", []))
    combine_top_bottom = options.get("combine_top_bottom", True)
    add_rank_column = options.get("add_rank_column", False)
    rank_column_name = options.get("rank_column_name", "rank")
    output_format = options.get("output_format", "pandas")
    safe_mode = options.get("safe_mode", False)
    result = None
    conn = None
    try:
        if not sort_columns and (not isinstance(sort_columns, list)):
            verbose and logger.error(
                f"[{brick_display_name}] No sort columns specified!"
            )
            raise ValueError("At least one sort column must be specified!")
        if top_n_count < 0 or bottom_n_count < 0:
            verbose and logger.error(
                f"[{brick_display_name}] Top N and Bottom N values must be non-negative!"
            )
            raise ValueError("Top N and Bottom N values must be non-negative!")
        if top_n_count == 0 and bottom_n_count == 0:
            verbose and logger.error(
                f"[{brick_display_name}] At least one of Top N or Bottom N must be greater than 0!"
            )
            raise ValueError(
                "At least one of Top N or Bottom N must be greater than 0!"
            )
        verbose and logger.info(
            f"[{brick_display_name}] Starting Top N filtering. Top: {top_n_count}, Bottom: {bottom_n_count}."
        )
        data_type = None
        if isinstance(data, pd.DataFrame):
            data_type = "pandas"
        elif isinstance(data, pl.DataFrame):
            data_type = "polars"
        elif isinstance(data, (pa.Table, pa.lib.Table)):
            data_type = "arrow"
        if data_type is None:
            verbose and logger.error(
                f"[{brick_display_name}] Input data must be a pandas DataFrame, Polars DataFrame, or Arrow Table"
            )
            raise ValueError(
                "Input data must be a pandas DataFrame, Polars DataFrame, or Arrow Table"
            )
        verbose and logger.info(
            f"[{brick_display_name}] Detected input format: {data_type}."
        )
        conn = duckdb.connect(":memory:")
        conn.register("input_table", data)
        column_info = conn.execute("DESCRIBE input_table").fetchall()
        all_columns = {col[0]: col[1] for col in column_info}
        verbose and logger.info(
            f"[{brick_display_name}] Total columns in data: {len(all_columns)}."
        )
        order_by_parts = []
        for sort_spec in sort_columns:
            if isinstance(sort_spec, dict):
                col_name = sort_spec.get("key", "")
                sort_order = sort_spec.get("value", "DESC")
            else:
                col_name = sort_spec
                sort_order = "DESC"
            if not col_name:
                continue
            if col_name not in all_columns:
                if safe_mode:
                    verbose and logger.warning(
                        f"[{brick_display_name}] Safe mode: Skipping non-existent column '{col_name}'."
                    )
                    continue
                else:
                    verbose and logger.error(
                        f"[{brick_display_name}] Column '{col_name}' not found in data!"
                    )
                    raise ValueError(f"Column '{col_name}' not found in data!")
            sanitized_col = _sanitize_identifier(col_name)
            order_by_parts.append(f'"{sanitized_col}" {sort_order}')
        if not order_by_parts:
            verbose and logger.error(
                f"[{brick_display_name}] No valid sort columns found!"
            )
            raise ValueError("No valid sort columns found!")
        order_by_clause = ", ".join(order_by_parts)
        verbose and logger.info(f"[{brick_display_name}] Sorting by: {order_by_clause}")
        group_by_clause = ""
        partition_clause = ""
        group_by_parts_for_order = []
        if group_by and isinstance(group_by, list) and (len(group_by) > 0):
            valid_group_cols = []
            for col in group_by:
                if col not in all_columns:
                    if safe_mode:
                        verbose and logger.warning(
                            f"[{brick_display_name}] Safe mode: Skipping non-existent group by column '{col}'."
                        )
                        continue
                    else:
                        verbose and logger.error(
                            f"[{brick_display_name}] Group by column '{col}' not found in data!"
                        )
                        raise ValueError(f"Group by column '{col}' not found in data!")
                valid_group_cols.append(_sanitize_identifier(col))
            if valid_group_cols:
                group_by_parts = [f'"{col}"' for col in valid_group_cols]
                group_by_parts_for_order = group_by_parts.copy()
                partition_clause = f"PARTITION BY {', '.join(group_by_parts)}"
                verbose and logger.info(
                    f"[{brick_display_name}] Grouping by columns: {group_by}"
                )
        queries = []
        rank_cte = f"\n        WITH ranked_data AS (\n            SELECT *,\n                ROW_NUMBER() OVER ({partition_clause} ORDER BY {order_by_clause}) as row_rank_top,\n                ROW_NUMBER() OVER ({partition_clause} ORDER BY {order_by_clause.replace('DESC', 'TEMP').replace('ASC', 'DESC').replace('TEMP', 'ASC')}) as row_rank_bottom\n            FROM input_table\n        )\n        "
        if top_n_count > 0:
            if add_rank_column:
                sanitized_rank_col = _sanitize_identifier(rank_column_name)
                top_query = f'\n                {rank_cte}\n                SELECT *, row_rank_top AS "{sanitized_rank_col}"\n                FROM ranked_data\n                WHERE row_rank_top <= {top_n_count}\n                '
            else:
                top_query = f"\n                {rank_cte}\n                SELECT * EXCLUDE (row_rank_top, row_rank_bottom)\n                FROM ranked_data\n                WHERE row_rank_top <= {top_n_count}\n                "
            queries.append(top_query)
            verbose and logger.info(
                f"[{brick_display_name}] Getting top {top_n_count} rows"
                + (f" per group" if partition_clause else "")
            )
        if bottom_n_count > 0:
            if add_rank_column:
                sanitized_rank_col = _sanitize_identifier(rank_column_name)
                bottom_query = f'\n                {rank_cte}\n                SELECT *, -row_rank_bottom AS "{sanitized_rank_col}"\n                FROM ranked_data\n                WHERE row_rank_bottom <= {bottom_n_count}\n                '
            else:
                bottom_query = f"\n                {rank_cte}\n                SELECT * EXCLUDE (row_rank_top, row_rank_bottom)\n                FROM ranked_data\n                WHERE row_rank_bottom <= {bottom_n_count}\n                "
            queries.append(bottom_query)
            verbose and logger.info(
                f"[{brick_display_name}] Getting bottom {bottom_n_count} rows"
                + (f" per group" if partition_clause else "")
            )
        if len(queries) == 2 and combine_top_bottom:
            combined_query = f"({queries[0]}) UNION ALL ({queries[1]})"
            verbose and logger.info(
                f"[{brick_display_name}] Combining top and bottom results."
            )
        elif len(queries) == 1:
            combined_query = queries[0]
        else:
            combined_query = queries[0]
            if len(queries) == 2 and (not combine_top_bottom):
                verbose and logger.warning(
                    f"[{brick_display_name}] Both top and bottom requested but combine is False. Returning only top rows."
                )
        final_order_parts = []
        if group_by_parts_for_order:
            final_order_parts.extend(group_by_parts_for_order)
        if add_rank_column:
            sanitized_rank_col = _sanitize_identifier(rank_column_name)
            final_order_parts.append(f'"{sanitized_rank_col}"')
        else:
            final_order_parts.extend(order_by_parts)
        final_order_by = ", ".join(final_order_parts)
        final_query = (
            f"SELECT * FROM ({combined_query}) final_result ORDER BY {final_order_by}"
        )
        verbose and logger.info(
            f"[{brick_display_name}] Executing query to filter data with maintained sort order."
        )
        if output_format == "pandas":
            result = conn.execute(final_query).df()
            verbose and logger.info(
                f"[{brick_display_name}] Converted result to pandas DataFrame. Shape: {result.shape}"
            )
        elif output_format == "polars":
            result = conn.execute(final_query).pl()
            verbose and logger.info(
                f"[{brick_display_name}] Converted result to Polars DataFrame. Shape: {result.shape}"
            )
        elif output_format == "arrow":
            result = conn.execute(final_query).fetch_arrow_table()
            verbose and logger.info(
                f"[{brick_display_name}] Converted result to Arrow Table. Rows: {result.num_rows}"
            )
        else:
            verbose and logger.error(
                f"[{brick_display_name}] Unsupported output format: {output_format}"
            )
            raise ValueError(f"Unsupported output format: {output_format}")
        verbose and logger.info(
            f"[{brick_display_name}] Top N filtering completed successfully."
        )
    except Exception as e:
        verbose and logger.error(
            f"[{brick_display_name}] Error during Top N filtering: {str(e)}"
        )
        raise
    finally:
        if conn is not None:
            conn.close()
    return result

Brick Info

version v0.1.3
python 3.10, 3.11, 3.12, 3.13
requirements
  • pandas
  • polars[pyarrow]
  • duckdb
  • pyarrow