diff --git a/CMakeLists.txt b/CMakeLists.txt index 271db2988212108d5e8baf42e206601f92a98eaf..a6e7ee166720a726fc6c9178acc4119e0c4286e9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -123,7 +123,7 @@ ENDIF(BUILD_DEMOS OR BUILD_TESTS) FIND_PACKAGE(Threads REQUIRED) -FIND_PACKAGE(Ceres REQUIRED) #Ceres is not required +FIND_PACKAGE(Ceres REQUIRED) #Ceres is required FIND_PACKAGE(Eigen3 3.3 REQUIRED) if(${EIGEN3_VERSION_STRING} VERSION_LESS 3.3) @@ -392,10 +392,11 @@ SET(SRCS_YAML IF (Ceres_FOUND) SET(HDRS_WRAPPER #ceres_wrapper/qr_manager.h - include/core/ceres_wrapper/solver_ceres.h include/core/ceres_wrapper/cost_function_wrapper.h include/core/ceres_wrapper/create_numeric_diff_cost_function.h include/core/ceres_wrapper/local_parametrization_wrapper.h + include/core/ceres_wrapper/iteration_update_callback.h + include/core/ceres_wrapper/solver_ceres.h include/core/solver/solver_manager.h include/core/solver_suitesparse/sparse_utils.h ) diff --git a/include/core/ceres_wrapper/iteration_update_callback.h b/include/core/ceres_wrapper/iteration_update_callback.h new file mode 100644 index 0000000000000000000000000000000000000000..0496c336e9b04b364696b2b5de38e26498e57b70 --- /dev/null +++ b/include/core/ceres_wrapper/iteration_update_callback.h @@ -0,0 +1,41 @@ +/* + * iteration_callback.h + * + * Created on: Jun 15, 2020 + * Author: joanvallve + */ + +#ifndef INCLUDE_CORE_CERES_WRAPPER_ITERATION_UPDATE_CALLBACK_H_ +#define INCLUDE_CORE_CERES_WRAPPER_ITERATION_UPDATE_CALLBACK_H_ + +#include "core/problem/problem.h" +#include "ceres/ceres.h" + +namespace wolf { + +class IterationUpdateCallback : public ceres::IterationCallback +{ + public: + explicit IterationUpdateCallback(ProblemPtr _problem) + : problem_(_problem) {} + + ~IterationUpdateCallback() {} + + ceres::CallbackReturnType operator()(const ceres::IterationSummary& summary) override + { + if (problem_->getStateBlockNotificationMapSize() != 0 or + problem_->getFactorNotificationMapSize() != 0) + { + WOLF_INFO("Stopping solver to update the problem!"); + return ceres::SOLVER_TERMINATE_SUCCESSFULLY; + } + return ceres::SOLVER_CONTINUE; + } + + private: + ProblemPtr problem_; +}; + +} + +#endif /* INCLUDE_CORE_CERES_WRAPPER_ITERATION_UPDATE_CALLBACK_H_ */ diff --git a/include/core/ceres_wrapper/solver_ceres.h b/include/core/ceres_wrapper/solver_ceres.h index cb160c87dc4e7d4dca9136656737ac4fce36419f..d3c6d17ace1b7cf5dd75497954ed0aeaa270c067 100644 --- a/include/core/ceres_wrapper/solver_ceres.h +++ b/include/core/ceres_wrapper/solver_ceres.h @@ -9,9 +9,6 @@ //wolf includes #include "core/solver/solver_manager.h" #include "core/utils/params_server.h" -#include "core/ceres_wrapper/cost_function_wrapper.h" -#include "core/ceres_wrapper/local_parametrization_wrapper.h" -#include "core/ceres_wrapper/create_numeric_diff_cost_function.h" namespace ceres { typedef std::shared_ptr<CostFunction> CostFunctionPtr; @@ -20,6 +17,7 @@ typedef std::shared_ptr<CostFunction> CostFunctionPtr; namespace wolf { WOLF_PTR_TYPEDEFS(SolverCeres); +WOLF_PTR_TYPEDEFS(LocalParametrizationWrapper); WOLF_STRUCT_PTR_TYPEDEFS(ParamsCeres); struct ParamsCeres : public ParamsSolver @@ -49,7 +47,6 @@ struct ParamsCeres : public ParamsSolver #endif covariance_options.num_threads = 1; covariance_options.apply_loss_function = false; - } ParamsCeres(std::string _unique_name, const ParamsServer& _server) : @@ -109,7 +106,6 @@ class SolverCeres : public SolverManager ceres::Solver::Options& getSolverOptions(); - const Eigen::SparseMatrixd computeHessian() const; protected: diff --git a/include/core/solver/solver_manager.h b/include/core/solver/solver_manager.h index 1fc5e60943ef0d37a72f5e8d2283afb8bbfcc9f9..96cdf4494e411802920f29caeb2fcbb0c9dccc35 100644 --- a/include/core/solver/solver_manager.h +++ b/include/core/solver/solver_manager.h @@ -121,6 +121,8 @@ class SolverManager */ virtual bool ready() const; + ReportVerbosity getVerbosity() const; + virtual bool isStateBlockRegistered(const StateBlockPtr& state_ptr); virtual bool isFactorRegistered(const FactorBasePtr& fac_ptr) const; diff --git a/src/ceres_wrapper/solver_ceres.cpp b/src/ceres_wrapper/solver_ceres.cpp index 6e7c8ac56a03d1738c2d1caaa6debad2448f62de..b3413b2c7760f7f29966427c8f2b4cce50a2953c 100644 --- a/src/ceres_wrapper/solver_ceres.cpp +++ b/src/ceres_wrapper/solver_ceres.cpp @@ -1,5 +1,8 @@ -#include "../../include/core/ceres_wrapper/solver_ceres.h" +#include "core/ceres_wrapper/solver_ceres.h" #include "core/ceres_wrapper/create_numeric_diff_cost_function.h" +#include "core/ceres_wrapper/cost_function_wrapper.h" +#include "core/ceres_wrapper/iteration_update_callback.h" +#include "core/ceres_wrapper/local_parametrization_wrapper.h" #include "core/trajectory/trajectory_base.h" #include "core/map/map_base.h" #include "core/landmark/landmark_base.h" @@ -11,16 +14,18 @@ namespace wolf SolverCeres::SolverCeres(const ProblemPtr& _wolf_problem) : SolverCeres(_wolf_problem, std::make_shared<ParamsCeres>()) { - } SolverCeres::SolverCeres(const ProblemPtr& _wolf_problem, - const ParamsCeresPtr& _params) : - SolverManager(_wolf_problem, _params), - params_ceres_(_params) + const ParamsCeresPtr& _params) + : SolverManager(_wolf_problem, _params) + , params_ceres_(_params) { covariance_ = wolf::make_unique<ceres::Covariance>(params_ceres_->covariance_options); ceres_problem_ = wolf::make_unique<ceres::Problem>(params_ceres_->problem_options); + + if (params_ceres_->update_immediately) + getSolverOptions().callbacks.push_back(new IterationUpdateCallback(wolf_problem_)); } SolverCeres::~SolverCeres() @@ -33,6 +38,12 @@ SolverCeres::~SolverCeres() removeStateBlockDerived(state_blocks_.begin()->first); state_blocks_.erase(state_blocks_.begin()); } + + while (!getSolverOptions().callbacks.empty()) + { + delete getSolverOptions().callbacks.back(); + getSolverOptions().callbacks.pop_back(); + } } std::string SolverCeres::solveDerived(const ReportVerbosity report_level) diff --git a/src/solver/solver_manager.cpp b/src/solver/solver_manager.cpp index 4d32ac6de815acb26ba5e7759facbf2e5d3f9ea2..a0b6f028ce349227ad94a979c16e788c51a09915 100644 --- a/src/solver/solver_manager.cpp +++ b/src/solver/solver_manager.cpp @@ -323,6 +323,11 @@ bool SolverManager::ready() const return (!last_solve_ts_.ok() || (TimeStamp::Now() - last_solve_ts_) > params_->period); } +SolverManager::ReportVerbosity SolverManager::getVerbosity() const +{ + return params_->verbose; +} + bool SolverManager::isStateBlockRegistered(const StateBlockPtr& state_ptr) { return state_blocks_.count(state_ptr) ==1 && isStateBlockRegisteredDerived(state_ptr); diff --git a/test/gtest_solver_ceres.cpp b/test/gtest_solver_ceres.cpp index 7b811e7b970416180132252b1b25462572f19557..42b2f1c825fe7cfc79bbe44e8f577eed7ec1dfd8 100644 --- a/test/gtest_solver_ceres.cpp +++ b/test/gtest_solver_ceres.cpp @@ -7,21 +7,22 @@ #include "core/utils/utils_gtest.h" - #include "core/problem/problem.h" #include "core/sensor/sensor_base.h" #include "core/state_block/state_block.h" #include "core/capture/capture_void.h" #include "core/factor/factor_pose_2d.h" #include "core/factor/factor_quaternion_absolute.h" -#include "core/solver/solver_manager.h" #include "core/state_block/local_parametrization_angle.h" #include "core/state_block/local_parametrization_quaternion.h" +#include "core/solver/solver_manager.h" +#include "core/ceres_wrapper/solver_ceres.h" +#include "core/ceres_wrapper/local_parametrization_wrapper.h" + #include "ceres/ceres.h" #include <iostream> -#include "../include/core/ceres_wrapper/solver_ceres.h" using namespace wolf; using namespace Eigen;