#!/usr/bin/env python3
"""
# Generated by Claude AI

Script to completely regenerate the GGML remoting codebase from YAML configuration.

This script reads api_functions.yaml and regenerates all the header files and
implementation templates for the GGML remoting layer.

Usage:
  python regenerate_remoting.py

The script will:
1. Read ggmlremoting_functions.yaml configuration
2. Generate updated header files
3. Generate implementation templates in dedicated files
4. Show a summary of what was generated
"""

import yaml
from typing import Dict, List, Any
from pathlib import Path
import os
import subprocess
import shutil
import logging

NL = '\n' # can't have f"{'\n'}" in f-strings


class RemotingCodebaseGenerator:
    def __init__(self, yaml_path: str = "ggmlremoting_functions.yaml"):
        """Initialize the generator with the YAML configuration."""
        self.yaml_path = yaml_path

        if not Path(yaml_path).exists():
            raise FileNotFoundError(f"Configuration file {yaml_path} not found")

        with open(yaml_path, 'r') as f:
            self.config = yaml.safe_load(f)

        self.functions = self.config['functions']
        self.naming_patterns = self.config['naming_patterns']
        self.config_data = self.config['config']

        # Check if clang-format is available
        self.clang_format_available = self._check_clang_format_available()

    def _check_clang_format_available(self) -> bool:
        """Check if clang-format is available in the system PATH."""
        return shutil.which("clang-format") is not None

    def _format_file_with_clang_format(self, file_path: Path) -> bool:
        """Format a file with clang-format -i. Returns True if successful, False otherwise."""
        if not self.clang_format_available:
            return False

        try:
            subprocess.run(
                ["clang-format", "-i", str(file_path)],
                check=True,
                capture_output=True,
                text=True
            )
            return True
        except subprocess.CalledProcessError:
            logging.exception(f"   ⚠️  clang-format failed for {file_path}")
            return False
        except Exception as e:
            logging.exception(f"   ⚠️  Unexpected error formatting {file_path}: {e}")
            return False

    def generate_enum_name(self, group_name: str, function_name: str) -> str:
        """Generate the APIR_COMMAND_TYPE enum name for a function."""
        prefix = self.naming_patterns['enum_prefix']
        return f"{prefix}{group_name.upper()}_{function_name.upper()}"

    def generate_backend_function_name(self, group_name: str, function_name: str) -> str:
        """Generate the backend function name."""
        function_key = f"{group_name}_{function_name}"
        overrides = self.naming_patterns.get('backend_function_overrides', {})

        if function_key in overrides:
            return overrides[function_key]

        prefix = self.naming_patterns['backend_function_prefix']
        return f"{prefix}{group_name}_{function_name}"

    def generate_frontend_function_name(self, group_name: str, function_name: str) -> str:
        """Generate the frontend function name."""
        prefix = self.naming_patterns['frontend_function_prefix']
        return f"{prefix}{group_name}_{function_name}"

    def get_enabled_functions(self) -> List[Dict[str, Any]]:
        """Get all enabled functions with their metadata."""
        functions = []
        enum_value = 0

        for group_name, group_data in self.functions.items():
            group_description = group_data['group_description']

            for function_name, func_metadata in group_data['functions'].items():
                # Handle case where func_metadata is None or empty (functions with only comments)
                if func_metadata is None:
                    func_metadata = {}

                # Functions are enabled by default unless explicitly disabled
                if func_metadata.get('enabled', True):
                    functions.append({
                        'group_name': group_name,
                        'function_name': function_name,
                        'enum_name': self.generate_enum_name(group_name, function_name),
                        'enum_value': enum_value,
                        'backend_function': self.generate_backend_function_name(group_name, function_name),
                        'frontend_function': self.generate_frontend_function_name(group_name, function_name),
                        'frontend_return': func_metadata.get('frontend_return', 'void'),
                        'frontend_extra_params': func_metadata.get('frontend_extra_params', []),
                        'group_description': group_description,
                        'newly_added': func_metadata.get('newly_added', False)
                    })
                    enum_value += 1

        return functions

    def generate_apir_backend_header(self) -> str:
        """Generate the complete apir_backend.h file."""
        functions = self.get_enabled_functions()

        # Generate the enum section
        enum_lines = ["typedef enum ApirBackendCommandType {"]
        current_group = None

        for func in functions:
            # Add comment for new group
            if func['group_name'] != current_group:
                enum_lines.append("")
                enum_lines.append(f"  /* {func['group_description']} */")
                current_group = func['group_name']

            enum_lines.append(f"  {func['enum_name']} = {func['enum_value']},")

        # Add the count
        total_count = len(functions)
        enum_lines.append("\n  // last command_type index + 1")
        enum_lines.append(f"  APIR_BACKEND_DISPATCH_TABLE_COUNT = {total_count},")
        enum_lines.append("} ApirBackendCommandType;")

        # Full header template
        header_content = NL.join(enum_lines) + "\n"

        return header_content

    def generate_backend_dispatched_header(self) -> str:
        """Generate the complete backend-dispatched.h file."""
        functions = self.get_enabled_functions()

        # Function declarations
        decl_lines = []
        current_group = None

        for func in functions:
            if func['group_name'] != current_group:
                decl_lines.append(f"\n/* {func['group_description']} */")
                current_group = func['group_name']

            signature = "uint32_t"
            params = "apir_encoder *enc, apir_decoder *dec, virgl_apir_context *ctx"
            decl_lines.append(f"{signature} {func['backend_function']}({params});")

        # Switch cases
        switch_lines = []
        current_group = None

        for func in functions:
            if func['group_name'] != current_group:
                switch_lines.append(f"  /* {func['group_description']} */")
                current_group = func['group_name']

            switch_lines.append(f"  case {func['enum_name']}: return \"{func['backend_function']}\";")

        # Dispatch table
        table_lines = []
        current_group = None

        for func in functions:
            if func['group_name'] != current_group:
                table_lines.append(f"\n  /* {func['group_description']} */")
                table_lines.append("")
                current_group = func['group_name']

            table_lines.append(f"  /* {func['enum_name']}  = */ {func['backend_function']},")

        header_content = f'''\
#pragma once

{NL.join(decl_lines)}

static inline const char *backend_dispatch_command_name(ApirBackendCommandType type)
{{
  switch (type) {{
{NL.join(switch_lines)}

  default: return "unknown";
  }}
}}

extern "C" {{
static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = {{
  {NL.join(table_lines)}
}};
}}
'''
        return header_content

    def generate_virtgpu_forward_header(self) -> str:
        """Generate the complete virtgpu-forward.gen.h file."""
        functions = self.get_enabled_functions()

        decl_lines = []
        current_group = None

        for func in functions:
            if func['group_name'] != current_group:
                decl_lines.append("")
                decl_lines.append(f"/* {func['group_description']} */")
                current_group = func['group_name']

            # Build parameter list
            params = [self.naming_patterns['frontend_base_param']]
            params.extend(func['frontend_extra_params'])
            param_str = ', '.join(params)

            decl_lines.append(f"{func['frontend_return']} {func['frontend_function']}({param_str});")

        header_content = f'''\
#pragma once
{NL.join(decl_lines)}
'''
        return header_content

    def regenerate_codebase(self) -> None:
        """Regenerate the entire remoting codebase."""
        logging.info("🔄 Regenerating GGML Remoting Codebase...")
        logging.info("=" * 50)

        # Detect if we're running from frontend directory
        current_dir = os.getcwd()
        is_frontend_dir = current_dir.endswith('ggml-virtgpu')

        if is_frontend_dir:
            # Running from ggml/src/ggml-virtgpu-apir
            logging.info("📍 Detected frontend directory execution")
            frontend_base = Path(".")
        else:
            # Running from project root (fallback to original behavior)
            logging.info("📍 Detected project root execution")
            base_path = self.config_data.get('base_path', 'ggml/src')
            frontend_base = Path(base_path) / "ggml-virtgpu"

        # Compute final file paths
        backend_base = frontend_base / "backend"
        apir_backend_path = backend_base / "shared" / "apir_backend.gen.h"
        backend_dispatched_path = backend_base / "backend-dispatched.gen.h"
        virtgpu_forward_path = frontend_base / "virtgpu-forward.gen.h"

        # Create output directories for each file
        apir_backend_path.parent.mkdir(parents=True, exist_ok=True)
        backend_dispatched_path.parent.mkdir(parents=True, exist_ok=True)
        virtgpu_forward_path.parent.mkdir(parents=True, exist_ok=True)

        # Generate header files
        logging.info("📁 Generating header files...")

        apir_backend_content = self.generate_apir_backend_header()
        apir_backend_path.write_text(apir_backend_content)
        logging.info(f"   ✅ {apir_backend_path.resolve()}")

        backend_dispatched_content = self.generate_backend_dispatched_header()
        backend_dispatched_path.write_text(backend_dispatched_content)
        logging.info(f"   ✅ {backend_dispatched_path.resolve()}")

        virtgpu_forward_content = self.generate_virtgpu_forward_header()
        virtgpu_forward_path.write_text(virtgpu_forward_content)
        logging.info(f"   ✅ {virtgpu_forward_path.resolve()}")

        # Format generated files with clang-format
        generated_files = [apir_backend_path, backend_dispatched_path, virtgpu_forward_path]

        if not self.clang_format_available:
            logging.warning("\n⚠️clang-format not found in PATH. Generated files will not be formatted."
                            "   Install clang-format to enable automatic code formatting.")
        else:
            logging.info("\n🎨 Formatting files with clang-format...")
            for file_path in generated_files:
                if self._format_file_with_clang_format(file_path):
                    logging.info(f"   ✅ Formatted {file_path.name}")
                else:
                    logging.warning(f"   ❌ Failed to format {file_path.name}")

        # Generate summary
        functions = self.get_enabled_functions()
        total_functions = len(functions)

        logging.info("\n📊 Generation Summary:")
        logging.info("=" * 50)
        logging.info(f"   Total functions: {total_functions}")
        logging.info(f"   Function groups: {len(self.functions)}")
        logging.info("   Header files: 3")
        logging.info(f"   Working directory: {current_dir}")


def main():
    try:
        generator = RemotingCodebaseGenerator()
        generator.regenerate_codebase()
    except Exception as e:
        logging.exception(f"❌ Error: {e}")
        exit(1)


if __name__ == "__main__":
    main()
