1. Introduction

- concept bottleneck models
- data point $(x,c,y)$가 주어졌을 때, $x$에 대해 concept $c$를 예측하고, 이를 이용해 $y$를 예측
- layer를 concept 수에 맞춰 resize + intermediate loss 추가
2. Related work
3. Setup
- a target $y \in \mathbb{R}$ from input $x \in \mathbb{R}^d$
- training point $\{(x^{(i)}, y^{(i)}, c^{(i)})\}^n_{i=1}$, $c \in \mathbb{R}^k$
- $\hat{y} = f(g(x)) = f(\hat{c})$ ($\hat{c} = g(x)$) : concept bottleneck models
- $g : \mathbb{R}^d \rightarrow \mathbb{R}^k$ : map an input $x$ into the concept space ("bone spurs")
- $f : \mathbb{R}^k \rightarrow \mathbb{R}$ : map concepts into a final prediction ("arthritis severity")
- task accuracy : how accurately $f(g(x))$ predicts $y$
- concept accuracy : how accurately $g(x)$ predicts $c$ (averaged over each concept)
- learning concept bottleneck models
- $L_{C_j} : \mathbb{R} \times \mathbb{R} \rightarrow \mathbb{R}_+$ : a loss function that measures the disrepancy between the predicted and true $j$-th concept
- $L_Y : \mathbb{R} \times \mathbb{R} \rightarrow \mathbb{R}_+$ : the discrepancy between predicted and true targets
- independent bottleneck (따로 학습)
- $\hat{f} = \text{argmin}_f \sum_i L_Y(f(c^{(i)}); y^{(i)})$
- $\hat{g} = \text{argmin}g \sum{i,j} L_{C_j} (g_j(x^{(i)}); c^{(i)}_j)$
- sequential bottleneck (순차적으로 학습)
- 처음에 $\hat{g}$를 학습
- $\hat{f} = \text{argmin}_f \sum_i L_Y (f(\hat{g}(x^{(i)})); y^{(i)})$를 학습
- joint bottleneck
- $\hat{f}, \hat{g} = \text{argmin}{f,g} \sum_i [L_Y(f(g(x^{(i)})); y^{(i)}) + \sum_j \lambda L{C_j}(g(x^{(i)}); c^{(i)})]$ for some $\lambda > 0$
- $\lambda$ : the tradeoff between concept vs. task loss
- standard model
- $\hat{f}, \hat{g} = \text{argmin}_{f,g} \sum_i L_Y(f(g(x^{(i)})); y^{(i)})$
- concept 수인 $k$에 맞춰서 레이어 하나를 수정
Classification
- independent bottleneck
- real valued score (e.g., concept logits $\hat{l} = \hat{g}(x) \in \mathbb{R}^k$)를 probabilistic prediction (e.g., $P(\hat{c}_j = 1) = \sigma(\hat{l}_j)$ for logistic regression)로 변환
- sequential and joint bottleneck
- $f$를 logit $\hat{l}$에 연결
- $P(\hat{c}_j = 1) = \sigma(\hat{g}_j(x))$와 $P(\hat{y} = 1) = \sigma(\hat{f}(\hat{g}(x)))$를 모두 계산
4. Benchmarking bottleneck model accuracy
4.1. Applications
- x-ray grading과 birt identification task
X-ray grading (OAI)
- knee x-rays from the Osteoarthritis Initiative (OAI)
- Kellgren-Lawrence grade (KLG)를 예측
- 4-level ordinal variable assessed by radiologists that measures the severity of osteoarthritis
- $k=10$ (joint space narrowing, bone spurs, calcification, etc.)
Bird identification (CUB)
- Caltech-UCSD Birds-200-2011 (CUB) dataset