diff --git a/src/ceres_wrapper/ceres_manager.cpp b/src/ceres_wrapper/ceres_manager.cpp index 13cb365c709d57f3ef10d845b697158d5f912a6e..bbd383c66f3ea4edbe0a3caadeebbd5af47d2a02 100644 --- a/src/ceres_wrapper/ceres_manager.cpp +++ b/src/ceres_wrapper/ceres_manager.cpp @@ -303,6 +303,7 @@ void CeresManager::removeStateBlock(const StateBlockPtr& state_ptr) { assert(state_ptr); ceres_problem_->RemoveParameterBlock(getAssociatedMemBlockPtr(state_ptr)); + state_blocks_local_param_.erase(state_ptr); } void CeresManager::updateStateBlockStatus(const StateBlockPtr& state_ptr) @@ -317,9 +318,37 @@ void CeresManager::updateStateBlockStatus(const StateBlockPtr& state_ptr) void CeresManager::updateStateBlockLocalParametrization(const StateBlockPtr& state_ptr) { assert(state_ptr != nullptr); - // in ceres the easiest way to change (add or remove) a local parameterization is remove&add (The associated memory block is and MUST be the same) + + /* in ceres the easiest way to update (add or remove) a local parameterization + * of a state block (parameter block in ceres) is remove & add: + * - the state block: The associated memory block (that identified the parameter_block) is and MUST be the same + * - all involved constraints (residual_blocks in ceres) + */ + + // get all involved constraints + ConstraintBaseList involved_constraints; + for (auto pair : ctr_2_costfunction_) + for (const StateBlockPtr& st : pair.first->getStateBlockPtrVector()) + if (st == state_ptr) + { + // store + involved_constraints.push_back(pair.first); + break; + } + + // Remove all involved constraints (it does not remove any parameter block) + for (auto ctr : involved_constraints) + removeConstraint(ctr); + + // Remove state block (it removes all involved residual blocks but they just were removed) removeStateBlock(state_ptr); + + // Add state block addStateBlock(state_ptr); + + // Add all involved constraints + for (auto ctr : involved_constraints) + addConstraint(ctr); } ceres::CostFunctionPtr CeresManager::createCostFunction(const ConstraintBasePtr& _ctr_ptr) @@ -357,7 +386,7 @@ void CeresManager::check() assert(ctr_2_costfunction_[ctr_res_pair.first].get() == ceres_problem_->GetCostFunctionForResidualBlock(ctr_res_pair.second)); // constraint - residual - assert(ctr_res_pair.first == static_cast<const CostFunctionWrapper*>(ceres_problem_->GetCostFunctionForResidualBlock(ctr_res_pair.second))->constraint_ptr_); + assert(ctr_res_pair.first == static_cast<const CostFunctionWrapper*>(ceres_problem_->GetCostFunctionForResidualBlock(ctr_res_pair.second))->getConstraintPtr()); // parameter blocks - state blocks std::vector<Scalar*> param_blocks; diff --git a/src/ceres_wrapper/ceres_manager.h b/src/ceres_wrapper/ceres_manager.h index e9eeb93eba41c90f1f83d232777b2a3b7d8d0f94..ebc8b04e945fb3007aa1e6c6926b450c84211746 100644 --- a/src/ceres_wrapper/ceres_manager.h +++ b/src/ceres_wrapper/ceres_manager.h @@ -55,6 +55,8 @@ class CeresManager : public SolverManager ceres::Solver::Options& getSolverOptions(); + void check(); + private: std::string solveImpl(const ReportVerbosity report_level) override; @@ -72,8 +74,6 @@ class CeresManager : public SolverManager void updateStateBlockLocalParametrization(const StateBlockPtr& state_ptr) override; ceres::CostFunctionPtr createCostFunction(const ConstraintBasePtr& _ctr_ptr); - - void check(); }; inline ceres::Solver::Summary CeresManager::getSummary() diff --git a/src/test/gtest_ceres_manager.cpp b/src/test/gtest_ceres_manager.cpp index a65afe7671c8f28f8149a09ca6dfd483d82464bc..0df86c75e583f51f49a4dd918c79337be92a5e64 100644 --- a/src/test/gtest_ceres_manager.cpp +++ b/src/test/gtest_ceres_manager.cpp @@ -13,8 +13,11 @@ #include "../state_block.h" #include "../capture_void.h" #include "../constraint_pose_2D.h" +#include "../constraint_quaternion_absolute.h" #include "../solver/solver_manager.h" #include "../ceres_wrapper/ceres_manager.h" +#include "../local_parametrization_angle.h" +#include "../local_parametrization_quaternion.h" #include "ceres/ceres.h" @@ -52,11 +55,28 @@ class CeresManagerWrapper : public CeresManager return ceres_problem_->NumParameterBlocks(); }; + int numConstraints() + { + return ceres_problem_->NumResidualBlocks(); + }; + bool isConstraintRegistered(const ConstraintBasePtr& ctr_ptr) const { return ctr_2_residual_idx_.find(ctr_ptr) != ctr_2_residual_idx_.end() && ctr_2_costfunction_.find(ctr_ptr) != ctr_2_costfunction_.end(); }; + bool hasThisLocalParametrization(const StateBlockPtr& st, const LocalParametrizationBasePtr& local_param) + { + return state_blocks_local_param_.find(st) != state_blocks_local_param_.end() && + state_blocks_local_param_.at(st)->getLocalParametrizationPtr() == local_param && + ceres_problem_->GetParameterization(getAssociatedMemBlockPtr(st)) == state_blocks_local_param_.at(st).get(); + }; + + bool hasLocalParametrization(const StateBlockPtr& st) const + { + return state_blocks_local_param_.find(st) != state_blocks_local_param_.end(); + }; + }; TEST(CeresManager, Create) @@ -66,6 +86,9 @@ TEST(CeresManager, Create) // check double ointers to branches ASSERT_EQ(P, ceres_manager_ptr->getProblemPtr()); + + // run ceres manager check + ceres_manager_ptr->check(); } TEST(CeresManager, AddStateBlock) @@ -86,6 +109,9 @@ TEST(CeresManager, AddStateBlock) // check stateblock ASSERT_TRUE(ceres_manager_ptr->isStateBlockRegisteredSolverManager(sb_ptr)); ASSERT_TRUE(ceres_manager_ptr->isStateBlockRegisteredCeresManager(sb_ptr)); + + // run ceres manager check + ceres_manager_ptr->check(); } TEST(CeresManager, DoubleAddStateBlock) @@ -112,6 +138,9 @@ TEST(CeresManager, DoubleAddStateBlock) // check stateblock ASSERT_TRUE(ceres_manager_ptr->isStateBlockRegisteredSolverManager(sb_ptr)); ASSERT_TRUE(ceres_manager_ptr->isStateBlockRegisteredCeresManager(sb_ptr)); + + // run ceres manager check + ceres_manager_ptr->check(); } TEST(CeresManager, UpdateStateBlock) @@ -140,6 +169,9 @@ TEST(CeresManager, UpdateStateBlock) // check stateblock fixed ASSERT_TRUE(ceres_manager_ptr->isStateBlockFixed(sb_ptr)); + + // run ceres manager check + ceres_manager_ptr->check(); } TEST(CeresManager, AddUpdateStateBlock) @@ -164,6 +196,9 @@ TEST(CeresManager, AddUpdateStateBlock) ASSERT_TRUE(ceres_manager_ptr->isStateBlockRegisteredSolverManager(sb_ptr)); ASSERT_TRUE(ceres_manager_ptr->isStateBlockRegisteredCeresManager(sb_ptr)); ASSERT_TRUE(ceres_manager_ptr->isStateBlockFixed(sb_ptr)); + + // run ceres manager check + ceres_manager_ptr->check(); } TEST(CeresManager, RemoveStateBlock) @@ -192,7 +227,10 @@ TEST(CeresManager, RemoveStateBlock) // check stateblock ASSERT_FALSE(ceres_manager_ptr->isStateBlockRegisteredSolverManager(sb_ptr)); - ASSERT_TRUE(ceres_manager_ptr->numStateBlocks() == 0); + ASSERT_EQ(ceres_manager_ptr->numStateBlocks(), 0); + + // run ceres manager check + ceres_manager_ptr->check(); } TEST(CeresManager, AddRemoveStateBlock) @@ -215,7 +253,10 @@ TEST(CeresManager, AddRemoveStateBlock) // check no stateblocks ASSERT_FALSE(ceres_manager_ptr->isStateBlockRegisteredSolverManager(sb_ptr)); - ASSERT_TRUE(ceres_manager_ptr->numStateBlocks() == 0); + ASSERT_EQ(ceres_manager_ptr->numStateBlocks(), 0); + + // run ceres manager check + ceres_manager_ptr->check(); } TEST(CeresManager, RemoveUpdateStateBlock) @@ -238,6 +279,9 @@ TEST(CeresManager, RemoveUpdateStateBlock) // update solver ceres_manager_ptr->update(); + + // run ceres manager check + ceres_manager_ptr->check(); } TEST(CeresManager, DoubleRemoveStateBlock) @@ -263,6 +307,9 @@ TEST(CeresManager, DoubleRemoveStateBlock) // update solver manager ceres_manager_ptr->update(); + + // run ceres manager check + ceres_manager_ptr->check(); } TEST(CeresManager, AddConstraint) @@ -270,9 +317,27 @@ TEST(CeresManager, AddConstraint) ProblemPtr P = Problem::create("PO 2D"); CeresManagerWrapperPtr ceres_manager_ptr = std::make_shared<CeresManagerWrapper>(P); - // Create State block - Vector2s state; state << 1, 2; - StateBlockPtr sb_ptr = std::make_shared<StateBlock>(state); + // Create (and add) constraint point 2d + FrameBasePtr F = P->emplaceFrame(KEY_FRAME, P->zeroState(), TimeStamp(0)); + CaptureBasePtr C = F->addCapture(std::make_shared<CaptureVoid>(0, nullptr)); + FeatureBasePtr f = C->addFeature(std::make_shared<FeatureBase>("ODOM 2D", Vector3s::Zero(), Matrix3s::Identity())); + ConstraintPose2DPtr c = std::static_pointer_cast<ConstraintPose2D>(f->addConstraint(std::make_shared<ConstraintPose2D>(f))); + + // update solver + ceres_manager_ptr->update(); + + // check constraint + ASSERT_TRUE(ceres_manager_ptr->isConstraintRegistered(c)); + ASSERT_EQ(ceres_manager_ptr->numConstraints(), 1); + + // run ceres manager check + ceres_manager_ptr->check(); +} + +TEST(CeresManager, DoubleAddConstraint) +{ + ProblemPtr P = Problem::create("PO 2D"); + CeresManagerWrapperPtr ceres_manager_ptr = std::make_shared<CeresManagerWrapper>(P); // Create (and add) constraint point 2d FrameBasePtr F = P->emplaceFrame(KEY_FRAME, P->zeroState(), TimeStamp(0)); @@ -280,11 +345,18 @@ TEST(CeresManager, AddConstraint) FeatureBasePtr f = C->addFeature(std::make_shared<FeatureBase>("ODOM 2D", Vector3s::Zero(), Matrix3s::Identity())); ConstraintPose2DPtr c = std::static_pointer_cast<ConstraintPose2D>(f->addConstraint(std::make_shared<ConstraintPose2D>(f))); + // add constraint again + P->addConstraint(c); + // update solver ceres_manager_ptr->update(); // check constraint ASSERT_TRUE(ceres_manager_ptr->isConstraintRegistered(c)); + ASSERT_EQ(ceres_manager_ptr->numConstraints(), 1); + + // run ceres manager check + ceres_manager_ptr->check(); } TEST(CeresManager, RemoveConstraint) @@ -292,10 +364,6 @@ TEST(CeresManager, RemoveConstraint) ProblemPtr P = Problem::create("PO 2D"); CeresManagerWrapperPtr ceres_manager_ptr = std::make_shared<CeresManagerWrapper>(P); - // Create State block - Vector2s state; state << 1, 2; - StateBlockPtr sb_ptr = std::make_shared<StateBlock>(state); - // Create (and add) constraint point 2d FrameBasePtr F = P->emplaceFrame(KEY_FRAME, P->zeroState(), TimeStamp(0)); CaptureBasePtr C = F->addCapture(std::make_shared<CaptureVoid>(0, nullptr)); @@ -305,7 +373,7 @@ TEST(CeresManager, RemoveConstraint) // update solver ceres_manager_ptr->update(); - // add constraint + // remove constraint P->removeConstraint(c); // update solver @@ -313,6 +381,10 @@ TEST(CeresManager, RemoveConstraint) // check constraint ASSERT_FALSE(ceres_manager_ptr->isConstraintRegistered(c)); + ASSERT_EQ(ceres_manager_ptr->numConstraints(), 0); + + // run ceres manager check + ceres_manager_ptr->check(); } TEST(CeresManager, AddRemoveConstraint) @@ -320,10 +392,6 @@ TEST(CeresManager, AddRemoveConstraint) ProblemPtr P = Problem::create("PO 2D"); CeresManagerWrapperPtr ceres_manager_ptr = std::make_shared<CeresManagerWrapper>(P); - // Create State block - Vector2s state; state << 1, 2; - StateBlockPtr sb_ptr = std::make_shared<StateBlock>(state); - // Create (and add) constraint point 2d FrameBasePtr F = P->emplaceFrame(KEY_FRAME, P->zeroState(), TimeStamp(0)); CaptureBasePtr C = F->addCapture(std::make_shared<CaptureVoid>(0, nullptr)); @@ -332,7 +400,7 @@ TEST(CeresManager, AddRemoveConstraint) ASSERT_TRUE(P->getConstraintNotificationMap().begin()->first == c); - // add constraint + // remove constraint P->removeConstraint(c); ASSERT_TRUE(P->getConstraintNotificationMap().empty()); @@ -342,6 +410,10 @@ TEST(CeresManager, AddRemoveConstraint) // check constraint ASSERT_FALSE(ceres_manager_ptr->isConstraintRegistered(c)); + ASSERT_EQ(ceres_manager_ptr->numConstraints(), 0); + + // run ceres manager check + ceres_manager_ptr->check(); } TEST(CeresManager, DoubleRemoveConstraint) @@ -349,10 +421,6 @@ TEST(CeresManager, DoubleRemoveConstraint) ProblemPtr P = Problem::create("PO 2D"); CeresManagerWrapperPtr ceres_manager_ptr = std::make_shared<CeresManagerWrapper>(P); - // Create State block - Vector2s state; state << 1, 2; - StateBlockPtr sb_ptr = std::make_shared<StateBlock>(state); - // Create (and add) constraint point 2d FrameBasePtr F = P->emplaceFrame(KEY_FRAME, P->zeroState(), TimeStamp(0)); CaptureBasePtr C = F->addCapture(std::make_shared<CaptureVoid>(0, nullptr)); @@ -377,8 +445,191 @@ TEST(CeresManager, DoubleRemoveConstraint) // check constraint ASSERT_FALSE(ceres_manager_ptr->isConstraintRegistered(c)); + ASSERT_EQ(ceres_manager_ptr->numConstraints(), 0); + + // run ceres manager check + ceres_manager_ptr->check(); +} + +TEST(CeresManager, AddStateBlockLocalParam) +{ + ProblemPtr P = Problem::create("PO 2D"); + CeresManagerWrapperPtr ceres_manager_ptr = std::make_shared<CeresManagerWrapper>(P); + + // Create State block + Vector1s state; state << 1; + StateBlockPtr sb_ptr = std::make_shared<StateBlock>(state); + + // Local param + LocalParametrizationBasePtr local_param_ptr = std::make_shared<LocalParametrizationAngle>(); + sb_ptr->setLocalParametrizationPtr(local_param_ptr); + + // add stateblock + P->addStateBlock(sb_ptr); + + // update solver + ceres_manager_ptr->update(); + + // check stateblock + ASSERT_TRUE(ceres_manager_ptr->hasLocalParametrization(sb_ptr)); + ASSERT_TRUE(ceres_manager_ptr->hasThisLocalParametrization(sb_ptr,local_param_ptr)); + + // run ceres manager check + ceres_manager_ptr->check(); } +TEST(CeresManager, RemoveLocalParam) +{ + ProblemPtr P = Problem::create("PO 2D"); + CeresManagerWrapperPtr ceres_manager_ptr = std::make_shared<CeresManagerWrapper>(P); + + // Create State block + Vector1s state; state << 1; + StateBlockPtr sb_ptr = std::make_shared<StateBlock>(state); + + // Local param + LocalParametrizationBasePtr local_param_ptr = std::make_shared<LocalParametrizationAngle>(); + sb_ptr->setLocalParametrizationPtr(local_param_ptr); + + // add stateblock + P->addStateBlock(sb_ptr); + + // update solver + ceres_manager_ptr->update(); + + // Remove local param + sb_ptr->removeLocalParametrization(); + + // update solver + ceres_manager_ptr->update(); + + // check stateblock + ASSERT_FALSE(ceres_manager_ptr->hasLocalParametrization(sb_ptr)); + + // run ceres manager check + ceres_manager_ptr->check(); +} + +TEST(CeresManager, AddLocalParam) +{ + ProblemPtr P = Problem::create("PO 2D"); + CeresManagerWrapperPtr ceres_manager_ptr = std::make_shared<CeresManagerWrapper>(P); + + // Create State block + Vector1s state; state << 1; + StateBlockPtr sb_ptr = std::make_shared<StateBlock>(state); + + // add stateblock + P->addStateBlock(sb_ptr); + + // update solver + ceres_manager_ptr->update(); + + // check stateblock + ASSERT_FALSE(ceres_manager_ptr->hasLocalParametrization(sb_ptr)); + + // Local param + LocalParametrizationBasePtr local_param_ptr = std::make_shared<LocalParametrizationAngle>(); + sb_ptr->setLocalParametrizationPtr(local_param_ptr); + + // update solver + ceres_manager_ptr->update(); + + // check stateblock + ASSERT_TRUE(ceres_manager_ptr->hasLocalParametrization(sb_ptr)); + ASSERT_TRUE(ceres_manager_ptr->hasThisLocalParametrization(sb_ptr,local_param_ptr)); + + // run ceres manager check + ceres_manager_ptr->check(); +} + +TEST(CeresManager, ConstraintsRemoveLocalParam) +{ + ProblemPtr P = Problem::create("PO 3D"); + CeresManagerWrapperPtr ceres_manager_ptr = std::make_shared<CeresManagerWrapper>(P); + + // Create (and add) 2 constraints quaternion + FrameBasePtr F = P->emplaceFrame(KEY_FRAME, P->zeroState(), TimeStamp(0)); + CaptureBasePtr C = F->addCapture(std::make_shared<CaptureVoid>(0, nullptr)); + FeatureBasePtr f = C->addFeature(std::make_shared<FeatureBase>("ODOM 2D", Vector3s::Zero(), Matrix3s::Identity())); + ConstraintQuaternionAbsolutePtr c1 = std::static_pointer_cast<ConstraintQuaternionAbsolute>(f->addConstraint(std::make_shared<ConstraintQuaternionAbsolute>(F->getOPtr()))); + ConstraintQuaternionAbsolutePtr c2 = std::static_pointer_cast<ConstraintQuaternionAbsolute>(f->addConstraint(std::make_shared<ConstraintQuaternionAbsolute>(F->getOPtr()))); + + // update solver + ceres_manager_ptr->update(); + + // check local param + ASSERT_TRUE(ceres_manager_ptr->hasLocalParametrization(F->getOPtr())); + ASSERT_TRUE(ceres_manager_ptr->hasThisLocalParametrization(F->getOPtr(),F->getOPtr()->getLocalParametrizationPtr())); + + // check constraint + ASSERT_TRUE(ceres_manager_ptr->isConstraintRegistered(c1)); + ASSERT_TRUE(ceres_manager_ptr->isConstraintRegistered(c2)); + ASSERT_EQ(ceres_manager_ptr->numConstraints(), 2); + + // remove local param + F->getOPtr()->removeLocalParametrization(); + + // update solver + ceres_manager_ptr->update(); + + // check local param + ASSERT_FALSE(ceres_manager_ptr->hasLocalParametrization(F->getOPtr())); + + // check constraint + ASSERT_TRUE(ceres_manager_ptr->isConstraintRegistered(c1)); + ASSERT_TRUE(ceres_manager_ptr->isConstraintRegistered(c2)); + ASSERT_EQ(ceres_manager_ptr->numConstraints(), 2); + + // run ceres manager check + ceres_manager_ptr->check(); +} + +TEST(CeresManager, ConstraintsUpdateLocalParam) +{ + ProblemPtr P = Problem::create("PO 3D"); + CeresManagerWrapperPtr ceres_manager_ptr = std::make_shared<CeresManagerWrapper>(P); + + // Create (and add) 2 constraints quaternion + FrameBasePtr F = P->emplaceFrame(KEY_FRAME, P->zeroState(), TimeStamp(0)); + CaptureBasePtr C = F->addCapture(std::make_shared<CaptureVoid>(0, nullptr)); + FeatureBasePtr f = C->addFeature(std::make_shared<FeatureBase>("ODOM 2D", Vector3s::Zero(), Matrix3s::Identity())); + ConstraintQuaternionAbsolutePtr c1 = std::static_pointer_cast<ConstraintQuaternionAbsolute>(f->addConstraint(std::make_shared<ConstraintQuaternionAbsolute>(F->getOPtr()))); + ConstraintQuaternionAbsolutePtr c2 = std::static_pointer_cast<ConstraintQuaternionAbsolute>(f->addConstraint(std::make_shared<ConstraintQuaternionAbsolute>(F->getOPtr()))); + + // update solver + ceres_manager_ptr->update(); + + // check local param + ASSERT_TRUE(ceres_manager_ptr->hasLocalParametrization(F->getOPtr())); + ASSERT_TRUE(ceres_manager_ptr->hasThisLocalParametrization(F->getOPtr(),F->getOPtr()->getLocalParametrizationPtr())); + + // check constraint + ASSERT_TRUE(ceres_manager_ptr->isConstraintRegistered(c1)); + ASSERT_TRUE(ceres_manager_ptr->isConstraintRegistered(c2)); + ASSERT_EQ(ceres_manager_ptr->numConstraints(), 2); + + // remove local param + LocalParametrizationBasePtr local_param_ptr = std::make_shared<LocalParametrizationQuaternionGlobal>(); + F->getOPtr()->setLocalParametrizationPtr(local_param_ptr); + + // update solver + ceres_manager_ptr->update(); + + // check local param + ASSERT_TRUE(ceres_manager_ptr->hasLocalParametrization(F->getOPtr())); + ASSERT_TRUE(ceres_manager_ptr->hasThisLocalParametrization(F->getOPtr(),local_param_ptr)); + + // check constraint + ASSERT_TRUE(ceres_manager_ptr->isConstraintRegistered(c1)); + ASSERT_TRUE(ceres_manager_ptr->isConstraintRegistered(c2)); + ASSERT_EQ(ceres_manager_ptr->numConstraints(), 2); + + // run ceres manager check + ceres_manager_ptr->check(); +} + + int main(int argc, char **argv) { testing::InitGoogleTest(&argc, argv);