"""
Battle Scene Module
Handles the turn-based combat system
"""
import pygame
import random
import time
import os
from game.ui.scene import Scene
from game.core.game_state import GameState
from game.core.unit_sprite import UnitSprite
from game.core.sprite import Sprite

# Grid cell size for the battle arena
CELL_SIZE = 64

class BattleScene(Scene):
    """
    Battle scene handling turn-based combat
    """
    def __init__(self, game_manager):
        """Initialize the battle scene"""
        super().__init__(game_manager)
        self.buttons = []
        self.hover_button = None
        
        # Battle state
        self.grid_size = (10, 6)  # Width, Height
        self.grid = [[None for _ in range(self.grid_size[1])] for _ in range(self.grid_size[0])]
        self.turn = "player"  # "player" or "enemy"
        self.selected_unit = None
        self.selected_cell = None
        self.move_range = []
        self.attack_range = []
        self.current_phase = "select"  # "select", "move", "attack", "enemy"
        self.battle_message = ""
        self.turn_counter = 1  # Track number of turns that have passed
        
        # Morale system
        self.player_morale = 100  # Starting morale for player side
        self.enemy_morale = 100   # Starting morale for enemy side
        
        # Status effect definitions
        self.status_effects = {
            "burning": {
                "duration": 2,  # Lasts 2 turns
                "effect": "damage",  # Deals damage per turn
                "value": 3,  # 3 damage per turn
                "color": (200, 50, 0),  # Orange-red color
                "icon": "🔥"  # Emoji icon
            },
            "stunned": {
                "duration": 1,  # Lasts 1 turn
                "effect": "skip_turn",  # Unit skips its turn
                "color": (150, 150, 0),  # Yellow color
                "icon": "⚡"  # Emoji icon
            },
            "poisoned": {
                "duration": 3,  # Lasts 3 turns
                "effect": "damage",  # Deals damage per turn
                "value": 2,  # 2 damage per turn
                "color": (0, 150, 0),  # Green color
                "icon": "☠️"  # Emoji icon
            },
            "regenerating": {
                "duration": 3,  # Lasts 3 turns
                "effect": "heal",  # Heals each turn
                "value": 5,  # 5 HP per turn
                "color": (0, 200, 100),  # Blue-green color
                "icon": "❤️"  # Emoji icon
            }
        }
        
        # Terrain types and their effects
        self.terrain_types = {
            "plain": {"color": (80, 60, 60), "move_cost": 1, "defense_bonus": 0},
            "forest": {"color": (40, 80, 40), "move_cost": 2, "defense_bonus": 2},
            "mountain": {"color": (100, 100, 100), "move_cost": 3, "defense_bonus": 4},
            "water": {"color": (40, 40, 100), "move_cost": 2, "defense_bonus": -1}
        }
        
        # Terrain grid holds terrain type for each cell
        self.terrain_grid = [["plain" for _ in range(self.grid_size[1])] for _ in range(self.grid_size[0])]
        
        # Track which units have already acted this turn
        self.acted_units = set()
        
        # Background color - dark red/brown for battle theme
        self.bg_color = (40, 20, 20)
        
        # Sample units (placeholder)
        self.player_units = []
        self.enemy_units = []
        
        # Unit sprites dictionary
        self.unit_sprites = {}
        
        # Debug flags
        self.debug_mode = True
        self.last_action = "None"
        
        # Initialize fonts
        self.fonts = {
            "small": pygame.font.Font(None, 24),
            "medium": pygame.font.Font(None, 32),
            "large": pygame.font.Font(None, 48)
        }
        
        # Initialize sample battle (will be replaced with mission data)
        self._initialize_sample_battle()
        
    def _initialize_sample_battle(self):
        """Initialize a sample battle using UnitLoader."""
        self.player_units = []
        self.enemy_units = []
        self.unit_sprites = {}
        self.grid = [[None for _ in range(self.grid_size[1])] for _ in range(self.grid_size[0])]
        self.acted_units = set()
        self.unit_loader = self.game_manager.unit_loader # Get loader instance

        # Generate random terrain
        self._generate_terrain()

        # Define player starting setup (unit type and position)
        player_setup = [
            {"type": "tactical", "pos": (1, 2)},
            {"type": "devastator", "pos": (2, 3)},
            {"type": "scout", "pos": (1, 4)},
            {"type": "battle_suit", "pos": (3, 1)},
        ]
        
        # Create player units using the loader
        for i, setup in enumerate(player_setup):
            instance_id = f"player_{i+1}"
            unit_instance = self.unit_loader.create_unit_instance(
                unit_type=setup["type"],
                faction="space_marines",
                instance_id=instance_id,
                position=setup["pos"]
            )
            if unit_instance:
                # DEBUG: Check if battle_suit has correct size
                if setup["type"] == "battle_suit":
                    print(f"\n==== BATTLE SUIT CREATION DEBUG ====")
                    print(f"Unit ID: {instance_id}")
                    print(f"Unit Type: {unit_instance.get('unit_type', 'unknown')}")
                    print(f"Unit Size: {unit_instance.get('size', 'unknown')}")
                    print(f"Unit Position: {unit_instance.get('position', 'unknown')}")
                    print(f"Unit Name: {unit_instance.get('name', 'unknown')}")
                    print("=======================================\n")
                self.player_units.append(unit_instance)
            else:
                print(f"Failed to create player unit instance: space_marines - {setup['type']}")

        # Define enemy starting setup
        enemy_setup = [
            {"type": "boy", "faction": "orks", "pos": (8, 1)},
            {"type": "nob", "faction": "orks", "pos": (7, 3)},
            {"type": "devilfish", "faction": "tau", "pos": (5, 1)}, # Example with Tau
            {"type": "devilfish", "faction": "tau", "pos": (6, 4)}, # Added new Devilfish
            {"type": "devilfish", "faction": "tau", "pos": (8, 5)}, # Added new Devilfish
        ]

        # Create enemy units using the loader
        for i, setup in enumerate(enemy_setup):
            instance_id = f"enemy_{setup['faction']}_{i+1}" # More descriptive ID
            unit_instance = self.unit_loader.create_unit_instance(
                unit_type=setup["type"],
                faction=setup["faction"],
                instance_id=instance_id,
                position=setup["pos"]
            )
            if unit_instance:
                self.enemy_units.append(unit_instance)
            else:
                print(f"Failed to create enemy unit instance: {setup['faction']} - {setup['type']}")

        # Place units on the grid, handling size
        all_units = self.player_units + self.enemy_units
        for unit in all_units:
            start_x, start_y = unit["position"]
            unit_size = unit.get("size", (1, 1)) # Get size, default to 1x1
            
            # DEBUG: Check if we're placing the Battle Suit
            is_battle_suit = unit.get("unit_type") == "battle_suit"
            if is_battle_suit:
                print(f"\n==== BATTLE SUIT PLACEMENT DEBUG ====")
                print(f"Attempting to place Battle Suit at ({start_x}, {start_y}) with size {unit_size}")
            
            # Check if placement is valid
            can_place = True
            for dx in range(unit_size[0]):
                for dy in range(unit_size[1]):
                    check_x, check_y = start_x + dx, start_y + dy
                    # Check grid bounds
                    if not (0 <= check_x < self.grid_size[0] and 0 <= check_y < self.grid_size[1]):
                        can_place = False
                        print(f"Warning: Unit '{unit['name']}' placement out of bounds at ({check_x}, {check_y}).")
                        break
                    # Check if cell is already occupied
                    if self.grid[check_x][check_y] is not None:
                        can_place = False
                        print(f"Warning: Unit '{unit['name']}' overlaps with existing unit at ({check_x}, {check_y}).")
                        break
                if not can_place: break
            
            # Place unit if valid
            if can_place:
                for dx in range(unit_size[0]):
                    for dy in range(unit_size[1]):
                        place_x, place_y = start_x + dx, start_y + dy
                        self.grid[place_x][place_y] = unit # Place reference in all cells
                        if is_battle_suit:
                            print(f"  Placed Battle Suit reference in cell ({place_x}, {place_y})")
                
                # Debug output for multi-tile units
                if unit_size[0] > 1 or unit_size[1] > 1:
                    print(f"Placed multi-tile unit '{unit['name']}' size {unit_size} at position {(start_x, start_y)}")
                    for dx in range(unit_size[0]):
                        for dy in range(unit_size[1]):
                            print(f"  Cell ({start_x + dx}, {start_y + dy}) references this unit")
                    
                    if is_battle_suit:
                        print("Battle Suit placed successfully!")
                        print("=======================================\n")
            else:
                 print(f"Error: Could not place unit '{unit['name']}' at {unit['position']} size {unit_size}.")
                 # Handle error - maybe remove unit from list or try alternative placement?
                 # For now, we just print an error and it won't appear on the grid.
                 
                 if is_battle_suit:
                    print("Failed to place Battle Suit!")
                    print("=======================================\n")
            
        # Start with player turn
        self.turn = "player"
        self.current_phase = "select"
        self.battle_message = "Select a unit to begin your turn."
        
        # Debug output to check grid state
        print("\n----- GRID STATE AFTER INITIALIZATION -----")
        for i, unit in enumerate(self.player_units):
            print(f"Player unit {i+1}: {unit.get('name', 'Unknown')} ({unit.get('unit_type', 'Unknown')}) at position {unit.get('position', 'Unknown')}, size {unit.get('size', (1,1))}")
        for i, unit in enumerate(self.enemy_units):
            print(f"Enemy unit {i+1}: {unit.get('name', 'Unknown')} ({unit.get('unit_type', 'Unknown')}) at position {unit.get('position', 'Unknown')}, size {unit.get('size', (1,1))}")
        
        # Check specific grid cells
        print("\n----- CHECKING GRID CELLS -----")
        for x in range(self.grid_size[0]):
            for y in range(self.grid_size[1]):
                unit = self.grid[x][y]
                if unit:
                    print(f"Grid cell ({x}, {y}) contains: {unit.get('name', 'Unknown')} ({unit.get('unit_type', 'Unknown')})")
        
        # After initialization, check each grid cell for the Battle Suit
        print("\n==== GRID VERIFICATION FOR BATTLE SUIT ====")
        battle_suit_found = False
        for x in range(self.grid_size[0]):
            for y in range(self.grid_size[1]):
                unit = self.grid[x][y]
                if unit and unit.get("unit_type") == "battle_suit":
                    battle_suit_found = True
                    print(f"Found Battle Suit at grid cell ({x}, {y})")
                    print(f"  Unit ID: {unit.get('id', 'unknown')}")
                    print(f"  Unit Size: {unit.get('size', 'unknown')}")
                    print(f"  Unit Position: {unit.get('position', 'unknown')}")
        
        if not battle_suit_found:
            print("WARNING: Battle Suit not found in any grid cell!")
        print("==========================================\n")
        
    def enter(self):
        """Called when entering the battle scene"""
        # Reset battle state
        self.selected_unit = None
        self.selected_cell = None
        self.move_range = []
        self.attack_range = []
        self.current_phase = "select"
        self.battle_message = "Select a unit to begin your turn."
        self.last_action = "Entered battle scene"
        self.acted_units = set()  # Reset acted units
        
        # Get the selected mission from the mission scene
        mission_scene = self.game_manager.scenes[GameState.MISSION_SELECT]
        if hasattr(mission_scene, 'selected_mission'):
            selected_mission_data = next((m for m in mission_scene.mission_list if m["id"] == mission_scene.selected_mission), None)
            if selected_mission_data:
                # Initialize battle based on mission
                self._initialize_battle_from_mission(selected_mission_data)
            else:
                # Fallback to sample battle if mission data not found
                self._initialize_sample_battle()
        else:
            # Fallback to sample battle if no mission selected
            self._initialize_sample_battle()

    def _initialize_battle_from_mission(self, mission_data):
        """Initialize battle with randomized positions, enemy scaling, and debug prints."""
        self.player_units = []
        self.enemy_units = []
        self.unit_sprites = {}
        self.grid = [[None for _ in range(self.grid_size[1])] for _ in range(self.grid_size[0])]
        self.acted_units = set()
        self.unit_loader = self.game_manager.unit_loader

        print(f"--- Initializing Battle: {mission_data.get('name', 'Unknown')} ({mission_data.get('difficulty', 'Normal')}) ---")

        # --- Determine Player Units based on Difficulty --- 
        difficulty = mission_data.get("difficulty", "Medium")
        player_unit_types = []
        if difficulty == "Easy": player_unit_types = ["tactical", "scout"]
        elif difficulty == "Medium": player_unit_types = ["tactical", "devastator", "scout"]
        elif difficulty == "Hard": player_unit_types = ["tactical", "devastator", "scout", "terminator"]
        elif difficulty == "Super Hard": player_unit_types = ["tactical", "devastator", "scout", "terminator", "captain"]
        else: player_unit_types = ["tactical", "devastator", "scout"] # Default

        # Create player unit instances (without position yet)
        temp_player_units = []
        for i, unit_type in enumerate(player_unit_types):
            instance_id = f"player_{i+1}"
            # Create instance with placeholder position (0,0)
            unit_instance = self.unit_loader.create_unit_instance(
                unit_type=unit_type, faction="space_marines", instance_id=instance_id, position=(0,0)
            )
            if unit_instance:
                temp_player_units.append(unit_instance)
            else:
                print(f"Failed to create player unit instance: space_marines - {unit_type}")

        # --- Determine Enemy Units from Mission Data ---
        enemy_faction = mission_data.get("enemy_faction", None)
        base_enemy_unit_types = mission_data.get("enemy_unit_types", [])
        print(f"  Base enemy types: {base_enemy_unit_types}, Faction: {enemy_faction}")
        
        # --- Scale Enemy Numbers by Difficulty --- 
        scaled_enemy_unit_types = list(base_enemy_unit_types) # Start with base list
        if difficulty == "Medium" and base_enemy_unit_types:
            scaled_enemy_unit_types.append(base_enemy_unit_types[0]) # Add one more of the first type
            print(f"  Difficulty Medium: Added 1 extra enemy ({base_enemy_unit_types[0]})")
        elif difficulty == "Hard" and base_enemy_unit_types:
            scaled_enemy_unit_types.append(base_enemy_unit_types[0]) # Add first type
            if len(base_enemy_unit_types) > 1:
                scaled_enemy_unit_types.append(base_enemy_unit_types[1]) # Add second type if exists
            else:
                scaled_enemy_unit_types.append(base_enemy_unit_types[0]) # Otherwise add first type again
            print(f"  Difficulty Hard: Added 2 extra enemies")
        elif difficulty == "Super Hard" and base_enemy_unit_types:
            # Add more enemies for Super Hard difficulty
            for _ in range(2):  # Reduced from 3 to 2 extra enemies
                if base_enemy_unit_types:
                    scaled_enemy_unit_types.append(base_enemy_unit_types[0])  # Add more of the first type
            # Always add a devilfish for Super Hard
            scaled_enemy_unit_types.append("devilfish")
            print(f"  Difficulty Super Hard: Added 3 extra enemies (including a Devil Fish)")
        # -----------------------------------------
        
        temp_enemy_units = []
        print(f"  Attempting to load scaled enemy types: {scaled_enemy_unit_types}")
        if enemy_faction and scaled_enemy_unit_types:
            for i, unit_type in enumerate(scaled_enemy_unit_types):
                instance_id = f"enemy_{enemy_faction}_{i+1}" 
                # Create instance with placeholder position (0,0)
                unit_instance = self.unit_loader.create_unit_instance(
                    unit_type=unit_type, faction=enemy_faction, instance_id=instance_id, position=(0,0)
                )
                # --- Debug Print for Loading --- 
                if unit_instance:
                    print(f"    Successfully created instance for: {enemy_faction} - {unit_type} (ID: {instance_id})")
                    temp_enemy_units.append(unit_instance)
                else:
                    print(f"    *** FAILED to create instance for: {enemy_faction} - {unit_type} ***")
                # ----------------------------- 
        else:
            print("Warning: Mission data missing enemy faction or base unit types!")
            
        print(f"  Final enemy instances created (before placement): {[u['name'] for u in temp_enemy_units]}")
            
        # --- Random Placement Logic --- 
        PLAYER_ZONE_COLS = range(0, 3) # Columns 0, 1, 2
        ENEMY_ZONE_COLS = range(self.grid_size[0] - 3, self.grid_size[0]) # Last 3 columns
        MAX_PLACEMENT_ATTEMPTS = 50 # Prevent infinite loops
        
        occupied_cells = set()

        def is_area_clear(start_x, start_y, size, grid_width, grid_height):
            for dx in range(size[0]):
                for dy in range(size[1]):
                    check_x, check_y = start_x + dx, start_y + dy
                    # Check bounds
                    if not (0 <= check_x < grid_width and 0 <= check_y < grid_height):
                        return False
                    # Check occupancy
                    if (check_x, check_y) in occupied_cells:
                        return False
            return True

        # Place Player Units
        print("Placing player units...")
        for unit in temp_player_units:
            unit_size = unit.get("size", (1, 1))
            placed = False
            for _ in range(MAX_PLACEMENT_ATTEMPTS):
                rand_x = random.choice(PLAYER_ZONE_COLS)
                rand_y = random.randrange(self.grid_size[1])
                
                if is_area_clear(rand_x, rand_y, unit_size, self.grid_size[0], self.grid_size[1]):
                    unit["position"] = (rand_x, rand_y)
                    self.player_units.append(unit) # Add to final list
                    for dx in range(unit_size[0]):
                        for dy in range(unit_size[1]):
                            place_x, place_y = rand_x + dx, rand_y + dy
                            self.grid[place_x][place_y] = unit
                            occupied_cells.add((place_x, place_y))
                    placed = True
                    print(f"  Placed {unit['name']} at ({rand_x}, {rand_y})")
                    break # Stop trying for this unit
            
            if not placed:
                 print(f"Warning: Could not find placement for player unit {unit['name']} after {MAX_PLACEMENT_ATTEMPTS} attempts.")

        # Place Enemy Units
        print("Placing enemy units...")
        # Sort units by size (larger first) to prioritize placing large units first
        temp_enemy_units.sort(key=lambda unit: unit.get("size", (1, 1))[0] * unit.get("size", (1, 1))[1], reverse=True)
        
        for unit in temp_enemy_units:
            unit_size = unit.get("size", (1, 1))
            placed = False
            
            # Increase attempts for larger units
            max_attempts = MAX_PLACEMENT_ATTEMPTS
            if unit_size[0] > 1 or unit_size[1] > 1:
                max_attempts = MAX_PLACEMENT_ATTEMPTS * 2
            
            for _ in range(max_attempts):
                rand_x = random.choice(ENEMY_ZONE_COLS)
                rand_y = random.randrange(self.grid_size[1] - unit_size[1] + 1)  # Ensure it fits vertically
                
                # Ensure large units don't go out of bounds horizontally
                if rand_x + unit_size[0] > self.grid_size[0]:
                    continue
                
                if is_area_clear(rand_x, rand_y, unit_size, self.grid_size[0], self.grid_size[1]):
                    unit["position"] = (rand_x, rand_y)
                    self.enemy_units.append(unit) # Add to final list
                    for dx in range(unit_size[0]):
                        for dy in range(unit_size[1]):
                            place_x, place_y = rand_x + dx, rand_y + dy
                            self.grid[place_x][place_y] = unit
                            occupied_cells.add((place_x, place_y))
                    placed = True
                    print(f"  Placed {unit['name']} at ({rand_x}, {rand_y}) with size {unit_size}")
                    break # Stop trying for this unit
            
            if not placed:
                 print(f"Warning: Could not find placement for enemy unit {unit['name']} after {max_attempts} attempts.")

        # --- Rest of setup remains --- 
        self.turn = "player"
        self.current_phase = "select"
        self.battle_message = "Select a unit to begin your turn."
        
        print("--- Battle Initialization Complete ---")
        
    def handle_event(self, event):
        """Handle pygame events"""
        # Handle keyboard events
        if event.type == pygame.KEYDOWN:
            # Handle keyboard events here
            pass
        
        # Handle mouse events
        elif event.type == pygame.MOUSEBUTTONDOWN:
            if event.button == 1:  # Left mouse button
                # Handle left click
                self._handle_click(event.pos)
        
        # Handle custom events
        elif event.type == pygame.USEREVENT:
            # Handle end of battle events
            if 'action' in event.dict:
                if event.dict['action'] == 'victory':
                    self.game_manager.change_scene('mission_select')
                elif event.dict['action'] == 'defeat':
                    self.game_manager.change_scene('mission_select')
        
        # Handle enemy AI event
        elif event.type == pygame.USEREVENT+1:
            if 'action' in event.dict and event.dict['action'] == 'run_enemy_ai':
                self._run_enemy_ai()
                # Clear the timer
                pygame.time.set_timer(pygame.USEREVENT+1, 0)
        
        # Handle end enemy turn event
        elif event.type == pygame.USEREVENT+2:
            if 'action' in event.dict and event.dict['action'] == 'end_enemy_turn':
                self._end_enemy_turn()
                # Clear the timer
                pygame.time.set_timer(pygame.USEREVENT+2, 0)
                
    def update(self):
        """Update the battle scene"""
        # If it's the enemy's turn, handle AI
        if self.turn == "enemy" and self.current_phase == "enemy":
            # Simple enemy AI
            self._run_enemy_ai()
            
    def render(self, screen):
        """Render the battle scene"""
        # Fill the background
        screen.fill(self.bg_color)
        
        screen_width, screen_height = screen.get_size()
        
        # Calculate grid offset to center the grid
        grid_width = self.grid_size[0] * CELL_SIZE
        grid_height = self.grid_size[1] * CELL_SIZE
        grid_offset_x = (screen_width - grid_width) // 2
        grid_offset_y = 100  # Leave space at top for UI
        
        # Draw grid background
        grid_rect = pygame.Rect(grid_offset_x, grid_offset_y, grid_width, grid_height)
        pygame.draw.rect(screen, (60, 40, 40), grid_rect)
        
        # --- Pass 1: Draw grid cells and unit sprites --- 
        units_to_draw_hp_for = [] # Store info for HP bars
        drawn_unit_ids = set() # Keep track of drawn multi-tile units

        # DEBUG: Track Battle Suit rendering
        battle_suit_cells_found = []
        battle_suit_origin_cell = None
        
        for x in range(self.grid_size[0]):
            for y in range(self.grid_size[1]):
                cell_rect = pygame.Rect(
                    grid_offset_x + x * CELL_SIZE, 
                    grid_offset_y + y * CELL_SIZE,
                    CELL_SIZE, 
                    CELL_SIZE
                )
                
                # Determine cell color (simplified for brevity, assuming original logic is fine)
                if (x, y) in self.move_range: 
                    cell_color = (100, 100, 150)
                elif (x, y) in self.attack_range: 
                    cell_color = (150, 100, 100)
                elif self.selected_cell == (x, y): 
                    cell_color = (120, 120, 120)
                else: 
                    # Use terrain type color as base
                    terrain_type = self.terrain_grid[x][y]
                    base_color = self.terrain_types[terrain_type]["color"]
                    # Apply checkerboard pattern by slightly adjusting brightness
                    cell_color = base_color if (x + y) % 2 == 0 else tuple(max(0, c - 10) for c in base_color)
                
                # Draw cell
                pygame.draw.rect(screen, cell_color, cell_rect)
                pygame.draw.rect(screen, (40, 30, 30), cell_rect, 1)  # Grid lines
                
                # Get unit if present in this cell
                unit = self.grid[x][y]
                if unit:
                    unit_id = unit["id"]
                    # --- Multi-tile Handling ---
                    # Check if this cell is part of the unit's occupied area
                    start_x, start_y = unit["position"]
                    unit_size = unit.get("size", (1, 1))
                    
                    # Track any Battle Suit cells we find
                    if unit.get("unit_type") == "battle_suit":
                        battle_suit_cells_found.append((x, y))
                        if x == start_x and y == start_y:
                            battle_suit_origin_cell = (x, y)
                    
                    # Check if this cell is within the unit's occupied area
                    is_part_of_unit = (
                        start_x <= x < start_x + unit_size[0] and 
                        start_y <= y < start_y + unit_size[1]
                    )
                    
                    # Only do full-unit draw from the top-left (origin) cell
                    # This ensures we draw once per unit, not once per cell
                    is_origin_cell = (x == start_x and y == start_y)
                    
                    if is_part_of_unit:
                        # Only calculate sprite once for the unit, at the origin cell
                        if is_origin_cell:
                            sprite_pixel_size = (CELL_SIZE * unit_size[0], CELL_SIZE * unit_size[1])
                            
                            # Calculate center position based on the actual area the unit occupies
                            unit_width_pixels = unit_size[0] * CELL_SIZE
                            unit_height_pixels = unit_size[1] * CELL_SIZE
                            unit_pos = (
                                grid_offset_x + start_x * CELL_SIZE + unit_width_pixels/2,
                                grid_offset_y + start_y * CELL_SIZE + unit_height_pixels/2
                            )
                            
                            # Debug output for multi-tile unit rendering
                            if unit.get("unit_type") == "battle_suit":
                                print(f"\n===== BATTLE SUIT RENDERING DEBUG =====")
                                print(f"Unit Details:")
                                print(f"  Name: {unit.get('name', 'Unknown')}")
                                print(f"  Position: {unit.get('position', 'Unknown')}")
                                print(f"  Size: {unit_size}")
                                print(f"  Grid Position: ({x}, {y})")
                                print(f"  Is Origin Cell: {is_origin_cell}")
                                print(f"  Unit ID: {unit_id}")
                                print(f"  Sprite Pixel Size: {sprite_pixel_size}")
                                print(f"  Unit Pos: {unit_pos}")
                                print("=======================================\n")
                            
                            # Special handling for Battle Suit
                            if unit.get("unit_type") == "battle_suit":
                                # Calculate the full grid area rectangle
                                overlay_rect = pygame.Rect(
                                    grid_offset_x + start_x * CELL_SIZE,
                                    grid_offset_y + start_y * CELL_SIZE,
                                    CELL_SIZE * unit_size[0],
                                    CELL_SIZE * unit_size[1]
                                )
                                
                                # Try loading directly with an absolute path first
                                image_loaded = False
                                absolute_path = os.path.abspath(os.path.join(
                                    os.path.dirname(__file__), 
                                    "..", "assets", "images", "units", "tau", "battle_suit.jpg"
                                ))
                                
                                try:
                                    if os.path.exists(absolute_path):
                                        print(f"Loading Battle Suit image from absolute path: {absolute_path}")
                                        battle_suit_image = pygame.image.load(absolute_path)
                                        print(f"Image loaded! Size: {battle_suit_image.get_size()}")
                                        
                                        # Draw a bright background underneath to help with visibility
                                        pygame.draw.rect(screen, (200, 200, 200), overlay_rect)
                                        
                                        # Scale the image to fit the entire 2x2 area
                                        scaled_image = pygame.transform.scale(
                                            battle_suit_image,
                                            (overlay_rect.width, overlay_rect.height)
                                        )
                                        
                                        # Draw the image
                                        screen.blit(scaled_image, overlay_rect)
                                        print(f"DIRECTLY DREW BATTLE SUIT AT {overlay_rect}")
                                        image_loaded = True
                                        
                                        # No border drawing here
                                except Exception as e:
                                    print(f"Error loading from absolute path: {e}")
                                    image_loaded = False
                                
                                # Try the relative path as fallback
                                if not image_loaded:
                                    try:
                                        relative_path = os.path.join("game", "assets", "images", "units", "tau", "battle_suit.jpg")
                                        if os.path.exists(relative_path):
                                            print(f"Loading Battle Suit image from relative path: {relative_path}")
                                            battle_suit_image = pygame.image.load(relative_path)
                                            
                                            # Draw a background first
                                            pygame.draw.rect(screen, (200, 200, 200), overlay_rect)
                                            
                                            # Scale and draw the image
                                            scaled_image = pygame.transform.scale(
                                                battle_suit_image,
                                                (overlay_rect.width, overlay_rect.height)
                                            )
                                            screen.blit(scaled_image, overlay_rect)
                                            print(f"SUCCESSFULLY DREW BATTLE SUIT FROM RELATIVE PATH")
                                            image_loaded = True
                                            
                                            # Removed cyan border
                                    except Exception as e:
                                        print(f"Error loading from relative path: {e}")
                                        image_loaded = False
                                
                                # If image loading failed, draw a fallback with text
                                if not image_loaded:
                                    # Draw a bright background with pattern
                                    pygame.draw.rect(screen, (255, 100, 100), overlay_rect)
                                    
                                    # Draw crosshatch pattern
                                    for i in range(0, overlay_rect.width, 10):
                                        pygame.draw.line(screen, (255, 255, 100), 
                                                        (overlay_rect.left + i, overlay_rect.top),
                                                        (overlay_rect.left, overlay_rect.top + i), 2)
                                        pygame.draw.line(screen, (255, 255, 100),
                                                        (overlay_rect.right - i, overlay_rect.top),
                                                        (overlay_rect.right, overlay_rect.top + i), 2)
                                    
                                    # Removed thick yellow border
                                    
                                    # Draw text label
                                    font = self.fonts["large"]
                                    text = font.render("BATTLE SUIT", True, (255, 255, 255))
                                    text_rect = text.get_rect(center=overlay_rect.center)
                                    
                                    # Draw text with shadow for better visibility
                                    shadow_rect = text_rect.copy()
                                    shadow_rect.x += 2
                                    shadow_rect.y += 2
                                    shadow_text = font.render("BATTLE SUIT", True, (0, 0, 0))
                                    screen.blit(shadow_text, shadow_rect)
                                    screen.blit(text, text_rect)
                                    
                                    # Add second line of text
                                    font_small = self.fonts["medium"]
                                    text2 = font_small.render("IMAGE FAILED", True, (255, 255, 0))
                                    text2_rect = text2.get_rect(centerx=overlay_rect.centerx, top=text_rect.bottom + 5)
                                    screen.blit(text2, text2_rect)
                                
                                # Keep track that this unit has been drawn
                                drawn_unit_ids.add(unit_id)
                                continue
                            
                            # Get or create unit sprite for normal units
                            if unit_id not in self.unit_sprites:
                                # Create new sprite with correct target_size
                                sprite = UnitSprite(unit, unit_pos, target_size=sprite_pixel_size)
                                self.unit_sprites[unit_id] = sprite
                    else:
                                # Update existing sprite position
                                sprite = self.unit_sprites[unit_id]
                                sprite.set_position(unit_pos[0], unit_pos[1])
                            
                            # Set sprite properties
                            sprite.set_selected(self.selected_unit and unit["id"] == self.selected_unit["id"])
                            sprite.set_hovered(False)
                            
                            # Draw the sprite
                            sprite.draw(screen)
                            drawn_unit_ids.add(unit_id)

                            # Store info needed to draw HP bar later
                            units_to_draw_hp_for.append({
                                'sprite': sprite,
                                'hp': unit["hp"],
                                'max_hp': unit["max_hp"]
                            })
                        # For non-origin cells of multi-tile units, draw an outline to show it's part of the same unit
                        elif unit.get("unit_type") == "battle_suit" and unit_size[0] > 1 and unit_size[1] > 1:
                            # Draw cell outline to show it's part of a multi-tile unit
                            multi_cell_rect = pygame.Rect(
                                grid_offset_x + x * CELL_SIZE,
                                grid_offset_y + y * CELL_SIZE,
                                CELL_SIZE,
                                CELL_SIZE
                            )
                            # Don't draw the outline/border
                            # pygame.draw.rect(screen, (120, 50, 50), multi_cell_rect, 2)

        # After rendering all cells, output debug info about Battle Suit rendering
        print("\n==== BATTLE SUIT RENDERING VERIFICATION ====")
        print(f"Cells containing Battle Suit: {battle_suit_cells_found}")
        print(f"Battle Suit origin cell: {battle_suit_origin_cell}")
        print(f"Drawn unit IDs: {drawn_unit_ids}")
        if not battle_suit_cells_found:
            print("WARNING: No Battle Suit found in any grid cell during rendering!")
        elif not battle_suit_origin_cell:
            print("WARNING: Battle Suit found but no origin cell identified!")
        print("===========================================\n")
        
        # --- Pass 2: Draw HP bars on top of sprites --- 
        for hp_info in units_to_draw_hp_for:
            sprite = hp_info['sprite']
            hp_percent = hp_info['hp'] / hp_info['max_hp']
            
            # Adjust HP bar based on potentially larger sprite width
            bar_width = sprite.rect.width - 10 
            bar_height = 6
            
            # HP bar background - Position below the sprite rect
                    bar_rect = pygame.Rect(
                sprite.rect.left + (sprite.rect.width - bar_width) // 2, # Centered below sprite
                sprite.rect.bottom + 3, 
                        bar_width,
                        bar_height
                    )
                    pygame.draw.rect(screen, (40, 40, 40), bar_rect)
                    
                    # HP bar fill
                    if hp_percent > 0:
                # Determine color based on health percentage
                if hp_percent > 0.6: 
                    hp_color = (0, 200, 0)  # Green for high health
                elif hp_percent > 0.3: 
                    hp_color = (200, 200, 0)  # Yellow for medium health
                else:
                    hp_color = (200, 0, 0)  # Red for low health
                
                fill_width = max(1, int(bar_width * hp_percent))
                fill_rect = pygame.Rect(bar_rect.left, bar_rect.top, fill_width, bar_height)
                pygame.draw.rect(screen, hp_color, fill_rect)
                pygame.draw.rect(screen, (20, 20, 20), fill_rect, 1)  # Border around filled portion
        
        # Draw battle UI
        self._render_battle_ui(screen, grid_offset_x, grid_offset_y, grid_width, grid_height)
        
        # Draw debug info if enabled
        if self.debug_mode:
            debug_info = [
                f"Turn: {self.turn}",
                f"Phase: {self.current_phase}",
                f"Selected Unit: {self.selected_unit['name'] if self.selected_unit else 'None'}",
                f"Last Action: {self.last_action}"
            ]
            
            for i, info in enumerate(debug_info):
                self.draw_text(screen, info, (10, 10 + i * 20), "small", (255, 255, 255))
        
        # FINAL DIRECT DRAWING OF BATTLE SUIT (force display as last step)
        self._force_draw_battle_suit(screen, grid_offset_x, grid_offset_y)
    
    def _force_draw_battle_suit(self, screen, grid_offset_x, grid_offset_y):
        """Force draw the Battle Suit as the very last step, on top of everything else"""
        # Find the Battle Suit in the grid
        battle_suit = None
        battle_suit_pos = None
        
        for x in range(self.grid_size[0]):
            for y in range(self.grid_size[1]):
                unit = self.grid[x][y]
                if unit and unit.get("unit_type") == "battle_suit":
                    # Check if this is the origin cell (top-left)
                    start_x, start_y = unit["position"]
                    if x == start_x and y == start_y:
                        battle_suit = unit
                        battle_suit_pos = (x, y)
                        break
            if battle_suit:
                break
        
        # Draw the Battle Suit if found
        if battle_suit and battle_suit_pos:
            # Get position and size
            x, y = battle_suit_pos
            unit_size = battle_suit.get("size", (2, 2))  # Default to 2x2
            
            # Create a very noticeable rectangle for the Battle Suit
            overlay_rect = pygame.Rect(
                grid_offset_x + x * CELL_SIZE,
                grid_offset_y + y * CELL_SIZE,
                CELL_SIZE * unit_size[0],
                CELL_SIZE * unit_size[1]
            )
            
            # Try loading directly with an absolute path first
            image_loaded = False
            absolute_path = os.path.abspath(os.path.join(
                os.path.dirname(__file__), 
                "..", "assets", "images", "units", "tau", "battle_suit.jpg"
            ))
            
            try:
                if os.path.exists(absolute_path):
                    print(f"Loading Battle Suit image from absolute path: {absolute_path}")
                    battle_suit_image = pygame.image.load(absolute_path)
                    print(f"Image loaded! Size: {battle_suit_image.get_size()}")
                    
                    # Draw a bright background underneath to help with visibility
                    pygame.draw.rect(screen, (200, 200, 200), overlay_rect)
                    
                    # Scale the image to fit the entire 2x2 area
                    scaled_image = pygame.transform.scale(
                        battle_suit_image,
                        (overlay_rect.width, overlay_rect.height)
                    )
                    
                    # Draw the image
                    screen.blit(scaled_image, overlay_rect)
                    print(f"DIRECTLY DREW BATTLE SUIT AT {overlay_rect}")
                    image_loaded = True
                    
                    # No border drawing here
            except Exception as e:
                print(f"Error loading from absolute path: {e}")
                image_loaded = False
            
            # Try the relative path as fallback
            if not image_loaded:
                try:
                    relative_path = os.path.join("game", "assets", "images", "units", "tau", "battle_suit.jpg")
                    if os.path.exists(relative_path):
                        print(f"Loading Battle Suit image from relative path: {relative_path}")
                        battle_suit_image = pygame.image.load(relative_path)
                        
                        # Draw a background first
                        pygame.draw.rect(screen, (200, 200, 200), overlay_rect)
                        
                        # Scale and draw the image
                        scaled_image = pygame.transform.scale(
                            battle_suit_image,
                            (overlay_rect.width, overlay_rect.height)
                        )
                        screen.blit(scaled_image, overlay_rect)
                        print(f"SUCCESSFULLY DREW BATTLE SUIT FROM RELATIVE PATH")
                        image_loaded = True
                        
                        # Removed cyan border
                except Exception as e:
                    print(f"Error loading from relative path: {e}")
                    image_loaded = False
            
            # If image loading failed, draw a fallback with text
            if not image_loaded:
                # Draw a bright background with pattern
                pygame.draw.rect(screen, (255, 100, 100), overlay_rect)
                
                # Draw crosshatch pattern
                for i in range(0, overlay_rect.width, 10):
                    pygame.draw.line(screen, (255, 255, 100), 
                                    (overlay_rect.left + i, overlay_rect.top),
                                    (overlay_rect.left, overlay_rect.top + i), 2)
                    pygame.draw.line(screen, (255, 255, 100),
                                    (overlay_rect.right - i, overlay_rect.top),
                                    (overlay_rect.right, overlay_rect.top + i), 2)
                
                # Removed thick yellow border
                
                # Draw text label
                font = self.fonts["large"]
                text = font.render("BATTLE SUIT", True, (255, 255, 255))
                text_rect = text.get_rect(center=overlay_rect.center)
                
                # Draw text with shadow for better visibility
                shadow_rect = text_rect.copy()
                shadow_rect.x += 2
                shadow_rect.y += 2
                shadow_text = font.render("BATTLE SUIT", True, (0, 0, 0))
                screen.blit(shadow_text, shadow_rect)
                screen.blit(text, text_rect)
                
                # Add second line of text
                font_small = self.fonts["medium"]
                text2 = font_small.render("IMAGE FAILED", True, (255, 255, 0))
                text2_rect = text2.get_rect(centerx=overlay_rect.centerx, top=text_rect.bottom + 5)
                screen.blit(text2, text2_rect)
            
            # Always draw HP bar
            hp_percent = battle_suit["hp"] / battle_suit["max_hp"]
            bar_width = overlay_rect.width - 10
            bar_height = 12  # Thicker bar for visibility
            bar_rect = pygame.Rect(
                overlay_rect.left + 5,
                overlay_rect.bottom + 5,
                bar_width,
                bar_height
            )
            pygame.draw.rect(screen, (40, 40, 40), bar_rect)
            
            if hp_percent > 0:
                hp_color = (0, 255, 0) if hp_percent > 0.6 else (255, 255, 0) if hp_percent > 0.3 else (255, 0, 0)
                fill_width = max(1, int(bar_width * hp_percent))
                fill_rect = pygame.Rect(bar_rect.left, bar_rect.top, fill_width, bar_height)
                pygame.draw.rect(screen, hp_color, fill_rect)
        else:
            print("WARNING: Battle Suit not found for direct drawing!")
        
    def _render_battle_ui(self, screen, grid_x, grid_y, grid_width, grid_height):
        """Render the battle UI elements"""
        screen_width, screen_height = screen.get_size()
        
        # Reset the buttons list
        self.buttons = []
        
        # Draw turn indicator and turn counter
        turn_text = f"Turn: {'Player' if self.turn == 'player' else 'Enemy'}"
        turn_pos = (grid_x + 20, 30)
        self.draw_text(screen, turn_text, turn_pos, "medium", (200, 200, 200))

        # Display turn counter
        turn_counter_text = f"Turn #{self.turn_counter}"
        turn_counter_pos = (grid_x + 20, 60)
        self.draw_text(screen, turn_counter_text, turn_counter_pos, "small", (180, 180, 200))
        
        # Draw units-that-can-act indicator - restored for player units
        if self.turn == "player":
            available_units = sum(1 for unit in self.player_units if unit["id"] not in self.acted_units)
            units_text = f"Units available: {available_units}/{len(self.player_units)}"
            units_pos = (grid_x + grid_width - 300, 90)
            self.draw_text(screen, units_text, units_pos, "small", (200, 200, 200))
        
        # Display morale indicators
        # Player morale
        player_morale_color = (0, 200, 0) if self.player_morale >= 100 else (200, 200, 0) if self.player_morale >= 50 else (200, 0, 0)
        player_morale_text = f"Player Morale: {self.player_morale}%"
        player_morale_pos = (grid_x + grid_width - 300, 60)
        self.draw_text(screen, player_morale_text, player_morale_pos, "small", player_morale_color)
        
        # Enemy morale
        enemy_morale_color = (0, 200, 0) if self.enemy_morale >= 100 else (200, 200, 0) if self.enemy_morale >= 50 else (200, 0, 0)
        enemy_morale_text = f"Enemy Morale: {self.enemy_morale}%"
        enemy_morale_pos = (grid_x + grid_width - 150, 60)
        self.draw_text(screen, enemy_morale_text, enemy_morale_pos, "small", enemy_morale_color)
        
        # Draw phase indicator
        phase_text = f"Phase: {self.current_phase.capitalize()}"
        phase_pos = (grid_x + grid_width - 150, 30)
        self.draw_text(screen, phase_text, phase_pos, "medium", (200, 200, 200))
        
        # Draw battle message
        message_pos = (screen_width // 2, 60)
        self.draw_text(screen, self.battle_message, message_pos, "medium", (255, 255, 255), True)
        
        # Draw units-that-can-act indicator - No longer needed since all units can act
        # if self.turn == "player":
        #     available_units = sum(1 for unit in self.player_units if unit["id"] not in self.acted_units)
        #     units_text = f"Units available: {available_units}/{len(self.player_units)}"
        #     units_pos = (grid_x + grid_width - 150, 60)
        #     self.draw_text(screen, units_text, units_pos, "small", (200, 200, 200))
        
        # Draw UI for selected unit
        if self.selected_unit:
            # Unit info panel
            info_panel_rect = pygame.Rect(grid_x, grid_y + grid_height + 20, grid_width, 100)
            pygame.draw.rect(screen, (60, 40, 40), info_panel_rect, border_radius=5)
            
            # Unit name and stats
            name_pos = (info_panel_rect.left + 20, info_panel_rect.top + 15)
            self.draw_text(
                screen, 
                f"{self.selected_unit['name']} - HP: {self.selected_unit['hp']}/{self.selected_unit['max_hp']}", 
                name_pos, 
                "medium", 
                (200, 200, 200)
            )
            
            # Unit abilities
            abilities_text = "Abilities: " + ", ".join(self.selected_unit["abilities"])
            abilities_pos = (info_panel_rect.left + 20, info_panel_rect.top + 45)
            self.draw_text(screen, abilities_text, abilities_pos, "small", (200, 200, 200))
            
            # Stats
            stats_text = f"ATK: {self.selected_unit['attack']} | DEF: {self.selected_unit['defense']} | Move: {self.selected_unit['move_range']} | Range: {self.selected_unit['attack_range']}"
            stats_pos = (info_panel_rect.left + 20, info_panel_rect.top + 75)
            self.draw_text(screen, stats_text, stats_pos, "small", (200, 200, 200))
            
            # Action buttons for player units during player turn
            if self.turn == "player" and self.selected_unit["id"].startswith("player_"):
                button_width = 120
                button_height = 40
                button_spacing = 20
                button_y = info_panel_rect.bottom + 20
                
                # No longer need to check if unit has acted
                # unit_has_acted = self.selected_unit["id"] in self.acted_units
                
                # Different buttons based on the current phase
                if self.current_phase == "select":
                    # Move button
                    move_button_rect = self.draw_button(
                        screen, "MOVE", 
                        (grid_x + button_width // 2 + 20, button_y),
                        (button_width, button_height),
                        hover=self.hover_button == "move",
                        centered=True
                    )
                    
                    # Always add button to clickable list
                        self.buttons.append({"id": "move", "rect": move_button_rect})
                    
                    # Attack button
                    attack_button_rect = self.draw_button(
                        screen, "ATTACK", 
                        (grid_x + button_width + button_spacing + button_width // 2 + 20, button_y),
                        (button_width, button_height),
                        hover=self.hover_button == "attack",
                        centered=True
                    )
                    
                    # Always add button to clickable list
                        self.buttons.append({"id": "attack", "rect": attack_button_rect})
                    
                    # Ability button - Enable based on specific abilities and usage
                    ability_active = False
                    ability_hover = self.hover_button == "ability"
                    ability_bg_color = (40, 40, 60) # Default dim

                    if "frag_rocket" in self.selected_unit["abilities"] and "frag_rocket" not in self.selected_unit.get("used_abilities", set()):
                        ability_active = True
                        ability_bg_color = None # Use default button color if active

                    ability_button_rect = self.draw_button(
                        screen, "ABILITY", 
                        (grid_x + 2 * (button_width + button_spacing) + button_width // 2 + 20, button_y),
                        (button_width, button_height),
                        hover=ability_hover and ability_active, # Hover only if active
                        centered=True,
                        bg_color=ability_bg_color if not ability_active else None 
                    )
                    
                    # Only add ability button to clickable list if active
                    if ability_active:
                        self.buttons.append({"id": "ability", "rect": ability_button_rect})
                    
                elif self.current_phase in ["move", "attack", "ability_target"]: # Added ability_target
                    # Cancel button
                    cancel_button_rect = self.draw_button(
                        screen, "CANCEL", 
                        (grid_x + button_width // 2 + 20, button_y),
                        (button_width, button_height),
                        hover=self.hover_button == "cancel",
                        centered=True
                    )
                    self.buttons.append({"id": "cancel", "rect": cancel_button_rect})
        
        # End turn button (only shown during player turn)
        if self.turn == "player":
            end_turn_rect = self.draw_button(
                screen, "END TURN", 
                (grid_x + grid_width - 80, grid_y - 50),
                (140, 40),
                hover=self.hover_button == "end_turn",
                centered=True
            )
            self.buttons.append({"id": "end_turn", "rect": end_turn_rect})
    
    def _screen_to_grid(self, screen_pos):
        """Convert screen coordinates to grid coordinates"""
        screen_width, screen_height = pygame.display.get_surface().get_size()
        
        # Calculate grid offset
        grid_width = self.grid_size[0] * CELL_SIZE
        grid_height = self.grid_size[1] * CELL_SIZE
        grid_offset_x = (screen_width - grid_width) // 2
        grid_offset_y = 100
        
        # Check if position is within grid
        if (grid_offset_x <= screen_pos[0] <= grid_offset_x + grid_width and
            grid_offset_y <= screen_pos[1] <= grid_offset_y + grid_height):
            
            # Calculate grid coordinates
            grid_x = (screen_pos[0] - grid_offset_x) // CELL_SIZE
            grid_y = (screen_pos[1] - grid_offset_y) // CELL_SIZE
            
            if 0 <= grid_x < self.grid_size[0] and 0 <= grid_y < self.grid_size[1]:
                return (int(grid_x), int(grid_y))
        
        return None
    
    def _handle_grid_click(self, grid_pos):
        """Handle clicks on the battle grid"""
        x, y = grid_pos
        unit = self.grid[x][y]
        
        # Different handling based on current phase
        if self.current_phase == "select":
            # Select a unit
            if unit:
                # In player turn, can only select player units that haven't acted yet
                if self.turn == "player" and unit["id"].startswith("player_"):
                    if unit["id"] in self.acted_units:
                        self.battle_message = f"{unit['name']} has already acted this turn."
                        self.last_action = f"Tried to select unit that already acted: {unit['name']}"
                    else:
                        self.selected_unit = unit
                        self.battle_message = f"Selected {unit['name']}. Choose an action."
                        self.last_action = f"Selected unit: {unit['name']}"
                # In enemy turn, can only select enemy units (for info)
                elif self.turn == "enemy" and unit["id"].startswith("enemy_"):
                    self.selected_unit = unit
                    self.battle_message = f"Enemy unit: {unit['name']}"
                    self.last_action = f"Viewing enemy unit: {unit['name']}"
                else:
                    self.battle_message = "Cannot select this unit now."
                    self.last_action = "Tried to select invalid unit"
            else:
                self.selected_unit = None
                self.battle_message = "No unit selected."
                self.last_action = "Deselected unit"
                
        elif self.current_phase == "move":
            # Move the selected unit
            if (x, y) in self.move_range:
                # If the destination is empty
                if not unit:
                    # Update unit position
                    old_x, old_y = self.selected_unit["position"]
                    self.grid[old_x][old_y] = None
                    self.grid[x][y] = self.selected_unit
                    self.selected_unit["position"] = (x, y)
                    
                    # Mark unit as having acted this turn
                    self.acted_units.add(self.selected_unit["id"])
                    
                    # Clear move range
                    self.move_range = []
                    
                    # Return to select phase
                    self.current_phase = "select"
                    self.battle_message = f"{self.selected_unit['name']} moved. Select another unit."
                    self.last_action = f"Moved {self.selected_unit['name']} to ({x}, {y})"
                else:
                    self.battle_message = "Cannot move to an occupied cell."
                    self.last_action = "Tried to move to occupied cell"
            else:
                self.battle_message = "Invalid move destination."
                self.last_action = "Tried to move to invalid destination"
        
        elif self.current_phase == "attack":
            # Attack a target
            if (x, y) in self.attack_range:
                # If there's an enemy unit at the target location
                target = self.grid[x][y] # Get the primary target
                if target and target["id"].startswith("enemy_" if self.turn == "player" else "player_"):
                    attacker = self.selected_unit
                    
                    # --- Primary Target Damage Calculation --- 
                    primary_damage = max(5, attacker["attack"] - target["defense"] // 2)
                    primary_damage += random.randint(-3, 3)
                    
                    attacker_damage_type = attacker.get("attack_damage_type", "normal")
                    target_is_weak = target.get("is_weak_to_explosives", False)
                    weakness_multiplier = 1.5 
                    
                    primary_damage_message_suffix = ""
                    if attacker_damage_type == "explosive" and target_is_weak:
                        primary_damage = int(primary_damage * weakness_multiplier)
                        primary_damage_message_suffix = " (Weakness!) "
                        print(f"Applying explosive weakness bonus to primary target! Final: {primary_damage}")
                    
                    # Apply primary damage
                    target["hp"] = max(0, target["hp"] - primary_damage)
                    self.battle_message = f"Attacked {target['name']} for {primary_damage} damage!{primary_damage_message_suffix}"
                    self.last_action = f"Attacked {target['name']} for {primary_damage} damage"
                    print(f"{attacker['name']} hit {target['name']} for {primary_damage} damage.{primary_damage_message_suffix}")

                    primary_target_defeated = target["hp"] <= 0
                    defeated_units_this_attack = []
                    if primary_target_defeated:
                        defeated_units_this_attack.append(target)

                    # --- Splash Damage Calculation (if explosive) --- 
                    if attacker_damage_type == "explosive":
                        print(f"Calculating splash damage around ({x}, {y})...")
                        splash_messages = []
                        for dx, dy in [(0, 1), (1, 0), (0, -1), (-1, 0)]: # Adjacent cells
                            adj_x, adj_y = x + dx, y + dy
                            
                            # Check grid bounds
                            if 0 <= adj_x < self.grid_size[0] and 0 <= adj_y < self.grid_size[1]:
                                splash_target = self.grid[adj_x][adj_y]
                                
                                # Check if splash target exists, is not the primary target, and is an enemy
                                if splash_target and splash_target != target and splash_target["id"].startswith("enemy_" if self.turn == "player" else "player_"):
                                    print(f"  ... Found potential splash target: {splash_target['name']} at ({adj_x}, {adj_y})")
                                    # Calculate base splash damage (e.g., 50% of attacker's base attack, min 1)
                                    splash_damage_base = max(1, (attacker["attack"] // 2) - (splash_target["defense"] // 4)) # Reduced defense effect too
                                    # splash_damage_base += random.randint(-1, 1) # Optional randomness?

                                    # Check splash target weakness
                                    splash_target_is_weak = splash_target.get("is_weak_to_explosives", False)
                                    splash_damage_final = splash_damage_base
                                    splash_weakness_suffix = ""
                                    if splash_target_is_weak:
                                        splash_damage_final = int(splash_damage_base * weakness_multiplier)
                                        splash_weakness_suffix = " (Weakness!)"
                                        print(f"    Applying explosive weakness bonus to splash target! Final: {splash_damage_final}")
                                    
                                    # Apply splash damage
                                    splash_target["hp"] = max(0, splash_target["hp"] - splash_damage_final)
                                    splash_msg = f"{splash_target['name']} took {splash_damage_final} splash damage!{splash_weakness_suffix}"
                                    splash_messages.append(splash_msg)
                                    print(f"    {splash_msg}")

                                    # Check if splash target defeated
                                    if splash_target["hp"] <= 0 and splash_target not in defeated_units_this_attack:
                                        defeated_units_this_attack.append(splash_target)
                        
                        # Append splash messages to main battle message if any occurred
                        if splash_messages:
                            self.battle_message += " | " + " | ".join(splash_messages)
                            self.last_action += " (+splash)"
                    # --- End Splash Damage --- 

                    # Mark attacker as having acted this turn
                    self.acted_units.add(attacker["id"])

                    # --- Handle Defeated Units (Primary + Splash) --- 
                    if defeated_units_this_attack:
                        for defeated_unit in defeated_units_this_attack:
                            dx, dy = defeated_unit["position"] # Use the unit's stored position
                            if self.grid[dx][dy] == defeated_unit: # Check if grid still matches (optional safety)
                                self.grid[dx][dy] = None 
                                # If unit was multi-tile, clear other cells too (Handle later if needed)
                            
                            if defeated_unit["id"] == target["id"]:
                                self.battle_message = f"{target['name']} was defeated!{primary_damage_message_suffix}" # Main defeat message
                        else:
                                self.battle_message += f" | {defeated_unit['name']} defeated by splash!" # Append splash defeat

                            if defeated_unit["id"] in self.unit_sprites:
                                del self.unit_sprites[defeated_unit["id"]] # Remove sprite

                            if defeated_unit["id"].startswith("enemy_"):
                                if defeated_unit in self.enemy_units:
                                     self.enemy_units.remove(defeated_unit)
                            else:
                                if defeated_unit in self.player_units:
                                     self.player_units.remove(defeated_unit)
                        
                        # Check win/lose condition after removing units
                        if not self.enemy_units:
                            self.battle_message = "Victory! All enemies defeated."
                            self.last_action = "Victory achieved"
                            pygame.time.set_timer(pygame.USEREVENT, 3000)
                            pygame.event.post(pygame.event.Event(pygame.USEREVENT, {'action': 'victory'}))
                        elif not self.player_units:
                            self.battle_message = "Defeat! All your units are lost."
                            self.last_action = "Defeat - all units lost"
                            pygame.time.set_timer(pygame.USEREVENT, 3000)
                            pygame.event.post(pygame.event.Event(pygame.USEREVENT, {'action': 'defeat'}))
                    # --- End Defeat Handling --- 
                    
                    # Clear attack range
                    self.attack_range = []
                    
                    # Return to select phase
                    self.current_phase = "select"
                else:
                    self.battle_message = "Invalid attack target."
                    self.last_action = "Tried to attack invalid target"
            else:
                self.battle_message = "Target out of range."
                self.last_action = "Tried to attack target out of range"
    
    def _handle_button_click(self, button_id):
        """Handle UI button clicks"""
        if button_id == "move":
            if self.selected_unit and self.turn == "player":
                # Check if unit has already acted
                if self.selected_unit["id"] in self.acted_units:
                    self.battle_message = f"{self.selected_unit['name']} has already acted this turn."
                    self.last_action = f"Tried to use unit that already acted: {self.selected_unit['name']}"
                    return
                
                self.current_phase = "move"
                self.battle_message = "Select a destination to move to."
                self.last_action = "Started move action"
                
                # Calculate move range
                self.move_range = self._calculate_move_range(
                    self.selected_unit["position"], 
                    self.selected_unit["move_range"]
                )
                self.attack_range = []
                
        elif button_id == "attack":
            if self.selected_unit and self.turn == "player":
                # Check if unit has already acted
                if self.selected_unit["id"] in self.acted_units:
                    self.battle_message = f"{self.selected_unit['name']} has already acted this turn."
                    self.last_action = f"Tried to use unit that already acted: {self.selected_unit['name']}"
                    return
                
                self.current_phase = "attack"
                self.battle_message = "Select a target to attack."
                self.last_action = "Started attack action"
                
                # Calculate attack range
                self.attack_range = self._calculate_attack_range(
                    self.selected_unit["position"], 
                    self.selected_unit["attack_range"]
                )
                self.move_range = []
                
        elif button_id == "ability":
            if self.selected_unit and self.turn == "player":
                if self.selected_unit["id"] in self.acted_units:
                     # ... (message: unit already acted) ...
                    return
                
                # Check for specific abilities usable now
                can_use_frag_rocket = ("frag_rocket" in self.selected_unit["abilities"] and 
                                        "frag_rocket" not in self.selected_unit.get("used_abilities", set()))

                if can_use_frag_rocket:
                    self.current_phase = "ability_target"
                    self.current_ability = "frag_rocket" # Store which ability
                    self.battle_message = "Select target for Frag Rocket."
                    self.last_action = "Started Frag Rocket ability"
                    
                    # Calculate range (using attack range for now)
                    self.attack_range = self._calculate_attack_range(
                        self.selected_unit["position"], 
                        self.selected_unit["attack_range"] 
                    )
                    self.move_range = []
                else:
                    self.battle_message = "No usable abilities available now."
                    self.last_action = "Checked abilities (none usable)"
            
        elif button_id == "cancel":
            # Cancel current action (move, attack, OR ability_target)
            self.current_phase = "select"
            self.move_range = []
            self.attack_range = []
            if hasattr(self, 'current_ability'): # Clear ability tracking if cancelling
                delattr(self, 'current_ability')
            self.battle_message = "Action canceled."
            self.last_action = "Canceled action"
            
        elif button_id == "end_turn":
            # End player turn
            self._end_player_turn()
    
    def _end_player_turn(self):
        """Processes the end of the player's turn"""
        # Process buffs for all player units
        for player_unit in self.player_units:
            buffs_to_remove = []
            
            if "buffs" in player_unit:
                for buff in player_unit["buffs"]:
                    # Decrease duration
                    buff["duration"] -= 1
                    
                    # If duration is 0, remove buff
                    if buff["duration"] <= 0:
                        buffs_to_remove.append(buff)
            
            # Remove expired buffs
            for buff in buffs_to_remove:
                player_unit["buffs"].remove(buff)
        
        # Set turn to enemy
        self.turn = "enemy"
        
        # Clear selected unit, move range, and attack range
        self.selected_unit = None
        self.move_range = []
        self.attack_range = []
        
        # Set battle message
        self.battle_message = "Enemy turn"
        
        # Schedule enemy AI to run after a short delay to allow player to see turn transition
        pygame.time.set_timer(pygame.USEREVENT+1, 500)  # 500ms delay
        pygame.event.post(pygame.event.Event(pygame.USEREVENT+1, {'action': 'run_enemy_ai'}))

    def _start_player_turn(self):
        """Starts the player's turn"""
        # Check win condition first - don't start player turn if there are no enemies
        if not self.enemy_units:
            self.battle_message = "Victory! All enemy units are defeated."
            # Return to mission select after 3 seconds
            pygame.time.set_timer(pygame.USEREVENT, 3000)
            pygame.event.post(pygame.event.Event(pygame.USEREVENT, {'action': 'victory'}))
            return

        # Process buffs for player units
        for player_unit in self.player_units:
            buffs_to_remove = []
            
            if "buffs" in player_unit:
                for buff in player_unit["buffs"]:
                    # Process any start-of-turn effects
                    if buff["type"] == "regeneration":
                        # Regenerate HP
                        player_unit["hp"] = min(player_unit["max_hp"], player_unit["hp"] + buff["value"])
                        self.battle_message = f"{player_unit['name']} regenerated {buff['value']} HP"
            
            # Remove expired buffs
            for buff in buffs_to_remove:
                player_unit["buffs"].remove(buff)
        
        # Set turn to player
        self.turn = "player"
        
        # Reset acted_units for player units at the start of player turn
        self.acted_units = set()
        # self.moved_units = set()
        
        # Set battle message
        self.battle_message = "Player turn - select a unit"

    def _run_enemy_ai(self):
        """Run AI for all enemy units"""
        # Check if player units are all gone before doing anything
        if not self.player_units:
            self.battle_message = "Defeat! All player units are defeated."
            # Return to mission select after 3 seconds
            pygame.time.set_timer(pygame.USEREVENT, 3000)
            pygame.event.post(pygame.event.Event(pygame.USEREVENT, {'action': 'defeat'}))
            return
            
        # Process all enemy units
        for enemy in self.enemy_units[:]:  # Create a copy to allow removing units
            # Get enemy faction
            faction = enemy.get("faction", "ork")
            
            # Process AI based on faction
            if faction == "tau":
                self._run_tau_ai(enemy)
            else:
                self._run_default_ai(enemy)
            
            # Check for player defeat after each enemy acts
            if not self.player_units:
                self.battle_message = "Defeat! All player units are defeated."
                # Return to mission select after 3 seconds
                pygame.time.set_timer(pygame.USEREVENT, 3000)
                pygame.event.post(pygame.event.Event(pygame.USEREVENT, {'action': 'defeat'}))
                return
        
        # After all enemies have acted, check if any enemies remain
        if not self.enemy_units:
            self.battle_message = "Victory! All enemy units are defeated."
            # Return to mission select after 3 seconds
            pygame.time.set_timer(pygame.USEREVENT, 3000)
            pygame.event.post(pygame.event.Event(pygame.USEREVENT, {'action': 'victory'}))
            return
            
        # Schedule the end of enemy turn after a short delay
        pygame.time.set_timer(pygame.USEREVENT+2, 500)  # 500ms delay
        pygame.event.post(pygame.event.Event(pygame.USEREVENT+2, {'action': 'end_enemy_turn'}))

    def _end_enemy_turn(self):
        """Ends the enemy turn and returns to the player's turn"""
        # Check win/lose conditions before starting player turn
        if not self.enemy_units:
            self.battle_message = "Victory! All enemy units are defeated."
            # Return to mission select after 3 seconds
            pygame.time.set_timer(pygame.USEREVENT, 3000)
            pygame.event.post(pygame.event.Event(pygame.USEREVENT, {'action': 'victory'}))
            return
        elif not self.player_units:
            self.battle_message = "Defeat! All player units are defeated."
            # Return to mission select after 3 seconds
            pygame.time.set_timer(pygame.USEREVENT, 3000)
            pygame.event.post(pygame.event.Event(pygame.USEREVENT, {'action': 'defeat'}))
            return
        
        # Increment the turn counter when starting a new player turn
        self.turn_counter += 1
            
        # End enemy turn
        self._start_player_turn()
        
        # Clear selected unit, move range, and attack range
        self.selected_unit = None
        self.move_range = []
        self.attack_range = []
        
        # Set battle message
        self.battle_message = "Player turn - select a unit"
    
    def _calculate_move_range(self, start_pos, move_range):
        """Calculate available move positions"""
        result = []
        to_check = [start_pos]
        checked = set()
        
        # Simple breadth-first search for move range
        while to_check and move_range > 0:
            new_to_check = []
            
            for pos in to_check:
                x, y = pos
                
                # Check all four directions
                for dx, dy in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
                    new_x, new_y = x + dx, y + dy
                    
                    # Skip if out of bounds
                    if not (0 <= new_x < self.grid_size[0] and 0 <= new_y < self.grid_size[1]):
                        continue
                    
                    # Skip if already checked
                    if (new_x, new_y) in checked:
                        continue
                    
                    # Skip if occupied
                    if self.grid[new_x][new_y] is not None:
                        continue
                    
                    # Add to result and check in next iteration
                    result.append((new_x, new_y))
                    new_to_check.append((new_x, new_y))
                    checked.add((new_x, new_y))
            
            to_check = new_to_check
            move_range -= 1
        
        return result
    
    def _calculate_attack_range(self, start_pos, attack_range):
        """Calculate available attack positions"""
        result = []
        x, y = start_pos
        
        # Check all cells within attack range (using Manhattan distance)
        for dx in range(-attack_range, attack_range + 1):
            for dy in range(-attack_range, attack_range + 1):
                if abs(dx) + abs(dy) <= attack_range:  # Manhattan distance
                    new_x, new_y = x + dx, y + dy
                    
                    # Skip if out of bounds
                    if not (0 <= new_x < self.grid_size[0] and 0 <= new_y < self.grid_size[1]):
                        continue
                    
                    # Skip the unit's own position
                    if new_x == x and new_y == y:
                        continue
                    
                    # Add to result
                    result.append((new_x, new_y))
        
        return result
    
    def _run_tau_ai(self, tau_unit):
        """Tau AI behavior: tactical, coordinated, prioritizing ranged combat"""
        # Tau are intelligent and strategic
        ex, ey = tau_unit["position"]
        
        # Assess if the Tau unit is in danger (player unit too close)
        in_danger = False
        for player_unit in self.player_units:
            px, py = player_unit["position"]
            distance = abs(ex - px) + abs(ey - py)
            if distance <= 2:  # If a player unit is adjacent or very close
                in_danger = True
                break
        
        if in_danger:
            # If in danger, retreat to a safer position
            # Try to move away from closest threat while maintaining attack capability
            closest_threat = None
            min_distance = float('inf')
            
            for player_unit in self.player_units:
                px, py = player_unit["position"]
                distance = abs(ex - px) + abs(ey - py)
                
                if distance < min_distance:
                    min_distance = distance
                    closest_threat = player_unit
            
            self._tau_tactical_retreat(tau_unit)
        else:
            # Find player units in attack range
            attackable_units = []
            for player_unit in self.player_units:
                px, py = player_unit["position"]
                distance = abs(ex - px) + abs(ey - py)
                if distance <= tau_unit["attack_range"]:
                    attackable_units.append((player_unit, distance))
            
            if attackable_units:
                # Tau are smart - they target the most damaged unit or highest value target
                target = self._tau_select_optimal_target(attackable_units)
                px, py = target["position"]
                
                # Tau have more precise attacks
                damage = max(5, tau_unit["attack"] - target["defense"] // 2)
                damage += random.randint(-2, 4)  # Smaller variance (more precise)
                damage = int(damage * 1.15)  # 15% bonus for superior weapons
                
                # Apply damage
                target["hp"] = max(0, target["hp"] - damage)
                self.battle_message = f"Enemy {tau_unit['name']} precisely targets your {target['name']} for {damage} damage."
                self.last_action = f"Tau {tau_unit['name']} attacked with precision for {damage} damage"
                
                # Check if unit is defeated
                if target["hp"] <= 0:
                    self.grid[px][py] = None
                    self.player_units.remove(target)
                    self.battle_message = f"Your {target['name']} was eliminated by Tau for the Greater Good."
                    self.last_action = f"Tau defeated your {target['name']}"
                    
                    # Check lose condition
                    if not self.player_units:
                        self.battle_message = "Defeat! All your units have been eliminated by the Tau Empire."
                        self.last_action = "Defeat - all units lost"
                        pygame.time.set_timer(pygame.USEREVENT, 3000)
                        pygame.event.post(pygame.event.Event(pygame.USEREVENT, {'action': 'defeat'}))
            else:
                # Find optimal position to move to (maintain range advantage)
                self._tau_find_optimal_position(tau_unit)

    def _tau_select_optimal_target(self, attackable_units):
        """Tau select the optimal target based on damage potential and unit value"""
        highest_value_target = None
        highest_score = -1
        
        for unit, distance in attackable_units:
            # Score based on multiple factors (higher is better target)
            score = 0
            
            # Factor 1: Low HP targets are more valuable (finishing off units)
            hp_percent = unit["hp"] / unit["max_hp"]
            score += (1 - hp_percent) * 10  # Up to 10 points for nearly dead units
            
            # Factor 2: High-value targets (based on unit type)
            unit_type = unit.get("unit_type", "")
            if "terminator" in unit_type or "captain" in unit_type:
                score += 5  # Elite units
            elif "devastator" in unit_type:
                score += 3  # Heavy weapon units
            
            # Factor 3: Closer targets are slightly preferred (more reliable hits)
            score += (5 - min(distance, 5))
            
            if score > highest_score:
                highest_score = score
                highest_value_target = unit
        
        # If no high-value target found, return the first one
        return highest_value_target or attackable_units[0][0]

    def _tau_tactical_retreat(self, tau_unit):
        """Move Tau unit to a safer position while maintaining firing capability"""
        ex, ey = tau_unit["position"]
        
        # Find the closest player unit (the threat)
        closest_threat = None
        min_distance = float('inf')
        
        for player_unit in self.player_units:
            px, py = player_unit["position"]
            distance = abs(ex - px) + abs(ey - py)
            
            if distance < min_distance:
                min_distance = distance
                closest_threat = player_unit
        
        if closest_threat:
            px, py = closest_threat["position"]
            
            # Calculate direction away from threat
            dx = ex - px
            dy = ey - py
            
            # Normalize direction
            if dx != 0:
                dx = dx // abs(dx)
            if dy != 0:
                dy = dy // abs(dy)
            
            # Try to move away from threat
            move_range = self._calculate_move_range(tau_unit["position"], tau_unit["move_range"])
            
            if move_range:
                # Find the best retreat position (away from threat but maintaining attack range)
                best_retreat = None
                best_score = -1
                
                for move_pos in move_range:
                    mx, my = move_pos
                    
                    # Score this position based on:
                    # 1. Distance from threat (higher is better)
                    # 2. Number of player units still in attack range from this position
                    distance_from_threat = abs(mx - px) + abs(my - py)
                    
                    # Count player units in attack range from this position
                    units_in_range = 0
                    for player_unit in self.player_units:
                        px2, py2 = player_unit["position"]
                        if abs(mx - px2) + abs(my - py2) <= tau_unit["attack_range"]:
                            units_in_range += 1
                    
                    # Calculate score (weight both factors)
                    position_score = distance_from_threat * 2 + units_in_range * 3
                    
                    if position_score > best_score:
                        best_score = position_score
                        best_retreat = move_pos
                
                if best_retreat:
                    # Move to best retreat position
                    mx, my = best_retreat
                    self.grid[ex][ey] = None
                    self.grid[mx][my] = tau_unit
                    tau_unit["position"] = (mx, my)
                    self.battle_message = f"Enemy {tau_unit['name']} makes a tactical withdrawal."
                    self.last_action = f"Tau {tau_unit['name']} retreated to ({mx}, {my})"
                    return True
        
        return False

    def _tau_find_optimal_position(self, tau_unit):
        """Find and move to optimal firing position for Tau"""
        ex, ey = tau_unit["position"]
        attack_range = tau_unit["attack_range"]
        move_range = self._calculate_move_range(tau_unit["position"], tau_unit["move_range"])
        
        if not move_range:
            return False
        
        # Find the position that can attack the most player units
        best_position = None
        max_targets = -1
        
        for move_pos in move_range:
            mx, my = move_pos
            
            # Count how many player units would be in range from this position
            targets_in_range = 0
            for player_unit in self.player_units:
                px, py = player_unit["position"]
                if abs(mx - px) + abs(my - py) <= attack_range:
                    targets_in_range += 1
            
            if targets_in_range > max_targets:
                max_targets = targets_in_range
                best_position = move_pos
        
        # If no position can attack, move toward nearest player unit
        if max_targets == 0:
            nearest_player = None
            min_distance = float('inf')
            
            for player_unit in self.player_units:
                px, py = player_unit["position"]
                distance = abs(ex - px) + abs(ey - py)
                
                if distance < min_distance:
                    min_distance = distance
                    nearest_player = player_unit
            
            if nearest_player:
                px, py = nearest_player["position"]
                
                # Find the move that gets closest to the target but maintains some distance
                best_move = None
                best_score = -1
                
                for move_pos in move_range:
                    mx, my = move_pos
                    distance_to_player = abs(mx - px) + abs(my - py)
                    
                    # Tau prefer to stay at medium range (not too close, not too far)
                    distance_score = 0
                    if distance_to_player <= attack_range:
                        # Prefer positions at about 2/3 of attack range
                        optimal_distance = (attack_range * 2) // 3
                        distance_score = 10 - abs(distance_to_player - optimal_distance)
                    else:
                        # If can't reach attack range, get as close as possible
                        distance_score = 20 - distance_to_player
                    
                    if distance_score > best_score:
                        best_score = distance_score
                        best_move = move_pos
                
                if best_move:
                    best_position = best_move
        
        # Move to the best position if found
        if best_position:
            mx, my = best_position
            self.grid[ex][ey] = None
            self.grid[mx][my] = tau_unit
            tau_unit["position"] = (mx, my)
            self.battle_message = f"Enemy {tau_unit['name']} maneuvers to an optimal firing position."
            self.last_action = f"Tau {tau_unit['name']} moved to optimal position ({mx}, {my})"
            return True
        
        return False

    def _run_default_ai(self, enemy):
        """Default AI for enemy units with no specific faction AI"""
            # Find the nearest player unit
            nearest_player = None
            min_distance = float('inf')
            
            for player_unit in self.player_units:
                ex, ey = enemy["position"]
                px, py = player_unit["position"]
                distance = abs(ex - px) + abs(ey - py)  # Manhattan distance
                
                if distance < min_distance:
                    min_distance = distance
                    nearest_player = player_unit
            
            if nearest_player:
                # If in attack range, attack
                ex, ey = enemy["position"]
                px, py = nearest_player["position"]
                
                if min_distance <= enemy["attack_range"]:
                    # Attack
                    damage = max(5, enemy["attack"] - nearest_player["defense"] // 2)
                    damage += random.randint(-3, 3)  # Add some randomness
                    
                    # Apply damage
                    nearest_player["hp"] = max(0, nearest_player["hp"] - damage)
                    
                    # Check if unit is defeated
                    if nearest_player["hp"] <= 0:
                        self.grid[px][py] = None
                        self.player_units.remove(nearest_player)
                        self.battle_message = f"Your {nearest_player['name']} was defeated!"
                        self.last_action = f"Enemy defeated your {nearest_player['name']}"
                        
                        # Check lose condition
                        if not self.player_units:
                            self.battle_message = "Defeat! All your units are lost."
                            self.last_action = "Defeat - all units lost"
                            # Return to mission select after 3 seconds
                            pygame.time.set_timer(pygame.USEREVENT, 3000)
                            pygame.event.post(pygame.event.Event(pygame.USEREVENT, {'action': 'defeat'}))
                    else:
                        self.battle_message = f"Enemy {enemy['name']} attacked your {nearest_player['name']} for {damage} damage!"
                        self.last_action = f"Enemy attacked for {damage} damage"
                
                # If not in attack range, try to move closer
                else:
                    # Calculate move
                    move_range = self._calculate_move_range(enemy["position"], enemy["move_range"])
                    
                    if move_range:
                        # Find the move that gets closest to the target
                        best_move = None
                        best_distance = float('inf')
                        
                        for move_pos in move_range:
                            mx, my = move_pos
                            distance = abs(mx - px) + abs(my - py)
                            
                            if distance < best_distance:
                                best_distance = distance
                                best_move = move_pos
                        
                        if best_move:
                            # Move
                            mx, my = best_move
                            self.grid[ex][ey] = None
                            self.grid[mx][my] = enemy
                            enemy["position"] = (mx, my)
                            self.battle_message = f"Enemy {enemy['name']} moved closer."
                            self.last_action = f"Enemy {enemy['name']} moved to ({mx}, {my})"

    def draw_button(self, screen, text, pos, size, hover=False, centered=False, bg_color=None):
        """Draw a button and return its rect"""
        # Default colors
        button_color = (60, 60, 80)  # Dark blue-gray
        hover_color = (80, 80, 100)  # Lighter blue-gray
        text_color = (200, 200, 220)  # Light gray
        
        # Use provided background color if valid, otherwise use default
        if bg_color is not None and isinstance(bg_color, tuple) and len(bg_color) == 3:
            button_color = bg_color
            hover_color = tuple(min(255, c + 20) for c in bg_color)  # Slightly lighter version
        
        # Create button rect
        button_rect = pygame.Rect(0, 0, size[0], size[1])
        if centered:
            button_rect.center = pos
        else:
            button_rect.topleft = pos
        
        # Draw button background
        pygame.draw.rect(screen, hover_color if hover else button_color, button_rect, border_radius=5)
        pygame.draw.rect(screen, (100, 100, 120), button_rect, border_radius=5, width=2)
        
        # Draw button text
        text_surface = self.fonts["medium"].render(text, True, text_color)
        text_rect = text_surface.get_rect(center=button_rect.center)
        screen.blit(text_surface, text_rect)
        
        return button_rect 

    def _get_unit_at_pos(self, pos):
        """Get the unit at a given position"""
        for unit in self.player_units + self.enemy_units:
            if unit["position"] == pos:
                return unit
        return None

    def _handle_click(self, pos):
        """Handle a click on the grid"""
        # Get the grid position that was clicked
        grid_pos = self._get_grid_pos(pos)
        
        # If grid position is out of bounds, return
        if not grid_pos:
            return
        
        # Get the unit at the clicked position
        clicked_unit = self._get_unit_at_pos(grid_pos)
        
        # If it's the player's turn
        if self.turn == "player":
            # If no unit is selected and the player clicks one of their units
            if not self.selected_unit and clicked_unit and clicked_unit in self.player_units:
                # Check if the unit has already acted this turn
                if clicked_unit["id"] in self.acted_units:
                    self.battle_message = f"{clicked_unit['name']} has already acted this turn."
                    return
                
                # Select the unit
                self.selected_unit = clicked_unit
                
                # Calculate move range
                self.move_range = self._calculate_move_range(grid_pos, clicked_unit["move_range"])
                
                # Set battle message
                self.battle_message = f"Selected {clicked_unit['name']}. Move or attack."
                
                return
            
            # If a unit is already selected
            if self.selected_unit:
                # If the clicked position is in the move range
                if grid_pos in self.move_range:
                    # Move the selected unit to the clicked position
                    old_pos = self.selected_unit["position"]
                    self.selected_unit["position"] = grid_pos
                    
                    # Clear move range
                    self.move_range = []
                    
                    # Calculate attack range
                    self.attack_range = self._calculate_attack_range(grid_pos, self.selected_unit["attack_range"])
                    
                    # Set battle message
                    self.battle_message = f"Unit moved to ({grid_pos[0]}, {grid_pos[1]}). Select an enemy to attack, or click the unit to end its turn."
                    
                    # Log the action
                    self.last_action = f"Moved {self.selected_unit['name']} from ({old_pos[0]}, {old_pos[1]}) to ({grid_pos[0]}, {grid_pos[1]})"
                    
                    # Track that this unit has moved
                    # self.moved_units.add(self.selected_unit["id"])
                    
                    return
                
                # If the clicked position is in the attack range and there's an enemy there
                if grid_pos in self.attack_range and clicked_unit and clicked_unit in self.enemy_units:
                    # Attack the enemy
                    damage = self._calculate_damage(self.selected_unit, clicked_unit)
                    clicked_unit["hp"] -= damage
                    
                    # Check if the attack was a critical hit
                    critical_text = " CRITICAL HIT!" if self.selected_unit.get("last_hit_critical", False) else ""
                    
                    # Set battle message
                    self.battle_message = f"{self.selected_unit['name']} attacks {clicked_unit['name']} for {damage} damage!{critical_text}"
                    
                    # Log the action
                    self.last_action = f"{self.selected_unit['name']} attacked {clicked_unit['name']} for {damage} damage{' (Critical)' if critical_text else ''}"
                    
                    # Check if the enemy is defeated
                    if clicked_unit["hp"] <= 0:
                        self.enemy_units.remove(clicked_unit)
                        self.battle_message = f"{clicked_unit['name']} was defeated!"
                        
                        # Update morale - enemy morale decreases, player morale increases
                        morale_impact = clicked_unit.get("morale_value", 10)  # Default impact of 10
                        self.enemy_morale = max(0, self.enemy_morale - morale_impact)  # Decrease enemy morale
                        self.player_morale = min(200, self.player_morale + (morale_impact // 2))  # Increase player morale
                        
                        # Add experience to the player unit
                        self.selected_unit["xp"] += clicked_unit.get("xp_value", 10)
                        
                        # Check for level up
                        if self.selected_unit["xp"] >= 100:
                            self.selected_unit["xp"] -= 100
                            self.selected_unit["level"] += 1
                            # Increase stats
                            self.selected_unit["attack"] += 2
                            self.selected_unit["defense"] += 1
                            self.selected_unit["max_hp"] += 5
                            self.selected_unit["hp"] = self.selected_unit["max_hp"]
                            self.battle_message = f"{self.selected_unit['name']} leveled up to level {self.selected_unit['level']}!"
                        
                        # Check win condition
                        if not self.enemy_units:
                            self.battle_message = "Victory! All enemy units are defeated."
                            # Return to mission select after 3 seconds
                            pygame.time.set_timer(pygame.USEREVENT, 3000)
                            pygame.event.post(pygame.event.Event(pygame.USEREVENT, {'action': 'victory'}))
                    
                    # Mark the unit as having acted
                    self.acted_units.add(self.selected_unit["id"])
                    
                    # Clear selected unit, move range, and attack range
                    self.selected_unit = None
                    self.move_range = []
                    self.attack_range = []
                    
                    return
                
                # If the player clicks the selected unit again
                if clicked_unit and clicked_unit == self.selected_unit:
                    # End the unit's turn
                    self.acted_units.add(self.selected_unit["id"])
                    
                    # Clear selected unit, move range, and attack range
                    self.selected_unit = None
                    self.move_range = []
                    self.attack_range = []
                    
                    # Set battle message
                    self.battle_message = "Unit's turn ended. Select another unit."
                    
                    return
                
                # If the player clicks elsewhere, deselect the unit
                self.selected_unit = None
                self.move_range = []
                self.attack_range = []
                self.battle_message = "Select a unit to begin."
                
                return
            
            # If no unit is selected and the player clicks elsewhere
            self.battle_message = "No unit at that position. Select one of your units."
            
            return

    def _get_grid_pos(self, screen_pos):
        """Convert screen coordinates to grid cell coordinates"""
        screen_width, screen_height = pygame.display.get_surface().get_size()
        
        # Calculate grid offset to center the grid
        grid_width = self.grid_size[0] * CELL_SIZE
        grid_height = self.grid_size[1] * CELL_SIZE
        grid_offset_x = (screen_width - grid_width) // 2
        grid_offset_y = 100  # Same as in render method
        
        # Check if position is within grid bounds
        if not (grid_offset_x <= screen_pos[0] < grid_offset_x + grid_width and
                grid_offset_y <= screen_pos[1] < grid_offset_y + grid_height):
            return None
        
        # Calculate grid cell coordinates
        grid_x = (screen_pos[0] - grid_offset_x) // CELL_SIZE
        grid_y = (screen_pos[1] - grid_offset_y) // CELL_SIZE
        
        # Extra validation
        if 0 <= grid_x < self.grid_size[0] and 0 <= grid_y < self.grid_size[1]:
            return (int(grid_x), int(grid_y))
        
        return None

    def _calculate_damage(self, attacker, defender):
        """Calculate damage based on attacker and defender stats"""
        # Base damage is attacker's attack stat
        base_damage = attacker.get("attack", 5)
        
        # Get defender's position and terrain type
        defender_x, defender_y = defender["position"]
        terrain_type = self.terrain_grid[defender_x][defender_y]
        terrain_defense_bonus = self.terrain_types[terrain_type]["defense_bonus"]
        
        # Reduce damage based on defender's defense + terrain bonus
        defense_modifier = (defender.get("defense", 0) + terrain_defense_bonus) // 2
        
        # Calculate damage with randomness
        damage = max(1, base_damage - defense_modifier)
        damage += random.randint(-2, 2)  # Add some randomness
        
        # Check for special damage types
        attacker_damage_type = attacker.get("attack_damage_type", "normal")
        defender_weakness = defender.get("is_weak_to_" + attacker_damage_type, False)
        
        # Apply weakness multiplier if applicable
        if defender_weakness:
            damage = int(damage * 1.5)
        
        # Critical hit system (10% chance for 50% more damage)
        is_critical = random.random() < 0.1
        if is_critical:
            damage = int(damage * 1.5)
            # Store the critical hit info to display in battle message
            attacker["last_hit_critical"] = True
        else:
            attacker["last_hit_critical"] = False
        
        # Store terrain info for battle messages
        attacker["last_hit_terrain"] = terrain_type
        attacker["last_hit_terrain_bonus"] = terrain_defense_bonus
        
        # Ensure at least 1 damage
        return max(1, damage)

    def _generate_terrain(self):
        """Generate random terrain across the battlefield"""
        # Reset terrain grid to all plains
        self.terrain_grid = [["plain" for _ in range(self.grid_size[1])] for _ in range(self.grid_size[0])]
        
        # Generate some forests (25% chance)
        for x in range(self.grid_size[0]):
            for y in range(self.grid_size[1]):
                if random.random() < 0.25:
                    self.terrain_grid[x][y] = "forest"
        
        # Add a few mountains (5% chance)
        for x in range(self.grid_size[0]):
            for y in range(self.grid_size[1]):
                if random.random() < 0.05:
                    self.terrain_grid[x][y] = "mountain"
        
        # Add a small lake or river (water)
        water_cells = min(8, self.grid_size[0] * self.grid_size[1] // 10)  # Around 10% of the map
        water_start_x = random.randint(2, self.grid_size[0] - 3)
        water_start_y = random.randint(1, self.grid_size[1] - 2)
        
        # Create the water area starting from the random point
        water_points = [(water_start_x, water_start_y)]
        for _ in range(water_cells - 1):
            if not water_points:
                break
                
            # Pick a random existing water point
            wx, wy = random.choice(water_points)
            
            # Try to expand in a random direction
            directions = [(0, 1), (1, 0), (0, -1), (-1, 0)]
            random.shuffle(directions)
            
            for dx, dy in directions:
                nx, ny = wx + dx, wy + dy
                
                # Check if valid position and not already water
                if (0 <= nx < self.grid_size[0] and 0 <= ny < self.grid_size[1] and 
                    self.terrain_grid[nx][ny] != "water"):
                    self.terrain_grid[nx][ny] = "water"
                    water_points.append((nx, ny))
                    break

    def _apply_morale_to_damage(self, damage, attacker_side):
        """Applies morale bonus to damage calculation"""
        if attacker_side == "player":
            # Player gets bonus/penalty based on player morale
            morale_modifier = (self.player_morale - 100) / 200  # Range: -0.5 to +0.5
            # Cap at 50% bonus/penalty
            morale_modifier = max(-0.5, min(0.5, morale_modifier))
        else:
            # Enemy gets bonus/penalty based on enemy morale
            morale_modifier = (self.enemy_morale - 100) / 200  # Range: -0.5 to +0.5
            # Cap at 50% bonus/penalty
            morale_modifier = max(-0.5, min(0.5, morale_modifier))
        
        # Apply the modifier (could be positive or negative)
        modified_damage = damage * (1 + morale_modifier)
        
        # Return the modified damage, ensuring at least 1 damage
        return max(1, int(modified_damage))