leon 5 月之前
父节点
当前提交
c41f92e4fe
共有 1 个文件被更改,包括 2 次插入2 次删除
  1. 2 2
      src/resnet.cu

+ 2 - 2
src/resnet.cu

@@ -291,7 +291,7 @@ class InferImpl : public Infer {
   string engine_file_;
   vector<shared_ptr<trt::Memory<unsigned char>>> preprocess_buffers_;
   trt::Memory<float> input_buffer_, output_array_;
-  trt::Memory<float> classes_indexes_;
+  trt::Memory<int> classes_indices_;
   int network_input_width_, network_input_height_;
   Norm normalize_;
   int num_classes_ = 0;
@@ -413,7 +413,7 @@ class InferImpl : public Infer {
 
     for (int ib = 0; ib < num_image; ++ib) {
       float *output_array_device = output_array_.gpu() + ib * num_classes_;
-      float *classes_indices_device = classes_indices_.gpu() + ib;
+      int *classes_indices_device = classes_indices_.gpu() + ib;
       classfier_softmax(output_array_device, num_classes_, classes_indices_device, stream_);
     }