diff --git a/include/core/ceres_wrapper/ceres_manager.h b/include/core/ceres_wrapper/ceres_manager.h index 0454c140eda8ed520a04625fa196cebd41fbbcdd..e994980fc20e1035e54075184ff3a6a2b18a58b0 100644 --- a/include/core/ceres_wrapper/ceres_manager.h +++ b/include/core/ceres_wrapper/ceres_manager.h @@ -72,7 +72,7 @@ class CeresManager : public SolverManager void check(); - private: + protected: std::string solveImpl(const ReportVerbosity report_level) override; @@ -89,6 +89,10 @@ class CeresManager : public SolverManager void updateStateBlockLocalParametrization(const StateBlockPtr& state_ptr) override; ceres::CostFunctionPtr createCostFunction(const FactorBasePtr& _fac_ptr); + + virtual bool isFactorRegistered(const FactorBasePtr& fac_ptr); + + virtual bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr); }; inline ceres::Solver::Summary CeresManager::getSummary() @@ -101,6 +105,18 @@ inline ceres::Solver::Options& CeresManager::getSolverOptions() return ceres_options_; } +inline bool CeresManager::isFactorRegistered(const FactorBasePtr& fac_ptr) +{ + return fac_2_residual_idx_.find(fac_ptr) != fac_2_residual_idx_.end() + && fac_2_costfunction_.find(fac_ptr) != fac_2_costfunction_.end(); +} + +inline bool CeresManager::isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) +{ + return state_blocks_local_param_.find(state_ptr) != state_blocks_local_param_.end() + && ceres_problem_->HasParameterBlock(getAssociatedMemBlockPtr(state_ptr)); +} + } // namespace wolf #endif diff --git a/include/core/solver/solver_manager.h b/include/core/solver/solver_manager.h index 240e563eb16412e4aa9a5af529db0af491060437..049cb951d07c5e1c0b70cdf9be9a1d9dea992261 100644 --- a/include/core/solver/solver_manager.h +++ b/include/core/solver/solver_manager.h @@ -67,9 +67,9 @@ public: ProblemPtr getProblem(); - virtual bool isRegistered(const StateBlockPtr& state_ptr) = 0; + virtual bool isStateBlockRegistered(const StateBlockPtr& state_ptr); - virtual bool isRegistered(const FactorBasePtr& fac_ptr) = 0; + virtual bool isFactorRegistered(const FactorBasePtr& fac_ptr); protected: @@ -91,6 +91,10 @@ protected: virtual void updateStateBlockStatus(const StateBlockPtr& state_ptr) = 0; virtual void updateStateBlockLocalParametrization(const StateBlockPtr& state_ptr) = 0; + + virtual bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) = 0; + + virtual bool isFactorRegisteredDerived(const FactorBasePtr& fac_ptr) = 0; }; } // namespace wolf diff --git a/src/solver/solver_manager.cpp b/src/solver/solver_manager.cpp index fa645762e86c1ffd414f88efdc053ea1d6a33617..1a15d9c170c918f455e85bb527a337b05a461706 100644 --- a/src/solver/solver_manager.cpp +++ b/src/solver/solver_manager.cpp @@ -160,4 +160,14 @@ Scalar* SolverManager::getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr) return it->second.data(); } +bool SolverManager::isStateBlockRegistered(const StateBlockPtr& state_ptr) +{ + return state_blocks_.find(state_ptr) != state_blocks_.end() && isStateBlockRegisteredDerived(); +} + +bool SolverManager::isFactorRegistered(const FactorBasePtr& fac_ptr) +{ + return isFactorRegisteredDerived(); +} + } // namespace wolf