Note
Go to the end to download the full example code
Working With Files#
Files are one of the most fundamental entities that users of Python work with, and they are fully supported by Flyte. In the IDL, they are known as Blob literals which are backed by the blob type.
Let’s assume our mission here is pretty simple. We download a few csv file
links, read them with the python built-in csv.DictReader
function,
normalize some pre-specified columns, and output the normalized columns to
another csv file.
First, let’s import the libraries.
import csv
import os
from collections import defaultdict
from typing import List
import flytekit
from flytekit import task, workflow
from flytekit.types.file import FlyteFile
Next, we write a task that accepts a FlyteFile
, a list of column names,
and a list of column names to normalize, then outputs a csv file of only
the normalized columns. For this example we’ll use z-score normalization,
i.e. mean-centering and standard-deviation-scaling.
Note
The FlyteFile
literal can be scoped with a string, which gets inserted
into the format of the Blob type (“jpeg” is the string in
FlyteFile[typing.TypeVar("jpeg")]
). The format is entirely optional,
and if not specified, defaults to ""
.
@task
def normalize_columns(
csv_url: FlyteFile,
column_names: List[str],
columns_to_normalize: List[str],
output_location: str,
) -> FlyteFile:
# read the data from the raw csv file
parsed_data = defaultdict(list)
with open(csv_url, newline="\n") as input_file:
reader = csv.DictReader(input_file, fieldnames=column_names)
for row in (x for i, x in enumerate(reader) if i > 0):
for column in columns_to_normalize:
parsed_data[column].append(float(row[column].strip()))
# normalize the data
normalized_data = defaultdict(list)
for colname, values in parsed_data.items():
mean = sum(values) / len(values)
std = (sum([(x - mean) ** 2 for x in values]) / len(values)) ** 0.5
normalized_data[colname] = [(x - mean) / std for x in values]
# write to local path
out_path = os.path.join(
flytekit.current_context().working_directory,
f"normalized-{os.path.basename(csv_url.path).rsplit('.')[0]}.csv",
)
with open(out_path, mode="w") as output_file:
writer = csv.DictWriter(output_file, fieldnames=columns_to_normalize)
writer.writeheader()
for row in zip(*normalized_data.values()):
writer.writerow({k: row[i] for i, k in enumerate(columns_to_normalize)})
if output_location:
return FlyteFile(path=out_path, remote_path=output_location)
else:
return FlyteFile(path=out_path)
When the image URL is sent to the task, the Flytekit engine translates it into a FlyteFile
object on the local
drive (but doesn’t download it). The act of calling download
method should trigger the download, and the path
attribute enables to open
the file.
If the output_location
argument is specified, it will be passed to the remote_path
argument of FlyteFile
,
which will use that path as the storage location instead of a random location (Flyte’s object store).
When this task finishes, the Flytekit engine returns the FlyteFile
instance, uploads the file to the location, and
creates a Blob literal pointing to it.
Lastly, we define a normalize_csv_files
workflow. Note that there is an output_location
argument specified in
the workflow. This is passed to the location
input of the task. If it’s not an empty string, the task attempts to
upload its file to that location.
@workflow
def normalize_csv_file(
csv_url: FlyteFile,
column_names: List[str],
columns_to_normalize: List[str],
output_location: str = "",
) -> FlyteFile:
return normalize_columns(
csv_url=csv_url,
column_names=column_names,
columns_to_normalize=columns_to_normalize,
output_location=output_location,
)
Finally, we can run the workflow locally.
if __name__ == "__main__":
default_files = [
(
"https://people.sc.fsu.edu/~jburkardt/data/csv/biostats.csv",
["Name", "Sex", "Age", "Heights (in)", "Weight (lbs)"],
["Age"],
),
(
"https://people.sc.fsu.edu/~jburkardt/data/csv/faithful.csv",
["Index", "Eruption length (mins)", "Eruption wait (mins)"],
["Eruption length (mins)"],
),
]
print(f"Running {__file__} main...")
for index, (csv_url, column_names, columns_to_normalize) in enumerate(
default_files
):
normalized_columns = normalize_csv_file(
csv_url=csv_url,
column_names=column_names,
columns_to_normalize=columns_to_normalize,
)
print(
f"Running normalize_csv_file workflow on {default_files}: "
f"{normalized_columns}"
)
Total running time of the script: ( 0 minutes 0.000 seconds)