diff --git a/src/ceres_wrapper/ceres_manager.cpp b/src/ceres_wrapper/ceres_manager.cpp index 94343edc2ed198c82a041d5a930be3459d0e4bd0..fa590b0e32c42fb4c62d622cfdd9f4a3939b6af1 100644 --- a/src/ceres_wrapper/ceres_manager.cpp +++ b/src/ceres_wrapper/ceres_manager.cpp @@ -2,9 +2,10 @@ namespace wolf { -CeresManager::CeresManager(Problem* _wolf_problem, ceres::Problem::Options _options) : +CeresManager::CeresManager(Problem* _wolf_problem, ceres::Problem::Options _options, const bool _use_wolf_cost_functions) : ceres_problem_(new ceres::Problem(_options)), - wolf_problem_(_wolf_problem) + wolf_problem_(_wolf_problem), + use_wolf_auto_diff_(_use_wolf_cost_functions) { ceres::Covariance::Options covariance_options; covariance_options.algorithm_type = ceres::SUITE_SPARSE_QR;//ceres::DENSE_SVD; @@ -170,7 +171,7 @@ void CeresManager::computeCovariances(CovarianceBlocksToBeComputed _blocks) std::cout << "WARNING: Couldn't compute covariances!" << std::endl; } -void CeresManager::update(const bool _self_auto_diff, const bool _apply_loss_function) +void CeresManager::update(const bool _apply_loss_function) { //std::cout << "CeresManager: updating... getConstraintRemoveList()->size()" << wolf_problem_->getConstraintRemoveList()->size() << std::endl; @@ -201,14 +202,14 @@ void CeresManager::update(const bool _self_auto_diff, const bool _apply_loss_fun // ADD CONSTRAINTS while (!wolf_problem_->getConstraintAddList()->empty()) { - addConstraint(wolf_problem_->getConstraintAddList()->front(), _self_auto_diff, _apply_loss_function); + addConstraint(wolf_problem_->getConstraintAddList()->front(), _apply_loss_function); wolf_problem_->getConstraintAddList()->pop_front(); } } -void CeresManager::addConstraint(ConstraintBase* _corr_ptr, const bool _self_auto_diff, const bool _apply_loss) +void CeresManager::addConstraint(ConstraintBase* _corr_ptr, const bool _apply_loss) { - id_2_costfunction_[_corr_ptr->nodeId()] = createCostFunction(_corr_ptr, _self_auto_diff); + id_2_costfunction_[_corr_ptr->nodeId()] = createCostFunction(_corr_ptr); if (_apply_loss) id_2_residual_idx_[_corr_ptr->nodeId()] = ceres_problem_->AddResidualBlock(id_2_costfunction_[_corr_ptr->nodeId()], new ceres::CauchyLoss(0.5), _corr_ptr->getStateBlockPtrVector()); @@ -224,8 +225,9 @@ void CeresManager::removeConstraint(const unsigned int& _corr_id) //std::cout << "residual block removed!" << std::endl; id_2_residual_idx_.erase(_corr_id); // std::cout << "deleting cost function" << std::endl; -// assert(id_2_costfunction_.find(_corr_id) != id_2_costfunction_.end()); -// delete id_2_costfunction_[_corr_id]; + assert(id_2_costfunction_.find(_corr_id) != id_2_costfunction_.end()); + if (use_wolf_auto_diff_) + delete id_2_costfunction_[_corr_id]; // std::cout << "cost function deleted!" << std::endl; id_2_costfunction_.erase(_corr_id); } @@ -271,7 +273,7 @@ void CeresManager::updateStateBlockStatus(StateBlock* _st_ptr) ceres_problem_->SetParameterBlockVariable(_st_ptr->getPtr()); } -ceres::CostFunction* CeresManager::createCostFunction(ConstraintBase* _corrPtr, const bool _use_wolf_auto_diff) +ceres::CostFunction* CeresManager::createCostFunction(ConstraintBase* _corrPtr) { //std::cout << "adding ctr " << _corrPtr->nodeId() << std::endl; @@ -281,11 +283,11 @@ ceres::CostFunction* CeresManager::createCostFunction(ConstraintBase* _corrPtr, // auto jacobian else if (_corrPtr->getJacobianMethod() == JAC_AUTO) - return createAutoDiffCostFunction(_corrPtr, _use_wolf_auto_diff); + return createAutoDiffCostFunction(_corrPtr, use_wolf_auto_diff_); // numeric jacobian else if (_corrPtr->getJacobianMethod() == JAC_NUMERIC) - return createNumericDiffCostFunction(_corrPtr, _use_wolf_auto_diff); + return createNumericDiffCostFunction(_corrPtr, use_wolf_auto_diff_); else throw std::invalid_argument( "Bad Jacobian Method!" ); diff --git a/src/ceres_wrapper/ceres_manager.h b/src/ceres_wrapper/ceres_manager.h index 0a2e9dbd19a54ecf8875f83c14d32bd289f8ebf2..813a437a07c2a66e6d6bcf28631a62c7cf898611 100644 --- a/src/ceres_wrapper/ceres_manager.h +++ b/src/ceres_wrapper/ceres_manager.h @@ -42,9 +42,10 @@ class CeresManager ceres::Problem* ceres_problem_; ceres::Covariance* covariance_; Problem* wolf_problem_; + bool use_wolf_auto_diff_; public: - CeresManager(Problem* _wolf_problem, ceres::Problem::Options _options); + CeresManager(Problem* _wolf_problem, ceres::Problem::Options _options, const bool _use_wolf_cost_functions = true); ~CeresManager(); @@ -52,9 +53,9 @@ class CeresManager void computeCovariances(CovarianceBlocksToBeComputed _blocks = ROBOT_LANDMARKS); - void update(const bool _self_auto_diff = true, const bool _apply_loss_function = false); + void update(const bool _apply_loss_function = false); - void addConstraint(ConstraintBase* _corr_ptr, const bool _self_auto_diff, const bool _apply_loss); + void addConstraint(ConstraintBase* _corr_ptr, const bool _apply_loss); void removeConstraint(const unsigned int& _corr_idx); @@ -66,7 +67,7 @@ class CeresManager void updateStateBlockStatus(StateBlock* _st_ptr); - ceres::CostFunction* createCostFunction(ConstraintBase* _corrPtr, const bool _self_auto_diff); + ceres::CostFunction* createCostFunction(ConstraintBase* _corrPtr); }; } // namespace wolf diff --git a/src/examples/test_autodiff.cpp b/src/examples/test_autodiff.cpp index e7dd1649541216df66432f15c9e11da601eef934..3342250b38aeeb5086d216c8cb0f45579b611e55 100644 --- a/src/examples/test_autodiff.cpp +++ b/src/examples/test_autodiff.cpp @@ -148,8 +148,8 @@ int main(int argc, char** argv) problem_options.cost_function_ownership = ceres::DO_NOT_TAKE_OWNERSHIP; problem_options.loss_function_ownership = ceres::TAKE_OWNERSHIP;//ceres::DO_NOT_TAKE_OWNERSHIP; problem_options.local_parameterization_ownership = ceres::DO_NOT_TAKE_OWNERSHIP; - CeresManager* ceres_manager_ceres = new CeresManager(wolf_manager_ceres->getProblemPtr(), problem_options); - CeresManager* ceres_manager_wolf = new CeresManager(wolf_manager_wolf->getProblemPtr(), problem_options); + CeresManager* ceres_manager_ceres = new CeresManager(wolf_manager_ceres->getProblemPtr(), problem_options, false); + CeresManager* ceres_manager_wolf = new CeresManager(wolf_manager_wolf->getProblemPtr(), problem_options, true); std::ofstream log_file, landmark_file; //output file //std::cout << "START TRAJECTORY..." << std::endl; @@ -225,8 +225,8 @@ int main(int argc, char** argv) std::cout << "UPDATING CERES..." << std::endl; t1 = clock(); // update state units and constraints in ceres - ceres_manager_ceres->update(false); - ceres_manager_wolf->update(true); + ceres_manager_ceres->update(); + ceres_manager_wolf->update(); mean_times(2) += ((double) clock() - t1) / CLOCKS_PER_SEC; // SOLVE OPTIMIZATION --------------------------- diff --git a/src/examples/test_wolf_autodiffwrapper.cpp b/src/examples/test_wolf_autodiffwrapper.cpp index b0848e1b51e540597db2e9afc1ed087dd54f6182..6bbee2999d98789cfb40c18c3630fa7f994f5ea2 100644 --- a/src/examples/test_wolf_autodiffwrapper.cpp +++ b/src/examples/test_wolf_autodiffwrapper.cpp @@ -58,8 +58,8 @@ int main(int argc, char** argv) problem_options.cost_function_ownership = ceres::DO_NOT_TAKE_OWNERSHIP; problem_options.loss_function_ownership = ceres::TAKE_OWNERSHIP; problem_options.local_parameterization_ownership = ceres::DO_NOT_TAKE_OWNERSHIP; - CeresManager* ceres_manager_ceres_diff = new CeresManager(wolf_problem_ceres_diff, problem_options); - CeresManager* ceres_manager_wolf_diff = new CeresManager(wolf_problem_wolf_diff, problem_options); + CeresManager* ceres_manager_ceres_diff = new CeresManager(wolf_problem_ceres_diff, problem_options, false); + CeresManager* ceres_manager_wolf_diff = new CeresManager(wolf_problem_wolf_diff, problem_options, true); @@ -277,10 +277,10 @@ int main(int argc, char** argv) // BUILD SOLVER PROBLEM std::cout << "updating ceres..." << std::endl; t1 = clock(); - ceres_manager_ceres_diff->update(false); + ceres_manager_ceres_diff->update(); double t_update_ceres = ((double) clock() - t1) / CLOCKS_PER_SEC; t1 = clock(); - ceres_manager_wolf_diff->update(true); + ceres_manager_wolf_diff->update(); double t_update_wolf = ((double) clock() - t1) / CLOCKS_PER_SEC; std::cout << "updated!" << std::endl;