2. Laboratorio de CNN con FastAPI

En este laboratorio, vamos a utilizar el modelo de clasificación de imágenes de CIFAR-10 que creamos en el laboratorio anterior, y vamos a crear una API REST con FastAPI para poder hacer predicciones con el modelo.

1. Cargar el modelo creado en el laboratorio anterior

Crear el archivo main,py con el siguiente contenido:

``import io
import os
import tensorflow as tf
import numpy as np
from PIL import Image
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()

# Configurar CORS
origins = [
    "http://localhost:5173",  # Origen de tu frontend (Vite)
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Cargar el modelo
model_path = "cifar10_model.keras"  # Actualiza esta ruta según sea necesario

if not os.path.exists(model_path):
    raise FileNotFoundError(f"El archivo {model_path} no se encuentra. Verifica la ruta.")

model = tf.keras.models.load_model(model_path)

# Lista de clases de CIFAR-10 con emojis
class_names = [
    "Avión ✈️",
    "Automóvil 🚗",
    "Pájaro 🐦",
    "Gato 🐱",
    "Ciervo 🦌",
    "Perro 🐶",
    "Rana 🐸",
    "Caballo 🐎",
    "Barco ⛴️",
    "Camión 🚚",
]

@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
    try:
        # Leer el archivo subido
        contents = await file.read()
        image = Image.open(io.BytesIO(contents))
        image = image.resize((32, 32))  # Asegúrate de que la imagen tenga el tamaño correcto
        image = np.array(image) / 255.0  # Normalizar la imagen
        image = np.expand_dims(image, axis=0)  # Añadir una dimensión para el batch

        # Realizar la predicción
        predictions = model.predict(image)
        predicted_class_index = np.argmax(predictions, axis=1)[0]

        # Obtener el nombre de la clase correspondiente con emoji
        predicted_class_name = class_names[predicted_class_index]

        return {"predicted_class": predicted_class_name}
    except Exception as e:
        return {"error": str(e)}

2. Ejecutar la API

Para ejecutar la API, ejecuta el siguiente comando:

uvicorn main:app --reload

3. Probar la API

Para probar la API vamos a utilizar la herramienta curl. Ejecuta el siguiente comando en una terminal:

curl -X POST "http://127.0.0.1:8000/predict/" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@<ruta de la imagen>"

Por ejemplo:

curl -X POST "http://127.0.0.1:8000/predict/" -F "file=@//home/statick/workspaces/practicas/fast_api/part2/avion.jpg"

En el caso de que la imagen sea un avión, deberías obtener una respuesta similar a la siguiente:

{“predicted_class”:“Avión ✈️”}%

Tip

Recuerda la lista de clases de CIFAR-10:

  • “Avión ✈️”,
  • “Automóvil 🚗”,
  • “Pájaro 🐦”,
  • “Gato 🐱”,
  • “Ciervo 🦌”,
  • “Perro 🐶”,
  • “Rana 🐸”,
  • “Caballo 🐎”,
  • “Barco ⛴️”,
  • “Camión 🚚”,