import numpy as np
from scipy.stats import multivariate_normal

from utils import HMM

# Gaussian definitions
mean = {
    "0": [0, 0],
    "a": [730, 1090],
    "e": [530, 1840],
    "i": [270, 2290],
    "o": [570, 840],
    "y": [440, 1020],
}

cov = {
    "0": [[0, 0], [0, 0]],
    "a": [[1625, 5300], [5300, 53300]],
    "e": [[15025, 7750], [7750, 36725]],
    "i": [[2525, 1200], [1200, 36125]],
    "o": [[2000, 3600], [3600, 20000]],
    "y": [[8000, 8400], [8400, 18500]],
}

GAUSSIANS = {
    vowel: multivariate_normal(mean=mean[vowel], cov=cov[vowel], allow_singular=True)
    for vowel in ["0", "a", "e", "i", "o", "y"]
}

# HMM definitions
transitions = [
    [0.0, 1.0, 0.0, 0.0, 0.0],
    [0.0, 0.95, 0.05, 0.0, 0.0],
    [0.0, 0.0, 0.95, 0.05, 0.0],
    [0.0, 0.0, 0.0, 0.95, 0.05],
    [0.0, 0.0, 0.0, 0.0, 1.0],
]

hmm1 = HMM(
    labels=["I", "/a/", "/i/", "/y/", "F"],
    transitions=[
        [0.0, 1.0, 0.0, 0.0, 0.0],
        [0.0, 0.4, 0.3, 0.3, 0.0],
        [0.0, 0.3, 0.4, 0.3, 0.0],
        [0.0, 0.3, 0.3, 0.3, 0.1],
        [0.0, 0.0, 0.0, 0.0, 1.0],
    ],
    gaussians=[
        GAUSSIANS["0"],
        GAUSSIANS["a"],
        GAUSSIANS["i"],
        GAUSSIANS["y"],
        GAUSSIANS["0"],
    ],
)

hmm2 = HMM(
    labels=["I", "/a/", "/i/", "/y/", "F"],
    transitions=[
        [0.0, 1.0, 0.0, 0.0, 0.0],
        [0.0, 0.95, 0.025, 0.025, 0.0],
        [0.0, 0.025, 0.95, 0.025, 0.0],
        [0.0, 0.02, 0.02, 0.95, 0.01],
        [0.0, 0.0, 0.0, 0.0, 1.0],
    ],
    gaussians=[
        GAUSSIANS["0"],
        GAUSSIANS["a"],
        GAUSSIANS["i"],
        GAUSSIANS["y"],
        GAUSSIANS["0"],
    ],
)

hmm3 = HMM(
    labels=["I", "/a/", "/i/", "/y/", "F"],
    transitions=[
        [0.0, 1.0, 0.0, 0.0, 0.0],
        [0.0, 0.5, 0.5, 0.0, 0.0],
        [0.0, 0.0, 0.5, 0.5, 0.0],
        [0.0, 0.0, 0.0, 0.5, 0.5],
        [0.0, 0.0, 0.0, 0.0, 1.0],
    ],
    gaussians=[
        GAUSSIANS["0"],
        GAUSSIANS["a"],
        GAUSSIANS["i"],
        GAUSSIANS["y"],
        GAUSSIANS["0"],
    ],
)

hmm4 = HMM(
    labels=["I", "/a/", "/i/", "/y/", "F"],
    transitions=transitions,
    gaussians=[
        GAUSSIANS["0"],
        GAUSSIANS["a"],
        GAUSSIANS["i"],
        GAUSSIANS["y"],
        GAUSSIANS["0"],
    ],
)

hmm5 = HMM(
    labels=["I", "/y/", "/i/", "/a/", "F"],
    transitions=transitions,
    gaussians=[
        GAUSSIANS["0"],
        GAUSSIANS["y"],
        GAUSSIANS["i"],
        GAUSSIANS["a"],
        GAUSSIANS["0"],
    ],
)

hmm6 = HMM(
    labels=["I", "/a/", "/i/", "/e/", "F"],
    transitions=transitions,
    gaussians=[
        GAUSSIANS["0"],
        GAUSSIANS["a"],
        GAUSSIANS["i"],
        GAUSSIANS["e"],
        GAUSSIANS["0"],
    ],
)

# Observation sequence definitions
X1 = np.array(
    [
        [683.3237129, 1047.38823556],
        [379.80698404, 897.10445171],
        [777.19174766, 1442.30820159],
        [362.41160186, 888.0624029],
        [734.63870443, 1237.56882299],
        [676.07679559, 796.61364498],
        [427.79868516, 1006.71700095],
        [455.00400837, 961.9980049],
        [716.06659034, 987.01370096],
        [757.92648752, 1441.49841675],
        [337.5868643, 820.58353803],
        [728.35707034, 891.5672656],
        [718.70516296, 946.38541218],
        [299.5343153, 2397.72485808],
        [681.7843396, 613.43151388],
        [284.59274009, 2593.32153737],
        [734.03636262, 746.87566093],
        [406.44077205, 928.23813143],
        [212.60057974, 2050.3500739],
        [282.29383591, 2209.87646423],
        [270.95069168, 2220.38302827],
        [248.60613456, 2358.1576441],
        [316.97494743, 2694.95184691],
        [662.0423761, 827.65371194],
        [319.94271219, 2221.12725678],
        [342.12045539, 1005.7699173],
        [351.20159785, 2432.26605092],
        [370.39194953, 2395.64997096],
        [788.46228017, 1049.1338716],
        [685.88525125, 1018.64601078],
        [490.57124459, 922.50265894],
        [466.76544573, 967.22738897],
        [292.06003332, 2083.78623716],
        [264.27671306, 1909.56292186],
        [419.03488706, 1060.34272483],
        [743.49594881, 1367.98770602],
        [226.80996152, 1845.05348723],
        [708.00568909, 1308.16639855],
        [266.23491864, 2361.48636836],
        [306.6078529, 1891.83992787],
        [360.25361762, 902.65361868],
        [690.72356006, 1027.75842594],
        [560.14350856, 1075.84349025],
    ]
)

X2 = np.array(
    [
        [653.75009086, 1149.87533574],
        [722.27102122, 956.38503025],
        [685.47199435, 828.35362072],
        [735.91631046, 1176.29462627],
        [250.98245645, 2193.36752705],
        [267.32726795, 2256.61564403],
        [415.1207097, 1149.74062386],
        [526.51040914, 1204.15870827],
        [355.43295008, 854.6576449],
    ]
)

X3 = np.array(
    [
        [447.71647221, 1012.69994713],
        [230.32366919, 918.24771014],
        [359.27609199, 950.43552672],
        [367.46386772, 856.32471035],
        [495.99448594, 1074.37037297],
        [447.61314587, 1121.93740614],
        [521.93694956, 939.0693422],
        [393.50040762, 874.58115509],
        [520.0734155, 1115.05508129],
        [376.25074963, 926.61233093],
        [416.76285805, 1110.22082526],
        [414.12499379, 990.5807],
        [351.83156481, 1010.30396875],
        [666.81000657, 1165.27159542],
        [200.09377825, 2329.77527091],
        [198.95422158, 2311.05002975],
        [188.94231396, 2418.18044855],
        [334.84832291, 2590.57675652],
        [335.39104444, 2152.88062224],
        [296.68862658, 2465.32191496],
        [217.30452868, 2322.96433469],
        [301.70222517, 2049.05874861],
        [226.24278127, 2433.08834598],
        [272.30588057, 2389.62083847],
        [340.62378642, 2388.40883177],
        [234.40670408, 2428.94453134],
        [257.57446395, 2134.08221944],
        [375.10402214, 2315.69930709],
        [226.31394043, 2406.10389282],
        [357.06957058, 2411.20063271],
        [310.03391303, 2313.94899525],
        [266.42560918, 2454.67952293],
        [260.63942601, 2303.36635759],
        [240.24083585, 2157.12479348],
        [250.63888774, 2062.8723285],
        [282.94120707, 2310.70870217],
        [256.71174458, 1862.75432606],
        [245.70738939, 2334.80221946],
        [289.35581883, 2325.89334147],
        [260.11928006, 1985.1847084],
        [286.31292787, 2266.78124543],
        [297.25324311, 2406.18858154],
        [267.32491591, 2351.37533442],
        [251.6305925, 2227.05893514],
        [246.11443908, 2106.94765214],
        [287.8550032, 2525.86436158],
        [238.56573317, 2191.81524188],
        [306.34066824, 2301.11902979],
        [338.55867409, 1974.50635414],
        [678.5449015, 893.67241741],
        [814.80715069, 1232.79930224],
        [715.06813602, 1088.02374607],
        [738.50843366, 1308.5910947],
        [746.99957726, 1278.49734924],
        [749.95652723, 1050.85459222],
        [728.57967507, 1117.15252768],
        [697.24842139, 1019.13105607],
        [706.40350327, 852.8599759],
        [740.17944166, 848.45532334],
        [681.73540975, 806.80076003],
        [753.68242518, 954.8181818],
        [685.861949, 1242.52045479],
    ]
)

X4 = np.array(
    [
        [793.91077568, 1273.45993886],
        [717.66259491, 895.08049595],
        [728.26353214, 1357.76645812],
        [758.11590213, 973.16933447],
        [762.55818412, 1156.06371439],
        [247.16679431, 2130.04691206],
        [275.24885693, 1945.84492303],
        [201.42575831, 2665.11122236],
        [279.54544589, 2789.13287924],
        [258.70053502, 2321.70594917],
        [315.26758885, 2260.69096563],
        [251.07456453, 2559.46425041],
        [230.53183664, 2439.08235325],
        [248.18583845, 2428.71451213],
        [308.24541166, 2136.5583552],
        [301.81396375, 2488.9061707],
        [219.20937042, 2353.54149047],
        [262.02249128, 2056.56530751],
        [255.4428801, 2258.14180975],
        [206.05019531, 1913.28170347],
        [224.49774348, 2284.39260546],
        [190.14089188, 2521.36722876],
        [229.65870723, 1894.8521183],
        [282.98361829, 2573.20126355],
        [273.4600108, 2321.04135203],
        [269.52398502, 2322.84508678],
        [315.01985983, 2329.04677829],
        [310.30494388, 2390.30304213],
        [277.99452767, 2343.87918067],
        [216.13718507, 2105.70915026],
        [271.11885059, 2357.26031485],
        [229.32511379, 2439.82754997],
        [289.13508747, 2809.56004097],
        [284.01964789, 2043.75162168],
        [291.6878431, 2386.82473139],
        [245.63533089, 2332.08126123],
        [240.21259837, 2290.95732995],
        [296.92765665, 2518.40692599],
        [295.33960265, 1791.45390611],
        [288.40195799, 2448.3847471],
        [273.88606618, 2453.06263837],
        [315.26913863, 2542.00355314],
        [336.36794498, 2478.28460929],
        [182.3206354, 2093.55201943],
        [455.82408124, 920.6763391],
        [418.93369874, 1085.01416092],
        [484.1559306, 1159.0960261],
        [351.89642879, 1028.69436264],
        [251.57547974, 767.09458766],
        [457.88795038, 1013.50392598],
        [337.87121824, 936.75642331],
        [357.54319899, 880.17328907],
        [400.30155094, 836.28780242],
        [384.07731689, 769.27937418],
        [496.05801236, 1085.8108051],
        [472.9778014, 960.37548042],
        [451.51264298, 907.77680822],
        [427.05807083, 779.98887133],
        [433.44406475, 936.05651006],
        [384.20361859, 1008.80439882],
        [451.88297851, 1132.72516821],
        [471.96076658, 1147.46449001],
        [404.73018515, 1009.94938853],
        [516.54041484, 1181.0660933],
        [478.12911071, 1007.42036214],
        [302.82339635, 875.9661708],
        [426.62590622, 903.56355338],
        [507.59658368, 1175.57691559],
    ]
)

X5 = np.array(
    [
        [718.81501642, 949.44953301],
        [678.98145422, 1165.49309562],
        [722.06205378, 1170.010651],
        [722.1318535, 979.2595887],
        [752.92273853, 1260.10054724],
        [738.04763959, 1123.83976384],
        [720.91260277, 1370.36409674],
        [720.42308575, 521.19555915],
        [664.70318001, 1137.08492805],
        [675.24775259, 789.69361645],
        [804.36633118, 934.81826386],
        [693.24725473, 1135.27035896],
        [681.57518614, 812.37688431],
        [668.70189139, 925.61803273],
        [718.30484311, 757.03117948],
        [726.71364834, 1085.36365292],
        [737.62119247, 1309.93436438],
        [810.91284877, 1774.45227797],
        [208.3563265, 2313.60888514],
        [306.55125801, 2184.29790689],
        [304.74018734, 2610.26004831],
        [314.42290921, 2461.91992252],
        [260.2757021, 2345.30739558],
        [243.17282738, 2257.8766794],
        [270.06093398, 2478.96985782],
        [291.59488233, 2287.28538359],
        [249.79131639, 2242.04518017],
        [215.52731768, 2427.92938982],
        [323.09084322, 2101.36444224],
        [352.60861261, 2204.61517889],
        [280.40598406, 2358.13522611],
        [256.95839221, 2261.07794995],
        [211.43955532, 2279.972853],
        [310.24321123, 2647.60573345],
        [282.13371248, 2595.92833444],
        [244.46874644, 2041.17492503],
        [230.97716953, 2179.03333672],
        [294.40584978, 2039.36879681],
        [258.33052343, 2079.84963269],
        [133.60128671, 2324.11732425],
        [232.76940927, 2501.42831245],
        [260.87627186, 2507.83329731],
        [708.70092699, 2270.25798634],
        [422.74796896, 1606.55571898],
        [633.76945356, 1786.85039764],
        [495.93100915, 1921.97562458],
        [592.78042203, 1923.50339185],
        [683.69298039, 2030.28361457],
        [398.6281814, 1886.91150344],
    ]
)

X6 = np.array(
    [
        [723.74774402, 955.38150018],
        [714.46091167, 1039.52856822],
        [802.32176356, 1191.012574],
        [784.03336945, 1329.64059618],
        [759.96534837, 1142.83601059],
        [674.16049406, 1012.89230034],
        [750.23863661, 1029.96693911],
        [241.99662666, 1918.74030655],
        [328.15744446, 2447.69686517],
        [257.85887869, 2071.2295713],
        [305.13188122, 2148.68599544],
        [343.2827616, 2458.0309559],
        [172.34090761, 1905.75204965],
        [452.46285034, 2244.55306136],
        [291.22821988, 2366.16489438],
        [242.72739661, 2254.27351644],
        [351.85146794, 2416.0065515],
        [309.61416632, 2469.78319061],
        [328.4394445, 2219.44242499],
        [234.8773876, 2224.77631998],
        [245.14648336, 2243.53920208],
        [300.5855591, 2175.16912779],
        [215.10291221, 2340.16607845],
        [309.42411966, 2618.72691865],
        [839.86698498, 1290.65677567],
        [698.06696931, 989.45586353],
        [811.76248632, 1300.0073246],
        [734.18005469, 1308.31484963],
        [796.16105284, 1380.66227913],
        [710.75672739, 952.43198513],
        [725.12091222, 1236.18827168],
        [714.70261946, 933.10320891],
        [812.67230976, 1178.83787779],
        [771.99558126, 1505.18034588],
        [752.0777443, 1288.48705215],
        [356.81812213, 885.89453273],
        [488.44321823, 949.99170824],
        [465.91582534, 1133.86135905],
        [286.11137241, 820.33859561],
        [437.42903616, 832.7435448],
        [483.63543813, 1066.45048504],
        [458.36129897, 943.47369332],
        [425.02724601, 1026.88966069],
        [624.53961319, 1140.47908616],
        [488.8586898, 1036.88985517],
        [545.31701114, 1089.94677472],
        [352.01163681, 940.92274906],
        [454.96149064, 1008.92592039],
        [488.25203376, 1063.06633365],
        [487.46771993, 1018.67830465],
        [510.88639733, 1162.03110143],
        [407.26508312, 852.40597532],
        [438.69292735, 1034.94066003],
        [479.87746134, 1127.12259656],
        [229.57239465, 2394.06185018],
        [241.88243495, 2219.97827842],
        [299.47200455, 2163.4295396],
        [233.3017185, 2380.10491977],
        [234.32097882, 1937.06335951],
        [203.34587481, 2498.07966961],
        [317.45660829, 2432.63658546],
        [314.05287926, 2338.04691989],
        [157.1006496, 2036.86413126],
        [247.28747269, 2469.78477254],
        [344.78381248, 2168.0781612],
        [277.65580883, 2189.10530795],
        [281.08810389, 2266.12561174],
        [109.45306338, 2164.61492613],
        [334.74257149, 2326.40662572],
        [324.23280436, 2285.74699667],
        [365.48569518, 2352.74064124],
        [198.84821088, 2070.44941546],
        [331.03471836, 2198.71087604],
        [282.60606843, 2443.43306108],
        [248.64617869, 2659.37941087],
        [258.17966943, 2364.72478513],
        [313.67894716, 2311.77457053],
        [300.27507706, 2324.10620221],
        [340.81183553, 2451.00232701],
        [224.50086075, 2577.17511095],
        [271.2981141, 2086.48842752],
        [291.88673001, 2367.83658695],
        [227.84059319, 2039.06760255],
        [253.74510992, 2042.91623216],
        [342.90211384, 2107.1946109],
        [234.49643817, 2045.61455336],
        [225.23589646, 2532.63671559],
        [722.48837537, 1267.88731375],
        [804.42915078, 1260.35355876],
        [708.05829522, 1054.10596132],
        [740.17452045, 1578.05495679],
        [720.06402037, 1157.0242264],
        [152.15789935, 2325.65846279],
        [267.9909854, 2229.43348364],
        [287.17419339, 2326.33733889],
        [228.45167284, 2148.92038825],
        [267.15427192, 2150.39438923],
        [243.94932152, 2193.80386329],
        [232.53552133, 1947.58858946],
        [305.23809326, 882.23770089],
        [560.03152492, 1115.45078395],
        [502.90799168, 1214.90758316],
        [488.40467997, 1103.79889243],
        [419.14871113, 1079.79455687],
        [383.13060915, 1051.74040967],
        [391.06710437, 1109.6329007],
        [301.32947493, 828.56001445],
        [415.8748568, 1042.84415187],
    ]
)

# Corresponding state sequence definitions
ST1 = np.array([0, 1, 3, 1, 3, 1, 1, 3, 3, 1, 1, 3, 1, 1, 2, 1, 2, 1, 3, 2, 2, 2,
       2, 2, 1, 2, 3, 2, 2, 1, 1, 3, 3, 2, 2, 3, 1, 2, 1, 2, 2, 3, 1, 3,
       4])

ST2 = np.array([0, 1, 1, 1, 1, 2, 2, 3, 3, 3, 4])

ST3 = np.array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4])

ST4 = np.array([0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 4])

ST5 = np.array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3,
       3, 3, 3, 3, 3, 3, 4])

ST6 = np.array([0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4])
