leon 5 сар өмнө
parent
commit
ed7866852a
1 өөрчлөгдсөн 13 нэмэгдсэн , 4 устгасан
  1. 13 4
      src/resnet.cu

+ 13 - 4
src/resnet.cu

@@ -400,16 +400,25 @@ class InferImpl : public Infer {
                                  output_array_.gpu_bytes(), cudaMemcpyDeviceToHost, stream_));
     checkRuntime(cudaStreamSynchronize(stream_));
 
+    vector<Attribute> arrout(num_image);
+
     for (int ib = 0; ib < num_image; ++ib) {
       float *output_array_cpu = output_array_.cpu() + ib * num_classes_;
+      float max_score = 0.f;
+      int index = -1;
       for (int i = 0; i < num_classes_; i++)
       {
-        printf("prob : %f\t", *(output_array_cpu+i));
+        if (*(output_array_cpu+i) > max_score)
+        {
+            index = i;
+            max_score = *(output_array_cpu+i);
+        }
       }
+      arrout.emplace_back(max_score, index);
+    }
+    for (int ib = 0; ib < num_image; ++ib) {
+        std::cout << "score : " << arrout[i].confidence << " label : " << arrout[i].class_label << std::endl;
     }
-
-    vector<Attribute> arrout(num_image);
-
     return arrout;
   }
 };