Skip to content

Commit 03ab100

Browse files
committed
Fixed minor bugs
1 parent e16ab0c commit 03ab100

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/tf_inputs/core.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ def from_file_paths(cls, file_paths, parse_fn=None, flatten=False, **kwargs):
9595

9696
def map_fn(filename):
9797
file_contents = tf.io.read_file(filename)
98-
if read_fn is not None:
99-
file_contents = read_fn(file_contents)
98+
if parse_fn is not None:
99+
file_contents = parse_fn(file_contents)
100100
return file_contents
101101

102102
def dataset_fn():
@@ -108,6 +108,7 @@ def dataset_fn():
108108
)
109109
if flatten:
110110
dataset = dataset.flat_map(tf.data.Dataset.from_tensor_slices)
111+
return dataset
111112

112113
return cls.from_dataset_fn(dataset_fn)
113114

@@ -120,7 +121,7 @@ def from_dataset(cls, dataset, **kwargs):
120121
@classmethod
121122
def from_dataset_fn(cls, dataset_fn, **kwargs):
122123
self = cls(**kwargs)
123-
self.read_data = dataset_fn.__get__(self)
124+
self.read_data = dataset_fn
124125
return self
125126

126127
@property

0 commit comments

Comments
 (0)