= k_per_label: continue q = ex["question"] opts = ex["choices"] block = f"Question: {q}\n" for l, opt in zip(choices, opts): block += f"{l}. {opt}\n" block += f"Answer: {ans}\n" examples.append(block) lab"> = k_per_label: continue q = ex["question"] opts = ex["choices"] block = f"Question: {q}\n" for l, opt in zip(choices, opts): block += f"{l}. {opt}\n" block += f"Answer: {ans}\n" examples.append(block) lab"> = k_per_label: continue q = ex["question"] opts = ex["choices"] block = f"Question: {q}\n" for l, opt in zip(choices, opts): block += f"{l}. {opt}\n" block += f"Answer: {ans}\n" examples.append(block) lab">
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import pandas as pd
import re
import matplotlib.pyplot as plt
import seaborn as sns
import gc
from IPython.display import display

subjects = ["high_school_biology", "high_school_physics", "philosophy"]
choices = ["A", "B", "C", "D"]
models = {
    "TinyLlama": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    "Pythia-1B": "EleutherAI/pythia-1b"
}

def label_to_choice(index):
    return choices[index]

def build_fewshot_prompt(data, k_per_label=2):
    examples = []
    label_counts = {label: 0 for label in choices}
    for ex in data:
        ans = label_to_choice(ex["answer"])
        if label_counts[ans] >= k_per_label:
            continue
        q = ex["question"]
        opts = ex["choices"]
        block = f"Question: {q}\\n"
        for l, opt in zip(choices, opts):
            block += f"{l}. {opt}\\n"
        block += f"Answer: {ans}\\n"
        examples.append(block)
        label_counts[ans] += 1
        if all(v >= k_per_label for v in label_counts.values()):
            break
    return "\\n".join(examples)

def format_question(example):
    block = f"Question: {example['question']}\\n"
    for l, opt in zip(choices, example["choices"]):
        block += f"{l}. {opt}\\n"
    block += "Please answer with A, B, C, or D.\\nAnswer:"
    return block

def build_chat_messages(example, fewshot):
    return [
        {"role": "system", "content": "You are a helpful assistant that answers multiple-choice questions with A, B, C, or D."},
        {"role": "user", "content": fewshot + "\\n\\n" + format_question(example)},
        {"role": "assistant", "content": "Answer:"}
    ]

def extract_answer(text):
    # 우선 Answer: X 패턴 탐색
    match = re.search(r"Answer:\\s*([ABCD])", text)
    if match:
        return match.group(1)
    # 마지막에 등장한 A/B/C/D 중 하나를 추출
    candidates = re.findall(r"\\b([ABCD])\\b", text)
    return candidates[-1] if candidates else "None"

def generate_output(model, tokenizer, prompt, max_new_tokens=60):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return decoded

comparison = []

for subject in subjects:
    print(f"\\n\\n📘 Evaluating subject: {subject}")
    dataset = load_dataset("cais/mmlu", subject)
    data = dataset["validation"].shuffle(seed=42)

    total_len = len(data)
    eval_len = min(50, total_len)
    fewshot_len = max(0, total_len - eval_len)

    eval_subset = data.select(range(eval_len))

    if fewshot_len >= 4:
        fewshot_data = data.select(range(eval_len, total_len))
        fewshot = build_fewshot_prompt(fewshot_data)
    else:
        fewshot = ""  # 자동 zero-shot fallback

    gt = [label_to_choice(ex["answer"]) for ex in eval_subset]
    outputs_by_model = {}

    for model_name, model_id in models.items():
        print(f"\\n🔍 Loading model: {model_name}")
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map="auto"
        )

        preds, raw_outputs = [], []
        for i, ex in enumerate(eval_subset):
            if "TinyLlama" in model_name:
                messages = build_chat_messages(ex, fewshot)
                prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            else:
                prompt = fewshot + "\\n\\n" + format_question(ex)

            gen = generate_output(model, tokenizer, prompt)
            answer = extract_answer(gen)

            preds.append(answer)
            raw_outputs.append(gen.strip())

            print(f"\\n--- {model_name} | Q{i+1} ---")
            print("Prompt:\\n", prompt[:300], "...")
            print("Output:\\n", gen.strip()[:300], "...")

        outputs_by_model[model_name] = {"preds": preds, "outputs": raw_outputs}

        del model
        del tokenizer
        gc.collect()
        torch.cuda.empty_cache()

    for i, ex in enumerate(eval_subset):
        comparison.append({
            "Subject": subject,
            "Question #": i+1,
            "GT Answer": gt[i],
            "TinyLlama": outputs_by_model["TinyLlama"]["preds"][i],
            "Tiny Correct": "✅" if outputs_by_model["TinyLlama"]["preds"][i] == gt[i] else "❌",
            "Pythia-1B": outputs_by_model["Pythia-1B"]["preds"][i],
            "Pythia Correct": "✅" if outputs_by_model["Pythia-1B"]["preds"][i] == gt[i] else "❌",
            "TinyLlama Output": outputs_by_model["TinyLlama"]["outputs"][i],
            "Pythia Output": outputs_by_model["Pythia-1B"]["outputs"][i]
        })

# 🎯 정확도 요약
summary = []
for model in ["TinyLlama", "Pythia-1B"]:
    correct_key = "Tiny Correct" if model == "TinyLlama" else "Pythia Correct"
    for subject in subjects:
        subset = [row for row in comparison if row["Subject"] == subject]
        correct = sum(1 for row in subset if row[correct_key] == "✅")
        accuracy = round(correct / len(subset) * 100, 2)
        summary.append({"Model": model, "Subject": subject, "Accuracy": accuracy})

df = pd.DataFrame(summary)

# 📊 시각화
plt.figure(figsize=(8, 5))
sns.barplot(data=df, x="Subject", y="Accuracy", hue="Model")
plt.title("TinyLlama vs Pythia-1B on MMLU (Stable Version)")
plt.ylim(0, 100)
plt.ylabel("Accuracy (%)")
plt.grid(True)
plt.show()

# 📋 결과 테이블 출력
comparison_df = pd.DataFrame(comparison)
display(df)
display(comparison_df[["Subject", "Question #", "GT Answer", "TinyLlama", "Tiny Correct", "Pythia-1B", "Pythia Correct"]])

(TinyLlama vs Pythia-1B)

실험 개요


평가 방식 요약

스크린샷 2025-04-03 오후 5.07.49.png

TinyLlama 성능 향상 방안 제안

1. ✅ Prompt 구조 최적화