diff --git a/include/core/ceres_wrapper/solver_ceres.h b/include/core/ceres_wrapper/solver_ceres.h index 6724d7be89d594a6f8c88c4cce0d57bcde47d37a..215df00ffb8abb292c20a237f231777355fe5d80 100644 --- a/include/core/ceres_wrapper/solver_ceres.h +++ b/include/core/ceres_wrapper/solver_ceres.h @@ -61,7 +61,7 @@ class SolverCeres : public SolverManager ~SolverCeres() override; - ceres::Solver::Summary getSummary(); + ceres::Solver::Summary getSummary() const; std::unique_ptr<ceres::Problem>& getCeresProblem(); @@ -70,12 +70,13 @@ class SolverCeres : public SolverManager bool computeCovariancesDerived(const std::vector<StateBlockPtr>& st_list) override; - bool hasConverged() override; - bool wasStopped() override; - unsigned int iterations() override; - double initialCost() override; - double finalCost() override; - double totalTime() override; + bool converged() const override; + bool failed() const override; + bool wasStopped() const override; + unsigned int iterations() const override; + double initialCost() const override; + double finalCost() const override; + double totalTime() const override; ceres::Solver::Options& getSolverOptions(); @@ -108,10 +109,10 @@ class SolverCeres : public SolverManager bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) const override; - bool isStateBlockFixedDerived(const StateBlockPtr& st) override; + bool isStateBlockFixedDerived(const StateBlockPtr& st) const override; bool hasThisLocalParametrizationDerived(const StateBlockPtr& st, - const LocalParametrizationBasePtr& local_param) override; + const LocalParametrizationBasePtr& local_param) const override; bool hasLocalParametrizationDerived(const StateBlockPtr& st) const override; @@ -121,7 +122,7 @@ class SolverCeres : public SolverManager ceres::Solver::Options solver_options_; }; -inline ceres::Solver::Summary SolverCeres::getSummary() +inline ceres::Solver::Summary SolverCeres::getSummary() const { return summary_; } @@ -147,7 +148,7 @@ inline bool SolverCeres::isStateBlockRegisteredDerived(const StateBlockPtr& stat return ceres_problem_->HasParameterBlock(getAssociatedMemBlockPtr(state_ptr)); } -inline bool SolverCeres::isStateBlockFixedDerived(const StateBlockPtr& st) +inline bool SolverCeres::isStateBlockFixedDerived(const StateBlockPtr& st) const { if (state_blocks_.count(st) == 0) return false; return ceres_problem_->IsParameterBlockConstant(getAssociatedMemBlockPtr(st)); diff --git a/include/core/solver/solver_manager.h b/include/core/solver/solver_manager.h index e28f456e8c60d69535d8f2665210e561c5d22c31..1a83b822faadfd17d22bf7a5da08e4cb6c87c816 100644 --- a/include/core/solver/solver_manager.h +++ b/include/core/solver/solver_manager.h @@ -176,12 +176,12 @@ class SolverManager virtual bool isStateBlockFloating(const StateBlockPtr& state_ptr) const final; - virtual bool isStateBlockFixed(const StateBlockPtr& st) final; + virtual bool isStateBlockFixed(const StateBlockPtr& st) const final; virtual bool isFactorRegistered(const FactorBasePtr& fac_ptr) const final; virtual bool hasThisLocalParametrization(const StateBlockPtr& st, - const LocalParametrizationBasePtr& local_param) final; + const LocalParametrizationBasePtr& local_param) const final; virtual bool hasLocalParametrization(const StateBlockPtr& st) const final; @@ -227,23 +227,24 @@ class SolverManager virtual void updateStateBlockStatusDerived(const StateBlockPtr& state_ptr) = 0; virtual void updateStateBlockLocalParametrizationDerived(const StateBlockPtr& state_ptr) = 0; - virtual bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) const = 0; - virtual bool isFactorRegisteredDerived(const FactorBasePtr& fac_ptr) const = 0; - virtual bool isStateBlockFixedDerived(const StateBlockPtr& st) = 0; - virtual bool hasLocalParametrizationDerived(const StateBlockPtr& st) const = 0; + virtual bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) const = 0; + virtual bool isFactorRegisteredDerived(const FactorBasePtr& fac_ptr) const = 0; + virtual bool isStateBlockFixedDerived(const StateBlockPtr& st) const = 0; + virtual bool hasLocalParametrizationDerived(const StateBlockPtr& st) const = 0; virtual bool hasThisLocalParametrizationDerived(const StateBlockPtr& st, - const LocalParametrizationBasePtr& local_param) = 0; + const LocalParametrizationBasePtr& local_param) const = 0; virtual void printProfilingDerived(std::ostream& stream = std::cout) const = 0; virtual bool checkDerived(std::string prefix = "") const = 0; public: - virtual bool hasConverged() = 0; - virtual bool wasStopped() = 0; - virtual unsigned int iterations() = 0; - virtual double initialCost() = 0; - virtual double finalCost() = 0; - virtual double totalTime() = 0; + virtual bool converged() const = 0; + virtual bool failed() const = 0; + virtual bool wasStopped() const = 0; + virtual unsigned int iterations() const = 0; + virtual double initialCost() const = 0; + virtual double finalCost() const = 0; + virtual double totalTime() const = 0; protected: // PARAMS diff --git a/src/ceres_wrapper/solver_ceres.cpp b/src/ceres_wrapper/solver_ceres.cpp index 08867c35a222a5aee188de831595d35e60b18c5f..a04f5bd76cb677bed023177fe23f2deafa0fcdee 100644 --- a/src/ceres_wrapper/solver_ceres.cpp +++ b/src/ceres_wrapper/solver_ceres.cpp @@ -134,7 +134,7 @@ std::string SolverCeres::solveDerived(const ReportVerbosity report_level) n_iter_max_ = std::max(n_iter_max_, iterations()); // convergence (profiling) - if (hasConverged()) + if (converged()) n_convergence_++; else if (wasStopped()) n_interrupted_++; @@ -649,33 +649,38 @@ void SolverCeres::updateStateBlockLocalParametrizationDerived(const StateBlockPt for (auto fac : involved_factors) addFactorDerived(fac); } -bool SolverCeres::hasConverged() +bool SolverCeres::converged() const { return summary_.termination_type == ceres::CONVERGENCE; } -bool SolverCeres::wasStopped() +bool SolverCeres::wasStopped() const { return summary_.termination_type == ceres::USER_FAILURE or summary_.termination_type == ceres::USER_SUCCESS; } -unsigned int SolverCeres::iterations() +bool SolverCeres::failed() const +{ + return summary_.termination_type == ceres::USER_FAILURE or summary_.termination_type == ceres::FAILURE; +} + +unsigned int SolverCeres::iterations() const { if (summary_.num_successful_steps + summary_.num_unsuccessful_steps < 1) return 0; return summary_.num_successful_steps + summary_.num_unsuccessful_steps; } -double SolverCeres::initialCost() +double SolverCeres::initialCost() const { return double(summary_.initial_cost); } -double SolverCeres::finalCost() +double SolverCeres::finalCost() const { return double(summary_.final_cost); } -double SolverCeres::totalTime() +double SolverCeres::totalTime() const { return double(summary_.total_time_in_seconds); } @@ -885,7 +890,7 @@ const Eigen::SparseMatrixd SolverCeres::computeHessian() const } bool SolverCeres::hasThisLocalParametrizationDerived(const StateBlockPtr& st, - const LocalParametrizationBasePtr& local_param) + const LocalParametrizationBasePtr& local_param) const { return state_blocks_local_param_.count(st) == 1 && state_blocks_local_param_.at(st)->getLocalParametrization() == local_param && diff --git a/src/solver/solver_manager.cpp b/src/solver/solver_manager.cpp index 1bec6c95e2e8e789cbc71059fd859938dcb9324a..e1f1006342d9bc7a587a0f104e1ae2a0ceece565 100644 --- a/src/solver/solver_manager.cpp +++ b/src/solver/solver_manager.cpp @@ -568,7 +568,7 @@ bool SolverManager::isFactorRegistered(const FactorBasePtr& fac_ptr) const return factors_.count(fac_ptr) == 1 and isFactorRegisteredDerived(fac_ptr); } -bool SolverManager::isStateBlockFixed(const StateBlockPtr& st) +bool SolverManager::isStateBlockFixed(const StateBlockPtr& st) const { if (!isStateBlockRegistered(st)) return false; @@ -578,7 +578,7 @@ bool SolverManager::isStateBlockFixed(const StateBlockPtr& st) } bool SolverManager::hasThisLocalParametrization(const StateBlockPtr& st, - const LocalParametrizationBasePtr& local_param) + const LocalParametrizationBasePtr& local_param) const { if (!isStateBlockRegistered(st)) return false; diff --git a/test/dummy/solver_dummy.cpp b/test/dummy/solver_dummy.cpp index 0ea15b1e6cbe519afb9087ab4af3558762a0a033..006b032e9a365a29c23b5cdc1cf1dae62a9d089a 100644 --- a/test/dummy/solver_dummy.cpp +++ b/test/dummy/solver_dummy.cpp @@ -21,15 +21,15 @@ namespace wolf { SolverDummy::SolverDummy(const ProblemPtr& wolf_problem, const YAML::Node params) - : SolverManager(wolf_problem, params){}; + : SolverManager(wolf_problem, params) {}; -bool SolverDummy::isStateBlockFixedDerived(const StateBlockPtr& st) +bool SolverDummy::isStateBlockFixedDerived(const StateBlockPtr& st) const { return state_block_fixed_.at(st); }; bool SolverDummy::hasThisLocalParametrizationDerived(const StateBlockPtr& st, - const LocalParametrizationBasePtr& local_param) + const LocalParametrizationBasePtr& local_param) const { return state_block_local_param_.at(st) == local_param; }; @@ -59,27 +59,31 @@ bool SolverDummy::computeCovariancesDerived(const std::vector<StateBlockPtr>& st }; // The following are dummy implementations -bool SolverDummy::hasConverged() +bool SolverDummy::converged() const { return true; } -bool SolverDummy::wasStopped() +bool SolverDummy::failed() const { return false; } -unsigned int SolverDummy::iterations() +bool SolverDummy::wasStopped() const +{ + return false; +} +unsigned int SolverDummy::iterations() const { return 1; } -double SolverDummy::initialCost() +double SolverDummy::initialCost() const { return double(1); } -double SolverDummy::finalCost() +double SolverDummy::finalCost() const { return double(0); } -double SolverDummy::totalTime() +double SolverDummy::totalTime() const { return double(0); } diff --git a/test/dummy/solver_dummy.h b/test/dummy/solver_dummy.h index fc233a9505ba0aff6e1e0243820f6ced417b05fc..d3cb42db19c3a24e5aa4d492de24fcc431d8425f 100644 --- a/test/dummy/solver_dummy.h +++ b/test/dummy/solver_dummy.h @@ -34,10 +34,10 @@ class SolverDummy : public SolverManager SolverDummy(const ProblemPtr& wolf_problem, const YAML::Node params); WOLF_SOLVER_CREATE(SolverDummy); - bool isStateBlockFixedDerived(const StateBlockPtr& st) override; + bool isStateBlockFixedDerived(const StateBlockPtr& st) const override; bool hasThisLocalParametrizationDerived(const StateBlockPtr& st, - const LocalParametrizationBasePtr& local_param) override; + const LocalParametrizationBasePtr& local_param) const override; bool hasLocalParametrizationDerived(const StateBlockPtr& st) const override; @@ -49,12 +49,13 @@ class SolverDummy : public SolverManager bool computeCovariancesDerived(const std::vector<StateBlockPtr>& st_list) override; // The following are dummy implementations - bool hasConverged() override; - bool wasStopped() override; - unsigned int iterations() override; - double initialCost() override; - double finalCost() override; - double totalTime() override; + bool converged() const override; + bool failed() const override; + bool wasStopped() const override; + unsigned int iterations() const override; + double initialCost() const override; + double finalCost() const override; + double totalTime() const override; void printProfilingDerived(std::ostream& _stream) const override; protected: diff --git a/test/gtest_factor_velocity_local_direction_3d.cpp b/test/gtest_factor_velocity_local_direction_3d.cpp index 312722b9616e676799f31d95ca29f2d33bb9d8a6..7cc77c8410d079b0ba47eb87d4c152e7dcc4aa68 100644 --- a/test/gtest_factor_velocity_local_direction_3d.cpp +++ b/test/gtest_factor_velocity_local_direction_3d.cpp @@ -139,7 +139,7 @@ class FactorVelocityLocalDirection3dTest : public testing::Test fac->getFeature()->remove(); // Update performaces - convergence.push_back(solver->hasConverged() ? 1 : 0); + convergence.push_back(solver->converged() ? 1 : 0); iterations.push_back(solver->iterations()); times.push_back(solver->totalTime()); error.push_back(acos(cos_angle_local));