https://towardsdatascience.com/the-only-guide-you-need-to-understand-regression-trees-4964992a07a8
"Classification And Regression Trees" 약자
이진 트리(binary tree) 생성
수치형 타겟 변수 지원 (regression)
범주 분류(classification)에는 Gini impurity를, 수치 회귀(regression)에는 Regression tree를 사용하여 노드 분할
Gini Impurity를 사용한 분류 처리
Classification에는 Gini impurity가 사용된다.
Gini impurity는 데이터셋에 다른 데이터가 섞여 있는 정도이다.
클래스가 두 개인 경우 0~0.5 값을 가지며, 모든 데이터가 같은 클래스이면 0, 두 클래스의 데이터 수가 같으면 0.5의 값을 가진다.
2023.08.18 - [AI,머신러닝,데이터/용어] - 지니 불순도 (Gini Impurity)
Gini impurity를 구하는 식은 다음과 같다.
$ Gini(D) = 1 - \sum\limits_{i=1}^{k}p_i^2 $
모든 피쳐의 범주(또는, 수치 분할지점들)에 대해 Gini impurity를 계산해서, 이 값이 가장 작은 피쳐를 사용해 노드를 분할한다.
만약 피쳐 A를 기준으로 데이터셋 $D$를 $D_1, D_2$ 로 나눴고, 각각의 크기가 $n_1, n_2$일 때, Gini impurity는 다음과 같이 데이터셋 크기에 비례해서 합산한다.
$ Gini_A(D) = \cfrac{n_1}{n}Gini(D_1) + \cfrac{n_2}{n}Gini(D_2) $
분할지점(Breakpoint) 결정
수치형 피쳐는 특정 분할지점(breakpoint)을 기준으로 초과/이하의 그룹으로 나누며, Breakpoints 후보 선택은 데이터 정렬 후 class가 바뀌는 지점을 모두 후보로 한다. (ID3, C4.5와 동일)
위 데이터셋에서 $x$가 [2.5, 3.5, 6.5]일 때의 Gini impurity는 각각 [0.125, 0.457, 0.069]이다.
Gini impurity가 0.069로 최소인 6.5를 기준으로 노드를 분할해야 한다.
Entropy와 Gini Impurity 비교
ID3, C4.5 등의 알고리즘에서 사용하는 entropy 값과 Gini impurity 비교
- 데이터셋 클래스 비율($p$)에 따른 수치 비교
p = np.arange(0.01, 1.0, 0.01)
plt.xlabel("p")
plt.ylabel("surprise factor")
plt.plot(p, -p * np.log2(p) - (1 - p) * np.log2(1 - p), label="Entropy");
plt.plot(p, -2*p*(p - 1), label="Gini impurity")
plt.legend();
Entropy와 스케일만 다를 뿐 동일한 형태를 갖는다.
Regression tree를 사용한 회귀 처리
Regression에서도 classification과 거의 비슷한 절차를 따라 트리를 생성한다.
단, 최적의 피쳐 및 분할지점을 결정하기 위해 Gini impurity 대신 mean squared error 등을 사용한다.
# Import the necessary modules and libraries
import matplotlib.pyplot as plt
import numpy as np
from sklearn.tree import DecisionTreeRegressor
# Create a random dataset
rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel()
y[::5] += 3 * (0.5 - rng.rand(16))
# Fit regression model
regr_1 = DecisionTreeRegressor(max_depth=2)
regr_2 = DecisionTreeRegressor(max_depth=5)
regr_1.fit(X, y)
regr_2.fit(X, y)
# Predict
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
y_1 = regr_1.predict(X_test)
y_2 = regr_2.predict(X_test)
# Plot the results
plt.figure()
plt.scatter(X, y, s=20, edgecolor="black", c="darkorange", label="data")
plt.plot(X_test, y_1, color="cornflowerblue", label="max_depth=2", linewidth=2)
plt.plot(X_test, y_2, color="yellowgreen", label="max_depth=5", linewidth=2)
plt.xlabel("data")
plt.ylabel("target")
plt.title("Decision Tree Regression")
plt.legend()
plt.show()
위 예제는 max_depth 2와 5로 학습한 트리에 0~5 사이의 값을 넣었을 때의 예측치를 파란색, 연두색 선으로 보여주고 있는데, leaf 노드의 평균값을 출력하는 것을 볼 수 있다.
트리 생성 예제
Scikit learn에는 DecisionTreeClassifier 및 DecisionTreeRegressor가 CART를 기반으로 구현되어 있다.
- DecisionTreeClassifier: Gini impurity를 사용하여 트리 생성 (생성자에서 critetion="entropy"
를 지정하면 Shannon 정보 이득을 사용할 수도 있음)
criterion : {"gini", "entropy", "log_loss"}, default="gini"
The function to measure the quality of a split. Supported criteria are
"gini" for the Gini impurity and "log_loss" and "entropy" both for the
Shannon information gain, see :ref:`tree_mathematical_formulation`.
- DecisionTreeRegressor: Mean squared error를 사용하여 트리 생성 (criterion
에 friedman_mse, absolute_error, poisson
지정 가능)
criterion : {"squared_error", "friedman_mse", "absolute_error", \
"poisson"}, default="squared_error"
The function to measure the quality of a split. Supported criteria
are "squared_error" for the mean squared error, which is equal to
variance reduction as feature selection criterion and minimizes the L2
loss using the mean of each terminal node, "friedman_mse", which uses
mean squared error with Friedman's improvement score for potential
splits, "absolute_error" for the mean absolute error, which minimizes
the L1 loss using the median of each terminal node, and "poisson" which
uses reduction in Poisson deviance to find splits.
Iris 트리
from sklearn.datasets import load_iris
from sklearn import tree
import matplotlib.pyplot as plt
iris = load_iris()
iris_tree = tree.DecisionTreeClassifier()
iris_tree.fit(iris.data, iris.target)
plt.figure(figsize=(15, 10))
tree.plot_tree(iris_tree,
feature_names=iris.feature_names,
class_names=list(iris.target_names),
filled=True, rounded=True)
plt.show()
iris 데이터로부터 트리를 학습하면 위와 같은 모습이 된다.
- 최초 데이터셋에 3개의 클래스에 50개씩 동수의 데이터 존재 (gini = 0.667)
- 루트 노드는 $petal width <= 0.8$ 을 기준으로 분할 (최소 gini impurity가 되는 속성 및 분할지점)
- gini = 0 이 될 때까지 트리 노드 분할
'AI,머신러닝' 카테고리의 다른 글
Gradient Boosting (XGBoost, LightGBM, CatBoost 비교) (0) | 2023.09.19 |
---|---|
SHAP (ML 모델 피쳐 중요도 측정) (0) | 2023.09.04 |
의사결정 트리(Decision Tree) - 피쳐 중요도(Feature Importance) 측정 (0) | 2023.09.01 |
의사결정 트리(Decision Tree) - C4.5 알고리즘 (0) | 2023.08.24 |
의사결정 트리 (Decision Tree) (0) | 2023.08.24 |