@@ -19,7 +19,8 @@ type QueryValue struct {
1919
2020 // Column is kept so late in the generation process around to differentiate
2121 // between mysql slices and pg arrays
22- Column * plugin.Column
22+ Column * plugin.Column
23+ QueryText string
2324}
2425
2526func (v QueryValue ) EmitStruct () bool {
@@ -84,6 +85,9 @@ func (v QueryValue) SlicePair() string {
8485
8586func (v QueryValue ) Type () string {
8687 if v .Typ != "" {
88+ if v .isUsedWithArrayComparison () {
89+ return strings .Trim (v .Typ , "[]" ) // Return single type if used in array comparison.
90+ }
8791 return v .Typ
8892 }
8993 if v .Struct != nil {
@@ -112,6 +116,9 @@ func (v QueryValue) UniqueFields() []Field {
112116 fields := make ([]Field , 0 , len (v .Struct .Fields ))
113117
114118 for _ , field := range v .Struct .Fields {
119+ if v .isUsedWithArrayComparison () {
120+ field .Type = strings .Trim (field .Type , "[]" )
121+ }
115122 if _ , found := seen [field .Name ]; found {
116123 continue
117124 }
@@ -128,14 +135,14 @@ func (v QueryValue) Params() string {
128135 }
129136 var out []string
130137 if v .Struct == nil {
131- if ! v .Column .IsSqlcSlice && strings .HasPrefix (v .Typ , "[]" ) && v .Typ != "[]byte" && ! v .SQLDriver .IsPGX () {
138+ if ! v .Column .IsSqlcSlice && strings .HasPrefix (v .Typ , "[]" ) && v .Typ != "[]byte" && ! v .SQLDriver .IsPGX () && ! v . isUsedWithArrayComparison () {
132139 out = append (out , "pq.Array(" + escape (v .Name )+ ")" )
133140 } else {
134141 out = append (out , escape (v .Name ))
135142 }
136143 } else {
137144 for _ , f := range v .Struct .Fields {
138- if ! f .HasSqlcSlice () && strings .HasPrefix (f .Type , "[]" ) && f .Type != "[]byte" && ! v .SQLDriver .IsPGX () {
145+ if ! f .HasSqlcSlice () && strings .HasPrefix (f .Type , "[]" ) && f .Type != "[]byte" && ! v .SQLDriver .IsPGX () && ! v . isUsedWithArrayComparison () {
139146 out = append (out , "pq.Array(" + escape (v .VariableForField (f ))+ ")" )
140147 } else {
141148 out = append (out , escape (v .VariableForField (f )))
@@ -253,6 +260,22 @@ func (v QueryValue) VariableForField(f Field) string {
253260 return v .Name + "." + f .Name
254261}
255262
263+ // isUsedWithArrayComparison returns true if the parameter is used with the ANY/SOME/ALL keyword in query.
264+ func (v QueryValue ) isUsedWithArrayComparison () bool {
265+ if v .Struct != nil {
266+ for _ , f := range v .Struct .Fields {
267+ if strings .Contains (v .QueryText , fmt .Sprintf ("ANY(%s)" , f .DBName )) || strings .Contains (v .QueryText , fmt .Sprintf ("SOME(%s)" , f .DBName )) || strings .Contains (v .QueryText , fmt .Sprintf ("ALL(%s)" , f .DBName )) {
268+ return true
269+ }
270+ }
271+ } else {
272+ if strings .Contains (v .QueryText , fmt .Sprintf ("ANY(%s)" , v .DBName )) || strings .Contains (v .QueryText , fmt .Sprintf ("SOME(%s)" , v .DBName )) || strings .Contains (v .QueryText , fmt .Sprintf ("ALL(%s)" , v .DBName )) {
273+ return true
274+ }
275+ }
276+ return false
277+ }
278+
256279// A struct used to generate methods and fields on the Queries struct
257280type Query struct {
258281 Cmd string
0 commit comments