001    package aima.logic.fol.inference;
002    
003    import java.util.ArrayList;
004    import java.util.HashMap;
005    import java.util.LinkedHashMap;
006    import java.util.LinkedHashSet;
007    import java.util.List;
008    import java.util.Map;
009    import java.util.Set;
010    
011    import aima.logic.fol.Connectors;
012    import aima.logic.fol.StandardizeApart;
013    import aima.logic.fol.StandardizeApartInPlace;
014    import aima.logic.fol.StandardizeApartIndexical;
015    import aima.logic.fol.StandardizeApartIndexicalFactory;
016    import aima.logic.fol.SubstVisitor;
017    import aima.logic.fol.SubsumptionElimination;
018    import aima.logic.fol.Unifier;
019    import aima.logic.fol.inference.proof.Proof;
020    import aima.logic.fol.inference.proof.ProofFinal;
021    import aima.logic.fol.inference.proof.ProofStepChainCancellation;
022    import aima.logic.fol.inference.proof.ProofStepChainDropped;
023    import aima.logic.fol.inference.proof.ProofStepChainFromClause;
024    import aima.logic.fol.inference.proof.ProofStepChainReduction;
025    import aima.logic.fol.inference.proof.ProofStepGoal;
026    import aima.logic.fol.inference.trace.FOLModelEliminationTracer;
027    import aima.logic.fol.kb.FOLKnowledgeBase;
028    import aima.logic.fol.kb.data.Chain;
029    import aima.logic.fol.kb.data.Clause;
030    import aima.logic.fol.kb.data.Literal;
031    import aima.logic.fol.kb.data.ReducedLiteral;
032    import aima.logic.fol.parsing.ast.AtomicSentence;
033    import aima.logic.fol.parsing.ast.ConnectedSentence;
034    import aima.logic.fol.parsing.ast.NotSentence;
035    import aima.logic.fol.parsing.ast.Sentence;
036    import aima.logic.fol.parsing.ast.Term;
037    import aima.logic.fol.parsing.ast.Variable;
038    
039    /**
040     * Based on lecture notes from:
041     * http://logic.stanford.edu/classes/cs157/2008/lectures/lecture13.pdf
042     * 
043     */
044    
045    /**
046     * @author Ciaran O'Reilly
047     * 
048     */
049    public class FOLModelElimination implements InferenceProcedure {
050    
051            // Ten seconds is default maximum query time permitted
052            private long maxQueryTime = 10 * 1000;
053            //
054            private FOLModelEliminationTracer tracer = null;
055            //
056            private Unifier unifier = new Unifier();
057            private SubstVisitor substVisitor = new SubstVisitor();
058    
059            public FOLModelElimination() {
060    
061            }
062    
063            public FOLModelElimination(long maxQueryTime) {
064                    setMaxQueryTime(maxQueryTime);
065            }
066    
067            public FOLModelElimination(FOLModelEliminationTracer tracer) {
068                    this.tracer = tracer;
069            }
070    
071            public FOLModelElimination(FOLModelEliminationTracer tracer,
072                            long maxQueryTime) {
073                    this.tracer = tracer;
074                    setMaxQueryTime(maxQueryTime);
075            }
076    
077            public long getMaxQueryTime() {
078                    return maxQueryTime;
079            }
080    
081            public void setMaxQueryTime(long maxQueryTime) {
082                    this.maxQueryTime = maxQueryTime;
083            }
084    
085            //
086            // START-InferenceProcedure
087    
088            public InferenceResult ask(FOLKnowledgeBase kb, Sentence aQuery) {
089                    //
090                    // Get the background knowledge - are assuming this is satisfiable
091                    // as using Set of Support strategy.
092                    Set<Clause> bgClauses = new LinkedHashSet<Clause>(kb.getAllClauses());
093                    bgClauses.removeAll(SubsumptionElimination.findSubsumedClauses(bgClauses));
094                    List<Chain> background = createChainsFromClauses(bgClauses);
095    
096                    // Collect the information necessary for constructing
097                    // an answer (supports use of answer literals).
098                    AnswerHandler ansHandler = new AnswerHandler(kb, aQuery, maxQueryTime);
099    
100                    IndexedFarParents ifps = new IndexedFarParents(ansHandler
101                                    .getSetOfSupport(), background);
102    
103                    // Iterative deepening to be used
104                    for (int maxDepth = 1; maxDepth < Integer.MAX_VALUE; maxDepth++) {
105                            // Track the depth actually reached
106                            ansHandler.resetMaxDepthReached();
107    
108                            if (null != tracer) {
109                                    tracer.reset();
110                            }
111    
112                            for (Chain nearParent : ansHandler.getSetOfSupport()) {
113                                    recursiveDLS(maxDepth, 0, nearParent, ifps, ansHandler);
114                                    if (ansHandler.isComplete()) {
115                                            return ansHandler;
116                                    }
117                            }
118                            // This means the search tree
119                            // has bottomed out (i.e. finite).
120                            // Return what I know based on exploring everything.
121                            if (ansHandler.getMaxDepthReached() < maxDepth) {
122                                    return ansHandler;
123                            }
124                    }
125    
126                    return ansHandler;
127            }
128    
129            // END-InferenceProcedure
130            //
131    
132            //
133            // PRIVATE METHODS
134            //
135            private List<Chain> createChainsFromClauses(Set<Clause> clauses) {
136                    List<Chain> chains = new ArrayList<Chain>();
137    
138                    for (Clause c : clauses) {
139                            Chain chn = new Chain(c.getLiterals());
140                            chn.setProofStep(new ProofStepChainFromClause(chn, c));
141                            chains.add(chn);
142                            chains.addAll(chn.getContrapositives());
143                    }
144    
145                    return chains;
146            }
147    
148            // Recursive Depth Limited Search
149            private void recursiveDLS(int maxDepth, int currentDepth, Chain nearParent,
150                            IndexedFarParents indexedFarParents, AnswerHandler ansHandler) {
151    
152                    // Keep track of the maximum depth reached.
153                    ansHandler.updateMaxDepthReached(currentDepth);
154    
155                    if (currentDepth == maxDepth) {
156                            return;
157                    }
158    
159                    int noCandidateFarParents = indexedFarParents
160                                    .getNumberCandidateFarParents(nearParent);
161                    if (null != tracer) {
162                            tracer.increment(currentDepth, noCandidateFarParents);
163                    }
164                    indexedFarParents.standardizeApart(nearParent);
165                    for (int farParentIdx = 0; farParentIdx < noCandidateFarParents; farParentIdx++) {
166                            // If have a complete answer, don't keep
167                            // checking candidate far parents
168                            if (ansHandler.isComplete()) {
169                                    break;
170                            }
171    
172                            // Reduction
173                            Chain nextNearParent = indexedFarParents.attemptReduction(
174                                            nearParent, farParentIdx);
175    
176                            if (null == nextNearParent) {
177                                    // Unable to remove the head via reduction
178                                    continue;
179                            }
180    
181                            // Handle Canceling and Dropping
182                            boolean cancelled = false;
183                            boolean dropped = false;
184                            do {
185                                    cancelled = false;
186                                    Chain nextParent = null;
187                                    while (nextNearParent != (nextParent = tryCancellation(nextNearParent))) {
188                                            nextNearParent = nextParent;
189                                            cancelled = true;
190                                    }
191    
192                                    dropped = false;
193                                    while (nextNearParent != (nextParent = tryDropping(nextNearParent))) {
194                                            nextNearParent = nextParent;
195                                            dropped = true;
196                                    }
197                            } while (dropped || cancelled);
198    
199                            // Check if have answer before
200                            // going to the next level
201                            if (!ansHandler.isAnswer(nextNearParent)) {
202                                    // Keep track of the current # of
203                                    // far parents that are possible for the next near parent.
204                                    int noNextFarParents = indexedFarParents
205                                                    .getNumberFarParents(nextNearParent);
206                                    // Add to indexed far parents
207                                    nextNearParent = indexedFarParents.addToIndex(nextNearParent);
208    
209                                    // Check the next level
210                                    recursiveDLS(maxDepth, currentDepth + 1, nextNearParent,
211                                                    indexedFarParents, ansHandler);
212    
213                                    // Reset the number of far parents possible
214                                    // when recursing back up.
215                                    indexedFarParents.resetNumberFarParentsTo(nextNearParent,
216                                                    noNextFarParents);
217                            }
218                    }
219            }
220    
221            // Returns c if no cancellation occurred
222            private Chain tryCancellation(Chain c) {
223                    Literal head = c.getHead();
224                    if (null != head && !(head instanceof ReducedLiteral)) {
225                            for (Literal l : c.getTail()) {
226                                    if (l instanceof ReducedLiteral) {
227                                            // if they can be resolved
228                                            if (head.isNegativeLiteral() != l.isNegativeLiteral()) {
229                                                    Map<Variable, Term> subst = unifier.unify(head
230                                                                    .getAtomicSentence(), l.getAtomicSentence());
231                                                    if (null != subst) {
232                                                            // I have a cancellation
233                                                            // Need to apply subst to all of the
234                                                            // literals in the cancellation
235                                                            List<Literal> cancLits = new ArrayList<Literal>();
236                                                            for (Literal lfc : c.getTail()) {
237                                                                    AtomicSentence a = (AtomicSentence) substVisitor
238                                                                                    .subst(subst, lfc.getAtomicSentence());
239                                                                    cancLits.add(lfc.newInstance(a));
240                                                            }
241                                                            Chain cancellation = new Chain(cancLits);
242                                                            cancellation
243                                                                            .setProofStep(new ProofStepChainCancellation(
244                                                                                            cancellation, c, subst));
245                                                            return cancellation;
246                                                    }
247                                            }
248                                    }
249                            }
250                    }
251                    return c;
252            }
253    
254            // Returns c if no dropping occurred
255            private Chain tryDropping(Chain c) {
256                    Literal head = c.getHead();
257                    if (null != head && (head instanceof ReducedLiteral)) {
258                            Chain dropped = new Chain(c.getTail());
259                            dropped.setProofStep(new ProofStepChainDropped(dropped, c));
260                            return dropped;
261                    }
262    
263                    return c;
264            }
265    
266            class AnswerHandler implements InferenceResult {
267                    private Chain answerChain = new Chain();
268                    private Set<Variable> answerLiteralVariables;
269                    private List<Chain> sos = null;
270                    private boolean complete = false;
271                    private long finishTime = 0L;
272                    private int maxDepthReached = 0;
273                    private List<Proof> proofs = new ArrayList<Proof>();
274                    private boolean timedOut = false;
275    
276                    public AnswerHandler(FOLKnowledgeBase kb, Sentence aQuery,
277                                    long maxQueryTime) {
278    
279                            finishTime = System.currentTimeMillis() + maxQueryTime;
280    
281                            Sentence refutationQuery = new NotSentence(aQuery);
282    
283                            // Want to use an answer literal to pull
284                            // query variables where necessary
285                            Literal answerLiteral = kb.createAnswerLiteral(refutationQuery);
286                            answerLiteralVariables = kb.collectAllVariables(answerLiteral
287                                            .getAtomicSentence());
288    
289                            // Create the Set of Support based on the Query.
290                            if (answerLiteralVariables.size() > 0) {
291                                    Sentence refutationQueryWithAnswer = new ConnectedSentence(
292                                                    Connectors.OR, refutationQuery, answerLiteral
293                                                                    .getAtomicSentence().copy());
294    
295                                    sos = createChainsFromClauses(kb
296                                                    .convertToClauses(refutationQueryWithAnswer));
297    
298                                    answerChain.addLiteral(answerLiteral);
299                            } else {
300                                    sos = createChainsFromClauses(kb
301                                                    .convertToClauses(refutationQuery));
302                            }
303    
304                            for (Chain s : sos) {
305                                    s.setProofStep(new ProofStepGoal(s));
306                            }
307                    }
308    
309                    //
310                    // START-InferenceResult
311                    public boolean isPossiblyFalse() {
312                            return !timedOut && proofs.size() == 0;
313                    }
314    
315                    public boolean isTrue() {
316                            return proofs.size() > 0;
317                    }
318    
319                    public boolean isUnknownDueToTimeout() {
320                            return timedOut && proofs.size() == 0;
321                    }
322    
323                    public boolean isPartialResultDueToTimeout() {
324                            return timedOut && proofs.size() > 0;
325                    }
326    
327                    public List<Proof> getProofs() {
328                            return proofs;
329                    }
330    
331                    // END-InferenceResult
332                    //
333    
334                    public List<Chain> getSetOfSupport() {
335                            return sos;
336                    }
337    
338                    public boolean isComplete() {
339                            return complete;
340                    }
341    
342                    public void resetMaxDepthReached() {
343                            maxDepthReached = 0;
344                    }
345    
346                    public int getMaxDepthReached() {
347                            return maxDepthReached;
348                    }
349    
350                    public void updateMaxDepthReached(int depth) {
351                            if (depth > maxDepthReached) {
352                                    maxDepthReached = depth;
353                            }
354                    }
355    
356                    public boolean isAnswer(Chain nearParent) {
357                            boolean isAns = false;
358                            if (answerChain.isEmpty()) {
359                                    if (nearParent.isEmpty()) {
360                                            proofs.add(new ProofFinal(nearParent.getProofStep(),
361                                                            new HashMap<Variable, Term>()));
362                                            complete = true;
363                                            isAns = true;
364                                    }
365                            } else {
366                                    if (nearParent.isEmpty()) {
367                                            // This should not happen
368                                            // as added an answer literal to sos, which
369                                            // implies the database (i.e. premises) are
370                                            // unsatisfiable to begin with.
371                                            throw new IllegalStateException(
372                                                            "Generated an empty chain while looking for an answer, implies original KB is unsatisfiable");
373                                    }
374                                    if (1 == nearParent.getNumberLiterals()
375                                                    && nearParent.getHead().getAtomicSentence()
376                                                                    .getSymbolicName().equals(
377                                                                                    answerChain.getHead()
378                                                                                                    .getAtomicSentence()
379                                                                                                    .getSymbolicName())) {
380                                            Map<Variable, Term> answerBindings = new HashMap<Variable, Term>();
381                                            List<Term> answerTerms = nearParent.getHead()
382                                                            .getAtomicSentence().getArgs();
383                                            int idx = 0;
384                                            for (Variable v : answerLiteralVariables) {
385                                                    answerBindings.put(v, answerTerms.get(idx));
386                                                    idx++;
387                                            }
388                                            boolean addNewAnswer = true;
389                                            for (Proof p : proofs) {
390                                                    if (p.getAnswerBindings().equals(answerBindings)) {
391                                                            addNewAnswer = false;
392                                                            break;
393                                                    }
394                                            }
395                                            if (addNewAnswer) {
396                                                    proofs.add(new ProofFinal(nearParent.getProofStep(),
397                                                                    answerBindings));
398                                            }
399                                            isAns = true;
400                                    }
401                            }
402    
403                            if (System.currentTimeMillis() > finishTime) {
404                                    complete = true;
405                                    // Indicate that I have run out of query time
406                                    timedOut = true;
407                            }
408    
409                            return isAns;
410                    }
411    
412                    public String toString() {
413                            StringBuilder sb = new StringBuilder();
414                            sb.append("isComplete=" + complete);
415                            sb.append("\n");
416                            sb.append("result=" + proofs);
417                            return sb.toString();
418                    }
419            }
420    }
421    
422    class IndexedFarParents {
423            //
424            private int saIdx = 0;
425            private Unifier unifier = new Unifier();
426            private SubstVisitor substVisitor = new SubstVisitor();
427            //
428            private Map<String, List<Chain>> posHeads = new LinkedHashMap<String, List<Chain>>();
429            private Map<String, List<Chain>> negHeads = new LinkedHashMap<String, List<Chain>>();
430    
431            public IndexedFarParents(List<Chain> sos, List<Chain> background) {
432                    constructInternalDataStructures(sos, background);
433            }
434    
435            public int getNumberFarParents(Chain farParent) {
436                    Literal head = farParent.getHead();
437    
438                    Map<String, List<Chain>> heads = null;
439                    if (head.isPositiveLiteral()) {
440                            heads = posHeads;
441                    } else {
442                            heads = negHeads;
443                    }
444                    String headKey = head.getAtomicSentence().getSymbolicName();
445    
446                    List<Chain> farParents = heads.get(headKey);
447                    if (null != farParents) {
448                            return farParents.size();
449                    }
450                    return 0;
451            }
452    
453            public void resetNumberFarParentsTo(Chain farParent, int toSize) {
454                    Literal head = farParent.getHead();
455                    Map<String, List<Chain>> heads = null;
456                    if (head.isPositiveLiteral()) {
457                            heads = posHeads;
458                    } else {
459                            heads = negHeads;
460                    }
461                    String key = head.getAtomicSentence().getSymbolicName();
462                    List<Chain> farParents = heads.get(key);
463                    while (farParents.size() > toSize) {
464                            farParents.remove(farParents.size() - 1);
465                    }
466            }
467    
468            public int getNumberCandidateFarParents(Chain nearParent) {
469                    Literal nearestHead = nearParent.getHead();
470    
471                    Map<String, List<Chain>> candidateHeads = null;
472                    if (nearestHead.isPositiveLiteral()) {
473                            candidateHeads = negHeads;
474                    } else {
475                            candidateHeads = posHeads;
476                    }
477    
478                    String nearestKey = nearestHead.getAtomicSentence().getSymbolicName();
479    
480                    List<Chain> farParents = candidateHeads.get(nearestKey);
481                    if (null != farParents) {
482                            return farParents.size();
483                    }
484                    return 0;
485            }
486    
487            public Chain attemptReduction(Chain nearParent, int farParentIndex) {
488                    Chain nnpc = null;
489    
490                    Literal nearLiteral = nearParent.getHead();
491    
492                    Map<String, List<Chain>> candidateHeads = null;
493                    if (nearLiteral.isPositiveLiteral()) {
494                            candidateHeads = negHeads;
495                    } else {
496                            candidateHeads = posHeads;
497                    }
498    
499                    AtomicSentence nearAtom = nearLiteral.getAtomicSentence();
500                    String nearestKey = nearAtom.getSymbolicName();
501                    List<Chain> farParents = candidateHeads.get(nearestKey);
502                    if (null != farParents) {
503                            Chain farParent = farParents.get(farParentIndex);
504                            standardizeApart(farParent);
505                            Literal farLiteral = farParent.getHead();
506                            AtomicSentence farAtom = farLiteral.getAtomicSentence();
507                            Map<Variable, Term> subst = unifier.unify(nearAtom, farAtom);
508    
509                            // If I was able to unify with one
510                            // of the far heads
511                            if (null != subst) {
512                                    // Want to always apply reduction uniformly
513                                    Chain topChain = farParent;
514                                    Literal botLit = nearLiteral;
515                                    Chain botChain = nearParent;
516    
517                                    // Need to apply subst to all of the
518                                    // literals in the reduction
519                                    List<Literal> reduction = new ArrayList<Literal>();
520                                    for (Literal l : topChain.getTail()) {
521                                            AtomicSentence atom = (AtomicSentence) substVisitor.subst(
522                                                            subst, l.getAtomicSentence());
523                                            reduction.add(l.newInstance(atom));
524                                    }
525                                    reduction.add(new ReducedLiteral((AtomicSentence) substVisitor
526                                                    .subst(subst, botLit.getAtomicSentence()), botLit
527                                                    .isNegativeLiteral()));
528                                    for (Literal l : botChain.getTail()) {
529                                            AtomicSentence atom = (AtomicSentence) substVisitor.subst(
530                                                            subst, l.getAtomicSentence());
531                                            reduction.add(l.newInstance(atom));
532                                    }
533    
534                                    nnpc = new Chain(reduction);
535                                    nnpc.setProofStep(new ProofStepChainReduction(nnpc, nearParent,
536                                                    farParent, subst));
537                            }
538                    }
539    
540                    return nnpc;
541            }
542    
543            public Chain addToIndex(Chain c) {
544                    Chain added = null;
545                    Literal head = c.getHead();
546                    if (null != head) {
547                            Map<String, List<Chain>> toAddTo = null;
548                            if (head.isPositiveLiteral()) {
549                                    toAddTo = posHeads;
550                            } else {
551                                    toAddTo = negHeads;
552                            }
553    
554                            String key = head.getAtomicSentence().getSymbolicName();
555                            List<Chain> farParents = toAddTo.get(key);
556                            if (null == farParents) {
557                                    farParents = new ArrayList<Chain>();
558                                    toAddTo.put(key, farParents);
559                            }
560                            
561                            added = c;
562                            farParents.add(added);
563                    }
564                    return added;
565            }
566            
567            public void standardizeApart(Chain c) {
568                    saIdx = StandardizeApartInPlace.standardizeApart(c, saIdx);
569            }
570    
571            public String toString() {
572                    StringBuilder sb = new StringBuilder();
573    
574                    sb.append("#");
575                    sb.append(posHeads.size());
576                    for (String key : posHeads.keySet()) {
577                            sb.append(",");
578                            sb.append(posHeads.get(key).size());
579                    }
580                    sb.append(" posHeads=");
581                    sb.append(posHeads.toString());
582                    sb.append("\n");
583                    sb.append("#");
584                    sb.append(negHeads.size());
585                    for (String key : negHeads.keySet()) {
586                            sb.append(",");
587                            sb.append(negHeads.get(key).size());
588                    }
589                    sb.append(" negHeads=");
590                    sb.append(negHeads.toString());
591    
592                    return sb.toString();
593            }
594    
595            // 
596            // PRIVATE METHODS
597            //
598            private void constructInternalDataStructures(List<Chain> sos,
599                            List<Chain> background) {
600                    List<Chain> toIndex = new ArrayList<Chain>();
601                    toIndex.addAll(sos);
602                    toIndex.addAll(background);
603    
604                    for (Chain c : toIndex) {
605                            addToIndex(c);
606                    }
607            }
608    }