package calculator
|
|
import (
|
"errors"
|
"fmt"
|
"strconv"
|
)
|
|
var precedence = map[string]int{"+": 20, "-": 20, "*": 40, "/": 40, "%": 40, "^": 60}
|
|
type ExprAST interface {
|
toStr() string
|
}
|
|
type NumberExprAST struct {
|
Val float64
|
Str string
|
}
|
|
type BinaryExprAST struct {
|
Op string
|
Lhs,
|
Rhs ExprAST
|
}
|
|
type FunCallerExprAST struct {
|
Name string
|
Arg []ExprAST
|
}
|
|
func (n NumberExprAST) toStr() string {
|
return fmt.Sprintf(
|
"NumberExprAST:%s",
|
n.Str,
|
)
|
}
|
|
func (b BinaryExprAST) toStr() string {
|
return fmt.Sprintf(
|
"BinaryExprAST: (%s %s %s)",
|
b.Op,
|
b.Lhs.toStr(),
|
b.Rhs.toStr(),
|
)
|
}
|
|
func (n FunCallerExprAST) toStr() string {
|
return fmt.Sprintf(
|
"FunCallerExprAST:%s",
|
n.Name,
|
)
|
}
|
|
type AST struct {
|
Tokens []*Token
|
|
source string
|
currTok *Token
|
currIndex int
|
depth int
|
|
Err error
|
}
|
|
func NewAST(toks []*Token, s string) *AST {
|
a := &AST{
|
Tokens: toks,
|
source: s,
|
}
|
if a.Tokens == nil || len(a.Tokens) == 0 {
|
a.Err = errors.New("empty token")
|
} else {
|
a.currIndex = 0
|
a.currTok = a.Tokens[0]
|
}
|
return a
|
}
|
|
func (a *AST) ParseExpression() ExprAST {
|
a.depth++ // called depth
|
lhs := a.parsePrimary()
|
r := a.parseBinOpRHS(0, lhs)
|
a.depth--
|
if a.depth == 0 && a.currIndex != len(a.Tokens) && a.Err == nil {
|
a.Err = errors.New(
|
fmt.Sprintf("bad expression, reaching the end or missing the operator\n%s",
|
ErrPos(a.source, a.currTok.Offset)))
|
}
|
return r
|
}
|
|
func (a *AST) getNextToken() *Token {
|
a.currIndex++
|
if a.currIndex < len(a.Tokens) {
|
a.currTok = a.Tokens[a.currIndex]
|
return a.currTok
|
}
|
return nil
|
}
|
|
func (a *AST) getTokPrecedence() int {
|
if p, ok := precedence[a.currTok.Tok]; ok {
|
return p
|
}
|
return -1
|
}
|
|
func (a *AST) parseNumber() NumberExprAST {
|
f64, err := strconv.ParseFloat(a.currTok.Tok, 64)
|
if err != nil {
|
a.Err = errors.New(
|
fmt.Sprintf("%v\nwant '(' or '0-9' but get '%s'\n%s",
|
err.Error(),
|
a.currTok.Tok,
|
ErrPos(a.source, a.currTok.Offset)))
|
return NumberExprAST{}
|
}
|
n := NumberExprAST{
|
Val: f64,
|
Str: a.currTok.Tok,
|
}
|
a.getNextToken()
|
return n
|
}
|
|
func (a *AST) parseFunCallerOrConst() ExprAST {
|
name := a.currTok.Tok
|
a.getNextToken()
|
// call func
|
if a.currTok.Tok == "(" {
|
f := FunCallerExprAST{}
|
if _, ok := defFunc[name]; !ok {
|
a.Err = errors.New(
|
fmt.Sprintf("function `%s` is undefined\n%s",
|
name,
|
ErrPos(a.source, a.currTok.Offset)))
|
return f
|
}
|
a.getNextToken()
|
exprs := make([]ExprAST, 0)
|
if a.currTok.Tok == ")" {
|
// function call without parameters
|
// ignore the process of parameter resolution
|
} else {
|
exprs = append(exprs, a.ParseExpression())
|
for a.currTok.Tok != ")" && a.getNextToken() != nil {
|
if a.currTok.Type == COMMA {
|
continue
|
}
|
exprs = append(exprs, a.ParseExpression())
|
}
|
}
|
def := defFunc[name]
|
if def.argc >= 0 && len(exprs) != def.argc {
|
a.Err = errors.New(
|
fmt.Sprintf("wrong way calling function `%s`, parameters want %d but get %d\n%s",
|
name,
|
def.argc,
|
len(exprs),
|
ErrPos(a.source, a.currTok.Offset)))
|
}
|
a.getNextToken()
|
f.Name = name
|
f.Arg = exprs
|
return f
|
}
|
// call const
|
if v, ok := defConst[name]; ok {
|
return NumberExprAST{
|
Val: v,
|
Str: strconv.FormatFloat(v, 'f', 0, 64),
|
}
|
} else {
|
a.Err = errors.New(
|
fmt.Sprintf("const `%s` is undefined\n%s",
|
name,
|
ErrPos(a.source, a.currTok.Offset)))
|
return NumberExprAST{}
|
}
|
}
|
|
func (a *AST) parsePrimary() ExprAST {
|
switch a.currTok.Type {
|
case Identifier:
|
return a.parseFunCallerOrConst()
|
case Literal:
|
return a.parseNumber()
|
case Operator:
|
if a.currTok.Tok == "(" {
|
t := a.getNextToken()
|
if t == nil {
|
a.Err = errors.New(
|
fmt.Sprintf("want '(' or '0-9' but get EOF\n%s",
|
ErrPos(a.source, a.currTok.Offset)))
|
return nil
|
}
|
e := a.ParseExpression()
|
if e == nil {
|
return nil
|
}
|
if a.currTok.Tok != ")" {
|
a.Err = errors.New(
|
fmt.Sprintf("want ')' but get %s\n%s",
|
a.currTok.Tok,
|
ErrPos(a.source, a.currTok.Offset)))
|
return nil
|
}
|
a.getNextToken()
|
return e
|
} else if a.currTok.Tok == "-" {
|
if a.getNextToken() == nil {
|
a.Err = errors.New(
|
fmt.Sprintf("want '0-9' but get '-'\n%s",
|
ErrPos(a.source, a.currTok.Offset)))
|
return nil
|
}
|
bin := BinaryExprAST{
|
Op: "-",
|
Lhs: NumberExprAST{},
|
Rhs: a.parsePrimary(),
|
}
|
return bin
|
} else {
|
return a.parseNumber()
|
}
|
case COMMA:
|
a.Err = errors.New(
|
fmt.Sprintf("want '(' or '0-9' but get %s\n%s",
|
a.currTok.Tok,
|
ErrPos(a.source, a.currTok.Offset)))
|
return nil
|
default:
|
return nil
|
}
|
}
|
|
func (a *AST) parseBinOpRHS(execPrec int, lhs ExprAST) ExprAST {
|
for {
|
tokPrec := a.getTokPrecedence()
|
if tokPrec < execPrec {
|
return lhs
|
}
|
binOp := a.currTok.Tok
|
if a.getNextToken() == nil {
|
a.Err = errors.New(
|
fmt.Sprintf("want '(' or '0-9' but get EOF\n%s",
|
ErrPos(a.source, a.currTok.Offset)))
|
return nil
|
}
|
rhs := a.parsePrimary()
|
if rhs == nil {
|
return nil
|
}
|
nextPrec := a.getTokPrecedence()
|
if tokPrec < nextPrec {
|
rhs = a.parseBinOpRHS(tokPrec+1, rhs)
|
if rhs == nil {
|
return nil
|
}
|
}
|
lhs = BinaryExprAST{
|
Op: binOp,
|
Lhs: lhs,
|
Rhs: rhs,
|
}
|
}
|
}
|