diff --git a/ext/opencv/mat.cpp b/ext/opencv/mat.cpp index e5b8a77..7f34c61 100644 --- a/ext/opencv/mat.cpp +++ b/ext/opencv/mat.cpp @@ -716,6 +716,26 @@ namespace rubyopencv { return mat2obj(retptr, CLASS_OF(self)); } + VALUE rb_diag(int argc, VALUE *argv, VALUE self) { + VALUE d; + rb_scan_args(argc, argv, "01", &d); + int d_value = NIL_P(d) ? 0 : NUM2INT(d); + cv::Mat* selfptr = obj2mat(self); + cv::Mat* retptr = NULL; + + try { + retptr = new cv::Mat(); + cv::Mat tmp = selfptr->diag(d_value); + tmp.copyTo(*retptr); + } + catch (cv::Exception& e) { + delete retptr; + Error::raise(e); + } + + return mat2obj(retptr, CLASS_OF(self)); + } + /* * Sets all or some of the array elements to the specified value. * @@ -858,7 +878,8 @@ namespace rubyopencv { rb_define_method(rb_klass, "-", RUBY_METHOD_FUNC(rb_sub), 1); rb_define_method(rb_klass, "*", RUBY_METHOD_FUNC(rb_mul), 1); rb_define_method(rb_klass, "/", RUBY_METHOD_FUNC(rb_div), 1); - + rb_define_method(rb_klass, "diag", RUBY_METHOD_FUNC(rb_diag), -1); + rb_define_method(rb_klass, "clone", RUBY_METHOD_FUNC(rb_clone), 0); rb_define_method(rb_klass, "rows", RUBY_METHOD_FUNC(rb_rows), 0); diff --git a/test/test_mat.rb b/test/test_mat.rb index db5f2b3..8975617 100755 --- a/test/test_mat.rb +++ b/test/test_mat.rb @@ -322,6 +322,37 @@ class TestMat < OpenCVTestCase } end + def test_diag + m0 = Mat.new(3, 3, CV_8U) + i = 1 + m0.rows.times { |r| + m0.cols.times { |c| + m0[r, c] = Scalar.new(i) + i += 1 + } + } + + m1 = m0.diag + elems = m1.to_s.scan(/(\[[^\]]+\])/m).flatten[0] + assert_equal("[ 1;\n 5;\n 9]", elems) + + m2 = m0.diag(0) + elems = m2.to_s.scan(/(\[[^\]]+\])/m).flatten[0] + assert_equal("[ 1;\n 5;\n 9]", elems) + + m3 = m0.diag(1) + elems = m3.to_s.scan(/(\[[^\]]+\])/m).flatten[0] + assert_equal("[ 2;\n 6]", elems) + + m4 = m0.diag(-1) + elems = m4.to_s.scan(/(\[[^\]]+\])/m).flatten[0] + assert_equal("[ 4;\n 8]", elems) + + assert_raise(TypeError) { + m0.diag(DUMMY_OBJ) + } + end + def test_cvt_color m = Mat.new(1, 1, CV_32FC3) m[0, 0] = Scalar.new(1.0, 2.0, 3.0)