Skip to content
Snippets Groups Projects
Commit 2d2c86c9 authored by Joan Vallvé Navarro's avatar Joan Vallvé Navarro
Browse files

Solvers const methods, added "failed" method

parent 09f17525
No related branches found
No related tags found
1 merge request!448Draft: Resolve "Implementation of new nodes creation"
Pipeline #20804 failed
...@@ -61,7 +61,7 @@ class SolverCeres : public SolverManager ...@@ -61,7 +61,7 @@ class SolverCeres : public SolverManager
~SolverCeres() override; ~SolverCeres() override;
ceres::Solver::Summary getSummary(); ceres::Solver::Summary getSummary() const;
std::unique_ptr<ceres::Problem>& getCeresProblem(); std::unique_ptr<ceres::Problem>& getCeresProblem();
...@@ -70,12 +70,13 @@ class SolverCeres : public SolverManager ...@@ -70,12 +70,13 @@ class SolverCeres : public SolverManager
bool computeCovariancesDerived(const std::vector<StateBlockPtr>& st_list) override; bool computeCovariancesDerived(const std::vector<StateBlockPtr>& st_list) override;
bool hasConverged() override; bool converged() const override;
bool wasStopped() override; bool failed() const override;
unsigned int iterations() override; bool wasStopped() const override;
double initialCost() override; unsigned int iterations() const override;
double finalCost() override; double initialCost() const override;
double totalTime() override; double finalCost() const override;
double totalTime() const override;
ceres::Solver::Options& getSolverOptions(); ceres::Solver::Options& getSolverOptions();
...@@ -108,10 +109,10 @@ class SolverCeres : public SolverManager ...@@ -108,10 +109,10 @@ class SolverCeres : public SolverManager
bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) const override; bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) const override;
bool isStateBlockFixedDerived(const StateBlockPtr& st) override; bool isStateBlockFixedDerived(const StateBlockPtr& st) const override;
bool hasThisLocalParametrizationDerived(const StateBlockPtr& st, bool hasThisLocalParametrizationDerived(const StateBlockPtr& st,
const LocalParametrizationBasePtr& local_param) override; const LocalParametrizationBasePtr& local_param) const override;
bool hasLocalParametrizationDerived(const StateBlockPtr& st) const override; bool hasLocalParametrizationDerived(const StateBlockPtr& st) const override;
...@@ -121,7 +122,7 @@ class SolverCeres : public SolverManager ...@@ -121,7 +122,7 @@ class SolverCeres : public SolverManager
ceres::Solver::Options solver_options_; ceres::Solver::Options solver_options_;
}; };
inline ceres::Solver::Summary SolverCeres::getSummary() inline ceres::Solver::Summary SolverCeres::getSummary() const
{ {
return summary_; return summary_;
} }
...@@ -147,7 +148,7 @@ inline bool SolverCeres::isStateBlockRegisteredDerived(const StateBlockPtr& stat ...@@ -147,7 +148,7 @@ inline bool SolverCeres::isStateBlockRegisteredDerived(const StateBlockPtr& stat
return ceres_problem_->HasParameterBlock(getAssociatedMemBlockPtr(state_ptr)); return ceres_problem_->HasParameterBlock(getAssociatedMemBlockPtr(state_ptr));
} }
inline bool SolverCeres::isStateBlockFixedDerived(const StateBlockPtr& st) inline bool SolverCeres::isStateBlockFixedDerived(const StateBlockPtr& st) const
{ {
if (state_blocks_.count(st) == 0) return false; if (state_blocks_.count(st) == 0) return false;
return ceres_problem_->IsParameterBlockConstant(getAssociatedMemBlockPtr(st)); return ceres_problem_->IsParameterBlockConstant(getAssociatedMemBlockPtr(st));
......
...@@ -176,12 +176,12 @@ class SolverManager ...@@ -176,12 +176,12 @@ class SolverManager
virtual bool isStateBlockFloating(const StateBlockPtr& state_ptr) const final; virtual bool isStateBlockFloating(const StateBlockPtr& state_ptr) const final;
virtual bool isStateBlockFixed(const StateBlockPtr& st) final; virtual bool isStateBlockFixed(const StateBlockPtr& st) const final;
virtual bool isFactorRegistered(const FactorBasePtr& fac_ptr) const final; virtual bool isFactorRegistered(const FactorBasePtr& fac_ptr) const final;
virtual bool hasThisLocalParametrization(const StateBlockPtr& st, virtual bool hasThisLocalParametrization(const StateBlockPtr& st,
const LocalParametrizationBasePtr& local_param) final; const LocalParametrizationBasePtr& local_param) const final;
virtual bool hasLocalParametrization(const StateBlockPtr& st) const final; virtual bool hasLocalParametrization(const StateBlockPtr& st) const final;
...@@ -227,23 +227,24 @@ class SolverManager ...@@ -227,23 +227,24 @@ class SolverManager
virtual void updateStateBlockStatusDerived(const StateBlockPtr& state_ptr) = 0; virtual void updateStateBlockStatusDerived(const StateBlockPtr& state_ptr) = 0;
virtual void updateStateBlockLocalParametrizationDerived(const StateBlockPtr& state_ptr) = 0; virtual void updateStateBlockLocalParametrizationDerived(const StateBlockPtr& state_ptr) = 0;
virtual bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) const = 0; virtual bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) const = 0;
virtual bool isFactorRegisteredDerived(const FactorBasePtr& fac_ptr) const = 0; virtual bool isFactorRegisteredDerived(const FactorBasePtr& fac_ptr) const = 0;
virtual bool isStateBlockFixedDerived(const StateBlockPtr& st) = 0; virtual bool isStateBlockFixedDerived(const StateBlockPtr& st) const = 0;
virtual bool hasLocalParametrizationDerived(const StateBlockPtr& st) const = 0; virtual bool hasLocalParametrizationDerived(const StateBlockPtr& st) const = 0;
virtual bool hasThisLocalParametrizationDerived(const StateBlockPtr& st, virtual bool hasThisLocalParametrizationDerived(const StateBlockPtr& st,
const LocalParametrizationBasePtr& local_param) = 0; const LocalParametrizationBasePtr& local_param) const = 0;
virtual void printProfilingDerived(std::ostream& stream = std::cout) const = 0; virtual void printProfilingDerived(std::ostream& stream = std::cout) const = 0;
virtual bool checkDerived(std::string prefix = "") const = 0; virtual bool checkDerived(std::string prefix = "") const = 0;
public: public:
virtual bool hasConverged() = 0; virtual bool converged() const = 0;
virtual bool wasStopped() = 0; virtual bool failed() const = 0;
virtual unsigned int iterations() = 0; virtual bool wasStopped() const = 0;
virtual double initialCost() = 0; virtual unsigned int iterations() const = 0;
virtual double finalCost() = 0; virtual double initialCost() const = 0;
virtual double totalTime() = 0; virtual double finalCost() const = 0;
virtual double totalTime() const = 0;
protected: protected:
// PARAMS // PARAMS
......
...@@ -134,7 +134,7 @@ std::string SolverCeres::solveDerived(const ReportVerbosity report_level) ...@@ -134,7 +134,7 @@ std::string SolverCeres::solveDerived(const ReportVerbosity report_level)
n_iter_max_ = std::max(n_iter_max_, iterations()); n_iter_max_ = std::max(n_iter_max_, iterations());
// convergence (profiling) // convergence (profiling)
if (hasConverged()) if (converged())
n_convergence_++; n_convergence_++;
else if (wasStopped()) else if (wasStopped())
n_interrupted_++; n_interrupted_++;
...@@ -649,33 +649,38 @@ void SolverCeres::updateStateBlockLocalParametrizationDerived(const StateBlockPt ...@@ -649,33 +649,38 @@ void SolverCeres::updateStateBlockLocalParametrizationDerived(const StateBlockPt
for (auto fac : involved_factors) addFactorDerived(fac); for (auto fac : involved_factors) addFactorDerived(fac);
} }
bool SolverCeres::hasConverged() bool SolverCeres::converged() const
{ {
return summary_.termination_type == ceres::CONVERGENCE; return summary_.termination_type == ceres::CONVERGENCE;
} }
bool SolverCeres::wasStopped() bool SolverCeres::wasStopped() const
{ {
return summary_.termination_type == ceres::USER_FAILURE or summary_.termination_type == ceres::USER_SUCCESS; return summary_.termination_type == ceres::USER_FAILURE or summary_.termination_type == ceres::USER_SUCCESS;
} }
unsigned int SolverCeres::iterations() bool SolverCeres::failed() const
{
return summary_.termination_type == ceres::USER_FAILURE or summary_.termination_type == ceres::FAILURE;
}
unsigned int SolverCeres::iterations() const
{ {
if (summary_.num_successful_steps + summary_.num_unsuccessful_steps < 1) return 0; if (summary_.num_successful_steps + summary_.num_unsuccessful_steps < 1) return 0;
return summary_.num_successful_steps + summary_.num_unsuccessful_steps; return summary_.num_successful_steps + summary_.num_unsuccessful_steps;
} }
double SolverCeres::initialCost() double SolverCeres::initialCost() const
{ {
return double(summary_.initial_cost); return double(summary_.initial_cost);
} }
double SolverCeres::finalCost() double SolverCeres::finalCost() const
{ {
return double(summary_.final_cost); return double(summary_.final_cost);
} }
double SolverCeres::totalTime() double SolverCeres::totalTime() const
{ {
return double(summary_.total_time_in_seconds); return double(summary_.total_time_in_seconds);
} }
...@@ -885,7 +890,7 @@ const Eigen::SparseMatrixd SolverCeres::computeHessian() const ...@@ -885,7 +890,7 @@ const Eigen::SparseMatrixd SolverCeres::computeHessian() const
} }
bool SolverCeres::hasThisLocalParametrizationDerived(const StateBlockPtr& st, bool SolverCeres::hasThisLocalParametrizationDerived(const StateBlockPtr& st,
const LocalParametrizationBasePtr& local_param) const LocalParametrizationBasePtr& local_param) const
{ {
return state_blocks_local_param_.count(st) == 1 && return state_blocks_local_param_.count(st) == 1 &&
state_blocks_local_param_.at(st)->getLocalParametrization() == local_param && state_blocks_local_param_.at(st)->getLocalParametrization() == local_param &&
......
...@@ -568,7 +568,7 @@ bool SolverManager::isFactorRegistered(const FactorBasePtr& fac_ptr) const ...@@ -568,7 +568,7 @@ bool SolverManager::isFactorRegistered(const FactorBasePtr& fac_ptr) const
return factors_.count(fac_ptr) == 1 and isFactorRegisteredDerived(fac_ptr); return factors_.count(fac_ptr) == 1 and isFactorRegisteredDerived(fac_ptr);
} }
bool SolverManager::isStateBlockFixed(const StateBlockPtr& st) bool SolverManager::isStateBlockFixed(const StateBlockPtr& st) const
{ {
if (!isStateBlockRegistered(st)) return false; if (!isStateBlockRegistered(st)) return false;
...@@ -578,7 +578,7 @@ bool SolverManager::isStateBlockFixed(const StateBlockPtr& st) ...@@ -578,7 +578,7 @@ bool SolverManager::isStateBlockFixed(const StateBlockPtr& st)
} }
bool SolverManager::hasThisLocalParametrization(const StateBlockPtr& st, bool SolverManager::hasThisLocalParametrization(const StateBlockPtr& st,
const LocalParametrizationBasePtr& local_param) const LocalParametrizationBasePtr& local_param) const
{ {
if (!isStateBlockRegistered(st)) return false; if (!isStateBlockRegistered(st)) return false;
......
...@@ -21,15 +21,15 @@ ...@@ -21,15 +21,15 @@
namespace wolf namespace wolf
{ {
SolverDummy::SolverDummy(const ProblemPtr& wolf_problem, const YAML::Node params) SolverDummy::SolverDummy(const ProblemPtr& wolf_problem, const YAML::Node params)
: SolverManager(wolf_problem, params){}; : SolverManager(wolf_problem, params) {};
bool SolverDummy::isStateBlockFixedDerived(const StateBlockPtr& st) bool SolverDummy::isStateBlockFixedDerived(const StateBlockPtr& st) const
{ {
return state_block_fixed_.at(st); return state_block_fixed_.at(st);
}; };
bool SolverDummy::hasThisLocalParametrizationDerived(const StateBlockPtr& st, bool SolverDummy::hasThisLocalParametrizationDerived(const StateBlockPtr& st,
const LocalParametrizationBasePtr& local_param) const LocalParametrizationBasePtr& local_param) const
{ {
return state_block_local_param_.at(st) == local_param; return state_block_local_param_.at(st) == local_param;
}; };
...@@ -59,27 +59,31 @@ bool SolverDummy::computeCovariancesDerived(const std::vector<StateBlockPtr>& st ...@@ -59,27 +59,31 @@ bool SolverDummy::computeCovariancesDerived(const std::vector<StateBlockPtr>& st
}; };
// The following are dummy implementations // The following are dummy implementations
bool SolverDummy::hasConverged() bool SolverDummy::converged() const
{ {
return true; return true;
} }
bool SolverDummy::wasStopped() bool SolverDummy::failed() const
{ {
return false; return false;
} }
unsigned int SolverDummy::iterations() bool SolverDummy::wasStopped() const
{
return false;
}
unsigned int SolverDummy::iterations() const
{ {
return 1; return 1;
} }
double SolverDummy::initialCost() double SolverDummy::initialCost() const
{ {
return double(1); return double(1);
} }
double SolverDummy::finalCost() double SolverDummy::finalCost() const
{ {
return double(0); return double(0);
} }
double SolverDummy::totalTime() double SolverDummy::totalTime() const
{ {
return double(0); return double(0);
} }
......
...@@ -34,10 +34,10 @@ class SolverDummy : public SolverManager ...@@ -34,10 +34,10 @@ class SolverDummy : public SolverManager
SolverDummy(const ProblemPtr& wolf_problem, const YAML::Node params); SolverDummy(const ProblemPtr& wolf_problem, const YAML::Node params);
WOLF_SOLVER_CREATE(SolverDummy); WOLF_SOLVER_CREATE(SolverDummy);
bool isStateBlockFixedDerived(const StateBlockPtr& st) override; bool isStateBlockFixedDerived(const StateBlockPtr& st) const override;
bool hasThisLocalParametrizationDerived(const StateBlockPtr& st, bool hasThisLocalParametrizationDerived(const StateBlockPtr& st,
const LocalParametrizationBasePtr& local_param) override; const LocalParametrizationBasePtr& local_param) const override;
bool hasLocalParametrizationDerived(const StateBlockPtr& st) const override; bool hasLocalParametrizationDerived(const StateBlockPtr& st) const override;
...@@ -49,12 +49,13 @@ class SolverDummy : public SolverManager ...@@ -49,12 +49,13 @@ class SolverDummy : public SolverManager
bool computeCovariancesDerived(const std::vector<StateBlockPtr>& st_list) override; bool computeCovariancesDerived(const std::vector<StateBlockPtr>& st_list) override;
// The following are dummy implementations // The following are dummy implementations
bool hasConverged() override; bool converged() const override;
bool wasStopped() override; bool failed() const override;
unsigned int iterations() override; bool wasStopped() const override;
double initialCost() override; unsigned int iterations() const override;
double finalCost() override; double initialCost() const override;
double totalTime() override; double finalCost() const override;
double totalTime() const override;
void printProfilingDerived(std::ostream& _stream) const override; void printProfilingDerived(std::ostream& _stream) const override;
protected: protected:
......
...@@ -139,7 +139,7 @@ class FactorVelocityLocalDirection3dTest : public testing::Test ...@@ -139,7 +139,7 @@ class FactorVelocityLocalDirection3dTest : public testing::Test
fac->getFeature()->remove(); fac->getFeature()->remove();
// Update performaces // Update performaces
convergence.push_back(solver->hasConverged() ? 1 : 0); convergence.push_back(solver->converged() ? 1 : 0);
iterations.push_back(solver->iterations()); iterations.push_back(solver->iterations());
times.push_back(solver->totalTime()); times.push_back(solver->totalTime());
error.push_back(acos(cos_angle_local)); error.push_back(acos(cos_angle_local));
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment