משחקים עם חישוב מקבילי בסקאלה

06/02/2024

אחת הבעיות של עבודה לא מבוססת Java בתוך ה JVM היא שיש יותר מדי דרכים לעשות דברים, ולא תמיד ברור במה לבחור. במקרה של סקאלה ומקביליות זה נהיה מסובך כי יש גם שיקולים של ארכיטקטורה ותכנות פונקציונאלי.

בשביל המשחק רציתי לספור כמה מספרים ראשוניים יש עד 10 מיליון, ולפצל את החישוב לכמה תהליכונים תוך שימוש בשתי גישות פשוטות למקביליות ובהשוואה עם חישוב סדרתי.

הגישה המקבילית הראשונה היתה פשוט לפתוח Future לכל מספר כדי לזהות אם הוא ראשוני, ולתת ל Java לשגר את התהליכונים.

הגישה המקבילית השניה היתה הספריה parallel-collections שמציעה מימוש של map מקבילי.

וכן נצטרך לעשות פוסט המשך עם cats-effect.

טוב קוד? יאללה. זאת התוכנית:

import scala.concurrent.{Await, ExecutionContext, Future}
import scala.util.Random
import concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration.Duration
import java.net.{URI, URL}
import java.util.concurrent.{CompletableFuture, Executors}
import scala.language.implicitConversions
import scala.util.chaining._
import scala.collection.parallel.CollectionConverters._


object futures {

  private def isPrime(n: Int): Boolean =
    2.to(Math.sqrt(n.toDouble).toInt).forall(n % _ != 0)

  @main
  def virtualThreadsDemo(): Unit =
    val s0 = System.nanoTime()
    1.to(10000000)
      .map(n => Future { isPrime(n) })
      .map(Await.result(_, Duration.Inf))
      .count(identity)
      .pipe(println)

    val s1 = System.nanoTime()

    1.to(10000000)
      .map(isPrime)
      .count(identity)
      .pipe(println)

    val s2 = System.nanoTime()

    1.to(10000000)
      .par
      .map(isPrime)
      .count(identity)
      .pipe(println)

    val s3 = System.nanoTime()

    println(s"1 thread = ${s2 - s1}")
    println(s"* thread = ${s1 - s0}")
    println(s"pmap     = ${s3 - s2}")
}

וכן בשביל המשחק כתבתי אותה גם בפייתון כדי שנוכל להשוות זמנים:

import time
import multiprocessing
import math

def isprime(n):
    for i in range(2, int(math.sqrt(n)) + 1):
        if n % i == 0:
            return False
    return True

if __name__ == "__main__":
    pool = multiprocessing.Pool(5)

    s0 = time.time_ns()
    print(sum(pool.map(isprime, range(10_000_000))))
    s1 = time.time_ns()
    print(s1 - s0)

והתוצאות לפחות אצלי על המחשב:

1 thread = 2025199791
* thread = 3463589500
pmap     = 674844834
python   = 14593734000

לא סיפרתי קודם אבל ניסיתי גם להחליף את ה Executor שמריץ את ה Threads לכזה שמשתמש ב Virtual Threads של Java אבל התוצאות לא עשו חשק לדבר על זה אז קברתי את הניסוי.

מה למדתי?

  1. חישובים בפייתון עובדים לאט. גם כשניסיתי להריץ את התוכנית בפייתון בלי multiprocessing זה לא עזר.

  2. אי אפשר סתם ליצור Thread לכל מספר. ככל שהמשימה מסובכת שווה להשקיע זמן ולחשוב איך לחלק אותה למספר תהליכונים.

  3. לא סתם בחרתי 10 מיליון. במספרים קטנים יותר (אפילו מיליון) כמעט לא היה הבדל בין תהליכון אחד למספר תהליכונים. חשוב להבין טוב את המשימה לפני שבונים פיתרון מבוסס תהליכונים כדי לא לעשות "אופטימיזציות" מיותרות.