Skip to content
Snippets Groups Projects
Commit ee9f892a authored by Joan Vallvé Navarro's avatar Joan Vallvé Navarro
Browse files

improved tests, all ok

parent b2e6175f
No related branches found
No related tags found
1 merge request!391Resolve "SolverManager posponing "floating" state blocks"
Pipeline #5895 passed
...@@ -169,14 +169,13 @@ inline ceres::Solver::Options& SolverCeres::getSolverOptions() ...@@ -169,14 +169,13 @@ inline ceres::Solver::Options& SolverCeres::getSolverOptions()
inline bool SolverCeres::isFactorRegisteredDerived(const FactorBasePtr& fac_ptr) const inline bool SolverCeres::isFactorRegisteredDerived(const FactorBasePtr& fac_ptr) const
{ {
return fac_2_residual_idx_.find(fac_ptr) != fac_2_residual_idx_.end() return fac_2_residual_idx_.find(fac_ptr) != fac_2_residual_idx_.end() and
&& fac_2_costfunction_.find(fac_ptr) != fac_2_costfunction_.end(); fac_2_costfunction_.find(fac_ptr) != fac_2_costfunction_.end();
} }
inline bool SolverCeres::isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) const inline bool SolverCeres::isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) const
{ {
return state_blocks_local_param_.find(state_ptr) != state_blocks_local_param_.end() return ceres_problem_->HasParameterBlock(getAssociatedMemBlockPtr(state_ptr));
&& ceres_problem_->HasParameterBlock(getAssociatedMemBlockPtr(state_ptr));
} }
inline bool SolverCeres::isStateBlockFixedDerived(const StateBlockPtr& st) inline bool SolverCeres::isStateBlockFixedDerived(const StateBlockPtr& st)
...@@ -184,14 +183,6 @@ inline bool SolverCeres::isStateBlockFixedDerived(const StateBlockPtr& st) ...@@ -184,14 +183,6 @@ inline bool SolverCeres::isStateBlockFixedDerived(const StateBlockPtr& st)
return ceres_problem_->IsParameterBlockConstant(SolverManager::getAssociatedMemBlockPtr(st)); return ceres_problem_->IsParameterBlockConstant(SolverManager::getAssociatedMemBlockPtr(st));
}; };
inline bool SolverCeres::hasThisLocalParametrizationDerived(const StateBlockPtr& st,
const LocalParametrizationBasePtr& local_param)
{
return state_blocks_local_param_.count(st) == 1 and
state_blocks_local_param_.at(st)->getLocalParametrization() == local_param and
ceres_problem_->GetParameterization(getAssociatedMemBlockPtr(st)) == state_blocks_local_param_.at(st).get();
};
inline bool SolverCeres::hasLocalParametrizationDerived(const StateBlockPtr& st) const inline bool SolverCeres::hasLocalParametrizationDerived(const StateBlockPtr& st) const
{ {
return state_blocks_local_param_.count(st) == 1; return state_blocks_local_param_.count(st) == 1;
......
...@@ -579,6 +579,14 @@ const Eigen::SparseMatrixd SolverCeres::computeHessian() const ...@@ -579,6 +579,14 @@ const Eigen::SparseMatrixd SolverCeres::computeHessian() const
return H; return H;
} }
bool SolverCeres::hasThisLocalParametrizationDerived(const StateBlockPtr& st,
const LocalParametrizationBasePtr& local_param)
{
return state_blocks_local_param_.count(st) == 1
&& state_blocks_local_param_.at(st)->getLocalParametrization() == local_param
&& ceres_problem_->GetParameterization(getAssociatedMemBlockPtr(st))
== state_blocks_local_param_.at(st).get();
}
} // namespace wolf } // namespace wolf
#include "core/solver/factory_solver.h" #include "core/solver/factory_solver.h"
......
...@@ -368,26 +368,35 @@ bool SolverManager::isFactorRegistered(const FactorBasePtr& fac_ptr) const ...@@ -368,26 +368,35 @@ bool SolverManager::isFactorRegistered(const FactorBasePtr& fac_ptr) const
bool SolverManager::isStateBlockFixed(const StateBlockPtr& st) bool SolverManager::isStateBlockFixed(const StateBlockPtr& st)
{ {
if (!isStateBlockRegistered(st))
return false;
if (isStateBlockFloating(st)) if (isStateBlockFloating(st))
return st->isFixed(); return st->isFixed();
else
return isStateBlockFixedDerived(st); return isStateBlockFixedDerived(st);
} }
bool SolverManager::hasThisLocalParametrization(const StateBlockPtr& st, const LocalParametrizationBasePtr& local_param) bool SolverManager::hasThisLocalParametrization(const StateBlockPtr& st, const LocalParametrizationBasePtr& local_param)
{ {
if (!isStateBlockRegistered(st))
return false;
if (isStateBlockFloating(st)) if (isStateBlockFloating(st))
return st->getLocalParametrization() == local_param; return st->getLocalParametrization() == local_param;
else
return hasThisLocalParametrizationDerived(st, local_param); return hasThisLocalParametrizationDerived(st, local_param);
}; };
bool SolverManager::hasLocalParametrization(const StateBlockPtr& st) const bool SolverManager::hasLocalParametrization(const StateBlockPtr& st) const
{ {
if (!isStateBlockRegistered(st))
return false;
if (isStateBlockFloating(st)) if (isStateBlockFloating(st))
return st->hasLocalParametrization(); return st->hasLocalParametrization();
else
return hasLocalParametrizationDerived(st); return hasLocalParametrizationDerived(st);
}; };
int SolverManager::numFactors() const int SolverManager::numFactors() const
...@@ -483,7 +492,7 @@ bool SolverManager::check(std::string prefix) const ...@@ -483,7 +492,7 @@ bool SolverManager::check(std::string prefix) const
for (auto sb : state_blocks_) for (auto sb : state_blocks_)
if (!isStateBlockRegisteredDerived(sb.first)) if (!isStateBlockRegisteredDerived(sb.first))
{ {
WOLF_ERROR("SolverManager::checkAgainstDerived: state block ", sb.first, " is in state_blocks_ but not registered in derived solver - in ", prefix); WOLF_ERROR("SolverManager::check: state block ", sb.first, " is in state_blocks_ but not registered in derived solver - in ", prefix);
ok = false; ok = false;
} }
...@@ -491,19 +500,19 @@ bool SolverManager::check(std::string prefix) const ...@@ -491,19 +500,19 @@ bool SolverManager::check(std::string prefix) const
for (auto fac : factors_) for (auto fac : factors_)
if (!isFactorRegisteredDerived(fac)) if (!isFactorRegisteredDerived(fac))
{ {
WOLF_ERROR("SolverManager::checkAgainstDerived: factor ", fac->id(), " is in factors_ but not registered in derived solver - in ", prefix); WOLF_ERROR("SolverManager::check: factor ", fac->id(), " is in factors_ but not registered in derived solver - in ", prefix);
ok = false; ok = false;
} }
// numbers // numbers
if (numStateBlocksDerived() != state_blocks_.size()) if (numStateBlocksDerived() != state_blocks_.size())
{ {
WOLF_ERROR("SolverManager::checkAgainstDerived: numStateBlocksDerived() = ", numStateBlocksDerived(), " DIFFERENT THAN state_blocks_.size() = ", state_blocks_.size(), " - in ", prefix); WOLF_ERROR("SolverManager::check: numStateBlocksDerived() = ", numStateBlocksDerived(), " DIFFERENT THAN state_blocks_.size() = ", state_blocks_.size(), " - in ", prefix);
ok = false; ok = false;
} }
if (numFactorsDerived() != numFactors()) if (numFactorsDerived() != numFactors())
{ {
WOLF_ERROR("SolverManager::checkAgainstDerived: numFactorsDerived() = ", numFactorsDerived(), " DIFFERENT THAN numFactors() = ", numFactors(), " - in ", prefix); WOLF_ERROR("SolverManager::check: numFactorsDerived() = ", numFactorsDerived(), " DIFFERENT THAN numFactors() = ", numFactors(), " - in ", prefix);
ok = false; ok = false;
} }
......
...@@ -33,12 +33,12 @@ class SolverManagerDummy : public SolverManager ...@@ -33,12 +33,12 @@ class SolverManagerDummy : public SolverManager
bool hasThisLocalParametrizationDerived(const StateBlockPtr& st, const LocalParametrizationBasePtr& local_param) override bool hasThisLocalParametrizationDerived(const StateBlockPtr& st, const LocalParametrizationBasePtr& local_param) override
{ {
return state_block_local_param_.count(st) == 1 and state_block_local_param_.at(st) == local_param; return state_block_local_param_.at(st) == local_param;
}; };
bool hasLocalParametrizationDerived(const StateBlockPtr& st) const override bool hasLocalParametrizationDerived(const StateBlockPtr& st) const override
{ {
return state_block_local_param_.count(st) == 1; return state_block_local_param_.at(st) != nullptr;
}; };
virtual int numStateBlocksDerived() const override virtual int numStateBlocksDerived() const override
...@@ -61,8 +61,6 @@ class SolverManagerDummy : public SolverManager ...@@ -61,8 +61,6 @@ class SolverManagerDummy : public SolverManager
double initialCost() override { return double(1); } double initialCost() override { return double(1); }
double finalCost() override { return double(0); } double finalCost() override { return double(0); }
protected: protected:
bool checkDerived(std::string prefix="") const override {return true;} bool checkDerived(std::string prefix="") const override {return true;}
...@@ -91,10 +89,7 @@ class SolverManagerDummy : public SolverManager ...@@ -91,10 +89,7 @@ class SolverManagerDummy : public SolverManager
}; };
void updateStateBlockLocalParametrizationDerived(const StateBlockPtr& state_ptr) override void updateStateBlockLocalParametrizationDerived(const StateBlockPtr& state_ptr) override
{ {
if (state_ptr->getLocalParametrization() == nullptr) state_block_local_param_[state_ptr] = state_ptr->getLocalParametrization();
state_block_local_param_.erase(state_ptr);
else
state_block_local_param_[state_ptr] = state_ptr->getLocalParametrization();
}; };
bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) const override bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) const override
{ {
......
This diff is collapsed.
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment