비버놀로지

[회귀(Regression)] 당뇨 진행상황을 선형 회귀로 예측 해보기 본문

인공지능 머신러닝

[회귀(Regression)] 당뇨 진행상황을 선형 회귀로 예측 해보기

KUNDUZ 2022. 11. 6. 15:38
728x90

당뇨 진행상황을 선형 회귀로 예측 해보기

선형 회귀는 여러 입력 변수(X)를 바탕으로 Y를 예측하는 여러 과제에 적용이 가능합니다.

이번 시간에는 sklearn에서 제공하는 데이터셋 중 하나인 당뇨(diabetes) 데이터셋에 지금까지 학습한 다양한 회귀 알고리즘을 적용하여, 가장 높은 성능을 도출하는 알고리즘을 확인해보겠습니다.

데이터 준비를 위한 사이킷런 함수/라이브러리

  • from sklearn.datasets import load_diabetes : 사이킷런의 당뇨(diabetes) 데이터를 불러옵니다.
  • load_diabetes(return_X_y = True) : (X, y) 형태의 당뇨(diabetes) 데이터를 반환합니다.
 

실습

  1. 당뇨병 데이터셋을 불러오고, 학습을 위한 준비를 진행하는 load_data() 함수를 구현합니다.
    • 불러온 데이터 셋의 변수 개수, 데이터의 형태 등을 살펴봅니다.
  2. 원하는 회귀 모델을 불러오고, 테스트 데이터에 대한 예측 결과를 반환하는 reg_model()함수를 구현합니다.
  3. 구현한 회귀 모델의 설명력을 표현하는 R2 값을 반환하는 r_square() 함수를 구현합니다.
  4. 구현한 함수들을 활용하여 당뇨병 데이터에 대한 회귀를 진행하는 main() 함수를 완성합니다.
  • 다양한 회귀 알고리즘을 적용해보고, 가장 높은 성능을 도출하는 회귀 알고리즘을 선택하여 R2값을 0.5 이상으로 높여 제출합니다.

Tips!

미션을 진행하며 필요한 모듈을 직접 import 해봅니다.

 

import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_diabetes

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression

from sklearn.metrics import r2_score
"""
    1. 학습을 위한 준비가 완료된 데이터를 반환하는
       load_data() 함수를 구현합니다.
       
       Step01. 당뇨병 관련 데이터셋을 (X, y)의 형태로 불러옵니다. 
              
       Step02. 모델 학습을 위해 데이터를 
               학습용(80%)/테스트용(20%)로 분리합니다.
               (random_state = 100)
"""

def load_data():
    X, y = load_diabetes(return_X_y = True)
    
    print(X)
    print(y)
    
    train_X, test_X, train_y, test_y = train_test_split(X, y, test_size=0.2, random_state = 100)
    
    return train_X, test_X, train_y, test_y
    
    
"""
    2. 회귀 모델을 불러오고,
       테스트 데이터에 대한 예측 결과를 반환하는
       reg_model() 함수를 구현합니다.
"""
def reg_model(train_X, test_X, train_y):
    
        
    lr = LinearRegression()
    
    lr.fit(train_X,train_y)

    return lr.predict(test_X)
    

"""
    3. 구현한 회귀 모델의 r_square 값을
       반환하는 r_square() 함수를 구현합니다.

"""
def r_square(pred, test_y):
    
    
    return r2_score(test_y, pred)
    
    
"""
    4. 구현한 함수들을 활용하여 
       당뇨병 데이터에 대한 회귀를 진행하는 
       main() 함수를 구현합니다.
"""

def main():
    
    train_X, test_X, train_y, test_y = load_data()
    
    predicted = reg_model(train_X, test_X, train_y)
    
    r2 = r_square(predicted,test_y) 

    print("r2 score : ",r2)
        
    return r2
    
if __name__ == "__main__":
    main()

 

 

 

728x90
Comments