import os
import pandas as pd
import numpy as np
from typing import List

from huggingface_hub import login
from transformers import AutoTokenizer
import datasets
from datasets import load_dataset, Dataset, DatasetDict, concatenate_datasets
from dotenv import load_dotenv

load_dotenv(dotenv_path="/mnt/t7/dnn/llm_practicing/.env")

login(token= os.environ["HF_TOKEN"])

# 캐시 디렉토리 설정
DATA_CACHE_DIR = "/mnt/t7/.cache/huggingface/datasets"
MODEL_CACHE_DIR = "/mnt/t7/.cache/huggingface/models"

TOKENIZER_NAME = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME, cache_dir=MODEL_CACHE_DIR)

columns = ['chosen', 'rejected', 'source', 'score_chosen', 'score_rejected', 'train_method']
empty_df = pd.DataFrame({col: [] for col in columns})
sft_dataset = Dataset.from_pandas(empty_df)
dpo_dataset = Dataset.from_pandas(empty_df)

datasets_info = [
    {
      "name": "aeolian83/allenai-llama-3.1-tulu-3-405b-preference-mixture_filtered",
      "train_method": "dpo",
      "split_list": ["train"],
      "score_list": ['chosen_rating', 'rejected_rating'],
    },
    {
      "name": "aeolian83/HuggingFaceH4-ultrafeedback_binarized_filtered",
      "train_method": "dpo",
      "split_list": ["train_prefs", "test_prefs"],
      "score_list": ['score_chosen', 'score_rejected'],
    },
    {
      "name": "aeolian83/HuggingFaceH4-ultrachat_200k_filtered",
      "train_method": "sft",
      "split_list": ["train_sft", "test_sft", "train_gen", "test_gen"],
      "score_list": None,
    },
    {
      "name": "aeolian83/HuggingFaceTB-smoltalk_filtered",
      "train_method": "sft",
      "split_list": ["train", "test"],
      "score_list": None,
    },
    {
      "name": "aeolian83/allenai-tulu-3-sft-mixture_filtered",
      "train_method": "sft",
      "split_list": ["train"],
      "score_list": None,
    }
]

# Assumes a global `tokenizer` object with methods:
#   - apply_chat_template(text: str, tokenize: bool) -> str
#   - encode(text: str) -> List[int]

def compute_min_count(ex: dict) -> dict:
    """
    Compute the minimum token count between chat-formatted 'chosen' and 'rejected' fields.
    Returns a dict with 'min_count'.
    """
    chosen_fmt = tokenizer.apply_chat_template(ex['chosen'], tokenize=False)
    rejected_fmt = tokenizer.apply_chat_template(ex['rejected'], tokenize=False)
    return {
        'min_count': min(
            len(tokenizer.encode(chosen_fmt)),
            len(tokenizer.encode(rejected_fmt))
        )
    }

def sft_data_processing(
    datasets_name: str,
    low_threshold: int,
    high_threshold: int,
    split_list: List[str],
    target_dataset: Dataset
) -> Dataset:
    """
    Use Hugging Face dataset's map and filter methods to select examples
    whose formatted 'messages' token count falls within [low_threshold, high_threshold].
    """
    ds_dict = load_dataset(datasets_name)
    filtered_splits = []

    for split in split_list:
        if split not in ds_dict:
            continue
        split_ds = ds_dict[split]
        # Compute token count per example
        split_ds = split_ds.map(
            lambda ex: {
                'token_count': len(
                    tokenizer.encode(
                        tokenizer.apply_chat_template(ex['messages'], tokenize=False)
                    )
                )
            },
            remove_columns=[]
        )
        # Filter by threshold
        split_ds = split_ds.filter(
            lambda ex: low_threshold <= ex['token_count'] <= high_threshold
        )
        # Format records
        split_ds = split_ds.map(
            lambda ex: {
                'chosen': ex['messages'],
                'rejected': ex['messages'],
                'source': datasets_name,
                'score_chosen': np.nan,
                'score_rejected': np.nan,
                'train_method': 'sft'
            },
            remove_columns=split_ds.column_names
        )
        filtered_splits.append(split_ds)
        
    print("-" * 50)
    print("filtered data num")            
    print(len(filtered_splits))
    print("-" * 50)

    if filtered_splits:
        new_ds = concatenate_datasets(filtered_splits)
        return concatenate_datasets([target_dataset, new_ds])
    return target_dataset

def dpo_data_processing(
    datasets_name: str,
    low_threshold: int,
    high_threshold: int,
    split_list: List[str],
    score_list: List[str],
    target_dataset: Dataset
) -> Dataset:
    """
    Use Hugging Face dataset's map and filter methods to select examples
    whose minimum token count between 'chosen' and 'rejected' falls within bounds.
    """
    ds_dict = load_dataset(datasets_name)
    filtered_splits = []

    for split in split_list:
        if split not in ds_dict:
            continue
        split_ds = ds_dict[split]
        # Compute min token count using standalone function
        split_ds = split_ds.map(compute_min_count, remove_columns=[])
        # Filter by threshold
        split_ds = split_ds.filter(
            lambda ex: low_threshold <= ex['min_count'] <= high_threshold
        )
        # Format records
        split_ds = split_ds.map(
            lambda ex: {
                'chosen': ex['chosen'],
                'rejected': ex['rejected'],
                'source': datasets_name,
                'score_chosen': ex[score_list[0]],
                'score_rejected': ex[score_list[1]],
                'train_method': 'dpo'
            },
            remove_columns=split_ds.column_names
        )
        filtered_splits.append(split_ds)
    
    print("-" * 50)
    print("filtered data num")            
    print(len(filtered_splits))
    print("-" * 50)

    if filtered_splits:
        new_ds = concatenate_datasets(filtered_splits)
        return concatenate_datasets([target_dataset, new_ds])
    return target_dataset
    
    
low_threshold = 2048
high_threshold = 4096

for dataset_info in datasets_info:
    train_method = dataset_info["train_method"]
    if train_method == "dpo":
        print(dataset_info["name"])
        dpo_dataset = dpo_data_processing(
            dataset_info["name"],
            low_threshold,
            high_threshold,
            dataset_info["split_list"],
            dataset_info["score_list"],
            dpo_dataset,
        )
    else:
        print(dataset_info["name"])
        sft_dataset = sft_data_processing(
            dataset_info["name"],
            low_threshold,
            high_threshold,
            dataset_info["split_list"],
            sft_dataset,
        )

mergeup_dataset_dict_2k4k = DatasetDict({
    "sft": sft_dataset, 
    "dpo": dpo_dataset
})

mergeup_dataset_dict_2k4k.push_to_hub("aeolian83/mergeup_dataset_dict_2k_to_4k", private=False)

aeolian83/mergeup_dataset_dict_1k_to_4k · Datasets at Hugging Face

aeolian83/mergeup_dataset_dict_2k_to_4k · Datasets at Hugging Face