|
1 | | -import ast |
2 | | -from dataclasses import dataclass |
3 | | -from typing import TYPE_CHECKING, Optional, Sequence |
| 1 | +from typing import Optional |
4 | 2 |
|
5 | | -from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition |
6 | | -from pyspark.sql.pandas.types import from_arrow_schema |
| 3 | +from pyspark.sql.datasource import DataSource, DataSourceArrowWriter, DataSourceReader |
7 | 4 | from pyspark.sql.types import StructType |
8 | 5 |
|
9 | | -if TYPE_CHECKING: |
10 | | - from datasets import DatasetBuilder, IterableDataset |
| 6 | +from pyspark_huggingface.huggingface_sink import HuggingFaceSink |
| 7 | +from pyspark_huggingface.huggingface_source import HuggingFaceSource |
| 8 | + |
11 | 9 |
|
12 | 10 | class HuggingFaceDatasets(DataSource): |
13 | 11 | """ |
14 | | - A DataSource for reading and writing HuggingFace Datasets in Spark. |
15 | | -
|
16 | | - This data source allows reading public datasets from the HuggingFace Hub directly into Spark |
17 | | - DataFrames. The schema is automatically inferred from the dataset features. The split can be |
18 | | - specified using the `split` option. The default split is `train`. |
19 | | -
|
20 | | - Name: `huggingface` |
21 | | -
|
22 | | - Data Source Options: |
23 | | - - split (str): Specify which split to retrieve. Default: train |
24 | | - - config (str): Specify which subset or configuration to retrieve. |
25 | | - - streaming (bool): Specify whether to read a dataset without downloading it. |
26 | | -
|
27 | | - Notes: |
28 | | - ----- |
29 | | - - Currently it can only be used with public datasets. Private or gated ones are not supported. |
30 | | -
|
31 | | - Examples |
32 | | - -------- |
33 | | -
|
34 | | - Load a public dataset from the HuggingFace Hub. |
35 | | -
|
36 | | - >>> df = spark.read.format("huggingface").load("imdb") |
37 | | - DataFrame[text: string, label: bigint] |
38 | | -
|
39 | | - >>> df.show() |
40 | | - +--------------------+-----+ |
41 | | - | text|label| |
42 | | - +--------------------+-----+ |
43 | | - |I rented I AM CUR...| 0| |
44 | | - |"I Am Curious: Ye...| 0| |
45 | | - |... | ...| |
46 | | - +--------------------+-----+ |
47 | | -
|
48 | | - Load a specific split from a public dataset from the HuggingFace Hub. |
| 12 | + DataSource for reading and writing HuggingFace Datasets in Spark. |
49 | 13 |
|
50 | | - >>> spark.read.format("huggingface").option("split", "test").load("imdb").show() |
51 | | - +--------------------+-----+ |
52 | | - | text|label| |
53 | | - +--------------------+-----+ |
54 | | - |I love sci-fi and...| 0| |
55 | | - |Worth the enterta...| 0| |
56 | | - |... | ...| |
57 | | - +--------------------+-----+ |
| 14 | + Read |
| 15 | + ------ |
| 16 | + See :py:class:`HuggingFaceSource` for more details. |
58 | 17 |
|
59 | | - Enable predicate pushdown for Parquet datasets. |
60 | | -
|
61 | | - >>> spark.read.format("huggingface") \ |
62 | | - ... .option("filters", '[("language_score", ">", 0.99)]') \ |
63 | | - ... .option("columns", '["text", "language_score"]') \ |
64 | | - ... .load("HuggingFaceFW/fineweb-edu") \ |
65 | | - ... .show() |
66 | | - +--------------------+------------------+ |
67 | | - | text| language_score| |
68 | | - +--------------------+------------------+ |
69 | | - |died Aug. 28, 181...|0.9901925325393677| |
70 | | - |Coyotes spend a g...|0.9902171492576599| |
71 | | - |... | ...| |
72 | | - +--------------------+------------------+ |
| 18 | + Write |
| 19 | + ------ |
| 20 | + See :py:class:`HuggingFaceSink` for more details. |
73 | 21 | """ |
74 | 22 |
|
75 | | - DEFAULT_SPLIT: str = "train" |
76 | | - |
77 | | - def __init__(self, options): |
| 23 | + # Delegate the source and sink methods to the respective classes. |
| 24 | + def __init__(self, options: dict): |
78 | 25 | super().__init__(options) |
79 | | - from datasets import load_dataset_builder |
80 | | - |
81 | | - if "path" not in options or not options["path"]: |
82 | | - raise Exception("You must specify a dataset name.") |
| 26 | + self.options = options |
| 27 | + self.source: Optional[HuggingFaceSource] = None |
| 28 | + self.sink: Optional[HuggingFaceSink] = None |
83 | 29 |
|
84 | | - kwargs = dict(self.options) |
85 | | - self.dataset_name = kwargs.pop("path") |
86 | | - self.config_name = kwargs.pop("config", None) |
87 | | - self.split = kwargs.pop("split", self.DEFAULT_SPLIT) |
88 | | - self.streaming = kwargs.pop("streaming", "true").lower() == "true" |
89 | | - for arg in kwargs: |
90 | | - if kwargs[arg].lower() == "true": |
91 | | - kwargs[arg] = True |
92 | | - elif kwargs[arg].lower() == "false": |
93 | | - kwargs[arg] = False |
94 | | - else: |
95 | | - try: |
96 | | - kwargs[arg] = ast.literal_eval(kwargs[arg]) |
97 | | - except ValueError: |
98 | | - pass |
| 30 | + def get_source(self) -> HuggingFaceSource: |
| 31 | + if self.source is None: |
| 32 | + self.source = HuggingFaceSource(self.options.copy()) |
| 33 | + return self.source |
99 | 34 |
|
100 | | - self.builder = load_dataset_builder(self.dataset_name, self.config_name, **kwargs) |
101 | | - streaming_dataset = self.builder.as_streaming_dataset() |
102 | | - if self.split not in streaming_dataset: |
103 | | - raise Exception(f"Split {self.split} is invalid. Valid options are {list(streaming_dataset)}") |
104 | | - |
105 | | - self.streaming_dataset = streaming_dataset[self.split] |
106 | | - if not self.streaming_dataset.features: |
107 | | - self.streaming_dataset = self.streaming_dataset._resolve_features() |
| 35 | + def get_sink(self): |
| 36 | + if self.sink is None: |
| 37 | + self.sink = HuggingFaceSink(self.options.copy()) |
| 38 | + return self.sink |
108 | 39 |
|
109 | 40 | @classmethod |
110 | 41 | def name(cls): |
111 | 42 | return "huggingface" |
112 | 43 |
|
113 | 44 | def schema(self): |
114 | | - return from_arrow_schema(self.streaming_dataset.features.arrow_schema) |
| 45 | + return self.get_source().schema() |
115 | 46 |
|
116 | 47 | def reader(self, schema: StructType) -> "DataSourceReader": |
117 | | - return HuggingFaceDatasetsReader( |
118 | | - schema, |
119 | | - builder=self.builder, |
120 | | - split=self.split, |
121 | | - streaming_dataset=self.streaming_dataset if self.streaming else None |
122 | | - ) |
123 | | - |
124 | | - |
125 | | -@dataclass |
126 | | -class Shard(InputPartition): |
127 | | - """ Represents a dataset shard. """ |
128 | | - index: int |
129 | | - |
130 | | - |
131 | | -class HuggingFaceDatasetsReader(DataSourceReader): |
132 | | - |
133 | | - def __init__(self, schema: StructType, builder: "DatasetBuilder", split: str, streaming_dataset: Optional["IterableDataset"]): |
134 | | - self.schema = schema |
135 | | - self.builder = builder |
136 | | - self.split = split |
137 | | - self.streaming_dataset = streaming_dataset |
138 | | - # Get and validate the split name |
139 | | - |
140 | | - def partitions(self) -> Sequence[InputPartition]: |
141 | | - if self.streaming_dataset: |
142 | | - return [Shard(index=i) for i in range(self.streaming_dataset.num_shards)] |
143 | | - else: |
144 | | - return [Shard(index=0)] |
| 48 | + return self.get_source().reader(schema) |
145 | 49 |
|
146 | | - def read(self, partition: Shard): |
147 | | - columns = [field.name for field in self.schema.fields] |
148 | | - if self.streaming_dataset: |
149 | | - shard = self.streaming_dataset.shard(num_shards=self.streaming_dataset.num_shards, index=partition.index) |
150 | | - if shard._ex_iterable.iter_arrow: |
151 | | - for _, pa_table in shard._ex_iterable.iter_arrow(): |
152 | | - yield from pa_table.select(columns).to_batches() |
153 | | - else: |
154 | | - for _, example in shard: |
155 | | - yield example |
156 | | - else: |
157 | | - self.builder.download_and_prepare() |
158 | | - dataset = self.builder.as_dataset(self.split) |
159 | | - # Get the underlying arrow table of the dataset |
160 | | - table = dataset._data |
161 | | - yield from table.select(columns).to_batches() |
| 50 | + def writer(self, schema: StructType, overwrite: bool) -> "DataSourceArrowWriter": |
| 51 | + return self.get_sink().writer(schema, overwrite) |
0 commit comments