With Unity Catalog enabled cluster in shared mode, Spark context is no longer available as the Spark connect feature is introduced which you can read more about here
For interruption to work, there are two methods going forward, you can use addTag, removeTag and inturruptTag or inurruptOperation
sharing code below for your reference which you can modify based on your requirements.
def runQueryWithTag(query: String, tag: String): Unit = {
try {
spark.addTag(tag)
val df = spark.sql(query)
println(df.count)
} finally {
spark.removeTag(tag)
}
}
import scala.concurrent.{Future, ExecutionContext}
import scala.concurrent.duration._
import ExecutionContext.Implicits.global
import scala.util.{Success, Failure}
val queriesWithTags = Seq(
("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b, system.information_schema.columns c", "tag3"),
("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b", "tag2"),
("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b, system.information_schema.columns c, system.information_schema.columns d, system.information_schema.columns e, system.information_schema.columns f, system.information_schema.columns g", "tag1")
)
val futures = queriesWithTags.map { case (query, tag) =>
Future { runQueryWithTag(query, tag) }
}
Thread.sleep(30000)
println("Interrupting tag1")
spark.interruptTag("tag1")
OR
import scala.collection.mutable.ListBuffer
val list1 = ListBuffer[String]()
def runQuery(query: String): Unit = {
val df = spark.sql(query).collectResult()
val opid = df.operationId
list1 += opid
}
import scala.concurrent.{Future, ExecutionContext}
import scala.concurrent.duration._
import ExecutionContext.Implicits.global
import scala.util.{Success, Failure}
val queries = Seq(
("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b, system.information_schema.columns c, system.information_schema.columns d, system.information_schema.columns e"),
("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b, system.information_schema.columns c, system.information_schema.columns d, system.information_schema.columns e"),
("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b, system.information_schema.columns c, system.information_schema.columns d, system.information_schema.columns e")
)
val futures = queries.map { case (query) =>
Future { runQuery(query) }
}
Thread.sleep(20000)
println("Interrupting query 1 !!!!")
println(list1)
spark.interruptOperation(list1(0))
update: 21st June 2024 Adding a better one with tracking of the tags and terminating individual tags with timeout
val timeoutQueue = new TimeoutQueue()
def runQueryWithTag(query: String, tag: String): Unit = {
try {
spark.addTag(tag)
val df = spark.sql(query)
println(df.count)
val a = spark.getTags()
a.foreach(println)
} finally {
println(s"Done with $tag")
spark.removeTag(tag)
timeoutQueue.remove(tag)
}
}
import scala.concurrent.{Future, ExecutionContext}
import scala.concurrent.duration._
import ExecutionContext.Implicits.global
import scala.util.{Success, Failure}
import java.util.concurrent.{Executors, ConcurrentLinkedQueue, PriorityBlockingQueue}
val queriesWithTags = Seq(
("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b, system.information_schema.columns c", "tag3"),
("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b", "tag2"),
("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b, system.information_schema.columns c, system.information_schema.columns d, system.information_schema.columns e, system.information_schema.columns f, system.information_schema.columns g", "tag1")
// ("SELECT * FROM system.information_schema.columns a", "tag1")
)
// Order them based on the expiration time.
case class TimedTag(tag: String, expirationTime: Long) extends Comparable[TimedTag] {
override def compareTo(other: TimedTag): Int = expirationTime.compareTo(other.expirationTime)
}
class TimeoutQueue {
private val queue = new PriorityBlockingQueue[TimedTag]()
private def popExpired(): Option[String] = {
val currentTime = System.currentTimeMillis()
val headOption = Option(queue.peek())
headOption match {
case Some(head) if head.expirationTime <= currentTime =>
queue.poll() // remove the head
Some(head.tag)
case _ => None
}
}
def add(tag: String, timeout: FiniteDuration): Unit = {
val expirationTime = System.currentTimeMillis() + timeout.toMillis
queue.put(TimedTag(tag, expirationTime))
}
def remove(tag: String): Unit = {
queue.removeIf(_.tag == tag)
}
def loop(checkInterval: FiniteDuration): Unit = {
while (queue.size() > 0) { // while we have tags, wait for them
popExpired() match {
case Some(tag) => { println(s"Interrupting $tag"); spark.interruptTag(tag) }
case None => // No item ready to process
}
// println(s"sleeping for $checkInterval")
Thread.sleep(checkInterval.toMillis)
}
}
}
val futures = queriesWithTags.map { case (query, tag) =>
Future { runQueryWithTag(query, tag) }
}
timeoutQueue.add("tag1", 30000.milliseconds)
timeoutQueue.add("tag2", 55000.milliseconds)
timeoutQueue.add("tag3", 42000.milliseconds)
timeoutQueue.loop(500.milliseconds)