//> using dep "org.scala-lang.modules::scala-parallel-collections:1.0.4"

import scala.collection.parallel.CollectionConverters.*
import scala.collection.parallel.immutable.ParSeq
import scala.reflect.ClassTag

inline def time() = System.currentTimeMillis()
type T = Long

type Matrix[A] = Array[Array[A]]

def matrix[A: ClassTag](maxRow: Int, maxCol: Int, f: (Int,Int) => A): Array[Array[A]] =
  Array.tabulate(maxRow)(i =>
    Array.tabulate(maxCol)(j => f(i, j)))

val MaxRow = 500
val MaxCol = 64

val m1 = matrix[T](MaxRow,MaxCol, (i,j) => 2*i + j)
val m2 = matrix[T](MaxCol,MaxRow, (i,j) => i + 3*j)

def prod(m1: Matrix[T], m2: Matrix[T])(i: Int, j: Int): T =
  val N = m1(i).length
  var k = 0
  var res: T = 0
  while k < N do
    res += m1(i)(k)*m2(k)(j)
    k += 1
  res

def seqMultiply(m1: Matrix[T], m2: Matrix[T]): Matrix[T] =
  (0 until m1.length).map(i =>
    Array.tabulate(m2(0).length)(j => 
      prod(m1, m2)(i, j))).toArray

def parMultiply(m1: Matrix[T], m2: Matrix[T]): Matrix[T] =
  (0 until m1.length).par.map(i =>
    Array.tabulate(m2(0).length)(j => 
      prod(m1, m2)(i, j))).toArray

def trace(m: Matrix[T]): T =
  var r: T = 0
  var i = 0
  while i < m.length && i < m(i).length do
    r += m(i)(i)
    i += 1
  r

def measureTime[R](code: => R): Unit =
  var total: Long = 0
  val iters = 10
  (1 to iters).foreach: i =>
    val t2 = time()
    val res = code
    val t3 = time()
    total = total + (t3 - t2)
    print(".")
  println(f"\nTime: ${total.toDouble/iters.toDouble}%2.2f ms average.\n")

@main
def test =
  (1 to 4).foreach: i =>
    println(f"PASS: ${i}")
    println("First version:")
    measureTime(trace(seqMultiply(m1,m2)))
    println("Second version:")
    measureTime(trace(parMultiply(m1,m2)))
