From b97f300cba10c9ebb2726ed1c952cdd3611496f8 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 1 Oct 2025 10:53:59 -0700 Subject: [PATCH 01/18] implement row2 --- server/handler.go | 28 +++++++++++++ sql/plan/process.go | 23 +++++++++++ sql/rowexec/transaction_iters.go | 16 ++++++++ sql/rows.go | 8 +++- sql/table_iter.go | 70 ++++++++++++++++++++++++++++++++ sql/type.go | 1 - 6 files changed, 144 insertions(+), 2 deletions(-) diff --git a/server/handler.go b/server/handler.go index 2275ca7a2d..701c1141ca 100644 --- a/server/handler.go +++ b/server/handler.go @@ -495,6 +495,8 @@ func (h *Handler) doQuery( r, err = resultForEmptyIter(sqlCtx, rowIter, resultFields) } else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) { r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf) + } else if ri2, ok := rowIter.(sql.RowIter2); ok && ri2.IsRowIter2(sqlCtx) { + r, err = h.resultForDefaultIter2(sqlCtx, ri2, resultFields, callback, more) } else { r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf) } @@ -768,6 +770,32 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s return r, processedAtLeastOneBatch, nil } +func (h *Handler) resultForDefaultIter2(ctx *sql.Context, iter sql.RowIter2, resultFields []*querypb.Field, callback func(*sqltypes.Result, bool) error, more bool) (*sqltypes.Result, error) { + res := &sqltypes.Result{Fields: resultFields} + for { + if res.RowsAffected == rowsBatch { + if err := callback(res, more); err != nil { + return nil, err + } + res = nil + } + row, err := iter.Next2(ctx) + if err == io.EOF { + return res, nil + } + if err != nil { + return nil, err + } + + outRow := make([]sqltypes.Value, len(res.Rows)) + for i := range row { + outRow[i] = sqltypes.MakeTrusted(row[i].Typ, row[i].Val) + } + res.Rows = append(res.Rows, outRow) + res.RowsAffected++ + } +} + // See https://dev.mysql.com/doc/internals/en/status-flags.html func setConnStatusFlags(ctx *sql.Context, c *mysql.Conn) error { ok, err := isSessionAutocommit(ctx) diff --git a/sql/plan/process.go b/sql/plan/process.go index ee95249f10..92f33ba19f 100644 --- a/sql/plan/process.go +++ b/sql/plan/process.go @@ -317,6 +317,29 @@ func (i *TrackedRowIter) Next(ctx *sql.Context) (sql.Row, error) { return row, nil } +func (i *TrackedRowIter) Next2(ctx *sql.Context) (sql.Row2, error) { + ri2, ok := i.iter.(sql.RowIter2) + if !ok { + panic(fmt.Sprintf("%T does not implement sql.RowIter2 interface", i.iter)) + } + row, err := ri2.Next2(ctx) + if err != nil { + return nil, err + } + i.numRows++ + if i.onNext != nil { + i.onNext() + } + return row, nil +} + +func (i *TrackedRowIter) IsRowIter2(ctx *sql.Context) bool { + if ri2, ok := i.iter.(sql.RowIter2); ok { + return ri2.IsRowIter2(ctx) + } + return false +} + func (i *TrackedRowIter) Close(ctx *sql.Context) error { err := i.iter.Close(ctx) diff --git a/sql/rowexec/transaction_iters.go b/sql/rowexec/transaction_iters.go index db69cf5327..aacc1f095d 100644 --- a/sql/rowexec/transaction_iters.go +++ b/sql/rowexec/transaction_iters.go @@ -15,6 +15,7 @@ package rowexec import ( + "fmt" "io" "gopkg.in/src-d/go-errors.v1" @@ -99,6 +100,21 @@ func (t *TransactionCommittingIter) Next(ctx *sql.Context) (sql.Row, error) { return t.childIter.Next(ctx) } +func (t *TransactionCommittingIter) Next2(ctx *sql.Context) (sql.Row2, error) { + ri2, ok := t.childIter.(sql.RowIter2) + if !ok { + panic(fmt.Sprintf("%T does not implement sql.RowIter2 interface", t.childIter)) + } + return ri2.Next2(ctx) +} + +func (t *TransactionCommittingIter) IsRowIter2(ctx *sql.Context) bool { + if ri2, ok := t.childIter.(sql.RowIter2); ok { + return ri2.IsRowIter2(ctx) + } + return false +} + func (t *TransactionCommittingIter) Close(ctx *sql.Context) error { var err error if t.childIter != nil { diff --git a/sql/rows.go b/sql/rows.go index a9e5f55d5c..191147ad68 100644 --- a/sql/rows.go +++ b/sql/rows.go @@ -92,6 +92,12 @@ type RowIter interface { Closer } +type RowIter2 interface { + RowIter + Next2(ctx *Context) (Row2, error) + IsRowIter2(ctx *Context) bool +} + // RowIterToRows converts a row iterator to a slice of rows. func RowIterToRows(ctx *Context, i RowIter) ([]Row, error) { var rows []Row @@ -112,7 +118,7 @@ func RowIterToRows(ctx *Context, i RowIter) ([]Row, error) { return rows, i.Close(ctx) } -func rowFromRow2(sch Schema, r Row2) Row { +func RowFromRow2(sch Schema, r Row2) Row { row := make(Row, len(sch)) for i, col := range sch { switch col.Type.Type() { diff --git a/sql/table_iter.go b/sql/table_iter.go index e302d5428a..6ac205c377 100644 --- a/sql/table_iter.go +++ b/sql/table_iter.go @@ -15,6 +15,7 @@ package sql import ( + "fmt" "io" ) @@ -24,6 +25,8 @@ type TableRowIter struct { partitions PartitionIter partition Partition rows RowIter + + rows2 RowIter2 } var _ RowIter = (*TableRowIter)(nil) @@ -76,6 +79,73 @@ func (i *TableRowIter) Next(ctx *Context) (Row, error) { return row, err } +func (i *TableRowIter) Next2(ctx *Context) (Row2, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + if i.partition == nil { + partition, err := i.partitions.Next(ctx) + if err != nil { + if err == io.EOF { + if e := i.partitions.Close(ctx); e != nil { + return nil, e + } + } + + return nil, err + } + + i.partition = partition + } + + if i.rows2 == nil { + rows, err := i.table.PartitionRows(ctx, i.partition) + if err != nil { + return nil, err + } + ri2, ok := rows.(RowIter2) + if !ok { + panic(fmt.Sprintf("%T does not implement RowIter2", rows)) + } + i.rows2 = ri2 + } + + row, err := i.rows2.Next2(ctx) + if err != nil && err == io.EOF { + if err = i.rows2.Close(ctx); err != nil { + return nil, err + } + i.partition = nil + i.rows2 = nil + row, err = i.Next2(ctx) + } + return row, err +} + +func (i *TableRowIter) IsRowIter2(ctx *Context) bool { + if i.partition == nil { + partition, err := i.partitions.Next(ctx) + if err != nil { + return false + } + i.partition = partition + } + if i.rows2 == nil { + rows, err := i.table.PartitionRows(ctx, i.partition) + if err != nil { + return false + } + ri2, ok := rows.(RowIter2) + if !ok { + return false + } + i.rows2 = ri2 + } + return i.rows2.IsRowIter2(ctx) +} + func (i *TableRowIter) Close(ctx *Context) error { if i.rows != nil { if err := i.rows.Close(ctx); err != nil { diff --git a/sql/type.go b/sql/type.go index 59af5360f1..e4d0f8ff96 100644 --- a/sql/type.go +++ b/sql/type.go @@ -294,7 +294,6 @@ func IsDecimalType(t Type) bool { type Type2 interface { Type - // Compare2 returns an integer comparing two Values. Compare2(Value, Value) (int, error) // Convert2 converts a value of a compatible type. From 1f43907c5ed9bee8d2ea8e544c831b2d6ebcdfaf Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 1 Oct 2025 16:37:57 -0700 Subject: [PATCH 02/18] better --- server/handler.go | 117 +++++++++++++++++++++++++++++++++++++-------- sql/plan/filter.go | 24 ++++++++++ 2 files changed, 121 insertions(+), 20 deletions(-) diff --git a/server/handler.go b/server/handler.go index 701c1141ca..bd82c27dfa 100644 --- a/server/handler.go +++ b/server/handler.go @@ -496,7 +496,7 @@ func (h *Handler) doQuery( } else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) { r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf) } else if ri2, ok := rowIter.(sql.RowIter2); ok && ri2.IsRowIter2(sqlCtx) { - r, err = h.resultForDefaultIter2(sqlCtx, ri2, resultFields, callback, more) + r, processedAtLeastOneBatch, err = h.resultForDefaultIter2(sqlCtx, c, ri2, resultFields, callback, more) } else { r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf) } @@ -770,30 +770,107 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s return r, processedAtLeastOneBatch, nil } -func (h *Handler) resultForDefaultIter2(ctx *sql.Context, iter sql.RowIter2, resultFields []*querypb.Field, callback func(*sqltypes.Result, bool) error, more bool) (*sqltypes.Result, error) { - res := &sqltypes.Result{Fields: resultFields} - for { - if res.RowsAffected == rowsBatch { - if err := callback(res, more); err != nil { - return nil, err - } - res = nil - } - row, err := iter.Next2(ctx) - if err == io.EOF { - return res, nil +func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sql.RowIter2, resultFields []*querypb.Field, callback func(*sqltypes.Result, bool) error, more bool) (*sqltypes.Result, bool, error) { + defer trace.StartRegion(ctx, "Handler.resultForDefaultIter").End() + + eg, ctx := ctx.NewErrgroup() + pan2err := func(err *error) { + if recoveredPanic := recover(); recoveredPanic != nil { + stack := debug.Stack() + wrappedErr := fmt.Errorf("handler caught panic: %v\n%s", recoveredPanic, stack) + *err = goerrors.Join(*err, wrappedErr) } - if err != nil { - return nil, err + } + + // TODO: poll for closed connections should obviously also run even if + // we're doing something with an OK result or a single row result, etc. + // This should be in the caller. + pollCtx, cancelF := ctx.NewSubContext() + eg.Go(func() (err error) { + defer pan2err(&err) + return h.pollForClosedConnection(pollCtx, c) + }) + + // Default waitTime is one minute if there is no timeout configured, in which case + // it will loop to iterate again unless the socket died by the OS timeout or other problems. + // If there is a timeout, it will be enforced to ensure that Vitess has a chance to + // call Handler.CloseConnection() + waitTime := 1 * time.Minute + if h.readTimeout > 0 { + waitTime = h.readTimeout + } + timer := time.NewTimer(waitTime) + defer timer.Stop() + + wg := sync.WaitGroup{} + wg.Add(1) + + var res *sqltypes.Result + var processedAtLeastOneBatch bool + eg.Go(func() (err error) { + defer pan2err(&err) + defer cancelF() + defer wg.Done() + for { + if res == nil { + res = &sqltypes.Result{Fields: resultFields} + } + if res.RowsAffected == rowsBatch { + if err := callback(res, more); err != nil { + return err + } + res = nil + processedAtLeastOneBatch = true + continue + } + + select { + case <-ctx.Done(): + return context.Cause(ctx) + case <-timer.C: + // TODO: timer should probably go in its own thread, as rowChan is blocking + if h.readTimeout != 0 { + // Cancel and return so Vitess can call the CloseConnection callback + ctx.GetLogger().Tracef("connection timeout") + return ErrRowTimeout.New() + } + default: + row, err := iter.Next2(ctx) + if err == io.EOF { + return nil + } + if err != nil { + return err + } + outRow := make([]sqltypes.Value, len(row)) + for i := range row { + outRow[i] = sqltypes.MakeTrusted(row[i].Typ, row[i].Val) + } + res.Rows = append(res.Rows, outRow) + res.RowsAffected++ + } + timer.Reset(waitTime) } + }) - outRow := make([]sqltypes.Value, len(res.Rows)) - for i := range row { - outRow[i] = sqltypes.MakeTrusted(row[i].Typ, row[i].Val) + // Close() kills this PID in the process list, + // wait until all rows have be sent over the wire + eg.Go(func() (err error) { + defer pan2err(&err) + wg.Wait() + return iter.Close(ctx) + }) + + err := eg.Wait() + if err != nil { + ctx.GetLogger().WithError(err).Warn("error running query") + if verboseErrorLogging { + fmt.Printf("Err: %+v", err) } - res.Rows = append(res.Rows, outRow) - res.RowsAffected++ + return nil, false, err } + + return res, processedAtLeastOneBatch, nil } // See https://dev.mysql.com/doc/internals/en/status-flags.html diff --git a/sql/plan/filter.go b/sql/plan/filter.go index f2c0691112..57c64664df 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -15,6 +15,7 @@ package plan import ( + "fmt" "github.com/dolthub/go-mysql-server/sql" ) @@ -133,6 +134,29 @@ func (i *FilterIter) Next(ctx *sql.Context) (sql.Row, error) { } } +func (i *FilterIter) Next2(ctx *sql.Context) (sql.Row2, error) { + ri2, ok := i.childIter.(sql.RowIter2) + if !ok { + panic(fmt.Sprintf("%T is not a sql.RowIter2", i.childIter)) + } + + for { + row, err := ri2.Next(ctx) + if err != nil { + return nil, err + } + + res, err := sql.EvaluateCondition(ctx, i.cond, row) + if err != nil { + return nil, err + } + + if sql.IsTrue(res) { + return nil, nil + } + } +} + // Close implements the RowIter interface. func (i *FilterIter) Close(ctx *sql.Context) error { return i.childIter.Close(ctx) From 461cf294e9f57001decb8ea2636b05b3614fc82e Mon Sep 17 00:00:00 2001 From: jycor Date: Thu, 2 Oct 2025 17:59:53 +0000 Subject: [PATCH 03/18] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/plan/filter.go | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/plan/filter.go b/sql/plan/filter.go index 57c64664df..8be34ec2c7 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -16,6 +16,7 @@ package plan import ( "fmt" + "github.com/dolthub/go-mysql-server/sql" ) From af05fafb96e2588e10ff04ad26702a60443c7ac2 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 10 Oct 2025 13:28:31 -0700 Subject: [PATCH 04/18] disable row2 --- server/handler.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/handler.go b/server/handler.go index bd82c27dfa..cc2621f54c 100644 --- a/server/handler.go +++ b/server/handler.go @@ -495,8 +495,8 @@ func (h *Handler) doQuery( r, err = resultForEmptyIter(sqlCtx, rowIter, resultFields) } else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) { r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf) - } else if ri2, ok := rowIter.(sql.RowIter2); ok && ri2.IsRowIter2(sqlCtx) { - r, processedAtLeastOneBatch, err = h.resultForDefaultIter2(sqlCtx, c, ri2, resultFields, callback, more) + //} else if ri2, ok := rowIter.(sql.RowIter2); ok && ri2.IsRowIter2(sqlCtx) { + // r, processedAtLeastOneBatch, err = h.resultForDefaultIter2(sqlCtx, c, ri2, resultFields, callback, more) } else { r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf) } From d7f3358571817c9dd1af1bcccf1d6ac597cd838c Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 10 Oct 2025 16:28:23 -0700 Subject: [PATCH 05/18] reenable row2 --- server/handler.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/handler.go b/server/handler.go index cc2621f54c..bd82c27dfa 100644 --- a/server/handler.go +++ b/server/handler.go @@ -495,8 +495,8 @@ func (h *Handler) doQuery( r, err = resultForEmptyIter(sqlCtx, rowIter, resultFields) } else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) { r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf) - //} else if ri2, ok := rowIter.(sql.RowIter2); ok && ri2.IsRowIter2(sqlCtx) { - // r, processedAtLeastOneBatch, err = h.resultForDefaultIter2(sqlCtx, c, ri2, resultFields, callback, more) + } else if ri2, ok := rowIter.(sql.RowIter2); ok && ri2.IsRowIter2(sqlCtx) { + r, processedAtLeastOneBatch, err = h.resultForDefaultIter2(sqlCtx, c, ri2, resultFields, callback, more) } else { r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf) } From 1a979832fafba48fb7480618ca0f1e3dcc42daee Mon Sep 17 00:00:00 2001 From: James Cor Date: Sun, 12 Oct 2025 17:52:13 -0700 Subject: [PATCH 06/18] implement expr2 for filters --- sql/core.go | 1 + sql/expression/comparison.go | 56 ++++++++++++++++++++++++++++++++++++ sql/expression/get_field.go | 4 +++ sql/expression/literal.go | 4 +++ sql/expression/unresolved.go | 4 +++ sql/plan/filter.go | 24 +++++++++++++--- 6 files changed, 89 insertions(+), 4 deletions(-) diff --git a/sql/core.go b/sql/core.go index c1e1f90b2a..c2996039eb 100644 --- a/sql/core.go +++ b/sql/core.go @@ -467,6 +467,7 @@ type Expression2 interface { Eval2(ctx *Context, row Row2) (Value, error) // Type2 returns the expression type. Type2() Type2 + IsExpr2() bool } var SystemVariables SystemVariableRegistry diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 97e5f4ba6e..6b01443c9f 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -15,7 +15,9 @@ package expression import ( + "bytes" "fmt" + querypb "github.com/dolthub/vitess/go/vt/proto/query" errors "gopkg.in/src-d/go-errors.v1" @@ -518,6 +520,60 @@ func (gt *GreaterThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) return result == 1, nil } +func (gt *GreaterThan) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { + l, ok := gt.Left().(sql.Expression2) + if !ok { + panic(fmt.Sprintf("%T does not implement sql.Expression2", gt.Left())) + } + r, ok := gt.Right().(sql.Expression2) + if !ok { + panic(fmt.Sprintf("%T does not implement sql.Expression2", gt.Right())) + } + + lv, err := l.Eval2(ctx, row) + if err != nil { + return sql.Value{}, nil + } + rv, err := r.Eval2(ctx, row) + if err != nil { + return sql.Value{}, nil + } + + // TODO: better implementation + res := bytes.Compare(lv.Val, rv.Val) // TODO: this is probably wrong + var rb byte + if res == 1 { + rb = 1 + } + ret := sql.Value{ + Val: sql.ValueBytes{rb}, + Typ: querypb.Type_INT8, + } + return ret, nil +} + +func (gt *GreaterThan) Type2() sql.Type2 { + return nil +} + +func (gt *GreaterThan) IsExpr2() bool { + lExpr, isExpr2 := gt.Left().(sql.Expression2) + if !isExpr2 { + return false + } + if !lExpr.IsExpr2() { + return false + } + rExpr, isExpr2 := gt.Right().(sql.Expression2) + if !isExpr2 { + return false + } + if !rExpr.IsExpr2() { + return false + } + return true +} + // WithChildren implements the Expression interface. func (gt *GreaterThan) WithChildren(children ...sql.Expression) (sql.Expression, error) { if len(children) != 2 { diff --git a/sql/expression/get_field.go b/sql/expression/get_field.go index f4ff9b429e..5e9263760f 100644 --- a/sql/expression/get_field.go +++ b/sql/expression/get_field.go @@ -157,6 +157,10 @@ func (p *GetField) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { return row.GetField(p.fieldIndex), nil } +func (p *GetField) IsExpr2() bool { + return true +} + // WithChildren implements the Expression interface. func (p *GetField) WithChildren(children ...sql.Expression) (sql.Expression, error) { if len(children) != 0 { diff --git a/sql/expression/literal.go b/sql/expression/literal.go index 8fff9557a7..2b5583dc5b 100644 --- a/sql/expression/literal.go +++ b/sql/expression/literal.go @@ -140,6 +140,10 @@ func (lit *Literal) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { return lit.val2, nil } +func (lit *Literal) IsExpr2() bool { + return true +} + func (lit *Literal) Type2() sql.Type2 { t2, ok := lit.Typ.(sql.Type2) if !ok { diff --git a/sql/expression/unresolved.go b/sql/expression/unresolved.go index 78e2e9d0b9..c421699722 100644 --- a/sql/expression/unresolved.go +++ b/sql/expression/unresolved.go @@ -79,6 +79,10 @@ func (uc *UnresolvedColumn) Type2() sql.Type2 { panic("unresolved column is a placeholder node, but Type2 was called") } +func (uc *UnresolvedColumn) IsExpr2() bool { + panic("unresolved column is a placeholder node, but IsExpr2 was called") +} + // Name implements the Nameable interface. func (uc *UnresolvedColumn) Name() string { return uc.name } diff --git a/sql/plan/filter.go b/sql/plan/filter.go index 8be34ec2c7..b9edb0aa41 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -142,20 +142,36 @@ func (i *FilterIter) Next2(ctx *sql.Context) (sql.Row2, error) { } for { - row, err := ri2.Next(ctx) + row, err := ri2.Next2(ctx) if err != nil { return nil, err } - res, err := sql.EvaluateCondition(ctx, i.cond, row) + // TODO: write EvaluateCondition2? + cond, isCond2 := i.cond.(sql.Expression2) + if !isCond2 { + panic(fmt.Sprintf("%T does not implement sql.Expression2 interface", i.cond)) + } + res, err := cond.Eval2(ctx, row) if err != nil { return nil, err } + if res.Val[0] == 1 { + return row, nil + } + } +} - if sql.IsTrue(res) { - return nil, nil +func (i *FilterIter) IsRowIter2(ctx *sql.Context) bool { + if cond, isExpr2 := i.cond.(sql.Expression2); isExpr2 { + if !cond.IsExpr2() { + return false } } + if ri2, ok := i.childIter.(sql.RowIter2); ok { + return ri2.IsRowIter2(ctx) + } + return false } // Close implements the RowIter interface. From 16e6f0205a4c4504d0d4fa3152fbc09d8dafc0c3 Mon Sep 17 00:00:00 2001 From: jycor Date: Mon, 13 Oct 2025 00:55:24 +0000 Subject: [PATCH 07/18] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/expression/comparison.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 6b01443c9f..9ff27d108e 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -17,8 +17,8 @@ package expression import ( "bytes" "fmt" - querypb "github.com/dolthub/vitess/go/vt/proto/query" + querypb "github.com/dolthub/vitess/go/vt/proto/query" errors "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" From b7c458bdd1474a2da11e305b49ecedfcb5d03fc6 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 13 Oct 2025 01:02:58 -0700 Subject: [PATCH 08/18] reduce type asserts --- sql/plan/filter.go | 36 +++++++++++++------------------- sql/plan/process.go | 15 +++++++------ sql/rowexec/transaction_iters.go | 16 +++++++------- sql/table_iter.go | 2 +- 4 files changed, 29 insertions(+), 40 deletions(-) diff --git a/sql/plan/filter.go b/sql/plan/filter.go index b9edb0aa41..c2bf35c50e 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -15,8 +15,6 @@ package plan import ( - "fmt" - "github.com/dolthub/go-mysql-server/sql" ) @@ -106,6 +104,9 @@ func (f *Filter) Expressions() []sql.Expression { type FilterIter struct { cond sql.Expression childIter sql.RowIter + + cond2 sql.Expression2 + childIter2 sql.RowIter2 } // NewFilterIter creates a new FilterIter. @@ -136,23 +137,12 @@ func (i *FilterIter) Next(ctx *sql.Context) (sql.Row, error) { } func (i *FilterIter) Next2(ctx *sql.Context) (sql.Row2, error) { - ri2, ok := i.childIter.(sql.RowIter2) - if !ok { - panic(fmt.Sprintf("%T is not a sql.RowIter2", i.childIter)) - } - for { - row, err := ri2.Next2(ctx) + row, err := i.childIter2.Next2(ctx) if err != nil { return nil, err } - - // TODO: write EvaluateCondition2? - cond, isCond2 := i.cond.(sql.Expression2) - if !isCond2 { - panic(fmt.Sprintf("%T does not implement sql.Expression2 interface", i.cond)) - } - res, err := cond.Eval2(ctx, row) + res, err := i.cond2.Eval2(ctx, row) if err != nil { return nil, err } @@ -163,15 +153,17 @@ func (i *FilterIter) Next2(ctx *sql.Context) (sql.Row2, error) { } func (i *FilterIter) IsRowIter2(ctx *sql.Context) bool { - if cond, isExpr2 := i.cond.(sql.Expression2); isExpr2 { - if !cond.IsExpr2() { - return false - } + cond, ok := i.cond.(sql.Expression2) + if !ok || !cond.IsExpr2() { + return false } - if ri2, ok := i.childIter.(sql.RowIter2); ok { - return ri2.IsRowIter2(ctx) + childIter, ok := i.childIter.(sql.RowIter2) + if !ok || !childIter.IsRowIter2(ctx) { + return false } - return false + i.cond2 = cond + i.childIter2 = childIter + return true } // Close implements the RowIter interface. diff --git a/sql/plan/process.go b/sql/plan/process.go index 92f33ba19f..70a687247f 100644 --- a/sql/plan/process.go +++ b/sql/plan/process.go @@ -226,6 +226,7 @@ const ( type TrackedRowIter struct { node sql.Node iter sql.RowIter + iter2 sql.RowIter2 onDone NotifyFunc onNext NotifyFunc numRows int64 @@ -318,11 +319,7 @@ func (i *TrackedRowIter) Next(ctx *sql.Context) (sql.Row, error) { } func (i *TrackedRowIter) Next2(ctx *sql.Context) (sql.Row2, error) { - ri2, ok := i.iter.(sql.RowIter2) - if !ok { - panic(fmt.Sprintf("%T does not implement sql.RowIter2 interface", i.iter)) - } - row, err := ri2.Next2(ctx) + row, err := i.iter2.Next2(ctx) if err != nil { return nil, err } @@ -334,10 +331,12 @@ func (i *TrackedRowIter) Next2(ctx *sql.Context) (sql.Row2, error) { } func (i *TrackedRowIter) IsRowIter2(ctx *sql.Context) bool { - if ri2, ok := i.iter.(sql.RowIter2); ok { - return ri2.IsRowIter2(ctx) + iter, ok := i.iter.(sql.RowIter2) + if !ok || !iter.IsRowIter2(ctx) { + return false } - return false + i.iter2 = iter + return true } func (i *TrackedRowIter) Close(ctx *sql.Context) error { diff --git a/sql/rowexec/transaction_iters.go b/sql/rowexec/transaction_iters.go index aacc1f095d..f0f56168ef 100644 --- a/sql/rowexec/transaction_iters.go +++ b/sql/rowexec/transaction_iters.go @@ -15,7 +15,6 @@ package rowexec import ( - "fmt" "io" "gopkg.in/src-d/go-errors.v1" @@ -72,6 +71,7 @@ func getLockableTable(table sql.Table) (sql.Lockable, error) { // during the Close() operation type TransactionCommittingIter struct { childIter sql.RowIter + childIter2 sql.RowIter2 transactionDatabase string autoCommit bool implicitCommit bool @@ -101,18 +101,16 @@ func (t *TransactionCommittingIter) Next(ctx *sql.Context) (sql.Row, error) { } func (t *TransactionCommittingIter) Next2(ctx *sql.Context) (sql.Row2, error) { - ri2, ok := t.childIter.(sql.RowIter2) - if !ok { - panic(fmt.Sprintf("%T does not implement sql.RowIter2 interface", t.childIter)) - } - return ri2.Next2(ctx) + return t.childIter2.Next2(ctx) } func (t *TransactionCommittingIter) IsRowIter2(ctx *sql.Context) bool { - if ri2, ok := t.childIter.(sql.RowIter2); ok { - return ri2.IsRowIter2(ctx) + childIter, ok := t.childIter.(sql.RowIter2) + if !ok || !childIter.IsRowIter2(ctx) { + return false } - return false + t.childIter2 = childIter + return true } func (t *TransactionCommittingIter) Close(ctx *sql.Context) error { diff --git a/sql/table_iter.go b/sql/table_iter.go index 6ac205c377..884778307a 100644 --- a/sql/table_iter.go +++ b/sql/table_iter.go @@ -106,7 +106,7 @@ func (i *TableRowIter) Next2(ctx *Context) (Row2, error) { return nil, err } ri2, ok := rows.(RowIter2) - if !ok { + if !ok || !ri2.IsRowIter2(ctx) { panic(fmt.Sprintf("%T does not implement RowIter2", rows)) } i.rows2 = ri2 From b9ef4ee014fbbc04e49e148a7506488680e183c3 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 13 Oct 2025 01:09:45 -0700 Subject: [PATCH 09/18] split send and receive --- server/handler.go | 46 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/server/handler.go b/server/handler.go index bd82c27dfa..34ca72882d 100644 --- a/server/handler.go +++ b/server/handler.go @@ -771,7 +771,7 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s } func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sql.RowIter2, resultFields []*querypb.Field, callback func(*sqltypes.Result, bool) error, more bool) (*sqltypes.Result, bool, error) { - defer trace.StartRegion(ctx, "Handler.resultForDefaultIter").End() + defer trace.StartRegion(ctx, "Handler.resultForDefaultIter2").End() eg, ctx := ctx.NewErrgroup() pan2err := func(err *error) { @@ -803,7 +803,34 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq defer timer.Stop() wg := sync.WaitGroup{} - wg.Add(1) + wg.Add(2) + + // TODO: this should be merged below go func + var rowChan = make(chan sql.Row2, 512) + eg.Go(func() (err error) { + defer pan2err(&err) + defer wg.Done() + defer close(rowChan) + for { + select { + case <-ctx.Done(): + return context.Cause(ctx) + default: + row, err := iter.Next2(ctx) + if err == io.EOF { + return nil + } + if err != nil { + return err + } + select { + case rowChan <- row: + case <-ctx.Done(): + return nil + } + } + } + }) var res *sqltypes.Result var processedAtLeastOneBatch bool @@ -813,7 +840,10 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq defer wg.Done() for { if res == nil { - res = &sqltypes.Result{Fields: resultFields} + res = &sqltypes.Result{ + Fields: resultFields, + Rows: make([][]sqltypes.Value, 0, rowsBatch), + } } if res.RowsAffected == rowsBatch { if err := callback(res, more); err != nil { @@ -834,14 +864,12 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq ctx.GetLogger().Tracef("connection timeout") return ErrRowTimeout.New() } - default: - row, err := iter.Next2(ctx) - if err == io.EOF { + case row, ok := <-rowChan: + if !ok { return nil } - if err != nil { - return err - } + // TODO: we can avoid deep copy here by redefining sql.Row2 + ctx.GetLogger().Tracef("spooling result row %s", row) outRow := make([]sqltypes.Value, len(row)) for i := range row { outRow[i] = sqltypes.MakeTrusted(row[i].Typ, row[i].Val) From 643530a2fd358dbc11ee2e94991537630ed855bb Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 13 Oct 2025 11:52:39 -0700 Subject: [PATCH 10/18] directly return rows --- server/handler.go | 47 +++------- sql/convert_value.go | 92 +++++-------------- sql/core.go | 3 +- sql/expression/comparison.go | 28 +++--- sql/expression/get_field.go | 8 +- sql/expression/literal.go | 5 +- sql/expression/unresolved.go | 3 +- sql/plan/filter.go | 5 +- sql/row_frame.go | 15 ++-- sql/rows.go | 24 ++--- sql/type.go | 8 +- sql/types/number.go | 168 +++++++++++++---------------------- 12 files changed, 145 insertions(+), 261 deletions(-) diff --git a/server/handler.go b/server/handler.go index 34ca72882d..fe1b50d700 100644 --- a/server/handler.go +++ b/server/handler.go @@ -803,34 +803,7 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq defer timer.Stop() wg := sync.WaitGroup{} - wg.Add(2) - - // TODO: this should be merged below go func - var rowChan = make(chan sql.Row2, 512) - eg.Go(func() (err error) { - defer pan2err(&err) - defer wg.Done() - defer close(rowChan) - for { - select { - case <-ctx.Done(): - return context.Cause(ctx) - default: - row, err := iter.Next2(ctx) - if err == io.EOF { - return nil - } - if err != nil { - return err - } - select { - case rowChan <- row: - case <-ctx.Done(): - return nil - } - } - } - }) + wg.Add(1) var res *sqltypes.Result var processedAtLeastOneBatch bool @@ -864,18 +837,20 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq ctx.GetLogger().Tracef("connection timeout") return ErrRowTimeout.New() } - case row, ok := <-rowChan: - if !ok { + default: + row, err := iter.Next2(ctx) + if err == io.EOF { return nil } - // TODO: we can avoid deep copy here by redefining sql.Row2 - ctx.GetLogger().Tracef("spooling result row %s", row) - outRow := make([]sqltypes.Value, len(row)) - for i := range row { - outRow[i] = sqltypes.MakeTrusted(row[i].Typ, row[i].Val) + if err != nil { + return err } - res.Rows = append(res.Rows, outRow) + ctx.GetLogger().Tracef("spooling result row %s", row) + res.Rows = append(res.Rows, row) res.RowsAffected++ + if !timer.Stop() { + <-timer.C + } } timer.Reset(waitTime) } diff --git a/sql/convert_value.go b/sql/convert_value.go index d46fe4de4e..d64d5fbe98 100644 --- a/sql/convert_value.go +++ b/sql/convert_value.go @@ -3,98 +3,46 @@ package sql import ( "fmt" - "github.com/dolthub/vitess/go/vt/proto/query" - "github.com/dolthub/go-mysql-server/sql/values" + + "github.com/dolthub/vitess/go/sqltypes" + "github.com/dolthub/vitess/go/vt/proto/query" ) // ConvertToValue converts the interface to a sql value. -func ConvertToValue(v interface{}) (Value, error) { +func ConvertToValue(v interface{}) (sqltypes.Value, error) { switch v := v.(type) { case nil: - return Value{ - Typ: query.Type_NULL_TYPE, - Val: nil, - }, nil + return sqltypes.MakeTrusted(query.Type_NULL_TYPE, nil), nil case int: - return Value{ - Typ: query.Type_INT64, - Val: values.WriteInt64(make([]byte, values.Int64Size), int64(v)), - }, nil + return sqltypes.MakeTrusted(query.Type_INT64, values.WriteInt64(make([]byte, values.Int64Size), int64(v))), nil case int8: - return Value{ - Typ: query.Type_INT8, - Val: values.WriteInt8(make([]byte, values.Int8Size), v), - }, nil + return sqltypes.MakeTrusted(query.Type_INT8, values.WriteInt8(make([]byte, values.Int8Size), v)), nil case int16: - return Value{ - Typ: query.Type_INT16, - Val: values.WriteInt16(make([]byte, values.Int16Size), v), - }, nil + return sqltypes.MakeTrusted(query.Type_INT16, values.WriteInt16(make([]byte, values.Int16Size), v)), nil case int32: - return Value{ - Typ: query.Type_INT32, - Val: values.WriteInt32(make([]byte, values.Int32Size), v), - }, nil + return sqltypes.MakeTrusted(query.Type_INT32, values.WriteInt32(make([]byte, values.Int32Size), v)), nil case int64: - return Value{ - Typ: query.Type_INT64, - Val: values.WriteInt64(make([]byte, values.Int64Size), v), - }, nil + return sqltypes.MakeTrusted(query.Type_INT64, values.WriteInt64(make([]byte, values.Int64Size), v)), nil case uint: - return Value{ - Typ: query.Type_UINT64, - Val: values.WriteUint64(make([]byte, values.Uint64Size), uint64(v)), - }, nil + return sqltypes.MakeTrusted(query.Type_UINT64, values.WriteUint64(make([]byte, values.Uint64Size), uint64(v))), nil case uint8: - return Value{ - Typ: query.Type_UINT8, - Val: values.WriteUint8(make([]byte, values.Uint8Size), v), - }, nil + return sqltypes.MakeTrusted(query.Type_UINT8, values.WriteUint8(make([]byte, values.Uint8Size), v)), nil case uint16: - return Value{ - Typ: query.Type_UINT16, - Val: values.WriteUint16(make([]byte, values.Uint16Size), v), - }, nil + return sqltypes.MakeTrusted(query.Type_UINT16, values.WriteUint16(make([]byte, values.Uint16Size), v)), nil case uint32: - return Value{ - Typ: query.Type_UINT32, - Val: values.WriteUint32(make([]byte, values.Uint32Size), v), - }, nil + return sqltypes.MakeTrusted(query.Type_UINT32, values.WriteUint32(make([]byte, values.Uint32Size), v)), nil case uint64: - return Value{ - Typ: query.Type_UINT64, - Val: values.WriteUint64(make([]byte, values.Uint64Size), v), - }, nil + return sqltypes.MakeTrusted(query.Type_UINT64, values.WriteUint64(make([]byte, values.Uint64Size), v)), nil case float32: - return Value{ - Typ: query.Type_FLOAT32, - Val: values.WriteFloat32(make([]byte, values.Float32Size), v), - }, nil + return sqltypes.MakeTrusted(query.Type_FLOAT32, values.WriteFloat32(make([]byte, values.Float32Size), v)), nil case float64: - return Value{ - Typ: query.Type_FLOAT64, - Val: values.WriteFloat64(make([]byte, values.Float64Size), v), - }, nil + return sqltypes.MakeTrusted(query.Type_FLOAT64, values.WriteFloat64(make([]byte, values.Float64Size), v)), nil case string: - return Value{ - Typ: query.Type_VARCHAR, - Val: values.WriteString(make([]byte, len(v)), v, values.ByteOrderCollation), - }, nil + return sqltypes.MakeTrusted(query.Type_VARCHAR, values.WriteString(make([]byte, len(v)), v, values.ByteOrderCollation)), nil case []byte: - return Value{ - Typ: query.Type_BLOB, - Val: values.WriteBytes(make([]byte, len(v)), v, values.ByteOrderCollation), - }, nil + return sqltypes.MakeTrusted(query.Type_BLOB, values.WriteBytes(make([]byte, len(v)), v, values.ByteOrderCollation)), nil default: - return Value{}, fmt.Errorf("type %T not implemented", v) - } -} - -func MustConvertToValue(v interface{}) Value { - ret, err := ConvertToValue(v) - if err != nil { - panic(err) + return sqltypes.Value{}, fmt.Errorf("type %T not implemented", v) } - return ret } diff --git a/sql/core.go b/sql/core.go index c2996039eb..19a7a7a895 100644 --- a/sql/core.go +++ b/sql/core.go @@ -30,6 +30,7 @@ import ( "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql/values" + "github.com/dolthub/vitess/go/sqltypes" ) // Expression is a combination of one or more SQL expressions. @@ -464,7 +465,7 @@ func DebugString(nodeOrExpression interface{}) string { type Expression2 interface { Expression // Eval2 evaluates the given row frame and returns a result. - Eval2(ctx *Context, row Row2) (Value, error) + Eval2(ctx *Context, row Row2) (sqltypes.Value, error) // Type2 returns the expression type. Type2() Type2 IsExpr2() bool diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 9ff27d108e..aa4ade4ff9 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -15,8 +15,8 @@ package expression import ( - "bytes" "fmt" + "github.com/dolthub/vitess/go/sqltypes" querypb "github.com/dolthub/vitess/go/vt/proto/query" errors "gopkg.in/src-d/go-errors.v1" @@ -494,6 +494,7 @@ type GreaterThan struct { } var _ sql.Expression = (*GreaterThan)(nil) +var _ sql.Expression2 = (*GreaterThan)(nil) var _ sql.CollationCoercible = (*GreaterThan)(nil) // NewGreaterThan creates a new GreaterThan expression. @@ -520,7 +521,7 @@ func (gt *GreaterThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) return result == 1, nil } -func (gt *GreaterThan) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { +func (gt *GreaterThan) Eval2(ctx *sql.Context, row sql.Row2) (sqltypes.Value, error) { l, ok := gt.Left().(sql.Expression2) if !ok { panic(fmt.Sprintf("%T does not implement sql.Expression2", gt.Left())) @@ -532,23 +533,28 @@ func (gt *GreaterThan) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) lv, err := l.Eval2(ctx, row) if err != nil { - return sql.Value{}, nil + return sqltypes.Value{}, err } rv, err := r.Eval2(ctx, row) if err != nil { - return sql.Value{}, nil + return sqltypes.Value{}, err } - // TODO: better implementation - res := bytes.Compare(lv.Val, rv.Val) // TODO: this is probably wrong + // TODO: just assume they are int64 + l64, err := types.ConvertValueToInt64(types.NumberTypeImpl_{}, lv) + if err != nil { + return sqltypes.Value{}, err + } + r64, err := types.ConvertValueToInt64(types.NumberTypeImpl_{}, rv) + if err != nil { + return sqltypes.Value{}, err + } var rb byte - if res == 1 { + if l64 > r64 { rb = 1 } - ret := sql.Value{ - Val: sql.ValueBytes{rb}, - Typ: querypb.Type_INT8, - } + + ret := sqltypes.MakeTrusted(querypb.Type_INT8, []byte{rb}) return ret, nil } diff --git a/sql/expression/get_field.go b/sql/expression/get_field.go index 5e9263760f..398ca7107a 100644 --- a/sql/expression/get_field.go +++ b/sql/expression/get_field.go @@ -16,6 +16,7 @@ package expression import ( "fmt" + "github.com/dolthub/vitess/go/sqltypes" "strings" errors "gopkg.in/src-d/go-errors.v1" @@ -149,12 +150,11 @@ func (p *GetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return row[p.fieldIndex], nil } -func (p *GetField) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { +func (p *GetField) Eval2(ctx *sql.Context, row sql.Row2) (sqltypes.Value, error) { if p.fieldIndex < 0 || p.fieldIndex >= row.Len() { - return sql.Value{}, ErrIndexOutOfBounds.New(p.fieldIndex, row.Len()) + return sqltypes.Value{}, ErrIndexOutOfBounds.New(p.fieldIndex, row.Len()) } - - return row.GetField(p.fieldIndex), nil + return row[p.fieldIndex], nil } func (p *GetField) IsExpr2() bool { diff --git a/sql/expression/literal.go b/sql/expression/literal.go index 2b5583dc5b..b386b86412 100644 --- a/sql/expression/literal.go +++ b/sql/expression/literal.go @@ -16,6 +16,7 @@ package expression import ( "fmt" + "github.com/dolthub/vitess/go/sqltypes" "strings" "github.com/dolthub/vitess/go/vt/proto/query" @@ -30,7 +31,7 @@ import ( type Literal struct { Val interface{} Typ sql.Type - val2 sql.Value + val2 sqltypes.Value } var _ sql.Expression = &Literal{} @@ -136,7 +137,7 @@ func (*Literal) Children() []sql.Expression { return nil } -func (lit *Literal) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { +func (lit *Literal) Eval2(ctx *sql.Context, row sql.Row2) (sqltypes.Value, error) { return lit.val2, nil } diff --git a/sql/expression/unresolved.go b/sql/expression/unresolved.go index c421699722..0173651464 100644 --- a/sql/expression/unresolved.go +++ b/sql/expression/unresolved.go @@ -16,6 +16,7 @@ package expression import ( "fmt" + "github.com/dolthub/vitess/go/sqltypes" "strings" "gopkg.in/src-d/go-errors.v1" @@ -71,7 +72,7 @@ func (*UnresolvedColumn) CollationCoercibility(ctx *sql.Context) (collation sql. return sql.Collation_binary, 7 } -func (uc *UnresolvedColumn) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { +func (uc *UnresolvedColumn) Eval2(ctx *sql.Context, row sql.Row2) (sqltypes.Value, error) { panic("unresolved column is a placeholder node, but Eval2 was called") } diff --git a/sql/plan/filter.go b/sql/plan/filter.go index c2bf35c50e..0d0e4ebe39 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -109,6 +109,9 @@ type FilterIter struct { childIter2 sql.RowIter2 } +var _ sql.RowIter = (*FilterIter)(nil) +var _ sql.RowIter2 = (*FilterIter)(nil) + // NewFilterIter creates a new FilterIter. func NewFilterIter( cond sql.Expression, @@ -146,7 +149,7 @@ func (i *FilterIter) Next2(ctx *sql.Context) (sql.Row2, error) { if err != nil { return nil, err } - if res.Val[0] == 1 { + if res.Raw()[0] == 1 { return row, nil } } diff --git a/sql/row_frame.go b/sql/row_frame.go index ef3ea6010f..a4384ec458 100644 --- a/sql/row_frame.go +++ b/sql/row_frame.go @@ -17,6 +17,7 @@ package sql import ( "sync" + "github.com/dolthub/vitess/go/sqltypes" querypb "github.com/dolthub/vitess/go/vt/proto/query" ) @@ -26,10 +27,10 @@ const ( ) // Row2 is a slice of values -type Row2 []Value +type Row2 []sqltypes.Value // GetField returns the Value for the ith field in this row. -func (r Row2) GetField(i int) Value { +func (r Row2) GetField(i int) sqltypes.Value { return r[i] } @@ -97,10 +98,7 @@ func (f *RowFrame) Row2() Row2 { rs := make(Row2, len(f.Values)) for i := range f.Values { - rs[i] = Value{ - Typ: f.Types[i], - Val: f.Values[i], - } + rs[i] = sqltypes.MakeTrusted(f.Types[i], f.Values[i]) } return rs } @@ -113,10 +111,7 @@ func (f *RowFrame) Row2Copy() Row2 { for i := range f.Values { v := make(ValueBytes, len(f.Values[i])) copy(v, f.Values[i]) - rs[i] = Value{ - Typ: f.Types[i], - Val: v, - } + rs[i] = sqltypes.MakeTrusted(f.Types[i], v) } return rs } diff --git a/sql/rows.go b/sql/rows.go index 191147ad68..2a969363bc 100644 --- a/sql/rows.go +++ b/sql/rows.go @@ -123,29 +123,29 @@ func RowFromRow2(sch Schema, r Row2) Row { for i, col := range sch { switch col.Type.Type() { case query.Type_INT8: - row[i] = values.ReadInt8(r.GetField(i).Val) + row[i] = values.ReadInt8(r.GetField(i).Raw()) case query.Type_UINT8: - row[i] = values.ReadUint8(r.GetField(i).Val) + row[i] = values.ReadUint8(r.GetField(i).Raw()) case query.Type_INT16: - row[i] = values.ReadInt16(r.GetField(i).Val) + row[i] = values.ReadInt16(r.GetField(i).Raw()) case query.Type_UINT16: - row[i] = values.ReadUint16(r.GetField(i).Val) + row[i] = values.ReadUint16(r.GetField(i).Raw()) case query.Type_INT32: - row[i] = values.ReadInt32(r.GetField(i).Val) + row[i] = values.ReadInt32(r.GetField(i).Raw()) case query.Type_UINT32: - row[i] = values.ReadUint32(r.GetField(i).Val) + row[i] = values.ReadUint32(r.GetField(i).Raw()) case query.Type_INT64: - row[i] = values.ReadInt64(r.GetField(i).Val) + row[i] = values.ReadInt64(r.GetField(i).Raw()) case query.Type_UINT64: - row[i] = values.ReadUint64(r.GetField(i).Val) + row[i] = values.ReadUint64(r.GetField(i).Raw()) case query.Type_FLOAT32: - row[i] = values.ReadFloat32(r.GetField(i).Val) + row[i] = values.ReadFloat32(r.GetField(i).Raw()) case query.Type_FLOAT64: - row[i] = values.ReadFloat64(r.GetField(i).Val) + row[i] = values.ReadFloat64(r.GetField(i).Raw()) case query.Type_TEXT, query.Type_VARCHAR, query.Type_CHAR: - row[i] = values.ReadString(r.GetField(i).Val, values.ByteOrderCollation) + row[i] = values.ReadString(r.GetField(i).Raw(), values.ByteOrderCollation) case query.Type_BLOB, query.Type_VARBINARY, query.Type_BINARY: - row[i] = values.ReadBytes(r.GetField(i).Val, values.ByteOrderCollation) + row[i] = values.ReadBytes(r.GetField(i).Raw(), values.ByteOrderCollation) case query.Type_BIT: fallthrough case query.Type_ENUM: diff --git a/sql/type.go b/sql/type.go index e4d0f8ff96..285744a564 100644 --- a/sql/type.go +++ b/sql/type.go @@ -295,13 +295,11 @@ func IsDecimalType(t Type) bool { type Type2 interface { Type // Compare2 returns an integer comparing two Values. - Compare2(Value, Value) (int, error) + Compare2(sqltypes.Value, sqltypes.Value) (int, error) // Convert2 converts a value of a compatible type. - Convert2(Value) (Value, error) + Convert2(sqltypes.Value) (sqltypes.Value, error) // Zero2 returns the zero Value for this type. - Zero2() Value - // SQL2 returns the sqltypes.Value for the given value - SQL2(Value) (sqltypes.Value, error) + Zero2() sqltypes.Value } // SpatialColumnType is a node that contains a reference to all spatial types. diff --git a/sql/types/number.go b/sql/types/number.go index 1601be830a..ff05d914e8 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -728,7 +728,7 @@ func (t NumberTypeImpl_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqlt return sqltypes.MakeTrusted(t.baseType, val), nil } -func (t NumberTypeImpl_) Compare2(a sql.Value, b sql.Value) (int, error) { +func (t NumberTypeImpl_) Compare2(a sqltypes.Value, b sqltypes.Value) (int, error) { switch t.baseType { case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint24, sqltypes.Uint32, sqltypes.Uint64: ca, err := convertValueToUint64(t, a) @@ -765,11 +765,11 @@ func (t NumberTypeImpl_) Compare2(a sql.Value, b sql.Value) (int, error) { } return +1, nil default: - ca, err := convertValueToInt64(t, a) + ca, err := ConvertValueToInt64(t, a) if err != nil { return 0, err } - cb, err := convertValueToInt64(t, b) + cb, err := ConvertValueToInt64(t, b) if err != nil { return 0, err } @@ -784,84 +784,40 @@ func (t NumberTypeImpl_) Compare2(a sql.Value, b sql.Value) (int, error) { } } -func (t NumberTypeImpl_) Convert2(value sql.Value) (sql.Value, error) { +func (t NumberTypeImpl_) Convert2(value sqltypes.Value) (sqltypes.Value, error) { panic("implement me") } -func (t NumberTypeImpl_) Zero2() sql.Value { +func (t NumberTypeImpl_) Zero2() sqltypes.Value { switch t.baseType { case sqltypes.Int8: - x := values.WriteInt8(make([]byte, values.Int8Size), 0) - return sql.Value{ - Typ: query.Type_INT8, - Val: x, - } + return sqltypes.MakeTrusted(query.Type_INT8, make([]byte, values.Int8Size)) case sqltypes.Int16: - x := values.WriteInt16(make([]byte, values.Int16Size), 0) - return sql.Value{ - Typ: query.Type_INT16, - Val: x, - } + return sqltypes.MakeTrusted(query.Type_INT16, make([]byte, values.Int16Size)) case sqltypes.Int24: - x := values.WriteInt24(make([]byte, values.Int24Size), 0) - return sql.Value{ - Typ: query.Type_INT24, - Val: x, - } + return sqltypes.MakeTrusted(query.Type_INT24, make([]byte, values.Int24Size)) case sqltypes.Int32: - x := values.WriteInt32(make([]byte, values.Int32Size), 0) - return sql.Value{ - Typ: query.Type_INT32, - Val: x, - } + return sqltypes.MakeTrusted(query.Type_INT32, make([]byte, values.Int32Size)) case sqltypes.Int64: - x := values.WriteInt64(make([]byte, values.Int64Size), 0) - return sql.Value{ - Typ: query.Type_INT64, - Val: x, - } + return sqltypes.MakeTrusted(query.Type_INT64, make([]byte, values.Int64Size)) case sqltypes.Uint8: - x := values.WriteUint8(make([]byte, values.Uint8Size), 0) - return sql.Value{ - Typ: query.Type_UINT8, - Val: x, - } + return sqltypes.MakeTrusted(query.Type_UINT8, make([]byte, values.Uint8Size)) case sqltypes.Uint16: - x := values.WriteUint16(make([]byte, values.Uint16Size), 0) - return sql.Value{ - Typ: query.Type_UINT16, - Val: x, - } + return sqltypes.MakeTrusted(query.Type_UINT16, make([]byte, values.Uint16Size)) case sqltypes.Uint24: - x := values.WriteUint24(make([]byte, values.Uint24Size), 0) - return sql.Value{ - Typ: query.Type_UINT24, - Val: x, - } + return sqltypes.MakeTrusted(query.Type_UINT24, make([]byte, values.Uint24Size)) case sqltypes.Uint32: - x := values.WriteUint32(make([]byte, values.Uint32Size), 0) - return sql.Value{ - Typ: query.Type_UINT32, - Val: x, - } + return sqltypes.MakeTrusted(query.Type_UINT32, make([]byte, values.Uint32Size)) case sqltypes.Uint64: - x := values.WriteUint64(make([]byte, values.Uint64Size), 0) - return sql.Value{ - Typ: query.Type_UINT64, - Val: x, - } + return sqltypes.MakeTrusted(query.Type_UINT64, make([]byte, values.Uint64Size)) case sqltypes.Float32: + // TODO: 0 float32 is just 0? x := values.WriteFloat32(make([]byte, values.Float32Size), 0) - return sql.Value{ - Typ: query.Type_FLOAT32, - Val: x, - } + return sqltypes.MakeTrusted(query.Type_FLOAT32, x) case sqltypes.Float64: - x := values.WriteUint64(make([]byte, values.Uint64Size), 0) - return sql.Value{ - Typ: query.Type_UINT64, - Val: x, - } + // TODO: 0 float64 is just 0? + x := values.WriteFloat64(make([]byte, values.Float64Size), 0) + return sqltypes.MakeTrusted(query.Type_FLOAT64, x) default: panic(sql.ErrInvalidBaseType.New(t.baseType.String(), "number")) } @@ -1149,34 +1105,34 @@ func convertToInt64(t NumberTypeImpl_, v any, round Round) (int64, sql.ConvertIn } } -func convertValueToInt64(t NumberTypeImpl_, v sql.Value) (int64, error) { - switch v.Typ { +func ConvertValueToInt64(t NumberTypeImpl_, v sqltypes.Value) (int64, error) { + switch v.Type() { case query.Type_INT8: - return int64(values.ReadInt8(v.Val)), nil + return int64(values.ReadInt8(v.Raw())), nil case query.Type_INT16: - return int64(values.ReadInt16(v.Val)), nil + return int64(values.ReadInt16(v.Raw())), nil case query.Type_INT24: - return int64(values.ReadInt24(v.Val)), nil + return int64(values.ReadInt24(v.Raw())), nil case query.Type_INT32: - return int64(values.ReadInt32(v.Val)), nil + return int64(values.ReadInt32(v.Raw())), nil case query.Type_INT64: - return values.ReadInt64(v.Val), nil + return values.ReadInt64(v.Raw()), nil case query.Type_UINT8: - return int64(values.ReadUint8(v.Val)), nil + return int64(values.ReadUint8(v.Raw())), nil case query.Type_UINT16: - return int64(values.ReadUint16(v.Val)), nil + return int64(values.ReadUint16(v.Raw())), nil case query.Type_UINT24: - return int64(values.ReadUint24(v.Val)), nil + return int64(values.ReadUint24(v.Raw())), nil case query.Type_UINT32: - return int64(values.ReadUint32(v.Val)), nil + return int64(values.ReadUint32(v.Raw())), nil case query.Type_UINT64: - v := values.ReadUint64(v.Val) + v := values.ReadUint64(v.Raw()) if v > math.MaxInt64 { return math.MaxInt64, nil } return int64(v), nil case query.Type_FLOAT32: - v := values.ReadFloat32(v.Val) + v := values.ReadFloat32(v.Raw()) if v > float32(math.MaxInt64) { return math.MaxInt64, nil } else if v < float32(math.MinInt64) { @@ -1184,7 +1140,7 @@ func convertValueToInt64(t NumberTypeImpl_, v sql.Value) (int64, error) { } return int64(math.Round(float64(v))), nil case query.Type_FLOAT64: - v := values.ReadFloat64(v.Val) + v := values.ReadFloat64(v.Raw()) if v > float64(math.MaxInt64) { return math.MaxInt64, nil } else if v < float64(math.MinInt64) { @@ -1197,36 +1153,36 @@ func convertValueToInt64(t NumberTypeImpl_, v sql.Value) (int64, error) { } } -func convertValueToUint64(t NumberTypeImpl_, v sql.Value) (uint64, error) { - switch v.Typ { +func convertValueToUint64(t NumberTypeImpl_, v sqltypes.Value) (uint64, error) { + switch v.Type() { case query.Type_INT8: - return uint64(values.ReadInt8(v.Val)), nil + return uint64(values.ReadInt8(v.Raw())), nil case query.Type_INT16: - return uint64(values.ReadInt16(v.Val)), nil + return uint64(values.ReadInt16(v.Raw())), nil case query.Type_INT24: - return uint64(values.ReadInt24(v.Val)), nil + return uint64(values.ReadInt24(v.Raw())), nil case query.Type_INT32: - return uint64(values.ReadInt32(v.Val)), nil + return uint64(values.ReadInt32(v.Raw())), nil case query.Type_INT64: - return uint64(values.ReadInt64(v.Val)), nil + return uint64(values.ReadInt64(v.Raw())), nil case query.Type_UINT8: - return uint64(values.ReadUint8(v.Val)), nil + return uint64(values.ReadUint8(v.Raw())), nil case query.Type_UINT16: - return uint64(values.ReadUint16(v.Val)), nil + return uint64(values.ReadUint16(v.Raw())), nil case query.Type_UINT24: - return uint64(values.ReadUint24(v.Val)), nil + return uint64(values.ReadUint24(v.Raw())), nil case query.Type_UINT32: - return uint64(values.ReadUint32(v.Val)), nil + return uint64(values.ReadUint32(v.Raw())), nil case query.Type_UINT64: - return values.ReadUint64(v.Val), nil + return values.ReadUint64(v.Raw()), nil case query.Type_FLOAT32: - v := values.ReadFloat32(v.Val) + v := values.ReadFloat32(v.Raw()) if v >= float32(math.MaxUint64) { return math.MaxUint64, nil } return uint64(math.Round(float64(v))), nil case query.Type_FLOAT64: - v := values.ReadFloat64(v.Val) + v := values.ReadFloat64(v.Raw()) if v >= float64(math.MaxUint64) { return math.MaxUint64, nil } @@ -1419,32 +1375,32 @@ func convertToFloat64(t NumberTypeImpl_, v interface{}) (float64, error) { } } -func convertValueToFloat64(t NumberTypeImpl_, v sql.Value) (float64, error) { - switch v.Typ { +func convertValueToFloat64(t NumberTypeImpl_, v sqltypes.Value) (float64, error) { + switch v.Type() { case query.Type_INT8: - return float64(values.ReadInt8(v.Val)), nil + return float64(values.ReadInt8(v.Raw())), nil case query.Type_INT16: - return float64(values.ReadInt16(v.Val)), nil + return float64(values.ReadInt16(v.Raw())), nil case query.Type_INT24: - return float64(values.ReadInt24(v.Val)), nil + return float64(values.ReadInt24(v.Raw())), nil case query.Type_INT32: - return float64(values.ReadInt32(v.Val)), nil + return float64(values.ReadInt32(v.Raw())), nil case query.Type_INT64: - return float64(values.ReadInt64(v.Val)), nil + return float64(values.ReadInt64(v.Raw())), nil case query.Type_UINT8: - return float64(values.ReadUint8(v.Val)), nil + return float64(values.ReadUint8(v.Raw())), nil case query.Type_UINT16: - return float64(values.ReadUint16(v.Val)), nil + return float64(values.ReadUint16(v.Raw())), nil case query.Type_UINT24: - return float64(values.ReadUint24(v.Val)), nil + return float64(values.ReadUint24(v.Raw())), nil case query.Type_UINT32: - return float64(values.ReadUint32(v.Val)), nil + return float64(values.ReadUint32(v.Raw())), nil case query.Type_UINT64: - return float64(values.ReadUint64(v.Val)), nil + return float64(values.ReadUint64(v.Raw())), nil case query.Type_FLOAT32: - return float64(values.ReadFloat32(v.Val)), nil + return float64(values.ReadFloat32(v.Raw())), nil case query.Type_FLOAT64: - return values.ReadFloat64(v.Val), nil + return values.ReadFloat64(v.Raw()), nil default: panic(sql.ErrInvalidBaseType.New(t.baseType.String(), "number")) } From aecb0e7b994a0e95f48b3c072d11bf072a88e66d Mon Sep 17 00:00:00 2001 From: jycor Date: Mon, 13 Oct 2025 18:54:34 +0000 Subject: [PATCH 11/18] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/core.go | 2 +- sql/expression/comparison.go | 2 +- sql/expression/get_field.go | 2 +- sql/expression/literal.go | 2 +- sql/expression/unresolved.go | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core.go b/sql/core.go index 19a7a7a895..ee8d6e2f4d 100644 --- a/sql/core.go +++ b/sql/core.go @@ -26,11 +26,11 @@ import ( "time" "unsafe" + "github.com/dolthub/vitess/go/sqltypes" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql/values" - "github.com/dolthub/vitess/go/sqltypes" ) // Expression is a combination of one or more SQL expressions. diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index aa4ade4ff9..d6d6d402ef 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -16,8 +16,8 @@ package expression import ( "fmt" - "github.com/dolthub/vitess/go/sqltypes" + "github.com/dolthub/vitess/go/sqltypes" querypb "github.com/dolthub/vitess/go/vt/proto/query" errors "gopkg.in/src-d/go-errors.v1" diff --git a/sql/expression/get_field.go b/sql/expression/get_field.go index 398ca7107a..2611858867 100644 --- a/sql/expression/get_field.go +++ b/sql/expression/get_field.go @@ -16,9 +16,9 @@ package expression import ( "fmt" - "github.com/dolthub/vitess/go/sqltypes" "strings" + "github.com/dolthub/vitess/go/sqltypes" errors "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" diff --git a/sql/expression/literal.go b/sql/expression/literal.go index b386b86412..1411ca0830 100644 --- a/sql/expression/literal.go +++ b/sql/expression/literal.go @@ -16,9 +16,9 @@ package expression import ( "fmt" - "github.com/dolthub/vitess/go/sqltypes" "strings" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/shopspring/decimal" diff --git a/sql/expression/unresolved.go b/sql/expression/unresolved.go index 0173651464..a0df5ab12f 100644 --- a/sql/expression/unresolved.go +++ b/sql/expression/unresolved.go @@ -16,9 +16,9 @@ package expression import ( "fmt" - "github.com/dolthub/vitess/go/sqltypes" "strings" + "github.com/dolthub/vitess/go/sqltypes" "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" From 8aaba4e43c43c19515afab39440ace53cb412093 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 13 Oct 2025 15:25:08 -0700 Subject: [PATCH 12/18] resplit --- server/handler.go | 38 ++++++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/server/handler.go b/server/handler.go index fe1b50d700..09399c09cf 100644 --- a/server/handler.go +++ b/server/handler.go @@ -803,7 +803,34 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq defer timer.Stop() wg := sync.WaitGroup{} - wg.Add(1) + wg.Add(2) + + // Read rows from iter and send them off + var rowChan = make(chan sql.Row2, 512) + eg.Go(func() (err error) { + defer pan2err(&err) + defer wg.Done() + defer close(rowChan) + for { + select { + case <-ctx.Done(): + return context.Cause(ctx) + default: + row, err := iter.Next2(ctx) + if err == io.EOF { + return nil + } + if err != nil { + return err + } + select { + case rowChan <- row: + case <-ctx.Done(): + return nil + } + } + } + }) var res *sqltypes.Result var processedAtLeastOneBatch bool @@ -831,20 +858,15 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq case <-ctx.Done(): return context.Cause(ctx) case <-timer.C: - // TODO: timer should probably go in its own thread, as rowChan is blocking if h.readTimeout != 0 { // Cancel and return so Vitess can call the CloseConnection callback ctx.GetLogger().Tracef("connection timeout") return ErrRowTimeout.New() } - default: - row, err := iter.Next2(ctx) - if err == io.EOF { + case row, ok := <-rowChan: + if !ok { return nil } - if err != nil { - return err - } ctx.GetLogger().Tracef("spooling result row %s", row) res.Rows = append(res.Rows, row) res.RowsAffected++ From f185d4ad236ad8d5bc9ac767aaa9e25c2d12be50 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 14 Oct 2025 02:07:50 -0700 Subject: [PATCH 13/18] TODO --- server/handler.go | 1 + 1 file changed, 1 insertion(+) diff --git a/server/handler.go b/server/handler.go index 09399c09cf..1000b98ae6 100644 --- a/server/handler.go +++ b/server/handler.go @@ -805,6 +805,7 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq wg := sync.WaitGroup{} wg.Add(2) + // TODO: send results instead of rows? // Read rows from iter and send them off var rowChan = make(chan sql.Row2, 512) eg.Go(func() (err error) { From 70fea5c6efb3c202b6c0fc73e0a4c32fb2aecaa7 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 14 Oct 2025 16:03:58 -0700 Subject: [PATCH 14/18] small fixes --- sql/expression/literal.go | 4 ++-- sql/values/encoding.go | 28 +--------------------------- 2 files changed, 3 insertions(+), 29 deletions(-) diff --git a/sql/expression/literal.go b/sql/expression/literal.go index 1411ca0830..977a9f5506 100644 --- a/sql/expression/literal.go +++ b/sql/expression/literal.go @@ -154,8 +154,8 @@ func (lit *Literal) Type2() sql.Type2 { } // Value returns the literal value. -func (p *Literal) Value() interface{} { - return p.Val +func (lit *Literal) Value() interface{} { + return lit.Val } func (lit *Literal) WithResolvedChildren(children []any) (any, error) { diff --git a/sql/values/encoding.go b/sql/values/encoding.go index 3472e870e5..d00e630091 100644 --- a/sql/values/encoding.go +++ b/sql/values/encoding.go @@ -130,11 +130,7 @@ func ReadUint16(val []byte) uint16 { func ReadInt24(val []byte) (i int32) { expectSize(val, Int24Size) - var tmp [4]byte - // copy |val| to |tmp| - tmp[3], tmp[2] = val[3], val[2] - tmp[1], tmp[0] = val[1], val[0] - i = int32(binary.LittleEndian.Uint32(tmp[:])) + i = int32(binary.LittleEndian.Uint32([]byte{0, val[0], val[1], val[2]})) return } @@ -158,28 +154,6 @@ func ReadUint32(val []byte) uint32 { return binary.LittleEndian.Uint32(val) } -func ReadInt48(val []byte) (i int64) { - expectSize(val, Int48Size) - var tmp [8]byte - // copy |val| to |tmp| - tmp[5], tmp[4] = val[5], val[4] - tmp[3], tmp[2] = val[3], val[2] - tmp[1], tmp[0] = val[1], val[0] - i = int64(binary.LittleEndian.Uint64(tmp[:])) - return -} - -func ReadUint48(val []byte) (u uint64) { - expectSize(val, Uint48Size) - var tmp [8]byte - // copy |val| to |tmp| - tmp[5], tmp[4] = val[5], val[4] - tmp[3], tmp[2] = val[3], val[2] - tmp[1], tmp[0] = val[1], val[0] - u = binary.LittleEndian.Uint64(tmp[:]) - return -} - func ReadInt64(val []byte) int64 { expectSize(val, Int64Size) return int64(binary.LittleEndian.Uint64(val)) From 63813224735dacf47620d346b347b8ab85f81e61 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 15 Oct 2025 15:04:32 -0700 Subject: [PATCH 15/18] don't use vitess type --- server/handler.go | 8 +- sql/convert_value.go | 80 +++++++++++++---- sql/core.go | 3 +- sql/expression/comparison.go | 16 ++-- sql/expression/get_field.go | 5 +- sql/expression/literal.go | 5 +- sql/expression/unresolved.go | 3 +- sql/plan/filter.go | 2 +- sql/row_frame.go | 15 ++-- sql/rows.go | 24 ++--- sql/type.go | 6 +- sql/types/number.go | 164 ++++++++++++++++++++++------------- 12 files changed, 213 insertions(+), 118 deletions(-) diff --git a/server/handler.go b/server/handler.go index 1000b98ae6..1ac48b7c7f 100644 --- a/server/handler.go +++ b/server/handler.go @@ -868,8 +868,12 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq if !ok { return nil } - ctx.GetLogger().Tracef("spooling result row %s", row) - res.Rows = append(res.Rows, row) + resRow := make([]sqltypes.Value, len(row)) + for i, v := range row { + resRow[i] = sqltypes.MakeTrusted(v.Typ, v.Val) + } + ctx.GetLogger().Tracef("spooling result row %s", resRow) + res.Rows = append(res.Rows, resRow) res.RowsAffected++ if !timer.Stop() { <-timer.C diff --git a/sql/convert_value.go b/sql/convert_value.go index d64d5fbe98..880b9f2f58 100644 --- a/sql/convert_value.go +++ b/sql/convert_value.go @@ -5,44 +5,88 @@ import ( "github.com/dolthub/go-mysql-server/sql/values" - "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" ) // ConvertToValue converts the interface to a sql value. -func ConvertToValue(v interface{}) (sqltypes.Value, error) { +func ConvertToValue(v interface{}) (Value, error) { switch v := v.(type) { case nil: - return sqltypes.MakeTrusted(query.Type_NULL_TYPE, nil), nil + return Value{ + Typ: query.Type_NULL_TYPE, + Val: nil, + }, nil case int: - return sqltypes.MakeTrusted(query.Type_INT64, values.WriteInt64(make([]byte, values.Int64Size), int64(v))), nil + return Value{ + Typ: query.Type_INT64, + Val: values.WriteInt64(make([]byte, values.Int64Size), int64(v)), + }, nil case int8: - return sqltypes.MakeTrusted(query.Type_INT8, values.WriteInt8(make([]byte, values.Int8Size), v)), nil + return Value{ + Typ: query.Type_INT8, + Val: values.WriteInt8(make([]byte, values.Int8Size), v), + }, nil case int16: - return sqltypes.MakeTrusted(query.Type_INT16, values.WriteInt16(make([]byte, values.Int16Size), v)), nil + return Value{ + Typ: query.Type_INT16, + Val: values.WriteInt16(make([]byte, values.Int16Size), v), + }, nil case int32: - return sqltypes.MakeTrusted(query.Type_INT32, values.WriteInt32(make([]byte, values.Int32Size), v)), nil + return Value{ + Typ: query.Type_INT32, + Val: values.WriteInt32(make([]byte, values.Int32Size), v), + }, nil case int64: - return sqltypes.MakeTrusted(query.Type_INT64, values.WriteInt64(make([]byte, values.Int64Size), v)), nil + return Value{ + Typ: query.Type_INT64, + Val: values.WriteInt64(make([]byte, values.Int64Size), v), + }, nil case uint: - return sqltypes.MakeTrusted(query.Type_UINT64, values.WriteUint64(make([]byte, values.Uint64Size), uint64(v))), nil + return Value{ + Typ: query.Type_UINT64, + Val: values.WriteUint64(make([]byte, values.Uint64Size), uint64(v)), + }, nil case uint8: - return sqltypes.MakeTrusted(query.Type_UINT8, values.WriteUint8(make([]byte, values.Uint8Size), v)), nil + return Value{ + Typ: query.Type_UINT8, + Val: values.WriteUint8(make([]byte, values.Uint8Size), v), + }, nil case uint16: - return sqltypes.MakeTrusted(query.Type_UINT16, values.WriteUint16(make([]byte, values.Uint16Size), v)), nil + return Value{ + Typ: query.Type_UINT16, + Val: values.WriteUint16(make([]byte, values.Uint16Size), v), + }, nil case uint32: - return sqltypes.MakeTrusted(query.Type_UINT32, values.WriteUint32(make([]byte, values.Uint32Size), v)), nil + return Value{ + Typ: query.Type_UINT32, + Val: values.WriteUint32(make([]byte, values.Uint32Size), v), + }, nil case uint64: - return sqltypes.MakeTrusted(query.Type_UINT64, values.WriteUint64(make([]byte, values.Uint64Size), v)), nil + return Value{ + Typ: query.Type_UINT64, + Val: values.WriteUint64(make([]byte, values.Uint64Size), v), + }, nil case float32: - return sqltypes.MakeTrusted(query.Type_FLOAT32, values.WriteFloat32(make([]byte, values.Float32Size), v)), nil + return Value{ + Typ: query.Type_FLOAT32, + Val: values.WriteFloat32(make([]byte, values.Float32Size), v), + }, nil case float64: - return sqltypes.MakeTrusted(query.Type_FLOAT64, values.WriteFloat64(make([]byte, values.Float64Size), v)), nil + return Value{ + Typ: query.Type_FLOAT64, + Val: values.WriteFloat64(make([]byte, values.Float64Size), v), + }, nil case string: - return sqltypes.MakeTrusted(query.Type_VARCHAR, values.WriteString(make([]byte, len(v)), v, values.ByteOrderCollation)), nil + return Value{ + Typ: query.Type_VARCHAR, + Val: values.WriteString(make([]byte, len(v)), v, values.ByteOrderCollation), + }, nil case []byte: - return sqltypes.MakeTrusted(query.Type_BLOB, values.WriteBytes(make([]byte, len(v)), v, values.ByteOrderCollation)), nil + return Value{ + Typ: query.Type_BLOB, + Val: values.WriteBytes(make([]byte, len(v)), v, values.ByteOrderCollation), + }, nil default: - return sqltypes.Value{}, fmt.Errorf("type %T not implemented", v) + return Value{}, fmt.Errorf("type %T not implemented", v) } } diff --git a/sql/core.go b/sql/core.go index ee8d6e2f4d..c2996039eb 100644 --- a/sql/core.go +++ b/sql/core.go @@ -26,7 +26,6 @@ import ( "time" "unsafe" - "github.com/dolthub/vitess/go/sqltypes" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" @@ -465,7 +464,7 @@ func DebugString(nodeOrExpression interface{}) string { type Expression2 interface { Expression // Eval2 evaluates the given row frame and returns a result. - Eval2(ctx *Context, row Row2) (sqltypes.Value, error) + Eval2(ctx *Context, row Row2) (Value, error) // Type2 returns the expression type. Type2() Type2 IsExpr2() bool diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index d6d6d402ef..9eb2a422c4 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -17,7 +17,6 @@ package expression import ( "fmt" - "github.com/dolthub/vitess/go/sqltypes" querypb "github.com/dolthub/vitess/go/vt/proto/query" errors "gopkg.in/src-d/go-errors.v1" @@ -521,7 +520,7 @@ func (gt *GreaterThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) return result == 1, nil } -func (gt *GreaterThan) Eval2(ctx *sql.Context, row sql.Row2) (sqltypes.Value, error) { +func (gt *GreaterThan) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { l, ok := gt.Left().(sql.Expression2) if !ok { panic(fmt.Sprintf("%T does not implement sql.Expression2", gt.Left())) @@ -533,28 +532,31 @@ func (gt *GreaterThan) Eval2(ctx *sql.Context, row sql.Row2) (sqltypes.Value, er lv, err := l.Eval2(ctx, row) if err != nil { - return sqltypes.Value{}, err + return sql.Value{}, err } rv, err := r.Eval2(ctx, row) if err != nil { - return sqltypes.Value{}, err + return sql.Value{}, err } // TODO: just assume they are int64 l64, err := types.ConvertValueToInt64(types.NumberTypeImpl_{}, lv) if err != nil { - return sqltypes.Value{}, err + return sql.Value{}, err } r64, err := types.ConvertValueToInt64(types.NumberTypeImpl_{}, rv) if err != nil { - return sqltypes.Value{}, err + return sql.Value{}, err } var rb byte if l64 > r64 { rb = 1 } - ret := sqltypes.MakeTrusted(querypb.Type_INT8, []byte{rb}) + ret := sql.Value{ + Val: []byte{rb}, + Typ: querypb.Type_INT8, + } return ret, nil } diff --git a/sql/expression/get_field.go b/sql/expression/get_field.go index 2611858867..319406e073 100644 --- a/sql/expression/get_field.go +++ b/sql/expression/get_field.go @@ -18,7 +18,6 @@ import ( "fmt" "strings" - "github.com/dolthub/vitess/go/sqltypes" errors "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" @@ -150,9 +149,9 @@ func (p *GetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return row[p.fieldIndex], nil } -func (p *GetField) Eval2(ctx *sql.Context, row sql.Row2) (sqltypes.Value, error) { +func (p *GetField) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { if p.fieldIndex < 0 || p.fieldIndex >= row.Len() { - return sqltypes.Value{}, ErrIndexOutOfBounds.New(p.fieldIndex, row.Len()) + return sql.Value{}, ErrIndexOutOfBounds.New(p.fieldIndex, row.Len()) } return row[p.fieldIndex], nil } diff --git a/sql/expression/literal.go b/sql/expression/literal.go index 977a9f5506..cc74bd7dc6 100644 --- a/sql/expression/literal.go +++ b/sql/expression/literal.go @@ -18,7 +18,6 @@ import ( "fmt" "strings" - "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/shopspring/decimal" @@ -31,7 +30,7 @@ import ( type Literal struct { Val interface{} Typ sql.Type - val2 sqltypes.Value + val2 sql.Value } var _ sql.Expression = &Literal{} @@ -137,7 +136,7 @@ func (*Literal) Children() []sql.Expression { return nil } -func (lit *Literal) Eval2(ctx *sql.Context, row sql.Row2) (sqltypes.Value, error) { +func (lit *Literal) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { return lit.val2, nil } diff --git a/sql/expression/unresolved.go b/sql/expression/unresolved.go index a0df5ab12f..c421699722 100644 --- a/sql/expression/unresolved.go +++ b/sql/expression/unresolved.go @@ -18,7 +18,6 @@ import ( "fmt" "strings" - "github.com/dolthub/vitess/go/sqltypes" "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" @@ -72,7 +71,7 @@ func (*UnresolvedColumn) CollationCoercibility(ctx *sql.Context) (collation sql. return sql.Collation_binary, 7 } -func (uc *UnresolvedColumn) Eval2(ctx *sql.Context, row sql.Row2) (sqltypes.Value, error) { +func (uc *UnresolvedColumn) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { panic("unresolved column is a placeholder node, but Eval2 was called") } diff --git a/sql/plan/filter.go b/sql/plan/filter.go index 0d0e4ebe39..79fa7d14e5 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -149,7 +149,7 @@ func (i *FilterIter) Next2(ctx *sql.Context) (sql.Row2, error) { if err != nil { return nil, err } - if res.Raw()[0] == 1 { + if res.Val[0] == 1 { return row, nil } } diff --git a/sql/row_frame.go b/sql/row_frame.go index a4384ec458..ebb79682e4 100644 --- a/sql/row_frame.go +++ b/sql/row_frame.go @@ -17,7 +17,6 @@ package sql import ( "sync" - "github.com/dolthub/vitess/go/sqltypes" querypb "github.com/dolthub/vitess/go/vt/proto/query" ) @@ -27,10 +26,10 @@ const ( ) // Row2 is a slice of values -type Row2 []sqltypes.Value +type Row2 []Value // GetField returns the Value for the ith field in this row. -func (r Row2) GetField(i int) sqltypes.Value { +func (r Row2) GetField(i int) Value { return r[i] } @@ -98,7 +97,10 @@ func (f *RowFrame) Row2() Row2 { rs := make(Row2, len(f.Values)) for i := range f.Values { - rs[i] = sqltypes.MakeTrusted(f.Types[i], f.Values[i]) + rs[i] = Value{ + Val: f.Values[i], + Typ: f.Types[i], + } } return rs } @@ -111,7 +113,10 @@ func (f *RowFrame) Row2Copy() Row2 { for i := range f.Values { v := make(ValueBytes, len(f.Values[i])) copy(v, f.Values[i]) - rs[i] = sqltypes.MakeTrusted(f.Types[i], v) + rs[i] = Value{ + Val: v, + Typ: f.Types[i], + } } return rs } diff --git a/sql/rows.go b/sql/rows.go index 2a969363bc..191147ad68 100644 --- a/sql/rows.go +++ b/sql/rows.go @@ -123,29 +123,29 @@ func RowFromRow2(sch Schema, r Row2) Row { for i, col := range sch { switch col.Type.Type() { case query.Type_INT8: - row[i] = values.ReadInt8(r.GetField(i).Raw()) + row[i] = values.ReadInt8(r.GetField(i).Val) case query.Type_UINT8: - row[i] = values.ReadUint8(r.GetField(i).Raw()) + row[i] = values.ReadUint8(r.GetField(i).Val) case query.Type_INT16: - row[i] = values.ReadInt16(r.GetField(i).Raw()) + row[i] = values.ReadInt16(r.GetField(i).Val) case query.Type_UINT16: - row[i] = values.ReadUint16(r.GetField(i).Raw()) + row[i] = values.ReadUint16(r.GetField(i).Val) case query.Type_INT32: - row[i] = values.ReadInt32(r.GetField(i).Raw()) + row[i] = values.ReadInt32(r.GetField(i).Val) case query.Type_UINT32: - row[i] = values.ReadUint32(r.GetField(i).Raw()) + row[i] = values.ReadUint32(r.GetField(i).Val) case query.Type_INT64: - row[i] = values.ReadInt64(r.GetField(i).Raw()) + row[i] = values.ReadInt64(r.GetField(i).Val) case query.Type_UINT64: - row[i] = values.ReadUint64(r.GetField(i).Raw()) + row[i] = values.ReadUint64(r.GetField(i).Val) case query.Type_FLOAT32: - row[i] = values.ReadFloat32(r.GetField(i).Raw()) + row[i] = values.ReadFloat32(r.GetField(i).Val) case query.Type_FLOAT64: - row[i] = values.ReadFloat64(r.GetField(i).Raw()) + row[i] = values.ReadFloat64(r.GetField(i).Val) case query.Type_TEXT, query.Type_VARCHAR, query.Type_CHAR: - row[i] = values.ReadString(r.GetField(i).Raw(), values.ByteOrderCollation) + row[i] = values.ReadString(r.GetField(i).Val, values.ByteOrderCollation) case query.Type_BLOB, query.Type_VARBINARY, query.Type_BINARY: - row[i] = values.ReadBytes(r.GetField(i).Raw(), values.ByteOrderCollation) + row[i] = values.ReadBytes(r.GetField(i).Val, values.ByteOrderCollation) case query.Type_BIT: fallthrough case query.Type_ENUM: diff --git a/sql/type.go b/sql/type.go index 285744a564..6d9f9adb01 100644 --- a/sql/type.go +++ b/sql/type.go @@ -295,11 +295,11 @@ func IsDecimalType(t Type) bool { type Type2 interface { Type // Compare2 returns an integer comparing two Values. - Compare2(sqltypes.Value, sqltypes.Value) (int, error) + Compare2(Value, Value) (int, error) // Convert2 converts a value of a compatible type. - Convert2(sqltypes.Value) (sqltypes.Value, error) + Convert2(Value) (Value, error) // Zero2 returns the zero Value for this type. - Zero2() sqltypes.Value + Zero2() Value } // SpatialColumnType is a node that contains a reference to all spatial types. diff --git a/sql/types/number.go b/sql/types/number.go index ff05d914e8..07d0af6245 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -728,7 +728,7 @@ func (t NumberTypeImpl_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqlt return sqltypes.MakeTrusted(t.baseType, val), nil } -func (t NumberTypeImpl_) Compare2(a sqltypes.Value, b sqltypes.Value) (int, error) { +func (t NumberTypeImpl_) Compare2(a sql.Value, b sql.Value) (int, error) { switch t.baseType { case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint24, sqltypes.Uint32, sqltypes.Uint64: ca, err := convertValueToUint64(t, a) @@ -784,40 +784,84 @@ func (t NumberTypeImpl_) Compare2(a sqltypes.Value, b sqltypes.Value) (int, erro } } -func (t NumberTypeImpl_) Convert2(value sqltypes.Value) (sqltypes.Value, error) { +func (t NumberTypeImpl_) Convert2(value sql.Value) (sql.Value, error) { panic("implement me") } -func (t NumberTypeImpl_) Zero2() sqltypes.Value { +func (t NumberTypeImpl_) Zero2() sql.Value { switch t.baseType { case sqltypes.Int8: - return sqltypes.MakeTrusted(query.Type_INT8, make([]byte, values.Int8Size)) + x := values.WriteInt8(make([]byte, values.Int8Size), 0) + return sql.Value{ + Typ: query.Type_INT8, + Val: x, + } case sqltypes.Int16: - return sqltypes.MakeTrusted(query.Type_INT16, make([]byte, values.Int16Size)) + x := values.WriteInt16(make([]byte, values.Int16Size), 0) + return sql.Value{ + Typ: query.Type_INT16, + Val: x, + } case sqltypes.Int24: - return sqltypes.MakeTrusted(query.Type_INT24, make([]byte, values.Int24Size)) + x := values.WriteInt24(make([]byte, values.Int24Size), 0) + return sql.Value{ + Typ: query.Type_INT24, + Val: x, + } case sqltypes.Int32: - return sqltypes.MakeTrusted(query.Type_INT32, make([]byte, values.Int32Size)) + x := values.WriteInt32(make([]byte, values.Int32Size), 0) + return sql.Value{ + Typ: query.Type_INT32, + Val: x, + } case sqltypes.Int64: - return sqltypes.MakeTrusted(query.Type_INT64, make([]byte, values.Int64Size)) + x := values.WriteInt64(make([]byte, values.Int64Size), 0) + return sql.Value{ + Typ: query.Type_INT64, + Val: x, + } case sqltypes.Uint8: - return sqltypes.MakeTrusted(query.Type_UINT8, make([]byte, values.Uint8Size)) + x := values.WriteUint8(make([]byte, values.Uint8Size), 0) + return sql.Value{ + Typ: query.Type_UINT8, + Val: x, + } case sqltypes.Uint16: - return sqltypes.MakeTrusted(query.Type_UINT16, make([]byte, values.Uint16Size)) + x := values.WriteUint16(make([]byte, values.Uint16Size), 0) + return sql.Value{ + Typ: query.Type_UINT16, + Val: x, + } case sqltypes.Uint24: - return sqltypes.MakeTrusted(query.Type_UINT24, make([]byte, values.Uint24Size)) + x := values.WriteUint24(make([]byte, values.Uint24Size), 0) + return sql.Value{ + Typ: query.Type_UINT24, + Val: x, + } case sqltypes.Uint32: - return sqltypes.MakeTrusted(query.Type_UINT32, make([]byte, values.Uint32Size)) + x := values.WriteUint32(make([]byte, values.Uint32Size), 0) + return sql.Value{ + Typ: query.Type_UINT32, + Val: x, + } case sqltypes.Uint64: - return sqltypes.MakeTrusted(query.Type_UINT64, make([]byte, values.Uint64Size)) + x := values.WriteUint64(make([]byte, values.Uint64Size), 0) + return sql.Value{ + Typ: query.Type_UINT64, + Val: x, + } case sqltypes.Float32: - // TODO: 0 float32 is just 0? x := values.WriteFloat32(make([]byte, values.Float32Size), 0) - return sqltypes.MakeTrusted(query.Type_FLOAT32, x) + return sql.Value{ + Typ: query.Type_FLOAT32, + Val: x, + } case sqltypes.Float64: - // TODO: 0 float64 is just 0? - x := values.WriteFloat64(make([]byte, values.Float64Size), 0) - return sqltypes.MakeTrusted(query.Type_FLOAT64, x) + x := values.WriteUint64(make([]byte, values.Uint64Size), 0) + return sql.Value{ + Typ: query.Type_UINT64, + Val: x, + } default: panic(sql.ErrInvalidBaseType.New(t.baseType.String(), "number")) } @@ -1105,34 +1149,34 @@ func convertToInt64(t NumberTypeImpl_, v any, round Round) (int64, sql.ConvertIn } } -func ConvertValueToInt64(t NumberTypeImpl_, v sqltypes.Value) (int64, error) { - switch v.Type() { +func ConvertValueToInt64(t NumberTypeImpl_, v sql.Value) (int64, error) { + switch v.Typ { case query.Type_INT8: - return int64(values.ReadInt8(v.Raw())), nil + return int64(values.ReadInt8(v.Val)), nil case query.Type_INT16: - return int64(values.ReadInt16(v.Raw())), nil + return int64(values.ReadInt16(v.Val)), nil case query.Type_INT24: - return int64(values.ReadInt24(v.Raw())), nil + return int64(values.ReadInt24(v.Val)), nil case query.Type_INT32: - return int64(values.ReadInt32(v.Raw())), nil + return int64(values.ReadInt32(v.Val)), nil case query.Type_INT64: - return values.ReadInt64(v.Raw()), nil + return values.ReadInt64(v.Val), nil case query.Type_UINT8: - return int64(values.ReadUint8(v.Raw())), nil + return int64(values.ReadUint8(v.Val)), nil case query.Type_UINT16: - return int64(values.ReadUint16(v.Raw())), nil + return int64(values.ReadUint16(v.Val)), nil case query.Type_UINT24: - return int64(values.ReadUint24(v.Raw())), nil + return int64(values.ReadUint24(v.Val)), nil case query.Type_UINT32: - return int64(values.ReadUint32(v.Raw())), nil + return int64(values.ReadUint32(v.Val)), nil case query.Type_UINT64: - v := values.ReadUint64(v.Raw()) + v := values.ReadUint64(v.Val) if v > math.MaxInt64 { return math.MaxInt64, nil } return int64(v), nil case query.Type_FLOAT32: - v := values.ReadFloat32(v.Raw()) + v := values.ReadFloat32(v.Val) if v > float32(math.MaxInt64) { return math.MaxInt64, nil } else if v < float32(math.MinInt64) { @@ -1140,7 +1184,7 @@ func ConvertValueToInt64(t NumberTypeImpl_, v sqltypes.Value) (int64, error) { } return int64(math.Round(float64(v))), nil case query.Type_FLOAT64: - v := values.ReadFloat64(v.Raw()) + v := values.ReadFloat64(v.Val) if v > float64(math.MaxInt64) { return math.MaxInt64, nil } else if v < float64(math.MinInt64) { @@ -1153,36 +1197,36 @@ func ConvertValueToInt64(t NumberTypeImpl_, v sqltypes.Value) (int64, error) { } } -func convertValueToUint64(t NumberTypeImpl_, v sqltypes.Value) (uint64, error) { - switch v.Type() { +func convertValueToUint64(t NumberTypeImpl_, v sql.Value) (uint64, error) { + switch v.Typ { case query.Type_INT8: - return uint64(values.ReadInt8(v.Raw())), nil + return uint64(values.ReadInt8(v.Val)), nil case query.Type_INT16: - return uint64(values.ReadInt16(v.Raw())), nil + return uint64(values.ReadInt16(v.Val)), nil case query.Type_INT24: - return uint64(values.ReadInt24(v.Raw())), nil + return uint64(values.ReadInt24(v.Val)), nil case query.Type_INT32: - return uint64(values.ReadInt32(v.Raw())), nil + return uint64(values.ReadInt32(v.Val)), nil case query.Type_INT64: - return uint64(values.ReadInt64(v.Raw())), nil + return uint64(values.ReadInt64(v.Val)), nil case query.Type_UINT8: - return uint64(values.ReadUint8(v.Raw())), nil + return uint64(values.ReadUint8(v.Val)), nil case query.Type_UINT16: - return uint64(values.ReadUint16(v.Raw())), nil + return uint64(values.ReadUint16(v.Val)), nil case query.Type_UINT24: - return uint64(values.ReadUint24(v.Raw())), nil + return uint64(values.ReadUint24(v.Val)), nil case query.Type_UINT32: - return uint64(values.ReadUint32(v.Raw())), nil + return uint64(values.ReadUint32(v.Val)), nil case query.Type_UINT64: - return values.ReadUint64(v.Raw()), nil + return values.ReadUint64(v.Val), nil case query.Type_FLOAT32: - v := values.ReadFloat32(v.Raw()) + v := values.ReadFloat32(v.Val) if v >= float32(math.MaxUint64) { return math.MaxUint64, nil } return uint64(math.Round(float64(v))), nil case query.Type_FLOAT64: - v := values.ReadFloat64(v.Raw()) + v := values.ReadFloat64(v.Val) if v >= float64(math.MaxUint64) { return math.MaxUint64, nil } @@ -1375,32 +1419,32 @@ func convertToFloat64(t NumberTypeImpl_, v interface{}) (float64, error) { } } -func convertValueToFloat64(t NumberTypeImpl_, v sqltypes.Value) (float64, error) { - switch v.Type() { +func convertValueToFloat64(t NumberTypeImpl_, v sql.Value) (float64, error) { + switch v.Typ { case query.Type_INT8: - return float64(values.ReadInt8(v.Raw())), nil + return float64(values.ReadInt8(v.Val)), nil case query.Type_INT16: - return float64(values.ReadInt16(v.Raw())), nil + return float64(values.ReadInt16(v.Val)), nil case query.Type_INT24: - return float64(values.ReadInt24(v.Raw())), nil + return float64(values.ReadInt24(v.Val)), nil case query.Type_INT32: - return float64(values.ReadInt32(v.Raw())), nil + return float64(values.ReadInt32(v.Val)), nil case query.Type_INT64: - return float64(values.ReadInt64(v.Raw())), nil + return float64(values.ReadInt64(v.Val)), nil case query.Type_UINT8: - return float64(values.ReadUint8(v.Raw())), nil + return float64(values.ReadUint8(v.Val)), nil case query.Type_UINT16: - return float64(values.ReadUint16(v.Raw())), nil + return float64(values.ReadUint16(v.Val)), nil case query.Type_UINT24: - return float64(values.ReadUint24(v.Raw())), nil + return float64(values.ReadUint24(v.Val)), nil case query.Type_UINT32: - return float64(values.ReadUint32(v.Raw())), nil + return float64(values.ReadUint32(v.Val)), nil case query.Type_UINT64: - return float64(values.ReadUint64(v.Raw())), nil + return float64(values.ReadUint64(v.Val)), nil case query.Type_FLOAT32: - return float64(values.ReadFloat32(v.Raw())), nil + return float64(values.ReadFloat32(v.Val)), nil case query.Type_FLOAT64: - return values.ReadFloat64(v.Raw()), nil + return values.ReadFloat64(v.Val), nil default: panic(sql.ErrInvalidBaseType.New(t.baseType.String(), "number")) } From 64f38c9e5623f9a1c21d47b1acaa9b2899c7a8f3 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 15 Oct 2025 14:33:13 -0700 Subject: [PATCH 16/18] implementing but something is wrong --- server/handler.go | 151 ++++++++++++++++++++++++++++++- sql/expression/comparison.go | 1 - sql/plan/filter.go | 21 +++++ sql/plan/process.go | 17 ++++ sql/row_frame.go | 13 ++- sql/rowexec/transaction_iters.go | 8 ++ sql/rows.go | 5 + sql/table_iter.go | 42 +++++++++ 8 files changed, 249 insertions(+), 9 deletions(-) diff --git a/server/handler.go b/server/handler.go index 1ac48b7c7f..79d784e639 100644 --- a/server/handler.go +++ b/server/handler.go @@ -495,8 +495,8 @@ func (h *Handler) doQuery( r, err = resultForEmptyIter(sqlCtx, rowIter, resultFields) } else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) { r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf) - } else if ri2, ok := rowIter.(sql.RowIter2); ok && ri2.IsRowIter2(sqlCtx) { - r, processedAtLeastOneBatch, err = h.resultForDefaultIter2(sqlCtx, c, ri2, resultFields, callback, more) + } else if r2, ok := rowIter.(sql.RowIter2); ok && r2.IsRowIter2(sqlCtx) { + r, processedAtLeastOneBatch, err = h.resultForDefaultIter2(sqlCtx, c, r2, resultFields, callback, more) } else { r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf) } @@ -805,7 +805,6 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq wg := sync.WaitGroup{} wg.Add(2) - // TODO: send results instead of rows? // Read rows from iter and send them off var rowChan = make(chan sql.Row2, 512) eg.Go(func() (err error) { @@ -903,6 +902,152 @@ func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sq return res, processedAtLeastOneBatch, nil } +func (h *Handler) resultForRowFrameIter(ctx *sql.Context, c *mysql.Conn, iter sql.RowFrameIter, resultFields []*querypb.Field, callback func(*sqltypes.Result, bool) error, more bool) (*sqltypes.Result, bool, error) { + defer trace.StartRegion(ctx, "Handler.resultForRowFrameIter").End() + + eg, ctx := ctx.NewErrgroup() + pan2err := func(err *error) { + if recoveredPanic := recover(); recoveredPanic != nil { + stack := debug.Stack() + wrappedErr := fmt.Errorf("handler caught panic: %v\n%s", recoveredPanic, stack) + *err = goerrors.Join(*err, wrappedErr) + } + } + + // TODO: poll for closed connections should obviously also run even if + // we're doing something with an OK result or a single row result, etc. + // This should be in the caller. + pollCtx, cancelF := ctx.NewSubContext() + eg.Go(func() (err error) { + defer pan2err(&err) + return h.pollForClosedConnection(pollCtx, c) + }) + + // Default waitTime is one minute if there is no timeout configured, in which case + // it will loop to iterate again unless the socket died by the OS timeout or other problems. + // If there is a timeout, it will be enforced to ensure that Vitess has a chance to + // call Handler.CloseConnection() + waitTime := 1 * time.Minute + if h.readTimeout > 0 { + waitTime = h.readTimeout + } + timer := time.NewTimer(waitTime) + defer timer.Stop() + + wg := sync.WaitGroup{} + wg.Add(2) + + // TODO: send results instead of rows? + // Read rows from iter and send them off + var rowFrameChan = make(chan sql.Row2, 512) + //var rowFrameChan = make(chan *sql.RowFrame, 512) + eg.Go(func() (err error) { + defer pan2err(&err) + defer wg.Done() + defer close(rowFrameChan) + for { + select { + case <-ctx.Done(): + return context.Cause(ctx) + default: + //rowFrame := sql.NewRowFrame() + r2, ok := iter.(sql.RowIter2) + if !ok { + panic("aaaaaaasdfasdgsdfgsdfghsfgd") + } + row, err := r2.Next2(ctx) + if err == io.EOF { + return nil + } + if err != nil { + return err + } + // DEEP COPY HERE IS IMPORTANT! + //row := rowFrame.Row2Copy() + // Should be safe to release memory + //rowFrame.Recycle() + select { + case rowFrameChan <- row: + case <-ctx.Done(): + return nil + } + } + } + }) + + var res *sqltypes.Result + var processedAtLeastOneBatch bool + eg.Go(func() (err error) { + defer pan2err(&err) + defer cancelF() + defer wg.Done() + for { + if res == nil { + res = &sqltypes.Result{ + Fields: resultFields, + Rows: make([][]sqltypes.Value, 0, rowsBatch), + } + } + if res.RowsAffected == rowsBatch { + if err := callback(res, more); err != nil { + return err + } + res = nil + processedAtLeastOneBatch = true + continue + } + + select { + case <-ctx.Done(): + return context.Cause(ctx) + case <-timer.C: + if h.readTimeout != 0 { + // Cancel and return so Vitess can call the CloseConnection callback + ctx.GetLogger().Tracef("connection timeout") + return ErrRowTimeout.New() + } + case rowFrame, ok := <-rowFrameChan: + if !ok { + return nil + } + //panic(fmt.Sprintf("TESTING: %v", rowFrame.Types)) + row := rowFrame + resRow := make([]sqltypes.Value, len(row)) + for i, val := range row { + resRow[i] = sqltypes.MakeTrusted(val.Typ, val.Val) + } + panic("received?") + ctx.GetLogger().Tracef("spooling result row %s", resRow) + res.Rows = append(res.Rows, resRow) + res.RowsAffected++ + if !timer.Stop() { + <-timer.C + } + } + timer.Reset(waitTime) + } + }) + + // Close() kills this PID in the process list, + // wait until all rows have be sent over the wire + eg.Go(func() (err error) { + defer pan2err(&err) + wg.Wait() + return iter.Close(ctx) + }) + + err := eg.Wait() + if err != nil { + ctx.GetLogger().WithError(err).Warn("error running query") + if verboseErrorLogging { + fmt.Printf("Err: %+v", err) + } + return nil, false, err + } + + return res, processedAtLeastOneBatch, nil +} + // See https://dev.mysql.com/doc/internals/en/status-flags.html func setConnStatusFlags(ctx *sql.Context, c *mysql.Conn) error { ok, err := isSessionAutocommit(ctx) diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 9eb2a422c4..312405377d 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -552,7 +552,6 @@ func (gt *GreaterThan) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) if l64 > r64 { rb = 1 } - ret := sql.Value{ Val: []byte{rb}, Typ: querypb.Type_INT8, diff --git a/sql/plan/filter.go b/sql/plan/filter.go index 79fa7d14e5..d1d0284afd 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -111,6 +111,7 @@ type FilterIter struct { var _ sql.RowIter = (*FilterIter)(nil) var _ sql.RowIter2 = (*FilterIter)(nil) +var _ sql.RowFrameIter = (*FilterIter)(nil) // NewFilterIter creates a new FilterIter. func NewFilterIter( @@ -155,6 +156,26 @@ func (i *FilterIter) Next2(ctx *sql.Context) (sql.Row2, error) { } } +func (i *FilterIter) NextRowFrame(ctx *sql.Context, rowFrame *sql.RowFrame) error { + // TODO: this is trickier... + childIter := i.childIter.(sql.RowFrameIter) + for { + err := childIter.NextRowFrame(ctx, rowFrame) + if err != nil { + return err + } + row := rowFrame.Row2() + res, err := i.cond2.Eval2(ctx, row) + if err != nil { + return err + } + if res.Val[0] == 1 { + return nil + } + rowFrame.Clear() + } +} + func (i *FilterIter) IsRowIter2(ctx *sql.Context) bool { cond, ok := i.cond.(sql.Expression2) if !ok || !cond.IsExpr2() { diff --git a/sql/plan/process.go b/sql/plan/process.go index 70a687247f..d13c28d90b 100644 --- a/sql/plan/process.go +++ b/sql/plan/process.go @@ -234,6 +234,10 @@ type TrackedRowIter struct { ShouldSetFoundRows bool } +var _ sql.RowIter = (*TrackedRowIter)(nil) +var _ sql.RowIter2 = (*TrackedRowIter)(nil) +var _ sql.RowFrameIter = (*TrackedRowIter)(nil) + func NewTrackedRowIter( node sql.Node, iter sql.RowIter, @@ -339,6 +343,19 @@ func (i *TrackedRowIter) IsRowIter2(ctx *sql.Context) bool { return true } +func (i *TrackedRowIter) NextRowFrame(ctx *sql.Context, rowFrame *sql.RowFrame) error { + iter := i.iter.(sql.RowFrameIter) + err := iter.NextRowFrame(ctx, rowFrame) + if err != nil { + return err + } + i.numRows++ + if i.onNext != nil { + i.onNext() + } + return nil +} + func (i *TrackedRowIter) Close(ctx *sql.Context) error { err := i.iter.Close(ctx) diff --git a/sql/row_frame.go b/sql/row_frame.go index ebb79682e4..a548860168 100644 --- a/sql/row_frame.go +++ b/sql/row_frame.go @@ -57,6 +57,7 @@ type RowFrame struct { // Values are the values this row. Values []ValueBytes + // TODO: this isn't used anywhere // varr is used as the backing array for the |Values| // slice when len(Values) <= valueArrSize varr [valueArrSize][]ValueBytes @@ -131,6 +132,7 @@ func (f *RowFrame) Clear() { // Append appends the values given into this frame. func (f *RowFrame) Append(vals ...Value) { + // TODO: one big copy here would be better probably, need to benchmark for _, v := range vals { f.append(v) } @@ -146,11 +148,12 @@ func (f *RowFrame) AppendMany(types []querypb.Type, vals []ValueBytes) { func (f *RowFrame) append(v Value) { buf := f.getBuffer(v) - copy(buf, v.Val) + copy(buf, v.Val) // TODO: not necessary if we're not referencing the backing array v.Val = buf f.Types = append(f.Types, v.Typ) + // TODO: do this? // if |f.Values| grows past |len(f.varr)| // we'll allocate a new backing array here f.Values = append(f.Values, v.Val) @@ -162,6 +165,7 @@ func (f *RowFrame) appendTypeAndVal(typ querypb.Type, val ValueBytes) { f.Types = append(f.Types, typ) + // TODO: do this? // if |f.Values| grows past |len(f.varr)| // we'll allocate a new backing array here f.Values = append(f.Values, v) @@ -172,11 +176,10 @@ func (f *RowFrame) getBuffer(v Value) (buf []byte) { } func (f *RowFrame) bufferForBytes(v ValueBytes) (buf []byte) { - if f.checkCapacity(v) { + if f.hasCapacity(v) { start := f.off f.off += uint16(len(v)) - stop := f.off - buf = f.farr[start:stop] + buf = f.farr[start:f.off] } else { buf = make([]byte, len(v)) } @@ -184,6 +187,6 @@ func (f *RowFrame) bufferForBytes(v ValueBytes) (buf []byte) { return } -func (f *RowFrame) checkCapacity(v ValueBytes) bool { +func (f *RowFrame) hasCapacity(v ValueBytes) bool { return len(v) <= (len(f.farr) - int(f.off)) } diff --git a/sql/rowexec/transaction_iters.go b/sql/rowexec/transaction_iters.go index f0f56168ef..ec09a39868 100644 --- a/sql/rowexec/transaction_iters.go +++ b/sql/rowexec/transaction_iters.go @@ -77,6 +77,10 @@ type TransactionCommittingIter struct { implicitCommit bool } +var _ sql.RowIter = (*TransactionCommittingIter)(nil) +var _ sql.RowIter2 = (*TransactionCommittingIter)(nil) +var _ sql.RowFrameIter = (*TransactionCommittingIter)(nil) + func AddTransactionCommittingIter(ctx *sql.Context, qFlags *sql.QueryFlags, iter sql.RowIter) (sql.RowIter, error) { // TODO: This is a bit of a hack. Need to figure out better relationship between new transaction node and warnings. if (qFlags != nil && qFlags.IsSet(sql.QFlagShowWarnings)) || ctx.IsInterpreted() { @@ -113,6 +117,10 @@ func (t *TransactionCommittingIter) IsRowIter2(ctx *sql.Context) bool { return true } +func (t *TransactionCommittingIter) NextRowFrame(ctx *sql.Context, rowFrame *sql.RowFrame) error { + return t.childIter.(sql.RowFrameIter).NextRowFrame(ctx, rowFrame) +} + func (t *TransactionCommittingIter) Close(ctx *sql.Context) error { var err error if t.childIter != nil { diff --git a/sql/rows.go b/sql/rows.go index 191147ad68..cb453562f3 100644 --- a/sql/rows.go +++ b/sql/rows.go @@ -248,3 +248,8 @@ type MutableRowIter interface { GetChildIter() RowIter WithChildIter(childIter RowIter) RowIter } + +type RowFrameIter interface { + RowIter + NextRowFrame(ctx *Context, frame *RowFrame) error +} diff --git a/sql/table_iter.go b/sql/table_iter.go index 884778307a..78caf432cd 100644 --- a/sql/table_iter.go +++ b/sql/table_iter.go @@ -30,6 +30,8 @@ type TableRowIter struct { } var _ RowIter = (*TableRowIter)(nil) +var _ RowIter2 = (*TableRowIter)(nil) +var _ RowFrameIter = (*TableRowIter)(nil) // NewTableRowIter returns a new iterator over the rows in the partitions of the table given. func NewTableRowIter(ctx *Context, table Table, partitions PartitionIter) *TableRowIter { @@ -146,6 +148,46 @@ func (i *TableRowIter) IsRowIter2(ctx *Context) bool { return i.rows2.IsRowIter2(ctx) } +func (i *TableRowIter) NextRowFrame(ctx *Context, rowFrame *RowFrame) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + if i.partition == nil { + partition, err := i.partitions.Next(ctx) + if err != nil { + if err == io.EOF { + if e := i.partitions.Close(ctx); e != nil { + return e + } + } + return err + } + i.partition = partition + } + + if i.rows2 == nil { + rows, err := i.table.PartitionRows(ctx, i.partition) + if err != nil { + return err + } + i.rows2 = rows.(RowIter2) + } + + err := i.rows2.(RowFrameIter).NextRowFrame(ctx, rowFrame) + if err != nil && err == io.EOF { + if err = i.rows2.Close(ctx); err != nil { + return err + } + i.partition = nil + i.rows2 = nil + err = i.NextRowFrame(ctx, rowFrame) + } + return nil +} + func (i *TableRowIter) Close(ctx *Context) error { if i.rows != nil { if err := i.rows.Close(ctx); err != nil { From 0210056c1194cad4cf8ea239cf54bbe2d4280736 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 15 Oct 2025 15:43:41 -0700 Subject: [PATCH 17/18] try to fix --- server/handler.go | 32 +++++++++++++++----------------- sql/plan/filter.go | 6 +++++- sql/plan/process.go | 5 ++++- sql/rowexec/transaction_iters.go | 7 ++++++- sql/table_iter.go | 6 +++++- 5 files changed, 35 insertions(+), 21 deletions(-) diff --git a/server/handler.go b/server/handler.go index 79d784e639..303e81c481 100644 --- a/server/handler.go +++ b/server/handler.go @@ -496,7 +496,11 @@ func (h *Handler) doQuery( } else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) { r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf) } else if r2, ok := rowIter.(sql.RowIter2); ok && r2.IsRowIter2(sqlCtx) { - r, processedAtLeastOneBatch, err = h.resultForDefaultIter2(sqlCtx, c, r2, resultFields, callback, more) + if rf, ok := r2.(sql.RowFrameIter); ok { + r, processedAtLeastOneBatch, err = h.resultForRowFrameIter(sqlCtx, c, rf, resultFields, callback, more) + } else { + r, processedAtLeastOneBatch, err = h.resultForDefaultIter2(sqlCtx, c, r2, resultFields, callback, more) + } } else { r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf) } @@ -939,8 +943,7 @@ func (h *Handler) resultForRowFrameIter(ctx *sql.Context, c *mysql.Conn, iter sq // TODO: send results instead of rows? // Read rows from iter and send them off - var rowFrameChan = make(chan sql.Row2, 512) - //var rowFrameChan = make(chan *sql.RowFrame, 512) + var rowFrameChan = make(chan *sql.RowFrame, 512) eg.Go(func() (err error) { defer pan2err(&err) defer wg.Done() @@ -950,24 +953,16 @@ func (h *Handler) resultForRowFrameIter(ctx *sql.Context, c *mysql.Conn, iter sq case <-ctx.Done(): return context.Cause(ctx) default: - //rowFrame := sql.NewRowFrame() - r2, ok := iter.(sql.RowIter2) - if !ok { - panic("aaaaaaasdfasdgsdfgsdfghsfgd") - } - row, err := r2.Next2(ctx) + rowFrame := sql.NewRowFrame() + err := iter.NextRowFrame(ctx, rowFrame) if err == io.EOF { return nil } if err != nil { return err } - // DEEP COPY HERE IS IMPORTANT! - //row := rowFrame.Row2Copy() - // Should be safe to release memory - //rowFrame.Recycle() select { - case rowFrameChan <- row: + case rowFrameChan <- rowFrame: case <-ctx.Done(): return nil } @@ -1010,13 +1005,16 @@ func (h *Handler) resultForRowFrameIter(ctx *sql.Context, c *mysql.Conn, iter sq if !ok { return nil } - //panic(fmt.Sprintf("TESTING: %v", rowFrame.Types)) - row := rowFrame + // DEEP COPY HERE IS IMPORTANT! + row := rowFrame.Row2Copy() resRow := make([]sqltypes.Value, len(row)) for i, val := range row { resRow[i] = sqltypes.MakeTrusted(val.Typ, val.Val) } - panic("received?") + + // Should be safe to release memory + rowFrame.Recycle() + ctx.GetLogger().Tracef("spooling result row %s", resRow) res.Rows = append(res.Rows, resRow) res.RowsAffected++ diff --git a/sql/plan/filter.go b/sql/plan/filter.go index d1d0284afd..d0a8efe55c 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -15,6 +15,7 @@ package plan import ( + "fmt" "github.com/dolthub/go-mysql-server/sql" ) @@ -158,7 +159,10 @@ func (i *FilterIter) Next2(ctx *sql.Context) (sql.Row2, error) { func (i *FilterIter) NextRowFrame(ctx *sql.Context, rowFrame *sql.RowFrame) error { // TODO: this is trickier... - childIter := i.childIter.(sql.RowFrameIter) + childIter, ok := i.childIter.(sql.RowFrameIter) + if !ok { + panic(fmt.Sprintf("%T does not implement sql.RowFrameIter", i.childIter)) + } for { err := childIter.NextRowFrame(ctx, rowFrame) if err != nil { diff --git a/sql/plan/process.go b/sql/plan/process.go index d13c28d90b..d909d1a5fe 100644 --- a/sql/plan/process.go +++ b/sql/plan/process.go @@ -344,7 +344,10 @@ func (i *TrackedRowIter) IsRowIter2(ctx *sql.Context) bool { } func (i *TrackedRowIter) NextRowFrame(ctx *sql.Context, rowFrame *sql.RowFrame) error { - iter := i.iter.(sql.RowFrameIter) + iter, ok := i.iter.(sql.RowFrameIter) + if !ok { + panic(fmt.Sprintf("%T does not implement sql.RowFrameIter", i.iter)) + } err := iter.NextRowFrame(ctx, rowFrame) if err != nil { return err diff --git a/sql/rowexec/transaction_iters.go b/sql/rowexec/transaction_iters.go index ec09a39868..166e6b7f82 100644 --- a/sql/rowexec/transaction_iters.go +++ b/sql/rowexec/transaction_iters.go @@ -15,6 +15,7 @@ package rowexec import ( + "fmt" "io" "gopkg.in/src-d/go-errors.v1" @@ -118,7 +119,11 @@ func (t *TransactionCommittingIter) IsRowIter2(ctx *sql.Context) bool { } func (t *TransactionCommittingIter) NextRowFrame(ctx *sql.Context, rowFrame *sql.RowFrame) error { - return t.childIter.(sql.RowFrameIter).NextRowFrame(ctx, rowFrame) + childIter, ok := t.childIter.(sql.RowFrameIter) + if !ok { + panic(fmt.Sprintf("%T does not implement sql.RowFrameIter", t.childIter)) + } + return childIter.NextRowFrame(ctx, rowFrame) } func (t *TransactionCommittingIter) Close(ctx *sql.Context) error { diff --git a/sql/table_iter.go b/sql/table_iter.go index 78caf432cd..c9d057b9d4 100644 --- a/sql/table_iter.go +++ b/sql/table_iter.go @@ -175,8 +175,12 @@ func (i *TableRowIter) NextRowFrame(ctx *Context, rowFrame *RowFrame) error { } i.rows2 = rows.(RowIter2) } + rows, ok := i.rows2.(RowFrameIter) + if !ok { + panic(fmt.Sprintf("%T does not implement sql.RowFrameIter", i.rows)) + } - err := i.rows2.(RowFrameIter).NextRowFrame(ctx, rowFrame) + err := rows.NextRowFrame(ctx, rowFrame) if err != nil && err == io.EOF { if err = i.rows2.Close(ctx); err != nil { return err From aaa30654d7acadcd11f8f544c11ab47be6b05d22 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 15 Oct 2025 15:56:35 -0700 Subject: [PATCH 18/18] oh... --- sql/table_iter.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/table_iter.go b/sql/table_iter.go index c9d057b9d4..f06cd3b566 100644 --- a/sql/table_iter.go +++ b/sql/table_iter.go @@ -175,6 +175,7 @@ func (i *TableRowIter) NextRowFrame(ctx *Context, rowFrame *RowFrame) error { } i.rows2 = rows.(RowIter2) } + rows, ok := i.rows2.(RowFrameIter) if !ok { panic(fmt.Sprintf("%T does not implement sql.RowFrameIter", i.rows)) @@ -189,7 +190,7 @@ func (i *TableRowIter) NextRowFrame(ctx *Context, rowFrame *RowFrame) error { i.rows2 = nil err = i.NextRowFrame(ctx, rowFrame) } - return nil + return err } func (i *TableRowIter) Close(ctx *Context) error {