선형 모델을 이용한 n차 다항식의 회귀

선형 모델은 1차 다항식인 직선에 대한 모델 만을 예측할 수 있습니다. 그렇다면 직선이 아닌 곡선, 즉 2차 다항식 이상의 모델을 예측하기 위해서는 선형 모델을 사용할 수 없다고 생각할 수 있습니다. 하지만 생각과는 다르게 선형 모델로도 2차 다항식 이상의 모델도 예측할 수 있는데, 이는 약간의 발상의 전환이 필요합니다. 즉, 선형 모델의 경우 특징 변수가 1차로 다항식으로만 구성되어 있습니다. 만약 x^2와 같은 거듭제곱인 2차식의 경우일때 더 이상 선형 모델이 아니게 되지만, x^2을 z라는 1차 다항식으로 취급하게 되면 선형 모델로도 2차 이상의 다항식도 회귀분석이 가능합니다.

다음의 코드는 3차 다항식에 대한 회귀분석에 대한 코드입니다.

import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression

np.random.seed(3224)
m = 100
X = 10 * np.random.rand(m, 1) - 5
y = (-0.8 * X**3) + (0.5 * X**2) + (2 * X) - 3 + (np.random.randn(m, 1) * 10)

poly_features = PolynomialFeatures(degree=3, include_bias=False)
X_poly = poly_features.fit_transform(X)

model = LinearRegression()
model.fit(X_poly, y)

X_new = np.linspace(-5, 5, 20).reshape(20, 1)
X_new_poly = poly_features.transform(X_new)
y_new = model.predict(X_new_poly)

plt.plot(X, y, "b.")
plt.plot(X_new, y_new, "r-")
plt.show()

결과는 다음과 같습니다.

코드를 살펴보면, 6-9는 잡음이 섞인 샘플 데이터는 3차 다항식의 형태로 구성합니다. 11-12는 1개의 특성을 2차항과 3차항에 대한 독립적인 특성을 추가로 생성해 줍니다. 즉, 특성값이 2라면 4와 8이라는 특성값이 생성됩니다.

분석가 관점에서 데이터를 개략적으로 살펴보기

수집된 데이터를 활용하여 AI 학습하기에 앞서 가장 먼저 해야 할 것은 수집된 데이터를 개략적으로 살펴보는 일입니다. 이 글은 간단하지만 의미있는 데이터셋을 개략적으로 살펴보는 것에 대한 내용을 살펴봅니다.

간단하지만 의미있는 데이터셋은 Kaggle에서 제공하는 전복(Abalone) 데이터셋이며 다운로드 받은 파일은 CSV 형식으로 파일을 열어 그 일부를 보면 다음과 같습니다.

내용을 보면 일반적인 첫줄에 컬럼명이 아닌 바로 데이터값으로 시작하는 것과 총 9가지의 컬럼값으로 구성되어 있다는 것을 파악할 수 있습니다.

이제 이 데이터를 파이선을 통해 개략적으로 살펴보도록 하겠습니다.

pandas를 사용하여 파일을 불러오는 코드로 시작합니다.

import pandas as pd

raw_data = pd.read_csv('./datasets/datasets_1495_2672_abalone.data.csv', 
        names=['sex', 'tall', 'radius', 'height', 'weg1', 'weg2', 'weg3', 'weg4', 'ring_cnt'])

데이터에 컬럼 정보가 없으므로 names 인자를 통해 컬럼의 의미를 파악할 수 있으면서 식별자로 사용할 수 있는 이름을 지정해 줍니다. 총 9개인데, 각각의 의미는 ‘성별’, ‘키’, ‘지름’, ‘높이’, ‘전체무게’, ‘몸통무게’, ‘내장무게’, ‘껍질무게’, ‘껍질의고리수’입니다.

개략적인 내용 파익으로 이 데이터셋의 실제 내용 중 시작부분을 살펴보는 코드입니다.

print(raw_data.head())

결과는 다음과 같습니다.

다음은 전체적인 데이터의 구성을 살펴보는 코드입니다.

print(raw_data.info())

결과는 다음과 같은데, 총 4177개의 전복 데이터가 있으며 각 컬럼 데이터의 타입과 Null 값이 아닌 데이터의 개수 정보를 파악할 수 있습니다. sex 컬럼의 데이터 타입은 object인데, 이는 문자열이기 때문입니다.

앞서 sex가 문자열인데, 이는 전복의 성별값입니다. I는 유충, M은 수컷, F는 암컷인데, 이 sex에 대한 정보를 좀더 살펴보기 위한 코드입니다.

print(raw_data['sex'].value_counts())

결과는 다음과 같은데, 수컷(M)이 1528개, I가 1342개, F가 1307라는 것을 알 수 있습니다.

다음은 데이터에 대한 간단한 통계를 확인하기 위한 코드입니다.

print(raw_data.describe())

결과는 다음과 같습니다.

각 컬럼에 대한 데이터수, 평균, 편차, 최대값, 최소값, 25%/50%/75%에 대한 백분위수(Percentile)가 제공됩니다.

끝으로 각 컬럼에 대한 히스토그램을 살펴보는 코드입니다.

import matplotlib.pyplot as plt

raw_data.hist(bins=10)
plt.show()

결과는 다음과 같습니다.