#! /usr/bin/env python
import argparse
import sys
import logging
import json
import subprocess
from PIL import Image, ImageDraw, ImageFont
from typing import List, Dict, Any, Optional
from pathlib import Path
from dataclasses import dataclass
from enum import Enum
from datetime import datetime

logger = logging.getLogger(__name__)
CONFIG_DIR = Path("~/.config/memegen").expanduser()
DEFAULT_OUTPUT_DIR = Path("~/Pictures/memes").expanduser()


@dataclass
class MemeTemplate:
    image_path: Path
    config: Dict[str, Any]


class HorizontalAlignment(Enum):
    LEFT = "left"
    CENTER = "center"
    RIGHT = "right"


class VerticalAlignment(Enum):
    TOP = "top"
    CENTER = "center"
    BOTTOM = "bottom"


def discover_templates() -> List[str]:
    # Find all directories in the templates directory
    return [p.name for p in (CONFIG_DIR / "templates").glob("*") if p.is_dir()]


def load_template(name: str) -> MemeTemplate:
    logger.info(f"Loading template: {name}")

    # Find the template directory
    template_dir = CONFIG_DIR / "templates" / name
    if not template_dir.exists():
        logger.error(f"Template {name} does not exist")
        sys.exit(1)

    return MemeTemplate(
        image_path=template_dir / "template.png",
        config=json.loads((template_dir / "config.json").read_text()),
    )


def calc_width_from_image(width_str: str, image_width: int) -> int:
    if width_str.endswith("%"):
        return int(image_width * int(width_str[:-1]) / 100)
    else:
        return int(width_str)


def render_text_on_image(image: Image, text: str, zone: str, config: Dict[str, Any]):
    # NOTE: This must handle text with newlines
    # Get the zone config
    zone_config = config["zones"][zone]
    horizontal_alignment = HorizontalAlignment(zone_config["horizontal_align"])
    vertical_alignment = VerticalAlignment(zone_config["vertical_align"])
    text_width = calc_width_from_image(zone_config["width"], image.width)
    max_line_height = zone_config["max_line_height"]
    font_path = CONFIG_DIR / "fonts" / config["font"]

    # Create the font
    font = None
    font_size = 1
    while True:
        font = ImageFont.truetype(str(font_path), font_size)

        # Split the text into lines
        lines = text.splitlines()
        bounding_boxes = []
        for line in lines:
            bounding_boxes.append(font.getbbox(line))

        # Calculate the height of the text
        line_height = max([bbox[3] for bbox in bounding_boxes])
        total_height = sum(
            [bbox[3] + zone_config["line_spacing"] for bbox in bounding_boxes]
        )
        max_width = max([bbox[2] for bbox in bounding_boxes])

        # If we have a max line height, ensure we don't exceed it
        if max_line_height and line_height > max_line_height:
            font_size -= 1
            break

        # Don't exceed the width
        if max_width > text_width:
            font_size -= 1
            break

        # Increment the font size
        font_size += 1

    # Determine the starting Y position
    y = zone_config["vertical_offset"]
    if vertical_alignment == VerticalAlignment.CENTER:
        y += (image.height - total_height) / 2
    elif vertical_alignment == VerticalAlignment.BOTTOM:
        y += image.height - total_height

    # Render each line onto the image
    draw = ImageDraw.Draw(image)
    for line in text.splitlines():
        # Calculate the x position
        if horizontal_alignment == HorizontalAlignment.LEFT:
            x = zone_config["horizontal_offset"]
        elif horizontal_alignment == HorizontalAlignment.CENTER:
            x = ((image.width - font.getbbox(line)[2]) / 2) + zone_config[
                "horizontal_offset"
            ]
        elif horizontal_alignment == HorizontalAlignment.RIGHT:
            x = (image.width - font.getbbox(line)[2]) + zone_config["horizontal_offset"]
        else:
            raise ValueError(f"Invalid horizontal alignment: {horizontal_alignment}")

        # Render the text
        draw.text(
            (x, y),
            line,
            fill=tuple(config["fill_color"]),
            stroke_fill=tuple(config["stroke_color"]),
            stroke_width=config["stroke_width"],
            font=font,
        )

        # Increment the y position
        y += line_height + zone_config["line_spacing"]


def main() -> int:
    # Handle program arguments
    ap = argparse.ArgumentParser(prog="memegen", description="Generates memes")
    ap.add_argument(
        "template", help="The template to use", choices=discover_templates()
    )
    ap.add_argument("--top-text", help="Top text (if applicable)")
    ap.add_argument("--bottom-text", help="Bottom text (if applicable)")
    ap.add_argument(
        "--keep-case", help="Keep the case of the text", action="store_true"
    )
    ap.add_argument("--output", "-o", help="Output file path")
    ap.add_argument("--no-show", help="Don't show the image after creation", action="store_true")
    ap.add_argument(
        "-v", "--verbose", help="Enable verbose logging", action="store_true"
    )
    args = ap.parse_args()

    # Configure logging
    logging.basicConfig(
        level=logging.DEBUG if args.verbose else logging.INFO,
        format="%(levelname)s:	%(message)s",
    )

    # Load the template
    template = load_template(args.template)
    template_supports_top_text = "top" in template.config["zones"]
    template_supports_bottom_text = "bottom" in template.config["zones"]

    # Ensure we have text
    if args.top_text and not template_supports_top_text:
        logger.error(f"Template {args.template} does not support top text")
        sys.exit(1)
    if args.bottom_text and not template_supports_bottom_text:
        logger.error(f"Template {args.template} does not support bottom text")
        sys.exit(1)
    if not args.top_text and not args.bottom_text:
        logger.error("No text provided")
        if not all([template_supports_top_text, template_supports_bottom_text]):
            required_text = "top" if template_supports_top_text else "bottom"
            logger.error(
                f"Template {args.template} requires the --{required_text}-text argument"
            )
        sys.exit(1)

    # Transform the text
    # fmt:off
    top_text = args.top_text.upper() if args.top_text and (not args.keep_case) else args.top_text
    bottom_text = args.bottom_text.upper() if args.bottom_text and (not args.keep_case) else args.bottom_text
    top_text = top_text.replace("\\n", "\n").replace("\\N", "\n") if top_text else None
    bottom_text = bottom_text.replace("\\n", "\n").replace("\\N", "\n") if bottom_text else None
    # fmt: on

    # Load the image
    image = Image.open(template.image_path)

    # Render the text
    if top_text:
        render_text_on_image(image, top_text, "top", template.config)
    if bottom_text:
        render_text_on_image(image, bottom_text, "bottom", template.config)

    # Build the output path
    output_path = (
        Path(args.output)
        if args.output
        else (
            DEFAULT_OUTPUT_DIR
            / f"meme-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.{args.template}.png"
        )
    )
    output_path.parent.mkdir(parents=True, exist_ok=True)

    # Save the image
    image.save(output_path)

    # Show the image
    if not args.no_show:
        subprocess.run(["xdg-open", str(output_path)])

    return 0


if __name__ == "__main__":
    sys.exit(main())