드레:
코딩 뿌시기
드레:
전체 방문자
오늘
어제
  • 분류 전체보기 (268)
    • Python (74)
      • Python 기초 (42)
      • Numpy (8)
      • Pandas (22)
    • Machine Learning (31)
      • Machine Learning (1)
      • Deep Learning (27)
    • AWS (22)
      • RDS (3)
      • EC2 (9)
      • Lambda (8)
      • S3 (2)
    • MySQL (24)
    • Git (8)
    • Streamlit (12)
    • REST API (22)
    • Java (24)
    • Android (36)
    • Debugging (15)

블로그 메뉴

  • 홈
  • 태그
  • 방명록

공지사항

인기 글

태그

  • tensorflow
  • API
  • Retrofit2
  • fine tuning
  • rest api
  • 딥러닝
  • serverless
  • Python
  • 깃이그노어
  • EC2
  • CNN
  • volley
  • Callback
  • github
  • 네이버 API
  • 안드로이드 스튜디오
  • GET
  • Streamlit
  • Transfer Learning
  • Ann
  • AWS
  • aws s3
  • flask
  • JWT
  • AWS Lambda
  • pandas
  • Lambda
  • 서버리스
  • 액션바
  • Java

최근 댓글

최근 글

티스토리

hELLO · Designed By 정상우.
드레:

코딩 뿌시기

Tensorflow - 이진 분류 문제의 인공신경망(ANN)
Machine Learning/Deep Learning

Tensorflow - 이진 분류 문제의 인공신경망(ANN)

2022. 12. 28. 18:48

Churn_Modelling.csv
0.64MB

위 데이터를 사용해 금융상품 갱신 여부를 예측하는 ANN을 만들어보자.

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
df = pd.read_csv('Churn_Modelling.csv')

 

 

1. nan이 있는지 확인

df.isna().sum()

 

 

2. 학습데이터 X와 정답데이터 y로 분리

X = df.loc[: , 'CreditScore':'EstimatedSalary']
y = df['Exited']

 

 

3. 범주형 문자열 데이터를 숫자로 바꿔준다.

X['Gender'].unique()
X['Geography'].unique()
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.compose import ColumnTransformer

# 성별은 카테고리가 2개이므로 레이블 인코딩
label_encoder = LabelEncoder()
X['Gender'] = label_encoder.fit_transform(X['Gender'])

# 국가는 카테고리가 3개이므로 원핫 인코딩
# 주의: ColumnTransformer를 하면 해당 컬럼들이 가장 앞에 위치한다.
ct = ColumnTransformer( [ ('encoder', OneHotEncoder(), [1]) ],
                       remainder='passthrough')
X = ct.fit_transform(X.values)

 

 

4. dummy variable trap을 피하기 위해 첫 번째 컬럼을 제거

참고: https://donghyeok90.tistory.com/148

X_df = pd.DataFrame(X)
X_df.drop(0, axis=1, inplace=True)
X = X_df.values

 

 

5. feature scaling

from sklearn.preprocessing import MinMaxScaler

scaler_X = MinMaxScaler()
X = scaler_X.fit_transform(X)

 

 

6. train, test 데이터 분리

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

 

 

7. 인공신경망 모델 연결하기

import tensorflow as tf
from tensorflow import keras
from keras.models import Sequential
from keras.layers import Dense

# 비어있는 틀을 만든다.
model = Sequential() 
# layer 추가(units: 노드갯수, activation: 활성화함수)
model.add( Dense(units=6, activation='relu', input_shape=(11,)) ) 
model.add( Dense(units=8, activation=tf.nn.relu) )
# 두 가지로 분류하는 모델의 출력층 활성화함수는 sigmoid를 사용한다.
model.add( Dense(units=1, activation='sigmoid') )
  • 첫 번째 층에는 input_shape을 지정해줘야 한다. ( X의 컬럼 갯수)
  • 두 가지로 분류하는 모델의 출력층 활성화함수는 sigmoid를 사용한다.

 

  • 모델링이 끝나면 컴파일(Compile)을 해야 한다.
    - 컴파일이란, 옵티마이저(optimizer)와 로스펑션(loss function 오차함수, 손실함수), 검증방법을 셋팅
  • 두 가지로 분류하는 모델의 loss function은 binary_crossentropy로 설정한다.
model.compile( optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'] )

 

 

 

 

8. 모델 학습

 

model.fit(X_train, y_train, batch_size= 10, epochs= 20)

 

9. 모델 평가

 

# 학습이 끝나면, 평가를 한다.
model.evaluate(X_test, y_test)

 

# 컨퓨전 매트릭스, 정확도 확인
from sklearn.metrics import confusion_matrix, accuracy_score

y_pred = model.predict(X_test)
# 확률로 된 예측값을 0 or 1로 바꿔준다.
y_pred = (y_pred > 0.5).astype(int)
cm = confusion_matrix(y_test, y_pred)
accuracy_score(y_test, y_pred)

 

 

※ 모델이 학습한 가중치 확인하는 법

model.layers[인덱스].weights

 

'Machine Learning > Deep Learning' 카테고리의 다른 글

Tensorflow - Grid Search 를 이용한, 최적의 하이퍼 파라미터 찾기  (0) 2022.12.28
Tensorflow에서 학습시 batch size, step, epoch 란?  (0) 2022.12.28
원핫인코딩 할 시 주의할 점 - Dummy Variable Trap  (0) 2022.12.28
Google Colab - 드라이브 연결하기(마운트)  (0) 2022.12.28
Back propagation(오차 역전파)  (0) 2022.12.27
    'Machine Learning/Deep Learning' 카테고리의 다른 글
    • Tensorflow - Grid Search 를 이용한, 최적의 하이퍼 파라미터 찾기
    • Tensorflow에서 학습시 batch size, step, epoch 란?
    • 원핫인코딩 할 시 주의할 점 - Dummy Variable Trap
    • Google Colab - 드라이브 연결하기(마운트)
    드레:
    드레:

    티스토리툴바