11#include " vector_data_generator.h"
22
3+ #include < contrib/libs/apache/arrow/cpp/src/arrow/array/array_binary.h>
4+ #include < contrib/libs/apache/arrow/cpp/src/arrow/array/array_nested.h>
5+ #include < contrib/libs/apache/arrow/cpp/src/arrow/array/array_primitive.h>
6+ #include < contrib/libs/apache/arrow/cpp/src/arrow/array/builder_binary.h>
7+ #include < contrib/libs/apache/arrow/cpp/src/arrow/compute/cast.h>
8+ #include < contrib/libs/apache/arrow/cpp/src/arrow/io/memory.h>
9+ #include < contrib/libs/apache/arrow/cpp/src/arrow/ipc/dictionary.h>
10+ #include < contrib/libs/apache/arrow/cpp/src/arrow/ipc/reader.h>
11+ #include < contrib/libs/apache/arrow/cpp/src/arrow/ipc/writer.h>
12+ #include < contrib/libs/apache/arrow/cpp/src/arrow/record_batch.h>
13+ #include < contrib/libs/apache/arrow/cpp/src/arrow/type.h>
14+ #include < contrib/libs/apache/arrow/cpp/src/arrow/type_fwd.h>
15+
316namespace NYdbWorkload {
417
18+ namespace {
19+
20+ class TTransformingDataGenerator final : public IBulkDataGenerator {
21+ private:
22+ std::shared_ptr<IBulkDataGenerator> InnerDataGenerator;
23+ const TString EmbeddingSourceField;
24+
25+ private:
26+ static std::pair<std::shared_ptr<arrow::Schema>, std::shared_ptr<arrow::RecordBatch>> Deserialize (TDataPortion::TArrow* data) {
27+ arrow::ipc::DictionaryMemo dictionary;
28+
29+ arrow::io::BufferReader schemaBuffer (arrow::util::string_view (data->Schema .data (), data->Schema .size ()));
30+ const std::shared_ptr<arrow::Schema> schema = arrow::ipc::ReadSchema (&schemaBuffer, &dictionary).ValueOrDie ();
31+
32+ arrow::io::BufferReader recordBatchBuffer (arrow::util::string_view (data->Data .data (), data->Data .size ()));
33+ const std::shared_ptr<arrow::RecordBatch> recordBatch = arrow::ipc::ReadRecordBatch (schema, &dictionary, {}, &recordBatchBuffer).ValueOrDie ();
34+
35+ return std::make_pair (schema, recordBatch);
36+ }
37+
38+ void ConvertArrow (TDataPortion::TArrow* data) {
39+ const auto [schema, batch] = Deserialize (data);
40+
41+ // extract
42+ const auto idColumn = batch->GetColumnByName (" id" );
43+ const auto embeddingColumn = batch->GetColumnByName (EmbeddingSourceField);
44+ const auto embeddingListColumn = dynamic_cast <arrow::ListArray*>(embeddingColumn.get ());
45+
46+ // conversion
47+ const auto newIdColumn = arrow::compute::Cast (idColumn, arrow::uint64 ()).ValueOrDie ().make_array ();
48+
49+ arrow::StringBuilder builder;
50+ for (int64_t row = 0 ; row < batch->num_rows (); ++row) {
51+ auto embeddingAsFloats = static_cast <const arrow::FloatArray*>(embeddingListColumn->value_slice (row).get ());
52+
53+ std::string serialized (embeddingAsFloats->length () * sizeof (float ) + 1 , ' \0 ' );
54+ float * float_bytes = reinterpret_cast <float *>(serialized.data ());
55+ for (int64_t i = 0 ; i < embeddingAsFloats->length (); ++i) {
56+ float_bytes[i] = embeddingAsFloats->Value (i);
57+ }
58+ serialized.back () = ' \x01 ' ;
59+
60+ if (const auto status = builder.Append (serialized); !status.ok ()) {
61+ status.Abort ();
62+ }
63+ }
64+
65+ std::shared_ptr<arrow::StringArray> newEmbeddingColumn;
66+ if (const auto status = builder.Finish (&newEmbeddingColumn); !status.ok ()) {
67+ status.Abort ();
68+ }
69+
70+ // serialize
71+ const auto newSchema = arrow::schema ({
72+ arrow::field (" id" , arrow::uint64 ()),
73+ arrow::field (" embedding" , arrow::utf8 ()),
74+ });
75+ const auto newRecordBatch = arrow::RecordBatch::Make (
76+ newSchema,
77+ batch->num_rows (),
78+ {
79+ newIdColumn,
80+ newEmbeddingColumn,
81+ }
82+ );
83+ data->Schema = arrow::ipc::SerializeSchema (*newSchema).ValueOrDie ()->ToString ();
84+ data->Data = arrow::ipc::SerializeRecordBatch (*newRecordBatch, arrow::ipc::IpcWriteOptions{}).ValueOrDie ()->ToString ();
85+ }
86+
87+ void Convert (TDataPortion::TDataType& data) {
88+ if (auto * value = std::get_if<TDataPortion::TArrow>(&data)) {
89+ ConvertArrow (value);
90+ }
91+ }
92+
93+ public:
94+ TTransformingDataGenerator (std::shared_ptr<IBulkDataGenerator> innerDataGenerator, const TString embeddingSourceField)
95+ : IBulkDataGenerator(innerDataGenerator->GetName (), innerDataGenerator->GetSize())
96+ , InnerDataGenerator(innerDataGenerator)
97+ , EmbeddingSourceField(embeddingSourceField)
98+ {}
99+
100+ virtual TDataPortions GenerateDataPortion () override {
101+ TDataPortions portions = InnerDataGenerator->GenerateDataPortion ();
102+ for (auto portion : portions) {
103+ Convert (portion->MutableData ());
104+ }
105+ return portions;
106+ }
107+ };
108+
109+ }
110+
5111TWorkloadVectorFilesDataInitializer::TWorkloadVectorFilesDataInitializer (const TVectorWorkloadParams& params)
6112 : TWorkloadDataInitializerBase(" files" , " Import vectors from files" , params)
7113 , Params(params)
@@ -12,14 +118,33 @@ void TWorkloadVectorFilesDataInitializer::ConfigureOpts(NLastGetopt::TOpts& opts
12118 " File or Directory with dataset. If directory is set, all its available files will be used. "
13119 " Supports zipped and unzipped csv, tsv files and parquet ones that may be downloaded here: "
14120 " https://huggingface.co/datasets/Cohere/wikipedia-22-12-simple-embeddings. "
15- " For better perfomanse you may split it to some parts for parrallel upload."
121+ " For better performance you may split it into some parts for parallel upload."
16122 ).Required ().StoreResult (&DataFiles);
123+ opts.AddLongOption (' t' , " transform" ,
124+ " Perform transformation of input data. "
125+ " Parquet: leave only required fields, cast to expected types, convert list of floats into serialized ydb representation."
126+ ).Optional ().StoreTrue (&DoTransform);
127+ opts.AddLongOption (
128+ " transform-embedding-source-field" ,
129+ " Specify field that contains list of floats to be converted into YDB embedding format."
130+ ).DefaultValue (EmbeddingSourceField).StoreResult (&EmbeddingSourceField);
17131}
18132
19133TBulkDataGeneratorList TWorkloadVectorFilesDataInitializer::DoGetBulkInitialData () {
20- return {
21- std::make_shared<TDataGenerator>(*this , Params.TableName , 0 , Params.TableName , DataFiles, Params.GetColumns (), TDataGenerator::EPortionSizeUnit::Line)
22- };
134+ auto dataGenerator = std::make_shared<TDataGenerator>(
135+ *this ,
136+ Params.TableName ,
137+ 0 ,
138+ Params.TableName ,
139+ DataFiles,
140+ Params.GetColumns (),
141+ TDataGenerator::EPortionSizeUnit::Line
142+ );
143+
144+ if (DoTransform) {
145+ return {std::make_shared<TTransformingDataGenerator>(dataGenerator, EmbeddingSourceField)};
146+ }
147+ return {dataGenerator};
23148}
24149
25150} // namespace NYdbWorkload
0 commit comments