06.05, Katowice AWS Summit Poland
16 min read

Running GPU jobs on AWS Batch

Explore how to best run periodic GPU-intensive machine learning inference on cloud GPUs with AWS Batch.



After successfully training and validating a machine learning algorithm, the next inevitable step in its life cycle is usually a production deployment. While hosting an always-on real-time API is the most popular way to do it, it is also not the only way.

In many scenarios — such as image processing, content generation, personalization and recommendation, and metadata generation — you should consider running some of the inference offline, in batches.

Consider a scenario we’ve had in one of our projects:

While it was tempting to hide our processing mechanism behind a 24/7 real-time API for another system, this would dramatically increase the project’s cost for very little benefit. At the time of writing, serverless GPU workloads are not an option and our mechanism requires a GPU. Therefore, we’d need to run at least one GPU instance 24/7 (or more for higher availability). These aren’t cheap.

A similar scenario might occur in your project as well.

The solution

However, after carefully reviewing the use case with the customer, we’ve noticed that they don’t need the results immediately. What we’ve settled upon was:

This slashed the AWS and maintenance costs of the project drastically and made the implementation easier.

There are various methods you can use to run “offline jobs” on AWS. SageMaker Batch Transform is one way to do it, but due to cost constraints and no other SageMaker services used in the solution, we have opted for a cheaper option, namely AWS Batch. To our surprise, running GPU-based AWS Batch comes with some caveats and there aren’t many materials available on how to run them properly.

To save you the hassle, here’s a simple hello-GPU AWS Batch solution to get you started.

Diving into AWS Batch

Our imaginary workflow will process images present in one S3 bucket via AWS Batch and an arbitrary machine learning algorithm and finally save the results to another S3 bucket.

The GPU-intensive algorithm that we’ll use is Segment Anything from Meta. Of course, your processing script can be anything. You could run generative AI models such as Llama 3.1 70B for all we know — Segment Anything is just an example.

We will also schedule the job to run daily using Amazon EventBridge.

An overview of our AWS Batch-based GPU processing solution
An overview of our AWS Batch-based GPU processing solution.

Our solution uses Terraform and Python, but you can recreate it in your tooling of choice. We will also use some premade modules, such as terraform-aws-modules to speed us up. Note that we skip the basics, such as AWS authentication or applying changes via Terraform. You need to fill in the blanks, but we provide the “hard parts.”

We’ll start with defining some Terraform identifiers for future use:

main.tfhcl
locals {  id       = "batch-gpu-example"  vpc_cidr = "172.31.0.0/16"}

Make sure that you tweak the locals.id value with your own random-ish name, otherwise you’ll have a collision.

Then, we declare create two buckets — one for input and one for output:

main.tfhcl
resource "aws_s3_bucket" "input" {  bucket = "${local.id}-input"}resource "aws_s3_bucket" "output" {  bucket = "${local.id}-output"}

Our processing script has the following flow:

  1. It loads images from a given S3 bucket,
  2. It processes those images using SAM,
  3. Finally, it saves them back to another S3 bucket.

The algorithm of our choice is open source, and we will just pack the weights into the container itself — in your case, you could pull it from SageMaker Model Registry or other places.

Also, note that we use a certain convention here — the script gets the date in the input (YYYY/MM/DD) and will process only files from the input S3 bucket that are present in the YYYY/MM/DD directory. This is an example of how you can divide your tasks workload by day in a clear way.

main.pypython
import osfrom io import BytesIOfrom pathlib import Pathimport boto3import cv2import numpy as npimport supervision as svimport torchfrom PIL import Imagefrom sam2.automatic_mask_generator import SAM2AutomaticMaskGeneratorINPUT_BUCKET_NAME = os.environ["INPUT_BUCKET_NAME"]INPUT_BUCKET_PREFIX = os.environ["INPUT_BUCKET_PREFIX"]OUTPUT_BUCKET_NAME = os.environ["OUTPUT_BUCKET_NAME"]REGION = os.getenv("REGION", "eu-west-1")def save_array_to_img_buffer(image: np.ndarray, buf: BytesIO) -> None:    image_pil = Image.fromarray(image)    image_pil.save(buf, format="JPEG")    buf.seek(0)def list_files_in_dir_and_subdirs(dir: Path) -> list[Path]:    return [file for file in dir.glob("**/*") if file.is_file()]def download_folder_from_s3(s3_resource, bucket_name: str, s3_prefix: str, local_path: Path):    bucket = s3_resource.Bucket(bucket_name)    for obj in bucket.objects.filter(Prefix=s3_prefix):        target = local_path / Path(obj.key).relative_to(s3_prefix)        target.parent.mkdir(parents=True, exist_ok=True)        bucket.download_file(obj.key, target)def main() -> None:    session = boto3.Session(region_name=REGION)    s3_client = session.client("s3")    s3_resource = session.resource("s3")    mask_generator = SAM2AutomaticMaskGenerator.from_pretrained(        "facebook/sam2-hiera-large", points_per_batch=4    )    mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)    temp_path: Path = Path("tmp")    temp_path.mkdir(parents=True, exist_ok=True)    download_folder_from_s3(        s3_resource=s3_resource,        bucket_name=INPUT_BUCKET_NAME,        s3_prefix=INPUT_BUCKET_PREFIX,        local_path=temp_path,    )    image_paths_to_process: list[Path] = list_files_in_dir_and_subdirs(temp_path)    for image_path in image_paths_to_process:        image: np.ndarray = cv2.imread(str(image_path))        with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):            masks = mask_generator.generate(image)            detections = sv.Detections.from_sam(sam_result=masks)            annotated_image = mask_annotator.annotate(                scene=image.copy(), detections=detections            )        with BytesIO() as buf:            save_array_to_img_buffer(annotated_image, buf)            s3_client.upload_fileobj(                buf,                OUTPUT_BUCKET_NAME,                f"{INPUT_BUCKET_PREFIX}/{image_path.name}",            )if __name__ == "__main__":    main()

The next step is to create an ECR repository because AWS Batch runs Docker containers and container images need to be placed somewhere. You can, of course use your own image registry instead.

main.tfhcl
resource "aws_ecr_repository" "this" {  name = "${local.id}-repository"}

We will also need to create a relevant Dockerfile:

Dockerfile
FROM nvcr.io/nvidia/pytorch:24.01-py3WORKDIR /appRUN apt-get update && apt-get install -y gitRUN git clone https://github.com/facebookresearch/segment-anything-2.gitWORKDIR /app/segment-anything-2RUN python -m pip install -e .WORKDIR /appCOPY requirements.txt .RUN python -m pip install --no-cache-dir -r requirements.txtCOPY main.py main.pyENTRYPOINT ["python", "main.py"]

With requirements.txt relevant to our processing script:

requirements.txttext
opencv-python-headless==4.8.0.74supervision==0.22.0boto3>=1.35.1huggingface-hub==0.24.6

Having all that, perform terraform apply and proceed. By this point, just the ECR and some S3 buckets should have been created.

The next step is to build and push the image straight to ECR. Here are the commands that you may use to do it locally from the CLI, provided you’re correctly authenticated and authorized:

build_and_push.shsh
#!/usr/bin/env bashimage=$1if [ "$image" == "" ]then    echo "Usage: $0 <image-name>"    exit 1fi# Get the account number associated with the current IAM credentialsaccount=$(aws sts get-caller-identity --query Account --output text)if [ $? -ne 0 ]then    exit 255fi# Get the region defined in the current configuration (default to us-west-2 if none defined)region=$(aws configure get region)fullname="${account}.dkr.ecr.${region}.amazonaws.com/${image}:latest"# If the repository doesn't exist in ECR, create it.aws ecr describe-repositories --repository-names "${image}" > /dev/null 2>&1if [ $? -ne 0 ]then    aws ecr create-repository --repository-name "${image}" > /dev/nullfi# Get the login command from ECR and execute it directlyaws ecr get-login-password --region "${region}" | docker login --username AWS --password-stdin "${account}".dkr.ecr."${region}".amazonaws.com# Build the docker image locally with the image name and then push it to ECR# with the full name.docker build  -t ${image} .docker tag ${image} ${fullname}docker push ${fullname}

And then run:

chmod +x ./build_and_push.sh./build_and_push batch-gpu-example

Note that building that image may take some time. In real-world projects, this step would have been done in some kind of CI/CD pipeline. The Dockerfile could also be optimized. We skip these parts for clarity, though.

The last important step is to define AWS Batch compute, task definitions, and other relevant resources.

We must also configure the network for the compute environment — particularly subnets. We could make use of the existing default VPC’s subnets, however we will create them manually:

main.tfhcl
module "vpc" {  source  = "terraform-aws-modules/vpc/aws"  version = "5.8.1"  name = "${local.id}-vpc"  cidr = local.vpc_cidr  azs = ["${local.aws_region}a", "${local.aws_region}b", "${local.aws_region}c"]  public_subnets = [    cidrsubnet(local.vpc_cidr, 8, 0),    cidrsubnet(local.vpc_cidr, 8, 4),    cidrsubnet(local.vpc_cidr, 8, 8)  ]  private_subnets = [    cidrsubnet(local.vpc_cidr, 8, 12),    cidrsubnet(local.vpc_cidr, 8, 16),    cidrsubnet(local.vpc_cidr, 8, 20)  ]  enable_nat_gateway = true  single_nat_gateway = true  public_route_table_tags  = { Name = "${local.id}-public" }  public_subnet_tags       = { Name = "${local.id}-public" }  private_route_table_tags = { Name = "${local.id}-private" }  private_subnet_tags      = { Name = "${local.id}-private" }  enable_dhcp_options  = true  enable_dns_hostnames = true}

With the VPC in place (you can reuse the existing one if you wish), there’s now only a couple of things left to define.

AWS Batch compute environment:

main.tfhcl
module "batch" {  source  = "terraform-aws-modules/batch/aws"  version = "2.0.2"  instance_iam_role_name        = "${local.id}-inst-rol"  instance_iam_role_path        = "/batch/"  instance_iam_role_description = "IAM instance role/profile for AWS Batch ECS instance(s)"  service_iam_role_name        = "${local.id}-batch"  service_iam_role_path        = "/batch/"  service_iam_role_description = "IAM service role for AWS Batch"  compute_environments = {    ec2 = {      name_prefix = "ec2"      compute_resources = {        type           = "EC2"        min_vcpus      = 0        max_vcpus      = 16        desired_vcpus  = 4        instance_types = ["g4dn.xlarge"]        launch_template = {          launch_template_id = aws_launch_template.this.id          version            = "$Latest"        }        subnets = module.vpc.private_subnets      }    }  }  job_queues = {    default = {      name     = "${local.id}-jobq-ec2"      state    = "ENABLED"      priority = 100      create_scheduling_policy = false      compute_environments = ["ec2"]    }  }  job_definitions = {    image_processing = {      name = "${local.id}-img-proc"      container_properties = jsonencode({        image      = "${aws_ecr_repository.this.repository_url}:latest"        jobRoleArn = aws_iam_role.ecs_task_execution_role.arn,        resourceRequirements = [          { type = "VCPU", value = "4" },          { type = "MEMORY", value = "14336" },          { type = "GPU", value = "1" }        ]        logConfiguration = {          logDriver = "awslogs"          options = {            awslogs-group         = aws_cloudwatch_log_group.this.id            awslogs-region        = local.aws_region            awslogs-stream-prefix = local.id          }        }        environment = [          {            name  = "NVIDIA_DRIVER_CAPABILITIES",            value = "all"          },          {            name  = "INPUT_BUCKET_NAME",            value = aws_s3_bucket.input.bucket          },          {            name  = "OUTPUT_BUCKET_NAME",            value = aws_s3_bucket.output.bucket          },          {            name  = "INPUT_BUCKET_PREFIX",            value = "2024/08/29"          }        ]      })      retry_strategy = {        attempts = 2        evaluate_on_exit = {          retry_error = {            action       = "RETRY"            on_exit_code = 1          }          exit_success = {            action       = "EXIT"            on_exit_code = 0          }        }      }    }  }}resource "aws_cloudwatch_log_group" "this" {  name              = "/aws/batch/${local.id}"  retention_in_days = 14}resource "aws_launch_template" "this" {  name_prefix = "${local.id}-launch-template"  block_device_mappings {    device_name = "/dev/xvda"    ebs {      volume_size = 256      volume_type = "gp2"    }  }}

In our case, there were several tricky parts here:

Having all that, we just define the necessary IAM policies via Terraform:

main.tfhcl
data "aws_iam_policy_document" "ecs_task_execution_role_policy" {  statement {    actions = ["sts:AssumeRole"]    principals {      type        = "Service"      identifiers = ["ecs-tasks.amazonaws.com"]    }  }}data "aws_iam_policy_document" "allow_s3_policy" {  statement {    effect = "Allow"    actions = [      "s3:ListBucket",      "s3:GetObject",      "s3:PutObject"    ]    resources = [      "aws_s3_bucket.input.arn",      "${aws_s3_bucket.input.arn}/*",      "aws_s3_bucket.output.arn",      "${aws_s3_bucket.output.arn}/*",    ]  }}resource "aws_iam_role" "ecs_task_execution_role" {  name               = "${local.id}-ecs-role"  assume_role_policy = data.aws_iam_policy_document.ecs_task_execution_role_policy.json}resource "aws_iam_role_policy" "this" {  name   = "${local.id}-ecs-pol"  role   = aws_iam_role.ecs_task_execution_role.id  policy = data.aws_iam_policy_document.allow_s3_policy.json}

A cherry on top is scheduling it via EventBridge. We will configure it to run every day at 5:00 CET by scheduling a Lambda function.

So, define a Lambda:

main.tfhcl
module "run_job_lambda_function" {  source  = "terraform-aws-modules/lambda/aws"  version = "6.3.0"  function_name = "${local.id}-run-job}"  description   = "Lambda for triggering Image Processing Batch Job."  handler       = "handler.lambda_handler"  timeout       = 30  memory_size   = 128  runtime = "python3.11"  source_path = "../run_job_lambda"  environment_variables = {    REGION         = local.aws_region    JOB_QUEUE      = module.batch.job_queues.default.name    JOB_DEFINITION = module.batch.job_definitions.image_processing.name  }  create_current_version_allowed_triggers = false  trusted_entities = ["scheduler.amazonaws.com"]  attach_policy_statements = true  policy_statements = {    batch = {      effect = "Allow",      actions = [        "batch:SubmitJob",      ]      resources = [        module.batch.job_definitions.image_processing.arn_prefix,        module.batch.job_queues.default.arn      ]    }  }  allowed_triggers = {    EventBridge = {      principal  = "scheduler.amazonaws.com"      source_arn = module.eventbridge.eventbridge_schedule_arns["${local.id}-lambda-cron"]    }  }}

With the following code:

run_job_lambda/handler.pypython
import osfrom datetime import datetime, timezonefrom http import HTTPStatusimport boto3REGION = os.getenv("REGION", "eu-west-1")JOB_QUEUE = os.environ["JOB_QUEUE"]JOB_DEFINITION = os.environ["JOB_DEFINITION"]def lambda_handler(event, context):    client = boto3.client("batch", REGION)    now = datetime.now(timezone.utc)    response = client.submit_job(        jobName=f"batch-job-example-{now.strftime("%Y%m%d")}",        jobQueue=JOB_QUEUE,        jobDefinition=JOB_DEFINITION,        containerOverrides={            "environment": [{              "name": "INPUT_BUCKET_PREFIX",              "value": now.strftime("%Y/%m/%d")            }]        }    )    status_code = response["ResponseMetadata"]["HTTPStatusCode"]    if status_code != HTTPStatus.OK:        return {"statusCode": status_code}

Note that this configuration instructs the batch job to pull files from the current date directory in the input S3 bucket, i.e. it opens the bucket and loads files from the /YYYY/MM/DD directory, where YYYY/MM/DD is today’s date. That’s just our convention and the way the processing script was written.

Finish everything with the EventBridge definition:

main.tfhcl
module "eventbridge" {  source  = "terraform-aws-modules/eventbridge/aws"  version = "3.7.0"  create_bus = false  attach_lambda_policy = true  role_name            = "${local.id}-eventbridge-role"  lambda_target_arns   = [module.run_job_lambda_function.lambda_function_arn]  schedules = {    "${local.id}-lambda-cron" = {      description         = "Trigger for a Lambda"      schedule_expression = "cron(0 17 * * ? *)" # Everyday at 17:00 UTC      timezone            = "UTC"      arn                 = module.run_job_lambda_function.lambda_function_arn    }  }}

After all the above steps are done and deployed (perform terraform apply now), you should have an end-to-end hello-GPU example processing code running on AWS Batch.

Upload any JPEG file of a car you can find on the Internet to the input bucket in the YYYY/MM/DD directory and wait until 5:00 PM CET passes.

…or just run it manually:

Conclusion

Machine learning inference does not have a silver bullet solution. Since serverless GPUs are not (yet?) available, batch offline processing is something you should consider for your non-time-sensitive ML workloads. You can slash the cost of the project and lower the maintenance effort drastically, making you and your customers happier.

The steps above should provide you with a great starting point. As we have mentioned, most of the code is arbitrary / for educational purposes — you will inevitably need to tweak it to match your expectations. The next step you might want is to parallelize the work by using more instances and splitting work evenly between them.

Happy processing!

Let's talk about your project

We'd love to answer your questions and help you thrive in the cloud.