﻿using UnityEngine;
using System.Collections;

public class HeadTracker : MonoBehaviour 
{
	// The measurement unit used here is unit in Unity 3D Space.
	// You should always use the correct measurement of your physical screen to represent the virtual screen.
	// e.g. I was converting 1 unit to 1cm in physical world. My monitor is 52cm width and 32cm height.
	//      And the origin of my trackers is at the bottom of the screen (thus the screen shift 16cm up in virtual world)
	//      and 10cm away from the screen.
	public float m_ScreenWidth = 52.0f;
	public float m_ScreenHeight = 32.0f;
	public Vector3 m_ScreenPosition = new Vector3(0.0f, 16.0f, 10.0f);
	public Vector3 m_ScreenOrientation = new Vector3 (0.0f, 0.0f, 0.0f);
	public bool m_RestrictHead = false;

	public Vector3 m_TrackerAxisScale = new Vector3 (100.0f, 100.0f, 100.0f);
	public string[] m_TrackerReorderingFormat = {"X", "Y", "Z"};
	public bool m_ReorderingTrackInputAxis = false;

	public bool m_UseCameraClippingPlanes = true;
	public float m_NearClipPlane = 0.1f;
	public float m_FarClipPlane = 1000.0f;

	public bool m_UseVRPNTracker = true;
	public string m_VRPNServerURL = "localhost";
	public string m_VRPNDevieName = "Tracker0";
	public int m_VRPNDeviceChannel = 0;

	[System.NonSerialized]
	public float m_OriginalLeft, m_OriginalRight, m_OriginalTop, m_OriginalBottom;

	private Matrix4x4 m_ScreenMatrix = Matrix4x4.zero;
	private Camera m_Camera;
	private Quaternion m_HeadRotation = Quaternion.identity;

	private int[] m_TrackerReorderedIndex = { 0, 1, 2 };
	private int[] m_TrackerReorderedInversion = { 1, 1, 1 };

	private Bounds m_CAVEBound = new Bounds();

	// Use this for initialization
	void Start () 
	{
		m_Camera = GetComponent<Camera> ();
		
		if (m_UseVRPNTracker)
		{
			// Dynamically add a new Cluster Input entry
			ClusterInput.AddInput ("HeadTracker", m_VRPNDevieName, m_VRPNServerURL, m_VRPNDeviceChannel, ClusterInputType.Tracker);
		}

		// Detect is cluster renderer component attached to the same camera.
		ClusterRenderer clusterRenderer = GetComponent<ClusterRenderer> ();
		if (clusterRenderer != null)
		{
			CameraRig rig = clusterRenderer.rig;

			// Force cluster renderer to follow head tracking setting on screen aspect
			rig.screenAspectWidth = m_ScreenWidth;
			rig.screenAspectHeight = m_ScreenHeight;
		}

		// Calculate the virtual screen transformation matrix and the virtual screen frustum extent once.
		RecalculateScreenMatrix (ref m_ScreenMatrix);
		RecalculateExtent ();
		
		// Calculate the bound of the Screens in CAVE
		CalculateBounds ();

		// Generate the reorder index and +/- of the tracker input axis
		GenerateReordering ();
	}

	public void RecalculateScreenMatrix (ref Matrix4x4 matrix)
	{
		Matrix4x4 parentLocalMatrix = this.transform.parent != null ? this.transform.parent.transform.localToWorldMatrix : Matrix4x4.identity;

		ClusterRenderer clusterRenderer = GetComponent<ClusterRenderer> ();
		if (clusterRenderer != null)
		{
			CameraRig rig = clusterRenderer.rig;
			if (rig.clusterType == CameraRig.ClusterType.Cave)
			{
				// If the cluster type is Cave, we need to rotate the virtual screen according to the index of the cluster
				// also make sure the virtual screen align correctly after the rotation.
				Matrix4x4 originalScreenMatrix = Matrix4x4.zero;
				originalScreenMatrix.SetTRS (m_ScreenPosition, Quaternion.Euler(m_ScreenOrientation), Vector3.one);

				// Get the correct pivot to rotate the virtual screen.
				float halfRotationAngle = rig.CalculateCaveHorizontalFOV() / 2.0f;
				Vector3 rotation = rig.CalculateRotationForCaveNode(clusterRenderer.GetCameraIndex());
				float backwardZ = (m_ScreenWidth / 2.0f) / (Mathf.Tan (halfRotationAngle * Mathf.Deg2Rad));
				Vector3 backward = new Vector3 (0.0f, 0.0f, backwardZ);

				Matrix4x4 translatefix = Matrix4x4.zero;
				translatefix.SetTRS (backward, Quaternion.identity, Vector3.one);

				Matrix4x4 rotationMatrix = Matrix4x4.zero;
				rotationMatrix.SetTRS (Vector3.zero, Quaternion.Euler (rotation), Vector3.one);

				Matrix4x4 skyFloorFix = Matrix4x4.identity;
				if (rig.IsFloor (clusterRenderer.GetCameraIndex ())
					|| rig.IsSky (clusterRenderer.GetCameraIndex ()))
				{
					// Calculate the translation to apply on floor/sky virtual screen
					Vector3 translateFix = new Vector3(0.0f, (m_ScreenWidth - m_ScreenHeight) / 2.0f, 0.0f);

					// Flip the Y for sky
					if (rig.IsSky (clusterRenderer.GetCameraIndex ()))
						translateFix.y *= -1;

					skyFloorFix.SetTRS (translateFix, Quaternion.identity, Vector3.one);
				}

				// Final matrix to apply to virtual sceen.
				matrix = parentLocalMatrix * originalScreenMatrix * skyFloorFix * translatefix.inverse * rotationMatrix * translatefix;
			}
			else
			{
				// Normal matrix for Wall.
				matrix.SetTRS (m_ScreenPosition, Quaternion.Euler (m_ScreenOrientation), Vector3.one);
				matrix = parentLocalMatrix * matrix;
			}
		}
		else
		{
			// Normal matrix for camera without cluster renderer.
			matrix.SetTRS (m_ScreenPosition, Quaternion.Euler (m_ScreenOrientation), Vector3.one);
			matrix = parentLocalMatrix * matrix;
		}
	}

	public void RecalculateExtent()
	{
		// Define the frustum extent of the virtual screen.
		m_OriginalLeft = -m_ScreenWidth / 2.0f;
		m_OriginalRight = m_ScreenWidth / 2.0f;
		m_OriginalTop = m_ScreenHeight / 2.0f;
		m_OriginalBottom = -m_ScreenHeight / 2.0f;

		// Changes for CAVE's floor/sky as they need to respect a 1:1 ratio.
		ClusterRenderer clusterRenderer = GetComponent<ClusterRenderer> ();
		if (clusterRenderer != null)
		{
			CameraRig rig = clusterRenderer.rig;
			if (rig.clusterType == CameraRig.ClusterType.Cave)
			{
				rig.screenAspectWidth = m_ScreenWidth;
				rig.screenAspectHeight = m_ScreenHeight;

				if (rig.IsFloor (clusterRenderer.GetCameraIndex ())
					|| rig.IsSky (clusterRenderer.GetCameraIndex ()))
				{
					m_OriginalLeft = -m_ScreenWidth / 2.0f;
					m_OriginalRight = m_ScreenWidth / 2.0f;
					m_OriginalTop = m_ScreenWidth / 2.0f;
					m_OriginalBottom = -m_ScreenWidth / 2.0f;
				}
			}
		}
	}

	void CalculateBounds()
	{
		if (IsClusterRendereCAVE())
		{
			Vector3 boundPos = new Vector3 (m_ScreenPosition.x, m_ScreenPosition.y, m_ScreenPosition.z - (m_ScreenWidth / 2.0f));

			Matrix4x4 parentLocalMatrix = this.transform.parent != null ? this.transform.parent.transform.localToWorldMatrix : Matrix4x4.identity;
			Matrix4x4 matrix = Matrix4x4.identity;
			matrix.SetTRS (boundPos, Quaternion.Euler (m_ScreenOrientation), Vector3.one);
			matrix = parentLocalMatrix * matrix;

			m_CAVEBound.size = new Vector3 (m_ScreenWidth, m_ScreenHeight, m_ScreenWidth);
			m_CAVEBound.center = matrix.MultiplyPoint (Vector3.zero);
		}
	}

	void OnPreCull()
	{
		// Update every frame in case camera changed it far and near planes.
		if (m_UseCameraClippingPlanes)
		{
			m_NearClipPlane = m_Camera.nearClipPlane;
			m_FarClipPlane = m_Camera.farClipPlane;
		}

		// Frustum extent
		float left = 0;
		float right = 0;
		float bottom = 0;
		float top = 0;

		if (m_Camera.stereoEnabled)
		{
			// For stereoscopic, we treat each eye as a camera (head in head tracking context)
			// Using the default setting of stereo eye separation, we compute the eye position and calculate the frustum
			// for each eye.
			// NOTE : For user that use their own stereo eye transformation, you will need to alter this code.

			bool isMatricesValid = true;
			Matrix4x4 transformMatrix = Matrix4x4.identity;

			// left eye
			Vector3 leftShift = new Vector3 ((m_Camera.stereoSeparation / 2.0f) * -1.0f, 0.0f);
			transformMatrix.SetTRS (this.transform.position, m_HeadRotation, Vector3.one);
			Vector3 leftCameraPos = transformMatrix.MultiplyPoint (leftShift);

			isMatricesValid |= CalculateFrustrum (leftCameraPos, ref left, ref right, ref bottom, ref top);
			Matrix4x4 leftProjMatrix = PerspectiveOffCenter (left, right, bottom, top, m_NearClipPlane, m_FarClipPlane);

			// right eye
 			Vector3 rightShift = new Vector3 ((m_Camera.stereoSeparation / 2.0f), 0.0f);
			transformMatrix.SetTRS (this.transform.position, m_HeadRotation, Vector3.one);
			Vector3 rightCameraPos = transformMatrix.MultiplyPoint (rightShift);

			isMatricesValid |= CalculateFrustrum (rightCameraPos, ref left, ref right, ref bottom, ref top);
			Matrix4x4 rightProjMatrix = PerspectiveOffCenter (left, right, bottom, top, m_NearClipPlane, m_FarClipPlane);

			// Then we set the custom projection matrix for each stereo eye.
			if (isMatricesValid)
				m_Camera.SetStereoProjectionMatrices (leftProjMatrix, rightProjMatrix);
		}
		else
		{
			if (CalculateFrustrum (this.transform.position, ref left, ref right, ref bottom, ref top))
				m_Camera.projectionMatrix = PerspectiveOffCenter (left, right, bottom, top, m_NearClipPlane, m_FarClipPlane);
		}
	}

	private bool CalculateFrustrum (Vector3 position, ref float left, ref float right, ref float bottom, ref float top)
	{
		// Locate the head position relative to the virtual screen.
		Vector3 headPosition = m_ScreenMatrix.inverse.MultiplyPoint (position);

		/* The main idea of making head tracking work is generate projection matrix that will always keep the
		 * virtual screen extent intact regardless of the position of the head position.
		 * We will calculate the new frustum with the thales factor of near clip plane and head position Z to the virtual screen.
		 * For more information, please refer to http://stackoverflow.com/a/16755262
		 */
		float thalesFactor = m_NearClipPlane / Mathf.Abs (headPosition.z);
		if (float.IsNaN (thalesFactor) || float.IsInfinity(thalesFactor))
			return false;

		left = (m_OriginalLeft - headPosition.x) * thalesFactor;
		top = (m_OriginalTop - headPosition.y) * thalesFactor;
		right = (m_OriginalRight - headPosition.x) * thalesFactor;
		bottom = (m_OriginalBottom - headPosition.y) * thalesFactor;

		if (left >= right)
		{
			left = right - float.Epsilon;
		}
		if (bottom >= top)
		{
			bottom = top - float.Epsilon;
		}

		// If there is cluster renderer and Wall is in used. We'll divide the frustum according to the frustum calculated above,
		// instead of the original calculation used in CameraRig
		ClusterRenderer clusterRenderer = GetComponent<ClusterRenderer> ();
		if (clusterRenderer != null)
		{
			CameraRig rig = clusterRenderer.rig;
			if (rig.clusterType == CameraRig.ClusterType.Wall)
			{
				int x = clusterRenderer.GetCameraIndex () % rig.wallParams.cols;
				int y = (int)(Mathf.Floor (clusterRenderer.GetCameraIndex () / rig.wallParams.cols));

				float horizontalOffset = (right - left) / rig.wallParams.cols;
				float verticalOffset = (top - bottom) / rig.wallParams.rows;

				left = left + x * horizontalOffset;
				right = left + horizontalOffset;
				top = top - y * verticalOffset;
				bottom = top - verticalOffset;
			}
		}

		return true;
	}

	private Matrix4x4 PerspectiveOffCenter (float left, float right, float bottom, float top, float near, float far)
	{
		Matrix4x4 m = Matrix4x4.zero;
		m.m00 = (2.0f * near) / (right - left);
		m.m02 = (right + left) / (right - left);
		m.m11 = (2.0f * near) / (top - bottom);
		m.m12 = (top + bottom) / (top - bottom);
		m.m22 = -(far + near) / (far - near);
		m.m23 = -(2.0f * far * near) / (far - near);
		m.m32 = -1.0f;
		return m;
	}
	
	void Update () 
	{
		m_HeadRotation = GetTrackerRotation ();

		Vector3 pos = GetTrackerPosition();
		if (m_ReorderingTrackInputAxis)
		{
			Vector3 reorderedPos = Vector3.zero;
			for (int i = 0; i < 3; ++i)
			{
				reorderedPos[i] = pos[m_TrackerReorderedIndex[i]];
				reorderedPos[i] *= m_TrackerReorderedInversion[i];
			}
			pos = reorderedPos;
		}
		pos.Scale(m_TrackerAxisScale);

		// Check is head goes beyond the screen / bound if CAVE
		if (m_RestrictHead)
		{
			if (IsClusterRendereCAVE ())
			{
				if (!m_CAVEBound.Contains (pos))
				{
					pos.x = Mathf.Clamp (pos.x, m_CAVEBound.min.x, m_CAVEBound.max.x);
					pos.y = Mathf.Clamp (pos.y, m_CAVEBound.min.y, m_CAVEBound.max.y);
					pos.z = Mathf.Clamp (pos.z, m_CAVEBound.min.z, m_CAVEBound.max.z);
				}
			}
			else
			{
				Vector3 headPosition = m_ScreenMatrix.inverse.MultiplyPoint (pos);

				// Do not go negative or zero
				if (headPosition.z >= 0.0f)
				{
					headPosition.z = -0.000001f;
					pos = m_ScreenMatrix.MultiplyPoint (headPosition);
				}
			}
		}

		// The final head position is applied to the camera local transformation. Camera = head.
		// It's highly recommended to nest the camera under a parent empty game object.
		this.transform.localPosition = pos;
	}

	Vector3 GetTrackerPosition ()
	{
		if (m_UseVRPNTracker)
			return ClusterInput.GetTrackerPosition ("HeadTracker");
		else
		{
			// User should fill this if not using vrpn
			return Vector3.zero;
		}
	}

	Quaternion GetTrackerRotation ()
	{
		if (m_UseVRPNTracker)
			return ClusterInput.GetTrackerRotation ("HeadTracker");
		else
		{
			// User should fill this if not using vrpn
			return Quaternion.identity;
		}
	}

	private void GenerateReordering ()
	{
		if (m_ReorderingTrackInputAxis)
		{
			for (int i = 0; i < 3; ++i)
			{
				m_TrackerReorderedIndex[i] = GetReorderingAxisIndex (m_TrackerReorderingFormat[i].ToUpper());
				m_TrackerReorderedInversion[i] *= GetReorderingAxisInversion(m_TrackerReorderingFormat[i]);
			}
		}
	}

	private int GetReorderingAxisIndex (string str)
	{
		char c = 'a';
		if (str.Length == 2)
			c = str[1];
		else if (str.Length == 1)
			c = str[0];
		else
			Debug.LogWarning ("Invalid Reordeing format");

		if (c == 'X')
			return 0;
		else if (c == 'Y')
			return 1;
		else if (c == 'Z')
			return 2;

		Debug.LogWarning ("Invalid Reordeing format");
		return -1;
	}

	private int GetReorderingAxisInversion (string str)
	{
		if (str.Length == 2
			&& str[0] == '-')
			return -1;
		else if (str.Length == 1)
			return 1;

		Debug.LogWarning ("Invalid Reordeing format");
		return 0;
	}

	private bool IsClusterRendereCAVE ()
	{
		ClusterRenderer clusterRenderer = GetComponent<ClusterRenderer> ();
		if (clusterRenderer != null)
		{
			CameraRig rig = clusterRenderer.rig;
			if (rig.clusterType == CameraRig.ClusterType.Cave)
			{
				return true;
			}
		}

		return false;
	}
}
