C#使用TensorFlow.NET训练自己的数据集的方法

时间:2021-05-20

今天,我结合代码来详细介绍如何使用 SciSharp STACK 的 TensorFlow.NET 来训练CNN模型,该模型主要实现 图像的分类 ,可以直接移植该代码在 CPU 或 GPU 下使用,并针对你们自己本地的图像数据集进行训练和推理。TensorFlow.NET是基于 .NET Standard 框架的完整实现的TensorFlow,可以支持 .NET Framework 或 .NET CORE , TensorFlow.NET 为广大.NET开发者提供了完美的机器学习框架选择。

SciSharp STACK:https://github.com/SciSharp

什么是TensorFlow.NET?

TensorFlow.NET 是 SciSharp STACK

开源社区团队的贡献,其使命是打造一个完全属于.NET开发者自己的机器学习平台,特别对于C#开发人员来说,是一个“0”学习成本的机器学习平台,该平台集成了大量API和底层封装,力图使TensorFlow的Python代码风格和编程习惯可以无缝移植到.NET平台,下图是同样TF任务的Python实现和C#实现的语法相似度对比,从中读者基本可以略窥一二。

由于TensorFlow.NET在.NET平台的优秀性能,同时搭配SciSharp的NumSharp、SharpCV、Pandas.NET、Keras.NET、Matplotlib.Net等模块,可以完全脱离Python环境使用,目前已经被微软ML.NET官方的底层算法集成,并被谷歌写入TensorFlow官网教程推荐给全球开发者。

SciSharp 产品结构

微软 ML.NET底层集成算法

谷歌官方推荐.NET开发者使用

URL: https://pleteAdding(); }); foreach (var item in BlockC.GetConsumingEnumerable()) { sess.run(optimizer, (x, item.c_x), (y, item.c_y)); if (item.iter % display_freq == 0) { // Calculate and display the batch loss and accuracy var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, item.c_x), new FeedItem(y, item.c_y)); loss_val = result[0]; accuracy_val = result[1]; print("CNN:" + ($"iter {item.iter.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")} {sw.ElapsedMilliseconds}ms")); sw.Restart(); } } // Run validation after every epoch (loss_val, accuracy_val) = sess.run((loss, accuracy), (x, x_valid), (y, y_valid)); print("CNN:" + "---------------------------------------------------------"); print("CNN:" + $"gloabl steps: {sess.run(gloabl_steps) },learning rate: {sess.run(learning_rate)}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}"); print("CNN:" + "---------------------------------------------------------"); if (SaverBest) { if (accuracy_val > max_accuracy) { max_accuracy = accuracy_val; saver.save(sess, path_model + "\\CNN_Best"); print("CKPT Model is save."); } } else { saver.save(sess, path_model + string.Format("\\CNN_Epoch_{0}_Loss_{1}_Acc_{2}", epoch, loss_val, accuracy_val)); print("CKPT Model is save."); } } Write_Dictionary(path_model + "\\dic.txt", Dict_Label);}private void Write_Dictionary(string path, Dictionary<Int64, string> mydic){ FileStream fs = new FileStream(path, FileMode.Create); StreamWriter sw = new StreamWriter(fs); foreach (var d in mydic) { sw.Write(d.Key + "," + d.Value + "\r\n"); } sw.Flush(); sw.Close(); fs.Close(); print("Write_Dictionary");}private (NDArray, NDArray) Randomize(NDArray x, NDArray y){ var perm = np.random.permutation(y.shape[0]); np.random.shuffle(perm); return (x[perm], y[perm]);}private (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end){ var slice = new Slice(start, end); var x_batch = x[slice]; var y_batch = y[slice]; return (x_batch, y_batch);}private unsafe (NDArray, NDArray) GetNextBatch(Session sess, string[] x, NDArray y, int start, int end){ NDArray x_batch = np.zeros(end - start, img_h, img_w, n_channels); int n = 0; for (int i = start; i < end; i++) { NDArray img4 = cv2.imread(x[i], IMREAD_COLOR.IMREAD_GRAYSCALE); x_batch[n] = sess.run(normalized, (decodeJpeg, img4)); n++; } var slice = new Slice(start, end); var y_batch = y[slice]; return (x_batch, y_batch);}#endregion

测试集预测

训练完成的模型对test数据集进行预测,并统计准确率

计算图中增加了一个提取预测结果Top-1的概率的节点,最后测试集预测的时候可以把详细的预测数据进行输出,方便实际工程中进行调试和优化。

public void Test(Session sess){ (loss_test, accuracy_test) = sess.run((loss, accuracy), (x, x_test), (y, y_test)); print("CNN:" + "---------------------------------------------------------"); print("CNN:" + $"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}"); print("CNN:" + "---------------------------------------------------------"); (Test_Cls, Test_Data) = sess.run((cls_prediction, prob), (x, x_test));}private void TestDataOutput(){ for (int i = 0; i < ArrayLabel_Test.Length; i++) { Int64 real = ArrayLabel_Test[i]; int predict = (int)(Test_Cls[i]); var probability = Test_Data[i, predict]; string result = (real == predict) ? "OK" : "NG"; string fileName = ArrayFileName_Test[i]; string real_str = Dict_Label[real]; string predict_str = Dict_Label[predict]; print((i + 1).ToString() + "|" + "result:" + result + "|" + "real_str:" + real_str + "|" + "predict_str:" + predict_str + "|" + "probability:" + probability.GetSingle().ToString() + "|" + "fileName:" + fileName); }}

总结

本文主要是.NET下的TensorFlow在实际工业现场视觉检测项目中的应用,使用SciSharp的TensorFlow.NET构建了简单的CNN图像分类模型,该模型包含输入层、卷积与池化层、扁平化层、全连接层和输出层,这些层都是CNN分类模型的必要的层,针对工业现场的实际图像进行了分类,分类准确性较高。

完整代码可以直接用于大家自己的数据集进行训练,已经在工业现场经过大量测试,可以在GPU或CPU环境下运行,只需要更换tensorflow.dll文件即可实现训练环境的切换。

同时,训练完成的模型文件,可以使用 “CKPT+Meta” 或 冻结成“PB” 2种方式,进行现场的部署,模型部署和现场应用推理可以全部在.NET平台下进行,实现工业现场程序的无缝对接。摆脱了以往Python下 需要通过Flask搭建服务器进行数据通讯交互 的方式,现场部署应用时无需配置Python和TensorFlow的环境【无需对工业现场的原有PC升级安装一大堆环境】,整个过程全部使用传统的.NET的DLL引用的方式。

到此这篇关于C#使用TensorFlow.NET训练自己的数据集的方法的文章就介绍到这了,更多相关C# TensorFlow.NET训练数据集内容请搜索以前的文章或继续浏览下面的相关文章希望大家以后多多支持!

声明:本页内容来源网络,仅供用户参考;我单位不保证亦不表示资料全面及准确无误,也不保证亦不表示这些资料为最新信息,如因任何原因,本网内容或者用户因倚赖本网内容造成任何损失或损害,我单位将不会负任何法律责任。如涉及版权问题,请提交至online#300.cn邮箱联系删除。

相关文章