import sys,copy,random,math,os,time,getopt,re,traceback,urllib,difflib,types

sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)),"../libraries/"))

import ned_stats as stats
import ned_basic as basic

import ned_conservationScorer as conservationHelper
import ned_disorderHelper as disorderHelper
import ned_uniprotHelper as uniprotHelper
import ned_proteinInfoHelper as proteinInfoHelper

import ned_SLiMPrints_Tester as slimprints_Tester


import rje_seq,rje_blast,rje_uniprot,rje_tree

from sets import Set

#slimprints.py -a ../Datasets/Alignments/Human/ -c0.35 -qT -W30 -PF -fT -m0.1 -FF
version = 3


	
class SLiMPrints:
	def __init__(self,options):
		self.options = options
		
	
		self.data = {}
		self.runnable = True
		
	def getMotif(self,offsets,sequence):
		motif = ""
		last = offsets[0]
		
		for offset in offsets[:-1]:
			motif +=  "."*(offset-last- 1) + sequence[offset]
			last = offset
			
		motif += "."*(offsets[-1]-last - 1) + sequence[offsets[-1]]
		return motif
			
			
	def readOffsets(self):
	
		if len(self.options["offsets_file"]) > 0:
			self.data["offsets"] = {}
			
			for instance in open(self.options["offsets_file"]).read().strip().split("\n")[1:]:
				bits = instance.split()
				protein = bits[4].split("__")[1]
				start =  int(bits[5])
				motif = bits[3]
				if protein in self.data["offsets"].keys():
					self.data["offsets"][protein].append([motif,start]) 
				else:
					self.data["offsets"][protein] = [[motif,start]]
		else:			
			self.data["offsets"] = []
			for motif in self.options["offsets"].split(","):
				self.data["offsets"].append([motif.split(":")[0],int(motif.split(":")[1])])	
		

	"""def readInput(self):
		######
		## Check for number of orthologues
		######
		if self.data["Species"] <= self.options["MinSpec"]:
			print "Too few (" + str(self.data["Species"]) + ") ortholog sequence. At least " + str(self.options["MinSpec"]) + " needed. "
			return
			
		if int(self.options['stop']) == 0 or self.options['stop'] > len(self.data["Sequence"]):
			self.options['stop'] = len(self.data["Sequence"])
		
		if int(self.options['start']) == 0:
			self.options['start'] = 1
		
		self.data["alignment"] = {}
		
		for i in range(self.options['start'] - 1,self.options['stop']):
			self.data["alignment"][i] = {"column":[]}
			for seq in seqs.seq:
				self.data["alignment"][i]["column"].append(seq.info['Sequence'][i])
		
		print self.data["alignment"]
		self.columnValues = []
		
		###########
		###########
		self.data["proteinListOrdered"] = []
		
		self.data["gaps"] = 0
		
		for i in range(self.options['start'] - 1,self.options['stop']):
			for seq in seqs.seq:
				if i == self.options['start'] - 1:
					self.data["proteinListOrdered"].append(seq.info['AccNum'])
				
				
				if seq.info['Sequence'][i] == "-" or seq.info['Sequence'][i] == "X":
					self.data["gaps"] += 1
					
		
		self.data["residues"] = (self.options['stop'] - self.options['start'])*len(seqs.seq)
		###########
		###########
		
		
		#if self.options["probScores"] == "T":
		#	scorer = probabilisticAlignmentScoring(options,self.data)
		#	scorer.scoreConservationProbabilityAlignment(seqs)

"""

	
	################################
	# Initialise data
	################################

	def initialiseData(self,data={}):
		
		self.hits = {1:{}}
		order = {}
		
		score = "WCS_W_p"
		
		self.data["unmasked"] = self.data["ignoreList"].count("")
		
		for i in range(1,len(self.data[score])):
			#print i,"\t",self.data[score][i]
			if self.data[score][i] in self.data:
				pass
			else:
				if self.data["ignoreList"][i] == "" or self.data["ignoreList"][i] == "MOTIF":
					order[self.data[score][i]] = [i]
		
		sorter = order.keys()
		sorter.sort()
		sorter.reverse()
		
		self.orderList = []
		count = 0
		
		for o in sorter:
			count += 1
			if count < self.options["maxCount"] and self.data["Disorder"][i] > self.options["iucut"]:
				if  o < 0.5:#self.options["minScore"]:
					self.hits[1][str(order[o])] = {}
					self.hits[1][str(order[o])]["scores"] = [o]
					self.hits[1][str(order[o])]["offsets"] = order[o]
					self.hits[1][str(order[o])]["disorder"] = [self.data["Disorder"][order[o][0]]]
					self.hits[1][str(order[o])]["parents"] = []
					self.orderList.append(str(order[o]))
			
		return [self.hits,self.orderList]
	
	################################
	# create SLiMPrints
	################################
	




	def createMotifs(self):
		self.ranks = {}
		
		for length in range(2,self.options["maxLength"]+1):
			self.hits[length] = {}
			
			for residue in self.orderList:
				queryScore  = 0
				query = 0
				queryScore = self.hits[1][residue]["scores"][0]
				query = self.hits[1][residue]["offsets"][0]
				
				for hit in self.hits[length - 1].keys():
					dis = min([abs(query - min(self.hits[length - 1][hit]["offsets"])), abs(query - max(self.hits[length - 1][hit]["offsets"]))]) - 1
						
					#print length,dis,residue,hit,
					
					if dis > self.options["gap"] or dis == -1 or query in self.hits[length - 1][hit]["offsets"]:
						pass
					else:
						scores = copy.deepcopy(self.hits[length - 1][hit]["scores"])
						offsets = copy.deepcopy(self.hits[length - 1][hit]["offsets"])
						disorder = copy.deepcopy(self.hits[length - 1][hit]["disorder"])
						parents = copy.deepcopy(self.hits[length - 1][hit]["parents"])
						
						scores.append(queryScore)
						offsets.append(query)
						
						disorder.append(self.data["Disorder"][query])
						id = copy.deepcopy(offsets)
						id.sort()
						
						parents.append(str(id))
						
						if str(id) in self.hits[length]:
							pass
						else:
							self.hits[length][str(id)] = {}
							self.hits[length][str(id)]["scores"] = scores
							self.hits[length][str(id)]["offsets"] = id
							self.hits[length][str(id)]["disorder"] = disorder
							self.hits[length][str(id)]["parents"] = parents
						
				
			for hit in self.hits[length]: 
				probs = []
				for p in self.hits[length][hit]["scores"]:
					probs.append(p)
				
				motifCount = pow(self.options["gap"] + 1,(length-1))*(max(1,self.data["unmasked"]))
			
				#print motifCount,self.options["gap"],length,pow(self.options["gap"] + 1,(length-1)),(length-1),len(self.data["Sequence"]) , (self.data["unmasked"])
				sig = float(1 -(1 - stats.product(probs))**motifCount)

				probCor = stats.cum_uniform_product(stats.product(probs),length)
				
				sigCor = float(1 -(1 - probCor)**motifCount)
				score = probCor

				meanrelConValues = sum(probs)/len(probs)
				varRelConValues = sum([abs(x - meanrelConValues) for x in probs])/len(probs)
				
				hasInsertion = False
				motifPattern  = re.compile(self.getMotif(self.hits[length][hit]["offsets"],self.data["Sequence"]))
				for insertion in self.data["insertion"]:
					if basic.within(min(self.hits[length][hit]["offsets"]),max(self.hits[length][hit]["offsets"]),insertion):
						for acc in self.data["insertionLengths"][insertion]:
							if self.data["insertionLengths"][insertion][acc] > 2:
								if len(motifPattern.findall(self.data["Alignment_Degapped"][acc][min(self.hits[length][hit]["offsets"]):max(self.hits[length][hit]["offsets"])+1])) > 0:
									hasInsertion = True
								
				orderedOffsets = self.hits[length][hit]['offsets']
				orderedOffsets.sort()
				
				context = ""
				SurfaceAccessibilityMean = ""
				SurfaceAccessibilityMax =  ""
				dsspSecondarystructurePercent = ""
			
				try:
					SurfaceAccessibilityMean = [str(sum(self.data["SurfaceAccessibilityNormalised"][x])/len(self.data["SurfaceAccessibilityNormalised"][x])) if x in self.data["SurfaceAccessibilityNormalised"] else "N/A" for x in orderedOffsets]
					SurfaceAccessibilityMax = [str(max(self.data["SurfaceAccessibilityNormalised"][x])) if x in self.data["SurfaceAccessibilityNormalised"] else "N/A" for x in orderedOffsets]		
					dsspSecondarystructurePercent = [self.data["dsspSecondarystructurePercent"][x] if x in self.data["dsspSecondarystructurePercent"] else "N/A" for x in orderedOffsets]
				except:
					pass
					
				gapPercent = [float(self.data["gapPercent"][x]) for x in range(min(orderedOffsets),max(orderedOffsets) +1)]
			
				for x in range(max(0,min(orderedOffsets) - 5),min(max(orderedOffsets) + 6,len(self.data["Sequence"]))):
					aa = self.data["Alignment_Degapped"][self.data["Acc"]][x]
					
					if x in orderedOffsets:
						context += aa
					else:
						context += aa.lower()
			
				sa = []
				
				try:			
					for x in orderedOffsets:
						if x in self.data["SurfaceAccessibilityNormalised"]:
							sa.append(max(self.data["SurfaceAccessibilityNormalised"][x])) 
				except:
					pass
					
					
				
				if len(sa) > 0:
					sa = "%1.2g"%(sum(sa)/len(sa))
				else:
					sa = ""
					
				#print self.getMotif(self.hits[length][hit]["offsets"],self.data["Sequence"]),gapPercent
				if score > 0.99 or hasInsertion or varRelConValues > 0.05:
					pass
				else:
					#print sig,"\t",self.getMotif(self.hits[length][hit]["offsets"],self.data["Alignment_Degapped"][self.data["Acc"]]),"\t",["%1.2f"%x for x in self.hits[length][hit]["scores"]]
				
					rankData ={
						"prob":stats.product(probs),
						"sig":sig,
						"probCor":probCor,
						"sigLenCor":sigCor,
						"probs":probs,
						"varRelConValues":varRelConValues,
						"length":length,
						"id":hit,
						"motif":self.getMotif(self.hits[length][hit]["offsets"],self.data["Alignment_Degapped"][self.data["Acc"]]),
						"context":context,
						"start":min(self.hits[length][hit]["offsets"] ),
						"stop":max(self.hits[length][hit]["offsets"]),
						"residues":self.hits[length][hit]["offsets"],
						"scores":self.hits[length][hit]["scores"],
						"disorder":self.hits[length][hit]["disorder"],
						"parents":self.hits[length][hit]["parents"],
						"anchor":sum([float(self.data["anchorScore"][x]) for x in orderedOffsets])/len(orderedOffsets),
						"anchorScores":[float(self.data["anchorScore"][x]) for x in orderedOffsets],
						
						"gapPercent":float(sum(gapPercent))/len(gapPercent),
						"SA":sa,
						"SurfaceAccessibilityMean":SurfaceAccessibilityMean,
						"SurfaceAccessibilityMax":SurfaceAccessibilityMax,
						"dsspSecondarystructurePercent":dsspSecondarystructurePercent,
						
						"insertion":sum([self.data["insertion"][x] if x in self.data["insertion"] else 0 for x in range(min(orderedOffsets),max(orderedOffsets)+1)])
						}
						
					

					key = score
					
					if key == 0:
						key = 0.1**16
						
					while key in self.ranks.keys():
						e = 16
						
						inc = 0.1**e
						key += inc
				
					self.ranks[key] = rankData
	
	def findChildren(self):
		topRanks = self.ranks.keys()
		topRanks.sort()
		offsetDict ={}
		
		self.compressedRanks = {}
		
		for score in topRanks:
			check = "F"
			t= set(self.ranks[score]["residues"])
		
			for offset in offsetDict:
				s = set(offsetDict[offset])
				if len(t.intersection(s)) != 0:
					check = "T"
					subRank = offset
					break
				
			if check == "T":
				self.compressedRanks[subRank]["children"][score] = copy.deepcopy(self.ranks[score])
			
			else:
				offsetDict[score] = t
			
				self.compressedRanks[score] = {"data":self.ranks[score],"children":{}}
				
		
	def compressMotifs(self):
		x = 0
		while x < len(self.compressedRanks):
			sorter = self.compressedRanks.keys()
			sorter.sort()
			rank = sorter[x]
			removeAsSub = []
			
			for comp in self.compressedRanks:
				if float(comp) > float(rank):
					t = set(self.compressedRanks[comp]["data"]["residues"])
					s = set(self.compressedRanks[rank]["data"]["residues"])
					
					if len(s.difference(t)) == 0:
						pass
					if len(s.intersection(t))  > 0 and len(s.difference(t)) > 0:
						removeAsSub.append(comp)
						
			for r in removeAsSub:
				children = {}
			
				children.update({r:self.compressedRanks[r]["data"]})
				children.update(self.compressedRanks[r]["children"])
				children.update(self.compressedRanks[rank]["children"])
				
				self.compressedRanks[rank]["children"] = children
			
				self.compressedRanks[r]["children"] ={}
				
				del self.compressedRanks[r]
			
			x += 1
	
	
	
	################################
	#Outputs
	################################
	def createCompressedOutput(self):
		sorter = self.compressedRanks.keys()
		sorter.sort()
		
		header = ["score","slim","submot","context","pos","spec","sig","sigCor","p","pCor","dis","anchor","anchorScores","meanSA","meanSAscores","maxSAscores","SS","insertions","gapPercent","gaps","unmasked","pVar"]
		output = "rank\t" + "\t".join(header) + "\t"
		
		counter = 0
		self.compressed = {}
		
		headerComplete = False
		
		proteinInfo = proteinInfoHelper.proteinInfoHelper()
		
		for i in sorter:		
			disorder = sum(self.compressedRanks[i]["data"]["disorder"])/len(self.compressedRanks[i]["data"]["disorder"])
			
			if disorder > self.options["iucut"]:
				counter += 1
				
				line = ""
				if i < 0.001:
					line +="%1.1e"%(i)  + "\t" 
				else:
					line +="%1.4f"%(i)  + "\t" 
			
				line += "%-14s"%(self.compressedRanks[i]["data"]['motif']) + "\t"
				line +=  str(len(self.compressedRanks[i]['children'])) + "\t"
				
				line += "%-14s"%(self.compressedRanks[i]["data"]['context']) + "\t"
				
				line += str(self.compressedRanks[i]["data"]['start']) + ":" + str(self.compressedRanks[i]["data"]['stop']) + "\t"
				line += str(self.data["SpeciesCount"]) + "\t"  
			
				
				line += "%1.2g"%(self.compressedRanks[i]["data"]['sig']) + "\t"
				line += "%1.2g"%(self.compressedRanks[i]["data"]['sigLenCor']) + "\t"
				line += "%1.2g"%(self.compressedRanks[i]["data"]["prob"]) + "\t"
				line += "%1.2g"%(self.compressedRanks[i]["data"]["probCor"]) + "\t"
				
				
				line += "%1.2g"%disorder + "\t"
				
				line += "%1.2g"%self.compressedRanks[i]["data"]["anchor"] + "\t"
				
				line += ",".join(["%1.2g"%x for x in self.compressedRanks[i]["data"]["anchorScores"]]) + "\t"
				
				
				line += self.compressedRanks[i]["data"]["SA"] + "\t"
				line += ",".join([x for x in self.compressedRanks[i]["data"]["SurfaceAccessibilityMean"]]) + "\t"
				line += ",".join([x for x in self.compressedRanks[i]["data"]["SurfaceAccessibilityMax"]]) + "\t"
				line +=  ",".join(self.compressedRanks[i]["data"]["dsspSecondarystructurePercent"]) + "\t"
				
				line +=  str(self.compressedRanks[i]["data"]["insertion"]) + "\t"
				
				line +=  str(self.compressedRanks[i]["data"]["gapPercent"]) + "\t"
				line += "%1.3g"%(float(self.data["gaps"])/int(self.data["residues"])) + "(" + str(self.data["gaps"]) + "/" + str(self.data["residues"]) + ")\t"
				line += "%1.2g"%(float(self.data["unmasked"])/len(self.data["Sequence"])) + "(" + str(self.data["unmasked"]) + "/" + str(len(self.data["Sequence"])) + ")\t"
				
				
				line += "%1.3f"%(self.compressedRanks[i]["data"]["varRelConValues"]) + "\t"
					
				########
				
				inFeatureTmp = proteinInfo.overlapsFeature(int(self.compressedRanks[i]["data"]['start']),self.compressedRanks[i]["data"]['motif'],self.data['Feature_by_type'])
				sortedFeatureHeader = inFeatureTmp.keys()
				sortedFeatureHeader.sort()
				
				
				if headerComplete == False:
					output += "\t".join(sortedFeatureHeader) + "\n"
					headerComplete = True
					
				for val in sortedFeatureHeader:
					line += ",".join([str(x) for x in inFeatureTmp[val]]) + "\t"
	
				########
					
				output += "(" + str(counter) + ")\t" + line + "\n"

				if counter == 1 and self.options["quiet"] == "F": 
					if i < 0.0001:
						lineBits = line.split("\t")[:-1]
						headerBits = header + sortedFeatureHeader
						
						print "\n#Top ranking hit"
						for val in range(0,len(lineBits)):
							print "%20s"%headerBits[val],"\t",("\n%20s"%"" + "\t").join(lineBits[val].split(","))
					else:
						print 
						
				self.compressed[counter] = {"score":i,"motif":str(self.compressedRanks[i]["data"]['motif']),"start":self.compressedRanks[i]["data"]['start'], "stop":self.compressedRanks[i]["data"]['stop']}
				
				offsets = set([])
				offsets = offsets.union(set(self.compressedRanks[i]["data"]["residues"]))
				
				childScores = self.compressedRanks[i]['children'].keys()
				childScores.sort()
				subCounter = 1
				
				kidsOutput = ""
				
				if len(childScores) > 0 and self.options["printSub"] == "T":
					for child in childScores:
						offsets = offsets.union(set(self.compressedRanks[i]['children'][child]["residues"]))
					
					length = len(offsets)
					offsets = list(offsets)
					offsets.sort()
					
					for child in childScores:
						shiftL =  max(offsets) - self.compressedRanks[i]['children'][child]['stop']
						shiftR = max(offsets) - self.compressedRanks[i]['children'][child]['stop']
						
						kidsOutput += "C\t" + str(counter)+"."+ str(subCounter) + "\t" +"%1.2g"%(child) + "\t" + ("%20s"%(str(self.compressedRanks[i]['children'][child]['motif'])+" "*shiftL)) + "\t" + str(self.compressedRanks[i]['children'][child]['start']) + ":" + str(self.compressedRanks[i]['children'][child]['stop']) + "\n"
						subCounter+=1
						
						
					motifCount = pow(self.options["gap"]  +1,(length-1))*(len(self.data["Sequence"]) -(length + self.data["unmasked"]))
					
					
					sig = 1 -(1 - getProb(offsets,self.data["RLC"]))**motifCount
					
					kidsDiscriptor = "D\t" + str(counter)+"."+ str(0) + "\t" "%1.2g"%(sig)  + "\t" + "%20s"%self.getMotif(list(offsets),self.data["Sequence"]) + "\t" + str(min(offsets)) + ":" + str(max(offsets))  +"\n"
					
					output += kidsDiscriptor  + kidsOutput 
				
			
		if counter == 0:
			print 
					
		if self.options["quiet"] == "F":
			pass#print output
			
		open(os.path.join(self.options["resdir"],self.data["Acc"] + "_comp_SLiMPrints.out"),"w").write(output)

		return output

	def createFlatOutput(self):
		sorter = self.ranks.keys()
		sorter.sort()
		
		output = ""
		counter = 0
		
		for i in sorter:
			#print self.ranks[i]
			counter += 1
			output +=  str(counter) + "\t" +"%1.2g"%(i) + "\t" + "%20s"%str(self.ranks[i]['motif']) + "\t" + str(self.ranks[i]['start']) + ":" + str(self.ranks[i]['stop'])+ "\n"
			
		open(os.path.join(self.options["resdir"],self.data["Acc"] + "_flat_SLiMPrints.out"),"w").write(output)
		#print output
		
		return output 
		
	def printScores(self):
		headers = ["Sequence","WCS_W_rStdev","WCS_W_p","WCS","ABS","ABV","ResidueDisorder","ignoreList"]
		
		outStr = "#\t" + "\t".join(headers) + "\n"
		
		for i in range(0,len(self.data["Sequence"])):
			outStr += str(i) + "\t"
			for key in headers:
				outStr += str(self.data[key][i]) + "\t"
		
			if i in self.data["insertion"]:
				outStr += str(self.data["insertion"][i]) + "\t"
			else:
				outStr +="\t"
				
			#
			outStr += "*"*int((1 - self.data["WCS_W_p"][i])*100)
				
			outStr += "\n"	
		
		print outStr
			
	################################
	################################

