Skip to content
Snippets Groups Projects
Commit 81b02518 authored by Joan Vallvé Navarro's avatar Joan Vallvé Navarro
Browse files

added 'floating_state_blocks' in SolverManager

parent dae47f89
No related branches found
No related tags found
1 merge request!391Resolve "SolverManager posponing "floating" state blocks"
Pipeline #5892 passed
......@@ -33,3 +33,4 @@ src/examples/map_apriltag_save.yaml
build_release/
.clangd
wolfcore.found
/wolf.found
......@@ -138,7 +138,7 @@ class SolverCeres : public SolverManager
bool isFactorRegisteredDerived(const FactorBasePtr& fac_ptr) const override;
bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) override;
bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) const override;
};
inline ceres::Solver::Summary SolverCeres::getSummary()
......@@ -162,7 +162,7 @@ inline bool SolverCeres::isFactorRegisteredDerived(const FactorBasePtr& fac_ptr)
&& fac_2_costfunction_.find(fac_ptr) != fac_2_costfunction_.end();
}
inline bool SolverCeres::isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr)
inline bool SolverCeres::isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) const
{
return state_blocks_local_param_.find(state_ptr) != state_blocks_local_param_.end()
&& ceres_problem_->HasParameterBlock(getAssociatedMemBlockPtr(state_ptr));
......
......@@ -119,17 +119,18 @@ class SolverManager
ReportVerbosity getVerbosity() const;
virtual bool isStateBlockRegistered(const StateBlockPtr& state_ptr);
virtual bool isStateBlockRegistered(const StateBlockPtr& state_ptr) const final;
virtual bool isFactorRegistered(const FactorBasePtr& fac_ptr) const;
virtual bool isFactorRegistered(const FactorBasePtr& fac_ptr) const final;
bool check(std::string prefix="") const;
virtual bool check(std::string prefix="") const final;
protected:
std::map<StateBlockPtr, Eigen::VectorXd> state_blocks_;
std::map<StateBlockPtr, FactorBasePtrList> state_blocks_2_factors_;
std::set<FactorBasePtr> factors_;
std::set<StateBlockPtr> floating_state_blocks_;
virtual Eigen::VectorXd& getAssociatedMemBlock(const StateBlockPtr& state_ptr);
const double* getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr) const;
......@@ -137,13 +138,13 @@ class SolverManager
private:
// SolverManager functions
void addFactor(const FactorBasePtr& fac_ptr);
void removeFactor(const FactorBasePtr& fac_ptr);
void addStateBlock(const StateBlockPtr& state_ptr);
void removeStateBlock(const StateBlockPtr& state_ptr);
void updateStateBlockState(const StateBlockPtr& state_ptr);
void updateStateBlockStatus(const StateBlockPtr& state_ptr);
void updateStateBlockLocalParametrization(const StateBlockPtr& state_ptr);
virtual void addFactor(const FactorBasePtr& fac_ptr) final;
virtual void removeFactor(const FactorBasePtr& fac_ptr) final;
virtual void addStateBlock(const StateBlockPtr& state_ptr) final;
virtual void removeStateBlock(const StateBlockPtr& state_ptr) final;
virtual void updateStateBlockState(const StateBlockPtr& state_ptr) final;
virtual void updateStateBlockStatus(const StateBlockPtr& state_ptr) final;
virtual void updateStateBlockLocalParametrization(const StateBlockPtr& state_ptr) final;
protected:
// Derived virtual functions
......@@ -154,7 +155,7 @@ class SolverManager
virtual void removeStateBlockDerived(const StateBlockPtr& state_ptr) = 0;
virtual void updateStateBlockStatusDerived(const StateBlockPtr& state_ptr) = 0;
virtual void updateStateBlockLocalParametrizationDerived(const StateBlockPtr& state_ptr) = 0;
virtual bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) = 0;
virtual bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) const = 0;
virtual bool isFactorRegisteredDerived(const FactorBasePtr& fac_ptr) const = 0;
virtual bool checkDerived(std::string prefix="") const = 0;
};
......
......@@ -56,12 +56,24 @@ void SolverManager::update()
// remove
else
removeStateBlock(sb_notification_map.begin()->first);
{
if (floating_state_blocks_.count(sb_notification_map.begin()->first) == 1)
floating_state_blocks_.erase(sb_notification_map.begin()->first);
else
removeStateBlock(sb_notification_map.begin()->first);
}
// remove notification
sb_notification_map.erase(sb_notification_map.begin());
}
// ADD "floating" STATE BLOCKS (last update they weren't involved in any factor)
while (!floating_state_blocks_.empty())
{
addStateBlock(*floating_state_blocks_.begin());
floating_state_blocks_.erase(floating_state_blocks_.begin());
}
// ADD FACTORS
while (!fac_notification_map.empty())
{
......@@ -79,6 +91,14 @@ void SolverManager::update()
{
auto state_ptr = state_pair.first;
// Check for "floating" state blocks (estimated but not involved in any factor -> not observable problem)
if (state_blocks_2_factors_.at(state_ptr).empty())
{
WOLF_INFO("SolverManager::update(): 'Floating' StateBlock ", state_ptr, " (not involved in any factor) Storing it apart.");
floating_state_blocks_.insert(state_ptr);
continue;
}
// state update
if (state_ptr->stateUpdated())
updateStateBlockState(state_ptr);
......@@ -92,6 +112,16 @@ void SolverManager::update()
updateStateBlockLocalParametrization(state_ptr);
}
// REMOVE "floating" STATE BLOCKS (will be added next update() call)
for (auto state_ptr : floating_state_blocks_)
{
removeStateBlock(state_ptr);
// reset flags meaning "solver will handle this change" (state, fix and local param will be set in addStateBlock)
state_ptr->resetStateUpdated();
state_ptr->resetFixUpdated();
state_ptr->resetLocalParamUpdated();
}
#ifdef _WOLF_DEBUG
assert(check("after update()"));
#endif
......@@ -321,9 +351,9 @@ SolverManager::ReportVerbosity SolverManager::getVerbosity() const
return params_->verbose;
}
bool SolverManager::isStateBlockRegistered(const StateBlockPtr& state_ptr)
bool SolverManager::isStateBlockRegistered(const StateBlockPtr& state_ptr) const
{
return state_blocks_.count(state_ptr) ==1 && isStateBlockRegisteredDerived(state_ptr);
return floating_state_blocks_.count(state_ptr) == 1 or (state_blocks_.count(state_ptr) == 1 and isStateBlockRegisteredDerived(state_ptr));
}
bool SolverManager::isFactorRegistered(const FactorBasePtr& fac_ptr) const
......@@ -357,15 +387,32 @@ bool SolverManager::check(std::string prefix) const
ok = false;
}
// factor involving state block in factors_
for (auto fac : sb_fac_it->second)
// no factors involving state block
if (sb_fac_it->second.empty())
{
if (factors_.count(fac) == 0)
WOLF_ERROR("SolverManager::check: state block ", sb_fac_it->first, " is in state_blocks_ but not involved in any factor - in ", prefix);
ok = false;
}
else
{
// factor involving state block in factors_
for (auto fac : sb_fac_it->second)
{
WOLF_ERROR("SolverManager::check: factor ", fac->id(), " (involved in sb ", sb_fac_it->first, ") missing in factors_ map - in ", prefix);
ok = false;
if (factors_.count(fac) == 0)
{
WOLF_ERROR("SolverManager::check: factor ", fac->id(), " (involved in sb ", sb_fac_it->first, ") missing in factors_ map - in ", prefix);
ok = false;
}
}
}
// can't be in both state_blocks_ and floating_state_blocks_
if (floating_state_blocks_.count(sb_fac_it->first) == 1)
{
WOLF_ERROR("SolverManager::check: state block ", sb_fac_it->first, " is both in state_blocks_ and floating_state_blocks_ - in ", prefix);
ok = false;
}
sb_vec_it++;
sb_fac_it++;
}
......
......@@ -17,7 +17,7 @@ WOLF_PTR_TYPEDEFS(SolverManagerDummy);
class SolverManagerDummy : public SolverManager
{
public:
std::list<FactorBasePtr> factors_;
std::set<FactorBasePtr> factors_;
std::map<StateBlockPtr,bool> state_block_fixed_;
std::map<StateBlockPtr,LocalParametrizationBasePtr> state_block_local_param_;
......@@ -26,35 +26,33 @@ class SolverManagerDummy : public SolverManager
{
};
bool isStateBlockRegistered(const StateBlockPtr& st) override
{
return state_blocks_.find(st)!=state_blocks_.end();
};
bool isStateBlockFixed(const StateBlockPtr& st) const
{
return state_block_fixed_.at(st);
};
bool isFactorRegistered(const FactorBasePtr& fac_ptr) const override
{
return std::find(factors_.begin(), factors_.end(), fac_ptr) != factors_.end();
if (floating_state_blocks_.count(st))
return st->isFixed();
else
return state_block_fixed_.at(st);
};
bool hasThisLocalParametrization(const StateBlockPtr& st, const LocalParametrizationBasePtr& local_param) const
{
return state_block_local_param_.find(st) != state_block_local_param_.end() && state_block_local_param_.at(st) == local_param;
if (floating_state_blocks_.count(st))
return st->getLocalParametrization() == local_param;
else
return state_block_local_param_.count(st) == 1 and state_block_local_param_.at(st) == local_param;
};
bool hasLocalParametrization(const StateBlockPtr& st) const
{
return state_block_local_param_.find(st) != state_block_local_param_.end();
if (floating_state_blocks_.count(st))
return st->hasLocalParametrization();
else
return state_block_local_param_.count(st) == 1;
};
void computeCovariances(const CovarianceBlocksToBeComputed blocks) override {};
void computeCovariances(const std::vector<StateBlockPtr>& st_list) override {};
bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) override {return true;};
bool isFactorRegisteredDerived(const FactorBasePtr& fac_ptr) const override {return true;};
// The following are dummy implementations
bool hasConverged() override { return true; }
......@@ -70,11 +68,11 @@ class SolverManagerDummy : public SolverManager
std::string solveDerived(const ReportVerbosity report_level) override { return std::string("");};
void addFactorDerived(const FactorBasePtr& fac_ptr) override
{
factors_.push_back(fac_ptr);
factors_.insert(fac_ptr);
};
void removeFactorDerived(const FactorBasePtr& fac_ptr) override
{
factors_.remove(fac_ptr);
factors_.erase(fac_ptr);
};
void addStateBlockDerived(const StateBlockPtr& state_ptr) override
{
......@@ -97,6 +95,15 @@ class SolverManagerDummy : public SolverManager
else
state_block_local_param_[state_ptr] = state_ptr->getLocalParametrization();
};
bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) const override
{
return state_block_fixed_.count(state_ptr) == 1 and state_block_local_param_.count(state_ptr) == 1;
};
bool isFactorRegisteredDerived(const FactorBasePtr& fac_ptr) const override
{
return factors_.count(fac_ptr) == 1;
};
};
}
......
......@@ -36,19 +36,12 @@ class SolverCeresWrapper : public SolverCeres
{
};
bool isStateBlockRegisteredSolverCeres(const StateBlockPtr& st)
{
return ceres_problem_->HasParameterBlock(SolverManager::getAssociatedMemBlockPtr(st));
};
bool isStateBlockRegisteredSolverManager(const StateBlockPtr& st)
{
return state_blocks_.find(st)!=state_blocks_.end();
};
bool isStateBlockFixed(const StateBlockPtr& st)
{
return ceres_problem_->IsParameterBlockConstant(SolverManager::getAssociatedMemBlockPtr(st));
if (floating_state_blocks_.count(st))
return st->isFixed();
else
return ceres_problem_->IsParameterBlockConstant(SolverManager::getAssociatedMemBlockPtr(st));
};
int numStateBlocks()
......@@ -61,21 +54,22 @@ class SolverCeresWrapper : public SolverCeres
return ceres_problem_->NumResidualBlocks();
};
bool isFactorRegistered(const FactorBasePtr& fac_ptr) const override
{
return fac_2_residual_idx_.find(fac_ptr) != fac_2_residual_idx_.end() && fac_2_costfunction_.find(fac_ptr) != fac_2_costfunction_.end();
};
bool hasThisLocalParametrization(const StateBlockPtr& st, const LocalParametrizationBasePtr& local_param)
{
return state_blocks_local_param_.find(st) != state_blocks_local_param_.end() &&
state_blocks_local_param_.at(st)->getLocalParametrization() == local_param &&
ceres_problem_->GetParameterization(getAssociatedMemBlockPtr(st)) == state_blocks_local_param_.at(st).get();
if (floating_state_blocks_.count(st))
return st->getLocalParametrization() == local_param;
else
return state_blocks_local_param_.count(st) == 1 and
state_blocks_local_param_.at(st)->getLocalParametrization() == local_param and
ceres_problem_->GetParameterization(getAssociatedMemBlockPtr(st)) == state_blocks_local_param_.at(st).get();
};
bool hasLocalParametrization(const StateBlockPtr& st) const
{
return state_blocks_local_param_.find(st) != state_blocks_local_param_.end();
if (floating_state_blocks_.count(st))
return st->hasLocalParametrization();
else
return state_blocks_local_param_.count(st) == 1;
};
};
......@@ -108,8 +102,7 @@ TEST(SolverCeres, AddStateBlock)
solver_ceres->update();
// check stateblock
ASSERT_TRUE(solver_ceres->isStateBlockRegisteredSolverManager(sb_ptr));
ASSERT_TRUE(solver_ceres->isStateBlockRegisteredSolverCeres(sb_ptr));
ASSERT_TRUE(solver_ceres->isStateBlockRegistered(sb_ptr));
// run ceres manager check
ASSERT_TRUE(solver_ceres->check());
......@@ -137,8 +130,7 @@ TEST(SolverCeres, DoubleAddStateBlock)
solver_ceres->update();
// check stateblock
ASSERT_TRUE(solver_ceres->isStateBlockRegisteredSolverManager(sb_ptr));
ASSERT_TRUE(solver_ceres->isStateBlockRegisteredSolverCeres(sb_ptr));
ASSERT_TRUE(solver_ceres->isStateBlockRegistered(sb_ptr));
// run ceres manager check
ASSERT_TRUE(solver_ceres->check());
......@@ -194,8 +186,7 @@ TEST(SolverCeres, AddUpdateStateBlock)
solver_ceres->update();
// check stateblock fixed
ASSERT_TRUE(solver_ceres->isStateBlockRegisteredSolverManager(sb_ptr));
ASSERT_TRUE(solver_ceres->isStateBlockRegisteredSolverCeres(sb_ptr));
ASSERT_TRUE(solver_ceres->isStateBlockRegistered(sb_ptr));
ASSERT_TRUE(solver_ceres->isStateBlockFixed(sb_ptr));
// run ceres manager check
......@@ -217,8 +208,7 @@ TEST(SolverCeres, RemoveStateBlock)
// update solver
solver_ceres->update();
ASSERT_TRUE(solver_ceres->isStateBlockRegisteredSolverManager(sb_ptr));
ASSERT_TRUE(solver_ceres->isStateBlockRegisteredSolverCeres(sb_ptr));
ASSERT_TRUE(solver_ceres->isStateBlockRegistered(sb_ptr));
// remove state_block
P->notifyStateBlock(sb_ptr,REMOVE);
......@@ -227,7 +217,7 @@ TEST(SolverCeres, RemoveStateBlock)
solver_ceres->update();
// check stateblock
ASSERT_FALSE(solver_ceres->isStateBlockRegisteredSolverManager(sb_ptr));
ASSERT_FALSE(solver_ceres->isStateBlockRegistered(sb_ptr));
ASSERT_EQ(solver_ceres->numStateBlocks(), 0);
// run ceres manager check
......@@ -253,7 +243,7 @@ TEST(SolverCeres, AddRemoveStateBlock)
solver_ceres->update();
// check no stateblocks
ASSERT_FALSE(solver_ceres->isStateBlockRegisteredSolverManager(sb_ptr));
ASSERT_FALSE(solver_ceres->isStateBlockRegistered(sb_ptr));
ASSERT_EQ(solver_ceres->numStateBlocks(), 0);
// run ceres manager check
......
......@@ -48,7 +48,7 @@ TEST(SolverManager, AddStateBlock)
solver_manager_ptr->update();
// check stateblock
ASSERT_TRUE(solver_manager_ptr->isStateBlockRegistered(sb_ptr));
EXPECT_TRUE(solver_manager_ptr->isStateBlockRegistered(sb_ptr));
}
TEST(SolverManager, DoubleAddStateBlock)
......@@ -73,7 +73,7 @@ TEST(SolverManager, DoubleAddStateBlock)
solver_manager_ptr->update();
// check stateblock
ASSERT_TRUE(solver_manager_ptr->isStateBlockRegistered(sb_ptr));
EXPECT_TRUE(solver_manager_ptr->isStateBlockRegistered(sb_ptr));
}
TEST(SolverManager, UpdateStateBlock)
......@@ -551,7 +551,7 @@ TEST(SolverManager, DoubleRemoveFactor)
ASSERT_FALSE(solver_manager_ptr->isFactorRegistered(c));
}
TEST(SolverManager, MultiThreadingTruncatedNotifications)
/*TEST(SolverManager, MultiThreadingTruncatedNotifications)
{
double Dt = 5.0;
ProblemPtr P = Problem::create("PO", 2);
......@@ -588,7 +588,7 @@ TEST(SolverManager, MultiThreadingTruncatedNotifications)
}
t.join();
}
}*/
int main(int argc, char **argv)
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment