Running GPU jobs on AWS Batch
Explore how to best run periodic GPU-intensive machine learning inference on cloud GPUs with AWS Batch.
AWS Batch
Amazon ECR
Amazon EventBridge
Amazon S3
Terraform
PyTorch
Python
After successfully training and validating a machine learning algorithm, the next inevitable step in its life cycle is usually a production deployment. While hosting an always-on real-time API is the most popular way to do it, it is also not the only way.
In many scenarios — such as image processing, content generation, personalization and recommendation, and metadata generation — you should consider running some of the inference offline, in batches.
Consider a scenario we’ve had in one of our projects:
- You process 10000s of images daily,
- The processing mechanism generates new images from existing ones for a particular website,
- It does not need to happen immediately. We can easily wait a day before processing a new batch of images.
While it was tempting to hide our processing mechanism behind a 24/7 real-time API for another system, this would dramatically increase the project’s cost for very little benefit. At the time of writing, serverless GPU workloads are not an option and our mechanism requires a GPU. Therefore, we’d need to run at least one GPU instance 24/7 (or more for higher availability). These aren’t cheap.
A similar scenario might occur in your project as well.
The solution
However, after carefully reviewing the use case with the customer, we’ve noticed that they don’t need the results immediately. What we’ve settled upon was:
- we let the external system upload images whenever they want via a traditional serverless API to upload images on S3;
- once per day, we run a batch job to process all images from a given day and send them back to the external system.
This slashed the AWS and maintenance costs of the project drastically and made the implementation easier.
There are various methods you can use to run “offline jobs” on AWS. SageMaker Batch Transform is one way to do it, but due to cost constraints and no other SageMaker services used in the solution, we have opted for a cheaper option, namely AWS Batch. To our surprise, running GPU-based AWS Batch comes with some caveats and there aren’t many materials available on how to run them properly.
To save you the hassle, here’s a simple hello-GPU AWS Batch solution to get you started.
Diving into AWS Batch
Our imaginary workflow will process images present in one S3 bucket via AWS Batch and an arbitrary machine learning algorithm and finally save the results to another S3 bucket.
The GPU-intensive algorithm that we’ll use is Segment Anything from Meta. Of course, your processing script can be anything. You could run generative AI models such as Llama 3.1 70B for all we know — Segment Anything is just an example.
We will also schedule the job to run daily using Amazon EventBridge.

Our solution uses Terraform and Python, but you can recreate it in your tooling of choice. We will also use some premade modules, such as terraform-aws-modules
to speed us up. Note that we skip the basics, such as AWS authentication or applying changes via Terraform. You need to fill in the blanks, but we provide the “hard parts.”
We’ll start with defining some Terraform identifiers for future use:
locals {
id = "batch-gpu-example"
vpc_cidr = "172.31.0.0/16"
}
Make sure that you tweak the locals.id
value with your own random-ish name, otherwise you’ll have a collision.
Then, we declare create two buckets — one for input and one for output:
resource "aws_s3_bucket" "input" {
bucket = "${local.id}-input"
}
resource "aws_s3_bucket" "output" {
bucket = "${local.id}-output"
}
Our processing script has the following flow:
- It loads images from a given S3 bucket,
- It processes those images using SAM,
- Finally, it saves them back to another S3 bucket.
The algorithm of our choice is open source, and we will just pack the weights into the container itself — in your case, you could pull it from SageMaker Model Registry or other places.
Also, note that we use a certain convention here — the script gets the date in the input (YYYY/MM/DD
) and will process only files from the input S3 bucket that are present in the YYYY/MM/DD
directory. This is an example of how you can divide your tasks workload by day in a clear way.
import os
from io import BytesIO
from pathlib import Path
import boto3
import cv2
import numpy as np
import supervision as sv
import torch
from PIL import Image
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
INPUT_BUCKET_NAME = os.environ["INPUT_BUCKET_NAME"]
INPUT_BUCKET_PREFIX = os.environ["INPUT_BUCKET_PREFIX"]
OUTPUT_BUCKET_NAME = os.environ["OUTPUT_BUCKET_NAME"]
REGION = os.getenv("REGION", "eu-west-1")
def save_array_to_img_buffer(image: np.ndarray, buf: BytesIO) -> None:
image_pil = Image.fromarray(image)
image_pil.save(buf, format="JPEG")
buf.seek(0)
def list_files_in_dir_and_subdirs(dir: Path) -> list[Path]:
return [file for file in dir.glob("**/*") if file.is_file()]
def download_folder_from_s3(s3_resource, bucket_name: str, s3_prefix: str, local_path: Path):
bucket = s3_resource.Bucket(bucket_name)
for obj in bucket.objects.filter(Prefix=s3_prefix):
target = local_path / Path(obj.key).relative_to(s3_prefix)
target.parent.mkdir(parents=True, exist_ok=True)
bucket.download_file(obj.key, target)
def main() -> None:
session = boto3.Session(region_name=REGION)
s3_client = session.client("s3")
s3_resource = session.resource("s3")
mask_generator = SAM2AutomaticMaskGenerator.from_pretrained(
"facebook/sam2-hiera-large", points_per_batch=4
)
mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
temp_path: Path = Path("tmp")
temp_path.mkdir(parents=True, exist_ok=True)
download_folder_from_s3(
s3_resource=s3_resource,
bucket_name=INPUT_BUCKET_NAME,
s3_prefix=INPUT_BUCKET_PREFIX,
local_path=temp_path,
)
image_paths_to_process: list[Path] = list_files_in_dir_and_subdirs(temp_path)
for image_path in image_paths_to_process:
image: np.ndarray = cv2.imread(str(image_path))
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
masks = mask_generator.generate(image)
detections = sv.Detections.from_sam(sam_result=masks)
annotated_image = mask_annotator.annotate(
scene=image.copy(), detections=detections
)
with BytesIO() as buf:
save_array_to_img_buffer(annotated_image, buf)
s3_client.upload_fileobj(
buf,
OUTPUT_BUCKET_NAME,
f"{INPUT_BUCKET_PREFIX}/{image_path.name}",
)
if __name__ == "__main__":
main()
The next step is to create an ECR repository because AWS Batch runs Docker containers and container images need to be placed somewhere. You can, of course use your own image registry instead.
resource "aws_ecr_repository" "this" {
name = "${local.id}-repository"
}
We will also need to create a relevant Dockerfile:
FROM nvcr.io/nvidia/pytorch:24.01-py3
WORKDIR /app
RUN apt-get update && apt-get install -y git
RUN git clone https://github.com/facebookresearch/segment-anything-2.git
WORKDIR /app/segment-anything-2
RUN python -m pip install -e .
WORKDIR /app
COPY requirements.txt .
RUN python -m pip install --no-cache-dir -r requirements.txt
COPY main.py main.py
ENTRYPOINT ["python", "main.py"]
With requirements.txt
relevant to our processing script:
opencv-python-headless==4.8.0.74
supervision==0.22.0
boto3>=1.35.1
huggingface-hub==0.24.6
Having all that, perform terraform apply
and proceed. By this point, just the ECR and some S3 buckets should have been created.
The next step is to build and push the image straight to ECR. Here are the commands that you may use to do it locally from the CLI, provided you’re correctly authenticated and authorized:
#!/usr/bin/env bash
image=$1
if [ "$image" == "" ]
then
echo "Usage: $0 <image-name>"
exit 1
fi
# Get the account number associated with the current IAM credentials
account=$(aws sts get-caller-identity --query Account --output text)
if [ $? -ne 0 ]
then
exit 255
fi
# Get the region defined in the current configuration (default to us-west-2 if none defined)
region=$(aws configure get region)
fullname="${account}.dkr.ecr.${region}.amazonaws.com/${image}:latest"
# If the repository doesn't exist in ECR, create it.
aws ecr describe-repositories --repository-names "${image}" > /dev/null 2>&1
if [ $? -ne 0 ]
then
aws ecr create-repository --repository-name "${image}" > /dev/null
fi
# Get the login command from ECR and execute it directly
aws ecr get-login-password --region "${region}" | docker login --username AWS --password-stdin "${account}".dkr.ecr."${region}".amazonaws.com
# Build the docker image locally with the image name and then push it to ECR
# with the full name.
docker build -t ${image} .
docker tag ${image} ${fullname}
docker push ${fullname}
And then run:
chmod +x ./build_and_push.sh
./build_and_push batch-gpu-example
Note that building that image may take some time. In real-world projects, this step would have been done in some kind of CI/CD pipeline. The Dockerfile could also be optimized. We skip these parts for clarity, though.
The last important step is to define AWS Batch compute, task definitions, and other relevant resources.
We must also configure the network for the compute environment — particularly subnets. We could make use of the existing default VPC’s subnets, however we will create them manually:
module "vpc" {
source = "terraform-aws-modules/vpc/aws"
version = "5.8.1"
name = "${local.id}-vpc"
cidr = local.vpc_cidr
azs = ["${local.aws_region}a", "${local.aws_region}b", "${local.aws_region}c"]
public_subnets = [
cidrsubnet(local.vpc_cidr, 8, 0),
cidrsubnet(local.vpc_cidr, 8, 4),
cidrsubnet(local.vpc_cidr, 8, 8)
]
private_subnets = [
cidrsubnet(local.vpc_cidr, 8, 12),
cidrsubnet(local.vpc_cidr, 8, 16),
cidrsubnet(local.vpc_cidr, 8, 20)
]
enable_nat_gateway = true
single_nat_gateway = true
public_route_table_tags = { Name = "${local.id}-public" }
public_subnet_tags = { Name = "${local.id}-public" }
private_route_table_tags = { Name = "${local.id}-private" }
private_subnet_tags = { Name = "${local.id}-private" }
enable_dhcp_options = true
enable_dns_hostnames = true
}
With the VPC in place (you can reuse the existing one if you wish), there’s now only a couple of things left to define.
AWS Batch compute environment:
module "batch" {
source = "terraform-aws-modules/batch/aws"
version = "2.0.2"
instance_iam_role_name = "${local.id}-inst-rol"
instance_iam_role_path = "/batch/"
instance_iam_role_description = "IAM instance role/profile for AWS Batch ECS instance(s)"
service_iam_role_name = "${local.id}-batch"
service_iam_role_path = "/batch/"
service_iam_role_description = "IAM service role for AWS Batch"
compute_environments = {
ec2 = {
name_prefix = "ec2"
compute_resources = {
type = "EC2"
min_vcpus = 0
max_vcpus = 16
desired_vcpus = 4
instance_types = ["g4dn.xlarge"]
launch_template = {
launch_template_id = aws_launch_template.this.id
version = "$Latest"
}
subnets = module.vpc.private_subnets
}
}
}
job_queues = {
default = {
name = "${local.id}-jobq-ec2"
state = "ENABLED"
priority = 100
create_scheduling_policy = false
compute_environments = ["ec2"]
}
}
job_definitions = {
image_processing = {
name = "${local.id}-img-proc"
container_properties = jsonencode({
image = "${aws_ecr_repository.this.repository_url}:latest"
jobRoleArn = aws_iam_role.ecs_task_execution_role.arn,
resourceRequirements = [
{ type = "VCPU", value = "4" },
{ type = "MEMORY", value = "14336" },
{ type = "GPU", value = "1" }
]
logConfiguration = {
logDriver = "awslogs"
options = {
awslogs-group = aws_cloudwatch_log_group.this.id
awslogs-region = local.aws_region
awslogs-stream-prefix = local.id
}
}
environment = [
{
name = "NVIDIA_DRIVER_CAPABILITIES",
value = "all"
},
{
name = "INPUT_BUCKET_NAME",
value = aws_s3_bucket.input.bucket
},
{
name = "OUTPUT_BUCKET_NAME",
value = aws_s3_bucket.output.bucket
},
{
name = "INPUT_BUCKET_PREFIX",
value = "2024/08/29"
}
]
})
retry_strategy = {
attempts = 2
evaluate_on_exit = {
retry_error = {
action = "RETRY"
on_exit_code = 1
}
exit_success = {
action = "EXIT"
on_exit_code = 0
}
}
}
}
}
}
resource "aws_cloudwatch_log_group" "this" {
name = "/aws/batch/${local.id}"
retention_in_days = 14
}
resource "aws_launch_template" "this" {
name_prefix = "${local.id}-launch-template"
block_device_mappings {
device_name = "/dev/xvda"
ebs {
volume_size = 256
volume_type = "gp2"
}
}
}
In our case, there were several tricky parts here:
- the default root size was 8 GB for AWS Batch, so we needed to modify the launch template,
- we needed to set up the
NVIDIA_DRIVER_CAPABILITIES
variable for this particular image, - the
resourceRequirements
needed to point at GPU.
Having all that, we just define the necessary IAM policies via Terraform:
data "aws_iam_policy_document" "ecs_task_execution_role_policy" {
statement {
actions = ["sts:AssumeRole"]
principals {
type = "Service"
identifiers = ["ecs-tasks.amazonaws.com"]
}
}
}
data "aws_iam_policy_document" "allow_s3_policy" {
statement {
effect = "Allow"
actions = [
"s3:ListBucket",
"s3:GetObject",
"s3:PutObject"
]
resources = [
"aws_s3_bucket.input.arn",
"${aws_s3_bucket.input.arn}/*",
"aws_s3_bucket.output.arn",
"${aws_s3_bucket.output.arn}/*",
]
}
}
resource "aws_iam_role" "ecs_task_execution_role" {
name = "${local.id}-ecs-role"
assume_role_policy = data.aws_iam_policy_document.ecs_task_execution_role_policy.json
}
resource "aws_iam_role_policy" "this" {
name = "${local.id}-ecs-pol"
role = aws_iam_role.ecs_task_execution_role.id
policy = data.aws_iam_policy_document.allow_s3_policy.json
}
A cherry on top is scheduling it via EventBridge. We will configure it to run every day at 5:00 CET by scheduling a Lambda function.
So, define a Lambda:
module "run_job_lambda_function" {
source = "terraform-aws-modules/lambda/aws"
version = "6.3.0"
function_name = "${local.id}-run-job}"
description = "Lambda for triggering Image Processing Batch Job."
handler = "handler.lambda_handler"
timeout = 30
memory_size = 128
runtime = "python3.11"
source_path = "../run_job_lambda"
environment_variables = {
REGION = local.aws_region
JOB_QUEUE = module.batch.job_queues.default.name
JOB_DEFINITION = module.batch.job_definitions.image_processing.name
}
create_current_version_allowed_triggers = false
trusted_entities = ["scheduler.amazonaws.com"]
attach_policy_statements = true
policy_statements = {
batch = {
effect = "Allow",
actions = [
"batch:SubmitJob",
]
resources = [
module.batch.job_definitions.image_processing.arn_prefix,
module.batch.job_queues.default.arn
]
}
}
allowed_triggers = {
EventBridge = {
principal = "scheduler.amazonaws.com"
source_arn = module.eventbridge.eventbridge_schedule_arns["${local.id}-lambda-cron"]
}
}
}
With the following code:
import os
from datetime import datetime, timezone
from http import HTTPStatus
import boto3
REGION = os.getenv("REGION", "eu-west-1")
JOB_QUEUE = os.environ["JOB_QUEUE"]
JOB_DEFINITION = os.environ["JOB_DEFINITION"]
def lambda_handler(event, context):
client = boto3.client("batch", REGION)
now = datetime.now(timezone.utc)
response = client.submit_job(
jobName=f"batch-job-example-{now.strftime("%Y%m%d")}",
jobQueue=JOB_QUEUE,
jobDefinition=JOB_DEFINITION,
containerOverrides={
"environment": [{
"name": "INPUT_BUCKET_PREFIX",
"value": now.strftime("%Y/%m/%d")
}]
}
)
status_code = response["ResponseMetadata"]["HTTPStatusCode"]
if status_code != HTTPStatus.OK:
return {"statusCode": status_code}
Note that this configuration instructs the batch job to pull files from the current date directory in the input S3 bucket, i.e. it opens the bucket and loads files from the /YYYY/MM/DD
directory, where YYYY/MM/DD
is today’s date. That’s just our convention and the way the processing script was written.
Finish everything with the EventBridge definition:
module "eventbridge" {
source = "terraform-aws-modules/eventbridge/aws"
version = "3.7.0"
create_bus = false
attach_lambda_policy = true
role_name = "${local.id}-eventbridge-role"
lambda_target_arns = [module.run_job_lambda_function.lambda_function_arn]
schedules = {
"${local.id}-lambda-cron" = {
description = "Trigger for a Lambda"
schedule_expression = "cron(0 17 * * ? *)" # Everyday at 17:00 UTC
timezone = "UTC"
arn = module.run_job_lambda_function.lambda_function_arn
}
}
}
After all the above steps are done and deployed (perform terraform apply
now), you should have an end-to-end hello-GPU example processing code running on AWS Batch.
Upload any JPEG file of a car you can find on the Internet to the input bucket in the YYYY/MM/DD
directory and wait until 5:00 PM CET passes.
…or just run it manually:
- run AWS Batch job via AWS Console specifying
INPUT_BUCKET_PREFIX
in the environment variables toYYYY/MM/DD
, such as 2024/09/09; - wait until it finishes;
- see the results in the target S3 bucket.
Conclusion
Machine learning inference does not have a silver bullet solution. Since serverless GPUs are not (yet?) available, batch offline processing is something you should consider for your non-time-sensitive ML workloads. You can slash the cost of the project and lower the maintenance effort drastically, making you and your customers happier.
The steps above should provide you with a great starting point. As we have mentioned, most of the code is arbitrary / for educational purposes — you will inevitably need to tweak it to match your expectations. The next step you might want is to parallelize the work by using more instances and splitting work evenly between them.
Happy processing!