//> using dep "org.scala-lang.modules::scala-parallel-collections:1.0.4"

import scala.collection.parallel.CollectionConverters.*
import scala.collection.mutable.ListBuffer
import scala.collection.parallel.mutable.ParArray

def parallel[A](e1: => A, e2: => A): (A,A) = 
  val both = ParArray(1,2).map(i => if i == 1 then e1 else e2)
  (both(0), both(1))

def parallel1[A](e1: => A, e2: => A): (A,A) = 
  val both = ParArray(() => e1, () => e2).map(f => f())
  (both(0), both(1))

type T = Long

object Norm:
  def power(x: T, p: Double): Int = 
    math.exp(p * math.log(math.abs(x))).toInt 

  def sumSegment(xs: Array[T], p: Double, from: Int, until: Int): Int = 
    var i = from
    var s = 0
    while i < until do
      s += power(xs(i), p)
      i += 1
    s

  // sequential
  def normSeq(xs: Array[T], p: Double): Int =
      power(sumSegment(xs, p, 0, xs.size), 1.0 / p)
  
  def segmentPar(xs: Array[T], p: Double, from: Int, until: Int, threshold: Int): Int =
    if until - from < threshold then
      sumSegment(xs, p, from, until)
    else 
      val mid = from + (until - from) / 2
      val (leftSum, rightSum) = 
        parallel(segmentPar(xs, p, mid, until, threshold),
                 segmentPar(xs, p, from, mid, threshold))
      leftSum + rightSum
  end segmentPar

  def norm(xs: Array[T], p: Double, threshold: Int): Int =
    power(segmentPar(xs, p, 0, xs.length, threshold), 1.0 / p)
end Norm

val arr: Array[T] = (2_500_000 to 1 by -3).map(i => (i * i * 230492384L).toLong).toArray

inline def time() = System.currentTimeMillis()

inline def measureTime[R](inline code: => R): Unit =
  var total: Long = 0
  val iters = 20
  (1 to iters).foreach: i =>
    val t2 = time()
    val res = code
    val t3 = time()
    total = total + (t3 - t2)
    print(".")
  println(f"\nTime: ${total.toDouble/iters.toDouble}%2.2f ms average.\n")

inline def show(threshold: Int): Unit =
  println(f"threshold $threshold:")
  measureTime(Norm.norm(arr, 2.0, threshold))


@main
def test =
  (1 to 3).foreach: i =>
    println(f"PASS: ${i}")
    show(2*arr.length)    
    show(120_000)
    show(60_000)
    show(20_000)
    show(5_000)
    show(3_000)
    
/* on laraquad3 20-core xeon:

Threshold 1666668:
....................
Time: 372.55 ms average.

threshold 120000:
....................
Time: 59.70 ms average.

threshold 60000:
....................
Time: 32.30 ms average.

threshold 20000:
....................
Time: 19.10 ms average.

threshold 5000:
....................
Time: 17.00 ms average.

threshold 3000:
....................
Time: 16.35 ms average.

*/