Skip to content

Commit 10a81e5

Browse files
committed
sql: wrap CASE expressions with parenthesis within PLpgSQL context
Currently, if we generate a CASE expression within the condition of the IF block in PLpgSQL, it cannot be parsed because it's unclear whether THEN belongs to CASE or to IF. In order to disambiguate this, we need to wrap CASE expressions in parenthesis, and we now do so in all sqlsmith-generated queries whenever we're in PLpgSQL context. Release note: None
1 parent 1df70f3 commit 10a81e5

File tree

6 files changed

+36
-12
lines changed

6 files changed

+36
-12
lines changed

pkg/internal/sqlsmith/plpgsql.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import (
1414
"github.com/cockroachdb/errors"
1515
)
1616

17+
var plpgsqlFlags = tree.FmtParsable | tree.FmtPLpgSQLParen
18+
1719
func (s *Smither) makeRoutineBodyPLpgSQL(
1820
params tree.ParamTypes, rTyp *types.T, vol tree.RoutineVolatility,
1921
) string {
@@ -25,7 +27,7 @@ func (s *Smither) makeRoutineBodyPLpgSQL(
2527
// errors.
2628
block := s.makePLpgSQLBlock(scope)
2729
block.Body = append(block.Body, s.makePLpgSQLReturn(scope))
28-
return "\n" + tree.AsStringWithFlags(s.makePLpgSQLBlock(scope), tree.FmtParsable)
30+
return "\n" + tree.AsStringWithFlags(s.makePLpgSQLBlock(scope), plpgsqlFlags, tree.FmtInPLpgSQL(true /* inPLpgSQL */))
2931
}
3032

3133
func (s *Smither) makePLpgSQLBlock(scope plpgsqlBlockScope) *ast.Block {

pkg/internal/sqlsmith/setup_test.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ package sqlsmith_test
88
import (
99
"context"
1010
"flag"
11-
"fmt"
12-
"strings"
1311
"testing"
1412

1513
"github.com/cockroachdb/cockroach/pkg/base"
@@ -92,9 +90,7 @@ func TestGenerateParse(t *testing.T) {
9290
t.Fatalf("unknown setting %s", settingName)
9391
}
9492
settings := setting(rnd)
95-
t.Log("setting:", settingName, settings.Options)
9693
setupSQL := setup(rnd)
97-
t.Log(strings.Join(setupSQL, "\n"))
9894
for _, stmt := range setupSQL {
9995
db.Exec(t, stmt)
10096
}
@@ -119,7 +115,6 @@ func TestGenerateParse(t *testing.T) {
119115
if err != nil {
120116
t.Fatal(err)
121117
}
122-
fmt.Print("STMT: ", i, "\n", stmt, ";\n\n")
123118
if *flagExec {
124119
_, err = conn.ExecContext(ctx, `SET statement_timeout = '9s'`)
125120
if err != nil {
@@ -129,7 +124,7 @@ func TestGenerateParse(t *testing.T) {
129124
es := err.Error()
130125
if !seen[es] {
131126
seen[es] = true
132-
fmt.Printf("ERR (%d): %v\n", i, err)
127+
t.Logf("ERR (%d): %v\n", i, err)
133128
}
134129
}
135130
}

pkg/internal/sqlsmith/sqlsmith.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ var prettyCfg = func() tree.PrettyCfg {
192192
cfg := tree.DefaultPrettyCfg()
193193
cfg.LineWidth = 120
194194
cfg.Simplify = false
195+
cfg.FmtFlags = tree.FmtPLpgSQLParen
195196
return cfg
196197
}()
197198

@@ -213,10 +214,10 @@ func (s *Smither) Generate() string {
213214
i = 0
214215

215216
printCfg := prettyCfg
216-
fl := tree.FmtParsable
217+
fl := plpgsqlFlags
217218
if s.postgres {
218-
printCfg.FmtFlags = tree.FmtPGCatalog
219-
fl = tree.FmtPGCatalog
219+
printCfg.FmtFlags = tree.FmtPGCatalog | tree.FmtPLpgSQLParen
220+
fl = tree.FmtPGCatalog | tree.FmtPLpgSQLParen
220221
}
221222
p, err := printCfg.Pretty(stmt)
222223
if err != nil {

pkg/sql/sem/plpgsqltree/statements.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,7 @@ func (s *DoBlock) Format(ctx *tree.FmtCtx) {
11211121
// Format the body of the DO block separately so that FormatStringDollarQuotes
11221122
// can examine the resulting string and determine how to quote the block.
11231123
bodyCtx := ctx.Clone()
1124+
tree.FmtInPLpgSQL(true /* inPLpgSQL */)(bodyCtx)
11241125
bodyCtx.FormatNode(s.Block)
11251126
bodyStr := "\n" + bodyCtx.CloseAndGetString()
11261127

pkg/sql/sem/tree/expr.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,6 +1438,13 @@ type CaseExpr struct {
14381438

14391439
// Format implements the NodeFormatter interface.
14401440
func (node *CaseExpr) Format(ctx *FmtCtx) {
1441+
if ctx.HasFlags(FmtPLpgSQLParen) && ctx.inPLpgSQL {
1442+
// In some cases in PLpgSQL context we need to wrap the CASE expression
1443+
// in parenthesis to make it parsable. We do so only if the caller
1444+
// requested it.
1445+
ctx.WriteByte('(')
1446+
defer ctx.WriteByte(')')
1447+
}
14411448
ctx.WriteString("CASE ")
14421449
if node.Expr != nil {
14431450
ctx.FormatNode(node.Expr)

pkg/sql/sem/tree/format.go

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,10 @@ const (
201201

202202
// FmtHideHints skips over any hints.
203203
FmtHideHints
204+
205+
// FmtPLpgSQLParen will wrap some expressions in parenthesis when in PLpgSQL
206+
// context. This should only be used in tests.
207+
FmtPLpgSQLParen
204208
)
205209

206210
const genericArityIndicator = "__more__"
@@ -343,6 +347,11 @@ type FmtCtx struct {
343347
// indexedTypeFormatter is an optional interceptor for formatting
344348
// IDTypeReferences differently than normal.
345349
indexedTypeFormatter func(*FmtCtx, *OIDTypeReference)
350+
351+
// inPLpgSQL, if set, indicates that we're formatting a node within PLpgSQL
352+
// context.
353+
inPLpgSQL bool
354+
346355
// small scratch buffer to reduce allocations.
347356
scratch [64]byte
348357
}
@@ -404,6 +413,14 @@ func FmtLocation(loc *time.Location) FmtCtxOption {
404413
}
405414
}
406415

416+
// FmtInPLpgSQL modifies FmtCtx to indicate whether we're in the PLpgSQL
417+
// context.
418+
func FmtInPLpgSQL(inPLpgSQL bool) FmtCtxOption {
419+
return func(ctx *FmtCtx) {
420+
ctx.inPLpgSQL = inPLpgSQL
421+
}
422+
}
423+
407424
// NewFmtCtx creates a FmtCtx; only flags that don't require Annotations
408425
// can be used.
409426
func NewFmtCtx(f FmtFlags, opts ...FmtCtxOption) *FmtCtx {
@@ -423,14 +440,15 @@ func NewFmtCtx(f FmtFlags, opts ...FmtCtxOption) *FmtCtx {
423440
// original.
424441
func (ctx *FmtCtx) Clone() *FmtCtx {
425442
newCtx := fmtCtxPool.Get().(*FmtCtx)
443+
newCtx.dataConversionConfig = ctx.dataConversionConfig
444+
newCtx.location = ctx.location
426445
newCtx.flags = ctx.flags
427446
newCtx.ann = ctx.ann
428447
newCtx.indexedVarFormat = ctx.indexedVarFormat
429448
newCtx.placeholderFormat = ctx.placeholderFormat
430449
newCtx.tableNameFormatter = ctx.tableNameFormatter
431450
newCtx.indexedTypeFormatter = ctx.indexedTypeFormatter
432-
newCtx.dataConversionConfig = ctx.dataConversionConfig
433-
newCtx.location = ctx.location
451+
newCtx.inPLpgSQL = ctx.inPLpgSQL
434452
return newCtx
435453
}
436454

0 commit comments

Comments
 (0)