774822 / train.py
mse0306's picture
Update train.py
a16a0bf verified
raw
history blame contribute delete
391 Bytes
import tensorflow as tf
from model import create_model
import numpy as np
def load_data():
return np.random.rand(1000,64,64,3).astype(np.float32), np.random.randint(10, size=1000)
def main():
train_images, train_labels=load_data()
model = create_model()
model.fit(train_images, train_labels, epochs=5)
model.save('my_model.keras')
if __name__ == "__main__":
main()