资讯详情

Tensorflow2训练Fer2013数据集

文件主要分为Train和Model两部分,可以分开两个文件,我在一起。

第三方工具:

tensorflow_gpu == 2.6.0

pandas == 1.3.2

numpy == 1.19.5

附件:Fer2013数据集

模型文件

有两个模型很简单CNN层,两个FC层

# ----------------------模型---------------------------------------- class CNN(tf.keras.Model):     def __init__(self, num_class, keep_prob):         super(CNN, self).__init__()          self.conv1 = tf.keras.layers.Conv2D(filters=32, kernel_size=5, strides=1, use_bias=True, padding='same')         self.conv1_act = tf.keras.layers.Activation('relu')         self.conv1_pool = tf.keras.layers.MaxPool2D(pool_size=2, strides=2, padding='same')          self.conv2 = tf.keras.layers.Conv2D(64, 5, strides=2, use_bias=True, padding='same')         self.conv2_act = tf.keras.layers.Activation('relu')         self.conv2_pool = tf.keras.layers.MaxPool2D(pool_size=2, strides=2, padding='same')          self.flat = tf.keras.layers.Flatten()          self.dense1 = tf.keras.layers.Dense(1024, activation='relu', use_bias=True)         self.dense1_act = tf.keras.layers.Activation('relu')          self.drop = tf.keras.layers.Dropout(rate=keep_prob)          self.dense2 = tf.keras.layers.Dense(num_class, use_bias=True)      def call(self, inputs):         x = self.conv1(inputs)         x = self.conv1_act(x)         x = self.conv1_pool(x)          x = self.conv2(x)         x = self.conv2_act(x)         x = self.conv2_pool(x)          x = self.flat(x)          x = self.dense1(x)         x = self.dense1_act(x)         x = self.drop(x)         x = self.dense2(x)         return x

训练文件

由于fer2013文件是csv文件、思路:

1.使用pandas读取csv文件

2.分为特征和标签

3.将pixel转换为所需类型

4.转换成tensorflow的Dateset

5.提取部分数据集

# ---------------------------参数--------------------- model_path = r'./checkpoint/emotion_analysis.ckpt'  # path where to save the trained model penalty_parameter = 0.02  # the SVM C penalty parameter log_path = r'/logs/'  # path where to save the TensorBoard logs num_classes = 7 dropout_rate = 0.5  # dropout batch_size = 128 epoch = 100 lr = 3e-4 weight_decay = 1e-4 data_path = r'../Datasets/archive/fer2013.csv'  AUTOTUNE = tf.data.experimental.AUTOTUNE np.set_printoptions(precision=3, suppress=True)  # 读取csv文件 df = pd.read_csv(filepath_or_buffer=data_path, usecols=["emotion", "pixels"], dtype={"pixels": str}) fer_pixels = df.copy()  # 分为特征和标签 fer_label = fer_pixels.pop('emotion') fer_pixels = np.asarray(fer_pixels)  # 将特征转换为模型所需的类型 fer_train = [] for i in range(len(fer_label)):     pixels_new = np.asarray([float(p) for p in fer_pixels[i][0].split()]).reshape([48,48,1])     fer_train.append(pixels_new) fer_train = np.asarray(fer_train) fer_label = np.asarray(fer_label)  # 转换为tf.Dateset类型 dataset = tf.data.Dataset.from_tensor_slices((fer_train, fer_label))  # 数据集验证集测试集的拆分 train_dataset = dataset.take(1000)  # 为了测试,这里只使用了1000张图片。 test_dataset = dataset.skip(32297)  # 打乱 train_dataset = (train_dataset.cache().shuffle(5 * batch_size).batch(batch_size).prefetch(AUTOTUNE))  # 训练 strategy = tf.distribute.MirroredStrategy() with strategy.scope():     model = CNN(num_class=num_classes, keep_prob=dropout_rate)     model.compile(loss=tf.keras.losses.Hinge(),                   optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),                   metrics=['accuracy'])     # 断点续训     if os.path.exists(model_path '/saved_model.pb'):         print('-加载模型-')         model = tf.keras.models.load_model(model_path)#加载model     cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=model_path,                                                      save_weights_only=False,                                                      monitor='val_accuracy',                                                      model='max',                                                      save_best_only=True)     # 训练     history = model.fit(x=train_dataset, epochs=epoch,callbacks=[cp_callback])     model.summary()

第一段代码的模型文件可以更改

标签: fer连接电缆meto

锐单商城拥有海量元器件数据手册IC替代型号,打造 电子元器件IC百科大全!

锐单商城 - 一站式电子元器件采购平台