#!/usr/bin/env python
import sys, time
import os
import roslib;
import rospy
import numpy as np
import timeit
from std_msgs.msg import Float32
import cv2
import matplotlib.pyplot as plt
from sensor_msgs.msg import LaserScan
import scipy.spatial.distance as dist
from leg_detection.msg import single_leg
from leg_detection.msg import multiple_legs
from sklearn.cluster import KMeans



def nan_helper(y):
    """Helper to handle indices and logical indices of NaNs.

    Input:
        - y, 1d numpy array with possible NaNs
    Output:
        - nans, logical indices of NaNs
        - index, a function, with signature indices= index(logical_indices),
          to convert logical indices of NaNs to 'equivalent' indices
    Example:
        >>> # linear interpolation of NaNs
        >>> nans, x= nan_helper(y)
        >>> y[nans]= np.interp(x(nans), x(~nans), y[~nans])
    """

    return np.isnan(y), lambda z: z.nonzero()[0]


def get_zero_runs(a):
    # Create an array that is 1 where a is 0, and pad each end with an extra 0.
    iszero = np.concatenate(([0], np.equal(a, 0).view(np.int8), [0]))
    absdiff = np.abs(np.diff(iszero))
    # Runs start and end where absdiff is 1.
    ranges = np.where(absdiff == 1)[0].reshape(-1, 2)
    return ranges

def calc_coordinate(angle,r):
    angle = angle-90.
    x_coord = np.sin(angle*np.pi/180.)*r
    y_coord = np.cos(angle*np.pi/180.)*r
    return x_coord,y_coord

class legDetectionLaser:
    def __init__(self):
		# Where we want to publish
        self.LASER_NODE = rospy.get_param("~LASER_NODE")
        self.LEG_PUB = rospy.Publisher("~output",multiple_legs)#,queue_size=6)
        self.MAX_ADJ_DIST = 0.1 # 0.075 #Maximum allowable range difference between two adjacent scan points when finding segments
        self.LASER_RES = 180./509 #Resolution of laser scanner
        self.SEG_LENGTH = 4 #Minimum number of scan points / readings in a segment
        self.SEG_MEAN = 5 #2 #Maximum mean range of segment
        self.SEG_STD = 0.04 #0.04 #Maximum standard deviation of segment
        self.SEG_WIDTH_MIN = 0.05 #0.05 #Min width of segment
        self.SEG_WIDTH_MAX = 0.25 #0.25 #Max width of segment
        self.MAX_LEG_SPACE = 0.5 #0.4 #Max distance between two legs
        self.LOG_DIR = rospy.get_param("~LOG_DIR")
        self.DO_LOG = rospy.get_param("~DO_LOG")
        self.MIN_DIST = rospy.get_param("~MIN_DIST")
        file_date = time.strftime("%y_%m_%d_%H_%M_%S")
        self.FILE_PATH = os.path.join(self.LOG_DIR,'LEG_LASER_SCAN_'+file_date)
        if self.DO_LOG:
            os.makedirs(self.FILE_PATH)
            self.log_file_scan = os.path.join(self.FILE_PATH,'log_laser.txt')



        self.Node_name = rospy.get_name()
        rospy.loginfo('Initialization successful. Node: '+self.Node_name)
        self.subscriber = rospy.Subscriber(self.LASER_NODE, LaserScan, self.callback) #Subscribes to array data


    def callback(self, data):
        #rospy.loginfo('Laser scan received')
        x = np.asarray(data.ranges)
        #x = self.remove_nan(x)
        x[np.isnan(x)] = 5
        x[x<self.MIN_DIST] = self.MIN_DIST

        if self.DO_LOG:
            self.Fid = open(self.log_file_scan,'a')
            np.savetxt(self.Fid,x,fmt='%10.2f',newline=' ')            
            self.Fid.write("\n")            
            self.Fid.close()

        leg_segments,leg_stats = self.calc_and_filter_segment_stats(x)
        #rospy.loginfo('###############################')
        #rospy.loginfo(leg_segments)
        #rospy.loginfo(leg_stats)
        person_segments = self.group_segments(leg_segments,leg_stats)
        #rospy.loginfo('N_person :'+str(len(person_segments)))
        #rospy.loginfo(person_segments)


        N_person = len(person_segments)
        if N_person > 0:
            rospy.loginfo('N_stats :'+str(len(person_segments[0])))
            leg_msg = multiple_legs()
            leg_msg.N_leg_pairs = N_person
            leg_msg.header.stamp = rospy.Time.now()
            for ii in range(N_person):
                single_leg_msg = single_leg()
                single_leg_msg.header.stamp = rospy.Time.now()
                single_leg_msg.angle = person_segments[ii][0]
                single_leg_msg.range = person_segments[ii][1]
                single_leg_msg.x_coord = person_segments[ii][2]
                single_leg_msg.y_coord = person_segments[ii][3]
                single_leg_msg.angle_1 = person_segments[ii][4]
                single_leg_msg.angle_2 = person_segments[ii][5]
                single_leg_msg.range_1 = person_segments[ii][6]
                single_leg_msg.range_2 = person_segments[ii][7]
                leg_msg.leg_pairs.append(single_leg_msg)

            self.LEG_PUB.publish(leg_msg)
        else:
            leg_msg = multiple_legs()
            leg_msg.N_leg_pairs = 0
            leg_msg.header.stamp = rospy.Time.now()            
            self.LEG_PUB.publish(leg_msg)

    def group_segments(self,segments,segment_stats):
        N_segments = len(segment_stats)
        rospy.loginfo('N_segments: '+str(N_segments))
        if N_segments == 2:
            leg_pair = [np.array([(segment_stats[1][0]+segment_stats[0][0])/2., (segment_stats[1][2]+segment_stats[0][2])/2., (segment_stats[1][5]+segment_stats[0][5])/2., (segment_stats[1][6]+segment_stats[0][6])/2., segment_stats[0][0], segment_stats[1][0], segment_stats[0][2], segment_stats[1][2] ])]
        elif N_segments > 2:
            leg_pair = []
            X = np.asarray([segment_stats[i][5:7] for i in range(N_segments)])
            inter_dist = dist.cdist(X,X,'euclidean')
            inter_dist[inter_dist==0] = np.inf
            #rospy.loginfo(inter_dist)
            for ii in range(N_segments):
                min_idx = np.argmin(inter_dist[ii,:])
                if ii == np.argmin(inter_dist[:,min_idx]):
                    leg_pair.append(np.array([(segment_stats[min_idx][0]+segment_stats[ii][0])/2., (segment_stats[min_idx][2]+segment_stats[ii][2])/2., (segment_stats[min_idx][5]+segment_stats[ii][5])/2., (segment_stats[min_idx][6]+segment_stats[ii][6])/2., segment_stats[ii][0], segment_stats[min_idx][0], segment_stats[ii][2], segment_stats[min_idx][2] ]))
                    inter_dist[:,min_idx] = np.inf
                    inter_dist[min_idx,:] = np.inf
                    inter_dist[ii,:] = np.inf
                    inter_dist[:,ii] = np.inf
        else:
            leg_pair = []

        return leg_pair

    def calc_and_filter_segment_stats(self,x):
        # Find connected segments, where the range diff between to adjacent point is less than MAX_ADJ_DIST
        x = np.r_[0,x]
        x_diff = np.abs(np.diff(x))
        x_diff[x_diff<self.MAX_ADJ_DIST] = 0
        x_diff[x_diff>=self.MAX_ADJ_DIST] = self.MAX_ADJ_DIST
        segments = get_zero_runs(x_diff)
        

        # Calculate the features/statistics for each segment and potentially remove 
        N_segments = len(segments)
        delete_idx = []
        segment_stats = []

        for ii in range(N_segments):
            tmp_center = (segments[ii,0]+(segments[ii,1]-segments[ii,0])/2.) * self.LASER_RES -90.
            tmp_len = segments[ii,1]-segments[ii,0]
            tmp_mean = np.mean(x[segments[ii,0]:segments[ii,1]])
            tmp_std = np.std(x[segments[ii,0]:segments[ii,1]])
            tmp_angle_span = tmp_len*self.LASER_RES
            tmp_width = tmp_mean * np.sin(tmp_angle_span*np.pi/180.) / np.sin((180.-tmp_angle_span)/2.*np.pi/180.)
            tmp_x_coord,tmp_y_coord = calc_coordinate(tmp_center,tmp_mean) #Coordinates relative to the laser scanner (robot)
            if ii == 0:
                coord_ar = np.array([tmp_x_coord,tmp_y_coord])
            else:
                coord_ar = np.r_['0,2',coord_ar,[tmp_x_coord,tmp_y_coord]]
            if tmp_len > self.SEG_LENGTH and tmp_mean < self.SEG_MEAN and tmp_std < self.SEG_STD and tmp_width > self.SEG_WIDTH_MIN and tmp_width < self.SEG_WIDTH_MAX:
                segment_stats.append([tmp_center,tmp_len,tmp_mean,tmp_std,tmp_width,tmp_x_coord,tmp_y_coord])
            else:
                delete_idx.append(ii)

        # Maybe all removed
        if len(delete_idx) < N_segments:
            filtered_segments = np.delete(segments,delete_idx,axis=0)
            coord_ar = np.delete(coord_ar,delete_idx,axis=0)
        else:
            filtered_segments = []

        # Remove single segments which are too far from neighboring segments
        N_filtered_segments = len(filtered_segments)
        if N_filtered_segments > 1:
            inter_dist = dist.cdist(coord_ar,coord_ar,'euclidean')
            inter_dist[inter_dist==0] = np.inf
            keep_idx = np.min(inter_dist,axis=1)<self.MAX_LEG_SPACE
            #print inter_dist
            #print keep_idx
            leg_segments = filtered_segments[keep_idx,:]
            leg_stats = [segment_stats[i] for i in np.where(keep_idx)[0]]
        else:
            leg_stats = segment_stats
            leg_segments = filtered_segments



        return leg_segments,leg_stats


    def remove_nan(self,data):
        if np.sum(np.isnan(data))>0: #Interpolates all the nan values
            #rospy.loginfo('CostMap: Interpolating NaN values')
            nans, x = nan_helper(data)
            data[nans]= np.interp(x(nans), x(~nans), data[~nans])
        
        return data    

def main(args):
    rospy.init_node('leg_detection_laser')
    #print prev_Phi.shape
    ic = legDetectionLaser() 
    try:
        rospy.spin()
    except rospy.ROSInterruptException:
        pass

if __name__ == '__main__':
    main(sys.argv)