11#include " vector_data_generator.h"
22
3+ #include < ydb/library/yql/udfs/common/knn/knn-serializer-shared.h>
4+
35#include < contrib/libs/apache/arrow/cpp/src/arrow/array/array_binary.h>
46#include < contrib/libs/apache/arrow/cpp/src/arrow/array/array_nested.h>
57#include < contrib/libs/apache/arrow/cpp/src/arrow/array/array_primitive.h>
68#include < contrib/libs/apache/arrow/cpp/src/arrow/array/builder_binary.h>
9+ #include < contrib/libs/apache/arrow/cpp/src/arrow/chunked_array.h>
710#include < contrib/libs/apache/arrow/cpp/src/arrow/compute/cast.h>
11+ #include < contrib/libs/apache/arrow/cpp/src/arrow/csv/api.h>
12+ #include < contrib/libs/apache/arrow/cpp/src/arrow/csv/reader.h>
13+ #include < contrib/libs/apache/arrow/cpp/src/arrow/csv/writer.h>
814#include < contrib/libs/apache/arrow/cpp/src/arrow/io/memory.h>
915#include < contrib/libs/apache/arrow/cpp/src/arrow/ipc/dictionary.h>
1016#include < contrib/libs/apache/arrow/cpp/src/arrow/ipc/reader.h>
1117#include < contrib/libs/apache/arrow/cpp/src/arrow/ipc/writer.h>
1218#include < contrib/libs/apache/arrow/cpp/src/arrow/record_batch.h>
19+ #include < contrib/libs/apache/arrow/cpp/src/arrow/table.h>
1320#include < contrib/libs/apache/arrow/cpp/src/arrow/type.h>
1421#include < contrib/libs/apache/arrow/cpp/src/arrow/type_fwd.h>
1522
23+ #include < util/stream/mem.h>
24+
1625namespace NYdbWorkload {
1726
1827namespace {
@@ -35,39 +44,48 @@ class TTransformingDataGenerator final: public IBulkDataGenerator {
3544 return std::make_pair (schema, recordBatch);
3645 }
3746
38- void ConvertArrow (TDataPortion::TArrow* data) {
47+ static std::shared_ptr<arrow::Table> Deserialize (TDataPortion::TCsv* data) {
48+ auto bufferReader = std::make_shared<arrow::io::BufferReader>(arrow::util::string_view (data->Data .data (), data->Data .size ()));
49+ auto csvReader = arrow::csv::TableReader::Make (
50+ arrow::io::default_io_context (),
51+ bufferReader,
52+ arrow::csv::ReadOptions::Defaults (),
53+ arrow::csv::ParseOptions::Defaults (),
54+ arrow::csv::ConvertOptions::Defaults ()
55+ ).ValueOrDie ();
56+
57+ return csvReader->Read ().ValueOrDie ();
58+ }
59+
60+ void TransformArrow (TDataPortion::TArrow* data) {
3961 const auto [schema, batch] = Deserialize (data);
4062
41- // extract
63+ // id
4264 const auto idColumn = batch->GetColumnByName (" id" );
43- const auto embeddingColumn = batch->GetColumnByName (EmbeddingSourceField);
44- const auto embeddingListColumn = dynamic_cast <arrow::ListArray*>(embeddingColumn.get ());
45-
46- // conversion
4765 const auto newIdColumn = arrow::compute::Cast (idColumn, arrow::uint64 ()).ValueOrDie ().make_array ();
4866
49- arrow::StringBuilder builder;
67+ // embedding
68+ const auto embeddingColumn = std::dynamic_pointer_cast<arrow::ListArray>(batch->GetColumnByName (EmbeddingSourceField));
69+ arrow::StringBuilder newEmbeddingsBuilder;
5070 for (int64_t row = 0 ; row < batch->num_rows (); ++row) {
51- auto embeddingAsFloats = static_cast < const arrow::FloatArray*>(embeddingListColumn ->value_slice (row). get ( ));
71+ const auto embeddingFloatList = std::static_pointer_cast< arrow::FloatArray>(embeddingColumn ->value_slice (row));
5272
53- std::string serialized (embeddingAsFloats-> length () * sizeof ( float ) + 1 , ' \0 ' ) ;
54- float * float_bytes = reinterpret_cast <float *>(serialized. data () );
55- for (int64_t i = 0 ; i < embeddingAsFloats ->length (); ++i) {
56- float_bytes[i] = embeddingAsFloats ->Value (i);
73+ TStringBuilder buffer ;
74+ NKnnVectorSerialization::TSerializer <float > serializer (&buffer. Out );
75+ for (int64_t i = 0 ; i < embeddingFloatList ->length (); ++i) {
76+ serializer. HandleElement (embeddingFloatList ->Value (i) );
5777 }
58- serialized. back () = ' \x01 ' ;
78+ serializer. Finish () ;
5979
60- if (const auto status = builder .Append (serialized ); !status.ok ()) {
80+ if (const auto status = newEmbeddingsBuilder .Append (buffer. MutRef () ); !status.ok ()) {
6181 status.Abort ();
6282 }
6383 }
64-
6584 std::shared_ptr<arrow::StringArray> newEmbeddingColumn;
66- if (const auto status = builder .Finish (&newEmbeddingColumn); !status.ok ()) {
85+ if (const auto status = newEmbeddingsBuilder .Finish (&newEmbeddingColumn); !status.ok ()) {
6786 status.Abort ();
6887 }
6988
70- // serialize
7189 const auto newSchema = arrow::schema ({
7290 arrow::field (" id" , arrow::uint64 ()),
7391 arrow::field (" embedding" , arrow::utf8 ()),
@@ -84,9 +102,67 @@ class TTransformingDataGenerator final: public IBulkDataGenerator {
84102 data->Data = arrow::ipc::SerializeRecordBatch (*newRecordBatch, arrow::ipc::IpcWriteOptions{}).ValueOrDie ()->ToString ();
85103 }
86104
87- void Convert (TDataPortion::TDataType& data) {
105+ void TransformCsv (TDataPortion::TCsv* data) {
106+ const auto table = Deserialize (data);
107+
108+ // id
109+ const auto idColumn = table->GetColumnByName (" id" );
110+
111+ // embedding
112+ const auto embeddingColumn = table->GetColumnByName (EmbeddingSourceField);
113+ arrow::StringBuilder newEmbeddingsBuilder;
114+ for (int64_t row = 0 ; row < table->num_rows (); ++row) {
115+ const auto embeddingListString = std::static_pointer_cast<arrow::StringArray>(embeddingColumn->Slice (row, 1 )->chunk (0 ))->Value (0 );
116+
117+ TStringBuf buffer (embeddingListString.data (), embeddingListString.size ());
118+ buffer.SkipPrefix (" [" );
119+ buffer.ChopSuffix (" ]" );
120+ TMemoryInput input (buffer);
121+
122+ TStringBuilder newEmbeddingBuilder;
123+ NKnnVectorSerialization::TSerializer<float > serializer (&newEmbeddingBuilder.Out );
124+ while (!input.Exhausted ()) {
125+ float val;
126+ input >> val;
127+ input.Skip (1 );
128+ serializer.HandleElement (val);
129+ }
130+ serializer.Finish ();
131+
132+ if (const auto status = newEmbeddingsBuilder.Append (newEmbeddingBuilder.MutRef ()); !status.ok ()) {
133+ status.Abort ();
134+ }
135+ }
136+ std::shared_ptr<arrow::StringArray> newEmbeddingColumn;
137+ if (const auto status = newEmbeddingsBuilder.Finish (&newEmbeddingColumn); !status.ok ()) {
138+ status.Abort ();
139+ }
140+
141+ const auto newSchema = arrow::schema ({
142+ arrow::field (" id" , arrow::uint64 ()),
143+ arrow::field (" embedding" , arrow::utf8 ()),
144+ });
145+ const auto newTable = arrow::Table::Make (
146+ newSchema,
147+ {
148+ idColumn,
149+ arrow::ChunkedArray::Make ({newEmbeddingColumn}).ValueOrDie (),
150+ }
151+ );
152+ auto outputStream = arrow::io::BufferOutputStream::Create ().ValueOrDie ();
153+ if (const auto status = arrow::csv::WriteCSV (*newTable, arrow::csv::WriteOptions::Defaults (), outputStream.get ()); !status.ok ()) {
154+ status.Abort ();
155+ }
156+ data->FormatString = " " ;
157+ data->Data = outputStream->Finish ().ValueOrDie ()->ToString ();
158+ }
159+
160+ void Transform (TDataPortion::TDataType& data) {
88161 if (auto * value = std::get_if<TDataPortion::TArrow>(&data)) {
89- ConvertArrow (value);
162+ TransformArrow (value);
163+ }
164+ if (auto * value = std::get_if<TDataPortion::TCsv>(&data)) {
165+ TransformCsv (value);
90166 }
91167 }
92168
@@ -100,7 +176,7 @@ class TTransformingDataGenerator final: public IBulkDataGenerator {
100176 virtual TDataPortions GenerateDataPortion () override {
101177 TDataPortions portions = InnerDataGenerator->GenerateDataPortion ();
102178 for (auto portion : portions) {
103- Convert (portion->MutableData ());
179+ Transform (portion->MutableData ());
104180 }
105181 return portions;
106182 }
@@ -122,7 +198,9 @@ void TWorkloadVectorFilesDataInitializer::ConfigureOpts(NLastGetopt::TOpts& opts
122198 ).Required ().StoreResult (&DataFiles);
123199 opts.AddLongOption (' t' , " transform" ,
124200 " Perform transformation of input data. "
125- " Parquet: leave only required fields, cast to expected types, convert list of floats into serialized ydb representation."
201+ " Parquet: leave only required fields, cast to expected types, convert list of floats into serialized representation. "
202+ " CSV: leave only required fields, parse float list from string and serialize. "
203+ " Reference for embedding serialization: https://ydb.tech/docs/yql/reference/udf/list/knn#functions-convert"
126204 ).Optional ().StoreTrue (&DoTransform);
127205 opts.AddLongOption (
128206 " transform-embedding-source-field" ,
0 commit comments