//! \file
//! \brief Thompson regex algorithm implementation.
//! \mainpage The Thompson algorithm
//! \brief Thompson regex algorithm implementation.
//! \author Mathieu Turcotte
//! \date November 2010
//!
//! To compile on GCC:
//!     g++ -std=gnu++0x thompson.cpp
//!
//! Should compile fine on VS2010.
//!
//! The regex grammar used:<pre>
//! \<regexp>    ::= \<branche> { '|' \<branche> }
//! \<branche>   ::= \<fragment> { \<fragment> }
//! \<fragment>  ::= \<atome> { '*' | '?' | '+' }
//! \<atome>     ::= '(' \<regexp> ')' | \<caractere> | \<nombre>
//! \<caractere> ::= 'a' | 'b' | ... | 'z'
//! \<nombre>    ::= '0' | '1' | ... | '9'</pre>

#include <iostream>
#include <cstdlib>      // NULL, EXIT_SUCCESS, size_t
#include <stdexcept>
#include <sstream>
#include <string>
#include <vector>
#include <memory>       // C++0x std::shared_ptr
#include <set>

//! Thompson implementation namespace.
namespace thompson {

class NFA;

//! Thompson implementation private namespace.
namespace internal {

struct State;
struct Transition;
struct Split;
struct Match;

//! The base class for all NFA visitor.
class Visitor {
public:
    //! \brief Called when visiting a NFA.
    //! \param[in] start The NFA start state.
    //! \param[in] end The NFA end state.
    virtual void visit_nfa(const State* start, const State* end) = 0;
    //! \brief Called when visiting a NFA Transition state.
    //! \param[in] trans The visited transition.
    virtual void visit_transition(const Transition* trans) = 0;
    //! \brief Called when visiting a NFA Split state.
    //! \param[in] split The visited split.
    virtual void visit_split(const Split* split) = 0;
    //! \brief Called when visiting a NFA Match state.
    //! \param[in] match The visited match.
    virtual void visit_match(const Match* match) = 0;
};

//! NFA state base class.
struct State {

    //! Define explicitly three kind of state to supplement
    //! C++ polymorphic behavior. Only for debugging purpose.
    typedef enum {
        labeled, split, match
    } Kind;

    //! \brief Construct a state of the defined Kind.
    //! \param[in] kind The state kind.
    State(Kind kind) :
        kind(kind) {
    }

    //! \brief Visitor entry point.
    //! \param visitor The visiting visitor.
    virtual void visit(Visitor* visitor) const {}

    Kind kind;  /**< The state kind. Only for debugging purpose. */
};

//! NFA labeled transition.
struct Transition: public State {

    //! \brief Construct a Transition State.
    //! \param[in] label The transition label.
    Transition(char label) :
        State(labeled), label(label) {
    }

    virtual void visit(Visitor* visitor) const {
        visitor->visit_transition(this);
    }

    State* next;    /**< A pointer to the next state. */
    char label;     /**< The state transition label. */
};

//! NFA splitting state.
struct Split: public State {

    //! \brief Construct a Split State.
    //! \param[in] first The first neighbor State.
    //! \param[in] second The second neighbor State.
    Split(State* first, State* second) :
        State(split), first(first), second(second) {
    }

    virtual void visit(Visitor* visitor) const {
        visitor->visit_split(this);
    }

    State* first;   /**< A pointer to the first following state. */
    State* second;  /**< A pointer to the second following state. */
};

//! NFA accepting state.
struct Match: public State {
    //! \brief Construct a matching state.
    Match() :
        State(match) {
    }

    virtual void visit(Visitor* visitor) const {
        visitor->visit_match(this);
    }
};

//! \brief NFA states "mark and release" memory pool.
//!
//! This "mark and release" memory pool is used to
//! allocate and construct the NFA states when parsing
//! a regular expression. All states are released at
//! once when the pool is destructed. This is loosely
//! inspired by the GNU libc obstacks.
//!
//! This is a simple solution to simple needs.
class StatePool {

    //! \brief StatePool memory representation.
    //!
    //! No data alignement is performed when a Chunk
    //! allocates memory, i.e. it can't be used in a
    //! general case.
    template <unsigned int chunk_size>
    struct Chunk {
    public:
        //! \brief Default constructor.
        //! \post The next pointer is NULL
        //! \post The data_index is set to 0.
        Chunk(): next(NULL), data_index(0) {}

        //! \brief Reserve N bytes in this chunk.
        //! \param nbytes Number of bytes to allocate/reserve.
        //! \pre The requested number of bytes (nbytes) must be
        //!      smaller or equal (<=) to the chunk_size.
        //! \return A raw pointer to the begining of the reserved
        //!         memory space.
        void* reserve(size_t nbytes) throw(std::logic_error) {

            // Make sure that a Chunk can really
            // serve a request of this size.
            if (nbytes > chunk_size)
                throw std::logic_error("StatePool::Chunk::reserve(): the "
                                       "requested number of bytes is bigger "
                                       "than a chunk_size.");

            // Make sure that we have enough room serve
            // the request. If not return NULL. The pool
            // will handle this by allocating a new chunk.
            if (chunk_size - data_index >= nbytes) {
                void* dptr = &data[data_index];
                data_index += nbytes;
                return dptr;
            }

            return NULL;
        }

        Chunk* next;             /**< A pointer to the next chunk. */
        unsigned int data_index; /**< Current position in the data array. */
        char data[chunk_size];   /**< Allocation array containing chunk_size bytes. */
    };

    typedef Chunk<64> MemoryChunk;

public:

    //! \brief Default constructor.
    StatePool() {
        root_chunk = last_chunk = new MemoryChunk;
    }

    //! \brief Destructor.
    //!
    //! Free all memory associated to the states allocated from
    //! this pool. Beware that it won't call the destructor on
    //! those objects.
    ~StatePool() {
        while (root_chunk) {
            MemoryChunk* temp_chunk = root_chunk;
            root_chunk = root_chunk->next;
            delete temp_chunk;
        }
    }

    //! \brief Construct a split state from the pool.
    //! \param[in] first The first neighbor State.
    //! \param[in] second The second neighbor State.
    //! \return A pointer to the newly constructed split state.
    Split* make_split(State* first, State* second) {
        void* split = reserve(sizeof(Split));
        new (static_cast<void*>(split)) Split(first, second);
        return reinterpret_cast<Split*>(split);
    }

    //! \brief Construct a transition state from the pool.
    //! \param label The transition label.
    //! \return A pointer to the newly constructed transition state.
    Transition* make_transition(char label) {
        void* trans = reserve(sizeof(Transition));
        new (static_cast<void*>(trans)) Transition(label);
        return reinterpret_cast<Transition*>(trans);
    }

    //! \brief Construct a match state from the pool.
    //! \return A pointer to the newly constructed match state.
    Match* make_match() {
        void* match = reserve(sizeof(Match));
        new (static_cast<void*>(match)) Match();
        return reinterpret_cast<Match*>(match);
    }

protected:

    //! \brief Reserve N bytes from this pool.
    //! \param nbytes Number of bytes to allocate/reserve.
    //! \return A raw pointer to the begining of the reserved
    //!         memory space.
    void* reserve(size_t nbytes) {
        // Try to reserve N bytes from the last
        // allocated memory chunk.
        void* data = last_chunk->reserve(nbytes);

        // Check if it's a miss. If so, we'll
        // need to allocate a new chunk.
        if (!data) {
            last_chunk->next = new MemoryChunk;
            last_chunk = last_chunk->next;
            data = last_chunk->reserve(nbytes);
        }

        return data;
    }

    StatePool(const StatePool&) {}
    void operator=(const StatePool&) {}

    MemoryChunk* root_chunk; /**< First allocated chunk. */
    MemoryChunk* last_chunk; /**< Last allocated chunk. */
};

} // namespace internal

//! \brief Nondeterministic Finite Automata representation.
//!
//! <pre>
//!                                    .---.
//!                            .------>| b |
//!                            |       '---'
//!                        .-------.     |
//!                 .----->| split |<----'
//!                 |      '-------'
//!             .-------.      |       .---.     .-------.
//!  start ---->| split |      '------>| a |---->| match |
//!             '-------'              '---'     '-------'
//!                 |        .---.                   ^
//!                 '------->| c |-------------------'
//!                          '---'
//! </pre>
class NFA {
protected:

    friend class RegexParser;

    typedef internal::State State;
    typedef internal::StatePool StatePool;
    typedef internal::Visitor Visitor;

    //! \brief NFA constructor.
    //! \param[in] start A pointer to the starting start
    //! \param[in] end A pointer to the accepting state
    //! \param pool The state memory pool
    NFA(State* start, State* end, std::shared_ptr<StatePool> pool) :
        start(start), end(end), pool(pool) {
    }

public:

    //! \brief
    void visit(Visitor* visitor) const {
        visitor->visit_nfa(start, end);
    }

protected:

    State* start;    /**< The NFA start state. */
    State* end;      /**< The NFA accepting state. */
    std::shared_ptr<StatePool> pool;    /**< The state pool. */
};

//! \brief Thrown when the regex can't be parsed.
class ParseError: public std::runtime_error {
public:
    //! \brief Constructor.
    //! \param msg The exception cause.
    //! \param pos The parser position when
    //!            the exception was thrown.
    ParseError(const std::string& msg, int pos) :
        std::runtime_error(msg), pos(pos) {}

    //! \brief Get the ParseError position.
    //! \return The ParseError position.
    int where() const {
        return pos;
    }
protected:
    int pos;
};

//! \brief A regexp parser.
class RegexParser {

    typedef internal::State State;
    typedef internal::Split Split;
    typedef internal::Match Match;
    typedef internal::Transition Transition;
    typedef internal::StatePool StatePool;

    //! \brief A class representing a partially constructed NFA.
    //!
    //! <pre>
    //!                                    .---.
    //!                            .------>| b |-----> out edge
    //!                            |       '---'
    //!                        .-------.
    //!                 .----->| split |
    //!                 |      '-------'
    //!             .-------.      |       .---.
    //!  start ---->| split |      '------>| a |-----> out edge
    //!             '-------'              '---'
    //!                 |        .---.
    //!                 '------->| c |---------------> out edge
    //!                          '---'
    //! </pre>
    class Fragment {
    public:
        //! \brief Construct an empty fragment.
        Fragment(std::shared_ptr<StatePool> pool) :
            start(NULL), pool(pool) {
        }

        //! \brief Construct a fragment consisting of one labeled transition.
        Fragment(char label, std::shared_ptr<StatePool> pool) :
            pool(pool) {
            Transition* state = pool->make_transition(label);
            out_edges.push_back(&state->next);
            start = state;
        }

        //! \brief Concatenate two NFA fragments.
        //! \param[in] frag The fragment to be concatenate.
        //!
        //! <pre>
        //!  ab  .-------.    .-------.
        //!  --->| frag1 |--->| frag2 |--->
        //!      '-------'    '-------'
        //! </pre>
        void concat(const Fragment& frag) {
            // Set all outgoing edges of the current fragment to
            // point to the concatenated fragment start state.
            for (size_t i = 0; i < out_edges.size(); i++) {
                *out_edges[i] = frag.start;
            }
            // Our outgoing edges are now the one of the
            // concatenated fragment.
            out_edges = frag.out_edges;
        }

        //! \brief Perform the union of two NFA fragments.
        //! \param[in] frag The fragment to be united.
        //!
        //! <pre>
        //!                   .-------.
        //!  a|b     .------->| frag2 |--->
        //!          |        '-------'
        //!      .-------.
        //!  --->| split |
        //!      '-------'
        //!          |        .-------.
        //!          '------->| frag1 |--->
        //!                   '-------'
        //! </pre>
        void split(const Fragment& frag) {
            // Our new start state is a split.
            start = pool->make_split(start, frag.start);

            // Append to our outgoing edges the one from frag.
            for (size_t i = 0; i < frag.out_edges.size(); i++) {
                out_edges.push_back(frag.out_edges[i]);
            }
        }

        //! \brief Perform a closure on this NFA fragment.
        //!
        //! <pre>
        //!                   .-------.
        //!  a*      .------->| frag  |
        //!          |        '-------'
        //!      .-------.        |
        //!  --->| split |<-------'
        //!      '-------'
        //!          |
        //!          '-------------------->
        //! </pre>
        void closure() {
            // Our new start state is a split.
            Split* split = pool->make_split(start, NULL);

            // All outgoing edges point back to the new start.
            for (size_t i = 0; i < out_edges.size(); i++) {
                *out_edges[i] = split;
            }
            start = split;
            out_edges.clear();
            out_edges.push_back(&split->second);
        }

        //! \brief Apply the + operator on this NFA fragment.
        //!
        //! <pre>
        //!      .-------.
        //!  --->| frag  |<-------.
        //!      '-------'        |
        //!          |            |
        //!  a+      |        .-------.
        //!          '------->| split |------->
        //!                   '-------'
        //! </pre>
        void one_or_more() {
            Split* split = pool->make_split(start, NULL);
            // All outgoing edges point to the new start.
            for (size_t i = 0; i < out_edges.size(); i++) {
                *out_edges[i] = split;
            }
            out_edges.clear();
            out_edges.push_back(&split->second);
        }

        //! \brief Apply the ? operator on this NFA fragment.
        //!
        //! <pre>
        //!                   .-------.
        //!  a?      .------->| frag  |--->
        //!          |        '-------'
        //!      .-------.
        //!  --->| split |
        //!      '-------'
        //!          |
        //!          '-------------------->
        //! </pre>
        void zero_or_one() {
            Split* split = pool->make_split(start, NULL);
            out_edges.push_back(&split->second);
            start = split;
        }

        //! \brief Convert this fragment to a complete NFA by appending
        //! to it an accepting state.
        std::shared_ptr<NFA> to_nfa() {
            Match* match = pool->make_match();
            // All outgoing edges point to the accepting state.
            for (size_t i = 0; i < out_edges.size(); i++) {
                *out_edges[i] = match;
            }
            return std::shared_ptr<NFA>(new NFA(start, match, pool));
        }

    protected:
        State* start;                    /**< The fragment start state. */
        std::vector<State**> out_edges;  /**< List of pointers to the unbouded
                                              outgoing edges of the fragment. */
        std::shared_ptr<StatePool> pool; /**< Pool from which this fragment
                                              allocates NFA states. */
    };

public:

    //! \brief Build the NFA for a regular expression.
    //! \param regexp The regular expression string.
    //! \return A heap allocated NFA.
    std::shared_ptr<NFA> parse(const std::string& regexp) {
        pool = std::make_shared<StatePool>();
        this->regexp = regexp;  // Set parser input.
        length = regexp.size(); // Set input size.
        pos = 0;    // Set current position in input.
        advance();  // Put the first symbol in curtok.
        Fragment fragment = parse_regexp(); // Parse the regexp,
        return fragment.to_nfa();           // converts it to an NFA.
    }

protected:

    //! \brief Advance by one curtok.
    //! \post curtok contains the next regexp symbol or the null
    //! char if the input regexp is exhausted.
    void advance() {
        if (pos < length)
            curtok = regexp[pos++];
        else
            curtok = '\0';
    }

    //! \brief Advance curtok by one, checking curtok current value.
    //! \param c The character to match curtok against.
    //! \post curtok contains the next regexp symbol or the null
    //! char if the input regexp is exhausted.
    void advance(char c) {
        if (curtok != c) {
            std::stringstream error_msg;
            error_msg << "expected " << c <<" but got ";
            if (curtok == '\0') { error_msg << "EOL"; }
            else { error_msg << curtok; }
            throw ParseError(error_msg.str(), pos);
        }
        advance();
    }

    //! \brief Parse a regexp.
    //! \return A NFA fragment.
    //!
    //! \<regexp>    ::= \<branche> { '|' \<branche> }
    Fragment parse_regexp() {
        Fragment frag = parse_branch();
        while (curtok == '|') {
            advance('|');
            frag.split(parse_branch());
        }
        return frag;
    }

    //! \brief Parse a regexp branch.
    //! \return A NFA fragment.
    //!
    //! \<branche>   ::= \<particule> { \<particule> }
    Fragment parse_branch() {
        Fragment frag = parse_particule();
        // Particule are delimited either by:
        // - the union symbol;
        // - a close parenthesis;
        // - a null chararter.
        while (curtok != '|' && curtok != ')' && curtok != '\0') {
            frag.concat(parse_particule());
        }
        return frag;
    }

    //! \brief Parse a regexp particule.
    //! \return A NFA fragment.
    //!
    //! \<particule> ::= \<atome> { '*' | '?' | '+' }
    Fragment parse_particule() {
        Fragment frag = parse_atome();
        switch (curtok) {
        case '*':
            frag.closure();
            advance('*');
            break;
        case '+':
            frag.one_or_more();
            advance('+');
            break;
        case '?':
            frag.zero_or_one();
            advance('?');
        }
        return frag;
    }

    //! \brief Parse a regexp atome.
    //! \return A NFA fragment.
    //!
    //! \<atome>     ::= '(' \<regexp> ')' | \<char> | \<number>
    //! \<char>      ::= 'a' | 'b' | ... | 'z'
    //! \<number>    ::= '0' | '1' | ... | '9'
    Fragment parse_atome() {
        Fragment frag(pool);
        if (curtok == '(') {
            advance('(');
            frag = parse_regexp();
            advance(')');
        } else {
            frag = Fragment(curtok, pool);
            advance();
        }
        return frag;
    }

    std::shared_ptr<StatePool> pool;  /**< State memory pool. */
    std::string regexp; /**< The input regular expression. */
    int length;         /**< length of the input regular expression. */
    int pos;            /**< Current position in the input regexp. */
    char curtok;        /**< The current token. */
};

//! \brief Visitor encapsulating the Thompson algorithm.
//!
//! This visitor is implementing the Thompson NFA simulation.
class ThompsonAlgorithm: public internal::Visitor {

    typedef internal::State State;
    typedef internal::Split Split;
    typedef internal::Transition Transition;
    typedef internal::Match Match;

public:

    //! \brief Simulate the NFA over the input string.
    //! \param str The string to match.
    //! \param nfa The NFA to execute.
    bool accept(const std::string& str, const std::shared_ptr<NFA> nfa) {

        input = str;         // Set the input string.
        pos = 0;             // Set current position in the input.
        length = str.size(); // Set the input length.
        curtok = 'a';        // Just to be sure that it's not a null char.

        nfa->visit(this);

        // In order to accept the input
        // - the final set must contain the accepting state (end) or
        //   matched is true, which means that the matching state was
        //   visited during the last iteration
        // - and the input must have all been consumed.
        return (matched || current.find(end) != current.end()) && curtok == '\0';
    }

    void visit_nfa(const State* start, const State* end) {
        this->start = start;
        this->end = end;

        current.clear();        // Prime the "current" set
        current.insert(start);  // for the first iteration.

        // Iterate over every input characters or until there is
        // no more accessible state in the "current" set.
        while (curtok != '\0' && current.size() != 0) {
            matched = false;
            advance();

            // We visit every accessible state.
            for (std::set<const State*>::const_iterator i = current.begin();
                   i != current.end(); ++i) {

                (*i)->visit(this);
            }

            current.swap(next); // Swap the "current" set with the
            next.clear();       // "next" set and then clear the
                                // "next" set for the next iteration.
        }
    }

    void visit_transition(const Transition* trans) {
        // If we can follow the transition based
        // on the current token, insert the next
        // state in the "next" state.
        if (curtok == trans->label)
            next.insert(trans->next);
    }

    void visit_split(const Split* split) {
        // It's a split, visit both branches.
        split->first->visit(this);
        split->second->visit(this);
    }

    void visit_match(const Match* match) {
        matched = true;
    }

protected:

    //! \brief Advance by one curtok.
    //! \post curtok contains the next regexp symbol or the null
    //! char if the input is exhausted.
    void advance() {
        if (pos < length) {
            curtok = input[pos++];
        } else {
            curtok = '\0';
        }
    }

    bool matched;       /**< Have we visited the accepting
                             state on the last iteration. */
    char curtok;        /**< The char we're looking at in the input string. */
    int pos;            /**< The current position in the input string. */
    int length;         /**< The input string length. */
    std::string input;  /**< The string to match. */
    const State* start; /**< Pointer to the start state. */
    const State* end;   /**< Pointer to the accepting state. */
    std::set<const State*> current; /**< The set of states being visited. */
    std::set<const State*> next;    /**< The set of states to be visited. */
};

} // namespace thompson

//! \brief Test a regular expression against some input string.
//! \param regex The regular expression.
//! \param input The input string.
//! \param expect The expected result of matching the
//!        input string against the regular expression.
bool regex_test(const std::string& regex,
        const std::string& input, bool expect) {

    thompson::RegexParser parser;
    thompson::ThompsonAlgorithm thompson_algorithm;

    std::shared_ptr<thompson::NFA> nfa = parser.parse(regex);
    bool result = thompson_algorithm.accept(input, nfa);

    if (result == expect) {
        std::cout << "[OK] " << regex << (expect ? " accept " : " reject ")
                << input << " as expected." << std::endl;
    } else {
        std::cout << "[ERR] " << regex << " erroneously "
                << (expect ? "reject " : "accept ") << input << "." << std::endl;
    }

    return result == expect;
}

int main() {

    int success = 0;
    int error = 0;

#define TEST(regex, input, expect) regex_test(regex, input, expect) \
                                        ? success++: error++;

    // Union.
    TEST("a(b|c)d", "abd", true);
    TEST("a(b|c)d", "acd", true);
    TEST("a(b|c)d", "e", false);
    TEST("a(b|c)d", "ace", false);
    TEST("a(b|c)d", "abbd", false);
    TEST("a(b|c)d", "abb", false);
    TEST("a(b|c)d", "add", false);

    // One or many.
    TEST("a+b", "ab", true);
    TEST("a+b", "aab", true);
    TEST("a+b", "aaab", true);
    TEST("a+b", "b", false);

    // Zero or many.
    TEST("a*", "", true);
    TEST("a*", "a", true);
    TEST("a*", "aa", true);
    TEST("a*", "ab", false);

    // Zero or one.
    TEST("a?", "", true);
    TEST("a?", "a", true);
    TEST("a?", "aa", false);
    TEST("a?", "ab", false);

    TEST("a(b|c)*d", "abbbbd", true);
    TEST("a(b|c)*d", "accccd", true);
    TEST("a(b|c)*d", "acbbcd", true);
    TEST("a(b|c)*d", "ad", true);
    TEST("a(b|c)*d", "aed", false);

    TEST("a|b|c", "a", true);
    TEST("a|b|c", "d", false);

    TEST("abc*", "ab", true);
    TEST("abc*", "abc", true);
    TEST("abc*", "abcc", true);

    TEST("ab(c*|a?)", "ab", true);
    TEST("ab(c*|a?)", "abc", true);
    TEST("ab(c*|a?)", "abcc", true);
    TEST("ab(c*|a?)", "aba", true);
    TEST("ab(c*|a?)", "abaa", false);

    TEST("a?a?a?a?a?a?a?a?a?a?a?a?a?a?a?", "aaaa", true);

    TEST("(a|aa)*b", "b", true);
    TEST("(a|aa)*b", "aaaaab", true);
    TEST("(a|aa)*b", "aa", false);

    TEST("(a|aa)**b", "*b", true);

    TEST("(a|(b|(c|d)))", "a", true);
    TEST("(a|(b|(c|d)))", "b", true);
    TEST("(a|(b|(c|d)))", "c", true);
    TEST("(a|(b|(c|d)))", "d", true);

    TEST("a|(b|(c|d)*)", "", true);
    TEST("a|(b|(c|d)*)", "cd", true);

    TEST("(a|(b|(c|d)+))", "c", true);
    TEST("(a|(b|(c|d)+))", "cc", true);
    TEST("(a|(b|(c|d)+))", "cd", true);
    TEST("(a|(b|(c|d)+))", "dd", true);

    std::cout << error << " error over " << success + error << " tests." << std::endl;

#undef TEST

    return EXIT_SUCCESS;
}