# NumPy stellt u.a. Matrixen, Vektoren und
# Methoden zum Lösen von LGS (np.linalg.solve)
# und zur Bestimmung von Inversen (np.linalg.inv)
# zur Verfügung.
import numpy as np
import math

# combinations erlaubt über alle k-elementigen Teilmengen
# einer gegebenen Menge zu iterieren
from itertools import combinations
from typing import List


# Diese Methode muss nicht angepasst werden.
def select_columns(A: np.ndarray, columns: List[int]):
    """
    Bestimme die Submatrix, die aus A entsteht, wenn man nur die
    gegebenen Spaltenindizes nimmt.
    :param A: Eine Matrix.
    :param columns: Eine Liste von Spaltenindizes.
    :return:
    """
    return A[:, columns]



# Diese Methode muss nicht angepasst werden.
def add_nonbasic_variables(A: np.ndarray, x: np.ndarray,
                           columns: List[int]):
    """
    Füge zu einer Lösung für die Basisvariablen 
    0-Einträge für die Nichtbasisvariablen hinzu.
    :param A: Die Matrix.
    :param x: Die Lösung für die Basisvariablen.
    :param columns: Die Spaltenindizes der Basisvariablen.
    :return:
    """
    result = np.zeros(A.shape[1], dtype=float)
    result[list(columns)] = x
    return result


# Diese Methode muss nicht angepasst werden.
def solve_linear_equations(A, b):
    """
    Löse das lineare Gleichungssystem Ax = b.
    :param A: Eine Matrix A.
    :param b: Ein Vektor b.
    :return: Den Lösungsvektor x, falls er existiert, sonst None.
    """
    try:
        return np.linalg.solve(A,b)
    except np.linalg.LinAlgError:
        return None

# Diese Methode soll angepasst werden.
# Muss hier eine Inverse bestimmt werden, oder geht es auch schneller?
def compute_basic_solution(A: np.ndarray, columns: List[int],
                           b: np.ndarray):
    """
    Diese Methode soll eine Basislösung zu den Spalten
    der gegebenen Indizes zur gegebenen Matrix A
    bestimmen, falls eine existiert. Andernfalls
    soll die Methode None zurückgeben.
    :param A: Eine zwei-dimensionale Matrix.
    :param columns: Die Spaltenindizes der Basis,
                    zu der die Basislösung bestimmt werden soll.
    :param b: Der b-Vektor.
    :return: None, falls keine Basislösung existiert.
             Andernfalls soll eine Basislösung zurückgegeben werden,
             die auch die Nicht-Basisvariablen enthält.
    """
    pass


# Diese Methode soll angepasst werden.
def basic_solution_is_feasible(x: np.ndarray):
    """
    Diese Methode soll für eine gegebene Basislösung überprüfen,
    ob sie zulässig ist.
    :param x:
    :return:
    """
    pass


# Diese Methode soll angepasst werden.
def feasible_basic_solutions(A, b):
    """
    Gebe eine Sequenz aller zulässigen Basislösungen zurück,
    indem alle potentiellen Basen ausprobiert werden.
    """
    num_rows = A.shape[0]
    all_column_indices = range(A.shape[1])
    for basic_indices in ???:
        x = compute_basic_solution(A, basic_indices, b)
        if x is not None and basic_solution_is_feasible(x):
            yield x


# Diese Methode muss nicht angepasst werden.
def naive_solve_lp(A: np.ndarray, b: np.ndarray, c: np.ndarray):
    """
    Löst das gegebene LP min c^Tx s.t. Ax = b, x >= 0,
    indem alle potentiellen Basen ausprobiert werden.
    :param A: Die Matrix A.
    :param b: Der Vektor b.
    :param c: Der Vektor c.
    :return: (x*, v), wobei x* die beste Lösung und v ihr Lösungswert ist.
    """
    # @ ist der Operator für Matrix-Multiplikation, c.T transponiert c.
    value_of = lambda x: c.T @ x
    best_x = min(feasible_basic_solutions(A, b),
                 default=None, key=value_of)
    return None if best_x is None else\
            (best_x, value_of(best_x))


# Falls die Datei nicht importiert wird, sondern das Hauptprogramm ist,
# starte den Lösungsvorgang für ein einfaches Beispiel.
if __name__ == '__main__':
    c = np.array([1, 0, 0, 0], dtype=np.float)
    A = np.array([[1, 0, 2, 1],
                  [0, 4, 0, 4],
                  [0, 0, 1, 0]], dtype=np.float)
    b = np.array([5, 7, 1], dtype=np.float)
    x, v = naive_solve_lp(A, b, c)
    print("Solution:\t", x)
    print("Value:\t\t", v)

