diff --git a/exception.go b/error.go similarity index 54% rename from exception.go rename to error.go index e6995b29..9b8b0054 100644 --- a/exception.go +++ b/error.go @@ -19,6 +19,24 @@ func (e *ResourceError) Error() string { return fmt.Sprintf("insufficient resource: %s", e.Resource) } +// RepresentationError is an error that signifies one of the implementation limits exceeded. +type RepresentationError struct { + flag string +} + +func (e *RepresentationError) Error() string { + return fmt.Sprintf("implementation limit exceeded: %s", e.flag) +} + +// SyntaxError is an error that signifies a syntax error. +type SyntaxError struct { + impDepAtom string +} + +func (e *SyntaxError) Error() string { + return fmt.Sprintf("syntax error: %s", e.impDepAtom) +} + // TypeError is an error that signifies an incorrect type. type TypeError struct { ValidType string @@ -29,6 +47,16 @@ func (e *TypeError) Error() string { return fmt.Sprintf("invalid type: expected %s, got %s", e.ValidType, e.Culprit) } +// DomainError is an error that signifies an incorrect value. +type DomainError struct { + ValidDomain string + Culprit Term +} + +func (e *DomainError) Error() string { + return fmt.Sprintf("invalid domain: expected %s, got %s", e.ValidDomain, e.Culprit) +} + // UninstantiationError is an error that signifies a term is non-variable. type UninstantiationError struct { Culprit Term diff --git a/exception_test.go b/error_test.go similarity index 100% rename from exception_test.go rename to error_test.go diff --git a/formatter.go b/formatter.go new file mode 100644 index 00000000..084c1aea --- /dev/null +++ b/formatter.go @@ -0,0 +1,493 @@ +package prolog + +import ( + "fmt" + "io" + "regexp" + "strconv" + "strings" +) + +type formatState struct { + priority int + visited map[Term]struct{} + prefixMinus bool + left, right operator + depth int +} + +type Formatter struct { + Term Term + Heap *Heap + + IgnoreOps bool + Quoted bool + VariableName map[Variable]string + NumberVars bool + + Ops Operators + MaxDepth int + Precision int +} + +func (f *Formatter) Format(s fmt.State, verb rune) { + c := *f + c.Quoted = verb == 'q' + c.IgnoreOps = s.Flag('-') + c.NumberVars = s.Flag('#') + + if w, ok := s.Width(); ok { + c.MaxDepth = w + } else { + c.MaxDepth = 10 + } + + if p, ok := s.Precision(); ok { + c.Precision = p + } else { + c.Precision = -1 + } + + _, _ = c.WriteTo(s) +} + +func (f *Formatter) WriteTo(w io.Writer) (int64, error) { + state := formatState{ + priority: 1201, + } + return writeTerm(w, f.Heap, f.Term, f, state) +} + +func writeTerm(w io.Writer, h *Heap, t Term, opts *Formatter, state formatState) (int64, error) { + t = t.resolve(h) + + if _, ok := state.visited[t]; ok || (opts.MaxDepth > 0 && state.depth > opts.MaxDepth) { + return writeAtom(w, "...", opts, state) + } + + if v, err := t.Variable(h); err == nil { + return writeVariable(w, v, opts, state) + } + + if name, err := t.Atom(h); err == nil { + return writeAtom(w, name, opts, state) + } + + if i, err := t.Integer(h); err == nil { + return writeInteger(w, i, opts, state) + } + + if f, err := t.Float(h); err == nil { + return writeFloat(w, f, opts, state) + } + + if state.visited == nil { + state.visited = map[Term]struct{}{} + } + state.visited[t] = struct{}{} + + c, err := t.Compound(h) + if err != nil { + return 0, err + } + + return writeCompound(w, h, c, opts, state) +} + +func writeVariable(w io.Writer, v Variable, opts *Formatter, state formatState) (int64, error) { + ew := errWriter{w: w} + if letterDigit(state.left.name) { + _, _ = fmt.Fprint(&ew, " ") + } + if name, ok := opts.VariableName[v]; ok { + f := *opts + f.Quoted = false + _, _ = writeAtom(&ew, name, opts, state) + } else { + _, _ = fmt.Fprintf(&ew, "_%d", v) + } + if letterDigit(state.right.name) { + _, _ = fmt.Fprint(&ew, " ") + } + return ew.Result() +} + +func writeAtom(w io.Writer, name string, opts *Formatter, state formatState) (int64, error) { + ew := errWriter{w: w} + openClose := (state.left != (operator{}) || state.right != (operator{})) && opts.Ops.defined(name) + + if openClose { + if state.left.name != "" && state.left.specifier.class() == operatorClassPrefix { + _, _ = fmt.Fprint(&ew, " ") + } + _, _ = fmt.Fprint(&ew, "(") + state.left, state.right = operator{}, operator{} + } + + if opts.Quoted && needQuoted(name) { + if state.left != (operator{}) && needQuoted(state.left.name) { // Avoid 'FOO''BAR'. + _, _ = fmt.Fprint(&ew, " ") + } + _, _ = ew.Write([]byte(quote(name))) + if state.right != (operator{}) && needQuoted(state.right.name) { // Avoid 'FOO''BAR'. + _, _ = fmt.Fprint(&ew, " ") + } + } else { + if (letterDigit(state.left.name) && letterDigit(name)) || (graphic(state.left.name) && graphic(name)) { + _, _ = fmt.Fprint(&ew, " ") + } + _, _ = fmt.Fprint(&ew, name) + if (letterDigit(state.right.name) && letterDigit(name)) || (graphic(state.right.name) && graphic(name)) { + _, _ = fmt.Fprint(&ew, " ") + } + } + + if openClose { + _, _ = fmt.Fprint(&ew, ")") + } + + return ew.Result() +} + +func needQuoted(name string) bool { + p := NewParser(strings.NewReader(name), Operators{}, doubleQuotesChars) + parsed, ok, err := p.atom() + return err != nil || !ok || parsed != name +} + +var ( + quotedAtomEscapePattern = regexp.MustCompile(`[[:cntrl:]]|\\|'`) +) + +func quotedIdentEscape(s string) string { + switch s { + case "\a": + return `\a` + case "\b": + return `\b` + case "\f": + return `\f` + case "\n": + return `\n` + case "\r": + return `\r` + case "\t": + return `\t` + case "\v": + return `\v` + case `\`: + return `\\` + case `'`: + return `\'` + default: + var ret []string + for _, r := range s { + ret = append(ret, fmt.Sprintf(`\x%x\`, r)) + } + return strings.Join(ret, "") + } +} + +func quote(s string) string { + return fmt.Sprintf("'%s'", quotedAtomEscapePattern.ReplaceAllStringFunc(s, quotedIdentEscape)) +} + +func letterDigit(s string) bool { + return len(s) > 0 && isSmallLetterChar([]rune(s)[0]) +} + +func graphic(s string) bool { + return len(s) > 0 && (isGraphicChar([]rune(s)[0]) || []rune(s)[0] == '\\') +} + +func writeInteger(w io.Writer, i int64, _ *Formatter, state formatState) (int64, error) { + ew := errWriter{w: w} + openClose := state.left.name == "-" && state.left.specifier.class() == operatorClassPrefix && i > 0 + + if openClose { + _, _ = ew.Write([]byte(" (")) + state.left = operator{} + state.right = operator{} + } else { + if state.left != (operator{}) && (letterDigit(state.left.name) || (i < 0 && graphic(state.left.name))) { + _, _ = ew.Write([]byte(" ")) + } + } + + s := strconv.FormatInt(i, 10) + _, _ = ew.Write([]byte(s)) + + if openClose { + _, _ = ew.Write([]byte(")")) + } + + // Avoid ambiguous 0b, 0o, 0x or 0'. + if !openClose && state.right != (operator{}) && (letterDigit(state.right.name) || (needQuoted(state.right.name) && state.right.name != "," && state.right.name != "|")) { + _, _ = ew.Write([]byte(" ")) + } + + return ew.Result() +} + +func writeFloat(w io.Writer, f float64, opts *Formatter, state formatState) (int64, error) { + ew := errWriter{w: w} + openClose := state.left.name == "-" && state.left.specifier.class() == operatorClassPrefix && f > 0 + + if openClose || (f < 0 && state.left != operator{}) { + _, _ = ew.Write([]byte(" ")) + } + + if openClose { + _, _ = ew.Write([]byte("(")) + } + + s := strconv.FormatFloat(f, 'g', opts.Precision, 64) + if !strings.ContainsRune(s, '.') { + if strings.ContainsRune(s, 'e') { + s = strings.Replace(s, "e", ".0e", 1) + } else { + s += ".0" + } + } + _, _ = ew.Write([]byte(s)) + + if openClose { + _, _ = ew.Write([]byte(")")) + } + + if !openClose && state.right != (operator{}) && (state.right.name == "e" || state.right.name == "E") { + _, _ = ew.Write([]byte(" ")) + } + + return ew.Result() +} + +func writeCompound(w io.Writer, h *Heap, c *Compound, opts *Formatter, state formatState) (int64, error) { + if c.Functor == (Functor{Name: "$VAR", Arity: 1}) && opts.NumberVars { + a := c.Arg(h, 0) + if n, err := a.Integer(h); err == nil { + return writeCompoundNumberVars(w, n) + } + } + + if !opts.IgnoreOps { + switch c.Functor { + case Functor{Name: ".", Arity: 2}: + return writeCompoundList(w, h, c, opts, state) + case Functor{Name: "{}", Arity: 1}: + return writeCompoundCurlyBracketed(w, h, c, opts, state) + } + + ops := opts.Ops.ops + switch c.Arity { + case 1: + if op, ok := ops[opKey{name: c.Name, opClass: operatorClassPrefix}]; ok { + return writeCompoundOpPrefix(w, h, c, &op, opts, state) + } + if op, ok := ops[opKey{name: c.Name, opClass: operatorClassPostfix}]; ok { + return writeCompoundOpPostfix(w, h, c, &op, opts, state) + } + case 2: + if op, ok := ops[opKey{name: c.Name, opClass: operatorClassInfix}]; ok { + return writeCompoundOpInfix(w, h, c, &op, opts, state) + } + } + } + + return writeCompoundFunctionalNotation(w, h, c, opts, state) +} + +func writeCompoundNumberVars(w io.Writer, n int64) (int64, error) { + const letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + ew := errWriter{w: w} + i, j := int(n)%len(letters), int(n)/len(letters) + _, _ = fmt.Fprint(&ew, string(letters[i])) + if j != 0 { + _, _ = fmt.Fprint(&ew, strconv.Itoa(j)) + } + return ew.Result() +} + +func writeCompoundList(w io.Writer, h *Heap, c *Compound, opts *Formatter, state formatState) (int64, error) { + ew := errWriter{w: w} + state.priority = 999 + state.left = operator{} + state.right = operator{} + _, _ = fmt.Fprint(&ew, "[") + _, _ = writeTerm(&ew, h, c.Arg(h, 0), opts, state) + for elem, err := range c.Arg(h, 1).List(h, AllowCycle(opts.MaxDepth > state.depth)) { + if err != nil { + _, _ = fmt.Fprint(&ew, "|") + if c, err := elem.Compound(h); err == nil && c.Functor == (Functor{Name: ".", Arity: 2}) { + _, _ = writeAtom(&ew, "...", opts, state) + } else { + _, _ = writeTerm(&ew, h, elem, opts, state) + } + break + } + + state.depth++ + _, _ = fmt.Fprint(&ew, ",") + _, _ = writeTerm(&ew, h, elem, opts, state) + } + _, _ = fmt.Fprint(&ew, "]") + return ew.Result() +} + +func writeCompoundCurlyBracketed(w io.Writer, h *Heap, c *Compound, opts *Formatter, state formatState) (int64, error) { + ew := errWriter{w: w} + state.left = operator{} + _, _ = fmt.Fprint(&ew, "{") + _, _ = writeTerm(&ew, h, c.Arg(h, 0), opts, state) + _, _ = fmt.Fprint(&ew, "}") + return ew.Result() +} + +func writeCompoundOpPrefix(w io.Writer, h *Heap, c *Compound, op *operator, opts *Formatter, state formatState) (int64, error) { + ew := errWriter{w: w} + _, r := op.bindingPriorities() + openClose := state.priority < op.priority || (state.right != operator{} && r >= state.right.priority) + + if state.left != (operator{}) { + _, _ = fmt.Fprint(&ew, " ") + } + if openClose { + _, _ = fmt.Fprint(&ew, "(") + state.left = operator{} + state.right = operator{} + } + { + state := state + state.left = operator{} + state.right = operator{} + _, _ = writeAtom(&ew, c.Name, opts, state) + } + { + state := state + state.priority = r + state.left = *op + state.depth++ + _, _ = writeTerm(&ew, h, c.Arg(h, 0), opts, state) + } + if openClose { + _, _ = fmt.Fprint(&ew, ")") + } + return ew.Result() +} + +func writeCompoundOpPostfix(w io.Writer, h *Heap, c *Compound, op *operator, opts *Formatter, state formatState) (int64, error) { + ew := errWriter{w: w} + l, _ := op.bindingPriorities() + openClose := state.priority < op.priority || (state.left.name == "-" && state.left.specifier.class() == operatorClassPrefix) + + if openClose { + if state.left != (operator{}) { + _, _ = fmt.Fprint(&ew, " ") + } + _, _ = fmt.Fprint(&ew, "(") + state.left = operator{} + state.right = operator{} + } + { + state := state + state.priority = l + state.right = *op + state.depth++ + _, _ = writeTerm(&ew, h, c.Arg(h, 0), opts, state) + } + { + state := state + state.left = operator{} + state.right = operator{} + _, _ = writeAtom(&ew, c.Name, opts, state) + } + if openClose { + _, _ = fmt.Fprint(&ew, ")") + } else if state.right != (operator{}) { + _, _ = fmt.Fprint(&ew, " ") + } + return ew.Result() +} + +func writeCompoundOpInfix(w io.Writer, h *Heap, c *Compound, op *operator, opts *Formatter, state formatState) (int64, error) { + ew := errWriter{w: w} + l, r := op.bindingPriorities() + openClose := state.priority < op.priority || + (state.left.name == "-" && state.left.specifier.class() == operatorClassPrefix) || + (state.right != operator{} && r >= state.right.priority) + + if openClose { + if state.left != (operator{}) && state.left.specifier.class() == operatorClassPrefix { + _, _ = fmt.Fprint(&ew, " ") + } + _, _ = fmt.Fprint(&ew, "(") + state.left = operator{} + state.right = operator{} + } + { + state := state + state.priority = l + state.right = *op + state.depth++ + _, _ = writeTerm(&ew, h, c.Arg(h, 0), opts, state) + } + switch name := c.Name; name { + case ",", "|": + _, _ = fmt.Fprint(&ew, name) + default: + state := state + state.left = operator{} + state.right = operator{} + _, _ = writeAtom(&ew, name, opts, state) + } + { + state := state + state.priority = r + state.left = *op + state.depth++ + _, _ = writeTerm(&ew, h, c.Arg(h, 1), opts, state) + } + if openClose { + _, _ = fmt.Fprint(&ew, ")") + } + return ew.Result() +} + +func writeCompoundFunctionalNotation(w io.Writer, h *Heap, c *Compound, opts *Formatter, state formatState) (int64, error) { + ew := errWriter{w: w} + state.right = operator{} + _, _ = writeAtom(&ew, c.Name, opts, state) + _, _ = fmt.Fprint(&ew, "(") + state.left = operator{} + state.priority = 999 + state.depth++ + for i, a := range c.Args(h) { + if i != 0 { + _, _ = fmt.Fprint(&ew, ",") + } + _, _ = writeTerm(&ew, h, a, opts, state) + } + _, _ = fmt.Fprint(&ew, ")") + return ew.Result() +} + +// https://go.dev/blog/errors-are-values +type errWriter struct { + w io.Writer + n int64 + err error +} + +func (ew *errWriter) Write(p []byte) (int, error) { + if ew.err != nil { + return 0, nil + } + var n int + n, ew.err = ew.w.Write(p) + ew.n += int64(n) + return n, nil +} + +func (ew *errWriter) Result() (int64, error) { + return ew.n, ew.err +} diff --git a/formatter_test.go b/formatter_test.go new file mode 100644 index 00000000..3c68179e --- /dev/null +++ b/formatter_test.go @@ -0,0 +1,488 @@ +package prolog + +import ( + "bytes" + "fmt" + "reflect" + "testing" +) + +func TestFormatter_WriteTo(t *testing.T) { + h := NewHeap(30 * 1024) + + x, err := NewVariable(h) + if err != nil { + t.Fatal(err) + } + + y, err := NewVariable(h) + if err != nil { + t.Fatal(err) + } + + a, err := NewAtom(h, "a") + if err != nil { + t.Fatal(err) + } + + b, err := NewAtom(h, "b") + if err != nil { + t.Fatal(err) + } + + c, err := NewAtom(h, "c") + if err != nil { + t.Fatal(err) + } + + X, err := NewAtom(h, "X") + if err != nil { + t.Fatal(err) + } + + rest, err := NewAtom(h, "rest") + if err != nil { + t.Fatal(err) + } + + escapeSequence, err := NewAtom(h, "\a\b\f\n\r\t\v\x00\\'\"`") + if err != nil { + t.Fatal(err) + } + + comma, err := NewAtom(h, ",") + if err != nil { + t.Fatal(err) + } + + emptyList, err := NewAtom(h, "[]") + if err != nil { + t.Fatal(err) + } + + emptyBlock, err := NewAtom(h, "{}") + if err != nil { + t.Fatal(err) + } + + minus, err := NewAtom(h, "-") + if err != nil { + t.Fatal(err) + } + + foo, err := NewAtom(h, "foo") + if err != nil { + t.Fatal(err) + } + + bar, err := NewAtom(h, "bar") + if err != nil { + t.Fatal(err) + } + + baz, err := NewAtom(h, "baz") + if err != nil { + t.Fatal(err) + } + + thirtyThree, err := NewInteger(h, 33) + if err != nil { + t.Fatal(err) + } + + minusThirtyThree, err := NewInteger(h, -33) + if err != nil { + t.Fatal(err) + } + + zero, err := NewInteger(h, 0) + if err != nil { + t.Fatal(err) + } + + one, err := NewInteger(h, 1) + if err != nil { + t.Fatal(err) + } + + two, err := NewInteger(h, 2) + if err != nil { + t.Fatal(err) + } + + twentyFive, err := NewInteger(h, 25) + if err != nil { + t.Fatal(err) + } + + twentySix, err := NewInteger(h, 26) + if err != nil { + t.Fatal(err) + } + + twentySeven, err := NewInteger(h, 27) + if err != nil { + t.Fatal(err) + } + + minusTwo, err := NewInteger(h, -2) + if err != nil { + t.Fatal(err) + } + + floatThirtyThree, err := NewFloat(h, 33) + if err != nil { + t.Fatal(err) + } + + floatWithE, err := NewFloat(h, 3.0e+100) + if err != nil { + t.Fatal(err) + } + + floatMinusThirtyThree, err := NewFloat(h, -33) + if err != nil { + t.Fatal(err) + } + + list, err := NewList(h, a, b, c) + if err != nil { + t.Fatal(err) + } + + listish, err := NewPartialList(h, rest, a, b) + if err != nil { + t.Fatal(err) + } + + v, err := NewVariable(h) + if err != nil { + t.Fatal(err) + } + circularList, err := NewPartialList(h, v, a, b) + if err != nil { + t.Fatal(err) + } + if err := h.env.Values.Set(Variable(v.payload), circularList); err != nil { + t.Fatal(err) + } + + curlyBrackets, err := NewCompound(h, "{}", foo) + if err != nil { + t.Fatal(err) + } + + ifFoo, err := NewCompound(h, ":-", foo) + if err != nil { + t.Fatal(err) + } + ifIfFoo, err := NewCompound(h, ":-", ifFoo) + if err != nil { + t.Fatal(err) + } + + notFoo, err := NewCompound(h, `\+`, foo) + if err != nil { + t.Fatal(err) + } + + minusNotFoo, err := NewCompound(h, `-`, notFoo) + if err != nil { + t.Fatal(err) + } + + notMinusNotFoo, err := NewCompound(h, `\+`, minusNotFoo) + if err != nil { + t.Fatal(err) + } + + fiFoo, err := NewCompound(h, `-:`, foo) + if err != nil { + t.Fatal(err) + } + + fiFiFoo, err := NewCompound(h, `-:`, fiFoo) + if err != nil { + t.Fatal(err) + } + + tonFoo, err := NewCompound(h, `+/`, foo) + if err != nil { + t.Fatal(err) + } + + minusMinusTonFoo, err := NewCompound(h, `--`, tonFoo) + if err != nil { + t.Fatal(err) + } + + tonMinusMinusTonFoo, err := NewCompound(h, `+/`, minusMinusTonFoo) + if err != nil { + t.Fatal(err) + } + + ifBarBaz, err := NewCompound(h, `:-`, bar, baz) + if err != nil { + t.Fatal(err) + } + + ifFooIfBarBaz, err := NewCompound(h, `:-`, foo, ifBarBaz) + if err != nil { + t.Fatal(err) + } + + plusTwoTwo, err := NewCompound(h, `+`, two, two) + if err != nil { + t.Fatal(err) + } + + asteriskTwoPlusTwoTwo, err := NewCompound(h, `*`, two, plusTwoTwo) + if err != nil { + t.Fatal(err) + } + + barTwoTwo, err := NewCompound(h, `|`, two, two) + if err != nil { + t.Fatal(err) + } + + commaTwoBarTwoTwo, err := NewCompound(h, `,`, two, barTwoTwo) + if err != nil { + t.Fatal(err) + } + + plusTwoMinusTwo, err := NewCompound(h, `+`, two, minusTwo) + if err != nil { + t.Fatal(err) + } + + varZero, err := NewCompound(h, `$VAR`, zero) + if err != nil { + t.Fatal(err) + } + + varOne, err := NewCompound(h, `$VAR`, one) + if err != nil { + t.Fatal(err) + } + + varTwentyFive, err := NewCompound(h, `$VAR`, twentyFive) + if err != nil { + t.Fatal(err) + } + + varTwentySix, err := NewCompound(h, `$VAR`, twentySix) + if err != nil { + t.Fatal(err) + } + + varTwentySeven, err := NewCompound(h, `$VAR`, twentySeven) + if err != nil { + t.Fatal(err) + } + + fVars, err := NewCompound(h, `f`, varZero, varOne, varTwentyFive, varTwentySix, varTwentySeven) + if err != nil { + t.Fatal(err) + } + + minusB, err := NewCompound(h, `-`, b) + if err != nil { + t.Fatal(err) + } + + asteriskAMinusB, err := NewCompound(h, `*`, a, minusB) + if err != nil { + t.Fatal(err) + } + + tonA, err := NewCompound(h, `+/`, a) + if err != nil { + t.Fatal(err) + } + + minusTonA, err := NewCompound(h, `-`, tonA) + if err != nil { + t.Fatal(err) + } + + asteriskAB, err := NewCompound(h, `*`, a, b) + if err != nil { + t.Fatal(err) + } + + minusAsteriskAB, err := NewCompound(h, `-`, asteriskAB) + if err != nil { + t.Fatal(err) + } + + w, err := NewVariable(h) + if err != nil { + t.Fatal(err) + } + r, err := NewCompound(h, "f", w) + if err != nil { + t.Fatal(err) + } + if err := h.env.Values.Set(Variable(w.payload), r); err != nil { + t.Fatal(err) + } + + isXY, err := NewCompound(h, "is", x, y) + if err != nil { + t.Fatal(err) + } + + minusMinus, err := NewCompound(h, "-", minus) + if err != nil { + t.Fatal(err) + } + + minusMinusMinus, err := NewCompound(h, "--", minus) + if err != nil { + t.Fatal(err) + } + + FXX, err := NewCompound(h, `F`, X, X) + if err != nil { + t.Fatal(err) + } + + isFooFoo, err := NewCompound(h, `is`, foo, foo) + if err != nil { + t.Fatal(err) + } + + unaryMinusThirtyThree, err := NewCompound(h, `-`, thirtyThree) + if err != nil { + t.Fatal(err) + } + + b0Zero, err := NewCompound(h, `b0`, zero) + if err != nil { + t.Fatal(err) + } + + o0Zero, err := NewCompound(h, `o0`, zero) + if err != nil { + t.Fatal(err) + } + + x0Zero, err := NewCompound(h, `x0`, zero) + if err != nil { + t.Fatal(err) + } + + FooZero, err := NewCompound(h, `Foo`, zero) + if err != nil { + t.Fatal(err) + } + + minusFloatThirtyThree, err := NewCompound(h, `-`, floatThirtyThree) + if err != nil { + t.Fatal(err) + } + + eFloatThirtyThree, err := NewCompound(h, `e`, floatThirtyThree) + if err != nil { + t.Fatal(err) + } + + var ops Operators + ops.Define(1200, XFX, `:-`) + ops.Define(1200, FX, `:-`) + ops.Define(1200, XF, `-:`) + ops.Define(1105, XFY, `|`) + ops.Define(1000, XFY, `,`) + ops.Define(900, FY, `\+`) + ops.Define(900, YF, `+/`) + ops.Define(700, XFX, `is`) + ops.Define(700, XFX, `F`) + ops.Define(500, YFX, `+`) + ops.Define(400, YFX, `*`) + ops.Define(200, FY, "+") + ops.Define(200, FY, `-`) + ops.Define(200, YF, `--`) + ops.Define(200, YF, `b0`) + ops.Define(200, YF, `o0`) + ops.Define(200, YF, `x0`) + ops.Define(200, YF, `Foo`) + ops.Define(200, YF, `e`) + + tests := []struct { + title string + formatter Formatter + output string + err error + }{ + {title: "variable: unnamed", formatter: Formatter{Term: x, Heap: h}, output: fmt.Sprintf("_%d", x.payload)}, + {title: "variable: variable_names", formatter: Formatter{Term: x, Heap: h, VariableName: map[Variable]string{ + Variable(x.payload): "Foo", + }}, output: `Foo`}, + + {title: "atom: a", formatter: Formatter{Term: a, Heap: h, Quoted: false}, output: `a`}, + {title: "atom: a with quoted", formatter: Formatter{Term: a, Heap: h, Quoted: true}, output: `a`}, + {title: "atom: escape sequence", formatter: Formatter{Term: escapeSequence, Heap: h, Quoted: false}, output: "\a\b\f\n\r\t\v\x00\\'\"`"}, + {title: "atom: escape sequence with quoted", formatter: Formatter{Term: escapeSequence, Heap: h, Quoted: true}, output: "'\\a\\b\\f\\n\\r\\t\\v\\x0\\\\\\\\'\"`'"}, + {title: "atom: comma", formatter: Formatter{Term: comma, Heap: h, Quoted: false}, output: `,`}, + {title: "atom: comma with quoted", formatter: Formatter{Term: comma, Heap: h, Quoted: true}, output: `','`}, + {title: "atom: empty list", formatter: Formatter{Term: emptyList, Heap: h, Quoted: false}, output: `[]`}, + {title: "atom: empty list with quoted", formatter: Formatter{Term: emptyList, Heap: h, Quoted: true}, output: `[]`}, + {title: "atom: empty block", formatter: Formatter{Term: emptyBlock, Heap: h, Quoted: false}, output: `{}`}, + {title: "atom: empty block with quoted", formatter: Formatter{Term: emptyBlock, Heap: h, Quoted: true}, output: `{}`}, + {title: "atom: minus", formatter: Formatter{Term: minus, Heap: h}, output: `-`}, + + {title: "integer: positive", formatter: Formatter{Term: thirtyThree, Heap: h}, output: `33`}, + {title: "integer: negative", formatter: Formatter{Term: minusThirtyThree, Heap: h}, output: `-33`}, + + {title: "float: positive", formatter: Formatter{Term: floatThirtyThree, Heap: h, Precision: -1}, output: `33.0`}, + {title: "float: with e", formatter: Formatter{Term: floatWithE, Heap: h, Precision: -1}, output: `3.0e+100`}, + {title: "float: negative", formatter: Formatter{Term: floatMinusThirtyThree, Heap: h, Precision: -1}, output: `-33.0`}, + + {title: "compound: list", formatter: Formatter{Term: list, Heap: h}, output: `[a,b,c]`}, + {title: "compound: list-ish", formatter: Formatter{Term: listish, Heap: h}, output: `[a,b|rest]`}, + {title: "compound: circular list", formatter: Formatter{Term: circularList, Heap: h}, output: `[a,b,a|...]`}, + {title: "compound: curly brackets", formatter: Formatter{Term: curlyBrackets, Heap: h}, output: `{foo}`}, + {title: "compound: fx", formatter: Formatter{Term: ifIfFoo, Heap: h, Ops: ops}, output: `:- (:-foo)`}, + {title: "compound: fy", formatter: Formatter{Term: notMinusNotFoo, Heap: h, Ops: ops}, output: `\+ - (\+foo)`}, + {title: "compound: xf", formatter: Formatter{Term: fiFiFoo, Heap: h, Ops: ops}, output: `(foo-:)-:`}, + {title: "compound: yf", formatter: Formatter{Term: tonMinusMinusTonFoo, Heap: h, Ops: ops}, output: `(foo+/)-- +/`}, + {title: "compound: xfx", formatter: Formatter{Term: ifFooIfBarBaz, Heap: h, Ops: ops}, output: `foo:-(bar:-baz)`}, + {title: "compound: yfx", formatter: Formatter{Term: asteriskTwoPlusTwoTwo, Heap: h, Ops: ops}, output: `2*(2+2)`}, + {title: "compound: xfy", formatter: Formatter{Term: commaTwoBarTwoTwo, Heap: h, Ops: ops}, output: `2,(2|2)`}, + {title: "compound: ignore_ops(false)", formatter: Formatter{Term: plusTwoMinusTwo, Heap: h, IgnoreOps: false, Ops: ops}, output: `2+ -2`}, + {title: "compound: ignore_ops(true)", formatter: Formatter{Term: plusTwoMinusTwo, Heap: h, IgnoreOps: true, Ops: ops}, output: `+(2,-2)`}, + {title: "compound: number_vars(false)", formatter: Formatter{Term: fVars, Heap: h, Quoted: true, NumberVars: false, Ops: ops}, output: `f('$VAR'(0),'$VAR'(1),'$VAR'(25),'$VAR'(26),'$VAR'(27))`}, + {title: "compound: number_vars(true)", formatter: Formatter{Term: fVars, Heap: h, Quoted: true, NumberVars: true, Ops: ops}, output: `f(A,B,Z,A1,B1)`}, + {title: "compound: prefix: spacing between operators", formatter: Formatter{Term: asteriskAMinusB, Heap: h, Ops: ops}, output: `a* -b`}, + {title: "compound: postfix: spacing between unary minus and open/close", formatter: Formatter{Term: minusTonA, Heap: h, Ops: ops}, output: `- (a+/)`}, + {title: "compound: infix: spacing between unary minus and open/close", formatter: Formatter{Term: minusAsteriskAB, Heap: h, Ops: ops}, output: `- (a*b)`}, + {title: "compound: recursive", formatter: Formatter{Term: r, Heap: h}, output: `f(...)`}, + {title: "compound: variable following/followed by a letter-digit operator", formatter: Formatter{Term: isXY, Heap: h, Ops: ops}, output: fmt.Sprintf("_%d is _%d", x.payload, y.payload)}, + {title: "compound: atom minus right after an operator", formatter: Formatter{Term: minusMinus, Heap: h, Ops: ops}, output: `- (-)`}, + {title: "compound: atom minus right before an operator", formatter: Formatter{Term: minusMinusMinus, Heap: h, Ops: ops}, output: `(-)--`}, + {title: "compound: atom X right before/after an operator that requires quotes", formatter: Formatter{Term: FXX, Heap: h, Quoted: true, Ops: ops}, output: `'X' 'F' 'X'`}, + {title: "compound: atom foo right before/after a letter-digit operator", formatter: Formatter{Term: isFooFoo, Heap: h, Ops: ops}, output: `foo is foo`}, // So that it won't be barfoo. + {title: "compound: positive integer following unary minus", formatter: Formatter{Term: unaryMinusThirtyThree, Heap: h, Ops: ops}, output: `- (33)`}, + {title: "compound: integer ambiguous 0b", formatter: Formatter{Term: b0Zero, Heap: h, Ops: ops}, output: `0 b0`}, // So that it won't be 0b0. + {title: "compound: integer ambiguous 0o", formatter: Formatter{Term: o0Zero, Heap: h, Ops: ops}, output: `0 o0`}, // So that it won't be 0o0. + {title: "compound: integer ambiguous 0x", formatter: Formatter{Term: x0Zero, Heap: h, Ops: ops}, output: `0 x0`}, // So that it won't be 0x0. + {title: "compound: integer ambiguous 0'", formatter: Formatter{Term: FooZero, Heap: h, Quoted: true, Ops: ops}, output: `0 'Foo'`}, // So that it won't be 0'Foo'. + {title: "float: positive following unary minus", formatter: Formatter{Term: minusFloatThirtyThree, Heap: h, Ops: ops, Precision: -1}, output: `- (33.0)`}, + {title: "float: ambiguous e", formatter: Formatter{Term: eFloatThirtyThree, Heap: h, Ops: ops, Precision: -1}, output: `33.0 e`}, // So that it won't be 33.0e. + } + + for _, tt := range tests { + t.Run(tt.title, func(t *testing.T) { + var buf bytes.Buffer + _, err := tt.formatter.WriteTo(&buf) + if !reflect.DeepEqual(tt.err, err) { + t.Errorf("want %v, got %v", tt.err, err) + } + + if tt.output != buf.String() { + t.Errorf("want %s, got %s", tt.output, buf.String()) + } + }) + } +} diff --git a/internal/ring/rune_reader.go b/internal/ring/rune_reader.go new file mode 100644 index 00000000..1eb4d90d --- /dev/null +++ b/internal/ring/rune_reader.go @@ -0,0 +1,37 @@ +package ring + +import "io" + +type runeWithSize struct { + rune rune + size int +} + +type RuneReader struct { + base io.RuneReader + buf *Buffer[runeWithSize] +} + +func NewRuneReader(r io.RuneReader, size int) *RuneReader { + return &RuneReader{ + base: r, + buf: NewBuffer[runeWithSize](size), + } +} + +func (r *RuneReader) ReadRune() (rune, int, error) { + if r.buf.Empty() { + c, n, err := r.base.ReadRune() + if err != nil { + return c, n, err + } + r.buf.Put(runeWithSize{rune: c, size: n}) + } + rs := r.buf.Get() + return rs.rune, rs.size, nil +} + +func (r *RuneReader) UnreadRune() error { + r.buf.Backup() + return nil +} diff --git a/internal/ring/rune_reader_test.go b/internal/ring/rune_reader_test.go new file mode 100644 index 00000000..b6136c0c --- /dev/null +++ b/internal/ring/rune_reader_test.go @@ -0,0 +1,38 @@ +package ring + +import ( + "io" + "reflect" + "strings" + "testing" +) + +func TestRuneReader_ReadRune(t *testing.T) { + tests := []struct { + title string + str string + size int + r rune + n int + err error + }{ + {title: "EOF", str: "", size: 0, r: 0, n: 0, err: io.EOF}, + {title: "ok", str: "foo", size: 2, r: 'f', n: 1}, + } + + for _, tt := range tests { + t.Run(tt.title, func(t *testing.T) { + rr := NewRuneReader(strings.NewReader(tt.str), tt.size) + r, n, err := rr.ReadRune() + if !reflect.DeepEqual(err, tt.err) { + t.Errorf("ReadRune() error = %v, wantErr %v", err, tt.err) + } + if r != tt.r { + t.Errorf("ReadRune() r = %v, want %v", r, tt.r) + } + if n != tt.n { + t.Errorf("ReadRune() n = %v, want %v", n, tt.n) + } + }) + } +} diff --git a/lexer.go b/lexer.go new file mode 100644 index 00000000..37226785 --- /dev/null +++ b/lexer.go @@ -0,0 +1,884 @@ +package prolog + +import ( + "bytes" + "fmt" + "io" + "strings" + "unicode" + "unicode/utf8" + "unsafe" + + "github.com/ichiban/prolog/v2/internal/ring" +) + +// lexer turns runes into tokens. +type lexer struct { + input *ring.RuneReader + charConversions map[rune]rune + + buf bytes.Buffer + offset int +} + +// Token returns the next token. +func (l *lexer) Token() (token, error) { + l.offset = l.buf.Len() + return l.layoutTextSequence(false) +} + +func (l *lexer) next() (rune, error) { + r, err := l.rawNext() + return l.conv(r), err +} + +func (l *lexer) rawNext() (rune, error) { + r, _, err := l.input.ReadRune() + return r, err +} + +func (l *lexer) conv(r rune) rune { + if r, ok := l.charConversions[r]; ok { + return r + } + return r +} + +func (l *lexer) backup() { + _ = l.input.UnreadRune() +} + +func (l *lexer) accept(r rune) { + _, _ = l.buf.WriteRune(r) +} + +func (l *lexer) chunk() string { + b := l.buf.Bytes()[l.offset:] + return *(*string)(unsafe.Pointer(&b)) +} + +// token is a smallest meaningful unit of prolog program. +type token struct { + kind tokenKind + val string +} + +func (t token) String() string { + return fmt.Sprintf("%s(%s)", t.kind.String(), t.val) +} + +// tokenKind is a type of token. +type tokenKind byte + +const ( + // tokenInvalid represents an invalid token. + tokenInvalid tokenKind = iota + + // tokenLetterDigit represents a letter digit token. + tokenLetterDigit + + // tokenGraphic represents a graphical token. + tokenGraphic + + // tokenQuoted represents a quoted token. + tokenQuoted + + // tokenSemicolon represents a semicolon token. + tokenSemicolon + + // tokenCut represents a cut token. + tokenCut + + // tokenVariable represents a variable token. + tokenVariable + + // tokenInteger represents an integer token. + tokenInteger + + // tokenFloatNumber represents a floating-point token. + tokenFloatNumber + + // tokenDoubleQuotedList represents a double-quoted string. + tokenDoubleQuotedList + + // tokenOpen represents an open parenthesis. + tokenOpen + + // tokenOpenCT represents an open CT parenthesis. + tokenOpenCT + + // tokenClose represents a close parenthesis. + tokenClose + + // tokenOpenList represents an open bracket. + tokenOpenList + + // tokenCloseList represents a close bracket. + tokenCloseList + + // tokenOpenCurly represents an open brace. + tokenOpenCurly + + // tokenCloseCurly represents a close brace. + tokenCloseCurly + + // tokenBar represents a bar. + tokenBar + + // tokenComma represents a comma. + tokenComma + + // tokenEnd represents a period. + tokenEnd +) + +// GoString returns a string representation of tokenKind. +func (k tokenKind) GoString() string { + return k.String() +} + +func (k tokenKind) String() string { + return [...]string{ + tokenInvalid: "invalid", + tokenLetterDigit: "letter digit", + tokenGraphic: "graphic", + tokenQuoted: "quoted", + tokenSemicolon: "semicolon", + tokenCut: "cut", + tokenVariable: "variable", + tokenInteger: "integer", + tokenFloatNumber: "float number", + tokenDoubleQuotedList: "double quoted list", + tokenOpen: "open", + tokenOpenCT: "open ct", + tokenClose: "close", + tokenOpenList: "open list", + tokenCloseList: "close list", + tokenOpenCurly: "open curly", + tokenCloseCurly: "close curly", + tokenBar: "bar", + tokenComma: "comma", + tokenEnd: "end", + }[k] +} + +// Tokens + +var soloTokenKinds = [...]tokenKind{ + ';': tokenSemicolon, + '!': tokenCut, + ')': tokenClose, + '[': tokenOpenList, + ']': tokenCloseList, + '{': tokenOpenCurly, + '}': tokenCloseCurly, + '|': tokenBar, + ',': tokenComma, +} + +func (l *lexer) token(afterLayout bool) (token, error) { + switch r, err := l.next(); { + case err != nil: + return token{}, err + case isSmallLetterChar(r): + l.accept(r) + return l.letterDigitToken() + case r == '.': + l.accept(r) + if l.wasEndChar() { + return token{kind: tokenEnd, val: l.chunk()}, nil + } + return l.graphicToken() + case isGraphicChar(r), r == '\\': + l.accept(r) + return l.graphicToken() + case r == '\'': + l.accept(r) + return l.quotedToken() + case r == '_', isCapitalLetterChar(r): + l.accept(r) + return l.variableToken() + case isDecimalDigitChar(r): + return l.integerToken(r) + case r == '"': + l.accept(r) + return l.doubleQuotedListToken() + case r == '(': + l.accept(r) + if afterLayout { + return token{kind: tokenOpen, val: l.chunk()}, nil + } + return token{kind: tokenOpenCT, val: l.chunk()}, nil + default: + k := tokenInvalid + if int(r) < len(soloTokenKinds) { + k = soloTokenKinds[r] + } + l.accept(r) + return token{kind: k, val: l.chunk()}, nil + } +} + +func (l *lexer) wasEndChar() bool { + r, err := l.next() + if err != nil { + return true + } + l.backup() + return isLayoutChar(r) || r == '%' +} + +//// Layout text + +func (l *lexer) layoutTextSequence(afterLayout bool) (token, error) { + for { + switch r, err := l.next(); { + case err == io.EOF: + return l.token(afterLayout) + case err != nil: + return token{}, err + case isLayoutChar(r): + afterLayout = true + continue + case r == '%': + return l.commentText(false) + case r == '/': + return l.commentOpen() + default: + l.backup() + return l.token(afterLayout) + } + } +} + +func (l *lexer) commentText(bracketed bool) (token, error) { + for { + switch r, err := l.next(); { + case err != nil: + return token{}, err + case bracketed && r == '*': + return l.commentClose() + case !bracketed && r == '\n': + return l.layoutTextSequence(true) + } + } +} + +func (l *lexer) commentOpen() (token, error) { + switch r, err := l.next(); { + case err == io.EOF: + l.accept('/') + return l.graphicToken() + case err != nil: + return token{}, err + case r == '*': + return l.commentText(true) + default: + l.backup() + l.accept('/') + return l.graphicToken() + } +} + +func (l *lexer) commentClose() (token, error) { + switch r, err := l.next(); { + case err != nil: + return token{}, err + case r == '/': + return l.layoutTextSequence(true) + case r == '*': + return l.commentClose() + default: + return l.commentText(true) + } +} + +//// Names + +func (l *lexer) letterDigitToken() (token, error) { + for { + switch r, err := l.next(); { + case err == io.EOF: + return token{kind: tokenLetterDigit, val: l.chunk()}, nil + case err != nil: + return token{}, err + case isAlphanumericChar(r): + l.accept(r) + default: + l.backup() + return token{kind: tokenLetterDigit, val: l.chunk()}, nil + } + } +} + +func (l *lexer) graphicToken() (token, error) { + for { + switch r, err := l.next(); { + case err == io.EOF: + return token{kind: tokenGraphic, val: l.chunk()}, nil + case err != nil: + return token{}, err + case isGraphicChar(r), r == '\\': + l.accept(r) + default: + l.backup() + return token{kind: tokenGraphic, val: l.chunk()}, nil + } + } +} + +func (l *lexer) quotedToken() (token, error) { + for { + switch r, err := l.rawNext(); { + case err != nil: + return token{}, err + case isSingleQuotedCharacter(r): + l.accept(r) + continue + case r == '\'': + l.accept(r) + switch r, err := l.rawNext(); { + case err == io.EOF: + break + case err != nil: + return token{}, err + case r == '\'': + l.accept(r) + continue + default: + l.backup() + } + + s := l.chunk() + + // Checks if it contains invalid octal or hexadecimal escape sequences. + if strings.ContainsRune(unquote(s), utf8.RuneError) { + return token{kind: tokenInvalid, val: s}, nil + } + + return token{kind: tokenQuoted, val: s}, nil + case r == '\\': + l.accept(r) + switch r, err := l.rawNext(); { + case err == io.EOF: + break + case err != nil: + return token{}, err + case r == '\n': + l.accept(r) + continue + default: + l.backup() + } + + return l.escapeSequence(l.quotedToken) + default: + l.accept(r) + return token{kind: tokenInvalid, val: l.chunk()}, nil + } + } +} + +func (l *lexer) escapeSequence(cont func() (token, error)) (token, error) { + switch r, err := l.rawNext(); { + case err != nil: + return token{}, err + case isMetaChar(r), isSymbolicControlChar(r): + l.accept(r) + return cont() + case isOctalDigitChar(r): + l.accept(r) + return l.octalEscapeSequence(cont) + case r == 'x': + l.accept(r) + return l.hexadecimalEscapeSequence(cont) + default: + l.accept(r) + return token{kind: tokenInvalid, val: l.chunk()}, nil + } +} + +func (l *lexer) octalEscapeSequence(cont func() (token, error)) (token, error) { + for { + switch r, err := l.rawNext(); { + case err != nil: + return token{}, err + case r == '\\': + l.accept(r) + return cont() + case isOctalDigitChar(r): + l.accept(r) + continue + default: + l.accept(r) + return token{kind: tokenInvalid, val: l.chunk()}, nil + } + } +} + +func (l *lexer) hexadecimalEscapeSequence(cont func() (token, error)) (token, error) { + switch r, err := l.rawNext(); { + case err != nil: + return token{}, err + case isHexadecimalDigitChar(r): + l.accept(r) + default: + l.accept(r) + return token{kind: tokenInvalid, val: l.chunk()}, nil + } + + for { + switch r, err := l.next(); { + case err != nil: + return token{}, err + case r == '\\': + l.accept(r) + return cont() + case isHexadecimalDigitChar(r): + l.accept(r) + continue + default: + l.accept(r) + return token{kind: tokenInvalid, val: l.chunk()}, nil + } + } +} + +//// Variables + +func (l *lexer) variableToken() (token, error) { + for { + switch r, err := l.next(); { + case err == io.EOF: + return token{kind: tokenVariable, val: l.chunk()}, nil + case err != nil: + return token{}, err + case isAlphanumericChar(r): + l.accept(r) + default: + l.backup() + return token{kind: tokenVariable, val: l.chunk()}, nil + } + } +} + +//// Integer numbers + +func (l *lexer) integerToken(first rune) (token, error) { + switch first { + case '0': + l.accept(first) + switch r, err := l.next(); { + case err == io.EOF: + return l.integerConstant() + case err != nil: + return token{}, err + case r == '\'': + return l.integerTokenCharacterCode(r) + case r == 'b': + return l.integerTokenBinary(r) + case r == 'o': + return l.integerTokenOctal(r) + case r == 'x': + return l.integerTokenHexadecimal(r) + default: + l.backup() + return l.integerConstant() + } + default: + l.accept(first) + return l.integerConstant() + } +} + +func (l *lexer) integerTokenCharacterCode(r rune) (token, error) { + switch r, err := l.next(); { + case err == io.EOF: + break + case err != nil: + return token{}, err + case r == '\'': + switch r, err := l.next(); { + case err == io.EOF: + l.backup() + l.backup() + return token{kind: tokenInteger, val: l.chunk()}, nil // 0 + case err != nil: + return token{}, err + case r == '\'': // 0''' + l.backup() + l.backup() + default: + l.backup() + l.backup() + l.backup() + return token{kind: tokenInteger, val: l.chunk()}, nil // 0 + } + case r == '\\': + switch r, err := l.next(); { + case err == io.EOF: + l.backup() + case err != nil: + return token{}, err + case r == '\n': + l.backup() + l.backup() + l.backup() + return token{kind: tokenInteger, val: l.chunk()}, nil // 0 + default: + l.backup() + l.backup() + } + default: + l.backup() + } + l.accept(r) + return l.characterCodeConstant() +} + +func (l *lexer) integerTokenBinary(r rune) (token, error) { + switch r, err := l.next(); { + case err == io.EOF: + l.backup() + return token{kind: tokenInteger, val: l.chunk()}, nil + case err != nil: + return token{}, err + case isBinaryDigitChar(r): + l.backup() + default: + l.backup() + l.backup() + return token{kind: tokenInteger, val: l.chunk()}, nil + } + l.accept(r) + return l.binaryConstant() +} + +func (l *lexer) integerTokenOctal(r rune) (token, error) { + switch r, err := l.next(); { + case err == io.EOF: + l.backup() + return token{kind: tokenInteger, val: l.chunk()}, nil + case err != nil: + return token{}, err + case isOctalDigitChar(r): + l.backup() + default: + l.backup() + l.backup() + return token{kind: tokenInteger, val: l.chunk()}, nil + } + l.accept(r) + return l.octalConstant() +} + +func (l *lexer) integerTokenHexadecimal(r rune) (token, error) { + switch r, err := l.next(); { + case err == io.EOF: + l.backup() + return token{kind: tokenInteger, val: l.chunk()}, nil + case err != nil: + return token{}, err + case isHexadecimalDigitChar(r): + l.backup() + default: + l.backup() + l.backup() + return token{kind: tokenInteger, val: l.chunk()}, nil + } + l.accept(r) + return l.hexadecimalConstant() +} + +func (l *lexer) integerConstant() (token, error) { + for { + switch r, err := l.next(); { + case err == io.EOF: + return token{kind: tokenInteger, val: l.chunk()}, nil + case err != nil: + return token{}, err + case isDecimalDigitChar(r): + l.accept(r) + case r == '.': + switch r, err := l.next(); { + case err == io.EOF: + l.backup() + return token{kind: tokenInteger, val: l.chunk()}, nil + case err != nil: + return token{}, err + case isDecimalDigitChar(r): + l.accept('.') + l.accept(r) + return l.fraction() + default: + l.backup() + l.backup() + return token{kind: tokenInteger, val: l.chunk()}, nil + } + default: + l.backup() + return token{kind: tokenInteger, val: l.chunk()}, nil + } + } +} + +func (l *lexer) characterCodeConstant() (token, error) { + switch r, err := l.next(); { + case err != nil: + return token{}, err + case r == '\'': + l.accept(r) + r, _ := l.next() // r == '\'' + l.accept(r) + return token{kind: tokenInteger, val: l.chunk()}, nil + case r == '\\': + l.accept(r) + return l.escapeSequence(func() (token, error) { + return token{kind: tokenInteger, val: l.chunk()}, nil + }) + case isGraphicChar(r), isAlphanumericChar(r), isSoloChar(r), r == ' ': + l.accept(r) + return token{kind: tokenInteger, val: l.chunk()}, nil + default: + l.accept(r) + return token{kind: tokenInvalid, val: l.chunk()}, nil + } +} + +func (l *lexer) binaryConstant() (token, error) { + for { + switch r, err := l.next(); { + case err == io.EOF: + return token{kind: tokenInteger, val: l.chunk()}, nil + case err != nil: + return token{}, err + case isBinaryDigitChar(r): + l.accept(r) + default: + l.backup() + return token{kind: tokenInteger, val: l.chunk()}, nil + } + } +} + +func (l *lexer) octalConstant() (token, error) { + for { + switch r, err := l.next(); { + case err == io.EOF: + return token{kind: tokenInteger, val: l.chunk()}, nil + case err != nil: + return token{}, err + case isOctalDigitChar(r): + l.accept(r) + default: + l.backup() + return token{kind: tokenInteger, val: l.chunk()}, nil + } + } +} + +func (l *lexer) hexadecimalConstant() (token, error) { + for { + switch r, err := l.next(); { + case err == io.EOF: + return token{kind: tokenInteger, val: l.chunk()}, nil + case err != nil: + return token{}, err + case isHexadecimalDigitChar(r): + l.accept(r) + default: + l.backup() + return token{kind: tokenInteger, val: l.chunk()}, nil + } + } +} + +//// Floating point numbers + +func (l *lexer) fraction() (token, error) { + for { + switch r, err := l.next(); { + case err == io.EOF: + return token{kind: tokenFloatNumber, val: l.chunk()}, nil + case err != nil: + return token{}, err + case isDecimalDigitChar(r): + l.accept(r) + case isExponentChar(r): + var sign rune + switch r, err := l.next(); { + case err == io.EOF: + l.backup() // for 'e' or 'E' + return token{kind: tokenFloatNumber, val: l.chunk()}, nil + case err != nil: + return token{}, err + case isSignChar(r): + sign = r + default: + l.backup() + } + + switch r, err := l.next(); { + case err == io.EOF: + if sign != 0 { + l.backup() + } + l.backup() // for 'e' or 'E' + return token{kind: tokenFloatNumber, val: l.chunk()}, nil + case err != nil: + return token{}, err + case isDecimalDigitChar(r): + l.backup() + break + default: + l.backup() + if sign != 0 { + l.backup() + } + l.backup() // for 'e' or 'E' + return token{kind: tokenFloatNumber, val: l.chunk()}, nil + } + + l.accept(r) // 'e' or 'E' + if sign != 0 { + l.accept(sign) + } + return l.exponent() + default: + l.backup() + return token{kind: tokenFloatNumber, val: l.chunk()}, nil + } + } +} + +func (l *lexer) exponent() (token, error) { + for { + switch r, err := l.next(); { + case err == io.EOF: + return token{kind: tokenFloatNumber, val: l.chunk()}, nil + case err != nil: + return token{}, err + case isDecimalDigitChar(r): + l.accept(r) + default: + l.backup() + return token{kind: tokenFloatNumber, val: l.chunk()}, nil + } + } +} + +//// Double quoted lists + +func (l *lexer) doubleQuotedListToken() (token, error) { + for { + switch r, err := l.rawNext(); { + case err != nil: + return token{}, err + case r == '"': + l.accept(r) + switch r, err := l.next(); { + case err == io.EOF: + return token{kind: tokenDoubleQuotedList, val: l.chunk()}, nil + case err != nil: + return token{}, err + case r == '"': + l.accept(r) + default: + l.backup() + return token{kind: tokenDoubleQuotedList, val: l.chunk()}, nil + } + case r == '\\': + l.accept(r) + switch r, err := l.next(); { + case err != nil: + return token{}, err + case r == '\n': + l.accept(r) + default: + l.backup() + return l.escapeSequence(l.doubleQuotedListToken) + } + default: + l.accept(r) + } + } +} + +// Characters + +func isGraphicChar(r rune) bool { + return strings.ContainsRune(`#$&*+-./:<=>?@^~`, r) || unicode.In(r, &unicode.RangeTable{ + R16: []unicode.Range16{ + {Lo: 0x2200, Hi: 0x22FF, Stride: 1}, // Mathematical Operators + {Lo: 0x2A00, Hi: 0x2AFF, Stride: 1}, // Supplemental Mathematical Operators + }, + }) +} + +func isAlphanumericChar(r rune) bool { + return isAlphaChar(r) || isDecimalDigitChar(r) +} + +func isAlphaChar(r rune) bool { + return isUnderscoreChar(r) || isLetterChar(r) +} + +func isLetterChar(r rune) bool { + return isCapitalLetterChar(r) || isSmallLetterChar(r) +} + +func isSmallLetterChar(r rune) bool { + return unicode.In(r, unicode.Ll, unicode.Lo, unicode.Lm) +} + +func isCapitalLetterChar(r rune) bool { + return unicode.IsUpper(r) +} + +func isDecimalDigitChar(r rune) bool { + return strings.ContainsRune(`0123456789`, r) +} + +func isBinaryDigitChar(r rune) bool { + return strings.ContainsRune(`01`, r) +} + +func isOctalDigitChar(r rune) bool { + return strings.ContainsRune("01234567", r) +} + +func isHexadecimalDigitChar(r rune) bool { + return strings.ContainsRune("0123456789ABCDEF", unicode.ToUpper(r)) +} + +func isSoloChar(r rune) bool { + return strings.ContainsRune(`!(),;[]{}|%`, r) +} + +func isUnderscoreChar(r rune) bool { + return r == '_' +} + +func isLayoutChar(r rune) bool { + return unicode.IsSpace(r) +} + +func isMetaChar(r rune) bool { + return strings.ContainsRune("\\'\"`", r) +} + +func isSymbolicControlChar(r rune) bool { + return strings.ContainsRune(`abrftnv`, r) +} + +func isSingleQuotedCharacter(r rune) bool { + return isGraphicChar(r) || isAlphanumericChar(r) || isSoloChar(r) || r == ' ' || r == '"' || r == '`' +} + +func isExponentChar(r rune) bool { + return r == 'e' || r == 'E' +} + +func isSignChar(r rune) bool { + return r == '-' || r == '+' +} diff --git a/operator.go b/operator.go new file mode 100644 index 00000000..405e46e5 --- /dev/null +++ b/operator.go @@ -0,0 +1,155 @@ +package prolog + +import ( + "math" +) + +// Operators is a set of defined operators. +type Operators struct { + ops map[opKey]operator +} + +// Define defines an operator. +func (o *Operators) Define(priority int, spec OperatorSpecifier, name string) { + if o.ops == nil { + o.ops = map[opKey]operator{} + } + o.ops[opKey{ + name: name, + opClass: operatorSpecifiers[spec].opClass, + }] = operator{ + priority: priority, + specifier: spec, + name: name, + } +} + +func (o *Operators) definedIn(name string, opClass operatorClass) bool { + _, ok := o.ops[opKey{name: name, opClass: opClass}] + return ok +} + +func (o *Operators) defined(name string) bool { + return o.definedIn(name, operatorClassPrefix) || + o.definedIn(name, operatorClassPostfix) || + o.definedIn(name, operatorClassInfix) +} + +type opKey struct { + name string + opClass operatorClass +} + +type operatorClass int8 + +const ( + operatorClassPrefix operatorClass = iota + operatorClassPostfix + operatorClassInfix +) + +var operatorClasses = [...]struct { + arity int +}{ + operatorClassPrefix: { + arity: 1, + }, + operatorClassPostfix: { + arity: 1, + }, + operatorClassInfix: { + arity: 2, + }, +} + +// OperatorSpecifier specifies a class and associativity of an operator. +type OperatorSpecifier int8 + +const ( + FX OperatorSpecifier = iota + FY + XF + YF + XFX + XFY + YFX +) + +var operatorSpecifiers = [...]struct { + name string + opClass operatorClass + priorities func(p int) (left int, right int) +}{ + FX: { + name: "fx", + opClass: operatorClassPrefix, + priorities: func(p int) (left int, right int) { + return math.MaxInt, p - 1 + }, + }, + FY: { + name: "fy", + opClass: operatorClassPrefix, + priorities: func(p int) (left int, right int) { + return math.MaxInt, p + }, + }, + XF: { + name: "xf", + opClass: operatorClassPostfix, + priorities: func(p int) (left int, right int) { + return p - 1, math.MaxInt + }, + }, + YF: { + name: "yf", + opClass: operatorClassPostfix, + priorities: func(p int) (left int, right int) { + return p, math.MaxInt + }, + }, + XFX: { + name: "xfx", + opClass: operatorClassInfix, + priorities: func(p int) (left int, right int) { + return p - 1, p - 1 + }, + }, + XFY: { + name: "xFy", + opClass: operatorClassInfix, + priorities: func(p int) (left int, right int) { + return p - 1, p + }, + }, + YFX: { + name: "yFx", + opClass: operatorClassInfix, + priorities: func(p int) (left int, right int) { + return p, p - 1 + }, + }, +} + +func (s OperatorSpecifier) class() operatorClass { + return operatorSpecifiers[s].opClass +} + +func (s OperatorSpecifier) String() string { + return operatorSpecifiers[s].name +} + +func (s OperatorSpecifier) arity() int { + return operatorClasses[operatorSpecifiers[s].opClass].arity +} + +type operator struct { + priority int // 1 ~ 1200 + specifier OperatorSpecifier + name string +} + +// Pratt parser's binding powers but in Prolog priority. +func (o *operator) bindingPriorities() (int, int) { + return operatorSpecifiers[o.specifier].priorities(o.priority) +} diff --git a/parser.go b/parser.go new file mode 100644 index 00000000..1f29681a --- /dev/null +++ b/parser.go @@ -0,0 +1,830 @@ +package prolog + +import ( + "fmt" + "io" + "math/big" + "regexp" + "strconv" + "strings" + + "github.com/ichiban/prolog/v2/internal/ring" +) + +// Parser turns bytes into Term. +type Parser struct { + lexer lexer + operators Operators + doubleQuotes doubleQuotes + + buf *ring.Buffer[token] +} + +// ParsedVariable is a set of information regarding a variable in a parsed term. +type ParsedVariable struct { + Name string + Variable Variable + Count int +} + +// NewParser creates a new parser from the current VM and io.RuneReader. +func NewParser(r io.RuneReader, ops Operators, doubleQuotes doubleQuotes) *Parser { + return &Parser{ + lexer: lexer{ + input: ring.NewRuneReader(r, 4), + }, + operators: ops, + doubleQuotes: doubleQuotes, + buf: ring.NewBuffer[token](4), + } +} + +func (p *Parser) next() (token, error) { + if p.buf.Empty() { + t, err := p.lexer.Token() + if err != nil { + return token{}, err + } + p.buf.Put(t) + } + return p.buf.Get(), nil +} + +func (p *Parser) backup() { + p.buf.Backup() +} + +func (p *Parser) current() token { + return p.buf.Current() +} + +// Term parses a term followed by a full stop. +func (p *Parser) Term(h *Heap) (_ Term, _ []ParsedVariable, err error) { + snapshot := *h + defer func() { + if err != nil { + *h = snapshot + } + }() + + var pvs []ParsedVariable + t, ok, err := p.term(h, &pvs, 1201) + if err != nil { + return Term{}, nil, err + } + if !ok { + return Term{}, nil, &SyntaxError{impDepAtom: fmt.Sprintf("unexpected token: %s", p.current())} + } + + switch t, _ := p.next(); t.kind { + case tokenEnd: + break + default: + p.backup() + return Term{}, nil, &SyntaxError{impDepAtom: fmt.Sprintf("unexpected token: %s", p.current())} + } + + return t, pvs, nil +} + +// Number parses a number term. +func (p *Parser) Number(h *Heap) (_ Term, err error) { + snapshot := *h + defer func() { + if err != nil { + *h = snapshot + } + }() + + var n Term + t, err := p.next() + if err != nil { + return Term{}, err + } + switch t.kind { + case tokenInteger: + n, err = integer(h, 1, t.val) + case tokenFloatNumber: + n, err = float(h, 1, t.val) + default: + p.backup() + var ( + a string + ok bool + ) + a, ok, err = p.name() + if err != nil { + return Term{}, err + } + if !ok { + return Term{}, &SyntaxError{impDepAtom: "not_a_number"} + } + + if a != "-" { + p.backup() + return Term{}, &SyntaxError{impDepAtom: "not_a_number"} + } + + t, err = p.next() + if err != nil { + return Term{}, &SyntaxError{impDepAtom: "not_a_number"} + } + switch t.kind { + case tokenInteger: + n, err = integer(h, -1, t.val) + case tokenFloatNumber: + n, err = float(h, -1, t.val) + default: + p.backup() + p.backup() + return Term{}, &SyntaxError{impDepAtom: "not_a_number"} + } + } + if err != nil { + return Term{}, err + } + + // No more runes after a number. + switch _, err := p.lexer.rawNext(); err { + case io.EOF: + return n, nil + default: + return Term{}, &SyntaxError{impDepAtom: "not_a_number"} + } +} + +// More checks if the parser has more tokens to read. +func (p *Parser) More() bool { + if _, err := p.next(); err != nil { + return false + } + p.backup() + return true +} + +type doubleQuotes int + +const ( + doubleQuotesChars doubleQuotes = iota + doubleQuotesCodes + doubleQuotesAtom +) + +var doubleQuoteNames = [...]string{ + doubleQuotesCodes: "codes", + doubleQuotesChars: "chars", + doubleQuotesAtom: "atom", +} + +func (d doubleQuotes) String() string { + return doubleQuoteNames[d] +} + +// Loosely based on Pratt parser explained in this article: https://matklad.github.io/2020/04/13/simple-but-powerful-pratt-parsing.html +func (p *Parser) term(h *Heap, pvs *[]ParsedVariable, maxPriority int) (Term, bool, error) { + var lhs Term + switch op, ok, err := p.prefix(maxPriority); { + case err != nil: + return Term{}, false, err + case !ok: + lhs, ok, err = p.term0(h, pvs, maxPriority) + if err != nil || !ok { + return Term{}, ok, err + } + default: + _, rbp := op.bindingPriorities() + t, ok, err := p.term(h, pvs, rbp) + if err != nil { + return Term{}, false, err + } + if !ok { + p.backup() + return p.term0(h, pvs, maxPriority) + } + lhs, err = NewCompound(h, op.name, t) + if err != nil { + return Term{}, false, err + } + } + + for { + op, ok, err := p.infix(maxPriority) + if err != nil { + return Term{}, false, err + } + if !ok { + break + } + + switch _, rbp := op.bindingPriorities(); { + case rbp > 1200: + var err error + lhs, err = NewCompound(h, op.name, lhs) + if err != nil { + return Term{}, false, err + } + default: + rhs, ok, err := p.term(h, pvs, rbp) + if err != nil || !ok { + return Term{}, ok, err + } + lhs, err = NewCompound(h, op.name, lhs, rhs) + if err != nil { + return Term{}, false, err + } + } + } + + return lhs, true, nil +} + +func (p *Parser) prefix(maxPriority int) (operator, bool, error) { + a, ok, err := p.op(maxPriority) + if err != nil || !ok { + return operator{}, ok, err + } + + if a == "-" { + t, err := p.next() + if err != nil { + return operator{}, false, err + } + switch t.kind { + case tokenInteger, tokenFloatNumber: + p.backup() + p.backup() + return operator{}, false, nil + default: + p.backup() + } + } + + t, err := p.next() + if err != nil { + return operator{}, false, err + } + switch t.kind { + case tokenOpenCT: + p.backup() + p.backup() + return operator{}, false, nil + default: + p.backup() + } + + op, ok := p.operators.ops[opKey{name: a, opClass: operatorClassPrefix}] + if !ok || op.priority > maxPriority { + p.backup() + return operator{}, false, nil + } + return op, true, nil +} + +func (p *Parser) infix(maxPriority int) (operator, bool, error) { + a, ok, err := p.op(maxPriority) + if err != nil || !ok { + return operator{}, ok, err + } + + if op := p.operators.ops[opKey{name: a, opClass: operatorClassInfix}]; op != (operator{}) { + l, _ := op.bindingPriorities() + if l <= maxPriority { + return op, true, nil + } + } + if op := p.operators.ops[opKey{name: a, opClass: operatorClassPostfix}]; op != (operator{}) { + l, _ := op.bindingPriorities() + if l <= maxPriority { + return op, true, nil + } + } + + p.backup() + return operator{}, false, nil +} + +func (p *Parser) op(maxPriority int) (string, bool, error) { + a, ok, err := p.atom() + if err != nil { + return "", false, err + } + if ok { + switch a { + case "[]": + p.backup() + if p.current().kind == tokenCloseList { + p.backup() + } + return "", false, nil + case "{}": + p.backup() + if p.current().kind == tokenCloseCurly { + p.backup() + } + return "", false, nil + default: + return a, true, nil + } + } + + t, err := p.next() + if err != nil { + return "", false, err + } + switch t.kind { + case tokenComma: + if maxPriority >= 1000 { + return t.val, true, nil + } + case tokenBar: + return t.val, true, nil + default: + break + } + + p.backup() + return "", false, nil +} + +func (p *Parser) term0(h *Heap, pvs *[]ParsedVariable, maxPriority int) (Term, bool, error) { + t, err := p.next() + if err != nil { + return Term{}, false, err + } + switch t.kind { + case tokenOpen, tokenOpenCT: + return p.openClose(h, pvs) + case tokenInteger: + i, err := integer(h, 1, t.val) + if err != nil { + return Term{}, false, err + } + return i, true, nil + case tokenFloatNumber: + f, err := float(h, 1, t.val) + if err != nil { + return Term{}, false, err + } + return f, true, nil + case tokenVariable: + v, err := p.variable(h, pvs, t.val) + if err != nil { + return Term{}, false, err + } + return v, true, nil + case tokenOpenList: + if t, _ := p.next(); t.kind == tokenCloseList { + p.backup() + p.backup() + break + } + p.backup() + return p.list(h, pvs) + case tokenOpenCurly: + if t, _ := p.next(); t.kind == tokenCloseCurly { + p.backup() + p.backup() + break + } + p.backup() + return p.curlyBracketedTerm(h, pvs) + case tokenDoubleQuotedList: + switch p.doubleQuotes { + case doubleQuotesChars: + cl, err := NewCharList(h, unDoubleQuote(t.val)) + if err != nil { + return Term{}, false, err + } + return cl, true, nil + case doubleQuotesCodes: + cl, err := NewCodeList(h, unDoubleQuote(t.val)) + if err != nil { + return Term{}, false, err + } + return cl, true, nil + default: + p.backup() + break + } + default: + p.backup() + } + + return p.term0Atom(h, pvs, maxPriority) +} + +func (p *Parser) term0Atom(h *Heap, pvs *[]ParsedVariable, maxPriority int) (Term, bool, error) { + a, ok, err := p.atom() + if err != nil || !ok { + return Term{}, ok, err + } + + if a == "-" { + t, err := p.next() + if err != nil { + return Term{}, false, err + } + switch t.kind { + case tokenInteger: + i, err := integer(h, -1, t.val) + if err != nil { + return Term{}, false, err + } + return i, true, nil + case tokenFloatNumber: + f, err := float(h, -1, t.val) + if err != nil { + return Term{}, false, err + } + return f, true, nil + default: + p.backup() + } + } + + t, ok, err := p.functionalNotation(h, pvs, a) + if err != nil || !ok { + return Term{}, ok, err + } + + // 6.3.1.3 An atom which is an operator shall not be the immediate operand (3.120) of an operator. + if a, err := t.Atom(h); err == nil && maxPriority < 1201 && p.operators.defined(a) { + p.backup() + return Term{}, false, nil + } + + return t, true, nil +} + +func (p *Parser) variable(h *Heap, pvs *[]ParsedVariable, s string) (Term, error) { + if s == "_" { + v, err := NewVariable(h) + return v, err + } + for i, pv := range *pvs { + if pv.Name == s { + (*pvs)[i].Count++ + return Term{tag: termTagVariable, payload: int32(pv.Variable)}, nil + } + } + v, err := NewVariable(h) + if err != nil { + return Term{}, err + } + *pvs = append(*pvs, ParsedVariable{Name: s, Variable: Variable(v.payload), Count: 1}) + return v, nil +} + +func (p *Parser) openClose(h *Heap, pvs *[]ParsedVariable) (Term, bool, error) { + t, ok, err := p.term(h, pvs, 1201) + if err != nil || !ok { + return Term{}, ok, err + } + if t, _ := p.next(); t.kind != tokenClose { + p.backup() + return Term{}, false, nil + } + return t, true, nil +} + +func (p *Parser) atom() (string, bool, error) { + if a, ok, err := p.name(); err != nil || ok { + return a, ok, err + } + + t, err := p.next() + if err != nil { + return "", false, err + } + switch t.kind { + case tokenOpenList: + t, err := p.next() + if err != nil { + return "", false, err + } + switch t.kind { + case tokenCloseList: + return "[]", true, nil + default: + p.backup() + p.backup() + return "", false, nil + } + case tokenOpenCurly: + t, err := p.next() + if err != nil { + return "", false, err + } + switch t.kind { + case tokenCloseCurly: + return "{}", true, nil + default: + p.backup() + p.backup() + return "", false, nil + } + case tokenDoubleQuotedList: + switch p.doubleQuotes { + case doubleQuotesAtom: + return unDoubleQuote(t.val), true, nil + default: + p.backup() + return "", false, nil + } + default: + p.backup() + return "", false, nil + } +} + +func (p *Parser) name() (string, bool, error) { + t, err := p.next() + if err != nil { + return "", false, err + } + switch t.kind { + case tokenLetterDigit, tokenGraphic, tokenSemicolon, tokenCut: + return t.val, true, nil + case tokenQuoted: + return unquote(t.val), true, nil + default: + p.backup() + return "", false, nil + } +} + +func (p *Parser) list(h *Heap, pvs *[]ParsedVariable) (Term, bool, error) { + var elems []Term + arg, err := p.arg(h, pvs) + if err != nil { + return Term{}, false, err + } + elems = append(elems, arg) + for { + switch t, _ := p.next(); t.kind { + case tokenComma: + arg, err := p.arg(h, pvs) + if err != nil { + return Term{}, false, err + } + elems = append(elems, arg) + case tokenBar: + tail, err := p.arg(h, pvs) + if err != nil { + return Term{}, false, err + } + + switch t, _ := p.next(); t.kind { + case tokenCloseList: + pl, err := NewPartialList(h, tail, elems...) + if err != nil { + return Term{}, false, err + } + + return pl, true, nil + default: + p.backup() + return Term{}, false, nil + } + case tokenCloseList: + l, err := NewList(h, elems...) + if err != nil { + return Term{}, false, err + } + + return l, true, nil + default: + p.backup() + return Term{}, false, nil + } + } +} + +func (p *Parser) curlyBracketedTerm(h *Heap, pvs *[]ParsedVariable) (Term, bool, error) { + t, ok, err := p.term(h, pvs, 1201) + if err != nil || !ok { + return Term{}, ok, err + } + + if t, _ := p.next(); t.kind != tokenCloseCurly { + p.backup() + return Term{}, false, nil + } + + c, err := NewCompound(h, "{}", t) + if err != nil { + return Term{}, false, err + } + + return c, true, nil +} + +func (p *Parser) functionalNotation(h *Heap, pvs *[]ParsedVariable, functor string) (Term, bool, error) { + switch t, _ := p.next(); t.kind { + case tokenOpenCT: + arg, err := p.arg(h, pvs) + if err != nil { + return Term{}, false, err + } + args := []Term{arg} + for { + switch t, _ := p.next(); t.kind { + case tokenComma: + arg, err := p.arg(h, pvs) + if err != nil { + return Term{}, false, err + } + args = append(args, arg) + case tokenClose: + c, err := NewCompound(h, functor, args...) + if err != nil { + return Term{}, false, err + } + + return c, true, nil + default: + p.backup() + return Term{}, false, nil + } + } + default: + p.backup() + a, err := NewAtom(h, functor) + if err != nil { + return Term{}, false, err + } + return a, true, nil + } +} + +func (p *Parser) arg(h *Heap, pvs *[]ParsedVariable) (Term, error) { + arg, ok, err := p.atom() + if err != nil { + return Term{}, err + } + if ok { + if p.operators.defined(arg) { + // Check if this atom is not followed by its own arguments. + switch t, _ := p.next(); t.kind { + case tokenComma, tokenClose, tokenBar, tokenCloseList: + p.backup() + a, err := NewAtom(h, arg) + if err != nil { + return Term{}, err + } + return a, nil + default: + p.backup() + } + } + p.backup() + if p.current().kind == tokenCloseList || p.current().kind == tokenCloseCurly { + p.backup() // Unquoted [] or {} consist of 2 tokens. + } + } + + t, ok, err := p.term(h, pvs, 999) + if err != nil { + return Term{}, err + } + if !ok { + return Term{}, &SyntaxError{impDepAtom: fmt.Sprintf("unexpected token: %s", p.current())} + } + return t, nil +} + +func integer(h *Heap, sign int64, s string) (Term, error) { + base := 10 + switch { + case strings.HasPrefix(s, "0'"): + s = s[2:] + s = quotedIdentEscapePattern.ReplaceAllStringFunc(s, quotedIdentUnescape) + return NewInteger(h, sign*int64([]rune(s)[0])) + case strings.HasPrefix(s, "0b"): + base = 2 + s = s[2:] + case strings.HasPrefix(s, "0o"): + base = 8 + s = s[2:] + case strings.HasPrefix(s, "0x"): + base = 16 + s = s[2:] + } + + f, _, _ := big.ParseFloat(s, base, 0, big.ToZero) + f.Mul(big.NewFloat(float64(sign)), f) + + switch i, a := f.Int64(); a { + case big.Above: + return Term{}, &RepresentationError{flag: "min_integer"} + case big.Below: + return Term{}, &RepresentationError{flag: "max_integer"} + default: + return NewInteger(h, i) + } +} + +func float(h *Heap, sign float64, s string) (Term, error) { + bf, _, _ := big.ParseFloat(s, 10, 0, big.ToZero) + bf.Mul(big.NewFloat(sign), bf) + + f, _ := bf.Float64() + return NewFloat(h, f) +} + +var ( + quotedIdentEscapePattern = regexp.MustCompile("''|\\\\(?:[\\nabfnrtv\\\\'\"`]|(?:x[\\da-fA-F]+|[0-8]+)\\\\)") + doubleQuotedEscapePattern = regexp.MustCompile("\"\"|\\\\(?:[\\nabfnrtv\\\\'\"`]|(?:x[\\da-fA-F]+|[0-8]+)\\\\)") +) + +func unquote(s string) string { + return quotedIdentEscapePattern.ReplaceAllStringFunc(s[1:len(s)-1], quotedIdentUnescape) +} + +func quotedIdentUnescape(s string) string { + switch s { + case "''": + return "'" + case "\\\n": + return "" + case `\a`: + return "\a" + case `\b`: + return "\b" + case `\f`: + return "\f" + case `\n`: + return "\n" + case `\r`: + return "\r" + case `\t`: + return "\t" + case `\v`: + return "\v" + case `\\`: + return `\` + case `\'`: + return `'` + case `\"`: + return `"` + case "\\`": + return "`" + default: // `\x23\` or `\23\` + s = s[1 : len(s)-1] // `x23` or `23` + base := 8 + + if s[0] == 'x' { + s = s[1:] + base = 16 + } + + r, _ := strconv.ParseInt(s, base, 4*8) // rune is up to 4 bytes + return string(rune(r)) + } +} + +func unDoubleQuote(s string) string { + return doubleQuotedEscapePattern.ReplaceAllStringFunc(s[1:len(s)-1], doubleQuotedUnescape) +} + +func doubleQuotedUnescape(s string) string { + switch s { + case `""`: + return `"` + case "\\\n": + return "" + case `\a`: + return "\a" + case `\b`: + return "\b" + case `\f`: + return "\f" + case `\n`: + return "\n" + case `\r`: + return "\r" + case `\t`: + return "\t" + case `\v`: + return "\v" + case `\\`: + return `\` + case `\'`: + return `'` + case `\"`: + return `"` + case "\\`": + return "`" + default: // `\x23\` or `\23\` + s = s[1 : len(s)-1] // `x23` or `23` + base := 8 + + if s[0] == 'x' { + s = s[1:] + base = 16 + } + + r, _ := strconv.ParseInt(s, base, 4*8) // rune is up to 4 bytes + return string(rune(r)) + } +} diff --git a/parser_test.go b/parser_test.go new file mode 100644 index 00000000..d961f336 --- /dev/null +++ b/parser_test.go @@ -0,0 +1,263 @@ +package prolog + +import ( + "io" + "reflect" + "strings" + "testing" +) + +func TestParser_Term(t *testing.T) { + h := NewHeap(11 * 1024) + + must := func(term Term, err error) Term { + if err != nil { + t.Fatal(err) + } + return term + } + + var ops Operators + ops.Define(1000, XFY, `,`) + ops.Define(500, YFX, `+`) + ops.Define(400, YFX, `*`) + ops.Define(200, FY, `-`) + ops.Define(200, YF, `--`) + + tests := []struct { + input string + doubleQuotes doubleQuotes + term Term + vars []ParsedVariable + err error + }{ + {input: ``, err: io.EOF}, + {input: `foo`, err: io.EOF}, + {input: `.`, err: &SyntaxError{impDepAtom: "unexpected token: end(.)"}}, + + {input: `(foo).`, term: must(NewAtom(h, "foo"))}, + {input: `(a b).`, err: &SyntaxError{impDepAtom: "unexpected token: letter digit(b)"}}, + + {input: `foo.`, term: must(NewAtom(h, "foo"))}, + {input: `[].`, term: must(NewAtom(h, "[]"))}, + {input: `[ ].`, term: must(NewAtom(h, "[]"))}, + {input: `{}.`, term: must(NewAtom(h, "{}"))}, + {input: `{ }.`, term: must(NewAtom(h, "{}"))}, + {input: `'abc'.`, term: must(NewAtom(h, "abc"))}, + {input: `'don''t panic'.`, term: must(NewAtom(h, "don't panic"))}, + {input: "'this is \\\na quoted ident'.", term: must(NewAtom(h, "this is a quoted ident"))}, + {input: `'\a'.`, term: must(NewAtom(h, "\a"))}, + {input: `'\b'.`, term: must(NewAtom(h, "\b"))}, + {input: `'\f'.`, term: must(NewAtom(h, "\f"))}, + {input: `'\n'.`, term: must(NewAtom(h, "\n"))}, + {input: `'\r'.`, term: must(NewAtom(h, "\r"))}, + {input: `'\t'.`, term: must(NewAtom(h, "\t"))}, + {input: `'\v'.`, term: must(NewAtom(h, "\v"))}, + {input: `'\43\'.`, term: must(NewAtom(h, "#"))}, + {input: `'\xa3\'.`, term: must(NewAtom(h, "£"))}, + {input: `'\\'.`, term: must(NewAtom(h, `\`))}, + {input: `'\''.`, term: must(NewAtom(h, `'`))}, + {input: `'\"'.`, term: must(NewAtom(h, `"`))}, + {input: "'\\`'.", term: must(NewAtom(h, "`"))}, + {input: `[`, err: io.EOF}, + {input: `{`, err: io.EOF}, + + {input: `1.`, term: must(NewInteger(h, 1))}, + {input: `0'1.`, term: must(NewInteger(h, 49))}, + {input: `0b1.`, term: must(NewInteger(h, 1))}, + {input: `0o1.`, term: must(NewInteger(h, 1))}, + {input: `0x1.`, term: must(NewInteger(h, 1))}, + {input: `-1.`, term: must(NewInteger(h, -1))}, + {input: `- 1.`, term: must(NewInteger(h, -1))}, + {input: `'-'1.`, term: must(NewInteger(h, -1))}, + {input: `9223372036854775808.`, err: &RepresentationError{flag: "max_integer"}}, + {input: `-9223372036854775809.`, err: &RepresentationError{flag: "min_integer"}}, + {input: `-`, err: io.EOF}, + {input: `- -`, err: io.EOF}, + + {input: `1.0.`, term: must(NewFloat(h, 1))}, + {input: `-1.0.`, term: must(NewFloat(h, -1))}, + {input: `- 1.0.`, term: must(NewFloat(h, -1))}, + {input: `'-'1.0.`, term: must(NewFloat(h, -1))}, + + {input: `_.`, term: Term{tag: termTagVariable, payload: 1}}, + {input: `X.`, term: Term{tag: termTagVariable, payload: 1}, vars: []ParsedVariable{ + {Name: "X", Variable: 1, Count: 1}, + }}, + + {input: `foo(a, b).`, term: must(NewCompound(h, "foo", must(NewAtom(h, "a")), must(NewAtom(h, "b"))))}, + {input: `foo(-(a)).`, term: must(NewCompound(h, "foo", must(NewCompound(h, "-", must(NewAtom(h, "a"))))))}, + {input: `foo(-).`, term: must(NewCompound(h, "foo", must(NewAtom(h, "-"))))}, + {input: `foo((), b).`, err: &SyntaxError{impDepAtom: "unexpected token: close())"}}, + {input: `foo([]).`, term: must(NewCompound(h, "foo", must(NewAtom(h, "[]"))))}, + {input: `foo(a, ()).`, err: &SyntaxError{impDepAtom: "unexpected token: close())"}}, + {input: `foo(a b).`, err: &SyntaxError{impDepAtom: "unexpected token: letter digit(b)"}}, + {input: `foo(a, b`, err: io.EOF}, + + {input: `[a, b].`, term: must(NewList(h, must(NewAtom(h, "a")), must(NewAtom(h, "b"))))}, + {input: `[(), b].`, err: &SyntaxError{impDepAtom: "unexpected token: close())"}}, + {input: `[a, ()].`, err: &SyntaxError{impDepAtom: "unexpected token: close())"}}, + {input: `[a b].`, err: &SyntaxError{impDepAtom: "unexpected token: letter digit(b)"}}, + {input: `[a|X].`, term: must(NewCompound(h, ".", must(NewAtom(h, "a")), Term{tag: termTagVariable, payload: 1})), vars: []ParsedVariable{ + {Name: "X", Variable: 1, Count: 1}, + }}, + {input: `[a, b|X].`, term: must(NewPartialList(h, Term{tag: termTagVariable, payload: 1}, must(NewAtom(h, "a")), must(NewAtom(h, "b")))), vars: []ParsedVariable{ + {Name: "X", Variable: 1, Count: 1}, + }}, + {input: `[a, b|()].`, err: &SyntaxError{impDepAtom: "unexpected token: close())"}}, + {input: `[a, b|c d].`, err: &SyntaxError{impDepAtom: "unexpected token: letter digit(d)"}}, + {input: `[a `, err: io.EOF}, + + {input: `{a}.`, term: must(NewCompound(h, "{}", must(NewAtom(h, "a"))))}, + {input: `{()}.`, err: &SyntaxError{impDepAtom: "unexpected token: close())"}}, + {input: `{a b}.`, err: &SyntaxError{impDepAtom: "unexpected token: letter digit(b)"}}, + + {input: `-a.`, term: must(NewCompound(h, "-", must(NewAtom(h, "a"))))}, + {input: `- .`, term: must(NewAtom(h, "-"))}, + + {input: `a-- .`, term: must(NewCompound(h, "--", must(NewAtom(h, "a"))))}, + + {input: `a + b.`, term: must(NewCompound(h, "+", must(NewAtom(h, "a")), must(NewAtom(h, "b"))))}, + {input: `a + ().`, err: &SyntaxError{impDepAtom: "unexpected token: close())"}}, + {input: `a * b + c.`, term: must(NewCompound(h, "+", must(NewCompound(h, "*", must(NewAtom(h, "a")), must(NewAtom(h, "b")))), must(NewAtom(h, "c"))))}, + {input: `a [] b.`, err: &SyntaxError{impDepAtom: "unexpected token: open list([)"}}, + {input: `a {} b.`, err: &SyntaxError{impDepAtom: "unexpected token: open curly({)"}}, + {input: `a, b.`, term: must(NewCompound(h, ",", must(NewAtom(h, "a")), must(NewAtom(h, "b"))))}, + {input: `+ * + .`, err: &SyntaxError{impDepAtom: "unexpected token: graphic(+)"}}, + + {input: `"abc".`, doubleQuotes: doubleQuotesChars, term: must(NewCharList(h, "abc"))}, + {input: `"abc".`, doubleQuotes: doubleQuotesCodes, term: must(NewCodeList(h, "abc"))}, + {input: `"abc".`, doubleQuotes: doubleQuotesAtom, term: must(NewAtom(h, "abc"))}, + {input: `"don""t panic".`, doubleQuotes: doubleQuotesAtom, term: must(NewAtom(h, "don\"t panic"))}, + {input: "\"this is \\\na double-quoted string\".", doubleQuotes: doubleQuotesAtom, term: must(NewAtom(h, "this is a double-quoted string"))}, + {input: `"\a".`, doubleQuotes: doubleQuotesAtom, term: must(NewAtom(h, "\a"))}, + {input: `"\b".`, doubleQuotes: doubleQuotesAtom, term: must(NewAtom(h, "\b"))}, + {input: `"\f".`, doubleQuotes: doubleQuotesAtom, term: must(NewAtom(h, "\f"))}, + {input: `"\n".`, doubleQuotes: doubleQuotesAtom, term: must(NewAtom(h, "\n"))}, + {input: `"\r".`, doubleQuotes: doubleQuotesAtom, term: must(NewAtom(h, "\r"))}, + {input: `"\t".`, doubleQuotes: doubleQuotesAtom, term: must(NewAtom(h, "\t"))}, + {input: `"\v".`, doubleQuotes: doubleQuotesAtom, term: must(NewAtom(h, "\v"))}, + {input: `"\xa3\".`, doubleQuotes: doubleQuotesAtom, term: must(NewAtom(h, "£"))}, + {input: `"\43\".`, doubleQuotes: doubleQuotesAtom, term: must(NewAtom(h, "#"))}, + {input: `"\\".`, doubleQuotes: doubleQuotesAtom, term: must(NewAtom(h, `\`))}, + {input: `"\'".`, doubleQuotes: doubleQuotesAtom, term: must(NewAtom(h, `'`))}, + {input: `"\"".`, doubleQuotes: doubleQuotesAtom, term: must(NewAtom(h, `"`))}, + {input: "\"\\`\".", doubleQuotes: doubleQuotesAtom, term: must(NewAtom(h, "`"))}, + + // https://github.com/ichiban/prolog/issues/219#issuecomment-1200489336 + {input: `write('[]').`, term: must(NewCompound(h, "write", must(NewAtom(h, "[]"))))}, + {input: `write('{}').`, term: must(NewCompound(h, "write", must(NewAtom(h, "{}"))))}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + snapshot := *h + defer func() { + *h = snapshot + }() + + p := NewParser(strings.NewReader(tt.input), ops, tt.doubleQuotes) + term, pvs, err := p.Term(h) + if !reflect.DeepEqual(err, tt.err) { + t.Errorf("expected error %q, got %q", tt.err, err) + } + if o := term.Compare(h, tt.term); o != 0 { + t.Errorf("expected %4q, got %4q", &Formatter{Term: tt.term, Heap: h}, &Formatter{Term: term, Heap: h}) + } + if !reflect.DeepEqual(pvs, tt.vars) { + t.Errorf("expected %v, got %v", tt.vars, pvs) + } + }) + } +} + +func TestParser_Number(t *testing.T) { + h := NewHeap(11 * 1024) + + must := func(term Term, err error) Term { + if err != nil { + t.Fatal(err) + } + return term + } + + tests := []struct { + input string + number Term + err error + }{ + {input: `33`, number: must(NewInteger(h, 33))}, + {input: `-33`, number: must(NewInteger(h, -33))}, + {input: `- 33`, number: must(NewInteger(h, -33))}, + {input: `'-'33`, number: must(NewInteger(h, -33))}, + {input: ` 33`, number: must(NewInteger(h, 33))}, + {input: `9223372036854775808.`, err: &RepresentationError{flag: "max_integer"}}, + {input: `-9223372036854775809.`, err: &RepresentationError{flag: "min_integer"}}, + + {input: `0'!`, number: must(NewInteger(h, 33))}, + {input: `-0'!`, number: must(NewInteger(h, -33))}, + {input: `- 0'!`, number: must(NewInteger(h, -33))}, + {input: `'-'0'!`, number: must(NewInteger(h, -33))}, + + {input: `0b1`, number: must(NewInteger(h, 1))}, + {input: `0o1`, number: must(NewInteger(h, 1))}, + {input: `0x1`, number: must(NewInteger(h, 1))}, + + {input: `3.3`, number: must(NewFloat(h, 3.3))}, + {input: `-3.3`, number: must(NewFloat(h, -3.3))}, + {input: `- 3.3`, number: must(NewFloat(h, -3.3))}, + {input: `'-'3.3`, number: must(NewFloat(h, -3.3))}, + + {input: ``, err: io.EOF}, + {input: `X`, err: &SyntaxError{impDepAtom: "not_a_number"}}, + {input: `33 three`, err: &SyntaxError{impDepAtom: "not_a_number"}}, + {input: `3 `, err: &SyntaxError{impDepAtom: "not_a_number"}}, + {input: `3.`, err: &SyntaxError{impDepAtom: "not_a_number"}}, + {input: `three`, err: &SyntaxError{impDepAtom: "not_a_number"}}, + {input: `-`, err: &SyntaxError{impDepAtom: "not_a_number"}}, + {input: `-a.`, err: &SyntaxError{impDepAtom: "not_a_number"}}, + {input: `()`, err: &SyntaxError{impDepAtom: "not_a_number"}}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + p := NewParser(strings.NewReader(tt.input), Operators{}, doubleQuotesChars) + n, err := p.Number(h) + if !reflect.DeepEqual(err, tt.err) { + t.Errorf("expected error %q, got %q", tt.err, err) + } + if o := n.Compare(h, tt.number); o != 0 { + t.Errorf("expected %4q, got %4q", tt.number, n) + } + }) + } +} + +func TestParser_More(t *testing.T) { + h := NewHeap(1024) + p := NewParser(strings.NewReader(`foo. bar.`), Operators{}, doubleQuotesChars) + term, _, err := p.Term(h) + if err != nil { + t.Fatal(err) + } + if a, err := term.Atom(h); err != nil || a != "foo" { + t.Errorf("expected foo, got %v", a) + } + + if !p.More() { + t.Fatal("expected more") + } + + term, _, err = p.Term(h) + if err != nil { + t.Fatal(err) + } + + if a, err := term.Atom(h); err != nil || a != "bar" { + t.Errorf("expected bar, got %v", a) + } + + if p.More() { + t.Fatal("expected no more") + } +} diff --git a/term.go b/term.go index aa7138eb..f78eeb28 100644 --- a/term.go +++ b/term.go @@ -67,7 +67,7 @@ func NewHeap(bytes int) *Heap { atomBytes = int(unsafe.Sizeof(rbtree.Map[string, atomID]{})) // TODO: ? integerBytes = int(unsafe.Sizeof(int64(0))) floatBytes = int(unsafe.Sizeof(float64(0))) - bindingBytes = int(unsafe.Sizeof(rbtree.Node[variable, Term]{})) + bindingBytes = int(unsafe.Sizeof(rbtree.Node[Variable, Term]{})) stringBytes = int(unsafe.Sizeof(stringEntry{})) ) @@ -95,8 +95,8 @@ func NewHeap(bytes int) *Heap { terms: make([]Term, 0, MaxTerms), env: env{ - Values: rbtree.Map[variable, Term]{ - Nodes: make([]rbtree.Node[variable, Term], 0, MaxBindings), + Values: rbtree.Map[Variable, Term]{ + Nodes: make([]rbtree.Node[Variable, Term], 0, MaxBindings), }, }, atoms: atomTable{ @@ -161,12 +161,12 @@ func NewVariable(h *Heap) (Term, error) { } // Variable returns an error if it's not a variable term. -func (t Term) Variable(h *Heap) error { +func (t Term) Variable(h *Heap) (Variable, error) { t = t.resolve(h) if t.tag != termTagVariable { - return &UninstantiationError{Culprit: t} + return 0, &UninstantiationError{Culprit: t} } - return nil + return Variable(t.payload), nil } func (t Term) resolve(h *Heap) Term { @@ -175,7 +175,7 @@ func (t Term) resolve(h *Heap) Term { return t } - val, ok := h.env.Values.Get(variable(t.payload)) + val, ok := h.env.Values.Get(Variable(t.payload)) if !ok { return t } @@ -306,21 +306,21 @@ func NewCompound(h *Heap, name string, args ...Term) (Term, error) { } // NewList creates a series of compound terms for a list. -func NewList(h *Heap, elems iter.Seq[Term]) (Term, error) { +func NewList(h *Heap, elems ...Term) (Term, error) { tail, err := NewAtom(h, "[]") if err != nil { return Term{}, err } - return NewPartialList(h, elems, tail) + return NewPartialList(h, tail, elems...) } // NewPartialList creates a series of compound terms for a partial list with the specified tail term. -func NewPartialList(h *Heap, elems iter.Seq[Term], tail Term) (Term, error) { +func NewPartialList(h *Heap, tail Term, elems ...Term) (Term, error) { id := int32(len(h.terms)) // CDR coding empty := true - for t := range elems { + for _, t := range elems { empty = false if _, err := h.putFunctor(Functor{Name: ".", Arity: 2}); err != nil { return Term{}, err @@ -356,6 +356,27 @@ func NewPartialCharList(h *Heap, str string, tail Term) (Term, error) { return Term{tag: termTagString, payload: int32(id)}, nil } +// NewCodeList creates a list of single-character atoms. +func NewCodeList(h *Heap, str string) (Term, error) { + tail, err := NewAtom(h, "[]") + if err != nil { + return Term{}, err + } + return NewPartialCodeList(h, str, tail) +} + +func NewPartialCodeList(h *Heap, str string, tail Term) (Term, error) { + var elems []Term + for _, r := range str { + i, err := NewInteger(h, int64(r)) + if err != nil { + return Term{}, err + } + elems = append(elems, i) + } + return NewPartialList(h, tail, elems...) +} + // Functor is a Name with Arity. type Functor struct { Name string @@ -439,13 +460,17 @@ type ListOptions struct { type ListOption func(*ListOptions) // AllowCycle configures the list iterator to allow cyclic lists. -func AllowCycle(opts *ListOptions) { - opts.allowCycle = true +func AllowCycle(ok bool) func(*ListOptions) { + return func(opts *ListOptions) { + opts.allowCycle = ok + } } // AllowPartial configures the list iterator to allow partial lists. -func AllowPartial(opts *ListOptions) { - opts.allowPartial = true +func AllowPartial(ok bool) func(*ListOptions) { + return func(opts *ListOptions) { + opts.allowPartial = ok + } } // List returns an iterator iterates over the elements of a list. @@ -465,7 +490,7 @@ func (t Term) List(h *Heap, opts ...ListOption) iter.Seq2[Term, error] { return func(yield func(Term, error) bool) { for { if tortoise == hare && !o.allowCycle { // Detected a cycle. - _ = yield(Term{}, &TypeError{ValidType: "list", Culprit: t}) + _ = yield(hare, &TypeError{ValidType: "list", Culprit: t}) return } @@ -475,23 +500,23 @@ func (t Term) List(h *Heap, opts ...ListOption) iter.Seq2[Term, error] { lam = 0 } - if err := hare.Variable(h); err == nil { + if _, err := hare.Variable(h); err == nil { if !o.allowPartial { - _ = yield(Term{}, ErrInstantiation) + _ = yield(hare, ErrInstantiation) } return } if a, err := hare.Atom(h); err == nil { if a != "[]" { - _ = yield(Term{}, &TypeError{ValidType: "list", Culprit: t}) + _ = yield(hare, &TypeError{ValidType: "list", Culprit: t}) } return } c, err := hare.Compound(h) if err != nil || c.Functor != (Functor{Name: ".", Arity: 2}) { - _ = yield(Term{}, &TypeError{ValidType: "list", Culprit: t}) + _ = yield(hare, &TypeError{ValidType: "list", Culprit: t}) return } @@ -560,13 +585,13 @@ func unify(h *Heap, x, y Term, occursCheck bool) (bool, error) { return true, nil } - if err := x.Variable(h); err == nil { + if _, err := x.Variable(h); err == nil { if _, err := y.Compound(h); err == nil && occursCheck { if y.Contains(h, x) { return false, nil } } - if err := h.env.Values.Set(variable(x.payload), y); err != nil { + if err := h.env.Values.Set(Variable(x.payload), y); err != nil { return false, &ResourceError{Resource: "variables"} } return true, nil @@ -589,7 +614,7 @@ func unify(h *Heap, x, y Term, occursCheck bool) (bool, error) { } } - if err := y.Variable(h); err == nil { + if _, err := y.Variable(h); err == nil { return unify(h, y, x, occursCheck) } @@ -634,7 +659,7 @@ func renamedCopy(h *Heap, t Term, copied map[Term]Term) (Term, error) { return c, nil } - if err := t.Variable(h); err == nil { + if _, err := t.Variable(h); err == nil { c, err := NewVariable(h) if err != nil { return Term{}, err @@ -719,16 +744,16 @@ func (t Term) Compare(h *Heap, u Term) int { return 0 } - if err := x.Variable(h); err == nil { - if err := y.Variable(h); err == nil { - return int(x.payload) - int(y.payload) + if vx, err := x.Variable(h); err == nil { + if vy, err := y.Variable(h); err == nil { + return int(vx) - int(vy) } return -1 } if x, err := x.Float(h); err == nil { - if err := y.Variable(h); err == nil { + if _, err := y.Variable(h); err == nil { return 1 } @@ -738,6 +763,8 @@ func (t Term) Compare(h *Heap, u Term) int { return 1 case x < y: return -1 + default: + return 0 } } @@ -745,7 +772,7 @@ func (t Term) Compare(h *Heap, u Term) int { } if x, err := x.Integer(h); err == nil { - if err := y.Variable(h); err == nil { + if _, err := y.Variable(h); err == nil { return 1 } @@ -759,6 +786,8 @@ func (t Term) Compare(h *Heap, u Term) int { return 1 case x < y: return -1 + default: + return 0 } } @@ -766,7 +795,7 @@ func (t Term) Compare(h *Heap, u Term) int { } if x, err := x.Atom(h); err == nil { - if err := y.Variable(h); err == nil { + if _, err := y.Variable(h); err == nil { return 1 } @@ -810,14 +839,14 @@ func (t Term) Compare(h *Heap, u Term) int { return 0 } -type variable int32 +type Variable int32 type env struct { - lastVariable variable - Values rbtree.Map[variable, Term] + lastVariable Variable + Values rbtree.Map[Variable, Term] } -func (e *env) Generate() (variable, error) { +func (e *env) Generate() (Variable, error) { if e.lastVariable == math.MaxInt32 { return 0, &ResourceError{Resource: "variables"} } @@ -838,12 +867,6 @@ func (a *atomTable) Put(name string) (atomID, error) { } id := atomID(len(a.Names)) - - // A hack to improve performance. - if id%2 == 0 { - id *= -1 - } - if err := a.IDs.Set(name, id); err != nil { return 0, err } @@ -882,13 +905,16 @@ func (s *stringPool) Put(str string, tail Term) (stringID, error) { if !ok { return 0, &ResourceError{Resource: "strings"} } - return stringID(len(*s) - 1), nil + return e.offset, nil } func (s *stringPool) First(id stringID) Term { i := slices.IndexFunc(*s, func(e stringEntry) bool { - return e.offset+stringID(len(e.string)) >= id + return e.offset+stringID(len(e.string)) > id }) + if i == -1 { + return Term{} + } e := (*s)[i] str := e.string[id-e.offset:] r, _ := utf8.DecodeRuneInString(str) @@ -897,8 +923,11 @@ func (s *stringPool) First(id stringID) Term { func (s *stringPool) Rest(id stringID) Term { i := slices.IndexFunc(*s, func(e stringEntry) bool { - return e.offset+stringID(len(e.string)) >= id + return e.offset+stringID(len(e.string)) > id }) + if i == -1 { + return Term{} + } e := (*s)[i] str := e.string[id-e.offset:] _, n := utf8.DecodeRuneInString(str) diff --git a/term_test.go b/term_test.go index d9b9ddf1..88875566 100644 --- a/term_test.go +++ b/term_test.go @@ -1,11 +1,9 @@ package prolog import ( - "iter" "maps" "math" "reflect" - "slices" "testing" "github.com/ichiban/prolog/v2/internal/rbtree" @@ -296,7 +294,7 @@ func TestNewList(t *testing.T) { tests := []struct { title string heap *Heap - elems iter.Seq[Term] + elems []Term term Term err error }{ @@ -310,7 +308,7 @@ func TestNewList(t *testing.T) { Names: make([]string, 0, 1), }, }, - elems: slices.Values([]Term{}), + elems: []Term{}, term: Term{tag: termTagAtom, payload: 0}, }, { @@ -325,10 +323,10 @@ func TestNewList(t *testing.T) { }, integers: make([]int64, 0, 2), }, - elems: slices.Values([]Term{ + elems: []Term{ {tag: termTagCharacter, payload: 'a'}, {tag: termTagCharacter, payload: 'b'}, - }), + }, term: Term{tag: termTagReference, payload: 0}, }, { @@ -340,10 +338,10 @@ func TestNewList(t *testing.T) { }, integers: make([]int64, 0, 2), }, - elems: slices.Values([]Term{ + elems: []Term{ {tag: termTagCharacter, payload: 'a'}, {tag: termTagCharacter, payload: 'b'}, - }), + }, err: &ResourceError{Resource: "atoms"}, }, { @@ -358,10 +356,10 @@ func TestNewList(t *testing.T) { }, integers: make([]int64, 0, 2), }, - elems: slices.Values([]Term{ + elems: []Term{ {tag: termTagCharacter, payload: 'a'}, {tag: termTagCharacter, payload: 'b'}, - }), + }, err: &ResourceError{Resource: "terms"}, }, { @@ -376,10 +374,10 @@ func TestNewList(t *testing.T) { }, integers: make([]int64, 0, 2), }, - elems: slices.Values([]Term{ + elems: []Term{ {tag: termTagCharacter, payload: 'a'}, {tag: termTagCharacter, payload: 'b'}, - }), + }, err: &ResourceError{Resource: "terms"}, }, { @@ -394,17 +392,17 @@ func TestNewList(t *testing.T) { }, integers: make([]int64, 0, 2), }, - elems: slices.Values([]Term{ + elems: []Term{ {tag: termTagCharacter, payload: 'a'}, {tag: termTagCharacter, payload: 'b'}, - }), + }, err: &ResourceError{Resource: "terms"}, }, } for _, tt := range tests { t.Run(tt.title, func(t *testing.T) { - l, err := NewList(tt.heap, tt.elems) + l, err := NewList(tt.heap, tt.elems...) if !reflect.DeepEqual(err, tt.err) { t.Errorf("expected: %v, got: %v", tt.err, err) } @@ -470,12 +468,6 @@ func TestNewCharList(t *testing.T) { } }) } - - for _, tt := range tests { - t.Run(tt.title, func(t *testing.T) { - - }) - } } func TestTerm_Variable(t *testing.T) { @@ -502,7 +494,7 @@ func TestTerm_Variable(t *testing.T) { for _, tt := range tests { t.Run(tt.title, func(t *testing.T) { - err := tt.term.Variable(h) + _, err := tt.term.Variable(h) if !reflect.DeepEqual(err, tt.err) { t.Errorf("expected: %v, got: %v", tt.err, err) } @@ -722,22 +714,22 @@ func TestTerm_List(t *testing.T) { t.Fatal(err) } - l, err := NewList(h, slices.Values([]Term{a, b})) + l, err := NewList(h, a, b) if err != nil { t.Fatal(err) } - nl, err := NewPartialList(h, slices.Values([]Term{a, b}), a) + nl, err := NewPartialList(h, a, a, b) if err != nil { t.Fatal(err) } - nl2, err := NewPartialList(h, slices.Values([]Term{a, b}), one) + nl2, err := NewPartialList(h, one, a, b) if err != nil { t.Fatal(err) } - pl, err := NewPartialList(h, slices.Values([]Term{a, b}), v) + pl, err := NewPartialList(h, v, a, b) if err != nil { t.Fatal(err) } @@ -746,7 +738,7 @@ func TestTerm_List(t *testing.T) { if err != nil { t.Fatal(err) } - cl, err := NewPartialList(h, slices.Values([]Term{a, b}), tail) + cl, err := NewPartialList(h, tail, a, b) if err != nil { t.Fatal(err) } @@ -780,28 +772,28 @@ func TestTerm_List(t *testing.T) { {title: "[a, b|a]", term: nl, results: []result{ {term: a}, {term: b}, - {err: &TypeError{ValidType: "list", Culprit: nl}}, + {term: a, err: &TypeError{ValidType: "list", Culprit: nl}}, }}, {title: "[a, b|1]", term: nl2, results: []result{ {term: a}, {term: b}, - {err: &TypeError{ValidType: "list", Culprit: nl2}}, + {term: one, err: &TypeError{ValidType: "list", Culprit: nl2}}, }}, {title: "[a, b|_]", term: pl, results: []result{ {term: a}, {term: b}, - {err: ErrInstantiation}, + {term: Term{tag: termTagVariable, payload: 1}, err: ErrInstantiation}, }}, - {title: "[a, b|_] with AllowPartial", term: pl, options: []ListOption{AllowPartial}, results: []result{ + {title: "[a, b|_] with AllowPartial", term: pl, options: []ListOption{AllowPartial(true)}, results: []result{ {term: a}, {term: b}, }}, {title: "[a, b, a, b|...]", term: cl, results: []result{ {term: a}, {term: b}, - {err: &TypeError{ValidType: "list", Culprit: cl}}, + {term: cl, err: &TypeError{ValidType: "list", Culprit: cl}}, }}, - {title: "[a, b, a, b|...] with AllowCyclic", term: cl, options: []ListOption{AllowCycle}, count: 8, results: []result{ + {title: "[a, b, a, b|...] with AllowCyclic", term: cl, options: []ListOption{AllowCycle(true)}, count: 8, results: []result{ {term: a}, {term: b}, {term: a}, @@ -849,7 +841,7 @@ func TestTerm_CharList(t *testing.T) { t.Fatal(err) } - list, err := NewList(h, slices.Values([]Term{a, b, c})) + list, err := NewList(h, a, b, c) if err != nil { t.Fatal(err) } @@ -979,21 +971,21 @@ func TestTerm_Unify(t *testing.T) { x, y Term ok bool err error - env map[variable]Term + env map[Variable]Term }{ {title: "a = a", heap: h, x: a, y: a, ok: true}, {title: "V = V", heap: h, x: v, y: v, ok: true}, - {title: "V = W", heap: h, x: v, y: w, ok: true, env: map[variable]Term{ - variable(v.payload): w, + {title: "V = W", heap: h, x: v, y: w, ok: true, env: map[Variable]Term{ + Variable(v.payload): w, }}, {title: "f(a) = g(a)", heap: h, x: fa, y: ga, ok: false}, {title: "f(a) = f(b)", heap: h, x: fa, y: fb, ok: false}, - {title: "a = V", heap: h, x: a, y: v, ok: true, env: map[variable]Term{ - variable(v.payload): a, + {title: "a = V", heap: h, x: a, y: v, ok: true, env: map[Variable]Term{ + Variable(v.payload): a, }}, {title: "a = b", heap: h, x: a, y: b, ok: false}, - {title: "X = f(X)", heap: h, x: x, y: fx, ok: true, env: map[variable]Term{ - variable(x.payload): fx, + {title: "X = f(X)", heap: h, x: x, y: fx, ok: true, env: map[Variable]Term{ + Variable(x.payload): fx, }}, {title: "insufficient variables", heap: &Heap{}, x: v, y: a, err: &ResourceError{Resource: "variables"}}, } @@ -1014,7 +1006,7 @@ func TestTerm_Unify(t *testing.T) { t.Errorf("expected: %v, got: %v", tt.ok, ok) } - env := map[variable]Term{} + env := map[Variable]Term{} for k, v := range h.env.Values.All() { env[k] = v } @@ -1080,17 +1072,17 @@ func TestTerm_UnifyWithOccursCheck(t *testing.T) { x, y Term ok bool err error - env map[variable]Term + env map[Variable]Term }{ {title: "a = a", heap: h, x: a, y: a, ok: true}, {title: "V = V", heap: h, x: v, y: v, ok: true}, - {title: "V = W", heap: h, x: v, y: w, ok: true, env: map[variable]Term{ - variable(v.payload): w, + {title: "V = W", heap: h, x: v, y: w, ok: true, env: map[Variable]Term{ + Variable(v.payload): w, }}, {title: "f(a) = g(a)", heap: h, x: fa, y: ga, ok: false}, {title: "f(a) = f(b)", heap: h, x: fa, y: fb, ok: false}, - {title: "a = V", heap: h, x: a, y: v, ok: true, env: map[variable]Term{ - variable(v.payload): a, + {title: "a = V", heap: h, x: a, y: v, ok: true, env: map[Variable]Term{ + Variable(v.payload): a, }}, {title: "a = b", heap: h, x: a, y: b, ok: false}, {title: "X = f(X)", heap: h, x: x, y: fx, ok: false}, @@ -1113,7 +1105,7 @@ func TestTerm_UnifyWithOccursCheck(t *testing.T) { t.Errorf("expected: %v, got: %v", tt.ok, ok) } - env := map[variable]Term{} + env := map[Variable]Term{} for k, v := range h.env.Values.All() { env[k] = v } @@ -1519,12 +1511,12 @@ func TestCompound_Arg(t *testing.T) { t.Fatal(err) } - listAB, err := NewList(h, slices.Values([]Term{a, b})) + listAB, err := NewList(h, a, b) if err != nil { t.Fatal(err) } - listB, err := NewList(h, slices.Values([]Term{b})) + listB, err := NewList(h, b) if err != nil { t.Fatal(err) } @@ -1570,6 +1562,33 @@ func TestCompound_Arg(t *testing.T) { } } -func indirect[T any](t T) *T { - return &t +func TestStringPool_First(t *testing.T) { + h := NewHeap(1024) + if _, err := h.strings.Put("ab", Term{tag: termTagVariable, payload: 0}); err != nil { + t.Fatal(err) + } + if _, err := h.strings.Put("cd", Term{tag: termTagVariable, payload: 1}); err != nil { + t.Fatal(err) + } + + tests := []struct { + title string + id stringID + term Term + }{ + {title: "0", id: 0, term: Term{tag: termTagCharacter, payload: 'a'}}, + {title: "1", id: 1, term: Term{tag: termTagCharacter, payload: 'b'}}, + {title: "2", id: 2, term: Term{tag: termTagCharacter, payload: 'c'}}, + {title: "3", id: 3, term: Term{tag: termTagCharacter, payload: 'd'}}, + {title: "4", id: 4, term: Term{}}, + } + + for _, tt := range tests { + t.Run(tt.title, func(t *testing.T) { + term := h.strings.First(tt.id) + if o := term.Compare(h, tt.term); o != 0 { + t.Errorf("expected %v, got %v", tt.term, term) + } + }) + } }