from typing import List, Dict, Optional
import pandas as pd
from tqdm import tqdm
import fasttext
from langdetect import detect_langs
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, PreTrainedTokenizer
def is_english_message(
messages: List[Dict[str, str]],
lang_detector: fasttext.FastText._FastText,
) -> bool:
"""Check if all messages are in English using fastText (fallback to langdetect)"""
try:
for m in messages:
content = m["content"].strip()
if len(content) < 10:
return False
pred = lang_detector.predict(
content.replace("\\n", " "),
k=1,
)
if pred[0][0] != "__label__en" or pred[1][0] < 0.9:
return False
return True
except Exception:
try:
for m in messages:
langs = detect_langs(m["content"])
if langs[0].lang != "en" or langs[0].prob < 0.9:
return False
return True
except Exception:
return False
def get_token_length(
messages: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
) -> Optional[int]:
"""Apply chat template and calculate token length"""
try:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_special_tokens=True,
)
return len(tokenizer.encode(text))
except Exception:
return None
def process_dataset(
tokenizer_name_or_path: str,
dataset_name_or_path: str,
message_col_name: str,
min_token_length: int,
max_token_length: int,
max_count: int,
full_data_path: str,
sampled_data_path: str,
) -> None:
"""Main pipeline to filter dataset and save full + sampled versions"""
tqdm.write(" Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path,
use_fast=True,
)
tqdm.write(" Loading fastText language model...")
lang_detector = fasttext.load_model("lid.176.ftz")
tqdm.write(" Loading dataset...")
dataset: Dataset = load_dataset(
dataset_name_or_path,
split="train",
)
records = []
tqdm.write(" Filtering and annotating...")
for example in tqdm(dataset, desc="Processing"):
messages = example.get(message_col_name)
if not (
isinstance(messages, list)
and len(messages) == 2
and messages[0]["role"] == "user"
and messages[1]["role"] == "assistant"
):
continue
if not is_english_message(
messages,
lang_detector,
):
continue
token_len = get_token_length(
messages,
tokenizer,
)
if token_len is None:
continue
example["token_length"] = token_len
records.append(example)
tqdm.write(f"✅ Total valid examples: {len(records):,}")
df_all = pd.DataFrame(records)
tqdm.write(f" Saving full filtered dataset to: {full_data_path}")
df_all.to_parquet(
full_data_path,
index=False,
)
tqdm.write(" Sampling based on token length range...")
in_range = df_all[
(df_all["token_length"] >= min_token_length)
& (df_all["token_length"] <= max_token_length)
]
if len(in_range) < max_count:
need = max_count - len(in_range)
under_range = df_all[df_all["token_length"] < min_token_length].sort_values(
by="token_length",
ascending=False,
)
supplement = under_range.head(need)
final_df = pd.concat(
[in_range, supplement],
ignore_index=True,
)
else:
final_df = in_range
final_df = final_df[final_df["token_length"] <= max_token_length]
tqdm.write(
f" Saving sampled dataset to: {sampled_data_path} ({len(final_df):,} rows)"
)
final_df.to_parquet(
sampled_data_path,
index=False,
)
tqdm.write(" Done!")
if __name__ == "__main__":
process_dataset(
tokenizer_name_or_path="HuggingFaceTB/SmolLM2-135M-Instruct",
dataset_name_or_path="allenai/tulu-3-sft-mixture",
message_col_name="messages",
min_token_length=2048,
max_token_length=4096,
max_count=100_000,
full_data_path="tulu3_lang_filtered_with_token_length.parquet",
sampled_data_path="tulu3_final_sampled_2048to4096.parquet",
)
이후 데이터셋 분포를 확인하고 2048 길이가 넘어가는 데이터가 거의 없음을 확인, 512 ~ 2048로 재 샘플링하여 약 7만개의 데이터 저장(tulu3_final_sampled_512to2048.parquet)
tulu3_final_sampled_512to2048.parquet