问题描述

今天使用tensorflow2.6 训练了一个模型,在使用predict_classes(x_test[0:10]) 预测类别的时候报错:(如下图)

AttributeError: 'Sequential' object has no attribute 'predict_classes


 

原因分析:

经过查找资料,我发现这个报错是因为 Tensorflow 版本的问题。

Tensorflow 2.6 版本,删除了predict_classes() 这个函数


解决方案:

目前有两个解决方法:

1 使用低版本的 Tensorflow,低于2.6版本即可(不推荐)。

2 使用np.argmax() 函数,如下代码即可。

predict_x = model.predict(x_test)
classes_x=np.argmax(predict_x,axis=1)
print(classes_x)

运行可以得到结果:

 

Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐