from flask import Flask, request, render_template, redirect, url_for
import torch
from torchvision import transforms
from PIL import Image
import timm
import json
import requests
import os
 
app = Flask(__name__)
UPLOAD_FOLDER = 'static/uploads'
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
 
# Folder na załadowane obrazki
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
 
# Wczytaj model Swin Transformer z biblioteki timm
# ta wersja jest wyuczona na zbiorze ImageNet-21k
nazwa_modelu = "swin_large_patch4_window7_224"
klasyfikator = timm.create_model(nazwa_modelu, pretrained=True)
#klasyfikator.eval&#40;&#41;
 
# Pobierz zbiór ImageNet (1000 klas)
imagenet_labels_url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
labels = json.loads(requests.get(imagenet_labels_url).text)
 
# Potok przetwarzania obrazu na potrzeby klasyfikatora
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
 
def recognize_image(img_path):
    """
   Wczytuje obraz z pliku, przetwarza zgodnie z potokiem preprocess a na koniec klasyfikuje
   """
    img = Image.open(img_path).convert("RGB")
    # Dodatkowy wymiar na wejściu
    img_tensor = preprocess(img).unsqueeze(0)
 
    # Predykcja, czyli przypisanie prawdopodobieństw klas do obrazu wejściowego
    with torch.no_grad():
        logits = klasyfikator(img_tensor)
        probabilities = torch.nn.functional.softmax(logits, dim=-1)
        top5_prob, top5_catid = torch.topk(probabilities, 5)
 
    # Przeliczenie prawdopodobieństw na etykiety klas
    results = [(labels[catid], prob.item()) for catid, prob in zip(top5_catid[0], top5_prob[0])]
    return results
 
@app.route('/')
def index():
    return render_template("index.html")
 
@app.route('/', methods=['POST'])
def classify_image():
    if 'image' not in request.files:
        return redirect(url_for('index'))
 
    image = request.files['image']
    if image.filename == '':
        return redirect(url_for('index'))
 
    try:
        # Zapisz wczytany obraz do katalogu UPLOAD_FOLDER
        image_path = os.path.join(app.config['UPLOAD_FOLDER'], image.filename)
        image.save(image_path)
 
        # Klasyfikacja
        wyjscie_klasyfikatora = recognize_image(image_path)
        return render_template("result.html", image_path=image_path, predictions=wyjscie_klasyfikatora)
    except Exception as e:
        return str(e), 500
 
if __name__ == "__main__":
    app.run(debug=True)
 