Skip to content
Snippets Groups Projects
Commit b2e6175f authored by Joan Vallvé Navarro's avatar Joan Vallvé Navarro
Browse files

WIP

parent e7f4acd0
No related branches found
No related tags found
1 merge request!391Resolve "SolverManager posponing "floating" state blocks"
Pipeline #5894 failed
......@@ -116,6 +116,10 @@ class SolverCeres : public SolverManager
const Eigen::SparseMatrixd computeHessian() const;
virtual int numStateBlocksDerived() const override;
virtual int numFactorsDerived() const override;
protected:
bool checkDerived(std::string prefix="") const override;
......@@ -139,6 +143,13 @@ class SolverCeres : public SolverManager
bool isFactorRegisteredDerived(const FactorBasePtr& fac_ptr) const override;
bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) const override;
bool isStateBlockFixedDerived(const StateBlockPtr& st) override;
bool hasThisLocalParametrizationDerived(const StateBlockPtr& st,
const LocalParametrizationBasePtr& local_param) override;
bool hasLocalParametrizationDerived(const StateBlockPtr& st) const override;
};
inline ceres::Solver::Summary SolverCeres::getSummary()
......@@ -168,6 +179,34 @@ inline bool SolverCeres::isStateBlockRegisteredDerived(const StateBlockPtr& stat
&& ceres_problem_->HasParameterBlock(getAssociatedMemBlockPtr(state_ptr));
}
inline bool SolverCeres::isStateBlockFixedDerived(const StateBlockPtr& st)
{
return ceres_problem_->IsParameterBlockConstant(SolverManager::getAssociatedMemBlockPtr(st));
};
inline bool SolverCeres::hasThisLocalParametrizationDerived(const StateBlockPtr& st,
const LocalParametrizationBasePtr& local_param)
{
return state_blocks_local_param_.count(st) == 1 and
state_blocks_local_param_.at(st)->getLocalParametrization() == local_param and
ceres_problem_->GetParameterization(getAssociatedMemBlockPtr(st)) == state_blocks_local_param_.at(st).get();
};
inline bool SolverCeres::hasLocalParametrizationDerived(const StateBlockPtr& st) const
{
return state_blocks_local_param_.count(st) == 1;
};
inline int SolverCeres::numStateBlocksDerived() const
{
return ceres_problem_->NumParameterBlocks();
}
inline int SolverCeres::numFactorsDerived() const
{
return ceres_problem_->NumResidualBlocks();
};
} // namespace wolf
#endif
......@@ -121,7 +121,22 @@ class SolverManager
virtual bool isStateBlockRegistered(const StateBlockPtr& state_ptr) const final;
virtual bool isStateBlockFloating(const StateBlockPtr& state_ptr) const final;
virtual bool isFactorRegistered(const FactorBasePtr& fac_ptr) const final;
virtual bool isStateBlockFixed(const StateBlockPtr& st) final;
virtual bool hasThisLocalParametrization(const StateBlockPtr& st,
const LocalParametrizationBasePtr& local_param) final;
virtual bool hasLocalParametrization(const StateBlockPtr& st) const final;
virtual int numFactors() const final;
virtual int numStateBlocks() const final;
virtual int numStateBlocksFloating() const final;
virtual int numFactorsDerived() const = 0;
virtual int numStateBlocksDerived() const = 0;
virtual bool check(std::string prefix="") const final;
......@@ -134,7 +149,7 @@ class SolverManager
virtual Eigen::VectorXd& getAssociatedMemBlock(const StateBlockPtr& state_ptr);
const double* getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr) const;
virtual double* getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr);
double* getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr);
private:
// SolverManager functions
......@@ -158,6 +173,10 @@ class SolverManager
virtual bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) const = 0;
virtual bool isFactorRegisteredDerived(const FactorBasePtr& fac_ptr) const = 0;
virtual bool checkDerived(std::string prefix="") const = 0;
virtual bool isStateBlockFixedDerived(const StateBlockPtr& st) = 0;
virtual bool hasThisLocalParametrizationDerived(const StateBlockPtr& st,
const LocalParametrizationBasePtr& local_param) = 0;
virtual bool hasLocalParametrizationDerived(const StateBlockPtr& st) const = 0;
};
// Params (here becaure it needs of declaration of SolverManager::ReportVerbosity)
......
......@@ -353,7 +353,12 @@ SolverManager::ReportVerbosity SolverManager::getVerbosity() const
bool SolverManager::isStateBlockRegistered(const StateBlockPtr& state_ptr) const
{
return floating_state_blocks_.count(state_ptr) == 1 or (state_blocks_.count(state_ptr) == 1 and isStateBlockRegisteredDerived(state_ptr));
return isStateBlockFloating(state_ptr) or (state_blocks_.count(state_ptr) == 1 and isStateBlockRegisteredDerived(state_ptr));
}
bool SolverManager::isStateBlockFloating(const StateBlockPtr& state_ptr) const
{
return floating_state_blocks_.count(state_ptr) == 1;
}
bool SolverManager::isFactorRegistered(const FactorBasePtr& fac_ptr) const
......@@ -361,6 +366,45 @@ bool SolverManager::isFactorRegistered(const FactorBasePtr& fac_ptr) const
return factors_.count(fac_ptr) and isFactorRegisteredDerived(fac_ptr);
}
bool SolverManager::isStateBlockFixed(const StateBlockPtr& st)
{
if (isStateBlockFloating(st))
return st->isFixed();
else
return isStateBlockFixedDerived(st);
}
bool SolverManager::hasThisLocalParametrization(const StateBlockPtr& st, const LocalParametrizationBasePtr& local_param)
{
if (isStateBlockFloating(st))
return st->getLocalParametrization() == local_param;
else
return hasThisLocalParametrizationDerived(st, local_param);
};
bool SolverManager::hasLocalParametrization(const StateBlockPtr& st) const
{
if (isStateBlockFloating(st))
return st->hasLocalParametrization();
else
return hasLocalParametrizationDerived(st);
};
int SolverManager::numFactors() const
{
return factors_.size();
}
int SolverManager::numStateBlocks() const
{
return state_blocks_.size() + floating_state_blocks_.size();
}
int SolverManager::numStateBlocksFloating() const
{
return floating_state_blocks_.size();
}
double SolverManager::getPeriod() const
{
return params_->period;
......@@ -431,9 +475,38 @@ bool SolverManager::check(std::string prefix) const
}
}
// checkDerived
// CHECK DERIVED ----------------------
ok &= checkDerived(prefix);
// CHECK IF DERIVED IS UP TO DATE ----------------------
// state blocks registered in derived solver
for (auto sb : state_blocks_)
if (!isStateBlockRegisteredDerived(sb.first))
{
WOLF_ERROR("SolverManager::checkAgainstDerived: state block ", sb.first, " is in state_blocks_ but not registered in derived solver - in ", prefix);
ok = false;
}
// factors registered in derived solver
for (auto fac : factors_)
if (!isFactorRegisteredDerived(fac))
{
WOLF_ERROR("SolverManager::checkAgainstDerived: factor ", fac->id(), " is in factors_ but not registered in derived solver - in ", prefix);
ok = false;
}
// numbers
if (numStateBlocksDerived() != state_blocks_.size())
{
WOLF_ERROR("SolverManager::checkAgainstDerived: numStateBlocksDerived() = ", numStateBlocksDerived(), " DIFFERENT THAN state_blocks_.size() = ", state_blocks_.size(), " - in ", prefix);
ok = false;
}
if (numFactorsDerived() != numFactors())
{
WOLF_ERROR("SolverManager::checkAgainstDerived: numFactorsDerived() = ", numFactorsDerived(), " DIFFERENT THAN numFactors() = ", numFactors(), " - in ", prefix);
ok = false;
}
return ok;
}
......
......@@ -17,7 +17,7 @@ WOLF_PTR_TYPEDEFS(SolverManagerDummy);
class SolverManagerDummy : public SolverManager
{
public:
std::set<FactorBasePtr> factors_;
std::set<FactorBasePtr> factors_derived_;
std::map<StateBlockPtr,bool> state_block_fixed_;
std::map<StateBlockPtr,LocalParametrizationBasePtr> state_block_local_param_;
......@@ -26,28 +26,29 @@ class SolverManagerDummy : public SolverManager
{
};
bool isStateBlockFixed(const StateBlockPtr& st) const
bool isStateBlockFixedDerived(const StateBlockPtr& st) override
{
if (floating_state_blocks_.count(st))
return st->isFixed();
else
return state_block_fixed_.at(st);
return state_block_fixed_.at(st);
};
bool hasThisLocalParametrization(const StateBlockPtr& st, const LocalParametrizationBasePtr& local_param) const
bool hasThisLocalParametrizationDerived(const StateBlockPtr& st, const LocalParametrizationBasePtr& local_param) override
{
if (floating_state_blocks_.count(st))
return st->getLocalParametrization() == local_param;
else
return state_block_local_param_.count(st) == 1 and state_block_local_param_.at(st) == local_param;
return state_block_local_param_.count(st) == 1 and state_block_local_param_.at(st) == local_param;
};
bool hasLocalParametrization(const StateBlockPtr& st) const
bool hasLocalParametrizationDerived(const StateBlockPtr& st) const override
{
if (floating_state_blocks_.count(st))
return st->hasLocalParametrization();
else
return state_block_local_param_.count(st) == 1;
return state_block_local_param_.count(st) == 1;
};
virtual int numStateBlocksDerived() const override
{
return state_block_fixed_.size();
}
virtual int numFactorsDerived() const override
{
return factors_derived_.size();
};
void computeCovariances(const CovarianceBlocksToBeComputed blocks) override {};
......@@ -68,11 +69,11 @@ class SolverManagerDummy : public SolverManager
std::string solveDerived(const ReportVerbosity report_level) override { return std::string("");};
void addFactorDerived(const FactorBasePtr& fac_ptr) override
{
factors_.insert(fac_ptr);
factors_derived_.insert(fac_ptr);
};
void removeFactorDerived(const FactorBasePtr& fac_ptr) override
{
factors_.erase(fac_ptr);
factors_derived_.erase(fac_ptr);
};
void addStateBlockDerived(const StateBlockPtr& state_ptr) override
{
......@@ -102,7 +103,7 @@ class SolverManagerDummy : public SolverManager
bool isFactorRegisteredDerived(const FactorBasePtr& fac_ptr) const override
{
return factors_.count(fac_ptr) == 1;
return factors_derived_.count(fac_ptr) == 1;
};
};
......
......@@ -36,14 +36,6 @@ class SolverCeresWrapper : public SolverCeres
{
};
bool isStateBlockFixed(const StateBlockPtr& st)
{
if (floating_state_blocks_.count(st))
return st->isFixed();
else
return ceres_problem_->IsParameterBlockConstant(SolverManager::getAssociatedMemBlockPtr(st));
};
int numStateBlocks()
{
return ceres_problem_->NumParameterBlocks();
......@@ -54,24 +46,6 @@ class SolverCeresWrapper : public SolverCeres
return ceres_problem_->NumResidualBlocks();
};
bool hasThisLocalParametrization(const StateBlockPtr& st, const LocalParametrizationBasePtr& local_param)
{
if (floating_state_blocks_.count(st))
return st->getLocalParametrization() == local_param;
else
return state_blocks_local_param_.count(st) == 1 and
state_blocks_local_param_.at(st)->getLocalParametrization() == local_param and
ceres_problem_->GetParameterization(getAssociatedMemBlockPtr(st)) == state_blocks_local_param_.at(st).get();
};
bool hasLocalParametrization(const StateBlockPtr& st) const
{
if (floating_state_blocks_.count(st))
return st->hasLocalParametrization();
else
return state_blocks_local_param_.count(st) == 1;
};
};
TEST(SolverCeres, Create)
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment