diff --git a/server/handler.go b/server/handler.go index 2275ca7a2d..303e81c481 100644 --- a/server/handler.go +++ b/server/handler.go @@ -495,6 +495,12 @@ func (h *Handler) doQuery( r, err = resultForEmptyIter(sqlCtx, rowIter, resultFields) } else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) { r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf) + } else if r2, ok := rowIter.(sql.RowIter2); ok && r2.IsRowIter2(sqlCtx) { + if rf, ok := r2.(sql.RowFrameIter); ok { + r, processedAtLeastOneBatch, err = h.resultForRowFrameIter(sqlCtx, c, rf, resultFields, callback, more) + } else { + r, processedAtLeastOneBatch, err = h.resultForDefaultIter2(sqlCtx, c, r2, resultFields, callback, more) + } } else { r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf) } @@ -768,6 +774,278 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s return r, processedAtLeastOneBatch, nil } +func (h *Handler) resultForDefaultIter2(ctx *sql.Context, c *mysql.Conn, iter sql.RowIter2, resultFields []*querypb.Field, callback func(*sqltypes.Result, bool) error, more bool) (*sqltypes.Result, bool, error) { + defer trace.StartRegion(ctx, "Handler.resultForDefaultIter2").End() + + eg, ctx := ctx.NewErrgroup() + pan2err := func(err *error) { + if recoveredPanic := recover(); recoveredPanic != nil { + stack := debug.Stack() + wrappedErr := fmt.Errorf("handler caught panic: %v\n%s", recoveredPanic, stack) + *err = goerrors.Join(*err, wrappedErr) + } + } + + // TODO: poll for closed connections should obviously also run even if + // we're doing something with an OK result or a single row result, etc. + // This should be in the caller. + pollCtx, cancelF := ctx.NewSubContext() + eg.Go(func() (err error) { + defer pan2err(&err) + return h.pollForClosedConnection(pollCtx, c) + }) + + // Default waitTime is one minute if there is no timeout configured, in which case + // it will loop to iterate again unless the socket died by the OS timeout or other problems. + // If there is a timeout, it will be enforced to ensure that Vitess has a chance to + // call Handler.CloseConnection() + waitTime := 1 * time.Minute + if h.readTimeout > 0 { + waitTime = h.readTimeout + } + timer := time.NewTimer(waitTime) + defer timer.Stop() + + wg := sync.WaitGroup{} + wg.Add(2) + + // Read rows from iter and send them off + var rowChan = make(chan sql.Row2, 512) + eg.Go(func() (err error) { + defer pan2err(&err) + defer wg.Done() + defer close(rowChan) + for { + select { + case <-ctx.Done(): + return context.Cause(ctx) + default: + row, err := iter.Next2(ctx) + if err == io.EOF { + return nil + } + if err != nil { + return err + } + select { + case rowChan <- row: + case <-ctx.Done(): + return nil + } + } + } + }) + + var res *sqltypes.Result + var processedAtLeastOneBatch bool + eg.Go(func() (err error) { + defer pan2err(&err) + defer cancelF() + defer wg.Done() + for { + if res == nil { + res = &sqltypes.Result{ + Fields: resultFields, + Rows: make([][]sqltypes.Value, 0, rowsBatch), + } + } + if res.RowsAffected == rowsBatch { + if err := callback(res, more); err != nil { + return err + } + res = nil + processedAtLeastOneBatch = true + continue + } + + select { + case <-ctx.Done(): + return context.Cause(ctx) + case <-timer.C: + if h.readTimeout != 0 { + // Cancel and return so Vitess can call the CloseConnection callback + ctx.GetLogger().Tracef("connection timeout") + return ErrRowTimeout.New() + } + case row, ok := <-rowChan: + if !ok { + return nil + } + resRow := make([]sqltypes.Value, len(row)) + for i, v := range row { + resRow[i] = sqltypes.MakeTrusted(v.Typ, v.Val) + } + ctx.GetLogger().Tracef("spooling result row %s", resRow) + res.Rows = append(res.Rows, resRow) + res.RowsAffected++ + if !timer.Stop() { + <-timer.C + } + } + timer.Reset(waitTime) + } + }) + + // Close() kills this PID in the process list, + // wait until all rows have be sent over the wire + eg.Go(func() (err error) { + defer pan2err(&err) + wg.Wait() + return iter.Close(ctx) + }) + + err := eg.Wait() + if err != nil { + ctx.GetLogger().WithError(err).Warn("error running query") + if verboseErrorLogging { + fmt.Printf("Err: %+v", err) + } + return nil, false, err + } + + return res, processedAtLeastOneBatch, nil +} + +func (h *Handler) resultForRowFrameIter(ctx *sql.Context, c *mysql.Conn, iter sql.RowFrameIter, resultFields []*querypb.Field, callback func(*sqltypes.Result, bool) error, more bool) (*sqltypes.Result, bool, error) { + defer trace.StartRegion(ctx, "Handler.resultForRowFrameIter").End() + + eg, ctx := ctx.NewErrgroup() + pan2err := func(err *error) { + if recoveredPanic := recover(); recoveredPanic != nil { + stack := debug.Stack() + wrappedErr := fmt.Errorf("handler caught panic: %v\n%s", recoveredPanic, stack) + *err = goerrors.Join(*err, wrappedErr) + } + } + + // TODO: poll for closed connections should obviously also run even if + // we're doing something with an OK result or a single row result, etc. + // This should be in the caller. + pollCtx, cancelF := ctx.NewSubContext() + eg.Go(func() (err error) { + defer pan2err(&err) + return h.pollForClosedConnection(pollCtx, c) + }) + + // Default waitTime is one minute if there is no timeout configured, in which case + // it will loop to iterate again unless the socket died by the OS timeout or other problems. + // If there is a timeout, it will be enforced to ensure that Vitess has a chance to + // call Handler.CloseConnection() + waitTime := 1 * time.Minute + if h.readTimeout > 0 { + waitTime = h.readTimeout + } + timer := time.NewTimer(waitTime) + defer timer.Stop() + + wg := sync.WaitGroup{} + wg.Add(2) + + // TODO: send results instead of rows? + // Read rows from iter and send them off + var rowFrameChan = make(chan *sql.RowFrame, 512) + eg.Go(func() (err error) { + defer pan2err(&err) + defer wg.Done() + defer close(rowFrameChan) + for { + select { + case <-ctx.Done(): + return context.Cause(ctx) + default: + rowFrame := sql.NewRowFrame() + err := iter.NextRowFrame(ctx, rowFrame) + if err == io.EOF { + return nil + } + if err != nil { + return err + } + select { + case rowFrameChan <- rowFrame: + case <-ctx.Done(): + return nil + } + } + } + }) + + var res *sqltypes.Result + var processedAtLeastOneBatch bool + eg.Go(func() (err error) { + defer pan2err(&err) + defer cancelF() + defer wg.Done() + for { + if res == nil { + res = &sqltypes.Result{ + Fields: resultFields, + Rows: make([][]sqltypes.Value, 0, rowsBatch), + } + } + if res.RowsAffected == rowsBatch { + if err := callback(res, more); err != nil { + return err + } + res = nil + processedAtLeastOneBatch = true + continue + } + + select { + case <-ctx.Done(): + return context.Cause(ctx) + case <-timer.C: + if h.readTimeout != 0 { + // Cancel and return so Vitess can call the CloseConnection callback + ctx.GetLogger().Tracef("connection timeout") + return ErrRowTimeout.New() + } + case rowFrame, ok := <-rowFrameChan: + if !ok { + return nil + } + // DEEP COPY HERE IS IMPORTANT! + row := rowFrame.Row2Copy() + resRow := make([]sqltypes.Value, len(row)) + for i, val := range row { + resRow[i] = sqltypes.MakeTrusted(val.Typ, val.Val) + } + + // Should be safe to release memory + rowFrame.Recycle() + + ctx.GetLogger().Tracef("spooling result row %s", resRow) + res.Rows = append(res.Rows, resRow) + res.RowsAffected++ + if !timer.Stop() { + <-timer.C + } + } + timer.Reset(waitTime) + } + }) + + // Close() kills this PID in the process list, + // wait until all rows have be sent over the wire + eg.Go(func() (err error) { + defer pan2err(&err) + wg.Wait() + return iter.Close(ctx) + }) + + err := eg.Wait() + if err != nil { + ctx.GetLogger().WithError(err).Warn("error running query") + if verboseErrorLogging { + fmt.Printf("Err: %+v", err) + } + return nil, false, err + } + + return res, processedAtLeastOneBatch, nil +} + // See https://dev.mysql.com/doc/internals/en/status-flags.html func setConnStatusFlags(ctx *sql.Context, c *mysql.Conn) error { ok, err := isSessionAutocommit(ctx) diff --git a/sql/convert_value.go b/sql/convert_value.go index d46fe4de4e..880b9f2f58 100644 --- a/sql/convert_value.go +++ b/sql/convert_value.go @@ -3,9 +3,9 @@ package sql import ( "fmt" - "github.com/dolthub/vitess/go/vt/proto/query" - "github.com/dolthub/go-mysql-server/sql/values" + + "github.com/dolthub/vitess/go/vt/proto/query" ) // ConvertToValue converts the interface to a sql value. @@ -90,11 +90,3 @@ func ConvertToValue(v interface{}) (Value, error) { return Value{}, fmt.Errorf("type %T not implemented", v) } } - -func MustConvertToValue(v interface{}) Value { - ret, err := ConvertToValue(v) - if err != nil { - panic(err) - } - return ret -} diff --git a/sql/core.go b/sql/core.go index c1e1f90b2a..c2996039eb 100644 --- a/sql/core.go +++ b/sql/core.go @@ -467,6 +467,7 @@ type Expression2 interface { Eval2(ctx *Context, row Row2) (Value, error) // Type2 returns the expression type. Type2() Type2 + IsExpr2() bool } var SystemVariables SystemVariableRegistry diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 97e5f4ba6e..312405377d 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -17,6 +17,7 @@ package expression import ( "fmt" + querypb "github.com/dolthub/vitess/go/vt/proto/query" errors "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" @@ -492,6 +493,7 @@ type GreaterThan struct { } var _ sql.Expression = (*GreaterThan)(nil) +var _ sql.Expression2 = (*GreaterThan)(nil) var _ sql.CollationCoercible = (*GreaterThan)(nil) // NewGreaterThan creates a new GreaterThan expression. @@ -518,6 +520,67 @@ func (gt *GreaterThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) return result == 1, nil } +func (gt *GreaterThan) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { + l, ok := gt.Left().(sql.Expression2) + if !ok { + panic(fmt.Sprintf("%T does not implement sql.Expression2", gt.Left())) + } + r, ok := gt.Right().(sql.Expression2) + if !ok { + panic(fmt.Sprintf("%T does not implement sql.Expression2", gt.Right())) + } + + lv, err := l.Eval2(ctx, row) + if err != nil { + return sql.Value{}, err + } + rv, err := r.Eval2(ctx, row) + if err != nil { + return sql.Value{}, err + } + + // TODO: just assume they are int64 + l64, err := types.ConvertValueToInt64(types.NumberTypeImpl_{}, lv) + if err != nil { + return sql.Value{}, err + } + r64, err := types.ConvertValueToInt64(types.NumberTypeImpl_{}, rv) + if err != nil { + return sql.Value{}, err + } + var rb byte + if l64 > r64 { + rb = 1 + } + ret := sql.Value{ + Val: []byte{rb}, + Typ: querypb.Type_INT8, + } + return ret, nil +} + +func (gt *GreaterThan) Type2() sql.Type2 { + return nil +} + +func (gt *GreaterThan) IsExpr2() bool { + lExpr, isExpr2 := gt.Left().(sql.Expression2) + if !isExpr2 { + return false + } + if !lExpr.IsExpr2() { + return false + } + rExpr, isExpr2 := gt.Right().(sql.Expression2) + if !isExpr2 { + return false + } + if !rExpr.IsExpr2() { + return false + } + return true +} + // WithChildren implements the Expression interface. func (gt *GreaterThan) WithChildren(children ...sql.Expression) (sql.Expression, error) { if len(children) != 2 { diff --git a/sql/expression/get_field.go b/sql/expression/get_field.go index f4ff9b429e..319406e073 100644 --- a/sql/expression/get_field.go +++ b/sql/expression/get_field.go @@ -153,8 +153,11 @@ func (p *GetField) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { if p.fieldIndex < 0 || p.fieldIndex >= row.Len() { return sql.Value{}, ErrIndexOutOfBounds.New(p.fieldIndex, row.Len()) } + return row[p.fieldIndex], nil +} - return row.GetField(p.fieldIndex), nil +func (p *GetField) IsExpr2() bool { + return true } // WithChildren implements the Expression interface. diff --git a/sql/expression/literal.go b/sql/expression/literal.go index 8fff9557a7..cc74bd7dc6 100644 --- a/sql/expression/literal.go +++ b/sql/expression/literal.go @@ -140,6 +140,10 @@ func (lit *Literal) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { return lit.val2, nil } +func (lit *Literal) IsExpr2() bool { + return true +} + func (lit *Literal) Type2() sql.Type2 { t2, ok := lit.Typ.(sql.Type2) if !ok { @@ -149,8 +153,8 @@ func (lit *Literal) Type2() sql.Type2 { } // Value returns the literal value. -func (p *Literal) Value() interface{} { - return p.Val +func (lit *Literal) Value() interface{} { + return lit.Val } func (lit *Literal) WithResolvedChildren(children []any) (any, error) { diff --git a/sql/expression/unresolved.go b/sql/expression/unresolved.go index 78e2e9d0b9..c421699722 100644 --- a/sql/expression/unresolved.go +++ b/sql/expression/unresolved.go @@ -79,6 +79,10 @@ func (uc *UnresolvedColumn) Type2() sql.Type2 { panic("unresolved column is a placeholder node, but Type2 was called") } +func (uc *UnresolvedColumn) IsExpr2() bool { + panic("unresolved column is a placeholder node, but IsExpr2 was called") +} + // Name implements the Nameable interface. func (uc *UnresolvedColumn) Name() string { return uc.name } diff --git a/sql/plan/filter.go b/sql/plan/filter.go index f2c0691112..d0a8efe55c 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -15,6 +15,7 @@ package plan import ( + "fmt" "github.com/dolthub/go-mysql-server/sql" ) @@ -104,8 +105,15 @@ func (f *Filter) Expressions() []sql.Expression { type FilterIter struct { cond sql.Expression childIter sql.RowIter + + cond2 sql.Expression2 + childIter2 sql.RowIter2 } +var _ sql.RowIter = (*FilterIter)(nil) +var _ sql.RowIter2 = (*FilterIter)(nil) +var _ sql.RowFrameIter = (*FilterIter)(nil) + // NewFilterIter creates a new FilterIter. func NewFilterIter( cond sql.Expression, @@ -133,6 +141,59 @@ func (i *FilterIter) Next(ctx *sql.Context) (sql.Row, error) { } } +func (i *FilterIter) Next2(ctx *sql.Context) (sql.Row2, error) { + for { + row, err := i.childIter2.Next2(ctx) + if err != nil { + return nil, err + } + res, err := i.cond2.Eval2(ctx, row) + if err != nil { + return nil, err + } + if res.Val[0] == 1 { + return row, nil + } + } +} + +func (i *FilterIter) NextRowFrame(ctx *sql.Context, rowFrame *sql.RowFrame) error { + // TODO: this is trickier... + childIter, ok := i.childIter.(sql.RowFrameIter) + if !ok { + panic(fmt.Sprintf("%T does not implement sql.RowFrameIter", i.childIter)) + } + for { + err := childIter.NextRowFrame(ctx, rowFrame) + if err != nil { + return err + } + row := rowFrame.Row2() + res, err := i.cond2.Eval2(ctx, row) + if err != nil { + return err + } + if res.Val[0] == 1 { + return nil + } + rowFrame.Clear() + } +} + +func (i *FilterIter) IsRowIter2(ctx *sql.Context) bool { + cond, ok := i.cond.(sql.Expression2) + if !ok || !cond.IsExpr2() { + return false + } + childIter, ok := i.childIter.(sql.RowIter2) + if !ok || !childIter.IsRowIter2(ctx) { + return false + } + i.cond2 = cond + i.childIter2 = childIter + return true +} + // Close implements the RowIter interface. func (i *FilterIter) Close(ctx *sql.Context) error { return i.childIter.Close(ctx) diff --git a/sql/plan/process.go b/sql/plan/process.go index ee95249f10..d909d1a5fe 100644 --- a/sql/plan/process.go +++ b/sql/plan/process.go @@ -226,6 +226,7 @@ const ( type TrackedRowIter struct { node sql.Node iter sql.RowIter + iter2 sql.RowIter2 onDone NotifyFunc onNext NotifyFunc numRows int64 @@ -233,6 +234,10 @@ type TrackedRowIter struct { ShouldSetFoundRows bool } +var _ sql.RowIter = (*TrackedRowIter)(nil) +var _ sql.RowIter2 = (*TrackedRowIter)(nil) +var _ sql.RowFrameIter = (*TrackedRowIter)(nil) + func NewTrackedRowIter( node sql.Node, iter sql.RowIter, @@ -317,6 +322,43 @@ func (i *TrackedRowIter) Next(ctx *sql.Context) (sql.Row, error) { return row, nil } +func (i *TrackedRowIter) Next2(ctx *sql.Context) (sql.Row2, error) { + row, err := i.iter2.Next2(ctx) + if err != nil { + return nil, err + } + i.numRows++ + if i.onNext != nil { + i.onNext() + } + return row, nil +} + +func (i *TrackedRowIter) IsRowIter2(ctx *sql.Context) bool { + iter, ok := i.iter.(sql.RowIter2) + if !ok || !iter.IsRowIter2(ctx) { + return false + } + i.iter2 = iter + return true +} + +func (i *TrackedRowIter) NextRowFrame(ctx *sql.Context, rowFrame *sql.RowFrame) error { + iter, ok := i.iter.(sql.RowFrameIter) + if !ok { + panic(fmt.Sprintf("%T does not implement sql.RowFrameIter", i.iter)) + } + err := iter.NextRowFrame(ctx, rowFrame) + if err != nil { + return err + } + i.numRows++ + if i.onNext != nil { + i.onNext() + } + return nil +} + func (i *TrackedRowIter) Close(ctx *sql.Context) error { err := i.iter.Close(ctx) diff --git a/sql/row_frame.go b/sql/row_frame.go index ef3ea6010f..a548860168 100644 --- a/sql/row_frame.go +++ b/sql/row_frame.go @@ -57,6 +57,7 @@ type RowFrame struct { // Values are the values this row. Values []ValueBytes + // TODO: this isn't used anywhere // varr is used as the backing array for the |Values| // slice when len(Values) <= valueArrSize varr [valueArrSize][]ValueBytes @@ -98,8 +99,8 @@ func (f *RowFrame) Row2() Row2 { rs := make(Row2, len(f.Values)) for i := range f.Values { rs[i] = Value{ - Typ: f.Types[i], Val: f.Values[i], + Typ: f.Types[i], } } return rs @@ -114,8 +115,8 @@ func (f *RowFrame) Row2Copy() Row2 { v := make(ValueBytes, len(f.Values[i])) copy(v, f.Values[i]) rs[i] = Value{ - Typ: f.Types[i], Val: v, + Typ: f.Types[i], } } return rs @@ -131,6 +132,7 @@ func (f *RowFrame) Clear() { // Append appends the values given into this frame. func (f *RowFrame) Append(vals ...Value) { + // TODO: one big copy here would be better probably, need to benchmark for _, v := range vals { f.append(v) } @@ -146,11 +148,12 @@ func (f *RowFrame) AppendMany(types []querypb.Type, vals []ValueBytes) { func (f *RowFrame) append(v Value) { buf := f.getBuffer(v) - copy(buf, v.Val) + copy(buf, v.Val) // TODO: not necessary if we're not referencing the backing array v.Val = buf f.Types = append(f.Types, v.Typ) + // TODO: do this? // if |f.Values| grows past |len(f.varr)| // we'll allocate a new backing array here f.Values = append(f.Values, v.Val) @@ -162,6 +165,7 @@ func (f *RowFrame) appendTypeAndVal(typ querypb.Type, val ValueBytes) { f.Types = append(f.Types, typ) + // TODO: do this? // if |f.Values| grows past |len(f.varr)| // we'll allocate a new backing array here f.Values = append(f.Values, v) @@ -172,11 +176,10 @@ func (f *RowFrame) getBuffer(v Value) (buf []byte) { } func (f *RowFrame) bufferForBytes(v ValueBytes) (buf []byte) { - if f.checkCapacity(v) { + if f.hasCapacity(v) { start := f.off f.off += uint16(len(v)) - stop := f.off - buf = f.farr[start:stop] + buf = f.farr[start:f.off] } else { buf = make([]byte, len(v)) } @@ -184,6 +187,6 @@ func (f *RowFrame) bufferForBytes(v ValueBytes) (buf []byte) { return } -func (f *RowFrame) checkCapacity(v ValueBytes) bool { +func (f *RowFrame) hasCapacity(v ValueBytes) bool { return len(v) <= (len(f.farr) - int(f.off)) } diff --git a/sql/rowexec/transaction_iters.go b/sql/rowexec/transaction_iters.go index db69cf5327..166e6b7f82 100644 --- a/sql/rowexec/transaction_iters.go +++ b/sql/rowexec/transaction_iters.go @@ -15,6 +15,7 @@ package rowexec import ( + "fmt" "io" "gopkg.in/src-d/go-errors.v1" @@ -71,11 +72,16 @@ func getLockableTable(table sql.Table) (sql.Lockable, error) { // during the Close() operation type TransactionCommittingIter struct { childIter sql.RowIter + childIter2 sql.RowIter2 transactionDatabase string autoCommit bool implicitCommit bool } +var _ sql.RowIter = (*TransactionCommittingIter)(nil) +var _ sql.RowIter2 = (*TransactionCommittingIter)(nil) +var _ sql.RowFrameIter = (*TransactionCommittingIter)(nil) + func AddTransactionCommittingIter(ctx *sql.Context, qFlags *sql.QueryFlags, iter sql.RowIter) (sql.RowIter, error) { // TODO: This is a bit of a hack. Need to figure out better relationship between new transaction node and warnings. if (qFlags != nil && qFlags.IsSet(sql.QFlagShowWarnings)) || ctx.IsInterpreted() { @@ -99,6 +105,27 @@ func (t *TransactionCommittingIter) Next(ctx *sql.Context) (sql.Row, error) { return t.childIter.Next(ctx) } +func (t *TransactionCommittingIter) Next2(ctx *sql.Context) (sql.Row2, error) { + return t.childIter2.Next2(ctx) +} + +func (t *TransactionCommittingIter) IsRowIter2(ctx *sql.Context) bool { + childIter, ok := t.childIter.(sql.RowIter2) + if !ok || !childIter.IsRowIter2(ctx) { + return false + } + t.childIter2 = childIter + return true +} + +func (t *TransactionCommittingIter) NextRowFrame(ctx *sql.Context, rowFrame *sql.RowFrame) error { + childIter, ok := t.childIter.(sql.RowFrameIter) + if !ok { + panic(fmt.Sprintf("%T does not implement sql.RowFrameIter", t.childIter)) + } + return childIter.NextRowFrame(ctx, rowFrame) +} + func (t *TransactionCommittingIter) Close(ctx *sql.Context) error { var err error if t.childIter != nil { diff --git a/sql/rows.go b/sql/rows.go index a9e5f55d5c..cb453562f3 100644 --- a/sql/rows.go +++ b/sql/rows.go @@ -92,6 +92,12 @@ type RowIter interface { Closer } +type RowIter2 interface { + RowIter + Next2(ctx *Context) (Row2, error) + IsRowIter2(ctx *Context) bool +} + // RowIterToRows converts a row iterator to a slice of rows. func RowIterToRows(ctx *Context, i RowIter) ([]Row, error) { var rows []Row @@ -112,7 +118,7 @@ func RowIterToRows(ctx *Context, i RowIter) ([]Row, error) { return rows, i.Close(ctx) } -func rowFromRow2(sch Schema, r Row2) Row { +func RowFromRow2(sch Schema, r Row2) Row { row := make(Row, len(sch)) for i, col := range sch { switch col.Type.Type() { @@ -242,3 +248,8 @@ type MutableRowIter interface { GetChildIter() RowIter WithChildIter(childIter RowIter) RowIter } + +type RowFrameIter interface { + RowIter + NextRowFrame(ctx *Context, frame *RowFrame) error +} diff --git a/sql/table_iter.go b/sql/table_iter.go index e302d5428a..f06cd3b566 100644 --- a/sql/table_iter.go +++ b/sql/table_iter.go @@ -15,6 +15,7 @@ package sql import ( + "fmt" "io" ) @@ -24,9 +25,13 @@ type TableRowIter struct { partitions PartitionIter partition Partition rows RowIter + + rows2 RowIter2 } var _ RowIter = (*TableRowIter)(nil) +var _ RowIter2 = (*TableRowIter)(nil) +var _ RowFrameIter = (*TableRowIter)(nil) // NewTableRowIter returns a new iterator over the rows in the partitions of the table given. func NewTableRowIter(ctx *Context, table Table, partitions PartitionIter) *TableRowIter { @@ -76,6 +81,118 @@ func (i *TableRowIter) Next(ctx *Context) (Row, error) { return row, err } +func (i *TableRowIter) Next2(ctx *Context) (Row2, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + if i.partition == nil { + partition, err := i.partitions.Next(ctx) + if err != nil { + if err == io.EOF { + if e := i.partitions.Close(ctx); e != nil { + return nil, e + } + } + + return nil, err + } + + i.partition = partition + } + + if i.rows2 == nil { + rows, err := i.table.PartitionRows(ctx, i.partition) + if err != nil { + return nil, err + } + ri2, ok := rows.(RowIter2) + if !ok || !ri2.IsRowIter2(ctx) { + panic(fmt.Sprintf("%T does not implement RowIter2", rows)) + } + i.rows2 = ri2 + } + + row, err := i.rows2.Next2(ctx) + if err != nil && err == io.EOF { + if err = i.rows2.Close(ctx); err != nil { + return nil, err + } + i.partition = nil + i.rows2 = nil + row, err = i.Next2(ctx) + } + return row, err +} + +func (i *TableRowIter) IsRowIter2(ctx *Context) bool { + if i.partition == nil { + partition, err := i.partitions.Next(ctx) + if err != nil { + return false + } + i.partition = partition + } + if i.rows2 == nil { + rows, err := i.table.PartitionRows(ctx, i.partition) + if err != nil { + return false + } + ri2, ok := rows.(RowIter2) + if !ok { + return false + } + i.rows2 = ri2 + } + return i.rows2.IsRowIter2(ctx) +} + +func (i *TableRowIter) NextRowFrame(ctx *Context, rowFrame *RowFrame) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + if i.partition == nil { + partition, err := i.partitions.Next(ctx) + if err != nil { + if err == io.EOF { + if e := i.partitions.Close(ctx); e != nil { + return e + } + } + return err + } + i.partition = partition + } + + if i.rows2 == nil { + rows, err := i.table.PartitionRows(ctx, i.partition) + if err != nil { + return err + } + i.rows2 = rows.(RowIter2) + } + + rows, ok := i.rows2.(RowFrameIter) + if !ok { + panic(fmt.Sprintf("%T does not implement sql.RowFrameIter", i.rows)) + } + + err := rows.NextRowFrame(ctx, rowFrame) + if err != nil && err == io.EOF { + if err = i.rows2.Close(ctx); err != nil { + return err + } + i.partition = nil + i.rows2 = nil + err = i.NextRowFrame(ctx, rowFrame) + } + return err +} + func (i *TableRowIter) Close(ctx *Context) error { if i.rows != nil { if err := i.rows.Close(ctx); err != nil { diff --git a/sql/type.go b/sql/type.go index 59af5360f1..6d9f9adb01 100644 --- a/sql/type.go +++ b/sql/type.go @@ -294,15 +294,12 @@ func IsDecimalType(t Type) bool { type Type2 interface { Type - // Compare2 returns an integer comparing two Values. Compare2(Value, Value) (int, error) // Convert2 converts a value of a compatible type. Convert2(Value) (Value, error) // Zero2 returns the zero Value for this type. Zero2() Value - // SQL2 returns the sqltypes.Value for the given value - SQL2(Value) (sqltypes.Value, error) } // SpatialColumnType is a node that contains a reference to all spatial types. diff --git a/sql/types/number.go b/sql/types/number.go index 1601be830a..07d0af6245 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -765,11 +765,11 @@ func (t NumberTypeImpl_) Compare2(a sql.Value, b sql.Value) (int, error) { } return +1, nil default: - ca, err := convertValueToInt64(t, a) + ca, err := ConvertValueToInt64(t, a) if err != nil { return 0, err } - cb, err := convertValueToInt64(t, b) + cb, err := ConvertValueToInt64(t, b) if err != nil { return 0, err } @@ -1149,7 +1149,7 @@ func convertToInt64(t NumberTypeImpl_, v any, round Round) (int64, sql.ConvertIn } } -func convertValueToInt64(t NumberTypeImpl_, v sql.Value) (int64, error) { +func ConvertValueToInt64(t NumberTypeImpl_, v sql.Value) (int64, error) { switch v.Typ { case query.Type_INT8: return int64(values.ReadInt8(v.Val)), nil diff --git a/sql/values/encoding.go b/sql/values/encoding.go index 3472e870e5..d00e630091 100644 --- a/sql/values/encoding.go +++ b/sql/values/encoding.go @@ -130,11 +130,7 @@ func ReadUint16(val []byte) uint16 { func ReadInt24(val []byte) (i int32) { expectSize(val, Int24Size) - var tmp [4]byte - // copy |val| to |tmp| - tmp[3], tmp[2] = val[3], val[2] - tmp[1], tmp[0] = val[1], val[0] - i = int32(binary.LittleEndian.Uint32(tmp[:])) + i = int32(binary.LittleEndian.Uint32([]byte{0, val[0], val[1], val[2]})) return } @@ -158,28 +154,6 @@ func ReadUint32(val []byte) uint32 { return binary.LittleEndian.Uint32(val) } -func ReadInt48(val []byte) (i int64) { - expectSize(val, Int48Size) - var tmp [8]byte - // copy |val| to |tmp| - tmp[5], tmp[4] = val[5], val[4] - tmp[3], tmp[2] = val[3], val[2] - tmp[1], tmp[0] = val[1], val[0] - i = int64(binary.LittleEndian.Uint64(tmp[:])) - return -} - -func ReadUint48(val []byte) (u uint64) { - expectSize(val, Uint48Size) - var tmp [8]byte - // copy |val| to |tmp| - tmp[5], tmp[4] = val[5], val[4] - tmp[3], tmp[2] = val[3], val[2] - tmp[1], tmp[0] = val[1], val[0] - u = binary.LittleEndian.Uint64(tmp[:]) - return -} - func ReadInt64(val []byte) int64 { expectSize(val, Int64Size) return int64(binary.LittleEndian.Uint64(val))