package scalashop

import scalashop.image.*
import scalashop.image.GameOfLife.*

import java.io.File
import javax.imageio.ImageIO
import java.awt.image.BufferedImage

class GameOfLifeSuite extends ImageTestSuite:

  val deadColor = State.deadColor
  val aliveColor = State.aliveColor

  extension (img: Image)
    def step: Image = Step(img)
    def step(n: Int): Image =
      if n <= 0 then img
      else img.step.build.step(n - 1)

  // Simple tests
  test("Functionality: Stepping blank state does nothing (2pts)"):
    val solidDead = pureColor(10, 10, deadColor)
    assertEquals(solidDead.step, solidDead)

  test(
    "Functionality: Stepping single pixel state leads to blank state (2pts)"
  ):
    val singleAlive = Image(
      9,
      9,
      Array.tabulate(81)(i => if i == 40 then aliveColor else deadColor)
    )
    val solidDead = pureColor(9, 9, deadColor)
    assertEquals(singleAlive.step, solidDead)

  // Basic functionality tests

  test("Functionality: Single neighbour cell vanishes (2pts)"):
    val singleAlive = Image(
      3,
      3,
      Array(
        deadColor,
        deadColor,
        deadColor,
        deadColor,
        aliveColor,
        deadColor,
        deadColor,
        deadColor,
        deadColor
      )
    )
    val solidDead = pureColor(3, 3, deadColor)
    assertEquals(singleAlive.step, solidDead)

  test("Functionality: Two neighbour cell survives (2pts)"):
    val lShape = Image(
      3,
      3,
      Array(
        deadColor,
        deadColor,
        deadColor,
        deadColor,
        aliveColor,
        aliveColor,
        deadColor,
        deadColor,
        aliveColor
      )
    )
    val nextState = lShape.step
    assertEquals(nextState(1, 1), aliveColor)
    assertEquals(nextState(1, 2), aliveColor)
    assertEquals(nextState(2, 2), aliveColor)

  test("Functionality: Three neighbour cell comes to life (2pts)"):
    val lShape = Image(
      3,
      3,
      Array(
        deadColor,
        deadColor,
        deadColor,
        deadColor,
        aliveColor,
        aliveColor,
        deadColor,
        deadColor,
        aliveColor
      )
    )
    val nextState = lShape.step
    assertEquals(nextState(2, 1), aliveColor)

  test("Functionality: Three neighbour cells are stable (2pts)"):
    val lShape = Image(
      4,
      4,
      Array(
        deadColor,
        deadColor,
        deadColor,
        deadColor,
        deadColor,
        aliveColor,
        aliveColor,
        deadColor,
        deadColor,
        aliveColor,
        aliveColor,
        deadColor,
        deadColor,
        deadColor,
        deadColor,
        deadColor
      )
    )

    def valid(state: Image) =
      assertEquals(state(1, 1), aliveColor)
      assertEquals(state(1, 2), aliveColor)
      assertEquals(state(2, 2), aliveColor)
      assertEquals(state(2, 1), aliveColor)

    for i <- 0 to 10 do valid(lShape.step(i))

  test("Functionality: Four neighbour cell dies (2pts)"):
    val crowd = Image(
      3,
      3,
      Array(
        deadColor,
        deadColor,
        deadColor,
        aliveColor,
        aliveColor,
        aliveColor,
        deadColor,
        aliveColor,
        aliveColor
      )
    )
    val nextState = crowd.step

    assertEquals(nextState(1, 1), deadColor)

  test("Functionality: Five neighbour cell dies (2pts)"):
    val crowd = Image(
      3,
      3,
      Array(
        aliveColor,
        deadColor,
        deadColor,
        aliveColor,
        aliveColor,
        aliveColor,
        deadColor,
        aliveColor,
        aliveColor
      )
    )
    val nextState = crowd.step

    assertEquals(nextState(1, 1), deadColor)

  test("Functionality: Board state wraps around (single cell) (2pts)"):
    val singleCell = Image(1, 1, Array(aliveColor))
    val nextState = singleCell.step
    // single cell sees itself as a neighbour due to looping
    // so it has "8" neighbours
    assertEquals(nextState(0, 0), deadColor)

  // Blinker tests

  /** Blinker is a 3x1 image that looks like this:
    *
    * 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0
    */
  val blinker = Image(
    5,
    5,
    Array(
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      aliveColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      aliveColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      aliveColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor
    )
  )

  /** After one step, the blinker should rotate to look like this:
    *
    * 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0
    */
  val sideWaysBlinker = Image(
    5,
    5,
    Array(
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      aliveColor,
      aliveColor,
      aliveColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor,
      deadColor
    )
  )
  test("Oscillation: Blinker rotates on step (3pts)"):
    assertEquals(blinker.step, sideWaysBlinker)
  test("Oscillation: Blinker is 2-periodic (3pts)"):
    assertEquals(blinker, blinker.step(2))

  // Pulsar tests

  /** A pulsar is a state that repeats every 3 steps. */
  val pulsar =
    val stream =
      this.getClass.getResourceAsStream("/scalashop/gameoflife/pulsar.bmp")
    val buffer = ImageIO.read(stream)
    Image.from(buffer)

  test("Oscillation: Pulsar is not 1-periodic (3pts)"):
    assertNotEquals(pulsar, pulsar.step(1))
  test("Oscillation: Pulsar is not 2-periodic (3pts)"):
    assertNotEquals(pulsar, pulsar.step(2))
  test("Oscillation: Pulsar is 3-periodic (3pts)"):
    assertEquals(pulsar, pulsar.step(3))

  // Glider tests

  /** A glider is a state that moves diagonally. One step every 4 steps. */
  val glider =
    val stream =
      this.getClass.getResourceAsStream("/scalashop/gameoflife/glider.bmp")
    val buffer = ImageIO.read(stream)
    Image.from(buffer)

  test("Glider: Glider loops around correctly (36 steps) (7pts)"):
    val steps = 36
    val gliderStates = Seq.tabulate(steps)(i => glider.step(i + 1))
    (0 until (steps - 1)).foreach {
      i =>
        assertNotEquals(glider, gliderStates(i))
    }
    assertEquals(glider, gliderStates(steps - 1))
