Skip to content

Commit 0f65390

Browse files
committed
Add transformation logic for parquet files
1 parent 3cee930 commit 0f65390

File tree

3 files changed

+132
-4
lines changed

3 files changed

+132
-4
lines changed
Lines changed: 129 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,113 @@
11
#include "vector_data_generator.h"
22

3+
#include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_binary.h>
4+
#include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_nested.h>
5+
#include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_primitive.h>
6+
#include <contrib/libs/apache/arrow/cpp/src/arrow/array/builder_binary.h>
7+
#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/cast.h>
8+
#include <contrib/libs/apache/arrow/cpp/src/arrow/io/memory.h>
9+
#include <contrib/libs/apache/arrow/cpp/src/arrow/ipc/dictionary.h>
10+
#include <contrib/libs/apache/arrow/cpp/src/arrow/ipc/reader.h>
11+
#include <contrib/libs/apache/arrow/cpp/src/arrow/ipc/writer.h>
12+
#include <contrib/libs/apache/arrow/cpp/src/arrow/record_batch.h>
13+
#include <contrib/libs/apache/arrow/cpp/src/arrow/type.h>
14+
#include <contrib/libs/apache/arrow/cpp/src/arrow/type_fwd.h>
15+
316
namespace NYdbWorkload {
417

18+
namespace {
19+
20+
class TTransformingDataGenerator final: public IBulkDataGenerator {
21+
private:
22+
std::shared_ptr<IBulkDataGenerator> InnerDataGenerator;
23+
const TString EmbeddingSourceField;
24+
25+
private:
26+
static std::pair<std::shared_ptr<arrow::Schema>, std::shared_ptr<arrow::RecordBatch>> Deserialize(TDataPortion::TArrow* data) {
27+
arrow::ipc::DictionaryMemo dictionary;
28+
29+
arrow::io::BufferReader schemaBuffer(arrow::util::string_view(data->Schema.data(), data->Schema.size()));
30+
const std::shared_ptr<arrow::Schema> schema = arrow::ipc::ReadSchema(&schemaBuffer, &dictionary).ValueOrDie();
31+
32+
arrow::io::BufferReader recordBatchBuffer(arrow::util::string_view(data->Data.data(), data->Data.size()));
33+
const std::shared_ptr<arrow::RecordBatch> recordBatch = arrow::ipc::ReadRecordBatch(schema, &dictionary, {}, &recordBatchBuffer).ValueOrDie();
34+
35+
return std::make_pair(schema, recordBatch);
36+
}
37+
38+
void ConvertArrow(TDataPortion::TArrow* data) {
39+
const auto [schema, batch] = Deserialize(data);
40+
41+
// extract
42+
const auto idColumn = batch->GetColumnByName("id");
43+
const auto embeddingColumn = batch->GetColumnByName(EmbeddingSourceField);
44+
const auto embeddingListColumn = dynamic_cast<arrow::ListArray*>(embeddingColumn.get());
45+
46+
// conversion
47+
const auto newIdColumn = arrow::compute::Cast(idColumn, arrow::uint64()).ValueOrDie().make_array();
48+
49+
arrow::StringBuilder builder;
50+
for (int64_t row = 0; row < batch->num_rows(); ++row) {
51+
auto embeddingAsFloats = static_cast<const arrow::FloatArray*>(embeddingListColumn->value_slice(row).get());
52+
53+
std::string serialized(embeddingAsFloats->length() * sizeof(float) + 1, '\0');
54+
float* float_bytes = reinterpret_cast<float*>(serialized.data());
55+
for (int64_t i = 0; i < embeddingAsFloats->length(); ++i) {
56+
float_bytes[i] = embeddingAsFloats->Value(i);
57+
}
58+
serialized.back() = '\x01';
59+
60+
if (const auto status = builder.Append(serialized); !status.ok()) {
61+
status.Abort();
62+
}
63+
}
64+
65+
std::shared_ptr<arrow::StringArray> newEmbeddingColumn;
66+
if (const auto status = builder.Finish(&newEmbeddingColumn); !status.ok()) {
67+
status.Abort();
68+
}
69+
70+
// serialize
71+
const auto newSchema = arrow::schema({
72+
arrow::field("id", arrow::uint64()),
73+
arrow::field("embedding", arrow::utf8()),
74+
});
75+
const auto newRecordBatch = arrow::RecordBatch::Make(
76+
newSchema,
77+
batch->num_rows(),
78+
{
79+
newIdColumn,
80+
newEmbeddingColumn,
81+
}
82+
);
83+
data->Schema = arrow::ipc::SerializeSchema(*newSchema).ValueOrDie()->ToString();
84+
data->Data = arrow::ipc::SerializeRecordBatch(*newRecordBatch, arrow::ipc::IpcWriteOptions{}).ValueOrDie()->ToString();
85+
}
86+
87+
void Convert(TDataPortion::TDataType& data) {
88+
if (auto* value = std::get_if<TDataPortion::TArrow>(&data)) {
89+
ConvertArrow(value);
90+
}
91+
}
92+
93+
public:
94+
TTransformingDataGenerator(std::shared_ptr<IBulkDataGenerator> innerDataGenerator, const TString embeddingSourceField)
95+
: IBulkDataGenerator(innerDataGenerator->GetName(), innerDataGenerator->GetSize())
96+
, InnerDataGenerator(innerDataGenerator)
97+
, EmbeddingSourceField(embeddingSourceField)
98+
{}
99+
100+
virtual TDataPortions GenerateDataPortion() override {
101+
TDataPortions portions = InnerDataGenerator->GenerateDataPortion();
102+
for (auto portion : portions) {
103+
Convert(portion->MutableData());
104+
}
105+
return portions;
106+
}
107+
};
108+
109+
}
110+
5111
TWorkloadVectorFilesDataInitializer::TWorkloadVectorFilesDataInitializer(const TVectorWorkloadParams& params)
6112
: TWorkloadDataInitializerBase("files", "Import vectors from files", params)
7113
, Params(params)
@@ -12,14 +118,33 @@ void TWorkloadVectorFilesDataInitializer::ConfigureOpts(NLastGetopt::TOpts& opts
12118
"File or Directory with dataset. If directory is set, all its available files will be used. "
13119
"Supports zipped and unzipped csv, tsv files and parquet ones that may be downloaded here: "
14120
"https://huggingface.co/datasets/Cohere/wikipedia-22-12-simple-embeddings. "
15-
"For better perfomanse you may split it to some parts for parrallel upload."
121+
"For better performance you may split it into some parts for parallel upload."
16122
).Required().StoreResult(&DataFiles);
123+
opts.AddLongOption('t', "transform",
124+
"Perform transformation of input data. "
125+
"Parquet: leave only required fields, cast to expected types, convert list of floats into serialized ydb representation."
126+
).Optional().StoreTrue(&DoTransform);
127+
opts.AddLongOption(
128+
"transform-embedding-source-field",
129+
"Specify field that contains list of floats to be converted into YDB embedding format."
130+
).DefaultValue(EmbeddingSourceField).StoreResult(&EmbeddingSourceField);
17131
}
18132

19133
TBulkDataGeneratorList TWorkloadVectorFilesDataInitializer::DoGetBulkInitialData() {
20-
return {
21-
std::make_shared<TDataGenerator>(*this, Params.TableName, 0, Params.TableName, DataFiles, Params.GetColumns(), TDataGenerator::EPortionSizeUnit::Line)
22-
};
134+
auto dataGenerator = std::make_shared<TDataGenerator>(
135+
*this,
136+
Params.TableName,
137+
0,
138+
Params.TableName,
139+
DataFiles,
140+
Params.GetColumns(),
141+
TDataGenerator::EPortionSizeUnit::Line
142+
);
143+
144+
if (DoTransform) {
145+
return {std::make_shared<TTransformingDataGenerator>(dataGenerator, EmbeddingSourceField)};
146+
}
147+
return {dataGenerator};
23148
}
24149

25150
} // namespace NYdbWorkload

ydb/library/workload/vector/vector_data_generator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ class TWorkloadVectorFilesDataInitializer : public TWorkloadDataInitializerBase
1111
private:
1212
const TVectorWorkloadParams& Params;
1313
TString DataFiles;
14+
bool DoTransform = false;
15+
TString EmbeddingSourceField = "embedding";
1416

1517
public:
1618
TWorkloadVectorFilesDataInitializer(const TVectorWorkloadParams& params);

ydb/library/workload/vector/ya.make

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ SRCS(
1212
)
1313

1414
PEERDIR(
15+
contrib/libs/apache/arrow
1516
ydb/library/workload/abstract
1617
)
1718

0 commit comments

Comments
 (0)