diff --git a/ext/opencv/facerecognizer.cpp b/ext/opencv/facerecognizer.cpp index 4ecce45..371f843 100644 --- a/ext/opencv/facerecognizer.cpp +++ b/ext/opencv/facerecognizer.cpp @@ -106,6 +106,30 @@ rb_predict(VALUE self, VALUE src) return INT2NUM(label); } +/* + * call-seq: + * predict_with_confidence(src) + * + * Predicts a label and associated confidence (e.g. distance) for a given input image. + */ +VALUE +rb_predict_with_confidence(VALUE self, VALUE src) +{ + cv::Mat mat = cv::Mat(CVMAT_WITH_CHECK(src)); + cv::FaceRecognizer *self_ptr = FACERECOGNIZER(self); + int label; + double confidence; + try { + self_ptr->predict(mat, label, confidence); + } + catch (cv::Exception& e) { + raise_cverror(e); + } + + return rb_ary_new3(2, INT2NUM(label), DBL2NUM(confidence)); +} + + /* * call-seq: * save(filename) @@ -164,6 +188,7 @@ define_ruby_class() rb_klass = rb_define_class_under(opencv, "FaceRecognizer", cAlgorithm::rb_class()); rb_define_method(rb_klass, "train", RUBY_METHOD_FUNC(rb_train), 2); rb_define_method(rb_klass, "predict", RUBY_METHOD_FUNC(rb_predict), 1); + rb_define_method(rb_klass, "predict_with_confidence", RUBY_METHOD_FUNC(rb_predict_with_confidence), 1); rb_define_method(rb_klass, "save", RUBY_METHOD_FUNC(rb_save), 1); rb_define_method(rb_klass, "load", RUBY_METHOD_FUNC(rb_load), 1); } diff --git a/test/test_eigenfaces.rb b/test/test_eigenfaces.rb index 4e3453e..621891e 100755 --- a/test/test_eigenfaces.rb +++ b/test/test_eigenfaces.rb @@ -55,6 +55,19 @@ class TestEigenFaces < OpenCVTestCase } end + def test_predict_with_confidence + img = CvMat.load(FILENAME_LENA256x256, CV_LOAD_IMAGE_GRAYSCALE) + label = 1 + @eigenfaces.train([img], [label]) + lbl, conf = @eigenfaces.predict_with_confidence(img) + assert_equal(label, lbl) + assert_equal(0.0, conf) + + assert_raise(TypeError) { + @eigenfaces.predict_with_confidence(DUMMY_OBJ) + } + end + def test_save img = CvMat.load(FILENAME_LENA256x256, CV_LOAD_IMAGE_GRAYSCALE) label = 1 diff --git a/test/test_fisherfaces.rb b/test/test_fisherfaces.rb index 6849c13..8a9c154 100755 --- a/test/test_fisherfaces.rb +++ b/test/test_fisherfaces.rb @@ -52,6 +52,17 @@ class TestFisherFaces < OpenCVTestCase } end + def test_predict_with_confidence + label = 1 + lbl, conf = @fisherfaces_trained.predict_with_confidence(@images[0]) + assert_equal(1, lbl) + assert_equal(0.0, conf) + + assert_raise(TypeError) { + @fisherfaces_trained.predict_with_confidence(DUMMY_OBJ) + } + end + def test_save filename = "fisherfaces_save-#{DateTime.now.strftime('%Y%m%d%H%M%S')}.xml" begin