Commit 4e1e85eb authored by asuarez's avatar asuarez
Browse files

refactored types to use smart pointers

parent e4f9bc30
# driver source files
SET(sources imagine-planner.cpp types.cpp queries.cpp effects.cpp domains.cpp parsing.cpp)
SET(sources types.cpp) # queries.cpp effects.cpp domains.cpp parsing.cpp imagine-planner.cpp)
# application header files
SET(headers imagine-planner.h types.h queries.h effects.h domains.h parsing.h)
SET(headers types.h) # queries.h effects.h domains.h parsing.h imagine-planner.h)
# Ades
SET(CMAKE_MODULE_PATH /usr/local/lib/cmake/ades)
FIND_PACKAGE(ADES)
......
......@@ -47,80 +47,99 @@ std::ostream& operator<<(std::ostream& os, const std::unordered_set<Predicate,
CONTAINER_TO_STREAM(ctn, os);
}
BOOST_AUTO_TEST_CASE(dummy)
{
BOOST_CHECK(true);
}
BOOST_AUTO_TEST_CASE(literal_test)
{
TermWrapper l1("atom1");
TermWrapper l1_("atom1");
TermWrapper l2("atom2");
TermWrapper l3("Var1");
TermWrapper l4("_");
INFO("l1=" << l1);
INFO("l1_=" << l1_);
INFO("l2=" << l2);
INFO("l3=" << l3);
BOOST_CHECK_EQUAL(l1, l1_);
BOOST_CHECK_LE(l1, l1_);
BOOST_CHECK_NE(l1, l2);
BOOST_CHECK_LT(l1, l2);
BOOST_CHECK_LE(l1, l2);
BOOST_CHECK_GT(l2, l1);
BOOST_CHECK_GE(l2, l1);
BOOST_CHECK(l1.is_ground());
BOOST_CHECK(l2.is_ground());
BOOST_CHECK(not l3.is_ground());
BOOST_CHECK(not l4.is_ground());
BOOST_CHECK_EQUAL(l1.hash(), l1_.hash());
BOOST_WARN_NE(l1.hash(), l2.hash()); // Very likely the hashes are different
TermFactory T;
auto l1 = T("atom1");
auto l1_ = T("atom1");
auto l2 = T("atom2");
auto l3 = T("Var1");
auto l4 = T("_");
INFO("l1=" << *l1);
INFO("l1_=" << *l1_);
INFO("l2=" << *l2);
INFO("l3=" << *l3);
BOOST_CHECK_EQUAL(*l1, *l1_);
BOOST_CHECK_LE(*l1, *l1_);
BOOST_CHECK_NE(*l1, *l2);
BOOST_CHECK_LT(*l1, *l2);
BOOST_CHECK_LE(*l1, *l2);
BOOST_CHECK_GT(*l2, *l1);
BOOST_CHECK_GE(*l2, *l1);
BOOST_CHECK(l1->is_ground());
BOOST_CHECK(l2->is_ground());
BOOST_CHECK(not l3->is_ground());
BOOST_CHECK(not l4->is_ground());
BOOST_CHECK_EQUAL(l1->hash(), l1_->hash());
BOOST_WARN_NE(l1->hash(), l2->hash()); // Very likely the hashes are different
}
BOOST_AUTO_TEST_CASE(number_test)
{
TermWrapper n1(0.5);
TermWrapper n1_(0.5+2e-10);
TermWrapper n1__(0.5+5e-9);
TermWrapper n2(1);
TermWrapper n2_(1+2e-10);
TermWrapper n2__(1+5e-9);
INFO("n1=" << n1);
INFO("n1_=" << n1_);
INFO("n1__=" << n1__);
INFO("n2=" << n2);
INFO("n2_=" << n2_);
INFO("n2__=" << n2__);
TermFactory T;
auto n1 = T(0.5);
auto n1_ = T(0.5+2e-10);
auto n1__ = T(0.5+5e-9);
auto n2 = T(1);
auto n2_ = T(1+2e-10);
auto n2__ = T(1+5e-9);
INFO("n1=" << *n1);
INFO("n1_=" << *n1_);
INFO("n1__=" << *n1__);
INFO("n2=" << *n2);
INFO("n2_=" << *n2_);
INFO("n2__=" << *n2__);
// Check number specific method's
BOOST_CHECK_CLOSE(n1.get_number_term()->get_number(), 0.5, 0.00001);
BOOST_CHECK_CLOSE(n2.get_number_term()->get_number(), 1, 0.00001);
BOOST_CHECK(not n1.get_number_term()->is_int());
BOOST_CHECK(not n2__.get_number_term()->is_int());
BOOST_CHECK(n2.get_number_term()->is_int());
BOOST_CHECK(n2_.get_number_term()->is_int());
BOOST_CHECK_CLOSE(n1->get_number(), 0.5, 0.00001);
BOOST_CHECK_CLOSE(n2->get_number(), 1, 0.00001);
BOOST_CHECK(not n1->is_int());
BOOST_CHECK(not n2__->is_int());
BOOST_CHECK(n2->is_int());
BOOST_CHECK(n2_->is_int());
// Check == operator
BOOST_CHECK_EQUAL(n1, n1_);
BOOST_CHECK_EQUAL(n2, n2_);
BOOST_CHECK_NE(n1, n1__);
BOOST_CHECK_NE(n2, n2__);
BOOST_CHECK_NE(n1, n2);
BOOST_CHECK_EQUAL(*n1, *n1_);
BOOST_CHECK_EQUAL(*n2, *n2_);
BOOST_CHECK_NE(*n1, *n1__);
BOOST_CHECK_NE(*n2, *n2__);
BOOST_CHECK_NE(*n1, *n2);
// Check < operator
BOOST_CHECK_LT(n1, n2);
BOOST_CHECK_LT(n1, n1__);
BOOST_CHECK_GE(n1_, n1);
BOOST_CHECK_GT(n2, n1);
BOOST_CHECK_LT(*n1, *n2);
BOOST_CHECK_LT(*n1, *n1__);
BOOST_CHECK_GE(*n1_, *n1);
BOOST_CHECK_GT(*n2, *n1);
// Check hash correctness
BOOST_CHECK_EQUAL(n1.hash(), n1.hash()); // non-ints should have the same hash
BOOST_CHECK_EQUAL(n2.hash(), n2_.hash());
BOOST_WARN_NE(n2.hash(), n2__.hash()); // Very likely the hashes are
// different
BOOST_CHECK_EQUAL(n1->hash(), n1->hash()); // non-ints should have the same hash
BOOST_CHECK_EQUAL(n2->hash(), n2_->hash());
BOOST_WARN_NE(n2->hash(), n2__->hash()); // Very likely the hashes are
// different
}
BOOST_AUTO_TEST_CASE(mixed_test)
{
TermWrapper t1(0);
TermWrapper t2("atom");
INFO("t1=" << t1);
INFO("t2=" << t2);
BOOST_CHECK_NE(t1, t2);
BOOST_CHECK_LT(t2, t1);
BOOST_WARN_NE(t1.hash(), t2.hash());
TermFactory T;
auto t1 = T(0);
auto t2 = T("atom");
INFO("t1=" << *t1);
INFO("t2=" << *t2);
BOOST_CHECK_NE(*t1, *t2);
BOOST_CHECK_LT(*t2, *t1);
BOOST_WARN_NE(t1->hash(), t2->hash());
}
BOOST_AUTO_TEST_CASE(term_v)
{
TermFactory T;
TermV v1 = T.vector();
TermV v2 = T.vector("a", "b");
for (auto& t : v1)
INFO(*t);
for (auto& t : v2)
INFO(*t);
}
BOOST_AUTO_TEST_CASE(predicate_test)
......@@ -207,14 +226,15 @@ BOOST_AUTO_TEST_CASE(unordered_set_test)
BOOST_AUTO_TEST_CASE(substitution_test)
{
TermFactory T;
Predicate p1("p", "X", "a");
Predicate p2("p", "x", "y");
Predicate p3("p", "X", "Y");
Predicate p4("q", "Y");
Predicate p5("q", "Z");
Predicate p6("r");
Substitution sigma({{"X", TermWrapper("a")}, {"Y", TermWrapper("b")}});
Substitution sigma_({{"W", TermWrapper("c")}});
Substitution sigma({{"X", T("a")}, {"Y", T("b")}});
Substitution sigma_({{"W", T("c")}});
INFO("p1 = " << p1);
INFO("p2 = " << p2);
INFO("p3 = " << p3);
......
......@@ -41,9 +41,8 @@ bool LiteralTerm::operator<(const Term& term) const
return type() < term.type();
}
// NumberTerm
double NumberTerm::EPSILON_ = 1E-9;
// NumberTerm
NumberTerm::NumberTerm(double number) : number_(number) {}
......@@ -76,51 +75,16 @@ bool NumberTerm::operator<(const Term& term) const
return type() < term.type();
}
// TermWrapper
TermWrapper::TermWrapper(const TermWrapper& term) : term_(term.term_->clone()) {}
TermWrapper::TermWrapper(const Term& term) : term_(term.clone()) {}
TermWrapper::TermWrapper(const std::string& literal) :
term_(new LiteralTerm(literal)) {}
TermWrapper::TermWrapper(double number) : term_(new NumberTerm(number)) {}
const LiteralTerm* TermWrapper::get_literal_term() const
{
return dynamic_cast<const LiteralTerm*>(term_);
}
const NumberTerm* TermWrapper::get_number_term() const
{
return dynamic_cast<const NumberTerm*>(term_);
}
TermWrapper& TermWrapper::operator=(const TermWrapper& term)
{
delete term_;
term_ = term.term_->clone();
return *this;
}
TermWrapper::~TermWrapper()
{
delete term_;
}
double NumberTerm::EPSILON_ = 1E-9;
TermV create_term_v() { return TermV(); }
// Predicate
Predicate::Predicate(const std::string& name, const TermV& arguments) :
name_(name), arguments_(arguments) {}
bool Predicate::is_ground() const
{
for (const TermWrapper& term : arguments_)
for (const Term::Ptr& term : arguments_)
{
if (not term.is_ground()) return false;
if (not term->is_ground()) return false;
}
return true;
}
......@@ -137,10 +101,10 @@ std::string Predicate::to_str() const
std::ostringstream oss;
oss << name_ << '(';
bool first = true;
for (const TermWrapper& term : arguments_)
for (const Term::Ptr& term : arguments_)
{
if (not first) oss << ',';
oss << term;
oss << *term;
first = false;
}
oss << ')';
......@@ -152,9 +116,9 @@ std::size_t Predicate::hash() const
std::hash<std::string> hasher_s;
std::hash<Hashable> hasher_h;
std::size_t acc = hasher_s(name_);
for (const TermWrapper& term : arguments_)
for (const Term::Ptr& term : arguments_)
{
acc ^= hasher_h(term) + 0x9e3779b9 + (acc<<6) + (acc>>2);
acc ^= hasher_h(*term) + 0x9e3779b9 + (acc<<6) + (acc>>2);
}
return acc;
}
......@@ -165,7 +129,7 @@ bool Predicate::operator==(const Predicate& predicate) const
if (arguments_.size() != predicate.arguments_.size()) return false;
for (std::size_t idx = 0; idx < arguments_.size(); ++idx)
{
if (arguments_[idx] != predicate.arguments_[idx]) return false;
if (*arguments_[idx] != *predicate.arguments_[idx]) return false;
}
return true;
}
......@@ -178,18 +142,14 @@ bool Predicate::operator<(const Predicate& predicate) const
if (arguments_.size() > predicate.arguments_.size()) return false;
for (std::size_t idx = 0; idx < arguments_.size(); ++idx)
{
if (arguments_[idx] < predicate.arguments_[idx]) return true;
if (arguments_[idx] > predicate.arguments_[idx]) return false;
if (*arguments_[idx] < *predicate.arguments_[idx]) return true;
if (*arguments_[idx] > *predicate.arguments_[idx]) return false;
}
return false;
}
// Implementation of Substitution
Substitution::Substitution() {}
Substitution::Substitution(const std::map<std::string, TermWrapper>& sigma)
: sigma_(sigma) {}
// Implementation of Substitution
std::string Substitution::to_str() const
{
......@@ -199,11 +159,11 @@ std::string Substitution::to_str() const
{
oss << '{';
bool first = true;
for (auto entry : sigma_)
for (const auto& entry : sigma_)
{
if (not first) oss << ',';
oss << "\n ";
oss << entry.first << " -> " << entry.second;
oss << entry.first << " -> " << *entry.second;
first = false;
}
oss << "\n}";
......@@ -216,16 +176,16 @@ bool Substitution::has(const std::string& varname) const
return (bool)sigma_.count(varname);
}
const Term* Substitution::get(const std::string& varname) const
Term::Ptr Substitution::get(const std::string& varname) const
{
auto it = sigma_.find(varname);
if (it != sigma_.end()) return it->second.get_term();
if (it != sigma_.end()) return it->second;
return nullptr;
}
void Substitution::put(const std::string& varname, const TermWrapper& term)
void Substitution::put(const std::string& varname, const Term::Ptr& value)
{
sigma_.insert(std::make_pair(varname, term));
sigma_.insert(std::make_pair(varname, value));
}
void Substitution::remove(const std::string& varname)
......@@ -235,7 +195,7 @@ void Substitution::remove(const std::string& varname)
void Substitution::operator+=(const Substitution& other)
{
for (auto entry : other)
for (const auto& entry : other)
{
sigma_.insert(entry);
}
......@@ -243,7 +203,7 @@ void Substitution::operator+=(const Substitution& other)
void Substitution::operator-=(const Substitution& other)
{
for (auto entry : other)
for (const auto& entry : other)
{
sigma_.erase(entry.first);
}
......@@ -258,7 +218,7 @@ Substitution Substitution::operator+(const Substitution& other) const
Substitution Substitution::operator-(const Substitution& other) const
{
Substitution ret;
Substitution ret(*this);
ret -= other;
return ret;
}
......@@ -267,34 +227,17 @@ Predicate Substitution::operator()(const Predicate& predicate) const
{
TermV arguments;
arguments.reserve(predicate.arity());
for (const TermWrapper& term : predicate.get_arguments())
for (const Term::Ptr& term : predicate.get_arguments())
{
const Term* value = get(term.to_str());
if (not value) value = term.get_term();
arguments.push_back(TermWrapper(*value));
Term::Ptr value = get(term->to_str());
if (not value) value = term;
arguments.push_back(value);
}
return Predicate(predicate.get_name(), arguments);
}
// Implementation of PlanState
PlanState::PlanState() {}
PlanState::PlanState(const std::set<Predicate>& predicates) :
predicates_(predicates) {}
std::set<TermWrapper> PlanState::symbols() const
{
std::set<TermWrapper> sym;
for (const Predicate& pred : predicates_)
{
for (const TermWrapper& arg : pred.get_arguments())
{
if (arg.get_literal_term()) sym.insert(arg);
}
}
return sym;
}
// Implementation of PlanState
std::string PlanState::to_str() const
{
......@@ -313,12 +256,24 @@ std::string PlanState::to_str() const
std::size_t PlanState::hash() const
{
std::size_t acc = 0;
for (const Predicate& pred : predicates_)
if (not hash_cached_)
{
acc += pred.hash();
std::size_t acc = 0;
for (const Predicate& pred : predicates_)
{
// since we are using an ordered set, two equal sets result in the same
// iteration order, so we can use this hash combination.
acc ^= pred.hash() + 0x9e3779b9 + (acc<<6) + (acc>>2);
// for unordered sets...
//acc += pred.hash();
}
// The use of const_cast here is appropriate since we are just storing the
// hash value for faster retrieval in the future.
PlanState* modifiable = const_cast<PlanState*>(this);
modifiable->hash_cached_ = true;
modifiable->hash_ = acc;
}
return acc;
return hash_;
}
bool PlanState::subset_of(const PlanState& other) const
......@@ -331,15 +286,10 @@ bool PlanState::subset_of(const PlanState& other) const
return true;
}
// Implementation of GoalSpecification's methods
GoalSpecification::GoalSpecification() {}
GoalSpecification::GoalSpecification(const std::vector<Predicate>& must_appear,
const std::vector<Predicate>& cannot_appear)
: must_appear_(must_appear), cannot_appear_(cannot_appear) {}
// Implementation of GoalCondition's methods
std::string GoalSpecification::to_str() const
std::string GoalCondition::to_str() const
{
std::ostringstream oss;
bool first = true;
......@@ -362,7 +312,7 @@ std::string GoalSpecification::to_str() const
return oss.str();
}
bool GoalSpecification::operator()(const PlanState& state) const
bool GoalCondition::operator()(const PlanState& state) const
{
for (const Predicate& pred : must_appear_)
{
......
#ifndef _LIBIMAGINE_PLANNER_TYPES_H_
#define _LIBIMAGINE_PLANNER_TYPES_H_
// Let's group some functions and typedefs that are common to all the Types
// defined here in a macro for convenience.
#define TYPES_COMMON(Class)\
virtual std::string type() const override { return #Class; }\
virtual Term* clone() const override { return new Class(*this); };
#include <exception>
#include <functional>
#include <map>
......@@ -25,7 +19,6 @@ class Hashable;
class Term;
class LiteralTerm;
class NumberTerm;
class TermWrapper;
class Predicate;
class Substitution;
class PlanState;
......@@ -33,13 +26,15 @@ class GoalSpecification;
class ImaginePlannerException : public std::exception
{
private:
std::string msg_;
public:
explicit ImaginePlannerException(const std::string& msg) : msg_(msg) {}
virtual const char* what() const throw() { return msg_.c_str(); }
private:
std::string msg_;
};
/**
......@@ -60,6 +55,10 @@ class Stringifiable
virtual ~Stringifiable() {}
};
/**
* @brief Abstract class (interface-like) that declares the hash method.
*/
class Hashable
{
public:
......@@ -68,11 +67,17 @@ class Hashable
virtual ~Hashable() {}
};
/**
* @brief General term (Either a Literal or a Number)
*/
class Term : public Stringifiable,
public Hashable
{
public:
typedef std::shared_ptr<const Term> Ptr;
virtual std::string type() const =0;
virtual bool is_ground() const =0;
......@@ -95,19 +100,25 @@ class Term : public Stringifiable,
virtual bool operator>=(const Term& term) const { return term <= *this; }
virtual Term* clone() const =0;
virtual Ptr clone() const =0;
virtual ~Term() {}
};
/**
* @brief A string Atom
*/
class LiteralTerm : public Term
{
private:
std::string literal_;
public:
typedef std::shared_ptr<LiteralTerm> Ptr;
LiteralTerm(const std::string& literal);
virtual std::string type() const override { return "LiteralTerm"; }
virtual std::string to_str() const override { return literal_; }
virtual std::size_t hash() const override;
......@@ -120,19 +131,27 @@ class LiteralTerm : public Term
virtual bool operator<(const Term& term) const override;
TYPES_COMMON(LiteralTerm)
virtual Term::Ptr clone() const override { return Term::Ptr(new LiteralTerm(literal_)); }
private:
std::string literal_;
};
/**
* @brief A numeric (double) atom
*/
class NumberTerm : public Term
{
private:
static double EPSILON_;
double number_;
public:
typedef std::shared_ptr<NumberTerm> Ptr;
NumberTerm(double number);
virtual std::string type() const override { return "NumberTerm"; }
std::string to_str() const override { return std::to_string(number_); }
virtual std::size_t hash() const override;
......@@ -147,98 +166,61 @@ class NumberTerm : public Term
virtual bool operator<(const Term& term) const override;
TYPES_COMMON(NumberTerm)
};
virtual Term::Ptr clone() const override { return Term::Ptr(new NumberTerm(number_)); }
class TermWrapper : public Stringifiable,
public Hashable
{
private:
const Term* term_;
public:
TermWrapper(const TermWrapper& term);
TermWrapper(const Term& term);
TermWrapper(const std::string& literal);
TermWrapper(double number);
const Term* get_term() const { return term_; }
const LiteralTerm* get_literal_term() const;
const NumberTerm* get_number_term() const;
bool is_ground() const { return term_->is_ground(); }