diff --git a/ast/Expression.go b/ast/Expression.go index e6fdc8f..20d25d2 100755 --- a/ast/Expression.go +++ b/ast/Expression.go @@ -17,10 +17,11 @@ package ast import ( "errors" "fmt" - "github.com/hyperjumptech/grule-rule-engine/ast/unique" "reflect" "strings" + "github.com/hyperjumptech/grule-rule-engine/ast/unique" + "github.com/hyperjumptech/grule-rule-engine/pkg" ) @@ -79,7 +80,8 @@ type Expression struct { Negated bool Value reflect.Value - Evaluated bool + Evaluated bool + CompareNilValues bool } // MakeCatalog will create a catalog entry from Expression node. @@ -279,7 +281,18 @@ func (e *Expression) SetGrlText(grlText string) { // Evaluate will evaluate this AST graph for when scope evaluation func (e *Expression) Evaluate(dataContext IDataContext, memory *WorkingMemory) (reflect.Value, error) { - if e.Evaluated == true { + compareNode := dataContext.Get("COMPARE_NILS") + if compareNode != nil { + if rv := compareNode.Value(); rv.IsValid() && rv.Kind() == reflect.Bool { + e.CompareNilValues = rv.Bool() + } else { + e.CompareNilValues = false + } + } else { + e.CompareNilValues = false + } + + if e.Evaluated { return e.Value, nil } @@ -350,41 +363,84 @@ func (e *Expression) Evaluate(dataContext IDataContext, memory *WorkingMemory) ( return reflect.Value{}, fmt.Errorf("right hand expression error. got %v", rerr) } - switch e.Operator { - case OpMul: - val, opErr = pkg.EvaluateMultiplication(lval, rval) - case OpDiv: - val, opErr = pkg.EvaluateDivision(lval, rval) - case OpMod: - val, opErr = pkg.EvaluateModulo(lval, rval) - case OpAdd: - val, opErr = pkg.EvaluateAddition(lval, rval) - case OpSub: - val, opErr = pkg.EvaluateSubtraction(lval, rval) - case OpBitAnd: - val, opErr = pkg.EvaluateBitAnd(lval, rval) - case OpBitOr: - val, opErr = pkg.EvaluateBitOr(lval, rval) - case OpGT: - val, opErr = pkg.EvaluateGreaterThan(lval, rval) - case OpLT: - val, opErr = pkg.EvaluateLesserThan(lval, rval) - case OpGTE: - val, opErr = pkg.EvaluateGreaterThanEqual(lval, rval) - case OpLTE: - val, opErr = pkg.EvaluateLesserThanEqual(lval, rval) - case OpEq: - val, opErr = pkg.EvaluateEqual(lval, rval) - case OpNEq: - val, opErr = pkg.EvaluateNotEqual(lval, rval) - case OpAnd: - val, opErr = pkg.EvaluateLogicAnd(lval, rval) - case OpOr: - val, opErr = pkg.EvaluateLogicOr(lval, rval) - } - if opErr == nil { - e.Value = val - e.Evaluated = true + if e.CompareNilValues && (!lval.IsValid() || !rval.IsValid()) { + if e.CompareNilValues { + AstLog.Debugf("Values have invalid value (%v and %v) but continuing with null handling", lval, rval) + switch e.Operator { + case OpMul, OpDiv, OpBitAnd, OpBitOr, OpMod: + // Can be left as nil, as these operators with Nil are not defined + e.Evaluated = true + case OpAdd: + if lval.IsValid() { + val = lval + } else if rval.IsValid() { + val = rval + } + e.Evaluated = true + case OpSub: + if lval.IsValid() { + val = lval + } else if rval.IsValid() { + val, _ = pkg.EvaluateSubtraction(reflect.ValueOf(0), rval) + } + e.Evaluated = true + case OpOr: + lvale := pkg.GetValueElem(lval) + rvale := pkg.GetValueElem(rval) + if (lvale.IsValid() && lvale.Kind() == reflect.Bool && lvale.Bool()) || (rvale.IsValid() && rvale.Kind() == reflect.Bool && rvale.Bool()) { + val = reflect.ValueOf(true) + } else { + val = reflect.ValueOf(false) + } + e.Value = val + e.Evaluated = true + case OpEq, OpGTE, OpLTE, OpGT, OpLT, OpAnd: + val = reflect.ValueOf(false) + e.Value = val + e.Evaluated = true + case OpNEq: + val = reflect.ValueOf(true) + e.Value = val + e.Evaluated = true + } + } + } else { + switch e.Operator { + case OpMul: + val, opErr = pkg.EvaluateMultiplication(lval, rval) + case OpDiv: + val, opErr = pkg.EvaluateDivision(lval, rval) + case OpMod: + val, opErr = pkg.EvaluateModulo(lval, rval) + case OpAdd: + val, opErr = pkg.EvaluateAddition(lval, rval) + case OpSub: + val, opErr = pkg.EvaluateSubtraction(lval, rval) + case OpBitAnd: + val, opErr = pkg.EvaluateBitAnd(lval, rval) + case OpBitOr: + val, opErr = pkg.EvaluateBitOr(lval, rval) + case OpGT: + val, opErr = pkg.EvaluateGreaterThan(lval, rval) + case OpLT: + val, opErr = pkg.EvaluateLesserThan(lval, rval) + case OpGTE: + val, opErr = pkg.EvaluateGreaterThanEqual(lval, rval) + case OpLTE: + val, opErr = pkg.EvaluateLesserThanEqual(lval, rval) + case OpEq: + val, opErr = pkg.EvaluateEqual(lval, rval) + case OpNEq: + val, opErr = pkg.EvaluateNotEqual(lval, rval) + case OpAnd: + val, opErr = pkg.EvaluateLogicAnd(lval, rval) + case OpOr: + val, opErr = pkg.EvaluateLogicOr(lval, rval) + } + if opErr == nil { + e.Value = val + e.Evaluated = true + } } return val, opErr diff --git a/ast/ExpressionAtom.go b/ast/ExpressionAtom.go index dae3fda..85f8c42 100755 --- a/ast/ExpressionAtom.go +++ b/ast/ExpressionAtom.go @@ -17,11 +17,12 @@ package ast import ( "errors" "fmt" - "github.com/hyperjumptech/grule-rule-engine/ast/unique" - "github.com/hyperjumptech/grule-rule-engine/model" "reflect" "strings" + "github.com/hyperjumptech/grule-rule-engine/ast/unique" + "github.com/hyperjumptech/grule-rule-engine/model" + "github.com/hyperjumptech/grule-rule-engine/pkg" ) @@ -49,7 +50,8 @@ type ExpressionAtom struct { Value reflect.Value ValueNode model.ValueNode - Evaluated bool + Evaluated bool + CompareNilValues bool } // MakeCatalog will create a catalog entry from ExpressionAtom node. @@ -265,10 +267,22 @@ func (e *ExpressionAtom) SetGrlText(grlText string) { // Evaluate will evaluate this AST graph for when scope evaluation func (e *ExpressionAtom) Evaluate(dataContext IDataContext, memory *WorkingMemory) (val reflect.Value, err error) { - if e.Evaluated == true { + // Extract COMPARE_NILS from dataContext as a bool, defaulting to false when unavailable or not boolean. + compareNode := dataContext.Get("COMPARE_NILS") + if compareNode != nil { + if rv := compareNode.Value(); rv.IsValid() && rv.Kind() == reflect.Bool { + e.CompareNilValues = rv.Bool() + } else { + e.CompareNilValues = false + } + } else { + e.CompareNilValues = false + } + if e.Evaluated { return e.Value, nil } + if e.Constant != nil { val, err := e.Constant.Evaluate(dataContext, memory) if err != nil { @@ -346,8 +360,15 @@ func (e *ExpressionAtom) Evaluate(dataContext IDataContext, memory *WorkingMemor return reflect.ValueOf(nil), err } + if e.ExpressionAtom.ValueNode == nil && e.CompareNilValues { + return reflect.ValueOf(nil), nil + } + retVal, err := e.ExpressionAtom.ValueNode.CallFunction(e.FunctionCall.FunctionName, args...) if err != nil { + if e.CompareNilValues { + return reflect.ValueOf(nil), nil + } return reflect.ValueOf(nil), err } @@ -368,6 +389,13 @@ func (e *ExpressionAtom) Evaluate(dataContext IDataContext, memory *WorkingMemor } valueNode, err := e.ExpressionAtom.ValueNode.GetChildNodeByField(e.VariableName) if err != nil { + if e.CompareNilValues { + e.ValueNode = model.NewGoValueNode(reflect.ValueOf(nil), fmt.Sprintf("%s.%s->nil", e.ExpressionAtom.ValueNode.IdentifiedAs(), e.VariableName)) + e.Value = e.ValueNode.Value() + e.Evaluated = true + + return e.Value, nil + } return reflect.Value{}, err } diff --git a/ast/Variable.go b/ast/Variable.go index 65aeef3..8c88886 100755 --- a/ast/Variable.go +++ b/ast/Variable.go @@ -16,11 +16,12 @@ package ast import ( "fmt" - "github.com/hyperjumptech/grule-rule-engine/ast/unique" - "github.com/hyperjumptech/grule-rule-engine/model" "reflect" "strings" + "github.com/hyperjumptech/grule-rule-engine/ast/unique" + "github.com/hyperjumptech/grule-rule-engine/model" + "github.com/hyperjumptech/grule-rule-engine/pkg" ) @@ -43,6 +44,8 @@ type Variable struct { ValueNode model.ValueNode Value reflect.Value + + CompareNilValues bool } // MakeCatalog create a catalog entry for this AST Node @@ -219,6 +222,17 @@ func (e *Variable) Assign(newVal reflect.Value, dataContext IDataContext, memory // Evaluate will evaluate this AST graph for when scope evaluation func (e *Variable) Evaluate(dataContext IDataContext, memory *WorkingMemory) (reflect.Value, error) { + compareNode := dataContext.Get("COMPARE_NILS") + if compareNode != nil { + if rv := compareNode.Value(); rv.IsValid() && rv.Kind() == reflect.Bool { + e.CompareNilValues = rv.Bool() + } else { + e.CompareNilValues = false + } + } else { + e.CompareNilValues = false + } + if len(e.Name) > 0 && e.Variable == nil { valueNode := dataContext.Get(e.Name) if valueNode == nil { @@ -238,7 +252,9 @@ func (e *Variable) Evaluate(dataContext IDataContext, memory *WorkingMemory) (re } valueNode, err := e.Variable.ValueNode.GetChildNodeByField(e.Name) if err != nil { - + if e.CompareNilValues { + return reflect.ValueOf(nil), nil + } return reflect.Value{}, err } e.ValueNode = valueNode diff --git a/engine/GruleEngine.go b/engine/GruleEngine.go index b65f867..e8740df 100755 --- a/engine/GruleEngine.go +++ b/engine/GruleEngine.go @@ -17,11 +17,12 @@ package engine import ( "context" "fmt" + "sort" + "time" + "github.com/rs/zerolog" "github.com/sirupsen/logrus" "go.uber.org/zap" - "sort" - "time" "github.com/hyperjumptech/grule-rule-engine/ast" "github.com/hyperjumptech/grule-rule-engine/logger" @@ -87,6 +88,7 @@ func NewGruleEngine() *GruleEngine { type GruleEngine struct { MaxCycle uint64 ReturnErrOnFailedRuleEvaluation bool + CompareNilValues bool Listeners []GruleEngineListener } @@ -150,6 +152,13 @@ func (g *GruleEngine) ExecuteWithContext(ctx context.Context, dataCtx ast.IDataC return err } + err = dataCtx.Add("COMPARE_NILS", g.CompareNilValues) + if err != nil { + log.Error("COMPARE_NILS add err") + + return err + } + // Working memory need to be resetted. all Expression will be set as not evaluated. log.Debugf("Resetting Working memory") knowledge.WorkingMemory.ResetAll() @@ -279,6 +288,12 @@ func (g *GruleEngine) FetchMatchingRules(dataCtx ast.IDataContext, knowledge *as return nil, err } + err = dataCtx.Add("COMPARE_NILS", g.CompareNilValues) + if err != nil { + log.Error("COMPARE_NILS add err") + + return nil, err + } // Working memory need to be resetted. all Expression will be set as not evaluated. log.Debugf("Resetting Working memory") knowledge.WorkingMemory.ResetAll() diff --git a/engine/GruleEngine_test.go b/engine/GruleEngine_test.go index 3808923..7eb6139 100755 --- a/engine/GruleEngine_test.go +++ b/engine/GruleEngine_test.go @@ -140,9 +140,10 @@ func getTypeOf(i interface{}) string { return t.Name() } +// TODO: Add also tests when function argument(s) are nil pointers const ruleWithAccessErr = `rule AccessErrRule "test access error rule" salience 10 { when - TestStruct.NotExist == 1 + TestStruct.NotExist == 1 || TestStruct.OtherNonExists || TestStruct.ThirdNonExist.Contains("included value") == true || TestStruct.exist.Contains(TestStruct.NonExisting) == true then Retract("AccessErrRule"); }` @@ -154,7 +155,7 @@ func TestEngine_ExecuteErr(t *testing.T) { lib := ast.NewKnowledgeLibrary() rb := builder.NewRuleBuilder(lib) - err = rb.BuildRuleFromResource("Test", "0.1.1", pkg.NewBytesResource([]byte(rules))) + err = rb.BuildRuleFromResource("Test", "0.1.1", pkg.NewBytesResource([]byte(ruleWithAccessErr))) assert.NoError(t, err) engine := NewGruleEngine() @@ -165,6 +166,45 @@ func TestEngine_ExecuteErr(t *testing.T) { assert.Error(t, err) } +func TestEngine_ExecuteHandleNilsJSON(t *testing.T) { + dctx := ast.NewDataContext() + testJson := `{ "exist": "\"This field exist\"" }` + err := dctx.AddJSON("TestStruct", []byte(testJson)) + assert.NoError(t, err) + + lib := ast.NewKnowledgeLibrary() + rb := builder.NewRuleBuilder(lib) + err = rb.BuildRuleFromResource("Test", "0.1.1", pkg.NewBytesResource([]byte(ruleWithAccessErr))) + assert.NoError(t, err) + + engine := NewGruleEngine() + engine.ReturnErrOnFailedRuleEvaluation = true + engine.CompareNilValues = true + kb, err := lib.NewKnowledgeBaseInstance("Test", "0.1.1") + assert.NoError(t, err) + err = engine.Execute(dctx, kb) + assert.NoError(t, err) +} + +func TestEngine_ExecuteHandleNils(t *testing.T) { + dctx := ast.NewDataContext() + err := dctx.Add("TestStruct", &TestStruct{}) + assert.NoError(t, err) + + lib := ast.NewKnowledgeLibrary() + rb := builder.NewRuleBuilder(lib) + err = rb.BuildRuleFromResource("Test", "0.1.1", pkg.NewBytesResource([]byte(ruleWithAccessErr))) + assert.NoError(t, err) + + engine := NewGruleEngine() + engine.ReturnErrOnFailedRuleEvaluation = true + engine.CompareNilValues = true + kb, err := lib.NewKnowledgeBaseInstance("Test", "0.1.1") + assert.NoError(t, err) + err = engine.Execute(dctx, kb) + assert.NoError(t, err) +} + func TestEmptyValueEquality(t *testing.T) { t1 := time.Time{} tv1 := reflect.ValueOf(t1)