深度阅读

lightning-flash 预测泰坦尼克号信息

作者
作者
2023年08月22日
更新时间
13.11 分钟
阅读时间
0
阅读量

lightning-flash

PyTorchLightning/lightning-flash:用于快速原型设计、基线、微调和解决深度学习问题的任务集合。

https://github.com/PyTorchLightning/lightning-flash

文档在这里
https://lightning-flash.readthedocs.io/en/latest/reference/tabular_classification.html

https://lightning-flash.readthedocs.io/

还是很不错的

from torchmetrics.classification import Accuracy, Precision, Recall
import flash
from flash.core.data.utils import download\_data
from flash.tabular import TabularClassifier, TabularData</p>

<h1>1. Download the data</h1>

download\_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')

<h1>2. Load the data</h1>

datamodule = TabularData.from\_csv(
 ["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
 "Fare",
 target\_fields="Survived",
 train\_file="./data/titanic/titanic.csv",
 test\_file="./data/titanic/test.csv",
 val\_split=0.25,
)

<h1>3. Build the model</h1>

model = TabularClassifier.from\_data(datamodule, metrics=[Accuracy(), Precision(), Recall()])

<h1>4. Create the trainer. Run 10 times on data</h1>

trainer = flash.Trainer(max\_epochs=10)

<h1>5. Train the model</h1>

trainer.fit(model, datamodule=datamodule)

<h1>6. Test model</h1>

trainer.test()

<h1>7. Predict!</h1>

predictions = model.predict("data/titanic/titanic.csv")
print(predictions)

相关标签

博客作者

热爱技术,乐于分享,持续学习。专注于Web开发、系统架构设计和人工智能领域。