from flask import Flask, request, render_template, redirect, url_for import torch from torchvision import transforms from PIL import Image import timm import json import requests import os app = Flask(__name__) UPLOAD_FOLDER = 'static/uploads' app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER # Folder na załadowane obrazki os.makedirs(UPLOAD_FOLDER, exist_ok=True) # Wczytaj model Swin Transformer z biblioteki timm # ta wersja jest wyuczona na zbiorze ImageNet-21k nazwa_modelu = "swin_large_patch4_window7_224" klasyfikator = timm.create_model(nazwa_modelu, pretrained=True) #klasyfikator.eval() # Pobierz zbiór ImageNet (1000 klas) imagenet_labels_url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json" labels = json.loads(requests.get(imagenet_labels_url).text) # Potok przetwarzania obrazu na potrzeby klasyfikatora preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ]) def recognize_image(img_path): """ Wczytuje obraz z pliku, przetwarza zgodnie z potokiem preprocess a na koniec klasyfikuje """ img = Image.open(img_path).convert("RGB") # Dodatkowy wymiar na wejściu img_tensor = preprocess(img).unsqueeze(0) # Predykcja, czyli przypisanie prawdopodobieństw klas do obrazu wejściowego with torch.no_grad(): logits = klasyfikator(img_tensor) probabilities = torch.nn.functional.softmax(logits, dim=-1) top5_prob, top5_catid = torch.topk(probabilities, 5) # Przeliczenie prawdopodobieństw na etykiety klas results = [(labels[catid], prob.item()) for catid, prob in zip(top5_catid[0], top5_prob[0])] return results @app.route('/') def index(): return render_template("index.html") @app.route('/', methods=['POST']) def classify_image(): if 'image' not in request.files: return redirect(url_for('index')) image = request.files['image'] if image.filename == '': return redirect(url_for('index')) try: # Zapisz wczytany obraz do katalogu UPLOAD_FOLDER image_path = os.path.join(app.config['UPLOAD_FOLDER'], image.filename) image.save(image_path) # Klasyfikacja wyjscie_klasyfikatora = recognize_image(image_path) return render_template("result.html", image_path=image_path, predictions=wyjscie_klasyfikatora) except Exception as e: return str(e), 500 if __name__ == "__main__": app.run(debug=True)