diff --git a/CMakeLists.txt b/CMakeLists.txt index 271db2988212108d5e8baf42e206601f92a98eaf..a6e7ee166720a726fc6c9178acc4119e0c4286e9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -123,7 +123,7 @@ ENDIF(BUILD_DEMOS OR BUILD_TESTS) FIND_PACKAGE(Threads REQUIRED) -FIND_PACKAGE(Ceres REQUIRED) #Ceres is not required +FIND_PACKAGE(Ceres REQUIRED) #Ceres is required FIND_PACKAGE(Eigen3 3.3 REQUIRED) if(${EIGEN3_VERSION_STRING} VERSION_LESS 3.3) @@ -392,10 +392,11 @@ SET(SRCS_YAML IF (Ceres_FOUND) SET(HDRS_WRAPPER #ceres_wrapper/qr_manager.h - include/core/ceres_wrapper/solver_ceres.h include/core/ceres_wrapper/cost_function_wrapper.h include/core/ceres_wrapper/create_numeric_diff_cost_function.h include/core/ceres_wrapper/local_parametrization_wrapper.h + include/core/ceres_wrapper/iteration_update_callback.h + include/core/ceres_wrapper/solver_ceres.h include/core/solver/solver_manager.h include/core/solver_suitesparse/sparse_utils.h ) diff --git a/include/core/ceres_wrapper/iteration_update_callback.h b/include/core/ceres_wrapper/iteration_update_callback.h new file mode 100644 index 0000000000000000000000000000000000000000..bf847c2270dd58af2ea02ba399909073ac125afc --- /dev/null +++ b/include/core/ceres_wrapper/iteration_update_callback.h @@ -0,0 +1,44 @@ +/* + * iteration_callback.h + * + * Created on: Jun 15, 2020 + * Author: joanvallve + */ + +#ifndef INCLUDE_CORE_CERES_WRAPPER_ITERATION_UPDATE_CALLBACK_H_ +#define INCLUDE_CORE_CERES_WRAPPER_ITERATION_UPDATE_CALLBACK_H_ + +#include "core/problem/problem.h" +#include "ceres/ceres.h" + +namespace wolf { + +class IterationUpdateCallback : public ceres::IterationCallback +{ + public: + explicit IterationUpdateCallback(ProblemPtr _problem, bool verbose = false) + : problem_(_problem) + , verbose_(verbose) + {} + + ~IterationUpdateCallback() {} + + ceres::CallbackReturnType operator()(const ceres::IterationSummary& summary) override + { + if (problem_->getStateBlockNotificationMapSize() != 0 or + problem_->getFactorNotificationMapSize() != 0) + { + WOLF_INFO_COND(verbose_, "Stopping solver to update the problem!"); + return ceres::SOLVER_TERMINATE_SUCCESSFULLY; + } + return ceres::SOLVER_CONTINUE; + } + + private: + ProblemPtr problem_; + bool verbose_; +}; + +} + +#endif /* INCLUDE_CORE_CERES_WRAPPER_ITERATION_UPDATE_CALLBACK_H_ */ diff --git a/include/core/ceres_wrapper/solver_ceres.h b/include/core/ceres_wrapper/solver_ceres.h index e5c0ab3b057cfbd48c4bc280b27d142723b49f70..e4306aa8ca126185a6343435405c1bcb5940e650 100644 --- a/include/core/ceres_wrapper/solver_ceres.h +++ b/include/core/ceres_wrapper/solver_ceres.h @@ -9,9 +9,6 @@ //wolf includes #include "core/solver/solver_manager.h" #include "core/utils/params_server.h" -#include "core/ceres_wrapper/cost_function_wrapper.h" -#include "core/ceres_wrapper/local_parametrization_wrapper.h" -#include "core/ceres_wrapper/create_numeric_diff_cost_function.h" namespace ceres { typedef std::shared_ptr<CostFunction> CostFunctionPtr; @@ -20,21 +17,39 @@ typedef std::shared_ptr<CostFunction> CostFunctionPtr; namespace wolf { WOLF_PTR_TYPEDEFS(SolverCeres); +WOLF_PTR_TYPEDEFS(LocalParametrizationWrapper); WOLF_STRUCT_PTR_TYPEDEFS(ParamsCeres); struct ParamsCeres : public ParamsSolver { - bool update_immediately; + bool update_immediately = false; ceres::Solver::Options solver_options; ceres::Problem::Options problem_options; ceres::Covariance::Options covariance_options; ParamsCeres() : - update_immediately(false), - solver_options(), - problem_options(), - covariance_options() + ParamsSolver() { + loadHardcodedValues(); + } + + ParamsCeres(std::string _unique_name, const ParamsServer& _server) : + ParamsSolver(_unique_name, _server) + { + loadHardcodedValues(); + + // stop solver whenever the problem is updated (via ceres::iterationCallback) + update_immediately = _server.getParam<bool>(prefix + "update_immediately"); + + // ceres solver options + solver_options.max_num_iterations = _server.getParam<int>(prefix + "max_num_iterations"); + } + + void loadHardcodedValues() + { + solver_options = ceres::Solver::Options(); + problem_options = ceres::Problem::Options(); + covariance_options = ceres::Covariance::Options(); problem_options.cost_function_ownership = ceres::DO_NOT_TAKE_OWNERSHIP; problem_options.loss_function_ownership = ceres::TAKE_OWNERSHIP; problem_options.local_parameterization_ownership = ceres::DO_NOT_TAKE_OWNERSHIP; @@ -49,16 +64,6 @@ struct ParamsCeres : public ParamsSolver #endif covariance_options.num_threads = 1; covariance_options.apply_loss_function = false; - - } - - ParamsCeres(std::string _unique_name, const ParamsServer& _server) : - ParamsCeres() - { - update_immediately = _server.getParam<bool>(prefix + "update_immediately"); - - // ceres solver options - solver_options.max_num_iterations = _server.getParam<int>(prefix + "max_num_iterations"); } ~ParamsCeres() override = default; @@ -84,7 +89,7 @@ class SolverCeres : public SolverManager SolverCeres(const ProblemPtr& _wolf_problem); SolverCeres(const ProblemPtr& _wolf_problem, - const ParamsCeresPtr& _params); + const ParamsCeresPtr& _params); WOLF_SOLVER_CREATE(SolverCeres, ParamsCeres); @@ -109,7 +114,6 @@ class SolverCeres : public SolverManager ceres::Solver::Options& getSolverOptions(); - const Eigen::SparseMatrixd computeHessian() const; protected: diff --git a/include/core/solver/solver_manager.h b/include/core/solver/solver_manager.h index 5c0faab09b6bf1729c1d47d7dfd53dea530ae795..96cdf4494e411802920f29caeb2fcbb0c9dccc35 100644 --- a/include/core/solver/solver_manager.h +++ b/include/core/solver/solver_manager.h @@ -38,115 +38,147 @@ static SolverManagerPtr create(const ProblemPtr& _problem, return std::make_shared<SolverClass>(_problem, params); \ } \ -struct ParamsSolver: public ParamsBase -{ - std::string prefix = "solver/"; - - ParamsSolver() = default; - ParamsSolver(std::string _unique_name, const ParamsServer& _server): - ParamsBase(_unique_name, _server) - { - // - } - - ~ParamsSolver() override = default; -}; + struct ParamsSolver; /** * \brief Solver manager for WOLF */ class SolverManager { -public: - - /** \brief Enumeration of covariance blocks to be computed - * - * Enumeration of covariance blocks to be computed - * - */ - enum class CovarianceBlocksToBeComputed : std::size_t - { - ALL, ///< All blocks and all cross-covariances - ALL_MARGINALS, ///< All marginals - ROBOT_LANDMARKS ///< marginals of landmarks and current robot pose plus cross covariances of current robot and all landmarks - }; - - /** - * \brief Enumeration for the verbosity of the solver report. - */ - enum class ReportVerbosity : std::size_t - { - QUIET = 0, - BRIEF, - FULL - }; - -protected: - - ProblemPtr wolf_problem_; - -public: - - SolverManager(const ProblemPtr& wolf_problem); - - virtual ~SolverManager(); - - std::string solve(const ReportVerbosity report_level = ReportVerbosity::QUIET); - - virtual void computeCovariances(const CovarianceBlocksToBeComputed blocks) = 0; - - virtual void computeCovariances(const std::vector<StateBlockPtr>& st_list) = 0; - - virtual bool hasConverged() = 0; - - virtual SizeStd iterations() = 0; - - virtual double initialCost() = 0; - - virtual double finalCost() = 0; - - virtual void update(); - - ProblemPtr getProblem(); - - virtual bool isStateBlockRegistered(const StateBlockPtr& state_ptr); - - virtual bool isFactorRegistered(const FactorBasePtr& fac_ptr) const; - - bool check(std::string prefix="") const; - -protected: - - std::map<StateBlockPtr, Eigen::VectorXd> state_blocks_; - std::map<StateBlockPtr, FactorBasePtrList> state_blocks_2_factors_; - std::set<FactorBasePtr> factors_; - - virtual Eigen::VectorXd& getAssociatedMemBlock(const StateBlockPtr& state_ptr); - const double* getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr) const; - virtual double* getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr); - -private: - // SolverManager functions - void addFactor(const FactorBasePtr& fac_ptr); - void removeFactor(const FactorBasePtr& fac_ptr); - void addStateBlock(const StateBlockPtr& state_ptr); - void removeStateBlock(const StateBlockPtr& state_ptr); - void updateStateBlockState(const StateBlockPtr& state_ptr); - void updateStateBlockStatus(const StateBlockPtr& state_ptr); - void updateStateBlockLocalParametrization(const StateBlockPtr& state_ptr); + public: + + /** \brief Enumeration of covariance blocks to be computed + * + * Enumeration of covariance blocks to be computed + * + */ + enum class CovarianceBlocksToBeComputed : std::size_t + { + ALL, ///< All blocks and all cross-covariances + ALL_MARGINALS, ///< All marginals + ROBOT_LANDMARKS ///< marginals of landmarks and current robot pose plus cross covariances of current robot and all landmarks + }; + + /** + * \brief Enumeration for the verbosity of the solver report. + */ + enum class ReportVerbosity : std::size_t + { + QUIET = 0, + BRIEF, + FULL + }; + + protected: + + ProblemPtr wolf_problem_; + ParamsSolverPtr params_; + TimeStamp last_solve_ts_; + + public: + /** + * \brief Constructor with default params_ + */ + SolverManager(const ProblemPtr& _problem); + /** + * \brief Constructor with given params_ + */ + SolverManager(const ProblemPtr& _problem, + const ParamsSolverPtr& _params); + + virtual ~SolverManager(); + + /** + * \brief Solves with the verbosity defined in params_ + */ + std::string solve(); + /** + * \brief Solves with a given verbosity + */ + std::string solve(const ReportVerbosity report_level); + + virtual void computeCovariances(const CovarianceBlocksToBeComputed blocks) = 0; + + virtual void computeCovariances(const std::vector<StateBlockPtr>& st_list) = 0; + + virtual bool hasConverged() = 0; + + virtual SizeStd iterations() = 0; + + virtual double initialCost() = 0; + + virtual double finalCost() = 0; + + /** + * \brief Updates solver's problem according to the wolf_problem + */ + virtual void update(); + + ProblemPtr getProblem(); + + /** + * \brief Returns if solve() should be called (according to period, can be derived to implement other criteria) + */ + virtual bool ready() const; + + ReportVerbosity getVerbosity() const; + + virtual bool isStateBlockRegistered(const StateBlockPtr& state_ptr); + + virtual bool isFactorRegistered(const FactorBasePtr& fac_ptr) const; + + bool check(std::string prefix="") const; + + protected: + + std::map<StateBlockPtr, Eigen::VectorXd> state_blocks_; + std::map<StateBlockPtr, FactorBasePtrList> state_blocks_2_factors_; + std::set<FactorBasePtr> factors_; + + virtual Eigen::VectorXd& getAssociatedMemBlock(const StateBlockPtr& state_ptr); + const double* getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr) const; + virtual double* getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr); + + private: + // SolverManager functions + void addFactor(const FactorBasePtr& fac_ptr); + void removeFactor(const FactorBasePtr& fac_ptr); + void addStateBlock(const StateBlockPtr& state_ptr); + void removeStateBlock(const StateBlockPtr& state_ptr); + void updateStateBlockState(const StateBlockPtr& state_ptr); + void updateStateBlockStatus(const StateBlockPtr& state_ptr); + void updateStateBlockLocalParametrization(const StateBlockPtr& state_ptr); + + protected: + // Derived virtual functions + virtual std::string solveDerived(const ReportVerbosity report_level) = 0; + virtual void addFactorDerived(const FactorBasePtr& fac_ptr) = 0; + virtual void removeFactorDerived(const FactorBasePtr& fac_ptr) = 0; + virtual void addStateBlockDerived(const StateBlockPtr& state_ptr) = 0; + virtual void removeStateBlockDerived(const StateBlockPtr& state_ptr) = 0; + virtual void updateStateBlockStatusDerived(const StateBlockPtr& state_ptr) = 0; + virtual void updateStateBlockLocalParametrizationDerived(const StateBlockPtr& state_ptr) = 0; + virtual bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) = 0; + virtual bool isFactorRegisteredDerived(const FactorBasePtr& fac_ptr) const = 0; + virtual bool checkDerived(std::string prefix="") const = 0; +}; -protected: - // Derived virtual functions - virtual std::string solveDerived(const ReportVerbosity report_level) = 0; - virtual void addFactorDerived(const FactorBasePtr& fac_ptr) = 0; - virtual void removeFactorDerived(const FactorBasePtr& fac_ptr) = 0; - virtual void addStateBlockDerived(const StateBlockPtr& state_ptr) = 0; - virtual void removeStateBlockDerived(const StateBlockPtr& state_ptr) = 0; - virtual void updateStateBlockStatusDerived(const StateBlockPtr& state_ptr) = 0; - virtual void updateStateBlockLocalParametrizationDerived(const StateBlockPtr& state_ptr) = 0; - virtual bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) = 0; - virtual bool isFactorRegisteredDerived(const FactorBasePtr& fac_ptr) const = 0; - virtual bool checkDerived(std::string prefix="") const = 0; +// Params (here becaure it needs of declaration of SolverManager::ReportVerbosity) +struct ParamsSolver: public ParamsBase +{ + std::string prefix = "solver/"; + double period = 0.0; + SolverManager::ReportVerbosity verbose = SolverManager::ReportVerbosity::QUIET; + + ParamsSolver() = default; + ParamsSolver(std::string _unique_name, const ParamsServer& _server): + ParamsBase(_unique_name, _server) + { + period = _server.getParam<double>(prefix + "period"); + verbose = (SolverManager::ReportVerbosity)_server.getParam<int>(prefix + "verbose"); + } + + ~ParamsSolver() override = default; }; } // namespace wolf diff --git a/src/ceres_wrapper/solver_ceres.cpp b/src/ceres_wrapper/solver_ceres.cpp index e499173a8d9c2f2551393beea506b3d47d099273..e2ebc5e1f71c59599433158196870fe69b21de7f 100644 --- a/src/ceres_wrapper/solver_ceres.cpp +++ b/src/ceres_wrapper/solver_ceres.cpp @@ -1,5 +1,8 @@ -#include "../../include/core/ceres_wrapper/solver_ceres.h" +#include "core/ceres_wrapper/solver_ceres.h" #include "core/ceres_wrapper/create_numeric_diff_cost_function.h" +#include "core/ceres_wrapper/cost_function_wrapper.h" +#include "core/ceres_wrapper/iteration_update_callback.h" +#include "core/ceres_wrapper/local_parametrization_wrapper.h" #include "core/trajectory/trajectory_base.h" #include "core/map/map_base.h" #include "core/landmark/landmark_base.h" @@ -11,16 +14,18 @@ namespace wolf SolverCeres::SolverCeres(const ProblemPtr& _wolf_problem) : SolverCeres(_wolf_problem, std::make_shared<ParamsCeres>()) { - } SolverCeres::SolverCeres(const ProblemPtr& _wolf_problem, - const ParamsCeresPtr& _params) : - SolverManager(_wolf_problem), - params_ceres_(_params) + const ParamsCeresPtr& _params) + : SolverManager(_wolf_problem, _params) + , params_ceres_(_params) { covariance_ = wolf::make_unique<ceres::Covariance>(params_ceres_->covariance_options); ceres_problem_ = wolf::make_unique<ceres::Problem>(params_ceres_->problem_options); + + if (params_ceres_->update_immediately) + getSolverOptions().callbacks.push_back(new IterationUpdateCallback(wolf_problem_, params_ceres_->verbose != SolverManager::ReportVerbosity::QUIET)); } SolverCeres::~SolverCeres() @@ -33,12 +38,18 @@ SolverCeres::~SolverCeres() removeStateBlockDerived(state_blocks_.begin()->first); state_blocks_.erase(state_blocks_.begin()); } + + while (!getSolverOptions().callbacks.empty()) + { + delete getSolverOptions().callbacks.back(); + getSolverOptions().callbacks.pop_back(); + } } std::string SolverCeres::solveDerived(const ReportVerbosity report_level) { // run Ceres Solver - ceres::Solve(params_ceres_->solver_options, ceres_problem_.get(), &summary_); + ceres::Solve(getSolverOptions(), ceres_problem_.get(), &summary_); std::string report; diff --git a/src/solver/solver_manager.cpp b/src/solver/solver_manager.cpp index 24a674f5f517e7bbd5ac8448c8678a7517632b90..a0b6f028ce349227ad94a979c16e788c51a09915 100644 --- a/src/solver/solver_manager.cpp +++ b/src/solver/solver_manager.cpp @@ -5,10 +5,17 @@ namespace wolf { -SolverManager::SolverManager(const ProblemPtr& _wolf_problem) : - wolf_problem_(_wolf_problem) +SolverManager::SolverManager(const ProblemPtr& _problem) : + SolverManager(_problem, std::make_shared<ParamsSolver>()) { - assert(_wolf_problem != nullptr && "Passed a nullptr ProblemPtr."); +} + +SolverManager::SolverManager(const ProblemPtr& _problem, + const ParamsSolverPtr& _params) : + wolf_problem_(_problem), + params_(_params) +{ + assert(_problem != nullptr && "Passed a nullptr ProblemPtr."); } SolverManager::~SolverManager() @@ -95,11 +102,19 @@ wolf::ProblemPtr SolverManager::getProblem() return wolf_problem_; } +std::string SolverManager::solve() +{ + return solve(params_->verbose); +} + std::string SolverManager::solve(const ReportVerbosity report_level) { // update problem update(); + last_solve_ts_ = TimeStamp::Now(); + + // call derived solver std::string report = solveDerived(report_level); // update StateBlocks with optimized state value. @@ -303,6 +318,16 @@ double* SolverManager::getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr) return it->second.data(); } +bool SolverManager::ready() const +{ + return (!last_solve_ts_.ok() || (TimeStamp::Now() - last_solve_ts_) > params_->period); +} + +SolverManager::ReportVerbosity SolverManager::getVerbosity() const +{ + return params_->verbose; +} + bool SolverManager::isStateBlockRegistered(const StateBlockPtr& state_ptr) { return state_blocks_.count(state_ptr) ==1 && isStateBlockRegisteredDerived(state_ptr); diff --git a/test/gtest_solver_ceres.cpp b/test/gtest_solver_ceres.cpp index 7b811e7b970416180132252b1b25462572f19557..42b2f1c825fe7cfc79bbe44e8f577eed7ec1dfd8 100644 --- a/test/gtest_solver_ceres.cpp +++ b/test/gtest_solver_ceres.cpp @@ -7,21 +7,22 @@ #include "core/utils/utils_gtest.h" - #include "core/problem/problem.h" #include "core/sensor/sensor_base.h" #include "core/state_block/state_block.h" #include "core/capture/capture_void.h" #include "core/factor/factor_pose_2d.h" #include "core/factor/factor_quaternion_absolute.h" -#include "core/solver/solver_manager.h" #include "core/state_block/local_parametrization_angle.h" #include "core/state_block/local_parametrization_quaternion.h" +#include "core/solver/solver_manager.h" +#include "core/ceres_wrapper/solver_ceres.h" +#include "core/ceres_wrapper/local_parametrization_wrapper.h" + #include "ceres/ceres.h" #include <iostream> -#include "../include/core/ceres_wrapper/solver_ceres.h" using namespace wolf; using namespace Eigen;