diff --git a/include/core/ceres_wrapper/ceres_manager.h b/include/core/ceres_wrapper/ceres_manager.h index 95eed7b383c3f3e003edae1eaa5347469a51d00e..a9a05eb9bb076dc68127e60b3a28a8964ebd3196 100644 --- a/include/core/ceres_wrapper/ceres_manager.h +++ b/include/core/ceres_wrapper/ceres_manager.h @@ -82,7 +82,7 @@ class CeresManager : public SolverManager ceres::Solver::Options& getSolverOptions(); - virtual bool check() const override; + virtual bool check() override; const Eigen::SparseMatrixd computeHessian() const; diff --git a/include/core/solver/solver_manager.h b/include/core/solver/solver_manager.h index a410690155fc898411c715a142376047c852f6a3..e28c544bd1266362b453b3f9619626647d4b5dcb 100644 --- a/include/core/solver/solver_manager.h +++ b/include/core/solver/solver_manager.h @@ -117,14 +117,16 @@ public: virtual bool isFactorRegistered(const FactorBasePtr& fac_ptr) const; - virtual bool check() const = 0; + virtual bool check() = 0; + + void assertCheck(); protected: std::map<StateBlockPtr, Eigen::VectorXd> state_blocks_; virtual Eigen::VectorXd& getAssociatedMemBlock(const StateBlockPtr& state_ptr); - const double* getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr) const; + //const double* getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr) const; virtual double* getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr); virtual std::string solveImpl(const ReportVerbosity report_level) = 0; @@ -146,6 +148,13 @@ protected: virtual bool isFactorRegisteredDerived(const FactorBasePtr& fac_ptr) const = 0; }; +inline void SolverManager::assertCheck() +{ + #ifdef _WOLF_DEBUG + assert(check()); + #endif +} + } // namespace wolf #endif /* _WOLF_SOLVER_MANAGER_H_ */ diff --git a/src/ceres_wrapper/ceres_manager.cpp b/src/ceres_wrapper/ceres_manager.cpp index af6f468fb7756fe527c104e78f5807e736cd4d6e..aad86ac2950a77fbe4067a200a419be9d553eaec 100644 --- a/src/ceres_wrapper/ceres_manager.cpp +++ b/src/ceres_wrapper/ceres_manager.cpp @@ -50,11 +50,6 @@ CeresManager::~CeresManager() std::string CeresManager::solveImpl(const ReportVerbosity report_level) { - // Check - #ifdef _WOLF_DEBUG - assert(check()); - #endif - // run Ceres Solver ceres::Solve(ceres_options_, ceres_problem_.get(), &summary_); @@ -410,7 +405,7 @@ ceres::CostFunctionPtr CeresManager::createCostFunction(const FactorBasePtr& _fa throw std::invalid_argument( "Wrong Jacobian Method!" ); } -bool CeresManager::check() const +bool CeresManager::check() { bool ok = true; @@ -432,7 +427,8 @@ bool CeresManager::check() const } // Check parameter blocks - for (const auto& state_block_pair : state_blocks_) + for (auto&& state_block_pair : state_blocks_) + //for (const auto& state_block_pair : state_blocks_) { if (!ceres_problem_->HasParameterBlock(state_block_pair.second.data())) { @@ -442,7 +438,8 @@ bool CeresManager::check() const } // Check residual blocks - for (const auto& fac_res_pair : fac_2_residual_idx_) + for (auto&& fac_res_pair : fac_2_residual_idx_) + //for (const auto& fac_res_pair : fac_2_residual_idx_) { // costfunction - residual if (fac_2_costfunction_.find(fac_res_pair.first) == fac_2_costfunction_.end()) @@ -450,7 +447,8 @@ bool CeresManager::check() const WOLF_ERROR("CeresManager::check: any factor in fac_2_residual_idx_ is not in fac_2_costfunction_"); ok = false; } - if (fac_2_costfunction_.at(fac_res_pair.first).get() != ceres_problem_->GetCostFunctionForResidualBlock(fac_res_pair.second)) + if (fac_2_costfunction_[fac_res_pair.first].get() != ceres_problem_->GetCostFunctionForResidualBlock(fac_res_pair.second)) + //if (fac_2_costfunction_.at(fac_res_pair.first).get() != ceres_problem_->GetCostFunctionForResidualBlock(fac_res_pair.second)) { WOLF_ERROR("CeresManager::check: fac_2_costfunction_ and ceres mismatch"); ok = false; diff --git a/src/solver/solver_manager.cpp b/src/solver/solver_manager.cpp index 800201dd064d6a7fd5197972ee0a5f16e423ba2a..b2f1f807821ba42bfb2ff28678c376517b33dc35 100644 --- a/src/solver/solver_manager.cpp +++ b/src/solver/solver_manager.cpp @@ -130,6 +130,9 @@ std::string SolverManager::solve(const ReportVerbosity report_level) // update problem update(); + // Check + assertCheck(); + std::string report = solveImpl(report_level); // update StateBlocks with optimized state value. @@ -158,15 +161,15 @@ Eigen::VectorXd& SolverManager::getAssociatedMemBlock(const StateBlockPtr& state return it->second; } -const double* SolverManager::getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr) const -{ - auto it = state_blocks_.find(state_ptr); - - if (it == state_blocks_.end()) - throw std::runtime_error("Tried to retrieve the memory block of an unregistered StateBlock !"); - - return it->second.data(); -} +//const double* SolverManager::getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr) const +//{ +// auto it = state_blocks_.find(state_ptr); +// +// if (it == state_blocks_.end()) +// throw std::runtime_error("Tried to retrieve the memory block of an unregistered StateBlock !"); +// +// return it->second.data(); +//} double* SolverManager::getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr) { diff --git a/test/dummy/solver_manager_dummy.h b/test/dummy/solver_manager_dummy.h index e9154bb38532ffa4cdcaae21a202103e24f16948..c84e3954b4c680963bbe1425cd883be208949d48 100644 --- a/test/dummy/solver_manager_dummy.h +++ b/test/dummy/solver_manager_dummy.h @@ -61,7 +61,7 @@ class SolverManagerDummy : public SolverManager SizeStd iterations() { return 1; } double initialCost() { return double(1); } double finalCost() { return double(0); } - virtual bool check() const override {return true;} + virtual bool check() override {return true;}