diff --git a/querier/flightsql.go b/querier/flightsql.go index da4ff31..f37d581 100644 --- a/querier/flightsql.go +++ b/querier/flightsql.go @@ -34,6 +34,13 @@ type FlightSQLServer struct { // Add result storage results map[string]arrow.Record resultsLock sync.RWMutex + + // Add meta result storage for GetTables/GetSqlInfo + metaResults map[string]struct { + cmdType string + params []byte // store the raw command bytes for later use + } + metaResultsLock sync.RWMutex } // mustEmbedUnimplementedFlightServiceServer implements the FlightServiceServer interface @@ -45,6 +52,10 @@ func NewFlightSQLServer(queryClient *QueryClient) *FlightSQLServer { queryClient: queryClient, mem: memory.DefaultAllocator, results: make(map[string]arrow.Record), + metaResults: make(map[string]struct { + cmdType string + params []byte + }), } } @@ -104,17 +115,15 @@ func (s *FlightSQLServer) GetFlightInfo(ctx context.Context, desc *flight.Flight log.Printf("GetFlightInfo called with descriptor type: %v, path: %v, cmd: %v", desc.Type, desc.Path, string(desc.Cmd)) - // Handle SQL query command if desc.Type == flight.DescriptorCMD { - // Unmarshal the Any message any := &anypb.Any{} if err := proto.Unmarshal(desc.Cmd, any); err != nil { log.Printf("Failed to unmarshal Any message: %v", err) return nil, fmt.Errorf("failed to unmarshal command: %w", err) } - // Check if this is a CommandStatementQuery - if any.TypeUrl == "type.googleapis.com/arrow.flight.protocol.sql.CommandStatementQuery" { + switch any.TypeUrl { + case "type.googleapis.com/arrow.flight.protocol.sql.CommandStatementQuery": // The query is in the Any message's value query := string(any.Value) // Clean up the query string @@ -146,23 +155,8 @@ func (s *FlightSQLServer) GetFlightInfo(ctx context.Context, desc *flight.Flight } } - // Parse the query to extract time range - parsed, err := s.queryClient.ParseQuery(query, dbName) - if err != nil { - log.Printf("Failed to parse query: %v", err) - return nil, fmt.Errorf("failed to parse query: %w", err) - } - - // Find relevant files based on the parsed query - files, err := s.queryClient.FindRelevantFiles(ctx, parsed.DbName, parsed.Measurement, parsed.TimeRange) - if err != nil { - log.Printf("Failed to find relevant files: %v", err) - return nil, fmt.Errorf("failed to find relevant files: %w", err) - } - log.Printf("Found %d relevant files for query", len(files)) - - // Execute the query using our existing QueryClient - results, err := s.queryClient.Query(ctx, query, parsed.DbName) // Use the parsed database name + // Use QueryClient.Query which now handles all fallback logic + results, err := s.queryClient.Query(ctx, query, dbName) if err != nil { log.Printf("Query execution failed: %v", err) return nil, fmt.Errorf("failed to execute query: %w", err) @@ -208,6 +202,48 @@ func (s *FlightSQLServer) GetFlightInfo(ctx context.Context, desc *flight.Flight log.Printf("Returning flight info with %d records", recordBatch.NumRows()) return info, nil + case "type.googleapis.com/arrow.flight.protocol.sql.CommandGetTables": + // Generate a unique ticket + ticketID := fmt.Sprintf("get-tables-%d", time.Now().UnixNano()) + s.metaResultsLock.Lock() + s.metaResults[ticketID] = struct { + cmdType string + params []byte + }{cmdType: "getTables", params: any.Value} + s.metaResultsLock.Unlock() + ticket := &flight.Ticket{Ticket: []byte(ticketID)} + info := &flight.FlightInfo{ + FlightDescriptor: desc, + Endpoint: []*flight.FlightEndpoint{{ + Ticket: ticket, + Location: []*flight.Location{{Uri: "grpc://localhost:8082"}}, + }}, + TotalRecords: -1, + TotalBytes: -1, + Schema: []byte{}, + } + return info, nil + case "type.googleapis.com/arrow.flight.protocol.sql.CommandGetSqlInfo": + // Generate a unique ticket + ticketID := fmt.Sprintf("get-sqlinfo-%d", time.Now().UnixNano()) + s.metaResultsLock.Lock() + s.metaResults[ticketID] = struct { + cmdType string + params []byte + }{cmdType: "getSqlInfo", params: any.Value} + s.metaResultsLock.Unlock() + ticket := &flight.Ticket{Ticket: []byte(ticketID)} + info := &flight.FlightInfo{ + FlightDescriptor: desc, + Endpoint: []*flight.FlightEndpoint{{ + Ticket: ticket, + Location: []*flight.Location{{Uri: "grpc://localhost:8082"}}, + }}, + TotalRecords: -1, + TotalBytes: -1, + Schema: []byte{}, + } + return info, nil } } @@ -268,6 +304,76 @@ func (s *FlightSQLServer) GetFlightInfoStatement(ctx context.Context, cmd *fligh func (s *FlightSQLServer) DoGet(ticket *flight.Ticket, stream flight.FlightService_DoGetServer) error { log.Printf("DoGet called with ticket: %v", string(ticket.Ticket)) + // Check if this is a meta result (GetTables/GetSqlInfo) + s.metaResultsLock.RLock() + meta, isMeta := s.metaResults[string(ticket.Ticket)] + s.metaResultsLock.RUnlock() + if isMeta { + // Print incoming metadata for debugging + if md, ok := metadata.FromIncomingContext(stream.Context()); ok { + log.Printf("DoGet metadata: %v", md) + } + switch meta.cmdType { + case "getTables": + // Enumerate real tables (directories) in the default database + dbName := "default" + // Use QueryClient logic to list tables + entries, err := s.queryClient.Query(context.Background(), "SHOW TABLES", dbName) + if err != nil { + return fmt.Errorf("failed to enumerate tables: %w", err) + } + schema := arrow.NewSchema([]arrow.Field{ + {Name: "catalog_name", Type: arrow.BinaryTypes.String, Nullable: true}, + {Name: "db_schema_name", Type: arrow.BinaryTypes.String, Nullable: true}, + {Name: "table_name", Type: arrow.BinaryTypes.String, Nullable: false}, + {Name: "table_type", Type: arrow.BinaryTypes.String, Nullable: false}, + }, nil) + b0 := array.NewStringBuilder(memory.DefaultAllocator) + b1 := array.NewStringBuilder(memory.DefaultAllocator) + b2 := array.NewStringBuilder(memory.DefaultAllocator) + b3 := array.NewStringBuilder(memory.DefaultAllocator) + for _, row := range entries { + tableName, _ := row["table_name"].(string) + b0.Append("") // catalog_name + b1.Append("") // db_schema_name + b2.Append(tableName) + b3.Append("BASE TABLE") + } + record := array.NewRecord(schema, []arrow.Array{b0.NewArray(), b1.NewArray(), b2.NewArray(), b3.NewArray()}, int64(len(entries))) + defer record.Release() + writer := flight.NewRecordWriter(stream, ipc.WithSchema(schema)) + err = writer.Write(record) + if err != nil { + return fmt.Errorf("failed to write getTables record: %w", err) + } + s.metaResultsLock.Lock() + delete(s.metaResults, string(ticket.Ticket)) + s.metaResultsLock.Unlock() + return writer.Close() + case "getSqlInfo": + // For now, return a single info value (e.g., server name) + schema := arrow.NewSchema([]arrow.Field{ + {Name: "info_name", Type: arrow.PrimitiveTypes.Uint32, Nullable: false}, + {Name: "value", Type: arrow.BinaryTypes.String, Nullable: true}, + }, nil) + b0 := array.NewUint32Builder(memory.DefaultAllocator) + b1 := array.NewStringBuilder(memory.DefaultAllocator) + b0.Append(1) // e.g., SQL_SERVER_NAME + b1.Append("Gigapi-Querier") + record := array.NewRecord(schema, []arrow.Array{b0.NewArray(), b1.NewArray()}, 1) + defer record.Release() + writer := flight.NewRecordWriter(stream, ipc.WithSchema(schema)) + err := writer.Write(record) + if err != nil { + return fmt.Errorf("failed to write getSqlInfo record: %w", err) + } + s.metaResultsLock.Lock() + delete(s.metaResults, string(ticket.Ticket)) + s.metaResultsLock.Unlock() + return writer.Close() + } + } + // Get the results from storage s.resultsLock.RLock() recordBatch, exists := s.results[string(ticket.Ticket)] diff --git a/querier/queryClient.go b/querier/queryClient.go index 354d620..ce89906 100644 --- a/querier/queryClient.go +++ b/querier/queryClient.go @@ -665,12 +665,65 @@ func (c *QueryClient) Query(ctx context.Context, query, dbName string) ([]map[st } } return results, nil + + // Special InfluxDB IOx compatibility: information_schema.tables + case "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES": + // Return SHOW TABLES for the given dbName, but format as []map[string]interface{}{"table_name": ...} + tables, err := c.Query(ctx, "SHOW TABLES", dbName) + if err != nil { + return nil, err + } + // Only return the "table_name" field, as expected by the client + var result []map[string]interface{} + for _, t := range tables { + if name, ok := t["table_name"]; ok { + result = append(result, map[string]interface{}{"table_name": name}) + } + } + return result, nil } // Parse the query parsed, err := c.ParseQuery(query, dbName) if err != nil { - return nil, err + // Fallback: Directly execute the query in DuckDB if ParseQuery fails + stmt, err2 := c.DB.Prepare(query) + if err2 != nil { + return nil, fmt.Errorf("failed to prepare query: %v", err2) + } + defer stmt.Close() + + rows, err2 := stmt.Query() + if err2 != nil { + return nil, fmt.Errorf("query execution failed: %v", err2) + } + defer rows.Close() + + columns, err2 := rows.Columns() + if err2 != nil { + return nil, fmt.Errorf("failed to get columns: %v", err2) + } + + var result []map[string]interface{} + for rows.Next() { + values := make([]interface{}, len(columns)) + valuePtrs := make([]interface{}, len(columns)) + for i := range values { + valuePtrs[i] = &values[i] + } + if err := rows.Scan(valuePtrs...); err != nil { + return nil, fmt.Errorf("error scanning row: %v", err) + } + row := make(map[string]interface{}) + for i, col := range columns { + row[col] = values[i] + } + result = append(result, row) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating rows: %v", err) + } + return result, nil } // Find relevant files