教育行業(yè)A股IPO第一股(股票代碼 003032)

全國(guó)咨詢/投訴熱線:400-618-4000

tf.keras是什么?tf.keras怎樣實(shí)現(xiàn)深度學(xué)習(xí)?

更新時(shí)間:2022年02月12日11時(shí)53分 來(lái)源:傳智教育 瀏覽次數(shù):

tf.keras是TensorFlow 2.0的高階API接口,為TensorFlow的代碼提供了新的風(fēng)格和設(shè)計(jì)模式,大大提升了TF代碼的簡(jiǎn)潔性和復(fù)用性,官方也推薦使用tf.keras來(lái)進(jìn)行模型設(shè)計(jì)和開發(fā)。

tf.keras是什么

常用模塊

tf.keras中常用模塊如下表所示:


模塊 概述
activations 激活函數(shù)
applications 預(yù)訓(xùn)練網(wǎng)絡(luò)模塊
Callbacks 在模型訓(xùn)練期間被調(diào)用
datasets tf.keras數(shù)據(jù)集模塊,包括boston_housing,cifar10,fashion_mnist,imdb ,mnist
layers Keras層API
losses 各種損失函數(shù)
metircs 各種評(píng)價(jià)指標(biāo)
models 模型創(chuàng)建模塊,以及與模型相關(guān)的API
optimizers 優(yōu)化方法
preprocessing Keras數(shù)據(jù)的預(yù)處理模塊
regularizers 正則化,L1,L2等
utils 輔助功能實(shí)現(xiàn)

常用方法

深度學(xué)習(xí)實(shí)現(xiàn)的主要流程:1.數(shù)據(jù)獲取,2,數(shù)據(jù)處理,3.模型創(chuàng)建與訓(xùn)練,4 模型測(cè)試與評(píng)估,5.模型預(yù)測(cè)。

深度學(xué)習(xí)常用方法

1.導(dǎo)入tf.keras

使用 tf.keras,首先需要在代碼開始時(shí)導(dǎo)入tf.keras。

import tensorflow as tf
from tensorflow import keras

2.數(shù)據(jù)輸入

對(duì)于小的數(shù)據(jù)集,可以直接使用numpy格式的數(shù)據(jù)進(jìn)行訓(xùn)練、評(píng)估模型,對(duì)于大型數(shù)據(jù)集或者要進(jìn)行跨設(shè)備訓(xùn)練時(shí)使用tf.data.datasets來(lái)進(jìn)行數(shù)據(jù)輸入。

3.模型構(gòu)建

  • 簡(jiǎn)單模型使用Sequential進(jìn)行構(gòu)建
  • 復(fù)雜模型使用函數(shù)式編程來(lái)構(gòu)建
  • 自定義layers

4.訓(xùn)練與評(píng)估

  • 配置訓(xùn)練過(guò)程:
# 配置優(yōu)化方法,損失函數(shù)和評(píng)價(jià)指標(biāo)
model.compile(optimizer=tf.train.AdamOptimizer(0.001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

模型訓(xùn)練:

# 指明訓(xùn)練數(shù)據(jù)集,訓(xùn)練epoch,批次大小和驗(yàn)證集數(shù)據(jù)model.fit/fit_generator(dataset, epochs=10, 
                        batch_size=3,
          validation_data=val_dataset,
          )

模型評(píng)估:

# 指明評(píng)估數(shù)據(jù)集和批次大小
model.evaluate(x, y, batch_size=32)

模型預(yù)測(cè):

# 對(duì)新的樣本進(jìn)行預(yù)測(cè)
model.predict(x, batch_size=32)

5.回調(diào)函數(shù)(callbacks)

回調(diào)函數(shù)用在模型訓(xùn)練過(guò)程中,來(lái)控制模型訓(xùn)練行為,可以自定義回調(diào)函數(shù),也可使用tf.keras.callbacks 內(nèi)置的 callback :

ModelCheckpoint:定期保存 checkpoints。 LearningRateScheduler:動(dòng)態(tài)改變學(xué)習(xí)速率。 EarlyStopping:當(dāng)驗(yàn)證集上的性能不再提高時(shí),終止訓(xùn)練。 TensorBoard:使用 TensorBoard 監(jiān)測(cè)模型的狀態(tài)。

6.模型的保存和恢復(fù)

只保存參數(shù):

# 只保存模型的權(quán)重
model.save_weights('./my_model')
# 加載模型的權(quán)重
model.load_weights('my_model')
保存整個(gè)模型:
# 保存模型架構(gòu)與權(quán)重在h5文件中
model.save('my_model.h5')
# 加載模型:包括架構(gòu)和對(duì)應(yīng)的權(quán)重
model = keras.models.load_model('my_model.h5')






猜你喜歡:

什么是深度學(xué)習(xí)?深度學(xué)習(xí)各層負(fù)責(zé)什么內(nèi)容?

OpenCV圖片相加和混合的方法【人工智能基礎(chǔ)】

機(jī)器學(xué)習(xí)中入門級(jí)必學(xué)的算法有哪些?

Numpy數(shù)組操作教程【傳智教育】

傳智教育Ai人工智能軟件工程師培訓(xùn)

0 分享到:
和我們?cè)诰€交談!