Skip to content
Snippets Groups Projects
Commit 5b1e38cd authored by Joan Vallvé Navarro's avatar Joan Vallvé Navarro
Browse files

Resolve "SolverCeres stop solving if can update"

parent 8769670d
No related branches found
No related tags found
No related merge requests found
...@@ -123,7 +123,7 @@ ENDIF(BUILD_DEMOS OR BUILD_TESTS) ...@@ -123,7 +123,7 @@ ENDIF(BUILD_DEMOS OR BUILD_TESTS)
FIND_PACKAGE(Threads REQUIRED) FIND_PACKAGE(Threads REQUIRED)
FIND_PACKAGE(Ceres REQUIRED) #Ceres is not required FIND_PACKAGE(Ceres REQUIRED) #Ceres is required
FIND_PACKAGE(Eigen3 3.3 REQUIRED) FIND_PACKAGE(Eigen3 3.3 REQUIRED)
if(${EIGEN3_VERSION_STRING} VERSION_LESS 3.3) if(${EIGEN3_VERSION_STRING} VERSION_LESS 3.3)
...@@ -392,10 +392,11 @@ SET(SRCS_YAML ...@@ -392,10 +392,11 @@ SET(SRCS_YAML
IF (Ceres_FOUND) IF (Ceres_FOUND)
SET(HDRS_WRAPPER SET(HDRS_WRAPPER
#ceres_wrapper/qr_manager.h #ceres_wrapper/qr_manager.h
include/core/ceres_wrapper/solver_ceres.h
include/core/ceres_wrapper/cost_function_wrapper.h include/core/ceres_wrapper/cost_function_wrapper.h
include/core/ceres_wrapper/create_numeric_diff_cost_function.h include/core/ceres_wrapper/create_numeric_diff_cost_function.h
include/core/ceres_wrapper/local_parametrization_wrapper.h include/core/ceres_wrapper/local_parametrization_wrapper.h
include/core/ceres_wrapper/iteration_update_callback.h
include/core/ceres_wrapper/solver_ceres.h
include/core/solver/solver_manager.h include/core/solver/solver_manager.h
include/core/solver_suitesparse/sparse_utils.h include/core/solver_suitesparse/sparse_utils.h
) )
......
/*
* iteration_callback.h
*
* Created on: Jun 15, 2020
* Author: joanvallve
*/
#ifndef INCLUDE_CORE_CERES_WRAPPER_ITERATION_UPDATE_CALLBACK_H_
#define INCLUDE_CORE_CERES_WRAPPER_ITERATION_UPDATE_CALLBACK_H_
#include "core/problem/problem.h"
#include "ceres/ceres.h"
namespace wolf {
class IterationUpdateCallback : public ceres::IterationCallback
{
public:
explicit IterationUpdateCallback(ProblemPtr _problem, bool verbose = false)
: problem_(_problem)
, verbose_(verbose)
{}
~IterationUpdateCallback() {}
ceres::CallbackReturnType operator()(const ceres::IterationSummary& summary) override
{
if (problem_->getStateBlockNotificationMapSize() != 0 or
problem_->getFactorNotificationMapSize() != 0)
{
WOLF_INFO_COND(verbose_, "Stopping solver to update the problem!");
return ceres::SOLVER_TERMINATE_SUCCESSFULLY;
}
return ceres::SOLVER_CONTINUE;
}
private:
ProblemPtr problem_;
bool verbose_;
};
}
#endif /* INCLUDE_CORE_CERES_WRAPPER_ITERATION_UPDATE_CALLBACK_H_ */
...@@ -9,9 +9,6 @@ ...@@ -9,9 +9,6 @@
//wolf includes //wolf includes
#include "core/solver/solver_manager.h" #include "core/solver/solver_manager.h"
#include "core/utils/params_server.h" #include "core/utils/params_server.h"
#include "core/ceres_wrapper/cost_function_wrapper.h"
#include "core/ceres_wrapper/local_parametrization_wrapper.h"
#include "core/ceres_wrapper/create_numeric_diff_cost_function.h"
namespace ceres { namespace ceres {
typedef std::shared_ptr<CostFunction> CostFunctionPtr; typedef std::shared_ptr<CostFunction> CostFunctionPtr;
...@@ -20,21 +17,39 @@ typedef std::shared_ptr<CostFunction> CostFunctionPtr; ...@@ -20,21 +17,39 @@ typedef std::shared_ptr<CostFunction> CostFunctionPtr;
namespace wolf { namespace wolf {
WOLF_PTR_TYPEDEFS(SolverCeres); WOLF_PTR_TYPEDEFS(SolverCeres);
WOLF_PTR_TYPEDEFS(LocalParametrizationWrapper);
WOLF_STRUCT_PTR_TYPEDEFS(ParamsCeres); WOLF_STRUCT_PTR_TYPEDEFS(ParamsCeres);
struct ParamsCeres : public ParamsSolver struct ParamsCeres : public ParamsSolver
{ {
bool update_immediately; bool update_immediately = false;
ceres::Solver::Options solver_options; ceres::Solver::Options solver_options;
ceres::Problem::Options problem_options; ceres::Problem::Options problem_options;
ceres::Covariance::Options covariance_options; ceres::Covariance::Options covariance_options;
ParamsCeres() : ParamsCeres() :
update_immediately(false), ParamsSolver()
solver_options(),
problem_options(),
covariance_options()
{ {
loadHardcodedValues();
}
ParamsCeres(std::string _unique_name, const ParamsServer& _server) :
ParamsSolver(_unique_name, _server)
{
loadHardcodedValues();
// stop solver whenever the problem is updated (via ceres::iterationCallback)
update_immediately = _server.getParam<bool>(prefix + "update_immediately");
// ceres solver options
solver_options.max_num_iterations = _server.getParam<int>(prefix + "max_num_iterations");
}
void loadHardcodedValues()
{
solver_options = ceres::Solver::Options();
problem_options = ceres::Problem::Options();
covariance_options = ceres::Covariance::Options();
problem_options.cost_function_ownership = ceres::DO_NOT_TAKE_OWNERSHIP; problem_options.cost_function_ownership = ceres::DO_NOT_TAKE_OWNERSHIP;
problem_options.loss_function_ownership = ceres::TAKE_OWNERSHIP; problem_options.loss_function_ownership = ceres::TAKE_OWNERSHIP;
problem_options.local_parameterization_ownership = ceres::DO_NOT_TAKE_OWNERSHIP; problem_options.local_parameterization_ownership = ceres::DO_NOT_TAKE_OWNERSHIP;
...@@ -49,16 +64,6 @@ struct ParamsCeres : public ParamsSolver ...@@ -49,16 +64,6 @@ struct ParamsCeres : public ParamsSolver
#endif #endif
covariance_options.num_threads = 1; covariance_options.num_threads = 1;
covariance_options.apply_loss_function = false; covariance_options.apply_loss_function = false;
}
ParamsCeres(std::string _unique_name, const ParamsServer& _server) :
ParamsCeres()
{
update_immediately = _server.getParam<bool>(prefix + "update_immediately");
// ceres solver options
solver_options.max_num_iterations = _server.getParam<int>(prefix + "max_num_iterations");
} }
~ParamsCeres() override = default; ~ParamsCeres() override = default;
...@@ -84,7 +89,7 @@ class SolverCeres : public SolverManager ...@@ -84,7 +89,7 @@ class SolverCeres : public SolverManager
SolverCeres(const ProblemPtr& _wolf_problem); SolverCeres(const ProblemPtr& _wolf_problem);
SolverCeres(const ProblemPtr& _wolf_problem, SolverCeres(const ProblemPtr& _wolf_problem,
const ParamsCeresPtr& _params); const ParamsCeresPtr& _params);
WOLF_SOLVER_CREATE(SolverCeres, ParamsCeres); WOLF_SOLVER_CREATE(SolverCeres, ParamsCeres);
...@@ -109,7 +114,6 @@ class SolverCeres : public SolverManager ...@@ -109,7 +114,6 @@ class SolverCeres : public SolverManager
ceres::Solver::Options& getSolverOptions(); ceres::Solver::Options& getSolverOptions();
const Eigen::SparseMatrixd computeHessian() const; const Eigen::SparseMatrixd computeHessian() const;
protected: protected:
......
...@@ -38,115 +38,147 @@ static SolverManagerPtr create(const ProblemPtr& _problem, ...@@ -38,115 +38,147 @@ static SolverManagerPtr create(const ProblemPtr& _problem,
return std::make_shared<SolverClass>(_problem, params); \ return std::make_shared<SolverClass>(_problem, params); \
} \ } \
struct ParamsSolver: public ParamsBase struct ParamsSolver;
{
std::string prefix = "solver/";
ParamsSolver() = default;
ParamsSolver(std::string _unique_name, const ParamsServer& _server):
ParamsBase(_unique_name, _server)
{
//
}
~ParamsSolver() override = default;
};
/** /**
* \brief Solver manager for WOLF * \brief Solver manager for WOLF
*/ */
class SolverManager class SolverManager
{ {
public: public:
/** \brief Enumeration of covariance blocks to be computed /** \brief Enumeration of covariance blocks to be computed
* *
* Enumeration of covariance blocks to be computed * Enumeration of covariance blocks to be computed
* *
*/ */
enum class CovarianceBlocksToBeComputed : std::size_t enum class CovarianceBlocksToBeComputed : std::size_t
{ {
ALL, ///< All blocks and all cross-covariances ALL, ///< All blocks and all cross-covariances
ALL_MARGINALS, ///< All marginals ALL_MARGINALS, ///< All marginals
ROBOT_LANDMARKS ///< marginals of landmarks and current robot pose plus cross covariances of current robot and all landmarks ROBOT_LANDMARKS ///< marginals of landmarks and current robot pose plus cross covariances of current robot and all landmarks
}; };
/** /**
* \brief Enumeration for the verbosity of the solver report. * \brief Enumeration for the verbosity of the solver report.
*/ */
enum class ReportVerbosity : std::size_t enum class ReportVerbosity : std::size_t
{ {
QUIET = 0, QUIET = 0,
BRIEF, BRIEF,
FULL FULL
}; };
protected: protected:
ProblemPtr wolf_problem_; ProblemPtr wolf_problem_;
ParamsSolverPtr params_;
public: TimeStamp last_solve_ts_;
SolverManager(const ProblemPtr& wolf_problem); public:
/**
virtual ~SolverManager(); * \brief Constructor with default params_
*/
std::string solve(const ReportVerbosity report_level = ReportVerbosity::QUIET); SolverManager(const ProblemPtr& _problem);
/**
virtual void computeCovariances(const CovarianceBlocksToBeComputed blocks) = 0; * \brief Constructor with given params_
*/
virtual void computeCovariances(const std::vector<StateBlockPtr>& st_list) = 0; SolverManager(const ProblemPtr& _problem,
const ParamsSolverPtr& _params);
virtual bool hasConverged() = 0;
virtual ~SolverManager();
virtual SizeStd iterations() = 0;
/**
virtual double initialCost() = 0; * \brief Solves with the verbosity defined in params_
*/
virtual double finalCost() = 0; std::string solve();
/**
virtual void update(); * \brief Solves with a given verbosity
*/
ProblemPtr getProblem(); std::string solve(const ReportVerbosity report_level);
virtual bool isStateBlockRegistered(const StateBlockPtr& state_ptr); virtual void computeCovariances(const CovarianceBlocksToBeComputed blocks) = 0;
virtual bool isFactorRegistered(const FactorBasePtr& fac_ptr) const; virtual void computeCovariances(const std::vector<StateBlockPtr>& st_list) = 0;
bool check(std::string prefix="") const; virtual bool hasConverged() = 0;
protected: virtual SizeStd iterations() = 0;
std::map<StateBlockPtr, Eigen::VectorXd> state_blocks_; virtual double initialCost() = 0;
std::map<StateBlockPtr, FactorBasePtrList> state_blocks_2_factors_;
std::set<FactorBasePtr> factors_; virtual double finalCost() = 0;
virtual Eigen::VectorXd& getAssociatedMemBlock(const StateBlockPtr& state_ptr); /**
const double* getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr) const; * \brief Updates solver's problem according to the wolf_problem
virtual double* getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr); */
virtual void update();
private:
// SolverManager functions ProblemPtr getProblem();
void addFactor(const FactorBasePtr& fac_ptr);
void removeFactor(const FactorBasePtr& fac_ptr); /**
void addStateBlock(const StateBlockPtr& state_ptr); * \brief Returns if solve() should be called (according to period, can be derived to implement other criteria)
void removeStateBlock(const StateBlockPtr& state_ptr); */
void updateStateBlockState(const StateBlockPtr& state_ptr); virtual bool ready() const;
void updateStateBlockStatus(const StateBlockPtr& state_ptr);
void updateStateBlockLocalParametrization(const StateBlockPtr& state_ptr); ReportVerbosity getVerbosity() const;
virtual bool isStateBlockRegistered(const StateBlockPtr& state_ptr);
virtual bool isFactorRegistered(const FactorBasePtr& fac_ptr) const;
bool check(std::string prefix="") const;
protected:
std::map<StateBlockPtr, Eigen::VectorXd> state_blocks_;
std::map<StateBlockPtr, FactorBasePtrList> state_blocks_2_factors_;
std::set<FactorBasePtr> factors_;
virtual Eigen::VectorXd& getAssociatedMemBlock(const StateBlockPtr& state_ptr);
const double* getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr) const;
virtual double* getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr);
private:
// SolverManager functions
void addFactor(const FactorBasePtr& fac_ptr);
void removeFactor(const FactorBasePtr& fac_ptr);
void addStateBlock(const StateBlockPtr& state_ptr);
void removeStateBlock(const StateBlockPtr& state_ptr);
void updateStateBlockState(const StateBlockPtr& state_ptr);
void updateStateBlockStatus(const StateBlockPtr& state_ptr);
void updateStateBlockLocalParametrization(const StateBlockPtr& state_ptr);
protected:
// Derived virtual functions
virtual std::string solveDerived(const ReportVerbosity report_level) = 0;
virtual void addFactorDerived(const FactorBasePtr& fac_ptr) = 0;
virtual void removeFactorDerived(const FactorBasePtr& fac_ptr) = 0;
virtual void addStateBlockDerived(const StateBlockPtr& state_ptr) = 0;
virtual void removeStateBlockDerived(const StateBlockPtr& state_ptr) = 0;
virtual void updateStateBlockStatusDerived(const StateBlockPtr& state_ptr) = 0;
virtual void updateStateBlockLocalParametrizationDerived(const StateBlockPtr& state_ptr) = 0;
virtual bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) = 0;
virtual bool isFactorRegisteredDerived(const FactorBasePtr& fac_ptr) const = 0;
virtual bool checkDerived(std::string prefix="") const = 0;
};
protected: // Params (here becaure it needs of declaration of SolverManager::ReportVerbosity)
// Derived virtual functions struct ParamsSolver: public ParamsBase
virtual std::string solveDerived(const ReportVerbosity report_level) = 0; {
virtual void addFactorDerived(const FactorBasePtr& fac_ptr) = 0; std::string prefix = "solver/";
virtual void removeFactorDerived(const FactorBasePtr& fac_ptr) = 0; double period = 0.0;
virtual void addStateBlockDerived(const StateBlockPtr& state_ptr) = 0; SolverManager::ReportVerbosity verbose = SolverManager::ReportVerbosity::QUIET;
virtual void removeStateBlockDerived(const StateBlockPtr& state_ptr) = 0;
virtual void updateStateBlockStatusDerived(const StateBlockPtr& state_ptr) = 0; ParamsSolver() = default;
virtual void updateStateBlockLocalParametrizationDerived(const StateBlockPtr& state_ptr) = 0; ParamsSolver(std::string _unique_name, const ParamsServer& _server):
virtual bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) = 0; ParamsBase(_unique_name, _server)
virtual bool isFactorRegisteredDerived(const FactorBasePtr& fac_ptr) const = 0; {
virtual bool checkDerived(std::string prefix="") const = 0; period = _server.getParam<double>(prefix + "period");
verbose = (SolverManager::ReportVerbosity)_server.getParam<int>(prefix + "verbose");
}
~ParamsSolver() override = default;
}; };
} // namespace wolf } // namespace wolf
......
#include "../../include/core/ceres_wrapper/solver_ceres.h" #include "core/ceres_wrapper/solver_ceres.h"
#include "core/ceres_wrapper/create_numeric_diff_cost_function.h" #include "core/ceres_wrapper/create_numeric_diff_cost_function.h"
#include "core/ceres_wrapper/cost_function_wrapper.h"
#include "core/ceres_wrapper/iteration_update_callback.h"
#include "core/ceres_wrapper/local_parametrization_wrapper.h"
#include "core/trajectory/trajectory_base.h" #include "core/trajectory/trajectory_base.h"
#include "core/map/map_base.h" #include "core/map/map_base.h"
#include "core/landmark/landmark_base.h" #include "core/landmark/landmark_base.h"
...@@ -11,16 +14,18 @@ namespace wolf ...@@ -11,16 +14,18 @@ namespace wolf
SolverCeres::SolverCeres(const ProblemPtr& _wolf_problem) : SolverCeres::SolverCeres(const ProblemPtr& _wolf_problem) :
SolverCeres(_wolf_problem, std::make_shared<ParamsCeres>()) SolverCeres(_wolf_problem, std::make_shared<ParamsCeres>())
{ {
} }
SolverCeres::SolverCeres(const ProblemPtr& _wolf_problem, SolverCeres::SolverCeres(const ProblemPtr& _wolf_problem,
const ParamsCeresPtr& _params) : const ParamsCeresPtr& _params)
SolverManager(_wolf_problem), : SolverManager(_wolf_problem, _params)
params_ceres_(_params) , params_ceres_(_params)
{ {
covariance_ = wolf::make_unique<ceres::Covariance>(params_ceres_->covariance_options); covariance_ = wolf::make_unique<ceres::Covariance>(params_ceres_->covariance_options);
ceres_problem_ = wolf::make_unique<ceres::Problem>(params_ceres_->problem_options); ceres_problem_ = wolf::make_unique<ceres::Problem>(params_ceres_->problem_options);
if (params_ceres_->update_immediately)
getSolverOptions().callbacks.push_back(new IterationUpdateCallback(wolf_problem_, params_ceres_->verbose != SolverManager::ReportVerbosity::QUIET));
} }
SolverCeres::~SolverCeres() SolverCeres::~SolverCeres()
...@@ -33,12 +38,18 @@ SolverCeres::~SolverCeres() ...@@ -33,12 +38,18 @@ SolverCeres::~SolverCeres()
removeStateBlockDerived(state_blocks_.begin()->first); removeStateBlockDerived(state_blocks_.begin()->first);
state_blocks_.erase(state_blocks_.begin()); state_blocks_.erase(state_blocks_.begin());
} }
while (!getSolverOptions().callbacks.empty())
{
delete getSolverOptions().callbacks.back();
getSolverOptions().callbacks.pop_back();
}
} }
std::string SolverCeres::solveDerived(const ReportVerbosity report_level) std::string SolverCeres::solveDerived(const ReportVerbosity report_level)
{ {
// run Ceres Solver // run Ceres Solver
ceres::Solve(params_ceres_->solver_options, ceres_problem_.get(), &summary_); ceres::Solve(getSolverOptions(), ceres_problem_.get(), &summary_);
std::string report; std::string report;
......
...@@ -5,10 +5,17 @@ ...@@ -5,10 +5,17 @@
namespace wolf { namespace wolf {
SolverManager::SolverManager(const ProblemPtr& _wolf_problem) : SolverManager::SolverManager(const ProblemPtr& _problem) :
wolf_problem_(_wolf_problem) SolverManager(_problem, std::make_shared<ParamsSolver>())
{ {
assert(_wolf_problem != nullptr && "Passed a nullptr ProblemPtr."); }
SolverManager::SolverManager(const ProblemPtr& _problem,
const ParamsSolverPtr& _params) :
wolf_problem_(_problem),
params_(_params)
{
assert(_problem != nullptr && "Passed a nullptr ProblemPtr.");
} }
SolverManager::~SolverManager() SolverManager::~SolverManager()
...@@ -95,11 +102,19 @@ wolf::ProblemPtr SolverManager::getProblem() ...@@ -95,11 +102,19 @@ wolf::ProblemPtr SolverManager::getProblem()
return wolf_problem_; return wolf_problem_;
} }
std::string SolverManager::solve()
{
return solve(params_->verbose);
}
std::string SolverManager::solve(const ReportVerbosity report_level) std::string SolverManager::solve(const ReportVerbosity report_level)
{ {
// update problem // update problem
update(); update();
last_solve_ts_ = TimeStamp::Now();
// call derived solver
std::string report = solveDerived(report_level); std::string report = solveDerived(report_level);
// update StateBlocks with optimized state value. // update StateBlocks with optimized state value.
...@@ -303,6 +318,16 @@ double* SolverManager::getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr) ...@@ -303,6 +318,16 @@ double* SolverManager::getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr)
return it->second.data(); return it->second.data();
} }
bool SolverManager::ready() const
{
return (!last_solve_ts_.ok() || (TimeStamp::Now() - last_solve_ts_) > params_->period);
}
SolverManager::ReportVerbosity SolverManager::getVerbosity() const
{
return params_->verbose;
}
bool SolverManager::isStateBlockRegistered(const StateBlockPtr& state_ptr) bool SolverManager::isStateBlockRegistered(const StateBlockPtr& state_ptr)
{ {
return state_blocks_.count(state_ptr) ==1 && isStateBlockRegisteredDerived(state_ptr); return state_blocks_.count(state_ptr) ==1 && isStateBlockRegisteredDerived(state_ptr);
......
...@@ -7,21 +7,22 @@ ...@@ -7,21 +7,22 @@
#include "core/utils/utils_gtest.h" #include "core/utils/utils_gtest.h"
#include "core/problem/problem.h" #include "core/problem/problem.h"
#include "core/sensor/sensor_base.h" #include "core/sensor/sensor_base.h"
#include "core/state_block/state_block.h" #include "core/state_block/state_block.h"
#include "core/capture/capture_void.h" #include "core/capture/capture_void.h"
#include "core/factor/factor_pose_2d.h" #include "core/factor/factor_pose_2d.h"
#include "core/factor/factor_quaternion_absolute.h" #include "core/factor/factor_quaternion_absolute.h"
#include "core/solver/solver_manager.h"
#include "core/state_block/local_parametrization_angle.h" #include "core/state_block/local_parametrization_angle.h"
#include "core/state_block/local_parametrization_quaternion.h" #include "core/state_block/local_parametrization_quaternion.h"
#include "core/solver/solver_manager.h"
#include "core/ceres_wrapper/solver_ceres.h"
#include "core/ceres_wrapper/local_parametrization_wrapper.h"
#include "ceres/ceres.h" #include "ceres/ceres.h"
#include <iostream> #include <iostream>
#include "../include/core/ceres_wrapper/solver_ceres.h"
using namespace wolf; using namespace wolf;
using namespace Eigen; using namespace Eigen;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment