2. Laboratorio de CNN con FastAPI
En este laboratorio, vamos a utilizar el modelo de clasificación de imágenes de CIFAR-10 que creamos en el laboratorio anterior, y vamos a crear una API REST con FastAPI para poder hacer predicciones con el modelo.
1. Cargar el modelo creado en el laboratorio anterior
Crear el archivo main,py con el siguiente contenido:
``import ioimport os
import tensorflow as tf
import numpy as np
from PIL import Image
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
= FastAPI()
app
# Configurar CORS
= [
origins "http://localhost:5173", # Origen de tu frontend (Vite)
]
app.add_middleware(
CORSMiddleware,=origins,
allow_origins=True,
allow_credentials=["*"],
allow_methods=["*"],
allow_headers
)
# Cargar el modelo
= "cifar10_model.keras" # Actualiza esta ruta según sea necesario
model_path
if not os.path.exists(model_path):
raise FileNotFoundError(f"El archivo {model_path} no se encuentra. Verifica la ruta.")
= tf.keras.models.load_model(model_path)
model
# Lista de clases de CIFAR-10 con emojis
= [
class_names "Avión ✈️",
"Automóvil 🚗",
"Pájaro 🐦",
"Gato 🐱",
"Ciervo 🦌",
"Perro 🐶",
"Rana 🐸",
"Caballo 🐎",
"Barco ⛴️",
"Camión 🚚",
]
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
try:
# Leer el archivo subido
= await file.read()
contents = Image.open(io.BytesIO(contents))
image = image.resize((32, 32)) # Asegúrate de que la imagen tenga el tamaño correcto
image = np.array(image) / 255.0 # Normalizar la imagen
image = np.expand_dims(image, axis=0) # Añadir una dimensión para el batch
image
# Realizar la predicción
= model.predict(image)
predictions = np.argmax(predictions, axis=1)[0]
predicted_class_index
# Obtener el nombre de la clase correspondiente con emoji
= class_names[predicted_class_index]
predicted_class_name
return {"predicted_class": predicted_class_name}
except Exception as e:
return {"error": str(e)}
2. Ejecutar la API
Para ejecutar la API, ejecuta el siguiente comando:
uvicorn main:app --reload
3. Probar la API
Para probar la API vamos a utilizar la herramienta curl. Ejecuta el siguiente comando en una terminal:
curl -X POST "http://127.0.0.1:8000/predict/" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@<ruta de la imagen>"
Por ejemplo:
curl -X POST "http://127.0.0.1:8000/predict/" -F "file=@//home/statick/workspaces/practicas/fast_api/part2/avion.jpg"
En el caso de que la imagen sea un avión, deberías obtener una respuesta similar a la siguiente:
{“predicted_class”:“Avión ✈️”}%
Tip
Recuerda la lista de clases de CIFAR-10:
- “Avión ✈️”,
- “Automóvil 🚗”,
- “Pájaro 🐦”,
- “Gato 🐱”,
- “Ciervo 🦌”,
- “Perro 🐶”,
- “Rana 🐸”,
- “Caballo 🐎”,
- “Barco ⛴️”,
- “Camión 🚚”,