티스토리 뷰

Linear regression


대한민국 국민들의 연봉에 관한 데이터가 있을 때, 각 개인의 키에 따른 연봉을 본다고 생각해 보자. 이런 개인의 특징을 이용하여 연봉을 예측하고자 할때 가장 기본적으로 사용할 수 있는 모델이 Linear regression 이다. 이는 통계학에서 사용되는 것과 동일한 개념이다. 통계학에서는 모델의 유의성, 변수의 중요도 등에 초점을 맞추는 반면 머신 러닝에서는 예측 자체를 위한 알고리즘에 초점을 맞춘다.


위의 문제에서는 키와 연봉에 관한 데이터가 하나의 데이터가 되며 그 모임이 모델을 만들때 사용되는 트레이닝 셋이다. 예측에 이용하는 feature가 하나이면 univariate, 여러 개이면 multivariate 이며 각 feature의 선형 결합에 의해 모델을 구성하면 linear regression이다. 


모델 구축을 위해서 가장 간단한 방법으로 gradient descent 알고리즘을 이용한다. 모델과 실제 데이터 값의 차이를 계산하는 cost function을 minimize하는 방향으로 알고리즘이 진행되며 partial derivative를 이용해 최소값을 찾아나가는 과정과 동일하다. 


gradient descent는 feature의 수 만큼 iteration이 필요한 데, 모델을 찾는 과정은 feature가 한개라면 결국 2차 함수에서 최소 값을 찾는 것과 동일하다. 2차함수에서 바로 미분을 통해 최소값을 찾을 수 있듯이, 수학적으로 이것도 풀 수가 있는데 design matrix X를 구성하고 theta = (X^T * X )^-1 * X  *Y 를 풀면 된다. 이 방법을 normal equation이라고 하는 데, 바로 풀 수 있어 속도가 빠르지만 트레이닝 셋의 크기가 커지면 연산에 matrix의 product와 inverse가 들어가기 때문에 computational cost가 매우 크다.

따라서 n이 10000 정도 까지는 괜찮으나, 더 커지면 gradient descent를 쓰는 것이 낫다.


linear regression은 매우 단순한 모델이어 예측을 위한 기계학습에서는 크게 사용되지 않지만 기본이 제일 중요하므로 잘 알아두자. (너말이야 너)




참조

[1] Machine Learning by Andrew Ng, Coursera. http://coursera.org

댓글