import { TranscriptUnit, Rangable, LinearTurnsTranscript, SpeakerTurn } from "./common_db";
import { distance } from "fastest-levenshtein";
import { ModelType, MODEL_INFORMATION } from "./llm_clients";



function isTUnit(r: Rangable): r is TranscriptUnit {
    return 'word' in r;
}

type RF = {
    getHeader: (r: any) => string;
    joinDelimiter: string;
}

export const basicTurnFormatters = [
  {
      getHeader(r: any) {return `${r.speaker}\n`},
      joinDelimiter: '\n'
  },
  {
      getHeader(r: any) {return ""},
      joinDelimiter: ' '
  }]

export const copyTextTurnFormatters = [
  {
      getHeader(r: any) {return `${r.speaker}:\n`},
      joinDelimiter: '\n\n'
  },
  {
      getHeader(r: any) {return ""},
      joinDelimiter: ' '
  }
]

const EMPTY_RF: RF = {
  getHeader: (r: any) => '',
  joinDelimiter: ''
}

function rangableDepth(r: Rangable): number {
    if (isTUnit(r)) {
        return 1;
    }
    return 1 + rangableDepth(r.contents[0]);
}

export type RangableIteratorResult = {
    text: string;
    transcriptUnit?: TranscriptUnit;
  };
  
// Generator function to create iterator
export function* iterateRangable(
rangables: Rangable[], 
formatters: RF[]
): Generator<RangableIteratorResult, void, unknown> {
    for (const [index, rangable] of rangables.entries()) {
        if (isTUnit(rangable)) {
            yield {
                text: rangable.word,
                transcriptUnit: rangable
            };
            if (index < rangables.length - 1) {
                yield {
                    text: formatters[0].joinDelimiter
                };
            }
        } else {
            const header = formatters[0].getHeader(rangable);
            if (header) {
                yield {
                text: header
                };
            }
            
            yield* iterateRangable(rangable.contents, formatters.slice(1));
            
            const delimiter = formatters[0].joinDelimiter;
            if (delimiter && index < rangables.length - 1) {
                yield {
                text: delimiter
                };
            }
        }
    }
}

export function stringifyRangable(rangables: Rangable[], formatters: RF[]): string {
return [...iterateRangable(rangables, formatters)].map(({ text }) => text).join('');
}

// Make Rangable array iterable
export function makeRangableIterable(rangables: Rangable[], formatters: RF[]) {
    return {
        [Symbol.iterator](): Iterator<RangableIteratorResult> {
        return iterateRangable(rangables, formatters);
        }
    };
}


export type AnnotationMetadata = {
  id: string;
  authorID: string;
  madeByHuman: boolean
  conversationID: string;
  creationTime: number;
  lastModifiedTime: number;
  tags: string[];
  groupID: string;
}

export type AnnotationContent = Record<string, string>;

export type AnnotationAttributes = Record<string, string>;

export type Annotation = {
  startUnit: TranscriptUnit, 
  endUnit: TranscriptUnit
  content: AnnotationContent
  attributes?: AnnotationAttributes // only for info llms parsed from annotations
} & AnnotationMetadata; 

export type AnnotationGroupDBValue = {
  id: string;
  name: string;
  annotationIDs: string[];
  tags: string[];
  authorID: string;
  conversationID: string;
  creationTime: number;
  madeByHuman: boolean;
  promptDebugValueID?: string;
}

export type AnnotationGroup = Omit<AnnotationGroupDBValue, 'annotationIDs'> & {annotations: Annotation[]};

type PartialAnnotation = {startUnit?: TranscriptUnit, attributes: AnnotationAttributes};
type ParsedAnnotation = {startUnit: TranscriptUnit, endUnit: TranscriptUnit, attributes: AnnotationAttributes};

export class AnnotationParseError extends Error {
  transcriptTextIndex: number;
  llmTextIndex: number;

  constructor(
    message: string,
    transcriptTextIndex: number,
    llmTextIndex: number
  ) {
    super(message);
    this.name = 'AnnotationParseError';
    this.transcriptTextIndex = transcriptTextIndex;
    this.llmTextIndex = llmTextIndex;
  }
}


export type AnnotationParseErrorData = {
  message: string;
  transcriptTextIndex: number;
  llmTextIndex: number;
}

function nextTranscriptUnit(tokens: RangableIteratorResult[]): [RangableIteratorResult, number] | undefined {
    for (let i = 0; i < tokens.length; i++) {
        if (tokens[i].transcriptUnit) {
            return [tokens[i], i]
        }
    }
    return undefined;
}

function removeTrailingPunctuation(text: string): string {
    const punctuation = text.match(/([.,!?;:]+)$/);
    return punctuation ? text.slice(0, -punctuation[0].length) : text;
}

function removeAnnotationTags(text: string): string {
    //TODO(hs): check for malformed annotation tags, throw error
    text = text.replace(/<<anno(\s+[a-zA-Z][a-zA-Z0-9]*="[^"]*")*>>/g, '');
    text = text.replace(/<<?\/anno>>/g, '');
    return text;
}



function processInnerAnnotations(
  token: RangableIteratorResult, 
  inputTextToProcess: string, 
  partialAnnotations: PartialAnnotation[], 
  parsedAnnotations: ParsedAnnotation[], 
  initialLastTranscriptUnit: TranscriptUnit | undefined) {

  let currentIndex = 0;
  let lastTranscriptUnit = initialLastTranscriptUnit;
  if (token.transcriptUnit) {
    lastTranscriptUnit = token.transcriptUnit;
    for (const partialAnnotation of partialAnnotations) {
      if (!partialAnnotation.startUnit) {
        partialAnnotation.startUnit = token.transcriptUnit;
      }
    }
  }
  while (currentIndex < inputTextToProcess.length) {
    const remaining = inputTextToProcess.slice(currentIndex);

    if (remaining.startsWith("<<anno")) {
        // Handle opening annotation tag
        const annotationMatch = remaining.match(/^<<anno(\s+[a-zA-Z][a-zA-Z0-9]*="[^"]*")*\s*>>/);
        if (!annotationMatch) {
            throw new AnnotationParseError(
              `Invalid annotation start tag in fuzzy match section`,
            -1, currentIndex);
        }
        const fullMatch = annotationMatch[0];
        
        // Extract attributes
        const attributesMatches = fullMatch.matchAll(/([a-zA-Z][a-zA-Z0-9]*)="([^"]*)"/g);
        const attributes: Record<string, string> = {};
        for (const match of attributesMatches) {
            const [_, key, value] = match;
            attributes[key] = value;
        }
        
        partialAnnotations.push({ attributes});
        currentIndex += fullMatch.length;
    } else if (remaining.startsWith("<</anno>>")) {
        // Handle closing annotation tag
        if (partialAnnotations.length === 0) {
            throw new AnnotationParseError(
              `Unexpected closing annotation in fuzzy match section`,
            -1, currentIndex);
        }
        const partialAnnotation = partialAnnotations.pop();
        if (!partialAnnotation || !lastTranscriptUnit) {
            throw new AnnotationParseError(`Invalid annotation state in fuzzy match section`,
              -1, currentIndex
            );
        }
        parsedAnnotations.push({
            startUnit: partialAnnotation.startUnit!,
            endUnit: lastTranscriptUnit,
            attributes: partialAnnotation.attributes,
        });
        currentIndex += "<</anno>>".length;
    } else {
        // Skip non-annotation characters
        currentIndex++;
    }
  }
}

export function parseAnnotations(llmText: string, tr: Rangable[], formatters: RF[]) {
    const parsedAnnotations: ParsedAnnotation[] = [];
    const partialAnnotations: (PartialAnnotation & {
      lastUnit?: TranscriptUnit;
    })[] = [];
    let inputTextIndex = 0;
    const transcriptTokens = [...iterateRangable(tr, formatters)];
    let tokenIndex = 0;
    let lastTranscriptUnit: TranscriptUnit | undefined;
  
    while (tokenIndex < transcriptTokens.length) {
      const token = transcriptTokens[tokenIndex];
      const remainingInput = llmText.slice(inputTextIndex);
  
      if (remainingInput.startsWith("<<anno")) {
        // Parse annotation start tag
        const annotationMatch = remainingInput.match(
          /^<<anno(\s+[a-zA-Z][a-zA-Z0-9]*="[^"]*")*\s*>>/
        );
        if (!annotationMatch) {
          throw new AnnotationParseError(
            `Invalid annotation start tag at position ${inputTextIndex}: "${remainingInput.slice(
              0,
              50
            )}"`, transcriptTokens.slice(0, tokenIndex).join("").length, inputTextIndex)
        }
        const fullMatch = annotationMatch[0];
        // Extract attributes
        const attributesMatches = fullMatch.matchAll(
          /([a-zA-Z][a-zA-Z0-9]*)="([^"]*)"/g
        );
        const attributes: Record<string, string> = {};
        for (const match of attributesMatches) {
          const [_, key, value] = match;
          attributes[key] = value;
        }
  
        partialAnnotations.push({ attributes });
        inputTextIndex += fullMatch.length;
        continue;
      } else if (remainingInput.startsWith("<</anno>>")) {
        if (partialAnnotations.length === 0) {
          throw new AnnotationParseError(
            `Unexpected closing annotation at position ${inputTextIndex}: "${remainingInput.slice(
              0,
              50
            )}"`,
            transcriptTokens.slice(0, tokenIndex).join("").length, inputTextIndex
          );
        }
        const partialAnnotation = partialAnnotations.pop();
        if (!partialAnnotation) {
          throw new AnnotationParseError(
            `Unexpected closing annotation at position ${inputTextIndex}: "${remainingInput.slice(
              0,
              50
            )}"`,
            transcriptTokens.slice(0, tokenIndex).join("").length, inputTextIndex
          );
        }
        if (!lastTranscriptUnit) {
            throw new AnnotationParseError(
              `Unexpected closing annotation at position ${inputTextIndex}: "${remainingInput.slice(0, 50)}"`,
              transcriptTokens.slice(0, tokenIndex).join("").length, inputTextIndex);
        }
        parsedAnnotations.push({
          startUnit: partialAnnotation.startUnit!,
          endUnit: lastTranscriptUnit,
          attributes: partialAnnotation.attributes,
        });
        inputTextIndex += "<</anno>>".length;
        continue;
      } else {

        // Try to match the token text directly
        if (llmText.startsWith(token.text, inputTextIndex)) {
          inputTextIndex += token.text.length;
          tokenIndex++;
          if (token.transcriptUnit) {
            lastTranscriptUnit = token.transcriptUnit;
            for (const partialAnnotation of partialAnnotations) {
              if (!partialAnnotation.startUnit) {
                partialAnnotation.startUnit = token.transcriptUnit;
              }
            }
          }
          continue;
        }

        // at this point, llm text is either slightly malformed or totally off. Attempting a fuzzy match
        // +1 to exclude current token which could be a transcript unit
        const nextUnit = nextTranscriptUnit(transcriptTokens.slice(tokenIndex + 1));
        if (!nextUnit) {
          // at the end of the transcript, just process all final annotations
            processInnerAnnotations(
              token, llmText.slice(inputTextIndex), partialAnnotations, parsedAnnotations, lastTranscriptUnit);
            break;
        }

        const [nextToken, nextTokenIndex] = nextUnit;
        const nextTokenText = removeTrailingPunctuation(nextToken.text);
        const i = llmText.slice(inputTextIndex).indexOf(nextTokenText);
        if (i === -1) {
            throw new AnnotationParseError(`Incoming llm text doesn't match transcript and next transcript unit doesn't match anything.
                LLM text: "${llmText.slice(inputTextIndex, inputTextIndex + 50)}"
                Next token text: "${nextTokenText}"
                Tokens: ${transcriptTokens.slice(tokenIndex, tokenIndex + 5).map(t => t.text).join('')}`,
                transcriptTokens.slice(0, tokenIndex).join("").length, inputTextIndex);
        }
        // +1 to account for earlier indexing 
        const transcriptText = transcriptTokens.slice(tokenIndex, tokenIndex + nextTokenIndex + 1).map(t => t.text).join('');
        const inputTextToProcess = llmText.slice(inputTextIndex, inputTextIndex + i);
        const possibleFuzzyMatchText = removeAnnotationTags(inputTextToProcess)
        const editDistance = distance(transcriptText, possibleFuzzyMatchText);
        // arbitrary threshold for now, will see if it works and maybe get a better fuzzy match criteria.
        const relThreshold = Math.min(
          Math.round(transcriptText.length * 0.75), 
          Math.round(possibleFuzzyMatchText.length * 0.75))
        if (editDistance < 8 && 
            editDistance < relThreshold) {
            // Process the text between current position and next match for annotations
            processInnerAnnotations(
              token, inputTextToProcess, partialAnnotations, parsedAnnotations, lastTranscriptUnit);
            // Update indices to continue after the fuzzy matched section
            inputTextIndex += i;
            tokenIndex += nextTokenIndex + 1;
            
        } else {
            throw new AnnotationParseError(`Incoming llm text couldn't be fuzzy matched to transcript text. 
                LLM text: "${inputTextToProcess}", transcript text: "${transcriptText}"
                NextToken: "${nextToken.text}
                I: ${i}, edit distance: ${editDistance}
                NextToken no punct: "${nextTokenText}"`,
                transcriptTokens.slice(0, tokenIndex).join("").length, inputTextIndex);
        }
      }
    }
  
    if (partialAnnotations.length > 0) {
      console.log('end, partial annos', partialAnnotations)
      console.log(`end, last text: "${llmText.slice(inputTextIndex)}"`)
      const d = (llmText.slice(inputTextIndex).match(/<<\/anno>>/g) || [])
      console.log(d)
      if ((llmText.slice(inputTextIndex).match(/<<\/anno>>/g) || []).length == partialAnnotations.length) {

        if (!lastTranscriptUnit) {
          throw new AnnotationParseError(
            'End of transcript, still closing anno tags, but no "lastTranscriptUnit"',
            transcriptTokens.slice(0, tokenIndex).join("").length, inputTextIndex);
        }

        for (const partialAnnotation of partialAnnotations) {
          parsedAnnotations.push({
            startUnit: partialAnnotation.startUnit ? partialAnnotation.startUnit : lastTranscriptUnit,
            endUnit: lastTranscriptUnit,
            attributes: partialAnnotation.attributes,
          })
        }
      } else {
        throw new AnnotationParseError(
          `Unclosed annotations at end of input text.`,
          transcriptTokens.slice(0, tokenIndex).join("").length, inputTextIndex);
      }
      
    }
    return parsedAnnotations;
  }

export function findUnitInRangable(rangables: Rangable[], id: number): TranscriptUnit {
  const formatters = []
  for (let i = 0; i < rangableDepth(rangables[0]); i++) {
      formatters.push(EMPTY_RF);
  }
  const iterator = makeRangableIterable(rangables, formatters);
  for (const token of iterator) {
      if (token.transcriptUnit && token.transcriptUnit.id === id) {
          return token.transcriptUnit
      }
  }
  throw new Error(`Transcript unit with id ${id} not found in rangables`);
}

export function createAnnotationExamples(
  transcript: Rangable[], 
  annotations: Annotation[], 
  formatters: RF[]
): {llmText: string, transcriptText: string} {
  const tokens = [...iterateRangable(transcript, formatters)];
  let curAnnoID = 0
  // Create maps to efficiently lookup annotations by start and end units
  const annosByStartUnit = new Map<number, Annotation[]>();
  const annosByEndUnit = new Map<number, Annotation[]>();
  for (const anno of annotations) {
    const startID = anno.startUnit.id;
    const endID = anno.endUnit.id;
    if (!annosByStartUnit.has(startID)) {
      annosByStartUnit.set(startID, []);
    }
    annosByStartUnit.get(startID)!.push(anno);
    if (!annosByEndUnit.has(endID)) {
      annosByEndUnit.set(endID, []);
    }
    annosByEndUnit.get(endID)!.push(anno);
  }

  // Keep track of which annotations are currently open
  const openAnnotations = new Set<Annotation>();
  
  let llmText = '';
  let transcriptText = '';

  for (let i = 0; i < tokens.length; i++) {
    const token = tokens[i];
    
    if (token.transcriptUnit) {
      // Add opening annotation tags for any annotations starting at this unit
      const startingAnnos = annosByStartUnit.get(token.transcriptUnit.id) || [];
      for (const anno of startingAnnos) {
        openAnnotations.add(anno);
        llmText += '<<anno';
        // Add all attributes from the annotation
        if (anno.content) {
          llmText += ` id="${curAnnoID}"`;
          curAnnoID++;
          for (const [key, value] of Object.entries(anno.content)) {
            llmText += ` ${key}="${value}"`;
          }
        }
        llmText += '>>';
      }
    }

    // Add the token text to both output strings
    llmText += token.text;
    transcriptText += token.text;

    if (token.transcriptUnit) {
      // Add closing annotation tags for any annotations ending at this unit
      const endingAnnos = annosByEndUnit.get(token.transcriptUnit.id) || [];
      for (const anno of endingAnnos) {
        openAnnotations.delete(anno);
        llmText += '<</anno>>';
      }
    }
  }

  // Close any remaining open annotations at the end
  for (const _ of openAnnotations) {
    llmText += '<</anno>>';
  }

  return {
    llmText,
    transcriptText
  };
}

function estimateTokens(text: string): number {
  // Rough estimation: ~4 characters per token on average
  return Math.ceil(text.split(/\s+/).length * 1.3);
}

// Helper to convert a turn into smaller turns based on units
function splitTurn(turn: SpeakerTurn, units: TranscriptUnit[]): SpeakerTurn[] {
  const splits: SpeakerTurn[] = [];
  let currentUnits: TranscriptUnit[] = [];
  // let startIndex = 0;

  for (let i = 0; i < turn.contents.length; i++) {
    const unit = turn.contents[i];
    if (units.some(u => u.id === unit.id)) {
      // If we have accumulated units before this split point, create a turn
      if (currentUnits.length > 0) {
        splits.push({
          speaker: turn.speaker,
          start: currentUnits[0].start,
          end: currentUnits[currentUnits.length - 1].end,
          contents: currentUnits
        });
        currentUnits = [];
      }
      // startIndex = i + 1;
    } else {
      currentUnits.push(unit);
    }
  }

  // Add remaining units if any
  if (currentUnits.length > 0) {
    splits.push({
      speaker: turn.speaker,
      start: currentUnits[0].start,
      end: currentUnits[currentUnits.length - 1].end,
      contents: currentUnits
    });
  }

  return splits;
}

export function chunkTranscriptForModel(transcript: LinearTurnsTranscript, model: ModelType): SpeakerTurn[][] {
  const maxTokens = MODEL_INFORMATION[model].maxOutputTokens;
  // the output is just the annotated text, which, if there's ALOT of annotations,
  // still shouldn't be more than 1.3x the input text.
  const maxChunkTokens = Math.floor(maxTokens / 1.3); // Account for 1.3x factor
  
  const chunks: SpeakerTurn[][] = [];
  let currentChunk: SpeakerTurn[] = [];
  let currentTokenCount = 0;
  
  for (const turn of transcript.turns) {
    // First try to add the entire turn
    const turnText = turn.contents.map(unit => unit.word).join(' ');
    const turnTokens = estimateTokens(turnText);
    
    if (currentTokenCount + turnTokens <= maxChunkTokens) {
      // Turn fits in current chunk
      currentChunk.push(turn);
      currentTokenCount += turnTokens;
    } else if (turnTokens <= maxChunkTokens) {
      // Turn doesn't fit in current chunk but can be its own chunk
      if (currentChunk.length > 0) {
        chunks.push(currentChunk);
      }
      currentChunk = [turn];
      currentTokenCount = turnTokens;
    } else {
      // Turn needs to be split
      if (currentChunk.length > 0) {
        chunks.push(currentChunk);
        currentChunk = [];
        currentTokenCount = 0;
      }
      
      // Find split points by accumulating units until we hit token limit
      const splitPoints: TranscriptUnit[] = [];
      let accumulatedTokens = 0;
      let lastSplitPoint: TranscriptUnit | null = null;
      
      for (const unit of turn.contents) {
        accumulatedTokens += estimateTokens(unit.word);
        if (accumulatedTokens > maxChunkTokens) {
          if (lastSplitPoint) {
            splitPoints.push(lastSplitPoint);
            accumulatedTokens = estimateTokens(unit.word);
          }
        }
        lastSplitPoint = unit;
      }
      
      // Split the turn and add resulting chunks
      const splitTurns = splitTurn(turn, splitPoints);
      for (const splitTurn of splitTurns) {
        const splitTokens = estimateTokens(splitTurn.contents.map(u => u.word).join(' '));
        if (currentTokenCount + splitTokens <= maxChunkTokens) {
          currentChunk.push(splitTurn);
          currentTokenCount += splitTokens;
        } else {
          if (currentChunk.length > 0) {
            chunks.push(currentChunk);
          }
          currentChunk = [splitTurn];
          currentTokenCount = splitTokens;
        }
      }
    }
  }
  
  // Add final chunk if not empty
  if (currentChunk.length > 0) {
    chunks.push(currentChunk);
  }
  
  return chunks;
}

export function formatAnnotationPrompt(
  chunk: SpeakerTurn[],
  userPrompt: string,
  systemPrompt: string,
  formatters: RF[] = basicTurnFormatters
): string {
  return `${systemPrompt}\n${userPrompt}\n====input transcript====\n${stringifyRangable(chunk, basicTurnFormatters)}\n====output annotations====\n`
}