# =========================================================
# Recod.ai/LUC - Scientific Image Forgery Detection
# Deep-dive EDA on Google Drive (Colab)
# - train_images/authentic, train_images/forged 구조 지원
# - Safe image/mask IO (.png/.npy), union of multi-masks
# - Component-level analytics (area/box/aspect/centroid/border)
# - Global forgery heatmap
# - Mask-vs-background color stats
# - Resolution/size correlations & numeric corr matrix
# - Optional texture: entropy & autocorr side-peak (sampled)
# =========================================================

# ----------------------------
# 0) Mount Google Drive (Colab)
# ----------------------------
try:
    from google.colab import drive
    drive.mount("/content/drive")
except ImportError:
    # Colab이 아닌 환경에서 실행할 때는 그냥 패스
    pass

# ----------------------------
# 1) Imports & config
# ----------------------------
import os, re, math, random, warnings
from pathlib import Path
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (9, 6)
plt.rcParams["axes.grid"] = True

from PIL import Image

try:
    import cv2
except Exception:
    cv2 = None

try:
    from skimage.measure import label as sk_label, regionprops
except Exception:
    sk_label = None
    regionprops = None

SEED = 42
random.seed(SEED)
np.random.seed(SEED)

# heavy 연산 샘플링 상한
SAMPLE_MAX_SLOW = 800
SAMPLE_MAX_TEXTURE = 200

# ----------------------------
# 2) Locate competition paths (Google Drive 버전)
# ----------------------------

# Colab에서 MyDrive 아래에 데이터가 있다고 가정
COMP_ROOT = Path("/content/drive/MyDrive")

# 디렉토리 구조:
# /content/drive/MyDrive/train_images/authentic/*.png
# /content/drive/MyDrive/train_images/forged/*.png
# /content/drive/MyDrive/train_masks/*.npy
# /content/drive/MyDrive/supplemental_images/*.png
# /content/drive/MyDrive/supplemental_masks/*.png, .npy
# /content/drive/MyDrive/test_images/*.png

DIR_TRAIN_IMG  = COMP_ROOT / "train_images"      # 내부에 authentic, forged 서브폴더
DIR_TRAIN_MASK = COMP_ROOT / "train_masks"
DIR_SUP_IMG    = COMP_ROOT / "supplemental_images"
DIR_SUP_MASK   = COMP_ROOT / "supplemental_masks"
DIR_TEST_IMG   = COMP_ROOT / "test_images"

print("[PATH] COMP_ROOT:", COMP_ROOT)
for p in [DIR_TRAIN_IMG, DIR_TRAIN_MASK, DIR_SUP_IMG, DIR_SUP_MASK, DIR_TEST_IMG]:
    print("  -", p, ":", "OK" if p.exists() else "Missing")

# ----------------------------
# 3) IO helpers (robust)
# ----------------------------
IMG_EXTS  = {".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp"}
MASK_EXTS = {".png", ".npy"}

def get_case_id(path: Path):
    # 파일명에서 숫자 부분만 추출 (예: 12345.png -> 12345)
    m = re.search(r"\\d+", path.stem)
    return int(m.group()) if m else path.stem

def read_image(path: Path):
    im = Image.open(path).convert("RGB")
    return np.array(im)  # HWC, uint8

def read_mask_file(path: Path):
    # (H, W) uint8, 값은 {0, 1}로 변환
    if path.suffix.lower() == ".npy":
        m = np.load(path)
        if m.ndim == 3 and m.shape[0] == 1:      # (1, H, W)
            m = m[0]
        if m.ndim == 3 and m.shape[-1] == 1:     # (H, W, 1)
            m = m[..., 0]
        return (m > 0).astype(np.uint8)
    else:
        if cv2 is not None:
            m = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
            if m is None:
                m = np.array(Image.open(path).convert("L"))
            elif m.ndim == 3:
                m = cv2.cvtColor(m, cv2.COLOR_BGR2GRAY)
        else:
            m = np.array(Image.open(path).convert("L"))
        return (m > 0).astype(np.uint8)

def list_mask_files(case_id, mask_dir: Path):
    # case_id에 대응하는 모든 mask 파일 찾기
    if not mask_dir.exists():
        return []
    cands = []
    cands += list(mask_dir.glob(f"{case_id}.*"))
    cands += list(mask_dir.glob(f"{case_id}_*.*"))
    cands += [
        p for p in mask_dir.glob("*.*")
        if p.suffix.lower() in MASK_EXTS and re.search(rf"\\b{case_id}\\b", p.stem)
    ]
    files = []
    for p in cands:
        if p.suffix.lower() in MASK_EXTS:
            files.append(p)
    files = sorted({str(p): p for p in files}.values(), key=lambda x: x.name)
    return files

def read_union_mask(case_id, target_hw=None):
    # train_masks + supplemental_masks에서 case_id에 대한 모든 mask를 읽고 union
    files = list_mask_files(case_id, DIR_TRAIN_MASK) + list_mask_files(case_id, DIR_SUP_MASK)
    if len(files) == 0:
        return None, 0
    masks = []
    for f in files:
        m = read_mask_file(f)
        if target_hw is not None and (m.shape[0], m.shape[1]) != tuple(target_hw):
            if cv2 is not None:
                m = cv2.resize(m, (target_hw[1], target_hw[0]), interpolation=cv2.INTER_NEAREST)
            else:
                m = np.array(
                    Image.fromarray(m).resize(
                        (target_hw[1], target_hw[0]),
                        resample=Image.NEAREST,
                    )
                )
        masks.append((m > 0).astype(np.uint8))
    union = np.zeros_like(masks[0], dtype=np.uint8)
    for m in masks:
        union = np.maximum(union, m)
    return union, len(files)

# ----------------------------
# 4) Metadata table
# ----------------------------

def collect_train_images_with_tags(root: Path):
    """
    train_images/authentic, train_images/forged 구조를 처리.
    img_tag 컬럼에 "authentic" / "forged" 저장.
    """
    recs = []

    for tag in ["authentic", "forged"]:
        sub = root / tag
        if not sub.exists():
            continue
        for ext in IMG_EXTS:
            for p in sub.glob(f"*{ext}"):
                cid = get_case_id(p)
                recs.append(
                    {
                        "case_id": cid,
                        "img_path": p,
                        "split": "train",
                        "img_tag": tag,  # 이미지 레벨 라벨
                    }
                )
    return recs

def collect_images(root: Path, split_name: str):
    """
    supplemental_images, test_images 등
    바로 아래에 이미지가 있는 단일 폴더용.
    """
    if not root.exists():
        return []
    files = []
    for ext in IMG_EXTS:
        files += list(root.glob(f"*{ext}"))
    recs = []
    for p in sorted(files, key=lambda x: (len(x.stem), x.stem)):
        cid = get_case_id(p)
        recs.append(
            {
                "case_id": cid,
                "img_path": p,
                "split": split_name,
                "img_tag": None,  # train이 아니므로 태그 정보 없음
            }
        )
    return recs

# 메타 데이터프레임 구성
records = []
records += collect_train_images_with_tags(DIR_TRAIN_IMG)
records += collect_images(DIR_SUP_IMG, "supplemental")
records += collect_images(DIR_TEST_IMG, "test")

meta = pd.DataFrame.from_records(records)
print("[META] images:", len(meta))
if len(meta) == 0:
    raise RuntimeError("No images found. Check dataset paths.")

def probe_fast(path: Path):
    try:
        with Image.open(path) as im:
            im = im.convert("RGB")
            w, h = im.size
        size_mb = path.stat().st_size / (1024**2)
        return pd.Series({"width": w, "height": h, "filesize_mb": size_mb})
    except Exception:
        return pd.Series({"width": np.nan, "height": np.nan, "filesize_mb": np.nan})

meta = pd.concat([meta, meta["img_path"].apply(probe_fast)], axis=1)

# ----------------------------
# 5) Mask stats per image
# ----------------------------
def mask_row_stats(row):
    if row["split"] == "test":
        # test에는 마스크가 없다고 가정
        return pd.Series({"mask_count": 0, "coverage": 0.0, "n_comp": 0})
    try:
        H = int(row["height"])
        W = int(row["width"])
        union, mcnt = read_union_mask(row["case_id"], target_hw=(H, W))
        if union is None:
            return pd.Series({"mask_count": 0, "coverage": 0.0, "n_comp": 0})
        area = int((union > 0).sum())
        cov = float(area) / float(max(1, H * W))
        if cv2 is not None:
            num, _ = cv2.connectedComponents((union > 0).astype(np.uint8), connectivity=8)
            ncomp = max(0, num - 1)
        elif sk_label is not None:
            ncomp = sk_label(union, connectivity=2).max()
        else:
            ncomp = int(area > 0)
        return pd.Series({"mask_count": mcnt, "coverage": cov, "n_comp": ncomp})
    except Exception:
        return pd.Series({"mask_count": 0, "coverage": 0.0, "n_comp": 0})

mask_stats = meta.apply(mask_row_stats, axis=1)
meta = pd.concat([meta, mask_stats], axis=1)

# 이미지 레벨 위조 여부 (마스크 존재 기반)
meta["is_forged"] = (meta["mask_count"] > 0).astype(int)
# 가로/세로 비율
meta["aspect"] = meta["width"] / meta["height"]

print(meta.head())

# ----------------------------
# 6) Plot helpers
# ----------------------------
def hist1(series, title, bins=30, xlabel=None):
    vals = pd.to_numeric(series, errors="coerce").dropna()
    if len(vals) == 0:
        print("[SKIP] empty for", title)
        return
    plt.figure(figsize=(8, 5))
    plt.hist(vals, bins=bins)
    plt.title(title)
    plt.xlabel(xlabel or series.name)
    plt.ylabel("count")
    plt.tight_layout()
    plt.show()

def scatter_xy(x, y, title, xlabel, ylabel, s=14, alpha=0.7):
    xs = pd.to_numeric(x, errors="coerce")
    ys = pd.to_numeric(y, errors="coerce")
    m = xs.notna() & ys.notna()
    if m.sum() == 0:
        print("[SKIP] empty for", title)
        return
    plt.figure(figsize=(7, 6))
    plt.scatter(xs[m], ys[m], s=s, alpha=alpha)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.tight_layout()
    plt.show()

def corr_heatmap(df, title="Correlation (numeric)"):
    num = df.select_dtypes(include=[np.number]).copy()
    if num.shape[1] < 2:
        print("[SKIP] not enough numeric columns for corr")
        return
    c = num.corr(numeric_only=True)
    plt.figure(figsize=(7, 6))
    im = plt.imshow(c, cmap="coolwarm", vmin=-1, vmax=1)
    plt.colorbar(im, fraction=0.046, pad=0.04)
    plt.xticks(range(len(c.columns)), c.columns, rotation=90)
    plt.yticks(range(len(c.columns)), c.columns)
    plt.title(title)
    plt.tight_layout()
    plt.show()

# ----------------------------
# 7) Component-level analytics
# ----------------------------
def component_props(binary_mask):
    """
    각 연결 성분에 대해
    area_px, bbox_w, bbox_h, aspect, cx, cy, touches_border, eccentricity 등을 계산
    """
    res = []
    if binary_mask is None or binary_mask.max() == 0:
        return res

    m = (binary_mask > 0).astype(np.uint8)
    H, W = m.shape[:2]

    if regionprops is not None:
        lab = sk_label(m, connectivity=2)
        for rp in regionprops(lab):
            area = int(rp.area)
            minr, minc, maxr, maxc = rp.bbox
            bw, bh = (maxc - minc), (maxr - minr)
            aspect = float(bw) / float(max(1, bh))
            cy, cx = rp.centroid
            touches = (minr == 0) or (minc == 0) or (maxr == H) or (maxc == W)
            ecc = float(getattr(rp, "eccentricity", np.nan))
            res.append(
                {
                    "area_px": area,
                    "bbox_w": bw,
                    "bbox_h": bh,
                    "aspect": aspect,
                    "cx": cx / W,
                    "cy": cy / H,
                    "touches_border": int(touches),
                    "eccentricity": ecc,
                }
            )
    elif cv2 is not None:
        num, labels, stats, centroids = cv2.connectedComponentsWithStats(m, connectivity=8)
        for comp in range(1, num):
            x, y, w, h, area = stats[comp]
            cx, cy = centroids[comp]
            aspect = float(w) / float(max(1, h))
            touches = (x == 0) or (y == 0) or (x + w == W) or (y + h == H)
            ecc = np.nan
            res.append(
                {
                    "area_px": int(area),
                    "bbox_w": int(w),
                    "bbox_h": int(h),
                    "aspect": aspect,
                    "cx": float(cx) / W,
                    "cy": float(cy) / H,
                    "touches_border": int(touches),
                    "eccentricity": float(ecc),
                }
            )
    else:
        ys, xs = np.where(m > 0)
        if len(xs) > 0:
            res.append(
                {
                    "area_px": int(len(xs)),
                    "bbox_w": int(xs.max() - xs.min() + 1),
                    "bbox_h": int(ys.max() - ys.min() + 1),
                    "aspect": float((xs.max() - xs.min() + 1))
                    / float(max(1, ys.max() - ys.min() + 1)),
                    "cx": float(xs.mean()) / m.shape[1],
                    "cy": float(ys.mean()) / m.shape[0],
                    "touches_border": int(
                        (xs.min() == 0)
                        or (ys.min() == 0)
                        or (xs.max() == m.shape[1] - 1)
                        or (ys.max() == m.shape[0] - 1)
                    ),
                    "eccentricity": np.nan,
                }
            )
    return res

# per-image -> per-component dataframe (sampled)
non_test_ids = meta.query("split!='test' and mask_count>0")["case_id"].tolist()
comp_sample_ids = random.sample(non_test_ids, k=min(len(non_test_ids), SAMPLE_MAX_SLOW))

comp_rows = []
for cid in comp_sample_ids:
    ipath = meta.loc[meta["case_id"] == cid, "img_path"].iloc[0]
    H, W = np.array(Image.open(ipath).convert("RGB")).shape[:2]
    union, _ = read_union_mask(cid, target_hw=(H, W))
    for d in component_props(union):
        d["case_id"] = cid
        d["area_pct"] = d["area_px"] / float(max(1, H * W))
        comp_rows.append(d)

comp_df = pd.DataFrame(comp_rows)
print("[COMP] components sampled:", len(comp_df))

if len(comp_df):
    hist1(comp_df["area_pct"], "Component area ratio", bins=40, xlabel="area / image")
    hist1(comp_df["aspect"], "Component bbox aspect", bins=40, xlabel="w / h")
    hist1(
        comp_df["touches_border"],
        "Component touches border (0/1)",
        bins=3,
        xlabel="0/1",
    )
    scatter_xy(
        comp_df["cx"],
        comp_df["cy"],
        "Component centroid distribution",
        "cx (0~1)",
        "cy (0~1)",
    )
else:
    print("[INFO] no components to analyze")

# ----------------------------
# 8) Global forgery heatmap
# ----------------------------
def aggregate_forgery_heatmap(case_ids, grid=64):
    acc = np.zeros((grid, grid), dtype=np.float32)
    total = 0
    for cid in case_ids:
        ipath = meta.loc[meta["case_id"] == cid, "img_path"].iloc[0]
        img = read_image(ipath)
        H, W = img.shape[:2]
        union, _ = read_union_mask(cid, target_hw=(H, W))
        if union is None or union.max() == 0:
            continue
        if cv2 is not None:
            m_small = cv2.resize(
                (union > 0).astype(np.uint8),
                (grid, grid),
                interpolation=cv2.INTER_NEAREST,
            )
        else:
            m_small = np.array(
                Image.fromarray((union > 0).astype(np.uint8)).resize(
                    (grid, grid),
                    resample=Image.NEAREST,
                )
            )
        acc += m_small.astype(np.float32)
        total += 1
    return acc, total

heatmap, hm_n = aggregate_forgery_heatmap(comp_sample_ids, grid=64)
if hm_n > 0:
    plt.figure(figsize=(6, 5))
    plt.imshow(heatmap / heatmap.max(), cmap="magma")
    plt.title(f"Global forgery heatmap (n={hm_n})")
    plt.axis("off")
    plt.tight_layout()
    plt.show()
else:
    print("[INFO] heatmap skipped (no masks)")

# ----------------------------
# 9) Mask vs background color stats
# ----------------------------
def mask_color_stats(cid):
    ipath = meta.loc[meta["case_id"] == cid, "img_path"].iloc[0]
    img = read_image(ipath).astype(np.float32) / 255.0
    H, W = img.shape[:2]
    union, _ = read_union_mask(cid, target_hw=(H, W))
    if union is None or union.max() == 0:
        return None
    m = union > 0
    fg = img[m]
    bg = img[~m]
    if len(fg) == 0 or len(bg) == 0:
        return None
    return {
        "case_id": cid,
        "fg_mean_r": fg[:, 0].mean(),
        "fg_mean_g": fg[:, 1].mean(),
        "fg_mean_b": fg[:, 2].mean(),
        "bg_mean_r": bg[:, 0].mean(),
        "bg_mean_g": bg[:, 1].mean(),
        "bg_mean_b": bg[:, 2].mean(),
        "fg_std_r": fg[:, 0].std(),
        "fg_std_g": fg[:, 1].std(),
        "fg_std_b": fg[:, 2].std(),
        "bg_std_r": bg[:, 0].std(),
        "bg_std_g": bg[:, 1].std(),
        "bg_std_b": bg[:, 2].std(),
    }

color_rows = []
color_ids = random.sample(non_test_ids, k=min(len(non_test_ids), SAMPLE_MAX_SLOW))
for cid in color_ids:
    d = mask_color_stats(cid)
    if d is not None:
        color_rows.append(d)
color_df = pd.DataFrame(color_rows)
print("[COLOR] samples:", len(color_df))

if len(color_df):
    for ch in ["r", "g", "b"]:
        diff = color_df[f"fg_mean_{ch}"] - color_df[f"bg_mean_{ch}"]
        hist1(
            diff,
            f"Mean difference (mask - background) [{ch.upper()}]",
            bins=40,
            xlabel="mean_fg - mean_bg",
        )
else:
    print("[INFO] color diff skipped (no data)")

# ----------------------------
# 10) Resolution/size relations + correlation
# ----------------------------
hist1(meta["aspect"], "Image aspect ratio (w/h)", bins=40, xlabel="aspect")
scatter_xy(
    meta["width"],
    meta["filesize_mb"],
    "Resolution vs File size",
    "width (px)",
    "file size (MiB)",
)
scatter_xy(
    meta["width"] * meta["height"],
    meta["filesize_mb"],
    "Pixels vs File size",
    "#pixels",
    "file size (MiB)",
)

corr_cols = ["width", "height", "filesize_mb", "coverage", "n_comp", "aspect"]
corr_heatmap(meta[corr_cols], title="Correlation (selected numeric)")

# ----------------------------
# 11) Texture: local entropy & autocorr side-peak
# ----------------------------
def local_entropy(gray, win=9):
    k = win
    pad = k // 2
    g = gray.astype(np.uint8)
    if cv2 is not None:
        small = cv2.resize(
            g,
            (g.shape[1] // 2, g.shape[0] // 2),
            interpolation=cv2.INTER_AREA,
        )
    else:
        small = np.array(
            Image.fromarray(g).resize(
                (g.shape[1] // 2, g.shape[0] // 2)
            )
        )
    from collections import Counter

    H, W = small.shape
    ent = np.zeros_like(small, dtype=np.float32)
    for y in range(pad, H - pad):
        for x in range(pad, W - pad):
            patch = small[y - pad : y + pad + 1, x - pad : x + pad + 1].ravel()
            cnt = Counter(patch.tolist())
            ps = np.array(list(cnt.values()), dtype=np.float32)
            ps = ps / ps.sum()
            ent[y, x] = -(ps * np.log2(ps + 1e-12)).sum()
    return ent.mean()

def autocorr_sidepeak(gray):
    f = np.fft.fft2(gray.astype(np.float32))
    ac = np.fft.ifft2(np.abs(f) ** 2).real
    ac = np.fft.fftshift(ac)
    cy, cx = np.array(ac.shape) // 2
    center = ac[cy, cx]
    rad = 5
    ac2 = ac.copy()
    ac2[cy - rad : cy + rad + 1, cx - rad : cx + rad + 1] = -np.inf
    peak2 = np.max(ac2)
    if not np.isfinite(peak2) or center == 0:
        return np.nan
    return float(peak2 / center)

tex_rows = []
tex_ids = random.sample(meta.index.tolist(), k=min(len(meta), SAMPLE_MAX_TEXTURE))
for i in tex_ids:
    try:
        img = read_image(meta.loc[i, "img_path"])
        gray = (
            0.299 * img[..., 0]
            + 0.587 * img[..., 1]
            + 0.114 * img[..., 2]
        ).astype(np.float32)
        ent = local_entropy(gray, win=9) if cv2 is not None else np.nan
        acp = autocorr_sidepeak(gray)
        tex_rows.append(
            {
                "case_id": meta.loc[i, "case_id"],
                "entropy_mean": ent,
                "autocorr_side_peak": acp,
            }
        )
    except Exception:
        continue

tex_df = pd.DataFrame(tex_rows)
print("[TEXTURE] samples:", len(tex_df))

if len(tex_df):
    hist1(tex_df["entropy_mean"], "Local entropy (mean)", bins=30, xlabel="entropy")
    hist1(
        tex_df["autocorr_side_peak"],
        "Autocorr side-peak ratio",
        bins=30,
        xlabel="peak2 / center",
    )
else:
    print("[INFO] texture plots skipped")

# ----------------------------
# 12) Summary prints
# ----------------------------
def pct(x):
    return f"{100.0 * float(x):.2f}%"

train_n = (meta["split"] == "train").sum()
sup_n = (meta["split"] == "supplemental").sum()
test_n = (meta["split"] == "test").sum()
non_test = meta["split"] != "test"
forged_n = int(non_test.sum() - (meta.loc[non_test, "mask_count"] == 0).sum())
auth_n = int(non_test.sum() - forged_n)

print("\\n================ SUMMARY ================")
print(f"Train: {train_n:,} | Supplemental: {sup_n:,} | Test: {test_n:,}")
print(f"Forged images (mask>0): {forged_n:,} | Authentic images (mask=0): {auth_n:,}")
print(
    f"Median coverage (non-test): "
    f"{meta.loc[non_test, 'coverage'].median():.5f}"
)
print(
    f"Mean #components (non-test, with mask): "
    f"{meta.loc[(non_test) & (meta['mask_count'] > 0), 'n_comp'].mean():.3f}"
)
if len(comp_df):
    print(f"Component area (median pct): {np.median(comp_df['area_pct']):.5f}")
    print(f"Touching border ratio: {comp_df['touches_border'].mean():.3f}")
print("=========================================")

image.png

image.png

1. Component area ratio 히스토그램

  1. 무엇을 그린 그래프인가?
  1. 핵심 용어
  1. 관찰 결과
  1. 해석