Skip to content

Commit 76cf590

Browse files
committed
Add csv conversion
1 parent 0f65390 commit 76cf590

File tree

1 file changed

+99
-21
lines changed

1 file changed

+99
-21
lines changed

ydb/library/workload/vector/vector_data_generator.cpp

Lines changed: 99 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,27 @@
11
#include "vector_data_generator.h"
22

3+
#include <ydb/library/yql/udfs/common/knn/knn-serializer-shared.h>
4+
35
#include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_binary.h>
46
#include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_nested.h>
57
#include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_primitive.h>
68
#include <contrib/libs/apache/arrow/cpp/src/arrow/array/builder_binary.h>
9+
#include <contrib/libs/apache/arrow/cpp/src/arrow/chunked_array.h>
710
#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/cast.h>
11+
#include <contrib/libs/apache/arrow/cpp/src/arrow/csv/api.h>
12+
#include <contrib/libs/apache/arrow/cpp/src/arrow/csv/reader.h>
13+
#include <contrib/libs/apache/arrow/cpp/src/arrow/csv/writer.h>
814
#include <contrib/libs/apache/arrow/cpp/src/arrow/io/memory.h>
915
#include <contrib/libs/apache/arrow/cpp/src/arrow/ipc/dictionary.h>
1016
#include <contrib/libs/apache/arrow/cpp/src/arrow/ipc/reader.h>
1117
#include <contrib/libs/apache/arrow/cpp/src/arrow/ipc/writer.h>
1218
#include <contrib/libs/apache/arrow/cpp/src/arrow/record_batch.h>
19+
#include <contrib/libs/apache/arrow/cpp/src/arrow/table.h>
1320
#include <contrib/libs/apache/arrow/cpp/src/arrow/type.h>
1421
#include <contrib/libs/apache/arrow/cpp/src/arrow/type_fwd.h>
1522

23+
#include <util/stream/mem.h>
24+
1625
namespace NYdbWorkload {
1726

1827
namespace {
@@ -35,39 +44,48 @@ class TTransformingDataGenerator final: public IBulkDataGenerator {
3544
return std::make_pair(schema, recordBatch);
3645
}
3746

38-
void ConvertArrow(TDataPortion::TArrow* data) {
47+
static std::shared_ptr<arrow::Table> Deserialize(TDataPortion::TCsv* data) {
48+
auto bufferReader = std::make_shared<arrow::io::BufferReader>(arrow::util::string_view(data->Data.data(), data->Data.size()));
49+
auto csvReader = arrow::csv::TableReader::Make(
50+
arrow::io::default_io_context(),
51+
bufferReader,
52+
arrow::csv::ReadOptions::Defaults(),
53+
arrow::csv::ParseOptions::Defaults(),
54+
arrow::csv::ConvertOptions::Defaults()
55+
).ValueOrDie();
56+
57+
return csvReader->Read().ValueOrDie();
58+
}
59+
60+
void TransformArrow(TDataPortion::TArrow* data) {
3961
const auto [schema, batch] = Deserialize(data);
4062

41-
// extract
63+
// id
4264
const auto idColumn = batch->GetColumnByName("id");
43-
const auto embeddingColumn = batch->GetColumnByName(EmbeddingSourceField);
44-
const auto embeddingListColumn = dynamic_cast<arrow::ListArray*>(embeddingColumn.get());
45-
46-
// conversion
4765
const auto newIdColumn = arrow::compute::Cast(idColumn, arrow::uint64()).ValueOrDie().make_array();
4866

49-
arrow::StringBuilder builder;
67+
// embedding
68+
const auto embeddingColumn = std::dynamic_pointer_cast<arrow::ListArray>(batch->GetColumnByName(EmbeddingSourceField));
69+
arrow::StringBuilder newEmbeddingsBuilder;
5070
for (int64_t row = 0; row < batch->num_rows(); ++row) {
51-
auto embeddingAsFloats = static_cast<const arrow::FloatArray*>(embeddingListColumn->value_slice(row).get());
71+
const auto embeddingFloatList = std::static_pointer_cast<arrow::FloatArray>(embeddingColumn->value_slice(row));
5272

53-
std::string serialized(embeddingAsFloats->length() * sizeof(float) + 1, '\0');
54-
float* float_bytes = reinterpret_cast<float*>(serialized.data());
55-
for (int64_t i = 0; i < embeddingAsFloats->length(); ++i) {
56-
float_bytes[i] = embeddingAsFloats->Value(i);
73+
TStringBuilder buffer;
74+
NKnnVectorSerialization::TSerializer<float> serializer(&buffer.Out);
75+
for (int64_t i = 0; i < embeddingFloatList->length(); ++i) {
76+
serializer.HandleElement(embeddingFloatList->Value(i));
5777
}
58-
serialized.back() = '\x01';
78+
serializer.Finish();
5979

60-
if (const auto status = builder.Append(serialized); !status.ok()) {
80+
if (const auto status = newEmbeddingsBuilder.Append(buffer.MutRef()); !status.ok()) {
6181
status.Abort();
6282
}
6383
}
64-
6584
std::shared_ptr<arrow::StringArray> newEmbeddingColumn;
66-
if (const auto status = builder.Finish(&newEmbeddingColumn); !status.ok()) {
85+
if (const auto status = newEmbeddingsBuilder.Finish(&newEmbeddingColumn); !status.ok()) {
6786
status.Abort();
6887
}
6988

70-
// serialize
7189
const auto newSchema = arrow::schema({
7290
arrow::field("id", arrow::uint64()),
7391
arrow::field("embedding", arrow::utf8()),
@@ -84,9 +102,67 @@ class TTransformingDataGenerator final: public IBulkDataGenerator {
84102
data->Data = arrow::ipc::SerializeRecordBatch(*newRecordBatch, arrow::ipc::IpcWriteOptions{}).ValueOrDie()->ToString();
85103
}
86104

87-
void Convert(TDataPortion::TDataType& data) {
105+
void TransformCsv(TDataPortion::TCsv* data) {
106+
const auto table = Deserialize(data);
107+
108+
// id
109+
const auto idColumn = table->GetColumnByName("id");
110+
111+
// embedding
112+
const auto embeddingColumn = table->GetColumnByName(EmbeddingSourceField);
113+
arrow::StringBuilder newEmbeddingsBuilder;
114+
for (int64_t row = 0; row < table->num_rows(); ++row) {
115+
const auto embeddingListString = std::static_pointer_cast<arrow::StringArray>(embeddingColumn->Slice(row, 1)->chunk(0))->Value(0);
116+
117+
TStringBuf buffer(embeddingListString.data(), embeddingListString.size());
118+
buffer.SkipPrefix("[");
119+
buffer.ChopSuffix("]");
120+
TMemoryInput input(buffer);
121+
122+
TStringBuilder newEmbeddingBuilder;
123+
NKnnVectorSerialization::TSerializer<float> serializer(&newEmbeddingBuilder.Out);
124+
while (!input.Exhausted()) {
125+
float val;
126+
input >> val;
127+
input.Skip(1);
128+
serializer.HandleElement(val);
129+
}
130+
serializer.Finish();
131+
132+
if (const auto status = newEmbeddingsBuilder.Append(newEmbeddingBuilder.MutRef()); !status.ok()) {
133+
status.Abort();
134+
}
135+
}
136+
std::shared_ptr<arrow::StringArray> newEmbeddingColumn;
137+
if (const auto status = newEmbeddingsBuilder.Finish(&newEmbeddingColumn); !status.ok()) {
138+
status.Abort();
139+
}
140+
141+
const auto newSchema = arrow::schema({
142+
arrow::field("id", arrow::uint64()),
143+
arrow::field("embedding", arrow::utf8()),
144+
});
145+
const auto newTable = arrow::Table::Make(
146+
newSchema,
147+
{
148+
idColumn,
149+
arrow::ChunkedArray::Make({newEmbeddingColumn}).ValueOrDie(),
150+
}
151+
);
152+
auto outputStream = arrow::io::BufferOutputStream::Create().ValueOrDie();
153+
if (const auto status = arrow::csv::WriteCSV(*newTable, arrow::csv::WriteOptions::Defaults(), outputStream.get()); !status.ok()) {
154+
status.Abort();
155+
}
156+
data->FormatString = "";
157+
data->Data = outputStream->Finish().ValueOrDie()->ToString();
158+
}
159+
160+
void Transform(TDataPortion::TDataType& data) {
88161
if (auto* value = std::get_if<TDataPortion::TArrow>(&data)) {
89-
ConvertArrow(value);
162+
TransformArrow(value);
163+
}
164+
if (auto* value = std::get_if<TDataPortion::TCsv>(&data)) {
165+
TransformCsv(value);
90166
}
91167
}
92168

@@ -100,7 +176,7 @@ class TTransformingDataGenerator final: public IBulkDataGenerator {
100176
virtual TDataPortions GenerateDataPortion() override {
101177
TDataPortions portions = InnerDataGenerator->GenerateDataPortion();
102178
for (auto portion : portions) {
103-
Convert(portion->MutableData());
179+
Transform(portion->MutableData());
104180
}
105181
return portions;
106182
}
@@ -122,7 +198,9 @@ void TWorkloadVectorFilesDataInitializer::ConfigureOpts(NLastGetopt::TOpts& opts
122198
).Required().StoreResult(&DataFiles);
123199
opts.AddLongOption('t', "transform",
124200
"Perform transformation of input data. "
125-
"Parquet: leave only required fields, cast to expected types, convert list of floats into serialized ydb representation."
201+
"Parquet: leave only required fields, cast to expected types, convert list of floats into serialized representation. "
202+
"CSV: leave only required fields, parse float list from string and serialize. "
203+
"Reference for embedding serialization: https://ydb.tech/docs/yql/reference/udf/list/knn#functions-convert"
126204
).Optional().StoreTrue(&DoTransform);
127205
opts.AddLongOption(
128206
"transform-embedding-source-field",

0 commit comments

Comments
 (0)