import re,copy,sys,os,re

from sets import Set

import ned_commandLine as commandline
import ned_basicReader as basicReader


amb = re.compile('\.\{[^\}]*\}|\[[^\]]*\]\{[^\}]*\}|[A-Z]\{[^\}]+\}|\[[^\]]*\]|\.\{[0-9]\}|[A-Za-z\.\$\^]')
ambWithMod = re.compile('\.\{[^\}]*\}|\[[^\]]*\]\{[^\}]*\}|[A-Z]\{[^\}]+\}|\[[^\]]*\]|\.\{[0-9]\}|[A-Za-z\.\$\^]|\([^\)\(]+\)')

roundBrackets = re.compile('\([^\)\(]+\)')
roundOptionsBrackets = re.compile('\([^\|\)\(]+\)\|\([^\|\)\(]+\)') 
curlyBrackets = re.compile('\[[^\]]*\]\{[^\}]+\}|\.\{[^\}]+\}|[A-Z]\{[^\}]+\}')
singleCurlyBracket =  re.compile('\[[^\]]*\]\{[0-9]\}')
notModBrackets = re.compile('\|\([^\)\(]+\)|\([^\)\(]+\)\|')
negativeOptionBrackets = re.compile('\[\^[^\]]+\]')

AA_Frequency = {}

AA_Frequency["all"] = {
'A': 0.086167,
 'C': 0.012759,
 'D': 0.052951,
 'E': 0.061395,
 'F': 0.040392,
 'G': 0.071295,
 'H': 0.021925,
 'I': 0.060238,
 'K': 0.052746,
 'L': 0.098334,
 'M': 0.024703,
 'N': 0.041559,
 'P': 0.047353,
 'Q': 0.038562,
 'R': 0.054707,
 'S': 0.066961,
 'T': 0.056144,
 'V': 0.067440,
 'W': 0.013183,
 'Y': 0.030613}


AA_Frequency["discut0.3"] ={
"A":0.082767,
"C":0.007309,
"D":0.058283,
"E":0.082903,
"F":0.022992,
"G":0.073166,
"H":0.024454,
"I":0.043799,
"K":0.064038,
"L":0.073925,
"M":0.023502,
"N":0.043715,
"P":0.062209,
"Q":0.048512,
"R":0.064976,
"S":0.075305,
"T":0.062569,
"V":0.059192,
"W":0.007122,
"Y":0.019261}

AA_Frequency["discut0.5"] ={
"A":0.079609,
"C":0.005029,
"D":0.057989,
"E":0.08815,
"F":0.015527,
"G":0.075008,
"H":0.024831,
"I":0.028494,
"K":0.06696,
"L":0.055605,
"M":0.021528,
"N":0.047938,
"P":0.085282,
"Q":0.055837,
"R":0.067414,
"S":0.0964,
"T":0.065369,
"V":0.044676,
"W":0.004848,
"Y":0.013504}

AA_list = AA_Frequency["all"].keys()
terminiProb = 1.0/321

class motifHelper():
	def __init__(self):
		cmdline = commandline.CommandLine()
		self.options = cmdline.loadIniFile(os.path.join(os.path.dirname(os.path.realpath(__file__)),"../settings/utilities.ini"))
		self.data = {}
		
	def expandMotif(self,motifInput):
		motif = motifInput.upper()
		
		motif = motif.replace("?","{0,1}")
		bits = amb.findall(motif)
		
		motifs = []
		expanded = []
		
		roundDict = {}
		roundBitIndexer = []
		roundModDict = {}
		roundModBitIndexer = []
		
		try:
			if motif.count("|") > 0 or motif.count("{") > 0:
					
				smallest = ""
				for roundBit in self.countMods(motif):	
						roundModBitIndexer.append(roundBit)
						roundModDict["m" + str(roundModBitIndexer.index(roundBit))] = roundBit
						smallest += "m" + str(roundModBitIndexer.index(roundBit)) + "|"
						motif = motif.replace(roundBit,"m" + str(roundModBitIndexer.index(roundBit)))
				
				round =  roundBrackets.findall(motif)
				
				if len(round) > 0:
					smallest = ""
					
					
				for roundBit in round:	
						roundBitIndexer.append(roundBit)
						roundDict["r" + str(roundBitIndexer.index(roundBit))] = roundBit
						smallest += "r" + str(roundBitIndexer.index(roundBit)) + "|"
						motif = motif.replace(roundBit,"r" + str(roundBitIndexer.index(roundBit)))
				
				for roundBits in roundOptionsBrackets.findall(motif):	
					for roundBit in roundBits.split("|"):
						#print "-", roundBit
						roundBitIndexer.append(roundBit)
						roundDict["r" + str(roundBitIndexer.index(roundBit))] = roundBit
						smallest += "r" + str(roundBitIndexer.index(roundBit)) + "|"
						motif = motif.replace(roundBit,"r" + str(roundBitIndexer.index(roundBit)))
						
				motifs = [motif]
				
				if motif == smallest[:-1]:
					for bit in motif.split("|"):
						motifs.append(bit)
					
					motifs.remove(motif)
					
				for motif in motifs:
					roundOption =  roundBrackets.findall(motif)
					removeList = []
					
					for roundOptionBit in roundOption:
						for round in roundOptionBit[1:-1].split("|"):
							for motifTemp in motifs:
								if roundOptionBit in motifTemp:
									motifs.append(motifTemp.replace(roundOptionBit,round))
						
					for roundOptionBit in roundOption:
						for motifTemp in motifs:
							if motifTemp.count(roundOptionBit) > 0:
								removeList.append(motifTemp)
						
					motifs = [i for i in motifs if i not in removeList]
					
					removeList = []
					for motifTemp in motifs:
						add = False
						
						removeList.append(motifTemp)
						
						for roundBit in roundDict:
							if roundBit in motifTemp:
								motifTemp = motifTemp.replace(roundBit,roundDict[roundBit]).replace("(","").replace(")","")
								
								add = True
						
						for roundBit in roundModDict:
							if roundBit in motifTemp:
								motifTemp = motifTemp.replace(roundBit,roundModDict[roundBit])
								add = True
								#print motifTemp
						if add:
							if motifTemp not in motifs:
								motifs.append(motifTemp)
						else:
							removeList.pop()
				
				motifs = [i for i in motifs if i not in removeList]
				for submotif in motifs:
					subbits = ambWithMod.findall(submotif)
					submotif = [submotif]
						
					for tempmotif in submotif:	
						curly =  curlyBrackets.findall(tempmotif)
						
						if len(curly) > 0:
							rootlist = [""]
							last = 0
							
							for rb in curly:
								min = rb.split("{")[1][0:-1].split(",")[0]
								if len(rb.split("{")[1][0:-1]) > 1:
									max = rb.split("{")[1][0:-1].split(",")[1]
								else:
									max = min
									
								repeat = rb.split("{")[0]
								
								
								temprootList = []
								for root in rootlist:	
									for i in range(int(min),int(max) + 1):
										temprootList.append(root + "".join(subbits[last:last + subbits[last:].index(rb)])  + repeat*i)
									
									rootlist = copy.deepcopy(temprootList)
								
								last = last + subbits[last:].index(rb)  +1
								
							temprootList = []
							
							for root in rootlist:
								if (root + "".join(subbits[last:])).strip(".") not in temprootList:
									temprootList.append((root + "".join(subbits[last:])))
								
							submotif = temprootList
							
						expanded += submotif
				
					
			else:
				if motif not in motifs:#replace("(","").replace(")","").
					motifs.append(motif)
		except:
			print id
			pass
			raise
			
		
		if len(expanded) == 0:
			expanded.append(motif)
		
		return expanded
	
	def countMods(self,motif):
		mods = notModBrackets.findall(motif)
		round = roundBrackets.findall(motif)
		
		startPos = 0
		
		while True:
			m =  notModBrackets.search(motif, startPos)
			
			if m is None:
				break
			
			[start,match] = [m.start(),m.group(0)]
			startPos = m.start() + 1
			
			
			round.remove(match.strip("|"))
			
		return round
	
	def calculateMotifProb(self,motif,aaProbabilities=AA_Frequency["all"]):
		expanded = self.expandMotif(motif)
		motifProb = 0
		
		for motifExpanded in expanded:
			subMotifProb = 1
			
			for pos in amb.findall(motifExpanded):
				if pos[0] == "[" and pos[1] == "^":
					ambProb= 0
					for aa in list(Set(AA_list).difference(Set(list(pos[2:-1])))):
						ambProb += aaProbabilities[aa]
						
					subMotifProb = subMotifProb*ambProb
					
				elif pos[0] == "[":
					ambProb= 0
					for aa in list(pos[1:-1]):
						ambProb +=aaProbabilities[aa]
						
					subMotifProb = subMotifProb*ambProb
					
				elif pos[0] == "^":
					subMotifProb = subMotifProb*terminiProb
					
				elif pos[0] == "$":
					subMotifProb = subMotifProb*terminiProb
					
				elif pos[0] == ".":
					pass
				else:
					subMotifProb = subMotifProb*aaProbabilities[pos]
			
			motifProb += subMotifProb
		
		return motifProb
			
	def getMotifStats(self,motif):
		
		motif = motif.replace("?","{0,1}")
	
		expanded = self.expandMotif(motif)
				
		wildCount = 0
		definedCount = 0
		lengthCount = 0
		fixedCount = 0
		modCount = 0
		negCount = 0
		ambCount = 0
		nTerm = False
		cTerm = False

		
		for motifExpanded in expanded:
			if motifExpanded[:2] == "^M":
				motifExpanded = "^" + motifExpanded[2:]

			term = False
			
			if motifExpanded[-1] == "$":
				cTerm = True
				term = True
				#motifExpanded = motifExpanded.replace("$","")
				
			if motifExpanded[0] == "^":
				nTerm = True
				term = True
				#motifExpanded = motifExpanded[1:]
	
			
			negativeOptionCount = len(negativeOptionBrackets.findall(motifExpanded))
			motifExpanded = motifExpanded.strip(".")
			bits = amb.findall(motifExpanded)
			
			if term:
				terminalAdj = 1
			else:
				terminalAdj = 0
				
				
			fixedCount += ([len(x) == 1 for x in bits ].count(True) - bits .count(".")) - terminalAdj
			definedCount += (len(bits ) - (negativeOptionCount + bits .count("."))) - terminalAdj
			lengthCount += len(bits) - terminalAdj
			
			wildCount += bits .count(".")
			ambCount +=  (len(bits )- (negativeOptionCount + bits .count("."))) - ([len(x) == 1 for x in bits ].count(True) - bits .count(".")) 
			modCount += len(self.countMods(motifExpanded))
			negCount += negativeOptionCount
			
			
	
		fixedCount = float(fixedCount)/len(expanded)
		definedCount = float(definedCount)/len(expanded)
		lengthCount = float(lengthCount)/len(expanded)
			
		wildCount = float(wildCount)/len(expanded) 
		ambCount = float(ambCount)/len(expanded)
		modCount = float(modCount)/len(expanded)
		negCount = float(negCount)/len(expanded)
		
		return {"len":"%1.2f"%lengthCount,
		"def":"%1.2f"%definedCount,
		"amb":"%1.2f"%ambCount,
		"fixed":"%1.2f"%fixedCount,
		"wild":"%1.2f"%wildCount,
		"neg":"%1.2f"%negCount,
		"N-term":nTerm,
		"C-term":cTerm,
		"Mod":int(modCount),
		"var":len(expanded),
		"expanded":expanded,
		"prob":"%1.2e"%self.calculateMotifProb(motif,AA_Frequency["all"]),
		"prob_0.3_discut":"%1.2e"%self.calculateMotifProb(motif,AA_Frequency["discut0.3"]),
		"prob_0.5_discut":"%1.2e"%self.calculateMotifProb(motif,AA_Frequency["discut0.5"])}
				
	def parseGOTerms(self):
		self.classes =  basicReader.readTableFile(self.options["elmClasses"],key="elmclassid",byColumn=False,relationship="binary")
		
		goTermClasses = {}
		for elmClass in self.classes:
			print elmClass,"\t"#self.classes[elmClass]['externallinks']
			for goTerm in self.classes[elmClass]['externallinks'].split(","):
				bits = goTerm.strip().split(":")
				if bits[0] == "GO cellular component":
					print goTerm
					if goTerm not in goTermClasses:	
						goTermClasses[goTerm] = []

					goTermClasses[goTerm].append(elmClass)
		
		for goTerm in goTermClasses:
			print goTerm,"\t",len(set(goTermClasses[goTerm])),"\t",",".join(list(set(goTermClasses[goTerm])))

		print len(set(goTermClasses[" GO cellular component:cytosol (GO:0005829)"]).intersection(set(goTermClasses[" GO cellular component:nucleus (GO:0005634)"])))

	def initialiseELMDataByID(self,types=["LIG","CLV","TRG","MOD"]):
		classesDict = {}
		
		self.classes =  basicReader.readTableFile(self.options["elmClasses"],key="ELMIdentifier",byColumn=False,relationship="binary",stripQuotes =True)
		self.types =  basicReader.readTableFile(self.options["elmInstances"],key="ELMIdentifier",byColumn=False,relationship="oneToMany",stripQuotes =True)
		
		self.instances = {}
		for elmType in self.types:	
			for instance in self.types[elmType]:
				tmpDict ={
				'elmclasstype':elmType,
				'elmclass':instance['ELMType'],
				'elmregexpr':self.classes[elmType]["Regex"],
				'uniprotid':instance["Accessions"].split()[0],
				'motifstop':instance['End'],
				'motifstart':instance['Start']
				}
				
				self.instances[elmType + "_" + tmpDict['uniprotid'] + "_" + instance['Start']] = tmpDict
		
		self.data["ID"] = {}
		self.data["ELM"] = {}
		
		for id in self.instances:
			elmType = self.instances[id]['elmclasstype']
			elmClass = self.instances[id]['elmclass']
			elmRE = self.instances[id]['elmregexpr']
			
		 	if elmType in types:
		 		#Pull from classes
				longDesc = elmClass
				desc = elmClass
				
				href = 'http://elm.eu.org/elmPages/' + id+ ".html"
			
				start = int(self.instances[id]['motifstart'])
				stop = int(self.instances[id]['motifstop'])
				
			

				if self.instances[id]['uniprotid'] not in self.data["ELM"]:
					self.data["ELM"][self.instances[id]['uniprotid']] = []
					
				
				self.data["ELM"][self.instances[id]['uniprotid']].append({"Start":start, "End":stop  +1, "Desc":desc,"LongDesc":longDesc,"href":href,"RE":elmRE,'Taxon':self.instances[id]['taxon'],"ELMID":id,'Sequence':self.instances[id]['proteinsequence']})
				self.data["ID"][id] = {"Start":start, "End":stop  +1,"Acc":self.instances[id]['uniprotid'], "Desc":desc,"LongDesc":longDesc,"href":href,"RE":elmRE,'Taxon':self.instances[id]['taxon'],"ELMID":id,'Sequence':self.instances[id]['proteinsequence']}
				
		
	def initialiseELMData(self,types=["LIG","CLV","TRG","MOD"]):
		classesDict = {}
		
		self.classes =  basicReader.readTableFile(self.options["elmClasses"],key="ELMIdentifier",byColumn=False,relationship="binary",stripQuotes =True)
		self.classInstances =  basicReader.readTableFile(self.options["elmInstances"],key="ELMIdentifier",byColumn=False,relationship="oneToMany",stripQuotes =True)
		
		self.data["ELM"] = {}
		self.data["RE"] = {}
		self.data["Class"] = {}
		self.data["Type"] = {}
		self.data["bindingDomain"] = {}
		
		for elmClass in self.classes:
			elmType = elmClass[0:3]
			if elmType in types:
				self.data["RE"][self.classes[elmClass]['Regex']] = {"Desc":elmClass}
				self.data["bindingDomain"][elmClass] = {"bindingDomain":self.classes[elmClass]['bindingDomain']}
		
		self.instances = {}
		for elmClass in self.classInstances :	
			for instance in self.classInstances [elmClass]:
				tmpDict ={
				'elmclasstype':instance['ELMType'],
				'elmclass':elmClass,
				'elmregexpr':self.classes[elmClass]["Regex"],
				'uniprotid':instance["Accessions"].split()[0],
				'motifstop':instance['End'],
				'motifstart':instance['Start'],
				'taxon':instance["Organism"],
				'bindingDomain':self.classes[elmClass]['bindingDomain']
				}
				
				self.instances[elmType + "_" + tmpDict['uniprotid'] + "_" + instance['Start']] = tmpDict
			
		
		for id in self.instances:	
			
			elmType = self.instances[id]['elmclasstype']
			elmClass = self.instances[id]['elmclass']
			elmRE = self.instances[id]['elmregexpr']
			
		 	if elmType in types:
		 		#Pull from classes
				longDesc = elmClass
				desc = elmClass
				
				href = 'http://elm.eu.org/elmPages/' + id+ ".html"
			
				self.data["RE"][self.instances[id]['elmregexpr']] = {"Desc":elmClass,"LongDesc":longDesc,"href":href}
				
				start = int(self.instances[id]['motifstart'])
				stop = int(self.instances[id]['motifstop'])
				
				if elmClass not in self.data["Class"]:
					self.data["Class"][elmClass] = {}
				
				if elmType not in self.data["Type"]:
					self.data["Type"][elmType] = {}

				if self.instances[id]['uniprotid'] not in self.data["ELM"]:
					self.data["ELM"][self.instances[id]['uniprotid']] = []
					
				if self.instances[id]['uniprotid'] not in self.data["Class"][elmClass]:
					self.data["Class"][elmClass][self.instances[id]['uniprotid']] = []
					
				if self.instances[id]['uniprotid'] not in self.data["Type"][elmType]:
					self.data["Type"][elmType][self.instances[id]['uniprotid']] = []
				
				
				self.data["ELM"][self.instances[id]['uniprotid']].append({"Start":start, "End":stop  +1, "Desc":desc,"LongDesc":longDesc,"href":href,"RE":elmRE,'Taxon':self.instances[id]['taxon'],"ELMID":id})
				self.data["Class"][elmClass][self.instances[id]['uniprotid']].append({"Start":start, "End":stop  +1, "Desc":desc,"LongDesc":longDesc,"href":href,"RE":elmRE})			
				self.data["Type"][elmType][self.instances[id]['uniprotid']].append({"Start":start, "End":stop  +1, "Desc":desc,"LongDesc":longDesc,"href":href,"RE":elmRE})
				
	def readELMData(self,accession=""):
		if "ELM" not  in self.data:
			self.initialiseELMData()
			
		if accession == "":
			pass
		else:
			if accession in self.data["ELM"]:
				return self.data["ELM"][accession]
			else:
				return {}
	
	
	####
	# Not tested
	####
	def findREHits(self):
		hits = {}
		
		if len(self.data["RE"]) > 0:
			self.data["Region"]["RE"]= []
			
			for motif in self.data["RE"]:
				hits[motif] = []
				motifMatcher = re.compile(motif)
				matchIterator = motifMatcher.finditer(self.data["Sequence"])
				for match in matchIterator:
					start = int(match.span()[0]) + 1
					stop = int(match.span()[1]) + 1
					
					if self.inFrame(start,stop):
						try:
							self.data["Region"]["RE"].append({"Start":start, "End":stop, "Desc":self.data["RE"][motif]["Desc"],"LongDesc":self.data["RE"][motif]["LongDesc"],"href":self.data["RE"][motif]["href"]})
						except Exception,e:
							print e
							pass
						
if __name__ == "__main__":	
	motifHelpObj = motifHelper()
	motifHelpObj.parseGOTerms()

	sys.exit()
	motifHelpObj.initialiseELMData()
	
	motifList =[]
	
	headers = ["prob","prob_0.3_discut","prob_0.5_discut" ,'fixed','def','amb','wild','neg','len','var','Mod','N-term','C-term']
	headerStr =  "\t".join(headers) + "\tMotifClass\tMotifRE" "\n"
	
	dataStr = ""
	
	#for re in ["(^M[DAL][VNI]R[RK])|(^M[HL]RR)","(.[^P].NP.(Y))|(.[ILVMFY].N..(Y))"]:
	for motifClass in motifHelpObj.classes:
		re = motifHelpObj.classes[motifClass]['elmregexpr']
		
		motifStats = motifHelpObj.getMotifStats(re)
		
		for header in headers:
			dataStr +=  str(motifStats[header]) + "\t" 
		
		dataStr += motifClass + "\t" +  re
	
		dataStr += "\n"
	
	print headerStr + dataStr
	#open(os.path.join(os.path.dirname(os.path.realpath(__file__)),"../../Datasets/ELM/Motif_RE_Stats.tdt"),"w").write(headerStr + dataStr)
	
	