@@ -126,7 +126,6 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
126126 continue
127127 }
128128 switch n := res .Val .(type ) {
129-
130129 case * ast.A_Const :
131130 name := ""
132131 if res .Name != nil {
@@ -150,7 +149,8 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
150149 if res .Name != nil {
151150 name = * res .Name
152151 }
153- switch op := astutils .Join (n .Name , "" ); {
152+ op := astutils .Join (n .Name , "" )
153+ switch {
154154 case lang .IsComparisonOperator (op ):
155155 // TODO: Generate a name for these operations
156156 cols = append (cols , & Column {Name : name , DataType : "bool" , NotNull : true })
@@ -774,98 +774,152 @@ func findColumnForRef(ref *ast.ColumnRef, tables []*Table, targetList *ast.List)
774774}
775775
776776// inferMathExpressionType attempts to infer the data type of a mathematical expression
777- // by analyzing its operands and the operation being performed
777+ // by analyzing its operands and the operation being performed.
778778func (c * Compiler ) inferMathExpressionType (expr * ast.A_Expr , tables []* Table , op string ) (string , bool ) {
779- // Try to infer types from left and right operands
780- leftType := c .inferOperandType (expr .Lexpr , tables )
781- rightType := c .inferOperandType (expr .Rexpr , tables )
779+ leftType , leftNotNull := c .inferOperandType (expr .Lexpr , tables )
780+ rightType , rightNotNull := c .inferOperandType (expr .Rexpr , tables )
782781
783- // Debug logging to understand what's happening
784- // fmt.Printf("DEBUG: Math expression %s: left=%s, right=%s\n", op, leftType, rightType)
782+ // result is non-null only if both sides are non-null
783+ notNull := leftNotNull && rightNotNull
785784
786- // Determine the result type based on operands and operation
787785 resultType := c .combineTypes (leftType , rightType , op )
788-
789- // For now, assume nullable since we're dealing with database columns
790- // In a more sophisticated implementation, we could track nullability through the expression
791- notNull := false
792-
793786 return resultType , notNull
794787}
795788
796- // inferOperandType tries to determine the type of an operand in an expression
797- func (c * Compiler ) inferOperandType (operand ast.Node , tables []* Table ) string {
789+ // inferOperandType tries to determine the type and nullability of an operand in an expression.
790+ func (c * Compiler ) inferOperandType (operand ast.Node , tables []* Table ) ( string , bool ) {
798791 switch n := operand .(type ) {
799792 case * ast.ColumnRef :
800- // Look up the column in the available tables
801793 parts := stringSlice (n .Fields )
802794 var name string
803795 if len (parts ) >= 1 {
804- name = parts [len (parts )- 1 ] // Get the column name (last part)
796+ name = parts [len (parts )- 1 ]
805797 }
806-
807798 for _ , table := range tables {
808799 for _ , col := range table .Columns {
809800 if col .Name == name {
810- return col .DataType
801+ return col .DataType , col . NotNull
811802 }
812803 }
813804 }
814- return "any"
805+ // Unknown column: assume non-null by default so generated code
806+ // keeps the previous non-nullable behavior (avoids sql.Null*).
807+ return "any" , true
815808 case * ast.A_Const :
816- // Determine type based on constant value
809+ // constants are non-nullable
817810 switch n .Val .(type ) {
818811 case * ast.Integer :
819- return "int"
812+ return "int" , true
820813 case * ast.Float :
821- return "float"
814+ return "float" , true
822815 case * ast.String :
823- return "text"
816+ return "text" , true
824817 default :
825- return "any"
818+ return "any" , true
826819 }
827820 case * ast.A_Expr :
828- // Recursive case for nested expressions
821+ // nested expression
822+ nestedOp := ""
829823 if n .Name != nil {
830- nestedOp := astutils .Join (n .Name , "" )
831- if lang .IsMathematicalOperator (nestedOp ) {
832- resultType , _ := c .inferMathExpressionType (n , tables , nestedOp )
833- return resultType
834- }
824+ nestedOp = astutils .Join (n .Name , "" )
835825 }
836- return "any"
826+ if lang .IsMathematicalOperator (nestedOp ) {
827+ t , notNull := c .inferMathExpressionType (n , tables , nestedOp )
828+ return t , notNull
829+ }
830+ return "any" , true
837831 default :
838- return "any"
832+ return "any" , true
839833 }
840834}
841835
842836// combineTypes determines the result type when combining two operand types with an operation
843837func (c * Compiler ) combineTypes (leftType , rightType , op string ) string {
844- // Handle division specially - division operations typically result in float
845- if op == "/" {
846- // If either operand is float, result is float
847- if leftType == "float" || rightType == "float" {
838+ // Helper function to check if a type is a float variant
839+ isFloatType := func (t string ) bool {
840+ return t == "float" || t == "float32" || t == "float64" || t == "double" || t == "double precision" || t == "real"
841+ }
842+
843+ // Helper function to check if a type is an integer variant
844+ isIntType := func (t string ) bool {
845+ return t == "int" || t == "int32" || t == "int64" || t == "integer" || t == "int4" || t == "int8" || t == "bigint" || t == "smallint"
846+ }
847+
848+ // Normalize common DB types to standard types
849+ normalizeType := func (t string ) string {
850+ switch t {
851+ case "int4" , "integer" , "int32" :
852+ return "int"
853+ case "int8" , "bigint" , "int64" :
854+ return "int"
855+ case "smallint" :
856+ return "int"
857+ case "float4" , "real" , "float32" :
858+ return "float"
859+ case "float8" , "double precision" , "float64" :
860+ return "float"
861+ case "any" :
862+ return "any"
863+ default :
864+ if isIntType (t ) {
865+ return "int"
866+ }
867+ if isFloatType (t ) {
868+ return "float"
869+ }
870+ return t
871+ }
872+ }
873+
874+ leftNorm := normalizeType (leftType )
875+ rightNorm := normalizeType (rightType )
876+
877+ // treat MySQL "div" same as "/" for division semantics
878+ if op == "/" || op == "div" {
879+ if leftNorm == "float" || rightNorm == "float" {
880+ return "float"
881+ }
882+ // If both are ints, return float for division (mathematical accuracy)
883+ if leftNorm == "int" && rightNorm == "int" {
884+ return "float"
885+ }
886+ // If at least one is numeric, prefer float
887+ if (leftNorm == "int" || leftNorm == "float" ) && rightNorm == "any" {
848888 return "float"
849889 }
850- // Even integer division might want to be float in many cases
851- // For safety, return float for division unless both operands are clearly non-numeric
852- if leftType != "text" && rightType != "text" {
890+ if leftNorm == "any" && (rightNorm == "int" || rightNorm == "float" ) {
853891 return "float"
854892 }
893+ // For mixed types with at least one numeric, prefer float
894+ if leftNorm != "text" && rightNorm != "text" {
895+ return "float"
896+ }
897+ return "any"
855898 }
856899
857- // For other mathematical operations
858- switch {
859- case leftType == "float" || rightType == "float" :
900+ // other math ops (* + -):
901+ if leftNorm == "float" || rightNorm == "float" {
860902 return "float"
861- case leftType == "int" && rightType == "int" :
903+ }
904+ if leftNorm == "int" && rightNorm == "int" {
862905 return "int"
863- case leftType == "int" && rightType == "any" :
906+ }
907+ // If one side is numeric and other is any, prefer the numeric type
908+ if leftNorm == "int" && rightNorm == "any" {
864909 return "int"
865- case leftType == "any" && rightType == "int" :
910+ }
911+ if leftNorm == "any" && rightNorm == "int" {
912+ return "int"
913+ }
914+ if leftNorm == "float" && rightNorm == "any" {
915+ return "float"
916+ }
917+ if leftNorm == "any" && rightNorm == "float" {
918+ return "float"
919+ }
920+ // If both are any, prefer int as a reasonable default for math
921+ if leftNorm == "any" && rightNorm == "any" {
866922 return "int"
867- default :
868- // Default fallback
869- return "any"
870923 }
924+ return "any"
871925}
0 commit comments