Commit ef970f7e authored by Alejandro Suarez Hernandez's avatar Alejandro Suarez Hernandez
Browse files

StateQuery tentative interface. Using Boost test suite.

parent 18bd3c05
......@@ -2,9 +2,10 @@
SET(sources imagine-planner.cpp types.cpp)
# application header files
SET(headers imagine-planner.h types.h)
FIND_PACKAGE(Boost COMPONENTS system filesystem unit_test_framework REQUIRED)
# locate the necessary dependencies
# add the necessary include directories
INCLUDE_DIRECTORIES(.)
INCLUDE_DIRECTORIES(. ${Boost_INCLUDE_DIR})
# create the shared library
ADD_LIBRARY(imagine-planner SHARED ${sources})
# link necessary libraries
......
ADD_EXECUTABLE(types_test types_test.cpp)
TARGET_LINK_LIBRARIES(types_test imagine-planner)
# create an example application
ADD_EXECUTABLE(imagine-planner_test imagine-planner_test.cpp)
# link necessary libraries
TARGET_LINK_LIBRARIES(imagine-planner_test imagine-planner)
ADD_EXECUTABLE(types_test types_test.cpp)
TARGET_LINK_LIBRARIES(types_test
imagine-planner
${Boost_FILESYSTEM_LIBRARY}
${Boost_SYSTEM_LIBRARY}
${Boost_UNIT_TEST_FRAMEWORK_LIBRARY})
#define BOOST_TEST_MODULE Types Test
#define BOOST_TEST_DYN_LINK
#include <boost/test/unit_test.hpp>
#include "types.h"
#include <iostream>
#include <set>
#include <unordered_set>
using namespace imagine_planner;
#if 1
#define INFO(x)\
std::cout << "\e[1m\e[32m\e[4m" << __PRETTY_FUNCTION__ << " [" << __LINE__ << "]" << "\e[24m: " << x << "\e[0m" << std::endl;
#else
#define INFO(x)
#endif
#define CONTAINER_TO_STREAM(ctn, os)\
bool first = true;\
os << '{';\
......@@ -37,55 +48,213 @@ std::ostream& operator<<(std::ostream& os, const std::unordered_set<Predicate, s
CONTAINER_TO_STREAM(ctn, os);
}
int main(int argc, char* argv[])
BOOST_AUTO_TEST_CASE(atom_test)
{
auto atom1 = A("atom1");
auto atom1_ = A("atom1");
auto atom2 = A("atom2");
INFO("atom1=" << *atom1);
INFO("atom1_=" << *atom1_);
INFO("atom2=" << *atom2);
// Check name method
BOOST_CHECK_EQUAL(atom1->name(), "atom1");
// Check == operator
BOOST_CHECK_EQUAL(*atom1, *atom1_);
BOOST_CHECK_NE(*atom1, *atom2);
// Check < operator
BOOST_CHECK_LT(*atom1, *atom2);
BOOST_CHECK(not(*atom2 < *atom1));
// Check hash correctness
BOOST_CHECK_EQUAL(atom1->hash(), atom1_->hash());
BOOST_CHECK_NE(atom1->hash(), atom2->hash());
}
BOOST_AUTO_TEST_CASE(number_test)
{
std::hash<Term> h;
Term::Ptr atom(new Atom("my_atom"));
Term::Ptr atom_(new Atom("another_atom"));
std::cout << "Testing stream operator for Atom: " << *atom << std::endl;
std::cout << "Testing hash function for Atom: " << h(*atom) << std::endl;
std::cout << "Testing == operator (equal): " << b2s(*atom == *atom) << std::endl;
std::cout << "Testing == operator (not equal): " << b2s(*atom == *atom_) << std::endl;
std::cout << "type(atom): " << atom->type() << std::endl;
std::cout << "Testing < operator (equal): " << b2s(*atom < *atom) << std::endl;
std::cout << "Testing < operator (lt): " << b2s(*atom_ < *atom) << std::endl;
std::cout << "Testing clone() method: " << *(atom->clone()) << std::endl;
std::cout << std::endl;
Term::Ptr term_v(new TermV{Term::Ptr(new Atom("a1")), Term::Ptr(new Atom("a2"))});
Term::Ptr term_v_(new TermV{Term::Ptr(new Atom("a2")), Term::Ptr(new Atom("a3"))});
std::cout << "Testing stream operator for TermV: " << *term_v << std::endl;
std::cout << "Testing hash function for TermV: " << h(*term_v) << std::endl;
std::cout << "Testing == operator (equal): " << b2s(*term_v == *term_v) << std::endl;
std::cout << "Testing == operator (not equal): " << b2s(*term_v == *atom_) << std::endl;
std::cout << "Testing == operator (not equal): " << b2s(*term_v == *term_v_) << std::endl;
std::cout << "type(atom): " << term_v->type() << std::endl;
std::cout << "Testing < operator (equal): " << b2s(*term_v < *term_v) << std::endl;
std::cout << "Testing < operator (lt): " << b2s(*term_v < *term_v_) << std::endl;
std::cout << "Testing < operator (gt): " << b2s(*term_v < *atom_) << std::endl;
std::cout << "Testing < operator (lt): " << b2s(*atom_ < *term_v) << std::endl;
std::cout << "Testing clone() method: " << *(term_v->clone()) << std::endl;
// Testing PredicateFactory
std::cout << std::endl;
auto p1 = P("p", A("x"), A("y")); std::cout << p1->hash() << std::endl;
auto p2 = P("p", A("x"), A("y")); std::cout << p2->hash() << std::endl;
auto p3 = P("p", A("x"), V("Y")); std::cout << p3->hash() << std::endl;
auto p4 = P("p", A("y"), A("x")); std::cout << p4->hash() << std::endl;
auto p5 = P("p", A("z")); std::cout << p5->hash() << std::endl;
auto p6 = P("p", N(1), N(3)); std::cout << p6->hash() << std::endl;
auto p7 = P("r", A("x"), N(0.5)); std::cout << p7->hash() << std::endl;
auto p8 = P("r", A("x"), N(0.5+5E-10)); std::cout << p8->hash() << std::endl;
auto p9 = P("r", A("x"), N(0.5+2E-9)); std::cout << p9->hash() << std::endl;
auto p10 = P("emptyhand");
std::cout << b2s(p1->arguments() == p2->arguments()) << std::endl;
std::set<Predicate> predicate_set{*p1, *p2, *p3, *p4, *p5, *p6, *p7, *p8, *p9, *p10};
std::unordered_set<Predicate, std::hash<Term>> predicate_uset{*p1, *p2, *p3, *p4, *p5, *p6, *p7, *p8, *p9, *p10};
std::cout << predicate_set << std::endl;
std::cout << predicate_uset << std::endl;
auto number1 = N(0.5);
auto number1_ = N(0.5+2e-10);
auto number1__ = N(0.5+5e-9);
auto number2 = N(1);
auto number2_ = N(1+2e-10);
auto number2__ = N(1+5e-9);
INFO("number1=" << *number1);
INFO("number1_=" << *number1_);
INFO("number1__=" << *number1__);
INFO("number2=" << *number2);
INFO("number2_=" << *number2_);
INFO("number2__=" << *number2__);
// Check number specific method's
BOOST_CHECK_CLOSE(number1->number(), 0.5, 0.00001);
BOOST_CHECK_CLOSE(number2->number(), 1, 0.00001);
BOOST_CHECK(not number1->is_int());
BOOST_CHECK(number2->is_int());
BOOST_CHECK(not number2__->is_int());
// Check == operator
BOOST_CHECK_EQUAL(*number1, *number1_);
BOOST_CHECK_EQUAL(*number2, *number2_);
BOOST_CHECK_NE(*number1, *number1__);
BOOST_CHECK_NE(*number2, *number2__);
// Check < operator
BOOST_CHECK_LT(*number1, *number2);
BOOST_CHECK_LT(*number1, *number1__);
BOOST_CHECK(not(*number1 < *number1_));
BOOST_CHECK(not(*number2 < *number1));
// Check hash correctness
BOOST_CHECK_EQUAL(number1->hash(), number1__->hash()); // non-ints should have the same hash
BOOST_CHECK_EQUAL(number2->hash(), number2_->hash());
BOOST_CHECK_NE(number2->hash(), number2__->hash());
}
BOOST_AUTO_TEST_CASE(variable_test)
{
auto variable1 = V("variable1");
auto variable1_ = V("variable1");
auto variable2 = V("variable2");
INFO("variable1=" << *variable1);
INFO("variable1_=" << *variable1_);
INFO("variable2=" << *variable2);
// Check name method
BOOST_CHECK_EQUAL(variable1->name(), "variable1");
// Check == operator
BOOST_CHECK_EQUAL(*variable1, *variable1_);
BOOST_CHECK_NE(*variable1, *variable2);
// Check < operator
BOOST_CHECK_LT(*variable1, *variable2);
BOOST_CHECK(not(*variable2 < *variable1));
// Check hash correctness
BOOST_CHECK_EQUAL(variable1->hash(), variable1_->hash());
BOOST_CHECK_NE(variable1->hash(), variable2->hash());
}
BOOST_AUTO_TEST_CASE(term_v_test)
{
auto term_v_1 = L(A("a"), N(1), V("X"));
auto term_v_1_ = L(A("a"), N(1), V("X"));
auto term_v_2 = L(A("a"), N(1), A("x"));
auto term_v_3 = L(A("a"), A("x"));
auto term_v_4 = L(A("a"), A("y"));
auto term_v_5 = L(A("y"), A("a"));
INFO("term_v_1=" << *term_v_1);
INFO("term_v_1_=" << *term_v_1_);
INFO("term_v_2=" << *term_v_2);
INFO("term_v_3=" << *term_v_3);
INFO("term_v_4=" << *term_v_4);
INFO("term_v_5=" << *term_v_5);
// Check size, at, and iterators
BOOST_CHECK_EQUAL(term_v_1->size(), 3);
BOOST_CHECK_EQUAL(*term_v_1->at(0), *A("a"));
BOOST_CHECK_EQUAL(*term_v_1->at(1), *N(1));
BOOST_CHECK_EQUAL(*term_v_1->at(2), *V("X"));
auto it = term_v_1->begin();
BOOST_CHECK_EQUAL(**it, *A("a"));
it += 3;
BOOST_CHECK(it == term_v_1->end());
// Check is_ground
//BOOST_CHECK(term_v_2->is_ground());
//BOOST_CHECK(not ẗerm_v_1->is_ground());
// Check == operator
BOOST_CHECK_EQUAL(*term_v_1, *term_v_1_);
BOOST_CHECK_NE(*term_v_1, *term_v_2);
BOOST_CHECK_NE(*term_v_1, *term_v_3);
BOOST_CHECK_NE(*term_v_4, *term_v_5);
// Check < operator
BOOST_CHECK_LT(*term_v_2, *term_v_1);
BOOST_CHECK_LT(*term_v_3, *term_v_2);
BOOST_CHECK_LT(*term_v_3, *term_v_4);
// Check hash correctness
BOOST_CHECK_EQUAL(term_v_1->hash(), term_v_1_->hash());
BOOST_CHECK_NE(term_v_4->hash(), term_v_5->hash());
// Check to_str
BOOST_CHECK_EQUAL(term_v_4->to_str(), "a,y");
}
BOOST_AUTO_TEST_CASE(predicate_test)
{
auto pred1 = P("p", A("x"), A("y"));
auto pred1_ = P("p", A("x"), A("y"));
auto pred2 = P("q", A("x"), A("y"));
auto pred3 = P("p", A("x"));
INFO("pred1=" << *pred1);
INFO("pred1_=" << *pred1_);
INFO("pred2=" << *pred2);
INFO("pred3=" << *pred3);
// Check == operator
BOOST_CHECK_EQUAL(*pred1, *pred1_);
BOOST_CHECK_NE(*pred1, *pred2);
BOOST_CHECK_NE(*pred1, *pred3);
// CHECK < OPERATOR
BOOST_CHECK_LT(*pred1, *pred2);
BOOST_CHECK_LT(*pred3, *pred1);
BOOST_CHECK(not(*pred1 < *pred3));
// Check to_str
BOOST_CHECK_EQUAL(pred1->to_str(), "p(x,y)");
}
BOOST_AUTO_TEST_CASE(substitution_test)
{
Substitution sigma{{"X", A("a")}, {"Y", N(0)}};
INFO("sigma=" << sigma);
auto pred1 = P("p", V("X"), V("Y"));
auto pred2 = P("q", A("a"), V("Z"), V("X"));
auto pred1_ = sigma(pred1);
auto pred2_ = sigma(pred2);
BOOST_CHECK_EQUAL(*pred1_, *P("p", A("a"), N(0)));
BOOST_CHECK_EQUAL(*pred2_, *P("q", A("a"), V("Z"), A("a")));
}
BOOST_AUTO_TEST_CASE(ordered_set_test)
{
// Further check the correctness of the < and == operators via set operations.
Predicate p1("p", {A("x"), A("y")});
Predicate p1_("p", {A("x"), A("y")});
Predicate p2("p", {V("X"), N(5)});
Predicate p3("q", {N(0)});
Predicate p4("r", {});
std::set<Predicate> o_set{p1, p1_, p2, p3, p4};
INFO("Ordered set (on creation): " << o_set);
BOOST_CHECK_EQUAL(o_set.size(), 4);
o_set.erase(p1);
INFO("Ordered set (after removing p(x,y)): " << o_set);
BOOST_CHECK_EQUAL(o_set.size(), 3);
o_set.erase(Predicate("q", {N(5e-9)}));
INFO("Ordered set (after trying to remove non-existent number): " << o_set);
BOOST_CHECK_EQUAL(o_set.size(), 3);
o_set.erase(Predicate("q", {N(5e-10)}));
INFO("Ordered set (after removing existent number): " << o_set);
BOOST_CHECK_EQUAL(o_set.size(), 2);
o_set.insert(Predicate("r", {}));
INFO("Ordered set (after trying to insert existent predicate): " << o_set);
BOOST_CHECK_EQUAL(o_set.size(), 2);
o_set.insert(Predicate("s", {N(7)}));
INFO("Ordered set (after inserting non-existent predicate): " << o_set);
BOOST_CHECK_EQUAL(o_set.size(), 3);
}
BOOST_AUTO_TEST_CASE(unordered_set_test)
{
// Further check the correctness of the hash and == methods via set operations.
Predicate p1("p", {A("x"), A("y")});
Predicate p1_("p", {A("x"), A("y")});
Predicate p2("p", {V("X"), N(5)});
Predicate p3("q", {N(0)});
Predicate p4("r", {});
std::unordered_set<Predicate, std::hash<Term>> u_set{p1, p1_, p2, p3, p4};
INFO("Unordered set (on creation): " << u_set);
BOOST_CHECK_EQUAL(u_set.size(), 4);
u_set.erase(p1);
INFO("Unordered set (after removing p(x,y)): " << u_set);
BOOST_CHECK_EQUAL(u_set.size(), 3);
u_set.erase(Predicate("q", {N(5e-9)}));
INFO("Unordered set (after trying to remove non-existent number): " << u_set);
BOOST_CHECK_EQUAL(u_set.size(), 3);
u_set.erase(Predicate("q", {N(5e-10)}));
INFO("Unordered set (after removing existent number): " << u_set);
BOOST_CHECK_EQUAL(u_set.size(), 2);
u_set.insert(Predicate("r", {}));
INFO("Unordered set (after trying to insert existent predicate): " << u_set);
BOOST_CHECK_EQUAL(u_set.size(), 2);
u_set.insert(Predicate("s", {N(7)}));
INFO("Unordered set (after inserting non-existent predicate): " << u_set);
BOOST_CHECK_EQUAL(u_set.size(), 3);
}
......@@ -87,7 +87,7 @@ bool Variable::operator<(const Term& other) const
{
if (const Variable* c_other = dynamic_cast<const Variable*>(&other))
{
return c_other->name_ < c_other->name_;
return this->name_ < c_other->name_;
}
return this->type() < other.type();
}
......@@ -100,24 +100,11 @@ std::size_t Variable::hash() const
// Implementation of TermV's methods
#define COPY_TERMV() for (auto term : term_v) term_v_.push_back(term->clone())
TermV::TermV() {}
TermV::TermV(const TermV& term_v)
{
COPY_TERMV();
}
TermV::TermV(const std::vector<Term::Ptr>& term_v)
{
COPY_TERMV();
}
TermV::TermV(const std::vector<Term::Ptr>& term_v) : term_v_(term_v) {}
TermV::TermV(const std::initializer_list<Term::Ptr>& term_v)
{
COPY_TERMV();
}
TermV::TermV(const std::initializer_list<Term::Ptr>& il) : term_v_(il) {}
std::string TermV::to_str() const
{
......@@ -141,12 +128,6 @@ bool TermV::is_ground() const
return true;
}
TermV& TermV::operator=(const TermV& term_v)
{
COPY_TERMV();
return *this;
}
bool TermV::operator==(const Term& other) const
{
if (const TermV* c_other = dynamic_cast<const TermV*>(&other))
......@@ -231,6 +212,22 @@ std::size_t Predicate::hash() const
Substitution::Substitution(std::initializer_list<Entry> il) : map_(il) {}
std::string Substitution::to_str() const
{
std::ostringstream oss;
oss << "{";
bool first = true;
for (auto entry : map_)
{
if (not first) oss << ',';
oss << "\n " << entry.first << " -> " << *entry.second;
first = false;
}
if (not first) oss << '\n';
oss << '}';
return oss.str();
}
Term::Ptr Substitution::get(const std::string& key, Term::Ptr default_) const
{
auto it = map_.find(key);
......@@ -264,12 +261,57 @@ Predicate::Ptr Substitution::operator()(Predicate::Ptr input) const
return Predicate::Ptr(new Predicate(input->name(), arguments));
}
// Implementation of State's methods
// Implementation of StateQuery's method
//void StateQuery::reset(const Substitution& sigma_0)
//{
//sigma_0_ = sigma_0;
//cur_predicate_ = predicate_range_.first;
//if (child != nullptr) child->reset(sigma_0);
//}
//StateQuery::StateQuery(const State& state, const Predicate* query,
//int query_size, const Substitution& sigma_0) :
//state_(state), query_(query), query_size_(query_size), child_(nullptr),
//done_(false), sigma_0_(sigma_0)
//{
//if (query_ != nullptr and query_size_ != 0)
//{
//child_ = StateQuery(state_, query_+sizeof(Predicate), query_size_-1,
//sigma_0_);
//}
//predicate_range_ = state_.predicates(query->name());
//cur_predicate_ = predicate_range_.first;
//}
//bool StateQuery::next_solution()
//{
//if (query_size_ == 0)
//{
//if (done_) return false;
//done_ = true;
//return true;
//}
//}
//const Substitution& StateQuery::get_substitution() const
//{
//return query_size_ == 0? sigma_0_ : child_->get_substitution();
//}
//StateQuery::~StateQuery()
//{
//if (child_ != nullptr) delete child_;
//}
// Implementation of stream operator
std::ostream& operator<<(std::ostream& os, const Stringifiable& strable)
{
os << strable.to_str();
return os;
return (os << strable.to_str());
}
} /* end imagine_planner namespace */
......
......@@ -6,17 +6,13 @@
#define TYPES_COMMON(Class)\
typedef std::shared_ptr<const Class> Ptr;\
virtual std::string type() const override { return #Class; }\
virtual Term::Ptr clone() const override\
{\
return Term::Ptr(new Class(*this));\
}\
#include <functional>
#include <memory>
#include <ostream>
#include <string>
#include <vector>
#include <unordered_map>
#include <map>
namespace imagine_planner
{
......@@ -46,7 +42,7 @@ class Stringifiable
* Declares the == and < operators and the hash function so they can be used
* with both unordered and ordered sets/maps. The class is abstract.
* Functionality is given by this class descendants (Atom, Number, TermV and
* Predicate).
* Predicate). Terms are immutable.
*/
class Term : public Stringifiable
{
......@@ -121,19 +117,6 @@ class Term : public Stringifiable
*/
virtual std::string type() const =0;
/**
* @brief Copies and returns a pointer to the copy of the current term.
*
* The attributes of the derived class are maintained. The method is
* automatically introduced by the TYPES_COMMON macro, and its definition
* makes use of the copy constructor of the derived class. Therefore,
* the copy constructor should be explicitly defined if the behavior of the
* default one is not satisfactory.
*
* @return Pointer to a copy of this term.
*/
virtual Term::Ptr clone() const =0;
/**
* @brief Virtual destructor.
*/
......@@ -307,11 +290,9 @@ class TermV : public Term
TermV();
TermV(const TermV& term_v);
TermV(const std::vector<Term::Ptr>& term_v);
TermV(const std::initializer_list<Term::Ptr>& term_v);
TermV(const std::initializer_list<Term::Ptr>& il);
virtual std::string to_str() const override;
......@@ -319,8 +300,6 @@ class TermV : public Term
virtual bool is_atomic() const override { return false; }
TermV& operator=(const TermV& term_v);
virtual bool operator==(const Term& other) const override;
virtual bool operator<(const Term& other) const override;
......@@ -331,9 +310,9 @@ class TermV : public Term
const Term::Ptr at(int idx) const { return term_v_.at(idx); }
CIter begin() const { return term_v_.cbegin(); }
CIter begin() const { return term_v_.begin(); }
CIter end() const { return term_v_.cend(); }
CIter end() const { return term_v_.end(); }
};
......@@ -448,7 +427,7 @@ class Substitution : public Stringifiable
{
private:
typedef std::unordered_map<std::string, Term::Ptr> Map;
typedef std::map<std::string, Term::Ptr> Map;
Map map_;
......@@ -458,6 +437,8 @@ class Substitution : public Stringifiable
Substitution(std::initializer_list<Entry> il);
virtual std::string to_str() const override;
Term::Ptr get(const std::string& key, Term::Ptr default_=nullptr) const;
void put(const std::string& key, Term::Ptr term);
......@@ -468,12 +449,90 @@ class Substitution : public Stringifiable
int size() const { return map_.size(); }
Map::const_iterator begin() const { return map_.cbegin(); }
Map::const_iterator begin() const { return map_.begin(); }
Map::const_iterator end() const { return map_.begin(); }
};
class StateQuery;
class State : public Stringifiable
{
private:
//typedef std::unordered_multiset<Predicate, std::hash<Predicate>> Set;
typedef std::multimap<std::string, Predicate> Map;
Map predicates_;
std::size_t hash_;
public:
typedef std::pair<const std::string, Predicate> Entry;
typedef Map::const_iterator CIter;
typedef std::pair<CIter, CIter> Range;
State();
State(std::initializer_list<Entry> il);
Map::const_iterator end() const { return map_.cbegin(); }
Range predicates(const std::string& type) const;
CIter begin() const { return predicates_.begin(); }
CIter end() const { return predicates_.end(); }
};
class StateQuery
{
private:
const State& state_;
const Predicate* query_;
int query_size_;
Substitution sigma_0_;
State::Range predicate_range_;
State::CIter cur_predicate_;
StateQuery* child_;
bool done_;
void reset();
protected:
StateQuery(const State& state, const Predicate* query=nullptr,
int query_size=0, const Substitution& sigma_0=Substitution{});
public:
bool next_solution();
const Substitution& get_substitution() const;
</