# Copyright John Lambert



import numpy as np 
import matplotlib.pyplot as plt 
import math
import pdb
import scipy.linalg
import itertools    

def generate_whitened_noise(cov,dim):
	return np.squeeze( np.matmul(scipy.linalg.sqrtm(cov),np.random.randn(dim)) )


class MultiHypothesisKalmanFilter():
	def __init__(self, num_timesteps):
		"""
		State and measurement dimension = 2
		Args:
		-	num_timesteps: integer denoting number of timesteps for which to run the tracker

		Returns:
		-	None
		"""
		self.num_timesteps = num_timesteps
		self.num_targets = 3
		self.R = 1 * np.eye(self.num_targets*2) # essentially, 3 block diagonal identity mats I_{2 x 2}
		self.Q = 0.1 * np.eye(self.num_targets*2)
		self.dt = 1.

		# TODO: make these uniform later
		self.weights = np.array([0.5,0.25,0.25])
		self.true_states = [ np.array([0, 0, 10, 0, 20, 0]) ]
		self.measurements = []

		# true measurement association
		self.Ctrue = np.zeros((6 ,6))
		self.Ctrue[0:2,0:2] = np.eye(2)
		self.Ctrue[2:4,4:6] = np.eye(2)
		self.Ctrue[4:6,2:4] = np.eye(2)

		self.possible_permutations = self.numpy_perms(self.num_targets)


	def generate_sequence(self):
		""" 
		Run simulation of the dynamics and measurement models
		for some number of timesteps before the filter is actually executed.

		Returns:
		-	true_states: 
		-	measurements: 
		"""
		x_t = self.true_states[0] # initial state
		y_t = self.h(x_t, add_noise=True)
		self.measurements += [y_t]

		for t in range(self.num_timesteps-1):
			x_t = self.dynamics(x_t, t, add_noise = True)
			self.true_states += [x_t]
			y_t = self.h(x_t, add_noise=True)
			self.measurements += [y_t]

		return self.true_states, self.measurements


	def h(self, x_t, add_noise=True):
		""" 
		Measurement model for generating measurements y. 
		Args:
		-	x_t:
		-	add_noise:
		Returns:
		-	
		"""
		y_t = self.Ctrue.dot( x_t )
		if add_noise:
			y_t += generate_whitened_noise(self.R,6)
		return y_t


	def get_initial_state_estimate(self):
		""" 
		Returns:
		-	mu:
		-	cov:
		"""
		mu = np.zeros((self.num_targets,6))
		cov = np.zeros((self.num_targets,6,6))
		# loop over the hypotheses (which are true targets initially)
		for i in range(self.num_targets):
			mu[i] = 0.5 * np.random.rand(6) + self.true_states[0]
			cov[i] = 100 * np.eye(6)

		return mu, cov


	def numpy_perms(self, n):
		"""
		Python equivalent of Matlab's "perms([1,2,3])"
		Args:
		-	n: integer representing total number of indices that can be permuted (e.g. num targets)
		Returns:
		-	
		"""
		indices = np.arange(n)
		permut = itertools.permutations(indices)
		permut_array = np.empty((0,n), dtype=np.int64)
		for p in permut:
			permut_array = np.append(permut_array,np.atleast_2d(p),axis=0)
		return permut_array[::-1]


	def generate_permutation_matrix(self, perm_vec):
		"""
		Args:
		-	perm_vec: NumPy n-d array of dim () and dtype '', e.g. 
		Returns:
		-	permutation matrix
		"""
		Ci = np.zeros((6,6))
		Ci[0:2, 2*perm_vec[0] : 2*(perm_vec[0]+1) ] = np.eye(2)
		Ci[2:4, 2*perm_vec[1] : 2*(perm_vec[1]+1) ] = np.eye(2)
		Ci[4:6, 2*perm_vec[2] : 2*(perm_vec[2]+1) ] = np.eye(2)
		return Ci


	def generate_permutation_matrices(self):
		"""
		Returns:
		-	perm_mats: Python list of 6x6 permutation matrices, with dtype 'float64'
		"""
		perm_mats = []
		for perm_vec in self.possible_permutations:
			perm_mats += [self.generate_permutation_matrix(perm_vec)]
		return perm_mats



	def dynamics(self, x_t, t, add_noise):
		""" 
		x_{t+1}^i = x_t^i + u_t^i + w_t^i 

		Q is our process noise

		Args:
		-	x_t: 
		-	t: 
		-	add_noise:
		Returns:
		-	
		"""
		x_tplus1 = x_t + np.hstack([[	np.cos(0.1*t), 	np.sin(0.1*t)	],
									[	-np.cos(0.2*t), np.sin(0.2*t)	],
									[	np.cos(0.1*t), 	np.sin(0.2*t)	] ])
		# add white Gaussian noise
		if add_noise:
			x_tplus1 += generate_whitened_noise(self.Q, 6)
		return x_tplus1


	def gauss_pdf(self, x, mu, cov):
		"""
		Args:
		-	x: the point at which we evaluate
		-	mu: the mean of the Gaussian
		-	cov: the covariance of the Gaussian
		Returns:
		-	
		"""
		n = cov.shape[0]
		exponent = -0.5 * (x-mu).T.dot(scipy.linalg.inv(cov)).dot(x-mu )
		eta = ((2 * np.pi) ** (-n/2.)) * (scipy.linalg.det(cov) ** (-0.5) )
		return eta * np.exp(exponent)


	def run_mhkf_step(self, mu, cov, t, y_true):
		"""
		keep only the first three most likely Gaussian components at each step
		We multiply measurements from 3 original gaussian with 
		6 possible permutation matrices. Then we take the top 3 scoring
		out of the 18 generated.

		Args:
		-	
		Returns:
		-	
		"""
		# predict 
		for i in range(self.num_targets):
			mu[i] = self.dynamics(mu[i], t, add_noise=False)
			cov[i] = np.eye(6).dot(cov[i]).dot( np.eye(6).T ) + self.Q


		perm_mats = self.generate_permutation_matrices()

		#alphas = np.zeros((self.num_targets * len(self.possible_permutations)))
		alphas = []
		hypothesis_idx_vec = []
		# loop over hypotheses
		for j in range(3):
			# loop over measurement associations
			for i, perm_mat in enumerate(perm_mats):
				# this is the expected measurement
				gauss_mu = perm_mat.dot(mu[j,:])
				gauss_cov = self.R + perm_mat.dot(cov[j]).dot(perm_mat.T)
				# these are the scores alpha_{ij}
				# # these scores are the probabilities p(y_t | y_{1:t-1} )
				alpha = self.weights[j] * self.gauss_pdf(x=y_true, mu=gauss_mu, cov=gauss_cov)
				alphas += [alpha] # alpha_vec[ (j-1)*6+i ] = alpha
				hypothesis_idx_vec += [j]

		# prune from 18 to 3
		alphas = np.array(alphas)
		top3scores = np.argsort(alphas) # by default, in ascending order
		top3_score_idxs = top3scores[::-1][0:3] # reverse, then take top 3 scores

		# update each hypothesis from among 18 choices
		for i in range(3):
			k = top3_score_idxs[i]
			perm_mat = perm_mats[k % 6]
			self.weights[i] = alphas[k]
			hypothesis_idx = hypothesis_idx_vec[k] # this is the trajectory we associated w it 

			inv = perm_mat.dot(cov[hypothesis_idx]).dot(perm_mat.T) + self.R
			inv = scipy.linalg.inv(inv)
			K = cov[hypothesis_idx].dot(perm_mat.T).dot(inv)
			pred_meas = perm_mat.dot(mu[hypothesis_idx])
			mu[i] = mu[hypothesis_idx] + K.dot( y_true - pred_meas)
			cov[i] = (np.eye(6) - K.dot(perm_mat)).dot(cov[hypothesis_idx])

		self.weights = self.weights / np.sum(self.weights)
		return mu, cov


def run_MHKF_filter():
	num_timesteps = 50
	mhkf = MultiHypothesisKalmanFilter(num_timesteps)
	true_states, true_meas = mhkf.generate_sequence()

	pred_states = []
	pred_state, pred_cov = mhkf.get_initial_state_estimate()

	# In Python 3, memory is overwritten in place
	for t in range(num_timesteps):
		y_true = true_meas[t]
		pred_state, pred_cov = mhkf.run_mhkf_step(pred_state, pred_cov, t, y_true)
		# must perform deepcopy
		pred_states += [np.copy(pred_state)]

	pred_states = np.array(pred_states)
	true_states = np.array(true_states)

	# # Keep around 3 possible hypothesis of assignments
	# plt.plot(true_states[:,0],true_states[:,1], 'r', linestyle='solid')
	# plt.plot(true_states[:,2],true_states[:,3], 'g', linestyle='solid')
	# plt.plot(true_states[:,4],true_states[:,5], 'b', linestyle='solid')
	# # # plt.savefig('true.png')

	# # Hypothesis 0
	# plt.plot(pred_states[:,0,0], pred_states[:,0,1], 'r', linestyle='dashed')
	# plt.plot(pred_states[:,0,2], pred_states[:,0,3], 'g', linestyle='dashed')
	# plt.plot(pred_states[:,0,4], pred_states[:,0,5], 'b', linestyle='dashed')
	# # plt.savefig('pred.png')

	# # Hypothesis 1
	# plt.plot(pred_states[:,1,0], pred_states[:,1,1], 'r', linestyle='dashdot')
	# plt.plot(pred_states[:,1,2], pred_states[:,1,3], 'g', linestyle='dashdot')
	# plt.plot(pred_states[:,1,4], pred_states[:,1,5], 'b', linestyle='dashdot')
	# # plt.savefig('pred.png')

	# # Hypothesis 2
	# plt.plot(pred_states[:,2,0], pred_states[:,2,1], 'r', linestyle='dotted')
	# plt.plot(pred_states[:,2,2], pred_states[:,2,3], 'g', linestyle='dotted')
	# plt.plot(pred_states[:,2,4], pred_states[:,2,5], 'b', linestyle='dotted')
	# # plt.savefig('pred.png')


	# plt.show()

	for t in range(num_timesteps):
		# Keep around 3 possible hypothesis of assignments
		plt.scatter(true_states[:t,0],true_states[:t,1], 30, color='r', marker='+')
		plt.scatter(true_states[:t,2],true_states[:t,3], 30, color='g', marker='+')
		plt.scatter(true_states[:t,4],true_states[:t,5], 30, color='b', marker='+')
		# # plt.savefig('true.png')

		# Hypothesis 0
		plt.scatter(pred_states[:t,0,0], pred_states[:t,0,1], 30, color='r', marker='.')
		plt.scatter(pred_states[:t,0,2], pred_states[:t,0,3], 30, color='g', marker='.')
		plt.scatter(pred_states[:t,0,4], pred_states[:t,0,5], 30, color='b', marker='.')
		# plt.savefig('pred.png')

		# Hypothesis 1
		plt.scatter(pred_states[:t,1,0], pred_states[:t,1,1], 30, color='r', marker='*')
		plt.scatter(pred_states[:t,1,2], pred_states[:t,1,3], 30, color='g', marker='*')
		plt.scatter(pred_states[:t,1,4], pred_states[:t,1,5], 30, color='b', marker='*')
		# plt.savefig('pred.png')

		# Hypothesis 2
		plt.scatter(pred_states[:t,2,0], pred_states[:t,2,1], 30, color='r', marker='p')
		plt.scatter(pred_states[:t,2,2], pred_states[:t,2,3], 30, color='g', marker='p')
		plt.scatter(pred_states[:t,2,4], pred_states[:t,2,5], 30, color='b', marker='p')
		# plt.savefig('pred.png')

		plt.savefig('mhkf/{}.png'.format(t))
		# plt.pause(0.01)
		# #plt.close()
		# plt.gcf.clear()






def main():
	run_MHKF_filter()



if __name__ == '__main__':
	main()