드레:
코딩 뿌시기
드레:
전체 방문자
오늘
어제
  • 분류 전체보기 (268)
    • Python (74)
      • Python 기초 (42)
      • Numpy (8)
      • Pandas (22)
    • Machine Learning (31)
      • Machine Learning (1)
      • Deep Learning (27)
    • AWS (22)
      • RDS (3)
      • EC2 (9)
      • Lambda (8)
      • S3 (2)
    • MySQL (24)
    • Git (8)
    • Streamlit (12)
    • REST API (22)
    • Java (24)
    • Android (36)
    • Debugging (15)

블로그 메뉴

  • 홈
  • 태그
  • 방명록

공지사항

인기 글

태그

  • AWS
  • 딥러닝
  • GET
  • Transfer Learning
  • JWT
  • Lambda
  • 안드로이드 스튜디오
  • pandas
  • aws s3
  • volley
  • flask
  • 깃이그노어
  • tensorflow
  • Java
  • Callback
  • fine tuning
  • Streamlit
  • Retrofit2
  • API
  • CNN
  • AWS Lambda
  • 네이버 API
  • rest api
  • EC2
  • github
  • 서버리스
  • Python
  • 액션바
  • serverless
  • Ann

최근 댓글

최근 글

티스토리

hELLO · Designed By 정상우.
드레:

코딩 뿌시기

Tensorflow - model을 파일로 저장하고 불러오는 방법
Machine Learning/Deep Learning

Tensorflow - model을 파일로 저장하고 불러오는 방법

2022. 12. 30. 18:04

1. 전체 모델구조(Network)와  가중치(Weight) 함께 저장하기

 

  • 모델 + 가중치를 저장해서 불러오면 모델에 complie을 할 필요없이 바로 사용 가능하다.

 

폴더구조로 저장, 불러오기

# 폴더구조로 저장
model.save('fashion_mnist_model')

# 불러오기
model2 = tf.keras.models.load_model('fashion_mnist_model')

 

 

파일구조로 저장, 불러오기

# 모델을 h5파일 하나로 저장
model.save('fashion_mnist_model.h5')

# 불러오기
model3 = tf.keras.models.load_model('fashion_mnist_model.h5')

 

 

 

2. 모델 구조(Network)만 저장하고 불러오기

 

# 네트워크를 json 파일로 저장하는 코드
fashin_mnist_network = model.to_json()
with open('fashin_mnist_network.json', 'w') as file:
  file.write(fashin_mnist_network)
  
# 저장된 네트워크를 읽어오는 코드
with open('fashin_mnist_network.json', 'r') as file:
  fashion_net = file.read()
  
# 위의 네트워크로부터 모델을 만들고 싶으면
model4 = tf.keras.models.model_from_json(fashion_net)

주의!!!

model4는 네트워크만 가져온 것이지, 학습 완료된 weight는 가져온 것이 아니다.

따라서 현재 weight는 랜덤으로 세팅된 weight다.

이 모델로 바로 예측 수행을 하면 안되고 가중치를 불러와야 한다.

 

 

 

3.  가중치(weight)만 저장하고 불러오기

 

# 웨이트를 h5파일로 저장
model.save_weights('fashion_mnist_weight.h5')

# 위의 model4에 웨이트를 load
model4.load_weights('fashion_mnist_weight.h5')

네트워크와 가중치를 따로 불러온 경우, compile을 해줘야 사용 가능하다!

(model4.predict는 compile을 안 해도 사용 가능하다.)

 

# 컴파일 예시
model4.compile('adam', 'sparse_categorical_crossentropy', ['accuracy'])

 

'Machine Learning > Deep Learning' 카테고리의 다른 글

Convolution Neural Network(CNN, 합성곱 신경망) 개념  (0) 2022.12.31
Tensorflow - 레이블인코딩 된 y값을 원핫인코딩으로 바꾸기  (0) 2022.12.31
Tensorflow - 이미지를 1차원으로 만드는 방법(Flatten)  (0) 2022.12.30
Tensorflow - Callback 클래스를 이용해서, 원하는 조건이 되면 학습을 멈추기  (0) 2022.12.29
Tensorflow - Dropout 사용법  (0) 2022.12.29
    'Machine Learning/Deep Learning' 카테고리의 다른 글
    • Convolution Neural Network(CNN, 합성곱 신경망) 개념
    • Tensorflow - 레이블인코딩 된 y값을 원핫인코딩으로 바꾸기
    • Tensorflow - 이미지를 1차원으로 만드는 방법(Flatten)
    • Tensorflow - Callback 클래스를 이용해서, 원하는 조건이 되면 학습을 멈추기
    드레:
    드레:

    티스토리툴바