Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 127 additions & 21 deletions querier/flightsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ type FlightSQLServer struct {
// Add result storage
results map[string]arrow.Record
resultsLock sync.RWMutex

// Add meta result storage for GetTables/GetSqlInfo
metaResults map[string]struct {
cmdType string
params []byte // store the raw command bytes for later use
}
metaResultsLock sync.RWMutex
}

// mustEmbedUnimplementedFlightServiceServer implements the FlightServiceServer interface
Expand All @@ -45,6 +52,10 @@ func NewFlightSQLServer(queryClient *QueryClient) *FlightSQLServer {
queryClient: queryClient,
mem: memory.DefaultAllocator,
results: make(map[string]arrow.Record),
metaResults: make(map[string]struct {
cmdType string
params []byte
}),
}
}

Expand Down Expand Up @@ -104,17 +115,15 @@ func (s *FlightSQLServer) GetFlightInfo(ctx context.Context, desc *flight.Flight
log.Printf("GetFlightInfo called with descriptor type: %v, path: %v, cmd: %v",
desc.Type, desc.Path, string(desc.Cmd))

// Handle SQL query command
if desc.Type == flight.DescriptorCMD {
// Unmarshal the Any message
any := &anypb.Any{}
if err := proto.Unmarshal(desc.Cmd, any); err != nil {
log.Printf("Failed to unmarshal Any message: %v", err)
return nil, fmt.Errorf("failed to unmarshal command: %w", err)
}

// Check if this is a CommandStatementQuery
if any.TypeUrl == "type.googleapis.com/arrow.flight.protocol.sql.CommandStatementQuery" {
switch any.TypeUrl {
case "type.googleapis.com/arrow.flight.protocol.sql.CommandStatementQuery":
// The query is in the Any message's value
query := string(any.Value)
// Clean up the query string
Expand Down Expand Up @@ -146,23 +155,8 @@ func (s *FlightSQLServer) GetFlightInfo(ctx context.Context, desc *flight.Flight
}
}

// Parse the query to extract time range
parsed, err := s.queryClient.ParseQuery(query, dbName)
if err != nil {
log.Printf("Failed to parse query: %v", err)
return nil, fmt.Errorf("failed to parse query: %w", err)
}

// Find relevant files based on the parsed query
files, err := s.queryClient.FindRelevantFiles(ctx, parsed.DbName, parsed.Measurement, parsed.TimeRange)
if err != nil {
log.Printf("Failed to find relevant files: %v", err)
return nil, fmt.Errorf("failed to find relevant files: %w", err)
}
log.Printf("Found %d relevant files for query", len(files))

// Execute the query using our existing QueryClient
results, err := s.queryClient.Query(ctx, query, parsed.DbName) // Use the parsed database name
// Use QueryClient.Query which now handles all fallback logic
results, err := s.queryClient.Query(ctx, query, dbName)
if err != nil {
log.Printf("Query execution failed: %v", err)
return nil, fmt.Errorf("failed to execute query: %w", err)
Expand Down Expand Up @@ -208,6 +202,48 @@ func (s *FlightSQLServer) GetFlightInfo(ctx context.Context, desc *flight.Flight

log.Printf("Returning flight info with %d records", recordBatch.NumRows())
return info, nil
case "type.googleapis.com/arrow.flight.protocol.sql.CommandGetTables":
// Generate a unique ticket
ticketID := fmt.Sprintf("get-tables-%d", time.Now().UnixNano())
s.metaResultsLock.Lock()
s.metaResults[ticketID] = struct {
cmdType string
params []byte
}{cmdType: "getTables", params: any.Value}
s.metaResultsLock.Unlock()
ticket := &flight.Ticket{Ticket: []byte(ticketID)}
info := &flight.FlightInfo{
FlightDescriptor: desc,
Endpoint: []*flight.FlightEndpoint{{
Ticket: ticket,
Location: []*flight.Location{{Uri: "grpc://localhost:8082"}},
}},
TotalRecords: -1,
TotalBytes: -1,
Schema: []byte{},
}
return info, nil
case "type.googleapis.com/arrow.flight.protocol.sql.CommandGetSqlInfo":
// Generate a unique ticket
ticketID := fmt.Sprintf("get-sqlinfo-%d", time.Now().UnixNano())
s.metaResultsLock.Lock()
s.metaResults[ticketID] = struct {
cmdType string
params []byte
}{cmdType: "getSqlInfo", params: any.Value}
s.metaResultsLock.Unlock()
ticket := &flight.Ticket{Ticket: []byte(ticketID)}
info := &flight.FlightInfo{
FlightDescriptor: desc,
Endpoint: []*flight.FlightEndpoint{{
Ticket: ticket,
Location: []*flight.Location{{Uri: "grpc://localhost:8082"}},
}},
TotalRecords: -1,
TotalBytes: -1,
Schema: []byte{},
}
return info, nil
}
}

Expand Down Expand Up @@ -268,6 +304,76 @@ func (s *FlightSQLServer) GetFlightInfoStatement(ctx context.Context, cmd *fligh
func (s *FlightSQLServer) DoGet(ticket *flight.Ticket, stream flight.FlightService_DoGetServer) error {
log.Printf("DoGet called with ticket: %v", string(ticket.Ticket))

// Check if this is a meta result (GetTables/GetSqlInfo)
s.metaResultsLock.RLock()
meta, isMeta := s.metaResults[string(ticket.Ticket)]
s.metaResultsLock.RUnlock()
if isMeta {
// Print incoming metadata for debugging
if md, ok := metadata.FromIncomingContext(stream.Context()); ok {
log.Printf("DoGet metadata: %v", md)
}
switch meta.cmdType {
case "getTables":
// Enumerate real tables (directories) in the default database
dbName := "default"
// Use QueryClient logic to list tables
entries, err := s.queryClient.Query(context.Background(), "SHOW TABLES", dbName)
if err != nil {
return fmt.Errorf("failed to enumerate tables: %w", err)
}
schema := arrow.NewSchema([]arrow.Field{
{Name: "catalog_name", Type: arrow.BinaryTypes.String, Nullable: true},
{Name: "db_schema_name", Type: arrow.BinaryTypes.String, Nullable: true},
{Name: "table_name", Type: arrow.BinaryTypes.String, Nullable: false},
{Name: "table_type", Type: arrow.BinaryTypes.String, Nullable: false},
}, nil)
b0 := array.NewStringBuilder(memory.DefaultAllocator)
b1 := array.NewStringBuilder(memory.DefaultAllocator)
b2 := array.NewStringBuilder(memory.DefaultAllocator)
b3 := array.NewStringBuilder(memory.DefaultAllocator)
for _, row := range entries {
tableName, _ := row["table_name"].(string)
b0.Append("") // catalog_name
b1.Append("") // db_schema_name
b2.Append(tableName)
b3.Append("BASE TABLE")
}
record := array.NewRecord(schema, []arrow.Array{b0.NewArray(), b1.NewArray(), b2.NewArray(), b3.NewArray()}, int64(len(entries)))
defer record.Release()
writer := flight.NewRecordWriter(stream, ipc.WithSchema(schema))
err = writer.Write(record)
if err != nil {
return fmt.Errorf("failed to write getTables record: %w", err)
}
s.metaResultsLock.Lock()
delete(s.metaResults, string(ticket.Ticket))
s.metaResultsLock.Unlock()
return writer.Close()
case "getSqlInfo":
// For now, return a single info value (e.g., server name)
schema := arrow.NewSchema([]arrow.Field{
{Name: "info_name", Type: arrow.PrimitiveTypes.Uint32, Nullable: false},
{Name: "value", Type: arrow.BinaryTypes.String, Nullable: true},
}, nil)
b0 := array.NewUint32Builder(memory.DefaultAllocator)
b1 := array.NewStringBuilder(memory.DefaultAllocator)
b0.Append(1) // e.g., SQL_SERVER_NAME
b1.Append("Gigapi-Querier")
record := array.NewRecord(schema, []arrow.Array{b0.NewArray(), b1.NewArray()}, 1)
defer record.Release()
writer := flight.NewRecordWriter(stream, ipc.WithSchema(schema))
err := writer.Write(record)
if err != nil {
return fmt.Errorf("failed to write getSqlInfo record: %w", err)
}
s.metaResultsLock.Lock()
delete(s.metaResults, string(ticket.Ticket))
s.metaResultsLock.Unlock()
return writer.Close()
}
}

// Get the results from storage
s.resultsLock.RLock()
recordBatch, exists := s.results[string(ticket.Ticket)]
Expand Down
55 changes: 54 additions & 1 deletion querier/queryClient.go
Original file line number Diff line number Diff line change
Expand Up @@ -665,12 +665,65 @@ func (c *QueryClient) Query(ctx context.Context, query, dbName string) ([]map[st
}
}
return results, nil

// Special InfluxDB IOx compatibility: information_schema.tables
case "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES":
// Return SHOW TABLES for the given dbName, but format as []map[string]interface{}{"table_name": ...}
tables, err := c.Query(ctx, "SHOW TABLES", dbName)
if err != nil {
return nil, err
}
// Only return the "table_name" field, as expected by the client
var result []map[string]interface{}
for _, t := range tables {
if name, ok := t["table_name"]; ok {
result = append(result, map[string]interface{}{"table_name": name})
}
}
return result, nil
}

// Parse the query
parsed, err := c.ParseQuery(query, dbName)
if err != nil {
return nil, err
// Fallback: Directly execute the query in DuckDB if ParseQuery fails
stmt, err2 := c.DB.Prepare(query)
if err2 != nil {
return nil, fmt.Errorf("failed to prepare query: %v", err2)
}
defer stmt.Close()

rows, err2 := stmt.Query()
if err2 != nil {
return nil, fmt.Errorf("query execution failed: %v", err2)
}
defer rows.Close()

columns, err2 := rows.Columns()
if err2 != nil {
return nil, fmt.Errorf("failed to get columns: %v", err2)
}

var result []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
return nil, fmt.Errorf("error scanning row: %v", err)
}
row := make(map[string]interface{})
for i, col := range columns {
row[col] = values[i]
}
result = append(result, row)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating rows: %v", err)
}
return result, nil
}

// Find relevant files
Expand Down
Loading