Skip to content
Snippets Groups Projects
Commit bc67170c authored by Jeremie Deray's avatar Jeremie Deray
Browse files

ceres_manager inherits from new solver_manager

parent 546200e4
No related branches found
No related tags found
No related merge requests found
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
namespace wolf { namespace wolf {
CeresManager::CeresManager(ProblemPtr _wolf_problem, const ceres::Solver::Options& _ceres_options) : CeresManager::CeresManager(ProblemPtr& _wolf_problem,
const ceres::Solver::Options& _ceres_options) :
SolverManager(_wolf_problem), SolverManager(_wolf_problem),
ceres_options_(_ceres_options) ceres_options_(_ceres_options)
{ {
...@@ -28,46 +29,37 @@ CeresManager::CeresManager(ProblemPtr _wolf_problem, const ceres::Solver::Option ...@@ -28,46 +29,37 @@ CeresManager::CeresManager(ProblemPtr _wolf_problem, const ceres::Solver::Option
problem_options.cost_function_ownership = ceres::DO_NOT_TAKE_OWNERSHIP; problem_options.cost_function_ownership = ceres::DO_NOT_TAKE_OWNERSHIP;
problem_options.loss_function_ownership = ceres::TAKE_OWNERSHIP; problem_options.loss_function_ownership = ceres::TAKE_OWNERSHIP;
problem_options.local_parameterization_ownership = ceres::DO_NOT_TAKE_OWNERSHIP; problem_options.local_parameterization_ownership = ceres::DO_NOT_TAKE_OWNERSHIP;
ceres_problem_ = wolf::make_unique<ceres::Problem>(problem_options); ceres_problem_ = wolf::make_unique<ceres::Problem>(problem_options);
} }
CeresManager::~CeresManager() CeresManager::~CeresManager()
{ {
// std::cout << "ceres residual blocks: " << ceres_problem_->NumResidualBlocks() << std::endl;
// std::cout << "ceres parameter blocks: " << ceres_problem_->NumParameterBlocks() << std::endl;
while (!ctr_2_residual_idx_.empty()) while (!ctr_2_residual_idx_.empty())
removeConstraint(ctr_2_residual_idx_.begin()->first); removeConstraint(ctr_2_residual_idx_.begin()->first);
// std::cout << "all residuals removed! \n";
} }
std::string CeresManager::solve(const unsigned int& _report_level) std::string CeresManager::solveImpl(const ReportVerbosity report_level)
{ {
//std::cout << "Residual blocks: " << ceres_problem_->NumResidualBlocks() << " Parameter blocks: " << ceres_problem_->NumParameterBlocks() << std::endl;
// update problem // update problem
update(); update();
//std::cout << "After Update: Residual blocks: " << ceres_problem_->NumResidualBlocks() << " Parameter blocks: " << ceres_problem_->NumParameterBlocks() << std::endl;
// run Ceres Solver // run Ceres Solver
ceres::Solve(ceres_options_, ceres_problem_.get(), &summary_); ceres::Solve(ceres_options_, ceres_problem_.get(), &summary_);
//std::cout << "solved" << std::endl;
std::string report;
//return report //return report
if (_report_level == 0) if (report_level == ReportVerbosity::BRIEF)
return std::string(); report = summary_.BriefReport();
else if (_report_level == 1) else if (report_level == ReportVerbosity::FULL)
return summary_.BriefReport(); report = summary_.FullReport();
else if (_report_level == 2)
return summary_.FullReport(); return report;
else
throw std::invalid_argument( "Report level should be 0, 1 or 2!" );
} }
void CeresManager::computeCovariances(CovarianceBlocksToBeComputed _blocks) void CeresManager::computeCovariances(const CovarianceBlocksToBeComputed _blocks)
{ {
//std::cout << "CeresManager: computing covariances..." << std::endl;
// update problem // update problem
update(); update();
...@@ -80,10 +72,10 @@ void CeresManager::computeCovariances(CovarianceBlocksToBeComputed _blocks) ...@@ -80,10 +72,10 @@ void CeresManager::computeCovariances(CovarianceBlocksToBeComputed _blocks)
switch (_blocks) switch (_blocks)
{ {
case ALL: case CovarianceBlocksToBeComputed::ALL:
{ {
// first create a vector containing all state blocks // first create a vector containing all state blocks
std::vector<StateBlockPtr> all_state_blocks, landmark_state_blocks, Frame_state_blocks; std::vector<StateBlockPtr> all_state_blocks, landmark_state_blocks;
//frame state blocks //frame state blocks
for(auto fr_ptr : wolf_problem_->getTrajectoryPtr()->getFrameList()) for(auto fr_ptr : wolf_problem_->getTrajectoryPtr()->getFrameList())
if (fr_ptr->isKey()) if (fr_ptr->isKey())
...@@ -102,11 +94,12 @@ void CeresManager::computeCovariances(CovarianceBlocksToBeComputed _blocks) ...@@ -102,11 +94,12 @@ void CeresManager::computeCovariances(CovarianceBlocksToBeComputed _blocks)
for (unsigned int j = i; j < all_state_blocks.size(); j++) for (unsigned int j = i; j < all_state_blocks.size(); j++)
{ {
state_block_pairs.emplace_back(all_state_blocks[i],all_state_blocks[j]); state_block_pairs.emplace_back(all_state_blocks[i],all_state_blocks[j]);
double_pairs.emplace_back(all_state_blocks[i]->getPtr(),all_state_blocks[j]->getPtr()); double_pairs.emplace_back(getAssociatedMemBlockPtr(all_state_blocks[i]),
getAssociatedMemBlockPtr(all_state_blocks[j]));
} }
break; break;
} }
case ALL_MARGINALS: case CovarianceBlocksToBeComputed::ALL_MARGINALS:
{ {
// first create a vector containing all state blocks // first create a vector containing all state blocks
for(auto fr_ptr : wolf_problem_->getTrajectoryPtr()->getFrameList()) for(auto fr_ptr : wolf_problem_->getTrajectoryPtr()->getFrameList())
...@@ -117,7 +110,7 @@ void CeresManager::computeCovariances(CovarianceBlocksToBeComputed _blocks) ...@@ -117,7 +110,7 @@ void CeresManager::computeCovariances(CovarianceBlocksToBeComputed _blocks)
if (sb) if (sb)
{ {
state_block_pairs.emplace_back(sb, sb2); state_block_pairs.emplace_back(sb, sb2);
double_pairs.emplace_back(sb->getPtr(), sb2->getPtr()); double_pairs.emplace_back(getAssociatedMemBlockPtr(sb), getAssociatedMemBlockPtr(sb2));
if (sb == sb2) break; if (sb == sb2) break;
} }
...@@ -127,7 +120,7 @@ void CeresManager::computeCovariances(CovarianceBlocksToBeComputed _blocks) ...@@ -127,7 +120,7 @@ void CeresManager::computeCovariances(CovarianceBlocksToBeComputed _blocks)
for(auto sb2 : l_ptr->getUsedStateBlockVec()) for(auto sb2 : l_ptr->getUsedStateBlockVec())
{ {
state_block_pairs.emplace_back(sb, sb2); state_block_pairs.emplace_back(sb, sb2);
double_pairs.emplace_back(sb->getPtr(), sb2->getPtr()); double_pairs.emplace_back(getAssociatedMemBlockPtr(sb), getAssociatedMemBlockPtr(sb2));
if (sb == sb2) break; if (sb == sb2) break;
} }
// // loop all marginals (PO marginals) // // loop all marginals (PO marginals)
...@@ -143,7 +136,7 @@ void CeresManager::computeCovariances(CovarianceBlocksToBeComputed _blocks) ...@@ -143,7 +136,7 @@ void CeresManager::computeCovariances(CovarianceBlocksToBeComputed _blocks)
// } // }
break; break;
} }
case ROBOT_LANDMARKS: case CovarianceBlocksToBeComputed::ROBOT_LANDMARKS:
{ {
//robot-robot //robot-robot
auto last_key_frame = wolf_problem_->getLastKeyFramePtr(); auto last_key_frame = wolf_problem_->getLastKeyFramePtr();
...@@ -152,9 +145,12 @@ void CeresManager::computeCovariances(CovarianceBlocksToBeComputed _blocks) ...@@ -152,9 +145,12 @@ void CeresManager::computeCovariances(CovarianceBlocksToBeComputed _blocks)
state_block_pairs.emplace_back(last_key_frame->getPPtr(), last_key_frame->getOPtr()); state_block_pairs.emplace_back(last_key_frame->getPPtr(), last_key_frame->getOPtr());
state_block_pairs.emplace_back(last_key_frame->getOPtr(), last_key_frame->getOPtr()); state_block_pairs.emplace_back(last_key_frame->getOPtr(), last_key_frame->getOPtr());
double_pairs.emplace_back(last_key_frame->getPPtr()->getPtr(), last_key_frame->getPPtr()->getPtr()); double_pairs.emplace_back(getAssociatedMemBlockPtr(last_key_frame->getPPtr()),
double_pairs.emplace_back(last_key_frame->getPPtr()->getPtr(), last_key_frame->getOPtr()->getPtr()); getAssociatedMemBlockPtr(last_key_frame->getPPtr()));
double_pairs.emplace_back(last_key_frame->getOPtr()->getPtr(), last_key_frame->getOPtr()->getPtr()); double_pairs.emplace_back(getAssociatedMemBlockPtr(last_key_frame->getPPtr()),
getAssociatedMemBlockPtr(last_key_frame->getOPtr()));
double_pairs.emplace_back(getAssociatedMemBlockPtr(last_key_frame->getOPtr()),
getAssociatedMemBlockPtr(last_key_frame->getOPtr()));
// landmarks // landmarks
std::vector<StateBlockPtr> landmark_state_blocks; std::vector<StateBlockPtr> landmark_state_blocks;
...@@ -168,14 +164,17 @@ void CeresManager::computeCovariances(CovarianceBlocksToBeComputed _blocks) ...@@ -168,14 +164,17 @@ void CeresManager::computeCovariances(CovarianceBlocksToBeComputed _blocks)
// robot - landmark // robot - landmark
state_block_pairs.emplace_back(last_key_frame->getPPtr(), *state_it); state_block_pairs.emplace_back(last_key_frame->getPPtr(), *state_it);
state_block_pairs.emplace_back(last_key_frame->getOPtr(), *state_it); state_block_pairs.emplace_back(last_key_frame->getOPtr(), *state_it);
double_pairs.emplace_back(last_key_frame->getPPtr()->getPtr(), (*state_it)->getPtr()); double_pairs.emplace_back(getAssociatedMemBlockPtr(last_key_frame->getPPtr()),
double_pairs.emplace_back(last_key_frame->getOPtr()->getPtr(), (*state_it)->getPtr()); getAssociatedMemBlockPtr((*state_it)));
double_pairs.emplace_back(getAssociatedMemBlockPtr(last_key_frame->getOPtr()),
getAssociatedMemBlockPtr((*state_it)));
// landmark marginal // landmark marginal
for (auto next_state_it = state_it; next_state_it != landmark_state_blocks.end(); next_state_it++) for (auto next_state_it = state_it; next_state_it != landmark_state_blocks.end(); next_state_it++)
{ {
state_block_pairs.emplace_back(*state_it, *next_state_it); state_block_pairs.emplace_back(*state_it, *next_state_it);
double_pairs.emplace_back((*state_it)->getPtr(), (*next_state_it)->getPtr()); double_pairs.emplace_back(getAssociatedMemBlockPtr((*state_it)),
getAssociatedMemBlockPtr((*next_state_it)));
} }
} }
} }
...@@ -218,8 +217,9 @@ void CeresManager::computeCovariances(const StateBlockList& st_list) ...@@ -218,8 +217,9 @@ void CeresManager::computeCovariances(const StateBlockList& st_list)
for (auto st_it1 = st_list.begin(); st_it1 != st_list.end(); st_it1++) for (auto st_it1 = st_list.begin(); st_it1 != st_list.end(); st_it1++)
for (auto st_it2 = st_it1; st_it2 != st_list.end(); st_it2++) for (auto st_it2 = st_it1; st_it2 != st_list.end(); st_it2++)
{ {
state_block_pairs.push_back(std::pair<StateBlockPtr, StateBlockPtr>(*st_it1,*st_it2)); state_block_pairs.emplace_back(*st_it1, *st_it2);
double_pairs.push_back(std::pair<const double*, const double*>((*st_it1)->getPtr(),(*st_it2)->getPtr())); double_pairs.emplace_back(getAssociatedMemBlockPtr((*st_it1)),
getAssociatedMemBlockPtr((*st_it2)));
} }
//std::cout << "pairs... " << double_pairs.size() << std::endl; //std::cout << "pairs... " << double_pairs.size() << std::endl;
...@@ -238,72 +238,69 @@ void CeresManager::computeCovariances(const StateBlockList& st_list) ...@@ -238,72 +238,69 @@ void CeresManager::computeCovariances(const StateBlockList& st_list)
std::cout << "WARNING: Couldn't compute covariances!" << std::endl; std::cout << "WARNING: Couldn't compute covariances!" << std::endl;
} }
void CeresManager::addConstraint(ConstraintBasePtr _ctr_ptr) void CeresManager::addConstraint(const ConstraintBasePtr& ctr_ptr)
{ {
ctr_2_costfunction_[_ctr_ptr] = createCostFunction(_ctr_ptr); auto cost_func_ptr = createCostFunction(ctr_ptr);
// std::cout << "adding constraint " << _ctr_ptr->id() << std::endl; ctr_2_costfunction_[ctr_ptr] = cost_func_ptr;
// std::cout << "constraint pointer " << _ctr_ptr << std::endl;
if (_ctr_ptr->getApplyLossFunction()) std::vector<Scalar*> res_block_mem;
ctr_2_residual_idx_[_ctr_ptr] = ceres_problem_->AddResidualBlock(ctr_2_costfunction_[_ctr_ptr].get(), new ceres::CauchyLoss(0.5), _ctr_ptr->getStateScalarPtrVector()); for (const StateBlockPtr& st : ctr_ptr->getStateBlockPtrVector())
else {
ctr_2_residual_idx_[_ctr_ptr] = ceres_problem_->AddResidualBlock(ctr_2_costfunction_[_ctr_ptr].get(), NULL, _ctr_ptr->getStateScalarPtrVector()); res_block_mem.push_back( getAssociatedMemBlockPtr(st) );
}
auto loss_func_ptr = (ctr_ptr->getApplyLossFunction())?
new ceres::CauchyLoss(0.5): nullptr;
ctr_2_residual_idx_[ctr_ptr] =
ceres_problem_->AddResidualBlock(cost_func_ptr.get(),
loss_func_ptr, res_block_mem);
assert(ceres_problem_->NumResidualBlocks() == ctr_2_residual_idx_.size() && "ceres residuals different from wrapper residuals"); assert(ceres_problem_->NumResidualBlocks() == ctr_2_residual_idx_.size() && "ceres residuals different from wrapper residuals");
} }
void CeresManager::removeConstraint(ConstraintBasePtr _ctr_ptr) void CeresManager::removeConstraint(const ConstraintBasePtr& _ctr_ptr)
{ {
// std::cout << "removing constraint " << _ctr_ptr->id() << std::endl;
assert(ctr_2_residual_idx_.find(_ctr_ptr) != ctr_2_residual_idx_.end()); assert(ctr_2_residual_idx_.find(_ctr_ptr) != ctr_2_residual_idx_.end());
ceres_problem_->RemoveResidualBlock(ctr_2_residual_idx_[_ctr_ptr]); ceres_problem_->RemoveResidualBlock(ctr_2_residual_idx_[_ctr_ptr]);
ctr_2_residual_idx_.erase(_ctr_ptr); ctr_2_residual_idx_.erase(_ctr_ptr);
ctr_2_costfunction_.erase(_ctr_ptr); ctr_2_costfunction_.erase(_ctr_ptr);
// std::cout << "removingremoved!" << std::endl;
assert(ceres_problem_->NumResidualBlocks() == ctr_2_residual_idx_.size() && "ceres residuals different from wrapper residuals"); assert(ceres_problem_->NumResidualBlocks() == ctr_2_residual_idx_.size() && "ceres residuals different from wrapper residuals");
} }
void CeresManager::addStateBlock(StateBlockPtr _st_ptr) void CeresManager::addStateBlock(const StateBlockPtr& state_ptr)
{ {
// std::cout << "Adding State Block " << _st_ptr->getPtr() << std::endl; /// @todo we create a new object LocalParametrizationWrapper
// std::cout << " size: " << _st_ptr->getSize() << std::endl; /// but ceres do not take the ownership, thus this is not deleted properly
// std::cout << " vector: " << _st_ptr->getVector().transpose() << std::endl; auto local_parametrization_ptr = (state_ptr->hasLocalParametrization())?
new LocalParametrizationWrapper(state_ptr->getLocalParametrizationPtr()) : nullptr;
if (_st_ptr->hasLocalParametrization())
{ ceres_problem_->AddParameterBlock(getAssociatedMemBlockPtr(state_ptr),
// std::cout << "Local Parametrization to be added:" << _st_ptr->getLocalParametrizationPtr() << std::endl; state_ptr->getSize(),
ceres_problem_->AddParameterBlock(_st_ptr->getPtr(), _st_ptr->getSize(), new LocalParametrizationWrapper(_st_ptr->getLocalParametrizationPtr())); local_parametrization_ptr);
}
else
{
// std::cout << "No Local Parametrization to be added" << std::endl;
ceres_problem_->AddParameterBlock(_st_ptr->getPtr(), _st_ptr->getSize(), nullptr);
}
} }
void CeresManager::removeStateBlock(StateBlockPtr _st_ptr) void CeresManager::removeStateBlock(const StateBlockPtr& state_ptr)
{ {
//std::cout << "Removing State Block " << _st_ptr << std::endl; assert(state_ptr);
assert(_st_ptr); ceres_problem_->RemoveParameterBlock(getAssociatedMemBlockPtr(state_ptr));
ceres_problem_->RemoveParameterBlock(_st_ptr->getPtr());
} }
void CeresManager::updateStateBlockStatus(StateBlockPtr _st_ptr) void CeresManager::updateStateBlockStatus(const StateBlockPtr& state_ptr)
{ {
assert(_st_ptr != nullptr); assert(state_ptr != nullptr);
if (_st_ptr->isFixed()) if (state_ptr->isFixed())
ceres_problem_->SetParameterBlockConstant(_st_ptr->getPtr()); ceres_problem_->SetParameterBlockConstant(getAssociatedMemBlockPtr(state_ptr));
else else
ceres_problem_->SetParameterBlockVariable(_st_ptr->getPtr()); ceres_problem_->SetParameterBlockVariable(getAssociatedMemBlockPtr(state_ptr));
} }
ceres::CostFunctionPtr CeresManager::createCostFunction(ConstraintBasePtr _ctr_ptr) ceres::CostFunctionPtr CeresManager::createCostFunction(const ConstraintBasePtr& _ctr_ptr)
{ {
assert(_ctr_ptr != nullptr); assert(_ctr_ptr != nullptr);
//std::cout << "creating cost function for constraint " << _ctr_ptr->id() << std::endl;
// analitic & autodiff jacobian // analitic & autodiff jacobian
if (_ctr_ptr->getJacobianMethod() == JAC_ANALYTIC || _ctr_ptr->getJacobianMethod() == JAC_AUTO) if (_ctr_ptr->getJacobianMethod() == JAC_ANALYTIC || _ctr_ptr->getJacobianMethod() == JAC_AUTO)
...@@ -314,7 +311,7 @@ ceres::CostFunctionPtr CeresManager::createCostFunction(ConstraintBasePtr _ctr_p ...@@ -314,7 +311,7 @@ ceres::CostFunctionPtr CeresManager::createCostFunction(ConstraintBasePtr _ctr_p
return createNumericDiffCostFunction(_ctr_ptr); return createNumericDiffCostFunction(_ctr_ptr);
else else
throw std::invalid_argument( "Bad Jacobian Method!" ); throw std::invalid_argument( "Wrong Jacobian Method!" );
} }
} // namespace wolf } // namespace wolf
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "glog/logging.h" #include "glog/logging.h"
//wolf includes //wolf includes
#include "solver_manager.h" #include "../solver/solver_manager.h"
#include "cost_function_wrapper.h" #include "cost_function_wrapper.h"
#include "local_parametrization_wrapper.h" #include "local_parametrization_wrapper.h"
#include "create_numeric_diff_cost_function.h" #include "create_numeric_diff_cost_function.h"
...@@ -39,15 +39,17 @@ protected: ...@@ -39,15 +39,17 @@ protected:
std::unique_ptr<ceres::Covariance> covariance_; std::unique_ptr<ceres::Covariance> covariance_;
public: public:
CeresManager(ProblemPtr _wolf_problem, const ceres::Solver::Options& _ceres_options = ceres::Solver::Options());
~CeresManager(); CeresManager(ProblemPtr& _wolf_problem,
const ceres::Solver::Options& _ceres_options
= ceres::Solver::Options());
virtual std::string solve(const unsigned int& _report_level); ~CeresManager();
ceres::Solver::Summary getSummary(); ceres::Solver::Summary getSummary();
virtual void computeCovariances(CovarianceBlocksToBeComputed _blocks = ROBOT_LANDMARKS); virtual void computeCovariances(CovarianceBlocksToBeComputed _blocks
= CovarianceBlocksToBeComputed::ROBOT_LANDMARKS);
virtual void computeCovariances(const StateBlockList& st_list); virtual void computeCovariances(const StateBlockList& st_list);
...@@ -55,17 +57,19 @@ public: ...@@ -55,17 +57,19 @@ public:
private: private:
virtual void addConstraint(ConstraintBasePtr _ctr_ptr); std::string solveImpl(const ReportVerbosity report_level) override;
void addConstraint(const ConstraintBasePtr& ctr_ptr) override;
virtual void removeConstraint(ConstraintBasePtr _ctr_ptr); void removeConstraint(const ConstraintBasePtr& ctr_ptr) override;
virtual void addStateBlock(StateBlockPtr _st_ptr); void addStateBlock(const StateBlockPtr& state_ptr) override;
virtual void removeStateBlock(StateBlockPtr _st_ptr); void removeStateBlock(const StateBlockPtr& state_ptr) override;
virtual void updateStateBlockStatus(StateBlockPtr _st_ptr); void updateStateBlockStatus(const StateBlockPtr& state_ptr) override;
ceres::CostFunctionPtr createCostFunction(ConstraintBasePtr _ctr_ptr); ceres::CostFunctionPtr createCostFunction(const ConstraintBasePtr& _ctr_ptr);
}; };
inline ceres::Solver::Summary CeresManager::getSummary() inline ceres::Solver::Summary CeresManager::getSummary()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment