import createOffscreenCanvas from "./createOffscreenCanvas";
import { ImageRegionData } from "./ImageRegionData";
import spiralIterator from "./spiralIterator";

class ImageRegions {
  image: HTMLImageElement;
  imageWidth: number;
  imageHeight: number;
  imageData: ImageData | null = null;
  cachedImageRegions: ImageRegionData[] = [];

  isBorderPixel(x: number, y: number) {
    if (!this.imageData) {
      const canvas = createOffscreenCanvas(this.imageWidth, this.imageHeight);
      canvas.width = this.imageWidth;
      canvas.height = this.imageHeight;

      const context = canvas.getContext("2d", {
        willReadFrequently: true,
      }) as CanvasRenderingContext2D;
      if (!context) {
        throw "Could not get 2d context from canvas";
      }
      context.imageSmoothingEnabled = false;
      context.drawImage(this.image, 0, 0);

      this.imageData = context.getImageData(
        0,
        0,
        this.imageWidth,
        this.imageHeight
      );
    }

    const r = this.imageData.data[(x + y * this.imageWidth) * 4];
    const g = this.imageData.data[(x + y * this.imageWidth) * 4 + 1];
    const b = this.imageData.data[(x + y * this.imageWidth) * 4 + 2];
    const luminance = (299 * r + 587 * g + 114 * b) / 1000;
    return luminance < 100;
  }

  constructor(image: HTMLImageElement) {
    this.image = image;
    this.imageWidth = image.width;
    this.imageHeight = image.height;
    this.cachedImageRegions = [];
  }

  merged(): ImageRegionData {
    const result = new Map<number, Set<number>>();

    for (let x = 0; x < this.imageWidth; x++) {
      for (let y = 0; y < this.imageHeight; y++) {
        if (!this.isBorderPixel(x, y)) {
          if (!result.has(x)) {
            result.set(x, new Set<number>());
          }
          result.get(x)!.add(y);
        }
      }
    }

    return result;
  }

  findAll() {
    const result = [];
    const processedPixels: Set<number> = new Set();

    for (let x = 0; x < this.imageWidth; x++) {
      for (let y = 0; y < this.imageHeight; y++) {
        const pixelIndex = x + y * this.imageWidth;

        if (processedPixels.has(pixelIndex)) continue;

        const isBorder = this.isBorderPixel(x, y);
        if (isBorder) continue;

        const region = this.find(x, y, processedPixels);

        processedPixels.add(pixelIndex);

        for (const regionX of region.keys()) {
          for (const regionY of region.get(regionX)!.values()) {
            processedPixels.add(regionX + regionY * this.imageWidth);
          }
        }

        result.push(region);
      }
    }

    return result;
  }

  findNearestNonBorderPixel(x: number, y: number) {
    const spiral = spiralIterator(x, y, 20);
    for (const next of spiral) {
      if (x < 0 || y < 0 || x >= this.imageWidth || y >= this.imageHeight) {
        continue;
      }
      if (!this.isBorderPixel(next.x, next.y)) {
        return next;
      }
    }

    return null;
  }

  find(
    x: number,
    y: number,
    processedPixels: Set<number> = new Set()
  ): ImageRegionData {
    const nearestNonBorderPixel = this.findNearestNonBorderPixel(x, y);
    if (nearestNonBorderPixel) {
      x = nearestNonBorderPixel.x;
      y = nearestNonBorderPixel.y;
    }

    for (const imageRegion of this.cachedImageRegions) {
      if (imageRegion.has(x) && imageRegion.get(x)!.has(y)) {
        return imageRegion;
      }
    }

    const result: ImageRegionData = new Map<number, Set<number>>();
    const stack: [{ x: number; y: number }] = [{ x, y }];

    while (stack.length > 0) {
      const next = stack.pop();
      if (next) {
        const pixelIndex = next.x + next.y * this.imageWidth;

        if (processedPixels.has(pixelIndex)) continue;
        processedPixels.add(pixelIndex);

        if (
          next.x >= this.imageWidth ||
          next.y >= this.imageHeight ||
          next.x < 0 ||
          next.y < 0
        )
          continue;

        const isBorder = this.isBorderPixel(next.x, next.y);
        if (isBorder) continue;

        if (!result.has(next.x)) {
          result.set(next.x, new Set<number>());
        }
        result.get(next.x)!.add(next.y);

        stack.push({ x: next.x + 1, y: next.y });
        stack.push({ x: next.x - 1, y: next.y });
        stack.push({ x: next.x, y: next.y + 1 });
        stack.push({ x: next.x, y: next.y - 1 });
      }
    }

    this.cachedImageRegions.push(result);
    return result;
  }
}

export default ImageRegions;
