import os
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import tensorflow as tf
# 1. PREPARAR LOS DATOS
base_dir = 'archive' # Carpeta con las 6 clases de frutas
# Crear generadores con preprocesamiento para MobileNetV2
datagen = ImageDataGenerator(
preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.1,
zoom_range=0.1,
horizontal_flip=True,
validation_split=0.2
)
# Generadores de datos
train_generator = datagen.flow_from_directory(
base_dir,
target_size=(224, 224),
batch_size=32,
class_mode='categorical',
subset='training'
)
validation_generator = datagen.flow_from_directory(
base_dir,
target_size=(224, 224),
batch_size=32,
class_mode='categorical',
subset='validation'
)
# 2. CREAR MODELO CON TRANSFER LEARNING
# Cargar modelo base pre-entrenado
base_model = MobileNetV2(weights='imagenet', include_top=False,
input_shape=(224, 224, 3))
# Congelar el modelo base
for layer in base_model.layers:
layer.trainable = False
# Construir modelo completo
model = Sequential([
base_model,
GlobalAveragePooling2D(),
Dense(256, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.001)),
Dropout(0.4),
Dense(6, activation='softmax') # 6 clases
])
# 3. COMPILAR
model.compile(
optimizer=Adam(learning_rate=0.001),
loss='categorical_crossentropy',
metrics=['accuracy']
)
# 4. CALLBACKS
checkpoint = ModelCheckpoint(
'mejor_modelo_frutas_transfer.h5',
monitor='val_accuracy',
save_best_only=True,
mode='max'
)
early_stopping = EarlyStopping(
monitor='val_loss',
patience=5,
restore_best_weights=True
)
# 5. ENTRENAR
history = model.fit(
train_generator,
steps_per_epoch=train_generator.samples // 32,
validation_data=validation_generator,
validation_steps=validation_generator.samples // 32,
epochs=20,
callbacks=[early_stopping, checkpoint]
)
# 6. EVALUAR
evaluation = model.evaluate(validation_generator)
print(f"Validation Accuracy: {evaluation[1]}")