"""
Battle Scene Module
Handles the turn-based combat system
"""
import pygame
import random
import time
import os
from game.ui.scene import Scene
from game.core.game_state import GameState
from game.core.unit_sprite import UnitSprite
from game.core.sprite import Sprite

# Grid cell size for the battle arena
CELL_SIZE = 64

class BattleScene(Scene):
    """
    Battle scene handling turn-based combat
    """
    def __init__(self, game_manager):
        """Initialize the battle scene"""
        super().__init__(game_manager)
        self.buttons = []
        self.hover_button = None
        
        # Battle state
        self.grid_size = (10, 6)  # Width, Height
        self.grid = [[None for _ in range(self.grid_size[1])] for _ in range(self.grid_size[0])]
        self.turn = "player"  # "player" or "enemy"
        self.selected_unit = None
        self.selected_cell = None
        self.move_range = []
        self.attack_range = []
        self.current_phase = "select"  # "select", "move", "attack", "enemy"
        self.battle_message = ""
        
        # Track which units have already acted this turn
        self.acted_units = set()
        
        # Background color - dark red/brown for battle theme
        self.bg_color = (40, 20, 20)
        
        # Sample units (placeholder)
        self.player_units = []
        self.enemy_units = []
        
        # Unit sprites dictionary
        self.unit_sprites = {}
        
        # Fullscreen tracking
        self.is_fullscreen = False
        
        # Debug flags
        self.debug_mode = True
        self.last_action = "None"
        
        # Initialize fonts
        self.fonts = {
            "small": pygame.font.Font(None, 24),
            "medium": pygame.font.Font(None, 32),
            "large": pygame.font.Font(None, 48)
        }
        
        # Initialize sample battle (will be replaced with mission data)
        self._initialize_sample_battle()
        
    def _initialize_sample_battle(self):
        """Initialize a sample battle using UnitLoader."""
        self.player_units = []
        self.enemy_units = []
        self.unit_sprites = {}
        self.grid = [[None for _ in range(self.grid_size[1])] for _ in range(self.grid_size[0])]
        self.acted_units = set()
        self.unit_loader = self.game_manager.unit_loader # Get loader instance

        # Define player starting setup (unit type and position)
        player_setup = [
            {"type": "tactical", "faction": "space_marines", "pos": (1, 2)},
            {"type": "devastator", "faction": "space_marines", "pos": (2, 3)},
            {"type": "scout", "faction": "space_marines", "pos": (1, 4)},
            {"type": "battle_suit", "faction": "tau", "pos": (3, 1)},
            {"type": "terminator", "faction": "space_marines", "pos": (2, 1)}
            # Removed Razorback unit
        ]
        
        # Create player units using the loader
        for i, setup in enumerate(player_setup):
            instance_id = f"player_{i+1}"
            unit_instance = self.unit_loader.create_unit_instance(
                unit_type=setup["type"],
                faction=setup["faction"],
                instance_id=instance_id,
                position=setup["pos"]
            )
            if unit_instance:
                # DEBUG: Check if battle_suit has correct size
                if setup["type"] == "battle_suit":
                    print(f"\n==== BATTLE SUIT CREATION DEBUG ====")
                    print(f"Unit ID: {instance_id}")
                    print(f"Unit Type: {unit_instance.get('unit_type', 'unknown')}")
                    print(f"Unit Size: {unit_instance.get('size', 'unknown')}")
                    print(f"Unit Position: {unit_instance.get('position', 'unknown')}")
                    print(f"Unit Name: {unit_instance.get('name', 'unknown')}")
                    print("=======================================\n")
                self.player_units.append(unit_instance)
            else:
                print(f"Failed to create player unit instance: {setup['faction']} - {setup['type']}")

        # Define enemy starting setup
        enemy_setup = [
            {"type": "boy", "faction": "orks", "pos": (8, 1)},
            {"type": "nob", "faction": "orks", "pos": (7, 3)},
            {"type": "devilfish", "faction": "tau", "pos": (5, 1)}, # Example with Tau
            {"type": "boy", "faction": "orks", "pos": (6, 4)}, # Changed to boy
            {"type": "shoota", "faction": "orks", "pos": (8, 5)}, # Changed to shoota
        ]

        # Create enemy units using the loader
        for i, setup in enumerate(enemy_setup):
            instance_id = f"enemy_{setup['faction']}_{i+1}" # More descriptive ID
            unit_instance = self.unit_loader.create_unit_instance(
                unit_type=setup["type"],
                faction=setup["faction"],
                instance_id=instance_id,
                position=setup["pos"]
            )
            if unit_instance:
                self.enemy_units.append(unit_instance)
            else:
                print(f"Failed to create enemy unit instance: {setup['faction']} - {setup['type']}")

        # Place units on the grid, handling size
        all_units = self.player_units + self.enemy_units
        for unit in all_units:
            start_x, start_y = unit["position"]
            unit_size = unit.get("size", (1, 1)) # Get size, default to 1x1
            
            # DEBUG: Check if we're placing the Battle Suit
            is_battle_suit = unit.get("unit_type") == "battle_suit"
            if is_battle_suit:
                print(f"\n==== BATTLE SUIT PLACEMENT DEBUG ====")
                print(f"Attempting to place Battle Suit at ({start_x}, {start_y}) with size {unit_size}")
            
            # Check if placement is valid
            can_place = True
            for dx in range(unit_size[0]):
                for dy in range(unit_size[1]):
                    check_x, check_y = start_x + dx, start_y + dy
                    # Check grid bounds
                    if not (0 <= check_x < self.grid_size[0] and 0 <= check_y < self.grid_size[1]):
                        can_place = False
                        print(f"Warning: Unit '{unit['name']}' placement out of bounds at ({check_x}, {check_y}).")
                        break
                    # Check if cell is already occupied
                    if self.grid[check_x][check_y] is not None:
                        can_place = False
                        print(f"Warning: Unit '{unit['name']}' overlaps with existing unit at ({check_x}, {check_y}).")
                        break
                if not can_place: break
            
            # Place unit if valid
            if can_place:
                for dx in range(unit_size[0]):
                    for dy in range(unit_size[1]):
                        place_x, place_y = start_x + dx, start_y + dy
                        self.grid[place_x][place_y] = unit # Place reference in all cells
                        if is_battle_suit:
                            print(f"  Placed Battle Suit reference in cell ({place_x}, {place_y})")
                
                # Debug output for multi-tile units
                if unit_size[0] > 1 or unit_size[1] > 1:
                    print(f"Placed multi-tile unit '{unit['name']}' size {unit_size} at position {(start_x, start_y)}")
                    for dx in range(unit_size[0]):
                        for dy in range(unit_size[1]):
                            print(f"  Cell ({start_x + dx}, {start_y + dy}) references this unit")
                    
                    if is_battle_suit:
                        print("Battle Suit placed successfully!")
                        print("=======================================\n")
            else:
                 print(f"Error: Could not place unit '{unit['name']}' at {unit['position']} size {unit_size}.")
                 # Handle error - maybe remove unit from list or try alternative placement?
                 # For now, we just print an error and it won't appear on the grid.
                 
                 if is_battle_suit:
                    print("Failed to place Battle Suit!")
                    print("=======================================\n")
            
        # Start with player turn
        self.turn = "player"
        self.current_phase = "select"
        self.battle_message = "Select a unit to begin your turn."
        
        # Debug output to check grid state
        print("\n----- GRID STATE AFTER INITIALIZATION -----")
        for i, unit in enumerate(self.player_units):
            print(f"Player unit {i+1}: {unit.get('name', 'Unknown')} ({unit.get('unit_type', 'Unknown')}) at position {unit.get('position', 'Unknown')}, size {unit.get('size', (1,1))}")
        for i, unit in enumerate(self.enemy_units):
            print(f"Enemy unit {i+1}: {unit.get('name', 'Unknown')} ({unit.get('unit_type', 'Unknown')}) at position {unit.get('position', 'Unknown')}, size {unit.get('size', (1,1))}")
        
        # Check specific grid cells
        print("\n----- CHECKING GRID CELLS -----")
        for x in range(self.grid_size[0]):
            for y in range(self.grid_size[1]):
                unit = self.grid[x][y]
                if unit:
                    print(f"Grid cell ({x}, {y}) contains: {unit.get('name', 'Unknown')} ({unit.get('unit_type', 'Unknown')})")
        
        # After initialization, check each grid cell for the Battle Suit
        print("\n==== GRID VERIFICATION FOR BATTLE SUIT ====")
        battle_suit_found = False
        for x in range(self.grid_size[0]):
            for y in range(self.grid_size[1]):
                unit = self.grid[x][y]
                if unit and unit.get("unit_type") == "battle_suit":
                    battle_suit_found = True
                    print(f"Found Battle Suit at grid cell ({x}, {y})")
                    print(f"  Unit ID: {unit.get('id', 'unknown')}")
                    print(f"  Unit Size: {unit.get('size', 'unknown')}")
                    print(f"  Unit Position: {unit.get('position', 'unknown')}")
        
        if not battle_suit_found:
            print("WARNING: Battle Suit not found in any grid cell!")
        print("==========================================\n")
        
    def enter(self):
        """Called when entering the battle scene"""
        # Reset battle state
        self.selected_unit = None
        self.selected_cell = None
        self.move_range = []
        self.attack_range = []
        self.current_phase = "select"
        self.battle_message = "Select a unit to begin your turn."
        self.last_action = "Entered battle scene"
        self.acted_units = set()  # Reset acted units
        
        # Get the selected mission from the mission scene
        mission_scene = self.game_manager.scenes[GameState.MISSION_SELECT]
        if hasattr(mission_scene, 'selected_mission'):
            selected_mission_data = next((m for m in mission_scene.mission_list if m["id"] == mission_scene.selected_mission), None)
            if selected_mission_data:
                # Initialize battle based on mission
                self._initialize_battle_from_mission(selected_mission_data)
            else:
                # Fallback to sample battle if mission data not found
                self._initialize_sample_battle()
        else:
            # Fallback to sample battle if no mission selected
            self._initialize_sample_battle()

    def _initialize_battle_from_mission(self, mission_data):
        """Initialize battle with randomized positions, enemy scaling, and debug prints."""
        self.player_units = []
        self.enemy_units = []
        self.unit_sprites = {}
        self.grid = [[None for _ in range(self.grid_size[1])] for _ in range(self.grid_size[0])]
        self.acted_units = set()
        self.unit_loader = self.game_manager.unit_loader

        print(f"--- Initializing Battle: {mission_data.get('name', 'Unknown')} ({mission_data.get('difficulty', 'Normal')}) ---")

        # --- Determine Player Units based on Difficulty --- 
        difficulty = mission_data.get("difficulty", "Medium")
        player_unit_types = []
        if difficulty == "Easy": player_unit_types = ["tactical", "scout"]
        elif difficulty == "Medium": player_unit_types = ["tactical", "devastator", "scout"]
        elif difficulty == "Hard": player_unit_types = ["tactical", "tactical", "devastator", "scout", "terminator"] # Added extra tactical marine
        elif difficulty == "Super Hard": player_unit_types = ["tactical", "tactical", "devastator", "devastator", "scout", "terminator", "captain"] # Added extra tactical and devastator marines
        else: player_unit_types = ["tactical", "devastator", "scout"] # Default

        print(f"  Player units: {player_unit_types} (Difficulty: {difficulty})")
        
        # Create player unit instances (without position yet)
        temp_player_units = []
        for i, unit_type in enumerate(player_unit_types):
            instance_id = f"player_{i+1}"
            # Create instance with placeholder position (0,0)
            unit_instance = self.unit_loader.create_unit_instance(
                unit_type=unit_type, faction="space_marines", instance_id=instance_id, position=(0,0)
            )
            if unit_instance:
                temp_player_units.append(unit_instance)
            else:
                print(f"Failed to create player unit instance: space_marines - {unit_type}")

        # --- Determine Enemy Units from Mission Data ---
        enemy_faction = mission_data.get("enemy_faction", None)
        base_enemy_unit_types = mission_data.get("enemy_unit_types", [])
        print(f"  Base enemy types: {base_enemy_unit_types}, Faction: {enemy_faction}")
        
        # --- Scale Enemy Numbers by Difficulty --- 
        scaled_enemy_unit_types = list(base_enemy_unit_types) # Start with base list
        if difficulty == "Medium" and base_enemy_unit_types:
            scaled_enemy_unit_types.append(base_enemy_unit_types[0]) # Add one more of the first type
            print(f"  Difficulty Medium: Added 1 extra enemy ({base_enemy_unit_types[0]})")
        elif difficulty == "Hard" and base_enemy_unit_types:
            scaled_enemy_unit_types.append(base_enemy_unit_types[0]) # Add first type
            if len(base_enemy_unit_types) > 1:
                scaled_enemy_unit_types.append(base_enemy_unit_types[1]) # Add second type if exists
            else:
                scaled_enemy_unit_types.append(base_enemy_unit_types[0]) # Otherwise add first type again
            print(f"  Difficulty Hard: Added 2 extra enemies")
        elif difficulty == "Super Hard" and base_enemy_unit_types:
            # Add more enemies for Super Hard difficulty
            for _ in range(2):  # Reduced from 3 to 2 extra enemies
                if base_enemy_unit_types:
                    scaled_enemy_unit_types.append(base_enemy_unit_types[0])  # Add more of the first type
            # Always add a devilfish for Super Hard
            scaled_enemy_unit_types.append("devilfish")
            print(f"  Difficulty Super Hard: Added 3 extra enemies (including a Devil Fish)")
        # -----------------------------------------
        
        temp_enemy_units = []
        print(f"  Attempting to load scaled enemy types: {scaled_enemy_unit_types}")
        if enemy_faction and scaled_enemy_unit_types:
            for i, unit_type in enumerate(scaled_enemy_unit_types):
                instance_id = f"enemy_{enemy_faction}_{i+1}" 
                # Create instance with placeholder position (0,0)
                unit_instance = self.unit_loader.create_unit_instance(
                    unit_type=unit_type, faction=enemy_faction, instance_id=instance_id, position=(0,0)
                )
                # --- Debug Print for Loading --- 
                if unit_instance:
                    print(f"    Successfully created instance for: {enemy_faction} - {unit_type} (ID: {instance_id})")
                    temp_enemy_units.append(unit_instance)
                else:
                    print(f"    *** FAILED to create instance for: {enemy_faction} - {unit_type} ***")
                # ----------------------------- 
        else:
            print("Warning: Mission data missing enemy faction or base unit types!")
            
        print(f"  Final enemy instances created (before placement): {[u['name'] for u in temp_enemy_units]}")
            
        # --- Random Placement Logic --- 
        PLAYER_ZONE_COLS = range(0, 3) # Columns 0, 1, 2
        ENEMY_ZONE_COLS = range(self.grid_size[0] - 3, self.grid_size[0]) # Last 3 columns
        MAX_PLACEMENT_ATTEMPTS = 50 # Prevent infinite loops
        
        occupied_cells = set()

        def is_area_clear(start_x, start_y, size, grid_width, grid_height):
            for dx in range(size[0]):
                for dy in range(size[1]):
                    check_x, check_y = start_x + dx, start_y + dy
                    # Check bounds
                    if not (0 <= check_x < grid_width and 0 <= check_y < grid_height):
                        return False
                    # Check occupancy
                    if (check_x, check_y) in occupied_cells:
                        return False
            return True

        # Place Player Units
        print("Placing player units...")
        for unit in temp_player_units:
            unit_size = unit.get("size", (1, 1))
            placed = False
            for _ in range(MAX_PLACEMENT_ATTEMPTS):
                rand_x = random.choice(PLAYER_ZONE_COLS)
                rand_y = random.randrange(self.grid_size[1])
                
                if is_area_clear(rand_x, rand_y, unit_size, self.grid_size[0], self.grid_size[1]):
                    unit["position"] = (rand_x, rand_y)
                    self.player_units.append(unit) # Add to final list
                    for dx in range(unit_size[0]):
                        for dy in range(unit_size[1]):
                            place_x, place_y = rand_x + dx, rand_y + dy
                            self.grid[place_x][place_y] = unit
                            occupied_cells.add((place_x, place_y))
                    placed = True
                    print(f"  Placed {unit['name']} at ({rand_x}, {rand_y})")
                    break # Stop trying for this unit
            
            if not placed:
                 print(f"Warning: Could not find placement for player unit {unit['name']} after {MAX_PLACEMENT_ATTEMPTS} attempts.")

        # Place Enemy Units
        print("Placing enemy units...")
        # Sort units by size (larger first) to prioritize placing large units first
        temp_enemy_units.sort(key=lambda unit: unit.get("size", (1, 1))[0] * unit.get("size", (1, 1))[1], reverse=True)
        
        for unit in temp_enemy_units:
            unit_size = unit.get("size", (1, 1))
            placed = False
            
            # Increase attempts for larger units
            max_attempts = MAX_PLACEMENT_ATTEMPTS
            if unit_size[0] > 1 or unit_size[1] > 1:
                max_attempts = MAX_PLACEMENT_ATTEMPTS * 2
            
            for _ in range(max_attempts):
                rand_x = random.choice(ENEMY_ZONE_COLS)
                rand_y = random.randrange(self.grid_size[1] - unit_size[1] + 1)  # Ensure it fits vertically
                
                # Ensure large units don't go out of bounds horizontally
                if rand_x + unit_size[0] > self.grid_size[0]:
                    continue
                
                if is_area_clear(rand_x, rand_y, unit_size, self.grid_size[0], self.grid_size[1]):
                    unit["position"] = (rand_x, rand_y)
                    self.enemy_units.append(unit) # Add to final list
                    for dx in range(unit_size[0]):
                        for dy in range(unit_size[1]):
                            place_x, place_y = rand_x + dx, rand_y + dy
                            self.grid[place_x][place_y] = unit
                            occupied_cells.add((place_x, place_y))
                    placed = True
                    print(f"  Placed {unit['name']} at ({rand_x}, {rand_y}) with size {unit_size}")
                    break # Stop trying for this unit
            
            if not placed:
                 print(f"Warning: Could not find placement for enemy unit {unit['name']} after {max_attempts} attempts.")

        # --- Rest of setup remains --- 
        self.turn = "player"
        self.current_phase = "select"
        self.battle_message = "Select a unit to begin your turn."
        
        print("--- Battle Initialization Complete ---")
        
    def handle_event(self, event):
        """Handle pygame events"""
        if event.type == pygame.MOUSEMOTION:
            # Check for button hover
            mouse_pos = pygame.mouse.get_pos()
            self.hover_button = None
            for button_info in self.buttons:
                if self.is_point_inside_rect(mouse_pos, button_info["rect"]):
                    self.hover_button = button_info["id"]
                    break
            
            # Highlight grid cell under mouse
            grid_pos = self._screen_to_grid(mouse_pos)
            if grid_pos:
                self.selected_cell = grid_pos
        
        elif event.type == pygame.MOUSEBUTTONDOWN:
            if event.button == 1:  # Left mouse button
                mouse_pos = pygame.mouse.get_pos()
                
                # Check for button clicks
                button_clicked = False
                for button_info in self.buttons:
                    if self.is_point_inside_rect(mouse_pos, button_info["rect"]):
                        self._handle_button_click(button_info["id"])
                        button_clicked = True
                        break
                
                # Handle grid clicks if no button was clicked
                if not button_clicked:
                    grid_pos = self._screen_to_grid(mouse_pos)
                    if grid_pos:
                        self._handle_grid_click(grid_pos)
        
        elif event.type == pygame.KEYDOWN:
            if event.key == pygame.K_ESCAPE:
                # Return to mission select
                self.game_manager.change_state(GameState.MISSION_SELECT)
            elif event.key == pygame.K_SPACE:
                # End turn if in select phase
                if self.current_phase == "select" and self.turn == "player":
                    self._end_player_turn()
            elif event.key == pygame.K_d:
                # Toggle debug mode
                self.debug_mode = not self.debug_mode
            elif event.key == pygame.K_F11:
                # Toggle fullscreen mode
                self.is_fullscreen = not self.is_fullscreen
                if self.is_fullscreen:
                    screen = pygame.display.set_mode((0, 0), pygame.FULLSCREEN)
                    print("Switched to fullscreen mode")
                else:
                    screen = pygame.display.set_mode((1280, 720))
                    print("Switched to windowed mode")
                self.last_action = f"Toggled fullscreen: {self.is_fullscreen}"
                
    def update(self):
        """Update the battle scene"""
        # If it's the enemy's turn, handle AI
        if self.turn == "enemy" and self.current_phase == "enemy":
            # Simple enemy AI
            self._run_enemy_ai()
            
    def render(self, screen):
        """Render the battle scene"""
        # Fill the background
        screen.fill(self.bg_color)
        
        screen_width, screen_height = screen.get_size()
        
        # Calculate grid offset to center the grid
        grid_width = self.grid_size[0] * CELL_SIZE
        grid_height = self.grid_size[1] * CELL_SIZE
        grid_offset_x = (screen_width - grid_width) // 2
        grid_offset_y = 100  # Leave space at top for UI
        
        # Adjust cell size if in fullscreen
        adjusted_cell_size = CELL_SIZE
        if self.is_fullscreen:
            # Scale up grid for fullscreen
            adjusted_cell_size = int(CELL_SIZE * 1.25)
            grid_width = self.grid_size[0] * adjusted_cell_size
            grid_height = self.grid_size[1] * adjusted_cell_size
            grid_offset_x = (screen_width - grid_width) // 2
        
        # Draw grid background
        grid_rect = pygame.Rect(grid_offset_x, grid_offset_y, grid_width, grid_height)
        pygame.draw.rect(screen, (60, 40, 40), grid_rect)
        
        # Track multi-tile units to draw separately
        multi_tile_units = []
        drawn_unit_ids = set()
        
        # Track Battle Suit cells for debugging
        battle_suit_cells_found = []
        battle_suit_origin_cell = None
        
        # Track units that need HP bars drawn
        units_to_draw_hp_for = []
        
        # First, identify and collect multi-tile units
        for y in range(self.grid_size[1]):
            for x in range(self.grid_size[0]):
                unit = self.grid[x][y]
                
                if unit:
                    unit_id = unit["id"]
                    unit_size = unit.get("size", (1, 1))
                    if unit_size[0] > 1 or unit_size[1] > 1:
                        # Check if this is the origin cell of the multi-tile unit
                        unit_pos = unit["position"]
                        if unit_pos == (x, y):
                            # This is the origin cell, add it to the list
                            multi_tile_units.append((unit, (x, y)))
                    
                    # Track Battle Suit cells specifically for debugging
                    if unit.get("unit_type") == "battle_suit" or unit.get("name") == "Tau Crisis Suit":
                        battle_suit_cells_found.append((x, y))
                        unit_pos = unit["position"]
                        if unit_pos == (x, y):
                            battle_suit_origin_cell = (x, y)
        
        # Draw move range
        for pos in self.move_range:
            x, y = pos
            cell_rect = pygame.Rect(
                grid_offset_x + x * adjusted_cell_size,
                grid_offset_y + y * adjusted_cell_size,
                adjusted_cell_size,
                adjusted_cell_size
            )
            pygame.draw.rect(screen, (100, 150, 250, 128), cell_rect)  # Blue with transparency
        
        # Draw attack range
        for pos in self.attack_range:
            x, y = pos
            cell_rect = pygame.Rect(
                grid_offset_x + x * adjusted_cell_size,
                grid_offset_y + y * adjusted_cell_size,
                adjusted_cell_size,
                adjusted_cell_size
            )
            pygame.draw.rect(screen, (250, 100, 100, 128), cell_rect)  # Red with transparency
        
        # Draw selected cell
        if self.selected_cell:
            x, y = self.selected_cell
            cell_rect = pygame.Rect(
                grid_offset_x + x * adjusted_cell_size,
                grid_offset_y + y * adjusted_cell_size,
                adjusted_cell_size,
                adjusted_cell_size
            )
            pygame.draw.rect(screen, (220, 220, 120), cell_rect, 3)  # Yellow outline
        
        # Draw regular (1x1) units
        for y in range(self.grid_size[1]):
            for x in range(self.grid_size[0]):
                unit = self.grid[x][y]
                
                if unit and unit["id"] not in drawn_unit_ids:
                    unit_id = unit["id"]
                    unit_pos = unit["position"]
                    unit_size = unit.get("size", (1, 1))
                    
                    # Skip multi-tile units - we'll draw them separately
                    if unit_size[0] > 1 or unit_size[1] > 1:
                        continue
                    
                    # Get or create unit sprite - adjust position and size for fullscreen
                    screen_pos = (
                        grid_offset_x + x * adjusted_cell_size,
                        grid_offset_y + y * adjusted_cell_size
                    )
                    
                    sprite_target_size = (adjusted_cell_size, adjusted_cell_size)
                    
                    # Get or create unit sprite
                    if unit_id not in self.unit_sprites or self.is_fullscreen:
                        sprite = UnitSprite(unit, screen_pos, target_size=sprite_target_size)
                        self.unit_sprites[unit_id] = sprite
                    else:
                        sprite = self.unit_sprites[unit_id]
                        sprite.set_position(screen_pos[0], screen_pos[1])
                    
                    # Set sprite properties
                    sprite.set_selected(self.selected_unit and unit["id"] == self.selected_unit["id"])
                    sprite.set_hovered(False)
                    
                    # Draw the sprite
                    sprite.draw(screen)
                    drawn_unit_ids.add(unit_id)
                    
                    # Store info needed to draw HP bar later
                    units_to_draw_hp_for.append({
                        'sprite': sprite,
                        'hp': unit["hp"],
                        'max_hp': unit["max_hp"]
                    })
        
        # Draw multi-tile units
        for unit, origin_pos in multi_tile_units:
            unit_id = unit["id"]
            if unit_id in drawn_unit_ids:
                continue
                
            # Special handling for multi-tile units
            x, y = origin_pos
            unit_size = unit.get("size", (1, 1))
            sprite_pixel_size = (adjusted_cell_size * unit_size[0], adjusted_cell_size * unit_size[1])
            
            # Calculate screen position (top-left corner of the multi-tile area)
            screen_pos = (
                grid_offset_x + x * adjusted_cell_size,
                grid_offset_y + y * adjusted_cell_size
            )
            
            # Special debug for multi-tile units
            multi_tile_name = unit.get("name", "Unknown")
            print(f"Drawing large unit: {unit.get('unit_type', 'unknown')}, Origin: ({x}, {y}), Size: {unit_size}, Screen pos: {screen_pos}")
            
            # Create sprite with appropriate size
            if unit_id not in self.unit_sprites or self.is_fullscreen:
                # Create new sprite using the exact screen position and size
                sprite = UnitSprite(unit, screen_pos, target_size=sprite_pixel_size)
                
                # Explicitly set the rectangle position to ensure correct placement
                if hasattr(sprite, 'rect'):
                    sprite.rect.topleft = screen_pos
                
                self.unit_sprites[unit_id] = sprite
            else:
                sprite = self.unit_sprites[unit_id]
                sprite.set_position(screen_pos[0], screen_pos[1])
                
                # Ensure rect position is correct
                if hasattr(sprite, 'rect'):
                    sprite.rect.topleft = screen_pos
            
            # Set sprite properties
            sprite.set_selected(self.selected_unit and unit["id"] == self.selected_unit["id"])
            sprite.set_hovered(False)
            
            # Draw the sprite
            sprite.draw(screen)
            drawn_unit_ids.add(unit_id)
            
            # Store info needed to draw HP bar later
            units_to_draw_hp_for.append({
                'sprite': sprite,
                'hp': unit["hp"],
                'max_hp': unit["max_hp"]
            })

        # After rendering all cells, output debug info about Battle Suit rendering
        print("\n==== BATTLE SUIT RENDERING VERIFICATION ====")
        print(f"Cells containing Battle Suit: {battle_suit_cells_found}")
        print(f"Battle Suit origin cell: {battle_suit_origin_cell}")
        print(f"Drawn unit IDs: {drawn_unit_ids}")
        if not battle_suit_cells_found:
            print("WARNING: No Battle Suit found in any grid cell during rendering!")
        elif not battle_suit_origin_cell:
            print("WARNING: Battle Suit found but no origin cell identified!")
        print("===========================================\n")
        
        # --- Pass 2: Draw HP bars on top of sprites --- 
        for hp_info in units_to_draw_hp_for:
            sprite = hp_info['sprite']
            hp_percent = hp_info['hp'] / hp_info['max_hp']
            
            # Adjust HP bar based on potentially larger sprite width
            bar_width = sprite.rect.width - 10 
            bar_height = 6
            
            # HP bar background - Position below the sprite rect
            bar_rect = pygame.Rect(
                sprite.rect.left + (sprite.rect.width - bar_width) // 2, # Centered below sprite
                sprite.rect.bottom + 3, 
                bar_width,
                bar_height
            )
            pygame.draw.rect(screen, (40, 40, 40), bar_rect)
            
            # HP bar fill
            if hp_percent > 0:
                # Determine color (assuming original logic is fine)
                if hp_percent > 0.6: 
                    hp_color = (0, 200, 0)
                elif hp_percent > 0.3: 
                    hp_color = (200, 200, 0)
                else: 
                    hp_color = (200, 0, 0)
                    
                fill_width = int(bar_width * hp_percent)
                fill_width = max(1, fill_width)
        
                fill_rect = pygame.Rect(bar_rect.left, bar_rect.top, fill_width, bar_height)
                pygame.draw.rect(screen, hp_color, fill_rect)
                pygame.draw.rect(screen, (20, 20, 20), fill_rect, 1)
        
        # Draw battle UI
        self._render_battle_ui(screen, grid_offset_x, grid_offset_y, grid_width, grid_height)
        
        # Draw debug info if enabled
        if self.debug_mode:
            debug_info = [
                f"Turn: {self.turn}",
                f"Phase: {self.current_phase}",
                f"Selected Unit: {self.selected_unit['name'] if self.selected_unit else 'None'}",
                f"Last Action: {self.last_action}",
                f"Screen Mode: {'Fullscreen' if self.is_fullscreen else 'Windowed'}"
            ]
            
            for i, info in enumerate(debug_info):
                self.draw_text(screen, info, (10, 10 + i * 20), "small", (255, 255, 255))
        
        # FINAL DIRECT DRAWING OF BATTLE SUIT (force display as last step)
        self._force_draw_battle_suit(screen, grid_offset_x, grid_offset_y, grid_width, grid_height, adjusted_cell_size)
    
    def _force_draw_battle_suit(self, screen, grid_x, grid_y, grid_width, grid_height, cell_size):
        """Force draw the battle suit directly"""
        # Find battle suit
        battle_suit = None
        battle_suit_pos = None
        
        for x in range(self.grid_size[0]):
            for y in range(self.grid_size[1]):
                unit = self.grid[x][y]
                if unit and (unit.get("unit_type") == "battle_suit" or unit.get("name") == "Tau Crisis Suit"):
                    # Check if this is the origin cell (top-left)
                    start_x, start_y = unit["position"]
                    if x == start_x and y == start_y:
                        battle_suit = unit
                        battle_suit_pos = (x, y)
                        break
            if battle_suit:
                break
        
        if battle_suit and battle_suit_pos:
            x, y = battle_suit_pos
            top_left_screen_x = grid_x + x * cell_size
            top_left_screen_y = grid_y + y * cell_size
            
            # Handle multi-tile unit (2x2)
            unit_width, unit_height = battle_suit.get("size", (2, 2))  # Default to 2x2
            pixel_width = unit_width * cell_size
            pixel_height = unit_height * cell_size
            
            # Draw the image
            try:
                # Try loading with absolute path
                abs_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../assets/images/units/tau/battle_suit.jpg"))
                
                # Try using relative path if absolute doesn't work
                rel_path = "game/assets/images/units/tau/battle_suit.jpg"
                
                if os.path.exists(abs_path):
                    try:
                        image = pygame.image.load(abs_path)
                        image = pygame.transform.scale(image, (pixel_width, pixel_height))
                        screen.blit(image, (top_left_screen_x, top_left_screen_y))
                    except Exception:
                        # Fall back to relative path
                        if os.path.exists(rel_path):
                            image = pygame.image.load(rel_path)
                            image = pygame.transform.scale(image, (pixel_width, pixel_height))
                            screen.blit(image, (top_left_screen_x, top_left_screen_y))
                        else:
                            # Fall back to alternative visualization
                            pygame.draw.rect(screen, (200, 100, 100), pygame.Rect(top_left_screen_x, top_left_screen_y, pixel_width, pixel_height))
                            
                            # Draw text
                            font = pygame.font.Font(None, 24)
                            text = font.render("Battle Suit", True, (0, 0, 0))
                            text_rect = text.get_rect(center=(top_left_screen_x + pixel_width//2, top_left_screen_y + pixel_height//2))
                            
                            # Draw text with shadow for visibility
                            shadow = font.render("Battle Suit", True, (255, 255, 255))
                            shadow_rect = shadow.get_rect(center=(text_rect.centerx+2, text_rect.centery+2))
                            screen.blit(shadow, shadow_rect)
                            screen.blit(text, text_rect)
                else:
                    # Try using relative path
                    if os.path.exists(rel_path):
                        image = pygame.image.load(rel_path)
                        image = pygame.transform.scale(image, (pixel_width, pixel_height))
                        screen.blit(image, (top_left_screen_x, top_left_screen_y))
                    else:
                        # Fall back to alternative visualization
                        pygame.draw.rect(screen, (200, 100, 100), pygame.Rect(top_left_screen_x, top_left_screen_y, pixel_width, pixel_height))
                        
                        # Draw text
                        font = pygame.font.Font(None, 24)
                        text = font.render("Battle Suit", True, (0, 0, 0))
                        text_rect = text.get_rect(center=(top_left_screen_x + pixel_width//2, top_left_screen_y + pixel_height//2))
                        
                        # Draw text with shadow for visibility
                        shadow = font.render("Battle Suit", True, (255, 255, 255))
                        shadow_rect = shadow.get_rect(center=(text_rect.centerx+2, text_rect.centery+2))
                        screen.blit(shadow, shadow_rect)
                        screen.blit(text, text_rect)
            except Exception:
                # Final fallback
                pygame.draw.rect(screen, (200, 100, 100), pygame.Rect(top_left_screen_x, top_left_screen_y, pixel_width, pixel_height))
                
                # Draw text
                font = pygame.font.Font(None, 24)
                text = font.render("Battle Suit", True, (0, 0, 0))
                text_rect = text.get_rect(center=(top_left_screen_x + pixel_width//2, top_left_screen_y + pixel_height//2))
                screen.blit(text, text_rect)
            
            # Draw health bar
            hp_percent = battle_suit["hp"] / battle_suit["max_hp"]
            bar_width = pixel_width
            bar_height = 8
            
            bar_rect = pygame.Rect(
                top_left_screen_x,
                top_left_screen_y + pixel_height + 5,
                bar_width,
                bar_height
            )
            pygame.draw.rect(screen, (40, 40, 40), bar_rect)
            
            if hp_percent > 0:
                hp_color = (0, 255, 0) if hp_percent > 0.6 else (255, 255, 0) if hp_percent > 0.3 else (255, 0, 0)
                fill_width = max(1, int(bar_width * hp_percent))
                fill_rect = pygame.Rect(bar_rect.left, bar_rect.top, fill_width, bar_height)
                pygame.draw.rect(screen, hp_color, fill_rect)
        
    def _render_battle_ui(self, screen, grid_x, grid_y, grid_width, grid_height):
        """Render the battle UI elements"""
        screen_width, screen_height = screen.get_size()
        
        # Reset the buttons list
        self.buttons = []
        
        # Draw turn indicator
        turn_text = f"Turn: {'Player' if self.turn == 'player' else 'Enemy'}"
        turn_pos = (grid_x + 20, 30)
        self.draw_text(screen, turn_text, turn_pos, "medium", (200, 200, 200))
        
        # Draw phase indicator
        phase_text = f"Phase: {self.current_phase.capitalize()}"
        phase_pos = (grid_x + grid_width - 150, 30)
        self.draw_text(screen, phase_text, phase_pos, "medium", (200, 200, 200))
        
        # Draw battle message
        message_pos = (screen_width // 2, 60)
        self.draw_text(screen, self.battle_message, message_pos, "medium", (255, 255, 255), True)
        
        # Draw units-that-can-act indicator
        if self.turn == "player":
            available_units = sum(1 for unit in self.player_units if unit["id"] not in self.acted_units)
            units_text = f"Units available: {available_units}/{len(self.player_units)}"
            units_pos = (grid_x + grid_width - 150, 60)
            self.draw_text(screen, units_text, units_pos, "small", (200, 200, 200))
        
        # Draw UI for selected unit
        if self.selected_unit:
            # Unit info panel
            info_panel_rect = pygame.Rect(grid_x, grid_y + grid_height + 20, grid_width, 100)
            pygame.draw.rect(screen, (60, 40, 40), info_panel_rect, border_radius=5)
            
            # Unit name and stats
            name_pos = (info_panel_rect.left + 20, info_panel_rect.top + 15)
            self.draw_text(
                screen, 
                f"{self.selected_unit['name']} - HP: {self.selected_unit['hp']}/{self.selected_unit['max_hp']}", 
                name_pos, 
                "medium", 
                (200, 200, 200)
            )
            
            # Unit abilities
            abilities_text = "Abilities: " + ", ".join(self.selected_unit["abilities"])
            abilities_pos = (info_panel_rect.left + 20, info_panel_rect.top + 45)
            self.draw_text(screen, abilities_text, abilities_pos, "small", (200, 200, 200))
            
            # Stats
            stats_text = f"ATK: {self.selected_unit['attack']} | DEF: {self.selected_unit['defense']} | Move: {self.selected_unit['move_range']} | Range: {self.selected_unit['attack_range']}"
            stats_pos = (info_panel_rect.left + 20, info_panel_rect.top + 75)
            self.draw_text(screen, stats_text, stats_pos, "small", (200, 200, 200))
            
            # Action buttons for player units during player turn
            if self.turn == "player" and self.selected_unit["id"].startswith("player_"):
                print(f"DEBUG: Drawing action buttons for {self.selected_unit['name']}")
                # Make buttons smaller
                button_width = 120
                button_height = 40
                button_spacing = 20
                
                # Position the action buttons at the bottom of the screen
                screen_width, screen_height = screen.get_size()
                center_x = screen_width // 2
                button_y = screen_height - 60  # Bottom of the screen
                
                # Check if the unit has already acted
                unit_has_acted = self.selected_unit["id"] in self.acted_units
                print(f"DEBUG: Unit has acted: {unit_has_acted}")
                print(f"DEBUG: Current phase: {self.current_phase}")
                
                # Different buttons based on the current phase
                if self.current_phase == "select":
                    print(f"DEBUG: Drawing MOVE and ATTACK buttons")
                    # Calculate button positions - center on screen bottom
                    move_button_x = center_x - button_width - button_spacing - button_width // 2
                    attack_button_x = center_x
                    ability_button_x = center_x + button_width + button_spacing // 2
                    
                    print(f"DEBUG: Move button position: ({move_button_x}, {button_y})")
                    print(f"DEBUG: Attack button position: ({attack_button_x}, {button_y})")
                    print(f"DEBUG: Ability button position: ({ability_button_x}, {button_y})")
                    
                    # Move button
                    move_button_rect = self.draw_button(
                        screen, "MOVE", 
                        (move_button_x, button_y),
                        (button_width, button_height),
                        hover=self.hover_button == "move" and not unit_has_acted,
                        centered=True,
                        bg_color=(80, 80, 160) if not unit_has_acted else (40, 40, 60)
                    )
                    
                    # Only add button to clickable list if unit hasn't acted
                    if not unit_has_acted:
                        self.buttons.append({"id": "move", "rect": move_button_rect})
                        print(f"DEBUG: Move button rect: {move_button_rect}")
                    
                    # Attack button
                    attack_button_rect = self.draw_button(
                        screen, "ATTACK", 
                        (attack_button_x, button_y),
                        (button_width, button_height),
                        hover=self.hover_button == "attack" and not unit_has_acted,
                        centered=True,
                        bg_color=(160, 80, 80) if not unit_has_acted else (40, 40, 60)
                    )
                    
                    # Only add button to clickable list if unit hasn't acted
                    if not unit_has_acted:
                        self.buttons.append({"id": "attack", "rect": attack_button_rect})
                        print(f"DEBUG: Attack button rect: {attack_button_rect}")
                    
                    # Ability button
                    ability_button_rect = self.draw_button(
                        screen, "ABILITY", 
                        (ability_button_x, button_y),
                        (button_width, button_height),
                        hover=self.hover_button == "ability" and not unit_has_acted,
                        centered=True,
                        bg_color=(80, 160, 80) if not unit_has_acted else (40, 40, 60)
                    )
                    
                    # Only add button to clickable list if unit hasn't acted
                    if not unit_has_acted:
                        self.buttons.append({"id": "ability", "rect": ability_button_rect})
                        print(f"DEBUG: Ability button rect: {ability_button_rect}")
                
                elif self.current_phase in ["move", "attack"]:
                    print(f"DEBUG: Drawing CANCEL button")
                    # Cancel button
                    cancel_button_rect = self.draw_button(
                        screen, "CANCEL", 
                        (center_x, button_y),
                        (button_width, button_height),
                        hover=self.hover_button == "cancel",
                        centered=True,
                        bg_color=(120, 80, 120)
                    )
                    self.buttons.append({"id": "cancel", "rect": cancel_button_rect})
        
        # End turn button (only shown during player turn)
        if self.turn == "player":
            end_turn_rect = self.draw_button(
                screen, "END TURN", 
                (screen_width - 80, screen_height - 60),
                (120, 40),
                hover=self.hover_button == "end_turn",
                centered=True
            )
            self.buttons.append({"id": "end_turn", "rect": end_turn_rect})
        
        # Fullscreen toggle button - always shown in the top-right corner
        fullscreen_text = "FULLSCREEN" if not self.is_fullscreen else "WINDOWED"
        fullscreen_button_rect = self.draw_button(
            screen, fullscreen_text,
            (screen_width - 80, 20),
            (140, 30),
            hover=self.hover_button == "toggle_fullscreen",
            centered=True,
            bg_color=(60, 50, 60)
        )
        self.buttons.append({"id": "toggle_fullscreen", "rect": fullscreen_button_rect})
    
    def _screen_to_grid(self, screen_pos):
        """Convert screen coordinates to grid coordinates"""
        screen_width, screen_height = pygame.display.get_surface().get_size()
        
        # Calculate adjusted cell size based on screen size (same as in render)
        adjusted_cell_size = CELL_SIZE
        if self.is_fullscreen:
            adjusted_cell_size = int(CELL_SIZE * 1.25)
        
        # Calculate grid offset - MUST MATCH RENDER METHOD EXACTLY
        grid_width = self.grid_size[0] * adjusted_cell_size
        grid_height = self.grid_size[1] * adjusted_cell_size
        grid_offset_x = (screen_width - grid_width) // 2
        grid_offset_y = 100  # Exactly 100px as in render method
        
        # Convert screen position to grid coordinates
        rel_x = screen_pos[0] - grid_offset_x
        rel_y = screen_pos[1] - grid_offset_y
        
        # Check if position is within the grid
        if 0 <= rel_x < grid_width and 0 <= rel_y < grid_height:
            grid_x = rel_x // adjusted_cell_size
            grid_y = rel_y // adjusted_cell_size
            
            if 0 <= grid_x < self.grid_size[0] and 0 <= grid_y < self.grid_size[1]:
                return (int(grid_x), int(grid_y))
        
        return None
    
    def _handle_grid_click(self, grid_pos):
        """Handle a click on the grid."""
        grid_x, grid_y = grid_pos
        
        # Ensure grid_pos is within bounds
        if not (0 <= grid_x < self.grid_size[0] and 0 <= grid_y < self.grid_size[1]):
            return
        
        # Debug click info
        print(f"Grid click: {grid_pos}")
        print(f"Current phase: {self.current_phase}")
        
        # Check what's at the clicked position
        target = self.grid[grid_x][grid_y]
        print(f"Target at {grid_pos}: {target['name'] if target else 'None'}")
        
        # SELECTING PHASE - directly select a unit
        if self.current_phase == "select":
            if target and target in self.player_units and target["id"] not in self.acted_units:
                self.selected_unit = target
                self.battle_message = f"Selected {target['name']}"
                print(f"Selected unit: {target['name']}")
                # Don't calculate move range yet, wait for move button
            else:
                if target and target in self.player_units and target["id"] in self.acted_units:
                    self.battle_message = f"{target['name']} has already acted this turn."
                elif target and target in self.enemy_units:
                    self.battle_message = "That's an enemy unit. Select one of your units."
                else:
                    self.battle_message = "No unit at that position."
                print(f"Cannot select: {self.battle_message}")
                
        elif self.current_phase == "move":
            # Player is selecting a unit or move destination
            if self.selected_unit:
                # Player has a unit selected, try to move it
                if grid_pos in self.move_range:
                    # Valid move destination
                    old_x, old_y = self.selected_unit["position"]
                    self.grid[old_x][old_y] = None
                    
                    # Support for multi-tile units - clear all old positions
                    unit_size = self.selected_unit.get("size", (1, 1))
                    is_multi_tile = unit_size[0] > 1 or unit_size[1] > 1
                    if is_multi_tile:
                        for dx in range(unit_size[0]):
                            for dy in range(unit_size[1]):
                                clear_x, clear_y = old_x + dx, old_y + dy
                                if 0 <= clear_x < self.grid_size[0] and 0 <= clear_y < self.grid_size[1]:
                                    if self.grid[clear_x][clear_y] == self.selected_unit:
                                        self.grid[clear_x][clear_y] = None
                    
                    # Set new position and update grid
                    self.selected_unit["position"] = grid_pos
                    self.grid[grid_x][grid_y] = self.selected_unit
                    
                    # For multi-tile units, fill all their cells
                    if is_multi_tile:
                        for dx in range(unit_size[0]):
                            for dy in range(unit_size[1]):
                                fill_x, fill_y = grid_x + dx, grid_y + dy
                                if 0 <= fill_x < self.grid_size[0] and 0 <= fill_y < self.grid_size[1]:
                                    # Only fill if the cell is empty
                                    if self.grid[fill_x][fill_y] is None:
                                        self.grid[fill_x][fill_y] = self.selected_unit
                    
                    self.battle_message = f"Unit moved to ({grid_x}, {grid_y})"
                    self.current_phase = "attack"  # Switch to attack phase after moving
                    self.selected_unit["has_moved"] = True
                    self.acted_units.add(self.selected_unit["id"])
                    # Clear move range to remove blue indicators
                    self.move_range = []
                    self.attack_range = self._calculate_attack_range(grid_pos, self.selected_unit["attack_range"])
                
                else:
                    # Player is selecting a unit
                    target = self.grid[grid_x][grid_y]
                    if target in self.player_units and target["id"] not in self.acted_units:
                        self.selected_unit = target
                        self.move_range = self._calculate_move_range(target["position"], target["move_range"])
                        self.battle_message = f"Selected {target['name']}"
        
        elif self.current_phase == "attack":
            # Player is selecting an attack target
            # Initialize the defeated_units_this_attack list
            defeated_units_this_attack = []
            
            if grid_pos in self.attack_range:
                target = self.grid[grid_x][grid_y]
                
                # Find the actual unit if it's a multi-tile unit
                if target is None:
                    # Check nearby cells to see if they're part of a multi-tile unit
                    for dx in range(-1, 2):
                        for dy in range(-1, 2):
                            check_x, check_y = grid_x + dx, grid_y + dy
                            if 0 <= check_x < self.grid_size[0] and 0 <= check_y < self.grid_size[1]:
                                check_target = self.grid[check_x][check_y]
                                if check_target in self.enemy_units:
                                    # Check if this is a multi-tile unit
                                    unit_size = check_target.get("size", (1, 1))
                                    if unit_size[0] > 1 or unit_size[1] > 1:
                                        # Verify this is actually part of the same unit
                                        unit_x, unit_y = check_target["position"]
                                        for udx in range(unit_size[0]):
                                            for udy in range(unit_size[1]):
                                                if unit_x + udx == grid_x and unit_y + udy == grid_y:
                                                    target = check_target
                                                    break
                                            if target:
                                                break
                                    if target:
                                        break
                            if target:
                                break
                
                if target in self.enemy_units:
                    damage = max(5, self.selected_unit["attack"] - target["defense"] // 2)
                    damage += random.randint(-3, 3)  # Add some randomness
                    
                    # Check if target is a multi-tile unit
                    unit_size = target.get("size", (1, 1))
                    is_multi_tile = unit_size[0] > 1 or unit_size[1] > 1
                    
                    # Apply damage
                    target["hp"] = max(0, target["hp"] - damage)
                    
                    # Remember defeated units for later clearing
                    defeated_units_this_attack = []
                    
                    # Check if unit is defeated
                    if target["hp"] <= 0:
                        defeated_units_this_attack.append(target)
                        
                        # Remove unit from list
                        if target in self.enemy_units:
                            self.enemy_units.remove(target)
                        
                        # Remove sprite
                        if target["id"] in self.unit_sprites:
                            del self.unit_sprites[target["id"]]
                        
                        # Clear all cells occupied by the unit
                        start_x, start_y = target["position"]
                        
                        if is_multi_tile:
                            # For multi-tile units, clear all cells
                            for dx in range(unit_size[0]):
                                for dy in range(unit_size[1]):
                                    clear_x, clear_y = start_x + dx, start_y + dy
                                    # Make sure we're in bounds
                                    if 0 <= clear_x < self.grid_size[0] and 0 <= clear_y < self.grid_size[1]:
                                        # Only clear if this cell contains our defeated unit
                                        if self.grid[clear_x][clear_y] == target:
                                            self.grid[clear_x][clear_y] = None
                                            print(f"Cleared cell at ({clear_x}, {clear_y}) of defeated multi-tile unit")
                            
                            self.battle_message = f"You defeated the {target['name']} (all blocks)!"
                        else:
                            # For regular units, just clear the one cell
                            self.grid[start_x][start_y] = None
                            self.battle_message = f"You defeated the {target['name']}!"
                        
                        # Check win condition
                        if not self.enemy_units:
                            self.battle_message = "Victory! All enemy units defeated."
                            # Return to mission select after 3 seconds
                            pygame.time.set_timer(pygame.USEREVENT, 3000)
                            pygame.event.post(pygame.event.Event(pygame.USEREVENT, {'action': 'victory'}))
                    else:
                        if is_multi_tile:
                            self.battle_message = f"You attacked the {target['name']} (all blocks) for {damage} damage!"
                        else:
                            self.battle_message = f"You attacked the {target['name']} for {damage} damage!"
                
                # End this unit's turn
                self.selected_unit = None
                self.move_range = []
                self.attack_range = []
                self.current_phase = "select"  # Change to "select" to allow selecting other units
                
                # Process any defeated units
                for defeated_unit in defeated_units_this_attack:
                    unit_size = defeated_unit.get("size", (1, 1))
                    is_multi_tile = unit_size[0] > 1 or unit_size[1] > 1
                    
                    if is_multi_tile:
                        # Double check we've cleared all cells for multi-tile units
                        start_x, start_y = defeated_unit["position"]
                        for dx in range(unit_size[0]):
                            for dy in range(unit_size[1]):
                                clear_x, clear_y = start_x + dx, start_y + dy
                                # Make sure we're in bounds
                                if 0 <= clear_x < self.grid_size[0] and 0 <= clear_y < self.grid_size[1]:
                                    # Only clear if this cell contains our defeated unit
                                    if self.grid[clear_x][clear_y] == defeated_unit:
                                        self.grid[clear_x][clear_y] = None
                                        print(f"Final check: Cleared cell at ({clear_x}, {clear_y}) of defeated multi-tile unit")
                
                # Check if all player units have acted
                all_acted = all(unit["id"] in self.acted_units for unit in self.player_units)
                if all_acted:
                    # End player turn, start enemy turn
                    self._start_enemy_turn()
            
            else:
                # Invalid attack target, switch back to select phase
                self.current_phase = "select"
                self.selected_unit = None
                self.attack_range = []
                self.battle_message = "Attack canceled"
        
        elif self.current_phase == "ability":
            # Player is selecting an ability target
            if grid_pos in self.attack_range:  # Using attack_range for ability targeting
                target = self.grid[grid_x][grid_y]
                
                # Find the actual unit if it's a multi-tile unit
                if target is None:
                    # Check nearby cells for multi-tile units
                    for dx in range(-1, 2):
                        for dy in range(-1, 2):
                            check_x, check_y = grid_x + dx, grid_y + dy
                            if 0 <= check_x < self.grid_size[0] and 0 <= check_y < self.grid_size[1]:
                                check_target = self.grid[check_x][check_y]
                                if check_target in self.enemy_units:
                                    unit_size = check_target.get("size", (1, 1))
                                    if unit_size[0] > 1 or unit_size[1] > 1:
                                        unit_x, unit_y = check_target["position"]
                                        for udx in range(unit_size[0]):
                                            for udy in range(unit_size[1]):
                                                if unit_x + udx == grid_x and unit_y + udy == grid_y:
                                                    target = check_target
                                                    break
                                            if target:
                                                break
                                    if target:
                                        break
                            if target:
                                break
                
                # Get the unit's abilities
                abilities = self.selected_unit.get("abilities", [])
                ability_name = abilities[0] if abilities else "Unknown Ability"
                
                if target in self.enemy_units:
                    # Apply special ability effect (more damage than regular attack)
                    damage = max(8, int(self.selected_unit["attack"] * 1.5) - target["defense"] // 3)
                    damage += random.randint(-2, 5)  # More randomness for abilities
                    
                    # Apply damage
                    target["hp"] = max(0, target["hp"] - damage)
                    
                    # Handle unit defeat if needed
                    if target["hp"] <= 0:
                        # Remove unit from list
                        self.enemy_units.remove(target)
                        
                        # Remove sprite
                        if target["id"] in self.unit_sprites:
                            del self.unit_sprites[target["id"]]
                        
                        # Clear from grid
                        start_x, start_y = target["position"]
                        unit_size = target.get("size", (1, 1))
                        is_multi_tile = unit_size[0] > 1 or unit_size[1] > 1
                        
                        if is_multi_tile:
                            # For multi-tile units, clear all cells
                            for dx in range(unit_size[0]):
                                for dy in range(unit_size[1]):
                                    clear_x, clear_y = start_x + dx, start_y + dy
                                    if 0 <= clear_x < self.grid_size[0] and 0 <= clear_y < self.grid_size[1]:
                                        if self.grid[clear_x][clear_y] == target:
                                            self.grid[clear_x][clear_y] = None
                            
                            self.battle_message = f"You used {ability_name} and defeated the {target['name']}!"
                        else:
                            # For regular units, just clear the one cell
                            self.grid[start_x][start_y] = None
                            self.battle_message = f"You used {ability_name} and defeated the {target['name']}!"
                        
                        # Check win condition
                        if not self.enemy_units:
                            self.battle_message = "Victory! All enemy units defeated."
                            pygame.time.set_timer(pygame.USEREVENT, 3000)
                            pygame.event.post(pygame.event.Event(pygame.USEREVENT, {'action': 'victory'}))
                    else:
                        self.battle_message = f"You used {ability_name} on {target['name']} for {damage} damage!"
                    
                    # End this unit's turn
                    self.acted_units.add(self.selected_unit["id"])
                    self.selected_unit = None
                    self.move_range = []
                    self.attack_range = []
                    self.current_phase = "select"
                    
                    # Check if all player units have acted
                    all_acted = all(unit["id"] in self.acted_units for unit in self.player_units)
                    if all_acted:
                        # End player turn, start enemy turn
                        self._start_enemy_turn()
                else:
                    # Invalid target
                    self.battle_message = f"Cannot use {ability_name} on this target."
            else:
                # Invalid ability target, switch back to select phase
                self.current_phase = "select"
                self.selected_unit = None
                self.attack_range = []
                self.battle_message = "Ability canceled"
    
    def _handle_button_click(self, button_id):
        """Handle clicks on UI buttons"""
        if button_id == "move":
            if self.selected_unit:
                # Check if unit has already moved or acted
                if self.selected_unit["id"] in self.acted_units:
                    self.battle_message = f"{self.selected_unit['name']} has already acted this turn."
                    return
                
                # Calculate move range
                self.current_phase = "move"
                self.move_range = self._calculate_move_range(self.selected_unit["position"], self.selected_unit["move_range"])
                self.battle_message = "Select destination."
                self.last_action = "Selected MOVE"
            else:
                self.battle_message = "No unit selected."
                self.last_action = "Attempted to MOVE with no unit selected"
                
        elif button_id == "attack":
            if self.selected_unit:
                # Check if unit has already acted
                if self.selected_unit["id"] in self.acted_units:
                    self.battle_message = f"{self.selected_unit['name']} has already acted this turn."
                    return
                
                # Calculate attack range
                self.current_phase = "attack"
                self.attack_range = self._calculate_attack_range(self.selected_unit["position"], self.selected_unit["attack_range"])
                self.battle_message = "Select target."
                self.last_action = "Selected ATTACK"
            else:
                self.battle_message = "No unit selected."
                self.last_action = "Attempted to ATTACK with no unit selected"
                
        elif button_id == "ability":
            if self.selected_unit:
                # Check if unit has already acted
                if self.selected_unit["id"] in self.acted_units:
                    self.battle_message = f"{self.selected_unit['name']} has already acted this turn."
                    return
                
                # Check if the unit has any abilities
                if "abilities" in self.selected_unit and self.selected_unit["abilities"]:
                    # Setup for ability selection
                    self.current_phase = "ability_select"
                    
                    # List available abilities (not used this turn)
                    if not hasattr(self.selected_unit, "used_abilities"):
                        self.selected_unit["used_abilities"] = set()
                    
                    # Filter out already used abilities
                    available_abilities = [ability for ability in self.selected_unit["abilities"] 
                        if ability not in self.selected_unit.get("used_abilities", set())]
                    
                    if not available_abilities:
                        self.battle_message = f"{self.selected_unit['name']} has used all abilities this turn."
                        return
                    
                    self.available_abilities = available_abilities
                    self.battle_message = "Select an ability to use."
                    self.last_action = "Selected ABILITY"
                else:
                    self.battle_message = f"{self.selected_unit['name']} has no special abilities."
                    self.last_action = "Unit has no abilities"
            else:
                self.battle_message = "No unit selected."
                self.last_action = "Attempted to use ABILITY with no unit selected"
        
        # Handle specific ability selections
        elif button_id.startswith("ability_"):
            ability_name = button_id[8:]  # Remove "ability_" prefix
            self.current_ability = ability_name
            
            # Special handling for command aura which doesn't need targeting
            if ability_name == "command_aura":
                # Buff all nearby allies
                self._activate_command_aura()
                # Mark ability as used
                if not hasattr(self.selected_unit, "used_abilities"):
                    self.selected_unit["used_abilities"] = set()
                self.selected_unit["used_abilities"].add(ability_name)
                
                # Mark unit as having acted
                self.acted_units.add(self.selected_unit["id"])
                
                # Return to select phase
                self.current_phase = "select"
                self.selected_unit = None
                return
                
            # Set up targeting for other abilities
            self.current_phase = "ability_target"
            
            # Calculate range for ability
            ability_range = self.selected_unit.get("ability_ranges", {}).get(ability_name, 3)  # Default range of 3
            self.attack_range = self._calculate_attack_range(self.selected_unit["position"], ability_range)
            
            self.battle_message = f"Select target for {ability_name.replace('_', ' ').title()}."
            self.last_action = f"Selected ability {ability_name}"
                
        elif button_id == "cancel":
            # Reset current phase
            self.current_phase = "select"
            
            # Clear ranges
            self.move_range = []
            self.attack_range = []
            
            self.battle_message = "Action canceled."
            self.last_action = "Canceled action"
            
        elif button_id == "end_turn":
            self._end_player_turn()
            
        elif button_id == "toggle_fullscreen":
            # Toggle fullscreen mode
            self.is_fullscreen = not self.is_fullscreen
            if self.is_fullscreen:
                screen = pygame.display.set_mode((0, 0), pygame.FULLSCREEN)
                print("Switched to fullscreen mode")
            else:
                screen = pygame.display.set_mode((1280, 720))
                print("Switched to windowed mode")
            self.last_action = f"Toggled fullscreen: {self.is_fullscreen}"

    def _start_enemy_turn(self):
        """Start the enemy turn sequence"""
        print("Starting enemy turn...")
        self.turn = "enemy"
        self.current_phase = "enemy"
        self.battle_message = "Enemy turn..."
        
        # Schedule enemy AI to run after a slight delay for visual feedback
        pygame.time.delay(300)  # Short delay before running AI
        self._run_enemy_ai()
        
    def _end_enemy_turn(self):
        """End the enemy turn and start player turn"""
        self.turn = "player"
        self.current_phase = "select"
        self.selected_unit = None
        self.battle_message = "Your turn. Select a unit to begin."
        self.last_action = "Enemy turn ended"
        # Reset acted units for next turn
        self.acted_units = set()
        
    def _end_player_turn(self):
        """End the player's turn and start enemy turn"""
        # Process buff durations for all player units
        for unit in self.player_units:
            if "buffs" in unit:
                # Process each buff
                buffs_to_remove = []
                for buff in unit["buffs"]:
                    # Decrement turns remaining
                    buff["turns_remaining"] -= 1
                    
                    # If buff expired, remove it and revert stats
                    if buff["turns_remaining"] <= 0:
                        unit["attack"] -= buff["attack_bonus"]
                        unit["defense"] -= buff["defense_bonus"]
                        buffs_to_remove.append(buff)
                
                # Remove expired buffs
                for buff in buffs_to_remove:
                    unit["buffs"].remove(buff)
                
                # If no buffs remain, remove the buffs list
                if not unit["buffs"]:
                    del unit["buffs"]
        
        self._start_enemy_turn()
        
    def _activate_command_aura(self):
        """Activate the Command Aura ability which buffs nearby allies"""
        # Get the position of the unit using the ability
        captain_pos = self.selected_unit["position"]
        
        # Define the aura range (3 tiles)
        aura_range = 3
        
        # Find allied units within range
        buffed_units = []
        for unit in self.player_units:
            # Skip the captain itself
            if unit["id"] == self.selected_unit["id"]:
                continue
                
            # Calculate distance to each unit
            unit_pos = unit["position"]
            distance = self._calculate_distance(captain_pos, unit_pos)
            
            if distance <= aura_range:
                # Initialize buffs list if it doesn't exist
                if "buffs" not in unit:
                    unit["buffs"] = []
                
                # Create the command aura buff
                command_buff = {
                    "name": "Command Aura",
                    "attack_bonus": 5,
                    "defense_bonus": 5,
                    "turns_remaining": 2  # Lasts for 2 turns
                }
                
                # Apply buff effects
                unit["attack"] += command_buff["attack_bonus"]
                unit["defense"] += command_buff["defense_bonus"]
                
                # Add buff to unit's buffs list
                unit["buffs"].append(command_buff)
                
                buffed_units.append(unit["name"])
        
        # Update battle message
        if buffed_units:
            unit_list = ", ".join(buffed_units)
            self.battle_message = f"Command Aura activated! Buffed units: {unit_list} (+5 ATK, +5 DEF for 2 turns)"
        else:
            self.battle_message = "Command Aura activated but no allies in range."

    def _calculate_distance(self, pos1, pos2):
        """Calculate Manhattan distance between two grid positions"""
        x1, y1 = pos1
        x2, y2 = pos2
        return abs(x1 - x2) + abs(y1 - y2)
        
    def _calculate_move_range(self, start_pos, move_range):
        """Calculate available move positions"""
        result = []
        to_check = [start_pos]
        checked = set()
        
        # Simple breadth-first search for move range
        while to_check and move_range > 0:
            new_to_check = []
            
            for pos in to_check:
                x, y = pos
                
                # Check all four directions
                for dx, dy in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
                    new_x, new_y = x + dx, y + dy
                    
                    # Skip if out of bounds
                    if not (0 <= new_x < self.grid_size[0] and 0 <= new_y < self.grid_size[1]):
                        continue
                    
                    # Skip if already checked
                    if (new_x, new_y) in checked:
                        continue
                    
                    # Skip if occupied
                    if self.grid[new_x][new_y] is not None:
                        continue
                    
                    # Add to result and check in next iteration
                    result.append((new_x, new_y))
                    new_to_check.append((new_x, new_y))
                    checked.add((new_x, new_y))
            
            to_check = new_to_check
            move_range -= 1
        
        return result
    
    def _calculate_attack_range(self, start_pos, attack_range):
        """Calculate available attack positions"""
        result = []
        x, y = start_pos
        
        # Check all cells within attack range (using Manhattan distance)
        for dx in range(-attack_range, attack_range + 1):
            for dy in range(-attack_range, attack_range + 1):
                if abs(dx) + abs(dy) <= attack_range:  # Manhattan distance
                    new_x, new_y = x + dx, y + dy
                    
                    # Skip if out of bounds
                    if not (0 <= new_x < self.grid_size[0] and 0 <= new_y < self.grid_size[1]):
                        continue
                    
                    # Skip the unit's own position
                    if new_x == x and new_y == y:
                        continue
                    
                    # Add to result
                    result.append((new_x, new_y))
        
        return result 