diff --git a/ext/opencv/facerecognizer.cpp b/ext/opencv/facerecognizer.cpp index a6ae7f1..952ee8a 100644 --- a/ext/opencv/facerecognizer.cpp +++ b/ext/opencv/facerecognizer.cpp @@ -83,6 +83,43 @@ rb_train(VALUE self, VALUE src, VALUE labels) return Qnil; } +/* + * call-seq: + * udpate(src, labels) + * + * Updates a FaceRecognizer with given data and associated labels. Only valid on LBPH models. + */ +VALUE +rb_update(VALUE self, VALUE src, VALUE labels) +{ + Check_Type(src, T_ARRAY); + Check_Type(labels, T_ARRAY); + + VALUE *src_ptr = RARRAY_PTR(src); + int src_size = RARRAY_LEN(src); + std::vector images; + for (int i = 0; i < src_size; i++) { + images.push_back(cv::Mat(CVMAT_WITH_CHECK(src_ptr[i]))); + } + + VALUE *labels_ptr = RARRAY_PTR(labels); + int labels_size = RARRAY_LEN(labels); + std::vector local_labels; + for (int i = 0; i < labels_size; i++) { + local_labels.push_back(NUM2INT(labels_ptr[i])); + } + + cv::FaceRecognizer *self_ptr = FACERECOGNIZER(self); + try { + self_ptr->update(images, local_labels); + } + catch (cv::Exception& e) { + raise_cverror(e); + } + + return Qnil; +} + /* * call-seq: * predict(src) @@ -171,6 +208,7 @@ init_ruby_class() VALUE alghorithm = cAlgorithm::rb_class(); rb_klass = rb_define_class_under(opencv, "FaceRecognizer", alghorithm); rb_define_method(rb_klass, "train", RUBY_METHOD_FUNC(rb_train), 2); + rb_define_method(rb_klass, "update", RUBY_METHOD_FUNC(rb_update), 2); rb_define_method(rb_klass, "predict", RUBY_METHOD_FUNC(rb_predict), 1); rb_define_method(rb_klass, "save", RUBY_METHOD_FUNC(rb_save), 1); rb_define_method(rb_klass, "load", RUBY_METHOD_FUNC(rb_load), 1); diff --git a/test/test_lbph.rb b/test/test_lbph.rb index 8cbc3e4..c7c2be7 100755 --- a/test/test_lbph.rb +++ b/test/test_lbph.rb @@ -13,6 +13,7 @@ class TestLBPH < OpenCVTestCase @lbph = LBPH.new @lbph_trained = LBPH.new + @lbph_update = LBPH.new @images = [CvMat.load(FILENAME_LENA256x256, CV_LOAD_IMAGE_GRAYSCALE)] * 2 @labels = [1, 2] @lbph_trained.train(@images, @labels) @@ -52,6 +53,19 @@ class TestLBPH < OpenCVTestCase } end + def test_update + assert_nil(@lbph_update.train([@images[0]], [@labels[0]])) + assert_nil(@lbph_update.update([@images[1]], [@labels[1]])) + + assert_raise(TypeError) { + @lbph_update.train(DUMMY_OBJ, @labels) + } + + assert_raise(TypeError) { + @lbph_update.train(@images, DUMMY_OBJ) + } + end + def test_predict predicted_label, predicted_confidence = @lbph_trained.predict(@images[0]) assert_equal(@labels[0], predicted_label)