// ---------------------------------------------------------------------------- // - Open3D: www.open3d.org - // ---------------------------------------------------------------------------- // Copyright (c) 2018-2023 www.open3d.org // SPDX-License-Identifier: MIT // ---------------------------------------------------------------------------- #pragma once #include #include #include #include namespace open3d { namespace ml { namespace op_util { /// Class for representing a possibly unknown dimension value class DimValue { public: DimValue() : value_(0), constant_(false) {} DimValue(int64_t v) : value_(v), constant_(true) {} DimValue& operator*=(const DimValue& b) { if (constant_ && b.constant_) value_ *= b.value_; else constant_ = false; return *this; } std::string ToString() const { if (constant_) return std::to_string(value_); else return "?"; } int64_t& value() { if (!constant_) throw std::runtime_error("DimValue is not constant"); return value_; } bool& constant() { return constant_; } private: int64_t value_; bool constant_; }; inline DimValue UnknownValue() { return DimValue(); } /// Class for dimensions for which the value should be inferred. class Dim { public: explicit Dim() : value_(0), constant_(false), origin_(this) {} explicit Dim(const std::string& name) : value_(0), constant_(false), origin_(this), name_(name) {} Dim(int64_t value, const std::string& name = "") : value_(value), constant_(true), origin_(nullptr), name_(name) {} Dim(const Dim& other) : value_(other.value_), constant_(other.constant_), origin_(other.origin_), name_(other.name_) {} ~Dim() {} Dim& operator=(const Dim&) = delete; int64_t& value() { if (origin_) return origin_->value_; else return value_; } bool& constant() { if (origin_) return origin_->constant_; else return constant_; } /// tries to assign a value to the Dim if not marked as constant and /// compares the Dim value with the value to be assigned. bool assign(int64_t a) { if (!constant()) { value() = a; constant() = true; } return value() == a; } std::string ToString(bool show_value = true) { if (name_.size()) { if (show_value) return name_ + "(" + (constant() ? std::to_string(value()) : "?") + ")"; else return name_; } if (constant()) return std::to_string(value()); else return "?"; } private: int64_t value_; bool constant_; Dim* origin_; std::string name_; }; // // Dim expression operator classes // struct DimXPlus { static bool constant() { return true; }; static int64_t apply(int64_t a, int64_t b) { return a + b; } template static bool backprop(int64_t ans, T1 a, T2 b) { if (!a.constant() && a.constant() == b.constant()) { std::string exstr = GetString(a, false) + ToString() + GetString(b, false); throw std::runtime_error("Illegal dim expression: " + exstr); return false; } else if (!a.constant()) { return a.assign(ans - b.value()); } else { return b.assign(ans - a.value()); } } static std::string ToString() { return "+"; } }; struct DimXMinus { static bool constant() { return true; }; static int64_t apply(int64_t a, int64_t b) { return a - b; } template static bool backprop(int64_t ans, T1 a, T2 b) { if (!a.constant() && a.constant() == b.constant()) { std::string exstr = GetString(a, false) + ToString() + GetString(b, false); throw std::runtime_error("Illegal dim expression: " + exstr); return false; } else if (!a.constant()) { return a.assign(ans + b.value()); } else { return b.assign(a.value() - ans); } } static std::string ToString() { return "-"; } }; struct DimXMultiply { static bool constant() { return true; }; static int64_t apply(int64_t a, int64_t b) { return a * b; } template static bool backprop(int64_t ans, T1 a, T2 b) { std::string exstr = GetString(a, false) + ToString() + GetString(b, false); throw std::runtime_error("Illegal dim expression: " + exstr); return false; } static std::string ToString() { return "*"; } }; struct DimXDivide { static bool constant() { return true; }; static int64_t apply(int64_t a, int64_t b) { return a / b; } template static bool backprop(int64_t ans, T1 a, T2 b) { std::string exstr = GetString(a, false) + ToString() + GetString(b, false); throw std::runtime_error("Illegal dim expression: " + exstr); return false; } static std::string ToString() { return "/"; } }; struct DimXOr { static bool constant() { return false; }; static int64_t apply(int64_t a, int64_t b) { throw std::runtime_error("Cannot evaluate OR expression"); return 0; } template static bool backprop(int64_t ans, T1 a, T2 b) { return a.assign(ans) || b.assign(ans); } static std::string ToString() { return "||"; } }; /// Dim expression class template class DimX { public: static DimX Create(TLeft left, TRight right) { return DimX(left, right); } int64_t value() { if (constant_) { return TOp::apply(left_.value(), right_.value()); } return 0; } bool& constant() { return constant_; } /// assigns a value to the expression bool assign(int64_t a) { if (constant_) { return value() == a; } else { return TOp::backprop(a, left_, right_); } } std::string ToString(bool show_value = true) { return left_.ToString(show_value) + TOp::ToString() + right_.ToString(show_value); } private: DimX(TLeft left, TRight right) : left_(left), right_(right) { constant_ = left.constant() && right.constant() && TOp::constant(); } TLeft left_; TRight right_; bool constant_; }; // // define operators for dim expressions // #define DEFINE_DIMX_OPERATOR(opclass, symbol) \ inline DimX operator symbol(Dim a, Dim b) { \ return DimX::Create(a, b); \ } \ \ template \ inline DimX, opclass> operator symbol( \ Dim a, DimX&& b) { \ return DimX, opclass>::Create(a, b); \ } \ \ template \ inline DimX, Dim, opclass> operator symbol( \ DimX&& a, Dim b) { \ return DimX, Dim, opclass>::Create(a, b); \ } \ \ template \ inline DimX, DimX, opclass> \ operator symbol(DimX&& a, DimX&& b) { \ return DimX, DimX, \ opclass>::Create(a, b); \ } DEFINE_DIMX_OPERATOR(DimXPlus, +) DEFINE_DIMX_OPERATOR(DimXMinus, -) DEFINE_DIMX_OPERATOR(DimXMultiply, *) DEFINE_DIMX_OPERATOR(DimXDivide, /) DEFINE_DIMX_OPERATOR(DimXOr, ||) #undef DEFINE_DIMX_OPERATOR // // define operators for comparing DimValue to dim expressions. // Using these operators will try to assign the dim value to the expression. // template inline bool operator==(DimValue a, DimX&& b) { if (a.constant()) { auto b_copy(b); return b_copy.assign(a.value()); } else return true; } inline bool operator==(DimValue a, Dim b) { if (a.constant()) return b.assign(a.value()); else return true; } // // some helper classes // template struct CountArgs { static const size_t value = sizeof...(args); }; template std::string GetString(DimX a, bool show_value = true) { return a.ToString(show_value); } inline std::string GetString(Dim a, bool show_value = true) { return a.ToString(show_value); } template int64_t GetValue(DimX a) { return a.value(); } template int64_t GetValue(DimX a, int64_t unknown_dim_value) { if (a.constant()) { return a.value(); } else { return unknown_dim_value; } return a.value(); } inline int64_t GetValue(Dim a) { return a.value(); } inline int64_t GetValue(Dim a, int64_t unknown_dim_value) { if (a.constant()) { return a.value(); } else { return unknown_dim_value; } } inline std::string CreateDimXString() { return std::string(); } template std::string CreateDimXString(TDimX dimex) { return GetString(dimex); } template std::string CreateDimXString(TDimX dimex, TArgs... args) { return GetString(dimex) + ", " + CreateDimXString(args...); } template void CreateDimVector(std::vector& out, int64_t unknown_dim_value, TDimX dimex) { out.push_back(GetValue(dimex, unknown_dim_value)); } template void CreateDimVector(std::vector& out, int64_t unknown_dim_value, TDimX dimex, TArgs... args) { out.push_back(GetValue(dimex, unknown_dim_value)); CreateDimVector(out, unknown_dim_value, args...); } template std::vector CreateDimVector(int64_t unknown_dim_value, TDimX dimex) { std::vector out; CreateDimVector(out, unknown_dim_value, dimex); return out; } template std::vector CreateDimVector(int64_t unknown_dim_value, TDimX dimex, TArgs... args) { std::vector out; CreateDimVector(out, unknown_dim_value, dimex, args...); return out; } // // classes which check if the dim value is compatible with the expression // template bool CheckDim(const DimValue& lhs, DimX&& rhs) { bool status = (lhs == std::forward>(rhs)); return status; } inline bool CheckDim(const DimValue& lhs, Dim d) { bool status = lhs == d; return status; } /// Check shape options enum class CSOpt { NONE = 0, COMBINE_FIRST_DIMS, IGNORE_FIRST_DIMS, COMBINE_LAST_DIMS, IGNORE_LAST_DIMS }; template bool _CheckShape(const std::vector& shape, TDimX&& dimex) { // check rank const int rank_diff = shape.size() - 1; if (Opt != CSOpt::NONE) { if (rank_diff < 0) { return false; } } else { if (rank_diff != 0) { return false; } } // check dim bool status; if (Opt == CSOpt::COMBINE_FIRST_DIMS) { DimValue s(1); for (int i = 0; i < rank_diff + 1; ++i) s *= shape[i]; status = CheckDim(s, std::forward(dimex)); } else if (Opt == CSOpt::IGNORE_FIRST_DIMS) { status = CheckDim(shape[rank_diff], std::forward(dimex)); } else if (Opt == CSOpt::COMBINE_LAST_DIMS) { DimValue s(1); for (DimValue x : shape) s *= x; status = CheckDim(s, std::forward(dimex)); } else { status = CheckDim(shape[0], std::forward(dimex)); } return status; } template bool _CheckShape(const std::vector& shape, TDimX&& dimex, TArgs&&... args) { // check rank const int rank_diff = shape.size() - (CountArgs::value + 1); if (Opt != CSOpt::NONE) { if (rank_diff < 0) { return false; } } else { if (rank_diff != 0) { return false; } } // check dim bool status; if (Opt == CSOpt::COMBINE_FIRST_DIMS) { DimValue s(1); for (int i = 0; i < rank_diff + 1; ++i) s *= shape[i]; status = CheckDim(s, std::forward(dimex)); } else if (Opt == CSOpt::IGNORE_FIRST_DIMS) { status = CheckDim(shape[rank_diff], std::forward(dimex)); } else { status = CheckDim(shape[0], std::forward(dimex)); } const int offset = 1 + (Opt == CSOpt::COMBINE_FIRST_DIMS || Opt == CSOpt::IGNORE_FIRST_DIMS ? rank_diff : 0); std::vector shape2(shape.begin() + offset, shape.end()); bool status2 = _CheckShape(shape2, std::forward(args)...); return status && status2; } /// Function for checking a shape with dim expressions. /// Usage example: /// /// Dim depth("depth"); /// Dim height("height"); /// Dim width("width"); /// status = CheckShape({30,40}, height, width); // VALID, will assign values /// // to height and width /// /// status = CheckShape({50,41}, height+20, width+1); // VALID, values match /// status = CheckShape({20,30,40}, depth+10, height, width); // VALID, will /// // assign 10 to depth /// /// status = CheckShape({0},depth||0); // VALID, shape must match depth or 0 /// status = CheckShape({10}, depth||0); // VALID, shape must match depth or 0 /// status = CheckShape({123,10}, Dim(), depth); // VALID, first dim may be /// // anything /// /// status = CheckShape({123,10,4}, Dim(), width); /// // VALID, width==40==10*4 /// /// status = CheckShape( /// {10,2,2,123,456}, width, Dim(), Dim()); /// // VALID, width==40==10*2*2 /// /// status = CheckShape({70}, height+width); // VALID, works because height /// // and width have been initialized since the first call to CheckShape /// /// status = CheckShape({1,2,3}, Dim(), Dim()); // INVALID, rank mismatch 3vs2 /// status = CheckShape({1,2,3}, depth, width, height); // INVALID, at least /// // one dim does not match /// /// The options CSOpt::COMBINE_FIRST_DIMS and CSOpt::COMBINE_LAST_DIMS allow /// to match the rank of the dim expressions by combining the shape /// dimensions at the beginning or end. /// The options CSOpt::IGNORE_FIRST_DIMS and CSOpt::IGNORE_LAST_DIMS allow to /// ignore additional dimensions in the shape. /// /// The shape to be checked may contain unknowns /// Dim A("A"); /// Dim B("B"); /// status = CheckShape({30, UnknownValue()}, A, B); // VALID, A is 30 and B /// // is still unknown /// /// status = /// CheckShape({30,1,2,UnknownValue()},A,B); /// // VALID, A is 30 and B is still unknown /// /// The following shows some limitations of the dim expressions /// Dim A("A"); /// Dim B("B"); /// status = CheckShape({30}, A+B); // THROWS EXCEPTION, illegal expression /// // because neither A or B is a constant /// /// However, the following will work /// Dim A(20,"A"); /// Dim B("B"); /// status = CheckShape({30}, A+B); // VALID, B is now 10 /// /// This will work, too /// Dim A("A"); // uninitialized /// Dim B("B"); /// status = CheckShape({20}, A); // VALID, A is now 20 /// status = CheckShape({30}, A+B); // VALID, B is now 10 /// /// Multiplication and division are not allowed for unknown dims /// Dim A("A"); /// status = CheckShape({30}, 3*A); // THROWS EXCEPTION, although expression /// // seems reasonable /// status = CheckShape({20}, 3*A); // THROWS EXCEPTION, this /// // is the reason why mul/div is only allowed for known dims /// /// Important, do not create objects of dim expressions, i.e., /// auto dimx = Dim("tmp") + 3; /// status = CheckShape({20}, dimx); // intended to not compile /// Assigning a value to dimx will assign a value to Dim("tmp") which has a /// shorter lifetime. /// /// The return value is a tuple . If the bool is false /// then the shape is INVALID and the string contains an error message of the /// form "got [shape], expected [dim expressions]". /// If true then the shape is VALID and the error string is empty. /// /// Note the goal of this function is to simplify checking tensor shapes. There /// may be cases where shapes cannot be checked with the provided functionality /// and you have to write custom shape checking code. /// /// \param shape This is the actual shape of an object. /// \param args This is a list of dim expression /// template std::tuple CheckShape(const std::vector& shape, TDimX&& dimex, TArgs&&... args) { const bool status = _CheckShape(shape, std::forward(dimex), std::forward(args)...); if (status) { return std::make_tuple(status, std::string()); } else { const int rank_diff = shape.size() - (CountArgs::value + 1); // generate string for the actual shape. This is a bit involved because // of the many options. std::string shape_str; if (rank_diff <= 0) { shape_str = "["; for (int i = 0; i < int(shape.size()); ++i) { shape_str += shape[i].ToString(); if (i + 1 < int(shape.size())) shape_str += ", "; } shape_str += "]"; } else { if (Opt == CSOpt::COMBINE_FIRST_DIMS) { shape_str += "["; for (int i = 0; i < rank_diff; ++i) { shape_str += shape[i].ToString(); if (i + 1 < int(shape.size())) shape_str += "*"; } } else if (Opt == CSOpt::IGNORE_FIRST_DIMS) { shape_str += "("; for (int i = 0; i < rank_diff; ++i) { shape_str += shape[i].ToString(); if (i + 1 < rank_diff) shape_str += ", "; } shape_str += ")["; } else { shape_str = "["; } int start = 0; if (Opt == CSOpt::COMBINE_FIRST_DIMS || Opt == CSOpt::IGNORE_FIRST_DIMS) { start = rank_diff; } int end = shape.size(); if (Opt == CSOpt::COMBINE_LAST_DIMS) { end -= rank_diff + 1; } else if (Opt == CSOpt::IGNORE_LAST_DIMS) { end -= rank_diff; } for (int i = start; i < end; ++i) { shape_str += shape[i].ToString(); if (i + 1 < end) shape_str += ", "; } if (Opt == CSOpt::COMBINE_LAST_DIMS) { shape_str += ", "; for (int i = std::max(0, shape.size() - rank_diff - 1); i < int(shape.size()); ++i) { shape_str += shape[i].ToString(); if (i + 1 < int(shape.size())) shape_str += "*"; } shape_str += "]"; } else if (Opt == CSOpt::IGNORE_LAST_DIMS) { shape_str += "]("; for (int i = std::max(0, shape.size() - rank_diff); i < int(shape.size()); ++i) { shape_str += shape[i].ToString(); if (i + 1 < int(shape.size())) shape_str += ", "; } shape_str += ")"; } else { shape_str += "]"; } } // generate string for the expected shape with the dim expressions std::string expected_shape; if ((CountArgs::value + 1) == 1) { expected_shape = "[" + GetString(dimex) + "]"; } else { expected_shape = "[" + GetString(dimex) + ", " + CreateDimXString(args...) + "]"; } std::string errstr; // print rank information if there is a problem with the rank if ((Opt != CSOpt::NONE && rank_diff < 0) || (Opt == CSOpt::NONE && rank_diff != 0)) { errstr = "got rank " + std::to_string(shape.size()) + " " + shape_str + ", expected rank " + std::to_string(CountArgs::value + 1) + " " + expected_shape; } else { // rank is OK print just the shapes errstr = "got " + shape_str + ", expected " + expected_shape; } return std::make_tuple(status, errstr); } } } // namespace op_util } // namespace ml } // namespace open3d