From 74168f92220d3a06d5938e1e66c95b30cfaa7027 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joan=20Vallv=C3=A9=20Navarro?= <jvallve@iri.upc.edu> Date: Wed, 15 May 2019 11:15:33 +0200 Subject: [PATCH] checkers for registered factors and state blocks --- include/core/ceres_wrapper/ceres_manager.h | 18 +++++++++++++++++- include/core/solver/solver_manager.h | 8 ++++++-- src/solver/solver_manager.cpp | 10 ++++++++++ 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/include/core/ceres_wrapper/ceres_manager.h b/include/core/ceres_wrapper/ceres_manager.h index 0454c140e..e994980fc 100644 --- a/include/core/ceres_wrapper/ceres_manager.h +++ b/include/core/ceres_wrapper/ceres_manager.h @@ -72,7 +72,7 @@ class CeresManager : public SolverManager void check(); - private: + protected: std::string solveImpl(const ReportVerbosity report_level) override; @@ -89,6 +89,10 @@ class CeresManager : public SolverManager void updateStateBlockLocalParametrization(const StateBlockPtr& state_ptr) override; ceres::CostFunctionPtr createCostFunction(const FactorBasePtr& _fac_ptr); + + virtual bool isFactorRegistered(const FactorBasePtr& fac_ptr); + + virtual bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr); }; inline ceres::Solver::Summary CeresManager::getSummary() @@ -101,6 +105,18 @@ inline ceres::Solver::Options& CeresManager::getSolverOptions() return ceres_options_; } +inline bool CeresManager::isFactorRegistered(const FactorBasePtr& fac_ptr) +{ + return fac_2_residual_idx_.find(fac_ptr) != fac_2_residual_idx_.end() + && fac_2_costfunction_.find(fac_ptr) != fac_2_costfunction_.end(); +} + +inline bool CeresManager::isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) +{ + return state_blocks_local_param_.find(state_ptr) != state_blocks_local_param_.end() + && ceres_problem_->HasParameterBlock(getAssociatedMemBlockPtr(state_ptr)); +} + } // namespace wolf #endif diff --git a/include/core/solver/solver_manager.h b/include/core/solver/solver_manager.h index 240e563eb..049cb951d 100644 --- a/include/core/solver/solver_manager.h +++ b/include/core/solver/solver_manager.h @@ -67,9 +67,9 @@ public: ProblemPtr getProblem(); - virtual bool isRegistered(const StateBlockPtr& state_ptr) = 0; + virtual bool isStateBlockRegistered(const StateBlockPtr& state_ptr); - virtual bool isRegistered(const FactorBasePtr& fac_ptr) = 0; + virtual bool isFactorRegistered(const FactorBasePtr& fac_ptr); protected: @@ -91,6 +91,10 @@ protected: virtual void updateStateBlockStatus(const StateBlockPtr& state_ptr) = 0; virtual void updateStateBlockLocalParametrization(const StateBlockPtr& state_ptr) = 0; + + virtual bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) = 0; + + virtual bool isFactorRegisteredDerived(const FactorBasePtr& fac_ptr) = 0; }; } // namespace wolf diff --git a/src/solver/solver_manager.cpp b/src/solver/solver_manager.cpp index fa645762e..1a15d9c17 100644 --- a/src/solver/solver_manager.cpp +++ b/src/solver/solver_manager.cpp @@ -160,4 +160,14 @@ Scalar* SolverManager::getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr) return it->second.data(); } +bool SolverManager::isStateBlockRegistered(const StateBlockPtr& state_ptr) +{ + return state_blocks_.find(state_ptr) != state_blocks_.end() && isStateBlockRegisteredDerived(); +} + +bool SolverManager::isFactorRegistered(const FactorBasePtr& fac_ptr) +{ + return isFactorRegisteredDerived(); +} + } // namespace wolf -- GitLab