From 8e1c166e74267b412fee6eeed838fbe7416146b5 Mon Sep 17 00:00:00 2001 From: joanvallve <jvallve@iri.upc.edu> Date: Mon, 15 Jun 2020 15:00:07 +0200 Subject: [PATCH] implemented iteration callback checking for problem update --- CMakeLists.txt | 5 ++- .../ceres_wrapper/iteration_update_callback.h | 41 +++++++++++++++++++ include/core/ceres_wrapper/solver_ceres.h | 6 +-- include/core/solver/solver_manager.h | 2 + src/ceres_wrapper/solver_ceres.cpp | 21 +++++++--- src/solver/solver_manager.cpp | 5 +++ test/gtest_solver_ceres.cpp | 7 ++-- 7 files changed, 72 insertions(+), 15 deletions(-) create mode 100644 include/core/ceres_wrapper/iteration_update_callback.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 271db2988..a6e7ee166 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -123,7 +123,7 @@ ENDIF(BUILD_DEMOS OR BUILD_TESTS) FIND_PACKAGE(Threads REQUIRED) -FIND_PACKAGE(Ceres REQUIRED) #Ceres is not required +FIND_PACKAGE(Ceres REQUIRED) #Ceres is required FIND_PACKAGE(Eigen3 3.3 REQUIRED) if(${EIGEN3_VERSION_STRING} VERSION_LESS 3.3) @@ -392,10 +392,11 @@ SET(SRCS_YAML IF (Ceres_FOUND) SET(HDRS_WRAPPER #ceres_wrapper/qr_manager.h - include/core/ceres_wrapper/solver_ceres.h include/core/ceres_wrapper/cost_function_wrapper.h include/core/ceres_wrapper/create_numeric_diff_cost_function.h include/core/ceres_wrapper/local_parametrization_wrapper.h + include/core/ceres_wrapper/iteration_update_callback.h + include/core/ceres_wrapper/solver_ceres.h include/core/solver/solver_manager.h include/core/solver_suitesparse/sparse_utils.h ) diff --git a/include/core/ceres_wrapper/iteration_update_callback.h b/include/core/ceres_wrapper/iteration_update_callback.h new file mode 100644 index 000000000..0496c336e --- /dev/null +++ b/include/core/ceres_wrapper/iteration_update_callback.h @@ -0,0 +1,41 @@ +/* + * iteration_callback.h + * + * Created on: Jun 15, 2020 + * Author: joanvallve + */ + +#ifndef INCLUDE_CORE_CERES_WRAPPER_ITERATION_UPDATE_CALLBACK_H_ +#define INCLUDE_CORE_CERES_WRAPPER_ITERATION_UPDATE_CALLBACK_H_ + +#include "core/problem/problem.h" +#include "ceres/ceres.h" + +namespace wolf { + +class IterationUpdateCallback : public ceres::IterationCallback +{ + public: + explicit IterationUpdateCallback(ProblemPtr _problem) + : problem_(_problem) {} + + ~IterationUpdateCallback() {} + + ceres::CallbackReturnType operator()(const ceres::IterationSummary& summary) override + { + if (problem_->getStateBlockNotificationMapSize() != 0 or + problem_->getFactorNotificationMapSize() != 0) + { + WOLF_INFO("Stopping solver to update the problem!"); + return ceres::SOLVER_TERMINATE_SUCCESSFULLY; + } + return ceres::SOLVER_CONTINUE; + } + + private: + ProblemPtr problem_; +}; + +} + +#endif /* INCLUDE_CORE_CERES_WRAPPER_ITERATION_UPDATE_CALLBACK_H_ */ diff --git a/include/core/ceres_wrapper/solver_ceres.h b/include/core/ceres_wrapper/solver_ceres.h index cb160c87d..d3c6d17ac 100644 --- a/include/core/ceres_wrapper/solver_ceres.h +++ b/include/core/ceres_wrapper/solver_ceres.h @@ -9,9 +9,6 @@ //wolf includes #include "core/solver/solver_manager.h" #include "core/utils/params_server.h" -#include "core/ceres_wrapper/cost_function_wrapper.h" -#include "core/ceres_wrapper/local_parametrization_wrapper.h" -#include "core/ceres_wrapper/create_numeric_diff_cost_function.h" namespace ceres { typedef std::shared_ptr<CostFunction> CostFunctionPtr; @@ -20,6 +17,7 @@ typedef std::shared_ptr<CostFunction> CostFunctionPtr; namespace wolf { WOLF_PTR_TYPEDEFS(SolverCeres); +WOLF_PTR_TYPEDEFS(LocalParametrizationWrapper); WOLF_STRUCT_PTR_TYPEDEFS(ParamsCeres); struct ParamsCeres : public ParamsSolver @@ -49,7 +47,6 @@ struct ParamsCeres : public ParamsSolver #endif covariance_options.num_threads = 1; covariance_options.apply_loss_function = false; - } ParamsCeres(std::string _unique_name, const ParamsServer& _server) : @@ -109,7 +106,6 @@ class SolverCeres : public SolverManager ceres::Solver::Options& getSolverOptions(); - const Eigen::SparseMatrixd computeHessian() const; protected: diff --git a/include/core/solver/solver_manager.h b/include/core/solver/solver_manager.h index 1fc5e6094..96cdf4494 100644 --- a/include/core/solver/solver_manager.h +++ b/include/core/solver/solver_manager.h @@ -121,6 +121,8 @@ class SolverManager */ virtual bool ready() const; + ReportVerbosity getVerbosity() const; + virtual bool isStateBlockRegistered(const StateBlockPtr& state_ptr); virtual bool isFactorRegistered(const FactorBasePtr& fac_ptr) const; diff --git a/src/ceres_wrapper/solver_ceres.cpp b/src/ceres_wrapper/solver_ceres.cpp index 6e7c8ac56..b3413b2c7 100644 --- a/src/ceres_wrapper/solver_ceres.cpp +++ b/src/ceres_wrapper/solver_ceres.cpp @@ -1,5 +1,8 @@ -#include "../../include/core/ceres_wrapper/solver_ceres.h" +#include "core/ceres_wrapper/solver_ceres.h" #include "core/ceres_wrapper/create_numeric_diff_cost_function.h" +#include "core/ceres_wrapper/cost_function_wrapper.h" +#include "core/ceres_wrapper/iteration_update_callback.h" +#include "core/ceres_wrapper/local_parametrization_wrapper.h" #include "core/trajectory/trajectory_base.h" #include "core/map/map_base.h" #include "core/landmark/landmark_base.h" @@ -11,16 +14,18 @@ namespace wolf SolverCeres::SolverCeres(const ProblemPtr& _wolf_problem) : SolverCeres(_wolf_problem, std::make_shared<ParamsCeres>()) { - } SolverCeres::SolverCeres(const ProblemPtr& _wolf_problem, - const ParamsCeresPtr& _params) : - SolverManager(_wolf_problem, _params), - params_ceres_(_params) + const ParamsCeresPtr& _params) + : SolverManager(_wolf_problem, _params) + , params_ceres_(_params) { covariance_ = wolf::make_unique<ceres::Covariance>(params_ceres_->covariance_options); ceres_problem_ = wolf::make_unique<ceres::Problem>(params_ceres_->problem_options); + + if (params_ceres_->update_immediately) + getSolverOptions().callbacks.push_back(new IterationUpdateCallback(wolf_problem_)); } SolverCeres::~SolverCeres() @@ -33,6 +38,12 @@ SolverCeres::~SolverCeres() removeStateBlockDerived(state_blocks_.begin()->first); state_blocks_.erase(state_blocks_.begin()); } + + while (!getSolverOptions().callbacks.empty()) + { + delete getSolverOptions().callbacks.back(); + getSolverOptions().callbacks.pop_back(); + } } std::string SolverCeres::solveDerived(const ReportVerbosity report_level) diff --git a/src/solver/solver_manager.cpp b/src/solver/solver_manager.cpp index 4d32ac6de..a0b6f028c 100644 --- a/src/solver/solver_manager.cpp +++ b/src/solver/solver_manager.cpp @@ -323,6 +323,11 @@ bool SolverManager::ready() const return (!last_solve_ts_.ok() || (TimeStamp::Now() - last_solve_ts_) > params_->period); } +SolverManager::ReportVerbosity SolverManager::getVerbosity() const +{ + return params_->verbose; +} + bool SolverManager::isStateBlockRegistered(const StateBlockPtr& state_ptr) { return state_blocks_.count(state_ptr) ==1 && isStateBlockRegisteredDerived(state_ptr); diff --git a/test/gtest_solver_ceres.cpp b/test/gtest_solver_ceres.cpp index 7b811e7b9..42b2f1c82 100644 --- a/test/gtest_solver_ceres.cpp +++ b/test/gtest_solver_ceres.cpp @@ -7,21 +7,22 @@ #include "core/utils/utils_gtest.h" - #include "core/problem/problem.h" #include "core/sensor/sensor_base.h" #include "core/state_block/state_block.h" #include "core/capture/capture_void.h" #include "core/factor/factor_pose_2d.h" #include "core/factor/factor_quaternion_absolute.h" -#include "core/solver/solver_manager.h" #include "core/state_block/local_parametrization_angle.h" #include "core/state_block/local_parametrization_quaternion.h" +#include "core/solver/solver_manager.h" +#include "core/ceres_wrapper/solver_ceres.h" +#include "core/ceres_wrapper/local_parametrization_wrapper.h" + #include "ceres/ceres.h" #include <iostream> -#include "../include/core/ceres_wrapper/solver_ceres.h" using namespace wolf; using namespace Eigen; -- GitLab