diff --git a/include/core/ceres_wrapper/ceres_manager.h b/include/core/ceres_wrapper/ceres_manager.h index a9a05eb9bb076dc68127e60b3a28a8964ebd3196..196ce659d5b7800b94bd395c4c21565963028a96 100644 --- a/include/core/ceres_wrapper/ceres_manager.h +++ b/include/core/ceres_wrapper/ceres_manager.h @@ -82,25 +82,25 @@ class CeresManager : public SolverManager ceres::Solver::Options& getSolverOptions(); - virtual bool check() override; + virtual bool check(std::string prefix="") const override; const Eigen::SparseMatrixd computeHessian() const; protected: - std::string solveImpl(const ReportVerbosity report_level) override; + std::string solveDerived(const ReportVerbosity report_level) override; - void addFactor(const FactorBasePtr& fac_ptr) override; + void addFactorDerived(const FactorBasePtr& fac_ptr) override; - void removeFactor(const FactorBasePtr& fac_ptr) override; + void removeFactorDerived(const FactorBasePtr& fac_ptr) override; - void addStateBlock(const StateBlockPtr& state_ptr) override; + void addStateBlockDerived(const StateBlockPtr& state_ptr) override; - void removeStateBlock(const StateBlockPtr& state_ptr) override; + void removeStateBlockDerived(const StateBlockPtr& state_ptr) override; - void updateStateBlockStatus(const StateBlockPtr& state_ptr) override; + void updateStateBlockStatusDerived(const StateBlockPtr& state_ptr) override; - void updateStateBlockLocalParametrization(const StateBlockPtr& state_ptr) override; + void updateStateBlockLocalParametrizationDerived(const StateBlockPtr& state_ptr) override; ceres::CostFunctionPtr createCostFunction(const FactorBasePtr& _fac_ptr); diff --git a/include/core/solver/solver_manager.h b/include/core/solver/solver_manager.h index e28c544bd1266362b453b3f9619626647d4b5dcb..2e61df70b18ec9251e1e3cedb14f386dbc2914cb 100644 --- a/include/core/solver/solver_manager.h +++ b/include/core/solver/solver_manager.h @@ -93,7 +93,7 @@ public: SolverManager(const ProblemPtr& wolf_problem); - virtual ~SolverManager() = default; + virtual ~SolverManager(); std::string solve(const ReportVerbosity report_level = ReportVerbosity::QUIET); @@ -117,38 +117,43 @@ public: virtual bool isFactorRegistered(const FactorBasePtr& fac_ptr) const; - virtual bool check() = 0; + virtual bool check(std::string prefix="") const = 0; - void assertCheck(); + void assertCheck() const; protected: - std::map<StateBlockPtr, Eigen::VectorXd> state_blocks_; + std::map<StateBlockPtr, Eigen::VectorXd> state_blocks_; + std::map<StateBlockPtr, FactorBasePtrList> state_blocks_2_factors_; - virtual Eigen::VectorXd& getAssociatedMemBlock(const StateBlockPtr& state_ptr); - //const double* getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr) const; - virtual double* getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr); + virtual Eigen::VectorXd& getAssociatedMemBlock(const StateBlockPtr& state_ptr); + const double* getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr) const; + virtual double* getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr); - virtual std::string solveImpl(const ReportVerbosity report_level) = 0; +private: + // SolverManager functions + void addFactor(const FactorBasePtr& fac_ptr); + void removeFactor(const FactorBasePtr& fac_ptr); + void addStateBlock(const StateBlockPtr& state_ptr); + void removeStateBlock(const StateBlockPtr& state_ptr); + void updateStateBlockState(const StateBlockPtr& state_ptr); + void updateStateBlockStatus(const StateBlockPtr& state_ptr); + void updateStateBlockLocalParametrization(const StateBlockPtr& state_ptr); - virtual void addFactor(const FactorBasePtr& fac_ptr) = 0; - - virtual void removeFactor(const FactorBasePtr& fac_ptr) = 0; - - virtual void addStateBlock(const StateBlockPtr& state_ptr) = 0; - - virtual void removeStateBlock(const StateBlockPtr& state_ptr) = 0; - - virtual void updateStateBlockStatus(const StateBlockPtr& state_ptr) = 0; - - virtual void updateStateBlockLocalParametrization(const StateBlockPtr& state_ptr) = 0; - - virtual bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) = 0; - - virtual bool isFactorRegisteredDerived(const FactorBasePtr& fac_ptr) const = 0; +protected: + // Derived virtual functions + virtual std::string solveDerived(const ReportVerbosity report_level) = 0; + virtual void addFactorDerived(const FactorBasePtr& fac_ptr) = 0; + virtual void removeFactorDerived(const FactorBasePtr& fac_ptr) = 0; + virtual void addStateBlockDerived(const StateBlockPtr& state_ptr) = 0; + virtual void removeStateBlockDerived(const StateBlockPtr& state_ptr) = 0; + virtual void updateStateBlockStatusDerived(const StateBlockPtr& state_ptr) = 0; + virtual void updateStateBlockLocalParametrizationDerived(const StateBlockPtr& state_ptr) = 0; + virtual bool isStateBlockRegisteredDerived(const StateBlockPtr& state_ptr) = 0; + virtual bool isFactorRegisteredDerived(const FactorBasePtr& fac_ptr) const = 0; }; -inline void SolverManager::assertCheck() +inline void SolverManager::assertCheck() const { #ifdef _WOLF_DEBUG assert(check()); diff --git a/src/ceres_wrapper/ceres_manager.cpp b/src/ceres_wrapper/ceres_manager.cpp index aad86ac2950a77fbe4067a200a419be9d553eaec..7092f97a3ba8c3fd945cc493bc23dda7784522de 100644 --- a/src/ceres_wrapper/ceres_manager.cpp +++ b/src/ceres_wrapper/ceres_manager.cpp @@ -36,8 +36,6 @@ CeresManager::CeresManager(const ProblemPtr& _wolf_problem, CeresManager::~CeresManager() { - while (!fac_2_residual_idx_.empty()) - removeFactor(fac_2_residual_idx_.begin()->first); } SolverManagerPtr CeresManager::create(const ProblemPtr &_wolf_problem, const ParamsServer& _server) @@ -48,7 +46,7 @@ CeresManager::~CeresManager() return std::make_shared<CeresManager>(_wolf_problem, opt); } -std::string CeresManager::solveImpl(const ReportVerbosity report_level) +std::string CeresManager::solveDerived(const ReportVerbosity report_level) { // run Ceres Solver ceres::Solve(ceres_options_, ceres_problem_.get(), &summary_); @@ -254,7 +252,7 @@ void CeresManager::computeCovariances(const std::vector<StateBlockPtr>& st_list) std::cout << "WARNING: Couldn't compute covariances!" << std::endl; } -void CeresManager::addFactor(const FactorBasePtr& fac_ptr) +void CeresManager::addFactorDerived(const FactorBasePtr& fac_ptr) { assert(fac_2_costfunction_.find(fac_ptr) == fac_2_costfunction_.end() && "adding a factor that is already in the fac_2_costfunction_ map"); @@ -282,7 +280,7 @@ void CeresManager::addFactor(const FactorBasePtr& fac_ptr) assert((unsigned int)(ceres_problem_->NumResidualBlocks()) == fac_2_residual_idx_.size() && "ceres residuals different from wrapper residuals"); } -void CeresManager::removeFactor(const FactorBasePtr& _fac_ptr) +void CeresManager::removeFactorDerived(const FactorBasePtr& _fac_ptr) { assert(fac_2_residual_idx_.find(_fac_ptr) != fac_2_residual_idx_.end() && "removing a factor that is not in the fac_2_residual map"); @@ -293,7 +291,7 @@ void CeresManager::removeFactor(const FactorBasePtr& _fac_ptr) assert((unsigned int)(ceres_problem_->NumResidualBlocks()) == fac_2_residual_idx_.size() && "ceres residuals different from wrapper residuals"); } -void CeresManager::addStateBlock(const StateBlockPtr& state_ptr) +void CeresManager::addStateBlockDerived(const StateBlockPtr& state_ptr) { ceres::LocalParameterization* local_parametrization_ptr = nullptr; @@ -313,10 +311,10 @@ void CeresManager::addStateBlock(const StateBlockPtr& state_ptr) if (state_ptr->isFixed()) ceres_problem_->SetParameterBlockConstant(getAssociatedMemBlockPtr(state_ptr)); - updateStateBlockStatus(state_ptr); + updateStateBlockStatusDerived(state_ptr); } -void CeresManager::removeStateBlock(const StateBlockPtr& state_ptr) +void CeresManager::removeStateBlockDerived(const StateBlockPtr& state_ptr) { //std::cout << "CeresManager::removeStateBlock " << state_ptr.get() << " - " << getAssociatedMemBlockPtr(state_ptr) << std::endl; assert(state_ptr); @@ -324,7 +322,7 @@ void CeresManager::removeStateBlock(const StateBlockPtr& state_ptr) state_blocks_local_param_.erase(state_ptr); } -void CeresManager::updateStateBlockStatus(const StateBlockPtr& state_ptr) +void CeresManager::updateStateBlockStatusDerived(const StateBlockPtr& state_ptr) { assert(state_ptr != nullptr); if (state_ptr->isFixed()) @@ -333,7 +331,7 @@ void CeresManager::updateStateBlockStatus(const StateBlockPtr& state_ptr) ceres_problem_->SetParameterBlockVariable(getAssociatedMemBlockPtr(state_ptr)); } -void CeresManager::updateStateBlockLocalParametrization(const StateBlockPtr& state_ptr) +void CeresManager::updateStateBlockLocalParametrizationDerived(const StateBlockPtr& state_ptr) { assert(state_ptr != nullptr); @@ -356,17 +354,17 @@ void CeresManager::updateStateBlockLocalParametrization(const StateBlockPtr& sta // Remove all involved factors (it does not remove any parameter block) for (auto fac : involved_factors) - removeFactor(fac); + removeFactorDerived(fac); // Remove state block (it removes all involved residual blocks but they just were removed) - removeStateBlock(state_ptr); + removeStateBlockDerived(state_ptr); // Add state block - addStateBlock(state_ptr); + addStateBlockDerived(state_ptr); // Add all involved factors for (auto fac : involved_factors) - addFactor(fac); + addFactorDerived(fac); } bool CeresManager::hasConverged() @@ -405,74 +403,82 @@ ceres::CostFunctionPtr CeresManager::createCostFunction(const FactorBasePtr& _fa throw std::invalid_argument( "Wrong Jacobian Method!" ); } -bool CeresManager::check() +bool CeresManager::check(std::string prefix) const { bool ok = true; // Check numbers if (ceres_problem_->NumResidualBlocks() != fac_2_costfunction_.size()) { - WOLF_ERROR("CeresManager::check: number of residuals mismatch"); + WOLF_ERROR("CeresManager::check: number of residuals mismatch - in ", prefix); ok = false; } if (ceres_problem_->NumResidualBlocks() != fac_2_residual_idx_.size()) { - WOLF_ERROR("CeresManager::check: number of residuals mismatch"); + WOLF_ERROR("CeresManager::check: number of residuals mismatch - in ", prefix); ok = false; } if (ceres_problem_->NumParameterBlocks() != state_blocks_.size()) { - WOLF_ERROR("CeresManager::check: number of parameters mismatch. ceres_problem_->NumParameterBlocks() = ", ceres_problem_->NumParameterBlocks(), " - state_blocks_.size() = ", state_blocks_.size()); + WOLF_ERROR("CeresManager::check: number of parameters mismatch. ceres_problem_->NumParameterBlocks() = ", ceres_problem_->NumParameterBlocks(), " - state_blocks_.size() = ", state_blocks_.size(), " - in ", prefix); ok = false; } // Check parameter blocks - for (auto&& state_block_pair : state_blocks_) - //for (const auto& state_block_pair : state_blocks_) + //for (auto&& state_block_pair : state_blocks_) + for (const auto& state_block_pair : state_blocks_) { if (!ceres_problem_->HasParameterBlock(state_block_pair.second.data())) { - WOLF_ERROR("CeresManager::check: any state block is missing in ceres problem"); + WOLF_ERROR("CeresManager::check: any state block is missing in ceres problem - in ", prefix); ok = false; } } // Check residual blocks - for (auto&& fac_res_pair : fac_2_residual_idx_) - //for (const auto& fac_res_pair : fac_2_residual_idx_) + //for (auto&& fac_res_pair : fac_2_residual_idx_) + for (const auto& fac_res_pair : fac_2_residual_idx_) { // costfunction - residual if (fac_2_costfunction_.find(fac_res_pair.first) == fac_2_costfunction_.end()) { - WOLF_ERROR("CeresManager::check: any factor in fac_2_residual_idx_ is not in fac_2_costfunction_"); + WOLF_ERROR("CeresManager::check: any factor in fac_2_residual_idx_ is not in fac_2_costfunction_ - in ", prefix); ok = false; } - if (fac_2_costfunction_[fac_res_pair.first].get() != ceres_problem_->GetCostFunctionForResidualBlock(fac_res_pair.second)) - //if (fac_2_costfunction_.at(fac_res_pair.first).get() != ceres_problem_->GetCostFunctionForResidualBlock(fac_res_pair.second)) + //if (fac_2_costfunction_[fac_res_pair.first].get() != ceres_problem_->GetCostFunctionForResidualBlock(fac_res_pair.second)) + if (fac_2_costfunction_.at(fac_res_pair.first).get() != ceres_problem_->GetCostFunctionForResidualBlock(fac_res_pair.second)) { - WOLF_ERROR("CeresManager::check: fac_2_costfunction_ and ceres mismatch"); + WOLF_ERROR("CeresManager::check: fac_2_costfunction_ and ceres mismatch - in ", prefix); ok = false; } // factor - residual if (fac_res_pair.first != static_cast<const CostFunctionWrapper*>(ceres_problem_->GetCostFunctionForResidualBlock(fac_res_pair.second))->getFactor()) { - WOLF_ERROR("CeresManager::check: fac_2_residual_idx_ and ceres mismatch"); + WOLF_ERROR("CeresManager::check: fac_2_residual_idx_ and ceres mismatch - in ", prefix); ok = false; } // parameter blocks - state blocks std::vector<double*> param_blocks; ceres_problem_->GetParameterBlocksForResidualBlock(fac_res_pair.second, ¶m_blocks); - auto i = 0; - for (const StateBlockPtr& st : fac_res_pair.first->getStateBlockPtrVector()) + if (param_blocks.size() != fac_res_pair.first->getStateBlockPtrVector().size()) { - if (getAssociatedMemBlockPtr(st) != param_blocks[i]) + WOLF_ERROR("CeresManager::check: number parameter blocks in ceres residual", param_blocks.size(), " different from number of state blocks in factor ", fac_res_pair.first->getStateBlockPtrVector().size(), " - in ", prefix); + ok = false; + } + else + { + auto i = 0; + for (const StateBlockPtr& st : fac_res_pair.first->getStateBlockPtrVector()) { - WOLF_ERROR("CeresManager::check: parameters mismatch"); - ok = false; + if (getAssociatedMemBlockPtr(st) != param_blocks[i]) + { + WOLF_ERROR("CeresManager::check: parameter", i, " mismatch - in ", prefix); + ok = false; + } + i++; } - i++; } } return ok; diff --git a/src/solver/solver_manager.cpp b/src/solver/solver_manager.cpp index b2f1f807821ba42bfb2ff28678c376517b33dc35..5489ea265c0f740a0ddc12abfc2eb3a26f6a5fba 100644 --- a/src/solver/solver_manager.cpp +++ b/src/solver/solver_manager.cpp @@ -11,12 +11,54 @@ SolverManager::SolverManager(const ProblemPtr& _wolf_problem) : assert(_wolf_problem != nullptr && "Passed a nullptr ProblemPtr."); } +SolverManager::~SolverManager() +{ + while (!state_blocks_.empty()) + { + auto sb = state_blocks_.begin()->first; + while(!state_blocks_2_factors_.at(sb).empty()) + removeFactor(state_blocks_2_factors_.at(sb).back()); + removeStateBlock(sb); + } +} + void SolverManager::update() { // Consume notification maps auto fac_notification_map = wolf_problem_->consumeFactorNotificationMap(); auto sb_notification_map = wolf_problem_->consumeStateBlockNotificationMap(); + // Check not-removed factors involved in removed state blocks + // This could happen in multi-threading if update() is called + // after removing a node with state blocks (e.g. a frame) + // and before the related factors are not removed (virally). + for (const auto& sb_notification : sb_notification_map) + { + // check only for removing state blocks (that actually are in registered) + if (sb_notification.second == REMOVE and state_blocks_2_factors_.count(sb_notification.first)) + { + for (auto fac : state_blocks_2_factors_.at(sb_notification.first)) + { + // All involved factors should have been notified to be removed + bool fac_remove_notification = false; + for (const auto& fac_notification : fac_notification_map) + { + if (fac_notification.first == fac and fac_notification.second == REMOVE) + { + fac_remove_notification = true; + break; + } + } + // Notification of remove of this factor not found + if (!fac_remove_notification) + { + WOLF_WARN("****\n****\nSolverManager::update(): The StateBlock ", sb_notification.first, " is notified to be removed but the involved factor ", fac->id(), " is not. Removing it.\n****\n****"); + fac_notification_map.emplace(fac, REMOVE); + } + } + } + } + // REMOVE FACTORS for (auto fac_notification_it = fac_notification_map.begin(); fac_notification_it != fac_notification_map.end(); @@ -24,9 +66,7 @@ void SolverManager::update() { if (fac_notification_it->second == REMOVE) { - assert(check()); removeFactor(fac_notification_it->first); - assert(check()); fac_notification_it = fac_notification_map.erase(fac_notification_it); } else @@ -36,43 +76,15 @@ void SolverManager::update() // ADD/REMOVE STATE BLOCS while ( !sb_notification_map.empty() ) { - StateBlockPtr state_ptr = sb_notification_map.begin()->first; - + // add if (sb_notification_map.begin()->second == ADD) - { - // only add if not added - if (state_blocks_.find(state_ptr) == state_blocks_.end()) - { - assert(check()); - state_blocks_.emplace(state_ptr, state_ptr->getState()); - addStateBlock(state_ptr); - // A state_block is added with its last state_ptr, status and local_param, thus, no update is needed for any of those things -> reset flags - state_ptr->resetStateUpdated(); - state_ptr->resetFixUpdated(); - state_ptr->resetLocalParamUpdated(); - assert(check()); - } - else - { - WOLF_DEBUG("Tried to add a StateBlock that was already added !"); - } - } + addStateBlock(sb_notification_map.begin()->first); + + // remove else - { - // only remove if it exists - if (state_blocks_.find(state_ptr)!=state_blocks_.end()) - { - assert(check()); - removeStateBlock(state_ptr); - state_blocks_.erase(state_ptr); - assert(check()); - } - else - { - WOLF_DEBUG("Tried to remove a StateBlock that was not added !"); - } - } - // next notification + removeStateBlock(sb_notification_map.begin()->first); + + // remove notification sb_notification_map.erase(sb_notification_map.begin()); } @@ -82,9 +94,8 @@ void SolverManager::update() assert(fac_notification_map.begin()->second == ADD && "unexpected factor notification value after all REMOVE have been processed, this should be ADD"); // add factor - assert(check()); addFactor(fac_notification_map.begin()->first); - assert(check()); + // remove notification fac_notification_map.erase(fac_notification_map.begin()); } @@ -96,27 +107,15 @@ void SolverManager::update() // state update if (state_ptr->stateUpdated()) - { - Eigen::VectorXd new_state = state_ptr->getState(); - // We assume the same size for the states in both WOLF and the solver. - std::copy(new_state.data(),new_state.data()+new_state.size(),getAssociatedMemBlockPtr(state_ptr)); - // reset flag - state_ptr->resetStateUpdated(); - } + updateStateBlockState(state_ptr); + // fix update if (state_ptr->fixUpdated()) - { updateStateBlockStatus(state_ptr); - // reset flag - state_ptr->resetFixUpdated(); - } + // local parameterization update if (state_ptr->localParamUpdated()) - { updateStateBlockLocalParametrization(state_ptr); - // reset flag - state_ptr->resetLocalParamUpdated(); - } } } @@ -133,24 +132,130 @@ std::string SolverManager::solve(const ReportVerbosity report_level) // Check assertCheck(); - std::string report = solveImpl(report_level); + std::string report = solveDerived(report_level); // update StateBlocks with optimized state value. /// @todo whatif someone has changed the state notification during opti ?? /// JV: I do not see a problem here, the solver provides the optimal state given the factors, if someone changed the state during optimization, it will be overwritten by the optimal one. - std::map<StateBlockPtr, Eigen::VectorXd>::iterator it = state_blocks_.begin(), - it_end = state_blocks_.end(); - for (; it != it_end; ++it) + //std::map<StateBlockPtr, Eigen::VectorXd>::iterator it = state_blocks_.begin(), + // it_end = state_blocks_.end(); + for (auto& stateblock_statevector : state_blocks_) { // Avoid usuless copies - if (!it->first->isFixed()) - it->first->setState(it->second, false); // false = do not raise the flag state_updated_ + if (!stateblock_statevector.first->isFixed()) + stateblock_statevector.first->setState(stateblock_statevector.second, false); // false = do not raise the flag state_updated_ } return report; } +void SolverManager::addFactor(const FactorBasePtr& fac_ptr) +{ + assert(check("before addFactor")); + + // add to state-factor map + for (const StateBlockPtr& st : fac_ptr->getStateBlockPtrVector()) + { + assert(state_blocks_.count(st) != 0 && "SolverManager::addFactor before adding any involved state block"); + assert(state_blocks_2_factors_.count(st) != 0 && "SolverManager::addFactor before adding any involved state block"); + state_blocks_2_factors_.at(st).push_back(fac_ptr); + } + + addFactorDerived(fac_ptr); + + assert(check("after addFactor")); +} + +void SolverManager::removeFactor(const FactorBasePtr& fac_ptr) +{ + assert(check("before removeFactor")); + + removeFactorDerived(fac_ptr); + + for (const auto& st : fac_ptr->getStateBlockPtrVector()) + { + assert(state_blocks_.count(st) != 0 && "SolverManager::removeFactor missing any involved state block"); + assert(state_blocks_2_factors_.count(st) != 0 && "SolverManager::removeFactor missing any involved state block"); + state_blocks_2_factors_.at(st).remove(fac_ptr); + } + + assert(check("after removeFactor")); +} + +void SolverManager::addStateBlock(const StateBlockPtr& state_ptr) +{ + // Warning if adding an already added state block + if (state_blocks_.count(state_ptr) != 0) + { + WOLF_WARN("Tried to add a StateBlock that was already added !"); + return; + } + + assert(state_blocks_.count(state_ptr) == 0 && "SolverManager::addStateBlock state block already added"); + assert(state_blocks_2_factors_.count(state_ptr) == 0 && "SolverManager::addStateBlock state block already added"); + + assert(check("before addStateBlock")); + + state_blocks_.emplace(state_ptr, state_ptr->getState()); + state_blocks_2_factors_.emplace(state_ptr, FactorBasePtrList()); + + addStateBlockDerived(state_ptr); + + // A state_block is added with its last state_ptr, status and local_param, thus, no update is needed for any of those things -> reset flags + state_ptr->resetStateUpdated(); + state_ptr->resetFixUpdated(); + state_ptr->resetLocalParamUpdated(); + + assert(check("after addStateBlock")); +} + +void SolverManager::removeStateBlock(const StateBlockPtr& state_ptr) +{ + // Warning if removing a missing state block + if (state_blocks_.count(state_ptr) == 0) + { + WOLF_WARN("Tried to remove a StateBlock that was not added !"); + return; + } + + assert(state_blocks_.count(state_ptr) != 0 && "SolverManager::removeStateBlock missing state block"); + assert(state_blocks_2_factors_.count(state_ptr) != 0 && "SolverManager::removeStateBlock missing state block"); + assert(state_blocks_2_factors_.at(state_ptr).empty() && "SolverManager::removeStateBlock removing state block before removing all factors involved"); + + assert(check("before removeStateBlock")); + + removeStateBlockDerived(state_ptr); + + state_blocks_.erase(state_ptr); + state_blocks_2_factors_.erase(state_ptr); + + assert(check("after removeStateBlock")); +} + +void SolverManager::updateStateBlockStatus(const StateBlockPtr& state_ptr) +{ + updateStateBlockStatusDerived(state_ptr); + // reset flag + state_ptr->resetFixUpdated(); +} + +void SolverManager::updateStateBlockState(const StateBlockPtr& state_ptr) +{ + Eigen::VectorXd new_state = state_ptr->getState(); + // We assume the same size for the states in both WOLF and the solver. + std::copy(new_state.data(),new_state.data()+new_state.size(),getAssociatedMemBlockPtr(state_ptr)); + // reset flag + state_ptr->resetStateUpdated(); +} + +void SolverManager::updateStateBlockLocalParametrization(const StateBlockPtr& state_ptr) +{ + updateStateBlockLocalParametrizationDerived(state_ptr); + // reset flag + state_ptr->resetLocalParamUpdated(); +} + Eigen::VectorXd& SolverManager::getAssociatedMemBlock(const StateBlockPtr& state_ptr) { auto it = state_blocks_.find(state_ptr); @@ -161,15 +266,15 @@ Eigen::VectorXd& SolverManager::getAssociatedMemBlock(const StateBlockPtr& state return it->second; } -//const double* SolverManager::getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr) const -//{ -// auto it = state_blocks_.find(state_ptr); -// -// if (it == state_blocks_.end()) -// throw std::runtime_error("Tried to retrieve the memory block of an unregistered StateBlock !"); -// -// return it->second.data(); -//} +const double* SolverManager::getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr) const +{ + auto it = state_blocks_.find(state_ptr); + + if (it == state_blocks_.end()) + throw std::runtime_error("Tried to retrieve the memory block of an unregistered StateBlock !"); + + return it->second.data(); +} double* SolverManager::getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr) { @@ -183,7 +288,7 @@ double* SolverManager::getAssociatedMemBlockPtr(const StateBlockPtr& state_ptr) bool SolverManager::isStateBlockRegistered(const StateBlockPtr& state_ptr) { - return state_blocks_.find(state_ptr) != state_blocks_.end() && isStateBlockRegisteredDerived(state_ptr); + return state_blocks_.count(state_ptr) ==1 && isStateBlockRegisteredDerived(state_ptr); } bool SolverManager::isFactorRegistered(const FactorBasePtr& fac_ptr) const diff --git a/test/dummy/solver_manager_dummy.h b/test/dummy/solver_manager_dummy.h index c84e3954b4c680963bbe1425cd883be208949d48..c31a2ff0179d95fb7c068287b95a4a2f640faafa 100644 --- a/test/dummy/solver_manager_dummy.h +++ b/test/dummy/solver_manager_dummy.h @@ -61,36 +61,36 @@ class SolverManagerDummy : public SolverManager SizeStd iterations() { return 1; } double initialCost() { return double(1); } double finalCost() { return double(0); } - virtual bool check() override {return true;} + virtual bool check(std::string prefix="") const override {return true;} protected: - virtual std::string solveImpl(const ReportVerbosity report_level){ return std::string("");}; - virtual void addFactor(const FactorBasePtr& fac_ptr) + virtual std::string solveDerived(const ReportVerbosity report_level){ return std::string("");}; + virtual void addFactorDerived(const FactorBasePtr& fac_ptr) { factors_.push_back(fac_ptr); }; - virtual void removeFactor(const FactorBasePtr& fac_ptr) + virtual void removeFactorDerived(const FactorBasePtr& fac_ptr) { factors_.remove(fac_ptr); }; - virtual void addStateBlock(const StateBlockPtr& state_ptr) + virtual void addStateBlockDerived(const StateBlockPtr& state_ptr) { state_block_fixed_[state_ptr] = state_ptr->isFixed(); state_block_local_param_[state_ptr] = state_ptr->getLocalParametrization(); }; - virtual void removeStateBlock(const StateBlockPtr& state_ptr) + virtual void removeStateBlockDerived(const StateBlockPtr& state_ptr) { state_block_fixed_.erase(state_ptr); state_block_local_param_.erase(state_ptr); }; - virtual void updateStateBlockStatus(const StateBlockPtr& state_ptr) + virtual void updateStateBlockStatusDerived(const StateBlockPtr& state_ptr) { state_block_fixed_[state_ptr] = state_ptr->isFixed(); }; - virtual void updateStateBlockLocalParametrization(const StateBlockPtr& state_ptr) + virtual void updateStateBlockLocalParametrizationDerived(const StateBlockPtr& state_ptr) { if (state_ptr->getLocalParametrization() == nullptr) state_block_local_param_.erase(state_ptr);