Orchestrate Ray jobs with Apache Airflow®

Ray is an open-source framework for scaling Python applications, particularly for machine learning and AI workloads where it provides the layer for parallel processing and distributed computing. Many large language models (LLMs) are trained using Ray, including OpenAI’s GPT models.

The Ray provider package for Apache Airflow® allows you to interact with Ray from your Airflow Dags. This tutorial demonstrates how to use the Ray provider package to orchestrate a simple Ray job with Airflow in an existing Ray cluster. For more in-depth information, see the Ray provider documentation.

For instructions on how to run Ray jobs on the Anyscale platform with Airflow, see the Orchestrate Ray jobs on Anyscale with Apache Airflow® tutorial.

This tutorial shows a simple implementation of the Ray provider package. For a more complex example, see the Processing User Feedback: an LLM-fine-tuning reference architecture with Ray on Anyscale reference architecture.

Time to complete

This tutorial takes approximately 30 minutes to complete.

Assumed knowledge

To get the most out of this tutorial, make sure you have an understanding of:

Prerequisites

  • The Astro CLI.
  • Optional: A pre-existing Ray cluster. This tutorial shows how to spin up a local Ray cluster using Docker. To connect to your existing Ray cluster, modify the connection defined in Step 2.

The Ray provider package can also create a Ray cluster for you in an existing Kubernetes cluster. For more information, see the Ray provider package documentation. Note that you need a Kubernetes cluster with a pre-configured LoadBalancer service to use the Ray provider package.

Step 1: Configure your Astro project

Use the Astro CLI to create and run an Airflow project on your local machine.

  1. Create a new Astro project:

    1$ mkdir astro-ray-tutorial && cd astro-ray-tutorial
    2$ astro dev init
  2. In the requirements.txt file, add the Ray provider.

    astro-provider-ray==0.3.1
  3. (Optional). If you don’t have a pre-existing Ray cluster, you can spin up a local Ray cluster alongside your local Astro project by using a docker-compose.override.yml file. Create a new file in your project’s root directory called docker-compose.override.yml and add the following:

    1services:
    2
    3 ray-head:
    4 image: rayproject/ray:latest
    5 container_name: ray-head
    6 command: >
    7 ray start
    8 --head
    9 --dashboard-host=0.0.0.0
    10 --dashboard-port=8265
    11 --ray-client-server-port=10001
    12 --port=6379
    13 --num-cpus=4
    14 --block
    15 ports:
    16 - "8265:8265" # Ray dashboard
    17 - "10001:10001" # Ray client server
    18 - "6379:6379" # Ray Redis
    19 networks:
    20 - airflow
    21 environment:
    22 - RAY_GRAFANA_HOST=http://grafana:3000
    23 - RAY_PROMETHEUS_HOST=http://prometheus:9090
    24 healthcheck:
    25 test: ["CMD", "ray", "status"]
    26 interval: 30s
    27 timeout: 10s
    28 retries: 5
    29 start_period: 30s
    30 restart: unless-stopped
    31
    32networks:
    33 airflow:
  4. In your .env file, specify your Ray cluster address. Modify this address if you are using a pre-existing Ray cluster.

    RAY_ADDRESS=http://ray-head:8265
  5. Run the following command to start your Astro project:

    1astro dev start

Step 2: Configure a Ray connection

For Astro customers, Astronomer recommends using the Astro Environment Manager to store connections in an Astro-managed secrets backend. These connections can be shared across multiple deployed and local Airflow environments. See Manage Astro connections in branch-based deploy workflows.

  1. In the Airflow UI, go to Admin -> Connections and click +.

  2. Create a new connection and choose the Ray connection type. If you used the docker-compose.override.yml file to spin up a local Ray cluster, use the information below. If you are connecting to your existing Ray cluster, you will need to modify your values accordingly.

    • Connection ID: ray_conn
    • Host: ray-head
    • Port: 8265
    • Extra Fields:
      • ray_dashboard_url: "http://ray-head:8265"
      • disable_job_log_to_stdout: false
  3. Click Save.

If you are connecting to a Ray cluster running on a cloud provider, you need to provide the .kubeconfig file of the Kubernetes cluster where the Ray cluster is running as Kube config (JSON format), as well as valid Cloud credentials as environment variables.

Step 3: Write a Dag to orchestrate Ray jobs

  1. Create a new file in your dags directory called ray_tutorial.py.

  2. Copy and paste the code below into the file:

Taskflow
1"""
2## Ray Tutorial
3
4This tutorial demonstrates how to use the Ray provider in Airflow to parallelize
5a task using Ray.
6"""
7
8from airflow.sdk import dag, task
9from ray_provider.decorators import ray
10
11CONN_ID = "ray_conn"
12RAY_TASK_CONFIG = {
13 "conn_id": CONN_ID,
14 "num_cpus": 1,
15 "num_gpus": 0,
16 "memory": 0,
17 "poll_interval": 5,
18}
19
20
21@dag(doc_md=__doc__)
22def ray_example_dag():
23
24 @task
25 def generate_data() -> list:
26 """
27 Generate sample data
28 Returns:
29 list: List of integers
30 """
31 import random
32
33 return [random.randint(1, 100) for _ in range(10)]
34
35 # use the @ray.task decorator to parallelize the task
36 @ray.task(config=RAY_TASK_CONFIG)
37 def get_mean_squared_value(data: list) -> float:
38 """
39 Get the mean squared value from a list of integers
40 Args:
41 data (list): List of integers
42 Returns:
43 float: Mean value of the list
44 """
45 import numpy as np
46 import ray
47
48 @ray.remote
49 def square(x: int) -> int:
50 """
51 Square a number
52 Args:
53 x (int): Number to square
54 Returns:
55 int: Squared number
56 """
57 return x**2
58
59 ray.init()
60 data = np.array(data)
61 futures = [square.remote(x) for x in data]
62 results = ray.get(futures)
63 mean = np.mean(results)
64 print(f"Mean squared value: {mean}")
65
66 data = generate_data()
67 get_mean_squared_value(data)
68
69
70ray_example_dag()
Traditional
1"""
2## Ray Tutorial
3
4This tutorial demonstrates how to use the Ray provider in Airflow to
5parallelize a task using Ray.
6"""
7
8from airflow.sdk import dag, chain
9from airflow.providers.standard.operators.python import PythonOperator
10from ray_provider.operators import SubmitRayJob
11from pathlib import Path
12
13CONN_ID = "ray_conn"
14FOLDER_PATH = Path(__file__).parent
15RAY_RUNTIME_ENV = {"working_dir": str(FOLDER_PATH)}
16
17
18def _generate_data() -> list:
19 """
20 Generate sample data
21 Returns:
22 list: List of integers
23 """
24 import random
25
26 return [random.randint(1, 100) for _ in range(10)]
27
28
29@dag(doc_md=__doc__)
30def ray_tutorial():
31
32 data = PythonOperator(
33 task_id="generate_data",
34 python_callable=_generate_data,
35 )
36
37 get_mean_squared_value = SubmitRayJob(
38 task_id="SubmitRayJob",
39 conn_id=CONN_ID,
40 entrypoint="python ray_script.py {{ ti.xcom_pull(task_ids='generate_data') | join(' ') }}",
41 runtime_env=RAY_RUNTIME_ENV,
42 num_cpus=1,
43 num_gpus=0,
44 memory=0,
45 resources={},
46 xcom_task_key="SubmitRayJob.dashboard",
47 fetch_logs=True,
48 wait_for_completion=True,
49 job_timeout_seconds=600,
50 poll_interval=5,
51 )
52
53 chain(data, get_mean_squared_value)
54
55
56ray_tutorial()

This is a simple Dag comprised of two tasks:

  • The generate_data task randomly generates a list of 10 integers.
  • The get_mean_squared_value task submits a Ray job on Anyscale to calculate the mean squared value of the list of integers.
  1. (Optional). If you are using the traditional syntax with the SubmitRayJob operator, you need to provide the Python code to run in the Ray job as a script. Create a new file in your dags directory called ray_script.py and add the following code:

    1# ray_script.py
    2import numpy as np
    3import ray
    4import argparse
    5
    6@ray.remote
    7def square(x):
    8 return x**2
    9
    10def main(data):
    11 ray.init()
    12 data = np.array(data)
    13 futures = [square.remote(x) for x in data]
    14 results = ray.get(futures)
    15 mean = np.mean(results)
    16 print(f"Mean of this population is {mean}")
    17 return mean
    18
    19if __name__ == "__main__":
    20 parser = argparse.ArgumentParser(description="Process some integers.")
    21 parser.add_argument('data', nargs='+', type=float, help='List of numbers to process')
    22 args = parser.parse_args()
    23
    24 data = args.data
    25 main(data)

Step 4: Run the Dag

  1. In the Airflow UI, click the play button to manually run your Dag.

  2. After the Dag runs successfully, check go to your Ray dashboard to see the job submitted by Airflow.

    Ray dashboard showing a Job completed successfully.

Conclusion

Congratulations! You’ve run a Ray job using Apache Airflow. You can now use the Ray provider package to orchestrate more complex Ray jobs, see Processing User Feedback: an LLM-fine-tuning reference architecture with Ray on Anyscale for an example.