Apache Spark: how to cancel job in code and kill running tasks?
Asked Answered
I

4

5

I am running a Spark application (version 1.6.0) on a Hadoop cluster with Yarn (version 2.6.0) in client mode. I have a piece of code that runs a long computation, and I want to kill it if it takes too long (and then run some other function instead).
Here is an example:

val conf = new SparkConf().setAppName("TIMEOUT_TEST")
val sc = new SparkContext(conf)
val lst = List(1,2,3)
// setting up an infite action
val future = sc.parallelize(lst).map(while (true) _).collectAsync()

try {
    Await.result(future, Duration(30, TimeUnit.SECONDS))
    println("success!")
} catch {
    case _:Throwable =>
        future.cancel()
        println("timeout")
}

// sleep for 1 hour to allow inspecting the application in yarn
Thread.sleep(60*60*1000)
sc.stop()

The timeout is set for 30 seconds, but of course the computation is infinite, and so Awaiting on the result of the future will throw an Exception, which will be caught and then the future will be canceled and the backup function will execute.
This all works perfectly well, except that the canceled job doesn't terminate completely: when looking at the web UI for the application, the job is marked as failed, but I can see there are still running tasks inside.

The same thing happens when I use SparkContext.cancelAllJobs or SparkContext.cancelJobGroup. The problem is that even though I manage to get on with my program, the running tasks of the canceled job are still hogging valuable resources (which will eventually slow me down to a near stop).

To sum things up: How do I kill a Spark job in a way that will also terminate all running tasks of that job? (as opposed to what happens now, which is stopping the job from running new tasks, but letting the currently running tasks finish)

UPDATE:
After a long time ignoring this problem, we found a messy but efficient little workaround. Instead of trying to kill the appropriate Spark Job/Stage from within the Spark application, we simply logged the stage ID of all active stages when the timeout occurred, and issued an HTTP GET request to the URL presented by the Spark Web UI used for killing said stages.

Incorporeal answered 25/9, 2016 at 10:17 Comment(2)
How about stopping the SparkContext in case of the exception, i.e. SparkContext.stop() And initialise a new SparkContext for subsequent jobs.Ohg
Unfortunately for me this is not an option, as the SparkContext is shared throughout the project. Stopping it and starting a new one will cause other modules to fail, since they are now holding a reference to a closed SparkContextIncorporeal
P
3

For the sake of future visitors, Spark introduced the Spark task reaper since 2.0.3, which does address this scenario (more or less) and is a built-in solution. Note that is can kill an Executor eventually, if the task is not responsive.

Moreover, some built-in Spark sources of data have been refactored to be more responsive to spark:

For the 1.6.0 version, Zohar's solution is a "messy but efficient" one.

Permalloy answered 20/9, 2020 at 10:12 Comment(0)
E
5

I don't know if this answers your question. My need was to kill the jobs hanging for longer duration (my jobs extract data from Oracle tables, but for some unknown reason, seldom the connection hangs forever).

After some study, I came to this solution:

import scala.concurrent.Future
import scala.concurrent.ExecutionContext.Implicits.global
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.JobExecutionStatus

val MAX_JOB_SECONDS = 100
val statusTracker = sc.statusTracker;
val sparkListener = new SparkListener()  
{ 
    
    override def onJobStart(jobStart : SparkListenerJobStart)     
    {
        val jobId = jobStart.jobId
        val f = Future 
        {
            var c = MAX_JOB_SECONDS;
            var mustCancel = false;
            var running = true;
            while(!mustCancel && running)
            {
                Thread.sleep(1000);
                c = c - 1;
                mustCancel = c <= 0;
                val jobInfo = statusTracker.getJobInfo(jobId);
                if(jobInfo!=null)
                {
                    val v = jobInfo.get.status()
                    running = v == JobExecutionStatus.RUNNING
                }
                else
                    running = false;
            }
            if(mustCancel)
            {
              sc.cancelJob(jobId)
            }
        }
    }
}
sc.addSparkListener(sparkListener)
try
{
    val df = spark.sql("SELECT * FROM VERY_BIG_TABLE") //just an example of long-running-job
    println(df.count)
}
catch
{
    case exc: org.apache.spark.SparkException =>
    {
        if(exc.getMessage.contains("cancelled"))
            throw new Exception("Job forcibly cancelled")
        else
            throw exc
    }
    case ex : Throwable => 
    {
        println(s"Another exception: $ex")
    }
}
finally
{
    sc.removeSparkListener(sparkListener)
}
Embroider answered 4/11, 2019 at 10:6 Comment(0)
P
3

For the sake of future visitors, Spark introduced the Spark task reaper since 2.0.3, which does address this scenario (more or less) and is a built-in solution. Note that is can kill an Executor eventually, if the task is not responsive.

Moreover, some built-in Spark sources of data have been refactored to be more responsive to spark:

For the 1.6.0 version, Zohar's solution is a "messy but efficient" one.

Permalloy answered 20/9, 2020 at 10:12 Comment(0)
G
1

With Unity Catalog enabled cluster in shared mode, Spark context is no longer available as the Spark connect feature is introduced which you can read more about here

For interruption to work, there are two methods going forward, you can use addTag, removeTag and inturruptTag or inurruptOperation

sharing code below for your reference which you can modify based on your requirements.

def runQueryWithTag(query: String, tag: String): Unit = {
  try {
    spark.addTag(tag)
    val df = spark.sql(query)
    println(df.count)
  } finally {
    spark.removeTag(tag)
  }
}

import scala.concurrent.{Future, ExecutionContext}
import scala.concurrent.duration._
import ExecutionContext.Implicits.global
import scala.util.{Success, Failure}

val queriesWithTags = Seq(
  ("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b, system.information_schema.columns c", "tag3"),
  ("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b", "tag2"),
  ("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b, system.information_schema.columns c, system.information_schema.columns d, system.information_schema.columns e, system.information_schema.columns f, system.information_schema.columns g", "tag1")
)

val futures = queriesWithTags.map { case (query, tag) =>
  Future { runQueryWithTag(query, tag) }
}

Thread.sleep(30000)
println("Interrupting tag1")
spark.interruptTag("tag1")

OR

import scala.collection.mutable.ListBuffer
val list1 = ListBuffer[String]()

def runQuery(query: String): Unit = {
  val df = spark.sql(query).collectResult()
  val opid = df.operationId
  list1 += opid
  }

import scala.concurrent.{Future, ExecutionContext}
import scala.concurrent.duration._
import ExecutionContext.Implicits.global
import scala.util.{Success, Failure}

val queries = Seq(
  ("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b, system.information_schema.columns c, system.information_schema.columns d, system.information_schema.columns e"),
  ("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b, system.information_schema.columns c, system.information_schema.columns d, system.information_schema.columns e"),
  ("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b, system.information_schema.columns c, system.information_schema.columns d, system.information_schema.columns e")
)

val futures = queries.map { case (query) =>
  Future { runQuery(query) }
}

Thread.sleep(20000)
println("Interrupting query 1 !!!!")
println(list1)
spark.interruptOperation(list1(0))

update: 21st June 2024 Adding a better one with tracking of the tags and terminating individual tags with timeout

val timeoutQueue = new TimeoutQueue()

def runQueryWithTag(query: String, tag: String): Unit = {
  try {
    spark.addTag(tag)
    val df = spark.sql(query)
    println(df.count)
    val a = spark.getTags()
    a.foreach(println)
  } finally {
    println(s"Done with $tag")
    spark.removeTag(tag)
    timeoutQueue.remove(tag)
  }
}

import scala.concurrent.{Future, ExecutionContext}
import scala.concurrent.duration._
import ExecutionContext.Implicits.global
import scala.util.{Success, Failure}
import java.util.concurrent.{Executors, ConcurrentLinkedQueue, PriorityBlockingQueue}

val queriesWithTags = Seq(
  ("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b, system.information_schema.columns c", "tag3"),
  ("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b", "tag2"),
  ("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b, system.information_schema.columns c, system.information_schema.columns d, system.information_schema.columns e, system.information_schema.columns f, system.information_schema.columns g", "tag1")
  // ("SELECT * FROM system.information_schema.columns a", "tag1")
)

// Order them based on the expiration time.
case class TimedTag(tag: String, expirationTime: Long) extends Comparable[TimedTag] {
  override def compareTo(other: TimedTag): Int = expirationTime.compareTo(other.expirationTime)
}

class TimeoutQueue {
  private val queue = new PriorityBlockingQueue[TimedTag]()

  private def popExpired(): Option[String] = {
    val currentTime = System.currentTimeMillis()
    val headOption = Option(queue.peek())
    headOption match {
      case Some(head) if head.expirationTime <= currentTime =>
        queue.poll() // remove the head
        Some(head.tag)
      case _ => None
    }
  }

  def add(tag: String, timeout: FiniteDuration): Unit = {
    val expirationTime = System.currentTimeMillis() + timeout.toMillis
    queue.put(TimedTag(tag, expirationTime))
  }

  def remove(tag: String): Unit = {
    queue.removeIf(_.tag == tag)
  }

  def loop(checkInterval: FiniteDuration): Unit = {
      while (queue.size() > 0) { // while we have tags, wait for them
        popExpired() match {
          case Some(tag) => { println(s"Interrupting $tag"); spark.interruptTag(tag) }
          case None => // No item ready to process
        }
        // println(s"sleeping for $checkInterval")
        Thread.sleep(checkInterval.toMillis)
      }
  }
}

val futures = queriesWithTags.map { case (query, tag) =>
  Future { runQueryWithTag(query, tag) }
}

timeoutQueue.add("tag1", 30000.milliseconds)
timeoutQueue.add("tag2", 55000.milliseconds)
timeoutQueue.add("tag3", 42000.milliseconds)
timeoutQueue.loop(500.milliseconds)
Gorse answered 17/6 at 6:47 Comment(0)
A
0

According to setJobGroup:

"If interruptOnCancel is set to true for the job group, then job cancellation will result in Thread.interrupt() being called on the job's executor threads."

So the anno function in your map must be interruptible like this:

val future = sc.parallelize(lst).map(while (!Thread.interrupted) _).collectAsync()
Andreas answered 10/6, 2017 at 16:33 Comment(1)
Thanks, but unfortunately this won't work in my case, since I don't actually use collectAsync() in my code, but rather call a long running function from mllib using a future with timeout.Incorporeal

© 2022 - 2024 — McMap. All rights reserved.