import math
from qiskit import QuantumCircuit
from qiskit.circuit.library import grover_operator, MCMTGate, ZGate
from qiskit.visualization import plot_distribution
from qiskit.primitives import StatevectorSampler
from qiskit_algorithms import Grover, AmplificationProblem
import matplotlib.pyplot as plt

def grover_oracle(marked_states):
    """Build oracle that marks one or more target states"""
    if not isinstance(marked_states, list):
        marked_states = [marked_states]
    
    num_qubits = len(marked_states[0])
    qc = QuantumCircuit(num_qubits)
    
    for target in marked_states:
        # Flip qubits where target has '0'
        for i, bit in enumerate(reversed(target)):
            if bit == '0':
                qc.x(i)
        # Apply multi-controlled Z gate to mark the state
        qc.compose(
            MCMTGate(ZGate(), num_qubits - 1, 1).definition,
            inplace=True
        )
        # Undo the flips
        for i, bit in enumerate(reversed(target)):
            if bit == '0':
                qc.x(i)
    return qc

# Search for some number of states among 8 possibilities (3 qubits)
marked_states = ["011"]
oracle = grover_oracle(marked_states)

# Optimal iterations = π/4 * √(N/solutions)
num_qubits = 3
num_solutions = len(marked_states)
iterations = math.floor(math.pi / 4 * math.sqrt(2**num_qubits / num_solutions))
print(f"Optimal Grover iterations: {iterations}")

# Build and run
problem = AmplificationProblem(oracle, is_good_state=marked_states)
grover = Grover(iterations=iterations, sampler=StatevectorSampler())
result = grover.amplify(problem)

print("Answer found:", result.top_measurement)

counts = result.circuit_results[0]
all_8_states = {format(i, '03b'): counts.get(format(i, '03b'), 0) for i in range(8)}
plot_distribution(all_8_states, title="Grover Search: All Basis States")
plt.show()
