Skip to content

Commit 9a4960f

Browse files
fix: resolve nullability issues and improve type combinations
1 parent 55dba9f commit 9a4960f

File tree

1 file changed

+104
-50
lines changed

1 file changed

+104
-50
lines changed

internal/compiler/output_columns.go

Lines changed: 104 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
126126
continue
127127
}
128128
switch n := res.Val.(type) {
129-
130129
case *ast.A_Const:
131130
name := ""
132131
if res.Name != nil {
@@ -150,7 +149,8 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
150149
if res.Name != nil {
151150
name = *res.Name
152151
}
153-
switch op := astutils.Join(n.Name, ""); {
152+
op := astutils.Join(n.Name, "")
153+
switch {
154154
case lang.IsComparisonOperator(op):
155155
// TODO: Generate a name for these operations
156156
cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true})
@@ -774,98 +774,152 @@ func findColumnForRef(ref *ast.ColumnRef, tables []*Table, targetList *ast.List)
774774
}
775775

776776
// inferMathExpressionType attempts to infer the data type of a mathematical expression
777-
// by analyzing its operands and the operation being performed
777+
// by analyzing its operands and the operation being performed.
778778
func (c *Compiler) inferMathExpressionType(expr *ast.A_Expr, tables []*Table, op string) (string, bool) {
779-
// Try to infer types from left and right operands
780-
leftType := c.inferOperandType(expr.Lexpr, tables)
781-
rightType := c.inferOperandType(expr.Rexpr, tables)
779+
leftType, leftNotNull := c.inferOperandType(expr.Lexpr, tables)
780+
rightType, rightNotNull := c.inferOperandType(expr.Rexpr, tables)
782781

783-
// Debug logging to understand what's happening
784-
// fmt.Printf("DEBUG: Math expression %s: left=%s, right=%s\n", op, leftType, rightType)
782+
// result is non-null only if both sides are non-null
783+
notNull := leftNotNull && rightNotNull
785784

786-
// Determine the result type based on operands and operation
787785
resultType := c.combineTypes(leftType, rightType, op)
788-
789-
// For now, assume nullable since we're dealing with database columns
790-
// In a more sophisticated implementation, we could track nullability through the expression
791-
notNull := false
792-
793786
return resultType, notNull
794787
}
795788

796-
// inferOperandType tries to determine the type of an operand in an expression
797-
func (c *Compiler) inferOperandType(operand ast.Node, tables []*Table) string {
789+
// inferOperandType tries to determine the type and nullability of an operand in an expression.
790+
func (c *Compiler) inferOperandType(operand ast.Node, tables []*Table) (string, bool) {
798791
switch n := operand.(type) {
799792
case *ast.ColumnRef:
800-
// Look up the column in the available tables
801793
parts := stringSlice(n.Fields)
802794
var name string
803795
if len(parts) >= 1 {
804-
name = parts[len(parts)-1] // Get the column name (last part)
796+
name = parts[len(parts)-1]
805797
}
806-
807798
for _, table := range tables {
808799
for _, col := range table.Columns {
809800
if col.Name == name {
810-
return col.DataType
801+
return col.DataType, col.NotNull
811802
}
812803
}
813804
}
814-
return "any"
805+
// Unknown column: assume non-null by default so generated code
806+
// keeps the previous non-nullable behavior (avoids sql.Null*).
807+
return "any", true
815808
case *ast.A_Const:
816-
// Determine type based on constant value
809+
// constants are non-nullable
817810
switch n.Val.(type) {
818811
case *ast.Integer:
819-
return "int"
812+
return "int", true
820813
case *ast.Float:
821-
return "float"
814+
return "float", true
822815
case *ast.String:
823-
return "text"
816+
return "text", true
824817
default:
825-
return "any"
818+
return "any", true
826819
}
827820
case *ast.A_Expr:
828-
// Recursive case for nested expressions
821+
// nested expression
822+
nestedOp := ""
829823
if n.Name != nil {
830-
nestedOp := astutils.Join(n.Name, "")
831-
if lang.IsMathematicalOperator(nestedOp) {
832-
resultType, _ := c.inferMathExpressionType(n, tables, nestedOp)
833-
return resultType
834-
}
824+
nestedOp = astutils.Join(n.Name, "")
835825
}
836-
return "any"
826+
if lang.IsMathematicalOperator(nestedOp) {
827+
t, notNull := c.inferMathExpressionType(n, tables, nestedOp)
828+
return t, notNull
829+
}
830+
return "any", true
837831
default:
838-
return "any"
832+
return "any", true
839833
}
840834
}
841835

842836
// combineTypes determines the result type when combining two operand types with an operation
843837
func (c *Compiler) combineTypes(leftType, rightType, op string) string {
844-
// Handle division specially - division operations typically result in float
845-
if op == "/" {
846-
// If either operand is float, result is float
847-
if leftType == "float" || rightType == "float" {
838+
// Helper function to check if a type is a float variant
839+
isFloatType := func(t string) bool {
840+
return t == "float" || t == "float32" || t == "float64" || t == "double" || t == "double precision" || t == "real"
841+
}
842+
843+
// Helper function to check if a type is an integer variant
844+
isIntType := func(t string) bool {
845+
return t == "int" || t == "int32" || t == "int64" || t == "integer" || t == "int4" || t == "int8" || t == "bigint" || t == "smallint"
846+
}
847+
848+
// Normalize common DB types to standard types
849+
normalizeType := func(t string) string {
850+
switch t {
851+
case "int4", "integer", "int32":
852+
return "int"
853+
case "int8", "bigint", "int64":
854+
return "int"
855+
case "smallint":
856+
return "int"
857+
case "float4", "real", "float32":
858+
return "float"
859+
case "float8", "double precision", "float64":
860+
return "float"
861+
case "any":
862+
return "any"
863+
default:
864+
if isIntType(t) {
865+
return "int"
866+
}
867+
if isFloatType(t) {
868+
return "float"
869+
}
870+
return t
871+
}
872+
}
873+
874+
leftNorm := normalizeType(leftType)
875+
rightNorm := normalizeType(rightType)
876+
877+
// treat MySQL "div" same as "/" for division semantics
878+
if op == "/" || op == "div" {
879+
if leftNorm == "float" || rightNorm == "float" {
880+
return "float"
881+
}
882+
// If both are ints, return float for division (mathematical accuracy)
883+
if leftNorm == "int" && rightNorm == "int" {
884+
return "float"
885+
}
886+
// If at least one is numeric, prefer float
887+
if (leftNorm == "int" || leftNorm == "float") && rightNorm == "any" {
848888
return "float"
849889
}
850-
// Even integer division might want to be float in many cases
851-
// For safety, return float for division unless both operands are clearly non-numeric
852-
if leftType != "text" && rightType != "text" {
890+
if leftNorm == "any" && (rightNorm == "int" || rightNorm == "float") {
853891
return "float"
854892
}
893+
// For mixed types with at least one numeric, prefer float
894+
if leftNorm != "text" && rightNorm != "text" {
895+
return "float"
896+
}
897+
return "any"
855898
}
856899

857-
// For other mathematical operations
858-
switch {
859-
case leftType == "float" || rightType == "float":
900+
// other math ops (* + -):
901+
if leftNorm == "float" || rightNorm == "float" {
860902
return "float"
861-
case leftType == "int" && rightType == "int":
903+
}
904+
if leftNorm == "int" && rightNorm == "int" {
862905
return "int"
863-
case leftType == "int" && rightType == "any":
906+
}
907+
// If one side is numeric and other is any, prefer the numeric type
908+
if leftNorm == "int" && rightNorm == "any" {
864909
return "int"
865-
case leftType == "any" && rightType == "int":
910+
}
911+
if leftNorm == "any" && rightNorm == "int" {
912+
return "int"
913+
}
914+
if leftNorm == "float" && rightNorm == "any" {
915+
return "float"
916+
}
917+
if leftNorm == "any" && rightNorm == "float" {
918+
return "float"
919+
}
920+
// If both are any, prefer int as a reasonable default for math
921+
if leftNorm == "any" && rightNorm == "any" {
866922
return "int"
867-
default:
868-
// Default fallback
869-
return "any"
870923
}
924+
return "any"
871925
}

0 commit comments

Comments
 (0)