[머신러닝의 해석] 5편. 모델에 영향을 주는 교호작용 파악: ICE plot (Individual Conditional Expectation) | Python
머신러닝의 해석 저번 4편에서는 어떤 feature가 모델에서 target에 평균적으로 어떤 영향을 주는지 한눈에 파악하는 Partial Dependence Plot에 대해서 알아보았는데요. PDP를 알고계신다면 ICE plot은 이미 끝났습니다ㅎㅎ 그래서 이번 포스트에서는 이 플롯의 아이디어에 대해 아주 간단하게 리마인드만 하고, 바로 파이썬으로 실습해보려고 합니다!
ICE plot의 아이디어
먼저, 저번 포스트에서 다뤘던 Partial Dependence plot을 살짝 리마인드 해보겠습니다! PDP는, 모델에서 특정 feature가 target에 어떤 영향을 어떻게 주는지 한눈에 파악하는 플롯입니다. 예를 들어, 선형 회귀에서는 어떤 변수에 대한 회귀계수를 기울기로 하여 그려보면, 해당 변수와 target이 양의 선형 관계인지, 음의 선형 관계인지 알 수 있잖아요? 그런데, PDP에서는 선형적 관계가 아니더라도, 즉 비선형 관계도 시각적으로 확인할 수 있습니다. 어떻게 하냐구요?
자, X가 ‘선호 정당’, Y가 ‘소득’이라고 해봅시다. 그 외에도 여러가지 변수들도 있다고 해봅시다. X의 범주에는 ‘보수’,’진보’,’중립’의 3가지가 있습니다. 이 때, X의 Partial Dependence Function을 구하기 위해서는, 다음의 과정을 거치게 됩니다.
- Step1. 주어진 데이터를 가지고 모델을 학습한다.
- Step2. 나머지 값은 모두 고정한 채, X의 모든 값을 ‘보수’로 바꿔서 각 샘플에 대한 예측값($\hat y_{1i}, \ i=1,\cdots , n$)들을 구한다.
- Step3. 나머지 값은 모두 고정한 채, X의 모든 값을 ‘진보’로 바꿔서 각 샘플에 대한 예측값($\hat y_{2i}, \ i=1,\cdots , n$)들을 구한다.
- Step4. 나머지 값은 모두 고정한 채, X의 모든 값을 ‘중립’으로 바꿔서 각 샘플에 대한 예측값(\(\hat y_{3i}, \ i=1,\cdots , n\))들을 구한다.
자, 그 다음에 PDP는 아래처럼 했었죠. 즉, 예측값들을 평균낸 값을 가지고 플롯을 그리는 것입니다. 그렇기 때문에, PDP는 특정 feature의 평균 영향력을 본 것이죠.
- Step5. x축은 [‘보수’,’진보’,’중립’]으로 하고, y축은 [$mean(\hat y_{1i}), mean(\hat y_{2i}), mean(\hat y_{3i})$] 으로 하는 플롯을 그린다.
그런데 이 단계에서, 평균을 내지 않고 모든 선을 그려버리는 것이 바로 ICE plot입니다!
모든 선을 그린다는 것은, 전체 n개의 샘플에 대해 [$\hat y_{1i}, \hat y_{2i}, \hat y_{3i}$], $(i=1,\cdots , n)$를 그리는 것이죠. 이렇게 다 그리게 되면 뭘 캐치할 수 있을까요? PDP에서는 잡을 수 없던 것을 잡을 수 있게 됩니다. 바로 교호작용입니다!
PDP vs ICE
이 둘의 비교는 그림으로 보는게 빠를 것 같습니다. 데이터는, Medium의 이 포스트에서 직접 만든 데이터를 이용했습니다! 플랏을 이해할 정도로만 데이터를 소개하자면, target은 회사에서 받는 bonus이고, 변수에는 experience(경력)와 degree(학위)가 있습니다. 그리고, 이 둘은 교호작용이 존재하도록 설계되었습니다. 이제 experience에 대한 ICE plot을 그려볼 건데요. ICE plot을 그리는 함수는 직접 작성한 것이고 이따가 설명하겠습니다. 저기 가운데 노란색 선이 평균값에 해당하는 것으로 PDP와 같다고 볼 수 있겠네요! 평균 영향만 보면, 경력이 증가할수록 보너스가 높아지는 것으로 보입니다. 그런데 ICE plot(회색 선들)을 보면, 샘플들이 두 그룹으로 나눠집니다. 모든 선들이 똑같이 우상향하는 그림이 아니라, 한 그룹은 우상향하고, 한 그룹은 평행하네요. 여기서 학위에 따라 색깔을 다르게 그려보도록 하겠습니다.
학위가 있으면 1, 학위가 없으면 0인데요. 학위에 따라 근무경력과 보너스의 관계가 다르네요! 학위가 존재하면 근무 경력이 많아질수록 보너스가 증가하지만, 학위가 존재하지 않으면 근무경력과 상관없이 보너스가 일정합니다. 즉, 학위와 경력 간 교호작용이 이렇게 존재하는데, PDP만 보면 교호작용을 파악할 수 없었을 것입니다. 교호작용의 존재 여부에 따라 샘플들 간 차이가 존재할 수 있다는 걸 쉽게 이해하기 위해 이렇게 딱 만들어진 데이터입니다ㅎㅎ
ICE plot은 PDP와 같은 주의할 점들이 몇가지 있습니다. (1)샘플 수가 많거나 변수의 cardinality가 클 때 연산량이 너무 많아진다는 점, (2)변수 간 상관성이 높을 때, 비정상적인 개체를 생성할 수 있다는 점 등으로, 구체적인 설명은 저번 포스트 참고하시면 되겠습니다!
이 외에, 샘플 수가 많으면 사실 선이 너무 많아지면서 플롯이 지저분해집니다. 샘플 수가 많을 때에는, 모든 샘플에 대한 예측값을 그리기 보다는, 일부만 추출해서 그리는 것이 더 보기 편합니다.
Python으로 ICE plot 그리기
그럼 이제 Python으로 ICE plot을 그려봅시다. 아래 모든 내용에 대한 코드는 Github에도 있습니다! partial_dependence, plot_partial_dependence 모듈을 이용하면, ‘kind’라는 파라미터가 존재하는데요. kind에 “average”, “individual”, “both” 중 옵션을 지정해주면 PDP 또는 ICE 플롯을 그리거나 값을 추출할 수 있습니다.
- kind = “average” : PDP
- kind = “individual” : ICE
- kind = “both”: PDP, ICE
일단 데이터는 머신러닝의 해석 시리즈에서 계속 쓰고 있는 adult census data이구요. 데이터 전처리를 살짝 진행했는데, 전처리 과정은 이곳에 있습니다.
자, 여기서는 target을 income으로 하여 RandomForest를 이용해서 모델 학습을 시켰습니다.
from sklearn.ensemble import RandomForestClassifier
from sklearn import metrics
rf = RandomForestClassifier(n_estimators=300, random_state=0, max_depth=10).fit(X_train, y_train)
print('train set accuracy: ', metrics.accuracy_score(y_train, rf.predict(X_train)))
print('test set accuracy: ', metrics.accuracy_score(y_test, rf.predict(X_test)))
## train set accuracy: 0.8712512205328498
## test set accuracy: 0.8568019093078759
먼저 기본 plot_partial_dependence를 이용해서 변수 capital.gain에 대한 ICE plot을 그려보겠습니다.
plot_partial_dependence(model, data, feature_index, kind=”both”)
capital.gain은 X_train 데이터에서 9번째 칼럼이기 때문에, 해당 데이터에서의 변수 인덱스 [8]을 집어넣습니다.
from sklearn.inspection import plot_partial_dependence, partial_dependence
plot_partial_dependence(rf, X_train, [2], kind='both')
그 결과 이렇게 나왔습니다! 잘 보이지는 않지만, 아래쪽에 살짝 굵은 선이 바로 평균에 해당하는 선으로 PDP입니다. 기본 함수를 이용하면 조금 뭔가 아쉬워서 값들만 추출해서 직접 그려보기로 했습니다. 아래는 직접 작성한 ICE plot을 그리는 함수입니다.
def plot_ICE(model,data,col,n_sample=500):
pdp = partial_dependence(model, data, col, kind='both')
average = pdp['average'][0] #extract pdp values
x = pdp['values'][0] #feature levels
n_level = len(x)
ice = pd.DataFrame(pdp['individual'][0]) #extract ice values
ice = ice.sample(n_sample, random_state=0)
for idx in range(ice.shape[0]):
plt.plot(x,ice.iloc[idx,:n_level], c='gray', alpha=0.8, linewidth=0.6)
plt.plot(x,average, c='#f2a154', linewidth=4)
plt.title("ICE plot for %s" %col)
plt.xlabel('%s' %col)
plt.ylabel('y_hat')
아래는, 다른 변수에 따라 색깔이 바뀌도록 hue 설정을 추가한 ICE plot 함수입니다.
import random
def plot_ICE_hue(model,data,col,hue,n_sample=500):
pdp = partial_dependence(model, data, col, kind='individual')
x = pdp['values'][0]
n_level = len(x)
ice = pd.DataFrame(pdp['individual'][0])
data = data.reset_index(drop=True)
ice[hue] = data[hue]
ice = ice.sample(n_sample,random_state=0)
#hue color
colors = dict()
for level in data[hue].unique():
colors[level] = "#"+''.join([random.choice('0123456789ABCDEF') for j in range(6)])
ice['color'] = ice[hue].apply(lambda x: colors[x])
#ICE plot
for idx in range(ice.shape[0]):
plt.plot(x,ice.iloc[idx,:n_level], c=ice.iloc[idx,:]['color'], alpha=0.8, linewidth=0.6)
#hue legend
from matplotlib.lines import Line2D
lines = [Line2D([0], [0], color=c) for c in colors.values()]
labels = [name for name in colors.keys()]
plt.legend(lines,labels)
plt.title("ICE plot for %s according to %s" %(col,hue))
plt.xlabel('%s' %col)
plt.ylabel('y_hat')
자, 이제 이 함수를 이용해서 변수들의 ICE plot을 그려보겠습니다. 변수 capital.gain에 대해 ICE plot을 그려보면 다음과 같습니다.
plot_ICE(rf, X_train,'capital.gain')
이때, marital.status(0:결혼하고 같이 살고 있음, 1:결혼 안했음, 2:결혼을 한 적 있지만 같이 살고 있지 않음.)에 따라, 색깔을 다르게 보면 다음과 같습니다. 결혼을 안했거나 결혼을 했지만 따로 살고 있는 경우에는, capital.gain이 5000 직전에 삐쭉 솟은 패턴이 보이네요. 결혼을 하고 현재도 같이 살고 있는 경우에 해당하는 패턴과 비교적 구분되는 모습입니다. PDP만 봤을 때는, 자본이득에 따라 연봉이 높을 확률이 높게 예측되지는 않는 것처럼, 관련이 크게 없는 것처럼 보였습니다. 그런데, 결혼 여부에 따라 색을 달리 하여 여러가지 샘플들의 경우를 확인해보니, 결혼을 했지만 현재 몇가지 이유로 따로 살고 있거나, 결혼하지 않은 사람들은 자본 이득이 4000 넘기고 나서 연봉이 높을 확률이 급상승하는 걸 볼 수 있었습니다. 왜 그런지는 모르겠지만 그림을 해석해보면 그러네요ㅋㅋㅋ
plot_ICE_hue(rf, X_train, 'capital.gain', 'marital.status', 200)
일단 이런 식으로 ICE plot을 그려보았습니다! 여러가지 다른 변수들도 해봤는데 사실 이 데이터에서는 엄청난 교호작용을 파악하지는 못한 것 같습니다. 나중에 다른 데이터 가지고도 한번 해봐야겠습니다😅 워낙 연산량이 많고, 모든 변수들을 시도해보기는 어려운 방법이라, 도메인 지식을 발휘해서 자신의 판단 하에서 일부 변수들, 또는 변수 중요도가 높게 나온 변수들을 그려보기 좋은 방법인 것 같습니다.
$Reference$