676 lines
23 KiB
C++
Executable File
676 lines
23 KiB
C++
Executable File
// ----------------------------------------------------------------------------
|
|
// - Open3D: www.open3d.org -
|
|
// ----------------------------------------------------------------------------
|
|
// Copyright (c) 2018-2023 www.open3d.org
|
|
// SPDX-License-Identifier: MIT
|
|
// ----------------------------------------------------------------------------
|
|
|
|
#pragma once
|
|
#include <iostream>
|
|
#include <string>
|
|
#include <tuple>
|
|
#include <vector>
|
|
|
|
namespace open3d {
|
|
namespace ml {
|
|
namespace op_util {
|
|
|
|
/// Class for representing a possibly unknown dimension value
|
|
class DimValue {
|
|
public:
|
|
DimValue() : value_(0), constant_(false) {}
|
|
DimValue(int64_t v) : value_(v), constant_(true) {}
|
|
DimValue& operator*=(const DimValue& b) {
|
|
if (constant_ && b.constant_)
|
|
value_ *= b.value_;
|
|
else
|
|
constant_ = false;
|
|
return *this;
|
|
}
|
|
std::string ToString() const {
|
|
if (constant_)
|
|
return std::to_string(value_);
|
|
else
|
|
return "?";
|
|
}
|
|
int64_t& value() {
|
|
if (!constant_) throw std::runtime_error("DimValue is not constant");
|
|
return value_;
|
|
}
|
|
bool& constant() { return constant_; }
|
|
|
|
private:
|
|
int64_t value_;
|
|
bool constant_;
|
|
};
|
|
|
|
inline DimValue UnknownValue() { return DimValue(); }
|
|
|
|
/// Class for dimensions for which the value should be inferred.
|
|
class Dim {
|
|
public:
|
|
explicit Dim() : value_(0), constant_(false), origin_(this) {}
|
|
|
|
explicit Dim(const std::string& name)
|
|
: value_(0), constant_(false), origin_(this), name_(name) {}
|
|
|
|
Dim(int64_t value, const std::string& name = "")
|
|
: value_(value), constant_(true), origin_(nullptr), name_(name) {}
|
|
|
|
Dim(const Dim& other)
|
|
: value_(other.value_),
|
|
constant_(other.constant_),
|
|
origin_(other.origin_),
|
|
name_(other.name_) {}
|
|
|
|
~Dim() {}
|
|
|
|
Dim& operator=(const Dim&) = delete;
|
|
|
|
int64_t& value() {
|
|
if (origin_)
|
|
return origin_->value_;
|
|
else
|
|
return value_;
|
|
}
|
|
|
|
bool& constant() {
|
|
if (origin_)
|
|
return origin_->constant_;
|
|
else
|
|
return constant_;
|
|
}
|
|
|
|
/// tries to assign a value to the Dim if not marked as constant and
|
|
/// compares the Dim value with the value to be assigned.
|
|
bool assign(int64_t a) {
|
|
if (!constant()) {
|
|
value() = a;
|
|
constant() = true;
|
|
}
|
|
return value() == a;
|
|
}
|
|
|
|
std::string ToString(bool show_value = true) {
|
|
if (name_.size()) {
|
|
if (show_value)
|
|
return name_ + "(" +
|
|
(constant() ? std::to_string(value()) : "?") + ")";
|
|
else
|
|
return name_;
|
|
}
|
|
if (constant())
|
|
return std::to_string(value());
|
|
else
|
|
return "?";
|
|
}
|
|
|
|
private:
|
|
int64_t value_;
|
|
bool constant_;
|
|
Dim* origin_;
|
|
std::string name_;
|
|
};
|
|
|
|
//
|
|
// Dim expression operator classes
|
|
//
|
|
|
|
struct DimXPlus {
|
|
static bool constant() { return true; };
|
|
static int64_t apply(int64_t a, int64_t b) { return a + b; }
|
|
|
|
template <class T1, class T2>
|
|
static bool backprop(int64_t ans, T1 a, T2 b) {
|
|
if (!a.constant() && a.constant() == b.constant()) {
|
|
std::string exstr =
|
|
GetString(a, false) + ToString() + GetString(b, false);
|
|
throw std::runtime_error("Illegal dim expression: " + exstr);
|
|
return false;
|
|
} else if (!a.constant()) {
|
|
return a.assign(ans - b.value());
|
|
} else {
|
|
return b.assign(ans - a.value());
|
|
}
|
|
}
|
|
|
|
static std::string ToString() { return "+"; }
|
|
};
|
|
|
|
struct DimXMinus {
|
|
static bool constant() { return true; };
|
|
static int64_t apply(int64_t a, int64_t b) { return a - b; }
|
|
|
|
template <class T1, class T2>
|
|
static bool backprop(int64_t ans, T1 a, T2 b) {
|
|
if (!a.constant() && a.constant() == b.constant()) {
|
|
std::string exstr =
|
|
GetString(a, false) + ToString() + GetString(b, false);
|
|
throw std::runtime_error("Illegal dim expression: " + exstr);
|
|
return false;
|
|
} else if (!a.constant()) {
|
|
return a.assign(ans + b.value());
|
|
} else {
|
|
return b.assign(a.value() - ans);
|
|
}
|
|
}
|
|
|
|
static std::string ToString() { return "-"; }
|
|
};
|
|
|
|
struct DimXMultiply {
|
|
static bool constant() { return true; };
|
|
static int64_t apply(int64_t a, int64_t b) { return a * b; }
|
|
|
|
template <class T1, class T2>
|
|
static bool backprop(int64_t ans, T1 a, T2 b) {
|
|
std::string exstr =
|
|
GetString(a, false) + ToString() + GetString(b, false);
|
|
throw std::runtime_error("Illegal dim expression: " + exstr);
|
|
return false;
|
|
}
|
|
|
|
static std::string ToString() { return "*"; }
|
|
};
|
|
|
|
struct DimXDivide {
|
|
static bool constant() { return true; };
|
|
static int64_t apply(int64_t a, int64_t b) { return a / b; }
|
|
|
|
template <class T1, class T2>
|
|
static bool backprop(int64_t ans, T1 a, T2 b) {
|
|
std::string exstr =
|
|
GetString(a, false) + ToString() + GetString(b, false);
|
|
throw std::runtime_error("Illegal dim expression: " + exstr);
|
|
return false;
|
|
}
|
|
|
|
static std::string ToString() { return "/"; }
|
|
};
|
|
|
|
struct DimXOr {
|
|
static bool constant() { return false; };
|
|
static int64_t apply(int64_t a, int64_t b) {
|
|
throw std::runtime_error("Cannot evaluate OR expression");
|
|
return 0;
|
|
}
|
|
template <class T1, class T2>
|
|
static bool backprop(int64_t ans, T1 a, T2 b) {
|
|
return a.assign(ans) || b.assign(ans);
|
|
}
|
|
|
|
static std::string ToString() { return "||"; }
|
|
};
|
|
|
|
/// Dim expression class
|
|
template <class TLeft, class TRight, class TOp>
|
|
class DimX {
|
|
public:
|
|
static DimX<TLeft, TRight, TOp> Create(TLeft left, TRight right) {
|
|
return DimX(left, right);
|
|
}
|
|
|
|
int64_t value() {
|
|
if (constant_) {
|
|
return TOp::apply(left_.value(), right_.value());
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
bool& constant() { return constant_; }
|
|
|
|
/// assigns a value to the expression
|
|
bool assign(int64_t a) {
|
|
if (constant_) {
|
|
return value() == a;
|
|
} else {
|
|
return TOp::backprop(a, left_, right_);
|
|
}
|
|
}
|
|
|
|
std::string ToString(bool show_value = true) {
|
|
return left_.ToString(show_value) + TOp::ToString() +
|
|
right_.ToString(show_value);
|
|
}
|
|
|
|
private:
|
|
DimX(TLeft left, TRight right) : left_(left), right_(right) {
|
|
constant_ = left.constant() && right.constant() && TOp::constant();
|
|
}
|
|
TLeft left_;
|
|
TRight right_;
|
|
bool constant_;
|
|
};
|
|
|
|
//
|
|
// define operators for dim expressions
|
|
//
|
|
|
|
#define DEFINE_DIMX_OPERATOR(opclass, symbol) \
|
|
inline DimX<Dim, Dim, opclass> operator symbol(Dim a, Dim b) { \
|
|
return DimX<Dim, Dim, opclass>::Create(a, b); \
|
|
} \
|
|
\
|
|
template <class TL, class TR, class TOp> \
|
|
inline DimX<Dim, DimX<TL, TR, TOp>, opclass> operator symbol( \
|
|
Dim a, DimX<TL, TR, TOp>&& b) { \
|
|
return DimX<Dim, DimX<TL, TR, TOp>, opclass>::Create(a, b); \
|
|
} \
|
|
\
|
|
template <class TL, class TR, class TOp> \
|
|
inline DimX<DimX<TL, TR, TOp>, Dim, opclass> operator symbol( \
|
|
DimX<TL, TR, TOp>&& a, Dim b) { \
|
|
return DimX<DimX<TL, TR, TOp>, Dim, opclass>::Create(a, b); \
|
|
} \
|
|
\
|
|
template <class TL1, class TR1, class TOp1, class TL2, class TR2, \
|
|
class TOp2> \
|
|
inline DimX<DimX<TL1, TR1, TOp1>, DimX<TL2, TR2, TOp2>, opclass> \
|
|
operator symbol(DimX<TL1, TR1, TOp1>&& a, DimX<TL2, TR2, TOp2>&& b) { \
|
|
return DimX<DimX<TL1, TR1, TOp1>, DimX<TL2, TR2, TOp2>, \
|
|
opclass>::Create(a, b); \
|
|
}
|
|
|
|
DEFINE_DIMX_OPERATOR(DimXPlus, +)
|
|
DEFINE_DIMX_OPERATOR(DimXMinus, -)
|
|
DEFINE_DIMX_OPERATOR(DimXMultiply, *)
|
|
DEFINE_DIMX_OPERATOR(DimXDivide, /)
|
|
DEFINE_DIMX_OPERATOR(DimXOr, ||)
|
|
#undef DEFINE_DIMX_OPERATOR
|
|
|
|
//
|
|
// define operators for comparing DimValue to dim expressions.
|
|
// Using these operators will try to assign the dim value to the expression.
|
|
//
|
|
|
|
template <class TLeft, class TRight, class TOp>
|
|
inline bool operator==(DimValue a, DimX<TLeft, TRight, TOp>&& b) {
|
|
if (a.constant()) {
|
|
auto b_copy(b);
|
|
return b_copy.assign(a.value());
|
|
} else
|
|
return true;
|
|
}
|
|
|
|
inline bool operator==(DimValue a, Dim b) {
|
|
if (a.constant())
|
|
return b.assign(a.value());
|
|
else
|
|
return true;
|
|
}
|
|
|
|
//
|
|
// some helper classes
|
|
//
|
|
|
|
template <class... args>
|
|
struct CountArgs {
|
|
static const size_t value = sizeof...(args);
|
|
};
|
|
|
|
template <class TLeft, class TRight, class TOp>
|
|
std::string GetString(DimX<TLeft, TRight, TOp> a, bool show_value = true) {
|
|
return a.ToString(show_value);
|
|
}
|
|
|
|
inline std::string GetString(Dim a, bool show_value = true) {
|
|
return a.ToString(show_value);
|
|
}
|
|
|
|
template <class TLeft, class TRight, class TOp>
|
|
int64_t GetValue(DimX<TLeft, TRight, TOp> a) {
|
|
return a.value();
|
|
}
|
|
|
|
template <class TLeft, class TRight, class TOp>
|
|
int64_t GetValue(DimX<TLeft, TRight, TOp> a, int64_t unknown_dim_value) {
|
|
if (a.constant()) {
|
|
return a.value();
|
|
} else {
|
|
return unknown_dim_value;
|
|
}
|
|
return a.value();
|
|
}
|
|
|
|
inline int64_t GetValue(Dim a) { return a.value(); }
|
|
|
|
inline int64_t GetValue(Dim a, int64_t unknown_dim_value) {
|
|
if (a.constant()) {
|
|
return a.value();
|
|
} else {
|
|
return unknown_dim_value;
|
|
}
|
|
}
|
|
|
|
inline std::string CreateDimXString() { return std::string(); }
|
|
|
|
template <class TDimX>
|
|
std::string CreateDimXString(TDimX dimex) {
|
|
return GetString(dimex);
|
|
}
|
|
|
|
template <class TDimX, class... TArgs>
|
|
std::string CreateDimXString(TDimX dimex, TArgs... args) {
|
|
return GetString(dimex) + ", " + CreateDimXString(args...);
|
|
}
|
|
|
|
template <class TDimX>
|
|
void CreateDimVector(std::vector<int64_t>& out,
|
|
int64_t unknown_dim_value,
|
|
TDimX dimex) {
|
|
out.push_back(GetValue(dimex, unknown_dim_value));
|
|
}
|
|
|
|
template <class TDimX, class... TArgs>
|
|
void CreateDimVector(std::vector<int64_t>& out,
|
|
int64_t unknown_dim_value,
|
|
TDimX dimex,
|
|
TArgs... args) {
|
|
out.push_back(GetValue(dimex, unknown_dim_value));
|
|
CreateDimVector(out, unknown_dim_value, args...);
|
|
}
|
|
|
|
template <class TDimX>
|
|
std::vector<int64_t> CreateDimVector(int64_t unknown_dim_value, TDimX dimex) {
|
|
std::vector<int64_t> out;
|
|
CreateDimVector(out, unknown_dim_value, dimex);
|
|
return out;
|
|
}
|
|
|
|
template <class TDimX, class... TArgs>
|
|
std::vector<int64_t> CreateDimVector(int64_t unknown_dim_value,
|
|
TDimX dimex,
|
|
TArgs... args) {
|
|
std::vector<int64_t> out;
|
|
CreateDimVector(out, unknown_dim_value, dimex, args...);
|
|
return out;
|
|
}
|
|
|
|
//
|
|
// classes which check if the dim value is compatible with the expression
|
|
//
|
|
|
|
template <class TLeft, class TRight, class TOp>
|
|
bool CheckDim(const DimValue& lhs, DimX<TLeft, TRight, TOp>&& rhs) {
|
|
bool status = (lhs == std::forward<DimX<TLeft, TRight, TOp>>(rhs));
|
|
return status;
|
|
}
|
|
|
|
inline bool CheckDim(const DimValue& lhs, Dim d) {
|
|
bool status = lhs == d;
|
|
return status;
|
|
}
|
|
|
|
/// Check shape options
|
|
enum class CSOpt {
|
|
NONE = 0,
|
|
COMBINE_FIRST_DIMS,
|
|
IGNORE_FIRST_DIMS,
|
|
COMBINE_LAST_DIMS,
|
|
IGNORE_LAST_DIMS
|
|
};
|
|
|
|
template <CSOpt Opt = CSOpt::NONE, class TDimX>
|
|
bool _CheckShape(const std::vector<DimValue>& shape, TDimX&& dimex) {
|
|
// check rank
|
|
const int rank_diff = shape.size() - 1;
|
|
if (Opt != CSOpt::NONE) {
|
|
if (rank_diff < 0) {
|
|
return false;
|
|
}
|
|
} else {
|
|
if (rank_diff != 0) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// check dim
|
|
bool status;
|
|
if (Opt == CSOpt::COMBINE_FIRST_DIMS) {
|
|
DimValue s(1);
|
|
for (int i = 0; i < rank_diff + 1; ++i) s *= shape[i];
|
|
status = CheckDim(s, std::forward<TDimX>(dimex));
|
|
} else if (Opt == CSOpt::IGNORE_FIRST_DIMS) {
|
|
status = CheckDim(shape[rank_diff], std::forward<TDimX>(dimex));
|
|
} else if (Opt == CSOpt::COMBINE_LAST_DIMS) {
|
|
DimValue s(1);
|
|
for (DimValue x : shape) s *= x;
|
|
status = CheckDim(s, std::forward<TDimX>(dimex));
|
|
} else {
|
|
status = CheckDim(shape[0], std::forward<TDimX>(dimex));
|
|
}
|
|
return status;
|
|
}
|
|
|
|
template <CSOpt Opt = CSOpt::NONE, class TDimX, class... TArgs>
|
|
bool _CheckShape(const std::vector<DimValue>& shape,
|
|
TDimX&& dimex,
|
|
TArgs&&... args) {
|
|
// check rank
|
|
const int rank_diff = shape.size() - (CountArgs<TArgs...>::value + 1);
|
|
if (Opt != CSOpt::NONE) {
|
|
if (rank_diff < 0) {
|
|
return false;
|
|
}
|
|
} else {
|
|
if (rank_diff != 0) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// check dim
|
|
bool status;
|
|
if (Opt == CSOpt::COMBINE_FIRST_DIMS) {
|
|
DimValue s(1);
|
|
for (int i = 0; i < rank_diff + 1; ++i) s *= shape[i];
|
|
status = CheckDim(s, std::forward<TDimX>(dimex));
|
|
} else if (Opt == CSOpt::IGNORE_FIRST_DIMS) {
|
|
status = CheckDim(shape[rank_diff], std::forward<TDimX>(dimex));
|
|
} else {
|
|
status = CheckDim(shape[0], std::forward<TDimX>(dimex));
|
|
}
|
|
|
|
const int offset = 1 + (Opt == CSOpt::COMBINE_FIRST_DIMS ||
|
|
Opt == CSOpt::IGNORE_FIRST_DIMS
|
|
? rank_diff
|
|
: 0);
|
|
std::vector<DimValue> shape2(shape.begin() + offset, shape.end());
|
|
bool status2 = _CheckShape<Opt>(shape2, std::forward<TArgs>(args)...);
|
|
|
|
return status && status2;
|
|
}
|
|
|
|
/// Function for checking a shape with dim expressions.
|
|
/// Usage example:
|
|
///
|
|
/// Dim depth("depth");
|
|
/// Dim height("height");
|
|
/// Dim width("width");
|
|
/// status = CheckShape({30,40}, height, width); // VALID, will assign values
|
|
/// // to height and width
|
|
///
|
|
/// status = CheckShape({50,41}, height+20, width+1); // VALID, values match
|
|
/// status = CheckShape({20,30,40}, depth+10, height, width); // VALID, will
|
|
/// // assign 10 to depth
|
|
///
|
|
/// status = CheckShape({0},depth||0); // VALID, shape must match depth or 0
|
|
/// status = CheckShape({10}, depth||0); // VALID, shape must match depth or 0
|
|
/// status = CheckShape({123,10}, Dim(), depth); // VALID, first dim may be
|
|
/// // anything
|
|
///
|
|
/// status = CheckShape<CSOpt::COMBINE_LAST_DIMS>({123,10,4}, Dim(), width);
|
|
/// // VALID, width==40==10*4
|
|
///
|
|
/// status = CheckShape<CSOpt::COMBINE_FIRST_DIMS>(
|
|
/// {10,2,2,123,456}, width, Dim(), Dim());
|
|
/// // VALID, width==40==10*2*2
|
|
///
|
|
/// status = CheckShape({70}, height+width); // VALID, works because height
|
|
/// // and width have been initialized since the first call to CheckShape
|
|
///
|
|
/// status = CheckShape({1,2,3}, Dim(), Dim()); // INVALID, rank mismatch 3vs2
|
|
/// status = CheckShape({1,2,3}, depth, width, height); // INVALID, at least
|
|
/// // one dim does not match
|
|
///
|
|
/// The options CSOpt::COMBINE_FIRST_DIMS and CSOpt::COMBINE_LAST_DIMS allow
|
|
/// to match the rank of the dim expressions by combining the shape
|
|
/// dimensions at the beginning or end.
|
|
/// The options CSOpt::IGNORE_FIRST_DIMS and CSOpt::IGNORE_LAST_DIMS allow to
|
|
/// ignore additional dimensions in the shape.
|
|
///
|
|
/// The shape to be checked may contain unknowns
|
|
/// Dim A("A");
|
|
/// Dim B("B");
|
|
/// status = CheckShape({30, UnknownValue()}, A, B); // VALID, A is 30 and B
|
|
/// // is still unknown
|
|
///
|
|
/// status =
|
|
/// CheckShape<CSOpt::COMBINE_LAST_DIMS>({30,1,2,UnknownValue()},A,B);
|
|
/// // VALID, A is 30 and B is still unknown
|
|
///
|
|
/// The following shows some limitations of the dim expressions
|
|
/// Dim A("A");
|
|
/// Dim B("B");
|
|
/// status = CheckShape({30}, A+B); // THROWS EXCEPTION, illegal expression
|
|
/// // because neither A or B is a constant
|
|
///
|
|
/// However, the following will work
|
|
/// Dim A(20,"A");
|
|
/// Dim B("B");
|
|
/// status = CheckShape({30}, A+B); // VALID, B is now 10
|
|
///
|
|
/// This will work, too
|
|
/// Dim A("A"); // uninitialized
|
|
/// Dim B("B");
|
|
/// status = CheckShape({20}, A); // VALID, A is now 20
|
|
/// status = CheckShape({30}, A+B); // VALID, B is now 10
|
|
///
|
|
/// Multiplication and division are not allowed for unknown dims
|
|
/// Dim A("A");
|
|
/// status = CheckShape({30}, 3*A); // THROWS EXCEPTION, although expression
|
|
/// // seems reasonable
|
|
/// status = CheckShape({20}, 3*A); // THROWS EXCEPTION, this
|
|
/// // is the reason why mul/div is only allowed for known dims
|
|
///
|
|
/// Important, do not create objects of dim expressions, i.e.,
|
|
/// auto dimx = Dim("tmp") + 3;
|
|
/// status = CheckShape({20}, dimx); // intended to not compile
|
|
/// Assigning a value to dimx will assign a value to Dim("tmp") which has a
|
|
/// shorter lifetime.
|
|
///
|
|
/// The return value is a tuple <bool,std::string>. If the bool is false
|
|
/// then the shape is INVALID and the string contains an error message of the
|
|
/// form "got [shape], expected [dim expressions]".
|
|
/// If true then the shape is VALID and the error string is empty.
|
|
///
|
|
/// Note the goal of this function is to simplify checking tensor shapes. There
|
|
/// may be cases where shapes cannot be checked with the provided functionality
|
|
/// and you have to write custom shape checking code.
|
|
///
|
|
/// \param shape This is the actual shape of an object.
|
|
/// \param args This is a list of dim expression
|
|
///
|
|
template <CSOpt Opt = CSOpt::NONE, class TDimX, class... TArgs>
|
|
std::tuple<bool, std::string> CheckShape(const std::vector<DimValue>& shape,
|
|
TDimX&& dimex,
|
|
TArgs&&... args) {
|
|
const bool status = _CheckShape<Opt>(shape, std::forward<TDimX>(dimex),
|
|
std::forward<TArgs>(args)...);
|
|
if (status) {
|
|
return std::make_tuple(status, std::string());
|
|
} else {
|
|
const int rank_diff = shape.size() - (CountArgs<TArgs...>::value + 1);
|
|
|
|
// generate string for the actual shape. This is a bit involved because
|
|
// of the many options.
|
|
std::string shape_str;
|
|
if (rank_diff <= 0) {
|
|
shape_str = "[";
|
|
for (int i = 0; i < int(shape.size()); ++i) {
|
|
shape_str += shape[i].ToString();
|
|
if (i + 1 < int(shape.size())) shape_str += ", ";
|
|
}
|
|
shape_str += "]";
|
|
} else {
|
|
if (Opt == CSOpt::COMBINE_FIRST_DIMS) {
|
|
shape_str += "[";
|
|
for (int i = 0; i < rank_diff; ++i) {
|
|
shape_str += shape[i].ToString();
|
|
if (i + 1 < int(shape.size())) shape_str += "*";
|
|
}
|
|
} else if (Opt == CSOpt::IGNORE_FIRST_DIMS) {
|
|
shape_str += "(";
|
|
for (int i = 0; i < rank_diff; ++i) {
|
|
shape_str += shape[i].ToString();
|
|
if (i + 1 < rank_diff) shape_str += ", ";
|
|
}
|
|
shape_str += ")[";
|
|
} else {
|
|
shape_str = "[";
|
|
}
|
|
int start = 0;
|
|
if (Opt == CSOpt::COMBINE_FIRST_DIMS ||
|
|
Opt == CSOpt::IGNORE_FIRST_DIMS) {
|
|
start = rank_diff;
|
|
}
|
|
|
|
int end = shape.size();
|
|
if (Opt == CSOpt::COMBINE_LAST_DIMS) {
|
|
end -= rank_diff + 1;
|
|
} else if (Opt == CSOpt::IGNORE_LAST_DIMS) {
|
|
end -= rank_diff;
|
|
}
|
|
for (int i = start; i < end; ++i) {
|
|
shape_str += shape[i].ToString();
|
|
if (i + 1 < end) shape_str += ", ";
|
|
}
|
|
if (Opt == CSOpt::COMBINE_LAST_DIMS) {
|
|
shape_str += ", ";
|
|
for (int i = std::max<int>(0, shape.size() - rank_diff - 1);
|
|
i < int(shape.size()); ++i) {
|
|
shape_str += shape[i].ToString();
|
|
if (i + 1 < int(shape.size())) shape_str += "*";
|
|
}
|
|
shape_str += "]";
|
|
} else if (Opt == CSOpt::IGNORE_LAST_DIMS) {
|
|
shape_str += "](";
|
|
for (int i = std::max<int>(0, shape.size() - rank_diff);
|
|
i < int(shape.size()); ++i) {
|
|
shape_str += shape[i].ToString();
|
|
if (i + 1 < int(shape.size())) shape_str += ", ";
|
|
}
|
|
shape_str += ")";
|
|
} else {
|
|
shape_str += "]";
|
|
}
|
|
}
|
|
|
|
// generate string for the expected shape with the dim expressions
|
|
std::string expected_shape;
|
|
if ((CountArgs<TArgs...>::value + 1) == 1) {
|
|
expected_shape = "[" + GetString(dimex) + "]";
|
|
|
|
} else {
|
|
expected_shape = "[" + GetString(dimex) + ", " +
|
|
CreateDimXString(args...) + "]";
|
|
}
|
|
|
|
std::string errstr;
|
|
// print rank information if there is a problem with the rank
|
|
if ((Opt != CSOpt::NONE && rank_diff < 0) ||
|
|
(Opt == CSOpt::NONE && rank_diff != 0)) {
|
|
errstr = "got rank " + std::to_string(shape.size()) + " " +
|
|
shape_str + ", expected rank " +
|
|
std::to_string(CountArgs<TArgs...>::value + 1) + " " +
|
|
expected_shape;
|
|
} else { // rank is OK print just the shapes
|
|
errstr = "got " + shape_str + ", expected " + expected_shape;
|
|
}
|
|
return std::make_tuple(status, errstr);
|
|
}
|
|
}
|
|
|
|
} // namespace op_util
|
|
} // namespace ml
|
|
} // namespace open3d
|