柚子快報邀請碼778899分享:LSTM中文新聞分類源碼詳解
柚子快報邀請碼778899分享:LSTM中文新聞分類源碼詳解
LSTM中文新聞分類
一、導包
二、讀取數據
三、數據預處理
1.分詞、去掉停用詞和數字、字母轉換成小寫等
2.新聞文本標簽數值化
三、創(chuàng)建詞匯表/詞典
1.data.Field()
2.空格切分等
3.構建詞匯表/詞典
使用訓練集構建單詞表,vectors=None:沒有使用預訓練好的詞向量,而是使用的是隨機初始化的詞向量,默認是100維
這里面的20002,多的那兩個應該是
四、構造數據集迭代器,方便批處理
batch.cutword[0]和batch.cutword[1]
batch.cutword[0]:表示的是一批數據也就是64條新聞,每條新聞都會被分詞,分成一個一個的詞語,每個詞語在詞典中的索引,最后面的1表示的是不足400,填充的
batch.cutword[1]:表示的是一批數據也就是64條新聞,每條新聞對應所有新聞中的索引號。
五、搭建LSTM網絡
r_out, (h_n, h_c)分別是:
r_out是最終輸出結果y(根據今天,昨天和日記)
h_n是隱藏層的輸出結果s(根據昨天)
h_c是長期信息的輸出結果c(根據日記)
六、LSTM網絡的訓練
七、LSTM網絡的測試
一、導包
%config InlineBackend.figure_format = 'retina'
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.font_manager import FontProperties
fonts = FontProperties(fname = "/Library/Fonts/華文細黑.ttf")
import re
import string
import copy
import time
from sklearn.metrics import accuracy_score,confusion_matrix
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
import jieba
jieba.setLogLevel(jieba.logging.INFO)
from torchtext.legacy import data
from torchtext.vocab import Vectors
#從 PyTorch 的拓展庫 torchtext 中導入了 Vectors 類,該類用于處理詞向量(word embeddings)
二、讀取數據
train_df = pd.read_csv("data/lstm/cnews/cnews.train.txt",sep="\t",
header=None,names = ["label","text"])
val_df = pd.read_csv("data/lstm/cnews/cnews.val.txt",sep="\t",
header=None,names = ["label","text"])
test_df = pd.read_csv("data/lstm/cnews/cnews.test.txt",sep="\t",
header=None,names = ["label","text"])
train_df.head(5)
三、數據預處理
stop_words = pd.read_csv("data/lstm/cnews/中文停用詞庫.txt",
header=None,names = ["text"])
1.分詞、去掉停用詞和數字、字母轉換成小寫等
## 對中文文本數據進行預處理,去除一些不需要的字符,分詞,去停用詞,等操作
def chinese_pre(text_data):
## 字母轉化為小寫,去除數字,
text_data = text_data.lower()
text_data = re.sub("\d+", "", text_data)
## 分詞,使用精確模式
text_data = list(jieba.cut(text_data,cut_all=False))
## 去停用詞和多余空格
text_data = [word.strip() for word in text_data if word not in stop_words.text.values]
## 處理后的詞語使用空格連接為字符串
text_data = " ".join(text_data)
return text_data
train_df["cutword"] = train_df.text.apply(chinese_pre)
val_df["cutword"] = val_df.text.apply(chinese_pre)
test_df["cutword"] = test_df.text.apply(chinese_pre)
## 預處理后的結果保存為新的文件
train_df[["label","cutword"]].to_csv("data/lstm/cnews_train.csv",index=False)
val_df[["label","cutword"]].to_csv("data/lstm/cnews_val.csv",index=False)
test_df[["label","cutword"]].to_csv("data/lstm/cnews_test.csv",index=False)
train_df.cutword.head()
train_df = pd.read_csv("data/lstm/cnews_train.csv")
val_df = pd.read_csv("data/lstm/cnews_val.csv")
test_df = pd.read_csv("data/lstm/cnews_test.csv")
2.新聞文本標簽數值化
labelMap = {
"體育": 0,"娛樂": 1,"家居": 2,"房產": 3,"教育": 4,
柚子快報邀請碼778899分享:LSTM中文新聞分類源碼詳解
推薦鏈接
本文內容根據網絡資料整理,出于傳遞更多信息之目的,不代表金鑰匙跨境贊同其觀點和立場。
轉載請注明,如有侵權,聯(lián)系刪除。