from typing import List, Dict, Optional

import pandas as pd
from tqdm import tqdm

import fasttext
from langdetect import detect_langs

from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, PreTrainedTokenizer

def is_english_message(
    messages: List[Dict[str, str]],
    lang_detector: fasttext.FastText._FastText,
) -> bool:
    """Check if all messages are in English using fastText (fallback to langdetect)"""
    try:
        for m in messages:
            content = m["content"].strip()
            if len(content) < 10:
                return False
            pred = lang_detector.predict(
                content.replace("\\n", " "),
                k=1,
            )
            if pred[0][0] != "__label__en" or pred[1][0] < 0.9:
                return False
        return True
    except Exception:
        try:
            for m in messages:
                langs = detect_langs(m["content"])
                if langs[0].lang != "en" or langs[0].prob < 0.9:
                    return False
            return True
        except Exception:
            return False

def get_token_length(
    messages: List[Dict[str, str]],
    tokenizer: PreTrainedTokenizer,
) -> Optional[int]:
    """Apply chat template and calculate token length"""
    try:
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_special_tokens=True,
        )
        return len(tokenizer.encode(text))
    except Exception:
        return None

def process_dataset(
    tokenizer_name_or_path: str,
    dataset_name_or_path: str,
    message_col_name: str,
    min_token_length: int,
    max_token_length: int,
    max_count: int,
    full_data_path: str,
    sampled_data_path: str,
) -> None:
    """Main pipeline to filter dataset and save full + sampled versions"""
    tqdm.write(" Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_name_or_path,
        use_fast=True,
    )

    tqdm.write(" Loading fastText language model...")
    lang_detector = fasttext.load_model("lid.176.ftz")

    tqdm.write(" Loading dataset...")
    dataset: Dataset = load_dataset(
        dataset_name_or_path,
        split="train",
    )

    records = []
    tqdm.write(" Filtering and annotating...")
    for example in tqdm(dataset, desc="Processing"):
        messages = example.get(message_col_name)
        if not (
            isinstance(messages, list)
            and len(messages) == 2
            and messages[0]["role"] == "user"
            and messages[1]["role"] == "assistant"
        ):
            continue

        if not is_english_message(
            messages,
            lang_detector,
        ):
            continue

        token_len = get_token_length(
            messages,
            tokenizer,
        )
        if token_len is None:
            continue

        example["token_length"] = token_len
        records.append(example)

    tqdm.write(f"✅ Total valid examples: {len(records):,}")
    df_all = pd.DataFrame(records)

    tqdm.write(f" Saving full filtered dataset to: {full_data_path}")
    df_all.to_parquet(
        full_data_path,
        index=False,
    )

    tqdm.write(" Sampling based on token length range...")
    in_range = df_all[
        (df_all["token_length"] >= min_token_length)
        & (df_all["token_length"] <= max_token_length)
    ]

    if len(in_range) < max_count:
        need = max_count - len(in_range)
        under_range = df_all[df_all["token_length"] < min_token_length].sort_values(
            by="token_length",
            ascending=False,
        )
        supplement = under_range.head(need)
        final_df = pd.concat(
            [in_range, supplement],
            ignore_index=True,
        )
    else:
        final_df = in_range

    final_df = final_df[final_df["token_length"] <= max_token_length]

    tqdm.write(
        f" Saving sampled dataset to: {sampled_data_path} ({len(final_df):,} rows)"
    )
    final_df.to_parquet(
        sampled_data_path,
        index=False,
    )

    tqdm.write(" Done!")

if __name__ == "__main__":
    process_dataset(
        tokenizer_name_or_path="HuggingFaceTB/SmolLM2-135M-Instruct",
        dataset_name_or_path="allenai/tulu-3-sft-mixture",
        message_col_name="messages",
        min_token_length=2048,
        max_token_length=4096,
        max_count=100_000,
        full_data_path="tulu3_lang_filtered_with_token_length.parquet",
        sampled_data_path="tulu3_final_sampled_2048to4096.parquet",
    )

이후 데이터셋 분포를 확인하고 2048 길이가 넘어가는 데이터가 거의 없음을 확인, 512 ~ 2048로 재 샘플링하여 약 7만개의 데이터 저장(tulu3_final_sampled_512to2048.parquet)

tulu3_final_sampled_512to2048.parquet

smoltalk_final_sampled_512to2048.parquet

Infinity-Instruct_final_sampled_512to4096.parquet