Testing agents in a local Python environment#
You can test agents locally without running the backend server.
To test an agent locally, create a class for the agent task that inherits from SyncAgentExecutorMixin
or AsyncAgentExecutorMixin
.
These mixins can handle synchronous and asynchronous tasks, respectively, and allow flytekit to mimic FlytePropeller’s behavior in calling the agent.
BigQuery example#
To test the BigQuery agent, copy the following code to a file called bigquery_task.py
, modifying as needed.
Note
In some cases, you will need to store credentials in your local environment when testing locally.
For example, you need to set the GOOGLE_APPLICATION_CREDENTIALS
environment variable when running BigQuery tasks to test the BigQuery agent.
Add AsyncAgentExecutorMixin
or SyncAgentExecutorMixin
to the class to tell flytekit to use the agent to run the task.
class BigQueryTask(AsyncAgentExecutorMixin, SQLTask[BigQueryConfig]):
...
class ChatGPTTask(SyncAgentExecutorMixin, PythonTask):
...
Flytekit will automatically use the agent to run the task in the local execution.
bigquery_doge_coin = BigQueryTask(
name=f"bigquery.doge_coin",
inputs=kwtypes(version=int),
query_template="SELECT * FROM `bigquery-public-data.crypto_dogecoin.transactions` WHERE version = @version LIMIT 10;",
output_structured_dataset_type=StructuredDataset,
task_config=BigQueryConfig(ProjectID="flyte-test-340607")
)
You can run the above example task locally and test the agent with the following command:
pyflyte run bigquery_task.py bigquery_doge_coin --version 10
You can also run a BigQuery task in your Python interpreter to test the agent locally.
Databricks example#
To test the Databricks agent, copy the following code to a file called databricks_task.py
, modifying as needed.
@task(task_config=Databricks(...))
def hello_spark(partitions: int) -> float:
print("Starting Spark with Partitions: {}".format(partitions))
n = 100000 * partitions
sess = flytekit.current_context().spark_session
count = (
sess.sparkContext.parallelize(range(1, n + 1), partitions).map(f).reduce(add)
)
pi_val = 4.0 * count / n
print("Pi val is :{}".format(pi_val))
return pi_val
To execute the Spark task on the agent, you must configure the raw-output-data-prefix
with a remote path.
This configuration ensures that flytekit transfers the input data to the blob storage and allows the Spark job running on Databricks to access the input data directly from the designated bucket.
Note
The Spark task will run locally if the raw-output-data-prefix
is not set.
pyflyte run --raw-output-data-prefix s3://my-s3-bucket/databricks databricks_task.py hello_spark