使用Rust的Linfa和Polars库进行机器学习:线性回归

 互联网资讯   2024-03-01 12:01   25 人阅读  0 条评论
2024-03-01,

在这篇文章中,我们将使用Rust的Linfa库和Polars库来实现机器学习中的线性回归算法。

Linfa crate旨在提供一个全面的工具包来使用Rust构建机器学习应用程序。

Polars是Rust的一个DataFrame库,它基于Apache Arrow的内存模型。Apache arrow提供了非常高效的列数据结构,并且正在成为列数据结构事实上的标准。

在下面的例子中,我们使用一个糖尿病数据集来训练线性回归算法。

使用以下命令创建一个Rust新项目:

cargo new machine_learning_linfa

在Cargo.toml文件中加入以下依赖项:

[dependencies] linfa = "0.7.0" linfa-linear = "0.7.0" ndarray = "0.15.6" Polars = { version = "0.35.4", features = ["ndarray"]}

在项目根目录下创建一个diabetes_file.csv文件,将数据集写入文件。

AGE SEX BMI BP S1 S2 S3 S4 S5 S6 Y 59 2 32.1 101 157 93.2 38 4 4.8598 87 151 48 1 21.6 87 183 103.2 70 3 3.8918 69 75 72 2 30.5 93 156 93.6 41 4 4.6728 85 141 24 1 25.3 84 198 131.4 40 5 4.8903 89 206 50 1 23 101 192 125.4 52 4 4.2905 80 135 23 1 22.6 89 139 64.8 61 2 4.1897 68 97 36 2 22 90 160 99.6 50 3 3.9512 82 138 66 2 26.2 114 255 185 56 4.55 4.2485 92 63 60 2 32.1 83 179 119.4 42 4 4.4773 94 110 .............

数据集从这里下载:https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt

在src/main.rs文件中写入以下代码:

use linfa::prelude::*; use linfa::traits::Fit; use linfa_linear::LinearRegression; use ndarray::{ArrayBase, OwnedRepr}; use polars::prelude::*; // Import polars fn main() -> Result<(), Box<dyn std::error::Error>> { // 将制表符定义为分隔符 let separator = b'\t'; let df = polars::prelude::CsvReader::from_path("./diabetes_file.csv")? .infer_schema(None) .with_separator(separator) .has_header(true) .finish()?; println!("{:?}", df); // 提取并转换目标列 let age_series = df.column("AGE")?.cast(&DataType::Float64)?; let target = age_series.f64()?; println!("Creating features dataset"); let mut features = df.drop("AGE")?; // 遍历列并将每个列强制转换为Float64 for col_name in features.get_column_names_owned() { let casted_col = df .column(&col_name)? .cast(&DataType::Float64) .expect("Failed to cast column"); features.with_column(casted_col)?; } println!("{:?}", df); let features_ndarray: ArrayBase<OwnedRepr<_>, _> = features.to_ndarray::<Float64Type>(IndexOrder::C)?; let target_ndarray = target.to_ndarray()?.to_owned(); let (dataset_training, dataset_validation) = Dataset::new(features_ndarray, target_ndarray).split_with_ratio(0.80); // 训练模型 let model = LinearRegression::default().fit(&dataset_training)?; // 预测 let pred = model.predict(&dataset_validation); // 评价模型 let r2 = pred.r2(&dataset_validation)?; println!("r2 from prediction: {}", r2); Ok(()) }

使用polar的CSV reader读取CSV文件。 将数据帧打印到控制台以供检查。 从DataFrame中提取“AGE”列作为线性回归的目标变量。将目标列强制转换为Float64(双精度浮点数),这是机器学习中数值数据的常用格式。 将features DataFrame转换为narray::ArrayBase(一个多维数组)以与linfa兼容。将目标序列转换为数组,这些数组与用于机器学习的linfa库兼容。 使用80-20的比例将数据集分割为训练集和验证集,这是机器学习中评估模型在未知数据上的常见做法。 使用linfa的线性回归算法在训练数据集上训练线性回归模型。 使用训练好的模型对验证数据集进行预测。 计算验证数据集上的R²(决定系数)度量,以评估模型的性能。R²值表示回归预测与实际数据点的近似程度。

执行cargo run,运行结果如下:

shape: (442, 11) ┌─────┬─────┬──────┬───────┬───┬──────┬────────┬─────┬─────┐ │ AGE ┆ SEX ┆ BMI ┆ BP ┆ … ┆ S4 ┆ S5 ┆ S6 ┆ Y │ │ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ f64 ┆ f64 ┆ ┆ f64 ┆ f64 ┆ i64 ┆ i64 │ ╞═════╪═════╪══════╪═══════╪═══╪══════╪════════╪═════╪═════╡ │ 59 ┆ 2 ┆ 32.1 ┆ 101.0 ┆ … ┆ 4.0 ┆ 4.8598 ┆ 87 ┆ 151 │ │ 48 ┆ 1 ┆ 21.6 ┆ 87.0 ┆ … ┆ 3.0 ┆ 3.8918 ┆ 69 ┆ 75 │ │ 72 ┆ 2 ┆ 30.5 ┆ 93.0 ┆ … ┆ 4.0 ┆ 4.6728 ┆ 85 ┆ 141 │ │ 24 ┆ 1 ┆ 25.3 ┆ 84.0 ┆ … ┆ 5.0 ┆ 4.8903 ┆ 89 ┆ 206 │ │ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │ │ 47 ┆ 2 ┆ 24.9 ┆ 75.0 ┆ … ┆ 5.0 ┆ 4.4427 ┆ 102 ┆ 104 │ │ 60 ┆ 2 ┆ 24.9 ┆ 99.67 ┆ … ┆ 3.77 ┆ 4.1271 ┆ 95 ┆ 132 │ │ 36 ┆ 1 ┆ 30.0 ┆ 95.0 ┆ … ┆ 4.79 ┆ 5.1299 ┆ 85 ┆ 220 │ │ 36 ┆ 1 ┆ 19.6 ┆ 71.0 ┆ … ┆ 3.0 ┆ 4.5951 ┆ 92 ┆ 57 │ └─────┴─────┴──────┴───────┴───┴──────┴────────┴─────┴─────┘ Creating features dataset shape: (442, 11) ┌─────┬─────┬──────┬───────┬───┬──────┬────────┬─────┬─────┐ │ AGE ┆ SEX ┆ BMI ┆ BP ┆ … ┆ S4 ┆ S5 ┆ S6 ┆ Y │ │ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ f64 ┆ f64 ┆ ┆ f64 ┆ f64 ┆ i64 ┆ i64 │ ╞═════╪═════╪══════╪═══════╪═══╪══════╪════════╪═════╪═════╡ │ 59 ┆ 2 ┆ 32.1 ┆ 101.0 ┆ … ┆ 4.0 ┆ 4.8598 ┆ 87 ┆ 151 │ │ 48 ┆ 1 ┆ 21.6 ┆ 87.0 ┆ … ┆ 3.0 ┆ 3.8918 ┆ 69 ┆ 75 │ │ 72 ┆ 2 ┆ 30.5 ┆ 93.0 ┆ … ┆ 4.0 ┆ 4.6728 ┆ 85 ┆ 141 │ │ 24 ┆ 1 ┆ 25.3 ┆ 84.0 ┆ … ┆ 5.0 ┆ 4.8903 ┆ 89 ┆ 206 │ │ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │ │ 47 ┆ 2 ┆ 24.9 ┆ 75.0 ┆ … ┆ 5.0 ┆ 4.4427 ┆ 102 ┆ 104 │ │ 60 ┆ 2 ┆ 24.9 ┆ 99.67 ┆ … ┆ 3.77 ┆ 4.1271 ┆ 95 ┆ 132 │ │ 36 ┆ 1 ┆ 30.0 ┆ 95.0 ┆ … ┆ 4.79 ┆ 5.1299 ┆ 85 ┆ 220 │ │ 36 ┆ 1 ┆ 19.6 ┆ 71.0 ┆ … ┆ 3.0 ┆ 4.5951 ┆ 92 ┆ 57 │ └─────┴─────┴──────┴───────┴───┴──────┴────────┴─────┴─────┘ r2 from prediction: 0.15937814745521017

对于优先考虑快速迭代和快速原型的数据科学家来说,Rust的编译时间可能是令人头疼的问题。Rust的强静态类型系统虽然有利于确保类型安全和减少运行时错误,但也会在编码过程中增加一层复杂性。

PS:本文来源:使用Rust的Linfa和Polars库进行机器学习:线性回归,Rust,Polars,机器学习,人工智能,作者:李明

版权声明:本文内容源于互联网搬运整理,仅限于小范围内传播学习和文献参考,不代表本站观点,请在下载后24小时内删除,如果有侵权之处请第一时间联系我们删除。敬请谅解! E-mail:c#seox.cn(#修改为@)


CRM论坛:CRM论坛(CRMBBS.COM)始办于2019年,是致力于✅CRM实施方案✅免费CRM软件✅SCRM系统✅客户管理系统的垂直内容社区网站,CRM论坛持续专注于CRM领域,在不断深化理解CRM系统的同时,进一步利用新型互联网技术,为用户实现企业、客户、合作伙伴与产品之间的无缝连接与交互。

评论已关闭!