问题描述

今天使用tensorflow2.6 训练了一个模型,在使用predict_classes(x_test[0:10]) 预测类别的时候报错:(如下图)

AttributeError: 'Sequential' object has no attribute 'predict_classes


 

原因分析:

经过查找资料,我发现这个报错是因为 Tensorflow 版本的问题。

Tensorflow 2.6 版本,删除了predict_classes() 这个函数


解决方案:

目前有两个解决方法:

1 使用低版本的 Tensorflow,低于2.6版本即可(不推荐)。

2 使用np.argmax() 函数,如下代码即可。

predict_x = model.predict(x_test)
classes_x=np.argmax(predict_x,axis=1)
print(classes_x)

运行可以得到结果:

 

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐