Source code for fabrictools.quality.clean

"""Pure DataFrame cleaning helpers."""

from __future__ import annotations

import re
import unicodedata
from typing import List, Optional

from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import (
    DateType,
    DoubleType,
    IntegerType,
    StringType,
)

from fabrictools.core import log
from fabrictools.io import resolve_lakehouse_read_candidate


def _to_snake_case(name: str) -> str:
    normalized = unicodedata.normalize("NFKD", name.strip())
    cleaned = "".join(ch for ch in normalized if not unicodedata.combining(ch))
    cleaned = re.sub(r"[^0-9A-Za-z]+", "_", cleaned)
    cleaned = re.sub(r"_+", "_", cleaned).strip("_").lower()
    if not cleaned:
        return "col"
    if cleaned[0].isdigit():
        return f"col_{cleaned}"
    return cleaned


def _build_unique_column_names(columns: List[str]) -> List[str]:
    seen: dict[str, int] = {}
    result: List[str] = []
    for col_name in columns:
        base = _to_snake_case(col_name)
        count = seen.get(base, 0) + 1
        seen[base] = count
        if count == 1:
            result.append(base)
        else:
            result.append(f"{base}_{count}")
    return result


def _normalized_name_collisions(columns: List[str]) -> dict[str, List[str]]:
    grouped: dict[str, List[str]] = {}
    for col_name in columns:
        normalized = _to_snake_case(col_name)
        grouped.setdefault(normalized, []).append(col_name)
    return {
        normalized: originals
        for normalized, originals in grouped.items()
        if len(originals) > 1
    }


def _replace_empty_strings_with_nulls(df: DataFrame) -> DataFrame:
    string_columns = {
        field.name
        for field in df.schema.fields
        if isinstance(field.dataType, StringType)
    }
    if not string_columns:
        return df

    select_exprs = []
    for col_name in df.columns:
        if col_name in string_columns:
            select_exprs.append(
                F.when(F.trim(F.col(col_name)) == "", F.lit(None))
                .otherwise(F.trim(F.col(col_name)))
                .alias(col_name)
            )
        else:
            select_exprs.append(F.col(col_name))
    return df.select(*select_exprs)


# Date-only shape (no time suffix). Used for diagnostics in mismatch logs, not for casting rules.
# Allows 1–2 digit month/day; ISO yyyy-first dash, European dd-MM-yyyy / US MM-dd-yyyy hyphen, slash, dot.
_DATE_ONLY_PATTERN = (
    r"^("
    r"\d{4}-\d{1,2}-\d{1,2}|"
    r"\d{1,2}-\d{1,2}-\d{4}|"
    r"\d{4}/\d{1,2}/\d{1,2}|"
    r"\d{1,2}/\d{1,2}/\d{4}|"
    r"\d{1,2}\.\d{1,2}\.\d{4}|"
    r"\d{4}\.\d{1,2}\.\d{1,2}"
    r")$"
)
_DATE_CANDIDATE_PATTERN = r"^\d{1,4}[-/\.]\d{1,2}[-/\.]\d{1,4}"
_INT_TEXT_PATTERN = r"^[+-]?\d+$"
_FLOAT_TEXT_PATTERN = r"^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?$"
_PARSED_DATE_SAMPLE_LIMIT = 5
_TIME_PARSER_POLICY_KEY = "spark.sql.legacy.timeParserPolicy"


[docs] def detect_and_cast_columns(df: DataFrame, verbose: bool = False) -> DataFrame: """Infer primitive types from string columns and cast when the column is uniform. Order of detection (first match wins): **date** (uniform non-null success of a ``to_date`` / ``to_timestamp`` chain over several patterns—European forms before US for ambiguous day/month; strings with a trailing time-of-day may still yield a calendar day and are cast to ``date``, dropping the time part; US slash dates with 12-hour clock and AM/PM suffix are handled via ``h:mm[:ss] a`` patterns), **timestamp** (``to_timestamp`` with several patterns including US 12h + AM/PM, 24h, plus ISO ``T``), **integer** (full string matches ``^[+-]?\\d+$``), **double** (decimal/scientific), else the column remains ``string``. Columns that are all-null are skipped; null cells are kept through casts. Sets ``spark.sql.legacy.timeParserPolicy`` to ``LEGACY`` for the duration of the call and restores the previous session value afterward. :param df: Input dataframe. :type df: ~pyspark.sql.DataFrame :returns: Dataframe with qualifying string columns cast. :rtype: ~pyspark.sql.DataFrame """ spark = df.sparkSession # Use CORRECTED to let Spark handle ancient dates (before 1582/1900) without throwing an error # and to fix parsing errors where LEGACY fails on some formats like 4/14/2026. # Note: We do not restore the policy because the returned DataFrame is lazy # and requires CORRECTED policy during evaluation (e.g. write/count). spark.conf.set(_TIME_PARSER_POLICY_KEY, "CORRECTED") string_columns = [ field.name for field in df.schema.fields if isinstance(field.dataType, StringType) ] if not string_columns: return df def _get_parsed_date_expr(safe_trimmed): return F.coalesce( F.to_date(safe_trimmed, "yyyy-MM-dd"), F.to_date(safe_trimmed, "yyyy/M/d"), F.to_date(safe_trimmed, "dd-MM-yyyy"), F.to_date(safe_trimmed, "d-M-yyyy"), F.to_date(safe_trimmed, "MM-dd-yyyy"), F.to_date(safe_trimmed, "M-d-yyyy"), F.to_date(safe_trimmed, "dd/MM/yyyy"), F.to_date(safe_trimmed, "d/M/yyyy"), F.to_date(safe_trimmed, "dd.MM.yyyy"), F.to_date(safe_trimmed, "d.M.yyyy"), F.to_date(safe_trimmed, "MM/dd/yyyy"), F.to_date(safe_trimmed, "M/d/yyyy"), F.to_date(safe_trimmed, "MM.dd.yyyy"), F.to_date(safe_trimmed, "M.d.yyyy"), F.to_timestamp(safe_trimmed, "M/d/yyyy h:mm:ss a").cast(DateType()), F.to_timestamp(safe_trimmed, "MM/dd/yyyy h:mm:ss a").cast(DateType()), F.to_timestamp(safe_trimmed, "M/d/yyyy h:mm a").cast(DateType()), F.to_timestamp(safe_trimmed, "MM/dd/yyyy h:mm a").cast(DateType()), ) def _get_parsed_ts_expr(safe_trimmed): return F.coalesce( F.to_timestamp(safe_trimmed, "yyyy-MM-dd HH:mm:ss"), F.to_timestamp(safe_trimmed, "dd-MM-yyyy HH:mm:ss"), F.to_timestamp(safe_trimmed, "d-M-yyyy HH:mm:ss"), F.to_timestamp(safe_trimmed, "MM-dd-yyyy HH:mm:ss"), F.to_timestamp(safe_trimmed, "M-d-yyyy HH:mm:ss"), F.to_timestamp(safe_trimmed, "dd/MM/yyyy HH:mm:ss"), F.to_timestamp(safe_trimmed, "d/M/yyyy HH:mm:ss"), F.to_timestamp(safe_trimmed, "MM/dd/yyyy HH:mm:ss"), F.to_timestamp(safe_trimmed, "M/d/yyyy HH:mm:ss"), F.to_timestamp(safe_trimmed, "M/d/yyyy h:mm:ss a"), F.to_timestamp(safe_trimmed, "MM/dd/yyyy h:mm:ss a"), F.to_timestamp(safe_trimmed, "M/d/yyyy h:mm a"), F.to_timestamp(safe_trimmed, "MM/dd/yyyy h:mm a"), F.to_timestamp(safe_trimmed, "yyyy-MM-dd'T'HH:mm:ss"), ) agg_exprs = [] for col_name in string_columns: col_expr = F.col(col_name) trimmed = F.trim(col_expr) agg_exprs.append(F.sum(F.when(col_expr.isNotNull(), 1).otherwise(0)).alias(f"{col_name}__nn")) agg_exprs.append(F.sum(F.when(col_expr.isNotNull() & ~trimmed.rlike(_INT_TEXT_PATTERN), 1).otherwise(0)).alias(f"{col_name}__int_fail")) agg_exprs.append(F.sum(F.when(col_expr.isNotNull() & ~trimmed.rlike(_FLOAT_TEXT_PATTERN), 1).otherwise(0)).alias(f"{col_name}__float_fail")) safe_trimmed = F.when(trimmed.rlike(_DATE_CANDIDATE_PATTERN), trimmed).otherwise(F.lit(None)) parsed_date = _get_parsed_date_expr(safe_trimmed) parsed_ts = _get_parsed_ts_expr(safe_trimmed) agg_exprs.append(F.sum(F.when(col_expr.isNotNull() & parsed_date.isNull(), 1).otherwise(0)).alias(f"{col_name}__date_fail")) agg_exprs.append(F.sum(F.when(col_expr.isNotNull() & parsed_ts.isNull(), 1).otherwise(0)).alias(f"{col_name}__ts_fail")) stats = df.agg(*agg_exprs).collect()[0].asDict() if agg_exprs else {} select_exprs = [] for col_name in df.columns: if col_name not in string_columns: select_exprs.append(F.col(col_name)) continue nn = stats.get(f"{col_name}__nn", 0) if nn == 0: select_exprs.append(F.col(col_name)) continue date_fail = stats.get(f"{col_name}__date_fail", 0) ts_fail = stats.get(f"{col_name}__ts_fail", 0) int_fail = stats.get(f"{col_name}__int_fail", 0) float_fail = stats.get(f"{col_name}__float_fail", 0) col_expr = F.col(col_name) trimmed = F.trim(col_expr) safe_trimmed = F.when(trimmed.rlike(_DATE_CANDIDATE_PATTERN), trimmed).otherwise(F.lit(None)) if date_fail == 0: parsed_date = _get_parsed_date_expr(safe_trimmed) select_exprs.append(F.when(col_expr.isNull(), None).otherwise(parsed_date).alias(col_name)) elif ts_fail == 0: parsed_ts = _get_parsed_ts_expr(safe_trimmed) select_exprs.append(F.when(col_expr.isNull(), None).otherwise(parsed_ts).alias(col_name)) elif int_fail == 0: select_exprs.append(F.when(col_expr.isNull(), None).otherwise(col_expr.cast(IntegerType())).alias(col_name)) elif float_fail == 0: select_exprs.append(F.when(col_expr.isNull(), None).otherwise(col_expr.cast(DoubleType())).alias(col_name)) else: select_exprs.append(col_expr) return df.select(*select_exprs)
def add_silver_metadata( df: DataFrame, source_lakehouse_name: str, source_relative_path: str, source_layer: str = "bronze", ingestion_timestamp_col: str = "ingestion_timestamp", source_layer_col: str = "ingestion_source_layer", source_path_col: str = "ingestion_source_path", year_col: str = "ingestion_year", month_col: str = "ingestion_month", day_col: str = "ingestion_day", spark: Optional[SparkSession] = None, verbose: bool = False, ) -> DataFrame: """Add Silver-layer metadata columns (ingestion time, source path, date parts). Resolves ``source_relative_path`` with :py:func:`fabrictools.io.lakehouse.resolve_lakehouse_read_candidate`. Date partition columns (``year_col`` / ``month_col`` / ``day_col``) are derived from the current ingestion date. :param df: Bronze or intermediate dataframe. :param source_lakehouse_name: Source Lakehouse display name. :param source_relative_path: Source path passed to path resolution. :param source_layer: Literal stored in ``source_layer_col`` (default ``bronze``). :param ingestion_timestamp_col: Column name for ``current_timestamp()``. :param source_layer_col: Column name for the layer literal. :param source_path_col: Column name for the resolved relative path string. :param year_col: Partition year column name. :param month_col: Partition month column name. :param day_col: Partition day-of-month column name. :param spark: Optional ``SparkSession`` for path resolution. :type df: ~pyspark.sql.DataFrame :type source_lakehouse_name: str :type source_relative_path: str :type source_layer: str :type ingestion_timestamp_col: str :type source_layer_col: str :type source_path_col: str :type year_col: str :type month_col: str :type day_col: str :type spark: ~pyspark.sql.SparkSession | None :returns: ``df`` with metadata and partition columns appended/overwritten. :rtype: ~pyspark.sql.DataFrame .. rubric:: Example >>> silver_df = add_silver_metadata( # doctest: +SKIP ... bronze_df, ... source_lakehouse_name="BronzeLakehouse", ... source_relative_path="dbo.RawOrders", ... ) """ resolved_source_path = resolve_lakehouse_read_candidate( lakehouse_name=source_lakehouse_name, relative_path=source_relative_path, spark=spark, ) current_date_expr = F.current_date() metadata_df = ( df.withColumn(ingestion_timestamp_col, F.current_timestamp()) .withColumn(source_layer_col, F.lit(source_layer)) .withColumn(source_path_col, F.lit(resolved_source_path)) .withColumn(year_col, F.year(current_date_expr)) .withColumn(month_col, F.month(current_date_expr)) .withColumn(day_col, F.dayofmonth(current_date_expr)) ) if verbose: log( "Silver metadata added: " f"{ingestion_timestamp_col}, {source_layer_col}, {source_path_col}, " f"{year_col}, {month_col}, {day_col} " "(partition source: current_date())" ) return metadata_df def clean_data( df: DataFrame, drop_duplicates: bool = True, drop_all_null_rows: bool = True, verbose: bool = False, ) -> DataFrame: """Normalize names, trim empty strings to null, infer types, optionally dedupe. Renames columns to unique snake_case (via internal helpers), replaces blank strings with null on string columns, runs :py:func:`detect_and_cast_columns`, then optionally drops duplicate rows and rows that are all-null. :param df: Input dataframe. :param drop_duplicates: If ``True``, call ``dropDuplicates()`` after cleaning. :param drop_all_null_rows: If ``True``, call ``dropna(how="all")``. :type df: ~pyspark.sql.DataFrame :type drop_duplicates: bool :type drop_all_null_rows: bool :returns: Cleaned dataframe. :rtype: ~pyspark.sql.DataFrame .. rubric:: Example >>> cleaned = clean_data(raw_df, drop_duplicates=True, drop_all_null_rows=True) # doctest: +SKIP """ before_cols = len(df.columns) normalized_columns = _build_unique_column_names(df.columns) cleaned_df = df.toDF(*normalized_columns) cleaned_df = _replace_empty_strings_with_nulls(cleaned_df) cleaned_df = detect_and_cast_columns(cleaned_df, verbose=verbose) if drop_duplicates: cleaned_df = cleaned_df.dropDuplicates() if drop_all_null_rows: cleaned_df = cleaned_df.dropna(how="all") after_cols = len(cleaned_df.columns) if verbose: log( f"Data cleaned: columns {before_cols} -> {after_cols}" ) return cleaned_df __all__ = [ "clean_data", "add_silver_metadata", "detect_and_cast_columns", "_to_snake_case", "_build_unique_column_names", "_normalized_name_collisions", "_replace_empty_strings_with_nulls", ] if __name__ == "__main__": print("Test")